source: etherws/trunk/etherws.py @ 280

Revision 280, 44.4 KB checked in by atzm, 9 years ago (diff)
  • logging support
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.5
8#   - python-pytun-2.1
9#   - websocket-client-0.14.0
10#   - tornado-2.4
11#
12# ===========================================================================
13# Copyright (c) 2012-2015, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
42import ssl
43import time
44import json
45import fcntl
46import base64
47import socket
48import struct
49import urllib2
50import hashlib
51import getpass
52import operator
53import argparse
54
55import logging
56import logging.config
57
58import tornado
59import websocket
60
61from tornado.web import Application, RequestHandler
62from tornado.websocket import WebSocketHandler
63from tornado.httpserver import HTTPServer
64from tornado.ioloop import IOLoop
65
66from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
67
68
69class EthernetFrame(object):
70    def __init__(self, data):
71        self.data = data
72
73    @property
74    def dst_multicast(self):
75        return ord(self.data[0]) & 1
76
77    @property
78    def src_multicast(self):
79        return ord(self.data[6]) & 1
80
81    @property
82    def dst_mac(self):
83        return self.data[:6]
84
85    @property
86    def src_mac(self):
87        return self.data[6:12]
88
89    @property
90    def tagged(self):
91        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
92
93    @property
94    def vid(self):
95        if self.tagged:
96            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
97        return 0
98
99    @staticmethod
100    def format_mac(mac, sep=':'):
101        return sep.join(b.encode('hex') for b in mac)
102
103
104class FDB(object):
105    class Entry(object):
106        def __init__(self, port, ageout):
107            self.port = port
108            self._time = time.time()
109            self._ageout = ageout
110
111        @property
112        def age(self):
113            return time.time() - self._time
114
115        @property
116        def agedout(self):
117            return self.age > self._ageout
118
119    def __init__(self, ageout, logger, debug):
120        self._ageout = ageout
121        self._logger = logger
122        self._debug = debug
123        self._table = {}
124
125    def _set_entry(self, vid, mac, port):
126        if vid not in self._table:
127            self._table[vid] = {}
128        self._table[vid][mac] = self.Entry(port, self._ageout)
129
130    def _del_entry(self, vid, mac):
131        if vid in self._table:
132            if mac in self._table[vid]:
133                del self._table[vid][mac]
134            if not self._table[vid]:
135                del self._table[vid]
136
137    def _get_entry(self, vid, mac):
138        try:
139            entry = self._table[vid][mac]
140        except KeyError:
141            return None
142
143        if not entry.agedout:
144            return entry
145
146        self._del_entry(vid, mac)
147
148        if self._debug:
149            self._logger.debug('fdb aged out: port:%d; vid:%d; mac:%s',
150                               entry.port.number, vid, mac.encode('hex'))
151
152    def each(self):
153        for vid in sorted(self._table.iterkeys()):
154            for mac in sorted(self._table[vid].iterkeys()):
155                entry = self._get_entry(vid, mac)
156                if entry:
157                    yield (vid, mac, entry)
158
159    def lookup(self, frame):
160        mac = frame.dst_mac
161        vid = frame.vid
162        entry = self._get_entry(vid, mac)
163        return getattr(entry, 'port', None)
164
165    def learn(self, port, frame):
166        mac = frame.src_mac
167        vid = frame.vid
168        self._set_entry(vid, mac, port)
169
170        if self._debug:
171            self._logger.debug('fdb learned: port:%d; vid:%d; mac:%s',
172                               port.number, vid, mac.encode('hex'))
173
174    def delete(self, port):
175        for vid, mac, entry in self.each():
176            if entry.port.number == port.number:
177                self._del_entry(vid, mac)
178
179                if self._debug:
180                    self._logger.debug('fdb deleted: port:%d; vid:%d; mac:%s',
181                                       port.number, vid, mac.encode('hex'))
182
183
184class SwitchingHub(object):
185    class Port(object):
186        def __init__(self, number, interface):
187            self.number = number
188            self.interface = interface
189            self.tx = 0
190            self.rx = 0
191            self.shut = False
192
193    def __init__(self, fdb, logger, debug):
194        self.fdb = fdb
195        self._logger = logger
196        self._debug = debug
197        self._table = {}
198        self._next = 1
199
200    @property
201    def portlist(self):
202        return sorted(self._table.itervalues(),
203                      key=operator.attrgetter('number'))
204
205    def get_port(self, portnum):
206        return self._table[portnum]
207
208    def register_port(self, interface):
209        try:
210            self._set_privattr('portnum', interface, self._next)  # XXX
211            self._table[self._next] = self.Port(self._next, interface)
212            return self._next
213        finally:
214            self._next += 1
215
216    def unregister_port(self, interface):
217        portnum = self._get_privattr('portnum', interface)
218        self._del_privattr('portnum', interface)
219        self.fdb.delete(self._table[portnum])
220        del self._table[portnum]
221
222    def send(self, dst_interfaces, frame):
223        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
224        ports = (self._table[n] for n in portnums)
225        ports = (p for p in ports if not p.shut)
226        ports = sorted(ports, key=operator.attrgetter('number'))
227
228        for p in ports:
229            p.interface.write_message(frame.data, True)
230            p.tx += 1
231
232        if self._debug and ports:
233            self._logger.debug('sent: port:%s; vid:%d; %s -> %s',
234                               ','.join(str(p.number) for p in ports),
235                               frame.vid,
236                               frame.src_mac.encode('hex'),
237                               frame.dst_mac.encode('hex'))
238
239    def receive(self, src_interface, frame):
240        port = self._table[self._get_privattr('portnum', src_interface)]
241
242        if not port.shut:
243            port.rx += 1
244            self._forward(port, frame)
245
246    def _forward(self, src_port, frame):
247        try:
248            if not frame.src_multicast:
249                self.fdb.learn(src_port, frame)
250
251            if not frame.dst_multicast:
252                dst_port = self.fdb.lookup(frame)
253
254                if dst_port:
255                    self.send([dst_port.interface], frame)
256                    return
257
258            ports = set(self.portlist) - set([src_port])
259            self.send((p.interface for p in ports), frame)
260
261        except:  # ex. received invalid frame
262            self._logger.exception(
263                'caught error while forwarding frame: port:%d; frame:%s',
264                src_port.number, frame.data.encode('hex'))
265
266    def _privattr(self, name):
267        return '_%s_%s_%s' % (type(self).__name__, id(self), name)
268
269    def _set_privattr(self, name, obj, value):
270        return setattr(obj, self._privattr(name), value)
271
272    def _get_privattr(self, name, obj, defaults=None):
273        return getattr(obj, self._privattr(name), defaults)
274
275    def _del_privattr(self, name, obj):
276        return delattr(obj, self._privattr(name))
277
278
279class NetworkInterface(object):
280    SIOCGIFADDR = 0x8915  # from <linux/sockios.h>
281    SIOCSIFADDR = 0x8916
282    SIOCGIFNETMASK = 0x891b
283    SIOCSIFNETMASK = 0x891c
284    SIOCGIFMTU = 0x8921
285    SIOCSIFMTU = 0x8922
286
287    def __init__(self, ifname):
288        self._ifname = struct.pack('16s', str(ifname)[:15])
289
290    def _ioctl(self, req, data):
291        try:
292            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
293            return fcntl.ioctl(sock.fileno(), req, data)
294        finally:
295            sock.close()
296
297    @property
298    def address(self):
299        try:
300            data = struct.pack('16s18x', self._ifname)
301            ret = self._ioctl(self.SIOCGIFADDR, data)
302            return socket.inet_ntoa(ret[20:24])
303        except IOError:
304            return ''
305
306    @property
307    def netmask(self):
308        try:
309            data = struct.pack('16s18x', self._ifname)
310            ret = self._ioctl(self.SIOCGIFNETMASK, data)
311            return socket.inet_ntoa(ret[20:24])
312        except IOError:
313            return ''
314
315    @property
316    def mtu(self):
317        try:
318            data = struct.pack('16s18x', self._ifname)
319            ret = self._ioctl(self.SIOCGIFMTU, data)
320            return struct.unpack('i', ret[16:20])[0]
321        except IOError:
322            return 0
323
324    @address.setter
325    def address(self, addr):
326        data = struct.pack('16si4s10x', self._ifname, socket.AF_INET,
327                           socket.inet_aton(addr))
328        self._ioctl(self.SIOCSIFADDR, data)
329
330    @netmask.setter
331    def netmask(self, addr):
332        data = struct.pack('16si4s10x', self._ifname, socket.AF_INET,
333                           socket.inet_aton(addr))
334        self._ioctl(self.SIOCSIFNETMASK, data)
335
336    @mtu.setter
337    def mtu(self, mtu):
338        data = struct.pack('16si12x', self._ifname, mtu)
339        self._ioctl(self.SIOCSIFMTU, data)
340
341
342class Htpasswd(object):
343    def __init__(self, path):
344        self._path = path
345        self._stat = None
346        self._data = {}
347
348    def auth(self, name, passwd):
349        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
350        return self._data.get(name) == passwd
351
352    def load(self):
353        old_stat = self._stat
354
355        with open(self._path) as fp:
356            fileno = fp.fileno()
357            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
358            self._stat = os.fstat(fileno)
359
360            unchanged = old_stat and \
361                old_stat.st_ino == self._stat.st_ino and \
362                old_stat.st_dev == self._stat.st_dev and \
363                old_stat.st_mtime == self._stat.st_mtime
364
365            if not unchanged:
366                self._data = self._parse(fp)
367
368        return self
369
370    def _parse(self, fp):
371        data = {}
372        for line in fp:
373            line = line.strip()
374            if 0 <= line.find(':'):
375                name, passwd = line.split(':', 1)
376                if passwd.startswith('{SHA}'):
377                    data[name] = passwd[5:]
378        return data
379
380
381class BasicAuthMixIn(object):
382    def _execute(self, transforms, *args, **kwargs):
383        def do_execute():
384            sp = super(BasicAuthMixIn, self)
385            return sp._execute(transforms, *args, **kwargs)
386
387        def auth_required():
388            stream = getattr(self, 'stream', self.request.connection.stream)
389            stream.write(tornado.escape.utf8(
390                'HTTP/1.1 401 Authorization Required\r\n'
391                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
392            ))
393            stream.close()
394
395        try:
396            if not self._htpasswd:
397                return do_execute()
398
399            creds = self.request.headers.get('Authorization')
400
401            if not creds or not creds.startswith('Basic '):
402                return auth_required()
403
404            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
405
406            if self._htpasswd.load().auth(name, passwd):
407                return do_execute()
408        except:
409            self._logger.exception('caught error while doing authorization')
410
411        return auth_required()
412
413
414class ServerHandler(BasicAuthMixIn, WebSocketHandler):
415    IFTYPE = 'server'
416    IFOP_ALLOWED = False
417
418    def __init__(self, app, req, switch, htpasswd, logger, debug):
419        super(ServerHandler, self).__init__(app, req)
420        self._switch = switch
421        self._htpasswd = htpasswd
422        self._logger = logger
423        self._debug = debug
424
425    @property
426    def target(self):
427        conn = self.request.connection
428        if hasattr(conn, 'address'):
429            return ':'.join(str(e) for e in conn.address[:2])
430        if hasattr(conn, 'context'):
431            return ':'.join(str(e) for e in conn.context.address[:2])
432        return str(conn)
433
434    def open(self):
435        try:
436            return self._switch.register_port(self)
437        finally:
438            self._logger.info('connected: %s', self.request.remote_ip)
439
440    def on_message(self, message):
441        self._switch.receive(self, EthernetFrame(message))
442
443    def on_close(self):
444        self._switch.unregister_port(self)
445        self._logger.info('disconnected: %s', self.request.remote_ip)
446
447
448class BaseClientHandler(object):
449    IFTYPE = 'baseclient'
450    IFOP_ALLOWED = False
451
452    def __init__(self, ioloop, switch, target, logger, debug, *args, **kwargs):
453        self._ioloop = ioloop
454        self._switch = switch
455        self._target = target
456        self._logger = logger
457        self._debug = debug
458        self._args = args
459        self._kwargs = kwargs
460        self._device = None
461
462    @property
463    def address(self):
464        raise NotImplementedError('unsupported')
465
466    @property
467    def netmask(self):
468        raise NotImplementedError('unsupported')
469
470    @property
471    def mtu(self):
472        raise NotImplementedError('unsupported')
473
474    @address.setter
475    def address(self, address):
476        raise NotImplementedError('unsupported')
477
478    @netmask.setter
479    def netmask(self, netmask):
480        raise NotImplementedError('unsupported')
481
482    @mtu.setter
483    def mtu(self, mtu):
484        raise NotImplementedError('unsupported')
485
486    def open(self):
487        raise NotImplementedError('unsupported')
488
489    def write_message(self, message, binary=False):
490        raise NotImplementedError('unsupported')
491
492    def read(self):
493        raise NotImplementedError('unsupported')
494
495    @property
496    def target(self):
497        return self._target
498
499    @property
500    def logger(self):
501        return self._logger
502
503    @property
504    def device(self):
505        return self._device
506
507    @property
508    def closed(self):
509        return not self.device
510
511    def close(self):
512        if self.closed:
513            raise ValueError('I/O operation on closed %s' % self.IFTYPE)
514        self.leave_switch()
515        self.unregister_device()
516        self.logger.info('disconnected: %s', self.target)
517
518    def register_device(self, device):
519        self._device = device
520
521    def unregister_device(self):
522        self._device.close()
523        self._device = None
524
525    def fileno(self):
526        if self.closed:
527            raise ValueError('I/O operation on closed %s' % self.IFTYPE)
528        return self.device.fileno()
529
530    def __call__(self, fd, events):
531        try:
532            data = self.read()
533            if data is not None:
534                self._switch.receive(self, EthernetFrame(data))
535                return
536        except:
537            self.logger.exception('caught error while receiving frame')
538        self.close()
539
540    def join_switch(self):
541        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
542        return self._switch.register_port(self)
543
544    def leave_switch(self):
545        self._switch.unregister_port(self)
546        self._ioloop.remove_handler(self.fileno())
547
548
549class NetdevHandler(BaseClientHandler):
550    IFTYPE = 'netdev'
551    IFOP_ALLOWED = True
552    ETH_P_ALL = 0x0003  # from <linux/if_ether.h>
553
554    @property
555    def address(self):
556        if self.closed:
557            raise ValueError('I/O operation on closed netdev')
558        return NetworkInterface(self.target).address
559
560    @property
561    def netmask(self):
562        if self.closed:
563            raise ValueError('I/O operation on closed netdev')
564        return NetworkInterface(self.target).netmask
565
566    @property
567    def mtu(self):
568        if self.closed:
569            raise ValueError('I/O operation on closed netdev')
570        return NetworkInterface(self.target).mtu
571
572    @address.setter
573    def address(self, address):
574        if self.closed:
575            raise ValueError('I/O operation on closed netdev')
576        NetworkInterface(self.target).address = address
577
578    @netmask.setter
579    def netmask(self, netmask):
580        if self.closed:
581            raise ValueError('I/O operation on closed netdev')
582        NetworkInterface(self.target).netmask = netmask
583
584    @mtu.setter
585    def mtu(self, mtu):
586        if self.closed:
587            raise ValueError('I/O operation on closed netdev')
588        NetworkInterface(self.target).mtu = mtu
589
590    def open(self):
591        if not self.closed:
592            raise ValueError('Already opened')
593        self.register_device(socket.socket(
594            socket.PF_PACKET, socket.SOCK_RAW, socket.htons(self.ETH_P_ALL)))
595        self.device.bind((self.target, self.ETH_P_ALL))
596        self.logger.info('connected: %s', self.target)
597        return self.join_switch()
598
599    def write_message(self, message, binary=False):
600        if self.closed:
601            raise ValueError('I/O operation on closed netdev')
602        self.device.sendall(message)
603
604    def read(self):
605        if self.closed:
606            raise ValueError('I/O operation on closed netdev')
607        buf = []
608        while True:
609            buf.append(self.device.recv(65535))
610            if len(buf[-1]) < 65535:
611                break
612        return ''.join(buf)
613
614
615class TapHandler(BaseClientHandler):
616    IFTYPE = 'tap'
617    IFOP_ALLOWED = True
618
619    @property
620    def address(self):
621        if self.closed:
622            raise ValueError('I/O operation on closed tap')
623        try:
624            return self.device.addr
625        except:
626            return ''
627
628    @property
629    def netmask(self):
630        if self.closed:
631            raise ValueError('I/O operation on closed tap')
632        try:
633            return self.device.netmask
634        except:
635            return ''
636
637    @property
638    def mtu(self):
639        if self.closed:
640            raise ValueError('I/O operation on closed tap')
641        return self.device.mtu
642
643    @address.setter
644    def address(self, address):
645        if self.closed:
646            raise ValueError('I/O operation on closed tap')
647        self.device.addr = address
648
649    @netmask.setter
650    def netmask(self, netmask):
651        if self.closed:
652            raise ValueError('I/O operation on closed tap')
653        self.device.netmask = netmask
654
655    @mtu.setter
656    def mtu(self, mtu):
657        if self.closed:
658            raise ValueError('I/O operation on closed tap')
659        self.device.mtu = mtu
660
661    @property
662    def target(self):
663        if self.closed:
664            return self._target
665        return self.device.name
666
667    def open(self):
668        if not self.closed:
669            raise ValueError('Already opened')
670        self.register_device(TunTapDevice(self.target, IFF_TAP | IFF_NO_PI))
671        self.device.up()
672        self.logger.info('connected: %s', self.target)
673        return self.join_switch()
674
675    def write_message(self, message, binary=False):
676        if self.closed:
677            raise ValueError('I/O operation on closed tap')
678        self.device.write(message)
679
680    def read(self):
681        if self.closed:
682            raise ValueError('I/O operation on closed tap')
683        buf = []
684        while True:
685            buf.append(self.device.read(65535))
686            if len(buf[-1]) < 65535:
687                break
688        return ''.join(buf)
689
690
691class ClientHandler(BaseClientHandler):
692    IFTYPE = 'client'
693    IFOP_ALLOWED = False
694
695    def __init__(self, *args, **kwargs):
696        super(ClientHandler, self).__init__(*args, **kwargs)
697
698        self._sslopt = kwargs.get('sslopt', {})
699        self._options = self._getproxy()
700
701        cred = kwargs.get('cred', None)
702
703        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
704            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
705            auth = ['Authorization: Basic %s' % token]
706            self._options['header'] = auth
707
708    def open(self):
709        if not self.closed:
710            raise websocket.WebSocketException('Already opened')
711
712        # XXX: may be blocked
713        self.register_device(websocket.WebSocket(sslopt=self._sslopt))
714        self.device.connect(self.target, **self._options)
715        self.logger.info('connected: %s', self.target)
716        return self.join_switch()
717
718    def write_message(self, message, binary=False):
719        if self.closed:
720            raise websocket.WebSocketException('Closed socket')
721        if binary:
722            flag = websocket.ABNF.OPCODE_BINARY
723        else:
724            flag = websocket.ABNF.OPCODE_TEXT
725        self.device.send(message, flag)
726
727    def read(self):
728        if self.closed:
729            raise websocket.WebSocketException('Closed socket')
730        return self.device.recv()
731
732    def _getproxy(self):
733        scheme, remain = urllib2.splittype(self.target)
734
735        proxy = urllib2.getproxies().get(scheme.replace('ws', 'http'))
736        if not proxy:
737            return {}
738
739        host = urllib2.splitport(urllib2.splithost(remain)[0])[0]
740        if urllib2.proxy_bypass(host):
741            return {}
742
743        hostport = urllib2.splithost(urllib2.splittype(proxy)[1])[0]
744        host, port = urllib2.splitport(hostport)
745        port = int(port) if port else 0
746        return {'http_proxy_host': host, 'http_proxy_port': port}
747
748
749class ControlServerHandler(BasicAuthMixIn, RequestHandler):
750    NAMESPACE = 'etherws.control'
751    IFTYPES = {
752        NetdevHandler.IFTYPE: NetdevHandler,
753        TapHandler.IFTYPE:    TapHandler,
754        ClientHandler.IFTYPE: ClientHandler,
755    }
756
757    def __init__(self, app, req, ioloop, switch, htpasswd, logger, debug):
758        super(ControlServerHandler, self).__init__(app, req)
759        self._ioloop = ioloop
760        self._switch = switch
761        self._htpasswd = htpasswd
762        self._logger = logger
763        self._debug = debug
764
765    def post(self):
766        try:
767            request = json.loads(self.request.body)
768        except Exception as e:
769            return self._jsonrpc_response(error={
770                'code':    0 - 32700,
771                'message': 'Parse error',
772                'data':    '%s: %s' % (type(e).__name__, str(e)),
773            })
774
775        try:
776            id_ = request.get('id')
777            params = request.get('params')
778            version = request['jsonrpc']
779            method = request['method']
780            if version != '2.0':
781                raise ValueError('Invalid JSON-RPC version: %s' % version)
782        except Exception as e:
783            return self._jsonrpc_response(id_=id_, error={
784                'code':    0 - 32600,
785                'message': 'Invalid Request',
786                'data':    '%s: %s' % (type(e).__name__, str(e)),
787            })
788
789        try:
790            if not method.startswith(self.NAMESPACE + '.'):
791                raise ValueError('Invalid method namespace: %s' % method)
792            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
793            handler = getattr(self, handler)
794        except Exception as e:
795            return self._jsonrpc_response(id_=id_, error={
796                'code':    0 - 32601,
797                'message': 'Method not found',
798                'data':    '%s: %s' % (type(e).__name__, str(e)),
799            })
800
801        try:
802            return self._jsonrpc_response(id_=id_, result=handler(params))
803        except Exception as e:
804            self._logger.exception('caught error while processing request')
805            return self._jsonrpc_response(id_=id_, error={
806                'code':    0 - 32602,
807                'message': 'Invalid params',
808                'data':     '%s: %s' % (type(e).__name__, str(e)),
809            })
810
811    def handle_listFdb(self, params):
812        list_ = []
813        for vid, mac, entry in self._switch.fdb.each():
814            list_.append({
815                'vid':  vid,
816                'mac':  EthernetFrame.format_mac(mac),
817                'port': entry.port.number,
818                'age':  int(entry.age),
819            })
820        return {'entries': list_}
821
822    def handle_listPort(self, params):
823        return {'entries': [self._portstat(p) for p in self._switch.portlist]}
824
825    def handle_addPort(self, params):
826        type_ = params['type']
827        target = params['target']
828        opt = getattr(self, '_optparse_' + type_)(params.get('options', {}))
829        cls = self.IFTYPES[type_]
830        interface = cls(self._ioloop, self._switch, target,
831                        self._logger, self._debug, **opt)
832        portnum = interface.open()
833        return {'entries': [self._portstat(self._switch.get_port(portnum))]}
834
835    def handle_setPort(self, params):
836        port = self._switch.get_port(int(params['port']))
837        shut = params.get('shut')
838        if shut is not None:
839            port.shut = bool(shut)
840        return {'entries': [self._portstat(port)]}
841
842    def handle_delPort(self, params):
843        port = self._switch.get_port(int(params['port']))
844        port.interface.close()
845        return {'entries': [self._portstat(port)]}
846
847    def handle_setInterface(self, params):
848        portnum = int(params['port'])
849        port = self._switch.get_port(portnum)
850        address = params.get('address')
851        netmask = params.get('netmask')
852        mtu = params.get('mtu')
853        if not port.interface.IFOP_ALLOWED:
854            raise ValueError('Port %d has unsupported interface: %s' %
855                             (portnum, port.interface.IFTYPE))
856        if address is not None:
857            port.interface.address = address
858        if netmask is not None:
859            port.interface.netmask = netmask
860        if mtu is not None:
861            port.interface.mtu = mtu
862        return {'entries': [self._ifstat(port)]}
863
864    def handle_listInterface(self, params):
865        return {'entries': [self._ifstat(p) for p in self._switch.portlist
866                            if p.interface.IFOP_ALLOWED]}
867
868    def _optparse_netdev(self, opt):
869        return {}
870
871    def _optparse_tap(self, opt):
872        return {}
873
874    def _optparse_client(self, opt):
875        if opt.get('insecure'):
876            sslopt = {'cert_reqs': ssl.CERT_NONE}
877        else:
878            sslopt = {'cert_reqs': ssl.CERT_REQUIRED,
879                      'ca_certs':  opt.get('cacerts')}
880        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
881        return {'sslopt': sslopt, 'cred': cred}
882
883    def _jsonrpc_response(self, id_=None, result=None, error=None):
884        res = {'jsonrpc': '2.0', 'id': id_}
885        if result:
886            res['result'] = result
887        if error:
888            res['error'] = error
889        self.finish(res)
890
891    @staticmethod
892    def _portstat(port):
893        return {
894            'port':   port.number,
895            'type':   port.interface.IFTYPE,
896            'target': port.interface.target,
897            'tx':     port.tx,
898            'rx':     port.rx,
899            'shut':   port.shut,
900        }
901
902    @staticmethod
903    def _ifstat(port):
904        return {
905            'port':    port.number,
906            'type':    port.interface.IFTYPE,
907            'target':  port.interface.target,
908            'address': port.interface.address,
909            'netmask': port.interface.netmask,
910            'mtu':     port.interface.mtu,
911        }
912
913
914def _print_error(error):
915    print(%s (%s)' % (error['message'], error['code']))
916    print('    %s' % error['data'])
917
918
919def _start_sw(args):
920    def daemonize(nochdir=False, noclose=False):
921        if os.fork() > 0:
922            sys.exit(0)
923
924        os.setsid()
925
926        if os.fork() > 0:
927            sys.exit(0)
928
929        if not nochdir:
930            os.chdir('/')
931
932        if not noclose:
933            os.umask(0)
934            sys.stdin.close()
935            sys.stdout.close()
936            sys.stderr.close()
937            os.close(0)
938            os.close(1)
939            os.close(2)
940            sys.stdin = open(os.devnull)
941            sys.stdout = open(os.devnull, 'a')
942            sys.stderr = open(os.devnull, 'a')
943
944    def checkabspath(ns, path):
945        val = getattr(ns, path, '')
946        if not val.startswith('/'):
947            raise ValueError('Invalid %: %s' % (path, val))
948
949    def getsslopt(ns, key, cert):
950        kval = getattr(ns, key, None)
951        cval = getattr(ns, cert, None)
952        if kval and cval:
953            return {'keyfile': kval, 'certfile': cval}
954        elif kval or cval:
955            raise ValueError('Both %s and %s are required' % (key, cert))
956        return None
957
958    def setrealpath(ns, *keys):
959        for k in keys:
960            v = getattr(ns, k, None)
961            if v is not None:
962                v = os.path.realpath(v)
963                open(v).close()  # check readable
964                setattr(ns, k, v)
965
966    def setport(ns, port, isssl):
967        val = getattr(ns, port, None)
968        if val is None:
969            if isssl:
970                return setattr(ns, port, 443)
971            return setattr(ns, port, 80)
972        if not (0 <= val <= 65535):
973            raise ValueError('Invalid %s: %s' % (port, val))
974
975    def sethtpasswd(ns, htpasswd):
976        val = getattr(ns, htpasswd, None)
977        if val:
978            return setattr(ns, htpasswd, Htpasswd(val))
979
980    # if args.debug:
981    #     websocket.enableTrace(True)
982
983    if args.ageout <= 0:
984        raise ValueError('Invalid ageout: %s' % args.ageout)
985
986    setrealpath(args, 'logconf')
987    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
988    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
989
990    checkabspath(args, 'path')
991    checkabspath(args, 'ctlpath')
992
993    sslopt = getsslopt(args, 'sslkey', 'sslcert')
994    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
995
996    setport(args, 'port', sslopt)
997    setport(args, 'ctlport', ctlsslopt)
998
999    sethtpasswd(args, 'htpasswd')
1000    sethtpasswd(args, 'ctlhtpasswd')
1001
1002    if args.logconf:
1003        logging.config.fileConfig(args.logconf)
1004        logger = logging.getLogger('etherws')
1005    else:
1006        logger = logging.getLogger('etherws')
1007        logger.setLevel(logging.DEBUG if args.debug else logging.INFO)
1008
1009    ioloop = IOLoop.instance()
1010    fdb = FDB(args.ageout, logger, args.debug)
1011    switch = SwitchingHub(fdb, logger, args.debug)
1012
1013    if args.port == args.ctlport and args.host == args.ctlhost:
1014        if args.path == args.ctlpath:
1015            raise ValueError('Same path/ctlpath on same host')
1016        if args.sslkey != args.ctlsslkey:
1017            raise ValueError('Different sslkey/ctlsslkey on same host')
1018        if args.sslcert != args.ctlsslcert:
1019            raise ValueError('Different sslcert/ctlsslcert on same host')
1020
1021        app = Application([
1022            (args.path, ServerHandler, {
1023                'switch':   switch,
1024                'htpasswd': args.htpasswd,
1025                'logger':   logger,
1026                'debug':    args.debug,
1027            }),
1028            (args.ctlpath, ControlServerHandler, {
1029                'ioloop':   ioloop,
1030                'switch':   switch,
1031                'htpasswd': args.ctlhtpasswd,
1032                'logger':   logger,
1033                'debug':    args.debug,
1034            }),
1035        ])
1036        server = HTTPServer(app, ssl_options=sslopt)
1037        server.listen(args.port, address=args.host)
1038
1039    else:
1040        app = Application([(args.path, ServerHandler, {
1041            'switch':   switch,
1042            'htpasswd': args.htpasswd,
1043            'logger':   logger,
1044            'debug':    args.debug,
1045        })])
1046        server = HTTPServer(app, ssl_options=sslopt)
1047        server.listen(args.port, address=args.host)
1048
1049        ctl = Application([(args.ctlpath, ControlServerHandler, {
1050            'ioloop':   ioloop,
1051            'switch':   switch,
1052            'htpasswd': args.ctlhtpasswd,
1053            'logger':   logger,
1054            'debug':    args.debug,
1055        })])
1056        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
1057        ctlserver.listen(args.ctlport, address=args.ctlhost)
1058
1059    if not args.foreground:
1060        daemonize()
1061
1062    ioloop.start()
1063
1064
1065def _start_ctl(args):
1066    def have_ssl_cert_verification():
1067        return 'context' in urllib2.urlopen.__code__.co_varnames
1068
1069    def request(args, method, params=None, id_=0):
1070        req = urllib2.Request(args.ctlurl)
1071        req.add_header('Content-type', 'application/json')
1072        if args.ctluser:
1073            if not args.ctlpasswd:
1074                args.ctlpasswd = getpass.getpass('Control Password: ')
1075            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
1076            req.add_header('Authorization', 'Basic %s' % token)
1077        method = '.'.join([ControlServerHandler.NAMESPACE, method])
1078        data = {'jsonrpc': '2.0', 'method': method, 'id': id_}
1079        if params is not None:
1080            data['params'] = params
1081        if have_ssl_cert_verification():
1082            ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH,
1083                                             cafile=args.ctlsslcert)
1084            if args.ctlinsecure:
1085                ctx.check_hostname = False
1086                ctx.verify_mode = ssl.CERT_NONE
1087            fp = urllib2.urlopen(req, json.dumps(data), context=ctx)
1088        elif args.ctlsslcert:
1089            raise EnvironmentError('do not support certificate verification')
1090        else:
1091            fp = urllib2.urlopen(req, json.dumps(data))
1092        return json.loads(fp.read())
1093
1094    def print_table(rows):
1095        cols = zip(*rows)
1096        maxlen = [0] * len(cols)
1097        for i in xrange(len(cols)):
1098            maxlen[i] = max(len(str(c)) for c in cols[i])
1099        fmt = '  '.join(['%%-%ds' % maxlen[i] for i in xrange(len(cols))])
1100        fmt = '  ' + fmt
1101        for row in rows:
1102            print(fmt % tuple(row))
1103
1104    def print_portlist(result):
1105        rows = [['Port', 'Type', 'State', 'RX', 'TX', 'Target']]
1106        for r in result:
1107            rows.append([r['port'], r['type'], 'shut' if r['shut'] else 'up',
1108                         r['rx'], r['tx'], r['target']])
1109        print_table(rows)
1110
1111    def print_iflist(result):
1112        rows = [['Port', 'Type', 'Address', 'Netmask', 'MTU', 'Target']]
1113        for r in result:
1114            rows.append([r['port'], r['type'], r['address'],
1115                         r['netmask'], r['mtu'], r['target']])
1116        print_table(rows)
1117
1118    def handle_ctl_addport(args):
1119        opts = {
1120            'user':     getattr(args, 'user', None),
1121            'passwd':   getattr(args, 'passwd', None),
1122            'cacerts':  getattr(args, 'cacerts', None),
1123            'insecure': getattr(args, 'insecure', None),
1124        }
1125        if args.iftype == ClientHandler.IFTYPE:
1126            if not args.target.startswith('ws://') and \
1127               not args.target.startswith('wss://'):
1128                raise ValueError('Invalid target URL scheme: %s' % args.target)
1129            if not opts['user'] and opts['passwd']:
1130                raise ValueError('Authentication required but username empty')
1131            if opts['user'] and not opts['passwd']:
1132                opts['passwd'] = getpass.getpass('Client Password: ')
1133        result = request(args, 'addPort', {
1134            'type':    args.iftype,
1135            'target':  args.target,
1136            'options': opts,
1137        })
1138        if 'error' in result:
1139            _print_error(result['error'])
1140        else:
1141            print_portlist(result['result']['entries'])
1142
1143    def handle_ctl_setport(args):
1144        if args.port <= 0:
1145            raise ValueError('Invalid port: %d' % args.port)
1146        req = {'port': args.port}
1147        shut = getattr(args, 'shut', None)
1148        if shut is not None:
1149            req['shut'] = bool(shut)
1150        result = request(args, 'setPort', req)
1151        if 'error' in result:
1152            _print_error(result['error'])
1153        else:
1154            print_portlist(result['result']['entries'])
1155
1156    def handle_ctl_delport(args):
1157        if args.port <= 0:
1158            raise ValueError('Invalid port: %d' % args.port)
1159        result = request(args, 'delPort', {'port': args.port})
1160        if 'error' in result:
1161            _print_error(result['error'])
1162        else:
1163            print_portlist(result['result']['entries'])
1164
1165    def handle_ctl_listport(args):
1166        result = request(args, 'listPort')
1167        if 'error' in result:
1168            _print_error(result['error'])
1169        else:
1170            print_portlist(result['result']['entries'])
1171
1172    def handle_ctl_setif(args):
1173        if args.port <= 0:
1174            raise ValueError('Invalid port: %d' % args.port)
1175        req = {'port': args.port}
1176        address = getattr(args, 'address', None)
1177        netmask = getattr(args, 'netmask', None)
1178        mtu = getattr(args, 'mtu', None)
1179        if address is not None:
1180            if address:
1181                socket.inet_aton(address)  # validate
1182            req['address'] = address
1183        if netmask is not None:
1184            if netmask:
1185                socket.inet_aton(netmask)  # validate
1186            req['netmask'] = netmask
1187        if mtu is not None:
1188            if mtu < 576:
1189                raise ValueError('Invalid MTU: %d' % mtu)
1190            req['mtu'] = mtu
1191        result = request(args, 'setInterface', req)
1192        if 'error' in result:
1193            _print_error(result['error'])
1194        else:
1195            print_iflist(result['result']['entries'])
1196
1197    def handle_ctl_listif(args):
1198        result = request(args, 'listInterface')
1199        if 'error' in result:
1200            _print_error(result['error'])
1201        else:
1202            print_iflist(result['result']['entries'])
1203
1204    def handle_ctl_listfdb(args):
1205        result = request(args, 'listFdb')
1206        if 'error' in result:
1207            return _print_error(result['error'])
1208        rows = [['Port', 'VLAN', 'MAC', 'Age']]
1209        for r in result['result']['entries']:
1210            rows.append([r['port'], r['vid'], r['mac'], r['age']])
1211        print_table(rows)
1212
1213    locals()['handle_ctl_' + args.control_method](args)
1214
1215
1216def _main():
1217    parser = argparse.ArgumentParser()
1218    subcommand = parser.add_subparsers(dest='subcommand')
1219
1220    # - sw
1221    parser_sw = subcommand.add_parser('sw',
1222                                      help='start virtual switch')
1223
1224    parser_sw.add_argument('--debug', action='store_true', default=False,
1225                           help='run as debug mode')
1226    parser_sw.add_argument('--logconf',
1227                           help='path to logging config')
1228    parser_sw.add_argument('--foreground', action='store_true', default=False,
1229                           help='run as foreground mode')
1230    parser_sw.add_argument('--ageout', type=int, default=300,
1231                           help='FDB ageout time (sec)')
1232
1233    parser_sw.add_argument('--path', default='/',
1234                           help='http(s) path to serve WebSocket')
1235    parser_sw.add_argument('--host', default='',
1236                           help='listen address to serve WebSocket')
1237    parser_sw.add_argument('--port', type=int,
1238                           help='listen port to serve WebSocket')
1239    parser_sw.add_argument('--htpasswd',
1240                           help='path to htpasswd file to auth WebSocket')
1241    parser_sw.add_argument('--sslkey',
1242                           help='path to SSL private key for WebSocket')
1243    parser_sw.add_argument('--sslcert',
1244                           help='path to SSL certificate for WebSocket')
1245
1246    parser_sw.add_argument('--ctlpath', default='/ctl',
1247                           help='http(s) path to serve control API')
1248    parser_sw.add_argument('--ctlhost', default='127.0.0.1',
1249                           help='listen address to serve control API')
1250    parser_sw.add_argument('--ctlport', type=int, default=7867,
1251                           help='listen port to serve control API')
1252    parser_sw.add_argument('--ctlhtpasswd',
1253                           help='path to htpasswd file to auth control API')
1254    parser_sw.add_argument('--ctlsslkey',
1255                           help='path to SSL private key for control API')
1256    parser_sw.add_argument('--ctlsslcert',
1257                           help='path to SSL certificate for control API')
1258
1259    # - ctl
1260    parser_ctl = subcommand.add_parser('ctl',
1261                                       help='control virtual switch')
1262    parser_ctl.add_argument('--ctlurl', default='http://127.0.0.1:7867/ctl',
1263                            help='URL to control API')
1264    parser_ctl.add_argument('--ctluser',
1265                            help='username to auth control API')
1266    parser_ctl.add_argument('--ctlpasswd',
1267                            help='password to auth control API')
1268    parser_ctl.add_argument('--ctlsslcert',
1269                            help='path to SSL certificate for control API')
1270    parser_ctl.add_argument(
1271        '--ctlinsecure', action='store_true', default=False,
1272        help='do not verify control API certificate')
1273
1274    control_method = parser_ctl.add_subparsers(dest='control_method')
1275
1276    # -- ctl addport
1277    parser_ctl_addport = control_method.add_parser('addport',
1278                                                   help='create and add port')
1279    iftype = parser_ctl_addport.add_subparsers(dest='iftype')
1280
1281    # --- ctl addport netdev
1282    parser_ctl_addport_netdev = iftype.add_parser(NetdevHandler.IFTYPE,
1283                                                  help='Network device')
1284    parser_ctl_addport_netdev.add_argument('target',
1285                                           help='device name to add interface')
1286
1287    # --- ctl addport tap
1288    parser_ctl_addport_tap = iftype.add_parser(TapHandler.IFTYPE,
1289                                               help='TAP device')
1290    parser_ctl_addport_tap.add_argument('target',
1291                                        help='device name to create interface')
1292
1293    # --- ctl addport client
1294    parser_ctl_addport_client = iftype.add_parser(ClientHandler.IFTYPE,
1295                                                  help='WebSocket client')
1296    parser_ctl_addport_client.add_argument('target',
1297                                           help='URL to connect WebSocket')
1298    parser_ctl_addport_client.add_argument('--user',
1299                                           help='username to auth WebSocket')
1300    parser_ctl_addport_client.add_argument('--passwd',
1301                                           help='password to auth WebSocket')
1302    parser_ctl_addport_client.add_argument('--cacerts',
1303                                           help='path to CA certificate')
1304    parser_ctl_addport_client.add_argument(
1305        '--insecure', action='store_true', default=False,
1306        help='do not verify server certificate')
1307
1308    # -- ctl setport
1309    parser_ctl_setport = control_method.add_parser('setport',
1310                                                   help='set port config')
1311    parser_ctl_setport.add_argument('port', type=int,
1312                                    help='port number to set config')
1313    parser_ctl_setport.add_argument('--shut', type=int, choices=(0, 1),
1314                                    help='set shutdown state')
1315
1316    # -- ctl delport
1317    parser_ctl_delport = control_method.add_parser('delport',
1318                                                   help='delete port')
1319    parser_ctl_delport.add_argument('port', type=int,
1320                                    help='port number to delete')
1321
1322    # -- ctl listport
1323    parser_ctl_listport = control_method.add_parser('listport',
1324                                                    help='show port list')
1325
1326    # -- ctl setif
1327    parser_ctl_setif = control_method.add_parser('setif',
1328                                                 help='set interface config')
1329    parser_ctl_setif.add_argument('port', type=int,
1330                                  help='port number to set config')
1331    parser_ctl_setif.add_argument('--address',
1332                                  help='IPv4 address to set interface')
1333    parser_ctl_setif.add_argument('--netmask',
1334                                  help='IPv4 netmask to set interface')
1335    parser_ctl_setif.add_argument('--mtu', type=int,
1336                                  help='MTU to set interface')
1337
1338    # -- ctl listif
1339    parser_ctl_listif = control_method.add_parser('listif',
1340                                                  help='show interface list')
1341
1342    # -- ctl listfdb
1343    parser_ctl_listfdb = control_method.add_parser('listfdb',
1344                                                   help='show FDB entries')
1345
1346    # -- go
1347    args = parser.parse_args()
1348
1349    try:
1350        globals()['_start_' + args.subcommand](args)
1351    except Exception as e:
1352        _print_error({
1353            'code':    0 - 32603,
1354            'message': 'Internal error',
1355            'data':    '%s: %s' % (type(e).__name__, str(e)),
1356        })
1357
1358
1359if __name__ == '__main__':
1360    _main()
Note: See TracBrowser for help on using the repository browser.