source: etherws/trunk/etherws.py @ 211

Revision 211, 30.8 KB checked in by atzm, 12 years ago (diff)
  • shutport -> setport
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.3
11#
12# ===========================================================================
13# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
42import ssl
43import time
44import json
45import fcntl
46import base64
47import urllib2
48import hashlib
49import getpass
50import argparse
51import traceback
52
53import tornado
54import websocket
55
56from tornado.web import Application, RequestHandler
57from tornado.websocket import WebSocketHandler
58from tornado.httpserver import HTTPServer
59from tornado.ioloop import IOLoop
60
61from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
62
63
64class DebugMixIn(object):
65    def dprintf(self, msg, func=lambda: ()):
66        if self._debug:
67            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
68            sys.stderr.write(prefix + (msg % func()))
69
70
71class EthernetFrame(object):
72    def __init__(self, data):
73        self.data = data
74
75    @property
76    def dst_multicast(self):
77        return ord(self.data[0]) & 1
78
79    @property
80    def src_multicast(self):
81        return ord(self.data[6]) & 1
82
83    @property
84    def dst_mac(self):
85        return self.data[:6]
86
87    @property
88    def src_mac(self):
89        return self.data[6:12]
90
91    @property
92    def tagged(self):
93        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
94
95    @property
96    def vid(self):
97        if self.tagged:
98            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
99        return 0
100
101    @staticmethod
102    def format_mac(mac, sep=':'):
103        return sep.join(b.encode('hex') for b in mac)
104
105
106class FDB(DebugMixIn):
107    class Entry(object):
108        def __init__(self, port, ageout):
109            self.port = port
110            self._time = time.time()
111            self._ageout = ageout
112
113        @property
114        def age(self):
115            return time.time() - self._time
116
117        @property
118        def agedout(self):
119            return self.age > self._ageout
120
121    def __init__(self, ageout, debug=False):
122        self._ageout = ageout
123        self._debug = debug
124        self._table = {}
125
126    def _set_entry(self, vid, mac, port):
127        if vid not in self._table:
128            self._table[vid] = {}
129        self._table[vid][mac] = self.Entry(port, self._ageout)
130
131    def _del_entry(self, vid, mac):
132        if vid in self._table:
133            if mac in self._table[vid]:
134                del self._table[vid][mac]
135            if not self._table[vid]:
136                del self._table[vid]
137
138    def _get_entry(self, vid, mac):
139        try:
140            entry = self._table[vid][mac]
141        except KeyError:
142            return None
143
144        if not entry.agedout:
145            return entry
146
147        self._del_entry(vid, mac)
148        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
149                     lambda: (entry.port.number, vid, mac.encode('hex')))
150
151    def each(self):
152        for vid in sorted(self._table.iterkeys()):
153            for mac in sorted(self._table[vid].iterkeys()):
154                entry = self._get_entry(vid, mac)
155                if entry:
156                    yield (vid, mac, entry)
157
158    def lookup(self, frame):
159        mac = frame.dst_mac
160        vid = frame.vid
161        entry = self._get_entry(vid, mac)
162        return getattr(entry, 'port', None)
163
164    def learn(self, port, frame):
165        mac = frame.src_mac
166        vid = frame.vid
167        self._set_entry(vid, mac, port)
168        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
169                     lambda: (port.number, vid, mac.encode('hex')))
170
171    def delete(self, port):
172        for vid, mac, entry in self.each():
173            if entry.port.number == port.number:
174                self._del_entry(vid, mac)
175                self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
176                             lambda: (port.number, vid, mac.encode('hex')))
177
178
179class SwitchingHub(DebugMixIn):
180    class Port(object):
181        def __init__(self, number, interface):
182            self.number = number
183            self.interface = interface
184            self.tx = 0
185            self.rx = 0
186            self.shut = False
187
188        @staticmethod
189        def cmp_by_number(x, y):
190            return cmp(x.number, y.number)
191
192    def __init__(self, fdb, debug=False):
193        self.fdb = fdb
194        self._debug = debug
195        self._table = {}
196        self._next = 1
197
198    @property
199    def portlist(self):
200        return sorted(self._table.itervalues(), cmp=self.Port.cmp_by_number)
201
202    def get_port(self, portnum):
203        return self._table[portnum]
204
205    def register_port(self, interface):
206        try:
207            self._set_privattr('portnum', interface, self._next)  # XXX
208            self._table[self._next] = self.Port(self._next, interface)
209            return self._next
210        finally:
211            self._next += 1
212
213    def unregister_port(self, interface):
214        portnum = self._get_privattr('portnum', interface)
215        self._del_privattr('portnum', interface)
216        self.fdb.delete(self._table[portnum])
217        del self._table[portnum]
218
219    def send(self, dst_interfaces, frame):
220        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
221        ports = (self._table[n] for n in portnums)
222        ports = (p for p in ports if not p.shut)
223        ports = sorted(ports, cmp=self.Port.cmp_by_number)
224
225        for p in ports:
226            p.interface.write_message(frame.data, True)
227            p.tx += 1
228
229        if ports:
230            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
231                         lambda: (','.join(str(p.number) for p in ports),
232                                  frame.vid,
233                                  frame.src_mac.encode('hex'),
234                                  frame.dst_mac.encode('hex')))
235
236    def receive(self, src_interface, frame):
237        port = self._table[self._get_privattr('portnum', src_interface)]
238
239        if not port.shut:
240            port.rx += 1
241            self._forward(port, frame)
242
243    def _forward(self, src_port, frame):
244        try:
245            if not frame.src_multicast:
246                self.fdb.learn(src_port, frame)
247
248            if not frame.dst_multicast:
249                dst_port = self.fdb.lookup(frame)
250
251                if dst_port:
252                    self.send([dst_port.interface], frame)
253                    return
254
255            ports = set(self.portlist) - set([src_port])
256            self.send((p.interface for p in ports), frame)
257
258        except:  # ex. received invalid frame
259            traceback.print_exc()
260
261    def _privattr(self, name):
262        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
263
264    def _set_privattr(self, name, obj, value):
265        return setattr(obj, self._privattr(name), value)
266
267    def _get_privattr(self, name, obj, defaults=None):
268        return getattr(obj, self._privattr(name), defaults)
269
270    def _del_privattr(self, name, obj):
271        return delattr(obj, self._privattr(name))
272
273
274class Htpasswd(object):
275    def __init__(self, path):
276        self._path = path
277        self._stat = None
278        self._data = {}
279
280    def auth(self, name, passwd):
281        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
282        return self._data.get(name) == passwd
283
284    def load(self):
285        old_stat = self._stat
286
287        with open(self._path) as fp:
288            fileno = fp.fileno()
289            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
290            self._stat = os.fstat(fileno)
291
292            unchanged = old_stat and \
293                        old_stat.st_ino == self._stat.st_ino and \
294                        old_stat.st_dev == self._stat.st_dev and \
295                        old_stat.st_mtime == self._stat.st_mtime
296
297            if not unchanged:
298                self._data = self._parse(fp)
299
300        return self
301
302    def _parse(self, fp):
303        data = {}
304        for line in fp:
305            line = line.strip()
306            if 0 <= line.find(':'):
307                name, passwd = line.split(':', 1)
308                if passwd.startswith('{SHA}'):
309                    data[name] = passwd[5:]
310        return data
311
312
313class BasicAuthMixIn(object):
314    def _execute(self, transforms, *args, **kwargs):
315        def do_execute():
316            sp = super(BasicAuthMixIn, self)
317            return sp._execute(transforms, *args, **kwargs)
318
319        def auth_required():
320            stream = getattr(self, 'stream', self.request.connection.stream)
321            stream.write(tornado.escape.utf8(
322                'HTTP/1.1 401 Authorization Required\r\n'
323                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
324            ))
325            stream.close()
326
327        try:
328            if not self._htpasswd:
329                return do_execute()
330
331            creds = self.request.headers.get('Authorization')
332
333            if not creds or not creds.startswith('Basic '):
334                return auth_required()
335
336            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
337
338            if self._htpasswd.load().auth(name, passwd):
339                return do_execute()
340        except:
341            traceback.print_exc()
342
343        return auth_required()
344
345
346class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
347    IFTYPE = 'server'
348
349    def __init__(self, app, req, switch, htpasswd=None, debug=False):
350        super(EtherWebSocketHandler, self).__init__(app, req)
351        self._switch = switch
352        self._htpasswd = htpasswd
353        self._debug = debug
354
355    @property
356    def target(self):
357        return ':'.join(str(e) for e in self.request.connection.address)
358
359    def open(self):
360        try:
361            return self._switch.register_port(self)
362        finally:
363            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
364
365    def on_message(self, message):
366        self._switch.receive(self, EthernetFrame(message))
367
368    def on_close(self):
369        self._switch.unregister_port(self)
370        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
371
372
373class TapHandler(DebugMixIn):
374    IFTYPE = 'tap'
375    READ_SIZE = 65535
376
377    def __init__(self, ioloop, switch, dev, debug=False):
378        self._ioloop = ioloop
379        self._switch = switch
380        self._dev = dev
381        self._debug = debug
382        self._tap = None
383
384    @property
385    def target(self):
386        if self.closed:
387            return self._dev
388        return self._tap.name
389
390    @property
391    def closed(self):
392        return not self._tap
393
394    def open(self):
395        if not self.closed:
396            raise ValueError('Already opened')
397        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
398        self._tap.up()
399        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
400        return self._switch.register_port(self)
401
402    def close(self):
403        if self.closed:
404            raise ValueError('I/O operation on closed tap')
405        self._switch.unregister_port(self)
406        self._ioloop.remove_handler(self.fileno())
407        self._tap.close()
408        self._tap = None
409
410    def fileno(self):
411        if self.closed:
412            raise ValueError('I/O operation on closed tap')
413        return self._tap.fileno()
414
415    def write_message(self, message, binary=False):
416        if self.closed:
417            raise ValueError('I/O operation on closed tap')
418        self._tap.write(message)
419
420    def __call__(self, fd, events):
421        try:
422            self._switch.receive(self, EthernetFrame(self._read()))
423            return
424        except:
425            traceback.print_exc()
426        self.close()
427
428    def _read(self):
429        if self.closed:
430            raise ValueError('I/O operation on closed tap')
431        buf = []
432        while True:
433            buf.append(self._tap.read(self.READ_SIZE))
434            if len(buf[-1]) < self.READ_SIZE:
435                break
436        return ''.join(buf)
437
438
439class EtherWebSocketClient(DebugMixIn):
440    IFTYPE = 'client'
441
442    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
443        self._ioloop = ioloop
444        self._switch = switch
445        self._url = url
446        self._ssl = ssl_
447        self._debug = debug
448        self._sock = None
449        self._options = {}
450
451        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
452            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
453            auth = ['Authorization: Basic %s' % token]
454            self._options['header'] = auth
455
456    @property
457    def target(self):
458        return self._url
459
460    @property
461    def closed(self):
462        return not self._sock
463
464    def open(self):
465        sslwrap = websocket._SSLSocketWrapper
466
467        if not self.closed:
468            raise websocket.WebSocketException('Already opened')
469
470        if self._ssl:
471            websocket._SSLSocketWrapper = self._ssl
472
473        try:
474            self._sock = websocket.WebSocket()
475            self._sock.connect(self._url, **self._options)
476            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
477            return self._switch.register_port(self)
478        finally:
479            websocket._SSLSocketWrapper = sslwrap
480            self.dprintf('connected: %s\n', lambda: self._url)
481
482    def close(self):
483        if self.closed:
484            raise websocket.WebSocketException('Already closed')
485        self._switch.unregister_port(self)
486        self._ioloop.remove_handler(self.fileno())
487        self._sock.close()
488        self._sock = None
489        self.dprintf('disconnected: %s\n', lambda: self._url)
490
491    def fileno(self):
492        if self.closed:
493            raise websocket.WebSocketException('Closed socket')
494        return self._sock.io_sock.fileno()
495
496    def write_message(self, message, binary=False):
497        if self.closed:
498            raise websocket.WebSocketException('Closed socket')
499        if binary:
500            flag = websocket.ABNF.OPCODE_BINARY
501        else:
502            flag = websocket.ABNF.OPCODE_TEXT
503        self._sock.send(message, flag)
504
505    def __call__(self, fd, events):
506        try:
507            data = self._sock.recv()
508            if data is not None:
509                self._switch.receive(self, EthernetFrame(data))
510                return
511        except:
512            traceback.print_exc()
513        self.close()
514
515
516class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
517    NAMESPACE = 'etherws.control'
518    IFTYPES = {
519        TapHandler.IFTYPE:           TapHandler,
520        EtherWebSocketClient.IFTYPE: EtherWebSocketClient,
521    }
522
523    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
524        super(EtherWebSocketControlHandler, self).__init__(app, req)
525        self._ioloop = ioloop
526        self._switch = switch
527        self._htpasswd = htpasswd
528        self._debug = debug
529
530    def post(self):
531        try:
532            request = json.loads(self.request.body)
533        except Exception as e:
534            return self._jsonrpc_response(error={
535                'code':    0 - 32700,
536                'message': 'Parse error',
537                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
538            })
539
540        try:
541            id_ = request.get('id')
542            params = request.get('params')
543            version = request['jsonrpc']
544            method = request['method']
545            if version != '2.0':
546                raise ValueError('Invalid JSON-RPC version: %s' % version)
547        except Exception as e:
548            return self._jsonrpc_response(id_=id_, error={
549                'code':    0 - 32600,
550                'message': 'Invalid Request',
551                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
552            })
553
554        try:
555            if not method.startswith(self.NAMESPACE + '.'):
556                raise ValueError('Invalid method namespace: %s' % method)
557            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
558            handler = getattr(self, handler)
559        except Exception as e:
560            return self._jsonrpc_response(id_=id_, error={
561                'code':    0 - 32601,
562                'message': 'Method not found',
563                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
564            })
565
566        try:
567            return self._jsonrpc_response(id_=id_, result=handler(params))
568        except Exception as e:
569            traceback.print_exc()
570            return self._jsonrpc_response(id_=id_, error={
571                'code':    0 - 32602,
572                'message': 'Invalid params',
573                'data':     '%s: %s' % (e.__class__.__name__, str(e)),
574            })
575
576    def handle_listFdb(self, params):
577        list_ = []
578        for vid, mac, entry in self._switch.fdb.each():
579            list_.append({
580                'vid':  vid,
581                'mac':  EthernetFrame.format_mac(mac),
582                'port': entry.port.number,
583                'age':  int(entry.age),
584            })
585        return {'entries': list_}
586
587    def handle_listPort(self, params):
588        return {'entries': [self._portstat(p) for p in self._switch.portlist]}
589
590    def handle_addPort(self, params):
591        type_ = params['type']
592        target = params['target']
593        opts = getattr(self, '_optparse_' + type_)(params.get('options', {}))
594        cls = self.IFTYPES[type_]
595        interface = cls(self._ioloop, self._switch, target, **opts)
596        portnum = interface.open()
597        return {'entries': [self._portstat(self._switch.get_port(portnum))]}
598
599    def handle_setPort(self, params):
600        port = self._switch.get_port(int(params['port']))
601        shut = params.get('shut')
602        if shut is not None:
603            port.shut = bool(shut)
604        return {'entries': [self._portstat(port)]}
605
606    def handle_delPort(self, params):
607        port = self._switch.get_port(int(params['port']))
608        port.interface.close()
609        return {'entries': [self._portstat(port)]}
610
611    def _optparse_tap(self, opt):
612        return {'debug': self._debug}
613
614    def _optparse_client(self, opt):
615        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
616        if opt.get('insecure'):
617            args = {}
618        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
619        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
620        return {'ssl_': ssl_, 'cred': cred, 'debug': self._debug}
621
622    def _jsonrpc_response(self, id_=None, result=None, error=None):
623        res = {'jsonrpc': '2.0', 'id': id_}
624        if result:
625            res['result'] = result
626        if error:
627            res['error'] = error
628        self.finish(res)
629
630    @staticmethod
631    def _portstat(port):
632        return {
633            'port':   port.number,
634            'type':   port.interface.IFTYPE,
635            'target': port.interface.target,
636            'tx':     port.tx,
637            'rx':     port.rx,
638            'shut':   port.shut,
639        }
640
641
642def _print_error(error):
643    print(%s (%s)' % (error['message'], error['code']))
644    print('    %s' % error['data'])
645
646
647def _start_sw(args):
648    def daemonize(nochdir=False, noclose=False):
649        if os.fork() > 0:
650            sys.exit(0)
651
652        os.setsid()
653
654        if os.fork() > 0:
655            sys.exit(0)
656
657        if not nochdir:
658            os.chdir('/')
659
660        if not noclose:
661            os.umask(0)
662            sys.stdin.close()
663            sys.stdout.close()
664            sys.stderr.close()
665            os.close(0)
666            os.close(1)
667            os.close(2)
668            sys.stdin = open(os.devnull)
669            sys.stdout = open(os.devnull, 'a')
670            sys.stderr = open(os.devnull, 'a')
671
672    def checkabspath(ns, path):
673        val = getattr(ns, path, '')
674        if not val.startswith('/'):
675            raise ValueError('Invalid %: %s' % (path, val))
676
677    def getsslopt(ns, key, cert):
678        kval = getattr(ns, key, None)
679        cval = getattr(ns, cert, None)
680        if kval and cval:
681            return {'keyfile': kval, 'certfile': cval}
682        elif kval or cval:
683            raise ValueError('Both %s and %s are required' % (key, cert))
684        return None
685
686    def setrealpath(ns, *keys):
687        for k in keys:
688            v = getattr(ns, k, None)
689            if v is not None:
690                v = os.path.realpath(v)
691                open(v).close()  # check readable
692                setattr(ns, k, v)
693
694    def setport(ns, port, isssl):
695        val = getattr(ns, port, None)
696        if val is None:
697            if isssl:
698                return setattr(ns, port, 443)
699            return setattr(ns, port, 80)
700        if not (0 <= val <= 65535):
701            raise ValueError('Invalid %s: %s' % (port, val))
702
703    def sethtpasswd(ns, htpasswd):
704        val = getattr(ns, htpasswd, None)
705        if val:
706            return setattr(ns, htpasswd, Htpasswd(val))
707
708    #if args.debug:
709    #    websocket.enableTrace(True)
710
711    if args.ageout <= 0:
712        raise ValueError('Invalid ageout: %s' % args.ageout)
713
714    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
715    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
716
717    checkabspath(args, 'path')
718    checkabspath(args, 'ctlpath')
719
720    sslopt = getsslopt(args, 'sslkey', 'sslcert')
721    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
722
723    setport(args, 'port', sslopt)
724    setport(args, 'ctlport', ctlsslopt)
725
726    sethtpasswd(args, 'htpasswd')
727    sethtpasswd(args, 'ctlhtpasswd')
728
729    ioloop = IOLoop.instance()
730    fdb = FDB(ageout=args.ageout, debug=args.debug)
731    switch = SwitchingHub(fdb, debug=args.debug)
732
733    if args.port == args.ctlport and args.host == args.ctlhost:
734        if args.path == args.ctlpath:
735            raise ValueError('Same path/ctlpath on same host')
736        if args.sslkey != args.ctlsslkey:
737            raise ValueError('Different sslkey/ctlsslkey on same host')
738        if args.sslcert != args.ctlsslcert:
739            raise ValueError('Different sslcert/ctlsslcert on same host')
740
741        app = Application([
742            (args.path, EtherWebSocketHandler, {
743                'switch':   switch,
744                'htpasswd': args.htpasswd,
745                'debug':    args.debug,
746            }),
747            (args.ctlpath, EtherWebSocketControlHandler, {
748                'ioloop':   ioloop,
749                'switch':   switch,
750                'htpasswd': args.ctlhtpasswd,
751                'debug':    args.debug,
752            }),
753        ])
754        server = HTTPServer(app, ssl_options=sslopt)
755        server.listen(args.port, address=args.host)
756
757    else:
758        app = Application([(args.path, EtherWebSocketHandler, {
759            'switch':   switch,
760            'htpasswd': args.htpasswd,
761            'debug':    args.debug,
762        })])
763        server = HTTPServer(app, ssl_options=sslopt)
764        server.listen(args.port, address=args.host)
765
766        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
767            'ioloop':   ioloop,
768            'switch':   switch,
769            'htpasswd': args.ctlhtpasswd,
770            'debug':    args.debug,
771        })])
772        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
773        ctlserver.listen(args.ctlport, address=args.ctlhost)
774
775    if not args.foreground:
776        daemonize()
777
778    ioloop.start()
779
780
781def _start_ctl(args):
782    def request(args, method, params=None, id_=0):
783        req = urllib2.Request(args.ctlurl)
784        req.add_header('Content-type', 'application/json')
785        if args.ctluser:
786            if not args.ctlpasswd:
787                args.ctlpasswd = getpass.getpass('Control Password: ')
788            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
789            req.add_header('Authorization', 'Basic %s' % token)
790        method = '.'.join([EtherWebSocketControlHandler.NAMESPACE, method])
791        data = {'jsonrpc': '2.0', 'method': method, 'id': id_}
792        if params is not None:
793            data['params'] = params
794        return json.loads(urllib2.urlopen(req, json.dumps(data)).read())
795
796    def maxlen(dict_, key, min_):
797        if not dict_:
798            return min_
799        max_ = max(len(str(r[key])) for r in dict_)
800        return min_ if max_ < min_ else max_
801
802    def print_portlist(result):
803        pmax = maxlen(result, 'port', 4)
804        ymax = maxlen(result, 'type', 4)
805        smax = maxlen(result, 'shut', 5)
806        rmax = maxlen(result, 'rx', 2)
807        tmax = maxlen(result, 'tx', 2)
808        fmt = %%%d%%%d%%%d%%%d%%%d%%s' % \
809              (pmax, ymax, smax, rmax, tmax)
810        print(fmt % ('Port', 'Type', 'State', 'RX', 'TX', 'Target'))
811        for r in result:
812            shut = 'shut' if r['shut'] else 'up'
813            print(fmt %
814                  (r['port'], r['type'], shut, r['rx'], r['tx'], r['target']))
815
816    def handle_ctl_addport(args):
817        opts = {
818            'user':     getattr(args, 'user', None),
819            'passwd':   getattr(args, 'passwd', None),
820            'cacerts':  getattr(args, 'cacerts', None),
821            'insecure': getattr(args, 'insecure', None),
822        }
823        if args.iftype == EtherWebSocketClient.IFTYPE:
824            if not args.target.startswith('ws://') and \
825               not args.target.startswith('wss://'):
826                raise ValueError('Invalid target URL scheme: %s' % args.target)
827            if not opts['user'] and opts['passwd']:
828                raise ValueError('Authentication required but username empty')
829            if opts['user'] and not opts['passwd']:
830                opts['passwd'] = getpass.getpass('Client Password: ')
831        result = request(args, 'addPort', {
832            'type':    args.iftype,
833            'target':  args.target,
834            'options': opts,
835        })
836        if 'error' in result:
837            _print_error(result['error'])
838        else:
839            print_portlist(result['result']['entries'])
840
841    def handle_ctl_setport(args):
842        if args.port <= 0:
843            raise ValueError('Invalid port: %d' % args.port)
844        req = {'port': args.port}
845        shut = getattr(args, 'shut', None)
846        if shut is not None:
847            req['shut'] = bool(shut)
848        result = request(args, 'setPort', req)
849        if 'error' in result:
850            _print_error(result['error'])
851        else:
852            print_portlist(result['result']['entries'])
853
854    def handle_ctl_delport(args):
855        if args.port <= 0:
856            raise ValueError('Invalid port: %d' % args.port)
857        result = request(args, 'delPort', {'port': args.port})
858        if 'error' in result:
859            _print_error(result['error'])
860        else:
861            print_portlist(result['result']['entries'])
862
863    def handle_ctl_listport(args):
864        result = request(args, 'listPort')
865        if 'error' in result:
866            _print_error(result['error'])
867        else:
868            print_portlist(result['result']['entries'])
869
870    def handle_ctl_listfdb(args):
871        result = request(args, 'listFdb')
872        if 'error' in result:
873            return _print_error(result['error'])
874        result = result['result']['entries']
875        pmax = maxlen(result, 'port', 4)
876        vmax = maxlen(result, 'vid', 4)
877        mmax = maxlen(result, 'mac', 3)
878        amax = maxlen(result, 'age', 3)
879        fmt = %%%d%%%d%%-%d%%%ds' % (pmax, vmax, mmax, amax)
880        print(fmt % ('Port', 'VLAN', 'MAC', 'Age'))
881        for r in result:
882            print(fmt % (r['port'], r['vid'], r['mac'], r['age']))
883
884    locals()['handle_ctl_' + args.control_method](args)
885
886
887def _main():
888    parser = argparse.ArgumentParser()
889    subcommand = parser.add_subparsers(dest='subcommand')
890
891    # - sw
892    parser_sw = subcommand.add_parser('sw')
893
894    parser_sw.add_argument('--debug', action='store_true', default=False)
895    parser_sw.add_argument('--foreground', action='store_true', default=False)
896    parser_sw.add_argument('--ageout', type=int, default=300)
897
898    parser_sw.add_argument('--path', default='/')
899    parser_sw.add_argument('--host', default='')
900    parser_sw.add_argument('--port', type=int)
901    parser_sw.add_argument('--htpasswd')
902    parser_sw.add_argument('--sslkey')
903    parser_sw.add_argument('--sslcert')
904
905    parser_sw.add_argument('--ctlpath', default='/ctl')
906    parser_sw.add_argument('--ctlhost', default='')
907    parser_sw.add_argument('--ctlport', type=int)
908    parser_sw.add_argument('--ctlhtpasswd')
909    parser_sw.add_argument('--ctlsslkey')
910    parser_sw.add_argument('--ctlsslcert')
911
912    # - ctl
913    parser_ctl = subcommand.add_parser('ctl')
914    parser_ctl.add_argument('--ctlurl', default='http://localhost/ctl')
915    parser_ctl.add_argument('--ctluser')
916    parser_ctl.add_argument('--ctlpasswd')
917
918    control_method = parser_ctl.add_subparsers(dest='control_method')
919
920    # -- ctl addport
921    parser_ctl_addport = control_method.add_parser('addport')
922    iftype = parser_ctl_addport.add_subparsers(dest='iftype')
923
924    # --- ctl addport tap
925    parser_ctl_addport_tap = iftype.add_parser(TapHandler.IFTYPE)
926    parser_ctl_addport_tap.add_argument('target')
927
928    # --- ctl addport client
929    parser_ctl_addport_client = iftype.add_parser(EtherWebSocketClient.IFTYPE)
930    parser_ctl_addport_client.add_argument('target')
931    parser_ctl_addport_client.add_argument('--user')
932    parser_ctl_addport_client.add_argument('--passwd')
933    parser_ctl_addport_client.add_argument('--cacerts')
934    parser_ctl_addport_client.add_argument(
935        '--insecure', action='store_true', default=False)
936
937    # -- ctl setport
938    parser_ctl_setport = control_method.add_parser('setport')
939    parser_ctl_setport.add_argument('port', type=int)
940    parser_ctl_setport.add_argument('--shut', type=int, choices=(0, 1))
941
942    # -- ctl delport
943    parser_ctl_delport = control_method.add_parser('delport')
944    parser_ctl_delport.add_argument('port', type=int)
945
946    # -- ctl listport
947    parser_ctl_listport = control_method.add_parser('listport')
948
949    # -- ctl listfdb
950    parser_ctl_listfdb = control_method.add_parser('listfdb')
951
952    # -- go
953    args = parser.parse_args()
954
955    try:
956        globals()['_start_' + args.subcommand](args)
957    except Exception as e:
958        _print_error({
959            'code':    0 - 32603,
960            'message': 'Internal error',
961            'data':    '%s: %s' % (e.__class__.__name__, str(e)),
962        })
963
964
965if __name__ == '__main__':
966    _main()
Note: See TracBrowser for help on using the repository browser.