source: etherws/trunk/etherws.py @ 251

Revision 251, 40.3 KB checked in by atzm, 11 years ago (diff)

add netdev support

  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.3
11#
12# ===========================================================================
13# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
42import ssl
43import time
44import json
45import fcntl
46import base64
47import socket
48import urllib2
49import hashlib
50import getpass
51import argparse
52import traceback
53
54import tornado
55import websocket
56
57from tornado.web import Application, RequestHandler
58from tornado.websocket import WebSocketHandler
59from tornado.httpserver import HTTPServer
60from tornado.ioloop import IOLoop
61
62from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
63
64
65class DebugMixIn(object):
66    def dprintf(self, msg, func=lambda: ()):
67        if self._debug:
68            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
69            sys.stderr.write(prefix + (msg % func()))
70
71
72class EthernetFrame(object):
73    def __init__(self, data):
74        self.data = data
75
76    @property
77    def dst_multicast(self):
78        return ord(self.data[0]) & 1
79
80    @property
81    def src_multicast(self):
82        return ord(self.data[6]) & 1
83
84    @property
85    def dst_mac(self):
86        return self.data[:6]
87
88    @property
89    def src_mac(self):
90        return self.data[6:12]
91
92    @property
93    def tagged(self):
94        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
95
96    @property
97    def vid(self):
98        if self.tagged:
99            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
100        return 0
101
102    @staticmethod
103    def format_mac(mac, sep=':'):
104        return sep.join(b.encode('hex') for b in mac)
105
106
107class FDB(DebugMixIn):
108    class Entry(object):
109        def __init__(self, port, ageout):
110            self.port = port
111            self._time = time.time()
112            self._ageout = ageout
113
114        @property
115        def age(self):
116            return time.time() - self._time
117
118        @property
119        def agedout(self):
120            return self.age > self._ageout
121
122    def __init__(self, ageout, debug=False):
123        self._ageout = ageout
124        self._debug = debug
125        self._table = {}
126
127    def _set_entry(self, vid, mac, port):
128        if vid not in self._table:
129            self._table[vid] = {}
130        self._table[vid][mac] = self.Entry(port, self._ageout)
131
132    def _del_entry(self, vid, mac):
133        if vid in self._table:
134            if mac in self._table[vid]:
135                del self._table[vid][mac]
136            if not self._table[vid]:
137                del self._table[vid]
138
139    def _get_entry(self, vid, mac):
140        try:
141            entry = self._table[vid][mac]
142        except KeyError:
143            return None
144
145        if not entry.agedout:
146            return entry
147
148        self._del_entry(vid, mac)
149        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
150                     lambda: (entry.port.number, vid, mac.encode('hex')))
151
152    def each(self):
153        for vid in sorted(self._table.iterkeys()):
154            for mac in sorted(self._table[vid].iterkeys()):
155                entry = self._get_entry(vid, mac)
156                if entry:
157                    yield (vid, mac, entry)
158
159    def lookup(self, frame):
160        mac = frame.dst_mac
161        vid = frame.vid
162        entry = self._get_entry(vid, mac)
163        return getattr(entry, 'port', None)
164
165    def learn(self, port, frame):
166        mac = frame.src_mac
167        vid = frame.vid
168        self._set_entry(vid, mac, port)
169        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
170                     lambda: (port.number, vid, mac.encode('hex')))
171
172    def delete(self, port):
173        for vid, mac, entry in self.each():
174            if entry.port.number == port.number:
175                self._del_entry(vid, mac)
176                self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
177                             lambda: (port.number, vid, mac.encode('hex')))
178
179
180class SwitchingHub(DebugMixIn):
181    class Port(object):
182        def __init__(self, number, interface):
183            self.number = number
184            self.interface = interface
185            self.tx = 0
186            self.rx = 0
187            self.shut = False
188
189        @staticmethod
190        def cmp_by_number(x, y):
191            return cmp(x.number, y.number)
192
193    def __init__(self, fdb, debug=False):
194        self.fdb = fdb
195        self._debug = debug
196        self._table = {}
197        self._next = 1
198
199    @property
200    def portlist(self):
201        return sorted(self._table.itervalues(), cmp=self.Port.cmp_by_number)
202
203    def get_port(self, portnum):
204        return self._table[portnum]
205
206    def register_port(self, interface):
207        try:
208            self._set_privattr('portnum', interface, self._next)  # XXX
209            self._table[self._next] = self.Port(self._next, interface)
210            return self._next
211        finally:
212            self._next += 1
213
214    def unregister_port(self, interface):
215        portnum = self._get_privattr('portnum', interface)
216        self._del_privattr('portnum', interface)
217        self.fdb.delete(self._table[portnum])
218        del self._table[portnum]
219
220    def send(self, dst_interfaces, frame):
221        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
222        ports = (self._table[n] for n in portnums)
223        ports = (p for p in ports if not p.shut)
224        ports = sorted(ports, cmp=self.Port.cmp_by_number)
225
226        for p in ports:
227            p.interface.write_message(frame.data, True)
228            p.tx += 1
229
230        if ports:
231            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
232                         lambda: (','.join(str(p.number) for p in ports),
233                                  frame.vid,
234                                  frame.src_mac.encode('hex'),
235                                  frame.dst_mac.encode('hex')))
236
237    def receive(self, src_interface, frame):
238        port = self._table[self._get_privattr('portnum', src_interface)]
239
240        if not port.shut:
241            port.rx += 1
242            self._forward(port, frame)
243
244    def _forward(self, src_port, frame):
245        try:
246            if not frame.src_multicast:
247                self.fdb.learn(src_port, frame)
248
249            if not frame.dst_multicast:
250                dst_port = self.fdb.lookup(frame)
251
252                if dst_port:
253                    self.send([dst_port.interface], frame)
254                    return
255
256            ports = set(self.portlist) - set([src_port])
257            self.send((p.interface for p in ports), frame)
258
259        except:  # ex. received invalid frame
260            traceback.print_exc()
261
262    def _privattr(self, name):
263        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
264
265    def _set_privattr(self, name, obj, value):
266        return setattr(obj, self._privattr(name), value)
267
268    def _get_privattr(self, name, obj, defaults=None):
269        return getattr(obj, self._privattr(name), defaults)
270
271    def _del_privattr(self, name, obj):
272        return delattr(obj, self._privattr(name))
273
274
275class Htpasswd(object):
276    def __init__(self, path):
277        self._path = path
278        self._stat = None
279        self._data = {}
280
281    def auth(self, name, passwd):
282        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
283        return self._data.get(name) == passwd
284
285    def load(self):
286        old_stat = self._stat
287
288        with open(self._path) as fp:
289            fileno = fp.fileno()
290            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
291            self._stat = os.fstat(fileno)
292
293            unchanged = old_stat and \
294                        old_stat.st_ino == self._stat.st_ino and \
295                        old_stat.st_dev == self._stat.st_dev and \
296                        old_stat.st_mtime == self._stat.st_mtime
297
298            if not unchanged:
299                self._data = self._parse(fp)
300
301        return self
302
303    def _parse(self, fp):
304        data = {}
305        for line in fp:
306            line = line.strip()
307            if 0 <= line.find(':'):
308                name, passwd = line.split(':', 1)
309                if passwd.startswith('{SHA}'):
310                    data[name] = passwd[5:]
311        return data
312
313
314class BasicAuthMixIn(object):
315    def _execute(self, transforms, *args, **kwargs):
316        def do_execute():
317            sp = super(BasicAuthMixIn, self)
318            return sp._execute(transforms, *args, **kwargs)
319
320        def auth_required():
321            stream = getattr(self, 'stream', self.request.connection.stream)
322            stream.write(tornado.escape.utf8(
323                'HTTP/1.1 401 Authorization Required\r\n'
324                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
325            ))
326            stream.close()
327
328        try:
329            if not self._htpasswd:
330                return do_execute()
331
332            creds = self.request.headers.get('Authorization')
333
334            if not creds or not creds.startswith('Basic '):
335                return auth_required()
336
337            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
338
339            if self._htpasswd.load().auth(name, passwd):
340                return do_execute()
341        except:
342            traceback.print_exc()
343
344        return auth_required()
345
346
347class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
348    IFTYPE = 'server'
349
350    def __init__(self, app, req, switch, htpasswd=None, debug=False):
351        super(EtherWebSocketHandler, self).__init__(app, req)
352        self._switch = switch
353        self._htpasswd = htpasswd
354        self._debug = debug
355
356    @property
357    def target(self):
358        return ':'.join(str(e) for e in self.request.connection.address)
359
360    def open(self):
361        try:
362            return self._switch.register_port(self)
363        finally:
364            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
365
366    def on_message(self, message):
367        self._switch.receive(self, EthernetFrame(message))
368
369    def on_close(self):
370        self._switch.unregister_port(self)
371        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
372
373
374class NetdevHandler(DebugMixIn):
375    IFTYPE = 'netdev'
376    READ_SIZE = 65535
377    ETH_P_ALL = 0x0003  # from <linux/if_ether.h>
378
379    def __init__(self, ioloop, switch, dev, debug=False):
380        self._ioloop = ioloop
381        self._switch = switch
382        self._dev = dev
383        self._debug = debug
384        self._sock = None
385
386    @property
387    def target(self):
388        return self._dev
389
390    @property
391    def closed(self):
392        return not self._sock
393
394    def open(self):
395        if not self.closed:
396            raise ValueError('Already opened')
397        self._sock = socket.socket(
398            socket.PF_PACKET, socket.SOCK_RAW, socket.htons(self.ETH_P_ALL))
399        self._sock.bind((self._dev, self.ETH_P_ALL))
400        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
401        return self._switch.register_port(self)
402
403    def close(self):
404        if self.closed:
405            raise ValueError('I/O operation on closed netdev')
406        self._switch.unregister_port(self)
407        self._ioloop.remove_handler(self.fileno())
408        self._sock.close()
409        self._sock = None
410
411    def fileno(self):
412        if self.closed:
413            raise ValueError('I/O operation on closed netdev')
414        return self._sock.fileno()
415
416    def write_message(self, message, binary=False):
417        if self.closed:
418            raise ValueError('I/O operation on closed netdev')
419        self._sock.sendall(message)
420
421    def __call__(self, fd, events):
422        try:
423            self._switch.receive(self, EthernetFrame(self._read()))
424            return
425        except:
426            traceback.print_exc()
427        self.close()
428
429    def _read(self):
430        if self.closed:
431            raise ValueError('I/O operation on closed netdev')
432        buf = []
433        while True:
434            buf.append(self._sock.recv(self.READ_SIZE))
435            if len(buf[-1]) < self.READ_SIZE:
436                break
437        return ''.join(buf)
438
439
440class TapHandler(DebugMixIn):
441    IFTYPE = 'tap'
442    READ_SIZE = 65535
443
444    def __init__(self, ioloop, switch, dev, debug=False):
445        self._ioloop = ioloop
446        self._switch = switch
447        self._dev = dev
448        self._debug = debug
449        self._tap = None
450
451    @property
452    def target(self):
453        if self.closed:
454            return self._dev
455        return self._tap.name
456
457    @property
458    def closed(self):
459        return not self._tap
460
461    @property
462    def address(self):
463        if self.closed:
464            raise ValueError('I/O operation on closed tap')
465        try:
466            return self._tap.addr
467        except:
468            return ''
469
470    @property
471    def netmask(self):
472        if self.closed:
473            raise ValueError('I/O operation on closed tap')
474        try:
475            return self._tap.netmask
476        except:
477            return ''
478
479    @property
480    def mtu(self):
481        if self.closed:
482            raise ValueError('I/O operation on closed tap')
483        return self._tap.mtu
484
485    @address.setter
486    def address(self, address):
487        if self.closed:
488            raise ValueError('I/O operation on closed tap')
489        self._tap.addr = address
490
491    @netmask.setter
492    def netmask(self, netmask):
493        if self.closed:
494            raise ValueError('I/O operation on closed tap')
495        self._tap.netmask = netmask
496
497    @mtu.setter
498    def mtu(self, mtu):
499        if self.closed:
500            raise ValueError('I/O operation on closed tap')
501        self._tap.mtu = mtu
502
503    def open(self):
504        if not self.closed:
505            raise ValueError('Already opened')
506        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
507        self._tap.up()
508        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
509        return self._switch.register_port(self)
510
511    def close(self):
512        if self.closed:
513            raise ValueError('I/O operation on closed tap')
514        self._switch.unregister_port(self)
515        self._ioloop.remove_handler(self.fileno())
516        self._tap.close()
517        self._tap = None
518
519    def fileno(self):
520        if self.closed:
521            raise ValueError('I/O operation on closed tap')
522        return self._tap.fileno()
523
524    def write_message(self, message, binary=False):
525        if self.closed:
526            raise ValueError('I/O operation on closed tap')
527        self._tap.write(message)
528
529    def __call__(self, fd, events):
530        try:
531            self._switch.receive(self, EthernetFrame(self._read()))
532            return
533        except:
534            traceback.print_exc()
535        self.close()
536
537    def _read(self):
538        if self.closed:
539            raise ValueError('I/O operation on closed tap')
540        buf = []
541        while True:
542            buf.append(self._tap.read(self.READ_SIZE))
543            if len(buf[-1]) < self.READ_SIZE:
544                break
545        return ''.join(buf)
546
547
548class EtherWebSocketClient(DebugMixIn):
549    IFTYPE = 'client'
550
551    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
552        self._ioloop = ioloop
553        self._switch = switch
554        self._url = url
555        self._ssl = ssl_
556        self._debug = debug
557        self._sock = None
558        self._options = {}
559
560        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
561            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
562            auth = ['Authorization: Basic %s' % token]
563            self._options['header'] = auth
564
565    @property
566    def target(self):
567        return self._url
568
569    @property
570    def closed(self):
571        return not self._sock
572
573    def open(self):
574        sslwrap = websocket._SSLSocketWrapper
575
576        if not self.closed:
577            raise websocket.WebSocketException('Already opened')
578
579        if self._ssl:
580            websocket._SSLSocketWrapper = self._ssl
581
582        # XXX: may be blocked
583        try:
584            self._sock = websocket.WebSocket()
585            self._sock.connect(self._url, **self._options)
586            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
587            self.dprintf('connected: %s\n', lambda: self._url)
588            return self._switch.register_port(self)
589        finally:
590            websocket._SSLSocketWrapper = sslwrap
591
592    def close(self):
593        if self.closed:
594            raise websocket.WebSocketException('Already closed')
595        self._switch.unregister_port(self)
596        self._ioloop.remove_handler(self.fileno())
597        self._sock.close()
598        self._sock = None
599        self.dprintf('disconnected: %s\n', lambda: self._url)
600
601    def fileno(self):
602        if self.closed:
603            raise websocket.WebSocketException('Closed socket')
604        return self._sock.io_sock.fileno()
605
606    def write_message(self, message, binary=False):
607        if self.closed:
608            raise websocket.WebSocketException('Closed socket')
609        if binary:
610            flag = websocket.ABNF.OPCODE_BINARY
611        else:
612            flag = websocket.ABNF.OPCODE_TEXT
613        self._sock.send(message, flag)
614
615    def __call__(self, fd, events):
616        try:
617            data = self._sock.recv()
618            if data is not None:
619                self._switch.receive(self, EthernetFrame(data))
620                return
621        except:
622            traceback.print_exc()
623        self.close()
624
625
626class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
627    NAMESPACE = 'etherws.control'
628    IFTYPES = {
629        NetdevHandler.IFTYPE:        NetdevHandler,
630        TapHandler.IFTYPE:           TapHandler,
631        EtherWebSocketClient.IFTYPE: EtherWebSocketClient,
632    }
633
634    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
635        super(EtherWebSocketControlHandler, self).__init__(app, req)
636        self._ioloop = ioloop
637        self._switch = switch
638        self._htpasswd = htpasswd
639        self._debug = debug
640
641    def post(self):
642        try:
643            request = json.loads(self.request.body)
644        except Exception as e:
645            return self._jsonrpc_response(error={
646                'code':    0 - 32700,
647                'message': 'Parse error',
648                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
649            })
650
651        try:
652            id_ = request.get('id')
653            params = request.get('params')
654            version = request['jsonrpc']
655            method = request['method']
656            if version != '2.0':
657                raise ValueError('Invalid JSON-RPC version: %s' % version)
658        except Exception as e:
659            return self._jsonrpc_response(id_=id_, error={
660                'code':    0 - 32600,
661                'message': 'Invalid Request',
662                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
663            })
664
665        try:
666            if not method.startswith(self.NAMESPACE + '.'):
667                raise ValueError('Invalid method namespace: %s' % method)
668            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
669            handler = getattr(self, handler)
670        except Exception as e:
671            return self._jsonrpc_response(id_=id_, error={
672                'code':    0 - 32601,
673                'message': 'Method not found',
674                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
675            })
676
677        try:
678            return self._jsonrpc_response(id_=id_, result=handler(params))
679        except Exception as e:
680            traceback.print_exc()
681            return self._jsonrpc_response(id_=id_, error={
682                'code':    0 - 32602,
683                'message': 'Invalid params',
684                'data':     '%s: %s' % (e.__class__.__name__, str(e)),
685            })
686
687    def handle_listFdb(self, params):
688        list_ = []
689        for vid, mac, entry in self._switch.fdb.each():
690            list_.append({
691                'vid':  vid,
692                'mac':  EthernetFrame.format_mac(mac),
693                'port': entry.port.number,
694                'age':  int(entry.age),
695            })
696        return {'entries': list_}
697
698    def handle_listPort(self, params):
699        return {'entries': [self._portstat(p) for p in self._switch.portlist]}
700
701    def handle_addPort(self, params):
702        type_ = params['type']
703        target = params['target']
704        opts = getattr(self, '_optparse_' + type_)(params.get('options', {}))
705        cls = self.IFTYPES[type_]
706        interface = cls(self._ioloop, self._switch, target, **opts)
707        portnum = interface.open()
708        return {'entries': [self._portstat(self._switch.get_port(portnum))]}
709
710    def handle_setPort(self, params):
711        port = self._switch.get_port(int(params['port']))
712        shut = params.get('shut')
713        if shut is not None:
714            port.shut = bool(shut)
715        return {'entries': [self._portstat(port)]}
716
717    def handle_delPort(self, params):
718        port = self._switch.get_port(int(params['port']))
719        port.interface.close()
720        return {'entries': [self._portstat(port)]}
721
722    def handle_setInterface(self, params):
723        portnum = int(params['port'])
724        port = self._switch.get_port(portnum)
725        address = params.get('address')
726        netmask = params.get('netmask')
727        mtu = params.get('mtu')
728        if not isinstance(port.interface, TapHandler):
729            raise ValueError('Port %d has unsupported interface: %s' %
730                             (portnum, port.interface.IFTYPE))
731        if address is not None:
732            port.interface.address = address
733        if netmask is not None:
734            port.interface.netmask = netmask
735        if mtu is not None:
736            port.interface.mtu = mtu
737        return {'entries': [self._ifstat(port)]}
738
739    def handle_listInterface(self, params):
740        return {'entries': [self._ifstat(p) for p in self._switch.portlist
741                            if isinstance(p.interface, TapHandler)]}
742
743    def _optparse_netdev(self, opt):
744        return {'debug': self._debug}
745
746    def _optparse_tap(self, opt):
747        return {'debug': self._debug}
748
749    def _optparse_client(self, opt):
750        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
751        if opt.get('insecure'):
752            args = {}
753        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
754        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
755        return {'ssl_': ssl_, 'cred': cred, 'debug': self._debug}
756
757    def _jsonrpc_response(self, id_=None, result=None, error=None):
758        res = {'jsonrpc': '2.0', 'id': id_}
759        if result:
760            res['result'] = result
761        if error:
762            res['error'] = error
763        self.finish(res)
764
765    @staticmethod
766    def _portstat(port):
767        return {
768            'port':   port.number,
769            'type':   port.interface.IFTYPE,
770            'target': port.interface.target,
771            'tx':     port.tx,
772            'rx':     port.rx,
773            'shut':   port.shut,
774        }
775
776    @staticmethod
777    def _ifstat(port):
778        return {
779            'port':    port.number,
780            'type':    port.interface.IFTYPE,
781            'target':  port.interface.target,
782            'address': port.interface.address,
783            'netmask': port.interface.netmask,
784            'mtu':     port.interface.mtu,
785        }
786
787
788def _print_error(error):
789    print(%s (%s)' % (error['message'], error['code']))
790    print('    %s' % error['data'])
791
792
793def _start_sw(args):
794    def daemonize(nochdir=False, noclose=False):
795        if os.fork() > 0:
796            sys.exit(0)
797
798        os.setsid()
799
800        if os.fork() > 0:
801            sys.exit(0)
802
803        if not nochdir:
804            os.chdir('/')
805
806        if not noclose:
807            os.umask(0)
808            sys.stdin.close()
809            sys.stdout.close()
810            sys.stderr.close()
811            os.close(0)
812            os.close(1)
813            os.close(2)
814            sys.stdin = open(os.devnull)
815            sys.stdout = open(os.devnull, 'a')
816            sys.stderr = open(os.devnull, 'a')
817
818    def checkabspath(ns, path):
819        val = getattr(ns, path, '')
820        if not val.startswith('/'):
821            raise ValueError('Invalid %: %s' % (path, val))
822
823    def getsslopt(ns, key, cert):
824        kval = getattr(ns, key, None)
825        cval = getattr(ns, cert, None)
826        if kval and cval:
827            return {'keyfile': kval, 'certfile': cval}
828        elif kval or cval:
829            raise ValueError('Both %s and %s are required' % (key, cert))
830        return None
831
832    def setrealpath(ns, *keys):
833        for k in keys:
834            v = getattr(ns, k, None)
835            if v is not None:
836                v = os.path.realpath(v)
837                open(v).close()  # check readable
838                setattr(ns, k, v)
839
840    def setport(ns, port, isssl):
841        val = getattr(ns, port, None)
842        if val is None:
843            if isssl:
844                return setattr(ns, port, 443)
845            return setattr(ns, port, 80)
846        if not (0 <= val <= 65535):
847            raise ValueError('Invalid %s: %s' % (port, val))
848
849    def sethtpasswd(ns, htpasswd):
850        val = getattr(ns, htpasswd, None)
851        if val:
852            return setattr(ns, htpasswd, Htpasswd(val))
853
854    #if args.debug:
855    #    websocket.enableTrace(True)
856
857    if args.ageout <= 0:
858        raise ValueError('Invalid ageout: %s' % args.ageout)
859
860    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
861    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
862
863    checkabspath(args, 'path')
864    checkabspath(args, 'ctlpath')
865
866    sslopt = getsslopt(args, 'sslkey', 'sslcert')
867    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
868
869    setport(args, 'port', sslopt)
870    setport(args, 'ctlport', ctlsslopt)
871
872    sethtpasswd(args, 'htpasswd')
873    sethtpasswd(args, 'ctlhtpasswd')
874
875    ioloop = IOLoop.instance()
876    fdb = FDB(ageout=args.ageout, debug=args.debug)
877    switch = SwitchingHub(fdb, debug=args.debug)
878
879    if args.port == args.ctlport and args.host == args.ctlhost:
880        if args.path == args.ctlpath:
881            raise ValueError('Same path/ctlpath on same host')
882        if args.sslkey != args.ctlsslkey:
883            raise ValueError('Different sslkey/ctlsslkey on same host')
884        if args.sslcert != args.ctlsslcert:
885            raise ValueError('Different sslcert/ctlsslcert on same host')
886
887        app = Application([
888            (args.path, EtherWebSocketHandler, {
889                'switch':   switch,
890                'htpasswd': args.htpasswd,
891                'debug':    args.debug,
892            }),
893            (args.ctlpath, EtherWebSocketControlHandler, {
894                'ioloop':   ioloop,
895                'switch':   switch,
896                'htpasswd': args.ctlhtpasswd,
897                'debug':    args.debug,
898            }),
899        ])
900        server = HTTPServer(app, ssl_options=sslopt)
901        server.listen(args.port, address=args.host)
902
903    else:
904        app = Application([(args.path, EtherWebSocketHandler, {
905            'switch':   switch,
906            'htpasswd': args.htpasswd,
907            'debug':    args.debug,
908        })])
909        server = HTTPServer(app, ssl_options=sslopt)
910        server.listen(args.port, address=args.host)
911
912        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
913            'ioloop':   ioloop,
914            'switch':   switch,
915            'htpasswd': args.ctlhtpasswd,
916            'debug':    args.debug,
917        })])
918        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
919        ctlserver.listen(args.ctlport, address=args.ctlhost)
920
921    if not args.foreground:
922        daemonize()
923
924    ioloop.start()
925
926
927def _start_ctl(args):
928    def request(args, method, params=None, id_=0):
929        req = urllib2.Request(args.ctlurl)
930        req.add_header('Content-type', 'application/json')
931        if args.ctluser:
932            if not args.ctlpasswd:
933                args.ctlpasswd = getpass.getpass('Control Password: ')
934            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
935            req.add_header('Authorization', 'Basic %s' % token)
936        method = '.'.join([EtherWebSocketControlHandler.NAMESPACE, method])
937        data = {'jsonrpc': '2.0', 'method': method, 'id': id_}
938        if params is not None:
939            data['params'] = params
940        return json.loads(urllib2.urlopen(req, json.dumps(data)).read())
941
942    def maxlen(dict_, key, min_):
943        if not dict_:
944            return min_
945        max_ = max(len(str(r[key])) for r in dict_)
946        return min_ if max_ < min_ else max_
947
948    def print_portlist(result):
949        pmax = maxlen(result, 'port', 4)
950        ymax = maxlen(result, 'type', 4)
951        smax = maxlen(result, 'shut', 5)
952        rmax = maxlen(result, 'rx', 2)
953        tmax = maxlen(result, 'tx', 2)
954        fmt = %%%d%%%d%%%d%%%d%%%d%%s' % \
955              (pmax, ymax, smax, rmax, tmax)
956        print(fmt % ('Port', 'Type', 'State', 'RX', 'TX', 'Target'))
957        for r in result:
958            shut = 'shut' if r['shut'] else 'up'
959            print(fmt %
960                  (r['port'], r['type'], shut, r['rx'], r['tx'], r['target']))
961
962    def print_iflist(result):
963        pmax = maxlen(result, 'port', 4)
964        tmax = maxlen(result, 'type', 4)
965        amax = maxlen(result, 'address', 7)
966        nmax = maxlen(result, 'netmask', 7)
967        mmax = maxlen(result, 'mtu', 3)
968        fmt = %%%d%%%d%%%d%%%d%%%d%%s' % \
969              (pmax, tmax, amax, nmax, mmax)
970        print(fmt % ('Port', 'Type', 'Address', 'Netmask', 'MTU', 'Target'))
971        for r in result:
972            print(fmt % (r['port'], r['type'],
973                         r['address'], r['netmask'], r['mtu'], r['target']))
974
975    def handle_ctl_addport(args):
976        opts = {
977            'user':     getattr(args, 'user', None),
978            'passwd':   getattr(args, 'passwd', None),
979            'cacerts':  getattr(args, 'cacerts', None),
980            'insecure': getattr(args, 'insecure', None),
981        }
982        if args.iftype == EtherWebSocketClient.IFTYPE:
983            if not args.target.startswith('ws://') and \
984               not args.target.startswith('wss://'):
985                raise ValueError('Invalid target URL scheme: %s' % args.target)
986            if not opts['user'] and opts['passwd']:
987                raise ValueError('Authentication required but username empty')
988            if opts['user'] and not opts['passwd']:
989                opts['passwd'] = getpass.getpass('Client Password: ')
990        result = request(args, 'addPort', {
991            'type':    args.iftype,
992            'target':  args.target,
993            'options': opts,
994        })
995        if 'error' in result:
996            _print_error(result['error'])
997        else:
998            print_portlist(result['result']['entries'])
999
1000    def handle_ctl_setport(args):
1001        if args.port <= 0:
1002            raise ValueError('Invalid port: %d' % args.port)
1003        req = {'port': args.port}
1004        shut = getattr(args, 'shut', None)
1005        if shut is not None:
1006            req['shut'] = bool(shut)
1007        result = request(args, 'setPort', req)
1008        if 'error' in result:
1009            _print_error(result['error'])
1010        else:
1011            print_portlist(result['result']['entries'])
1012
1013    def handle_ctl_delport(args):
1014        if args.port <= 0:
1015            raise ValueError('Invalid port: %d' % args.port)
1016        result = request(args, 'delPort', {'port': args.port})
1017        if 'error' in result:
1018            _print_error(result['error'])
1019        else:
1020            print_portlist(result['result']['entries'])
1021
1022    def handle_ctl_listport(args):
1023        result = request(args, 'listPort')
1024        if 'error' in result:
1025            _print_error(result['error'])
1026        else:
1027            print_portlist(result['result']['entries'])
1028
1029    def handle_ctl_setif(args):
1030        if args.port <= 0:
1031            raise ValueError('Invalid port: %d' % args.port)
1032        req = {'port': args.port}
1033        address = getattr(args, 'address', None)
1034        netmask = getattr(args, 'netmask', None)
1035        mtu = getattr(args, 'mtu', None)
1036        if address is not None:
1037            if address:
1038                socket.inet_aton(address)  # validate
1039            req['address'] = address
1040        if netmask is not None:
1041            if netmask:
1042                socket.inet_aton(netmask)  # validate
1043            req['netmask'] = netmask
1044        if mtu is not None:
1045            if mtu < 576:
1046                raise ValueError('Invalid MTU: %d' % mtu)
1047            req['mtu'] = mtu
1048        result = request(args, 'setInterface', req)
1049        if 'error' in result:
1050            _print_error(result['error'])
1051        else:
1052            print_iflist(result['result']['entries'])
1053
1054    def handle_ctl_listif(args):
1055        result = request(args, 'listInterface')
1056        if 'error' in result:
1057            _print_error(result['error'])
1058        else:
1059            print_iflist(result['result']['entries'])
1060
1061    def handle_ctl_listfdb(args):
1062        result = request(args, 'listFdb')
1063        if 'error' in result:
1064            return _print_error(result['error'])
1065        result = result['result']['entries']
1066        pmax = maxlen(result, 'port', 4)
1067        vmax = maxlen(result, 'vid', 4)
1068        mmax = maxlen(result, 'mac', 3)
1069        amax = maxlen(result, 'age', 3)
1070        fmt = %%%d%%%d%%-%d%%%ds' % (pmax, vmax, mmax, amax)
1071        print(fmt % ('Port', 'VLAN', 'MAC', 'Age'))
1072        for r in result:
1073            print(fmt % (r['port'], r['vid'], r['mac'], r['age']))
1074
1075    locals()['handle_ctl_' + args.control_method](args)
1076
1077
1078def _main():
1079    parser = argparse.ArgumentParser()
1080    subcommand = parser.add_subparsers(dest='subcommand')
1081
1082    # - sw
1083    parser_sw = subcommand.add_parser('sw',
1084                                      help='start virtual switch')
1085
1086    parser_sw.add_argument('--debug', action='store_true', default=False,
1087                           help='run as debug mode')
1088    parser_sw.add_argument('--foreground', action='store_true', default=False,
1089                           help='run as foreground mode')
1090    parser_sw.add_argument('--ageout', type=int, default=300,
1091                           help='FDB ageout time (sec)')
1092
1093    parser_sw.add_argument('--path', default='/',
1094                           help='http(s) path to serve WebSocket')
1095    parser_sw.add_argument('--host', default='',
1096                           help='listen address to serve WebSocket')
1097    parser_sw.add_argument('--port', type=int,
1098                           help='listen port to serve WebSocket')
1099    parser_sw.add_argument('--htpasswd',
1100                           help='path to htpasswd file to auth WebSocket')
1101    parser_sw.add_argument('--sslkey',
1102                           help='path to SSL private key for WebSocket')
1103    parser_sw.add_argument('--sslcert',
1104                           help='path to SSL certificate for WebSocket')
1105
1106    parser_sw.add_argument('--ctlpath', default='/ctl',
1107                           help='http(s) path to serve control API')
1108    parser_sw.add_argument('--ctlhost', default='127.0.0.1',
1109                           help='listen address to serve control API')
1110    parser_sw.add_argument('--ctlport', type=int, default=7867,
1111                           help='listen port to serve control API')
1112    parser_sw.add_argument('--ctlhtpasswd',
1113                           help='path to htpasswd file to auth control API')
1114    parser_sw.add_argument('--ctlsslkey',
1115                           help='path to SSL private key for control API')
1116    parser_sw.add_argument('--ctlsslcert',
1117                           help='path to SSL certificate for control API')
1118
1119    # - ctl
1120    parser_ctl = subcommand.add_parser('ctl',
1121                                       help='control virtual switch')
1122    parser_ctl.add_argument('--ctlurl', default='http://127.0.0.1:7867/ctl',
1123                            help='URL to control API')
1124    parser_ctl.add_argument('--ctluser',
1125                            help='username to auth control API')
1126    parser_ctl.add_argument('--ctlpasswd',
1127                            help='password to auth control API')
1128
1129    control_method = parser_ctl.add_subparsers(dest='control_method')
1130
1131    # -- ctl addport
1132    parser_ctl_addport = control_method.add_parser('addport',
1133                                                   help='create and add port')
1134    iftype = parser_ctl_addport.add_subparsers(dest='iftype')
1135
1136    # --- ctl addport netdev
1137    parser_ctl_addport_netdev = iftype.add_parser(NetdevHandler.IFTYPE,
1138                                                  help='netdev')
1139    parser_ctl_addport_netdev.add_argument('target',
1140                                           help='device name to add interface')
1141
1142    # --- ctl addport tap
1143    parser_ctl_addport_tap = iftype.add_parser(TapHandler.IFTYPE,
1144                                               help='TAP device')
1145    parser_ctl_addport_tap.add_argument('target',
1146                                        help='device name to create interface')
1147
1148    # --- ctl addport client
1149    parser_ctl_addport_client = iftype.add_parser(EtherWebSocketClient.IFTYPE,
1150                                                  help='WebSocket client')
1151    parser_ctl_addport_client.add_argument('target',
1152                                           help='URL to connect WebSocket')
1153    parser_ctl_addport_client.add_argument('--user',
1154                                           help='username to auth WebSocket')
1155    parser_ctl_addport_client.add_argument('--passwd',
1156                                           help='password to auth WebSocket')
1157    parser_ctl_addport_client.add_argument('--cacerts',
1158                                           help='path to CA certificate')
1159    parser_ctl_addport_client.add_argument(
1160        '--insecure', action='store_true', default=False,
1161        help='do not verify server certificate')
1162
1163    # -- ctl setport
1164    parser_ctl_setport = control_method.add_parser('setport',
1165                                                   help='set port config')
1166    parser_ctl_setport.add_argument('port', type=int,
1167                                    help='port number to set config')
1168    parser_ctl_setport.add_argument('--shut', type=int, choices=(0, 1),
1169                                    help='set shutdown state')
1170
1171    # -- ctl delport
1172    parser_ctl_delport = control_method.add_parser('delport',
1173                                                   help='delete port')
1174    parser_ctl_delport.add_argument('port', type=int,
1175                                    help='port number to delete')
1176
1177    # -- ctl listport
1178    parser_ctl_listport = control_method.add_parser('listport',
1179                                                    help='show port list')
1180
1181    # -- ctl setif
1182    parser_ctl_setif = control_method.add_parser('setif',
1183                                                 help='set interface config')
1184    parser_ctl_setif.add_argument('port', type=int,
1185                                  help='port number to set config')
1186    parser_ctl_setif.add_argument('--address',
1187                                  help='IPv4 address to set interface')
1188    parser_ctl_setif.add_argument('--netmask',
1189                                  help='IPv4 netmask to set interface')
1190    parser_ctl_setif.add_argument('--mtu', type=int,
1191                                  help='MTU to set interface')
1192
1193    # -- ctl listif
1194    parser_ctl_listif = control_method.add_parser('listif',
1195                                                  help='show interface list')
1196
1197    # -- ctl listfdb
1198    parser_ctl_listfdb = control_method.add_parser('listfdb',
1199                                                   help='show FDB entries')
1200
1201    # -- go
1202    args = parser.parse_args()
1203
1204    try:
1205        globals()['_start_' + args.subcommand](args)
1206    except Exception as e:
1207        _print_error({
1208            'code':    0 - 32603,
1209            'message': 'Internal error',
1210            'data':    '%s: %s' % (e.__class__.__name__, str(e)),
1211        })
1212
1213
1214if __name__ == '__main__':
1215    _main()
Note: See TracBrowser for help on using the repository browser.