source: etherws/trunk/etherws.py @ 214

Revision 214, 35.1 KB checked in by atzm, 12 years ago (diff)
  • change default ctl bind address
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.3
11#
12# ===========================================================================
13# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
42import ssl
43import time
44import json
45import fcntl
46import base64
47import socket
48import urllib2
49import hashlib
50import getpass
51import argparse
52import traceback
53
54import tornado
55import websocket
56
57from tornado.web import Application, RequestHandler
58from tornado.websocket import WebSocketHandler
59from tornado.httpserver import HTTPServer
60from tornado.ioloop import IOLoop
61
62from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
63
64
65class DebugMixIn(object):
66    def dprintf(self, msg, func=lambda: ()):
67        if self._debug:
68            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
69            sys.stderr.write(prefix + (msg % func()))
70
71
72class EthernetFrame(object):
73    def __init__(self, data):
74        self.data = data
75
76    @property
77    def dst_multicast(self):
78        return ord(self.data[0]) & 1
79
80    @property
81    def src_multicast(self):
82        return ord(self.data[6]) & 1
83
84    @property
85    def dst_mac(self):
86        return self.data[:6]
87
88    @property
89    def src_mac(self):
90        return self.data[6:12]
91
92    @property
93    def tagged(self):
94        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
95
96    @property
97    def vid(self):
98        if self.tagged:
99            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
100        return 0
101
102    @staticmethod
103    def format_mac(mac, sep=':'):
104        return sep.join(b.encode('hex') for b in mac)
105
106
107class FDB(DebugMixIn):
108    class Entry(object):
109        def __init__(self, port, ageout):
110            self.port = port
111            self._time = time.time()
112            self._ageout = ageout
113
114        @property
115        def age(self):
116            return time.time() - self._time
117
118        @property
119        def agedout(self):
120            return self.age > self._ageout
121
122    def __init__(self, ageout, debug=False):
123        self._ageout = ageout
124        self._debug = debug
125        self._table = {}
126
127    def _set_entry(self, vid, mac, port):
128        if vid not in self._table:
129            self._table[vid] = {}
130        self._table[vid][mac] = self.Entry(port, self._ageout)
131
132    def _del_entry(self, vid, mac):
133        if vid in self._table:
134            if mac in self._table[vid]:
135                del self._table[vid][mac]
136            if not self._table[vid]:
137                del self._table[vid]
138
139    def _get_entry(self, vid, mac):
140        try:
141            entry = self._table[vid][mac]
142        except KeyError:
143            return None
144
145        if not entry.agedout:
146            return entry
147
148        self._del_entry(vid, mac)
149        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
150                     lambda: (entry.port.number, vid, mac.encode('hex')))
151
152    def each(self):
153        for vid in sorted(self._table.iterkeys()):
154            for mac in sorted(self._table[vid].iterkeys()):
155                entry = self._get_entry(vid, mac)
156                if entry:
157                    yield (vid, mac, entry)
158
159    def lookup(self, frame):
160        mac = frame.dst_mac
161        vid = frame.vid
162        entry = self._get_entry(vid, mac)
163        return getattr(entry, 'port', None)
164
165    def learn(self, port, frame):
166        mac = frame.src_mac
167        vid = frame.vid
168        self._set_entry(vid, mac, port)
169        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
170                     lambda: (port.number, vid, mac.encode('hex')))
171
172    def delete(self, port):
173        for vid, mac, entry in self.each():
174            if entry.port.number == port.number:
175                self._del_entry(vid, mac)
176                self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
177                             lambda: (port.number, vid, mac.encode('hex')))
178
179
180class SwitchingHub(DebugMixIn):
181    class Port(object):
182        def __init__(self, number, interface):
183            self.number = number
184            self.interface = interface
185            self.tx = 0
186            self.rx = 0
187            self.shut = False
188
189        @staticmethod
190        def cmp_by_number(x, y):
191            return cmp(x.number, y.number)
192
193    def __init__(self, fdb, debug=False):
194        self.fdb = fdb
195        self._debug = debug
196        self._table = {}
197        self._next = 1
198
199    @property
200    def portlist(self):
201        return sorted(self._table.itervalues(), cmp=self.Port.cmp_by_number)
202
203    def get_port(self, portnum):
204        return self._table[portnum]
205
206    def register_port(self, interface):
207        try:
208            self._set_privattr('portnum', interface, self._next)  # XXX
209            self._table[self._next] = self.Port(self._next, interface)
210            return self._next
211        finally:
212            self._next += 1
213
214    def unregister_port(self, interface):
215        portnum = self._get_privattr('portnum', interface)
216        self._del_privattr('portnum', interface)
217        self.fdb.delete(self._table[portnum])
218        del self._table[portnum]
219
220    def send(self, dst_interfaces, frame):
221        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
222        ports = (self._table[n] for n in portnums)
223        ports = (p for p in ports if not p.shut)
224        ports = sorted(ports, cmp=self.Port.cmp_by_number)
225
226        for p in ports:
227            p.interface.write_message(frame.data, True)
228            p.tx += 1
229
230        if ports:
231            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
232                         lambda: (','.join(str(p.number) for p in ports),
233                                  frame.vid,
234                                  frame.src_mac.encode('hex'),
235                                  frame.dst_mac.encode('hex')))
236
237    def receive(self, src_interface, frame):
238        port = self._table[self._get_privattr('portnum', src_interface)]
239
240        if not port.shut:
241            port.rx += 1
242            self._forward(port, frame)
243
244    def _forward(self, src_port, frame):
245        try:
246            if not frame.src_multicast:
247                self.fdb.learn(src_port, frame)
248
249            if not frame.dst_multicast:
250                dst_port = self.fdb.lookup(frame)
251
252                if dst_port:
253                    self.send([dst_port.interface], frame)
254                    return
255
256            ports = set(self.portlist) - set([src_port])
257            self.send((p.interface for p in ports), frame)
258
259        except:  # ex. received invalid frame
260            traceback.print_exc()
261
262    def _privattr(self, name):
263        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
264
265    def _set_privattr(self, name, obj, value):
266        return setattr(obj, self._privattr(name), value)
267
268    def _get_privattr(self, name, obj, defaults=None):
269        return getattr(obj, self._privattr(name), defaults)
270
271    def _del_privattr(self, name, obj):
272        return delattr(obj, self._privattr(name))
273
274
275class Htpasswd(object):
276    def __init__(self, path):
277        self._path = path
278        self._stat = None
279        self._data = {}
280
281    def auth(self, name, passwd):
282        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
283        return self._data.get(name) == passwd
284
285    def load(self):
286        old_stat = self._stat
287
288        with open(self._path) as fp:
289            fileno = fp.fileno()
290            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
291            self._stat = os.fstat(fileno)
292
293            unchanged = old_stat and \
294                        old_stat.st_ino == self._stat.st_ino and \
295                        old_stat.st_dev == self._stat.st_dev and \
296                        old_stat.st_mtime == self._stat.st_mtime
297
298            if not unchanged:
299                self._data = self._parse(fp)
300
301        return self
302
303    def _parse(self, fp):
304        data = {}
305        for line in fp:
306            line = line.strip()
307            if 0 <= line.find(':'):
308                name, passwd = line.split(':', 1)
309                if passwd.startswith('{SHA}'):
310                    data[name] = passwd[5:]
311        return data
312
313
314class BasicAuthMixIn(object):
315    def _execute(self, transforms, *args, **kwargs):
316        def do_execute():
317            sp = super(BasicAuthMixIn, self)
318            return sp._execute(transforms, *args, **kwargs)
319
320        def auth_required():
321            stream = getattr(self, 'stream', self.request.connection.stream)
322            stream.write(tornado.escape.utf8(
323                'HTTP/1.1 401 Authorization Required\r\n'
324                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
325            ))
326            stream.close()
327
328        try:
329            if not self._htpasswd:
330                return do_execute()
331
332            creds = self.request.headers.get('Authorization')
333
334            if not creds or not creds.startswith('Basic '):
335                return auth_required()
336
337            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
338
339            if self._htpasswd.load().auth(name, passwd):
340                return do_execute()
341        except:
342            traceback.print_exc()
343
344        return auth_required()
345
346
347class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
348    IFTYPE = 'server'
349
350    def __init__(self, app, req, switch, htpasswd=None, debug=False):
351        super(EtherWebSocketHandler, self).__init__(app, req)
352        self._switch = switch
353        self._htpasswd = htpasswd
354        self._debug = debug
355
356    @property
357    def target(self):
358        return ':'.join(str(e) for e in self.request.connection.address)
359
360    def open(self):
361        try:
362            return self._switch.register_port(self)
363        finally:
364            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
365
366    def on_message(self, message):
367        self._switch.receive(self, EthernetFrame(message))
368
369    def on_close(self):
370        self._switch.unregister_port(self)
371        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
372
373
374class TapHandler(DebugMixIn):
375    IFTYPE = 'tap'
376    READ_SIZE = 65535
377
378    def __init__(self, ioloop, switch, dev, debug=False):
379        self._ioloop = ioloop
380        self._switch = switch
381        self._dev = dev
382        self._debug = debug
383        self._tap = None
384
385    @property
386    def target(self):
387        if self.closed:
388            return self._dev
389        return self._tap.name
390
391    @property
392    def closed(self):
393        return not self._tap
394
395    @property
396    def address(self):
397        if self.closed:
398            raise ValueError('I/O operation on closed tap')
399        try:
400            return self._tap.addr
401        except:
402            return ''
403
404    @property
405    def netmask(self):
406        if self.closed:
407            raise ValueError('I/O operation on closed tap')
408        try:
409            return self._tap.netmask
410        except:
411            return ''
412
413    @property
414    def mtu(self):
415        if self.closed:
416            raise ValueError('I/O operation on closed tap')
417        return self._tap.mtu
418
419    @address.setter
420    def address(self, address):
421        if self.closed:
422            raise ValueError('I/O operation on closed tap')
423        self._tap.addr = address
424
425    @netmask.setter
426    def netmask(self, netmask):
427        if self.closed:
428            raise ValueError('I/O operation on closed tap')
429        self._tap.netmask = netmask
430
431    @mtu.setter
432    def mtu(self, mtu):
433        if self.closed:
434            raise ValueError('I/O operation on closed tap')
435        self._tap.mtu = mtu
436
437    def open(self):
438        if not self.closed:
439            raise ValueError('Already opened')
440        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
441        self._tap.up()
442        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
443        return self._switch.register_port(self)
444
445    def close(self):
446        if self.closed:
447            raise ValueError('I/O operation on closed tap')
448        self._switch.unregister_port(self)
449        self._ioloop.remove_handler(self.fileno())
450        self._tap.close()
451        self._tap = None
452
453    def fileno(self):
454        if self.closed:
455            raise ValueError('I/O operation on closed tap')
456        return self._tap.fileno()
457
458    def write_message(self, message, binary=False):
459        if self.closed:
460            raise ValueError('I/O operation on closed tap')
461        self._tap.write(message)
462
463    def __call__(self, fd, events):
464        try:
465            self._switch.receive(self, EthernetFrame(self._read()))
466            return
467        except:
468            traceback.print_exc()
469        self.close()
470
471    def _read(self):
472        if self.closed:
473            raise ValueError('I/O operation on closed tap')
474        buf = []
475        while True:
476            buf.append(self._tap.read(self.READ_SIZE))
477            if len(buf[-1]) < self.READ_SIZE:
478                break
479        return ''.join(buf)
480
481
482class EtherWebSocketClient(DebugMixIn):
483    IFTYPE = 'client'
484
485    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
486        self._ioloop = ioloop
487        self._switch = switch
488        self._url = url
489        self._ssl = ssl_
490        self._debug = debug
491        self._sock = None
492        self._options = {}
493
494        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
495            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
496            auth = ['Authorization: Basic %s' % token]
497            self._options['header'] = auth
498
499    @property
500    def target(self):
501        return self._url
502
503    @property
504    def closed(self):
505        return not self._sock
506
507    def open(self):
508        sslwrap = websocket._SSLSocketWrapper
509
510        if not self.closed:
511            raise websocket.WebSocketException('Already opened')
512
513        if self._ssl:
514            websocket._SSLSocketWrapper = self._ssl
515
516        try:
517            self._sock = websocket.WebSocket()
518            self._sock.connect(self._url, **self._options)
519            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
520            self.dprintf('connected: %s\n', lambda: self._url)
521            return self._switch.register_port(self)
522        finally:
523            websocket._SSLSocketWrapper = sslwrap
524
525    def close(self):
526        if self.closed:
527            raise websocket.WebSocketException('Already closed')
528        self._switch.unregister_port(self)
529        self._ioloop.remove_handler(self.fileno())
530        self._sock.close()
531        self._sock = None
532        self.dprintf('disconnected: %s\n', lambda: self._url)
533
534    def fileno(self):
535        if self.closed:
536            raise websocket.WebSocketException('Closed socket')
537        return self._sock.io_sock.fileno()
538
539    def write_message(self, message, binary=False):
540        if self.closed:
541            raise websocket.WebSocketException('Closed socket')
542        if binary:
543            flag = websocket.ABNF.OPCODE_BINARY
544        else:
545            flag = websocket.ABNF.OPCODE_TEXT
546        self._sock.send(message, flag)
547
548    def __call__(self, fd, events):
549        try:
550            data = self._sock.recv()
551            if data is not None:
552                self._switch.receive(self, EthernetFrame(data))
553                return
554        except:
555            traceback.print_exc()
556        self.close()
557
558
559class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
560    NAMESPACE = 'etherws.control'
561    IFTYPES = {
562        TapHandler.IFTYPE:           TapHandler,
563        EtherWebSocketClient.IFTYPE: EtherWebSocketClient,
564    }
565
566    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
567        super(EtherWebSocketControlHandler, self).__init__(app, req)
568        self._ioloop = ioloop
569        self._switch = switch
570        self._htpasswd = htpasswd
571        self._debug = debug
572
573    def post(self):
574        try:
575            request = json.loads(self.request.body)
576        except Exception as e:
577            return self._jsonrpc_response(error={
578                'code':    0 - 32700,
579                'message': 'Parse error',
580                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
581            })
582
583        try:
584            id_ = request.get('id')
585            params = request.get('params')
586            version = request['jsonrpc']
587            method = request['method']
588            if version != '2.0':
589                raise ValueError('Invalid JSON-RPC version: %s' % version)
590        except Exception as e:
591            return self._jsonrpc_response(id_=id_, error={
592                'code':    0 - 32600,
593                'message': 'Invalid Request',
594                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
595            })
596
597        try:
598            if not method.startswith(self.NAMESPACE + '.'):
599                raise ValueError('Invalid method namespace: %s' % method)
600            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
601            handler = getattr(self, handler)
602        except Exception as e:
603            return self._jsonrpc_response(id_=id_, error={
604                'code':    0 - 32601,
605                'message': 'Method not found',
606                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
607            })
608
609        try:
610            return self._jsonrpc_response(id_=id_, result=handler(params))
611        except Exception as e:
612            traceback.print_exc()
613            return self._jsonrpc_response(id_=id_, error={
614                'code':    0 - 32602,
615                'message': 'Invalid params',
616                'data':     '%s: %s' % (e.__class__.__name__, str(e)),
617            })
618
619    def handle_listFdb(self, params):
620        list_ = []
621        for vid, mac, entry in self._switch.fdb.each():
622            list_.append({
623                'vid':  vid,
624                'mac':  EthernetFrame.format_mac(mac),
625                'port': entry.port.number,
626                'age':  int(entry.age),
627            })
628        return {'entries': list_}
629
630    def handle_listPort(self, params):
631        return {'entries': [self._portstat(p) for p in self._switch.portlist]}
632
633    def handle_addPort(self, params):
634        type_ = params['type']
635        target = params['target']
636        opts = getattr(self, '_optparse_' + type_)(params.get('options', {}))
637        cls = self.IFTYPES[type_]
638        interface = cls(self._ioloop, self._switch, target, **opts)
639        portnum = interface.open()
640        return {'entries': [self._portstat(self._switch.get_port(portnum))]}
641
642    def handle_setPort(self, params):
643        port = self._switch.get_port(int(params['port']))
644        shut = params.get('shut')
645        if shut is not None:
646            port.shut = bool(shut)
647        return {'entries': [self._portstat(port)]}
648
649    def handle_delPort(self, params):
650        port = self._switch.get_port(int(params['port']))
651        port.interface.close()
652        return {'entries': [self._portstat(port)]}
653
654    def handle_setInterface(self, params):
655        portnum = int(params['port'])
656        port = self._switch.get_port(portnum)
657        address = params.get('address')
658        netmask = params.get('netmask')
659        mtu = params.get('mtu')
660        if not isinstance(port.interface, TapHandler):
661            raise ValueError('Port %d has unsupported interface: %s' %
662                             (portnum, port.interface.IFTYPE))
663        if address is not None:
664            port.interface.address = address
665        if netmask is not None:
666            port.interface.netmask = netmask
667        if mtu is not None:
668            port.interface.mtu = mtu
669        return {'entries': [self._ifstat(port)]}
670
671    def handle_listInterface(self, params):
672        return {'entries': [self._ifstat(p) for p in self._switch.portlist
673                            if isinstance(p.interface, TapHandler)]}
674
675    def _optparse_tap(self, opt):
676        return {'debug': self._debug}
677
678    def _optparse_client(self, opt):
679        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
680        if opt.get('insecure'):
681            args = {}
682        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
683        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
684        return {'ssl_': ssl_, 'cred': cred, 'debug': self._debug}
685
686    def _jsonrpc_response(self, id_=None, result=None, error=None):
687        res = {'jsonrpc': '2.0', 'id': id_}
688        if result:
689            res['result'] = result
690        if error:
691            res['error'] = error
692        self.finish(res)
693
694    @staticmethod
695    def _portstat(port):
696        return {
697            'port':   port.number,
698            'type':   port.interface.IFTYPE,
699            'target': port.interface.target,
700            'tx':     port.tx,
701            'rx':     port.rx,
702            'shut':   port.shut,
703        }
704
705    @staticmethod
706    def _ifstat(port):
707        return {
708            'port':    port.number,
709            'type':    port.interface.IFTYPE,
710            'target':  port.interface.target,
711            'address': port.interface.address,
712            'netmask': port.interface.netmask,
713            'mtu':     port.interface.mtu,
714        }
715
716
717def _print_error(error):
718    print(%s (%s)' % (error['message'], error['code']))
719    print('    %s' % error['data'])
720
721
722def _start_sw(args):
723    def daemonize(nochdir=False, noclose=False):
724        if os.fork() > 0:
725            sys.exit(0)
726
727        os.setsid()
728
729        if os.fork() > 0:
730            sys.exit(0)
731
732        if not nochdir:
733            os.chdir('/')
734
735        if not noclose:
736            os.umask(0)
737            sys.stdin.close()
738            sys.stdout.close()
739            sys.stderr.close()
740            os.close(0)
741            os.close(1)
742            os.close(2)
743            sys.stdin = open(os.devnull)
744            sys.stdout = open(os.devnull, 'a')
745            sys.stderr = open(os.devnull, 'a')
746
747    def checkabspath(ns, path):
748        val = getattr(ns, path, '')
749        if not val.startswith('/'):
750            raise ValueError('Invalid %: %s' % (path, val))
751
752    def getsslopt(ns, key, cert):
753        kval = getattr(ns, key, None)
754        cval = getattr(ns, cert, None)
755        if kval and cval:
756            return {'keyfile': kval, 'certfile': cval}
757        elif kval or cval:
758            raise ValueError('Both %s and %s are required' % (key, cert))
759        return None
760
761    def setrealpath(ns, *keys):
762        for k in keys:
763            v = getattr(ns, k, None)
764            if v is not None:
765                v = os.path.realpath(v)
766                open(v).close()  # check readable
767                setattr(ns, k, v)
768
769    def setport(ns, port, isssl):
770        val = getattr(ns, port, None)
771        if val is None:
772            if isssl:
773                return setattr(ns, port, 443)
774            return setattr(ns, port, 80)
775        if not (0 <= val <= 65535):
776            raise ValueError('Invalid %s: %s' % (port, val))
777
778    def sethtpasswd(ns, htpasswd):
779        val = getattr(ns, htpasswd, None)
780        if val:
781            return setattr(ns, htpasswd, Htpasswd(val))
782
783    #if args.debug:
784    #    websocket.enableTrace(True)
785
786    if args.ageout <= 0:
787        raise ValueError('Invalid ageout: %s' % args.ageout)
788
789    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
790    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
791
792    checkabspath(args, 'path')
793    checkabspath(args, 'ctlpath')
794
795    sslopt = getsslopt(args, 'sslkey', 'sslcert')
796    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
797
798    setport(args, 'port', sslopt)
799    setport(args, 'ctlport', ctlsslopt)
800
801    sethtpasswd(args, 'htpasswd')
802    sethtpasswd(args, 'ctlhtpasswd')
803
804    ioloop = IOLoop.instance()
805    fdb = FDB(ageout=args.ageout, debug=args.debug)
806    switch = SwitchingHub(fdb, debug=args.debug)
807
808    if args.port == args.ctlport and args.host == args.ctlhost:
809        if args.path == args.ctlpath:
810            raise ValueError('Same path/ctlpath on same host')
811        if args.sslkey != args.ctlsslkey:
812            raise ValueError('Different sslkey/ctlsslkey on same host')
813        if args.sslcert != args.ctlsslcert:
814            raise ValueError('Different sslcert/ctlsslcert on same host')
815
816        app = Application([
817            (args.path, EtherWebSocketHandler, {
818                'switch':   switch,
819                'htpasswd': args.htpasswd,
820                'debug':    args.debug,
821            }),
822            (args.ctlpath, EtherWebSocketControlHandler, {
823                'ioloop':   ioloop,
824                'switch':   switch,
825                'htpasswd': args.ctlhtpasswd,
826                'debug':    args.debug,
827            }),
828        ])
829        server = HTTPServer(app, ssl_options=sslopt)
830        server.listen(args.port, address=args.host)
831
832    else:
833        app = Application([(args.path, EtherWebSocketHandler, {
834            'switch':   switch,
835            'htpasswd': args.htpasswd,
836            'debug':    args.debug,
837        })])
838        server = HTTPServer(app, ssl_options=sslopt)
839        server.listen(args.port, address=args.host)
840
841        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
842            'ioloop':   ioloop,
843            'switch':   switch,
844            'htpasswd': args.ctlhtpasswd,
845            'debug':    args.debug,
846        })])
847        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
848        ctlserver.listen(args.ctlport, address=args.ctlhost)
849
850    if not args.foreground:
851        daemonize()
852
853    ioloop.start()
854
855
856def _start_ctl(args):
857    def request(args, method, params=None, id_=0):
858        req = urllib2.Request(args.ctlurl)
859        req.add_header('Content-type', 'application/json')
860        if args.ctluser:
861            if not args.ctlpasswd:
862                args.ctlpasswd = getpass.getpass('Control Password: ')
863            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
864            req.add_header('Authorization', 'Basic %s' % token)
865        method = '.'.join([EtherWebSocketControlHandler.NAMESPACE, method])
866        data = {'jsonrpc': '2.0', 'method': method, 'id': id_}
867        if params is not None:
868            data['params'] = params
869        return json.loads(urllib2.urlopen(req, json.dumps(data)).read())
870
871    def maxlen(dict_, key, min_):
872        if not dict_:
873            return min_
874        max_ = max(len(str(r[key])) for r in dict_)
875        return min_ if max_ < min_ else max_
876
877    def print_portlist(result):
878        pmax = maxlen(result, 'port', 4)
879        ymax = maxlen(result, 'type', 4)
880        smax = maxlen(result, 'shut', 5)
881        rmax = maxlen(result, 'rx', 2)
882        tmax = maxlen(result, 'tx', 2)
883        fmt = %%%d%%%d%%%d%%%d%%%d%%s' % \
884              (pmax, ymax, smax, rmax, tmax)
885        print(fmt % ('Port', 'Type', 'State', 'RX', 'TX', 'Target'))
886        for r in result:
887            shut = 'shut' if r['shut'] else 'up'
888            print(fmt %
889                  (r['port'], r['type'], shut, r['rx'], r['tx'], r['target']))
890
891    def print_iflist(result):
892        pmax = maxlen(result, 'port', 4)
893        tmax = maxlen(result, 'type', 4)
894        amax = maxlen(result, 'address', 7)
895        nmax = maxlen(result, 'netmask', 7)
896        mmax = maxlen(result, 'mtu', 3)
897        fmt = %%%d%%%d%%%d%%%d%%%d%%s' % \
898              (pmax, tmax, amax, nmax, mmax)
899        print(fmt % ('Port', 'Type', 'Address', 'Netmask', 'MTU', 'Target'))
900        for r in result:
901            print(fmt % (r['port'], r['type'],
902                         r['address'], r['netmask'], r['mtu'], r['target']))
903
904    def handle_ctl_addport(args):
905        opts = {
906            'user':     getattr(args, 'user', None),
907            'passwd':   getattr(args, 'passwd', None),
908            'cacerts':  getattr(args, 'cacerts', None),
909            'insecure': getattr(args, 'insecure', None),
910        }
911        if args.iftype == EtherWebSocketClient.IFTYPE:
912            if not args.target.startswith('ws://') and \
913               not args.target.startswith('wss://'):
914                raise ValueError('Invalid target URL scheme: %s' % args.target)
915            if not opts['user'] and opts['passwd']:
916                raise ValueError('Authentication required but username empty')
917            if opts['user'] and not opts['passwd']:
918                opts['passwd'] = getpass.getpass('Client Password: ')
919        result = request(args, 'addPort', {
920            'type':    args.iftype,
921            'target':  args.target,
922            'options': opts,
923        })
924        if 'error' in result:
925            _print_error(result['error'])
926        else:
927            print_portlist(result['result']['entries'])
928
929    def handle_ctl_setport(args):
930        if args.port <= 0:
931            raise ValueError('Invalid port: %d' % args.port)
932        req = {'port': args.port}
933        shut = getattr(args, 'shut', None)
934        if shut is not None:
935            req['shut'] = bool(shut)
936        result = request(args, 'setPort', req)
937        if 'error' in result:
938            _print_error(result['error'])
939        else:
940            print_portlist(result['result']['entries'])
941
942    def handle_ctl_delport(args):
943        if args.port <= 0:
944            raise ValueError('Invalid port: %d' % args.port)
945        result = request(args, 'delPort', {'port': args.port})
946        if 'error' in result:
947            _print_error(result['error'])
948        else:
949            print_portlist(result['result']['entries'])
950
951    def handle_ctl_listport(args):
952        result = request(args, 'listPort')
953        if 'error' in result:
954            _print_error(result['error'])
955        else:
956            print_portlist(result['result']['entries'])
957
958    def handle_ctl_setif(args):
959        if args.port <= 0:
960            raise ValueError('Invalid port: %d' % args.port)
961        req = {'port': args.port}
962        address = getattr(args, 'address', None)
963        netmask = getattr(args, 'netmask', None)
964        mtu = getattr(args, 'mtu', None)
965        if address is not None:
966            if address:
967                socket.inet_aton(address)  # validate
968            req['address'] = address
969        if netmask is not None:
970            if netmask:
971                socket.inet_aton(netmask)  # validate
972            req['netmask'] = netmask
973        if mtu is not None:
974            if mtu < 576:
975                raise ValueError('Invalid MTU: %d' % mtu)
976            req['mtu'] = mtu
977        result = request(args, 'setInterface', req)
978        if 'error' in result:
979            _print_error(result['error'])
980        else:
981            print_iflist(result['result']['entries'])
982
983    def handle_ctl_listif(args):
984        result = request(args, 'listInterface')
985        if 'error' in result:
986            _print_error(result['error'])
987        else:
988            print_iflist(result['result']['entries'])
989
990    def handle_ctl_listfdb(args):
991        result = request(args, 'listFdb')
992        if 'error' in result:
993            return _print_error(result['error'])
994        result = result['result']['entries']
995        pmax = maxlen(result, 'port', 4)
996        vmax = maxlen(result, 'vid', 4)
997        mmax = maxlen(result, 'mac', 3)
998        amax = maxlen(result, 'age', 3)
999        fmt = %%%d%%%d%%-%d%%%ds' % (pmax, vmax, mmax, amax)
1000        print(fmt % ('Port', 'VLAN', 'MAC', 'Age'))
1001        for r in result:
1002            print(fmt % (r['port'], r['vid'], r['mac'], r['age']))
1003
1004    locals()['handle_ctl_' + args.control_method](args)
1005
1006
1007def _main():
1008    parser = argparse.ArgumentParser()
1009    subcommand = parser.add_subparsers(dest='subcommand')
1010
1011    # - sw
1012    parser_sw = subcommand.add_parser('sw')
1013
1014    parser_sw.add_argument('--debug', action='store_true', default=False)
1015    parser_sw.add_argument('--foreground', action='store_true', default=False)
1016    parser_sw.add_argument('--ageout', type=int, default=300)
1017
1018    parser_sw.add_argument('--path', default='/')
1019    parser_sw.add_argument('--host', default='')
1020    parser_sw.add_argument('--port', type=int)
1021    parser_sw.add_argument('--htpasswd')
1022    parser_sw.add_argument('--sslkey')
1023    parser_sw.add_argument('--sslcert')
1024
1025    parser_sw.add_argument('--ctlpath', default='/ctl')
1026    parser_sw.add_argument('--ctlhost', default='127.0.0.1')
1027    parser_sw.add_argument('--ctlport', type=int, default=7867)
1028    parser_sw.add_argument('--ctlhtpasswd')
1029    parser_sw.add_argument('--ctlsslkey')
1030    parser_sw.add_argument('--ctlsslcert')
1031
1032    # - ctl
1033    parser_ctl = subcommand.add_parser('ctl')
1034    parser_ctl.add_argument('--ctlurl', default='http://127.0.0.1:7867/ctl')
1035    parser_ctl.add_argument('--ctluser')
1036    parser_ctl.add_argument('--ctlpasswd')
1037
1038    control_method = parser_ctl.add_subparsers(dest='control_method')
1039
1040    # -- ctl addport
1041    parser_ctl_addport = control_method.add_parser('addport')
1042    iftype = parser_ctl_addport.add_subparsers(dest='iftype')
1043
1044    # --- ctl addport tap
1045    parser_ctl_addport_tap = iftype.add_parser(TapHandler.IFTYPE)
1046    parser_ctl_addport_tap.add_argument('target')
1047
1048    # --- ctl addport client
1049    parser_ctl_addport_client = iftype.add_parser(EtherWebSocketClient.IFTYPE)
1050    parser_ctl_addport_client.add_argument('target')
1051    parser_ctl_addport_client.add_argument('--user')
1052    parser_ctl_addport_client.add_argument('--passwd')
1053    parser_ctl_addport_client.add_argument('--cacerts')
1054    parser_ctl_addport_client.add_argument(
1055        '--insecure', action='store_true', default=False)
1056
1057    # -- ctl setport
1058    parser_ctl_setport = control_method.add_parser('setport')
1059    parser_ctl_setport.add_argument('port', type=int)
1060    parser_ctl_setport.add_argument('--shut', type=int, choices=(0, 1))
1061
1062    # -- ctl delport
1063    parser_ctl_delport = control_method.add_parser('delport')
1064    parser_ctl_delport.add_argument('port', type=int)
1065
1066    # -- ctl listport
1067    parser_ctl_listport = control_method.add_parser('listport')
1068
1069    # -- ctl setif
1070    parser_ctl_setif = control_method.add_parser('setif')
1071    parser_ctl_setif.add_argument('port', type=int)
1072    parser_ctl_setif.add_argument('--address')
1073    parser_ctl_setif.add_argument('--netmask')
1074    parser_ctl_setif.add_argument('--mtu', type=int)
1075
1076    # -- ctl listif
1077    parser_ctl_listif = control_method.add_parser('listif')
1078
1079    # -- ctl listfdb
1080    parser_ctl_listfdb = control_method.add_parser('listfdb')
1081
1082    # -- go
1083    args = parser.parse_args()
1084
1085    try:
1086        globals()['_start_' + args.subcommand](args)
1087    except Exception as e:
1088        _print_error({
1089            'code':    0 - 32603,
1090            'message': 'Internal error',
1091            'data':    '%s: %s' % (e.__class__.__name__, str(e)),
1092        })
1093
1094
1095if __name__ == '__main__':
1096    _main()
Note: See TracBrowser for help on using the repository browser.