source: etherws/trunk/etherws.py @ 207

Revision 207, 30.7 KB checked in by atzm, 12 years ago (diff)
  • refactoring
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.3
11#
12# ===========================================================================
13# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
42import ssl
43import time
44import json
45import fcntl
46import base64
47import urllib2
48import hashlib
49import getpass
50import argparse
51import traceback
52
53import tornado
54import websocket
55
56from tornado.web import Application, RequestHandler
57from tornado.websocket import WebSocketHandler
58from tornado.httpserver import HTTPServer
59from tornado.ioloop import IOLoop
60
61from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
62
63
64class DebugMixIn(object):
65    def dprintf(self, msg, func=lambda: ()):
66        if self._debug:
67            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
68            sys.stderr.write(prefix + (msg % func()))
69
70
71class EthernetFrame(object):
72    def __init__(self, data):
73        self.data = data
74
75    @property
76    def dst_multicast(self):
77        return ord(self.data[0]) & 1
78
79    @property
80    def src_multicast(self):
81        return ord(self.data[6]) & 1
82
83    @property
84    def dst_mac(self):
85        return self.data[:6]
86
87    @property
88    def src_mac(self):
89        return self.data[6:12]
90
91    @property
92    def tagged(self):
93        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
94
95    @property
96    def vid(self):
97        if self.tagged:
98            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
99        return 0
100
101    @staticmethod
102    def format_mac(mac, sep=':'):
103        return sep.join(b.encode('hex') for b in mac)
104
105
106class FDB(DebugMixIn):
107    class Entry(object):
108        def __init__(self, port, ageout):
109            self.port = port
110            self._time = time.time()
111            self._ageout = ageout
112
113        @property
114        def age(self):
115            return time.time() - self._time
116
117        @property
118        def agedout(self):
119            return self.age > self._ageout
120
121    def __init__(self, ageout, debug=False):
122        self._ageout = ageout
123        self._debug = debug
124        self._table = {}
125
126    def _set_entry(self, vid, mac, port):
127        if vid not in self._table:
128            self._table[vid] = {}
129        self._table[vid][mac] = self.Entry(port, self._ageout)
130
131    def _del_entry(self, vid, mac):
132        if vid in self._table:
133            if mac in self._table[vid]:
134                del self._table[vid][mac]
135            if not self._table[vid]:
136                del self._table[vid]
137
138    def _get_entry(self, vid, mac):
139        try:
140            entry = self._table[vid][mac]
141        except KeyError:
142            return None
143
144        if not entry.agedout:
145            return entry
146
147        self._del_entry(vid, mac)
148        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
149                     lambda: (entry.port.number, vid, mac.encode('hex')))
150
151    def foreach(self):
152        for vid in sorted(self._table.iterkeys()):
153            for mac in sorted(self._table[vid].iterkeys()):
154                entry = self._get_entry(vid, mac)
155                if entry:
156                    yield (vid, mac, entry)
157
158    def lookup(self, frame):
159        mac = frame.dst_mac
160        vid = frame.vid
161        entry = self._get_entry(vid, mac)
162        return getattr(entry, 'port', None)
163
164    def learn(self, port, frame):
165        mac = frame.src_mac
166        vid = frame.vid
167        self._set_entry(vid, mac, port)
168        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
169                     lambda: (port.number, vid, mac.encode('hex')))
170
171    def delete(self, port):
172        for vid, mac, entry in self.foreach():
173            if entry.port.number == port.number:
174                self._del_entry(vid, mac)
175                self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
176                             lambda: (port.number, vid, mac.encode('hex')))
177
178
179class SwitchingHub(DebugMixIn):
180    class Port(object):
181        def __init__(self, number, interface):
182            self.number = number
183            self.interface = interface
184            self.tx = 0
185            self.rx = 0
186            self.shut = False
187
188        @staticmethod
189        def cmp_by_number(x, y):
190            return cmp(x.number, y.number)
191
192    def __init__(self, fdb, debug=False):
193        self.fdb = fdb
194        self._debug = debug
195        self._table = {}
196        self._next = 1
197
198    @property
199    def portlist(self):
200        return sorted(self._table.itervalues(), cmp=self.Port.cmp_by_number)
201
202    def get_port(self, portnum):
203        return self._table[portnum]
204
205    def register_port(self, interface):
206        try:
207            self._set_privattr('portnum', interface, self._next)  # XXX
208            self._table[self._next] = self.Port(self._next, interface)
209            return self._next
210        finally:
211            self._next += 1
212
213    def unregister_port(self, interface):
214        portnum = self._get_privattr('portnum', interface)
215        self._del_privattr('portnum', interface)
216        self.fdb.delete(self._table[portnum])
217        del self._table[portnum]
218
219    def send(self, dst_interfaces, frame):
220        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
221        ports = (self._table[n] for n in portnums)
222        ports = (p for p in ports if not p.shut)
223        ports = sorted(ports, cmp=self.Port.cmp_by_number)
224
225        for p in ports:
226            p.interface.write_message(frame.data, True)
227            p.tx += 1
228
229        if ports:
230            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
231                         lambda: (','.join(str(p.number) for p in ports),
232                                  frame.vid,
233                                  frame.src_mac.encode('hex'),
234                                  frame.dst_mac.encode('hex')))
235
236    def receive(self, src_interface, frame):
237        port = self._table[self._get_privattr('portnum', src_interface)]
238
239        if not port.shut:
240            port.rx += 1
241            self._forward(port, frame)
242
243    def _forward(self, src_port, frame):
244        try:
245            if not frame.src_multicast:
246                self.fdb.learn(src_port, frame)
247
248            if not frame.dst_multicast:
249                dst_port = self.fdb.lookup(frame)
250
251                if dst_port:
252                    self.send([dst_port.interface], frame)
253                    return
254
255            ports = set(self.portlist) - set([src_port])
256            self.send((p.interface for p in ports), frame)
257
258        except:  # ex. received invalid frame
259            traceback.print_exc()
260
261    def _privattr(self, name):
262        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
263
264    def _set_privattr(self, name, obj, value):
265        return setattr(obj, self._privattr(name), value)
266
267    def _get_privattr(self, name, obj, defaults=None):
268        return getattr(obj, self._privattr(name), defaults)
269
270    def _del_privattr(self, name, obj):
271        return delattr(obj, self._privattr(name))
272
273
274class Htpasswd(object):
275    def __init__(self, path):
276        self._path = path
277        self._stat = None
278        self._data = {}
279
280    def auth(self, name, passwd):
281        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
282        return self._data.get(name) == passwd
283
284    def load(self):
285        old_stat = self._stat
286
287        with open(self._path) as fp:
288            fileno = fp.fileno()
289            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
290            self._stat = os.fstat(fileno)
291
292            unchanged = old_stat and \
293                        old_stat.st_ino == self._stat.st_ino and \
294                        old_stat.st_dev == self._stat.st_dev and \
295                        old_stat.st_mtime == self._stat.st_mtime
296
297            if not unchanged:
298                self._data = self._parse(fp)
299
300        return self
301
302    def _parse(self, fp):
303        data = {}
304        for line in fp:
305            line = line.strip()
306            if 0 <= line.find(':'):
307                name, passwd = line.split(':', 1)
308                if passwd.startswith('{SHA}'):
309                    data[name] = passwd[5:]
310        return data
311
312
313class BasicAuthMixIn(object):
314    def _execute(self, transforms, *args, **kwargs):
315        def do_execute():
316            sp = super(BasicAuthMixIn, self)
317            return sp._execute(transforms, *args, **kwargs)
318
319        def auth_required():
320            stream = getattr(self, 'stream', self.request.connection.stream)
321            stream.write(tornado.escape.utf8(
322                'HTTP/1.1 401 Authorization Required\r\n'
323                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
324            ))
325            stream.close()
326
327        try:
328            if not self._htpasswd:
329                return do_execute()
330
331            creds = self.request.headers.get('Authorization')
332
333            if not creds or not creds.startswith('Basic '):
334                return auth_required()
335
336            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
337
338            if self._htpasswd.load().auth(name, passwd):
339                return do_execute()
340        except:
341            traceback.print_exc()
342
343        return auth_required()
344
345
346class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
347    IFTYPE = 'server'
348
349    def __init__(self, app, req, switch, htpasswd=None, debug=False):
350        super(EtherWebSocketHandler, self).__init__(app, req)
351        self._switch = switch
352        self._htpasswd = htpasswd
353        self._debug = debug
354
355    @property
356    def target(self):
357        return ':'.join(str(e) for e in self.request.connection.address)
358
359    def open(self):
360        try:
361            return self._switch.register_port(self)
362        finally:
363            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
364
365    def on_message(self, message):
366        self._switch.receive(self, EthernetFrame(message))
367
368    def on_close(self):
369        self._switch.unregister_port(self)
370        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
371
372
373class TapHandler(DebugMixIn):
374    IFTYPE = 'tap'
375    READ_SIZE = 65535
376
377    def __init__(self, ioloop, switch, dev, debug=False):
378        self._ioloop = ioloop
379        self._switch = switch
380        self._dev = dev
381        self._debug = debug
382        self._tap = None
383
384    @property
385    def target(self):
386        if self.closed:
387            return self._dev
388        return self._tap.name
389
390    @property
391    def closed(self):
392        return not self._tap
393
394    def open(self):
395        if not self.closed:
396            raise ValueError('Already opened')
397        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
398        self._tap.up()
399        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
400        return self._switch.register_port(self)
401
402    def close(self):
403        if self.closed:
404            raise ValueError('I/O operation on closed tap')
405        self._switch.unregister_port(self)
406        self._ioloop.remove_handler(self.fileno())
407        self._tap.close()
408        self._tap = None
409
410    def fileno(self):
411        if self.closed:
412            raise ValueError('I/O operation on closed tap')
413        return self._tap.fileno()
414
415    def write_message(self, message, binary=False):
416        if self.closed:
417            raise ValueError('I/O operation on closed tap')
418        self._tap.write(message)
419
420    def __call__(self, fd, events):
421        try:
422            self._switch.receive(self, EthernetFrame(self._read()))
423            return
424        except:
425            traceback.print_exc()
426        self.close()
427
428    def _read(self):
429        if self.closed:
430            raise ValueError('I/O operation on closed tap')
431        buf = []
432        while True:
433            buf.append(self._tap.read(self.READ_SIZE))
434            if len(buf[-1]) < self.READ_SIZE:
435                break
436        return ''.join(buf)
437
438
439class EtherWebSocketClient(DebugMixIn):
440    IFTYPE = 'client'
441
442    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
443        self._ioloop = ioloop
444        self._switch = switch
445        self._url = url
446        self._ssl = ssl_
447        self._debug = debug
448        self._sock = None
449        self._options = {}
450
451        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
452            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
453            auth = ['Authorization: Basic %s' % token]
454            self._options['header'] = auth
455
456    @property
457    def target(self):
458        return self._url
459
460    @property
461    def closed(self):
462        return not self._sock
463
464    def open(self):
465        sslwrap = websocket._SSLSocketWrapper
466
467        if not self.closed:
468            raise websocket.WebSocketException('Already opened')
469
470        if self._ssl:
471            websocket._SSLSocketWrapper = self._ssl
472
473        try:
474            self._sock = websocket.WebSocket()
475            self._sock.connect(self._url, **self._options)
476            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
477            return self._switch.register_port(self)
478        finally:
479            websocket._SSLSocketWrapper = sslwrap
480            self.dprintf('connected: %s\n', lambda: self._url)
481
482    def close(self):
483        if self.closed:
484            raise websocket.WebSocketException('Already closed')
485        self._switch.unregister_port(self)
486        self._ioloop.remove_handler(self.fileno())
487        self._sock.close()
488        self._sock = None
489        self.dprintf('disconnected: %s\n', lambda: self._url)
490
491    def fileno(self):
492        if self.closed:
493            raise websocket.WebSocketException('Closed socket')
494        return self._sock.io_sock.fileno()
495
496    def write_message(self, message, binary=False):
497        if self.closed:
498            raise websocket.WebSocketException('Closed socket')
499        if binary:
500            flag = websocket.ABNF.OPCODE_BINARY
501        else:
502            flag = websocket.ABNF.OPCODE_TEXT
503        self._sock.send(message, flag)
504
505    def __call__(self, fd, events):
506        try:
507            data = self._sock.recv()
508            if data is not None:
509                self._switch.receive(self, EthernetFrame(data))
510                return
511        except:
512            traceback.print_exc()
513        self.close()
514
515
516class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
517    NAMESPACE = 'etherws.control'
518    IFTYPES = {
519        TapHandler.IFTYPE:           TapHandler,
520        EtherWebSocketClient.IFTYPE: EtherWebSocketClient,
521    }
522
523    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
524        super(EtherWebSocketControlHandler, self).__init__(app, req)
525        self._ioloop = ioloop
526        self._switch = switch
527        self._htpasswd = htpasswd
528        self._debug = debug
529
530    def post(self):
531        try:
532            request = json.loads(self.request.body)
533        except Exception as e:
534            return self._jsonrpc_response(error={
535                'code':    0 - 32700,
536                'message': 'Parse error',
537                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
538            })
539
540        try:
541            id_ = request.get('id')
542            params = request.get('params')
543            version = request['jsonrpc']
544            method = request['method']
545            if version != '2.0':
546                raise ValueError('Invalid JSON-RPC version: %s' % version)
547        except Exception as e:
548            return self._jsonrpc_response(id_=id_, error={
549                'code':    0 - 32600,
550                'message': 'Invalid Request',
551                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
552            })
553
554        try:
555            if not method.startswith(self.NAMESPACE + '.'):
556                raise ValueError('Invalid method namespace: %s' % method)
557            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
558            handler = getattr(self, handler)
559        except Exception as e:
560            return self._jsonrpc_response(id_=id_, error={
561                'code':    0 - 32601,
562                'message': 'Method not found',
563                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
564            })
565
566        try:
567            return self._jsonrpc_response(id_=id_, result=handler(params))
568        except Exception as e:
569            traceback.print_exc()
570            return self._jsonrpc_response(id_=id_, error={
571                'code':    0 - 32602,
572                'message': 'Invalid params',
573                'data':     '%s: %s' % (e.__class__.__name__, str(e)),
574            })
575
576    def handle_listFdb(self, params):
577        list_ = []
578        for vid, mac, entry in self._switch.fdb.foreach():
579            list_.append({
580                'vid':  vid,
581                'mac':  EthernetFrame.format_mac(mac),
582                'port': entry.port.number,
583                'age':  int(entry.age),
584            })
585        return {'entries': list_}
586
587    def handle_listPort(self, params):
588        return {'entries': [self._portstat(p) for p in self._switch.portlist]}
589
590    def handle_addPort(self, params):
591        type_ = params['type']
592        target = params['target']
593        opts = getattr(self, '_optparse_' + type_)(params.get('options', {}))
594        cls = self.IFTYPES[type_]
595        interface = cls(self._ioloop, self._switch, target, **opts)
596        portnum = interface.open()
597        return {'entries': [self._portstat(self._switch.get_port(portnum))]}
598
599    def handle_delPort(self, params):
600        port = self._switch.get_port(int(params['port']))
601        port.interface.close()
602        return {'entries': [self._portstat(port)]}
603
604    def handle_shutPort(self, params):
605        port = self._switch.get_port(int(params['port']))
606        port.shut = bool(params['shut'])
607        return {'entries': [self._portstat(port)]}
608
609    def _optparse_tap(self, opt):
610        return {'debug': self._debug}
611
612    def _optparse_client(self, opt):
613        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
614        if opt.get('insecure'):
615            args = {}
616        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
617        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
618        return {'ssl_': ssl_, 'cred': cred, 'debug': self._debug}
619
620    def _jsonrpc_response(self, id_=None, result=None, error=None):
621        res = {'jsonrpc': '2.0', 'id': id_}
622        if result:
623            res['result'] = result
624        if error:
625            res['error'] = error
626        self.finish(res)
627
628    @staticmethod
629    def _portstat(port):
630        return {
631            'port':   port.number,
632            'type':   port.interface.IFTYPE,
633            'target': port.interface.target,
634            'tx':     port.tx,
635            'rx':     port.rx,
636            'shut':   port.shut,
637        }
638
639
640def _print_error(error):
641    print(%s (%s)' % (error['message'], error['code']))
642    print('    %s' % error['data'])
643
644
645def _start_sw(args):
646    def daemonize(nochdir=False, noclose=False):
647        if os.fork() > 0:
648            sys.exit(0)
649
650        os.setsid()
651
652        if os.fork() > 0:
653            sys.exit(0)
654
655        if not nochdir:
656            os.chdir('/')
657
658        if not noclose:
659            os.umask(0)
660            sys.stdin.close()
661            sys.stdout.close()
662            sys.stderr.close()
663            os.close(0)
664            os.close(1)
665            os.close(2)
666            sys.stdin = open(os.devnull)
667            sys.stdout = open(os.devnull, 'a')
668            sys.stderr = open(os.devnull, 'a')
669
670    def checkabspath(ns, path):
671        val = getattr(ns, path, '')
672        if not val.startswith('/'):
673            raise ValueError('Invalid %: %s' % (path, val))
674
675    def getsslopt(ns, key, cert):
676        kval = getattr(ns, key, None)
677        cval = getattr(ns, cert, None)
678        if kval and cval:
679            return {'keyfile': kval, 'certfile': cval}
680        elif kval or cval:
681            raise ValueError('Both %s and %s are required' % (key, cert))
682        return None
683
684    def setrealpath(ns, *keys):
685        for k in keys:
686            v = getattr(ns, k, None)
687            if v is not None:
688                v = os.path.realpath(v)
689                open(v).close()  # check readable
690                setattr(ns, k, v)
691
692    def setport(ns, port, isssl):
693        val = getattr(ns, port, None)
694        if val is None:
695            if isssl:
696                return setattr(ns, port, 443)
697            return setattr(ns, port, 80)
698        if not (0 <= val <= 65535):
699            raise ValueError('Invalid %s: %s' % (port, val))
700
701    def sethtpasswd(ns, htpasswd):
702        val = getattr(ns, htpasswd, None)
703        if val:
704            return setattr(ns, htpasswd, Htpasswd(val))
705
706    #if args.debug:
707    #    websocket.enableTrace(True)
708
709    if args.ageout <= 0:
710        raise ValueError('Invalid ageout: %s' % args.ageout)
711
712    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
713    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
714
715    checkabspath(args, 'path')
716    checkabspath(args, 'ctlpath')
717
718    sslopt = getsslopt(args, 'sslkey', 'sslcert')
719    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
720
721    setport(args, 'port', sslopt)
722    setport(args, 'ctlport', ctlsslopt)
723
724    sethtpasswd(args, 'htpasswd')
725    sethtpasswd(args, 'ctlhtpasswd')
726
727    ioloop = IOLoop.instance()
728    fdb = FDB(ageout=args.ageout, debug=args.debug)
729    switch = SwitchingHub(fdb, debug=args.debug)
730
731    if args.port == args.ctlport and args.host == args.ctlhost:
732        if args.path == args.ctlpath:
733            raise ValueError('Same path/ctlpath on same host')
734        if args.sslkey != args.ctlsslkey:
735            raise ValueError('Different sslkey/ctlsslkey on same host')
736        if args.sslcert != args.ctlsslcert:
737            raise ValueError('Different sslcert/ctlsslcert on same host')
738
739        app = Application([
740            (args.path, EtherWebSocketHandler, {
741                'switch':   switch,
742                'htpasswd': args.htpasswd,
743                'debug':    args.debug,
744            }),
745            (args.ctlpath, EtherWebSocketControlHandler, {
746                'ioloop':   ioloop,
747                'switch':   switch,
748                'htpasswd': args.ctlhtpasswd,
749                'debug':    args.debug,
750            }),
751        ])
752        server = HTTPServer(app, ssl_options=sslopt)
753        server.listen(args.port, address=args.host)
754
755    else:
756        app = Application([(args.path, EtherWebSocketHandler, {
757            'switch':   switch,
758            'htpasswd': args.htpasswd,
759            'debug':    args.debug,
760        })])
761        server = HTTPServer(app, ssl_options=sslopt)
762        server.listen(args.port, address=args.host)
763
764        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
765            'ioloop':   ioloop,
766            'switch':   switch,
767            'htpasswd': args.ctlhtpasswd,
768            'debug':    args.debug,
769        })])
770        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
771        ctlserver.listen(args.ctlport, address=args.ctlhost)
772
773    if not args.foreground:
774        daemonize()
775
776    ioloop.start()
777
778
779def _start_ctl(args):
780    def request(args, method, params=None, id_=0):
781        req = urllib2.Request(args.ctlurl)
782        req.add_header('Content-type', 'application/json')
783        if args.ctluser:
784            if not args.ctlpasswd:
785                args.ctlpasswd = getpass.getpass()
786            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
787            req.add_header('Authorization', 'Basic %s' % token)
788        method = '.'.join([EtherWebSocketControlHandler.NAMESPACE, method])
789        data = {'jsonrpc': '2.0', 'method': method, 'id': id_}
790        if params is not None:
791            data['params'] = params
792        return json.loads(urllib2.urlopen(req, json.dumps(data)).read())
793
794    def maxlen(dict_, key, min_):
795        if not dict_:
796            return min_
797        max_ = max(len(str(r[key])) for r in dict_)
798        return min_ if max_ < min_ else max_
799
800    def print_portlist(result):
801        pmax = maxlen(result, 'port', 4)
802        ymax = maxlen(result, 'type', 4)
803        smax = maxlen(result, 'shut', 5)
804        rmax = maxlen(result, 'rx', 2)
805        tmax = maxlen(result, 'tx', 2)
806        fmt = %%%d%%%d%%%d%%%d%%%d%%s' % \
807              (pmax, ymax, smax, rmax, tmax)
808        print(fmt % ('Port', 'Type', 'State', 'RX', 'TX', 'Target'))
809        for r in result:
810            shut = 'shut' if r['shut'] else 'up'
811            print(fmt %
812                  (r['port'], r['type'], shut, r['rx'], r['tx'], r['target']))
813
814    def handle_ctl_addport(args):
815        opts = {
816            'user':     getattr(args, 'user', None),
817            'passwd':   getattr(args, 'passwd', None),
818            'cacerts':  getattr(args, 'cacerts', None),
819            'insecure': getattr(args, 'insecure', None),
820        }
821        if args.iftype == EtherWebSocketClient.IFTYPE:
822            if not args.target.startswith('ws://') and \
823               not args.target.startswith('wss://'):
824                raise ValueError('Invalid target URL scheme: %s' % args.target)
825        if opts['user'] and not opts['passwd']:
826            raise ValueError('Authentication required but password empty')
827        if not opts['user'] and opts['passwd']:
828            raise ValueError('Authentication required but username empty')
829        result = request(args, 'addPort', {
830            'type':    args.iftype,
831            'target':  args.target,
832            'options': opts,
833        })
834        if 'error' in result:
835            _print_error(result['error'])
836        else:
837            print_portlist(result['result']['entries'])
838
839    def handle_ctl_shutport(args):
840        if args.port <= 0:
841            raise ValueError('Invalid port: %d' % args.port)
842        result = request(args, 'shutPort', {
843            'port': args.port,
844            'shut': args.no,
845        })
846        if 'error' in result:
847            _print_error(result['error'])
848        else:
849            print_portlist(result['result']['entries'])
850
851    def handle_ctl_delport(args):
852        if args.port <= 0:
853            raise ValueError('Invalid port: %d' % args.port)
854        result = request(args, 'delPort', {'port': args.port})
855        if 'error' in result:
856            _print_error(result['error'])
857        else:
858            print_portlist(result['result']['entries'])
859
860    def handle_ctl_listport(args):
861        result = request(args, 'listPort')
862        if 'error' in result:
863            _print_error(result['error'])
864        else:
865            print_portlist(result['result']['entries'])
866
867    def handle_ctl_listfdb(args):
868        result = request(args, 'listFdb')
869        if 'error' in result:
870            return _print_error(result['error'])
871        result = result['result']['entries']
872        pmax = maxlen(result, 'port', 4)
873        vmax = maxlen(result, 'vid', 4)
874        mmax = maxlen(result, 'mac', 3)
875        amax = maxlen(result, 'age', 3)
876        fmt = %%%d%%%d%%-%d%%%ds' % (pmax, vmax, mmax, amax)
877        print(fmt % ('Port', 'VLAN', 'MAC', 'Age'))
878        for r in result:
879            print(fmt % (r['port'], r['vid'], r['mac'], r['age']))
880
881    locals()['handle_ctl_' + args.control_method](args)
882
883
884def _main():
885    parser = argparse.ArgumentParser()
886    subcommand = parser.add_subparsers(dest='subcommand')
887
888    # - sw
889    parser_sw = subcommand.add_parser('sw')
890
891    parser_sw.add_argument('--debug', action='store_true', default=False)
892    parser_sw.add_argument('--foreground', action='store_true', default=False)
893    parser_sw.add_argument('--ageout', type=int, default=300)
894
895    parser_sw.add_argument('--path', default='/')
896    parser_sw.add_argument('--host', default='')
897    parser_sw.add_argument('--port', type=int)
898    parser_sw.add_argument('--htpasswd')
899    parser_sw.add_argument('--sslkey')
900    parser_sw.add_argument('--sslcert')
901
902    parser_sw.add_argument('--ctlpath', default='/ctl')
903    parser_sw.add_argument('--ctlhost', default='')
904    parser_sw.add_argument('--ctlport', type=int)
905    parser_sw.add_argument('--ctlhtpasswd')
906    parser_sw.add_argument('--ctlsslkey')
907    parser_sw.add_argument('--ctlsslcert')
908
909    # - ctl
910    parser_ctl = subcommand.add_parser('ctl')
911    parser_ctl.add_argument('--ctlurl', default='http://localhost/ctl')
912    parser_ctl.add_argument('--ctluser')
913    parser_ctl.add_argument('--ctlpasswd')
914
915    control_method = parser_ctl.add_subparsers(dest='control_method')
916
917    # -- ctl addport
918    parser_ctl_addport = control_method.add_parser('addport')
919    iftype = parser_ctl_addport.add_subparsers(dest='iftype')
920
921    # --- ctl addport tap
922    parser_ctl_addport_tap = iftype.add_parser(TapHandler.IFTYPE)
923    parser_ctl_addport_tap.add_argument('target')
924
925    # --- ctl addport client
926    parser_ctl_addport_client = iftype.add_parser(EtherWebSocketClient.IFTYPE)
927    parser_ctl_addport_client.add_argument('target')
928    parser_ctl_addport_client.add_argument('--user')
929    parser_ctl_addport_client.add_argument('--passwd')
930    parser_ctl_addport_client.add_argument('--cacerts')
931    parser_ctl_addport_client.add_argument(
932        '--insecure', action='store_true', default=False)
933
934    # -- ctl shutport
935    parser_ctl_shutport = control_method.add_parser('shutport')
936    parser_ctl_shutport.add_argument('port', type=int)
937    parser_ctl_shutport.add_argument(
938        '--no', action='store_false', default=True)
939
940    # -- ctl delport
941    parser_ctl_delport = control_method.add_parser('delport')
942    parser_ctl_delport.add_argument('port', type=int)
943
944    # -- ctl listport
945    parser_ctl_listport = control_method.add_parser('listport')
946
947    # -- ctl listfdb
948    parser_ctl_listfdb = control_method.add_parser('listfdb')
949
950    # -- go
951    args = parser.parse_args()
952
953    try:
954        globals()['_start_' + args.subcommand](args)
955    except Exception as e:
956        _print_error({
957            'code':    0 - 32603,
958            'message': 'Internal error',
959            'data':    '%s: %s' % (e.__class__.__name__, str(e)),
960        })
961
962
963if __name__ == '__main__':
964    _main()
Note: See TracBrowser for help on using the repository browser.