source: etherws/trunk/etherws.py @ 267

Revision 267, 41.9 KB checked in by atzm, 11 years ago (diff)

fixed pep8 warnings

  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.5
8#   - python-pytun-2.1
9#   - websocket-client-0.12.0
10#   - tornado-2.4
11#
12# ===========================================================================
13# Copyright (c) 2012-2013, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
42import ssl
43import time
44import json
45import fcntl
46import base64
47import socket
48import struct
49import urllib2
50import hashlib
51import getpass
52import operator
53import argparse
54import traceback
55
56import tornado
57import websocket
58
59from tornado.web import Application, RequestHandler
60from tornado.websocket import WebSocketHandler
61from tornado.httpserver import HTTPServer
62from tornado.ioloop import IOLoop
63
64from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
65
66
67class DebugMixIn(object):
68    def dprintf(self, msg, func=lambda: ()):
69        if self._debug:
70            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
71            sys.stderr.write(prefix + (msg % func()))
72
73
74class EthernetFrame(object):
75    def __init__(self, data):
76        self.data = data
77
78    @property
79    def dst_multicast(self):
80        return ord(self.data[0]) & 1
81
82    @property
83    def src_multicast(self):
84        return ord(self.data[6]) & 1
85
86    @property
87    def dst_mac(self):
88        return self.data[:6]
89
90    @property
91    def src_mac(self):
92        return self.data[6:12]
93
94    @property
95    def tagged(self):
96        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
97
98    @property
99    def vid(self):
100        if self.tagged:
101            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
102        return 0
103
104    @staticmethod
105    def format_mac(mac, sep=':'):
106        return sep.join(b.encode('hex') for b in mac)
107
108
109class FDB(DebugMixIn):
110    class Entry(object):
111        def __init__(self, port, ageout):
112            self.port = port
113            self._time = time.time()
114            self._ageout = ageout
115
116        @property
117        def age(self):
118            return time.time() - self._time
119
120        @property
121        def agedout(self):
122            return self.age > self._ageout
123
124    def __init__(self, ageout, debug):
125        self._ageout = ageout
126        self._debug = debug
127        self._table = {}
128
129    def _set_entry(self, vid, mac, port):
130        if vid not in self._table:
131            self._table[vid] = {}
132        self._table[vid][mac] = self.Entry(port, self._ageout)
133
134    def _del_entry(self, vid, mac):
135        if vid in self._table:
136            if mac in self._table[vid]:
137                del self._table[vid][mac]
138            if not self._table[vid]:
139                del self._table[vid]
140
141    def _get_entry(self, vid, mac):
142        try:
143            entry = self._table[vid][mac]
144        except KeyError:
145            return None
146
147        if not entry.agedout:
148            return entry
149
150        self._del_entry(vid, mac)
151        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
152                     lambda: (entry.port.number, vid, mac.encode('hex')))
153
154    def each(self):
155        for vid in sorted(self._table.iterkeys()):
156            for mac in sorted(self._table[vid].iterkeys()):
157                entry = self._get_entry(vid, mac)
158                if entry:
159                    yield (vid, mac, entry)
160
161    def lookup(self, frame):
162        mac = frame.dst_mac
163        vid = frame.vid
164        entry = self._get_entry(vid, mac)
165        return getattr(entry, 'port', None)
166
167    def learn(self, port, frame):
168        mac = frame.src_mac
169        vid = frame.vid
170        self._set_entry(vid, mac, port)
171        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
172                     lambda: (port.number, vid, mac.encode('hex')))
173
174    def delete(self, port):
175        for vid, mac, entry in self.each():
176            if entry.port.number == port.number:
177                self._del_entry(vid, mac)
178                self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
179                             lambda: (port.number, vid, mac.encode('hex')))
180
181
182class SwitchingHub(DebugMixIn):
183    class Port(object):
184        def __init__(self, number, interface):
185            self.number = number
186            self.interface = interface
187            self.tx = 0
188            self.rx = 0
189            self.shut = False
190
191    def __init__(self, fdb, debug):
192        self.fdb = fdb
193        self._debug = debug
194        self._table = {}
195        self._next = 1
196
197    @property
198    def portlist(self):
199        return sorted(self._table.itervalues(),
200                      key=operator.attrgetter('number'))
201
202    def get_port(self, portnum):
203        return self._table[portnum]
204
205    def register_port(self, interface):
206        try:
207            self._set_privattr('portnum', interface, self._next)  # XXX
208            self._table[self._next] = self.Port(self._next, interface)
209            return self._next
210        finally:
211            self._next += 1
212
213    def unregister_port(self, interface):
214        portnum = self._get_privattr('portnum', interface)
215        self._del_privattr('portnum', interface)
216        self.fdb.delete(self._table[portnum])
217        del self._table[portnum]
218
219    def send(self, dst_interfaces, frame):
220        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
221        ports = (self._table[n] for n in portnums)
222        ports = (p for p in ports if not p.shut)
223        ports = sorted(ports, key=operator.attrgetter('number'))
224
225        for p in ports:
226            p.interface.write_message(frame.data, True)
227            p.tx += 1
228
229        if ports:
230            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
231                         lambda: (','.join(str(p.number) for p in ports),
232                                  frame.vid,
233                                  frame.src_mac.encode('hex'),
234                                  frame.dst_mac.encode('hex')))
235
236    def receive(self, src_interface, frame):
237        port = self._table[self._get_privattr('portnum', src_interface)]
238
239        if not port.shut:
240            port.rx += 1
241            self._forward(port, frame)
242
243    def _forward(self, src_port, frame):
244        try:
245            if not frame.src_multicast:
246                self.fdb.learn(src_port, frame)
247
248            if not frame.dst_multicast:
249                dst_port = self.fdb.lookup(frame)
250
251                if dst_port:
252                    self.send([dst_port.interface], frame)
253                    return
254
255            ports = set(self.portlist) - set([src_port])
256            self.send((p.interface for p in ports), frame)
257
258        except:  # ex. received invalid frame
259            traceback.print_exc()
260
261    def _privattr(self, name):
262        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
263
264    def _set_privattr(self, name, obj, value):
265        return setattr(obj, self._privattr(name), value)
266
267    def _get_privattr(self, name, obj, defaults=None):
268        return getattr(obj, self._privattr(name), defaults)
269
270    def _del_privattr(self, name, obj):
271        return delattr(obj, self._privattr(name))
272
273
274class NetworkInterface(object):
275    SIOCGIFADDR = 0x8915  # from <linux/sockios.h>
276    SIOCSIFADDR = 0x8916
277    SIOCGIFNETMASK = 0x891b
278    SIOCSIFNETMASK = 0x891c
279    SIOCGIFMTU = 0x8921
280    SIOCSIFMTU = 0x8922
281
282    def __init__(self, ifname):
283        self._ifname = struct.pack('16s', str(ifname)[:15])
284
285    def _ioctl(self, req, data):
286        try:
287            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
288            return fcntl.ioctl(sock.fileno(), req, data)
289        finally:
290            sock.close()
291
292    @property
293    def address(self):
294        try:
295            data = struct.pack('16s18x', self._ifname)
296            ret = self._ioctl(self.SIOCGIFADDR, data)
297            return socket.inet_ntoa(ret[20:24])
298        except IOError:
299            return ''
300
301    @property
302    def netmask(self):
303        try:
304            data = struct.pack('16s18x', self._ifname)
305            ret = self._ioctl(self.SIOCGIFNETMASK, data)
306            return socket.inet_ntoa(ret[20:24])
307        except IOError:
308            return ''
309
310    @property
311    def mtu(self):
312        try:
313            data = struct.pack('16s18x', self._ifname)
314            ret = self._ioctl(self.SIOCGIFMTU, data)
315            return struct.unpack('i', ret[16:20])[0]
316        except IOError:
317            return 0
318
319    @address.setter
320    def address(self, addr):
321        data = struct.pack('16si4s10x', self._ifname, socket.AF_INET,
322                           socket.inet_aton(addr))
323        self._ioctl(self.SIOCSIFADDR, data)
324
325    @netmask.setter
326    def netmask(self, addr):
327        data = struct.pack('16si4s10x', self._ifname, socket.AF_INET,
328                           socket.inet_aton(addr))
329        self._ioctl(self.SIOCSIFNETMASK, data)
330
331    @mtu.setter
332    def mtu(self, mtu):
333        data = struct.pack('16si12x', self._ifname, mtu)
334        self._ioctl(self.SIOCSIFMTU, data)
335
336
337class Htpasswd(object):
338    def __init__(self, path):
339        self._path = path
340        self._stat = None
341        self._data = {}
342
343    def auth(self, name, passwd):
344        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
345        return self._data.get(name) == passwd
346
347    def load(self):
348        old_stat = self._stat
349
350        with open(self._path) as fp:
351            fileno = fp.fileno()
352            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
353            self._stat = os.fstat(fileno)
354
355            unchanged = old_stat and \
356                old_stat.st_ino == self._stat.st_ino and \
357                old_stat.st_dev == self._stat.st_dev and \
358                old_stat.st_mtime == self._stat.st_mtime
359
360            if not unchanged:
361                self._data = self._parse(fp)
362
363        return self
364
365    def _parse(self, fp):
366        data = {}
367        for line in fp:
368            line = line.strip()
369            if 0 <= line.find(':'):
370                name, passwd = line.split(':', 1)
371                if passwd.startswith('{SHA}'):
372                    data[name] = passwd[5:]
373        return data
374
375
376class BasicAuthMixIn(object):
377    def _execute(self, transforms, *args, **kwargs):
378        def do_execute():
379            sp = super(BasicAuthMixIn, self)
380            return sp._execute(transforms, *args, **kwargs)
381
382        def auth_required():
383            stream = getattr(self, 'stream', self.request.connection.stream)
384            stream.write(tornado.escape.utf8(
385                'HTTP/1.1 401 Authorization Required\r\n'
386                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
387            ))
388            stream.close()
389
390        try:
391            if not self._htpasswd:
392                return do_execute()
393
394            creds = self.request.headers.get('Authorization')
395
396            if not creds or not creds.startswith('Basic '):
397                return auth_required()
398
399            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
400
401            if self._htpasswd.load().auth(name, passwd):
402                return do_execute()
403        except:
404            traceback.print_exc()
405
406        return auth_required()
407
408
409class ServerHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
410    IFTYPE = 'server'
411    IFOP_ALLOWED = False
412
413    def __init__(self, app, req, switch, htpasswd, debug):
414        super(ServerHandler, self).__init__(app, req)
415        self._switch = switch
416        self._htpasswd = htpasswd
417        self._debug = debug
418
419    @property
420    def target(self):
421        return ':'.join(str(e) for e in self.request.connection.address)
422
423    def open(self):
424        try:
425            return self._switch.register_port(self)
426        finally:
427            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
428
429    def on_message(self, message):
430        self._switch.receive(self, EthernetFrame(message))
431
432    def on_close(self):
433        self._switch.unregister_port(self)
434        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
435
436
437class BaseClientHandler(DebugMixIn):
438    IFTYPE = 'baseclient'
439    IFOP_ALLOWED = False
440
441    def __init__(self, ioloop, switch, target, debug, *args, **kwargs):
442        self._ioloop = ioloop
443        self._switch = switch
444        self._target = target
445        self._debug = debug
446        self._args = args
447        self._kwargs = kwargs
448        self._device = None
449
450    @property
451    def address(self):
452        raise NotImplementedError('unsupported')
453
454    @property
455    def netmask(self):
456        raise NotImplementedError('unsupported')
457
458    @property
459    def mtu(self):
460        raise NotImplementedError('unsupported')
461
462    @address.setter
463    def address(self, address):
464        raise NotImplementedError('unsupported')
465
466    @netmask.setter
467    def netmask(self, netmask):
468        raise NotImplementedError('unsupported')
469
470    @mtu.setter
471    def mtu(self, mtu):
472        raise NotImplementedError('unsupported')
473
474    def open(self):
475        raise NotImplementedError('unsupported')
476
477    def write_message(self, message, binary=False):
478        raise NotImplementedError('unsupported')
479
480    def read(self):
481        raise NotImplementedError('unsupported')
482
483    @property
484    def target(self):
485        return self._target
486
487    @property
488    def device(self):
489        return self._device
490
491    @property
492    def closed(self):
493        return not self.device
494
495    def close(self):
496        if self.closed:
497            raise ValueError('I/O operation on closed %s' % self.IFTYPE)
498        self.leave_switch()
499        self.unregister_device()
500        self.dprintf('disconnected: %s\n', lambda: self.target)
501
502    def register_device(self, device):
503        self._device = device
504
505    def unregister_device(self):
506        self._device.close()
507        self._device = None
508
509    def fileno(self):
510        if self.closed:
511            raise ValueError('I/O operation on closed %s' % self.IFTYPE)
512        return self.device.fileno()
513
514    def __call__(self, fd, events):
515        try:
516            data = self.read()
517            if data is not None:
518                self._switch.receive(self, EthernetFrame(data))
519                return
520        except:
521            traceback.print_exc()
522        self.close()
523
524    def join_switch(self):
525        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
526        return self._switch.register_port(self)
527
528    def leave_switch(self):
529        self._switch.unregister_port(self)
530        self._ioloop.remove_handler(self.fileno())
531
532
533class NetdevHandler(BaseClientHandler):
534    IFTYPE = 'netdev'
535    IFOP_ALLOWED = True
536    ETH_P_ALL = 0x0003  # from <linux/if_ether.h>
537
538    @property
539    def address(self):
540        if self.closed:
541            raise ValueError('I/O operation on closed netdev')
542        return NetworkInterface(self.target).address
543
544    @property
545    def netmask(self):
546        if self.closed:
547            raise ValueError('I/O operation on closed netdev')
548        return NetworkInterface(self.target).netmask
549
550    @property
551    def mtu(self):
552        if self.closed:
553            raise ValueError('I/O operation on closed netdev')
554        return NetworkInterface(self.target).mtu
555
556    @address.setter
557    def address(self, address):
558        if self.closed:
559            raise ValueError('I/O operation on closed netdev')
560        NetworkInterface(self.target).address = address
561
562    @netmask.setter
563    def netmask(self, netmask):
564        if self.closed:
565            raise ValueError('I/O operation on closed netdev')
566        NetworkInterface(self.target).netmask = netmask
567
568    @mtu.setter
569    def mtu(self, mtu):
570        if self.closed:
571            raise ValueError('I/O operation on closed netdev')
572        NetworkInterface(self.target).mtu = mtu
573
574    def open(self):
575        if not self.closed:
576            raise ValueError('Already opened')
577        self.register_device(socket.socket(
578            socket.PF_PACKET, socket.SOCK_RAW, socket.htons(self.ETH_P_ALL)))
579        self.device.bind((self.target, self.ETH_P_ALL))
580        self.dprintf('connected: %s\n', lambda: self.target)
581        return self.join_switch()
582
583    def write_message(self, message, binary=False):
584        if self.closed:
585            raise ValueError('I/O operation on closed netdev')
586        self.device.sendall(message)
587
588    def read(self):
589        if self.closed:
590            raise ValueError('I/O operation on closed netdev')
591        buf = []
592        while True:
593            buf.append(self.device.recv(65535))
594            if len(buf[-1]) < 65535:
595                break
596        return ''.join(buf)
597
598
599class TapHandler(BaseClientHandler):
600    IFTYPE = 'tap'
601    IFOP_ALLOWED = True
602
603    @property
604    def address(self):
605        if self.closed:
606            raise ValueError('I/O operation on closed tap')
607        try:
608            return self.device.addr
609        except:
610            return ''
611
612    @property
613    def netmask(self):
614        if self.closed:
615            raise ValueError('I/O operation on closed tap')
616        try:
617            return self.device.netmask
618        except:
619            return ''
620
621    @property
622    def mtu(self):
623        if self.closed:
624            raise ValueError('I/O operation on closed tap')
625        return self.device.mtu
626
627    @address.setter
628    def address(self, address):
629        if self.closed:
630            raise ValueError('I/O operation on closed tap')
631        self.device.addr = address
632
633    @netmask.setter
634    def netmask(self, netmask):
635        if self.closed:
636            raise ValueError('I/O operation on closed tap')
637        self.device.netmask = netmask
638
639    @mtu.setter
640    def mtu(self, mtu):
641        if self.closed:
642            raise ValueError('I/O operation on closed tap')
643        self.device.mtu = mtu
644
645    @property
646    def target(self):
647        if self.closed:
648            return self._target
649        return self.device.name
650
651    def open(self):
652        if not self.closed:
653            raise ValueError('Already opened')
654        self.register_device(TunTapDevice(self.target, IFF_TAP | IFF_NO_PI))
655        self.device.up()
656        self.dprintf('connected: %s\n', lambda: self.target)
657        return self.join_switch()
658
659    def write_message(self, message, binary=False):
660        if self.closed:
661            raise ValueError('I/O operation on closed tap')
662        self.device.write(message)
663
664    def read(self):
665        if self.closed:
666            raise ValueError('I/O operation on closed tap')
667        buf = []
668        while True:
669            buf.append(self.device.read(65535))
670            if len(buf[-1]) < 65535:
671                break
672        return ''.join(buf)
673
674
675class ClientHandler(BaseClientHandler):
676    IFTYPE = 'client'
677    IFOP_ALLOWED = False
678
679    def __init__(self, *args, **kwargs):
680        super(ClientHandler, self).__init__(*args, **kwargs)
681
682        self._sslopt = kwargs.get('sslopt', {})
683        self._options = {}
684
685        cred = kwargs.get('cred', None)
686
687        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
688            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
689            auth = ['Authorization: Basic %s' % token]
690            self._options['header'] = auth
691
692    def open(self):
693        if not self.closed:
694            raise websocket.WebSocketException('Already opened')
695
696        # XXX: may be blocked
697        self.register_device(websocket.WebSocket(sslopt=self._sslopt))
698        self.device.connect(self.target, **self._options)
699        self.dprintf('connected: %s\n', lambda: self.target)
700        return self.join_switch()
701
702    def write_message(self, message, binary=False):
703        if self.closed:
704            raise websocket.WebSocketException('Closed socket')
705        if binary:
706            flag = websocket.ABNF.OPCODE_BINARY
707        else:
708            flag = websocket.ABNF.OPCODE_TEXT
709        self.device.send(message, flag)
710
711    def read(self):
712        if self.closed:
713            raise websocket.WebSocketException('Closed socket')
714        return self.device.recv()
715
716
717class ControlServerHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
718    NAMESPACE = 'etherws.control'
719    IFTYPES = {
720        NetdevHandler.IFTYPE: NetdevHandler,
721        TapHandler.IFTYPE:    TapHandler,
722        ClientHandler.IFTYPE: ClientHandler,
723    }
724
725    def __init__(self, app, req, ioloop, switch, htpasswd, debug):
726        super(ControlServerHandler, self).__init__(app, req)
727        self._ioloop = ioloop
728        self._switch = switch
729        self._htpasswd = htpasswd
730        self._debug = debug
731
732    def post(self):
733        try:
734            request = json.loads(self.request.body)
735        except Exception as e:
736            return self._jsonrpc_response(error={
737                'code':    0 - 32700,
738                'message': 'Parse error',
739                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
740            })
741
742        try:
743            id_ = request.get('id')
744            params = request.get('params')
745            version = request['jsonrpc']
746            method = request['method']
747            if version != '2.0':
748                raise ValueError('Invalid JSON-RPC version: %s' % version)
749        except Exception as e:
750            return self._jsonrpc_response(id_=id_, error={
751                'code':    0 - 32600,
752                'message': 'Invalid Request',
753                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
754            })
755
756        try:
757            if not method.startswith(self.NAMESPACE + '.'):
758                raise ValueError('Invalid method namespace: %s' % method)
759            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
760            handler = getattr(self, handler)
761        except Exception as e:
762            return self._jsonrpc_response(id_=id_, error={
763                'code':    0 - 32601,
764                'message': 'Method not found',
765                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
766            })
767
768        try:
769            return self._jsonrpc_response(id_=id_, result=handler(params))
770        except Exception as e:
771            traceback.print_exc()
772            return self._jsonrpc_response(id_=id_, error={
773                'code':    0 - 32602,
774                'message': 'Invalid params',
775                'data':     '%s: %s' % (e.__class__.__name__, str(e)),
776            })
777
778    def handle_listFdb(self, params):
779        list_ = []
780        for vid, mac, entry in self._switch.fdb.each():
781            list_.append({
782                'vid':  vid,
783                'mac':  EthernetFrame.format_mac(mac),
784                'port': entry.port.number,
785                'age':  int(entry.age),
786            })
787        return {'entries': list_}
788
789    def handle_listPort(self, params):
790        return {'entries': [self._portstat(p) for p in self._switch.portlist]}
791
792    def handle_addPort(self, params):
793        type_ = params['type']
794        target = params['target']
795        opt = getattr(self, '_optparse_' + type_)(params.get('options', {}))
796        cls = self.IFTYPES[type_]
797        interface = cls(self._ioloop, self._switch, target, self._debug, **opt)
798        portnum = interface.open()
799        return {'entries': [self._portstat(self._switch.get_port(portnum))]}
800
801    def handle_setPort(self, params):
802        port = self._switch.get_port(int(params['port']))
803        shut = params.get('shut')
804        if shut is not None:
805            port.shut = bool(shut)
806        return {'entries': [self._portstat(port)]}
807
808    def handle_delPort(self, params):
809        port = self._switch.get_port(int(params['port']))
810        port.interface.close()
811        return {'entries': [self._portstat(port)]}
812
813    def handle_setInterface(self, params):
814        portnum = int(params['port'])
815        port = self._switch.get_port(portnum)
816        address = params.get('address')
817        netmask = params.get('netmask')
818        mtu = params.get('mtu')
819        if not port.interface.IFOP_ALLOWED:
820            raise ValueError('Port %d has unsupported interface: %s' %
821                             (portnum, port.interface.IFTYPE))
822        if address is not None:
823            port.interface.address = address
824        if netmask is not None:
825            port.interface.netmask = netmask
826        if mtu is not None:
827            port.interface.mtu = mtu
828        return {'entries': [self._ifstat(port)]}
829
830    def handle_listInterface(self, params):
831        return {'entries': [self._ifstat(p) for p in self._switch.portlist
832                            if p.interface.IFOP_ALLOWED]}
833
834    def _optparse_netdev(self, opt):
835        return {}
836
837    def _optparse_tap(self, opt):
838        return {}
839
840    def _optparse_client(self, opt):
841        if opt.get('insecure'):
842            sslopt = {}
843        else:
844            sslopt = {'cert_reqs': ssl.CERT_REQUIRED,
845                      'ca_certs':  opt.get('cacerts')}
846        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
847        return {'sslopt': sslopt, 'cred': cred}
848
849    def _jsonrpc_response(self, id_=None, result=None, error=None):
850        res = {'jsonrpc': '2.0', 'id': id_}
851        if result:
852            res['result'] = result
853        if error:
854            res['error'] = error
855        self.finish(res)
856
857    @staticmethod
858    def _portstat(port):
859        return {
860            'port':   port.number,
861            'type':   port.interface.IFTYPE,
862            'target': port.interface.target,
863            'tx':     port.tx,
864            'rx':     port.rx,
865            'shut':   port.shut,
866        }
867
868    @staticmethod
869    def _ifstat(port):
870        return {
871            'port':    port.number,
872            'type':    port.interface.IFTYPE,
873            'target':  port.interface.target,
874            'address': port.interface.address,
875            'netmask': port.interface.netmask,
876            'mtu':     port.interface.mtu,
877        }
878
879
880def _print_error(error):
881    print(%s (%s)' % (error['message'], error['code']))
882    print('    %s' % error['data'])
883
884
885def _start_sw(args):
886    def daemonize(nochdir=False, noclose=False):
887        if os.fork() > 0:
888            sys.exit(0)
889
890        os.setsid()
891
892        if os.fork() > 0:
893            sys.exit(0)
894
895        if not nochdir:
896            os.chdir('/')
897
898        if not noclose:
899            os.umask(0)
900            sys.stdin.close()
901            sys.stdout.close()
902            sys.stderr.close()
903            os.close(0)
904            os.close(1)
905            os.close(2)
906            sys.stdin = open(os.devnull)
907            sys.stdout = open(os.devnull, 'a')
908            sys.stderr = open(os.devnull, 'a')
909
910    def checkabspath(ns, path):
911        val = getattr(ns, path, '')
912        if not val.startswith('/'):
913            raise ValueError('Invalid %: %s' % (path, val))
914
915    def getsslopt(ns, key, cert):
916        kval = getattr(ns, key, None)
917        cval = getattr(ns, cert, None)
918        if kval and cval:
919            return {'keyfile': kval, 'certfile': cval}
920        elif kval or cval:
921            raise ValueError('Both %s and %s are required' % (key, cert))
922        return None
923
924    def setrealpath(ns, *keys):
925        for k in keys:
926            v = getattr(ns, k, None)
927            if v is not None:
928                v = os.path.realpath(v)
929                open(v).close()  # check readable
930                setattr(ns, k, v)
931
932    def setport(ns, port, isssl):
933        val = getattr(ns, port, None)
934        if val is None:
935            if isssl:
936                return setattr(ns, port, 443)
937            return setattr(ns, port, 80)
938        if not (0 <= val <= 65535):
939            raise ValueError('Invalid %s: %s' % (port, val))
940
941    def sethtpasswd(ns, htpasswd):
942        val = getattr(ns, htpasswd, None)
943        if val:
944            return setattr(ns, htpasswd, Htpasswd(val))
945
946    #if args.debug:
947    #    websocket.enableTrace(True)
948
949    if args.ageout <= 0:
950        raise ValueError('Invalid ageout: %s' % args.ageout)
951
952    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
953    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
954
955    checkabspath(args, 'path')
956    checkabspath(args, 'ctlpath')
957
958    sslopt = getsslopt(args, 'sslkey', 'sslcert')
959    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
960
961    setport(args, 'port', sslopt)
962    setport(args, 'ctlport', ctlsslopt)
963
964    sethtpasswd(args, 'htpasswd')
965    sethtpasswd(args, 'ctlhtpasswd')
966
967    ioloop = IOLoop.instance()
968    fdb = FDB(args.ageout, args.debug)
969    switch = SwitchingHub(fdb, args.debug)
970
971    if args.port == args.ctlport and args.host == args.ctlhost:
972        if args.path == args.ctlpath:
973            raise ValueError('Same path/ctlpath on same host')
974        if args.sslkey != args.ctlsslkey:
975            raise ValueError('Different sslkey/ctlsslkey on same host')
976        if args.sslcert != args.ctlsslcert:
977            raise ValueError('Different sslcert/ctlsslcert on same host')
978
979        app = Application([
980            (args.path, ServerHandler, {
981                'switch':   switch,
982                'htpasswd': args.htpasswd,
983                'debug':    args.debug,
984            }),
985            (args.ctlpath, ControlServerHandler, {
986                'ioloop':   ioloop,
987                'switch':   switch,
988                'htpasswd': args.ctlhtpasswd,
989                'debug':    args.debug,
990            }),
991        ])
992        server = HTTPServer(app, ssl_options=sslopt)
993        server.listen(args.port, address=args.host)
994
995    else:
996        app = Application([(args.path, ServerHandler, {
997            'switch':   switch,
998            'htpasswd': args.htpasswd,
999            'debug':    args.debug,
1000        })])
1001        server = HTTPServer(app, ssl_options=sslopt)
1002        server.listen(args.port, address=args.host)
1003
1004        ctl = Application([(args.ctlpath, ControlServerHandler, {
1005            'ioloop':   ioloop,
1006            'switch':   switch,
1007            'htpasswd': args.ctlhtpasswd,
1008            'debug':    args.debug,
1009        })])
1010        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
1011        ctlserver.listen(args.ctlport, address=args.ctlhost)
1012
1013    if not args.foreground:
1014        daemonize()
1015
1016    ioloop.start()
1017
1018
1019def _start_ctl(args):
1020    def request(args, method, params=None, id_=0):
1021        req = urllib2.Request(args.ctlurl)
1022        req.add_header('Content-type', 'application/json')
1023        if args.ctluser:
1024            if not args.ctlpasswd:
1025                args.ctlpasswd = getpass.getpass('Control Password: ')
1026            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
1027            req.add_header('Authorization', 'Basic %s' % token)
1028        method = '.'.join([ControlServerHandler.NAMESPACE, method])
1029        data = {'jsonrpc': '2.0', 'method': method, 'id': id_}
1030        if params is not None:
1031            data['params'] = params
1032        return json.loads(urllib2.urlopen(req, json.dumps(data)).read())
1033
1034    def print_table(rows):
1035        cols = zip(*rows)
1036        maxlen = [0] * len(cols)
1037        for i in xrange(len(cols)):
1038            maxlen[i] = max(len(str(c)) for c in cols[i])
1039        fmt = '  '.join(['%%-%ds' % maxlen[i] for i in xrange(len(cols))])
1040        fmt = '  ' + fmt
1041        for row in rows:
1042            print(fmt % tuple(row))
1043
1044    def print_portlist(result):
1045        rows = [['Port', 'Type', 'State', 'RX', 'TX', 'Target']]
1046        for r in result:
1047            rows.append([r['port'], r['type'], 'shut' if r['shut'] else 'up',
1048                         r['rx'], r['tx'], r['target']])
1049        print_table(rows)
1050
1051    def print_iflist(result):
1052        rows = [['Port', 'Type', 'Address', 'Netmask', 'MTU', 'Target']]
1053        for r in result:
1054            rows.append([r['port'], r['type'], r['address'],
1055                         r['netmask'], r['mtu'], r['target']])
1056        print_table(rows)
1057
1058    def handle_ctl_addport(args):
1059        opts = {
1060            'user':     getattr(args, 'user', None),
1061            'passwd':   getattr(args, 'passwd', None),
1062            'cacerts':  getattr(args, 'cacerts', None),
1063            'insecure': getattr(args, 'insecure', None),
1064        }
1065        if args.iftype == ClientHandler.IFTYPE:
1066            if not args.target.startswith('ws://') and \
1067               not args.target.startswith('wss://'):
1068                raise ValueError('Invalid target URL scheme: %s' % args.target)
1069            if not opts['user'] and opts['passwd']:
1070                raise ValueError('Authentication required but username empty')
1071            if opts['user'] and not opts['passwd']:
1072                opts['passwd'] = getpass.getpass('Client Password: ')
1073        result = request(args, 'addPort', {
1074            'type':    args.iftype,
1075            'target':  args.target,
1076            'options': opts,
1077        })
1078        if 'error' in result:
1079            _print_error(result['error'])
1080        else:
1081            print_portlist(result['result']['entries'])
1082
1083    def handle_ctl_setport(args):
1084        if args.port <= 0:
1085            raise ValueError('Invalid port: %d' % args.port)
1086        req = {'port': args.port}
1087        shut = getattr(args, 'shut', None)
1088        if shut is not None:
1089            req['shut'] = bool(shut)
1090        result = request(args, 'setPort', req)
1091        if 'error' in result:
1092            _print_error(result['error'])
1093        else:
1094            print_portlist(result['result']['entries'])
1095
1096    def handle_ctl_delport(args):
1097        if args.port <= 0:
1098            raise ValueError('Invalid port: %d' % args.port)
1099        result = request(args, 'delPort', {'port': args.port})
1100        if 'error' in result:
1101            _print_error(result['error'])
1102        else:
1103            print_portlist(result['result']['entries'])
1104
1105    def handle_ctl_listport(args):
1106        result = request(args, 'listPort')
1107        if 'error' in result:
1108            _print_error(result['error'])
1109        else:
1110            print_portlist(result['result']['entries'])
1111
1112    def handle_ctl_setif(args):
1113        if args.port <= 0:
1114            raise ValueError('Invalid port: %d' % args.port)
1115        req = {'port': args.port}
1116        address = getattr(args, 'address', None)
1117        netmask = getattr(args, 'netmask', None)
1118        mtu = getattr(args, 'mtu', None)
1119        if address is not None:
1120            if address:
1121                socket.inet_aton(address)  # validate
1122            req['address'] = address
1123        if netmask is not None:
1124            if netmask:
1125                socket.inet_aton(netmask)  # validate
1126            req['netmask'] = netmask
1127        if mtu is not None:
1128            if mtu < 576:
1129                raise ValueError('Invalid MTU: %d' % mtu)
1130            req['mtu'] = mtu
1131        result = request(args, 'setInterface', req)
1132        if 'error' in result:
1133            _print_error(result['error'])
1134        else:
1135            print_iflist(result['result']['entries'])
1136
1137    def handle_ctl_listif(args):
1138        result = request(args, 'listInterface')
1139        if 'error' in result:
1140            _print_error(result['error'])
1141        else:
1142            print_iflist(result['result']['entries'])
1143
1144    def handle_ctl_listfdb(args):
1145        result = request(args, 'listFdb')
1146        if 'error' in result:
1147            return _print_error(result['error'])
1148        rows = [['Port', 'VLAN', 'MAC', 'Age']]
1149        for r in result['result']['entries']:
1150            rows.append([r['port'], r['vid'], r['mac'], r['age']])
1151        print_table(rows)
1152
1153    locals()['handle_ctl_' + args.control_method](args)
1154
1155
1156def _main():
1157    parser = argparse.ArgumentParser()
1158    subcommand = parser.add_subparsers(dest='subcommand')
1159
1160    # - sw
1161    parser_sw = subcommand.add_parser('sw',
1162                                      help='start virtual switch')
1163
1164    parser_sw.add_argument('--debug', action='store_true', default=False,
1165                           help='run as debug mode')
1166    parser_sw.add_argument('--foreground', action='store_true', default=False,
1167                           help='run as foreground mode')
1168    parser_sw.add_argument('--ageout', type=int, default=300,
1169                           help='FDB ageout time (sec)')
1170
1171    parser_sw.add_argument('--path', default='/',
1172                           help='http(s) path to serve WebSocket')
1173    parser_sw.add_argument('--host', default='',
1174                           help='listen address to serve WebSocket')
1175    parser_sw.add_argument('--port', type=int,
1176                           help='listen port to serve WebSocket')
1177    parser_sw.add_argument('--htpasswd',
1178                           help='path to htpasswd file to auth WebSocket')
1179    parser_sw.add_argument('--sslkey',
1180                           help='path to SSL private key for WebSocket')
1181    parser_sw.add_argument('--sslcert',
1182                           help='path to SSL certificate for WebSocket')
1183
1184    parser_sw.add_argument('--ctlpath', default='/ctl',
1185                           help='http(s) path to serve control API')
1186    parser_sw.add_argument('--ctlhost', default='127.0.0.1',
1187                           help='listen address to serve control API')
1188    parser_sw.add_argument('--ctlport', type=int, default=7867,
1189                           help='listen port to serve control API')
1190    parser_sw.add_argument('--ctlhtpasswd',
1191                           help='path to htpasswd file to auth control API')
1192    parser_sw.add_argument('--ctlsslkey',
1193                           help='path to SSL private key for control API')
1194    parser_sw.add_argument('--ctlsslcert',
1195                           help='path to SSL certificate for control API')
1196
1197    # - ctl
1198    parser_ctl = subcommand.add_parser('ctl',
1199                                       help='control virtual switch')
1200    parser_ctl.add_argument('--ctlurl', default='http://127.0.0.1:7867/ctl',
1201                            help='URL to control API')
1202    parser_ctl.add_argument('--ctluser',
1203                            help='username to auth control API')
1204    parser_ctl.add_argument('--ctlpasswd',
1205                            help='password to auth control API')
1206
1207    control_method = parser_ctl.add_subparsers(dest='control_method')
1208
1209    # -- ctl addport
1210    parser_ctl_addport = control_method.add_parser('addport',
1211                                                   help='create and add port')
1212    iftype = parser_ctl_addport.add_subparsers(dest='iftype')
1213
1214    # --- ctl addport netdev
1215    parser_ctl_addport_netdev = iftype.add_parser(NetdevHandler.IFTYPE,
1216                                                  help='Network device')
1217    parser_ctl_addport_netdev.add_argument('target',
1218                                           help='device name to add interface')
1219
1220    # --- ctl addport tap
1221    parser_ctl_addport_tap = iftype.add_parser(TapHandler.IFTYPE,
1222                                               help='TAP device')
1223    parser_ctl_addport_tap.add_argument('target',
1224                                        help='device name to create interface')
1225
1226    # --- ctl addport client
1227    parser_ctl_addport_client = iftype.add_parser(ClientHandler.IFTYPE,
1228                                                  help='WebSocket client')
1229    parser_ctl_addport_client.add_argument('target',
1230                                           help='URL to connect WebSocket')
1231    parser_ctl_addport_client.add_argument('--user',
1232                                           help='username to auth WebSocket')
1233    parser_ctl_addport_client.add_argument('--passwd',
1234                                           help='password to auth WebSocket')
1235    parser_ctl_addport_client.add_argument('--cacerts',
1236                                           help='path to CA certificate')
1237    parser_ctl_addport_client.add_argument(
1238        '--insecure', action='store_true', default=False,
1239        help='do not verify server certificate')
1240
1241    # -- ctl setport
1242    parser_ctl_setport = control_method.add_parser('setport',
1243                                                   help='set port config')
1244    parser_ctl_setport.add_argument('port', type=int,
1245                                    help='port number to set config')
1246    parser_ctl_setport.add_argument('--shut', type=int, choices=(0, 1),
1247                                    help='set shutdown state')
1248
1249    # -- ctl delport
1250    parser_ctl_delport = control_method.add_parser('delport',
1251                                                   help='delete port')
1252    parser_ctl_delport.add_argument('port', type=int,
1253                                    help='port number to delete')
1254
1255    # -- ctl listport
1256    parser_ctl_listport = control_method.add_parser('listport',
1257                                                    help='show port list')
1258
1259    # -- ctl setif
1260    parser_ctl_setif = control_method.add_parser('setif',
1261                                                 help='set interface config')
1262    parser_ctl_setif.add_argument('port', type=int,
1263                                  help='port number to set config')
1264    parser_ctl_setif.add_argument('--address',
1265                                  help='IPv4 address to set interface')
1266    parser_ctl_setif.add_argument('--netmask',
1267                                  help='IPv4 netmask to set interface')
1268    parser_ctl_setif.add_argument('--mtu', type=int,
1269                                  help='MTU to set interface')
1270
1271    # -- ctl listif
1272    parser_ctl_listif = control_method.add_parser('listif',
1273                                                  help='show interface list')
1274
1275    # -- ctl listfdb
1276    parser_ctl_listfdb = control_method.add_parser('listfdb',
1277                                                   help='show FDB entries')
1278
1279    # -- go
1280    args = parser.parse_args()
1281
1282    try:
1283        globals()['_start_' + args.subcommand](args)
1284    except Exception as e:
1285        _print_error({
1286            'code':    0 - 32603,
1287            'message': 'Internal error',
1288            'data':    '%s: %s' % (e.__class__.__name__, str(e)),
1289        })
1290
1291
1292if __name__ == '__main__':
1293    _main()
Note: See TracBrowser for help on using the repository browser.