source: etherws/trunk/etherws.py @ 202

Revision 202, 29.6 KB checked in by atzm, 12 years ago (diff)
  • support JSON-RPC 2.0
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.3
11#
12# ===========================================================================
13# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
42import ssl
43import time
44import json
45import fcntl
46import base64
47import urllib2
48import hashlib
49import getpass
50import argparse
51import traceback
52
53import tornado
54import websocket
55
56from tornado.web import Application, RequestHandler
57from tornado.websocket import WebSocketHandler
58from tornado.httpserver import HTTPServer
59from tornado.ioloop import IOLoop
60
61from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
62
63
64class DebugMixIn(object):
65    def dprintf(self, msg, func=lambda: ()):
66        if self._debug:
67            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
68            sys.stderr.write(prefix + (msg % func()))
69
70
71class EthernetFrame(object):
72    def __init__(self, data):
73        self.data = data
74
75    @property
76    def dst_multicast(self):
77        return ord(self.data[0]) & 1
78
79    @property
80    def src_multicast(self):
81        return ord(self.data[6]) & 1
82
83    @property
84    def dst_mac(self):
85        return self.data[:6]
86
87    @property
88    def src_mac(self):
89        return self.data[6:12]
90
91    @property
92    def tagged(self):
93        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
94
95    @property
96    def vid(self):
97        if self.tagged:
98            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
99        return 0
100
101    @staticmethod
102    def format_mac(mac, sep=':'):
103        return sep.join(b.encode('hex') for b in mac)
104
105
106class FDB(DebugMixIn):
107    class Entry(object):
108        def __init__(self, port, ageout):
109            self.port = port
110            self._time = time.time()
111            self._ageout = ageout
112
113        @property
114        def age(self):
115            return time.time() - self._time
116
117        @property
118        def agedout(self):
119            return self.age > self._ageout
120
121    def __init__(self, ageout, debug=False):
122        self._ageout = ageout
123        self._debug = debug
124        self._table = {}
125
126    def _set_entry(self, vid, mac, port):
127        if vid not in self._table:
128            self._table[vid] = {}
129        self._table[vid][mac] = self.Entry(port, self._ageout)
130
131    def _del_entry(self, vid, mac):
132        if vid in self._table:
133            if mac in self._table[vid]:
134                del self._table[vid][mac]
135            if not self._table[vid]:
136                del self._table[vid]
137
138    def get_entry(self, vid, mac):
139        try:
140            entry = self._table[vid][mac]
141        except KeyError:
142            return None
143
144        if not entry.agedout:
145            return entry
146
147        self._del_entry(vid, mac)
148        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
149                     lambda: (entry.port.number, vid, mac.encode('hex')))
150
151    def get_vid_list(self):
152        return sorted(self._table.iterkeys())
153
154    def get_mac_list(self, vid):
155        return sorted(self._table[vid].iterkeys())
156
157    def lookup(self, frame):
158        mac = frame.dst_mac
159        vid = frame.vid
160        entry = self.get_entry(vid, mac)
161        return getattr(entry, 'port', None)
162
163    def learn(self, port, frame):
164        mac = frame.src_mac
165        vid = frame.vid
166        self._set_entry(vid, mac, port)
167        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
168                     lambda: (port.number, vid, mac.encode('hex')))
169
170    def delete(self, port):
171        for vid in self.get_vid_list():
172            for mac in self.get_mac_list(vid):
173                entry = self.get_entry(vid, mac)
174                if entry and entry.port.number == port.number:
175                    self._del_entry(vid, mac)
176                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
177                                 lambda: (port.number, vid, mac.encode('hex')))
178
179
180class SwitchingHub(DebugMixIn):
181    class Port(object):
182        def __init__(self, number, interface):
183            self.number = number
184            self.interface = interface
185            self.tx = 0
186            self.rx = 0
187            self.shut = False
188
189        @staticmethod
190        def cmp_by_number(x, y):
191            return cmp(x.number, y.number)
192
193    def __init__(self, fdb, debug=False):
194        self.fdb = fdb
195        self._debug = debug
196        self._table = {}
197        self._next = 1
198
199    @property
200    def portlist(self):
201        return sorted(self._table.itervalues(), cmp=self.Port.cmp_by_number)
202
203    def get_port(self, portnum):
204        return self._table[portnum]
205
206    def register_port(self, interface):
207        try:
208            self._set_privattr('portnum', interface, self._next)  # XXX
209            self._table[self._next] = self.Port(self._next, interface)
210            return self._next
211        finally:
212            self._next += 1
213
214    def unregister_port(self, interface):
215        portnum = self._get_privattr('portnum', interface)
216        self._del_privattr('portnum', interface)
217        self.fdb.delete(self._table[portnum])
218        del self._table[portnum]
219
220    def send(self, dst_interfaces, frame):
221        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
222        ports = (self._table[n] for n in portnums)
223        ports = (p for p in ports if not p.shut)
224        ports = sorted(ports, cmp=self.Port.cmp_by_number)
225
226        for p in ports:
227            p.interface.write_message(frame.data, True)
228            p.tx += 1
229
230        if ports:
231            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
232                         lambda: (','.join(str(p.number) for p in ports),
233                                  frame.vid,
234                                  frame.src_mac.encode('hex'),
235                                  frame.dst_mac.encode('hex')))
236
237    def receive(self, src_interface, frame):
238        port = self._table[self._get_privattr('portnum', src_interface)]
239
240        if not port.shut:
241            port.rx += 1
242            self._forward(port, frame)
243
244    def _forward(self, src_port, frame):
245        try:
246            if not frame.src_multicast:
247                self.fdb.learn(src_port, frame)
248
249            if not frame.dst_multicast:
250                dst_port = self.fdb.lookup(frame)
251
252                if dst_port:
253                    self.send([dst_port.interface], frame)
254                    return
255
256            ports = set(self.portlist) - set([src_port])
257            self.send((p.interface for p in ports), frame)
258
259        except:  # ex. received invalid frame
260            traceback.print_exc()
261
262    def _privattr(self, name):
263        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
264
265    def _set_privattr(self, name, obj, value):
266        return setattr(obj, self._privattr(name), value)
267
268    def _get_privattr(self, name, obj, defaults=None):
269        return getattr(obj, self._privattr(name), defaults)
270
271    def _del_privattr(self, name, obj):
272        return delattr(obj, self._privattr(name))
273
274
275class Htpasswd(object):
276    def __init__(self, path):
277        self._path = path
278        self._stat = None
279        self._data = {}
280
281    def auth(self, name, passwd):
282        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
283        return self._data.get(name) == passwd
284
285    def load(self):
286        old_stat = self._stat
287
288        with open(self._path) as fp:
289            fileno = fp.fileno()
290            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
291            self._stat = os.fstat(fileno)
292
293            unchanged = old_stat and \
294                        old_stat.st_ino == self._stat.st_ino and \
295                        old_stat.st_dev == self._stat.st_dev and \
296                        old_stat.st_mtime == self._stat.st_mtime
297
298            if not unchanged:
299                self._data = self._parse(fp)
300
301        return self
302
303    def _parse(self, fp):
304        data = {}
305        for line in fp:
306            line = line.strip()
307            if 0 <= line.find(':'):
308                name, passwd = line.split(':', 1)
309                if passwd.startswith('{SHA}'):
310                    data[name] = passwd[5:]
311        return data
312
313
314class BasicAuthMixIn(object):
315    def _execute(self, transforms, *args, **kwargs):
316        def do_execute():
317            sp = super(BasicAuthMixIn, self)
318            return sp._execute(transforms, *args, **kwargs)
319
320        def auth_required():
321            stream = getattr(self, 'stream', self.request.connection.stream)
322            stream.write(tornado.escape.utf8(
323                'HTTP/1.1 401 Authorization Required\r\n'
324                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
325            ))
326            stream.close()
327
328        try:
329            if not self._htpasswd:
330                return do_execute()
331
332            creds = self.request.headers.get('Authorization')
333
334            if not creds or not creds.startswith('Basic '):
335                return auth_required()
336
337            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
338
339            if self._htpasswd.load().auth(name, passwd):
340                return do_execute()
341        except:
342            traceback.print_exc()
343
344        return auth_required()
345
346
347class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
348    IFTYPE = 'server'
349
350    def __init__(self, app, req, switch, htpasswd=None, debug=False):
351        super(EtherWebSocketHandler, self).__init__(app, req)
352        self._switch = switch
353        self._htpasswd = htpasswd
354        self._debug = debug
355
356    def get_target(self):
357        return self.request.remote_ip
358
359    def open(self):
360        try:
361            return self._switch.register_port(self)
362        finally:
363            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
364
365    def on_message(self, message):
366        self._switch.receive(self, EthernetFrame(message))
367
368    def on_close(self):
369        self._switch.unregister_port(self)
370        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
371
372
373class TapHandler(DebugMixIn):
374    IFTYPE = 'tap'
375    READ_SIZE = 65535
376
377    def __init__(self, ioloop, switch, dev, debug=False):
378        self._ioloop = ioloop
379        self._switch = switch
380        self._dev = dev
381        self._debug = debug
382        self._tap = None
383
384    def get_target(self):
385        if self.closed:
386            return self._dev
387        return self._tap.name
388
389    @property
390    def closed(self):
391        return not self._tap
392
393    def open(self):
394        if not self.closed:
395            raise ValueError('Already opened')
396        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
397        self._tap.up()
398        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
399        return self._switch.register_port(self)
400
401    def close(self):
402        if self.closed:
403            raise ValueError('I/O operation on closed tap')
404        self._switch.unregister_port(self)
405        self._ioloop.remove_handler(self.fileno())
406        self._tap.close()
407        self._tap = None
408
409    def fileno(self):
410        if self.closed:
411            raise ValueError('I/O operation on closed tap')
412        return self._tap.fileno()
413
414    def write_message(self, message, binary=False):
415        if self.closed:
416            raise ValueError('I/O operation on closed tap')
417        self._tap.write(message)
418
419    def __call__(self, fd, events):
420        try:
421            self._switch.receive(self, EthernetFrame(self._read()))
422            return
423        except:
424            traceback.print_exc()
425        self.close()
426
427    def _read(self):
428        if self.closed:
429            raise ValueError('I/O operation on closed tap')
430        buf = []
431        while True:
432            buf.append(self._tap.read(self.READ_SIZE))
433            if len(buf[-1]) < self.READ_SIZE:
434                break
435        return ''.join(buf)
436
437
438class EtherWebSocketClient(DebugMixIn):
439    IFTYPE = 'client'
440
441    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
442        self._ioloop = ioloop
443        self._switch = switch
444        self._url = url
445        self._ssl = ssl_
446        self._debug = debug
447        self._sock = None
448        self._options = {}
449
450        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
451            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
452            auth = ['Authorization: Basic %s' % token]
453            self._options['header'] = auth
454
455    def get_target(self):
456        return self._url
457
458    @property
459    def closed(self):
460        return not self._sock
461
462    def open(self):
463        sslwrap = websocket._SSLSocketWrapper
464
465        if not self.closed:
466            raise websocket.WebSocketException('Already opened')
467
468        if self._ssl:
469            websocket._SSLSocketWrapper = self._ssl
470
471        try:
472            self._sock = websocket.WebSocket()
473            self._sock.connect(self._url, **self._options)
474            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
475            return self._switch.register_port(self)
476        finally:
477            websocket._SSLSocketWrapper = sslwrap
478            self.dprintf('connected: %s\n', lambda: self._url)
479
480    def close(self):
481        if self.closed:
482            raise websocket.WebSocketException('Already closed')
483        self._switch.unregister_port(self)
484        self._ioloop.remove_handler(self.fileno())
485        self._sock.close()
486        self._sock = None
487        self.dprintf('disconnected: %s\n', lambda: self._url)
488
489    def fileno(self):
490        if self.closed:
491            raise websocket.WebSocketException('Closed socket')
492        return self._sock.io_sock.fileno()
493
494    def write_message(self, message, binary=False):
495        if self.closed:
496            raise websocket.WebSocketException('Closed socket')
497        if binary:
498            flag = websocket.ABNF.OPCODE_BINARY
499        else:
500            flag = websocket.ABNF.OPCODE_TEXT
501        self._sock.send(message, flag)
502
503    def __call__(self, fd, events):
504        try:
505            data = self._sock.recv()
506            if data is not None:
507                self._switch.receive(self, EthernetFrame(data))
508                return
509        except:
510            traceback.print_exc()
511        self.close()
512
513
514class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
515    NAMESPACE = 'etherws.control'
516    IFTYPES = {
517        TapHandler.IFTYPE:           TapHandler,
518        EtherWebSocketClient.IFTYPE: EtherWebSocketClient,
519    }
520
521    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
522        super(EtherWebSocketControlHandler, self).__init__(app, req)
523        self._ioloop = ioloop
524        self._switch = switch
525        self._htpasswd = htpasswd
526        self._debug = debug
527
528    def post(self):
529        try:
530            request = json.loads(self.request.body)
531        except Exception as e:
532            return self._jsonrpc_response(error={
533                'code':    0 - 32700,
534                'message': 'Parse error',
535                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
536            })
537
538        try:
539            id_ = request.get('id')
540            params = request.get('params')
541            version = request['jsonrpc']
542            method = request['method']
543            if version != '2.0':
544                raise ValueError('Invalid JSON-RPC version: %s' % version)
545        except Exception as e:
546            return self._jsonrpc_response(id_=id_, error={
547                'code':    0 - 32600,
548                'message': 'Invalid Request',
549                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
550            })
551
552        try:
553            if not method.startswith(self.NAMESPACE + '.'):
554                raise ValueError('Invalid method namespace: %s' % method)
555            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
556            handler = getattr(self, handler)
557        except Exception as e:
558            return self._jsonrpc_response(id_=id_, error={
559                'code':    0 - 32601,
560                'message': 'Method not found',
561                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
562            })
563
564        try:
565            return self._jsonrpc_response(id_=id_, result=handler(params))
566        except Exception as e:
567            traceback.print_exc()
568            return self._jsonrpc_response(id_=id_, error={
569                'code':    0 - 32602,
570                'message': 'Invalid params',
571                'data':     '%s: %s' % (e.__class__.__name__, str(e)),
572            })
573
574    def handle_listFdb(self, params):
575        list_ = []
576        for vid in self._switch.fdb.get_vid_list():
577            for mac in self._switch.fdb.get_mac_list(vid):
578                entry = self._switch.fdb.get_entry(vid, mac)
579                if entry:
580                    list_.append({
581                        'vid':  vid,
582                        'mac':  EthernetFrame.format_mac(mac),
583                        'port': entry.port.number,
584                        'age':  int(entry.age),
585                    })
586        return {'entries': list_}
587
588    def handle_listPort(self, params):
589        return {'entries': [self._portstat(p) for p in self._switch.portlist]}
590
591    def handle_addPort(self, params):
592        type_ = params['type']
593        target = params['target']
594        opts = getattr(self, '_optparse_' + type_)(params.get('options', {}))
595        cls = self.IFTYPES[type_]
596        interface = cls(self._ioloop, self._switch, target, **opts)
597        portnum = interface.open()
598        return {'entries': [self._portstat(self._switch.get_port(portnum))]}
599
600    def handle_delPort(self, params):
601        port = self._switch.get_port(int(params['port']))
602        port.interface.close()
603        return {'entries': [self._portstat(port)]}
604
605    def handle_shutPort(self, params):
606        port = self._switch.get_port(int(params['port']))
607        port.shut = bool(params['shut'])
608        return {'entries': [self._portstat(port)]}
609
610    def _optparse_tap(self, opt):
611        return {'debug': self._debug}
612
613    def _optparse_client(self, opt):
614        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
615        if opt.get('insecure'):
616            args = {}
617        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
618        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
619        return {'ssl_': ssl_, 'cred': cred, 'debug': self._debug}
620
621    def _jsonrpc_response(self, id_=None, result=None, error=None):
622        res = {'jsonrpc': '2.0', 'id': id_}
623        if result:
624            res['result'] = result
625        if error:
626            res['error'] = error
627        self.finish(res)
628
629    @staticmethod
630    def _portstat(port):
631        return {
632            'port':   port.number,
633            'type':   port.interface.IFTYPE,
634            'target': port.interface.get_target(),
635            'tx':     port.tx,
636            'rx':     port.rx,
637            'shut':   port.shut,
638        }
639
640
641def start_sw(args):
642    def daemonize(nochdir=False, noclose=False):
643        if os.fork() > 0:
644            sys.exit(0)
645
646        os.setsid()
647
648        if os.fork() > 0:
649            sys.exit(0)
650
651        if not nochdir:
652            os.chdir('/')
653
654        if not noclose:
655            os.umask(0)
656            sys.stdin.close()
657            sys.stdout.close()
658            sys.stderr.close()
659            os.close(0)
660            os.close(1)
661            os.close(2)
662            sys.stdin = open(os.devnull)
663            sys.stdout = open(os.devnull, 'a')
664            sys.stderr = open(os.devnull, 'a')
665
666    def checkabspath(ns, path):
667        val = getattr(ns, path, '')
668        if not val.startswith('/'):
669            raise ValueError('Invalid %: %s' % (path, val))
670
671    def getsslopt(ns, key, cert):
672        kval = getattr(ns, key, None)
673        cval = getattr(ns, cert, None)
674        if kval and cval:
675            return {'keyfile': kval, 'certfile': cval}
676        elif kval or cval:
677            raise ValueError('Both %s and %s are required' % (key, cert))
678        return None
679
680    def setrealpath(ns, *keys):
681        for k in keys:
682            v = getattr(ns, k, None)
683            if v is not None:
684                v = os.path.realpath(v)
685                open(v).close()  # check readable
686                setattr(ns, k, v)
687
688    def setport(ns, port, isssl):
689        val = getattr(ns, port, None)
690        if val is None:
691            if isssl:
692                return setattr(ns, port, 443)
693            return setattr(ns, port, 80)
694        if not (0 <= val <= 65535):
695            raise ValueError('Invalid %s: %s' % (port, val))
696
697    def sethtpasswd(ns, htpasswd):
698        val = getattr(ns, htpasswd, None)
699        if val:
700            return setattr(ns, htpasswd, Htpasswd(val))
701
702    #if args.debug:
703    #    websocket.enableTrace(True)
704
705    if args.ageout <= 0:
706        raise ValueError('Invalid ageout: %s' % args.ageout)
707
708    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
709    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
710
711    checkabspath(args, 'path')
712    checkabspath(args, 'ctlpath')
713
714    sslopt = getsslopt(args, 'sslkey', 'sslcert')
715    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
716
717    setport(args, 'port', sslopt)
718    setport(args, 'ctlport', ctlsslopt)
719
720    sethtpasswd(args, 'htpasswd')
721    sethtpasswd(args, 'ctlhtpasswd')
722
723    ioloop = IOLoop.instance()
724    fdb = FDB(ageout=args.ageout, debug=args.debug)
725    switch = SwitchingHub(fdb, debug=args.debug)
726
727    if args.port == args.ctlport and args.host == args.ctlhost:
728        if args.path == args.ctlpath:
729            raise ValueError('Same path/ctlpath on same host')
730        if args.sslkey != args.ctlsslkey:
731            raise ValueError('Different sslkey/ctlsslkey on same host')
732        if args.sslcert != args.ctlsslcert:
733            raise ValueError('Different sslcert/ctlsslcert on same host')
734
735        app = Application([
736            (args.path, EtherWebSocketHandler, {
737                'switch':   switch,
738                'htpasswd': args.htpasswd,
739                'debug':    args.debug,
740            }),
741            (args.ctlpath, EtherWebSocketControlHandler, {
742                'ioloop':   ioloop,
743                'switch':   switch,
744                'htpasswd': args.ctlhtpasswd,
745                'debug':    args.debug,
746            }),
747        ])
748        server = HTTPServer(app, ssl_options=sslopt)
749        server.listen(args.port, address=args.host)
750
751    else:
752        app = Application([(args.path, EtherWebSocketHandler, {
753            'switch':   switch,
754            'htpasswd': args.htpasswd,
755            'debug':    args.debug,
756        })])
757        server = HTTPServer(app, ssl_options=sslopt)
758        server.listen(args.port, address=args.host)
759
760        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
761            'ioloop':   ioloop,
762            'switch':   switch,
763            'htpasswd': args.ctlhtpasswd,
764            'debug':    args.debug,
765        })])
766        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
767        ctlserver.listen(args.ctlport, address=args.ctlhost)
768
769    if not args.foreground:
770        daemonize()
771
772    ioloop.start()
773
774
775def start_ctl(args):
776    def request(args, method, params=None, id_=0):
777        req = urllib2.Request(args.ctlurl)
778        req.add_header('Content-type', 'application/json')
779        if args.ctluser:
780            if not args.ctlpasswd:
781                args.ctlpasswd = getpass.getpass()
782            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
783            req.add_header('Authorization', 'Basic %s' % token)
784        method = '.'.join([EtherWebSocketControlHandler.NAMESPACE, method])
785        data = {'jsonrpc': '2.0', 'method': method, 'id': id_}
786        if params is not None:
787            data['params'] = params
788        return json.loads(urllib2.urlopen(req, json.dumps(data)).read())
789
790    def maxlen(dict_, key, min_):
791        if not dict_:
792            return min_
793        max_ = max(len(str(r[key])) for r in dict_)
794        return min_ if max_ < min_ else max_
795
796    def print_portlist(result):
797        pmax = maxlen(result, 'port', 4)
798        ymax = maxlen(result, 'type', 4)
799        smax = maxlen(result, 'shut', 5)
800        rmax = maxlen(result, 'rx', 2)
801        tmax = maxlen(result, 'tx', 2)
802        fmt = %%%d%%%d%%%d%%%d%%%d%%s' % \
803              (pmax, ymax, smax, rmax, tmax)
804        print(fmt % ('Port', 'Type', 'State', 'RX', 'TX', 'Target'))
805        for r in result:
806            shut = 'shut' if r['shut'] else 'up'
807            print(fmt %
808                  (r['port'], r['type'], shut, r['rx'], r['tx'], r['target']))
809
810    def print_error(error):
811        print(%s (%s)' % (error['message'], error['code']))
812        print('    %s' % error['data'])
813
814    def handle_ctl_addport(args):
815        result = request(args, 'addPort', {
816            'type':    args.type,
817            'target':  args.target,
818            'options': {
819                'insecure': args.insecure,
820                'cacerts':  args.cacerts,
821                'user':     args.user,
822                'passwd':   args.passwd,
823            },
824        })
825        if 'error' in result:
826            print_error(result['error'])
827        else:
828            print_portlist(result['result']['entries'])
829
830    def handle_ctl_shutport(args):
831        if args.port <= 0:
832            raise ValueError('Invalid port: %d' % args.port)
833        result = request(args, 'shutPort', {
834            'port': args.port,
835            'shut': args.no,
836        })
837        if 'error' in result:
838            print_error(result['error'])
839        else:
840            print_portlist(result['result']['entries'])
841
842    def handle_ctl_delport(args):
843        if args.port <= 0:
844            raise ValueError('Invalid port: %d' % args.port)
845        result = request(args, 'delPort', {'port': args.port})
846        if 'error' in result:
847            print_error(result['error'])
848        else:
849            print_portlist(result['result']['entries'])
850
851    def handle_ctl_listport(args):
852        result = request(args, 'listPort')
853        if 'error' in result:
854            print_error(result['error'])
855        else:
856            print_portlist(result['result']['entries'])
857
858    def handle_ctl_listfdb(args):
859        result = request(args, 'listFdb')
860        if 'error' in result:
861            return print_error(result['error'])
862        result = result['result']['entries']
863        pmax = maxlen(result, 'port', 4)
864        vmax = maxlen(result, 'vid', 4)
865        mmax = maxlen(result, 'mac', 3)
866        amax = maxlen(result, 'age', 3)
867        fmt = %%%d%%%d%%-%d%%%ds' % (pmax, vmax, mmax, amax)
868        print(fmt % ('Port', 'VLAN', 'MAC', 'Age'))
869        for r in result:
870            print(fmt % (r['port'], r['vid'], r['mac'], r['age']))
871
872    locals()['handle_ctl_' + args.control_method](args)
873
874
875def main():
876    parser = argparse.ArgumentParser()
877    subcommand = parser.add_subparsers(dest='subcommand')
878
879    # -- sw command parser
880    parser_s = subcommand.add_parser('sw')
881
882    parser_s.add_argument('--debug', action='store_true', default=False)
883    parser_s.add_argument('--foreground', action='store_true', default=False)
884    parser_s.add_argument('--ageout', type=int, default=300)
885
886    parser_s.add_argument('--path', default='/')
887    parser_s.add_argument('--host', default='')
888    parser_s.add_argument('--port', type=int)
889    parser_s.add_argument('--htpasswd')
890    parser_s.add_argument('--sslkey')
891    parser_s.add_argument('--sslcert')
892
893    parser_s.add_argument('--ctlpath', default='/ctl')
894    parser_s.add_argument('--ctlhost', default='')
895    parser_s.add_argument('--ctlport', type=int)
896    parser_s.add_argument('--ctlhtpasswd')
897    parser_s.add_argument('--ctlsslkey')
898    parser_s.add_argument('--ctlsslcert')
899
900    # -- ctl command parser
901    parser_c = subcommand.add_parser('ctl')
902    parser_c.add_argument('--ctlurl', default='http://localhost/ctl')
903    parser_c.add_argument('--ctluser')
904    parser_c.add_argument('--ctlpasswd')
905
906    control_method = parser_c.add_subparsers(dest='control_method')
907
908    parser_c_ap = control_method.add_parser('addport')
909    parser_c_ap.add_argument(
910        'type', choices=EtherWebSocketControlHandler.IFTYPES.keys())
911    parser_c_ap.add_argument('target')
912    parser_c_ap.add_argument('--insecure', action='store_true', default=False)
913    parser_c_ap.add_argument('--cacerts')
914    parser_c_ap.add_argument('--user')
915    parser_c_ap.add_argument('--passwd')
916
917    parser_c_sp = control_method.add_parser('shutport')
918    parser_c_sp.add_argument('port', type=int)
919    parser_c_sp.add_argument('--no', action='store_false', default=True)
920
921    parser_c_dp = control_method.add_parser('delport')
922    parser_c_dp.add_argument('port', type=int)
923
924    parser_c_lp = control_method.add_parser('listport')
925
926    parser_c_lf = control_method.add_parser('listfdb')
927
928    # -- go
929    args = parser.parse_args()
930    globals()['start_' + args.subcommand](args)
931
932
933if __name__ == '__main__':
934    main()
Note: See TracBrowser for help on using the repository browser.