source: etherws/trunk/etherws.py @ 253

Revision 253, 40.9 KB checked in by atzm, 11 years ago (diff)

refactoring

  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.3
11#
12# ===========================================================================
13# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
42import ssl
43import time
44import json
45import fcntl
46import base64
47import socket
48import urllib2
49import hashlib
50import getpass
51import argparse
52import traceback
53
54import tornado
55import websocket
56
57from tornado.web import Application, RequestHandler
58from tornado.websocket import WebSocketHandler
59from tornado.httpserver import HTTPServer
60from tornado.ioloop import IOLoop
61
62from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
63
64
65class DebugMixIn(object):
66    def dprintf(self, msg, func=lambda: ()):
67        if self._debug:
68            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
69            sys.stderr.write(prefix + (msg % func()))
70
71
72class EthernetFrame(object):
73    def __init__(self, data):
74        self.data = data
75
76    @property
77    def dst_multicast(self):
78        return ord(self.data[0]) & 1
79
80    @property
81    def src_multicast(self):
82        return ord(self.data[6]) & 1
83
84    @property
85    def dst_mac(self):
86        return self.data[:6]
87
88    @property
89    def src_mac(self):
90        return self.data[6:12]
91
92    @property
93    def tagged(self):
94        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
95
96    @property
97    def vid(self):
98        if self.tagged:
99            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
100        return 0
101
102    @staticmethod
103    def format_mac(mac, sep=':'):
104        return sep.join(b.encode('hex') for b in mac)
105
106
107class FDB(DebugMixIn):
108    class Entry(object):
109        def __init__(self, port, ageout):
110            self.port = port
111            self._time = time.time()
112            self._ageout = ageout
113
114        @property
115        def age(self):
116            return time.time() - self._time
117
118        @property
119        def agedout(self):
120            return self.age > self._ageout
121
122    def __init__(self, ageout, debug):
123        self._ageout = ageout
124        self._debug = debug
125        self._table = {}
126
127    def _set_entry(self, vid, mac, port):
128        if vid not in self._table:
129            self._table[vid] = {}
130        self._table[vid][mac] = self.Entry(port, self._ageout)
131
132    def _del_entry(self, vid, mac):
133        if vid in self._table:
134            if mac in self._table[vid]:
135                del self._table[vid][mac]
136            if not self._table[vid]:
137                del self._table[vid]
138
139    def _get_entry(self, vid, mac):
140        try:
141            entry = self._table[vid][mac]
142        except KeyError:
143            return None
144
145        if not entry.agedout:
146            return entry
147
148        self._del_entry(vid, mac)
149        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
150                     lambda: (entry.port.number, vid, mac.encode('hex')))
151
152    def each(self):
153        for vid in sorted(self._table.iterkeys()):
154            for mac in sorted(self._table[vid].iterkeys()):
155                entry = self._get_entry(vid, mac)
156                if entry:
157                    yield (vid, mac, entry)
158
159    def lookup(self, frame):
160        mac = frame.dst_mac
161        vid = frame.vid
162        entry = self._get_entry(vid, mac)
163        return getattr(entry, 'port', None)
164
165    def learn(self, port, frame):
166        mac = frame.src_mac
167        vid = frame.vid
168        self._set_entry(vid, mac, port)
169        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
170                     lambda: (port.number, vid, mac.encode('hex')))
171
172    def delete(self, port):
173        for vid, mac, entry in self.each():
174            if entry.port.number == port.number:
175                self._del_entry(vid, mac)
176                self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
177                             lambda: (port.number, vid, mac.encode('hex')))
178
179
180class SwitchingHub(DebugMixIn):
181    class Port(object):
182        def __init__(self, number, interface):
183            self.number = number
184            self.interface = interface
185            self.tx = 0
186            self.rx = 0
187            self.shut = False
188
189        @staticmethod
190        def cmp_by_number(x, y):
191            return cmp(x.number, y.number)
192
193    def __init__(self, fdb, debug):
194        self.fdb = fdb
195        self._debug = debug
196        self._table = {}
197        self._next = 1
198
199    @property
200    def portlist(self):
201        return sorted(self._table.itervalues(), cmp=self.Port.cmp_by_number)
202
203    def get_port(self, portnum):
204        return self._table[portnum]
205
206    def register_port(self, interface):
207        try:
208            self._set_privattr('portnum', interface, self._next)  # XXX
209            self._table[self._next] = self.Port(self._next, interface)
210            return self._next
211        finally:
212            self._next += 1
213
214    def unregister_port(self, interface):
215        portnum = self._get_privattr('portnum', interface)
216        self._del_privattr('portnum', interface)
217        self.fdb.delete(self._table[portnum])
218        del self._table[portnum]
219
220    def send(self, dst_interfaces, frame):
221        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
222        ports = (self._table[n] for n in portnums)
223        ports = (p for p in ports if not p.shut)
224        ports = sorted(ports, cmp=self.Port.cmp_by_number)
225
226        for p in ports:
227            p.interface.write_message(frame.data, True)
228            p.tx += 1
229
230        if ports:
231            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
232                         lambda: (','.join(str(p.number) for p in ports),
233                                  frame.vid,
234                                  frame.src_mac.encode('hex'),
235                                  frame.dst_mac.encode('hex')))
236
237    def receive(self, src_interface, frame):
238        port = self._table[self._get_privattr('portnum', src_interface)]
239
240        if not port.shut:
241            port.rx += 1
242            self._forward(port, frame)
243
244    def _forward(self, src_port, frame):
245        try:
246            if not frame.src_multicast:
247                self.fdb.learn(src_port, frame)
248
249            if not frame.dst_multicast:
250                dst_port = self.fdb.lookup(frame)
251
252                if dst_port:
253                    self.send([dst_port.interface], frame)
254                    return
255
256            ports = set(self.portlist) - set([src_port])
257            self.send((p.interface for p in ports), frame)
258
259        except:  # ex. received invalid frame
260            traceback.print_exc()
261
262    def _privattr(self, name):
263        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
264
265    def _set_privattr(self, name, obj, value):
266        return setattr(obj, self._privattr(name), value)
267
268    def _get_privattr(self, name, obj, defaults=None):
269        return getattr(obj, self._privattr(name), defaults)
270
271    def _del_privattr(self, name, obj):
272        return delattr(obj, self._privattr(name))
273
274
275class Htpasswd(object):
276    def __init__(self, path):
277        self._path = path
278        self._stat = None
279        self._data = {}
280
281    def auth(self, name, passwd):
282        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
283        return self._data.get(name) == passwd
284
285    def load(self):
286        old_stat = self._stat
287
288        with open(self._path) as fp:
289            fileno = fp.fileno()
290            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
291            self._stat = os.fstat(fileno)
292
293            unchanged = old_stat and \
294                        old_stat.st_ino == self._stat.st_ino and \
295                        old_stat.st_dev == self._stat.st_dev and \
296                        old_stat.st_mtime == self._stat.st_mtime
297
298            if not unchanged:
299                self._data = self._parse(fp)
300
301        return self
302
303    def _parse(self, fp):
304        data = {}
305        for line in fp:
306            line = line.strip()
307            if 0 <= line.find(':'):
308                name, passwd = line.split(':', 1)
309                if passwd.startswith('{SHA}'):
310                    data[name] = passwd[5:]
311        return data
312
313
314class BasicAuthMixIn(object):
315    def _execute(self, transforms, *args, **kwargs):
316        def do_execute():
317            sp = super(BasicAuthMixIn, self)
318            return sp._execute(transforms, *args, **kwargs)
319
320        def auth_required():
321            stream = getattr(self, 'stream', self.request.connection.stream)
322            stream.write(tornado.escape.utf8(
323                'HTTP/1.1 401 Authorization Required\r\n'
324                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
325            ))
326            stream.close()
327
328        try:
329            if not self._htpasswd:
330                return do_execute()
331
332            creds = self.request.headers.get('Authorization')
333
334            if not creds or not creds.startswith('Basic '):
335                return auth_required()
336
337            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
338
339            if self._htpasswd.load().auth(name, passwd):
340                return do_execute()
341        except:
342            traceback.print_exc()
343
344        return auth_required()
345
346
347class ServerHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
348    IFTYPE = 'server'
349    IFOP_ALLOWED = False
350
351    def __init__(self, app, req, switch, htpasswd, debug):
352        super(ServerHandler, self).__init__(app, req)
353        self._switch = switch
354        self._htpasswd = htpasswd
355        self._debug = debug
356
357    @property
358    def target(self):
359        return ':'.join(str(e) for e in self.request.connection.address)
360
361    def open(self):
362        try:
363            return self._switch.register_port(self)
364        finally:
365            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
366
367    def on_message(self, message):
368        self._switch.receive(self, EthernetFrame(message))
369
370    def on_close(self):
371        self._switch.unregister_port(self)
372        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
373
374
375class BaseClientHandler(DebugMixIn):
376    IFTYPE = 'baseclient'
377    IFOP_ALLOWED = False
378
379    def __init__(self, ioloop, switch, target, debug, *args, **kwargs):
380        self._ioloop = ioloop
381        self._switch = switch
382        self._target = target
383        self._debug = debug
384        self._args = args
385        self._kwargs = kwargs
386        self._device = None
387
388    @property
389    def address(self):
390        raise NotImplementedError('unsupported')
391
392    @property
393    def netmask(self):
394        raise NotImplementedError('unsupported')
395
396    @property
397    def mtu(self):
398        raise NotImplementedError('unsupported')
399
400    @address.setter
401    def address(self, address):
402        raise NotImplementedError('unsupported')
403
404    @netmask.setter
405    def netmask(self, netmask):
406        raise NotImplementedError('unsupported')
407
408    @mtu.setter
409    def mtu(self, mtu):
410        raise NotImplementedError('unsupported')
411
412    def open(self):
413        raise NotImplementedError('unsupported')
414
415    def write_message(self, message, binary=False):
416        raise NotImplementedError('unsupported')
417
418    def read(self):
419        raise NotImplementedError('unsupported')
420
421    @property
422    def target(self):
423        return self._target
424
425    @property
426    def device(self):
427        return self._device
428
429    @property
430    def closed(self):
431        return not self.device
432
433    def close(self):
434        if self.closed:
435            raise ValueError('I/O operation on closed %s' % self.IFTYPE)
436        self.leave_switch()
437        self.unregister_device()
438        self.dprintf('disconnected: %s\n', lambda: self.target)
439
440    def register_device(self, device):
441        self._device = device
442
443    def unregister_device(self):
444        self._device.close()
445        self._device = None
446
447    def fileno(self):
448        if self.closed:
449            raise ValueError('I/O operation on closed %s' % self.IFTYPE)
450        return self.device.fileno()
451
452    def __call__(self, fd, events):
453        try:
454            data = self.read()
455            if data is not None:
456                self._switch.receive(self, EthernetFrame(data))
457                return
458        except:
459            traceback.print_exc()
460        self.close()
461
462    def join_switch(self):
463        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
464        return self._switch.register_port(self)
465
466    def leave_switch(self):
467        self._switch.unregister_port(self)
468        self._ioloop.remove_handler(self.fileno())
469
470
471class NetdevHandler(BaseClientHandler):
472    IFTYPE = 'netdev'
473    IFOP_ALLOWED = True
474    ETH_P_ALL = 0x0003  # from <linux/if_ether.h>
475
476    @property
477    def address(self):
478        if self.closed:
479            raise ValueError('I/O operation on closed netdev')
480        return ''
481
482    @property
483    def netmask(self):
484        if self.closed:
485            raise ValueError('I/O operation on closed netdev')
486        return ''
487
488    @property
489    def mtu(self):
490        if self.closed:
491            raise ValueError('I/O operation on closed netdev')
492        return ''
493
494    @address.setter
495    def address(self, address):
496        if self.closed:
497            raise ValueError('I/O operation on closed netdev')
498        raise NotImplementedError('unsupported')
499
500    @netmask.setter
501    def netmask(self, netmask):
502        if self.closed:
503            raise ValueError('I/O operation on closed netdev')
504        raise NotImplementedError('unsupported')
505
506    @mtu.setter
507    def mtu(self, mtu):
508        if self.closed:
509            raise ValueError('I/O operation on closed netdev')
510        raise NotImplementedError('unsupported')
511
512    def open(self):
513        if not self.closed:
514            raise ValueError('Already opened')
515        self.register_device(socket.socket(
516            socket.PF_PACKET, socket.SOCK_RAW, socket.htons(self.ETH_P_ALL)))
517        self.device.bind((self.target, self.ETH_P_ALL))
518        return self.join_switch()
519
520    def write_message(self, message, binary=False):
521        if self.closed:
522            raise ValueError('I/O operation on closed netdev')
523        self.device.sendall(message)
524
525    def read(self):
526        if self.closed:
527            raise ValueError('I/O operation on closed netdev')
528        buf = []
529        while True:
530            buf.append(self.device.recv(65535))
531            if len(buf[-1]) < 65535:
532                break
533        return ''.join(buf)
534
535
536class TapHandler(BaseClientHandler):
537    IFTYPE = 'tap'
538    IFOP_ALLOWED = True
539
540    @property
541    def address(self):
542        if self.closed:
543            raise ValueError('I/O operation on closed tap')
544        try:
545            return self.device.addr
546        except:
547            return ''
548
549    @property
550    def netmask(self):
551        if self.closed:
552            raise ValueError('I/O operation on closed tap')
553        try:
554            return self.device.netmask
555        except:
556            return ''
557
558    @property
559    def mtu(self):
560        if self.closed:
561            raise ValueError('I/O operation on closed tap')
562        return self.device.mtu
563
564    @address.setter
565    def address(self, address):
566        if self.closed:
567            raise ValueError('I/O operation on closed tap')
568        self.device.addr = address
569
570    @netmask.setter
571    def netmask(self, netmask):
572        if self.closed:
573            raise ValueError('I/O operation on closed tap')
574        self.device.netmask = netmask
575
576    @mtu.setter
577    def mtu(self, mtu):
578        if self.closed:
579            raise ValueError('I/O operation on closed tap')
580        self.device.mtu = mtu
581
582    @property
583    def target(self):
584        if self.closed:
585            return self._target
586        return self.device.name
587
588    def open(self):
589        if not self.closed:
590            raise ValueError('Already opened')
591        self.register_device(TunTapDevice(self.target, IFF_TAP | IFF_NO_PI))
592        self.device.up()
593        return self.join_switch()
594
595    def write_message(self, message, binary=False):
596        if self.closed:
597            raise ValueError('I/O operation on closed tap')
598        self.device.write(message)
599
600    def read(self):
601        if self.closed:
602            raise ValueError('I/O operation on closed tap')
603        buf = []
604        while True:
605            buf.append(self.device.read(65535))
606            if len(buf[-1]) < 65535:
607                break
608        return ''.join(buf)
609
610
611class ClientHandler(BaseClientHandler):
612    IFTYPE = 'client'
613    IFOP_ALLOWED = False
614
615    def __init__(self, *args, **kwargs):
616        super(ClientHandler, self).__init__(*args, **kwargs)
617
618        self._ssl = kwargs.get('ssl', False)
619        self._options = {}
620
621        cred = kwargs.get('cred', None)
622
623        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
624            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
625            auth = ['Authorization: Basic %s' % token]
626            self._options['header'] = auth
627
628    def open(self):
629        sslwrap = websocket._SSLSocketWrapper
630
631        if not self.closed:
632            raise websocket.WebSocketException('Already opened')
633
634        if self._ssl:
635            websocket._SSLSocketWrapper = self._ssl
636
637        # XXX: may be blocked
638        try:
639            self.register_device(websocket.WebSocket())
640            self.device.connect(self.target, **self._options)
641            self.dprintf('connected: %s\n', lambda: self.target)
642            return self.join_switch()
643        finally:
644            websocket._SSLSocketWrapper = sslwrap
645
646    def fileno(self):
647        if self.closed:
648            raise websocket.WebSocketException('Closed socket')
649        return self.device.io_sock.fileno()
650
651    def write_message(self, message, binary=False):
652        if self.closed:
653            raise websocket.WebSocketException('Closed socket')
654        if binary:
655            flag = websocket.ABNF.OPCODE_BINARY
656        else:
657            flag = websocket.ABNF.OPCODE_TEXT
658        self.device.send(message, flag)
659
660    def read(self):
661        if self.closed:
662            raise websocket.WebSocketException('Closed socket')
663        return self.device.recv()
664
665
666class ControlServerHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
667    NAMESPACE = 'etherws.control'
668    IFTYPES = {
669        NetdevHandler.IFTYPE: NetdevHandler,
670        TapHandler.IFTYPE:    TapHandler,
671        ClientHandler.IFTYPE: ClientHandler,
672    }
673
674    def __init__(self, app, req, ioloop, switch, htpasswd, debug):
675        super(ControlServerHandler, self).__init__(app, req)
676        self._ioloop = ioloop
677        self._switch = switch
678        self._htpasswd = htpasswd
679        self._debug = debug
680
681    def post(self):
682        try:
683            request = json.loads(self.request.body)
684        except Exception as e:
685            return self._jsonrpc_response(error={
686                'code':    0 - 32700,
687                'message': 'Parse error',
688                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
689            })
690
691        try:
692            id_ = request.get('id')
693            params = request.get('params')
694            version = request['jsonrpc']
695            method = request['method']
696            if version != '2.0':
697                raise ValueError('Invalid JSON-RPC version: %s' % version)
698        except Exception as e:
699            return self._jsonrpc_response(id_=id_, error={
700                'code':    0 - 32600,
701                'message': 'Invalid Request',
702                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
703            })
704
705        try:
706            if not method.startswith(self.NAMESPACE + '.'):
707                raise ValueError('Invalid method namespace: %s' % method)
708            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
709            handler = getattr(self, handler)
710        except Exception as e:
711            return self._jsonrpc_response(id_=id_, error={
712                'code':    0 - 32601,
713                'message': 'Method not found',
714                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
715            })
716
717        try:
718            return self._jsonrpc_response(id_=id_, result=handler(params))
719        except Exception as e:
720            traceback.print_exc()
721            return self._jsonrpc_response(id_=id_, error={
722                'code':    0 - 32602,
723                'message': 'Invalid params',
724                'data':     '%s: %s' % (e.__class__.__name__, str(e)),
725            })
726
727    def handle_listFdb(self, params):
728        list_ = []
729        for vid, mac, entry in self._switch.fdb.each():
730            list_.append({
731                'vid':  vid,
732                'mac':  EthernetFrame.format_mac(mac),
733                'port': entry.port.number,
734                'age':  int(entry.age),
735            })
736        return {'entries': list_}
737
738    def handle_listPort(self, params):
739        return {'entries': [self._portstat(p) for p in self._switch.portlist]}
740
741    def handle_addPort(self, params):
742        type_ = params['type']
743        target = params['target']
744        opt = getattr(self, '_optparse_' + type_)(params.get('options', {}))
745        cls = self.IFTYPES[type_]
746        interface = cls(self._ioloop, self._switch, target, self._debug, **opt)
747        portnum = interface.open()
748        return {'entries': [self._portstat(self._switch.get_port(portnum))]}
749
750    def handle_setPort(self, params):
751        port = self._switch.get_port(int(params['port']))
752        shut = params.get('shut')
753        if shut is not None:
754            port.shut = bool(shut)
755        return {'entries': [self._portstat(port)]}
756
757    def handle_delPort(self, params):
758        port = self._switch.get_port(int(params['port']))
759        port.interface.close()
760        return {'entries': [self._portstat(port)]}
761
762    def handle_setInterface(self, params):
763        portnum = int(params['port'])
764        port = self._switch.get_port(portnum)
765        address = params.get('address')
766        netmask = params.get('netmask')
767        mtu = params.get('mtu')
768        if not port.interface.IFOP_ALLOWED:
769            raise ValueError('Port %d has unsupported interface: %s' %
770                             (portnum, port.interface.IFTYPE))
771        if address is not None:
772            port.interface.address = address
773        if netmask is not None:
774            port.interface.netmask = netmask
775        if mtu is not None:
776            port.interface.mtu = mtu
777        return {'entries': [self._ifstat(port)]}
778
779    def handle_listInterface(self, params):
780        return {'entries': [self._ifstat(p) for p in self._switch.portlist
781                            if p.interface.IFOP_ALLOWED]}
782
783    def _optparse_netdev(self, opt):
784        return {}
785
786    def _optparse_tap(self, opt):
787        return {}
788
789    def _optparse_client(self, opt):
790        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
791        if opt.get('insecure'):
792            args = {}
793        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
794        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
795        return {'ssl': ssl_, 'cred': cred}
796
797    def _jsonrpc_response(self, id_=None, result=None, error=None):
798        res = {'jsonrpc': '2.0', 'id': id_}
799        if result:
800            res['result'] = result
801        if error:
802            res['error'] = error
803        self.finish(res)
804
805    @staticmethod
806    def _portstat(port):
807        return {
808            'port':   port.number,
809            'type':   port.interface.IFTYPE,
810            'target': port.interface.target,
811            'tx':     port.tx,
812            'rx':     port.rx,
813            'shut':   port.shut,
814        }
815
816    @staticmethod
817    def _ifstat(port):
818        return {
819            'port':    port.number,
820            'type':    port.interface.IFTYPE,
821            'target':  port.interface.target,
822            'address': port.interface.address,
823            'netmask': port.interface.netmask,
824            'mtu':     port.interface.mtu,
825        }
826
827
828def _print_error(error):
829    print(%s (%s)' % (error['message'], error['code']))
830    print('    %s' % error['data'])
831
832
833def _start_sw(args):
834    def daemonize(nochdir=False, noclose=False):
835        if os.fork() > 0:
836            sys.exit(0)
837
838        os.setsid()
839
840        if os.fork() > 0:
841            sys.exit(0)
842
843        if not nochdir:
844            os.chdir('/')
845
846        if not noclose:
847            os.umask(0)
848            sys.stdin.close()
849            sys.stdout.close()
850            sys.stderr.close()
851            os.close(0)
852            os.close(1)
853            os.close(2)
854            sys.stdin = open(os.devnull)
855            sys.stdout = open(os.devnull, 'a')
856            sys.stderr = open(os.devnull, 'a')
857
858    def checkabspath(ns, path):
859        val = getattr(ns, path, '')
860        if not val.startswith('/'):
861            raise ValueError('Invalid %: %s' % (path, val))
862
863    def getsslopt(ns, key, cert):
864        kval = getattr(ns, key, None)
865        cval = getattr(ns, cert, None)
866        if kval and cval:
867            return {'keyfile': kval, 'certfile': cval}
868        elif kval or cval:
869            raise ValueError('Both %s and %s are required' % (key, cert))
870        return None
871
872    def setrealpath(ns, *keys):
873        for k in keys:
874            v = getattr(ns, k, None)
875            if v is not None:
876                v = os.path.realpath(v)
877                open(v).close()  # check readable
878                setattr(ns, k, v)
879
880    def setport(ns, port, isssl):
881        val = getattr(ns, port, None)
882        if val is None:
883            if isssl:
884                return setattr(ns, port, 443)
885            return setattr(ns, port, 80)
886        if not (0 <= val <= 65535):
887            raise ValueError('Invalid %s: %s' % (port, val))
888
889    def sethtpasswd(ns, htpasswd):
890        val = getattr(ns, htpasswd, None)
891        if val:
892            return setattr(ns, htpasswd, Htpasswd(val))
893
894    #if args.debug:
895    #    websocket.enableTrace(True)
896
897    if args.ageout <= 0:
898        raise ValueError('Invalid ageout: %s' % args.ageout)
899
900    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
901    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
902
903    checkabspath(args, 'path')
904    checkabspath(args, 'ctlpath')
905
906    sslopt = getsslopt(args, 'sslkey', 'sslcert')
907    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
908
909    setport(args, 'port', sslopt)
910    setport(args, 'ctlport', ctlsslopt)
911
912    sethtpasswd(args, 'htpasswd')
913    sethtpasswd(args, 'ctlhtpasswd')
914
915    ioloop = IOLoop.instance()
916    fdb = FDB(args.ageout, args.debug)
917    switch = SwitchingHub(fdb, args.debug)
918
919    if args.port == args.ctlport and args.host == args.ctlhost:
920        if args.path == args.ctlpath:
921            raise ValueError('Same path/ctlpath on same host')
922        if args.sslkey != args.ctlsslkey:
923            raise ValueError('Different sslkey/ctlsslkey on same host')
924        if args.sslcert != args.ctlsslcert:
925            raise ValueError('Different sslcert/ctlsslcert on same host')
926
927        app = Application([
928            (args.path, ServerHandler, {
929                'switch':   switch,
930                'htpasswd': args.htpasswd,
931                'debug':    args.debug,
932            }),
933            (args.ctlpath, ControlServerHandler, {
934                'ioloop':   ioloop,
935                'switch':   switch,
936                'htpasswd': args.ctlhtpasswd,
937                'debug':    args.debug,
938            }),
939        ])
940        server = HTTPServer(app, ssl_options=sslopt)
941        server.listen(args.port, address=args.host)
942
943    else:
944        app = Application([(args.path, ServerHandler, {
945            'switch':   switch,
946            'htpasswd': args.htpasswd,
947            'debug':    args.debug,
948        })])
949        server = HTTPServer(app, ssl_options=sslopt)
950        server.listen(args.port, address=args.host)
951
952        ctl = Application([(args.ctlpath, ControlServerHandler, {
953            'ioloop':   ioloop,
954            'switch':   switch,
955            'htpasswd': args.ctlhtpasswd,
956            'debug':    args.debug,
957        })])
958        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
959        ctlserver.listen(args.ctlport, address=args.ctlhost)
960
961    if not args.foreground:
962        daemonize()
963
964    ioloop.start()
965
966
967def _start_ctl(args):
968    def request(args, method, params=None, id_=0):
969        req = urllib2.Request(args.ctlurl)
970        req.add_header('Content-type', 'application/json')
971        if args.ctluser:
972            if not args.ctlpasswd:
973                args.ctlpasswd = getpass.getpass('Control Password: ')
974            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
975            req.add_header('Authorization', 'Basic %s' % token)
976        method = '.'.join([ControlServerHandler.NAMESPACE, method])
977        data = {'jsonrpc': '2.0', 'method': method, 'id': id_}
978        if params is not None:
979            data['params'] = params
980        return json.loads(urllib2.urlopen(req, json.dumps(data)).read())
981
982    def maxlen(dict_, key, min_):
983        if not dict_:
984            return min_
985        max_ = max(len(str(r[key])) for r in dict_)
986        return min_ if max_ < min_ else max_
987
988    def print_portlist(result):
989        pmax = maxlen(result, 'port', 4)
990        ymax = maxlen(result, 'type', 4)
991        smax = maxlen(result, 'shut', 5)
992        rmax = maxlen(result, 'rx', 2)
993        tmax = maxlen(result, 'tx', 2)
994        fmt = %%%d%%%d%%%d%%%d%%%d%%s' % \
995              (pmax, ymax, smax, rmax, tmax)
996        print(fmt % ('Port', 'Type', 'State', 'RX', 'TX', 'Target'))
997        for r in result:
998            shut = 'shut' if r['shut'] else 'up'
999            print(fmt %
1000                  (r['port'], r['type'], shut, r['rx'], r['tx'], r['target']))
1001
1002    def print_iflist(result):
1003        pmax = maxlen(result, 'port', 4)
1004        tmax = maxlen(result, 'type', 4)
1005        amax = maxlen(result, 'address', 7)
1006        nmax = maxlen(result, 'netmask', 7)
1007        mmax = maxlen(result, 'mtu', 3)
1008        fmt = %%%d%%%d%%%d%%%d%%%d%%s' % \
1009              (pmax, tmax, amax, nmax, mmax)
1010        print(fmt % ('Port', 'Type', 'Address', 'Netmask', 'MTU', 'Target'))
1011        for r in result:
1012            print(fmt % (r['port'], r['type'],
1013                         r['address'], r['netmask'], r['mtu'], r['target']))
1014
1015    def handle_ctl_addport(args):
1016        opts = {
1017            'user':     getattr(args, 'user', None),
1018            'passwd':   getattr(args, 'passwd', None),
1019            'cacerts':  getattr(args, 'cacerts', None),
1020            'insecure': getattr(args, 'insecure', None),
1021        }
1022        if args.iftype == ClientHandler.IFTYPE:
1023            if not args.target.startswith('ws://') and \
1024               not args.target.startswith('wss://'):
1025                raise ValueError('Invalid target URL scheme: %s' % args.target)
1026            if not opts['user'] and opts['passwd']:
1027                raise ValueError('Authentication required but username empty')
1028            if opts['user'] and not opts['passwd']:
1029                opts['passwd'] = getpass.getpass('Client Password: ')
1030        result = request(args, 'addPort', {
1031            'type':    args.iftype,
1032            'target':  args.target,
1033            'options': opts,
1034        })
1035        if 'error' in result:
1036            _print_error(result['error'])
1037        else:
1038            print_portlist(result['result']['entries'])
1039
1040    def handle_ctl_setport(args):
1041        if args.port <= 0:
1042            raise ValueError('Invalid port: %d' % args.port)
1043        req = {'port': args.port}
1044        shut = getattr(args, 'shut', None)
1045        if shut is not None:
1046            req['shut'] = bool(shut)
1047        result = request(args, 'setPort', req)
1048        if 'error' in result:
1049            _print_error(result['error'])
1050        else:
1051            print_portlist(result['result']['entries'])
1052
1053    def handle_ctl_delport(args):
1054        if args.port <= 0:
1055            raise ValueError('Invalid port: %d' % args.port)
1056        result = request(args, 'delPort', {'port': args.port})
1057        if 'error' in result:
1058            _print_error(result['error'])
1059        else:
1060            print_portlist(result['result']['entries'])
1061
1062    def handle_ctl_listport(args):
1063        result = request(args, 'listPort')
1064        if 'error' in result:
1065            _print_error(result['error'])
1066        else:
1067            print_portlist(result['result']['entries'])
1068
1069    def handle_ctl_setif(args):
1070        if args.port <= 0:
1071            raise ValueError('Invalid port: %d' % args.port)
1072        req = {'port': args.port}
1073        address = getattr(args, 'address', None)
1074        netmask = getattr(args, 'netmask', None)
1075        mtu = getattr(args, 'mtu', None)
1076        if address is not None:
1077            if address:
1078                socket.inet_aton(address)  # validate
1079            req['address'] = address
1080        if netmask is not None:
1081            if netmask:
1082                socket.inet_aton(netmask)  # validate
1083            req['netmask'] = netmask
1084        if mtu is not None:
1085            if mtu < 576:
1086                raise ValueError('Invalid MTU: %d' % mtu)
1087            req['mtu'] = mtu
1088        result = request(args, 'setInterface', req)
1089        if 'error' in result:
1090            _print_error(result['error'])
1091        else:
1092            print_iflist(result['result']['entries'])
1093
1094    def handle_ctl_listif(args):
1095        result = request(args, 'listInterface')
1096        if 'error' in result:
1097            _print_error(result['error'])
1098        else:
1099            print_iflist(result['result']['entries'])
1100
1101    def handle_ctl_listfdb(args):
1102        result = request(args, 'listFdb')
1103        if 'error' in result:
1104            return _print_error(result['error'])
1105        result = result['result']['entries']
1106        pmax = maxlen(result, 'port', 4)
1107        vmax = maxlen(result, 'vid', 4)
1108        mmax = maxlen(result, 'mac', 3)
1109        amax = maxlen(result, 'age', 3)
1110        fmt = %%%d%%%d%%-%d%%%ds' % (pmax, vmax, mmax, amax)
1111        print(fmt % ('Port', 'VLAN', 'MAC', 'Age'))
1112        for r in result:
1113            print(fmt % (r['port'], r['vid'], r['mac'], r['age']))
1114
1115    locals()['handle_ctl_' + args.control_method](args)
1116
1117
1118def _main():
1119    parser = argparse.ArgumentParser()
1120    subcommand = parser.add_subparsers(dest='subcommand')
1121
1122    # - sw
1123    parser_sw = subcommand.add_parser('sw',
1124                                      help='start virtual switch')
1125
1126    parser_sw.add_argument('--debug', action='store_true', default=False,
1127                           help='run as debug mode')
1128    parser_sw.add_argument('--foreground', action='store_true', default=False,
1129                           help='run as foreground mode')
1130    parser_sw.add_argument('--ageout', type=int, default=300,
1131                           help='FDB ageout time (sec)')
1132
1133    parser_sw.add_argument('--path', default='/',
1134                           help='http(s) path to serve WebSocket')
1135    parser_sw.add_argument('--host', default='',
1136                           help='listen address to serve WebSocket')
1137    parser_sw.add_argument('--port', type=int,
1138                           help='listen port to serve WebSocket')
1139    parser_sw.add_argument('--htpasswd',
1140                           help='path to htpasswd file to auth WebSocket')
1141    parser_sw.add_argument('--sslkey',
1142                           help='path to SSL private key for WebSocket')
1143    parser_sw.add_argument('--sslcert',
1144                           help='path to SSL certificate for WebSocket')
1145
1146    parser_sw.add_argument('--ctlpath', default='/ctl',
1147                           help='http(s) path to serve control API')
1148    parser_sw.add_argument('--ctlhost', default='127.0.0.1',
1149                           help='listen address to serve control API')
1150    parser_sw.add_argument('--ctlport', type=int, default=7867,
1151                           help='listen port to serve control API')
1152    parser_sw.add_argument('--ctlhtpasswd',
1153                           help='path to htpasswd file to auth control API')
1154    parser_sw.add_argument('--ctlsslkey',
1155                           help='path to SSL private key for control API')
1156    parser_sw.add_argument('--ctlsslcert',
1157                           help='path to SSL certificate for control API')
1158
1159    # - ctl
1160    parser_ctl = subcommand.add_parser('ctl',
1161                                       help='control virtual switch')
1162    parser_ctl.add_argument('--ctlurl', default='http://127.0.0.1:7867/ctl',
1163                            help='URL to control API')
1164    parser_ctl.add_argument('--ctluser',
1165                            help='username to auth control API')
1166    parser_ctl.add_argument('--ctlpasswd',
1167                            help='password to auth control API')
1168
1169    control_method = parser_ctl.add_subparsers(dest='control_method')
1170
1171    # -- ctl addport
1172    parser_ctl_addport = control_method.add_parser('addport',
1173                                                   help='create and add port')
1174    iftype = parser_ctl_addport.add_subparsers(dest='iftype')
1175
1176    # --- ctl addport netdev
1177    parser_ctl_addport_netdev = iftype.add_parser(NetdevHandler.IFTYPE,
1178                                                  help='netdev')
1179    parser_ctl_addport_netdev.add_argument('target',
1180                                           help='device name to add interface')
1181
1182    # --- ctl addport tap
1183    parser_ctl_addport_tap = iftype.add_parser(TapHandler.IFTYPE,
1184                                               help='TAP device')
1185    parser_ctl_addport_tap.add_argument('target',
1186                                        help='device name to create interface')
1187
1188    # --- ctl addport client
1189    parser_ctl_addport_client = iftype.add_parser(ClientHandler.IFTYPE,
1190                                                  help='WebSocket client')
1191    parser_ctl_addport_client.add_argument('target',
1192                                           help='URL to connect WebSocket')
1193    parser_ctl_addport_client.add_argument('--user',
1194                                           help='username to auth WebSocket')
1195    parser_ctl_addport_client.add_argument('--passwd',
1196                                           help='password to auth WebSocket')
1197    parser_ctl_addport_client.add_argument('--cacerts',
1198                                           help='path to CA certificate')
1199    parser_ctl_addport_client.add_argument(
1200        '--insecure', action='store_true', default=False,
1201        help='do not verify server certificate')
1202
1203    # -- ctl setport
1204    parser_ctl_setport = control_method.add_parser('setport',
1205                                                   help='set port config')
1206    parser_ctl_setport.add_argument('port', type=int,
1207                                    help='port number to set config')
1208    parser_ctl_setport.add_argument('--shut', type=int, choices=(0, 1),
1209                                    help='set shutdown state')
1210
1211    # -- ctl delport
1212    parser_ctl_delport = control_method.add_parser('delport',
1213                                                   help='delete port')
1214    parser_ctl_delport.add_argument('port', type=int,
1215                                    help='port number to delete')
1216
1217    # -- ctl listport
1218    parser_ctl_listport = control_method.add_parser('listport',
1219                                                    help='show port list')
1220
1221    # -- ctl setif
1222    parser_ctl_setif = control_method.add_parser('setif',
1223                                                 help='set interface config')
1224    parser_ctl_setif.add_argument('port', type=int,
1225                                  help='port number to set config')
1226    parser_ctl_setif.add_argument('--address',
1227                                  help='IPv4 address to set interface')
1228    parser_ctl_setif.add_argument('--netmask',
1229                                  help='IPv4 netmask to set interface')
1230    parser_ctl_setif.add_argument('--mtu', type=int,
1231                                  help='MTU to set interface')
1232
1233    # -- ctl listif
1234    parser_ctl_listif = control_method.add_parser('listif',
1235                                                  help='show interface list')
1236
1237    # -- ctl listfdb
1238    parser_ctl_listfdb = control_method.add_parser('listfdb',
1239                                                   help='show FDB entries')
1240
1241    # -- go
1242    args = parser.parse_args()
1243
1244    try:
1245        globals()['_start_' + args.subcommand](args)
1246    except Exception as e:
1247        _print_error({
1248            'code':    0 - 32603,
1249            'message': 'Internal error',
1250            'data':    '%s: %s' % (e.__class__.__name__, str(e)),
1251        })
1252
1253
1254if __name__ == '__main__':
1255    _main()
Note: See TracBrowser for help on using the repository browser.