source: etherws/trunk/etherws.py @ 281

Revision 281, 44.5 KB checked in by atzm, 9 years ago (diff)
  • refactoring
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.5
8#   - python-pytun-2.1
9#   - websocket-client-0.14.0
10#   - tornado-2.4
11#
12# ===========================================================================
13# Copyright (c) 2012-2015, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
42import ssl
43import time
44import json
45import fcntl
46import base64
47import socket
48import struct
49import urllib2
50import hashlib
51import getpass
52import operator
53import argparse
54
55import logging
56import logging.config
57
58import tornado
59import websocket
60
61from tornado.web import Application, RequestHandler
62from tornado.websocket import WebSocketHandler
63from tornado.httpserver import HTTPServer
64from tornado.ioloop import IOLoop
65
66from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
67
68
69class EthernetFrame(object):
70    def __init__(self, data):
71        self.data = data
72
73    @property
74    def dst_multicast(self):
75        return ord(self.data[0]) & 1
76
77    @property
78    def src_multicast(self):
79        return ord(self.data[6]) & 1
80
81    @property
82    def dst_mac(self):
83        return self.data[:6]
84
85    @property
86    def src_mac(self):
87        return self.data[6:12]
88
89    @property
90    def tagged(self):
91        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
92
93    @property
94    def vid(self):
95        if self.tagged:
96            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
97        return 0
98
99    @staticmethod
100    def format_mac(mac, sep=':'):
101        return sep.join(b.encode('hex') for b in mac)
102
103
104class FDB(object):
105    class Entry(object):
106        def __init__(self, port, ageout):
107            self.port = port
108            self._time = time.time()
109            self._ageout = ageout
110
111        @property
112        def age(self):
113            return time.time() - self._time
114
115        @property
116        def agedout(self):
117            return self.age > self._ageout
118
119    def __init__(self, ageout, logger, debug):
120        self._ageout = ageout
121        self._logger = logger
122        self._debug = debug
123        self._table = {}
124
125    def _set_entry(self, vid, mac, port):
126        if vid not in self._table:
127            self._table[vid] = {}
128        self._table[vid][mac] = self.Entry(port, self._ageout)
129
130    def _del_entry(self, vid, mac):
131        if vid in self._table:
132            if mac in self._table[vid]:
133                del self._table[vid][mac]
134            if not self._table[vid]:
135                del self._table[vid]
136
137    def _get_entry(self, vid, mac):
138        try:
139            entry = self._table[vid][mac]
140        except KeyError:
141            return None
142
143        if not entry.agedout:
144            return entry
145
146        self._del_entry(vid, mac)
147
148        if self._debug:
149            self._logger.debug('fdb aged out: port:%d; vid:%d; mac:%s',
150                               entry.port.number, vid, mac.encode('hex'))
151
152    def each(self):
153        for vid in sorted(self._table.iterkeys()):
154            for mac in sorted(self._table[vid].iterkeys()):
155                entry = self._get_entry(vid, mac)
156                if entry:
157                    yield (vid, mac, entry)
158
159    def lookup(self, frame):
160        mac = frame.dst_mac
161        vid = frame.vid
162        entry = self._get_entry(vid, mac)
163        return getattr(entry, 'port', None)
164
165    def learn(self, port, frame):
166        mac = frame.src_mac
167        vid = frame.vid
168        self._set_entry(vid, mac, port)
169
170        if self._debug:
171            self._logger.debug('fdb learned: port:%d; vid:%d; mac:%s',
172                               port.number, vid, mac.encode('hex'))
173
174    def delete(self, port):
175        for vid, mac, entry in self.each():
176            if entry.port.number == port.number:
177                self._del_entry(vid, mac)
178
179                if self._debug:
180                    self._logger.debug('fdb deleted: port:%d; vid:%d; mac:%s',
181                                       port.number, vid, mac.encode('hex'))
182
183
184class SwitchingHub(object):
185    class Port(object):
186        def __init__(self, number, interface):
187            self.number = number
188            self.interface = interface
189            self.tx = 0
190            self.rx = 0
191            self.shut = False
192
193    def __init__(self, fdb, logger, debug):
194        self.fdb = fdb
195        self._logger = logger
196        self._debug = debug
197        self._table = {}
198        self._next = 1
199
200    @property
201    def portlist(self):
202        return sorted(self._table.itervalues(),
203                      key=operator.attrgetter('number'))
204
205    def get_port(self, portnum):
206        return self._table[portnum]
207
208    def register_port(self, interface):
209        try:
210            self._set_privattr('portnum', interface, self._next)  # XXX
211            self._table[self._next] = self.Port(self._next, interface)
212            return self._next
213        finally:
214            self._next += 1
215
216    def unregister_port(self, interface):
217        portnum = self._get_privattr('portnum', interface)
218        self._del_privattr('portnum', interface)
219        self.fdb.delete(self._table[portnum])
220        del self._table[portnum]
221
222    def send(self, dst_interfaces, frame):
223        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
224        ports = (self._table[n] for n in portnums)
225        ports = (p for p in ports if not p.shut)
226        ports = sorted(ports, key=operator.attrgetter('number'))
227
228        for p in ports:
229            p.interface.write_message(frame.data, True)
230            p.tx += 1
231
232        if self._debug and ports:
233            self._logger.debug('sent: port:%s; vid:%d; %s -> %s',
234                               ','.join(str(p.number) for p in ports),
235                               frame.vid,
236                               frame.src_mac.encode('hex'),
237                               frame.dst_mac.encode('hex'))
238
239    def receive(self, src_interface, frame):
240        port = self._table[self._get_privattr('portnum', src_interface)]
241
242        if not port.shut:
243            port.rx += 1
244            self._forward(port, frame)
245
246    def _forward(self, src_port, frame):
247        try:
248            if not frame.src_multicast:
249                self.fdb.learn(src_port, frame)
250
251            if not frame.dst_multicast:
252                dst_port = self.fdb.lookup(frame)
253
254                if dst_port:
255                    self.send([dst_port.interface], frame)
256                    return
257
258            ports = set(self.portlist) - set([src_port])
259            self.send((p.interface for p in ports), frame)
260
261        except:  # ex. received invalid frame
262            self._logger.exception(
263                'caught error while forwarding frame: port:%d; frame:%s',
264                src_port.number, frame.data.encode('hex'))
265
266    def _privattr(self, name):
267        return '_%s_%s_%s' % (type(self).__name__, id(self), name)
268
269    def _set_privattr(self, name, obj, value):
270        return setattr(obj, self._privattr(name), value)
271
272    def _get_privattr(self, name, obj, defaults=None):
273        return getattr(obj, self._privattr(name), defaults)
274
275    def _del_privattr(self, name, obj):
276        return delattr(obj, self._privattr(name))
277
278
279class NetworkInterface(object):
280    SIOCGIFADDR = 0x8915  # from <linux/sockios.h>
281    SIOCSIFADDR = 0x8916
282    SIOCGIFNETMASK = 0x891b
283    SIOCSIFNETMASK = 0x891c
284    SIOCGIFMTU = 0x8921
285    SIOCSIFMTU = 0x8922
286
287    def __init__(self, ifname):
288        self._ifname = struct.pack('16s', str(ifname)[:15])
289
290    def _ioctl(self, req, data):
291        try:
292            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
293            return fcntl.ioctl(sock.fileno(), req, data)
294        finally:
295            sock.close()
296
297    @property
298    def address(self):
299        try:
300            data = struct.pack('16s18x', self._ifname)
301            ret = self._ioctl(self.SIOCGIFADDR, data)
302            return socket.inet_ntoa(ret[20:24])
303        except IOError:
304            return ''
305
306    @property
307    def netmask(self):
308        try:
309            data = struct.pack('16s18x', self._ifname)
310            ret = self._ioctl(self.SIOCGIFNETMASK, data)
311            return socket.inet_ntoa(ret[20:24])
312        except IOError:
313            return ''
314
315    @property
316    def mtu(self):
317        try:
318            data = struct.pack('16s18x', self._ifname)
319            ret = self._ioctl(self.SIOCGIFMTU, data)
320            return struct.unpack('i', ret[16:20])[0]
321        except IOError:
322            return 0
323
324    @address.setter
325    def address(self, addr):
326        data = struct.pack('16si4s10x', self._ifname, socket.AF_INET,
327                           socket.inet_aton(addr))
328        self._ioctl(self.SIOCSIFADDR, data)
329
330    @netmask.setter
331    def netmask(self, addr):
332        data = struct.pack('16si4s10x', self._ifname, socket.AF_INET,
333                           socket.inet_aton(addr))
334        self._ioctl(self.SIOCSIFNETMASK, data)
335
336    @mtu.setter
337    def mtu(self, mtu):
338        data = struct.pack('16si12x', self._ifname, mtu)
339        self._ioctl(self.SIOCSIFMTU, data)
340
341
342class Htpasswd(object):
343    def __init__(self, path):
344        self._path = path
345        self._stat = None
346        self._data = {}
347
348    def auth(self, name, passwd):
349        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
350        return self._data.get(name) == passwd
351
352    def load(self):
353        old_stat = self._stat
354
355        with open(self._path) as fp:
356            fileno = fp.fileno()
357            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
358            self._stat = os.fstat(fileno)
359
360            unchanged = old_stat and \
361                old_stat.st_ino == self._stat.st_ino and \
362                old_stat.st_dev == self._stat.st_dev and \
363                old_stat.st_mtime == self._stat.st_mtime
364
365            if not unchanged:
366                self._data = self._parse(fp)
367
368        return self
369
370    def _parse(self, fp):
371        data = {}
372        for line in fp:
373            line = line.strip()
374            if 0 <= line.find(':'):
375                name, passwd = line.split(':', 1)
376                if passwd.startswith('{SHA}'):
377                    data[name] = passwd[5:]
378        return data
379
380
381class BasicAuthMixIn(object):
382    def _execute(self, transforms, *args, **kwargs):
383        def do_execute():
384            sp = super(BasicAuthMixIn, self)
385            return sp._execute(transforms, *args, **kwargs)
386
387        def auth_required():
388            stream = getattr(self, 'stream', self.request.connection.stream)
389            stream.write(tornado.escape.utf8(
390                'HTTP/1.1 401 Authorization Required\r\n'
391                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
392            ))
393            stream.close()
394
395        try:
396            if not self._htpasswd:
397                return do_execute()
398
399            creds = self.request.headers.get('Authorization')
400
401            if not creds or not creds.startswith('Basic '):
402                return auth_required()
403
404            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
405
406            if self._htpasswd.load().auth(name, passwd):
407                return do_execute()
408        except:
409            self._logger.exception('caught error while doing authorization')
410
411        return auth_required()
412
413
414class ServerHandler(BasicAuthMixIn, WebSocketHandler):
415    IFTYPE = 'server'
416    IFOP_ALLOWED = False
417
418    def __init__(self, app, req, switch, htpasswd, logger, debug):
419        super(ServerHandler, self).__init__(app, req)
420        self._switch = switch
421        self._htpasswd = htpasswd
422        self._logger = logger
423        self._debug = debug
424
425    @property
426    def target(self):
427        conn = self.request.connection
428        if hasattr(conn, 'address'):
429            return ':'.join(str(e) for e in conn.address[:2])
430        if hasattr(conn, 'context'):
431            return ':'.join(str(e) for e in conn.context.address[:2])
432        return str(conn)
433
434    def open(self):
435        try:
436            return self._switch.register_port(self)
437        finally:
438            self._logger.info('connected: %s', self.request.remote_ip)
439
440    def on_message(self, message):
441        self._switch.receive(self, EthernetFrame(message))
442
443    def on_close(self):
444        self._switch.unregister_port(self)
445        self._logger.info('disconnected: %s', self.request.remote_ip)
446
447
448class BaseClientHandler(object):
449    IFTYPE = 'baseclient'
450    IFOP_ALLOWED = False
451
452    def __init__(self, ioloop, switch, target, logger, debug, *args, **kwargs):
453        self._ioloop = ioloop
454        self._switch = switch
455        self._target = target
456        self._logger = logger
457        self._debug = debug
458        self._args = args
459        self._kwargs = kwargs
460        self._device = None
461
462    @property
463    def address(self):
464        raise NotImplementedError('unsupported')
465
466    @property
467    def netmask(self):
468        raise NotImplementedError('unsupported')
469
470    @property
471    def mtu(self):
472        raise NotImplementedError('unsupported')
473
474    @address.setter
475    def address(self, address):
476        raise NotImplementedError('unsupported')
477
478    @netmask.setter
479    def netmask(self, netmask):
480        raise NotImplementedError('unsupported')
481
482    @mtu.setter
483    def mtu(self, mtu):
484        raise NotImplementedError('unsupported')
485
486    def open(self):
487        raise NotImplementedError('unsupported')
488
489    def write_message(self, message, binary=False):
490        raise NotImplementedError('unsupported')
491
492    def read(self):
493        raise NotImplementedError('unsupported')
494
495    @property
496    def target(self):
497        return self._target
498
499    @property
500    def logger(self):
501        return self._logger
502
503    @property
504    def device(self):
505        return self._device
506
507    @property
508    def closed(self):
509        return not self.device
510
511    def close(self):
512        if self.closed:
513            raise ValueError('I/O operation on closed %s' % self.IFTYPE)
514        self.leave_switch()
515        self.unregister_device()
516        self.logger.info('disconnected: %s', self.target)
517
518    def register_device(self, device):
519        self._device = device
520
521    def unregister_device(self):
522        self._device.close()
523        self._device = None
524
525    def fileno(self):
526        if self.closed:
527            raise ValueError('I/O operation on closed %s' % self.IFTYPE)
528        return self.device.fileno()
529
530    def __call__(self, fd, events):
531        try:
532            data = self.read()
533            if data is not None:
534                self._switch.receive(self, EthernetFrame(data))
535                return
536        except:
537            self.logger.exception('caught error while receiving frame')
538        self.close()
539
540    def join_switch(self):
541        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
542        return self._switch.register_port(self)
543
544    def leave_switch(self):
545        self._switch.unregister_port(self)
546        self._ioloop.remove_handler(self.fileno())
547
548
549class NetdevHandler(BaseClientHandler):
550    IFTYPE = 'netdev'
551    IFOP_ALLOWED = True
552    ETH_P_ALL = 0x0003  # from <linux/if_ether.h>
553
554    @property
555    def address(self):
556        if self.closed:
557            raise ValueError('I/O operation on closed netdev')
558        return NetworkInterface(self.target).address
559
560    @property
561    def netmask(self):
562        if self.closed:
563            raise ValueError('I/O operation on closed netdev')
564        return NetworkInterface(self.target).netmask
565
566    @property
567    def mtu(self):
568        if self.closed:
569            raise ValueError('I/O operation on closed netdev')
570        return NetworkInterface(self.target).mtu
571
572    @address.setter
573    def address(self, address):
574        if self.closed:
575            raise ValueError('I/O operation on closed netdev')
576        NetworkInterface(self.target).address = address
577
578    @netmask.setter
579    def netmask(self, netmask):
580        if self.closed:
581            raise ValueError('I/O operation on closed netdev')
582        NetworkInterface(self.target).netmask = netmask
583
584    @mtu.setter
585    def mtu(self, mtu):
586        if self.closed:
587            raise ValueError('I/O operation on closed netdev')
588        NetworkInterface(self.target).mtu = mtu
589
590    def open(self):
591        if not self.closed:
592            raise ValueError('Already opened')
593        self.register_device(socket.socket(
594            socket.PF_PACKET, socket.SOCK_RAW, socket.htons(self.ETH_P_ALL)))
595        self.device.bind((self.target, self.ETH_P_ALL))
596        self.logger.info('connected: %s', self.target)
597        return self.join_switch()
598
599    def write_message(self, message, binary=False):
600        if self.closed:
601            raise ValueError('I/O operation on closed netdev')
602        self.device.sendall(message)
603
604    def read(self):
605        if self.closed:
606            raise ValueError('I/O operation on closed netdev')
607        buf = []
608        while True:
609            buf.append(self.device.recv(65535))
610            if len(buf[-1]) < 65535:
611                break
612        return ''.join(buf)
613
614
615class TapHandler(BaseClientHandler):
616    IFTYPE = 'tap'
617    IFOP_ALLOWED = True
618
619    @property
620    def address(self):
621        if self.closed:
622            raise ValueError('I/O operation on closed tap')
623        try:
624            return self.device.addr
625        except:
626            return ''
627
628    @property
629    def netmask(self):
630        if self.closed:
631            raise ValueError('I/O operation on closed tap')
632        try:
633            return self.device.netmask
634        except:
635            return ''
636
637    @property
638    def mtu(self):
639        if self.closed:
640            raise ValueError('I/O operation on closed tap')
641        return self.device.mtu
642
643    @address.setter
644    def address(self, address):
645        if self.closed:
646            raise ValueError('I/O operation on closed tap')
647        self.device.addr = address
648
649    @netmask.setter
650    def netmask(self, netmask):
651        if self.closed:
652            raise ValueError('I/O operation on closed tap')
653        self.device.netmask = netmask
654
655    @mtu.setter
656    def mtu(self, mtu):
657        if self.closed:
658            raise ValueError('I/O operation on closed tap')
659        self.device.mtu = mtu
660
661    @property
662    def target(self):
663        if self.closed:
664            return self._target
665        return self.device.name
666
667    def open(self):
668        if not self.closed:
669            raise ValueError('Already opened')
670        self.register_device(TunTapDevice(self.target, IFF_TAP | IFF_NO_PI))
671        self.device.up()
672        self.logger.info('connected: %s', self.target)
673        return self.join_switch()
674
675    def write_message(self, message, binary=False):
676        if self.closed:
677            raise ValueError('I/O operation on closed tap')
678        self.device.write(message)
679
680    def read(self):
681        if self.closed:
682            raise ValueError('I/O operation on closed tap')
683        buf = []
684        while True:
685            buf.append(self.device.read(65535))
686            if len(buf[-1]) < 65535:
687                break
688        return ''.join(buf)
689
690
691class ClientHandler(BaseClientHandler):
692    IFTYPE = 'client'
693    IFOP_ALLOWED = False
694
695    def __init__(self, *args, **kwargs):
696        super(ClientHandler, self).__init__(*args, **kwargs)
697        self._sslopt = self._kwargs.get('sslopt', {})
698        self._options = dict(self.authopts, **self.proxyopts)
699
700    def open(self):
701        if not self.closed:
702            raise websocket.WebSocketException('Already opened')
703
704        # XXX: may be blocked
705        self.register_device(websocket.WebSocket(sslopt=self._sslopt))
706        self.device.connect(self.target, **self._options)
707        self.logger.info('connected: %s', self.target)
708        return self.join_switch()
709
710    def write_message(self, message, binary=False):
711        if self.closed:
712            raise websocket.WebSocketException('Closed socket')
713        if binary:
714            flag = websocket.ABNF.OPCODE_BINARY
715        else:
716            flag = websocket.ABNF.OPCODE_TEXT
717        self.device.send(message, flag)
718
719    def read(self):
720        if self.closed:
721            raise websocket.WebSocketException('Closed socket')
722        return self.device.recv()
723
724    @property
725    def authopts(self):
726        cred = self._kwargs.get('cred')
727
728        if not isinstance(cred, dict):
729            return {}
730        if 'user' not in cred:
731            return {}
732        if 'passwd' not in cred:
733            return {}
734
735        token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
736        return {'header': ['Authorization: Basic %s' % token]}
737
738    @property
739    def proxyopts(self):
740        scheme, remain = urllib2.splittype(self.target)
741
742        proxy = urllib2.getproxies().get(scheme.replace('ws', 'http'))
743        if not proxy:
744            return {}
745
746        host = urllib2.splitport(urllib2.splithost(remain)[0])[0]
747        if urllib2.proxy_bypass(host):
748            return {}
749
750        hostport = urllib2.splithost(urllib2.splittype(proxy)[1])[0]
751        host, port = urllib2.splitport(hostport)
752        port = int(port) if port else 0
753        return {'http_proxy_host': host, 'http_proxy_port': port}
754
755
756class ControlServerHandler(BasicAuthMixIn, RequestHandler):
757    NAMESPACE = 'etherws.control'
758    IFTYPES = {
759        NetdevHandler.IFTYPE: NetdevHandler,
760        TapHandler.IFTYPE:    TapHandler,
761        ClientHandler.IFTYPE: ClientHandler,
762    }
763
764    def __init__(self, app, req, ioloop, switch, htpasswd, logger, debug):
765        super(ControlServerHandler, self).__init__(app, req)
766        self._ioloop = ioloop
767        self._switch = switch
768        self._htpasswd = htpasswd
769        self._logger = logger
770        self._debug = debug
771
772    def post(self):
773        try:
774            request = json.loads(self.request.body)
775        except Exception as e:
776            return self._jsonrpc_response(error={
777                'code':    0 - 32700,
778                'message': 'Parse error',
779                'data':    '%s: %s' % (type(e).__name__, str(e)),
780            })
781
782        try:
783            id_ = request.get('id')
784            params = request.get('params')
785            version = request['jsonrpc']
786            method = request['method']
787            if version != '2.0':
788                raise ValueError('Invalid JSON-RPC version: %s' % version)
789        except Exception as e:
790            return self._jsonrpc_response(id_=id_, error={
791                'code':    0 - 32600,
792                'message': 'Invalid Request',
793                'data':    '%s: %s' % (type(e).__name__, str(e)),
794            })
795
796        try:
797            if not method.startswith(self.NAMESPACE + '.'):
798                raise ValueError('Invalid method namespace: %s' % method)
799            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
800            handler = getattr(self, handler)
801        except Exception as e:
802            return self._jsonrpc_response(id_=id_, error={
803                'code':    0 - 32601,
804                'message': 'Method not found',
805                'data':    '%s: %s' % (type(e).__name__, str(e)),
806            })
807
808        try:
809            return self._jsonrpc_response(id_=id_, result=handler(params))
810        except Exception as e:
811            self._logger.exception('caught error while processing request')
812            return self._jsonrpc_response(id_=id_, error={
813                'code':    0 - 32602,
814                'message': 'Invalid params',
815                'data':     '%s: %s' % (type(e).__name__, str(e)),
816            })
817
818    def handle_listFdb(self, params):
819        list_ = []
820        for vid, mac, entry in self._switch.fdb.each():
821            list_.append({
822                'vid':  vid,
823                'mac':  EthernetFrame.format_mac(mac),
824                'port': entry.port.number,
825                'age':  int(entry.age),
826            })
827        return {'entries': list_}
828
829    def handle_listPort(self, params):
830        return {'entries': [self._portstat(p) for p in self._switch.portlist]}
831
832    def handle_addPort(self, params):
833        type_ = params['type']
834        target = params['target']
835        opt = getattr(self, '_optparse_' + type_)(params.get('options', {}))
836        cls = self.IFTYPES[type_]
837        interface = cls(self._ioloop, self._switch, target,
838                        self._logger, self._debug, **opt)
839        portnum = interface.open()
840        return {'entries': [self._portstat(self._switch.get_port(portnum))]}
841
842    def handle_setPort(self, params):
843        port = self._switch.get_port(int(params['port']))
844        shut = params.get('shut')
845        if shut is not None:
846            port.shut = bool(shut)
847        return {'entries': [self._portstat(port)]}
848
849    def handle_delPort(self, params):
850        port = self._switch.get_port(int(params['port']))
851        port.interface.close()
852        return {'entries': [self._portstat(port)]}
853
854    def handle_setInterface(self, params):
855        portnum = int(params['port'])
856        port = self._switch.get_port(portnum)
857        address = params.get('address')
858        netmask = params.get('netmask')
859        mtu = params.get('mtu')
860        if not port.interface.IFOP_ALLOWED:
861            raise ValueError('Port %d has unsupported interface: %s' %
862                             (portnum, port.interface.IFTYPE))
863        if address is not None:
864            port.interface.address = address
865        if netmask is not None:
866            port.interface.netmask = netmask
867        if mtu is not None:
868            port.interface.mtu = mtu
869        return {'entries': [self._ifstat(port)]}
870
871    def handle_listInterface(self, params):
872        return {'entries': [self._ifstat(p) for p in self._switch.portlist
873                            if p.interface.IFOP_ALLOWED]}
874
875    def _optparse_netdev(self, opt):
876        return {}
877
878    def _optparse_tap(self, opt):
879        return {}
880
881    def _optparse_client(self, opt):
882        if opt.get('insecure'):
883            sslopt = {'cert_reqs': ssl.CERT_NONE}
884        else:
885            sslopt = {'cert_reqs': ssl.CERT_REQUIRED,
886                      'ca_certs':  opt.get('cacerts')}
887        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
888        return {'sslopt': sslopt, 'cred': cred}
889
890    def _jsonrpc_response(self, id_=None, result=None, error=None):
891        res = {'jsonrpc': '2.0', 'id': id_}
892        if result:
893            res['result'] = result
894        if error:
895            res['error'] = error
896        self.finish(res)
897
898    @staticmethod
899    def _portstat(port):
900        return {
901            'port':   port.number,
902            'type':   port.interface.IFTYPE,
903            'target': port.interface.target,
904            'tx':     port.tx,
905            'rx':     port.rx,
906            'shut':   port.shut,
907        }
908
909    @staticmethod
910    def _ifstat(port):
911        return {
912            'port':    port.number,
913            'type':    port.interface.IFTYPE,
914            'target':  port.interface.target,
915            'address': port.interface.address,
916            'netmask': port.interface.netmask,
917            'mtu':     port.interface.mtu,
918        }
919
920
921def _print_error(error):
922    print(%s (%s)' % (error['message'], error['code']))
923    print('    %s' % error['data'])
924
925
926def _start_sw(args):
927    def daemonize(nochdir=False, noclose=False):
928        if os.fork() > 0:
929            sys.exit(0)
930
931        os.setsid()
932
933        if os.fork() > 0:
934            sys.exit(0)
935
936        if not nochdir:
937            os.chdir('/')
938
939        if not noclose:
940            os.umask(0)
941            sys.stdin.close()
942            sys.stdout.close()
943            sys.stderr.close()
944            os.close(0)
945            os.close(1)
946            os.close(2)
947            sys.stdin = open(os.devnull)
948            sys.stdout = open(os.devnull, 'a')
949            sys.stderr = open(os.devnull, 'a')
950
951    def checkabspath(ns, path):
952        val = getattr(ns, path, '')
953        if not val.startswith('/'):
954            raise ValueError('Invalid %: %s' % (path, val))
955
956    def getsslopt(ns, key, cert):
957        kval = getattr(ns, key, None)
958        cval = getattr(ns, cert, None)
959        if kval and cval:
960            return {'keyfile': kval, 'certfile': cval}
961        elif kval or cval:
962            raise ValueError('Both %s and %s are required' % (key, cert))
963        return None
964
965    def setrealpath(ns, *keys):
966        for k in keys:
967            v = getattr(ns, k, None)
968            if v is not None:
969                v = os.path.realpath(v)
970                open(v).close()  # check readable
971                setattr(ns, k, v)
972
973    def setport(ns, port, isssl):
974        val = getattr(ns, port, None)
975        if val is None:
976            if isssl:
977                return setattr(ns, port, 443)
978            return setattr(ns, port, 80)
979        if not (0 <= val <= 65535):
980            raise ValueError('Invalid %s: %s' % (port, val))
981
982    def sethtpasswd(ns, htpasswd):
983        val = getattr(ns, htpasswd, None)
984        if val:
985            return setattr(ns, htpasswd, Htpasswd(val))
986
987    # if args.debug:
988    #     websocket.enableTrace(True)
989
990    if args.ageout <= 0:
991        raise ValueError('Invalid ageout: %s' % args.ageout)
992
993    setrealpath(args, 'logconf')
994    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
995    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
996
997    checkabspath(args, 'path')
998    checkabspath(args, 'ctlpath')
999
1000    sslopt = getsslopt(args, 'sslkey', 'sslcert')
1001    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
1002
1003    setport(args, 'port', sslopt)
1004    setport(args, 'ctlport', ctlsslopt)
1005
1006    sethtpasswd(args, 'htpasswd')
1007    sethtpasswd(args, 'ctlhtpasswd')
1008
1009    if args.logconf:
1010        logging.config.fileConfig(args.logconf)
1011        logger = logging.getLogger('etherws')
1012    else:
1013        logger = logging.getLogger('etherws')
1014        logger.setLevel(logging.DEBUG if args.debug else logging.INFO)
1015
1016    ioloop = IOLoop.instance()
1017    fdb = FDB(args.ageout, logger, args.debug)
1018    switch = SwitchingHub(fdb, logger, args.debug)
1019
1020    if args.port == args.ctlport and args.host == args.ctlhost:
1021        if args.path == args.ctlpath:
1022            raise ValueError('Same path/ctlpath on same host')
1023        if args.sslkey != args.ctlsslkey:
1024            raise ValueError('Different sslkey/ctlsslkey on same host')
1025        if args.sslcert != args.ctlsslcert:
1026            raise ValueError('Different sslcert/ctlsslcert on same host')
1027
1028        app = Application([
1029            (args.path, ServerHandler, {
1030                'switch':   switch,
1031                'htpasswd': args.htpasswd,
1032                'logger':   logger,
1033                'debug':    args.debug,
1034            }),
1035            (args.ctlpath, ControlServerHandler, {
1036                'ioloop':   ioloop,
1037                'switch':   switch,
1038                'htpasswd': args.ctlhtpasswd,
1039                'logger':   logger,
1040                'debug':    args.debug,
1041            }),
1042        ])
1043        server = HTTPServer(app, ssl_options=sslopt)
1044        server.listen(args.port, address=args.host)
1045
1046    else:
1047        app = Application([(args.path, ServerHandler, {
1048            'switch':   switch,
1049            'htpasswd': args.htpasswd,
1050            'logger':   logger,
1051            'debug':    args.debug,
1052        })])
1053        server = HTTPServer(app, ssl_options=sslopt)
1054        server.listen(args.port, address=args.host)
1055
1056        ctl = Application([(args.ctlpath, ControlServerHandler, {
1057            'ioloop':   ioloop,
1058            'switch':   switch,
1059            'htpasswd': args.ctlhtpasswd,
1060            'logger':   logger,
1061            'debug':    args.debug,
1062        })])
1063        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
1064        ctlserver.listen(args.ctlport, address=args.ctlhost)
1065
1066    if not args.foreground:
1067        daemonize()
1068
1069    ioloop.start()
1070
1071
1072def _start_ctl(args):
1073    def have_ssl_cert_verification():
1074        return 'context' in urllib2.urlopen.__code__.co_varnames
1075
1076    def request(args, method, params=None, id_=0):
1077        req = urllib2.Request(args.ctlurl)
1078        req.add_header('Content-type', 'application/json')
1079        if args.ctluser:
1080            if not args.ctlpasswd:
1081                args.ctlpasswd = getpass.getpass('Control Password: ')
1082            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
1083            req.add_header('Authorization', 'Basic %s' % token)
1084        method = '.'.join([ControlServerHandler.NAMESPACE, method])
1085        data = {'jsonrpc': '2.0', 'method': method, 'id': id_}
1086        if params is not None:
1087            data['params'] = params
1088        if have_ssl_cert_verification():
1089            ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH,
1090                                             cafile=args.ctlsslcert)
1091            if args.ctlinsecure:
1092                ctx.check_hostname = False
1093                ctx.verify_mode = ssl.CERT_NONE
1094            fp = urllib2.urlopen(req, json.dumps(data), context=ctx)
1095        elif args.ctlsslcert:
1096            raise EnvironmentError('do not support certificate verification')
1097        else:
1098            fp = urllib2.urlopen(req, json.dumps(data))
1099        return json.loads(fp.read())
1100
1101    def print_table(rows):
1102        cols = zip(*rows)
1103        maxlen = [0] * len(cols)
1104        for i in xrange(len(cols)):
1105            maxlen[i] = max(len(str(c)) for c in cols[i])
1106        fmt = '  '.join(['%%-%ds' % maxlen[i] for i in xrange(len(cols))])
1107        fmt = '  ' + fmt
1108        for row in rows:
1109            print(fmt % tuple(row))
1110
1111    def print_portlist(result):
1112        rows = [['Port', 'Type', 'State', 'RX', 'TX', 'Target']]
1113        for r in result:
1114            rows.append([r['port'], r['type'], 'shut' if r['shut'] else 'up',
1115                         r['rx'], r['tx'], r['target']])
1116        print_table(rows)
1117
1118    def print_iflist(result):
1119        rows = [['Port', 'Type', 'Address', 'Netmask', 'MTU', 'Target']]
1120        for r in result:
1121            rows.append([r['port'], r['type'], r['address'],
1122                         r['netmask'], r['mtu'], r['target']])
1123        print_table(rows)
1124
1125    def handle_ctl_addport(args):
1126        opts = {
1127            'user':     getattr(args, 'user', None),
1128            'passwd':   getattr(args, 'passwd', None),
1129            'cacerts':  getattr(args, 'cacerts', None),
1130            'insecure': getattr(args, 'insecure', None),
1131        }
1132        if args.iftype == ClientHandler.IFTYPE:
1133            if not args.target.startswith('ws://') and \
1134               not args.target.startswith('wss://'):
1135                raise ValueError('Invalid target URL scheme: %s' % args.target)
1136            if not opts['user'] and opts['passwd']:
1137                raise ValueError('Authentication required but username empty')
1138            if opts['user'] and not opts['passwd']:
1139                opts['passwd'] = getpass.getpass('Client Password: ')
1140        result = request(args, 'addPort', {
1141            'type':    args.iftype,
1142            'target':  args.target,
1143            'options': opts,
1144        })
1145        if 'error' in result:
1146            _print_error(result['error'])
1147        else:
1148            print_portlist(result['result']['entries'])
1149
1150    def handle_ctl_setport(args):
1151        if args.port <= 0:
1152            raise ValueError('Invalid port: %d' % args.port)
1153        req = {'port': args.port}
1154        shut = getattr(args, 'shut', None)
1155        if shut is not None:
1156            req['shut'] = bool(shut)
1157        result = request(args, 'setPort', req)
1158        if 'error' in result:
1159            _print_error(result['error'])
1160        else:
1161            print_portlist(result['result']['entries'])
1162
1163    def handle_ctl_delport(args):
1164        if args.port <= 0:
1165            raise ValueError('Invalid port: %d' % args.port)
1166        result = request(args, 'delPort', {'port': args.port})
1167        if 'error' in result:
1168            _print_error(result['error'])
1169        else:
1170            print_portlist(result['result']['entries'])
1171
1172    def handle_ctl_listport(args):
1173        result = request(args, 'listPort')
1174        if 'error' in result:
1175            _print_error(result['error'])
1176        else:
1177            print_portlist(result['result']['entries'])
1178
1179    def handle_ctl_setif(args):
1180        if args.port <= 0:
1181            raise ValueError('Invalid port: %d' % args.port)
1182        req = {'port': args.port}
1183        address = getattr(args, 'address', None)
1184        netmask = getattr(args, 'netmask', None)
1185        mtu = getattr(args, 'mtu', None)
1186        if address is not None:
1187            if address:
1188                socket.inet_aton(address)  # validate
1189            req['address'] = address
1190        if netmask is not None:
1191            if netmask:
1192                socket.inet_aton(netmask)  # validate
1193            req['netmask'] = netmask
1194        if mtu is not None:
1195            if mtu < 576:
1196                raise ValueError('Invalid MTU: %d' % mtu)
1197            req['mtu'] = mtu
1198        result = request(args, 'setInterface', req)
1199        if 'error' in result:
1200            _print_error(result['error'])
1201        else:
1202            print_iflist(result['result']['entries'])
1203
1204    def handle_ctl_listif(args):
1205        result = request(args, 'listInterface')
1206        if 'error' in result:
1207            _print_error(result['error'])
1208        else:
1209            print_iflist(result['result']['entries'])
1210
1211    def handle_ctl_listfdb(args):
1212        result = request(args, 'listFdb')
1213        if 'error' in result:
1214            return _print_error(result['error'])
1215        rows = [['Port', 'VLAN', 'MAC', 'Age']]
1216        for r in result['result']['entries']:
1217            rows.append([r['port'], r['vid'], r['mac'], r['age']])
1218        print_table(rows)
1219
1220    locals()['handle_ctl_' + args.control_method](args)
1221
1222
1223def _main():
1224    parser = argparse.ArgumentParser()
1225    subcommand = parser.add_subparsers(dest='subcommand')
1226
1227    # - sw
1228    parser_sw = subcommand.add_parser('sw',
1229                                      help='start virtual switch')
1230
1231    parser_sw.add_argument('--debug', action='store_true', default=False,
1232                           help='run as debug mode')
1233    parser_sw.add_argument('--logconf',
1234                           help='path to logging config')
1235    parser_sw.add_argument('--foreground', action='store_true', default=False,
1236                           help='run as foreground mode')
1237    parser_sw.add_argument('--ageout', type=int, default=300,
1238                           help='FDB ageout time (sec)')
1239
1240    parser_sw.add_argument('--path', default='/',
1241                           help='http(s) path to serve WebSocket')
1242    parser_sw.add_argument('--host', default='',
1243                           help='listen address to serve WebSocket')
1244    parser_sw.add_argument('--port', type=int,
1245                           help='listen port to serve WebSocket')
1246    parser_sw.add_argument('--htpasswd',
1247                           help='path to htpasswd file to auth WebSocket')
1248    parser_sw.add_argument('--sslkey',
1249                           help='path to SSL private key for WebSocket')
1250    parser_sw.add_argument('--sslcert',
1251                           help='path to SSL certificate for WebSocket')
1252
1253    parser_sw.add_argument('--ctlpath', default='/ctl',
1254                           help='http(s) path to serve control API')
1255    parser_sw.add_argument('--ctlhost', default='127.0.0.1',
1256                           help='listen address to serve control API')
1257    parser_sw.add_argument('--ctlport', type=int, default=7867,
1258                           help='listen port to serve control API')
1259    parser_sw.add_argument('--ctlhtpasswd',
1260                           help='path to htpasswd file to auth control API')
1261    parser_sw.add_argument('--ctlsslkey',
1262                           help='path to SSL private key for control API')
1263    parser_sw.add_argument('--ctlsslcert',
1264                           help='path to SSL certificate for control API')
1265
1266    # - ctl
1267    parser_ctl = subcommand.add_parser('ctl',
1268                                       help='control virtual switch')
1269    parser_ctl.add_argument('--ctlurl', default='http://127.0.0.1:7867/ctl',
1270                            help='URL to control API')
1271    parser_ctl.add_argument('--ctluser',
1272                            help='username to auth control API')
1273    parser_ctl.add_argument('--ctlpasswd',
1274                            help='password to auth control API')
1275    parser_ctl.add_argument('--ctlsslcert',
1276                            help='path to SSL certificate for control API')
1277    parser_ctl.add_argument(
1278        '--ctlinsecure', action='store_true', default=False,
1279        help='do not verify control API certificate')
1280
1281    control_method = parser_ctl.add_subparsers(dest='control_method')
1282
1283    # -- ctl addport
1284    parser_ctl_addport = control_method.add_parser('addport',
1285                                                   help='create and add port')
1286    iftype = parser_ctl_addport.add_subparsers(dest='iftype')
1287
1288    # --- ctl addport netdev
1289    parser_ctl_addport_netdev = iftype.add_parser(NetdevHandler.IFTYPE,
1290                                                  help='Network device')
1291    parser_ctl_addport_netdev.add_argument('target',
1292                                           help='device name to add interface')
1293
1294    # --- ctl addport tap
1295    parser_ctl_addport_tap = iftype.add_parser(TapHandler.IFTYPE,
1296                                               help='TAP device')
1297    parser_ctl_addport_tap.add_argument('target',
1298                                        help='device name to create interface')
1299
1300    # --- ctl addport client
1301    parser_ctl_addport_client = iftype.add_parser(ClientHandler.IFTYPE,
1302                                                  help='WebSocket client')
1303    parser_ctl_addport_client.add_argument('target',
1304                                           help='URL to connect WebSocket')
1305    parser_ctl_addport_client.add_argument('--user',
1306                                           help='username to auth WebSocket')
1307    parser_ctl_addport_client.add_argument('--passwd',
1308                                           help='password to auth WebSocket')
1309    parser_ctl_addport_client.add_argument('--cacerts',
1310                                           help='path to CA certificate')
1311    parser_ctl_addport_client.add_argument(
1312        '--insecure', action='store_true', default=False,
1313        help='do not verify server certificate')
1314
1315    # -- ctl setport
1316    parser_ctl_setport = control_method.add_parser('setport',
1317                                                   help='set port config')
1318    parser_ctl_setport.add_argument('port', type=int,
1319                                    help='port number to set config')
1320    parser_ctl_setport.add_argument('--shut', type=int, choices=(0, 1),
1321                                    help='set shutdown state')
1322
1323    # -- ctl delport
1324    parser_ctl_delport = control_method.add_parser('delport',
1325                                                   help='delete port')
1326    parser_ctl_delport.add_argument('port', type=int,
1327                                    help='port number to delete')
1328
1329    # -- ctl listport
1330    parser_ctl_listport = control_method.add_parser('listport',
1331                                                    help='show port list')
1332
1333    # -- ctl setif
1334    parser_ctl_setif = control_method.add_parser('setif',
1335                                                 help='set interface config')
1336    parser_ctl_setif.add_argument('port', type=int,
1337                                  help='port number to set config')
1338    parser_ctl_setif.add_argument('--address',
1339                                  help='IPv4 address to set interface')
1340    parser_ctl_setif.add_argument('--netmask',
1341                                  help='IPv4 netmask to set interface')
1342    parser_ctl_setif.add_argument('--mtu', type=int,
1343                                  help='MTU to set interface')
1344
1345    # -- ctl listif
1346    parser_ctl_listif = control_method.add_parser('listif',
1347                                                  help='show interface list')
1348
1349    # -- ctl listfdb
1350    parser_ctl_listfdb = control_method.add_parser('listfdb',
1351                                                   help='show FDB entries')
1352
1353    # -- go
1354    args = parser.parse_args()
1355
1356    try:
1357        globals()['_start_' + args.subcommand](args)
1358    except Exception as e:
1359        _print_error({
1360            'code':    0 - 32603,
1361            'message': 'Internal error',
1362            'data':    '%s: %s' % (type(e).__name__, str(e)),
1363        })
1364
1365
1366if __name__ == '__main__':
1367    _main()
Note: See TracBrowser for help on using the repository browser.