source: etherws/trunk/etherws.py @ 276

Revision 276, 42.9 KB checked in by atzm, 9 years ago (diff)

fix a listport bug on tornado 4.0.x

  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.5
8#   - python-pytun-2.1
9#   - websocket-client-0.12.0
10#   - tornado-2.4
11#
12# ===========================================================================
13# Copyright (c) 2012-2015, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
42import ssl
43import time
44import json
45import fcntl
46import base64
47import socket
48import struct
49import urllib2
50import hashlib
51import getpass
52import operator
53import argparse
54import traceback
55
56import tornado
57import websocket
58
59from tornado.web import Application, RequestHandler
60from tornado.websocket import WebSocketHandler
61from tornado.httpserver import HTTPServer
62from tornado.ioloop import IOLoop
63
64from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
65
66
67class DebugMixIn(object):
68    def dprintf(self, msg, func=lambda: ()):
69        if self._debug:
70            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
71            sys.stderr.write(prefix + (msg % func()))
72
73
74class EthernetFrame(object):
75    def __init__(self, data):
76        self.data = data
77
78    @property
79    def dst_multicast(self):
80        return ord(self.data[0]) & 1
81
82    @property
83    def src_multicast(self):
84        return ord(self.data[6]) & 1
85
86    @property
87    def dst_mac(self):
88        return self.data[:6]
89
90    @property
91    def src_mac(self):
92        return self.data[6:12]
93
94    @property
95    def tagged(self):
96        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
97
98    @property
99    def vid(self):
100        if self.tagged:
101            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
102        return 0
103
104    @staticmethod
105    def format_mac(mac, sep=':'):
106        return sep.join(b.encode('hex') for b in mac)
107
108
109class FDB(DebugMixIn):
110    class Entry(object):
111        def __init__(self, port, ageout):
112            self.port = port
113            self._time = time.time()
114            self._ageout = ageout
115
116        @property
117        def age(self):
118            return time.time() - self._time
119
120        @property
121        def agedout(self):
122            return self.age > self._ageout
123
124    def __init__(self, ageout, debug):
125        self._ageout = ageout
126        self._debug = debug
127        self._table = {}
128
129    def _set_entry(self, vid, mac, port):
130        if vid not in self._table:
131            self._table[vid] = {}
132        self._table[vid][mac] = self.Entry(port, self._ageout)
133
134    def _del_entry(self, vid, mac):
135        if vid in self._table:
136            if mac in self._table[vid]:
137                del self._table[vid][mac]
138            if not self._table[vid]:
139                del self._table[vid]
140
141    def _get_entry(self, vid, mac):
142        try:
143            entry = self._table[vid][mac]
144        except KeyError:
145            return None
146
147        if not entry.agedout:
148            return entry
149
150        self._del_entry(vid, mac)
151        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
152                     lambda: (entry.port.number, vid, mac.encode('hex')))
153
154    def each(self):
155        for vid in sorted(self._table.iterkeys()):
156            for mac in sorted(self._table[vid].iterkeys()):
157                entry = self._get_entry(vid, mac)
158                if entry:
159                    yield (vid, mac, entry)
160
161    def lookup(self, frame):
162        mac = frame.dst_mac
163        vid = frame.vid
164        entry = self._get_entry(vid, mac)
165        return getattr(entry, 'port', None)
166
167    def learn(self, port, frame):
168        mac = frame.src_mac
169        vid = frame.vid
170        self._set_entry(vid, mac, port)
171        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
172                     lambda: (port.number, vid, mac.encode('hex')))
173
174    def delete(self, port):
175        for vid, mac, entry in self.each():
176            if entry.port.number == port.number:
177                self._del_entry(vid, mac)
178                self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
179                             lambda: (port.number, vid, mac.encode('hex')))
180
181
182class SwitchingHub(DebugMixIn):
183    class Port(object):
184        def __init__(self, number, interface):
185            self.number = number
186            self.interface = interface
187            self.tx = 0
188            self.rx = 0
189            self.shut = False
190
191    def __init__(self, fdb, debug):
192        self.fdb = fdb
193        self._debug = debug
194        self._table = {}
195        self._next = 1
196
197    @property
198    def portlist(self):
199        return sorted(self._table.itervalues(),
200                      key=operator.attrgetter('number'))
201
202    def get_port(self, portnum):
203        return self._table[portnum]
204
205    def register_port(self, interface):
206        try:
207            self._set_privattr('portnum', interface, self._next)  # XXX
208            self._table[self._next] = self.Port(self._next, interface)
209            return self._next
210        finally:
211            self._next += 1
212
213    def unregister_port(self, interface):
214        portnum = self._get_privattr('portnum', interface)
215        self._del_privattr('portnum', interface)
216        self.fdb.delete(self._table[portnum])
217        del self._table[portnum]
218
219    def send(self, dst_interfaces, frame):
220        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
221        ports = (self._table[n] for n in portnums)
222        ports = (p for p in ports if not p.shut)
223        ports = sorted(ports, key=operator.attrgetter('number'))
224
225        for p in ports:
226            p.interface.write_message(frame.data, True)
227            p.tx += 1
228
229        if ports:
230            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
231                         lambda: (','.join(str(p.number) for p in ports),
232                                  frame.vid,
233                                  frame.src_mac.encode('hex'),
234                                  frame.dst_mac.encode('hex')))
235
236    def receive(self, src_interface, frame):
237        port = self._table[self._get_privattr('portnum', src_interface)]
238
239        if not port.shut:
240            port.rx += 1
241            self._forward(port, frame)
242
243    def _forward(self, src_port, frame):
244        try:
245            if not frame.src_multicast:
246                self.fdb.learn(src_port, frame)
247
248            if not frame.dst_multicast:
249                dst_port = self.fdb.lookup(frame)
250
251                if dst_port:
252                    self.send([dst_port.interface], frame)
253                    return
254
255            ports = set(self.portlist) - set([src_port])
256            self.send((p.interface for p in ports), frame)
257
258        except:  # ex. received invalid frame
259            traceback.print_exc()
260
261    def _privattr(self, name):
262        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
263
264    def _set_privattr(self, name, obj, value):
265        return setattr(obj, self._privattr(name), value)
266
267    def _get_privattr(self, name, obj, defaults=None):
268        return getattr(obj, self._privattr(name), defaults)
269
270    def _del_privattr(self, name, obj):
271        return delattr(obj, self._privattr(name))
272
273
274class NetworkInterface(object):
275    SIOCGIFADDR = 0x8915  # from <linux/sockios.h>
276    SIOCSIFADDR = 0x8916
277    SIOCGIFNETMASK = 0x891b
278    SIOCSIFNETMASK = 0x891c
279    SIOCGIFMTU = 0x8921
280    SIOCSIFMTU = 0x8922
281
282    def __init__(self, ifname):
283        self._ifname = struct.pack('16s', str(ifname)[:15])
284
285    def _ioctl(self, req, data):
286        try:
287            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
288            return fcntl.ioctl(sock.fileno(), req, data)
289        finally:
290            sock.close()
291
292    @property
293    def address(self):
294        try:
295            data = struct.pack('16s18x', self._ifname)
296            ret = self._ioctl(self.SIOCGIFADDR, data)
297            return socket.inet_ntoa(ret[20:24])
298        except IOError:
299            return ''
300
301    @property
302    def netmask(self):
303        try:
304            data = struct.pack('16s18x', self._ifname)
305            ret = self._ioctl(self.SIOCGIFNETMASK, data)
306            return socket.inet_ntoa(ret[20:24])
307        except IOError:
308            return ''
309
310    @property
311    def mtu(self):
312        try:
313            data = struct.pack('16s18x', self._ifname)
314            ret = self._ioctl(self.SIOCGIFMTU, data)
315            return struct.unpack('i', ret[16:20])[0]
316        except IOError:
317            return 0
318
319    @address.setter
320    def address(self, addr):
321        data = struct.pack('16si4s10x', self._ifname, socket.AF_INET,
322                           socket.inet_aton(addr))
323        self._ioctl(self.SIOCSIFADDR, data)
324
325    @netmask.setter
326    def netmask(self, addr):
327        data = struct.pack('16si4s10x', self._ifname, socket.AF_INET,
328                           socket.inet_aton(addr))
329        self._ioctl(self.SIOCSIFNETMASK, data)
330
331    @mtu.setter
332    def mtu(self, mtu):
333        data = struct.pack('16si12x', self._ifname, mtu)
334        self._ioctl(self.SIOCSIFMTU, data)
335
336
337class Htpasswd(object):
338    def __init__(self, path):
339        self._path = path
340        self._stat = None
341        self._data = {}
342
343    def auth(self, name, passwd):
344        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
345        return self._data.get(name) == passwd
346
347    def load(self):
348        old_stat = self._stat
349
350        with open(self._path) as fp:
351            fileno = fp.fileno()
352            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
353            self._stat = os.fstat(fileno)
354
355            unchanged = old_stat and \
356                old_stat.st_ino == self._stat.st_ino and \
357                old_stat.st_dev == self._stat.st_dev and \
358                old_stat.st_mtime == self._stat.st_mtime
359
360            if not unchanged:
361                self._data = self._parse(fp)
362
363        return self
364
365    def _parse(self, fp):
366        data = {}
367        for line in fp:
368            line = line.strip()
369            if 0 <= line.find(':'):
370                name, passwd = line.split(':', 1)
371                if passwd.startswith('{SHA}'):
372                    data[name] = passwd[5:]
373        return data
374
375
376class BasicAuthMixIn(object):
377    def _execute(self, transforms, *args, **kwargs):
378        def do_execute():
379            sp = super(BasicAuthMixIn, self)
380            return sp._execute(transforms, *args, **kwargs)
381
382        def auth_required():
383            stream = getattr(self, 'stream', self.request.connection.stream)
384            stream.write(tornado.escape.utf8(
385                'HTTP/1.1 401 Authorization Required\r\n'
386                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
387            ))
388            stream.close()
389
390        try:
391            if not self._htpasswd:
392                return do_execute()
393
394            creds = self.request.headers.get('Authorization')
395
396            if not creds or not creds.startswith('Basic '):
397                return auth_required()
398
399            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
400
401            if self._htpasswd.load().auth(name, passwd):
402                return do_execute()
403        except:
404            traceback.print_exc()
405
406        return auth_required()
407
408
409class ServerHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
410    IFTYPE = 'server'
411    IFOP_ALLOWED = False
412
413    def __init__(self, app, req, switch, htpasswd, debug):
414        super(ServerHandler, self).__init__(app, req)
415        self._switch = switch
416        self._htpasswd = htpasswd
417        self._debug = debug
418
419    @property
420    def target(self):
421        conn = self.request.connection
422        if hasattr(conn, 'address'):
423            return ':'.join(str(e) for e in conn.address)
424        if hasattr(conn, 'context'):
425            return ':'.join(str(e) for e in conn.context.address)
426        return str(conn)
427
428    def open(self):
429        try:
430            return self._switch.register_port(self)
431        finally:
432            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
433
434    def on_message(self, message):
435        self._switch.receive(self, EthernetFrame(message))
436
437    def on_close(self):
438        self._switch.unregister_port(self)
439        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
440
441
442class BaseClientHandler(DebugMixIn):
443    IFTYPE = 'baseclient'
444    IFOP_ALLOWED = False
445
446    def __init__(self, ioloop, switch, target, debug, *args, **kwargs):
447        self._ioloop = ioloop
448        self._switch = switch
449        self._target = target
450        self._debug = debug
451        self._args = args
452        self._kwargs = kwargs
453        self._device = None
454
455    @property
456    def address(self):
457        raise NotImplementedError('unsupported')
458
459    @property
460    def netmask(self):
461        raise NotImplementedError('unsupported')
462
463    @property
464    def mtu(self):
465        raise NotImplementedError('unsupported')
466
467    @address.setter
468    def address(self, address):
469        raise NotImplementedError('unsupported')
470
471    @netmask.setter
472    def netmask(self, netmask):
473        raise NotImplementedError('unsupported')
474
475    @mtu.setter
476    def mtu(self, mtu):
477        raise NotImplementedError('unsupported')
478
479    def open(self):
480        raise NotImplementedError('unsupported')
481
482    def write_message(self, message, binary=False):
483        raise NotImplementedError('unsupported')
484
485    def read(self):
486        raise NotImplementedError('unsupported')
487
488    @property
489    def target(self):
490        return self._target
491
492    @property
493    def device(self):
494        return self._device
495
496    @property
497    def closed(self):
498        return not self.device
499
500    def close(self):
501        if self.closed:
502            raise ValueError('I/O operation on closed %s' % self.IFTYPE)
503        self.leave_switch()
504        self.unregister_device()
505        self.dprintf('disconnected: %s\n', lambda: self.target)
506
507    def register_device(self, device):
508        self._device = device
509
510    def unregister_device(self):
511        self._device.close()
512        self._device = None
513
514    def fileno(self):
515        if self.closed:
516            raise ValueError('I/O operation on closed %s' % self.IFTYPE)
517        return self.device.fileno()
518
519    def __call__(self, fd, events):
520        try:
521            data = self.read()
522            if data is not None:
523                self._switch.receive(self, EthernetFrame(data))
524                return
525        except:
526            traceback.print_exc()
527        self.close()
528
529    def join_switch(self):
530        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
531        return self._switch.register_port(self)
532
533    def leave_switch(self):
534        self._switch.unregister_port(self)
535        self._ioloop.remove_handler(self.fileno())
536
537
538class NetdevHandler(BaseClientHandler):
539    IFTYPE = 'netdev'
540    IFOP_ALLOWED = True
541    ETH_P_ALL = 0x0003  # from <linux/if_ether.h>
542
543    @property
544    def address(self):
545        if self.closed:
546            raise ValueError('I/O operation on closed netdev')
547        return NetworkInterface(self.target).address
548
549    @property
550    def netmask(self):
551        if self.closed:
552            raise ValueError('I/O operation on closed netdev')
553        return NetworkInterface(self.target).netmask
554
555    @property
556    def mtu(self):
557        if self.closed:
558            raise ValueError('I/O operation on closed netdev')
559        return NetworkInterface(self.target).mtu
560
561    @address.setter
562    def address(self, address):
563        if self.closed:
564            raise ValueError('I/O operation on closed netdev')
565        NetworkInterface(self.target).address = address
566
567    @netmask.setter
568    def netmask(self, netmask):
569        if self.closed:
570            raise ValueError('I/O operation on closed netdev')
571        NetworkInterface(self.target).netmask = netmask
572
573    @mtu.setter
574    def mtu(self, mtu):
575        if self.closed:
576            raise ValueError('I/O operation on closed netdev')
577        NetworkInterface(self.target).mtu = mtu
578
579    def open(self):
580        if not self.closed:
581            raise ValueError('Already opened')
582        self.register_device(socket.socket(
583            socket.PF_PACKET, socket.SOCK_RAW, socket.htons(self.ETH_P_ALL)))
584        self.device.bind((self.target, self.ETH_P_ALL))
585        self.dprintf('connected: %s\n', lambda: self.target)
586        return self.join_switch()
587
588    def write_message(self, message, binary=False):
589        if self.closed:
590            raise ValueError('I/O operation on closed netdev')
591        self.device.sendall(message)
592
593    def read(self):
594        if self.closed:
595            raise ValueError('I/O operation on closed netdev')
596        buf = []
597        while True:
598            buf.append(self.device.recv(65535))
599            if len(buf[-1]) < 65535:
600                break
601        return ''.join(buf)
602
603
604class TapHandler(BaseClientHandler):
605    IFTYPE = 'tap'
606    IFOP_ALLOWED = True
607
608    @property
609    def address(self):
610        if self.closed:
611            raise ValueError('I/O operation on closed tap')
612        try:
613            return self.device.addr
614        except:
615            return ''
616
617    @property
618    def netmask(self):
619        if self.closed:
620            raise ValueError('I/O operation on closed tap')
621        try:
622            return self.device.netmask
623        except:
624            return ''
625
626    @property
627    def mtu(self):
628        if self.closed:
629            raise ValueError('I/O operation on closed tap')
630        return self.device.mtu
631
632    @address.setter
633    def address(self, address):
634        if self.closed:
635            raise ValueError('I/O operation on closed tap')
636        self.device.addr = address
637
638    @netmask.setter
639    def netmask(self, netmask):
640        if self.closed:
641            raise ValueError('I/O operation on closed tap')
642        self.device.netmask = netmask
643
644    @mtu.setter
645    def mtu(self, mtu):
646        if self.closed:
647            raise ValueError('I/O operation on closed tap')
648        self.device.mtu = mtu
649
650    @property
651    def target(self):
652        if self.closed:
653            return self._target
654        return self.device.name
655
656    def open(self):
657        if not self.closed:
658            raise ValueError('Already opened')
659        self.register_device(TunTapDevice(self.target, IFF_TAP | IFF_NO_PI))
660        self.device.up()
661        self.dprintf('connected: %s\n', lambda: self.target)
662        return self.join_switch()
663
664    def write_message(self, message, binary=False):
665        if self.closed:
666            raise ValueError('I/O operation on closed tap')
667        self.device.write(message)
668
669    def read(self):
670        if self.closed:
671            raise ValueError('I/O operation on closed tap')
672        buf = []
673        while True:
674            buf.append(self.device.read(65535))
675            if len(buf[-1]) < 65535:
676                break
677        return ''.join(buf)
678
679
680class ClientHandler(BaseClientHandler):
681    IFTYPE = 'client'
682    IFOP_ALLOWED = False
683
684    def __init__(self, *args, **kwargs):
685        super(ClientHandler, self).__init__(*args, **kwargs)
686
687        self._sslopt = kwargs.get('sslopt', {})
688        self._options = {}
689
690        cred = kwargs.get('cred', None)
691
692        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
693            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
694            auth = ['Authorization: Basic %s' % token]
695            self._options['header'] = auth
696
697    def open(self):
698        if not self.closed:
699            raise websocket.WebSocketException('Already opened')
700
701        # XXX: may be blocked
702        self.register_device(websocket.WebSocket(sslopt=self._sslopt))
703        self.device.connect(self.target, **self._options)
704        self.dprintf('connected: %s\n', lambda: self.target)
705        return self.join_switch()
706
707    def write_message(self, message, binary=False):
708        if self.closed:
709            raise websocket.WebSocketException('Closed socket')
710        if binary:
711            flag = websocket.ABNF.OPCODE_BINARY
712        else:
713            flag = websocket.ABNF.OPCODE_TEXT
714        self.device.send(message, flag)
715
716    def read(self):
717        if self.closed:
718            raise websocket.WebSocketException('Closed socket')
719        return self.device.recv()
720
721
722class ControlServerHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
723    NAMESPACE = 'etherws.control'
724    IFTYPES = {
725        NetdevHandler.IFTYPE: NetdevHandler,
726        TapHandler.IFTYPE:    TapHandler,
727        ClientHandler.IFTYPE: ClientHandler,
728    }
729
730    def __init__(self, app, req, ioloop, switch, htpasswd, debug):
731        super(ControlServerHandler, self).__init__(app, req)
732        self._ioloop = ioloop
733        self._switch = switch
734        self._htpasswd = htpasswd
735        self._debug = debug
736
737    def post(self):
738        try:
739            request = json.loads(self.request.body)
740        except Exception as e:
741            return self._jsonrpc_response(error={
742                'code':    0 - 32700,
743                'message': 'Parse error',
744                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
745            })
746
747        try:
748            id_ = request.get('id')
749            params = request.get('params')
750            version = request['jsonrpc']
751            method = request['method']
752            if version != '2.0':
753                raise ValueError('Invalid JSON-RPC version: %s' % version)
754        except Exception as e:
755            return self._jsonrpc_response(id_=id_, error={
756                'code':    0 - 32600,
757                'message': 'Invalid Request',
758                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
759            })
760
761        try:
762            if not method.startswith(self.NAMESPACE + '.'):
763                raise ValueError('Invalid method namespace: %s' % method)
764            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
765            handler = getattr(self, handler)
766        except Exception as e:
767            return self._jsonrpc_response(id_=id_, error={
768                'code':    0 - 32601,
769                'message': 'Method not found',
770                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
771            })
772
773        try:
774            return self._jsonrpc_response(id_=id_, result=handler(params))
775        except Exception as e:
776            traceback.print_exc()
777            return self._jsonrpc_response(id_=id_, error={
778                'code':    0 - 32602,
779                'message': 'Invalid params',
780                'data':     '%s: %s' % (e.__class__.__name__, str(e)),
781            })
782
783    def handle_listFdb(self, params):
784        list_ = []
785        for vid, mac, entry in self._switch.fdb.each():
786            list_.append({
787                'vid':  vid,
788                'mac':  EthernetFrame.format_mac(mac),
789                'port': entry.port.number,
790                'age':  int(entry.age),
791            })
792        return {'entries': list_}
793
794    def handle_listPort(self, params):
795        return {'entries': [self._portstat(p) for p in self._switch.portlist]}
796
797    def handle_addPort(self, params):
798        type_ = params['type']
799        target = params['target']
800        opt = getattr(self, '_optparse_' + type_)(params.get('options', {}))
801        cls = self.IFTYPES[type_]
802        interface = cls(self._ioloop, self._switch, target, self._debug, **opt)
803        portnum = interface.open()
804        return {'entries': [self._portstat(self._switch.get_port(portnum))]}
805
806    def handle_setPort(self, params):
807        port = self._switch.get_port(int(params['port']))
808        shut = params.get('shut')
809        if shut is not None:
810            port.shut = bool(shut)
811        return {'entries': [self._portstat(port)]}
812
813    def handle_delPort(self, params):
814        port = self._switch.get_port(int(params['port']))
815        port.interface.close()
816        return {'entries': [self._portstat(port)]}
817
818    def handle_setInterface(self, params):
819        portnum = int(params['port'])
820        port = self._switch.get_port(portnum)
821        address = params.get('address')
822        netmask = params.get('netmask')
823        mtu = params.get('mtu')
824        if not port.interface.IFOP_ALLOWED:
825            raise ValueError('Port %d has unsupported interface: %s' %
826                             (portnum, port.interface.IFTYPE))
827        if address is not None:
828            port.interface.address = address
829        if netmask is not None:
830            port.interface.netmask = netmask
831        if mtu is not None:
832            port.interface.mtu = mtu
833        return {'entries': [self._ifstat(port)]}
834
835    def handle_listInterface(self, params):
836        return {'entries': [self._ifstat(p) for p in self._switch.portlist
837                            if p.interface.IFOP_ALLOWED]}
838
839    def _optparse_netdev(self, opt):
840        return {}
841
842    def _optparse_tap(self, opt):
843        return {}
844
845    def _optparse_client(self, opt):
846        if opt.get('insecure'):
847            sslopt = {'cert_reqs': ssl.CERT_NONE}
848        else:
849            sslopt = {'cert_reqs': ssl.CERT_REQUIRED,
850                      'ca_certs':  opt.get('cacerts')}
851        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
852        return {'sslopt': sslopt, 'cred': cred}
853
854    def _jsonrpc_response(self, id_=None, result=None, error=None):
855        res = {'jsonrpc': '2.0', 'id': id_}
856        if result:
857            res['result'] = result
858        if error:
859            res['error'] = error
860        self.finish(res)
861
862    @staticmethod
863    def _portstat(port):
864        return {
865            'port':   port.number,
866            'type':   port.interface.IFTYPE,
867            'target': port.interface.target,
868            'tx':     port.tx,
869            'rx':     port.rx,
870            'shut':   port.shut,
871        }
872
873    @staticmethod
874    def _ifstat(port):
875        return {
876            'port':    port.number,
877            'type':    port.interface.IFTYPE,
878            'target':  port.interface.target,
879            'address': port.interface.address,
880            'netmask': port.interface.netmask,
881            'mtu':     port.interface.mtu,
882        }
883
884
885def _print_error(error):
886    print(%s (%s)' % (error['message'], error['code']))
887    print('    %s' % error['data'])
888
889
890def _start_sw(args):
891    def daemonize(nochdir=False, noclose=False):
892        if os.fork() > 0:
893            sys.exit(0)
894
895        os.setsid()
896
897        if os.fork() > 0:
898            sys.exit(0)
899
900        if not nochdir:
901            os.chdir('/')
902
903        if not noclose:
904            os.umask(0)
905            sys.stdin.close()
906            sys.stdout.close()
907            sys.stderr.close()
908            os.close(0)
909            os.close(1)
910            os.close(2)
911            sys.stdin = open(os.devnull)
912            sys.stdout = open(os.devnull, 'a')
913            sys.stderr = open(os.devnull, 'a')
914
915    def checkabspath(ns, path):
916        val = getattr(ns, path, '')
917        if not val.startswith('/'):
918            raise ValueError('Invalid %: %s' % (path, val))
919
920    def getsslopt(ns, key, cert):
921        kval = getattr(ns, key, None)
922        cval = getattr(ns, cert, None)
923        if kval and cval:
924            return {'keyfile': kval, 'certfile': cval}
925        elif kval or cval:
926            raise ValueError('Both %s and %s are required' % (key, cert))
927        return None
928
929    def setrealpath(ns, *keys):
930        for k in keys:
931            v = getattr(ns, k, None)
932            if v is not None:
933                v = os.path.realpath(v)
934                open(v).close()  # check readable
935                setattr(ns, k, v)
936
937    def setport(ns, port, isssl):
938        val = getattr(ns, port, None)
939        if val is None:
940            if isssl:
941                return setattr(ns, port, 443)
942            return setattr(ns, port, 80)
943        if not (0 <= val <= 65535):
944            raise ValueError('Invalid %s: %s' % (port, val))
945
946    def sethtpasswd(ns, htpasswd):
947        val = getattr(ns, htpasswd, None)
948        if val:
949            return setattr(ns, htpasswd, Htpasswd(val))
950
951    # if args.debug:
952    #     websocket.enableTrace(True)
953
954    if args.ageout <= 0:
955        raise ValueError('Invalid ageout: %s' % args.ageout)
956
957    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
958    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
959
960    checkabspath(args, 'path')
961    checkabspath(args, 'ctlpath')
962
963    sslopt = getsslopt(args, 'sslkey', 'sslcert')
964    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
965
966    setport(args, 'port', sslopt)
967    setport(args, 'ctlport', ctlsslopt)
968
969    sethtpasswd(args, 'htpasswd')
970    sethtpasswd(args, 'ctlhtpasswd')
971
972    ioloop = IOLoop.instance()
973    fdb = FDB(args.ageout, args.debug)
974    switch = SwitchingHub(fdb, args.debug)
975
976    if args.port == args.ctlport and args.host == args.ctlhost:
977        if args.path == args.ctlpath:
978            raise ValueError('Same path/ctlpath on same host')
979        if args.sslkey != args.ctlsslkey:
980            raise ValueError('Different sslkey/ctlsslkey on same host')
981        if args.sslcert != args.ctlsslcert:
982            raise ValueError('Different sslcert/ctlsslcert on same host')
983
984        app = Application([
985            (args.path, ServerHandler, {
986                'switch':   switch,
987                'htpasswd': args.htpasswd,
988                'debug':    args.debug,
989            }),
990            (args.ctlpath, ControlServerHandler, {
991                'ioloop':   ioloop,
992                'switch':   switch,
993                'htpasswd': args.ctlhtpasswd,
994                'debug':    args.debug,
995            }),
996        ])
997        server = HTTPServer(app, ssl_options=sslopt)
998        server.listen(args.port, address=args.host)
999
1000    else:
1001        app = Application([(args.path, ServerHandler, {
1002            'switch':   switch,
1003            'htpasswd': args.htpasswd,
1004            'debug':    args.debug,
1005        })])
1006        server = HTTPServer(app, ssl_options=sslopt)
1007        server.listen(args.port, address=args.host)
1008
1009        ctl = Application([(args.ctlpath, ControlServerHandler, {
1010            'ioloop':   ioloop,
1011            'switch':   switch,
1012            'htpasswd': args.ctlhtpasswd,
1013            'debug':    args.debug,
1014        })])
1015        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
1016        ctlserver.listen(args.ctlport, address=args.ctlhost)
1017
1018    if not args.foreground:
1019        daemonize()
1020
1021    ioloop.start()
1022
1023
1024def _start_ctl(args):
1025    def have_ssl_cert_verification():
1026        return 'context' in urllib2.urlopen.__code__.co_varnames
1027
1028    def request(args, method, params=None, id_=0):
1029        req = urllib2.Request(args.ctlurl)
1030        req.add_header('Content-type', 'application/json')
1031        if args.ctluser:
1032            if not args.ctlpasswd:
1033                args.ctlpasswd = getpass.getpass('Control Password: ')
1034            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
1035            req.add_header('Authorization', 'Basic %s' % token)
1036        method = '.'.join([ControlServerHandler.NAMESPACE, method])
1037        data = {'jsonrpc': '2.0', 'method': method, 'id': id_}
1038        if params is not None:
1039            data['params'] = params
1040        if have_ssl_cert_verification():
1041            ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH,
1042                                             cafile=args.ctlsslcert)
1043            if args.ctlinsecure:
1044                ctx.check_hostname = False
1045                ctx.verify_mode = ssl.CERT_NONE
1046            fp = urllib2.urlopen(req, json.dumps(data), context=ctx)
1047        elif args.ctlsslcert:
1048            raise EnvironmentError('do not support certificate verification')
1049        else:
1050            fp = urllib2.urlopen(req, json.dumps(data))
1051        return json.loads(fp.read())
1052
1053    def print_table(rows):
1054        cols = zip(*rows)
1055        maxlen = [0] * len(cols)
1056        for i in xrange(len(cols)):
1057            maxlen[i] = max(len(str(c)) for c in cols[i])
1058        fmt = '  '.join(['%%-%ds' % maxlen[i] for i in xrange(len(cols))])
1059        fmt = '  ' + fmt
1060        for row in rows:
1061            print(fmt % tuple(row))
1062
1063    def print_portlist(result):
1064        rows = [['Port', 'Type', 'State', 'RX', 'TX', 'Target']]
1065        for r in result:
1066            rows.append([r['port'], r['type'], 'shut' if r['shut'] else 'up',
1067                         r['rx'], r['tx'], r['target']])
1068        print_table(rows)
1069
1070    def print_iflist(result):
1071        rows = [['Port', 'Type', 'Address', 'Netmask', 'MTU', 'Target']]
1072        for r in result:
1073            rows.append([r['port'], r['type'], r['address'],
1074                         r['netmask'], r['mtu'], r['target']])
1075        print_table(rows)
1076
1077    def handle_ctl_addport(args):
1078        opts = {
1079            'user':     getattr(args, 'user', None),
1080            'passwd':   getattr(args, 'passwd', None),
1081            'cacerts':  getattr(args, 'cacerts', None),
1082            'insecure': getattr(args, 'insecure', None),
1083        }
1084        if args.iftype == ClientHandler.IFTYPE:
1085            if not args.target.startswith('ws://') and \
1086               not args.target.startswith('wss://'):
1087                raise ValueError('Invalid target URL scheme: %s' % args.target)
1088            if not opts['user'] and opts['passwd']:
1089                raise ValueError('Authentication required but username empty')
1090            if opts['user'] and not opts['passwd']:
1091                opts['passwd'] = getpass.getpass('Client Password: ')
1092        result = request(args, 'addPort', {
1093            'type':    args.iftype,
1094            'target':  args.target,
1095            'options': opts,
1096        })
1097        if 'error' in result:
1098            _print_error(result['error'])
1099        else:
1100            print_portlist(result['result']['entries'])
1101
1102    def handle_ctl_setport(args):
1103        if args.port <= 0:
1104            raise ValueError('Invalid port: %d' % args.port)
1105        req = {'port': args.port}
1106        shut = getattr(args, 'shut', None)
1107        if shut is not None:
1108            req['shut'] = bool(shut)
1109        result = request(args, 'setPort', req)
1110        if 'error' in result:
1111            _print_error(result['error'])
1112        else:
1113            print_portlist(result['result']['entries'])
1114
1115    def handle_ctl_delport(args):
1116        if args.port <= 0:
1117            raise ValueError('Invalid port: %d' % args.port)
1118        result = request(args, 'delPort', {'port': args.port})
1119        if 'error' in result:
1120            _print_error(result['error'])
1121        else:
1122            print_portlist(result['result']['entries'])
1123
1124    def handle_ctl_listport(args):
1125        result = request(args, 'listPort')
1126        if 'error' in result:
1127            _print_error(result['error'])
1128        else:
1129            print_portlist(result['result']['entries'])
1130
1131    def handle_ctl_setif(args):
1132        if args.port <= 0:
1133            raise ValueError('Invalid port: %d' % args.port)
1134        req = {'port': args.port}
1135        address = getattr(args, 'address', None)
1136        netmask = getattr(args, 'netmask', None)
1137        mtu = getattr(args, 'mtu', None)
1138        if address is not None:
1139            if address:
1140                socket.inet_aton(address)  # validate
1141            req['address'] = address
1142        if netmask is not None:
1143            if netmask:
1144                socket.inet_aton(netmask)  # validate
1145            req['netmask'] = netmask
1146        if mtu is not None:
1147            if mtu < 576:
1148                raise ValueError('Invalid MTU: %d' % mtu)
1149            req['mtu'] = mtu
1150        result = request(args, 'setInterface', req)
1151        if 'error' in result:
1152            _print_error(result['error'])
1153        else:
1154            print_iflist(result['result']['entries'])
1155
1156    def handle_ctl_listif(args):
1157        result = request(args, 'listInterface')
1158        if 'error' in result:
1159            _print_error(result['error'])
1160        else:
1161            print_iflist(result['result']['entries'])
1162
1163    def handle_ctl_listfdb(args):
1164        result = request(args, 'listFdb')
1165        if 'error' in result:
1166            return _print_error(result['error'])
1167        rows = [['Port', 'VLAN', 'MAC', 'Age']]
1168        for r in result['result']['entries']:
1169            rows.append([r['port'], r['vid'], r['mac'], r['age']])
1170        print_table(rows)
1171
1172    locals()['handle_ctl_' + args.control_method](args)
1173
1174
1175def _main():
1176    parser = argparse.ArgumentParser()
1177    subcommand = parser.add_subparsers(dest='subcommand')
1178
1179    # - sw
1180    parser_sw = subcommand.add_parser('sw',
1181                                      help='start virtual switch')
1182
1183    parser_sw.add_argument('--debug', action='store_true', default=False,
1184                           help='run as debug mode')
1185    parser_sw.add_argument('--foreground', action='store_true', default=False,
1186                           help='run as foreground mode')
1187    parser_sw.add_argument('--ageout', type=int, default=300,
1188                           help='FDB ageout time (sec)')
1189
1190    parser_sw.add_argument('--path', default='/',
1191                           help='http(s) path to serve WebSocket')
1192    parser_sw.add_argument('--host', default='',
1193                           help='listen address to serve WebSocket')
1194    parser_sw.add_argument('--port', type=int,
1195                           help='listen port to serve WebSocket')
1196    parser_sw.add_argument('--htpasswd',
1197                           help='path to htpasswd file to auth WebSocket')
1198    parser_sw.add_argument('--sslkey',
1199                           help='path to SSL private key for WebSocket')
1200    parser_sw.add_argument('--sslcert',
1201                           help='path to SSL certificate for WebSocket')
1202
1203    parser_sw.add_argument('--ctlpath', default='/ctl',
1204                           help='http(s) path to serve control API')
1205    parser_sw.add_argument('--ctlhost', default='127.0.0.1',
1206                           help='listen address to serve control API')
1207    parser_sw.add_argument('--ctlport', type=int, default=7867,
1208                           help='listen port to serve control API')
1209    parser_sw.add_argument('--ctlhtpasswd',
1210                           help='path to htpasswd file to auth control API')
1211    parser_sw.add_argument('--ctlsslkey',
1212                           help='path to SSL private key for control API')
1213    parser_sw.add_argument('--ctlsslcert',
1214                           help='path to SSL certificate for control API')
1215
1216    # - ctl
1217    parser_ctl = subcommand.add_parser('ctl',
1218                                       help='control virtual switch')
1219    parser_ctl.add_argument('--ctlurl', default='http://127.0.0.1:7867/ctl',
1220                            help='URL to control API')
1221    parser_ctl.add_argument('--ctluser',
1222                            help='username to auth control API')
1223    parser_ctl.add_argument('--ctlpasswd',
1224                            help='password to auth control API')
1225    parser_ctl.add_argument('--ctlsslcert',
1226                            help='path to SSL certificate for control API')
1227    parser_ctl.add_argument(
1228        '--ctlinsecure', action='store_true', default=False,
1229        help='do not verify control API certificate')
1230
1231    control_method = parser_ctl.add_subparsers(dest='control_method')
1232
1233    # -- ctl addport
1234    parser_ctl_addport = control_method.add_parser('addport',
1235                                                   help='create and add port')
1236    iftype = parser_ctl_addport.add_subparsers(dest='iftype')
1237
1238    # --- ctl addport netdev
1239    parser_ctl_addport_netdev = iftype.add_parser(NetdevHandler.IFTYPE,
1240                                                  help='Network device')
1241    parser_ctl_addport_netdev.add_argument('target',
1242                                           help='device name to add interface')
1243
1244    # --- ctl addport tap
1245    parser_ctl_addport_tap = iftype.add_parser(TapHandler.IFTYPE,
1246                                               help='TAP device')
1247    parser_ctl_addport_tap.add_argument('target',
1248                                        help='device name to create interface')
1249
1250    # --- ctl addport client
1251    parser_ctl_addport_client = iftype.add_parser(ClientHandler.IFTYPE,
1252                                                  help='WebSocket client')
1253    parser_ctl_addport_client.add_argument('target',
1254                                           help='URL to connect WebSocket')
1255    parser_ctl_addport_client.add_argument('--user',
1256                                           help='username to auth WebSocket')
1257    parser_ctl_addport_client.add_argument('--passwd',
1258                                           help='password to auth WebSocket')
1259    parser_ctl_addport_client.add_argument('--cacerts',
1260                                           help='path to CA certificate')
1261    parser_ctl_addport_client.add_argument(
1262        '--insecure', action='store_true', default=False,
1263        help='do not verify server certificate')
1264
1265    # -- ctl setport
1266    parser_ctl_setport = control_method.add_parser('setport',
1267                                                   help='set port config')
1268    parser_ctl_setport.add_argument('port', type=int,
1269                                    help='port number to set config')
1270    parser_ctl_setport.add_argument('--shut', type=int, choices=(0, 1),
1271                                    help='set shutdown state')
1272
1273    # -- ctl delport
1274    parser_ctl_delport = control_method.add_parser('delport',
1275                                                   help='delete port')
1276    parser_ctl_delport.add_argument('port', type=int,
1277                                    help='port number to delete')
1278
1279    # -- ctl listport
1280    parser_ctl_listport = control_method.add_parser('listport',
1281                                                    help='show port list')
1282
1283    # -- ctl setif
1284    parser_ctl_setif = control_method.add_parser('setif',
1285                                                 help='set interface config')
1286    parser_ctl_setif.add_argument('port', type=int,
1287                                  help='port number to set config')
1288    parser_ctl_setif.add_argument('--address',
1289                                  help='IPv4 address to set interface')
1290    parser_ctl_setif.add_argument('--netmask',
1291                                  help='IPv4 netmask to set interface')
1292    parser_ctl_setif.add_argument('--mtu', type=int,
1293                                  help='MTU to set interface')
1294
1295    # -- ctl listif
1296    parser_ctl_listif = control_method.add_parser('listif',
1297                                                  help='show interface list')
1298
1299    # -- ctl listfdb
1300    parser_ctl_listfdb = control_method.add_parser('listfdb',
1301                                                   help='show FDB entries')
1302
1303    # -- go
1304    args = parser.parse_args()
1305
1306    try:
1307        globals()['_start_' + args.subcommand](args)
1308    except Exception as e:
1309        _print_error({
1310            'code':    0 - 32603,
1311            'message': 'Internal error',
1312            'data':    '%s: %s' % (e.__class__.__name__, str(e)),
1313        })
1314
1315
1316if __name__ == '__main__':
1317    _main()
Note: See TracBrowser for help on using the repository browser.