source: etherws/trunk/etherws.py @ 205

Revision 205, 30.9 KB checked in by atzm, 12 years ago (diff)
  • add error handling
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.3
11#
12# ===========================================================================
13# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
42import ssl
43import time
44import json
45import fcntl
46import base64
47import urllib2
48import hashlib
49import getpass
50import argparse
51import traceback
52
53import tornado
54import websocket
55
56from tornado.web import Application, RequestHandler
57from tornado.websocket import WebSocketHandler
58from tornado.httpserver import HTTPServer
59from tornado.ioloop import IOLoop
60
61from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
62
63
64class DebugMixIn(object):
65    def dprintf(self, msg, func=lambda: ()):
66        if self._debug:
67            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
68            sys.stderr.write(prefix + (msg % func()))
69
70
71class EthernetFrame(object):
72    def __init__(self, data):
73        self.data = data
74
75    @property
76    def dst_multicast(self):
77        return ord(self.data[0]) & 1
78
79    @property
80    def src_multicast(self):
81        return ord(self.data[6]) & 1
82
83    @property
84    def dst_mac(self):
85        return self.data[:6]
86
87    @property
88    def src_mac(self):
89        return self.data[6:12]
90
91    @property
92    def tagged(self):
93        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
94
95    @property
96    def vid(self):
97        if self.tagged:
98            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
99        return 0
100
101    @staticmethod
102    def format_mac(mac, sep=':'):
103        return sep.join(b.encode('hex') for b in mac)
104
105
106class FDB(DebugMixIn):
107    class Entry(object):
108        def __init__(self, port, ageout):
109            self.port = port
110            self._time = time.time()
111            self._ageout = ageout
112
113        @property
114        def age(self):
115            return time.time() - self._time
116
117        @property
118        def agedout(self):
119            return self.age > self._ageout
120
121    def __init__(self, ageout, debug=False):
122        self._ageout = ageout
123        self._debug = debug
124        self._table = {}
125
126    def _set_entry(self, vid, mac, port):
127        if vid not in self._table:
128            self._table[vid] = {}
129        self._table[vid][mac] = self.Entry(port, self._ageout)
130
131    def _del_entry(self, vid, mac):
132        if vid in self._table:
133            if mac in self._table[vid]:
134                del self._table[vid][mac]
135            if not self._table[vid]:
136                del self._table[vid]
137
138    def get_entry(self, vid, mac):
139        try:
140            entry = self._table[vid][mac]
141        except KeyError:
142            return None
143
144        if not entry.agedout:
145            return entry
146
147        self._del_entry(vid, mac)
148        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
149                     lambda: (entry.port.number, vid, mac.encode('hex')))
150
151    def get_vid_list(self):
152        return sorted(self._table.iterkeys())
153
154    def get_mac_list(self, vid):
155        return sorted(self._table[vid].iterkeys())
156
157    def lookup(self, frame):
158        mac = frame.dst_mac
159        vid = frame.vid
160        entry = self.get_entry(vid, mac)
161        return getattr(entry, 'port', None)
162
163    def learn(self, port, frame):
164        mac = frame.src_mac
165        vid = frame.vid
166        self._set_entry(vid, mac, port)
167        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
168                     lambda: (port.number, vid, mac.encode('hex')))
169
170    def delete(self, port):
171        for vid in self.get_vid_list():
172            for mac in self.get_mac_list(vid):
173                entry = self.get_entry(vid, mac)
174                if entry and entry.port.number == port.number:
175                    self._del_entry(vid, mac)
176                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
177                                 lambda: (port.number, vid, mac.encode('hex')))
178
179
180class SwitchingHub(DebugMixIn):
181    class Port(object):
182        def __init__(self, number, interface):
183            self.number = number
184            self.interface = interface
185            self.tx = 0
186            self.rx = 0
187            self.shut = False
188
189        @staticmethod
190        def cmp_by_number(x, y):
191            return cmp(x.number, y.number)
192
193    def __init__(self, fdb, debug=False):
194        self.fdb = fdb
195        self._debug = debug
196        self._table = {}
197        self._next = 1
198
199    @property
200    def portlist(self):
201        return sorted(self._table.itervalues(), cmp=self.Port.cmp_by_number)
202
203    def get_port(self, portnum):
204        return self._table[portnum]
205
206    def register_port(self, interface):
207        try:
208            self._set_privattr('portnum', interface, self._next)  # XXX
209            self._table[self._next] = self.Port(self._next, interface)
210            return self._next
211        finally:
212            self._next += 1
213
214    def unregister_port(self, interface):
215        portnum = self._get_privattr('portnum', interface)
216        self._del_privattr('portnum', interface)
217        self.fdb.delete(self._table[portnum])
218        del self._table[portnum]
219
220    def send(self, dst_interfaces, frame):
221        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
222        ports = (self._table[n] for n in portnums)
223        ports = (p for p in ports if not p.shut)
224        ports = sorted(ports, cmp=self.Port.cmp_by_number)
225
226        for p in ports:
227            p.interface.write_message(frame.data, True)
228            p.tx += 1
229
230        if ports:
231            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
232                         lambda: (','.join(str(p.number) for p in ports),
233                                  frame.vid,
234                                  frame.src_mac.encode('hex'),
235                                  frame.dst_mac.encode('hex')))
236
237    def receive(self, src_interface, frame):
238        port = self._table[self._get_privattr('portnum', src_interface)]
239
240        if not port.shut:
241            port.rx += 1
242            self._forward(port, frame)
243
244    def _forward(self, src_port, frame):
245        try:
246            if not frame.src_multicast:
247                self.fdb.learn(src_port, frame)
248
249            if not frame.dst_multicast:
250                dst_port = self.fdb.lookup(frame)
251
252                if dst_port:
253                    self.send([dst_port.interface], frame)
254                    return
255
256            ports = set(self.portlist) - set([src_port])
257            self.send((p.interface for p in ports), frame)
258
259        except:  # ex. received invalid frame
260            traceback.print_exc()
261
262    def _privattr(self, name):
263        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
264
265    def _set_privattr(self, name, obj, value):
266        return setattr(obj, self._privattr(name), value)
267
268    def _get_privattr(self, name, obj, defaults=None):
269        return getattr(obj, self._privattr(name), defaults)
270
271    def _del_privattr(self, name, obj):
272        return delattr(obj, self._privattr(name))
273
274
275class Htpasswd(object):
276    def __init__(self, path):
277        self._path = path
278        self._stat = None
279        self._data = {}
280
281    def auth(self, name, passwd):
282        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
283        return self._data.get(name) == passwd
284
285    def load(self):
286        old_stat = self._stat
287
288        with open(self._path) as fp:
289            fileno = fp.fileno()
290            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
291            self._stat = os.fstat(fileno)
292
293            unchanged = old_stat and \
294                        old_stat.st_ino == self._stat.st_ino and \
295                        old_stat.st_dev == self._stat.st_dev and \
296                        old_stat.st_mtime == self._stat.st_mtime
297
298            if not unchanged:
299                self._data = self._parse(fp)
300
301        return self
302
303    def _parse(self, fp):
304        data = {}
305        for line in fp:
306            line = line.strip()
307            if 0 <= line.find(':'):
308                name, passwd = line.split(':', 1)
309                if passwd.startswith('{SHA}'):
310                    data[name] = passwd[5:]
311        return data
312
313
314class BasicAuthMixIn(object):
315    def _execute(self, transforms, *args, **kwargs):
316        def do_execute():
317            sp = super(BasicAuthMixIn, self)
318            return sp._execute(transforms, *args, **kwargs)
319
320        def auth_required():
321            stream = getattr(self, 'stream', self.request.connection.stream)
322            stream.write(tornado.escape.utf8(
323                'HTTP/1.1 401 Authorization Required\r\n'
324                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
325            ))
326            stream.close()
327
328        try:
329            if not self._htpasswd:
330                return do_execute()
331
332            creds = self.request.headers.get('Authorization')
333
334            if not creds or not creds.startswith('Basic '):
335                return auth_required()
336
337            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
338
339            if self._htpasswd.load().auth(name, passwd):
340                return do_execute()
341        except:
342            traceback.print_exc()
343
344        return auth_required()
345
346
347class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
348    IFTYPE = 'server'
349
350    def __init__(self, app, req, switch, htpasswd=None, debug=False):
351        super(EtherWebSocketHandler, self).__init__(app, req)
352        self._switch = switch
353        self._htpasswd = htpasswd
354        self._debug = debug
355
356    @property
357    def target(self):
358        return ':'.join(str(e) for e in self.request.connection.address)
359
360    def open(self):
361        try:
362            return self._switch.register_port(self)
363        finally:
364            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
365
366    def on_message(self, message):
367        self._switch.receive(self, EthernetFrame(message))
368
369    def on_close(self):
370        self._switch.unregister_port(self)
371        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
372
373
374class TapHandler(DebugMixIn):
375    IFTYPE = 'tap'
376    READ_SIZE = 65535
377
378    def __init__(self, ioloop, switch, dev, debug=False):
379        self._ioloop = ioloop
380        self._switch = switch
381        self._dev = dev
382        self._debug = debug
383        self._tap = None
384
385    @property
386    def target(self):
387        if self.closed:
388            return self._dev
389        return self._tap.name
390
391    @property
392    def closed(self):
393        return not self._tap
394
395    def open(self):
396        if not self.closed:
397            raise ValueError('Already opened')
398        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
399        self._tap.up()
400        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
401        return self._switch.register_port(self)
402
403    def close(self):
404        if self.closed:
405            raise ValueError('I/O operation on closed tap')
406        self._switch.unregister_port(self)
407        self._ioloop.remove_handler(self.fileno())
408        self._tap.close()
409        self._tap = None
410
411    def fileno(self):
412        if self.closed:
413            raise ValueError('I/O operation on closed tap')
414        return self._tap.fileno()
415
416    def write_message(self, message, binary=False):
417        if self.closed:
418            raise ValueError('I/O operation on closed tap')
419        self._tap.write(message)
420
421    def __call__(self, fd, events):
422        try:
423            self._switch.receive(self, EthernetFrame(self._read()))
424            return
425        except:
426            traceback.print_exc()
427        self.close()
428
429    def _read(self):
430        if self.closed:
431            raise ValueError('I/O operation on closed tap')
432        buf = []
433        while True:
434            buf.append(self._tap.read(self.READ_SIZE))
435            if len(buf[-1]) < self.READ_SIZE:
436                break
437        return ''.join(buf)
438
439
440class EtherWebSocketClient(DebugMixIn):
441    IFTYPE = 'client'
442
443    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
444        self._ioloop = ioloop
445        self._switch = switch
446        self._url = url
447        self._ssl = ssl_
448        self._debug = debug
449        self._sock = None
450        self._options = {}
451
452        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
453            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
454            auth = ['Authorization: Basic %s' % token]
455            self._options['header'] = auth
456
457    @property
458    def target(self):
459        return self._url
460
461    @property
462    def closed(self):
463        return not self._sock
464
465    def open(self):
466        sslwrap = websocket._SSLSocketWrapper
467
468        if not self.closed:
469            raise websocket.WebSocketException('Already opened')
470
471        if self._ssl:
472            websocket._SSLSocketWrapper = self._ssl
473
474        try:
475            self._sock = websocket.WebSocket()
476            self._sock.connect(self._url, **self._options)
477            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
478            return self._switch.register_port(self)
479        finally:
480            websocket._SSLSocketWrapper = sslwrap
481            self.dprintf('connected: %s\n', lambda: self._url)
482
483    def close(self):
484        if self.closed:
485            raise websocket.WebSocketException('Already closed')
486        self._switch.unregister_port(self)
487        self._ioloop.remove_handler(self.fileno())
488        self._sock.close()
489        self._sock = None
490        self.dprintf('disconnected: %s\n', lambda: self._url)
491
492    def fileno(self):
493        if self.closed:
494            raise websocket.WebSocketException('Closed socket')
495        return self._sock.io_sock.fileno()
496
497    def write_message(self, message, binary=False):
498        if self.closed:
499            raise websocket.WebSocketException('Closed socket')
500        if binary:
501            flag = websocket.ABNF.OPCODE_BINARY
502        else:
503            flag = websocket.ABNF.OPCODE_TEXT
504        self._sock.send(message, flag)
505
506    def __call__(self, fd, events):
507        try:
508            data = self._sock.recv()
509            if data is not None:
510                self._switch.receive(self, EthernetFrame(data))
511                return
512        except:
513            traceback.print_exc()
514        self.close()
515
516
517class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
518    NAMESPACE = 'etherws.control'
519    IFTYPES = {
520        TapHandler.IFTYPE:           TapHandler,
521        EtherWebSocketClient.IFTYPE: EtherWebSocketClient,
522    }
523
524    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
525        super(EtherWebSocketControlHandler, self).__init__(app, req)
526        self._ioloop = ioloop
527        self._switch = switch
528        self._htpasswd = htpasswd
529        self._debug = debug
530
531    def post(self):
532        try:
533            request = json.loads(self.request.body)
534        except Exception as e:
535            return self._jsonrpc_response(error={
536                'code':    0 - 32700,
537                'message': 'Parse error',
538                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
539            })
540
541        try:
542            id_ = request.get('id')
543            params = request.get('params')
544            version = request['jsonrpc']
545            method = request['method']
546            if version != '2.0':
547                raise ValueError('Invalid JSON-RPC version: %s' % version)
548        except Exception as e:
549            return self._jsonrpc_response(id_=id_, error={
550                'code':    0 - 32600,
551                'message': 'Invalid Request',
552                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
553            })
554
555        try:
556            if not method.startswith(self.NAMESPACE + '.'):
557                raise ValueError('Invalid method namespace: %s' % method)
558            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
559            handler = getattr(self, handler)
560        except Exception as e:
561            return self._jsonrpc_response(id_=id_, error={
562                'code':    0 - 32601,
563                'message': 'Method not found',
564                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
565            })
566
567        try:
568            return self._jsonrpc_response(id_=id_, result=handler(params))
569        except Exception as e:
570            traceback.print_exc()
571            return self._jsonrpc_response(id_=id_, error={
572                'code':    0 - 32602,
573                'message': 'Invalid params',
574                'data':     '%s: %s' % (e.__class__.__name__, str(e)),
575            })
576
577    def handle_listFdb(self, params):
578        list_ = []
579        for vid in self._switch.fdb.get_vid_list():
580            for mac in self._switch.fdb.get_mac_list(vid):
581                entry = self._switch.fdb.get_entry(vid, mac)
582                if entry:
583                    list_.append({
584                        'vid':  vid,
585                        'mac':  EthernetFrame.format_mac(mac),
586                        'port': entry.port.number,
587                        'age':  int(entry.age),
588                    })
589        return {'entries': list_}
590
591    def handle_listPort(self, params):
592        return {'entries': [self._portstat(p) for p in self._switch.portlist]}
593
594    def handle_addPort(self, params):
595        type_ = params['type']
596        target = params['target']
597        opts = getattr(self, '_optparse_' + type_)(params.get('options', {}))
598        cls = self.IFTYPES[type_]
599        interface = cls(self._ioloop, self._switch, target, **opts)
600        portnum = interface.open()
601        return {'entries': [self._portstat(self._switch.get_port(portnum))]}
602
603    def handle_delPort(self, params):
604        port = self._switch.get_port(int(params['port']))
605        port.interface.close()
606        return {'entries': [self._portstat(port)]}
607
608    def handle_shutPort(self, params):
609        port = self._switch.get_port(int(params['port']))
610        port.shut = bool(params['shut'])
611        return {'entries': [self._portstat(port)]}
612
613    def _optparse_tap(self, opt):
614        return {'debug': self._debug}
615
616    def _optparse_client(self, opt):
617        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
618        if opt.get('insecure'):
619            args = {}
620        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
621        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
622        return {'ssl_': ssl_, 'cred': cred, 'debug': self._debug}
623
624    def _jsonrpc_response(self, id_=None, result=None, error=None):
625        res = {'jsonrpc': '2.0', 'id': id_}
626        if result:
627            res['result'] = result
628        if error:
629            res['error'] = error
630        self.finish(res)
631
632    @staticmethod
633    def _portstat(port):
634        return {
635            'port':   port.number,
636            'type':   port.interface.IFTYPE,
637            'target': port.interface.target,
638            'tx':     port.tx,
639            'rx':     port.rx,
640            'shut':   port.shut,
641        }
642
643
644def print_error(error):
645    print(%s (%s)' % (error['message'], error['code']))
646    print('    %s' % error['data'])
647
648
649def start_sw(args):
650    def daemonize(nochdir=False, noclose=False):
651        if os.fork() > 0:
652            sys.exit(0)
653
654        os.setsid()
655
656        if os.fork() > 0:
657            sys.exit(0)
658
659        if not nochdir:
660            os.chdir('/')
661
662        if not noclose:
663            os.umask(0)
664            sys.stdin.close()
665            sys.stdout.close()
666            sys.stderr.close()
667            os.close(0)
668            os.close(1)
669            os.close(2)
670            sys.stdin = open(os.devnull)
671            sys.stdout = open(os.devnull, 'a')
672            sys.stderr = open(os.devnull, 'a')
673
674    def checkabspath(ns, path):
675        val = getattr(ns, path, '')
676        if not val.startswith('/'):
677            raise ValueError('Invalid %: %s' % (path, val))
678
679    def getsslopt(ns, key, cert):
680        kval = getattr(ns, key, None)
681        cval = getattr(ns, cert, None)
682        if kval and cval:
683            return {'keyfile': kval, 'certfile': cval}
684        elif kval or cval:
685            raise ValueError('Both %s and %s are required' % (key, cert))
686        return None
687
688    def setrealpath(ns, *keys):
689        for k in keys:
690            v = getattr(ns, k, None)
691            if v is not None:
692                v = os.path.realpath(v)
693                open(v).close()  # check readable
694                setattr(ns, k, v)
695
696    def setport(ns, port, isssl):
697        val = getattr(ns, port, None)
698        if val is None:
699            if isssl:
700                return setattr(ns, port, 443)
701            return setattr(ns, port, 80)
702        if not (0 <= val <= 65535):
703            raise ValueError('Invalid %s: %s' % (port, val))
704
705    def sethtpasswd(ns, htpasswd):
706        val = getattr(ns, htpasswd, None)
707        if val:
708            return setattr(ns, htpasswd, Htpasswd(val))
709
710    #if args.debug:
711    #    websocket.enableTrace(True)
712
713    if args.ageout <= 0:
714        raise ValueError('Invalid ageout: %s' % args.ageout)
715
716    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
717    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
718
719    checkabspath(args, 'path')
720    checkabspath(args, 'ctlpath')
721
722    sslopt = getsslopt(args, 'sslkey', 'sslcert')
723    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
724
725    setport(args, 'port', sslopt)
726    setport(args, 'ctlport', ctlsslopt)
727
728    sethtpasswd(args, 'htpasswd')
729    sethtpasswd(args, 'ctlhtpasswd')
730
731    ioloop = IOLoop.instance()
732    fdb = FDB(ageout=args.ageout, debug=args.debug)
733    switch = SwitchingHub(fdb, debug=args.debug)
734
735    if args.port == args.ctlport and args.host == args.ctlhost:
736        if args.path == args.ctlpath:
737            raise ValueError('Same path/ctlpath on same host')
738        if args.sslkey != args.ctlsslkey:
739            raise ValueError('Different sslkey/ctlsslkey on same host')
740        if args.sslcert != args.ctlsslcert:
741            raise ValueError('Different sslcert/ctlsslcert on same host')
742
743        app = Application([
744            (args.path, EtherWebSocketHandler, {
745                'switch':   switch,
746                'htpasswd': args.htpasswd,
747                'debug':    args.debug,
748            }),
749            (args.ctlpath, EtherWebSocketControlHandler, {
750                'ioloop':   ioloop,
751                'switch':   switch,
752                'htpasswd': args.ctlhtpasswd,
753                'debug':    args.debug,
754            }),
755        ])
756        server = HTTPServer(app, ssl_options=sslopt)
757        server.listen(args.port, address=args.host)
758
759    else:
760        app = Application([(args.path, EtherWebSocketHandler, {
761            'switch':   switch,
762            'htpasswd': args.htpasswd,
763            'debug':    args.debug,
764        })])
765        server = HTTPServer(app, ssl_options=sslopt)
766        server.listen(args.port, address=args.host)
767
768        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
769            'ioloop':   ioloop,
770            'switch':   switch,
771            'htpasswd': args.ctlhtpasswd,
772            'debug':    args.debug,
773        })])
774        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
775        ctlserver.listen(args.ctlport, address=args.ctlhost)
776
777    if not args.foreground:
778        daemonize()
779
780    ioloop.start()
781
782
783def start_ctl(args):
784    def request(args, method, params=None, id_=0):
785        req = urllib2.Request(args.ctlurl)
786        req.add_header('Content-type', 'application/json')
787        if args.ctluser:
788            if not args.ctlpasswd:
789                args.ctlpasswd = getpass.getpass()
790            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
791            req.add_header('Authorization', 'Basic %s' % token)
792        method = '.'.join([EtherWebSocketControlHandler.NAMESPACE, method])
793        data = {'jsonrpc': '2.0', 'method': method, 'id': id_}
794        if params is not None:
795            data['params'] = params
796        return json.loads(urllib2.urlopen(req, json.dumps(data)).read())
797
798    def maxlen(dict_, key, min_):
799        if not dict_:
800            return min_
801        max_ = max(len(str(r[key])) for r in dict_)
802        return min_ if max_ < min_ else max_
803
804    def print_portlist(result):
805        pmax = maxlen(result, 'port', 4)
806        ymax = maxlen(result, 'type', 4)
807        smax = maxlen(result, 'shut', 5)
808        rmax = maxlen(result, 'rx', 2)
809        tmax = maxlen(result, 'tx', 2)
810        fmt = %%%d%%%d%%%d%%%d%%%d%%s' % \
811              (pmax, ymax, smax, rmax, tmax)
812        print(fmt % ('Port', 'Type', 'State', 'RX', 'TX', 'Target'))
813        for r in result:
814            shut = 'shut' if r['shut'] else 'up'
815            print(fmt %
816                  (r['port'], r['type'], shut, r['rx'], r['tx'], r['target']))
817
818    def handle_ctl_addport(args):
819        opts = {
820            'user':     getattr(args, 'user', None),
821            'passwd':   getattr(args, 'passwd', None),
822            'cacerts':  getattr(args, 'cacerts', None),
823            'insecure': getattr(args, 'insecure', None),
824        }
825        if args.iftype == EtherWebSocketClient.IFTYPE:
826            if not args.target.startswith('ws://') and \
827               not args.target.startswith('wss://'):
828                raise ValueError('Invalid target URL scheme: %s' % args.target)
829        if opts['user'] and not opts['passwd']:
830            raise ValueError('Authentication required but password empty')
831        if not opts['user'] and opts['passwd']:
832            raise ValueError('Authentication required but username empty')
833        result = request(args, 'addPort', {
834            'type':    args.iftype,
835            'target':  args.target,
836            'options': opts,
837        })
838        if 'error' in result:
839            print_error(result['error'])
840        else:
841            print_portlist(result['result']['entries'])
842
843    def handle_ctl_shutport(args):
844        if args.port <= 0:
845            raise ValueError('Invalid port: %d' % args.port)
846        result = request(args, 'shutPort', {
847            'port': args.port,
848            'shut': args.no,
849        })
850        if 'error' in result:
851            print_error(result['error'])
852        else:
853            print_portlist(result['result']['entries'])
854
855    def handle_ctl_delport(args):
856        if args.port <= 0:
857            raise ValueError('Invalid port: %d' % args.port)
858        result = request(args, 'delPort', {'port': args.port})
859        if 'error' in result:
860            print_error(result['error'])
861        else:
862            print_portlist(result['result']['entries'])
863
864    def handle_ctl_listport(args):
865        result = request(args, 'listPort')
866        if 'error' in result:
867            print_error(result['error'])
868        else:
869            print_portlist(result['result']['entries'])
870
871    def handle_ctl_listfdb(args):
872        result = request(args, 'listFdb')
873        if 'error' in result:
874            return print_error(result['error'])
875        result = result['result']['entries']
876        pmax = maxlen(result, 'port', 4)
877        vmax = maxlen(result, 'vid', 4)
878        mmax = maxlen(result, 'mac', 3)
879        amax = maxlen(result, 'age', 3)
880        fmt = %%%d%%%d%%-%d%%%ds' % (pmax, vmax, mmax, amax)
881        print(fmt % ('Port', 'VLAN', 'MAC', 'Age'))
882        for r in result:
883            print(fmt % (r['port'], r['vid'], r['mac'], r['age']))
884
885    locals()['handle_ctl_' + args.control_method](args)
886
887
888def main():
889    parser = argparse.ArgumentParser()
890    subcommand = parser.add_subparsers(dest='subcommand')
891
892    # - sw
893    parser_sw = subcommand.add_parser('sw')
894
895    parser_sw.add_argument('--debug', action='store_true', default=False)
896    parser_sw.add_argument('--foreground', action='store_true', default=False)
897    parser_sw.add_argument('--ageout', type=int, default=300)
898
899    parser_sw.add_argument('--path', default='/')
900    parser_sw.add_argument('--host', default='')
901    parser_sw.add_argument('--port', type=int)
902    parser_sw.add_argument('--htpasswd')
903    parser_sw.add_argument('--sslkey')
904    parser_sw.add_argument('--sslcert')
905
906    parser_sw.add_argument('--ctlpath', default='/ctl')
907    parser_sw.add_argument('--ctlhost', default='')
908    parser_sw.add_argument('--ctlport', type=int)
909    parser_sw.add_argument('--ctlhtpasswd')
910    parser_sw.add_argument('--ctlsslkey')
911    parser_sw.add_argument('--ctlsslcert')
912
913    # - ctl
914    parser_ctl = subcommand.add_parser('ctl')
915    parser_ctl.add_argument('--ctlurl', default='http://localhost/ctl')
916    parser_ctl.add_argument('--ctluser')
917    parser_ctl.add_argument('--ctlpasswd')
918
919    control_method = parser_ctl.add_subparsers(dest='control_method')
920
921    # -- ctl addport
922    parser_ctl_addport = control_method.add_parser('addport')
923    iftype = parser_ctl_addport.add_subparsers(dest='iftype')
924
925    # --- ctl addport tap
926    parser_ctl_addport_tap = iftype.add_parser(TapHandler.IFTYPE)
927    parser_ctl_addport_tap.add_argument('target')
928
929    # --- ctl addport client
930    parser_ctl_addport_client = iftype.add_parser(EtherWebSocketClient.IFTYPE)
931    parser_ctl_addport_client.add_argument('target')
932    parser_ctl_addport_client.add_argument('--user')
933    parser_ctl_addport_client.add_argument('--passwd')
934    parser_ctl_addport_client.add_argument('--cacerts')
935    parser_ctl_addport_client.add_argument(
936        '--insecure', action='store_true', default=False)
937
938    # -- ctl shutport
939    parser_ctl_shutport = control_method.add_parser('shutport')
940    parser_ctl_shutport.add_argument('port', type=int)
941    parser_ctl_shutport.add_argument(
942        '--no', action='store_false', default=True)
943
944    # -- ctl delport
945    parser_ctl_delport = control_method.add_parser('delport')
946    parser_ctl_delport.add_argument('port', type=int)
947
948    # -- ctl listport
949    parser_ctl_listport = control_method.add_parser('listport')
950
951    # -- ctl listfdb
952    parser_ctl_listfdb = control_method.add_parser('listfdb')
953
954    # -- go
955    args = parser.parse_args()
956
957    try:
958        globals()['start_' + args.subcommand](args)
959    except Exception as e:
960        print_error({
961            'code':    0 - 32603,
962            'message': 'Internal error',
963            'data':    '%s: %s' % (e.__class__.__name__, str(e)),
964        })
965
966
967if __name__ == '__main__':
968    main()
Note: See TracBrowser for help on using the repository browser.