source: etherws/trunk/etherws.py @ 209

Revision 209, 30.7 KB checked in by atzm, 12 years ago (diff)
  • fixed behavior when "addport client" password was empty
  • Property svn:keywords set to Id
RevLine 
[133]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
[186]4#                          Ethernet over WebSocket
[133]5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
[136]9#   - websocket-client-0.7.0
[183]10#   - tornado-2.3
[133]11#
12# ===========================================================================
13# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
[156]42import ssl
[160]43import time
[183]44import json
[175]45import fcntl
[150]46import base64
[190]47import urllib2
[150]48import hashlib
[151]49import getpass
[133]50import argparse
[165]51import traceback
[133]52
[185]53import tornado
[133]54import websocket
55
[183]56from tornado.web import Application, RequestHandler
[182]57from tornado.websocket import WebSocketHandler
[183]58from tornado.httpserver import HTTPServer
59from tornado.ioloop import IOLoop
60
[166]61from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
[133]62
[166]63
[160]64class DebugMixIn(object):
[166]65    def dprintf(self, msg, func=lambda: ()):
[160]66        if self._debug:
67            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
[164]68            sys.stderr.write(prefix + (msg % func()))
[160]69
70
[164]71class EthernetFrame(object):
72    def __init__(self, data):
73        self.data = data
74
[176]75    @property
76    def dst_multicast(self):
77        return ord(self.data[0]) & 1
[164]78
79    @property
[176]80    def src_multicast(self):
81        return ord(self.data[6]) & 1
82
83    @property
[164]84    def dst_mac(self):
85        return self.data[:6]
86
87    @property
88    def src_mac(self):
89        return self.data[6:12]
90
91    @property
92    def tagged(self):
93        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
94
95    @property
96    def vid(self):
97        if self.tagged:
98            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
[183]99        return 0
[164]100
[198]101    @staticmethod
102    def format_mac(mac, sep=':'):
103        return sep.join(b.encode('hex') for b in mac)
[164]104
[198]105
[166]106class FDB(DebugMixIn):
[196]107    class Entry(object):
[195]108        def __init__(self, port, ageout):
109            self.port = port
110            self._time = time.time()
111            self._ageout = ageout
112
113        @property
114        def age(self):
115            return time.time() - self._time
116
117        @property
118        def agedout(self):
119            return self.age > self._ageout
120
[167]121    def __init__(self, ageout, debug=False):
[164]122        self._ageout = ageout
123        self._debug = debug
[195]124        self._table = {}
[164]125
[195]126    def _set_entry(self, vid, mac, port):
127        if vid not in self._table:
128            self._table[vid] = {}
[196]129        self._table[vid][mac] = self.Entry(port, self._ageout)
[164]130
[195]131    def _del_entry(self, vid, mac):
132        if vid in self._table:
133            if mac in self._table[vid]:
134                del self._table[vid][mac]
135            if not self._table[vid]:
136                del self._table[vid]
[164]137
[207]138    def _get_entry(self, vid, mac):
[195]139        try:
140            entry = self._table[vid][mac]
141        except KeyError:
[164]142            return None
143
[195]144        if not entry.agedout:
145            return entry
[164]146
[195]147        self._del_entry(vid, mac)
148        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
149                     lambda: (entry.port.number, vid, mac.encode('hex')))
[164]150
[208]151    def each(self):
[207]152        for vid in sorted(self._table.iterkeys()):
153            for mac in sorted(self._table[vid].iterkeys()):
154                entry = self._get_entry(vid, mac)
155                if entry:
156                    yield (vid, mac, entry)
[195]157
158    def lookup(self, frame):
159        mac = frame.dst_mac
160        vid = frame.vid
[207]161        entry = self._get_entry(vid, mac)
[195]162        return getattr(entry, 'port', None)
163
[166]164    def learn(self, port, frame):
165        mac = frame.src_mac
166        vid = frame.vid
[195]167        self._set_entry(vid, mac, port)
[183]168        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
169                     lambda: (port.number, vid, mac.encode('hex')))
[166]170
[164]171    def delete(self, port):
[208]172        for vid, mac, entry in self.each():
[207]173            if entry.port.number == port.number:
174                self._del_entry(vid, mac)
175                self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
176                             lambda: (port.number, vid, mac.encode('hex')))
[164]177
178
[195]179class SwitchingHub(DebugMixIn):
180    class Port(object):
181        def __init__(self, number, interface):
182            self.number = number
183            self.interface = interface
184            self.tx = 0
185            self.rx = 0
186            self.shut = False
[183]187
[195]188        @staticmethod
189        def cmp_by_number(x, y):
190            return cmp(x.number, y.number)
[183]191
[166]192    def __init__(self, fdb, debug=False):
[197]193        self.fdb = fdb
[133]194        self._debug = debug
[183]195        self._table = {}
196        self._next = 1
[133]197
[183]198    @property
199    def portlist(self):
[195]200        return sorted(self._table.itervalues(), cmp=self.Port.cmp_by_number)
[133]201
[183]202    def get_port(self, portnum):
203        return self._table[portnum]
204
205    def register_port(self, interface):
[186]206        try:
[187]207            self._set_privattr('portnum', interface, self._next)  # XXX
[195]208            self._table[self._next] = self.Port(self._next, interface)
[186]209            return self._next
210        finally:
211            self._next += 1
[183]212
213    def unregister_port(self, interface):
[187]214        portnum = self._get_privattr('portnum', interface)
215        self._del_privattr('portnum', interface)
[197]216        self.fdb.delete(self._table[portnum])
[187]217        del self._table[portnum]
[183]218
219    def send(self, dst_interfaces, frame):
[187]220        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
221        ports = (self._table[n] for n in portnums)
222        ports = (p for p in ports if not p.shut)
[195]223        ports = sorted(ports, cmp=self.Port.cmp_by_number)
[183]224
225        for p in ports:
226            p.interface.write_message(frame.data, True)
227            p.tx += 1
228
229        if ports:
230            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
231                         lambda: (','.join(str(p.number) for p in ports),
232                                  frame.vid,
233                                  frame.src_mac.encode('hex'),
234                                  frame.dst_mac.encode('hex')))
235
236    def receive(self, src_interface, frame):
[187]237        port = self._table[self._get_privattr('portnum', src_interface)]
[183]238
239        if not port.shut:
240            port.rx += 1
241            self._forward(port, frame)
242
243    def _forward(self, src_port, frame):
[166]244        try:
[176]245            if not frame.src_multicast:
[197]246                self.fdb.learn(src_port, frame)
[133]247
[176]248            if not frame.dst_multicast:
[197]249                dst_port = self.fdb.lookup(frame)
[164]250
[166]251                if dst_port:
[183]252                    self.send([dst_port.interface], frame)
[166]253                    return
[133]254
[187]255            ports = set(self.portlist) - set([src_port])
[183]256            self.send((p.interface for p in ports), frame)
[162]257
[166]258        except:  # ex. received invalid frame
259            traceback.print_exc()
[133]260
[187]261    def _privattr(self, name):
262        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
[164]263
[187]264    def _set_privattr(self, name, obj, value):
265        return setattr(obj, self._privattr(name), value)
266
267    def _get_privattr(self, name, obj, defaults=None):
268        return getattr(obj, self._privattr(name), defaults)
269
270    def _del_privattr(self, name, obj):
271        return delattr(obj, self._privattr(name))
272
273
[179]274class Htpasswd(object):
275    def __init__(self, path):
276        self._path = path
277        self._stat = None
278        self._data = {}
279
280    def auth(self, name, passwd):
281        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
282        return self._data.get(name) == passwd
283
284    def load(self):
285        old_stat = self._stat
286
287        with open(self._path) as fp:
288            fileno = fp.fileno()
289            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
290            self._stat = os.fstat(fileno)
291
292            unchanged = old_stat and \
293                        old_stat.st_ino == self._stat.st_ino and \
294                        old_stat.st_dev == self._stat.st_dev and \
295                        old_stat.st_mtime == self._stat.st_mtime
296
297            if not unchanged:
298                self._data = self._parse(fp)
299
300        return self
301
302    def _parse(self, fp):
303        data = {}
304        for line in fp:
305            line = line.strip()
306            if 0 <= line.find(':'):
307                name, passwd = line.split(':', 1)
308                if passwd.startswith('{SHA}'):
309                    data[name] = passwd[5:]
310        return data
311
312
[182]313class BasicAuthMixIn(object):
314    def _execute(self, transforms, *args, **kwargs):
315        def do_execute():
316            sp = super(BasicAuthMixIn, self)
317            return sp._execute(transforms, *args, **kwargs)
318
319        def auth_required():
[185]320            stream = getattr(self, 'stream', self.request.connection.stream)
321            stream.write(tornado.escape.utf8(
[182]322                'HTTP/1.1 401 Authorization Required\r\n'
323                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
324            ))
[185]325            stream.close()
[182]326
327        try:
328            if not self._htpasswd:
329                return do_execute()
330
331            creds = self.request.headers.get('Authorization')
332
333            if not creds or not creds.startswith('Basic '):
334                return auth_required()
335
336            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
337
338            if self._htpasswd.load().auth(name, passwd):
339                return do_execute()
340        except:
341            traceback.print_exc()
342
343        return auth_required()
344
345
[186]346class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
[191]347    IFTYPE = 'server'
348
[186]349    def __init__(self, app, req, switch, htpasswd=None, debug=False):
350        super(EtherWebSocketHandler, self).__init__(app, req)
351        self._switch = switch
352        self._htpasswd = htpasswd
353        self._debug = debug
354
[203]355    @property
356    def target(self):
357        return ':'.join(str(e) for e in self.request.connection.address)
[186]358
359    def open(self):
360        try:
361            return self._switch.register_port(self)
362        finally:
363            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
364
365    def on_message(self, message):
366        self._switch.receive(self, EthernetFrame(message))
367
368    def on_close(self):
369        self._switch.unregister_port(self)
370        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
371
372
[166]373class TapHandler(DebugMixIn):
[191]374    IFTYPE = 'tap'
[166]375    READ_SIZE = 65535
376
[178]377    def __init__(self, ioloop, switch, dev, debug=False):
378        self._ioloop = ioloop
[166]379        self._switch = switch
380        self._dev = dev
381        self._debug = debug
382        self._tap = None
383
[203]384    @property
385    def target(self):
[183]386        if self.closed:
387            return self._dev
388        return self._tap.name
389
[186]390    @property
391    def closed(self):
392        return not self._tap
393
[166]394    def open(self):
395        if not self.closed:
[202]396            raise ValueError('Already opened')
[166]397        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
398        self._tap.up()
[178]399        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
[186]400        return self._switch.register_port(self)
[166]401
402    def close(self):
403        if self.closed:
404            raise ValueError('I/O operation on closed tap')
[186]405        self._switch.unregister_port(self)
[178]406        self._ioloop.remove_handler(self.fileno())
[166]407        self._tap.close()
408        self._tap = None
409
410    def fileno(self):
411        if self.closed:
412            raise ValueError('I/O operation on closed tap')
413        return self._tap.fileno()
414
415    def write_message(self, message, binary=False):
416        if self.closed:
417            raise ValueError('I/O operation on closed tap')
418        self._tap.write(message)
419
[138]420    def __call__(self, fd, events):
[166]421        try:
[183]422            self._switch.receive(self, EthernetFrame(self._read()))
[166]423            return
424        except:
425            traceback.print_exc()
[178]426        self.close()
[166]427
428    def _read(self):
429        if self.closed:
430            raise ValueError('I/O operation on closed tap')
[162]431        buf = []
432        while True:
[166]433            buf.append(self._tap.read(self.READ_SIZE))
434            if len(buf[-1]) < self.READ_SIZE:
[162]435                break
[166]436        return ''.join(buf)
[162]437
438
[160]439class EtherWebSocketClient(DebugMixIn):
[191]440    IFTYPE = 'client'
441
[181]442    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
[178]443        self._ioloop = ioloop
[166]444        self._switch = switch
[151]445        self._url = url
[181]446        self._ssl = ssl_
[160]447        self._debug = debug
[166]448        self._sock = None
[151]449        self._options = {}
450
[174]451        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
452            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
[151]453            auth = ['Authorization: Basic %s' % token]
454            self._options['header'] = auth
455
[203]456    @property
457    def target(self):
[186]458        return self._url
459
[160]460    @property
461    def closed(self):
462        return not self._sock
463
[151]464    def open(self):
[181]465        sslwrap = websocket._SSLSocketWrapper
466
[160]467        if not self.closed:
[202]468            raise websocket.WebSocketException('Already opened')
[151]469
[181]470        if self._ssl:
471            websocket._SSLSocketWrapper = self._ssl
472
473        try:
474            self._sock = websocket.WebSocket()
475            self._sock.connect(self._url, **self._options)
476            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
[186]477            return self._switch.register_port(self)
[181]478        finally:
479            websocket._SSLSocketWrapper = sslwrap
[186]480            self.dprintf('connected: %s\n', lambda: self._url)
[181]481
[151]482    def close(self):
[160]483        if self.closed:
[202]484            raise websocket.WebSocketException('Already closed')
[186]485        self._switch.unregister_port(self)
[178]486        self._ioloop.remove_handler(self.fileno())
[151]487        self._sock.close()
488        self._sock = None
[164]489        self.dprintf('disconnected: %s\n', lambda: self._url)
[151]490
[165]491    def fileno(self):
492        if self.closed:
[202]493            raise websocket.WebSocketException('Closed socket')
[165]494        return self._sock.io_sock.fileno()
495
[151]496    def write_message(self, message, binary=False):
[160]497        if self.closed:
[202]498            raise websocket.WebSocketException('Closed socket')
[151]499        if binary:
500            flag = websocket.ABNF.OPCODE_BINARY
[160]501        else:
502            flag = websocket.ABNF.OPCODE_TEXT
[151]503        self._sock.send(message, flag)
504
[165]505    def __call__(self, fd, events):
[151]506        try:
[165]507            data = self._sock.recv()
508            if data is not None:
[183]509                self._switch.receive(self, EthernetFrame(data))
[165]510                return
511        except:
512            traceback.print_exc()
[178]513        self.close()
[151]514
515
[183]516class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
517    NAMESPACE = 'etherws.control'
[191]518    IFTYPES = {
519        TapHandler.IFTYPE:           TapHandler,
520        EtherWebSocketClient.IFTYPE: EtherWebSocketClient,
[186]521    }
[183]522
523    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
524        super(EtherWebSocketControlHandler, self).__init__(app, req)
525        self._ioloop = ioloop
526        self._switch = switch
527        self._htpasswd = htpasswd
528        self._debug = debug
529
530    def post(self):
[202]531        try:
532            request = json.loads(self.request.body)
533        except Exception as e:
534            return self._jsonrpc_response(error={
535                'code':    0 - 32700,
536                'message': 'Parse error',
537                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
538            })
[183]539
540        try:
[202]541            id_ = request.get('id')
542            params = request.get('params')
543            version = request['jsonrpc']
544            method = request['method']
545            if version != '2.0':
546                raise ValueError('Invalid JSON-RPC version: %s' % version)
547        except Exception as e:
548            return self._jsonrpc_response(id_=id_, error={
549                'code':    0 - 32600,
550                'message': 'Invalid Request',
551                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
552            })
[183]553
[202]554        try:
[183]555            if not method.startswith(self.NAMESPACE + '.'):
[202]556                raise ValueError('Invalid method namespace: %s' % method)
[183]557            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
[202]558            handler = getattr(self, handler)
559        except Exception as e:
560            return self._jsonrpc_response(id_=id_, error={
561                'code':    0 - 32601,
562                'message': 'Method not found',
563                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
564            })
[183]565
[202]566        try:
567            return self._jsonrpc_response(id_=id_, result=handler(params))
[183]568        except Exception as e:
569            traceback.print_exc()
[202]570            return self._jsonrpc_response(id_=id_, error={
571                'code':    0 - 32602,
572                'message': 'Invalid params',
573                'data':     '%s: %s' % (e.__class__.__name__, str(e)),
574            })
[183]575
[198]576    def handle_listFdb(self, params):
577        list_ = []
[208]578        for vid, mac, entry in self._switch.fdb.each():
[207]579            list_.append({
580                'vid':  vid,
581                'mac':  EthernetFrame.format_mac(mac),
582                'port': entry.port.number,
583                'age':  int(entry.age),
584            })
[199]585        return {'entries': list_}
[198]586
[183]587    def handle_listPort(self, params):
[202]588        return {'entries': [self._portstat(p) for p in self._switch.portlist]}
[183]589
590    def handle_addPort(self, params):
[202]591        type_ = params['type']
592        target = params['target']
593        opts = getattr(self, '_optparse_' + type_)(params.get('options', {}))
594        cls = self.IFTYPES[type_]
595        interface = cls(self._ioloop, self._switch, target, **opts)
596        portnum = interface.open()
597        return {'entries': [self._portstat(self._switch.get_port(portnum))]}
[183]598
599    def handle_delPort(self, params):
[202]600        port = self._switch.get_port(int(params['port']))
601        port.interface.close()
602        return {'entries': [self._portstat(port)]}
[183]603
604    def handle_shutPort(self, params):
[202]605        port = self._switch.get_port(int(params['port']))
606        port.shut = bool(params['shut'])
607        return {'entries': [self._portstat(port)]}
[183]608
[186]609    def _optparse_tap(self, opt):
610        return {'debug': self._debug}
[183]611
[186]612    def _optparse_client(self, opt):
613        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
614        if opt.get('insecure'):
615            args = {}
616        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
617        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
618        return {'ssl_': ssl_, 'cred': cred, 'debug': self._debug}
[183]619
[202]620    def _jsonrpc_response(self, id_=None, result=None, error=None):
621        res = {'jsonrpc': '2.0', 'id': id_}
622        if result:
623            res['result'] = result
624        if error:
625            res['error'] = error
626        self.finish(res)
627
[183]628    @staticmethod
[186]629    def _portstat(port):
630        return {
631            'port':   port.number,
[191]632            'type':   port.interface.IFTYPE,
[203]633            'target': port.interface.target,
[186]634            'tx':     port.tx,
635            'rx':     port.rx,
636            'shut':   port.shut,
637        }
[183]638
639
[206]640def _print_error(error):
[205]641    print(%s (%s)' % (error['message'], error['code']))
642    print('    %s' % error['data'])
643
644
[206]645def _start_sw(args):
[186]646    def daemonize(nochdir=False, noclose=False):
647        if os.fork() > 0:
648            sys.exit(0)
[134]649
[186]650        os.setsid()
[134]651
[186]652        if os.fork() > 0:
653            sys.exit(0)
[134]654
[186]655        if not nochdir:
656            os.chdir('/')
[134]657
[186]658        if not noclose:
659            os.umask(0)
660            sys.stdin.close()
661            sys.stdout.close()
662            sys.stderr.close()
663            os.close(0)
664            os.close(1)
665            os.close(2)
666            sys.stdin = open(os.devnull)
667            sys.stdout = open(os.devnull, 'a')
668            sys.stderr = open(os.devnull, 'a')
[134]669
[186]670    def checkabspath(ns, path):
[184]671        val = getattr(ns, path, '')
672        if not val.startswith('/'):
[202]673            raise ValueError('Invalid %: %s' % (path, val))
[184]674
675    def getsslopt(ns, key, cert):
676        kval = getattr(ns, key, None)
677        cval = getattr(ns, cert, None)
678        if kval and cval:
679            return {'keyfile': kval, 'certfile': cval}
680        elif kval or cval:
[202]681            raise ValueError('Both %s and %s are required' % (key, cert))
[184]682        return None
683
[186]684    def setrealpath(ns, *keys):
685        for k in keys:
686            v = getattr(ns, k, None)
687            if v is not None:
688                v = os.path.realpath(v)
689                open(v).close()  # check readable
690                setattr(ns, k, v)
691
[184]692    def setport(ns, port, isssl):
693        val = getattr(ns, port, None)
694        if val is None:
695            if isssl:
696                return setattr(ns, port, 443)
697            return setattr(ns, port, 80)
698        if not (0 <= val <= 65535):
[202]699            raise ValueError('Invalid %s: %s' % (port, val))
[184]700
701    def sethtpasswd(ns, htpasswd):
702        val = getattr(ns, htpasswd, None)
703        if val:
704            return setattr(ns, htpasswd, Htpasswd(val))
705
[183]706    #if args.debug:
707    #    websocket.enableTrace(True)
708
709    if args.ageout <= 0:
[202]710        raise ValueError('Invalid ageout: %s' % args.ageout)
[183]711
[186]712    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
713    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
[183]714
[186]715    checkabspath(args, 'path')
716    checkabspath(args, 'ctlpath')
[183]717
[184]718    sslopt = getsslopt(args, 'sslkey', 'sslcert')
719    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
[143]720
[184]721    setport(args, 'port', sslopt)
722    setport(args, 'ctlport', ctlsslopt)
[143]723
[184]724    sethtpasswd(args, 'htpasswd')
725    sethtpasswd(args, 'ctlhtpasswd')
[167]726
[183]727    ioloop = IOLoop.instance()
[167]728    fdb = FDB(ageout=args.ageout, debug=args.debug)
[183]729    switch = SwitchingHub(fdb, debug=args.debug)
[167]730
[184]731    if args.port == args.ctlport and args.host == args.ctlhost:
732        if args.path == args.ctlpath:
[202]733            raise ValueError('Same path/ctlpath on same host')
[184]734        if args.sslkey != args.ctlsslkey:
[202]735            raise ValueError('Different sslkey/ctlsslkey on same host')
[184]736        if args.sslcert != args.ctlsslcert:
[202]737            raise ValueError('Different sslcert/ctlsslcert on same host')
[133]738
[184]739        app = Application([
740            (args.path, EtherWebSocketHandler, {
741                'switch':   switch,
742                'htpasswd': args.htpasswd,
743                'debug':    args.debug,
744            }),
745            (args.ctlpath, EtherWebSocketControlHandler, {
746                'ioloop':   ioloop,
747                'switch':   switch,
748                'htpasswd': args.ctlhtpasswd,
749                'debug':    args.debug,
750            }),
751        ])
752        server = HTTPServer(app, ssl_options=sslopt)
753        server.listen(args.port, address=args.host)
[151]754
[184]755    else:
756        app = Application([(args.path, EtherWebSocketHandler, {
757            'switch':   switch,
758            'htpasswd': args.htpasswd,
759            'debug':    args.debug,
760        })])
761        server = HTTPServer(app, ssl_options=sslopt)
762        server.listen(args.port, address=args.host)
763
764        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
765            'ioloop':   ioloop,
766            'switch':   switch,
767            'htpasswd': args.ctlhtpasswd,
768            'debug':    args.debug,
769        })])
770        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
771        ctlserver.listen(args.ctlport, address=args.ctlhost)
772
[151]773    if not args.foreground:
774        daemonize()
775
[138]776    ioloop.start()
[133]777
778
[206]779def _start_ctl(args):
[202]780    def request(args, method, params=None, id_=0):
[190]781        req = urllib2.Request(args.ctlurl)
782        req.add_header('Content-type', 'application/json')
783        if args.ctluser:
784            if not args.ctlpasswd:
[209]785                args.ctlpasswd = getpass.getpass('Control Password: ')
[190]786            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
787            req.add_header('Authorization', 'Basic %s' % token)
[202]788        method = '.'.join([EtherWebSocketControlHandler.NAMESPACE, method])
789        data = {'jsonrpc': '2.0', 'method': method, 'id': id_}
790        if params is not None:
791            data['params'] = params
792        return json.loads(urllib2.urlopen(req, json.dumps(data)).read())
[190]793
[199]794    def maxlen(dict_, key, min_):
[201]795        if not dict_:
796            return min_
[199]797        max_ = max(len(str(r[key])) for r in dict_)
798        return min_ if max_ < min_ else max_
799
800    def print_portlist(result):
801        pmax = maxlen(result, 'port', 4)
802        ymax = maxlen(result, 'type', 4)
803        smax = maxlen(result, 'shut', 5)
804        rmax = maxlen(result, 'rx', 2)
805        tmax = maxlen(result, 'tx', 2)
806        fmt = %%%d%%%d%%%d%%%d%%%d%%s' % \
807              (pmax, ymax, smax, rmax, tmax)
808        print(fmt % ('Port', 'Type', 'State', 'RX', 'TX', 'Target'))
809        for r in result:
810            shut = 'shut' if r['shut'] else 'up'
811            print(fmt %
812                  (r['port'], r['type'], shut, r['rx'], r['tx'], r['target']))
813
[190]814    def handle_ctl_addport(args):
[205]815        opts = {
816            'user':     getattr(args, 'user', None),
817            'passwd':   getattr(args, 'passwd', None),
818            'cacerts':  getattr(args, 'cacerts', None),
819            'insecure': getattr(args, 'insecure', None),
820        }
821        if args.iftype == EtherWebSocketClient.IFTYPE:
822            if not args.target.startswith('ws://') and \
823               not args.target.startswith('wss://'):
824                raise ValueError('Invalid target URL scheme: %s' % args.target)
825        if not opts['user'] and opts['passwd']:
826            raise ValueError('Authentication required but username empty')
[209]827        if opts['user'] and not opts['passwd']:
828            opts['passwd'] = getpass.getpass('Client Password: ')
[202]829        result = request(args, 'addPort', {
[204]830            'type':    args.iftype,
[191]831            'target':  args.target,
[205]832            'options': opts,
[202]833        })
834        if 'error' in result:
[206]835            _print_error(result['error'])
[199]836        else:
837            print_portlist(result['result']['entries'])
[190]838
839    def handle_ctl_shutport(args):
840        if args.port <= 0:
[202]841            raise ValueError('Invalid port: %d' % args.port)
842        result = request(args, 'shutPort', {
843            'port': args.port,
844            'shut': args.no,
845        })
846        if 'error' in result:
[206]847            _print_error(result['error'])
[199]848        else:
849            print_portlist(result['result']['entries'])
[190]850
851    def handle_ctl_delport(args):
852        if args.port <= 0:
[202]853            raise ValueError('Invalid port: %d' % args.port)
854        result = request(args, 'delPort', {'port': args.port})
855        if 'error' in result:
[206]856            _print_error(result['error'])
[199]857        else:
858            print_portlist(result['result']['entries'])
[190]859
860    def handle_ctl_listport(args):
[202]861        result = request(args, 'listPort')
862        if 'error' in result:
[206]863            _print_error(result['error'])
[199]864        else:
865            print_portlist(result['result']['entries'])
[190]866
[198]867    def handle_ctl_listfdb(args):
[202]868        result = request(args, 'listFdb')
869        if 'error' in result:
[206]870            return _print_error(result['error'])
[199]871        result = result['result']['entries']
[201]872        pmax = maxlen(result, 'port', 4)
[199]873        vmax = maxlen(result, 'vid', 4)
874        mmax = maxlen(result, 'mac', 3)
875        amax = maxlen(result, 'age', 3)
[201]876        fmt = %%%d%%%d%%-%d%%%ds' % (pmax, vmax, mmax, amax)
877        print(fmt % ('Port', 'VLAN', 'MAC', 'Age'))
[199]878        for r in result:
[201]879            print(fmt % (r['port'], r['vid'], r['mac'], r['age']))
[198]880
[199]881    locals()['handle_ctl_' + args.control_method](args)
[190]882
883
[206]884def _main():
[186]885    parser = argparse.ArgumentParser()
[190]886    subcommand = parser.add_subparsers(dest='subcommand')
[186]887
[204]888    # - sw
889    parser_sw = subcommand.add_parser('sw')
[190]890
[204]891    parser_sw.add_argument('--debug', action='store_true', default=False)
892    parser_sw.add_argument('--foreground', action='store_true', default=False)
893    parser_sw.add_argument('--ageout', type=int, default=300)
[186]894
[204]895    parser_sw.add_argument('--path', default='/')
896    parser_sw.add_argument('--host', default='')
897    parser_sw.add_argument('--port', type=int)
898    parser_sw.add_argument('--htpasswd')
899    parser_sw.add_argument('--sslkey')
900    parser_sw.add_argument('--sslcert')
[186]901
[204]902    parser_sw.add_argument('--ctlpath', default='/ctl')
903    parser_sw.add_argument('--ctlhost', default='')
904    parser_sw.add_argument('--ctlport', type=int)
905    parser_sw.add_argument('--ctlhtpasswd')
906    parser_sw.add_argument('--ctlsslkey')
907    parser_sw.add_argument('--ctlsslcert')
[186]908
[204]909    # - ctl
910    parser_ctl = subcommand.add_parser('ctl')
911    parser_ctl.add_argument('--ctlurl', default='http://localhost/ctl')
912    parser_ctl.add_argument('--ctluser')
913    parser_ctl.add_argument('--ctlpasswd')
[190]914
[204]915    control_method = parser_ctl.add_subparsers(dest='control_method')
[190]916
[204]917    # -- ctl addport
918    parser_ctl_addport = control_method.add_parser('addport')
919    iftype = parser_ctl_addport.add_subparsers(dest='iftype')
[190]920
[204]921    # --- ctl addport tap
922    parser_ctl_addport_tap = iftype.add_parser(TapHandler.IFTYPE)
923    parser_ctl_addport_tap.add_argument('target')
[190]924
[204]925    # --- ctl addport client
926    parser_ctl_addport_client = iftype.add_parser(EtherWebSocketClient.IFTYPE)
927    parser_ctl_addport_client.add_argument('target')
928    parser_ctl_addport_client.add_argument('--user')
929    parser_ctl_addport_client.add_argument('--passwd')
930    parser_ctl_addport_client.add_argument('--cacerts')
931    parser_ctl_addport_client.add_argument(
932        '--insecure', action='store_true', default=False)
[190]933
[204]934    # -- ctl shutport
935    parser_ctl_shutport = control_method.add_parser('shutport')
936    parser_ctl_shutport.add_argument('port', type=int)
937    parser_ctl_shutport.add_argument(
938        '--no', action='store_false', default=True)
[190]939
[204]940    # -- ctl delport
941    parser_ctl_delport = control_method.add_parser('delport')
942    parser_ctl_delport.add_argument('port', type=int)
[198]943
[204]944    # -- ctl listport
945    parser_ctl_listport = control_method.add_parser('listport')
946
947    # -- ctl listfdb
948    parser_ctl_listfdb = control_method.add_parser('listfdb')
949
[190]950    # -- go
[186]951    args = parser.parse_args()
952
[205]953    try:
[206]954        globals()['_start_' + args.subcommand](args)
[205]955    except Exception as e:
[206]956        _print_error({
[205]957            'code':    0 - 32603,
958            'message': 'Internal error',
959            'data':    '%s: %s' % (e.__class__.__name__, str(e)),
960        })
[186]961
[205]962
[133]963if __name__ == '__main__':
[206]964    _main()
Note: See TracBrowser for help on using the repository browser.