source: etherws/trunk/etherws.py @ 202

Revision 202, 29.6 KB checked in by atzm, 12 years ago (diff)
  • support JSON-RPC 2.0
  • Property svn:keywords set to Id
RevLine 
[133]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
[186]4#                          Ethernet over WebSocket
[133]5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
[136]9#   - websocket-client-0.7.0
[183]10#   - tornado-2.3
[133]11#
12# ===========================================================================
13# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
[156]42import ssl
[160]43import time
[183]44import json
[175]45import fcntl
[150]46import base64
[190]47import urllib2
[150]48import hashlib
[151]49import getpass
[133]50import argparse
[165]51import traceback
[133]52
[185]53import tornado
[133]54import websocket
55
[183]56from tornado.web import Application, RequestHandler
[182]57from tornado.websocket import WebSocketHandler
[183]58from tornado.httpserver import HTTPServer
59from tornado.ioloop import IOLoop
60
[166]61from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
[133]62
[166]63
[160]64class DebugMixIn(object):
[166]65    def dprintf(self, msg, func=lambda: ()):
[160]66        if self._debug:
67            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
[164]68            sys.stderr.write(prefix + (msg % func()))
[160]69
70
[164]71class EthernetFrame(object):
72    def __init__(self, data):
73        self.data = data
74
[176]75    @property
76    def dst_multicast(self):
77        return ord(self.data[0]) & 1
[164]78
79    @property
[176]80    def src_multicast(self):
81        return ord(self.data[6]) & 1
82
83    @property
[164]84    def dst_mac(self):
85        return self.data[:6]
86
87    @property
88    def src_mac(self):
89        return self.data[6:12]
90
91    @property
92    def tagged(self):
93        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
94
95    @property
96    def vid(self):
97        if self.tagged:
98            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
[183]99        return 0
[164]100
[198]101    @staticmethod
102    def format_mac(mac, sep=':'):
103        return sep.join(b.encode('hex') for b in mac)
[164]104
[198]105
[166]106class FDB(DebugMixIn):
[196]107    class Entry(object):
[195]108        def __init__(self, port, ageout):
109            self.port = port
110            self._time = time.time()
111            self._ageout = ageout
112
113        @property
114        def age(self):
115            return time.time() - self._time
116
117        @property
118        def agedout(self):
119            return self.age > self._ageout
120
[167]121    def __init__(self, ageout, debug=False):
[164]122        self._ageout = ageout
123        self._debug = debug
[195]124        self._table = {}
[164]125
[195]126    def _set_entry(self, vid, mac, port):
127        if vid not in self._table:
128            self._table[vid] = {}
[196]129        self._table[vid][mac] = self.Entry(port, self._ageout)
[164]130
[195]131    def _del_entry(self, vid, mac):
132        if vid in self._table:
133            if mac in self._table[vid]:
134                del self._table[vid][mac]
135            if not self._table[vid]:
136                del self._table[vid]
[164]137
[195]138    def get_entry(self, vid, mac):
139        try:
140            entry = self._table[vid][mac]
141        except KeyError:
[164]142            return None
143
[195]144        if not entry.agedout:
145            return entry
[164]146
[195]147        self._del_entry(vid, mac)
148        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
149                     lambda: (entry.port.number, vid, mac.encode('hex')))
[164]150
[195]151    def get_vid_list(self):
[198]152        return sorted(self._table.iterkeys())
[195]153
154    def get_mac_list(self, vid):
[198]155        return sorted(self._table[vid].iterkeys())
[195]156
157    def lookup(self, frame):
158        mac = frame.dst_mac
159        vid = frame.vid
160        entry = self.get_entry(vid, mac)
161        return getattr(entry, 'port', None)
162
[166]163    def learn(self, port, frame):
164        mac = frame.src_mac
165        vid = frame.vid
[195]166        self._set_entry(vid, mac, port)
[183]167        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
168                     lambda: (port.number, vid, mac.encode('hex')))
[166]169
[164]170    def delete(self, port):
[195]171        for vid in self.get_vid_list():
172            for mac in self.get_mac_list(vid):
173                entry = self.get_entry(vid, mac)
174                if entry and entry.port.number == port.number:
175                    self._del_entry(vid, mac)
[183]176                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
177                                 lambda: (port.number, vid, mac.encode('hex')))
[164]178
179
[195]180class SwitchingHub(DebugMixIn):
181    class Port(object):
182        def __init__(self, number, interface):
183            self.number = number
184            self.interface = interface
185            self.tx = 0
186            self.rx = 0
187            self.shut = False
[183]188
[195]189        @staticmethod
190        def cmp_by_number(x, y):
191            return cmp(x.number, y.number)
[183]192
[166]193    def __init__(self, fdb, debug=False):
[197]194        self.fdb = fdb
[133]195        self._debug = debug
[183]196        self._table = {}
197        self._next = 1
[133]198
[183]199    @property
200    def portlist(self):
[195]201        return sorted(self._table.itervalues(), cmp=self.Port.cmp_by_number)
[133]202
[183]203    def get_port(self, portnum):
204        return self._table[portnum]
205
206    def register_port(self, interface):
[186]207        try:
[187]208            self._set_privattr('portnum', interface, self._next)  # XXX
[195]209            self._table[self._next] = self.Port(self._next, interface)
[186]210            return self._next
211        finally:
212            self._next += 1
[183]213
214    def unregister_port(self, interface):
[187]215        portnum = self._get_privattr('portnum', interface)
216        self._del_privattr('portnum', interface)
[197]217        self.fdb.delete(self._table[portnum])
[187]218        del self._table[portnum]
[183]219
220    def send(self, dst_interfaces, frame):
[187]221        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
222        ports = (self._table[n] for n in portnums)
223        ports = (p for p in ports if not p.shut)
[195]224        ports = sorted(ports, cmp=self.Port.cmp_by_number)
[183]225
226        for p in ports:
227            p.interface.write_message(frame.data, True)
228            p.tx += 1
229
230        if ports:
231            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
232                         lambda: (','.join(str(p.number) for p in ports),
233                                  frame.vid,
234                                  frame.src_mac.encode('hex'),
235                                  frame.dst_mac.encode('hex')))
236
237    def receive(self, src_interface, frame):
[187]238        port = self._table[self._get_privattr('portnum', src_interface)]
[183]239
240        if not port.shut:
241            port.rx += 1
242            self._forward(port, frame)
243
244    def _forward(self, src_port, frame):
[166]245        try:
[176]246            if not frame.src_multicast:
[197]247                self.fdb.learn(src_port, frame)
[133]248
[176]249            if not frame.dst_multicast:
[197]250                dst_port = self.fdb.lookup(frame)
[164]251
[166]252                if dst_port:
[183]253                    self.send([dst_port.interface], frame)
[166]254                    return
[133]255
[187]256            ports = set(self.portlist) - set([src_port])
[183]257            self.send((p.interface for p in ports), frame)
[162]258
[166]259        except:  # ex. received invalid frame
260            traceback.print_exc()
[133]261
[187]262    def _privattr(self, name):
263        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
[164]264
[187]265    def _set_privattr(self, name, obj, value):
266        return setattr(obj, self._privattr(name), value)
267
268    def _get_privattr(self, name, obj, defaults=None):
269        return getattr(obj, self._privattr(name), defaults)
270
271    def _del_privattr(self, name, obj):
272        return delattr(obj, self._privattr(name))
273
274
[179]275class Htpasswd(object):
276    def __init__(self, path):
277        self._path = path
278        self._stat = None
279        self._data = {}
280
281    def auth(self, name, passwd):
282        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
283        return self._data.get(name) == passwd
284
285    def load(self):
286        old_stat = self._stat
287
288        with open(self._path) as fp:
289            fileno = fp.fileno()
290            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
291            self._stat = os.fstat(fileno)
292
293            unchanged = old_stat and \
294                        old_stat.st_ino == self._stat.st_ino and \
295                        old_stat.st_dev == self._stat.st_dev and \
296                        old_stat.st_mtime == self._stat.st_mtime
297
298            if not unchanged:
299                self._data = self._parse(fp)
300
301        return self
302
303    def _parse(self, fp):
304        data = {}
305        for line in fp:
306            line = line.strip()
307            if 0 <= line.find(':'):
308                name, passwd = line.split(':', 1)
309                if passwd.startswith('{SHA}'):
310                    data[name] = passwd[5:]
311        return data
312
313
[182]314class BasicAuthMixIn(object):
315    def _execute(self, transforms, *args, **kwargs):
316        def do_execute():
317            sp = super(BasicAuthMixIn, self)
318            return sp._execute(transforms, *args, **kwargs)
319
320        def auth_required():
[185]321            stream = getattr(self, 'stream', self.request.connection.stream)
322            stream.write(tornado.escape.utf8(
[182]323                'HTTP/1.1 401 Authorization Required\r\n'
324                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
325            ))
[185]326            stream.close()
[182]327
328        try:
329            if not self._htpasswd:
330                return do_execute()
331
332            creds = self.request.headers.get('Authorization')
333
334            if not creds or not creds.startswith('Basic '):
335                return auth_required()
336
337            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
338
339            if self._htpasswd.load().auth(name, passwd):
340                return do_execute()
341        except:
342            traceback.print_exc()
343
344        return auth_required()
345
346
[186]347class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
[191]348    IFTYPE = 'server'
349
[186]350    def __init__(self, app, req, switch, htpasswd=None, debug=False):
351        super(EtherWebSocketHandler, self).__init__(app, req)
352        self._switch = switch
353        self._htpasswd = htpasswd
354        self._debug = debug
355
356    def get_target(self):
357        return self.request.remote_ip
358
359    def open(self):
360        try:
361            return self._switch.register_port(self)
362        finally:
363            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
364
365    def on_message(self, message):
366        self._switch.receive(self, EthernetFrame(message))
367
368    def on_close(self):
369        self._switch.unregister_port(self)
370        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
371
372
[166]373class TapHandler(DebugMixIn):
[191]374    IFTYPE = 'tap'
[166]375    READ_SIZE = 65535
376
[178]377    def __init__(self, ioloop, switch, dev, debug=False):
378        self._ioloop = ioloop
[166]379        self._switch = switch
380        self._dev = dev
381        self._debug = debug
382        self._tap = None
383
[186]384    def get_target(self):
[183]385        if self.closed:
386            return self._dev
387        return self._tap.name
388
[186]389    @property
390    def closed(self):
391        return not self._tap
392
[166]393    def open(self):
394        if not self.closed:
[202]395            raise ValueError('Already opened')
[166]396        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
397        self._tap.up()
[178]398        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
[186]399        return self._switch.register_port(self)
[166]400
401    def close(self):
402        if self.closed:
403            raise ValueError('I/O operation on closed tap')
[186]404        self._switch.unregister_port(self)
[178]405        self._ioloop.remove_handler(self.fileno())
[166]406        self._tap.close()
407        self._tap = None
408
409    def fileno(self):
410        if self.closed:
411            raise ValueError('I/O operation on closed tap')
412        return self._tap.fileno()
413
414    def write_message(self, message, binary=False):
415        if self.closed:
416            raise ValueError('I/O operation on closed tap')
417        self._tap.write(message)
418
[138]419    def __call__(self, fd, events):
[166]420        try:
[183]421            self._switch.receive(self, EthernetFrame(self._read()))
[166]422            return
423        except:
424            traceback.print_exc()
[178]425        self.close()
[166]426
427    def _read(self):
428        if self.closed:
429            raise ValueError('I/O operation on closed tap')
[162]430        buf = []
431        while True:
[166]432            buf.append(self._tap.read(self.READ_SIZE))
433            if len(buf[-1]) < self.READ_SIZE:
[162]434                break
[166]435        return ''.join(buf)
[162]436
437
[160]438class EtherWebSocketClient(DebugMixIn):
[191]439    IFTYPE = 'client'
440
[181]441    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
[178]442        self._ioloop = ioloop
[166]443        self._switch = switch
[151]444        self._url = url
[181]445        self._ssl = ssl_
[160]446        self._debug = debug
[166]447        self._sock = None
[151]448        self._options = {}
449
[174]450        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
451            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
[151]452            auth = ['Authorization: Basic %s' % token]
453            self._options['header'] = auth
454
[186]455    def get_target(self):
456        return self._url
457
[160]458    @property
459    def closed(self):
460        return not self._sock
461
[151]462    def open(self):
[181]463        sslwrap = websocket._SSLSocketWrapper
464
[160]465        if not self.closed:
[202]466            raise websocket.WebSocketException('Already opened')
[151]467
[181]468        if self._ssl:
469            websocket._SSLSocketWrapper = self._ssl
470
471        try:
472            self._sock = websocket.WebSocket()
473            self._sock.connect(self._url, **self._options)
474            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
[186]475            return self._switch.register_port(self)
[181]476        finally:
477            websocket._SSLSocketWrapper = sslwrap
[186]478            self.dprintf('connected: %s\n', lambda: self._url)
[181]479
[151]480    def close(self):
[160]481        if self.closed:
[202]482            raise websocket.WebSocketException('Already closed')
[186]483        self._switch.unregister_port(self)
[178]484        self._ioloop.remove_handler(self.fileno())
[151]485        self._sock.close()
486        self._sock = None
[164]487        self.dprintf('disconnected: %s\n', lambda: self._url)
[151]488
[165]489    def fileno(self):
490        if self.closed:
[202]491            raise websocket.WebSocketException('Closed socket')
[165]492        return self._sock.io_sock.fileno()
493
[151]494    def write_message(self, message, binary=False):
[160]495        if self.closed:
[202]496            raise websocket.WebSocketException('Closed socket')
[151]497        if binary:
498            flag = websocket.ABNF.OPCODE_BINARY
[160]499        else:
500            flag = websocket.ABNF.OPCODE_TEXT
[151]501        self._sock.send(message, flag)
502
[165]503    def __call__(self, fd, events):
[151]504        try:
[165]505            data = self._sock.recv()
506            if data is not None:
[183]507                self._switch.receive(self, EthernetFrame(data))
[165]508                return
509        except:
510            traceback.print_exc()
[178]511        self.close()
[151]512
513
[183]514class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
515    NAMESPACE = 'etherws.control'
[191]516    IFTYPES = {
517        TapHandler.IFTYPE:           TapHandler,
518        EtherWebSocketClient.IFTYPE: EtherWebSocketClient,
[186]519    }
[183]520
521    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
522        super(EtherWebSocketControlHandler, self).__init__(app, req)
523        self._ioloop = ioloop
524        self._switch = switch
525        self._htpasswd = htpasswd
526        self._debug = debug
527
528    def post(self):
[202]529        try:
530            request = json.loads(self.request.body)
531        except Exception as e:
532            return self._jsonrpc_response(error={
533                'code':    0 - 32700,
534                'message': 'Parse error',
535                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
536            })
[183]537
538        try:
[202]539            id_ = request.get('id')
540            params = request.get('params')
541            version = request['jsonrpc']
542            method = request['method']
543            if version != '2.0':
544                raise ValueError('Invalid JSON-RPC version: %s' % version)
545        except Exception as e:
546            return self._jsonrpc_response(id_=id_, error={
547                'code':    0 - 32600,
548                'message': 'Invalid Request',
549                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
550            })
[183]551
[202]552        try:
[183]553            if not method.startswith(self.NAMESPACE + '.'):
[202]554                raise ValueError('Invalid method namespace: %s' % method)
[183]555            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
[202]556            handler = getattr(self, handler)
557        except Exception as e:
558            return self._jsonrpc_response(id_=id_, error={
559                'code':    0 - 32601,
560                'message': 'Method not found',
561                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
562            })
[183]563
[202]564        try:
565            return self._jsonrpc_response(id_=id_, result=handler(params))
[183]566        except Exception as e:
567            traceback.print_exc()
[202]568            return self._jsonrpc_response(id_=id_, error={
569                'code':    0 - 32602,
570                'message': 'Invalid params',
571                'data':     '%s: %s' % (e.__class__.__name__, str(e)),
572            })
[183]573
[198]574    def handle_listFdb(self, params):
575        list_ = []
576        for vid in self._switch.fdb.get_vid_list():
577            for mac in self._switch.fdb.get_mac_list(vid):
578                entry = self._switch.fdb.get_entry(vid, mac)
579                if entry:
[199]580                    list_.append({
581                        'vid':  vid,
[200]582                        'mac':  EthernetFrame.format_mac(mac),
[199]583                        'port': entry.port.number,
584                        'age':  int(entry.age),
585                    })
586        return {'entries': list_}
[198]587
[183]588    def handle_listPort(self, params):
[202]589        return {'entries': [self._portstat(p) for p in self._switch.portlist]}
[183]590
591    def handle_addPort(self, params):
[202]592        type_ = params['type']
593        target = params['target']
594        opts = getattr(self, '_optparse_' + type_)(params.get('options', {}))
595        cls = self.IFTYPES[type_]
596        interface = cls(self._ioloop, self._switch, target, **opts)
597        portnum = interface.open()
598        return {'entries': [self._portstat(self._switch.get_port(portnum))]}
[183]599
600    def handle_delPort(self, params):
[202]601        port = self._switch.get_port(int(params['port']))
602        port.interface.close()
603        return {'entries': [self._portstat(port)]}
[183]604
605    def handle_shutPort(self, params):
[202]606        port = self._switch.get_port(int(params['port']))
607        port.shut = bool(params['shut'])
608        return {'entries': [self._portstat(port)]}
[183]609
[186]610    def _optparse_tap(self, opt):
611        return {'debug': self._debug}
[183]612
[186]613    def _optparse_client(self, opt):
614        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
615        if opt.get('insecure'):
616            args = {}
617        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
618        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
619        return {'ssl_': ssl_, 'cred': cred, 'debug': self._debug}
[183]620
[202]621    def _jsonrpc_response(self, id_=None, result=None, error=None):
622        res = {'jsonrpc': '2.0', 'id': id_}
623        if result:
624            res['result'] = result
625        if error:
626            res['error'] = error
627        self.finish(res)
628
[183]629    @staticmethod
[186]630    def _portstat(port):
631        return {
632            'port':   port.number,
[191]633            'type':   port.interface.IFTYPE,
[186]634            'target': port.interface.get_target(),
635            'tx':     port.tx,
636            'rx':     port.rx,
637            'shut':   port.shut,
638        }
[183]639
640
[190]641def start_sw(args):
[186]642    def daemonize(nochdir=False, noclose=False):
643        if os.fork() > 0:
644            sys.exit(0)
[134]645
[186]646        os.setsid()
[134]647
[186]648        if os.fork() > 0:
649            sys.exit(0)
[134]650
[186]651        if not nochdir:
652            os.chdir('/')
[134]653
[186]654        if not noclose:
655            os.umask(0)
656            sys.stdin.close()
657            sys.stdout.close()
658            sys.stderr.close()
659            os.close(0)
660            os.close(1)
661            os.close(2)
662            sys.stdin = open(os.devnull)
663            sys.stdout = open(os.devnull, 'a')
664            sys.stderr = open(os.devnull, 'a')
[134]665
[186]666    def checkabspath(ns, path):
[184]667        val = getattr(ns, path, '')
668        if not val.startswith('/'):
[202]669            raise ValueError('Invalid %: %s' % (path, val))
[184]670
671    def getsslopt(ns, key, cert):
672        kval = getattr(ns, key, None)
673        cval = getattr(ns, cert, None)
674        if kval and cval:
675            return {'keyfile': kval, 'certfile': cval}
676        elif kval or cval:
[202]677            raise ValueError('Both %s and %s are required' % (key, cert))
[184]678        return None
679
[186]680    def setrealpath(ns, *keys):
681        for k in keys:
682            v = getattr(ns, k, None)
683            if v is not None:
684                v = os.path.realpath(v)
685                open(v).close()  # check readable
686                setattr(ns, k, v)
687
[184]688    def setport(ns, port, isssl):
689        val = getattr(ns, port, None)
690        if val is None:
691            if isssl:
692                return setattr(ns, port, 443)
693            return setattr(ns, port, 80)
694        if not (0 <= val <= 65535):
[202]695            raise ValueError('Invalid %s: %s' % (port, val))
[184]696
697    def sethtpasswd(ns, htpasswd):
698        val = getattr(ns, htpasswd, None)
699        if val:
700            return setattr(ns, htpasswd, Htpasswd(val))
701
[183]702    #if args.debug:
703    #    websocket.enableTrace(True)
704
705    if args.ageout <= 0:
[202]706        raise ValueError('Invalid ageout: %s' % args.ageout)
[183]707
[186]708    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
709    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
[183]710
[186]711    checkabspath(args, 'path')
712    checkabspath(args, 'ctlpath')
[183]713
[184]714    sslopt = getsslopt(args, 'sslkey', 'sslcert')
715    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
[143]716
[184]717    setport(args, 'port', sslopt)
718    setport(args, 'ctlport', ctlsslopt)
[143]719
[184]720    sethtpasswd(args, 'htpasswd')
721    sethtpasswd(args, 'ctlhtpasswd')
[167]722
[183]723    ioloop = IOLoop.instance()
[167]724    fdb = FDB(ageout=args.ageout, debug=args.debug)
[183]725    switch = SwitchingHub(fdb, debug=args.debug)
[167]726
[184]727    if args.port == args.ctlport and args.host == args.ctlhost:
728        if args.path == args.ctlpath:
[202]729            raise ValueError('Same path/ctlpath on same host')
[184]730        if args.sslkey != args.ctlsslkey:
[202]731            raise ValueError('Different sslkey/ctlsslkey on same host')
[184]732        if args.sslcert != args.ctlsslcert:
[202]733            raise ValueError('Different sslcert/ctlsslcert on same host')
[133]734
[184]735        app = Application([
736            (args.path, EtherWebSocketHandler, {
737                'switch':   switch,
738                'htpasswd': args.htpasswd,
739                'debug':    args.debug,
740            }),
741            (args.ctlpath, EtherWebSocketControlHandler, {
742                'ioloop':   ioloop,
743                'switch':   switch,
744                'htpasswd': args.ctlhtpasswd,
745                'debug':    args.debug,
746            }),
747        ])
748        server = HTTPServer(app, ssl_options=sslopt)
749        server.listen(args.port, address=args.host)
[151]750
[184]751    else:
752        app = Application([(args.path, EtherWebSocketHandler, {
753            'switch':   switch,
754            'htpasswd': args.htpasswd,
755            'debug':    args.debug,
756        })])
757        server = HTTPServer(app, ssl_options=sslopt)
758        server.listen(args.port, address=args.host)
759
760        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
761            'ioloop':   ioloop,
762            'switch':   switch,
763            'htpasswd': args.ctlhtpasswd,
764            'debug':    args.debug,
765        })])
766        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
767        ctlserver.listen(args.ctlport, address=args.ctlhost)
768
[151]769    if not args.foreground:
770        daemonize()
771
[138]772    ioloop.start()
[133]773
774
[190]775def start_ctl(args):
[202]776    def request(args, method, params=None, id_=0):
[190]777        req = urllib2.Request(args.ctlurl)
778        req.add_header('Content-type', 'application/json')
779        if args.ctluser:
780            if not args.ctlpasswd:
781                args.ctlpasswd = getpass.getpass()
782            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
783            req.add_header('Authorization', 'Basic %s' % token)
[202]784        method = '.'.join([EtherWebSocketControlHandler.NAMESPACE, method])
785        data = {'jsonrpc': '2.0', 'method': method, 'id': id_}
786        if params is not None:
787            data['params'] = params
788        return json.loads(urllib2.urlopen(req, json.dumps(data)).read())
[190]789
[199]790    def maxlen(dict_, key, min_):
[201]791        if not dict_:
792            return min_
[199]793        max_ = max(len(str(r[key])) for r in dict_)
794        return min_ if max_ < min_ else max_
795
796    def print_portlist(result):
797        pmax = maxlen(result, 'port', 4)
798        ymax = maxlen(result, 'type', 4)
799        smax = maxlen(result, 'shut', 5)
800        rmax = maxlen(result, 'rx', 2)
801        tmax = maxlen(result, 'tx', 2)
802        fmt = %%%d%%%d%%%d%%%d%%%d%%s' % \
803              (pmax, ymax, smax, rmax, tmax)
804        print(fmt % ('Port', 'Type', 'State', 'RX', 'TX', 'Target'))
805        for r in result:
806            shut = 'shut' if r['shut'] else 'up'
807            print(fmt %
808                  (r['port'], r['type'], shut, r['rx'], r['tx'], r['target']))
809
[202]810    def print_error(error):
811        print(%s (%s)' % (error['message'], error['code']))
812        print('    %s' % error['data'])
813
[190]814    def handle_ctl_addport(args):
[202]815        result = request(args, 'addPort', {
[191]816            'type':    args.type,
817            'target':  args.target,
818            'options': {
[193]819                'insecure': args.insecure,
[191]820                'cacerts':  args.cacerts,
821                'user':     args.user,
822                'passwd':   args.passwd,
[202]823            },
824        })
825        if 'error' in result:
826            print_error(result['error'])
[199]827        else:
828            print_portlist(result['result']['entries'])
[190]829
830    def handle_ctl_shutport(args):
831        if args.port <= 0:
[202]832            raise ValueError('Invalid port: %d' % args.port)
833        result = request(args, 'shutPort', {
834            'port': args.port,
835            'shut': args.no,
836        })
837        if 'error' in result:
838            print_error(result['error'])
[199]839        else:
840            print_portlist(result['result']['entries'])
[190]841
842    def handle_ctl_delport(args):
843        if args.port <= 0:
[202]844            raise ValueError('Invalid port: %d' % args.port)
845        result = request(args, 'delPort', {'port': args.port})
846        if 'error' in result:
847            print_error(result['error'])
[199]848        else:
849            print_portlist(result['result']['entries'])
[190]850
851    def handle_ctl_listport(args):
[202]852        result = request(args, 'listPort')
853        if 'error' in result:
854            print_error(result['error'])
[199]855        else:
856            print_portlist(result['result']['entries'])
[190]857
[198]858    def handle_ctl_listfdb(args):
[202]859        result = request(args, 'listFdb')
860        if 'error' in result:
861            return print_error(result['error'])
[199]862        result = result['result']['entries']
[201]863        pmax = maxlen(result, 'port', 4)
[199]864        vmax = maxlen(result, 'vid', 4)
865        mmax = maxlen(result, 'mac', 3)
866        amax = maxlen(result, 'age', 3)
[201]867        fmt = %%%d%%%d%%-%d%%%ds' % (pmax, vmax, mmax, amax)
868        print(fmt % ('Port', 'VLAN', 'MAC', 'Age'))
[199]869        for r in result:
[201]870            print(fmt % (r['port'], r['vid'], r['mac'], r['age']))
[198]871
[199]872    locals()['handle_ctl_' + args.control_method](args)
[190]873
874
[186]875def main():
876    parser = argparse.ArgumentParser()
[190]877    subcommand = parser.add_subparsers(dest='subcommand')
[186]878
[190]879    # -- sw command parser
880    parser_s = subcommand.add_parser('sw')
881
[186]882    parser_s.add_argument('--debug', action='store_true', default=False)
883    parser_s.add_argument('--foreground', action='store_true', default=False)
[190]884    parser_s.add_argument('--ageout', type=int, default=300)
[186]885
[190]886    parser_s.add_argument('--path', default='/')
887    parser_s.add_argument('--host', default='')
888    parser_s.add_argument('--port', type=int)
889    parser_s.add_argument('--htpasswd')
890    parser_s.add_argument('--sslkey')
891    parser_s.add_argument('--sslcert')
[186]892
[190]893    parser_s.add_argument('--ctlpath', default='/ctl')
894    parser_s.add_argument('--ctlhost', default='')
895    parser_s.add_argument('--ctlport', type=int)
896    parser_s.add_argument('--ctlhtpasswd')
897    parser_s.add_argument('--ctlsslkey')
898    parser_s.add_argument('--ctlsslcert')
[186]899
[190]900    # -- ctl command parser
901    parser_c = subcommand.add_parser('ctl')
902    parser_c.add_argument('--ctlurl', default='http://localhost/ctl')
903    parser_c.add_argument('--ctluser')
904    parser_c.add_argument('--ctlpasswd')
905
906    control_method = parser_c.add_subparsers(dest='control_method')
907
908    parser_c_ap = control_method.add_parser('addport')
[191]909    parser_c_ap.add_argument(
910        'type', choices=EtherWebSocketControlHandler.IFTYPES.keys())
[190]911    parser_c_ap.add_argument('target')
912    parser_c_ap.add_argument('--insecure', action='store_true', default=False)
913    parser_c_ap.add_argument('--cacerts')
914    parser_c_ap.add_argument('--user')
915    parser_c_ap.add_argument('--passwd')
916
917    parser_c_sp = control_method.add_parser('shutport')
918    parser_c_sp.add_argument('port', type=int)
919    parser_c_sp.add_argument('--no', action='store_false', default=True)
920
921    parser_c_dp = control_method.add_parser('delport')
922    parser_c_dp.add_argument('port', type=int)
923
924    parser_c_lp = control_method.add_parser('listport')
925
[198]926    parser_c_lf = control_method.add_parser('listfdb')
927
[190]928    # -- go
[186]929    args = parser.parse_args()
930    globals()['start_' + args.subcommand](args)
931
932
[133]933if __name__ == '__main__':
934    main()
Note: See TracBrowser for help on using the repository browser.