source: etherws/trunk/etherws.py @ 201

Revision 201, 28.6 KB checked in by atzm, 12 years ago (diff)
  • fixed a bug, raise exception when result is empty
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.3
11#
12# ===========================================================================
13# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
42import ssl
43import time
44import json
45import fcntl
46import base64
47import urllib2
48import hashlib
49import getpass
50import argparse
51import traceback
52
53import tornado
54import websocket
55
56from tornado.web import Application, RequestHandler
57from tornado.websocket import WebSocketHandler
58from tornado.httpserver import HTTPServer
59from tornado.ioloop import IOLoop
60
61from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
62
63
64class DebugMixIn(object):
65    def dprintf(self, msg, func=lambda: ()):
66        if self._debug:
67            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
68            sys.stderr.write(prefix + (msg % func()))
69
70
71class EthernetFrame(object):
72    def __init__(self, data):
73        self.data = data
74
75    @property
76    def dst_multicast(self):
77        return ord(self.data[0]) & 1
78
79    @property
80    def src_multicast(self):
81        return ord(self.data[6]) & 1
82
83    @property
84    def dst_mac(self):
85        return self.data[:6]
86
87    @property
88    def src_mac(self):
89        return self.data[6:12]
90
91    @property
92    def tagged(self):
93        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
94
95    @property
96    def vid(self):
97        if self.tagged:
98            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
99        return 0
100
101    @staticmethod
102    def format_mac(mac, sep=':'):
103        return sep.join(b.encode('hex') for b in mac)
104
105
106class FDB(DebugMixIn):
107    class Entry(object):
108        def __init__(self, port, ageout):
109            self.port = port
110            self._time = time.time()
111            self._ageout = ageout
112
113        @property
114        def age(self):
115            return time.time() - self._time
116
117        @property
118        def agedout(self):
119            return self.age > self._ageout
120
121    def __init__(self, ageout, debug=False):
122        self._ageout = ageout
123        self._debug = debug
124        self._table = {}
125
126    def _set_entry(self, vid, mac, port):
127        if vid not in self._table:
128            self._table[vid] = {}
129        self._table[vid][mac] = self.Entry(port, self._ageout)
130
131    def _del_entry(self, vid, mac):
132        if vid in self._table:
133            if mac in self._table[vid]:
134                del self._table[vid][mac]
135            if not self._table[vid]:
136                del self._table[vid]
137
138    def get_entry(self, vid, mac):
139        try:
140            entry = self._table[vid][mac]
141        except KeyError:
142            return None
143
144        if not entry.agedout:
145            return entry
146
147        self._del_entry(vid, mac)
148        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
149                     lambda: (entry.port.number, vid, mac.encode('hex')))
150
151    def get_vid_list(self):
152        return sorted(self._table.iterkeys())
153
154    def get_mac_list(self, vid):
155        return sorted(self._table[vid].iterkeys())
156
157    def lookup(self, frame):
158        mac = frame.dst_mac
159        vid = frame.vid
160        entry = self.get_entry(vid, mac)
161        return getattr(entry, 'port', None)
162
163    def learn(self, port, frame):
164        mac = frame.src_mac
165        vid = frame.vid
166        self._set_entry(vid, mac, port)
167        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
168                     lambda: (port.number, vid, mac.encode('hex')))
169
170    def delete(self, port):
171        for vid in self.get_vid_list():
172            for mac in self.get_mac_list(vid):
173                entry = self.get_entry(vid, mac)
174                if entry and entry.port.number == port.number:
175                    self._del_entry(vid, mac)
176                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
177                                 lambda: (port.number, vid, mac.encode('hex')))
178
179
180class SwitchingHub(DebugMixIn):
181    class Port(object):
182        def __init__(self, number, interface):
183            self.number = number
184            self.interface = interface
185            self.tx = 0
186            self.rx = 0
187            self.shut = False
188
189        @staticmethod
190        def cmp_by_number(x, y):
191            return cmp(x.number, y.number)
192
193    def __init__(self, fdb, debug=False):
194        self.fdb = fdb
195        self._debug = debug
196        self._table = {}
197        self._next = 1
198
199    @property
200    def portlist(self):
201        return sorted(self._table.itervalues(), cmp=self.Port.cmp_by_number)
202
203    def get_port(self, portnum):
204        return self._table[portnum]
205
206    def register_port(self, interface):
207        try:
208            self._set_privattr('portnum', interface, self._next)  # XXX
209            self._table[self._next] = self.Port(self._next, interface)
210            return self._next
211        finally:
212            self._next += 1
213
214    def unregister_port(self, interface):
215        portnum = self._get_privattr('portnum', interface)
216        self._del_privattr('portnum', interface)
217        self.fdb.delete(self._table[portnum])
218        del self._table[portnum]
219
220    def send(self, dst_interfaces, frame):
221        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
222        ports = (self._table[n] for n in portnums)
223        ports = (p for p in ports if not p.shut)
224        ports = sorted(ports, cmp=self.Port.cmp_by_number)
225
226        for p in ports:
227            p.interface.write_message(frame.data, True)
228            p.tx += 1
229
230        if ports:
231            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
232                         lambda: (','.join(str(p.number) for p in ports),
233                                  frame.vid,
234                                  frame.src_mac.encode('hex'),
235                                  frame.dst_mac.encode('hex')))
236
237    def receive(self, src_interface, frame):
238        port = self._table[self._get_privattr('portnum', src_interface)]
239
240        if not port.shut:
241            port.rx += 1
242            self._forward(port, frame)
243
244    def _forward(self, src_port, frame):
245        try:
246            if not frame.src_multicast:
247                self.fdb.learn(src_port, frame)
248
249            if not frame.dst_multicast:
250                dst_port = self.fdb.lookup(frame)
251
252                if dst_port:
253                    self.send([dst_port.interface], frame)
254                    return
255
256            ports = set(self.portlist) - set([src_port])
257            self.send((p.interface for p in ports), frame)
258
259        except:  # ex. received invalid frame
260            traceback.print_exc()
261
262    def _privattr(self, name):
263        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
264
265    def _set_privattr(self, name, obj, value):
266        return setattr(obj, self._privattr(name), value)
267
268    def _get_privattr(self, name, obj, defaults=None):
269        return getattr(obj, self._privattr(name), defaults)
270
271    def _del_privattr(self, name, obj):
272        return delattr(obj, self._privattr(name))
273
274
275class Htpasswd(object):
276    def __init__(self, path):
277        self._path = path
278        self._stat = None
279        self._data = {}
280
281    def auth(self, name, passwd):
282        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
283        return self._data.get(name) == passwd
284
285    def load(self):
286        old_stat = self._stat
287
288        with open(self._path) as fp:
289            fileno = fp.fileno()
290            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
291            self._stat = os.fstat(fileno)
292
293            unchanged = old_stat and \
294                        old_stat.st_ino == self._stat.st_ino and \
295                        old_stat.st_dev == self._stat.st_dev and \
296                        old_stat.st_mtime == self._stat.st_mtime
297
298            if not unchanged:
299                self._data = self._parse(fp)
300
301        return self
302
303    def _parse(self, fp):
304        data = {}
305        for line in fp:
306            line = line.strip()
307            if 0 <= line.find(':'):
308                name, passwd = line.split(':', 1)
309                if passwd.startswith('{SHA}'):
310                    data[name] = passwd[5:]
311        return data
312
313
314class BasicAuthMixIn(object):
315    def _execute(self, transforms, *args, **kwargs):
316        def do_execute():
317            sp = super(BasicAuthMixIn, self)
318            return sp._execute(transforms, *args, **kwargs)
319
320        def auth_required():
321            stream = getattr(self, 'stream', self.request.connection.stream)
322            stream.write(tornado.escape.utf8(
323                'HTTP/1.1 401 Authorization Required\r\n'
324                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
325            ))
326            stream.close()
327
328        try:
329            if not self._htpasswd:
330                return do_execute()
331
332            creds = self.request.headers.get('Authorization')
333
334            if not creds or not creds.startswith('Basic '):
335                return auth_required()
336
337            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
338
339            if self._htpasswd.load().auth(name, passwd):
340                return do_execute()
341        except:
342            traceback.print_exc()
343
344        return auth_required()
345
346
347class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
348    IFTYPE = 'server'
349
350    def __init__(self, app, req, switch, htpasswd=None, debug=False):
351        super(EtherWebSocketHandler, self).__init__(app, req)
352        self._switch = switch
353        self._htpasswd = htpasswd
354        self._debug = debug
355
356    def get_target(self):
357        return self.request.remote_ip
358
359    def open(self):
360        try:
361            return self._switch.register_port(self)
362        finally:
363            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
364
365    def on_message(self, message):
366        self._switch.receive(self, EthernetFrame(message))
367
368    def on_close(self):
369        self._switch.unregister_port(self)
370        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
371
372
373class TapHandler(DebugMixIn):
374    IFTYPE = 'tap'
375    READ_SIZE = 65535
376
377    def __init__(self, ioloop, switch, dev, debug=False):
378        self._ioloop = ioloop
379        self._switch = switch
380        self._dev = dev
381        self._debug = debug
382        self._tap = None
383
384    def get_target(self):
385        if self.closed:
386            return self._dev
387        return self._tap.name
388
389    @property
390    def closed(self):
391        return not self._tap
392
393    def open(self):
394        if not self.closed:
395            raise ValueError('already opened')
396        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
397        self._tap.up()
398        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
399        return self._switch.register_port(self)
400
401    def close(self):
402        if self.closed:
403            raise ValueError('I/O operation on closed tap')
404        self._switch.unregister_port(self)
405        self._ioloop.remove_handler(self.fileno())
406        self._tap.close()
407        self._tap = None
408
409    def fileno(self):
410        if self.closed:
411            raise ValueError('I/O operation on closed tap')
412        return self._tap.fileno()
413
414    def write_message(self, message, binary=False):
415        if self.closed:
416            raise ValueError('I/O operation on closed tap')
417        self._tap.write(message)
418
419    def __call__(self, fd, events):
420        try:
421            self._switch.receive(self, EthernetFrame(self._read()))
422            return
423        except:
424            traceback.print_exc()
425        self.close()
426
427    def _read(self):
428        if self.closed:
429            raise ValueError('I/O operation on closed tap')
430        buf = []
431        while True:
432            buf.append(self._tap.read(self.READ_SIZE))
433            if len(buf[-1]) < self.READ_SIZE:
434                break
435        return ''.join(buf)
436
437
438class EtherWebSocketClient(DebugMixIn):
439    IFTYPE = 'client'
440
441    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
442        self._ioloop = ioloop
443        self._switch = switch
444        self._url = url
445        self._ssl = ssl_
446        self._debug = debug
447        self._sock = None
448        self._options = {}
449
450        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
451            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
452            auth = ['Authorization: Basic %s' % token]
453            self._options['header'] = auth
454
455    def get_target(self):
456        return self._url
457
458    @property
459    def closed(self):
460        return not self._sock
461
462    def open(self):
463        sslwrap = websocket._SSLSocketWrapper
464
465        if not self.closed:
466            raise websocket.WebSocketException('already opened')
467
468        if self._ssl:
469            websocket._SSLSocketWrapper = self._ssl
470
471        try:
472            self._sock = websocket.WebSocket()
473            self._sock.connect(self._url, **self._options)
474            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
475            return self._switch.register_port(self)
476        finally:
477            websocket._SSLSocketWrapper = sslwrap
478            self.dprintf('connected: %s\n', lambda: self._url)
479
480    def close(self):
481        if self.closed:
482            raise websocket.WebSocketException('already closed')
483        self._switch.unregister_port(self)
484        self._ioloop.remove_handler(self.fileno())
485        self._sock.close()
486        self._sock = None
487        self.dprintf('disconnected: %s\n', lambda: self._url)
488
489    def fileno(self):
490        if self.closed:
491            raise websocket.WebSocketException('closed socket')
492        return self._sock.io_sock.fileno()
493
494    def write_message(self, message, binary=False):
495        if self.closed:
496            raise websocket.WebSocketException('closed socket')
497        if binary:
498            flag = websocket.ABNF.OPCODE_BINARY
499        else:
500            flag = websocket.ABNF.OPCODE_TEXT
501        self._sock.send(message, flag)
502
503    def __call__(self, fd, events):
504        try:
505            data = self._sock.recv()
506            if data is not None:
507                self._switch.receive(self, EthernetFrame(data))
508                return
509        except:
510            traceback.print_exc()
511        self.close()
512
513
514class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
515    NAMESPACE = 'etherws.control'
516    IFTYPES = {
517        TapHandler.IFTYPE:           TapHandler,
518        EtherWebSocketClient.IFTYPE: EtherWebSocketClient,
519    }
520
521    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
522        super(EtherWebSocketControlHandler, self).__init__(app, req)
523        self._ioloop = ioloop
524        self._switch = switch
525        self._htpasswd = htpasswd
526        self._debug = debug
527
528    def post(self):
529        id_ = None
530
531        try:
532            req = json.loads(self.request.body)
533            method = req['method']
534            params = req['params']
535            id_ = req.get('id')
536
537            if not method.startswith(self.NAMESPACE + '.'):
538                raise ValueError('invalid method: %s' % method)
539
540            if not isinstance(params, list):
541                raise ValueError('invalid params: %s' % params)
542
543            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
544            result = getattr(self, handler)(params)
545            self.finish({'result': result, 'error': None, 'id': id_})
546
547        except Exception as e:
548            traceback.print_exc()
549            msg = '%s: %s' % (e.__class__.__name__, str(e))
550            self.finish({'result': None, 'error': {'message': msg}, 'id': id_})
551
552    def handle_listFdb(self, params):
553        list_ = []
554        for vid in self._switch.fdb.get_vid_list():
555            for mac in self._switch.fdb.get_mac_list(vid):
556                entry = self._switch.fdb.get_entry(vid, mac)
557                if entry:
558                    list_.append({
559                        'vid':  vid,
560                        'mac':  EthernetFrame.format_mac(mac),
561                        'port': entry.port.number,
562                        'age':  int(entry.age),
563                    })
564        return {'entries': list_}
565
566    def handle_listPort(self, params):
567        list_ = [self._portstat(p) for p in self._switch.portlist]
568        return {'entries': list_}
569
570    def handle_addPort(self, params):
571        list_ = []
572        for p in params:
573            type_ = p['type']
574            target = p['target']
575            options = getattr(self, '_optparse_' + type_)(p.get('options', {}))
576            klass = self.IFTYPES[type_]
577            interface = klass(self._ioloop, self._switch, target, **options)
578            portnum = interface.open()
579            list_.append(self._portstat(self._switch.get_port(portnum)))
580        return {'entries': list_}
581
582    def handle_delPort(self, params):
583        list_ = []
584        for p in params:
585            port = self._switch.get_port(int(p['port']))
586            list_.append(self._portstat(port))
587            port.interface.close()
588        return {'entries': list_}
589
590    def handle_shutPort(self, params):
591        list_ = []
592        for p in params:
593            port = self._switch.get_port(int(p['port']))
594            port.shut = bool(p['shut'])
595            list_.append(self._portstat(port))
596        return {'entries': list_}
597
598    def _optparse_tap(self, opt):
599        return {'debug': self._debug}
600
601    def _optparse_client(self, opt):
602        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
603        if opt.get('insecure'):
604            args = {}
605        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
606        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
607        return {'ssl_': ssl_, 'cred': cred, 'debug': self._debug}
608
609    @staticmethod
610    def _portstat(port):
611        return {
612            'port':   port.number,
613            'type':   port.interface.IFTYPE,
614            'target': port.interface.get_target(),
615            'tx':     port.tx,
616            'rx':     port.rx,
617            'shut':   port.shut,
618        }
619
620
621def start_sw(args):
622    def daemonize(nochdir=False, noclose=False):
623        if os.fork() > 0:
624            sys.exit(0)
625
626        os.setsid()
627
628        if os.fork() > 0:
629            sys.exit(0)
630
631        if not nochdir:
632            os.chdir('/')
633
634        if not noclose:
635            os.umask(0)
636            sys.stdin.close()
637            sys.stdout.close()
638            sys.stderr.close()
639            os.close(0)
640            os.close(1)
641            os.close(2)
642            sys.stdin = open(os.devnull)
643            sys.stdout = open(os.devnull, 'a')
644            sys.stderr = open(os.devnull, 'a')
645
646    def checkabspath(ns, path):
647        val = getattr(ns, path, '')
648        if not val.startswith('/'):
649            raise ValueError('invalid %: %s' % (path, val))
650
651    def getsslopt(ns, key, cert):
652        kval = getattr(ns, key, None)
653        cval = getattr(ns, cert, None)
654        if kval and cval:
655            return {'keyfile': kval, 'certfile': cval}
656        elif kval or cval:
657            raise ValueError('both %s and %s are required' % (key, cert))
658        return None
659
660    def setrealpath(ns, *keys):
661        for k in keys:
662            v = getattr(ns, k, None)
663            if v is not None:
664                v = os.path.realpath(v)
665                open(v).close()  # check readable
666                setattr(ns, k, v)
667
668    def setport(ns, port, isssl):
669        val = getattr(ns, port, None)
670        if val is None:
671            if isssl:
672                return setattr(ns, port, 443)
673            return setattr(ns, port, 80)
674        if not (0 <= val <= 65535):
675            raise ValueError('invalid %s: %s' % (port, val))
676
677    def sethtpasswd(ns, htpasswd):
678        val = getattr(ns, htpasswd, None)
679        if val:
680            return setattr(ns, htpasswd, Htpasswd(val))
681
682    #if args.debug:
683    #    websocket.enableTrace(True)
684
685    if args.ageout <= 0:
686        raise ValueError('invalid ageout: %s' % args.ageout)
687
688    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
689    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
690
691    checkabspath(args, 'path')
692    checkabspath(args, 'ctlpath')
693
694    sslopt = getsslopt(args, 'sslkey', 'sslcert')
695    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
696
697    setport(args, 'port', sslopt)
698    setport(args, 'ctlport', ctlsslopt)
699
700    sethtpasswd(args, 'htpasswd')
701    sethtpasswd(args, 'ctlhtpasswd')
702
703    ioloop = IOLoop.instance()
704    fdb = FDB(ageout=args.ageout, debug=args.debug)
705    switch = SwitchingHub(fdb, debug=args.debug)
706
707    if args.port == args.ctlport and args.host == args.ctlhost:
708        if args.path == args.ctlpath:
709            raise ValueError('same path/ctlpath on same host')
710        if args.sslkey != args.ctlsslkey:
711            raise ValueError('different sslkey/ctlsslkey on same host')
712        if args.sslcert != args.ctlsslcert:
713            raise ValueError('different sslcert/ctlsslcert on same host')
714
715        app = Application([
716            (args.path, EtherWebSocketHandler, {
717                'switch':   switch,
718                'htpasswd': args.htpasswd,
719                'debug':    args.debug,
720            }),
721            (args.ctlpath, EtherWebSocketControlHandler, {
722                'ioloop':   ioloop,
723                'switch':   switch,
724                'htpasswd': args.ctlhtpasswd,
725                'debug':    args.debug,
726            }),
727        ])
728        server = HTTPServer(app, ssl_options=sslopt)
729        server.listen(args.port, address=args.host)
730
731    else:
732        app = Application([(args.path, EtherWebSocketHandler, {
733            'switch':   switch,
734            'htpasswd': args.htpasswd,
735            'debug':    args.debug,
736        })])
737        server = HTTPServer(app, ssl_options=sslopt)
738        server.listen(args.port, address=args.host)
739
740        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
741            'ioloop':   ioloop,
742            'switch':   switch,
743            'htpasswd': args.ctlhtpasswd,
744            'debug':    args.debug,
745        })])
746        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
747        ctlserver.listen(args.ctlport, address=args.ctlhost)
748
749    if not args.foreground:
750        daemonize()
751
752    ioloop.start()
753
754
755def start_ctl(args):
756    def request(args, method, params):
757        method = '.'.join([EtherWebSocketControlHandler.NAMESPACE, method])
758        data = json.dumps({'method': method, 'params': params})
759        req = urllib2.Request(args.ctlurl)
760        req.add_header('Content-type', 'application/json')
761        if args.ctluser:
762            if not args.ctlpasswd:
763                args.ctlpasswd = getpass.getpass()
764            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
765            req.add_header('Authorization', 'Basic %s' % token)
766        return json.loads(urllib2.urlopen(req, data).read())
767
768    def maxlen(dict_, key, min_):
769        if not dict_:
770            return min_
771        max_ = max(len(str(r[key])) for r in dict_)
772        return min_ if max_ < min_ else max_
773
774    def print_portlist(result):
775        pmax = maxlen(result, 'port', 4)
776        ymax = maxlen(result, 'type', 4)
777        smax = maxlen(result, 'shut', 5)
778        rmax = maxlen(result, 'rx', 2)
779        tmax = maxlen(result, 'tx', 2)
780        fmt = %%%d%%%d%%%d%%%d%%%d%%s' % \
781              (pmax, ymax, smax, rmax, tmax)
782        print(fmt % ('Port', 'Type', 'State', 'RX', 'TX', 'Target'))
783        for r in result:
784            shut = 'shut' if r['shut'] else 'up'
785            print(fmt %
786                  (r['port'], r['type'], shut, r['rx'], r['tx'], r['target']))
787
788    def handle_ctl_addport(args):
789        params = [{
790            'type':    args.type,
791            'target':  args.target,
792            'options': {
793                'insecure': args.insecure,
794                'cacerts':  args.cacerts,
795                'user':     args.user,
796                'passwd':   args.passwd,
797            }
798        }]
799        result = request(args, 'addPort', params)
800        if result['error']:
801            print(result['error']['message'])
802        else:
803            print_portlist(result['result']['entries'])
804
805    def handle_ctl_shutport(args):
806        if args.port <= 0:
807            raise ValueError('invalid port: %d' % args.port)
808        params = [{'port': args.port, 'shut': args.no}]
809        result = request(args, 'shutPort', params)
810        if result['error']:
811            print(result['error']['message'])
812        else:
813            print_portlist(result['result']['entries'])
814
815    def handle_ctl_delport(args):
816        if args.port <= 0:
817            raise ValueError('invalid port: %d' % args.port)
818        params = [{'port': args.port}]
819        result = request(args, 'delPort', params)
820        if result['error']:
821            print(result['error']['message'])
822        else:
823            print_portlist(result['result']['entries'])
824
825    def handle_ctl_listport(args):
826        result = request(args, 'listPort', [])
827        if result['error']:
828            print(result['error']['message'])
829        else:
830            print_portlist(result['result']['entries'])
831
832    def handle_ctl_listfdb(args):
833        result = request(args, 'listFdb', [])
834        if result['error']:
835            print(result['error']['message'])
836            return
837        result = result['result']['entries']
838        pmax = maxlen(result, 'port', 4)
839        vmax = maxlen(result, 'vid', 4)
840        mmax = maxlen(result, 'mac', 3)
841        amax = maxlen(result, 'age', 3)
842        fmt = %%%d%%%d%%-%d%%%ds' % (pmax, vmax, mmax, amax)
843        print(fmt % ('Port', 'VLAN', 'MAC', 'Age'))
844        for r in result:
845            print(fmt % (r['port'], r['vid'], r['mac'], r['age']))
846
847    locals()['handle_ctl_' + args.control_method](args)
848
849
850def main():
851    parser = argparse.ArgumentParser()
852    subcommand = parser.add_subparsers(dest='subcommand')
853
854    # -- sw command parser
855    parser_s = subcommand.add_parser('sw')
856
857    parser_s.add_argument('--debug', action='store_true', default=False)
858    parser_s.add_argument('--foreground', action='store_true', default=False)
859    parser_s.add_argument('--ageout', type=int, default=300)
860
861    parser_s.add_argument('--path', default='/')
862    parser_s.add_argument('--host', default='')
863    parser_s.add_argument('--port', type=int)
864    parser_s.add_argument('--htpasswd')
865    parser_s.add_argument('--sslkey')
866    parser_s.add_argument('--sslcert')
867
868    parser_s.add_argument('--ctlpath', default='/ctl')
869    parser_s.add_argument('--ctlhost', default='')
870    parser_s.add_argument('--ctlport', type=int)
871    parser_s.add_argument('--ctlhtpasswd')
872    parser_s.add_argument('--ctlsslkey')
873    parser_s.add_argument('--ctlsslcert')
874
875    # -- ctl command parser
876    parser_c = subcommand.add_parser('ctl')
877    parser_c.add_argument('--ctlurl', default='http://localhost/ctl')
878    parser_c.add_argument('--ctluser')
879    parser_c.add_argument('--ctlpasswd')
880
881    control_method = parser_c.add_subparsers(dest='control_method')
882
883    parser_c_ap = control_method.add_parser('addport')
884    parser_c_ap.add_argument(
885        'type', choices=EtherWebSocketControlHandler.IFTYPES.keys())
886    parser_c_ap.add_argument('target')
887    parser_c_ap.add_argument('--insecure', action='store_true', default=False)
888    parser_c_ap.add_argument('--cacerts')
889    parser_c_ap.add_argument('--user')
890    parser_c_ap.add_argument('--passwd')
891
892    parser_c_sp = control_method.add_parser('shutport')
893    parser_c_sp.add_argument('port', type=int)
894    parser_c_sp.add_argument('--no', action='store_false', default=True)
895
896    parser_c_dp = control_method.add_parser('delport')
897    parser_c_dp.add_argument('port', type=int)
898
899    parser_c_lp = control_method.add_parser('listport')
900
901    parser_c_lf = control_method.add_parser('listfdb')
902
903    # -- go
904    args = parser.parse_args()
905    globals()['start_' + args.subcommand](args)
906
907
908if __name__ == '__main__':
909    main()
Note: See TracBrowser for help on using the repository browser.