source: etherws/trunk/etherws.py @ 199

Revision 199, 28.6 KB checked in by atzm, 12 years ago (diff)
  • kill depends on yaml
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.3
11#
12# ===========================================================================
13# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
42import ssl
43import time
44import json
45import fcntl
46import base64
47import urllib2
48import hashlib
49import getpass
50import argparse
51import traceback
52
53import tornado
54import websocket
55
56from tornado.web import Application, RequestHandler
57from tornado.websocket import WebSocketHandler
58from tornado.httpserver import HTTPServer
59from tornado.ioloop import IOLoop
60
61from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
62
63
64class DebugMixIn(object):
65    def dprintf(self, msg, func=lambda: ()):
66        if self._debug:
67            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
68            sys.stderr.write(prefix + (msg % func()))
69
70
71class EthernetFrame(object):
72    def __init__(self, data):
73        self.data = data
74
75    @property
76    def dst_multicast(self):
77        return ord(self.data[0]) & 1
78
79    @property
80    def src_multicast(self):
81        return ord(self.data[6]) & 1
82
83    @property
84    def dst_mac(self):
85        return self.data[:6]
86
87    @property
88    def src_mac(self):
89        return self.data[6:12]
90
91    @property
92    def tagged(self):
93        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
94
95    @property
96    def vid(self):
97        if self.tagged:
98            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
99        return 0
100
101    @staticmethod
102    def format_mac(mac, sep=':'):
103        return sep.join(b.encode('hex') for b in mac)
104
105
106class FDB(DebugMixIn):
107    class Entry(object):
108        def __init__(self, port, ageout):
109            self.port = port
110            self._time = time.time()
111            self._ageout = ageout
112
113        @property
114        def age(self):
115            return time.time() - self._time
116
117        @property
118        def agedout(self):
119            return self.age > self._ageout
120
121    def __init__(self, ageout, debug=False):
122        self._ageout = ageout
123        self._debug = debug
124        self._table = {}
125
126    def _set_entry(self, vid, mac, port):
127        if vid not in self._table:
128            self._table[vid] = {}
129        self._table[vid][mac] = self.Entry(port, self._ageout)
130
131    def _del_entry(self, vid, mac):
132        if vid in self._table:
133            if mac in self._table[vid]:
134                del self._table[vid][mac]
135            if not self._table[vid]:
136                del self._table[vid]
137
138    def get_entry(self, vid, mac):
139        try:
140            entry = self._table[vid][mac]
141        except KeyError:
142            return None
143
144        if not entry.agedout:
145            return entry
146
147        self._del_entry(vid, mac)
148        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
149                     lambda: (entry.port.number, vid, mac.encode('hex')))
150
151    def get_vid_list(self):
152        return sorted(self._table.iterkeys())
153
154    def get_mac_list(self, vid):
155        return sorted(self._table[vid].iterkeys())
156
157    def lookup(self, frame):
158        mac = frame.dst_mac
159        vid = frame.vid
160        entry = self.get_entry(vid, mac)
161        return getattr(entry, 'port', None)
162
163    def learn(self, port, frame):
164        mac = frame.src_mac
165        vid = frame.vid
166        self._set_entry(vid, mac, port)
167        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
168                     lambda: (port.number, vid, mac.encode('hex')))
169
170    def delete(self, port):
171        for vid in self.get_vid_list():
172            for mac in self.get_mac_list(vid):
173                entry = self.get_entry(vid, mac)
174                if entry and entry.port.number == port.number:
175                    self._del_entry(vid, mac)
176                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
177                                 lambda: (port.number, vid, mac.encode('hex')))
178
179
180class SwitchingHub(DebugMixIn):
181    class Port(object):
182        def __init__(self, number, interface):
183            self.number = number
184            self.interface = interface
185            self.tx = 0
186            self.rx = 0
187            self.shut = False
188
189        @staticmethod
190        def cmp_by_number(x, y):
191            return cmp(x.number, y.number)
192
193    def __init__(self, fdb, debug=False):
194        self.fdb = fdb
195        self._debug = debug
196        self._table = {}
197        self._next = 1
198
199    @property
200    def portlist(self):
201        return sorted(self._table.itervalues(), cmp=self.Port.cmp_by_number)
202
203    def get_port(self, portnum):
204        return self._table[portnum]
205
206    def register_port(self, interface):
207        try:
208            self._set_privattr('portnum', interface, self._next)  # XXX
209            self._table[self._next] = self.Port(self._next, interface)
210            return self._next
211        finally:
212            self._next += 1
213
214    def unregister_port(self, interface):
215        portnum = self._get_privattr('portnum', interface)
216        self._del_privattr('portnum', interface)
217        self.fdb.delete(self._table[portnum])
218        del self._table[portnum]
219
220    def send(self, dst_interfaces, frame):
221        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
222        ports = (self._table[n] for n in portnums)
223        ports = (p for p in ports if not p.shut)
224        ports = sorted(ports, cmp=self.Port.cmp_by_number)
225
226        for p in ports:
227            p.interface.write_message(frame.data, True)
228            p.tx += 1
229
230        if ports:
231            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
232                         lambda: (','.join(str(p.number) for p in ports),
233                                  frame.vid,
234                                  frame.src_mac.encode('hex'),
235                                  frame.dst_mac.encode('hex')))
236
237    def receive(self, src_interface, frame):
238        port = self._table[self._get_privattr('portnum', src_interface)]
239
240        if not port.shut:
241            port.rx += 1
242            self._forward(port, frame)
243
244    def _forward(self, src_port, frame):
245        try:
246            if not frame.src_multicast:
247                self.fdb.learn(src_port, frame)
248
249            if not frame.dst_multicast:
250                dst_port = self.fdb.lookup(frame)
251
252                if dst_port:
253                    self.send([dst_port.interface], frame)
254                    return
255
256            ports = set(self.portlist) - set([src_port])
257            self.send((p.interface for p in ports), frame)
258
259        except:  # ex. received invalid frame
260            traceback.print_exc()
261
262    def _privattr(self, name):
263        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
264
265    def _set_privattr(self, name, obj, value):
266        return setattr(obj, self._privattr(name), value)
267
268    def _get_privattr(self, name, obj, defaults=None):
269        return getattr(obj, self._privattr(name), defaults)
270
271    def _del_privattr(self, name, obj):
272        return delattr(obj, self._privattr(name))
273
274
275class Htpasswd(object):
276    def __init__(self, path):
277        self._path = path
278        self._stat = None
279        self._data = {}
280
281    def auth(self, name, passwd):
282        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
283        return self._data.get(name) == passwd
284
285    def load(self):
286        old_stat = self._stat
287
288        with open(self._path) as fp:
289            fileno = fp.fileno()
290            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
291            self._stat = os.fstat(fileno)
292
293            unchanged = old_stat and \
294                        old_stat.st_ino == self._stat.st_ino and \
295                        old_stat.st_dev == self._stat.st_dev and \
296                        old_stat.st_mtime == self._stat.st_mtime
297
298            if not unchanged:
299                self._data = self._parse(fp)
300
301        return self
302
303    def _parse(self, fp):
304        data = {}
305        for line in fp:
306            line = line.strip()
307            if 0 <= line.find(':'):
308                name, passwd = line.split(':', 1)
309                if passwd.startswith('{SHA}'):
310                    data[name] = passwd[5:]
311        return data
312
313
314class BasicAuthMixIn(object):
315    def _execute(self, transforms, *args, **kwargs):
316        def do_execute():
317            sp = super(BasicAuthMixIn, self)
318            return sp._execute(transforms, *args, **kwargs)
319
320        def auth_required():
321            stream = getattr(self, 'stream', self.request.connection.stream)
322            stream.write(tornado.escape.utf8(
323                'HTTP/1.1 401 Authorization Required\r\n'
324                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
325            ))
326            stream.close()
327
328        try:
329            if not self._htpasswd:
330                return do_execute()
331
332            creds = self.request.headers.get('Authorization')
333
334            if not creds or not creds.startswith('Basic '):
335                return auth_required()
336
337            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
338
339            if self._htpasswd.load().auth(name, passwd):
340                return do_execute()
341        except:
342            traceback.print_exc()
343
344        return auth_required()
345
346
347class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
348    IFTYPE = 'server'
349
350    def __init__(self, app, req, switch, htpasswd=None, debug=False):
351        super(EtherWebSocketHandler, self).__init__(app, req)
352        self._switch = switch
353        self._htpasswd = htpasswd
354        self._debug = debug
355
356    def get_target(self):
357        return self.request.remote_ip
358
359    def open(self):
360        try:
361            return self._switch.register_port(self)
362        finally:
363            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
364
365    def on_message(self, message):
366        self._switch.receive(self, EthernetFrame(message))
367
368    def on_close(self):
369        self._switch.unregister_port(self)
370        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
371
372
373class TapHandler(DebugMixIn):
374    IFTYPE = 'tap'
375    READ_SIZE = 65535
376
377    def __init__(self, ioloop, switch, dev, debug=False):
378        self._ioloop = ioloop
379        self._switch = switch
380        self._dev = dev
381        self._debug = debug
382        self._tap = None
383
384    def get_target(self):
385        if self.closed:
386            return self._dev
387        return self._tap.name
388
389    @property
390    def closed(self):
391        return not self._tap
392
393    def open(self):
394        if not self.closed:
395            raise ValueError('already opened')
396        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
397        self._tap.up()
398        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
399        return self._switch.register_port(self)
400
401    def close(self):
402        if self.closed:
403            raise ValueError('I/O operation on closed tap')
404        self._switch.unregister_port(self)
405        self._ioloop.remove_handler(self.fileno())
406        self._tap.close()
407        self._tap = None
408
409    def fileno(self):
410        if self.closed:
411            raise ValueError('I/O operation on closed tap')
412        return self._tap.fileno()
413
414    def write_message(self, message, binary=False):
415        if self.closed:
416            raise ValueError('I/O operation on closed tap')
417        self._tap.write(message)
418
419    def __call__(self, fd, events):
420        try:
421            self._switch.receive(self, EthernetFrame(self._read()))
422            return
423        except:
424            traceback.print_exc()
425        self.close()
426
427    def _read(self):
428        if self.closed:
429            raise ValueError('I/O operation on closed tap')
430        buf = []
431        while True:
432            buf.append(self._tap.read(self.READ_SIZE))
433            if len(buf[-1]) < self.READ_SIZE:
434                break
435        return ''.join(buf)
436
437
438class EtherWebSocketClient(DebugMixIn):
439    IFTYPE = 'client'
440
441    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
442        self._ioloop = ioloop
443        self._switch = switch
444        self._url = url
445        self._ssl = ssl_
446        self._debug = debug
447        self._sock = None
448        self._options = {}
449
450        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
451            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
452            auth = ['Authorization: Basic %s' % token]
453            self._options['header'] = auth
454
455    def get_target(self):
456        return self._url
457
458    @property
459    def closed(self):
460        return not self._sock
461
462    def open(self):
463        sslwrap = websocket._SSLSocketWrapper
464
465        if not self.closed:
466            raise websocket.WebSocketException('already opened')
467
468        if self._ssl:
469            websocket._SSLSocketWrapper = self._ssl
470
471        try:
472            self._sock = websocket.WebSocket()
473            self._sock.connect(self._url, **self._options)
474            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
475            return self._switch.register_port(self)
476        finally:
477            websocket._SSLSocketWrapper = sslwrap
478            self.dprintf('connected: %s\n', lambda: self._url)
479
480    def close(self):
481        if self.closed:
482            raise websocket.WebSocketException('already closed')
483        self._switch.unregister_port(self)
484        self._ioloop.remove_handler(self.fileno())
485        self._sock.close()
486        self._sock = None
487        self.dprintf('disconnected: %s\n', lambda: self._url)
488
489    def fileno(self):
490        if self.closed:
491            raise websocket.WebSocketException('closed socket')
492        return self._sock.io_sock.fileno()
493
494    def write_message(self, message, binary=False):
495        if self.closed:
496            raise websocket.WebSocketException('closed socket')
497        if binary:
498            flag = websocket.ABNF.OPCODE_BINARY
499        else:
500            flag = websocket.ABNF.OPCODE_TEXT
501        self._sock.send(message, flag)
502
503    def __call__(self, fd, events):
504        try:
505            data = self._sock.recv()
506            if data is not None:
507                self._switch.receive(self, EthernetFrame(data))
508                return
509        except:
510            traceback.print_exc()
511        self.close()
512
513
514class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
515    NAMESPACE = 'etherws.control'
516    IFTYPES = {
517        TapHandler.IFTYPE:           TapHandler,
518        EtherWebSocketClient.IFTYPE: EtherWebSocketClient,
519    }
520
521    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
522        super(EtherWebSocketControlHandler, self).__init__(app, req)
523        self._ioloop = ioloop
524        self._switch = switch
525        self._htpasswd = htpasswd
526        self._debug = debug
527
528    def post(self):
529        id_ = None
530
531        try:
532            req = json.loads(self.request.body)
533            method = req['method']
534            params = req['params']
535            id_ = req.get('id')
536
537            if not method.startswith(self.NAMESPACE + '.'):
538                raise ValueError('invalid method: %s' % method)
539
540            if not isinstance(params, list):
541                raise ValueError('invalid params: %s' % params)
542
543            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
544            result = getattr(self, handler)(params)
545            self.finish({'result': result, 'error': None, 'id': id_})
546
547        except Exception as e:
548            traceback.print_exc()
549            msg = '%s: %s' % (e.__class__.__name__, str(e))
550            self.finish({'result': None, 'error': {'message': msg}, 'id': id_})
551
552    def handle_listFdb(self, params):
553        list_ = []
554        for vid in self._switch.fdb.get_vid_list():
555            for mac in self._switch.fdb.get_mac_list(vid):
556                entry = self._switch.fdb.get_entry(vid, mac)
557                if entry:
558                    mac = EthernetFrame.format_mac(mac)
559                    list_.append({
560                        'vid':  vid,
561                        'mac':  mac,
562                        'port': entry.port.number,
563                        'age':  int(entry.age),
564                    })
565        return {'entries': list_}
566
567    def handle_listPort(self, params):
568        list_ = [self._portstat(p) for p in self._switch.portlist]
569        return {'entries': list_}
570
571    def handle_addPort(self, params):
572        list_ = []
573        for p in params:
574            type_ = p['type']
575            target = p['target']
576            options = getattr(self, '_optparse_' + type_)(p.get('options', {}))
577            klass = self.IFTYPES[type_]
578            interface = klass(self._ioloop, self._switch, target, **options)
579            portnum = interface.open()
580            list_.append(self._portstat(self._switch.get_port(portnum)))
581        return {'entries': list_}
582
583    def handle_delPort(self, params):
584        list_ = []
585        for p in params:
586            port = self._switch.get_port(int(p['port']))
587            list_.append(self._portstat(port))
588            port.interface.close()
589        return {'entries': list_}
590
591    def handle_shutPort(self, params):
592        list_ = []
593        for p in params:
594            port = self._switch.get_port(int(p['port']))
595            port.shut = bool(p['shut'])
596            list_.append(self._portstat(port))
597        return {'entries': list_}
598
599    def _optparse_tap(self, opt):
600        return {'debug': self._debug}
601
602    def _optparse_client(self, opt):
603        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
604        if opt.get('insecure'):
605            args = {}
606        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
607        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
608        return {'ssl_': ssl_, 'cred': cred, 'debug': self._debug}
609
610    @staticmethod
611    def _portstat(port):
612        return {
613            'port':   port.number,
614            'type':   port.interface.IFTYPE,
615            'target': port.interface.get_target(),
616            'tx':     port.tx,
617            'rx':     port.rx,
618            'shut':   port.shut,
619        }
620
621
622def start_sw(args):
623    def daemonize(nochdir=False, noclose=False):
624        if os.fork() > 0:
625            sys.exit(0)
626
627        os.setsid()
628
629        if os.fork() > 0:
630            sys.exit(0)
631
632        if not nochdir:
633            os.chdir('/')
634
635        if not noclose:
636            os.umask(0)
637            sys.stdin.close()
638            sys.stdout.close()
639            sys.stderr.close()
640            os.close(0)
641            os.close(1)
642            os.close(2)
643            sys.stdin = open(os.devnull)
644            sys.stdout = open(os.devnull, 'a')
645            sys.stderr = open(os.devnull, 'a')
646
647    def checkabspath(ns, path):
648        val = getattr(ns, path, '')
649        if not val.startswith('/'):
650            raise ValueError('invalid %: %s' % (path, val))
651
652    def getsslopt(ns, key, cert):
653        kval = getattr(ns, key, None)
654        cval = getattr(ns, cert, None)
655        if kval and cval:
656            return {'keyfile': kval, 'certfile': cval}
657        elif kval or cval:
658            raise ValueError('both %s and %s are required' % (key, cert))
659        return None
660
661    def setrealpath(ns, *keys):
662        for k in keys:
663            v = getattr(ns, k, None)
664            if v is not None:
665                v = os.path.realpath(v)
666                open(v).close()  # check readable
667                setattr(ns, k, v)
668
669    def setport(ns, port, isssl):
670        val = getattr(ns, port, None)
671        if val is None:
672            if isssl:
673                return setattr(ns, port, 443)
674            return setattr(ns, port, 80)
675        if not (0 <= val <= 65535):
676            raise ValueError('invalid %s: %s' % (port, val))
677
678    def sethtpasswd(ns, htpasswd):
679        val = getattr(ns, htpasswd, None)
680        if val:
681            return setattr(ns, htpasswd, Htpasswd(val))
682
683    #if args.debug:
684    #    websocket.enableTrace(True)
685
686    if args.ageout <= 0:
687        raise ValueError('invalid ageout: %s' % args.ageout)
688
689    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
690    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
691
692    checkabspath(args, 'path')
693    checkabspath(args, 'ctlpath')
694
695    sslopt = getsslopt(args, 'sslkey', 'sslcert')
696    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
697
698    setport(args, 'port', sslopt)
699    setport(args, 'ctlport', ctlsslopt)
700
701    sethtpasswd(args, 'htpasswd')
702    sethtpasswd(args, 'ctlhtpasswd')
703
704    ioloop = IOLoop.instance()
705    fdb = FDB(ageout=args.ageout, debug=args.debug)
706    switch = SwitchingHub(fdb, debug=args.debug)
707
708    if args.port == args.ctlport and args.host == args.ctlhost:
709        if args.path == args.ctlpath:
710            raise ValueError('same path/ctlpath on same host')
711        if args.sslkey != args.ctlsslkey:
712            raise ValueError('different sslkey/ctlsslkey on same host')
713        if args.sslcert != args.ctlsslcert:
714            raise ValueError('different sslcert/ctlsslcert on same host')
715
716        app = Application([
717            (args.path, EtherWebSocketHandler, {
718                'switch':   switch,
719                'htpasswd': args.htpasswd,
720                'debug':    args.debug,
721            }),
722            (args.ctlpath, EtherWebSocketControlHandler, {
723                'ioloop':   ioloop,
724                'switch':   switch,
725                'htpasswd': args.ctlhtpasswd,
726                'debug':    args.debug,
727            }),
728        ])
729        server = HTTPServer(app, ssl_options=sslopt)
730        server.listen(args.port, address=args.host)
731
732    else:
733        app = Application([(args.path, EtherWebSocketHandler, {
734            'switch':   switch,
735            'htpasswd': args.htpasswd,
736            'debug':    args.debug,
737        })])
738        server = HTTPServer(app, ssl_options=sslopt)
739        server.listen(args.port, address=args.host)
740
741        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
742            'ioloop':   ioloop,
743            'switch':   switch,
744            'htpasswd': args.ctlhtpasswd,
745            'debug':    args.debug,
746        })])
747        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
748        ctlserver.listen(args.ctlport, address=args.ctlhost)
749
750    if not args.foreground:
751        daemonize()
752
753    ioloop.start()
754
755
756def start_ctl(args):
757    def request(args, method, params):
758        method = '.'.join([EtherWebSocketControlHandler.NAMESPACE, method])
759        data = json.dumps({'method': method, 'params': params})
760        req = urllib2.Request(args.ctlurl)
761        req.add_header('Content-type', 'application/json')
762        if args.ctluser:
763            if not args.ctlpasswd:
764                args.ctlpasswd = getpass.getpass()
765            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
766            req.add_header('Authorization', 'Basic %s' % token)
767        return json.loads(urllib2.urlopen(req, data).read())
768
769    def maxlen(dict_, key, min_):
770        max_ = max(len(str(r[key])) for r in dict_)
771        return min_ if max_ < min_ else max_
772
773    def print_portlist(result):
774        pmax = maxlen(result, 'port', 4)
775        ymax = maxlen(result, 'type', 4)
776        smax = maxlen(result, 'shut', 5)
777        rmax = maxlen(result, 'rx', 2)
778        tmax = maxlen(result, 'tx', 2)
779        fmt = %%%d%%%d%%%d%%%d%%%d%%s' % \
780              (pmax, ymax, smax, rmax, tmax)
781        print(fmt % ('Port', 'Type', 'State', 'RX', 'TX', 'Target'))
782        for r in result:
783            shut = 'shut' if r['shut'] else 'up'
784            print(fmt %
785                  (r['port'], r['type'], shut, r['rx'], r['tx'], r['target']))
786
787    def handle_ctl_addport(args):
788        params = [{
789            'type':    args.type,
790            'target':  args.target,
791            'options': {
792                'insecure': args.insecure,
793                'cacerts':  args.cacerts,
794                'user':     args.user,
795                'passwd':   args.passwd,
796            }
797        }]
798        result = request(args, 'addPort', params)
799        if result['error']:
800            print(result['error']['message'])
801        else:
802            print_portlist(result['result']['entries'])
803
804    def handle_ctl_shutport(args):
805        if args.port <= 0:
806            raise ValueError('invalid port: %d' % args.port)
807        params = [{'port': args.port, 'shut': args.no}]
808        result = request(args, 'shutPort', params)
809        if result['error']:
810            print(result['error']['message'])
811        else:
812            print_portlist(result['result']['entries'])
813
814    def handle_ctl_delport(args):
815        if args.port <= 0:
816            raise ValueError('invalid port: %d' % args.port)
817        params = [{'port': args.port}]
818        result = request(args, 'delPort', params)
819        if result['error']:
820            print(result['error']['message'])
821        else:
822            print_portlist(result['result']['entries'])
823
824    def handle_ctl_listport(args):
825        result = request(args, 'listPort', [])
826        if result['error']:
827            print(result['error']['message'])
828        else:
829            print_portlist(result['result']['entries'])
830
831    def handle_ctl_listfdb(args):
832        result = request(args, 'listFdb', [])
833        if result['error']:
834            print(result['error']['message'])
835            return
836        result = result['result']['entries']
837        vmax = maxlen(result, 'vid', 4)
838        mmax = maxlen(result, 'mac', 3)
839        pmax = maxlen(result, 'port', 4)
840        amax = maxlen(result, 'age', 3)
841        fmt = %%%d%%%d%%%d%%%ds' % (vmax, mmax, pmax, amax)
842        print(fmt % ('VLAN', 'MAC', 'Port', 'Age'))
843        for r in result:
844            print(fmt % (r['vid'], r['mac'], r['port'], r['age']))
845
846    locals()['handle_ctl_' + args.control_method](args)
847
848
849def main():
850    parser = argparse.ArgumentParser()
851    subcommand = parser.add_subparsers(dest='subcommand')
852
853    # -- sw command parser
854    parser_s = subcommand.add_parser('sw')
855
856    parser_s.add_argument('--debug', action='store_true', default=False)
857    parser_s.add_argument('--foreground', action='store_true', default=False)
858    parser_s.add_argument('--ageout', type=int, default=300)
859
860    parser_s.add_argument('--path', default='/')
861    parser_s.add_argument('--host', default='')
862    parser_s.add_argument('--port', type=int)
863    parser_s.add_argument('--htpasswd')
864    parser_s.add_argument('--sslkey')
865    parser_s.add_argument('--sslcert')
866
867    parser_s.add_argument('--ctlpath', default='/ctl')
868    parser_s.add_argument('--ctlhost', default='')
869    parser_s.add_argument('--ctlport', type=int)
870    parser_s.add_argument('--ctlhtpasswd')
871    parser_s.add_argument('--ctlsslkey')
872    parser_s.add_argument('--ctlsslcert')
873
874    # -- ctl command parser
875    parser_c = subcommand.add_parser('ctl')
876    parser_c.add_argument('--ctlurl', default='http://localhost/ctl')
877    parser_c.add_argument('--ctluser')
878    parser_c.add_argument('--ctlpasswd')
879
880    control_method = parser_c.add_subparsers(dest='control_method')
881
882    parser_c_ap = control_method.add_parser('addport')
883    parser_c_ap.add_argument(
884        'type', choices=EtherWebSocketControlHandler.IFTYPES.keys())
885    parser_c_ap.add_argument('target')
886    parser_c_ap.add_argument('--insecure', action='store_true', default=False)
887    parser_c_ap.add_argument('--cacerts')
888    parser_c_ap.add_argument('--user')
889    parser_c_ap.add_argument('--passwd')
890
891    parser_c_sp = control_method.add_parser('shutport')
892    parser_c_sp.add_argument('port', type=int)
893    parser_c_sp.add_argument('--no', action='store_false', default=True)
894
895    parser_c_dp = control_method.add_parser('delport')
896    parser_c_dp.add_argument('port', type=int)
897
898    parser_c_lp = control_method.add_parser('listport')
899
900    parser_c_lf = control_method.add_parser('listfdb')
901
902    # -- go
903    args = parser.parse_args()
904    globals()['start_' + args.subcommand](args)
905
906
907if __name__ == '__main__':
908    main()
Note: See TracBrowser for help on using the repository browser.