source: etherws/trunk/etherws.py @ 287

Revision 287, 44.9 KB checked in by atzm, 8 years ago (diff)
  • add --noorigin option for sw
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.5
8#   - python-pytun-2.1
9#   - websocket-client-0.14.0
10#   - tornado-2.4
11#
12# ===========================================================================
13# Copyright (c) 2012-2015, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
42import ssl
43import time
44import json
45import fcntl
46import base64
47import socket
48import struct
49import urllib2
50import hashlib
51import getpass
52import operator
53import argparse
54
55import logging
56import logging.config
57
58import tornado
59import websocket
60
61from tornado.web import Application, RequestHandler
62from tornado.websocket import WebSocketHandler
63from tornado.httpserver import HTTPServer
64from tornado.ioloop import IOLoop
65
66from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
67
68
69class EthernetFrame(object):
70    def __init__(self, data):
71        self.data = data
72
73    @property
74    def dst_multicast(self):
75        return ord(self.data[0]) & 1
76
77    @property
78    def src_multicast(self):
79        return ord(self.data[6]) & 1
80
81    @property
82    def dst_mac(self):
83        return self.data[:6]
84
85    @property
86    def src_mac(self):
87        return self.data[6:12]
88
89    @property
90    def tagged(self):
91        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
92
93    @property
94    def vid(self):
95        if self.tagged:
96            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
97        return 0
98
99    @staticmethod
100    def format_mac(mac, sep=':'):
101        return sep.join(b.encode('hex') for b in mac)
102
103
104class FDB(object):
105    class Entry(object):
106        def __init__(self, port, ageout):
107            self.port = port
108            self._time = time.time()
109            self._ageout = ageout
110
111        @property
112        def age(self):
113            return time.time() - self._time
114
115        @property
116        def agedout(self):
117            return self.age > self._ageout
118
119    def __init__(self, ageout, logger, debug):
120        self._ageout = ageout
121        self._logger = logger
122        self._debug = debug
123        self._table = {}
124
125    def _set_entry(self, vid, mac, port):
126        if vid not in self._table:
127            self._table[vid] = {}
128        self._table[vid][mac] = self.Entry(port, self._ageout)
129
130    def _del_entry(self, vid, mac):
131        if vid in self._table:
132            if mac in self._table[vid]:
133                del self._table[vid][mac]
134            if not self._table[vid]:
135                del self._table[vid]
136
137    def _get_entry(self, vid, mac):
138        try:
139            entry = self._table[vid][mac]
140        except KeyError:
141            return None
142
143        if not entry.agedout:
144            return entry
145
146        self._del_entry(vid, mac)
147
148        if self._debug:
149            self._logger.debug('fdb aged out: port:%d; vid:%d; mac:%s',
150                               entry.port.number, vid, mac.encode('hex'))
151
152    def each(self):
153        for vid in sorted(self._table.iterkeys()):
154            for mac in sorted(self._table[vid].iterkeys()):
155                entry = self._get_entry(vid, mac)
156                if entry:
157                    yield (vid, mac, entry)
158
159    def lookup(self, frame):
160        mac = frame.dst_mac
161        vid = frame.vid
162        entry = self._get_entry(vid, mac)
163        return getattr(entry, 'port', None)
164
165    def learn(self, port, frame):
166        mac = frame.src_mac
167        vid = frame.vid
168        self._set_entry(vid, mac, port)
169
170        if self._debug:
171            self._logger.debug('fdb learned: port:%d; vid:%d; mac:%s',
172                               port.number, vid, mac.encode('hex'))
173
174    def delete(self, port):
175        for vid, mac, entry in self.each():
176            if entry.port.number == port.number:
177                self._del_entry(vid, mac)
178
179                if self._debug:
180                    self._logger.debug('fdb deleted: port:%d; vid:%d; mac:%s',
181                                       port.number, vid, mac.encode('hex'))
182
183
184class SwitchingHub(object):
185    class Port(object):
186        def __init__(self, number, interface):
187            self.number = number
188            self.interface = interface
189            self.tx = 0
190            self.rx = 0
191            self.shut = False
192
193    def __init__(self, fdb, logger, debug):
194        self.fdb = fdb
195        self._logger = logger
196        self._debug = debug
197        self._table = {}
198        self._next = 1
199
200    @property
201    def portlist(self):
202        return sorted(self._table.itervalues(),
203                      key=operator.attrgetter('number'))
204
205    def get_port(self, portnum):
206        return self._table[portnum]
207
208    def register_port(self, interface):
209        try:
210            self._set_privattr('portnum', interface, self._next)  # XXX
211            self._table[self._next] = self.Port(self._next, interface)
212            return self._next
213        finally:
214            self._next += 1
215
216    def unregister_port(self, interface):
217        portnum = self._get_privattr('portnum', interface)
218        self._del_privattr('portnum', interface)
219        self.fdb.delete(self._table[portnum])
220        del self._table[portnum]
221
222    def send(self, dst_interfaces, frame):
223        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
224        ports = (self._table[n] for n in portnums)
225        ports = (p for p in ports if not p.shut)
226        ports = sorted(ports, key=operator.attrgetter('number'))
227
228        for p in ports:
229            p.interface.write_message(frame.data, True)
230            p.tx += 1
231
232        if self._debug and ports:
233            self._logger.debug('sent: port:%s; vid:%d; %s -> %s',
234                               ','.join(str(p.number) for p in ports),
235                               frame.vid,
236                               frame.src_mac.encode('hex'),
237                               frame.dst_mac.encode('hex'))
238
239    def receive(self, src_interface, frame):
240        port = self._table[self._get_privattr('portnum', src_interface)]
241
242        if not port.shut:
243            port.rx += 1
244            self._forward(port, frame)
245
246    def _forward(self, src_port, frame):
247        try:
248            if not frame.src_multicast:
249                self.fdb.learn(src_port, frame)
250
251            if not frame.dst_multicast:
252                dst_port = self.fdb.lookup(frame)
253
254                if dst_port:
255                    self.send([dst_port.interface], frame)
256                    return
257
258            ports = set(self.portlist) - set([src_port])
259            self.send((p.interface for p in ports), frame)
260
261        except:  # ex. received invalid frame
262            self._logger.exception(
263                'caught error while forwarding frame: port:%d; frame:%s',
264                src_port.number, frame.data.encode('hex'))
265
266    def _privattr(self, name):
267        return '_%s_%s_%s' % (type(self).__name__, id(self), name)
268
269    def _set_privattr(self, name, obj, value):
270        return setattr(obj, self._privattr(name), value)
271
272    def _get_privattr(self, name, obj, defaults=None):
273        return getattr(obj, self._privattr(name), defaults)
274
275    def _del_privattr(self, name, obj):
276        return delattr(obj, self._privattr(name))
277
278
279class NetworkInterface(object):
280    SIOCGIFADDR = 0x8915  # from <linux/sockios.h>
281    SIOCSIFADDR = 0x8916
282    SIOCGIFNETMASK = 0x891b
283    SIOCSIFNETMASK = 0x891c
284    SIOCGIFMTU = 0x8921
285    SIOCSIFMTU = 0x8922
286
287    def __init__(self, ifname):
288        self._ifname = struct.pack('16s', str(ifname)[:15])
289
290    def _ioctl(self, req, data):
291        try:
292            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
293            return fcntl.ioctl(sock.fileno(), req, data)
294        finally:
295            sock.close()
296
297    @property
298    def address(self):
299        try:
300            data = struct.pack('16s18x', self._ifname)
301            ret = self._ioctl(self.SIOCGIFADDR, data)
302            return socket.inet_ntoa(ret[20:24])
303        except IOError:
304            return ''
305
306    @property
307    def netmask(self):
308        try:
309            data = struct.pack('16s18x', self._ifname)
310            ret = self._ioctl(self.SIOCGIFNETMASK, data)
311            return socket.inet_ntoa(ret[20:24])
312        except IOError:
313            return ''
314
315    @property
316    def mtu(self):
317        try:
318            data = struct.pack('16s18x', self._ifname)
319            ret = self._ioctl(self.SIOCGIFMTU, data)
320            return struct.unpack('i', ret[16:20])[0]
321        except IOError:
322            return 0
323
324    @address.setter
325    def address(self, addr):
326        data = struct.pack('16si4s10x', self._ifname, socket.AF_INET,
327                           socket.inet_aton(addr))
328        self._ioctl(self.SIOCSIFADDR, data)
329
330    @netmask.setter
331    def netmask(self, addr):
332        data = struct.pack('16si4s10x', self._ifname, socket.AF_INET,
333                           socket.inet_aton(addr))
334        self._ioctl(self.SIOCSIFNETMASK, data)
335
336    @mtu.setter
337    def mtu(self, mtu):
338        data = struct.pack('16si12x', self._ifname, mtu)
339        self._ioctl(self.SIOCSIFMTU, data)
340
341
342class Htpasswd(object):
343    def __init__(self, path):
344        self._path = path
345        self._stat = None
346        self._data = {}
347
348    def auth(self, name, passwd):
349        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
350        return self._data.get(name) == passwd
351
352    def load(self):
353        old_stat = self._stat
354
355        with open(self._path) as fp:
356            fileno = fp.fileno()
357            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
358            self._stat = os.fstat(fileno)
359
360            unchanged = old_stat and \
361                old_stat.st_ino == self._stat.st_ino and \
362                old_stat.st_dev == self._stat.st_dev and \
363                old_stat.st_mtime == self._stat.st_mtime
364
365            if not unchanged:
366                self._data = self._parse(fp)
367
368        return self
369
370    def _parse(self, fp):
371        data = {}
372        for line in fp:
373            line = line.strip()
374            if 0 <= line.find(':'):
375                name, passwd = line.split(':', 1)
376                if passwd.startswith('{SHA}'):
377                    data[name] = passwd[5:]
378        return data
379
380
381class BasicAuthMixIn(object):
382    def _execute(self, transforms, *args, **kwargs):
383        def do_execute():
384            sp = super(BasicAuthMixIn, self)
385            return sp._execute(transforms, *args, **kwargs)
386
387        def auth_required():
388            stream = getattr(self, 'stream', self.request.connection.stream)
389            stream.write(tornado.escape.utf8(
390                'HTTP/1.1 401 Authorization Required\r\n'
391                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
392            ))
393            stream.close()
394
395        try:
396            if not self._htpasswd:
397                return do_execute()
398
399            creds = self.request.headers.get('Authorization')
400
401            if not creds or not creds.startswith('Basic '):
402                return auth_required()
403
404            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
405
406            if self._htpasswd.load().auth(name, passwd):
407                return do_execute()
408        except:
409            self._logger.exception('caught error while doing authorization')
410
411        return auth_required()
412
413
414class ServerHandler(BasicAuthMixIn, WebSocketHandler):
415    IFTYPE = 'server'
416    IFOP_ALLOWED = False
417
418    def __init__(self, app, req, switch, htpasswd, noorigin, logger, debug):
419        super(ServerHandler, self).__init__(app, req)
420        self._switch = switch
421        self._htpasswd = htpasswd
422        self._noorigin = noorigin
423        self._logger = logger
424        self._debug = debug
425
426    @property
427    def target(self):
428        conn = self.request.connection
429        if hasattr(conn, 'address'):
430            return ':'.join(str(e) for e in conn.address[:2])
431        if hasattr(conn, 'context'):
432            return ':'.join(str(e) for e in conn.context.address[:2])
433        return str(conn)
434
435    def open(self):
436        try:
437            return self._switch.register_port(self)
438        finally:
439            self._logger.info('connected: %s', self.request.remote_ip)
440
441    def on_message(self, message):
442        self._switch.receive(self, EthernetFrame(message))
443
444    def on_close(self):
445        self._switch.unregister_port(self)
446        self._logger.info('disconnected: %s', self.request.remote_ip)
447
448    def check_origin(self, origin):
449        if self._noorigin:
450            return True
451        return super(ServerHandler, self).check_origin(origin)
452
453
454class BaseClientHandler(object):
455    IFTYPE = 'baseclient'
456    IFOP_ALLOWED = False
457
458    def __init__(self, ioloop, switch, target, logger, debug, *args, **kwargs):
459        self._ioloop = ioloop
460        self._switch = switch
461        self._target = target
462        self._logger = logger
463        self._debug = debug
464        self._args = args
465        self._kwargs = kwargs
466        self._device = None
467
468    @property
469    def address(self):
470        raise NotImplementedError('unsupported')
471
472    @property
473    def netmask(self):
474        raise NotImplementedError('unsupported')
475
476    @property
477    def mtu(self):
478        raise NotImplementedError('unsupported')
479
480    @address.setter
481    def address(self, address):
482        raise NotImplementedError('unsupported')
483
484    @netmask.setter
485    def netmask(self, netmask):
486        raise NotImplementedError('unsupported')
487
488    @mtu.setter
489    def mtu(self, mtu):
490        raise NotImplementedError('unsupported')
491
492    def open(self):
493        raise NotImplementedError('unsupported')
494
495    def write_message(self, message, binary=False):
496        raise NotImplementedError('unsupported')
497
498    def read(self):
499        raise NotImplementedError('unsupported')
500
501    @property
502    def target(self):
503        return self._target
504
505    @property
506    def logger(self):
507        return self._logger
508
509    @property
510    def device(self):
511        return self._device
512
513    @property
514    def closed(self):
515        return not self.device
516
517    def close(self):
518        if self.closed:
519            raise ValueError('I/O operation on closed %s' % self.IFTYPE)
520        self.leave_switch()
521        self.unregister_device()
522        self.logger.info('disconnected: %s', self.target)
523
524    def register_device(self, device):
525        self._device = device
526
527    def unregister_device(self):
528        self._device.close()
529        self._device = None
530
531    def fileno(self):
532        if self.closed:
533            raise ValueError('I/O operation on closed %s' % self.IFTYPE)
534        return self.device.fileno()
535
536    def __call__(self, fd, events):
537        try:
538            data = self.read()
539            if data is not None:
540                self._switch.receive(self, EthernetFrame(data))
541                return
542        except:
543            self.logger.exception('caught error while receiving frame')
544        self.close()
545
546    def join_switch(self):
547        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
548        return self._switch.register_port(self)
549
550    def leave_switch(self):
551        self._switch.unregister_port(self)
552        self._ioloop.remove_handler(self.fileno())
553
554
555class NetdevHandler(BaseClientHandler):
556    IFTYPE = 'netdev'
557    IFOP_ALLOWED = True
558    ETH_P_ALL = 0x0003  # from <linux/if_ether.h>
559
560    @property
561    def address(self):
562        if self.closed:
563            raise ValueError('I/O operation on closed netdev')
564        return NetworkInterface(self.target).address
565
566    @property
567    def netmask(self):
568        if self.closed:
569            raise ValueError('I/O operation on closed netdev')
570        return NetworkInterface(self.target).netmask
571
572    @property
573    def mtu(self):
574        if self.closed:
575            raise ValueError('I/O operation on closed netdev')
576        return NetworkInterface(self.target).mtu
577
578    @address.setter
579    def address(self, address):
580        if self.closed:
581            raise ValueError('I/O operation on closed netdev')
582        NetworkInterface(self.target).address = address
583
584    @netmask.setter
585    def netmask(self, netmask):
586        if self.closed:
587            raise ValueError('I/O operation on closed netdev')
588        NetworkInterface(self.target).netmask = netmask
589
590    @mtu.setter
591    def mtu(self, mtu):
592        if self.closed:
593            raise ValueError('I/O operation on closed netdev')
594        NetworkInterface(self.target).mtu = mtu
595
596    def open(self):
597        if not self.closed:
598            raise ValueError('Already opened')
599        self.register_device(socket.socket(
600            socket.PF_PACKET, socket.SOCK_RAW, socket.htons(self.ETH_P_ALL)))
601        self.device.bind((self.target, self.ETH_P_ALL))
602        self.logger.info('connected: %s', self.target)
603        return self.join_switch()
604
605    def write_message(self, message, binary=False):
606        if self.closed:
607            raise ValueError('I/O operation on closed netdev')
608        self.device.sendall(message)
609
610    def read(self):
611        if self.closed:
612            raise ValueError('I/O operation on closed netdev')
613        buf = []
614        while True:
615            buf.append(self.device.recv(65535))
616            if len(buf[-1]) < 65535:
617                break
618        return ''.join(buf)
619
620
621class TapHandler(BaseClientHandler):
622    IFTYPE = 'tap'
623    IFOP_ALLOWED = True
624
625    @property
626    def address(self):
627        if self.closed:
628            raise ValueError('I/O operation on closed tap')
629        try:
630            return self.device.addr
631        except:
632            return ''
633
634    @property
635    def netmask(self):
636        if self.closed:
637            raise ValueError('I/O operation on closed tap')
638        try:
639            return self.device.netmask
640        except:
641            return ''
642
643    @property
644    def mtu(self):
645        if self.closed:
646            raise ValueError('I/O operation on closed tap')
647        return self.device.mtu
648
649    @address.setter
650    def address(self, address):
651        if self.closed:
652            raise ValueError('I/O operation on closed tap')
653        self.device.addr = address
654
655    @netmask.setter
656    def netmask(self, netmask):
657        if self.closed:
658            raise ValueError('I/O operation on closed tap')
659        self.device.netmask = netmask
660
661    @mtu.setter
662    def mtu(self, mtu):
663        if self.closed:
664            raise ValueError('I/O operation on closed tap')
665        self.device.mtu = mtu
666
667    @property
668    def target(self):
669        if self.closed:
670            return self._target
671        return self.device.name
672
673    def open(self):
674        if not self.closed:
675            raise ValueError('Already opened')
676        self.register_device(TunTapDevice(self.target, IFF_TAP | IFF_NO_PI))
677        self.device.up()
678        self.logger.info('connected: %s', self.target)
679        return self.join_switch()
680
681    def write_message(self, message, binary=False):
682        if self.closed:
683            raise ValueError('I/O operation on closed tap')
684        self.device.write(message)
685
686    def read(self):
687        if self.closed:
688            raise ValueError('I/O operation on closed tap')
689        buf = []
690        while True:
691            buf.append(self.device.read(65535))
692            if len(buf[-1]) < 65535:
693                break
694        return ''.join(buf)
695
696
697class ClientHandler(BaseClientHandler):
698    IFTYPE = 'client'
699    IFOP_ALLOWED = False
700
701    def __init__(self, *args, **kwargs):
702        super(ClientHandler, self).__init__(*args, **kwargs)
703        self._sslopt = self._kwargs.get('sslopt', {})
704        self._options = dict(self.authopts, **self.proxyopts)
705
706    def open(self):
707        if not self.closed:
708            raise websocket.WebSocketException('Already opened')
709
710        # XXX: may be blocked
711        self.register_device(websocket.WebSocket(sslopt=self._sslopt))
712        self.device.connect(self.target, **self._options)
713        self.logger.info('connected: %s', self.target)
714        return self.join_switch()
715
716    def write_message(self, message, binary=False):
717        if self.closed:
718            raise websocket.WebSocketException('Closed socket')
719        if binary:
720            flag = websocket.ABNF.OPCODE_BINARY
721        else:
722            flag = websocket.ABNF.OPCODE_TEXT
723        self.device.send(message, flag)
724
725    def read(self):
726        if self.closed:
727            raise websocket.WebSocketException('Closed socket')
728        return self.device.recv()
729
730    @property
731    def authopts(self):
732        cred = self._kwargs.get('cred')
733
734        if not isinstance(cred, dict):
735            return {}
736        if cred.get('user') is None:
737            return {}
738        if cred.get('passwd') is None:
739            return {}
740
741        token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
742        return {'header': ['Authorization: Basic %s' % token]}
743
744    @property
745    def proxyopts(self):
746        scheme, remain = urllib2.splittype(self.target)
747
748        proxy = urllib2.getproxies().get(scheme.replace('ws', 'http'))
749        if not proxy:
750            return {}
751
752        host = urllib2.splitport(urllib2.splithost(remain)[0])[0]
753        if urllib2.proxy_bypass(host):
754            return {}
755
756        hostport = urllib2.splithost(urllib2.splittype(proxy)[1])[0]
757        host, port = urllib2.splitport(hostport)
758        port = int(port) if port else 0
759        return {'http_proxy_host': host, 'http_proxy_port': port}
760
761
762class ControlServerHandler(BasicAuthMixIn, RequestHandler):
763    NAMESPACE = 'etherws.control'
764    IFTYPES = {
765        NetdevHandler.IFTYPE: NetdevHandler,
766        TapHandler.IFTYPE:    TapHandler,
767        ClientHandler.IFTYPE: ClientHandler,
768    }
769
770    def __init__(self, app, req, ioloop, switch, htpasswd, logger, debug):
771        super(ControlServerHandler, self).__init__(app, req)
772        self._ioloop = ioloop
773        self._switch = switch
774        self._htpasswd = htpasswd
775        self._logger = logger
776        self._debug = debug
777
778    def post(self):
779        try:
780            request = json.loads(self.request.body)
781        except Exception as e:
782            return self._jsonrpc_response(error={
783                'code':    0 - 32700,
784                'message': 'Parse error',
785                'data':    '%s: %s' % (type(e).__name__, str(e)),
786            })
787
788        try:
789            id_ = request.get('id')
790            params = request.get('params')
791            version = request['jsonrpc']
792            method = request['method']
793            if version != '2.0':
794                raise ValueError('Invalid JSON-RPC version: %s' % version)
795        except Exception as e:
796            return self._jsonrpc_response(id_=id_, error={
797                'code':    0 - 32600,
798                'message': 'Invalid Request',
799                'data':    '%s: %s' % (type(e).__name__, str(e)),
800            })
801
802        try:
803            if not method.startswith(self.NAMESPACE + '.'):
804                raise ValueError('Invalid method namespace: %s' % method)
805            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
806            handler = getattr(self, handler)
807        except Exception as e:
808            return self._jsonrpc_response(id_=id_, error={
809                'code':    0 - 32601,
810                'message': 'Method not found',
811                'data':    '%s: %s' % (type(e).__name__, str(e)),
812            })
813
814        try:
815            return self._jsonrpc_response(id_=id_, result=handler(params))
816        except Exception as e:
817            self._logger.exception('caught error while processing request')
818            return self._jsonrpc_response(id_=id_, error={
819                'code':    0 - 32602,
820                'message': 'Invalid params',
821                'data':     '%s: %s' % (type(e).__name__, str(e)),
822            })
823
824    def handle_listFdb(self, params):
825        list_ = []
826        for vid, mac, entry in self._switch.fdb.each():
827            list_.append({
828                'vid':  vid,
829                'mac':  EthernetFrame.format_mac(mac),
830                'port': entry.port.number,
831                'age':  int(entry.age),
832            })
833        return {'entries': list_}
834
835    def handle_listPort(self, params):
836        return {'entries': [self._portstat(p) for p in self._switch.portlist]}
837
838    def handle_addPort(self, params):
839        type_ = params['type']
840        target = params['target']
841        opt = getattr(self, '_optparse_' + type_)(params.get('options', {}))
842        cls = self.IFTYPES[type_]
843        interface = cls(self._ioloop, self._switch, target,
844                        self._logger, self._debug, **opt)
845        portnum = interface.open()
846        return {'entries': [self._portstat(self._switch.get_port(portnum))]}
847
848    def handle_setPort(self, params):
849        port = self._switch.get_port(int(params['port']))
850        shut = params.get('shut')
851        if shut is not None:
852            port.shut = bool(shut)
853        return {'entries': [self._portstat(port)]}
854
855    def handle_delPort(self, params):
856        port = self._switch.get_port(int(params['port']))
857        port.interface.close()
858        return {'entries': [self._portstat(port)]}
859
860    def handle_setInterface(self, params):
861        portnum = int(params['port'])
862        port = self._switch.get_port(portnum)
863        address = params.get('address')
864        netmask = params.get('netmask')
865        mtu = params.get('mtu')
866        if not port.interface.IFOP_ALLOWED:
867            raise ValueError('Port %d has unsupported interface: %s' %
868                             (portnum, port.interface.IFTYPE))
869        if address is not None:
870            port.interface.address = address
871        if netmask is not None:
872            port.interface.netmask = netmask
873        if mtu is not None:
874            port.interface.mtu = mtu
875        return {'entries': [self._ifstat(port)]}
876
877    def handle_listInterface(self, params):
878        return {'entries': [self._ifstat(p) for p in self._switch.portlist
879                            if p.interface.IFOP_ALLOWED]}
880
881    def _optparse_netdev(self, opt):
882        return {}
883
884    def _optparse_tap(self, opt):
885        return {}
886
887    def _optparse_client(self, opt):
888        if opt.get('insecure'):
889            sslopt = {'cert_reqs': ssl.CERT_NONE}
890        else:
891            sslopt = {'cert_reqs': ssl.CERT_REQUIRED,
892                      'ca_certs':  opt.get('cacerts')}
893        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
894        return {'sslopt': sslopt, 'cred': cred}
895
896    def _jsonrpc_response(self, id_=None, result=None, error=None):
897        res = {'jsonrpc': '2.0', 'id': id_}
898        if result:
899            res['result'] = result
900        if error:
901            res['error'] = error
902        self.finish(res)
903
904    @staticmethod
905    def _portstat(port):
906        return {
907            'port':   port.number,
908            'type':   port.interface.IFTYPE,
909            'target': port.interface.target,
910            'tx':     port.tx,
911            'rx':     port.rx,
912            'shut':   port.shut,
913        }
914
915    @staticmethod
916    def _ifstat(port):
917        return {
918            'port':    port.number,
919            'type':    port.interface.IFTYPE,
920            'target':  port.interface.target,
921            'address': port.interface.address,
922            'netmask': port.interface.netmask,
923            'mtu':     port.interface.mtu,
924        }
925
926
927def _print_error(error):
928    print(%s (%s)' % (error['message'], error['code']))
929    print('    %s' % error['data'])
930
931
932def _start_sw(args):
933    def daemonize(nochdir=False, noclose=False):
934        if os.fork() > 0:
935            sys.exit(0)
936
937        os.setsid()
938
939        if os.fork() > 0:
940            sys.exit(0)
941
942        if not nochdir:
943            os.chdir('/')
944
945        if not noclose:
946            os.umask(0)
947            sys.stdin.close()
948            sys.stdout.close()
949            sys.stderr.close()
950            os.close(0)
951            os.close(1)
952            os.close(2)
953            sys.stdin = open(os.devnull)
954            sys.stdout = open(os.devnull, 'a')
955            sys.stderr = open(os.devnull, 'a')
956
957    def checkabspath(ns, path):
958        val = getattr(ns, path, '')
959        if not val.startswith('/'):
960            raise ValueError('Invalid %: %s' % (path, val))
961
962    def getsslopt(ns, key, cert):
963        kval = getattr(ns, key, None)
964        cval = getattr(ns, cert, None)
965        if kval and cval:
966            return {'keyfile': kval, 'certfile': cval}
967        elif kval or cval:
968            raise ValueError('Both %s and %s are required' % (key, cert))
969        return None
970
971    def setrealpath(ns, *keys):
972        for k in keys:
973            v = getattr(ns, k, None)
974            if v is not None:
975                v = os.path.realpath(v)
976                open(v).close()  # check readable
977                setattr(ns, k, v)
978
979    def setport(ns, port, isssl):
980        val = getattr(ns, port, None)
981        if val is None:
982            if isssl:
983                return setattr(ns, port, 443)
984            return setattr(ns, port, 80)
985        if not (0 <= val <= 65535):
986            raise ValueError('Invalid %s: %s' % (port, val))
987
988    def sethtpasswd(ns, htpasswd):
989        val = getattr(ns, htpasswd, None)
990        if val:
991            return setattr(ns, htpasswd, Htpasswd(val))
992
993    # if args.debug:
994    #     websocket.enableTrace(True)
995
996    if args.ageout <= 0:
997        raise ValueError('Invalid ageout: %s' % args.ageout)
998
999    setrealpath(args, 'logconf')
1000    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
1001    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
1002
1003    checkabspath(args, 'path')
1004    checkabspath(args, 'ctlpath')
1005
1006    sslopt = getsslopt(args, 'sslkey', 'sslcert')
1007    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
1008
1009    setport(args, 'port', sslopt)
1010    setport(args, 'ctlport', ctlsslopt)
1011
1012    sethtpasswd(args, 'htpasswd')
1013    sethtpasswd(args, 'ctlhtpasswd')
1014
1015    if args.logconf:
1016        logging.config.fileConfig(args.logconf)
1017        logger = logging.getLogger('etherws')
1018    else:
1019        logger = logging.getLogger('etherws')
1020        logger.setLevel(logging.DEBUG if args.debug else logging.INFO)
1021
1022    ioloop = IOLoop.instance()
1023    fdb = FDB(args.ageout, logger, args.debug)
1024    switch = SwitchingHub(fdb, logger, args.debug)
1025
1026    if args.port == args.ctlport and args.host == args.ctlhost:
1027        if args.path == args.ctlpath:
1028            raise ValueError('Same path/ctlpath on same host')
1029        if args.sslkey != args.ctlsslkey:
1030            raise ValueError('Different sslkey/ctlsslkey on same host')
1031        if args.sslcert != args.ctlsslcert:
1032            raise ValueError('Different sslcert/ctlsslcert on same host')
1033
1034        app = Application([
1035            (args.path, ServerHandler, {
1036                'switch':   switch,
1037                'htpasswd': args.htpasswd,
1038                'noorigin': args.noorigin,
1039                'logger':   logger,
1040                'debug':    args.debug,
1041            }),
1042            (args.ctlpath, ControlServerHandler, {
1043                'ioloop':   ioloop,
1044                'switch':   switch,
1045                'htpasswd': args.ctlhtpasswd,
1046                'logger':   logger,
1047                'debug':    args.debug,
1048            }),
1049        ])
1050        server = HTTPServer(app, ssl_options=sslopt)
1051        server.listen(args.port, address=args.host)
1052
1053    else:
1054        app = Application([(args.path, ServerHandler, {
1055            'switch':   switch,
1056            'htpasswd': args.htpasswd,
1057            'noorigin': args.noorigin,
1058            'logger':   logger,
1059            'debug':    args.debug,
1060        })])
1061        server = HTTPServer(app, ssl_options=sslopt)
1062        server.listen(args.port, address=args.host)
1063
1064        ctl = Application([(args.ctlpath, ControlServerHandler, {
1065            'ioloop':   ioloop,
1066            'switch':   switch,
1067            'htpasswd': args.ctlhtpasswd,
1068            'logger':   logger,
1069            'debug':    args.debug,
1070        })])
1071        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
1072        ctlserver.listen(args.ctlport, address=args.ctlhost)
1073
1074    if not args.foreground:
1075        daemonize()
1076
1077    ioloop.start()
1078
1079
1080def _start_ctl(args):
1081    def have_ssl_cert_verification():
1082        return 'context' in urllib2.urlopen.__code__.co_varnames
1083
1084    def request(args, method, params=None, id_=0):
1085        req = urllib2.Request(args.ctlurl)
1086        req.add_header('Content-type', 'application/json')
1087        if args.ctluser:
1088            if not args.ctlpasswd:
1089                args.ctlpasswd = getpass.getpass('Control Password: ')
1090            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
1091            req.add_header('Authorization', 'Basic %s' % token)
1092        method = '.'.join([ControlServerHandler.NAMESPACE, method])
1093        data = {'jsonrpc': '2.0', 'method': method, 'id': id_}
1094        if params is not None:
1095            data['params'] = params
1096        if have_ssl_cert_verification():
1097            ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH,
1098                                             cafile=args.ctlsslcert)
1099            if args.ctlinsecure:
1100                ctx.check_hostname = False
1101                ctx.verify_mode = ssl.CERT_NONE
1102            fp = urllib2.urlopen(req, json.dumps(data), context=ctx)
1103        elif args.ctlsslcert:
1104            raise EnvironmentError('do not support certificate verification')
1105        else:
1106            fp = urllib2.urlopen(req, json.dumps(data))
1107        return json.loads(fp.read())
1108
1109    def print_table(rows):
1110        cols = zip(*rows)
1111        maxlen = [0] * len(cols)
1112        for i in xrange(len(cols)):
1113            maxlen[i] = max(len(str(c)) for c in cols[i])
1114        fmt = '  '.join(['%%-%ds' % maxlen[i] for i in xrange(len(cols))])
1115        fmt = '  ' + fmt
1116        for row in rows:
1117            print(fmt % tuple(row))
1118
1119    def print_portlist(result):
1120        rows = [['Port', 'Type', 'State', 'RX', 'TX', 'Target']]
1121        for r in result:
1122            rows.append([r['port'], r['type'], 'shut' if r['shut'] else 'up',
1123                         r['rx'], r['tx'], r['target']])
1124        print_table(rows)
1125
1126    def print_iflist(result):
1127        rows = [['Port', 'Type', 'Address', 'Netmask', 'MTU', 'Target']]
1128        for r in result:
1129            rows.append([r['port'], r['type'], r['address'],
1130                         r['netmask'], r['mtu'], r['target']])
1131        print_table(rows)
1132
1133    def handle_ctl_addport(args):
1134        opts = {
1135            'user':     getattr(args, 'user', None),
1136            'passwd':   getattr(args, 'passwd', None),
1137            'cacerts':  getattr(args, 'cacerts', None),
1138            'insecure': getattr(args, 'insecure', None),
1139        }
1140        if args.iftype == ClientHandler.IFTYPE:
1141            if not args.target.startswith('ws://') and \
1142               not args.target.startswith('wss://'):
1143                raise ValueError('Invalid target URL scheme: %s' % args.target)
1144            if not opts['user'] and opts['passwd']:
1145                raise ValueError('Authentication required but username empty')
1146            if opts['user'] and not opts['passwd']:
1147                opts['passwd'] = getpass.getpass('Client Password: ')
1148        result = request(args, 'addPort', {
1149            'type':    args.iftype,
1150            'target':  args.target,
1151            'options': opts,
1152        })
1153        if 'error' in result:
1154            _print_error(result['error'])
1155        else:
1156            print_portlist(result['result']['entries'])
1157
1158    def handle_ctl_setport(args):
1159        if args.port <= 0:
1160            raise ValueError('Invalid port: %d' % args.port)
1161        req = {'port': args.port}
1162        shut = getattr(args, 'shut', None)
1163        if shut is not None:
1164            req['shut'] = bool(shut)
1165        result = request(args, 'setPort', req)
1166        if 'error' in result:
1167            _print_error(result['error'])
1168        else:
1169            print_portlist(result['result']['entries'])
1170
1171    def handle_ctl_delport(args):
1172        if args.port <= 0:
1173            raise ValueError('Invalid port: %d' % args.port)
1174        result = request(args, 'delPort', {'port': args.port})
1175        if 'error' in result:
1176            _print_error(result['error'])
1177        else:
1178            print_portlist(result['result']['entries'])
1179
1180    def handle_ctl_listport(args):
1181        result = request(args, 'listPort')
1182        if 'error' in result:
1183            _print_error(result['error'])
1184        else:
1185            print_portlist(result['result']['entries'])
1186
1187    def handle_ctl_setif(args):
1188        if args.port <= 0:
1189            raise ValueError('Invalid port: %d' % args.port)
1190        req = {'port': args.port}
1191        address = getattr(args, 'address', None)
1192        netmask = getattr(args, 'netmask', None)
1193        mtu = getattr(args, 'mtu', None)
1194        if address is not None:
1195            if address:
1196                socket.inet_aton(address)  # validate
1197            req['address'] = address
1198        if netmask is not None:
1199            if netmask:
1200                socket.inet_aton(netmask)  # validate
1201            req['netmask'] = netmask
1202        if mtu is not None:
1203            if mtu < 576:
1204                raise ValueError('Invalid MTU: %d' % mtu)
1205            req['mtu'] = mtu
1206        result = request(args, 'setInterface', req)
1207        if 'error' in result:
1208            _print_error(result['error'])
1209        else:
1210            print_iflist(result['result']['entries'])
1211
1212    def handle_ctl_listif(args):
1213        result = request(args, 'listInterface')
1214        if 'error' in result:
1215            _print_error(result['error'])
1216        else:
1217            print_iflist(result['result']['entries'])
1218
1219    def handle_ctl_listfdb(args):
1220        result = request(args, 'listFdb')
1221        if 'error' in result:
1222            return _print_error(result['error'])
1223        rows = [['Port', 'VLAN', 'MAC', 'Age']]
1224        for r in result['result']['entries']:
1225            rows.append([r['port'], r['vid'], r['mac'], r['age']])
1226        print_table(rows)
1227
1228    locals()['handle_ctl_' + args.control_method](args)
1229
1230
1231def _main():
1232    parser = argparse.ArgumentParser()
1233    subcommand = parser.add_subparsers(dest='subcommand')
1234
1235    # - sw
1236    parser_sw = subcommand.add_parser('sw',
1237                                      help='start virtual switch')
1238
1239    parser_sw.add_argument('--debug', action='store_true', default=False,
1240                           help='run as debug mode')
1241    parser_sw.add_argument('--logconf',
1242                           help='path to logging config')
1243    parser_sw.add_argument('--foreground', action='store_true', default=False,
1244                           help='run as foreground mode')
1245    parser_sw.add_argument('--ageout', type=int, default=300,
1246                           help='FDB ageout time (sec)')
1247
1248    parser_sw.add_argument('--path', default='/',
1249                           help='http(s) path to serve WebSocket')
1250    parser_sw.add_argument('--host', default='',
1251                           help='listen address to serve WebSocket')
1252    parser_sw.add_argument('--port', type=int,
1253                           help='listen port to serve WebSocket')
1254    parser_sw.add_argument('--htpasswd',
1255                           help='path to htpasswd file to auth WebSocket')
1256    parser_sw.add_argument('--sslkey',
1257                           help='path to SSL private key for WebSocket')
1258    parser_sw.add_argument('--sslcert',
1259                           help='path to SSL certificate for WebSocket')
1260
1261    parser_sw.add_argument('--ctlpath', default='/ctl',
1262                           help='http(s) path to serve control API')
1263    parser_sw.add_argument('--ctlhost', default='127.0.0.1',
1264                           help='listen address to serve control API')
1265    parser_sw.add_argument('--ctlport', type=int, default=7867,
1266                           help='listen port to serve control API')
1267    parser_sw.add_argument('--ctlhtpasswd',
1268                           help='path to htpasswd file to auth control API')
1269    parser_sw.add_argument('--ctlsslkey',
1270                           help='path to SSL private key for control API')
1271    parser_sw.add_argument('--ctlsslcert',
1272                           help='path to SSL certificate for control API')
1273
1274    parser_sw.add_argument('--noorigin', action='store_true', default=False,
1275                           help='do not check origin header')
1276
1277    # - ctl
1278    parser_ctl = subcommand.add_parser('ctl',
1279                                       help='control virtual switch')
1280    parser_ctl.add_argument('--ctlurl', default='http://127.0.0.1:7867/ctl',
1281                            help='URL to control API')
1282    parser_ctl.add_argument('--ctluser',
1283                            help='username to auth control API')
1284    parser_ctl.add_argument('--ctlpasswd',
1285                            help='password to auth control API')
1286    parser_ctl.add_argument('--ctlsslcert',
1287                            help='path to SSL certificate for control API')
1288    parser_ctl.add_argument(
1289        '--ctlinsecure', action='store_true', default=False,
1290        help='do not verify control API certificate')
1291
1292    control_method = parser_ctl.add_subparsers(dest='control_method')
1293
1294    # -- ctl addport
1295    parser_ctl_addport = control_method.add_parser('addport',
1296                                                   help='create and add port')
1297    iftype = parser_ctl_addport.add_subparsers(dest='iftype')
1298
1299    # --- ctl addport netdev
1300    parser_ctl_addport_netdev = iftype.add_parser(NetdevHandler.IFTYPE,
1301                                                  help='Network device')
1302    parser_ctl_addport_netdev.add_argument('target',
1303                                           help='device name to add interface')
1304
1305    # --- ctl addport tap
1306    parser_ctl_addport_tap = iftype.add_parser(TapHandler.IFTYPE,
1307                                               help='TAP device')
1308    parser_ctl_addport_tap.add_argument('target',
1309                                        help='device name to create interface')
1310
1311    # --- ctl addport client
1312    parser_ctl_addport_client = iftype.add_parser(ClientHandler.IFTYPE,
1313                                                  help='WebSocket client')
1314    parser_ctl_addport_client.add_argument('target',
1315                                           help='URL to connect WebSocket')
1316    parser_ctl_addport_client.add_argument('--user',
1317                                           help='username to auth WebSocket')
1318    parser_ctl_addport_client.add_argument('--passwd',
1319                                           help='password to auth WebSocket')
1320    parser_ctl_addport_client.add_argument('--cacerts',
1321                                           help='path to CA certificate')
1322    parser_ctl_addport_client.add_argument(
1323        '--insecure', action='store_true', default=False,
1324        help='do not verify server certificate')
1325
1326    # -- ctl setport
1327    parser_ctl_setport = control_method.add_parser('setport',
1328                                                   help='set port config')
1329    parser_ctl_setport.add_argument('port', type=int,
1330                                    help='port number to set config')
1331    parser_ctl_setport.add_argument('--shut', type=int, choices=(0, 1),
1332                                    help='set shutdown state')
1333
1334    # -- ctl delport
1335    parser_ctl_delport = control_method.add_parser('delport',
1336                                                   help='delete port')
1337    parser_ctl_delport.add_argument('port', type=int,
1338                                    help='port number to delete')
1339
1340    # -- ctl listport
1341    parser_ctl_listport = control_method.add_parser('listport',
1342                                                    help='show port list')
1343
1344    # -- ctl setif
1345    parser_ctl_setif = control_method.add_parser('setif',
1346                                                 help='set interface config')
1347    parser_ctl_setif.add_argument('port', type=int,
1348                                  help='port number to set config')
1349    parser_ctl_setif.add_argument('--address',
1350                                  help='IPv4 address to set interface')
1351    parser_ctl_setif.add_argument('--netmask',
1352                                  help='IPv4 netmask to set interface')
1353    parser_ctl_setif.add_argument('--mtu', type=int,
1354                                  help='MTU to set interface')
1355
1356    # -- ctl listif
1357    parser_ctl_listif = control_method.add_parser('listif',
1358                                                  help='show interface list')
1359
1360    # -- ctl listfdb
1361    parser_ctl_listfdb = control_method.add_parser('listfdb',
1362                                                   help='show FDB entries')
1363
1364    # -- go
1365    args = parser.parse_args()
1366
1367    try:
1368        globals()['_start_' + args.subcommand](args)
1369    except Exception as e:
1370        _print_error({
1371            'code':    0 - 32603,
1372            'message': 'Internal error',
1373            'data':    '%s: %s' % (type(e).__name__, str(e)),
1374        })
1375
1376
1377if __name__ == '__main__':
1378    _main()
Note: See TracBrowser for help on using the repository browser.