source: etherws/trunk/etherws.py @ 258

Revision 258, 42.3 KB checked in by atzm, 11 years ago (diff)

tsume

  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.3
11#
12# ===========================================================================
13# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
42import ssl
43import time
44import json
45import fcntl
46import base64
47import socket
48import struct
49import urllib2
50import hashlib
51import getpass
52import argparse
53import traceback
54
55import tornado
56import websocket
57
58from tornado.web import Application, RequestHandler
59from tornado.websocket import WebSocketHandler
60from tornado.httpserver import HTTPServer
61from tornado.ioloop import IOLoop
62
63from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
64
65
66class DebugMixIn(object):
67    def dprintf(self, msg, func=lambda: ()):
68        if self._debug:
69            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
70            sys.stderr.write(prefix + (msg % func()))
71
72
73class EthernetFrame(object):
74    def __init__(self, data):
75        self.data = data
76
77    @property
78    def dst_multicast(self):
79        return ord(self.data[0]) & 1
80
81    @property
82    def src_multicast(self):
83        return ord(self.data[6]) & 1
84
85    @property
86    def dst_mac(self):
87        return self.data[:6]
88
89    @property
90    def src_mac(self):
91        return self.data[6:12]
92
93    @property
94    def tagged(self):
95        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
96
97    @property
98    def vid(self):
99        if self.tagged:
100            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
101        return 0
102
103    @staticmethod
104    def format_mac(mac, sep=':'):
105        return sep.join(b.encode('hex') for b in mac)
106
107
108class FDB(DebugMixIn):
109    class Entry(object):
110        def __init__(self, port, ageout):
111            self.port = port
112            self._time = time.time()
113            self._ageout = ageout
114
115        @property
116        def age(self):
117            return time.time() - self._time
118
119        @property
120        def agedout(self):
121            return self.age > self._ageout
122
123    def __init__(self, ageout, debug):
124        self._ageout = ageout
125        self._debug = debug
126        self._table = {}
127
128    def _set_entry(self, vid, mac, port):
129        if vid not in self._table:
130            self._table[vid] = {}
131        self._table[vid][mac] = self.Entry(port, self._ageout)
132
133    def _del_entry(self, vid, mac):
134        if vid in self._table:
135            if mac in self._table[vid]:
136                del self._table[vid][mac]
137            if not self._table[vid]:
138                del self._table[vid]
139
140    def _get_entry(self, vid, mac):
141        try:
142            entry = self._table[vid][mac]
143        except KeyError:
144            return None
145
146        if not entry.agedout:
147            return entry
148
149        self._del_entry(vid, mac)
150        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
151                     lambda: (entry.port.number, vid, mac.encode('hex')))
152
153    def each(self):
154        for vid in sorted(self._table.iterkeys()):
155            for mac in sorted(self._table[vid].iterkeys()):
156                entry = self._get_entry(vid, mac)
157                if entry:
158                    yield (vid, mac, entry)
159
160    def lookup(self, frame):
161        mac = frame.dst_mac
162        vid = frame.vid
163        entry = self._get_entry(vid, mac)
164        return getattr(entry, 'port', None)
165
166    def learn(self, port, frame):
167        mac = frame.src_mac
168        vid = frame.vid
169        self._set_entry(vid, mac, port)
170        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
171                     lambda: (port.number, vid, mac.encode('hex')))
172
173    def delete(self, port):
174        for vid, mac, entry in self.each():
175            if entry.port.number == port.number:
176                self._del_entry(vid, mac)
177                self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
178                             lambda: (port.number, vid, mac.encode('hex')))
179
180
181class SwitchingHub(DebugMixIn):
182    class Port(object):
183        def __init__(self, number, interface):
184            self.number = number
185            self.interface = interface
186            self.tx = 0
187            self.rx = 0
188            self.shut = False
189
190        @staticmethod
191        def cmp_by_number(x, y):
192            return cmp(x.number, y.number)
193
194    def __init__(self, fdb, debug):
195        self.fdb = fdb
196        self._debug = debug
197        self._table = {}
198        self._next = 1
199
200    @property
201    def portlist(self):
202        return sorted(self._table.itervalues(), cmp=self.Port.cmp_by_number)
203
204    def get_port(self, portnum):
205        return self._table[portnum]
206
207    def register_port(self, interface):
208        try:
209            self._set_privattr('portnum', interface, self._next)  # XXX
210            self._table[self._next] = self.Port(self._next, interface)
211            return self._next
212        finally:
213            self._next += 1
214
215    def unregister_port(self, interface):
216        portnum = self._get_privattr('portnum', interface)
217        self._del_privattr('portnum', interface)
218        self.fdb.delete(self._table[portnum])
219        del self._table[portnum]
220
221    def send(self, dst_interfaces, frame):
222        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
223        ports = (self._table[n] for n in portnums)
224        ports = (p for p in ports if not p.shut)
225        ports = sorted(ports, cmp=self.Port.cmp_by_number)
226
227        for p in ports:
228            p.interface.write_message(frame.data, True)
229            p.tx += 1
230
231        if ports:
232            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
233                         lambda: (','.join(str(p.number) for p in ports),
234                                  frame.vid,
235                                  frame.src_mac.encode('hex'),
236                                  frame.dst_mac.encode('hex')))
237
238    def receive(self, src_interface, frame):
239        port = self._table[self._get_privattr('portnum', src_interface)]
240
241        if not port.shut:
242            port.rx += 1
243            self._forward(port, frame)
244
245    def _forward(self, src_port, frame):
246        try:
247            if not frame.src_multicast:
248                self.fdb.learn(src_port, frame)
249
250            if not frame.dst_multicast:
251                dst_port = self.fdb.lookup(frame)
252
253                if dst_port:
254                    self.send([dst_port.interface], frame)
255                    return
256
257            ports = set(self.portlist) - set([src_port])
258            self.send((p.interface for p in ports), frame)
259
260        except:  # ex. received invalid frame
261            traceback.print_exc()
262
263    def _privattr(self, name):
264        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
265
266    def _set_privattr(self, name, obj, value):
267        return setattr(obj, self._privattr(name), value)
268
269    def _get_privattr(self, name, obj, defaults=None):
270        return getattr(obj, self._privattr(name), defaults)
271
272    def _del_privattr(self, name, obj):
273        return delattr(obj, self._privattr(name))
274
275
276class NetworkInterface(object):
277    SIOCGIFADDR = 0x8915  # from <linux/sockios.h>
278    SIOCSIFADDR = 0x8916
279    SIOCGIFNETMASK = 0x891b
280    SIOCSIFNETMASK = 0x891c
281    SIOCGIFMTU = 0x8921
282    SIOCSIFMTU = 0x8922
283
284    def __init__(self, ifname):
285        self._ifname = struct.pack('16s', str(ifname)[:15])
286
287    def _ioctl(self, req, data):
288        try:
289            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
290            return fcntl.ioctl(sock.fileno(), req, data)
291        finally:
292            sock.close()
293
294    @property
295    def address(self):
296        try:
297            data = struct.pack('16s18x', self._ifname)
298            ret = self._ioctl(self.SIOCGIFADDR, data)
299            return socket.inet_ntoa(ret[20:24])
300        except IOError:
301            return ''
302
303    @property
304    def netmask(self):
305        try:
306            data = struct.pack('16s18x', self._ifname)
307            ret = self._ioctl(self.SIOCGIFNETMASK, data)
308            return socket.inet_ntoa(ret[20:24])
309        except IOError:
310            return ''
311
312    @property
313    def mtu(self):
314        try:
315            data = struct.pack('16s18x', self._ifname)
316            ret = self._ioctl(self.SIOCGIFMTU, data)
317            return struct.unpack('i', ret[16:20])[0]
318        except IOError:
319            return 0
320
321    @address.setter
322    def address(self, addr):
323        data = struct.pack('16si4s10x', self._ifname, socket.AF_INET,
324                           socket.inet_aton(addr))
325        self._ioctl(self.SIOCSIFADDR, data)
326
327    @netmask.setter
328    def netmask(self, addr):
329        data = struct.pack('16si4s10x', self._ifname, socket.AF_INET,
330                           socket.inet_aton(addr))
331        self._ioctl(self.SIOCSIFNETMASK, data)
332
333    @mtu.setter
334    def mtu(self, mtu):
335        data = struct.pack('16si12x', self._ifname, mtu)
336        self._ioctl(self.SIOCSIFMTU, data)
337
338
339class Htpasswd(object):
340    def __init__(self, path):
341        self._path = path
342        self._stat = None
343        self._data = {}
344
345    def auth(self, name, passwd):
346        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
347        return self._data.get(name) == passwd
348
349    def load(self):
350        old_stat = self._stat
351
352        with open(self._path) as fp:
353            fileno = fp.fileno()
354            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
355            self._stat = os.fstat(fileno)
356
357            unchanged = old_stat and \
358                        old_stat.st_ino == self._stat.st_ino and \
359                        old_stat.st_dev == self._stat.st_dev and \
360                        old_stat.st_mtime == self._stat.st_mtime
361
362            if not unchanged:
363                self._data = self._parse(fp)
364
365        return self
366
367    def _parse(self, fp):
368        data = {}
369        for line in fp:
370            line = line.strip()
371            if 0 <= line.find(':'):
372                name, passwd = line.split(':', 1)
373                if passwd.startswith('{SHA}'):
374                    data[name] = passwd[5:]
375        return data
376
377
378class BasicAuthMixIn(object):
379    def _execute(self, transforms, *args, **kwargs):
380        def do_execute():
381            sp = super(BasicAuthMixIn, self)
382            return sp._execute(transforms, *args, **kwargs)
383
384        def auth_required():
385            stream = getattr(self, 'stream', self.request.connection.stream)
386            stream.write(tornado.escape.utf8(
387                'HTTP/1.1 401 Authorization Required\r\n'
388                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
389            ))
390            stream.close()
391
392        try:
393            if not self._htpasswd:
394                return do_execute()
395
396            creds = self.request.headers.get('Authorization')
397
398            if not creds or not creds.startswith('Basic '):
399                return auth_required()
400
401            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
402
403            if self._htpasswd.load().auth(name, passwd):
404                return do_execute()
405        except:
406            traceback.print_exc()
407
408        return auth_required()
409
410
411class ServerHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
412    IFTYPE = 'server'
413    IFOP_ALLOWED = False
414
415    def __init__(self, app, req, switch, htpasswd, debug):
416        super(ServerHandler, self).__init__(app, req)
417        self._switch = switch
418        self._htpasswd = htpasswd
419        self._debug = debug
420
421    @property
422    def target(self):
423        return ':'.join(str(e) for e in self.request.connection.address)
424
425    def open(self):
426        try:
427            return self._switch.register_port(self)
428        finally:
429            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
430
431    def on_message(self, message):
432        self._switch.receive(self, EthernetFrame(message))
433
434    def on_close(self):
435        self._switch.unregister_port(self)
436        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
437
438
439class BaseClientHandler(DebugMixIn):
440    IFTYPE = 'baseclient'
441    IFOP_ALLOWED = False
442
443    def __init__(self, ioloop, switch, target, debug, *args, **kwargs):
444        self._ioloop = ioloop
445        self._switch = switch
446        self._target = target
447        self._debug = debug
448        self._args = args
449        self._kwargs = kwargs
450        self._device = None
451
452    @property
453    def address(self):
454        raise NotImplementedError('unsupported')
455
456    @property
457    def netmask(self):
458        raise NotImplementedError('unsupported')
459
460    @property
461    def mtu(self):
462        raise NotImplementedError('unsupported')
463
464    @address.setter
465    def address(self, address):
466        raise NotImplementedError('unsupported')
467
468    @netmask.setter
469    def netmask(self, netmask):
470        raise NotImplementedError('unsupported')
471
472    @mtu.setter
473    def mtu(self, mtu):
474        raise NotImplementedError('unsupported')
475
476    def open(self):
477        raise NotImplementedError('unsupported')
478
479    def write_message(self, message, binary=False):
480        raise NotImplementedError('unsupported')
481
482    def read(self):
483        raise NotImplementedError('unsupported')
484
485    @property
486    def target(self):
487        return self._target
488
489    @property
490    def device(self):
491        return self._device
492
493    @property
494    def closed(self):
495        return not self.device
496
497    def close(self):
498        if self.closed:
499            raise ValueError('I/O operation on closed %s' % self.IFTYPE)
500        self.leave_switch()
501        self.unregister_device()
502        self.dprintf('disconnected: %s\n', lambda: self.target)
503
504    def register_device(self, device):
505        self._device = device
506
507    def unregister_device(self):
508        self._device.close()
509        self._device = None
510
511    def fileno(self):
512        if self.closed:
513            raise ValueError('I/O operation on closed %s' % self.IFTYPE)
514        return self.device.fileno()
515
516    def __call__(self, fd, events):
517        try:
518            data = self.read()
519            if data is not None:
520                self._switch.receive(self, EthernetFrame(data))
521                return
522        except:
523            traceback.print_exc()
524        self.close()
525
526    def join_switch(self):
527        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
528        return self._switch.register_port(self)
529
530    def leave_switch(self):
531        self._switch.unregister_port(self)
532        self._ioloop.remove_handler(self.fileno())
533
534
535class NetdevHandler(BaseClientHandler):
536    IFTYPE = 'netdev'
537    IFOP_ALLOWED = True
538    ETH_P_ALL = 0x0003  # from <linux/if_ether.h>
539
540    @property
541    def address(self):
542        if self.closed:
543            raise ValueError('I/O operation on closed netdev')
544        return NetworkInterface(self.target).address
545
546    @property
547    def netmask(self):
548        if self.closed:
549            raise ValueError('I/O operation on closed netdev')
550        return NetworkInterface(self.target).netmask
551
552    @property
553    def mtu(self):
554        if self.closed:
555            raise ValueError('I/O operation on closed netdev')
556        return NetworkInterface(self.target).mtu
557
558    @address.setter
559    def address(self, address):
560        if self.closed:
561            raise ValueError('I/O operation on closed netdev')
562        NetworkInterface(self.target).address = address
563
564    @netmask.setter
565    def netmask(self, netmask):
566        if self.closed:
567            raise ValueError('I/O operation on closed netdev')
568        NetworkInterface(self.target).netmask = netmask
569
570    @mtu.setter
571    def mtu(self, mtu):
572        if self.closed:
573            raise ValueError('I/O operation on closed netdev')
574        NetworkInterface(self.target).mtu = mtu
575
576    def open(self):
577        if not self.closed:
578            raise ValueError('Already opened')
579        self.register_device(socket.socket(
580            socket.PF_PACKET, socket.SOCK_RAW, socket.htons(self.ETH_P_ALL)))
581        self.device.bind((self.target, self.ETH_P_ALL))
582        self.dprintf('connected: %s\n', lambda: self.target)
583        return self.join_switch()
584
585    def write_message(self, message, binary=False):
586        if self.closed:
587            raise ValueError('I/O operation on closed netdev')
588        self.device.sendall(message)
589
590    def read(self):
591        if self.closed:
592            raise ValueError('I/O operation on closed netdev')
593        buf = []
594        while True:
595            buf.append(self.device.recv(65535))
596            if len(buf[-1]) < 65535:
597                break
598        return ''.join(buf)
599
600
601class TapHandler(BaseClientHandler):
602    IFTYPE = 'tap'
603    IFOP_ALLOWED = True
604
605    @property
606    def address(self):
607        if self.closed:
608            raise ValueError('I/O operation on closed tap')
609        try:
610            return self.device.addr
611        except:
612            return ''
613
614    @property
615    def netmask(self):
616        if self.closed:
617            raise ValueError('I/O operation on closed tap')
618        try:
619            return self.device.netmask
620        except:
621            return ''
622
623    @property
624    def mtu(self):
625        if self.closed:
626            raise ValueError('I/O operation on closed tap')
627        return self.device.mtu
628
629    @address.setter
630    def address(self, address):
631        if self.closed:
632            raise ValueError('I/O operation on closed tap')
633        self.device.addr = address
634
635    @netmask.setter
636    def netmask(self, netmask):
637        if self.closed:
638            raise ValueError('I/O operation on closed tap')
639        self.device.netmask = netmask
640
641    @mtu.setter
642    def mtu(self, mtu):
643        if self.closed:
644            raise ValueError('I/O operation on closed tap')
645        self.device.mtu = mtu
646
647    @property
648    def target(self):
649        if self.closed:
650            return self._target
651        return self.device.name
652
653    def open(self):
654        if not self.closed:
655            raise ValueError('Already opened')
656        self.register_device(TunTapDevice(self.target, IFF_TAP | IFF_NO_PI))
657        self.device.up()
658        self.dprintf('connected: %s\n', lambda: self.target)
659        return self.join_switch()
660
661    def write_message(self, message, binary=False):
662        if self.closed:
663            raise ValueError('I/O operation on closed tap')
664        self.device.write(message)
665
666    def read(self):
667        if self.closed:
668            raise ValueError('I/O operation on closed tap')
669        buf = []
670        while True:
671            buf.append(self.device.read(65535))
672            if len(buf[-1]) < 65535:
673                break
674        return ''.join(buf)
675
676
677class ClientHandler(BaseClientHandler):
678    IFTYPE = 'client'
679    IFOP_ALLOWED = False
680
681    def __init__(self, *args, **kwargs):
682        super(ClientHandler, self).__init__(*args, **kwargs)
683
684        self._ssl = kwargs.get('ssl', False)
685        self._options = {}
686
687        cred = kwargs.get('cred', None)
688
689        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
690            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
691            auth = ['Authorization: Basic %s' % token]
692            self._options['header'] = auth
693
694    def open(self):
695        sslwrap = websocket._SSLSocketWrapper
696
697        if not self.closed:
698            raise websocket.WebSocketException('Already opened')
699
700        if self._ssl:
701            websocket._SSLSocketWrapper = self._ssl
702
703        # XXX: may be blocked
704        try:
705            self.register_device(websocket.WebSocket())
706            self.device.connect(self.target, **self._options)
707            self.dprintf('connected: %s\n', lambda: self.target)
708            return self.join_switch()
709        finally:
710            websocket._SSLSocketWrapper = sslwrap
711
712    def fileno(self):
713        if self.closed:
714            raise websocket.WebSocketException('Closed socket')
715        return self.device.io_sock.fileno()
716
717    def write_message(self, message, binary=False):
718        if self.closed:
719            raise websocket.WebSocketException('Closed socket')
720        if binary:
721            flag = websocket.ABNF.OPCODE_BINARY
722        else:
723            flag = websocket.ABNF.OPCODE_TEXT
724        self.device.send(message, flag)
725
726    def read(self):
727        if self.closed:
728            raise websocket.WebSocketException('Closed socket')
729        return self.device.recv()
730
731
732class ControlServerHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
733    NAMESPACE = 'etherws.control'
734    IFTYPES = {
735        NetdevHandler.IFTYPE: NetdevHandler,
736        TapHandler.IFTYPE:    TapHandler,
737        ClientHandler.IFTYPE: ClientHandler,
738    }
739
740    def __init__(self, app, req, ioloop, switch, htpasswd, debug):
741        super(ControlServerHandler, self).__init__(app, req)
742        self._ioloop = ioloop
743        self._switch = switch
744        self._htpasswd = htpasswd
745        self._debug = debug
746
747    def post(self):
748        try:
749            request = json.loads(self.request.body)
750        except Exception as e:
751            return self._jsonrpc_response(error={
752                'code':    0 - 32700,
753                'message': 'Parse error',
754                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
755            })
756
757        try:
758            id_ = request.get('id')
759            params = request.get('params')
760            version = request['jsonrpc']
761            method = request['method']
762            if version != '2.0':
763                raise ValueError('Invalid JSON-RPC version: %s' % version)
764        except Exception as e:
765            return self._jsonrpc_response(id_=id_, error={
766                'code':    0 - 32600,
767                'message': 'Invalid Request',
768                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
769            })
770
771        try:
772            if not method.startswith(self.NAMESPACE + '.'):
773                raise ValueError('Invalid method namespace: %s' % method)
774            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
775            handler = getattr(self, handler)
776        except Exception as e:
777            return self._jsonrpc_response(id_=id_, error={
778                'code':    0 - 32601,
779                'message': 'Method not found',
780                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
781            })
782
783        try:
784            return self._jsonrpc_response(id_=id_, result=handler(params))
785        except Exception as e:
786            traceback.print_exc()
787            return self._jsonrpc_response(id_=id_, error={
788                'code':    0 - 32602,
789                'message': 'Invalid params',
790                'data':     '%s: %s' % (e.__class__.__name__, str(e)),
791            })
792
793    def handle_listFdb(self, params):
794        list_ = []
795        for vid, mac, entry in self._switch.fdb.each():
796            list_.append({
797                'vid':  vid,
798                'mac':  EthernetFrame.format_mac(mac),
799                'port': entry.port.number,
800                'age':  int(entry.age),
801            })
802        return {'entries': list_}
803
804    def handle_listPort(self, params):
805        return {'entries': [self._portstat(p) for p in self._switch.portlist]}
806
807    def handle_addPort(self, params):
808        type_ = params['type']
809        target = params['target']
810        opt = getattr(self, '_optparse_' + type_)(params.get('options', {}))
811        cls = self.IFTYPES[type_]
812        interface = cls(self._ioloop, self._switch, target, self._debug, **opt)
813        portnum = interface.open()
814        return {'entries': [self._portstat(self._switch.get_port(portnum))]}
815
816    def handle_setPort(self, params):
817        port = self._switch.get_port(int(params['port']))
818        shut = params.get('shut')
819        if shut is not None:
820            port.shut = bool(shut)
821        return {'entries': [self._portstat(port)]}
822
823    def handle_delPort(self, params):
824        port = self._switch.get_port(int(params['port']))
825        port.interface.close()
826        return {'entries': [self._portstat(port)]}
827
828    def handle_setInterface(self, params):
829        portnum = int(params['port'])
830        port = self._switch.get_port(portnum)
831        address = params.get('address')
832        netmask = params.get('netmask')
833        mtu = params.get('mtu')
834        if not port.interface.IFOP_ALLOWED:
835            raise ValueError('Port %d has unsupported interface: %s' %
836                             (portnum, port.interface.IFTYPE))
837        if address is not None:
838            port.interface.address = address
839        if netmask is not None:
840            port.interface.netmask = netmask
841        if mtu is not None:
842            port.interface.mtu = mtu
843        return {'entries': [self._ifstat(port)]}
844
845    def handle_listInterface(self, params):
846        return {'entries': [self._ifstat(p) for p in self._switch.portlist
847                            if p.interface.IFOP_ALLOWED]}
848
849    def _optparse_netdev(self, opt):
850        return {}
851
852    def _optparse_tap(self, opt):
853        return {}
854
855    def _optparse_client(self, opt):
856        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
857        if opt.get('insecure'):
858            args = {}
859        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
860        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
861        return {'ssl': ssl_, 'cred': cred}
862
863    def _jsonrpc_response(self, id_=None, result=None, error=None):
864        res = {'jsonrpc': '2.0', 'id': id_}
865        if result:
866            res['result'] = result
867        if error:
868            res['error'] = error
869        self.finish(res)
870
871    @staticmethod
872    def _portstat(port):
873        return {
874            'port':   port.number,
875            'type':   port.interface.IFTYPE,
876            'target': port.interface.target,
877            'tx':     port.tx,
878            'rx':     port.rx,
879            'shut':   port.shut,
880        }
881
882    @staticmethod
883    def _ifstat(port):
884        return {
885            'port':    port.number,
886            'type':    port.interface.IFTYPE,
887            'target':  port.interface.target,
888            'address': port.interface.address,
889            'netmask': port.interface.netmask,
890            'mtu':     port.interface.mtu,
891        }
892
893
894def _print_error(error):
895    print(%s (%s)' % (error['message'], error['code']))
896    print('    %s' % error['data'])
897
898
899def _start_sw(args):
900    def daemonize(nochdir=False, noclose=False):
901        if os.fork() > 0:
902            sys.exit(0)
903
904        os.setsid()
905
906        if os.fork() > 0:
907            sys.exit(0)
908
909        if not nochdir:
910            os.chdir('/')
911
912        if not noclose:
913            os.umask(0)
914            sys.stdin.close()
915            sys.stdout.close()
916            sys.stderr.close()
917            os.close(0)
918            os.close(1)
919            os.close(2)
920            sys.stdin = open(os.devnull)
921            sys.stdout = open(os.devnull, 'a')
922            sys.stderr = open(os.devnull, 'a')
923
924    def checkabspath(ns, path):
925        val = getattr(ns, path, '')
926        if not val.startswith('/'):
927            raise ValueError('Invalid %: %s' % (path, val))
928
929    def getsslopt(ns, key, cert):
930        kval = getattr(ns, key, None)
931        cval = getattr(ns, cert, None)
932        if kval and cval:
933            return {'keyfile': kval, 'certfile': cval}
934        elif kval or cval:
935            raise ValueError('Both %s and %s are required' % (key, cert))
936        return None
937
938    def setrealpath(ns, *keys):
939        for k in keys:
940            v = getattr(ns, k, None)
941            if v is not None:
942                v = os.path.realpath(v)
943                open(v).close()  # check readable
944                setattr(ns, k, v)
945
946    def setport(ns, port, isssl):
947        val = getattr(ns, port, None)
948        if val is None:
949            if isssl:
950                return setattr(ns, port, 443)
951            return setattr(ns, port, 80)
952        if not (0 <= val <= 65535):
953            raise ValueError('Invalid %s: %s' % (port, val))
954
955    def sethtpasswd(ns, htpasswd):
956        val = getattr(ns, htpasswd, None)
957        if val:
958            return setattr(ns, htpasswd, Htpasswd(val))
959
960    #if args.debug:
961    #    websocket.enableTrace(True)
962
963    if args.ageout <= 0:
964        raise ValueError('Invalid ageout: %s' % args.ageout)
965
966    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
967    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
968
969    checkabspath(args, 'path')
970    checkabspath(args, 'ctlpath')
971
972    sslopt = getsslopt(args, 'sslkey', 'sslcert')
973    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
974
975    setport(args, 'port', sslopt)
976    setport(args, 'ctlport', ctlsslopt)
977
978    sethtpasswd(args, 'htpasswd')
979    sethtpasswd(args, 'ctlhtpasswd')
980
981    ioloop = IOLoop.instance()
982    fdb = FDB(args.ageout, args.debug)
983    switch = SwitchingHub(fdb, args.debug)
984
985    if args.port == args.ctlport and args.host == args.ctlhost:
986        if args.path == args.ctlpath:
987            raise ValueError('Same path/ctlpath on same host')
988        if args.sslkey != args.ctlsslkey:
989            raise ValueError('Different sslkey/ctlsslkey on same host')
990        if args.sslcert != args.ctlsslcert:
991            raise ValueError('Different sslcert/ctlsslcert on same host')
992
993        app = Application([
994            (args.path, ServerHandler, {
995                'switch':   switch,
996                'htpasswd': args.htpasswd,
997                'debug':    args.debug,
998            }),
999            (args.ctlpath, ControlServerHandler, {
1000                'ioloop':   ioloop,
1001                'switch':   switch,
1002                'htpasswd': args.ctlhtpasswd,
1003                'debug':    args.debug,
1004            }),
1005        ])
1006        server = HTTPServer(app, ssl_options=sslopt)
1007        server.listen(args.port, address=args.host)
1008
1009    else:
1010        app = Application([(args.path, ServerHandler, {
1011            'switch':   switch,
1012            'htpasswd': args.htpasswd,
1013            'debug':    args.debug,
1014        })])
1015        server = HTTPServer(app, ssl_options=sslopt)
1016        server.listen(args.port, address=args.host)
1017
1018        ctl = Application([(args.ctlpath, ControlServerHandler, {
1019            'ioloop':   ioloop,
1020            'switch':   switch,
1021            'htpasswd': args.ctlhtpasswd,
1022            'debug':    args.debug,
1023        })])
1024        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
1025        ctlserver.listen(args.ctlport, address=args.ctlhost)
1026
1027    if not args.foreground:
1028        daemonize()
1029
1030    ioloop.start()
1031
1032
1033def _start_ctl(args):
1034    def request(args, method, params=None, id_=0):
1035        req = urllib2.Request(args.ctlurl)
1036        req.add_header('Content-type', 'application/json')
1037        if args.ctluser:
1038            if not args.ctlpasswd:
1039                args.ctlpasswd = getpass.getpass('Control Password: ')
1040            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
1041            req.add_header('Authorization', 'Basic %s' % token)
1042        method = '.'.join([ControlServerHandler.NAMESPACE, method])
1043        data = {'jsonrpc': '2.0', 'method': method, 'id': id_}
1044        if params is not None:
1045            data['params'] = params
1046        return json.loads(urllib2.urlopen(req, json.dumps(data)).read())
1047
1048    def print_table(rows):
1049        cols = zip(*rows)
1050        maxlen = [0] * len(cols)
1051        for i in xrange(len(cols)):
1052            maxlen[i] = max(len(str(c)) for c in cols[i])
1053        fmt = '  '.join(['%%-%ds' % maxlen[i] for i in xrange(len(cols))])
1054        fmt = '  ' + fmt
1055        for row in rows:
1056            print(fmt % tuple(row))
1057
1058    def print_portlist(result):
1059        rows = [['Port', 'Type', 'State', 'RX', 'TX', 'Target']]
1060        for r in result:
1061            rows.append([r['port'], r['type'], 'shut' if r['shut'] else 'up',
1062                         r['rx'], r['tx'], r['target']])
1063        print_table(rows)
1064
1065    def print_iflist(result):
1066        rows = [['Port', 'Type', 'Address', 'Netmask', 'MTU', 'Target']]
1067        for r in result:
1068            rows.append([r['port'], r['type'], r['address'],
1069                         r['netmask'], r['mtu'], r['target']])
1070        print_table(rows)
1071
1072    def handle_ctl_addport(args):
1073        opts = {
1074            'user':     getattr(args, 'user', None),
1075            'passwd':   getattr(args, 'passwd', None),
1076            'cacerts':  getattr(args, 'cacerts', None),
1077            'insecure': getattr(args, 'insecure', None),
1078        }
1079        if args.iftype == ClientHandler.IFTYPE:
1080            if not args.target.startswith('ws://') and \
1081               not args.target.startswith('wss://'):
1082                raise ValueError('Invalid target URL scheme: %s' % args.target)
1083            if not opts['user'] and opts['passwd']:
1084                raise ValueError('Authentication required but username empty')
1085            if opts['user'] and not opts['passwd']:
1086                opts['passwd'] = getpass.getpass('Client Password: ')
1087        result = request(args, 'addPort', {
1088            'type':    args.iftype,
1089            'target':  args.target,
1090            'options': opts,
1091        })
1092        if 'error' in result:
1093            _print_error(result['error'])
1094        else:
1095            print_portlist(result['result']['entries'])
1096
1097    def handle_ctl_setport(args):
1098        if args.port <= 0:
1099            raise ValueError('Invalid port: %d' % args.port)
1100        req = {'port': args.port}
1101        shut = getattr(args, 'shut', None)
1102        if shut is not None:
1103            req['shut'] = bool(shut)
1104        result = request(args, 'setPort', req)
1105        if 'error' in result:
1106            _print_error(result['error'])
1107        else:
1108            print_portlist(result['result']['entries'])
1109
1110    def handle_ctl_delport(args):
1111        if args.port <= 0:
1112            raise ValueError('Invalid port: %d' % args.port)
1113        result = request(args, 'delPort', {'port': args.port})
1114        if 'error' in result:
1115            _print_error(result['error'])
1116        else:
1117            print_portlist(result['result']['entries'])
1118
1119    def handle_ctl_listport(args):
1120        result = request(args, 'listPort')
1121        if 'error' in result:
1122            _print_error(result['error'])
1123        else:
1124            print_portlist(result['result']['entries'])
1125
1126    def handle_ctl_setif(args):
1127        if args.port <= 0:
1128            raise ValueError('Invalid port: %d' % args.port)
1129        req = {'port': args.port}
1130        address = getattr(args, 'address', None)
1131        netmask = getattr(args, 'netmask', None)
1132        mtu = getattr(args, 'mtu', None)
1133        if address is not None:
1134            if address:
1135                socket.inet_aton(address)  # validate
1136            req['address'] = address
1137        if netmask is not None:
1138            if netmask:
1139                socket.inet_aton(netmask)  # validate
1140            req['netmask'] = netmask
1141        if mtu is not None:
1142            if mtu < 576:
1143                raise ValueError('Invalid MTU: %d' % mtu)
1144            req['mtu'] = mtu
1145        result = request(args, 'setInterface', req)
1146        if 'error' in result:
1147            _print_error(result['error'])
1148        else:
1149            print_iflist(result['result']['entries'])
1150
1151    def handle_ctl_listif(args):
1152        result = request(args, 'listInterface')
1153        if 'error' in result:
1154            _print_error(result['error'])
1155        else:
1156            print_iflist(result['result']['entries'])
1157
1158    def handle_ctl_listfdb(args):
1159        result = request(args, 'listFdb')
1160        if 'error' in result:
1161            return _print_error(result['error'])
1162        rows = [['Port', 'VLAN', 'MAC', 'Age']]
1163        for r in result['result']['entries']:
1164            rows.append([r['port'], r['vid'], r['mac'], r['age']])
1165        print_table(rows)
1166
1167    locals()['handle_ctl_' + args.control_method](args)
1168
1169
1170def _main():
1171    parser = argparse.ArgumentParser()
1172    subcommand = parser.add_subparsers(dest='subcommand')
1173
1174    # - sw
1175    parser_sw = subcommand.add_parser('sw',
1176                                      help='start virtual switch')
1177
1178    parser_sw.add_argument('--debug', action='store_true', default=False,
1179                           help='run as debug mode')
1180    parser_sw.add_argument('--foreground', action='store_true', default=False,
1181                           help='run as foreground mode')
1182    parser_sw.add_argument('--ageout', type=int, default=300,
1183                           help='FDB ageout time (sec)')
1184
1185    parser_sw.add_argument('--path', default='/',
1186                           help='http(s) path to serve WebSocket')
1187    parser_sw.add_argument('--host', default='',
1188                           help='listen address to serve WebSocket')
1189    parser_sw.add_argument('--port', type=int,
1190                           help='listen port to serve WebSocket')
1191    parser_sw.add_argument('--htpasswd',
1192                           help='path to htpasswd file to auth WebSocket')
1193    parser_sw.add_argument('--sslkey',
1194                           help='path to SSL private key for WebSocket')
1195    parser_sw.add_argument('--sslcert',
1196                           help='path to SSL certificate for WebSocket')
1197
1198    parser_sw.add_argument('--ctlpath', default='/ctl',
1199                           help='http(s) path to serve control API')
1200    parser_sw.add_argument('--ctlhost', default='127.0.0.1',
1201                           help='listen address to serve control API')
1202    parser_sw.add_argument('--ctlport', type=int, default=7867,
1203                           help='listen port to serve control API')
1204    parser_sw.add_argument('--ctlhtpasswd',
1205                           help='path to htpasswd file to auth control API')
1206    parser_sw.add_argument('--ctlsslkey',
1207                           help='path to SSL private key for control API')
1208    parser_sw.add_argument('--ctlsslcert',
1209                           help='path to SSL certificate for control API')
1210
1211    # - ctl
1212    parser_ctl = subcommand.add_parser('ctl',
1213                                       help='control virtual switch')
1214    parser_ctl.add_argument('--ctlurl', default='http://127.0.0.1:7867/ctl',
1215                            help='URL to control API')
1216    parser_ctl.add_argument('--ctluser',
1217                            help='username to auth control API')
1218    parser_ctl.add_argument('--ctlpasswd',
1219                            help='password to auth control API')
1220
1221    control_method = parser_ctl.add_subparsers(dest='control_method')
1222
1223    # -- ctl addport
1224    parser_ctl_addport = control_method.add_parser('addport',
1225                                                   help='create and add port')
1226    iftype = parser_ctl_addport.add_subparsers(dest='iftype')
1227
1228    # --- ctl addport netdev
1229    parser_ctl_addport_netdev = iftype.add_parser(NetdevHandler.IFTYPE,
1230                                                  help='Network device')
1231    parser_ctl_addport_netdev.add_argument('target',
1232                                           help='device name to add interface')
1233
1234    # --- ctl addport tap
1235    parser_ctl_addport_tap = iftype.add_parser(TapHandler.IFTYPE,
1236                                               help='TAP device')
1237    parser_ctl_addport_tap.add_argument('target',
1238                                        help='device name to create interface')
1239
1240    # --- ctl addport client
1241    parser_ctl_addport_client = iftype.add_parser(ClientHandler.IFTYPE,
1242                                                  help='WebSocket client')
1243    parser_ctl_addport_client.add_argument('target',
1244                                           help='URL to connect WebSocket')
1245    parser_ctl_addport_client.add_argument('--user',
1246                                           help='username to auth WebSocket')
1247    parser_ctl_addport_client.add_argument('--passwd',
1248                                           help='password to auth WebSocket')
1249    parser_ctl_addport_client.add_argument('--cacerts',
1250                                           help='path to CA certificate')
1251    parser_ctl_addport_client.add_argument(
1252        '--insecure', action='store_true', default=False,
1253        help='do not verify server certificate')
1254
1255    # -- ctl setport
1256    parser_ctl_setport = control_method.add_parser('setport',
1257                                                   help='set port config')
1258    parser_ctl_setport.add_argument('port', type=int,
1259                                    help='port number to set config')
1260    parser_ctl_setport.add_argument('--shut', type=int, choices=(0, 1),
1261                                    help='set shutdown state')
1262
1263    # -- ctl delport
1264    parser_ctl_delport = control_method.add_parser('delport',
1265                                                   help='delete port')
1266    parser_ctl_delport.add_argument('port', type=int,
1267                                    help='port number to delete')
1268
1269    # -- ctl listport
1270    parser_ctl_listport = control_method.add_parser('listport',
1271                                                    help='show port list')
1272
1273    # -- ctl setif
1274    parser_ctl_setif = control_method.add_parser('setif',
1275                                                 help='set interface config')
1276    parser_ctl_setif.add_argument('port', type=int,
1277                                  help='port number to set config')
1278    parser_ctl_setif.add_argument('--address',
1279                                  help='IPv4 address to set interface')
1280    parser_ctl_setif.add_argument('--netmask',
1281                                  help='IPv4 netmask to set interface')
1282    parser_ctl_setif.add_argument('--mtu', type=int,
1283                                  help='MTU to set interface')
1284
1285    # -- ctl listif
1286    parser_ctl_listif = control_method.add_parser('listif',
1287                                                  help='show interface list')
1288
1289    # -- ctl listfdb
1290    parser_ctl_listfdb = control_method.add_parser('listfdb',
1291                                                   help='show FDB entries')
1292
1293    # -- go
1294    args = parser.parse_args()
1295
1296    try:
1297        globals()['_start_' + args.subcommand](args)
1298    except Exception as e:
1299        _print_error({
1300            'code':    0 - 32603,
1301            'message': 'Internal error',
1302            'data':    '%s: %s' % (e.__class__.__name__, str(e)),
1303        })
1304
1305
1306if __name__ == '__main__':
1307    _main()
Note: See TracBrowser for help on using the repository browser.