source: etherws/trunk/etherws.py @ 203

Revision 203, 29.7 KB checked in by atzm, 12 years ago (diff)
  • identify server target strictly
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.3
11#
12# ===========================================================================
13# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
42import ssl
43import time
44import json
45import fcntl
46import base64
47import urllib2
48import hashlib
49import getpass
50import argparse
51import traceback
52
53import tornado
54import websocket
55
56from tornado.web import Application, RequestHandler
57from tornado.websocket import WebSocketHandler
58from tornado.httpserver import HTTPServer
59from tornado.ioloop import IOLoop
60
61from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
62
63
64class DebugMixIn(object):
65    def dprintf(self, msg, func=lambda: ()):
66        if self._debug:
67            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
68            sys.stderr.write(prefix + (msg % func()))
69
70
71class EthernetFrame(object):
72    def __init__(self, data):
73        self.data = data
74
75    @property
76    def dst_multicast(self):
77        return ord(self.data[0]) & 1
78
79    @property
80    def src_multicast(self):
81        return ord(self.data[6]) & 1
82
83    @property
84    def dst_mac(self):
85        return self.data[:6]
86
87    @property
88    def src_mac(self):
89        return self.data[6:12]
90
91    @property
92    def tagged(self):
93        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
94
95    @property
96    def vid(self):
97        if self.tagged:
98            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
99        return 0
100
101    @staticmethod
102    def format_mac(mac, sep=':'):
103        return sep.join(b.encode('hex') for b in mac)
104
105
106class FDB(DebugMixIn):
107    class Entry(object):
108        def __init__(self, port, ageout):
109            self.port = port
110            self._time = time.time()
111            self._ageout = ageout
112
113        @property
114        def age(self):
115            return time.time() - self._time
116
117        @property
118        def agedout(self):
119            return self.age > self._ageout
120
121    def __init__(self, ageout, debug=False):
122        self._ageout = ageout
123        self._debug = debug
124        self._table = {}
125
126    def _set_entry(self, vid, mac, port):
127        if vid not in self._table:
128            self._table[vid] = {}
129        self._table[vid][mac] = self.Entry(port, self._ageout)
130
131    def _del_entry(self, vid, mac):
132        if vid in self._table:
133            if mac in self._table[vid]:
134                del self._table[vid][mac]
135            if not self._table[vid]:
136                del self._table[vid]
137
138    def get_entry(self, vid, mac):
139        try:
140            entry = self._table[vid][mac]
141        except KeyError:
142            return None
143
144        if not entry.agedout:
145            return entry
146
147        self._del_entry(vid, mac)
148        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
149                     lambda: (entry.port.number, vid, mac.encode('hex')))
150
151    def get_vid_list(self):
152        return sorted(self._table.iterkeys())
153
154    def get_mac_list(self, vid):
155        return sorted(self._table[vid].iterkeys())
156
157    def lookup(self, frame):
158        mac = frame.dst_mac
159        vid = frame.vid
160        entry = self.get_entry(vid, mac)
161        return getattr(entry, 'port', None)
162
163    def learn(self, port, frame):
164        mac = frame.src_mac
165        vid = frame.vid
166        self._set_entry(vid, mac, port)
167        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
168                     lambda: (port.number, vid, mac.encode('hex')))
169
170    def delete(self, port):
171        for vid in self.get_vid_list():
172            for mac in self.get_mac_list(vid):
173                entry = self.get_entry(vid, mac)
174                if entry and entry.port.number == port.number:
175                    self._del_entry(vid, mac)
176                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
177                                 lambda: (port.number, vid, mac.encode('hex')))
178
179
180class SwitchingHub(DebugMixIn):
181    class Port(object):
182        def __init__(self, number, interface):
183            self.number = number
184            self.interface = interface
185            self.tx = 0
186            self.rx = 0
187            self.shut = False
188
189        @staticmethod
190        def cmp_by_number(x, y):
191            return cmp(x.number, y.number)
192
193    def __init__(self, fdb, debug=False):
194        self.fdb = fdb
195        self._debug = debug
196        self._table = {}
197        self._next = 1
198
199    @property
200    def portlist(self):
201        return sorted(self._table.itervalues(), cmp=self.Port.cmp_by_number)
202
203    def get_port(self, portnum):
204        return self._table[portnum]
205
206    def register_port(self, interface):
207        try:
208            self._set_privattr('portnum', interface, self._next)  # XXX
209            self._table[self._next] = self.Port(self._next, interface)
210            return self._next
211        finally:
212            self._next += 1
213
214    def unregister_port(self, interface):
215        portnum = self._get_privattr('portnum', interface)
216        self._del_privattr('portnum', interface)
217        self.fdb.delete(self._table[portnum])
218        del self._table[portnum]
219
220    def send(self, dst_interfaces, frame):
221        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
222        ports = (self._table[n] for n in portnums)
223        ports = (p for p in ports if not p.shut)
224        ports = sorted(ports, cmp=self.Port.cmp_by_number)
225
226        for p in ports:
227            p.interface.write_message(frame.data, True)
228            p.tx += 1
229
230        if ports:
231            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
232                         lambda: (','.join(str(p.number) for p in ports),
233                                  frame.vid,
234                                  frame.src_mac.encode('hex'),
235                                  frame.dst_mac.encode('hex')))
236
237    def receive(self, src_interface, frame):
238        port = self._table[self._get_privattr('portnum', src_interface)]
239
240        if not port.shut:
241            port.rx += 1
242            self._forward(port, frame)
243
244    def _forward(self, src_port, frame):
245        try:
246            if not frame.src_multicast:
247                self.fdb.learn(src_port, frame)
248
249            if not frame.dst_multicast:
250                dst_port = self.fdb.lookup(frame)
251
252                if dst_port:
253                    self.send([dst_port.interface], frame)
254                    return
255
256            ports = set(self.portlist) - set([src_port])
257            self.send((p.interface for p in ports), frame)
258
259        except:  # ex. received invalid frame
260            traceback.print_exc()
261
262    def _privattr(self, name):
263        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
264
265    def _set_privattr(self, name, obj, value):
266        return setattr(obj, self._privattr(name), value)
267
268    def _get_privattr(self, name, obj, defaults=None):
269        return getattr(obj, self._privattr(name), defaults)
270
271    def _del_privattr(self, name, obj):
272        return delattr(obj, self._privattr(name))
273
274
275class Htpasswd(object):
276    def __init__(self, path):
277        self._path = path
278        self._stat = None
279        self._data = {}
280
281    def auth(self, name, passwd):
282        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
283        return self._data.get(name) == passwd
284
285    def load(self):
286        old_stat = self._stat
287
288        with open(self._path) as fp:
289            fileno = fp.fileno()
290            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
291            self._stat = os.fstat(fileno)
292
293            unchanged = old_stat and \
294                        old_stat.st_ino == self._stat.st_ino and \
295                        old_stat.st_dev == self._stat.st_dev and \
296                        old_stat.st_mtime == self._stat.st_mtime
297
298            if not unchanged:
299                self._data = self._parse(fp)
300
301        return self
302
303    def _parse(self, fp):
304        data = {}
305        for line in fp:
306            line = line.strip()
307            if 0 <= line.find(':'):
308                name, passwd = line.split(':', 1)
309                if passwd.startswith('{SHA}'):
310                    data[name] = passwd[5:]
311        return data
312
313
314class BasicAuthMixIn(object):
315    def _execute(self, transforms, *args, **kwargs):
316        def do_execute():
317            sp = super(BasicAuthMixIn, self)
318            return sp._execute(transforms, *args, **kwargs)
319
320        def auth_required():
321            stream = getattr(self, 'stream', self.request.connection.stream)
322            stream.write(tornado.escape.utf8(
323                'HTTP/1.1 401 Authorization Required\r\n'
324                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
325            ))
326            stream.close()
327
328        try:
329            if not self._htpasswd:
330                return do_execute()
331
332            creds = self.request.headers.get('Authorization')
333
334            if not creds or not creds.startswith('Basic '):
335                return auth_required()
336
337            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
338
339            if self._htpasswd.load().auth(name, passwd):
340                return do_execute()
341        except:
342            traceback.print_exc()
343
344        return auth_required()
345
346
347class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
348    IFTYPE = 'server'
349
350    def __init__(self, app, req, switch, htpasswd=None, debug=False):
351        super(EtherWebSocketHandler, self).__init__(app, req)
352        self._switch = switch
353        self._htpasswd = htpasswd
354        self._debug = debug
355
356    @property
357    def target(self):
358        return ':'.join(str(e) for e in self.request.connection.address)
359
360    def open(self):
361        try:
362            return self._switch.register_port(self)
363        finally:
364            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
365
366    def on_message(self, message):
367        self._switch.receive(self, EthernetFrame(message))
368
369    def on_close(self):
370        self._switch.unregister_port(self)
371        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
372
373
374class TapHandler(DebugMixIn):
375    IFTYPE = 'tap'
376    READ_SIZE = 65535
377
378    def __init__(self, ioloop, switch, dev, debug=False):
379        self._ioloop = ioloop
380        self._switch = switch
381        self._dev = dev
382        self._debug = debug
383        self._tap = None
384
385    @property
386    def target(self):
387        if self.closed:
388            return self._dev
389        return self._tap.name
390
391    @property
392    def closed(self):
393        return not self._tap
394
395    def open(self):
396        if not self.closed:
397            raise ValueError('Already opened')
398        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
399        self._tap.up()
400        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
401        return self._switch.register_port(self)
402
403    def close(self):
404        if self.closed:
405            raise ValueError('I/O operation on closed tap')
406        self._switch.unregister_port(self)
407        self._ioloop.remove_handler(self.fileno())
408        self._tap.close()
409        self._tap = None
410
411    def fileno(self):
412        if self.closed:
413            raise ValueError('I/O operation on closed tap')
414        return self._tap.fileno()
415
416    def write_message(self, message, binary=False):
417        if self.closed:
418            raise ValueError('I/O operation on closed tap')
419        self._tap.write(message)
420
421    def __call__(self, fd, events):
422        try:
423            self._switch.receive(self, EthernetFrame(self._read()))
424            return
425        except:
426            traceback.print_exc()
427        self.close()
428
429    def _read(self):
430        if self.closed:
431            raise ValueError('I/O operation on closed tap')
432        buf = []
433        while True:
434            buf.append(self._tap.read(self.READ_SIZE))
435            if len(buf[-1]) < self.READ_SIZE:
436                break
437        return ''.join(buf)
438
439
440class EtherWebSocketClient(DebugMixIn):
441    IFTYPE = 'client'
442
443    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
444        self._ioloop = ioloop
445        self._switch = switch
446        self._url = url
447        self._ssl = ssl_
448        self._debug = debug
449        self._sock = None
450        self._options = {}
451
452        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
453            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
454            auth = ['Authorization: Basic %s' % token]
455            self._options['header'] = auth
456
457    @property
458    def target(self):
459        return self._url
460
461    @property
462    def closed(self):
463        return not self._sock
464
465    def open(self):
466        sslwrap = websocket._SSLSocketWrapper
467
468        if not self.closed:
469            raise websocket.WebSocketException('Already opened')
470
471        if self._ssl:
472            websocket._SSLSocketWrapper = self._ssl
473
474        try:
475            self._sock = websocket.WebSocket()
476            self._sock.connect(self._url, **self._options)
477            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
478            return self._switch.register_port(self)
479        finally:
480            websocket._SSLSocketWrapper = sslwrap
481            self.dprintf('connected: %s\n', lambda: self._url)
482
483    def close(self):
484        if self.closed:
485            raise websocket.WebSocketException('Already closed')
486        self._switch.unregister_port(self)
487        self._ioloop.remove_handler(self.fileno())
488        self._sock.close()
489        self._sock = None
490        self.dprintf('disconnected: %s\n', lambda: self._url)
491
492    def fileno(self):
493        if self.closed:
494            raise websocket.WebSocketException('Closed socket')
495        return self._sock.io_sock.fileno()
496
497    def write_message(self, message, binary=False):
498        if self.closed:
499            raise websocket.WebSocketException('Closed socket')
500        if binary:
501            flag = websocket.ABNF.OPCODE_BINARY
502        else:
503            flag = websocket.ABNF.OPCODE_TEXT
504        self._sock.send(message, flag)
505
506    def __call__(self, fd, events):
507        try:
508            data = self._sock.recv()
509            if data is not None:
510                self._switch.receive(self, EthernetFrame(data))
511                return
512        except:
513            traceback.print_exc()
514        self.close()
515
516
517class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
518    NAMESPACE = 'etherws.control'
519    IFTYPES = {
520        TapHandler.IFTYPE:           TapHandler,
521        EtherWebSocketClient.IFTYPE: EtherWebSocketClient,
522    }
523
524    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
525        super(EtherWebSocketControlHandler, self).__init__(app, req)
526        self._ioloop = ioloop
527        self._switch = switch
528        self._htpasswd = htpasswd
529        self._debug = debug
530
531    def post(self):
532        try:
533            request = json.loads(self.request.body)
534        except Exception as e:
535            return self._jsonrpc_response(error={
536                'code':    0 - 32700,
537                'message': 'Parse error',
538                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
539            })
540
541        try:
542            id_ = request.get('id')
543            params = request.get('params')
544            version = request['jsonrpc']
545            method = request['method']
546            if version != '2.0':
547                raise ValueError('Invalid JSON-RPC version: %s' % version)
548        except Exception as e:
549            return self._jsonrpc_response(id_=id_, error={
550                'code':    0 - 32600,
551                'message': 'Invalid Request',
552                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
553            })
554
555        try:
556            if not method.startswith(self.NAMESPACE + '.'):
557                raise ValueError('Invalid method namespace: %s' % method)
558            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
559            handler = getattr(self, handler)
560        except Exception as e:
561            return self._jsonrpc_response(id_=id_, error={
562                'code':    0 - 32601,
563                'message': 'Method not found',
564                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
565            })
566
567        try:
568            return self._jsonrpc_response(id_=id_, result=handler(params))
569        except Exception as e:
570            traceback.print_exc()
571            return self._jsonrpc_response(id_=id_, error={
572                'code':    0 - 32602,
573                'message': 'Invalid params',
574                'data':     '%s: %s' % (e.__class__.__name__, str(e)),
575            })
576
577    def handle_listFdb(self, params):
578        list_ = []
579        for vid in self._switch.fdb.get_vid_list():
580            for mac in self._switch.fdb.get_mac_list(vid):
581                entry = self._switch.fdb.get_entry(vid, mac)
582                if entry:
583                    list_.append({
584                        'vid':  vid,
585                        'mac':  EthernetFrame.format_mac(mac),
586                        'port': entry.port.number,
587                        'age':  int(entry.age),
588                    })
589        return {'entries': list_}
590
591    def handle_listPort(self, params):
592        return {'entries': [self._portstat(p) for p in self._switch.portlist]}
593
594    def handle_addPort(self, params):
595        type_ = params['type']
596        target = params['target']
597        opts = getattr(self, '_optparse_' + type_)(params.get('options', {}))
598        cls = self.IFTYPES[type_]
599        interface = cls(self._ioloop, self._switch, target, **opts)
600        portnum = interface.open()
601        return {'entries': [self._portstat(self._switch.get_port(portnum))]}
602
603    def handle_delPort(self, params):
604        port = self._switch.get_port(int(params['port']))
605        port.interface.close()
606        return {'entries': [self._portstat(port)]}
607
608    def handle_shutPort(self, params):
609        port = self._switch.get_port(int(params['port']))
610        port.shut = bool(params['shut'])
611        return {'entries': [self._portstat(port)]}
612
613    def _optparse_tap(self, opt):
614        return {'debug': self._debug}
615
616    def _optparse_client(self, opt):
617        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
618        if opt.get('insecure'):
619            args = {}
620        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
621        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
622        return {'ssl_': ssl_, 'cred': cred, 'debug': self._debug}
623
624    def _jsonrpc_response(self, id_=None, result=None, error=None):
625        res = {'jsonrpc': '2.0', 'id': id_}
626        if result:
627            res['result'] = result
628        if error:
629            res['error'] = error
630        self.finish(res)
631
632    @staticmethod
633    def _portstat(port):
634        return {
635            'port':   port.number,
636            'type':   port.interface.IFTYPE,
637            'target': port.interface.target,
638            'tx':     port.tx,
639            'rx':     port.rx,
640            'shut':   port.shut,
641        }
642
643
644def start_sw(args):
645    def daemonize(nochdir=False, noclose=False):
646        if os.fork() > 0:
647            sys.exit(0)
648
649        os.setsid()
650
651        if os.fork() > 0:
652            sys.exit(0)
653
654        if not nochdir:
655            os.chdir('/')
656
657        if not noclose:
658            os.umask(0)
659            sys.stdin.close()
660            sys.stdout.close()
661            sys.stderr.close()
662            os.close(0)
663            os.close(1)
664            os.close(2)
665            sys.stdin = open(os.devnull)
666            sys.stdout = open(os.devnull, 'a')
667            sys.stderr = open(os.devnull, 'a')
668
669    def checkabspath(ns, path):
670        val = getattr(ns, path, '')
671        if not val.startswith('/'):
672            raise ValueError('Invalid %: %s' % (path, val))
673
674    def getsslopt(ns, key, cert):
675        kval = getattr(ns, key, None)
676        cval = getattr(ns, cert, None)
677        if kval and cval:
678            return {'keyfile': kval, 'certfile': cval}
679        elif kval or cval:
680            raise ValueError('Both %s and %s are required' % (key, cert))
681        return None
682
683    def setrealpath(ns, *keys):
684        for k in keys:
685            v = getattr(ns, k, None)
686            if v is not None:
687                v = os.path.realpath(v)
688                open(v).close()  # check readable
689                setattr(ns, k, v)
690
691    def setport(ns, port, isssl):
692        val = getattr(ns, port, None)
693        if val is None:
694            if isssl:
695                return setattr(ns, port, 443)
696            return setattr(ns, port, 80)
697        if not (0 <= val <= 65535):
698            raise ValueError('Invalid %s: %s' % (port, val))
699
700    def sethtpasswd(ns, htpasswd):
701        val = getattr(ns, htpasswd, None)
702        if val:
703            return setattr(ns, htpasswd, Htpasswd(val))
704
705    #if args.debug:
706    #    websocket.enableTrace(True)
707
708    if args.ageout <= 0:
709        raise ValueError('Invalid ageout: %s' % args.ageout)
710
711    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
712    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
713
714    checkabspath(args, 'path')
715    checkabspath(args, 'ctlpath')
716
717    sslopt = getsslopt(args, 'sslkey', 'sslcert')
718    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
719
720    setport(args, 'port', sslopt)
721    setport(args, 'ctlport', ctlsslopt)
722
723    sethtpasswd(args, 'htpasswd')
724    sethtpasswd(args, 'ctlhtpasswd')
725
726    ioloop = IOLoop.instance()
727    fdb = FDB(ageout=args.ageout, debug=args.debug)
728    switch = SwitchingHub(fdb, debug=args.debug)
729
730    if args.port == args.ctlport and args.host == args.ctlhost:
731        if args.path == args.ctlpath:
732            raise ValueError('Same path/ctlpath on same host')
733        if args.sslkey != args.ctlsslkey:
734            raise ValueError('Different sslkey/ctlsslkey on same host')
735        if args.sslcert != args.ctlsslcert:
736            raise ValueError('Different sslcert/ctlsslcert on same host')
737
738        app = Application([
739            (args.path, EtherWebSocketHandler, {
740                'switch':   switch,
741                'htpasswd': args.htpasswd,
742                'debug':    args.debug,
743            }),
744            (args.ctlpath, EtherWebSocketControlHandler, {
745                'ioloop':   ioloop,
746                'switch':   switch,
747                'htpasswd': args.ctlhtpasswd,
748                'debug':    args.debug,
749            }),
750        ])
751        server = HTTPServer(app, ssl_options=sslopt)
752        server.listen(args.port, address=args.host)
753
754    else:
755        app = Application([(args.path, EtherWebSocketHandler, {
756            'switch':   switch,
757            'htpasswd': args.htpasswd,
758            'debug':    args.debug,
759        })])
760        server = HTTPServer(app, ssl_options=sslopt)
761        server.listen(args.port, address=args.host)
762
763        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
764            'ioloop':   ioloop,
765            'switch':   switch,
766            'htpasswd': args.ctlhtpasswd,
767            'debug':    args.debug,
768        })])
769        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
770        ctlserver.listen(args.ctlport, address=args.ctlhost)
771
772    if not args.foreground:
773        daemonize()
774
775    ioloop.start()
776
777
778def start_ctl(args):
779    def request(args, method, params=None, id_=0):
780        req = urllib2.Request(args.ctlurl)
781        req.add_header('Content-type', 'application/json')
782        if args.ctluser:
783            if not args.ctlpasswd:
784                args.ctlpasswd = getpass.getpass()
785            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
786            req.add_header('Authorization', 'Basic %s' % token)
787        method = '.'.join([EtherWebSocketControlHandler.NAMESPACE, method])
788        data = {'jsonrpc': '2.0', 'method': method, 'id': id_}
789        if params is not None:
790            data['params'] = params
791        return json.loads(urllib2.urlopen(req, json.dumps(data)).read())
792
793    def maxlen(dict_, key, min_):
794        if not dict_:
795            return min_
796        max_ = max(len(str(r[key])) for r in dict_)
797        return min_ if max_ < min_ else max_
798
799    def print_portlist(result):
800        pmax = maxlen(result, 'port', 4)
801        ymax = maxlen(result, 'type', 4)
802        smax = maxlen(result, 'shut', 5)
803        rmax = maxlen(result, 'rx', 2)
804        tmax = maxlen(result, 'tx', 2)
805        fmt = %%%d%%%d%%%d%%%d%%%d%%s' % \
806              (pmax, ymax, smax, rmax, tmax)
807        print(fmt % ('Port', 'Type', 'State', 'RX', 'TX', 'Target'))
808        for r in result:
809            shut = 'shut' if r['shut'] else 'up'
810            print(fmt %
811                  (r['port'], r['type'], shut, r['rx'], r['tx'], r['target']))
812
813    def print_error(error):
814        print(%s (%s)' % (error['message'], error['code']))
815        print('    %s' % error['data'])
816
817    def handle_ctl_addport(args):
818        result = request(args, 'addPort', {
819            'type':    args.type,
820            'target':  args.target,
821            'options': {
822                'insecure': args.insecure,
823                'cacerts':  args.cacerts,
824                'user':     args.user,
825                'passwd':   args.passwd,
826            },
827        })
828        if 'error' in result:
829            print_error(result['error'])
830        else:
831            print_portlist(result['result']['entries'])
832
833    def handle_ctl_shutport(args):
834        if args.port <= 0:
835            raise ValueError('Invalid port: %d' % args.port)
836        result = request(args, 'shutPort', {
837            'port': args.port,
838            'shut': args.no,
839        })
840        if 'error' in result:
841            print_error(result['error'])
842        else:
843            print_portlist(result['result']['entries'])
844
845    def handle_ctl_delport(args):
846        if args.port <= 0:
847            raise ValueError('Invalid port: %d' % args.port)
848        result = request(args, 'delPort', {'port': args.port})
849        if 'error' in result:
850            print_error(result['error'])
851        else:
852            print_portlist(result['result']['entries'])
853
854    def handle_ctl_listport(args):
855        result = request(args, 'listPort')
856        if 'error' in result:
857            print_error(result['error'])
858        else:
859            print_portlist(result['result']['entries'])
860
861    def handle_ctl_listfdb(args):
862        result = request(args, 'listFdb')
863        if 'error' in result:
864            return print_error(result['error'])
865        result = result['result']['entries']
866        pmax = maxlen(result, 'port', 4)
867        vmax = maxlen(result, 'vid', 4)
868        mmax = maxlen(result, 'mac', 3)
869        amax = maxlen(result, 'age', 3)
870        fmt = %%%d%%%d%%-%d%%%ds' % (pmax, vmax, mmax, amax)
871        print(fmt % ('Port', 'VLAN', 'MAC', 'Age'))
872        for r in result:
873            print(fmt % (r['port'], r['vid'], r['mac'], r['age']))
874
875    locals()['handle_ctl_' + args.control_method](args)
876
877
878def main():
879    parser = argparse.ArgumentParser()
880    subcommand = parser.add_subparsers(dest='subcommand')
881
882    # -- sw command parser
883    parser_s = subcommand.add_parser('sw')
884
885    parser_s.add_argument('--debug', action='store_true', default=False)
886    parser_s.add_argument('--foreground', action='store_true', default=False)
887    parser_s.add_argument('--ageout', type=int, default=300)
888
889    parser_s.add_argument('--path', default='/')
890    parser_s.add_argument('--host', default='')
891    parser_s.add_argument('--port', type=int)
892    parser_s.add_argument('--htpasswd')
893    parser_s.add_argument('--sslkey')
894    parser_s.add_argument('--sslcert')
895
896    parser_s.add_argument('--ctlpath', default='/ctl')
897    parser_s.add_argument('--ctlhost', default='')
898    parser_s.add_argument('--ctlport', type=int)
899    parser_s.add_argument('--ctlhtpasswd')
900    parser_s.add_argument('--ctlsslkey')
901    parser_s.add_argument('--ctlsslcert')
902
903    # -- ctl command parser
904    parser_c = subcommand.add_parser('ctl')
905    parser_c.add_argument('--ctlurl', default='http://localhost/ctl')
906    parser_c.add_argument('--ctluser')
907    parser_c.add_argument('--ctlpasswd')
908
909    control_method = parser_c.add_subparsers(dest='control_method')
910
911    parser_c_ap = control_method.add_parser('addport')
912    parser_c_ap.add_argument(
913        'type', choices=EtherWebSocketControlHandler.IFTYPES.keys())
914    parser_c_ap.add_argument('target')
915    parser_c_ap.add_argument('--insecure', action='store_true', default=False)
916    parser_c_ap.add_argument('--cacerts')
917    parser_c_ap.add_argument('--user')
918    parser_c_ap.add_argument('--passwd')
919
920    parser_c_sp = control_method.add_parser('shutport')
921    parser_c_sp.add_argument('port', type=int)
922    parser_c_sp.add_argument('--no', action='store_false', default=True)
923
924    parser_c_dp = control_method.add_parser('delport')
925    parser_c_dp.add_argument('port', type=int)
926
927    parser_c_lp = control_method.add_parser('listport')
928
929    parser_c_lf = control_method.add_parser('listfdb')
930
931    # -- go
932    args = parser.parse_args()
933    globals()['start_' + args.subcommand](args)
934
935
936if __name__ == '__main__':
937    main()
Note: See TracBrowser for help on using the repository browser.