source: etherws/trunk/etherws.py @ 278

Revision 278, 43.5 KB checked in by atzm, 9 years ago (diff)
  • fix comment
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.5
8#   - python-pytun-2.1
9#   - websocket-client-0.14.0
10#   - tornado-2.4
11#
12# ===========================================================================
13# Copyright (c) 2012-2015, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
42import ssl
43import time
44import json
45import fcntl
46import base64
47import socket
48import struct
49import urllib2
50import hashlib
51import getpass
52import operator
53import argparse
54import traceback
55
56import tornado
57import websocket
58
59from tornado.web import Application, RequestHandler
60from tornado.websocket import WebSocketHandler
61from tornado.httpserver import HTTPServer
62from tornado.ioloop import IOLoop
63
64from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
65
66
67class DebugMixIn(object):
68    def dprintf(self, msg, func=lambda: ()):
69        if self._debug:
70            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
71            sys.stderr.write(prefix + (msg % func()))
72
73
74class EthernetFrame(object):
75    def __init__(self, data):
76        self.data = data
77
78    @property
79    def dst_multicast(self):
80        return ord(self.data[0]) & 1
81
82    @property
83    def src_multicast(self):
84        return ord(self.data[6]) & 1
85
86    @property
87    def dst_mac(self):
88        return self.data[:6]
89
90    @property
91    def src_mac(self):
92        return self.data[6:12]
93
94    @property
95    def tagged(self):
96        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
97
98    @property
99    def vid(self):
100        if self.tagged:
101            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
102        return 0
103
104    @staticmethod
105    def format_mac(mac, sep=':'):
106        return sep.join(b.encode('hex') for b in mac)
107
108
109class FDB(DebugMixIn):
110    class Entry(object):
111        def __init__(self, port, ageout):
112            self.port = port
113            self._time = time.time()
114            self._ageout = ageout
115
116        @property
117        def age(self):
118            return time.time() - self._time
119
120        @property
121        def agedout(self):
122            return self.age > self._ageout
123
124    def __init__(self, ageout, debug):
125        self._ageout = ageout
126        self._debug = debug
127        self._table = {}
128
129    def _set_entry(self, vid, mac, port):
130        if vid not in self._table:
131            self._table[vid] = {}
132        self._table[vid][mac] = self.Entry(port, self._ageout)
133
134    def _del_entry(self, vid, mac):
135        if vid in self._table:
136            if mac in self._table[vid]:
137                del self._table[vid][mac]
138            if not self._table[vid]:
139                del self._table[vid]
140
141    def _get_entry(self, vid, mac):
142        try:
143            entry = self._table[vid][mac]
144        except KeyError:
145            return None
146
147        if not entry.agedout:
148            return entry
149
150        self._del_entry(vid, mac)
151        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
152                     lambda: (entry.port.number, vid, mac.encode('hex')))
153
154    def each(self):
155        for vid in sorted(self._table.iterkeys()):
156            for mac in sorted(self._table[vid].iterkeys()):
157                entry = self._get_entry(vid, mac)
158                if entry:
159                    yield (vid, mac, entry)
160
161    def lookup(self, frame):
162        mac = frame.dst_mac
163        vid = frame.vid
164        entry = self._get_entry(vid, mac)
165        return getattr(entry, 'port', None)
166
167    def learn(self, port, frame):
168        mac = frame.src_mac
169        vid = frame.vid
170        self._set_entry(vid, mac, port)
171        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
172                     lambda: (port.number, vid, mac.encode('hex')))
173
174    def delete(self, port):
175        for vid, mac, entry in self.each():
176            if entry.port.number == port.number:
177                self._del_entry(vid, mac)
178                self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
179                             lambda: (port.number, vid, mac.encode('hex')))
180
181
182class SwitchingHub(DebugMixIn):
183    class Port(object):
184        def __init__(self, number, interface):
185            self.number = number
186            self.interface = interface
187            self.tx = 0
188            self.rx = 0
189            self.shut = False
190
191    def __init__(self, fdb, debug):
192        self.fdb = fdb
193        self._debug = debug
194        self._table = {}
195        self._next = 1
196
197    @property
198    def portlist(self):
199        return sorted(self._table.itervalues(),
200                      key=operator.attrgetter('number'))
201
202    def get_port(self, portnum):
203        return self._table[portnum]
204
205    def register_port(self, interface):
206        try:
207            self._set_privattr('portnum', interface, self._next)  # XXX
208            self._table[self._next] = self.Port(self._next, interface)
209            return self._next
210        finally:
211            self._next += 1
212
213    def unregister_port(self, interface):
214        portnum = self._get_privattr('portnum', interface)
215        self._del_privattr('portnum', interface)
216        self.fdb.delete(self._table[portnum])
217        del self._table[portnum]
218
219    def send(self, dst_interfaces, frame):
220        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
221        ports = (self._table[n] for n in portnums)
222        ports = (p for p in ports if not p.shut)
223        ports = sorted(ports, key=operator.attrgetter('number'))
224
225        for p in ports:
226            p.interface.write_message(frame.data, True)
227            p.tx += 1
228
229        if ports:
230            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
231                         lambda: (','.join(str(p.number) for p in ports),
232                                  frame.vid,
233                                  frame.src_mac.encode('hex'),
234                                  frame.dst_mac.encode('hex')))
235
236    def receive(self, src_interface, frame):
237        port = self._table[self._get_privattr('portnum', src_interface)]
238
239        if not port.shut:
240            port.rx += 1
241            self._forward(port, frame)
242
243    def _forward(self, src_port, frame):
244        try:
245            if not frame.src_multicast:
246                self.fdb.learn(src_port, frame)
247
248            if not frame.dst_multicast:
249                dst_port = self.fdb.lookup(frame)
250
251                if dst_port:
252                    self.send([dst_port.interface], frame)
253                    return
254
255            ports = set(self.portlist) - set([src_port])
256            self.send((p.interface for p in ports), frame)
257
258        except:  # ex. received invalid frame
259            traceback.print_exc()
260
261    def _privattr(self, name):
262        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
263
264    def _set_privattr(self, name, obj, value):
265        return setattr(obj, self._privattr(name), value)
266
267    def _get_privattr(self, name, obj, defaults=None):
268        return getattr(obj, self._privattr(name), defaults)
269
270    def _del_privattr(self, name, obj):
271        return delattr(obj, self._privattr(name))
272
273
274class NetworkInterface(object):
275    SIOCGIFADDR = 0x8915  # from <linux/sockios.h>
276    SIOCSIFADDR = 0x8916
277    SIOCGIFNETMASK = 0x891b
278    SIOCSIFNETMASK = 0x891c
279    SIOCGIFMTU = 0x8921
280    SIOCSIFMTU = 0x8922
281
282    def __init__(self, ifname):
283        self._ifname = struct.pack('16s', str(ifname)[:15])
284
285    def _ioctl(self, req, data):
286        try:
287            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
288            return fcntl.ioctl(sock.fileno(), req, data)
289        finally:
290            sock.close()
291
292    @property
293    def address(self):
294        try:
295            data = struct.pack('16s18x', self._ifname)
296            ret = self._ioctl(self.SIOCGIFADDR, data)
297            return socket.inet_ntoa(ret[20:24])
298        except IOError:
299            return ''
300
301    @property
302    def netmask(self):
303        try:
304            data = struct.pack('16s18x', self._ifname)
305            ret = self._ioctl(self.SIOCGIFNETMASK, data)
306            return socket.inet_ntoa(ret[20:24])
307        except IOError:
308            return ''
309
310    @property
311    def mtu(self):
312        try:
313            data = struct.pack('16s18x', self._ifname)
314            ret = self._ioctl(self.SIOCGIFMTU, data)
315            return struct.unpack('i', ret[16:20])[0]
316        except IOError:
317            return 0
318
319    @address.setter
320    def address(self, addr):
321        data = struct.pack('16si4s10x', self._ifname, socket.AF_INET,
322                           socket.inet_aton(addr))
323        self._ioctl(self.SIOCSIFADDR, data)
324
325    @netmask.setter
326    def netmask(self, addr):
327        data = struct.pack('16si4s10x', self._ifname, socket.AF_INET,
328                           socket.inet_aton(addr))
329        self._ioctl(self.SIOCSIFNETMASK, data)
330
331    @mtu.setter
332    def mtu(self, mtu):
333        data = struct.pack('16si12x', self._ifname, mtu)
334        self._ioctl(self.SIOCSIFMTU, data)
335
336
337class Htpasswd(object):
338    def __init__(self, path):
339        self._path = path
340        self._stat = None
341        self._data = {}
342
343    def auth(self, name, passwd):
344        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
345        return self._data.get(name) == passwd
346
347    def load(self):
348        old_stat = self._stat
349
350        with open(self._path) as fp:
351            fileno = fp.fileno()
352            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
353            self._stat = os.fstat(fileno)
354
355            unchanged = old_stat and \
356                old_stat.st_ino == self._stat.st_ino and \
357                old_stat.st_dev == self._stat.st_dev and \
358                old_stat.st_mtime == self._stat.st_mtime
359
360            if not unchanged:
361                self._data = self._parse(fp)
362
363        return self
364
365    def _parse(self, fp):
366        data = {}
367        for line in fp:
368            line = line.strip()
369            if 0 <= line.find(':'):
370                name, passwd = line.split(':', 1)
371                if passwd.startswith('{SHA}'):
372                    data[name] = passwd[5:]
373        return data
374
375
376class BasicAuthMixIn(object):
377    def _execute(self, transforms, *args, **kwargs):
378        def do_execute():
379            sp = super(BasicAuthMixIn, self)
380            return sp._execute(transforms, *args, **kwargs)
381
382        def auth_required():
383            stream = getattr(self, 'stream', self.request.connection.stream)
384            stream.write(tornado.escape.utf8(
385                'HTTP/1.1 401 Authorization Required\r\n'
386                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
387            ))
388            stream.close()
389
390        try:
391            if not self._htpasswd:
392                return do_execute()
393
394            creds = self.request.headers.get('Authorization')
395
396            if not creds or not creds.startswith('Basic '):
397                return auth_required()
398
399            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
400
401            if self._htpasswd.load().auth(name, passwd):
402                return do_execute()
403        except:
404            traceback.print_exc()
405
406        return auth_required()
407
408
409class ServerHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
410    IFTYPE = 'server'
411    IFOP_ALLOWED = False
412
413    def __init__(self, app, req, switch, htpasswd, debug):
414        super(ServerHandler, self).__init__(app, req)
415        self._switch = switch
416        self._htpasswd = htpasswd
417        self._debug = debug
418
419    @property
420    def target(self):
421        conn = self.request.connection
422        if hasattr(conn, 'address'):
423            return ':'.join(str(e) for e in conn.address)
424        if hasattr(conn, 'context'):
425            return ':'.join(str(e) for e in conn.context.address)
426        return str(conn)
427
428    def open(self):
429        try:
430            return self._switch.register_port(self)
431        finally:
432            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
433
434    def on_message(self, message):
435        self._switch.receive(self, EthernetFrame(message))
436
437    def on_close(self):
438        self._switch.unregister_port(self)
439        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
440
441
442class BaseClientHandler(DebugMixIn):
443    IFTYPE = 'baseclient'
444    IFOP_ALLOWED = False
445
446    def __init__(self, ioloop, switch, target, debug, *args, **kwargs):
447        self._ioloop = ioloop
448        self._switch = switch
449        self._target = target
450        self._debug = debug
451        self._args = args
452        self._kwargs = kwargs
453        self._device = None
454
455    @property
456    def address(self):
457        raise NotImplementedError('unsupported')
458
459    @property
460    def netmask(self):
461        raise NotImplementedError('unsupported')
462
463    @property
464    def mtu(self):
465        raise NotImplementedError('unsupported')
466
467    @address.setter
468    def address(self, address):
469        raise NotImplementedError('unsupported')
470
471    @netmask.setter
472    def netmask(self, netmask):
473        raise NotImplementedError('unsupported')
474
475    @mtu.setter
476    def mtu(self, mtu):
477        raise NotImplementedError('unsupported')
478
479    def open(self):
480        raise NotImplementedError('unsupported')
481
482    def write_message(self, message, binary=False):
483        raise NotImplementedError('unsupported')
484
485    def read(self):
486        raise NotImplementedError('unsupported')
487
488    @property
489    def target(self):
490        return self._target
491
492    @property
493    def device(self):
494        return self._device
495
496    @property
497    def closed(self):
498        return not self.device
499
500    def close(self):
501        if self.closed:
502            raise ValueError('I/O operation on closed %s' % self.IFTYPE)
503        self.leave_switch()
504        self.unregister_device()
505        self.dprintf('disconnected: %s\n', lambda: self.target)
506
507    def register_device(self, device):
508        self._device = device
509
510    def unregister_device(self):
511        self._device.close()
512        self._device = None
513
514    def fileno(self):
515        if self.closed:
516            raise ValueError('I/O operation on closed %s' % self.IFTYPE)
517        return self.device.fileno()
518
519    def __call__(self, fd, events):
520        try:
521            data = self.read()
522            if data is not None:
523                self._switch.receive(self, EthernetFrame(data))
524                return
525        except:
526            traceback.print_exc()
527        self.close()
528
529    def join_switch(self):
530        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
531        return self._switch.register_port(self)
532
533    def leave_switch(self):
534        self._switch.unregister_port(self)
535        self._ioloop.remove_handler(self.fileno())
536
537
538class NetdevHandler(BaseClientHandler):
539    IFTYPE = 'netdev'
540    IFOP_ALLOWED = True
541    ETH_P_ALL = 0x0003  # from <linux/if_ether.h>
542
543    @property
544    def address(self):
545        if self.closed:
546            raise ValueError('I/O operation on closed netdev')
547        return NetworkInterface(self.target).address
548
549    @property
550    def netmask(self):
551        if self.closed:
552            raise ValueError('I/O operation on closed netdev')
553        return NetworkInterface(self.target).netmask
554
555    @property
556    def mtu(self):
557        if self.closed:
558            raise ValueError('I/O operation on closed netdev')
559        return NetworkInterface(self.target).mtu
560
561    @address.setter
562    def address(self, address):
563        if self.closed:
564            raise ValueError('I/O operation on closed netdev')
565        NetworkInterface(self.target).address = address
566
567    @netmask.setter
568    def netmask(self, netmask):
569        if self.closed:
570            raise ValueError('I/O operation on closed netdev')
571        NetworkInterface(self.target).netmask = netmask
572
573    @mtu.setter
574    def mtu(self, mtu):
575        if self.closed:
576            raise ValueError('I/O operation on closed netdev')
577        NetworkInterface(self.target).mtu = mtu
578
579    def open(self):
580        if not self.closed:
581            raise ValueError('Already opened')
582        self.register_device(socket.socket(
583            socket.PF_PACKET, socket.SOCK_RAW, socket.htons(self.ETH_P_ALL)))
584        self.device.bind((self.target, self.ETH_P_ALL))
585        self.dprintf('connected: %s\n', lambda: self.target)
586        return self.join_switch()
587
588    def write_message(self, message, binary=False):
589        if self.closed:
590            raise ValueError('I/O operation on closed netdev')
591        self.device.sendall(message)
592
593    def read(self):
594        if self.closed:
595            raise ValueError('I/O operation on closed netdev')
596        buf = []
597        while True:
598            buf.append(self.device.recv(65535))
599            if len(buf[-1]) < 65535:
600                break
601        return ''.join(buf)
602
603
604class TapHandler(BaseClientHandler):
605    IFTYPE = 'tap'
606    IFOP_ALLOWED = True
607
608    @property
609    def address(self):
610        if self.closed:
611            raise ValueError('I/O operation on closed tap')
612        try:
613            return self.device.addr
614        except:
615            return ''
616
617    @property
618    def netmask(self):
619        if self.closed:
620            raise ValueError('I/O operation on closed tap')
621        try:
622            return self.device.netmask
623        except:
624            return ''
625
626    @property
627    def mtu(self):
628        if self.closed:
629            raise ValueError('I/O operation on closed tap')
630        return self.device.mtu
631
632    @address.setter
633    def address(self, address):
634        if self.closed:
635            raise ValueError('I/O operation on closed tap')
636        self.device.addr = address
637
638    @netmask.setter
639    def netmask(self, netmask):
640        if self.closed:
641            raise ValueError('I/O operation on closed tap')
642        self.device.netmask = netmask
643
644    @mtu.setter
645    def mtu(self, mtu):
646        if self.closed:
647            raise ValueError('I/O operation on closed tap')
648        self.device.mtu = mtu
649
650    @property
651    def target(self):
652        if self.closed:
653            return self._target
654        return self.device.name
655
656    def open(self):
657        if not self.closed:
658            raise ValueError('Already opened')
659        self.register_device(TunTapDevice(self.target, IFF_TAP | IFF_NO_PI))
660        self.device.up()
661        self.dprintf('connected: %s\n', lambda: self.target)
662        return self.join_switch()
663
664    def write_message(self, message, binary=False):
665        if self.closed:
666            raise ValueError('I/O operation on closed tap')
667        self.device.write(message)
668
669    def read(self):
670        if self.closed:
671            raise ValueError('I/O operation on closed tap')
672        buf = []
673        while True:
674            buf.append(self.device.read(65535))
675            if len(buf[-1]) < 65535:
676                break
677        return ''.join(buf)
678
679
680class ClientHandler(BaseClientHandler):
681    IFTYPE = 'client'
682    IFOP_ALLOWED = False
683
684    def __init__(self, *args, **kwargs):
685        super(ClientHandler, self).__init__(*args, **kwargs)
686
687        self._sslopt = kwargs.get('sslopt', {})
688        self._options = self._getproxy()
689
690        cred = kwargs.get('cred', None)
691
692        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
693            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
694            auth = ['Authorization: Basic %s' % token]
695            self._options['header'] = auth
696
697    def open(self):
698        if not self.closed:
699            raise websocket.WebSocketException('Already opened')
700
701        # XXX: may be blocked
702        self.register_device(websocket.WebSocket(sslopt=self._sslopt))
703        self.device.connect(self.target, **self._options)
704        self.dprintf('connected: %s\n', lambda: self.target)
705        return self.join_switch()
706
707    def write_message(self, message, binary=False):
708        if self.closed:
709            raise websocket.WebSocketException('Closed socket')
710        if binary:
711            flag = websocket.ABNF.OPCODE_BINARY
712        else:
713            flag = websocket.ABNF.OPCODE_TEXT
714        self.device.send(message, flag)
715
716    def read(self):
717        if self.closed:
718            raise websocket.WebSocketException('Closed socket')
719        return self.device.recv()
720
721    def _getproxy(self):
722        scheme, remain = urllib2.splittype(self.target)
723
724        proxy = urllib2.getproxies().get(scheme.replace('ws', 'http'))
725        if not proxy:
726            return {}
727
728        host = urllib2.splitport(urllib2.splithost(remain)[0])[0]
729        if urllib2.proxy_bypass(host):
730            return {}
731
732        hostport = urllib2.splithost(urllib2.splittype(proxy)[1])[0]
733        host, port = urllib2.splitport(hostport)
734        port = int(port) if port else 0
735        return {'http_proxy_host': host, 'http_proxy_port': port}
736
737
738class ControlServerHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
739    NAMESPACE = 'etherws.control'
740    IFTYPES = {
741        NetdevHandler.IFTYPE: NetdevHandler,
742        TapHandler.IFTYPE:    TapHandler,
743        ClientHandler.IFTYPE: ClientHandler,
744    }
745
746    def __init__(self, app, req, ioloop, switch, htpasswd, debug):
747        super(ControlServerHandler, self).__init__(app, req)
748        self._ioloop = ioloop
749        self._switch = switch
750        self._htpasswd = htpasswd
751        self._debug = debug
752
753    def post(self):
754        try:
755            request = json.loads(self.request.body)
756        except Exception as e:
757            return self._jsonrpc_response(error={
758                'code':    0 - 32700,
759                'message': 'Parse error',
760                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
761            })
762
763        try:
764            id_ = request.get('id')
765            params = request.get('params')
766            version = request['jsonrpc']
767            method = request['method']
768            if version != '2.0':
769                raise ValueError('Invalid JSON-RPC version: %s' % version)
770        except Exception as e:
771            return self._jsonrpc_response(id_=id_, error={
772                'code':    0 - 32600,
773                'message': 'Invalid Request',
774                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
775            })
776
777        try:
778            if not method.startswith(self.NAMESPACE + '.'):
779                raise ValueError('Invalid method namespace: %s' % method)
780            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
781            handler = getattr(self, handler)
782        except Exception as e:
783            return self._jsonrpc_response(id_=id_, error={
784                'code':    0 - 32601,
785                'message': 'Method not found',
786                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
787            })
788
789        try:
790            return self._jsonrpc_response(id_=id_, result=handler(params))
791        except Exception as e:
792            traceback.print_exc()
793            return self._jsonrpc_response(id_=id_, error={
794                'code':    0 - 32602,
795                'message': 'Invalid params',
796                'data':     '%s: %s' % (e.__class__.__name__, str(e)),
797            })
798
799    def handle_listFdb(self, params):
800        list_ = []
801        for vid, mac, entry in self._switch.fdb.each():
802            list_.append({
803                'vid':  vid,
804                'mac':  EthernetFrame.format_mac(mac),
805                'port': entry.port.number,
806                'age':  int(entry.age),
807            })
808        return {'entries': list_}
809
810    def handle_listPort(self, params):
811        return {'entries': [self._portstat(p) for p in self._switch.portlist]}
812
813    def handle_addPort(self, params):
814        type_ = params['type']
815        target = params['target']
816        opt = getattr(self, '_optparse_' + type_)(params.get('options', {}))
817        cls = self.IFTYPES[type_]
818        interface = cls(self._ioloop, self._switch, target, self._debug, **opt)
819        portnum = interface.open()
820        return {'entries': [self._portstat(self._switch.get_port(portnum))]}
821
822    def handle_setPort(self, params):
823        port = self._switch.get_port(int(params['port']))
824        shut = params.get('shut')
825        if shut is not None:
826            port.shut = bool(shut)
827        return {'entries': [self._portstat(port)]}
828
829    def handle_delPort(self, params):
830        port = self._switch.get_port(int(params['port']))
831        port.interface.close()
832        return {'entries': [self._portstat(port)]}
833
834    def handle_setInterface(self, params):
835        portnum = int(params['port'])
836        port = self._switch.get_port(portnum)
837        address = params.get('address')
838        netmask = params.get('netmask')
839        mtu = params.get('mtu')
840        if not port.interface.IFOP_ALLOWED:
841            raise ValueError('Port %d has unsupported interface: %s' %
842                             (portnum, port.interface.IFTYPE))
843        if address is not None:
844            port.interface.address = address
845        if netmask is not None:
846            port.interface.netmask = netmask
847        if mtu is not None:
848            port.interface.mtu = mtu
849        return {'entries': [self._ifstat(port)]}
850
851    def handle_listInterface(self, params):
852        return {'entries': [self._ifstat(p) for p in self._switch.portlist
853                            if p.interface.IFOP_ALLOWED]}
854
855    def _optparse_netdev(self, opt):
856        return {}
857
858    def _optparse_tap(self, opt):
859        return {}
860
861    def _optparse_client(self, opt):
862        if opt.get('insecure'):
863            sslopt = {'cert_reqs': ssl.CERT_NONE}
864        else:
865            sslopt = {'cert_reqs': ssl.CERT_REQUIRED,
866                      'ca_certs':  opt.get('cacerts')}
867        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
868        return {'sslopt': sslopt, 'cred': cred}
869
870    def _jsonrpc_response(self, id_=None, result=None, error=None):
871        res = {'jsonrpc': '2.0', 'id': id_}
872        if result:
873            res['result'] = result
874        if error:
875            res['error'] = error
876        self.finish(res)
877
878    @staticmethod
879    def _portstat(port):
880        return {
881            'port':   port.number,
882            'type':   port.interface.IFTYPE,
883            'target': port.interface.target,
884            'tx':     port.tx,
885            'rx':     port.rx,
886            'shut':   port.shut,
887        }
888
889    @staticmethod
890    def _ifstat(port):
891        return {
892            'port':    port.number,
893            'type':    port.interface.IFTYPE,
894            'target':  port.interface.target,
895            'address': port.interface.address,
896            'netmask': port.interface.netmask,
897            'mtu':     port.interface.mtu,
898        }
899
900
901def _print_error(error):
902    print(%s (%s)' % (error['message'], error['code']))
903    print('    %s' % error['data'])
904
905
906def _start_sw(args):
907    def daemonize(nochdir=False, noclose=False):
908        if os.fork() > 0:
909            sys.exit(0)
910
911        os.setsid()
912
913        if os.fork() > 0:
914            sys.exit(0)
915
916        if not nochdir:
917            os.chdir('/')
918
919        if not noclose:
920            os.umask(0)
921            sys.stdin.close()
922            sys.stdout.close()
923            sys.stderr.close()
924            os.close(0)
925            os.close(1)
926            os.close(2)
927            sys.stdin = open(os.devnull)
928            sys.stdout = open(os.devnull, 'a')
929            sys.stderr = open(os.devnull, 'a')
930
931    def checkabspath(ns, path):
932        val = getattr(ns, path, '')
933        if not val.startswith('/'):
934            raise ValueError('Invalid %: %s' % (path, val))
935
936    def getsslopt(ns, key, cert):
937        kval = getattr(ns, key, None)
938        cval = getattr(ns, cert, None)
939        if kval and cval:
940            return {'keyfile': kval, 'certfile': cval}
941        elif kval or cval:
942            raise ValueError('Both %s and %s are required' % (key, cert))
943        return None
944
945    def setrealpath(ns, *keys):
946        for k in keys:
947            v = getattr(ns, k, None)
948            if v is not None:
949                v = os.path.realpath(v)
950                open(v).close()  # check readable
951                setattr(ns, k, v)
952
953    def setport(ns, port, isssl):
954        val = getattr(ns, port, None)
955        if val is None:
956            if isssl:
957                return setattr(ns, port, 443)
958            return setattr(ns, port, 80)
959        if not (0 <= val <= 65535):
960            raise ValueError('Invalid %s: %s' % (port, val))
961
962    def sethtpasswd(ns, htpasswd):
963        val = getattr(ns, htpasswd, None)
964        if val:
965            return setattr(ns, htpasswd, Htpasswd(val))
966
967    # if args.debug:
968    #     websocket.enableTrace(True)
969
970    if args.ageout <= 0:
971        raise ValueError('Invalid ageout: %s' % args.ageout)
972
973    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
974    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
975
976    checkabspath(args, 'path')
977    checkabspath(args, 'ctlpath')
978
979    sslopt = getsslopt(args, 'sslkey', 'sslcert')
980    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
981
982    setport(args, 'port', sslopt)
983    setport(args, 'ctlport', ctlsslopt)
984
985    sethtpasswd(args, 'htpasswd')
986    sethtpasswd(args, 'ctlhtpasswd')
987
988    ioloop = IOLoop.instance()
989    fdb = FDB(args.ageout, args.debug)
990    switch = SwitchingHub(fdb, args.debug)
991
992    if args.port == args.ctlport and args.host == args.ctlhost:
993        if args.path == args.ctlpath:
994            raise ValueError('Same path/ctlpath on same host')
995        if args.sslkey != args.ctlsslkey:
996            raise ValueError('Different sslkey/ctlsslkey on same host')
997        if args.sslcert != args.ctlsslcert:
998            raise ValueError('Different sslcert/ctlsslcert on same host')
999
1000        app = Application([
1001            (args.path, ServerHandler, {
1002                'switch':   switch,
1003                'htpasswd': args.htpasswd,
1004                'debug':    args.debug,
1005            }),
1006            (args.ctlpath, ControlServerHandler, {
1007                'ioloop':   ioloop,
1008                'switch':   switch,
1009                'htpasswd': args.ctlhtpasswd,
1010                'debug':    args.debug,
1011            }),
1012        ])
1013        server = HTTPServer(app, ssl_options=sslopt)
1014        server.listen(args.port, address=args.host)
1015
1016    else:
1017        app = Application([(args.path, ServerHandler, {
1018            'switch':   switch,
1019            'htpasswd': args.htpasswd,
1020            'debug':    args.debug,
1021        })])
1022        server = HTTPServer(app, ssl_options=sslopt)
1023        server.listen(args.port, address=args.host)
1024
1025        ctl = Application([(args.ctlpath, ControlServerHandler, {
1026            'ioloop':   ioloop,
1027            'switch':   switch,
1028            'htpasswd': args.ctlhtpasswd,
1029            'debug':    args.debug,
1030        })])
1031        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
1032        ctlserver.listen(args.ctlport, address=args.ctlhost)
1033
1034    if not args.foreground:
1035        daemonize()
1036
1037    ioloop.start()
1038
1039
1040def _start_ctl(args):
1041    def have_ssl_cert_verification():
1042        return 'context' in urllib2.urlopen.__code__.co_varnames
1043
1044    def request(args, method, params=None, id_=0):
1045        req = urllib2.Request(args.ctlurl)
1046        req.add_header('Content-type', 'application/json')
1047        if args.ctluser:
1048            if not args.ctlpasswd:
1049                args.ctlpasswd = getpass.getpass('Control Password: ')
1050            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
1051            req.add_header('Authorization', 'Basic %s' % token)
1052        method = '.'.join([ControlServerHandler.NAMESPACE, method])
1053        data = {'jsonrpc': '2.0', 'method': method, 'id': id_}
1054        if params is not None:
1055            data['params'] = params
1056        if have_ssl_cert_verification():
1057            ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH,
1058                                             cafile=args.ctlsslcert)
1059            if args.ctlinsecure:
1060                ctx.check_hostname = False
1061                ctx.verify_mode = ssl.CERT_NONE
1062            fp = urllib2.urlopen(req, json.dumps(data), context=ctx)
1063        elif args.ctlsslcert:
1064            raise EnvironmentError('do not support certificate verification')
1065        else:
1066            fp = urllib2.urlopen(req, json.dumps(data))
1067        return json.loads(fp.read())
1068
1069    def print_table(rows):
1070        cols = zip(*rows)
1071        maxlen = [0] * len(cols)
1072        for i in xrange(len(cols)):
1073            maxlen[i] = max(len(str(c)) for c in cols[i])
1074        fmt = '  '.join(['%%-%ds' % maxlen[i] for i in xrange(len(cols))])
1075        fmt = '  ' + fmt
1076        for row in rows:
1077            print(fmt % tuple(row))
1078
1079    def print_portlist(result):
1080        rows = [['Port', 'Type', 'State', 'RX', 'TX', 'Target']]
1081        for r in result:
1082            rows.append([r['port'], r['type'], 'shut' if r['shut'] else 'up',
1083                         r['rx'], r['tx'], r['target']])
1084        print_table(rows)
1085
1086    def print_iflist(result):
1087        rows = [['Port', 'Type', 'Address', 'Netmask', 'MTU', 'Target']]
1088        for r in result:
1089            rows.append([r['port'], r['type'], r['address'],
1090                         r['netmask'], r['mtu'], r['target']])
1091        print_table(rows)
1092
1093    def handle_ctl_addport(args):
1094        opts = {
1095            'user':     getattr(args, 'user', None),
1096            'passwd':   getattr(args, 'passwd', None),
1097            'cacerts':  getattr(args, 'cacerts', None),
1098            'insecure': getattr(args, 'insecure', None),
1099        }
1100        if args.iftype == ClientHandler.IFTYPE:
1101            if not args.target.startswith('ws://') and \
1102               not args.target.startswith('wss://'):
1103                raise ValueError('Invalid target URL scheme: %s' % args.target)
1104            if not opts['user'] and opts['passwd']:
1105                raise ValueError('Authentication required but username empty')
1106            if opts['user'] and not opts['passwd']:
1107                opts['passwd'] = getpass.getpass('Client Password: ')
1108        result = request(args, 'addPort', {
1109            'type':    args.iftype,
1110            'target':  args.target,
1111            'options': opts,
1112        })
1113        if 'error' in result:
1114            _print_error(result['error'])
1115        else:
1116            print_portlist(result['result']['entries'])
1117
1118    def handle_ctl_setport(args):
1119        if args.port <= 0:
1120            raise ValueError('Invalid port: %d' % args.port)
1121        req = {'port': args.port}
1122        shut = getattr(args, 'shut', None)
1123        if shut is not None:
1124            req['shut'] = bool(shut)
1125        result = request(args, 'setPort', req)
1126        if 'error' in result:
1127            _print_error(result['error'])
1128        else:
1129            print_portlist(result['result']['entries'])
1130
1131    def handle_ctl_delport(args):
1132        if args.port <= 0:
1133            raise ValueError('Invalid port: %d' % args.port)
1134        result = request(args, 'delPort', {'port': args.port})
1135        if 'error' in result:
1136            _print_error(result['error'])
1137        else:
1138            print_portlist(result['result']['entries'])
1139
1140    def handle_ctl_listport(args):
1141        result = request(args, 'listPort')
1142        if 'error' in result:
1143            _print_error(result['error'])
1144        else:
1145            print_portlist(result['result']['entries'])
1146
1147    def handle_ctl_setif(args):
1148        if args.port <= 0:
1149            raise ValueError('Invalid port: %d' % args.port)
1150        req = {'port': args.port}
1151        address = getattr(args, 'address', None)
1152        netmask = getattr(args, 'netmask', None)
1153        mtu = getattr(args, 'mtu', None)
1154        if address is not None:
1155            if address:
1156                socket.inet_aton(address)  # validate
1157            req['address'] = address
1158        if netmask is not None:
1159            if netmask:
1160                socket.inet_aton(netmask)  # validate
1161            req['netmask'] = netmask
1162        if mtu is not None:
1163            if mtu < 576:
1164                raise ValueError('Invalid MTU: %d' % mtu)
1165            req['mtu'] = mtu
1166        result = request(args, 'setInterface', req)
1167        if 'error' in result:
1168            _print_error(result['error'])
1169        else:
1170            print_iflist(result['result']['entries'])
1171
1172    def handle_ctl_listif(args):
1173        result = request(args, 'listInterface')
1174        if 'error' in result:
1175            _print_error(result['error'])
1176        else:
1177            print_iflist(result['result']['entries'])
1178
1179    def handle_ctl_listfdb(args):
1180        result = request(args, 'listFdb')
1181        if 'error' in result:
1182            return _print_error(result['error'])
1183        rows = [['Port', 'VLAN', 'MAC', 'Age']]
1184        for r in result['result']['entries']:
1185            rows.append([r['port'], r['vid'], r['mac'], r['age']])
1186        print_table(rows)
1187
1188    locals()['handle_ctl_' + args.control_method](args)
1189
1190
1191def _main():
1192    parser = argparse.ArgumentParser()
1193    subcommand = parser.add_subparsers(dest='subcommand')
1194
1195    # - sw
1196    parser_sw = subcommand.add_parser('sw',
1197                                      help='start virtual switch')
1198
1199    parser_sw.add_argument('--debug', action='store_true', default=False,
1200                           help='run as debug mode')
1201    parser_sw.add_argument('--foreground', action='store_true', default=False,
1202                           help='run as foreground mode')
1203    parser_sw.add_argument('--ageout', type=int, default=300,
1204                           help='FDB ageout time (sec)')
1205
1206    parser_sw.add_argument('--path', default='/',
1207                           help='http(s) path to serve WebSocket')
1208    parser_sw.add_argument('--host', default='',
1209                           help='listen address to serve WebSocket')
1210    parser_sw.add_argument('--port', type=int,
1211                           help='listen port to serve WebSocket')
1212    parser_sw.add_argument('--htpasswd',
1213                           help='path to htpasswd file to auth WebSocket')
1214    parser_sw.add_argument('--sslkey',
1215                           help='path to SSL private key for WebSocket')
1216    parser_sw.add_argument('--sslcert',
1217                           help='path to SSL certificate for WebSocket')
1218
1219    parser_sw.add_argument('--ctlpath', default='/ctl',
1220                           help='http(s) path to serve control API')
1221    parser_sw.add_argument('--ctlhost', default='127.0.0.1',
1222                           help='listen address to serve control API')
1223    parser_sw.add_argument('--ctlport', type=int, default=7867,
1224                           help='listen port to serve control API')
1225    parser_sw.add_argument('--ctlhtpasswd',
1226                           help='path to htpasswd file to auth control API')
1227    parser_sw.add_argument('--ctlsslkey',
1228                           help='path to SSL private key for control API')
1229    parser_sw.add_argument('--ctlsslcert',
1230                           help='path to SSL certificate for control API')
1231
1232    # - ctl
1233    parser_ctl = subcommand.add_parser('ctl',
1234                                       help='control virtual switch')
1235    parser_ctl.add_argument('--ctlurl', default='http://127.0.0.1:7867/ctl',
1236                            help='URL to control API')
1237    parser_ctl.add_argument('--ctluser',
1238                            help='username to auth control API')
1239    parser_ctl.add_argument('--ctlpasswd',
1240                            help='password to auth control API')
1241    parser_ctl.add_argument('--ctlsslcert',
1242                            help='path to SSL certificate for control API')
1243    parser_ctl.add_argument(
1244        '--ctlinsecure', action='store_true', default=False,
1245        help='do not verify control API certificate')
1246
1247    control_method = parser_ctl.add_subparsers(dest='control_method')
1248
1249    # -- ctl addport
1250    parser_ctl_addport = control_method.add_parser('addport',
1251                                                   help='create and add port')
1252    iftype = parser_ctl_addport.add_subparsers(dest='iftype')
1253
1254    # --- ctl addport netdev
1255    parser_ctl_addport_netdev = iftype.add_parser(NetdevHandler.IFTYPE,
1256                                                  help='Network device')
1257    parser_ctl_addport_netdev.add_argument('target',
1258                                           help='device name to add interface')
1259
1260    # --- ctl addport tap
1261    parser_ctl_addport_tap = iftype.add_parser(TapHandler.IFTYPE,
1262                                               help='TAP device')
1263    parser_ctl_addport_tap.add_argument('target',
1264                                        help='device name to create interface')
1265
1266    # --- ctl addport client
1267    parser_ctl_addport_client = iftype.add_parser(ClientHandler.IFTYPE,
1268                                                  help='WebSocket client')
1269    parser_ctl_addport_client.add_argument('target',
1270                                           help='URL to connect WebSocket')
1271    parser_ctl_addport_client.add_argument('--user',
1272                                           help='username to auth WebSocket')
1273    parser_ctl_addport_client.add_argument('--passwd',
1274                                           help='password to auth WebSocket')
1275    parser_ctl_addport_client.add_argument('--cacerts',
1276                                           help='path to CA certificate')
1277    parser_ctl_addport_client.add_argument(
1278        '--insecure', action='store_true', default=False,
1279        help='do not verify server certificate')
1280
1281    # -- ctl setport
1282    parser_ctl_setport = control_method.add_parser('setport',
1283                                                   help='set port config')
1284    parser_ctl_setport.add_argument('port', type=int,
1285                                    help='port number to set config')
1286    parser_ctl_setport.add_argument('--shut', type=int, choices=(0, 1),
1287                                    help='set shutdown state')
1288
1289    # -- ctl delport
1290    parser_ctl_delport = control_method.add_parser('delport',
1291                                                   help='delete port')
1292    parser_ctl_delport.add_argument('port', type=int,
1293                                    help='port number to delete')
1294
1295    # -- ctl listport
1296    parser_ctl_listport = control_method.add_parser('listport',
1297                                                    help='show port list')
1298
1299    # -- ctl setif
1300    parser_ctl_setif = control_method.add_parser('setif',
1301                                                 help='set interface config')
1302    parser_ctl_setif.add_argument('port', type=int,
1303                                  help='port number to set config')
1304    parser_ctl_setif.add_argument('--address',
1305                                  help='IPv4 address to set interface')
1306    parser_ctl_setif.add_argument('--netmask',
1307                                  help='IPv4 netmask to set interface')
1308    parser_ctl_setif.add_argument('--mtu', type=int,
1309                                  help='MTU to set interface')
1310
1311    # -- ctl listif
1312    parser_ctl_listif = control_method.add_parser('listif',
1313                                                  help='show interface list')
1314
1315    # -- ctl listfdb
1316    parser_ctl_listfdb = control_method.add_parser('listfdb',
1317                                                   help='show FDB entries')
1318
1319    # -- go
1320    args = parser.parse_args()
1321
1322    try:
1323        globals()['_start_' + args.subcommand](args)
1324    except Exception as e:
1325        _print_error({
1326            'code':    0 - 32603,
1327            'message': 'Internal error',
1328            'data':    '%s: %s' % (e.__class__.__name__, str(e)),
1329        })
1330
1331
1332if __name__ == '__main__':
1333    _main()
Note: See TracBrowser for help on using the repository browser.