source: etherws/trunk/etherws.py @ 288

Revision 288, 45.0 KB checked in by atzm, 8 years ago (diff)
  • fix exception on tornado 4.x
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.5
8#   - python-pytun-2.1
9#   - websocket-client-0.14.0
10#   - tornado-2.4
11#
12# ===========================================================================
13# Copyright (c) 2012-2015, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
42import ssl
43import time
44import json
45import fcntl
46import base64
47import socket
48import struct
49import urllib2
50import hashlib
51import getpass
52import operator
53import argparse
54
55import logging
56import logging.config
57
58import tornado
59import websocket
60
61from tornado.web import Application, RequestHandler
62from tornado.websocket import WebSocketHandler
63from tornado.httpserver import HTTPServer
64from tornado.ioloop import IOLoop
65
66from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
67
68
69class EthernetFrame(object):
70    def __init__(self, data):
71        self.data = data
72
73    @property
74    def dst_multicast(self):
75        return ord(self.data[0]) & 1
76
77    @property
78    def src_multicast(self):
79        return ord(self.data[6]) & 1
80
81    @property
82    def dst_mac(self):
83        return self.data[:6]
84
85    @property
86    def src_mac(self):
87        return self.data[6:12]
88
89    @property
90    def tagged(self):
91        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
92
93    @property
94    def vid(self):
95        if self.tagged:
96            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
97        return 0
98
99    @staticmethod
100    def format_mac(mac, sep=':'):
101        return sep.join(b.encode('hex') for b in mac)
102
103
104class FDB(object):
105    class Entry(object):
106        def __init__(self, port, ageout):
107            self.port = port
108            self._time = time.time()
109            self._ageout = ageout
110
111        @property
112        def age(self):
113            return time.time() - self._time
114
115        @property
116        def agedout(self):
117            return self.age > self._ageout
118
119    def __init__(self, ageout, logger, debug):
120        self._ageout = ageout
121        self._logger = logger
122        self._debug = debug
123        self._table = {}
124
125    def _set_entry(self, vid, mac, port):
126        if vid not in self._table:
127            self._table[vid] = {}
128        self._table[vid][mac] = self.Entry(port, self._ageout)
129
130    def _del_entry(self, vid, mac):
131        if vid in self._table:
132            if mac in self._table[vid]:
133                del self._table[vid][mac]
134            if not self._table[vid]:
135                del self._table[vid]
136
137    def _get_entry(self, vid, mac):
138        try:
139            entry = self._table[vid][mac]
140        except KeyError:
141            return None
142
143        if not entry.agedout:
144            return entry
145
146        self._del_entry(vid, mac)
147
148        if self._debug:
149            self._logger.debug('fdb aged out: port:%d; vid:%d; mac:%s',
150                               entry.port.number, vid, mac.encode('hex'))
151
152    def each(self):
153        for vid in sorted(self._table.iterkeys()):
154            for mac in sorted(self._table[vid].iterkeys()):
155                entry = self._get_entry(vid, mac)
156                if entry:
157                    yield (vid, mac, entry)
158
159    def lookup(self, frame):
160        mac = frame.dst_mac
161        vid = frame.vid
162        entry = self._get_entry(vid, mac)
163        return getattr(entry, 'port', None)
164
165    def learn(self, port, frame):
166        mac = frame.src_mac
167        vid = frame.vid
168        self._set_entry(vid, mac, port)
169
170        if self._debug:
171            self._logger.debug('fdb learned: port:%d; vid:%d; mac:%s',
172                               port.number, vid, mac.encode('hex'))
173
174    def delete(self, port):
175        for vid, mac, entry in self.each():
176            if entry.port.number == port.number:
177                self._del_entry(vid, mac)
178
179                if self._debug:
180                    self._logger.debug('fdb deleted: port:%d; vid:%d; mac:%s',
181                                       port.number, vid, mac.encode('hex'))
182
183
184class SwitchingHub(object):
185    class Port(object):
186        def __init__(self, number, interface):
187            self.number = number
188            self.interface = interface
189            self.tx = 0
190            self.rx = 0
191            self.shut = False
192
193    def __init__(self, fdb, logger, debug):
194        self.fdb = fdb
195        self._logger = logger
196        self._debug = debug
197        self._table = {}
198        self._next = 1
199
200    @property
201    def portlist(self):
202        return sorted(self._table.itervalues(),
203                      key=operator.attrgetter('number'))
204
205    def get_port(self, portnum):
206        return self._table[portnum]
207
208    def register_port(self, interface):
209        try:
210            self._set_privattr('portnum', interface, self._next)  # XXX
211            self._table[self._next] = self.Port(self._next, interface)
212            return self._next
213        finally:
214            self._next += 1
215
216    def unregister_port(self, interface):
217        portnum = self._get_privattr('portnum', interface)
218        self._del_privattr('portnum', interface)
219        self.fdb.delete(self._table[portnum])
220        del self._table[portnum]
221
222    def send(self, dst_interfaces, frame):
223        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
224        ports = (self._table[n] for n in portnums)
225        ports = (p for p in ports if not p.shut)
226        ports = sorted(ports, key=operator.attrgetter('number'))
227
228        for p in ports:
229            p.interface.write_message(frame.data, True)
230            p.tx += 1
231
232        if self._debug and ports:
233            self._logger.debug('sent: port:%s; vid:%d; %s -> %s',
234                               ','.join(str(p.number) for p in ports),
235                               frame.vid,
236                               frame.src_mac.encode('hex'),
237                               frame.dst_mac.encode('hex'))
238
239    def receive(self, src_interface, frame):
240        port = self._table[self._get_privattr('portnum', src_interface)]
241
242        if not port.shut:
243            port.rx += 1
244            self._forward(port, frame)
245
246    def _forward(self, src_port, frame):
247        try:
248            if not frame.src_multicast:
249                self.fdb.learn(src_port, frame)
250
251            if not frame.dst_multicast:
252                dst_port = self.fdb.lookup(frame)
253
254                if dst_port:
255                    self.send([dst_port.interface], frame)
256                    return
257
258            ports = set(self.portlist) - set([src_port])
259            self.send((p.interface for p in ports), frame)
260
261        except:  # ex. received invalid frame
262            self._logger.exception(
263                'caught error while forwarding frame: port:%d; frame:%s',
264                src_port.number, frame.data.encode('hex'))
265
266    def _privattr(self, name):
267        return '_%s_%s_%s' % (type(self).__name__, id(self), name)
268
269    def _set_privattr(self, name, obj, value):
270        return setattr(obj, self._privattr(name), value)
271
272    def _get_privattr(self, name, obj, defaults=None):
273        return getattr(obj, self._privattr(name), defaults)
274
275    def _del_privattr(self, name, obj):
276        return delattr(obj, self._privattr(name))
277
278
279class NetworkInterface(object):
280    SIOCGIFADDR = 0x8915  # from <linux/sockios.h>
281    SIOCSIFADDR = 0x8916
282    SIOCGIFNETMASK = 0x891b
283    SIOCSIFNETMASK = 0x891c
284    SIOCGIFMTU = 0x8921
285    SIOCSIFMTU = 0x8922
286
287    def __init__(self, ifname):
288        self._ifname = struct.pack('16s', str(ifname)[:15])
289
290    def _ioctl(self, req, data):
291        try:
292            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
293            return fcntl.ioctl(sock.fileno(), req, data)
294        finally:
295            sock.close()
296
297    @property
298    def address(self):
299        try:
300            data = struct.pack('16s18x', self._ifname)
301            ret = self._ioctl(self.SIOCGIFADDR, data)
302            return socket.inet_ntoa(ret[20:24])
303        except IOError:
304            return ''
305
306    @property
307    def netmask(self):
308        try:
309            data = struct.pack('16s18x', self._ifname)
310            ret = self._ioctl(self.SIOCGIFNETMASK, data)
311            return socket.inet_ntoa(ret[20:24])
312        except IOError:
313            return ''
314
315    @property
316    def mtu(self):
317        try:
318            data = struct.pack('16s18x', self._ifname)
319            ret = self._ioctl(self.SIOCGIFMTU, data)
320            return struct.unpack('i', ret[16:20])[0]
321        except IOError:
322            return 0
323
324    @address.setter
325    def address(self, addr):
326        data = struct.pack('16si4s10x', self._ifname, socket.AF_INET,
327                           socket.inet_aton(addr))
328        self._ioctl(self.SIOCSIFADDR, data)
329
330    @netmask.setter
331    def netmask(self, addr):
332        data = struct.pack('16si4s10x', self._ifname, socket.AF_INET,
333                           socket.inet_aton(addr))
334        self._ioctl(self.SIOCSIFNETMASK, data)
335
336    @mtu.setter
337    def mtu(self, mtu):
338        data = struct.pack('16si12x', self._ifname, mtu)
339        self._ioctl(self.SIOCSIFMTU, data)
340
341
342class Htpasswd(object):
343    def __init__(self, path):
344        self._path = path
345        self._stat = None
346        self._data = {}
347
348    def auth(self, name, passwd):
349        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
350        return self._data.get(name) == passwd
351
352    def load(self):
353        old_stat = self._stat
354
355        with open(self._path) as fp:
356            fileno = fp.fileno()
357            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
358            self._stat = os.fstat(fileno)
359
360            unchanged = old_stat and \
361                old_stat.st_ino == self._stat.st_ino and \
362                old_stat.st_dev == self._stat.st_dev and \
363                old_stat.st_mtime == self._stat.st_mtime
364
365            if not unchanged:
366                self._data = self._parse(fp)
367
368        return self
369
370    def _parse(self, fp):
371        data = {}
372        for line in fp:
373            line = line.strip()
374            if 0 <= line.find(':'):
375                name, passwd = line.split(':', 1)
376                if passwd.startswith('{SHA}'):
377                    data[name] = passwd[5:]
378        return data
379
380
381class BasicAuthMixIn(object):
382    def _execute(self, transforms, *args, **kwargs):
383        def do_execute():
384            sp = super(BasicAuthMixIn, self)
385            return sp._execute(transforms, *args, **kwargs)
386
387        def auth_required():
388            stream = getattr(self, 'stream', None)
389            if stream is None:
390                stream = self.request.connection.stream
391
392            stream.write(tornado.escape.utf8(
393                'HTTP/1.1 401 Authorization Required\r\n'
394                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
395            ))
396            stream.close()
397
398        try:
399            if not self._htpasswd:
400                return do_execute()
401
402            creds = self.request.headers.get('Authorization')
403
404            if not creds or not creds.startswith('Basic '):
405                return auth_required()
406
407            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
408
409            if self._htpasswd.load().auth(name, passwd):
410                return do_execute()
411        except:
412            self._logger.exception('caught error while doing authorization')
413
414        return auth_required()
415
416
417class ServerHandler(BasicAuthMixIn, WebSocketHandler):
418    IFTYPE = 'server'
419    IFOP_ALLOWED = False
420
421    def __init__(self, app, req, switch, htpasswd, noorigin, logger, debug):
422        super(ServerHandler, self).__init__(app, req)
423        self._switch = switch
424        self._htpasswd = htpasswd
425        self._noorigin = noorigin
426        self._logger = logger
427        self._debug = debug
428
429    @property
430    def target(self):
431        conn = self.request.connection
432        if hasattr(conn, 'address'):
433            return ':'.join(str(e) for e in conn.address[:2])
434        if hasattr(conn, 'context'):
435            return ':'.join(str(e) for e in conn.context.address[:2])
436        return str(conn)
437
438    def open(self):
439        try:
440            return self._switch.register_port(self)
441        finally:
442            self._logger.info('connected: %s', self.request.remote_ip)
443
444    def on_message(self, message):
445        self._switch.receive(self, EthernetFrame(message))
446
447    def on_close(self):
448        self._switch.unregister_port(self)
449        self._logger.info('disconnected: %s', self.request.remote_ip)
450
451    def check_origin(self, origin):
452        if self._noorigin:
453            return True
454        return super(ServerHandler, self).check_origin(origin)
455
456
457class BaseClientHandler(object):
458    IFTYPE = 'baseclient'
459    IFOP_ALLOWED = False
460
461    def __init__(self, ioloop, switch, target, logger, debug, *args, **kwargs):
462        self._ioloop = ioloop
463        self._switch = switch
464        self._target = target
465        self._logger = logger
466        self._debug = debug
467        self._args = args
468        self._kwargs = kwargs
469        self._device = None
470
471    @property
472    def address(self):
473        raise NotImplementedError('unsupported')
474
475    @property
476    def netmask(self):
477        raise NotImplementedError('unsupported')
478
479    @property
480    def mtu(self):
481        raise NotImplementedError('unsupported')
482
483    @address.setter
484    def address(self, address):
485        raise NotImplementedError('unsupported')
486
487    @netmask.setter
488    def netmask(self, netmask):
489        raise NotImplementedError('unsupported')
490
491    @mtu.setter
492    def mtu(self, mtu):
493        raise NotImplementedError('unsupported')
494
495    def open(self):
496        raise NotImplementedError('unsupported')
497
498    def write_message(self, message, binary=False):
499        raise NotImplementedError('unsupported')
500
501    def read(self):
502        raise NotImplementedError('unsupported')
503
504    @property
505    def target(self):
506        return self._target
507
508    @property
509    def logger(self):
510        return self._logger
511
512    @property
513    def device(self):
514        return self._device
515
516    @property
517    def closed(self):
518        return not self.device
519
520    def close(self):
521        if self.closed:
522            raise ValueError('I/O operation on closed %s' % self.IFTYPE)
523        self.leave_switch()
524        self.unregister_device()
525        self.logger.info('disconnected: %s', self.target)
526
527    def register_device(self, device):
528        self._device = device
529
530    def unregister_device(self):
531        self._device.close()
532        self._device = None
533
534    def fileno(self):
535        if self.closed:
536            raise ValueError('I/O operation on closed %s' % self.IFTYPE)
537        return self.device.fileno()
538
539    def __call__(self, fd, events):
540        try:
541            data = self.read()
542            if data is not None:
543                self._switch.receive(self, EthernetFrame(data))
544                return
545        except:
546            self.logger.exception('caught error while receiving frame')
547        self.close()
548
549    def join_switch(self):
550        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
551        return self._switch.register_port(self)
552
553    def leave_switch(self):
554        self._switch.unregister_port(self)
555        self._ioloop.remove_handler(self.fileno())
556
557
558class NetdevHandler(BaseClientHandler):
559    IFTYPE = 'netdev'
560    IFOP_ALLOWED = True
561    ETH_P_ALL = 0x0003  # from <linux/if_ether.h>
562
563    @property
564    def address(self):
565        if self.closed:
566            raise ValueError('I/O operation on closed netdev')
567        return NetworkInterface(self.target).address
568
569    @property
570    def netmask(self):
571        if self.closed:
572            raise ValueError('I/O operation on closed netdev')
573        return NetworkInterface(self.target).netmask
574
575    @property
576    def mtu(self):
577        if self.closed:
578            raise ValueError('I/O operation on closed netdev')
579        return NetworkInterface(self.target).mtu
580
581    @address.setter
582    def address(self, address):
583        if self.closed:
584            raise ValueError('I/O operation on closed netdev')
585        NetworkInterface(self.target).address = address
586
587    @netmask.setter
588    def netmask(self, netmask):
589        if self.closed:
590            raise ValueError('I/O operation on closed netdev')
591        NetworkInterface(self.target).netmask = netmask
592
593    @mtu.setter
594    def mtu(self, mtu):
595        if self.closed:
596            raise ValueError('I/O operation on closed netdev')
597        NetworkInterface(self.target).mtu = mtu
598
599    def open(self):
600        if not self.closed:
601            raise ValueError('Already opened')
602        self.register_device(socket.socket(
603            socket.PF_PACKET, socket.SOCK_RAW, socket.htons(self.ETH_P_ALL)))
604        self.device.bind((self.target, self.ETH_P_ALL))
605        self.logger.info('connected: %s', self.target)
606        return self.join_switch()
607
608    def write_message(self, message, binary=False):
609        if self.closed:
610            raise ValueError('I/O operation on closed netdev')
611        self.device.sendall(message)
612
613    def read(self):
614        if self.closed:
615            raise ValueError('I/O operation on closed netdev')
616        buf = []
617        while True:
618            buf.append(self.device.recv(65535))
619            if len(buf[-1]) < 65535:
620                break
621        return ''.join(buf)
622
623
624class TapHandler(BaseClientHandler):
625    IFTYPE = 'tap'
626    IFOP_ALLOWED = True
627
628    @property
629    def address(self):
630        if self.closed:
631            raise ValueError('I/O operation on closed tap')
632        try:
633            return self.device.addr
634        except:
635            return ''
636
637    @property
638    def netmask(self):
639        if self.closed:
640            raise ValueError('I/O operation on closed tap')
641        try:
642            return self.device.netmask
643        except:
644            return ''
645
646    @property
647    def mtu(self):
648        if self.closed:
649            raise ValueError('I/O operation on closed tap')
650        return self.device.mtu
651
652    @address.setter
653    def address(self, address):
654        if self.closed:
655            raise ValueError('I/O operation on closed tap')
656        self.device.addr = address
657
658    @netmask.setter
659    def netmask(self, netmask):
660        if self.closed:
661            raise ValueError('I/O operation on closed tap')
662        self.device.netmask = netmask
663
664    @mtu.setter
665    def mtu(self, mtu):
666        if self.closed:
667            raise ValueError('I/O operation on closed tap')
668        self.device.mtu = mtu
669
670    @property
671    def target(self):
672        if self.closed:
673            return self._target
674        return self.device.name
675
676    def open(self):
677        if not self.closed:
678            raise ValueError('Already opened')
679        self.register_device(TunTapDevice(self.target, IFF_TAP | IFF_NO_PI))
680        self.device.up()
681        self.logger.info('connected: %s', self.target)
682        return self.join_switch()
683
684    def write_message(self, message, binary=False):
685        if self.closed:
686            raise ValueError('I/O operation on closed tap')
687        self.device.write(message)
688
689    def read(self):
690        if self.closed:
691            raise ValueError('I/O operation on closed tap')
692        buf = []
693        while True:
694            buf.append(self.device.read(65535))
695            if len(buf[-1]) < 65535:
696                break
697        return ''.join(buf)
698
699
700class ClientHandler(BaseClientHandler):
701    IFTYPE = 'client'
702    IFOP_ALLOWED = False
703
704    def __init__(self, *args, **kwargs):
705        super(ClientHandler, self).__init__(*args, **kwargs)
706        self._sslopt = self._kwargs.get('sslopt', {})
707        self._options = dict(self.authopts, **self.proxyopts)
708
709    def open(self):
710        if not self.closed:
711            raise websocket.WebSocketException('Already opened')
712
713        # XXX: may be blocked
714        self.register_device(websocket.WebSocket(sslopt=self._sslopt))
715        self.device.connect(self.target, **self._options)
716        self.logger.info('connected: %s', self.target)
717        return self.join_switch()
718
719    def write_message(self, message, binary=False):
720        if self.closed:
721            raise websocket.WebSocketException('Closed socket')
722        if binary:
723            flag = websocket.ABNF.OPCODE_BINARY
724        else:
725            flag = websocket.ABNF.OPCODE_TEXT
726        self.device.send(message, flag)
727
728    def read(self):
729        if self.closed:
730            raise websocket.WebSocketException('Closed socket')
731        return self.device.recv()
732
733    @property
734    def authopts(self):
735        cred = self._kwargs.get('cred')
736
737        if not isinstance(cred, dict):
738            return {}
739        if cred.get('user') is None:
740            return {}
741        if cred.get('passwd') is None:
742            return {}
743
744        token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
745        return {'header': ['Authorization: Basic %s' % token]}
746
747    @property
748    def proxyopts(self):
749        scheme, remain = urllib2.splittype(self.target)
750
751        proxy = urllib2.getproxies().get(scheme.replace('ws', 'http'))
752        if not proxy:
753            return {}
754
755        host = urllib2.splitport(urllib2.splithost(remain)[0])[0]
756        if urllib2.proxy_bypass(host):
757            return {}
758
759        hostport = urllib2.splithost(urllib2.splittype(proxy)[1])[0]
760        host, port = urllib2.splitport(hostport)
761        port = int(port) if port else 0
762        return {'http_proxy_host': host, 'http_proxy_port': port}
763
764
765class ControlServerHandler(BasicAuthMixIn, RequestHandler):
766    NAMESPACE = 'etherws.control'
767    IFTYPES = {
768        NetdevHandler.IFTYPE: NetdevHandler,
769        TapHandler.IFTYPE:    TapHandler,
770        ClientHandler.IFTYPE: ClientHandler,
771    }
772
773    def __init__(self, app, req, ioloop, switch, htpasswd, logger, debug):
774        super(ControlServerHandler, self).__init__(app, req)
775        self._ioloop = ioloop
776        self._switch = switch
777        self._htpasswd = htpasswd
778        self._logger = logger
779        self._debug = debug
780
781    def post(self):
782        try:
783            request = json.loads(self.request.body)
784        except Exception as e:
785            return self._jsonrpc_response(error={
786                'code':    0 - 32700,
787                'message': 'Parse error',
788                'data':    '%s: %s' % (type(e).__name__, str(e)),
789            })
790
791        try:
792            id_ = request.get('id')
793            params = request.get('params')
794            version = request['jsonrpc']
795            method = request['method']
796            if version != '2.0':
797                raise ValueError('Invalid JSON-RPC version: %s' % version)
798        except Exception as e:
799            return self._jsonrpc_response(id_=id_, error={
800                'code':    0 - 32600,
801                'message': 'Invalid Request',
802                'data':    '%s: %s' % (type(e).__name__, str(e)),
803            })
804
805        try:
806            if not method.startswith(self.NAMESPACE + '.'):
807                raise ValueError('Invalid method namespace: %s' % method)
808            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
809            handler = getattr(self, handler)
810        except Exception as e:
811            return self._jsonrpc_response(id_=id_, error={
812                'code':    0 - 32601,
813                'message': 'Method not found',
814                'data':    '%s: %s' % (type(e).__name__, str(e)),
815            })
816
817        try:
818            return self._jsonrpc_response(id_=id_, result=handler(params))
819        except Exception as e:
820            self._logger.exception('caught error while processing request')
821            return self._jsonrpc_response(id_=id_, error={
822                'code':    0 - 32602,
823                'message': 'Invalid params',
824                'data':     '%s: %s' % (type(e).__name__, str(e)),
825            })
826
827    def handle_listFdb(self, params):
828        list_ = []
829        for vid, mac, entry in self._switch.fdb.each():
830            list_.append({
831                'vid':  vid,
832                'mac':  EthernetFrame.format_mac(mac),
833                'port': entry.port.number,
834                'age':  int(entry.age),
835            })
836        return {'entries': list_}
837
838    def handle_listPort(self, params):
839        return {'entries': [self._portstat(p) for p in self._switch.portlist]}
840
841    def handle_addPort(self, params):
842        type_ = params['type']
843        target = params['target']
844        opt = getattr(self, '_optparse_' + type_)(params.get('options', {}))
845        cls = self.IFTYPES[type_]
846        interface = cls(self._ioloop, self._switch, target,
847                        self._logger, self._debug, **opt)
848        portnum = interface.open()
849        return {'entries': [self._portstat(self._switch.get_port(portnum))]}
850
851    def handle_setPort(self, params):
852        port = self._switch.get_port(int(params['port']))
853        shut = params.get('shut')
854        if shut is not None:
855            port.shut = bool(shut)
856        return {'entries': [self._portstat(port)]}
857
858    def handle_delPort(self, params):
859        port = self._switch.get_port(int(params['port']))
860        port.interface.close()
861        return {'entries': [self._portstat(port)]}
862
863    def handle_setInterface(self, params):
864        portnum = int(params['port'])
865        port = self._switch.get_port(portnum)
866        address = params.get('address')
867        netmask = params.get('netmask')
868        mtu = params.get('mtu')
869        if not port.interface.IFOP_ALLOWED:
870            raise ValueError('Port %d has unsupported interface: %s' %
871                             (portnum, port.interface.IFTYPE))
872        if address is not None:
873            port.interface.address = address
874        if netmask is not None:
875            port.interface.netmask = netmask
876        if mtu is not None:
877            port.interface.mtu = mtu
878        return {'entries': [self._ifstat(port)]}
879
880    def handle_listInterface(self, params):
881        return {'entries': [self._ifstat(p) for p in self._switch.portlist
882                            if p.interface.IFOP_ALLOWED]}
883
884    def _optparse_netdev(self, opt):
885        return {}
886
887    def _optparse_tap(self, opt):
888        return {}
889
890    def _optparse_client(self, opt):
891        if opt.get('insecure'):
892            sslopt = {'cert_reqs': ssl.CERT_NONE}
893        else:
894            sslopt = {'cert_reqs': ssl.CERT_REQUIRED,
895                      'ca_certs':  opt.get('cacerts')}
896        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
897        return {'sslopt': sslopt, 'cred': cred}
898
899    def _jsonrpc_response(self, id_=None, result=None, error=None):
900        res = {'jsonrpc': '2.0', 'id': id_}
901        if result:
902            res['result'] = result
903        if error:
904            res['error'] = error
905        self.finish(res)
906
907    @staticmethod
908    def _portstat(port):
909        return {
910            'port':   port.number,
911            'type':   port.interface.IFTYPE,
912            'target': port.interface.target,
913            'tx':     port.tx,
914            'rx':     port.rx,
915            'shut':   port.shut,
916        }
917
918    @staticmethod
919    def _ifstat(port):
920        return {
921            'port':    port.number,
922            'type':    port.interface.IFTYPE,
923            'target':  port.interface.target,
924            'address': port.interface.address,
925            'netmask': port.interface.netmask,
926            'mtu':     port.interface.mtu,
927        }
928
929
930def _print_error(error):
931    print(%s (%s)' % (error['message'], error['code']))
932    print('    %s' % error['data'])
933
934
935def _start_sw(args):
936    def daemonize(nochdir=False, noclose=False):
937        if os.fork() > 0:
938            sys.exit(0)
939
940        os.setsid()
941
942        if os.fork() > 0:
943            sys.exit(0)
944
945        if not nochdir:
946            os.chdir('/')
947
948        if not noclose:
949            os.umask(0)
950            sys.stdin.close()
951            sys.stdout.close()
952            sys.stderr.close()
953            os.close(0)
954            os.close(1)
955            os.close(2)
956            sys.stdin = open(os.devnull)
957            sys.stdout = open(os.devnull, 'a')
958            sys.stderr = open(os.devnull, 'a')
959
960    def checkabspath(ns, path):
961        val = getattr(ns, path, '')
962        if not val.startswith('/'):
963            raise ValueError('Invalid %: %s' % (path, val))
964
965    def getsslopt(ns, key, cert):
966        kval = getattr(ns, key, None)
967        cval = getattr(ns, cert, None)
968        if kval and cval:
969            return {'keyfile': kval, 'certfile': cval}
970        elif kval or cval:
971            raise ValueError('Both %s and %s are required' % (key, cert))
972        return None
973
974    def setrealpath(ns, *keys):
975        for k in keys:
976            v = getattr(ns, k, None)
977            if v is not None:
978                v = os.path.realpath(v)
979                open(v).close()  # check readable
980                setattr(ns, k, v)
981
982    def setport(ns, port, isssl):
983        val = getattr(ns, port, None)
984        if val is None:
985            if isssl:
986                return setattr(ns, port, 443)
987            return setattr(ns, port, 80)
988        if not (0 <= val <= 65535):
989            raise ValueError('Invalid %s: %s' % (port, val))
990
991    def sethtpasswd(ns, htpasswd):
992        val = getattr(ns, htpasswd, None)
993        if val:
994            return setattr(ns, htpasswd, Htpasswd(val))
995
996    # if args.debug:
997    #     websocket.enableTrace(True)
998
999    if args.ageout <= 0:
1000        raise ValueError('Invalid ageout: %s' % args.ageout)
1001
1002    setrealpath(args, 'logconf')
1003    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
1004    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
1005
1006    checkabspath(args, 'path')
1007    checkabspath(args, 'ctlpath')
1008
1009    sslopt = getsslopt(args, 'sslkey', 'sslcert')
1010    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
1011
1012    setport(args, 'port', sslopt)
1013    setport(args, 'ctlport', ctlsslopt)
1014
1015    sethtpasswd(args, 'htpasswd')
1016    sethtpasswd(args, 'ctlhtpasswd')
1017
1018    if args.logconf:
1019        logging.config.fileConfig(args.logconf)
1020        logger = logging.getLogger('etherws')
1021    else:
1022        logger = logging.getLogger('etherws')
1023        logger.setLevel(logging.DEBUG if args.debug else logging.INFO)
1024
1025    ioloop = IOLoop.instance()
1026    fdb = FDB(args.ageout, logger, args.debug)
1027    switch = SwitchingHub(fdb, logger, args.debug)
1028
1029    if args.port == args.ctlport and args.host == args.ctlhost:
1030        if args.path == args.ctlpath:
1031            raise ValueError('Same path/ctlpath on same host')
1032        if args.sslkey != args.ctlsslkey:
1033            raise ValueError('Different sslkey/ctlsslkey on same host')
1034        if args.sslcert != args.ctlsslcert:
1035            raise ValueError('Different sslcert/ctlsslcert on same host')
1036
1037        app = Application([
1038            (args.path, ServerHandler, {
1039                'switch':   switch,
1040                'htpasswd': args.htpasswd,
1041                'noorigin': args.noorigin,
1042                'logger':   logger,
1043                'debug':    args.debug,
1044            }),
1045            (args.ctlpath, ControlServerHandler, {
1046                'ioloop':   ioloop,
1047                'switch':   switch,
1048                'htpasswd': args.ctlhtpasswd,
1049                'logger':   logger,
1050                'debug':    args.debug,
1051            }),
1052        ])
1053        server = HTTPServer(app, ssl_options=sslopt)
1054        server.listen(args.port, address=args.host)
1055
1056    else:
1057        app = Application([(args.path, ServerHandler, {
1058            'switch':   switch,
1059            'htpasswd': args.htpasswd,
1060            'noorigin': args.noorigin,
1061            'logger':   logger,
1062            'debug':    args.debug,
1063        })])
1064        server = HTTPServer(app, ssl_options=sslopt)
1065        server.listen(args.port, address=args.host)
1066
1067        ctl = Application([(args.ctlpath, ControlServerHandler, {
1068            'ioloop':   ioloop,
1069            'switch':   switch,
1070            'htpasswd': args.ctlhtpasswd,
1071            'logger':   logger,
1072            'debug':    args.debug,
1073        })])
1074        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
1075        ctlserver.listen(args.ctlport, address=args.ctlhost)
1076
1077    if not args.foreground:
1078        daemonize()
1079
1080    ioloop.start()
1081
1082
1083def _start_ctl(args):
1084    def have_ssl_cert_verification():
1085        return 'context' in urllib2.urlopen.__code__.co_varnames
1086
1087    def request(args, method, params=None, id_=0):
1088        req = urllib2.Request(args.ctlurl)
1089        req.add_header('Content-type', 'application/json')
1090        if args.ctluser:
1091            if not args.ctlpasswd:
1092                args.ctlpasswd = getpass.getpass('Control Password: ')
1093            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
1094            req.add_header('Authorization', 'Basic %s' % token)
1095        method = '.'.join([ControlServerHandler.NAMESPACE, method])
1096        data = {'jsonrpc': '2.0', 'method': method, 'id': id_}
1097        if params is not None:
1098            data['params'] = params
1099        if have_ssl_cert_verification():
1100            ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH,
1101                                             cafile=args.ctlsslcert)
1102            if args.ctlinsecure:
1103                ctx.check_hostname = False
1104                ctx.verify_mode = ssl.CERT_NONE
1105            fp = urllib2.urlopen(req, json.dumps(data), context=ctx)
1106        elif args.ctlsslcert:
1107            raise EnvironmentError('do not support certificate verification')
1108        else:
1109            fp = urllib2.urlopen(req, json.dumps(data))
1110        return json.loads(fp.read())
1111
1112    def print_table(rows):
1113        cols = zip(*rows)
1114        maxlen = [0] * len(cols)
1115        for i in xrange(len(cols)):
1116            maxlen[i] = max(len(str(c)) for c in cols[i])
1117        fmt = '  '.join(['%%-%ds' % maxlen[i] for i in xrange(len(cols))])
1118        fmt = '  ' + fmt
1119        for row in rows:
1120            print(fmt % tuple(row))
1121
1122    def print_portlist(result):
1123        rows = [['Port', 'Type', 'State', 'RX', 'TX', 'Target']]
1124        for r in result:
1125            rows.append([r['port'], r['type'], 'shut' if r['shut'] else 'up',
1126                         r['rx'], r['tx'], r['target']])
1127        print_table(rows)
1128
1129    def print_iflist(result):
1130        rows = [['Port', 'Type', 'Address', 'Netmask', 'MTU', 'Target']]
1131        for r in result:
1132            rows.append([r['port'], r['type'], r['address'],
1133                         r['netmask'], r['mtu'], r['target']])
1134        print_table(rows)
1135
1136    def handle_ctl_addport(args):
1137        opts = {
1138            'user':     getattr(args, 'user', None),
1139            'passwd':   getattr(args, 'passwd', None),
1140            'cacerts':  getattr(args, 'cacerts', None),
1141            'insecure': getattr(args, 'insecure', None),
1142        }
1143        if args.iftype == ClientHandler.IFTYPE:
1144            if not args.target.startswith('ws://') and \
1145               not args.target.startswith('wss://'):
1146                raise ValueError('Invalid target URL scheme: %s' % args.target)
1147            if not opts['user'] and opts['passwd']:
1148                raise ValueError('Authentication required but username empty')
1149            if opts['user'] and not opts['passwd']:
1150                opts['passwd'] = getpass.getpass('Client Password: ')
1151        result = request(args, 'addPort', {
1152            'type':    args.iftype,
1153            'target':  args.target,
1154            'options': opts,
1155        })
1156        if 'error' in result:
1157            _print_error(result['error'])
1158        else:
1159            print_portlist(result['result']['entries'])
1160
1161    def handle_ctl_setport(args):
1162        if args.port <= 0:
1163            raise ValueError('Invalid port: %d' % args.port)
1164        req = {'port': args.port}
1165        shut = getattr(args, 'shut', None)
1166        if shut is not None:
1167            req['shut'] = bool(shut)
1168        result = request(args, 'setPort', req)
1169        if 'error' in result:
1170            _print_error(result['error'])
1171        else:
1172            print_portlist(result['result']['entries'])
1173
1174    def handle_ctl_delport(args):
1175        if args.port <= 0:
1176            raise ValueError('Invalid port: %d' % args.port)
1177        result = request(args, 'delPort', {'port': args.port})
1178        if 'error' in result:
1179            _print_error(result['error'])
1180        else:
1181            print_portlist(result['result']['entries'])
1182
1183    def handle_ctl_listport(args):
1184        result = request(args, 'listPort')
1185        if 'error' in result:
1186            _print_error(result['error'])
1187        else:
1188            print_portlist(result['result']['entries'])
1189
1190    def handle_ctl_setif(args):
1191        if args.port <= 0:
1192            raise ValueError('Invalid port: %d' % args.port)
1193        req = {'port': args.port}
1194        address = getattr(args, 'address', None)
1195        netmask = getattr(args, 'netmask', None)
1196        mtu = getattr(args, 'mtu', None)
1197        if address is not None:
1198            if address:
1199                socket.inet_aton(address)  # validate
1200            req['address'] = address
1201        if netmask is not None:
1202            if netmask:
1203                socket.inet_aton(netmask)  # validate
1204            req['netmask'] = netmask
1205        if mtu is not None:
1206            if mtu < 576:
1207                raise ValueError('Invalid MTU: %d' % mtu)
1208            req['mtu'] = mtu
1209        result = request(args, 'setInterface', req)
1210        if 'error' in result:
1211            _print_error(result['error'])
1212        else:
1213            print_iflist(result['result']['entries'])
1214
1215    def handle_ctl_listif(args):
1216        result = request(args, 'listInterface')
1217        if 'error' in result:
1218            _print_error(result['error'])
1219        else:
1220            print_iflist(result['result']['entries'])
1221
1222    def handle_ctl_listfdb(args):
1223        result = request(args, 'listFdb')
1224        if 'error' in result:
1225            return _print_error(result['error'])
1226        rows = [['Port', 'VLAN', 'MAC', 'Age']]
1227        for r in result['result']['entries']:
1228            rows.append([r['port'], r['vid'], r['mac'], r['age']])
1229        print_table(rows)
1230
1231    locals()['handle_ctl_' + args.control_method](args)
1232
1233
1234def _main():
1235    parser = argparse.ArgumentParser()
1236    subcommand = parser.add_subparsers(dest='subcommand')
1237
1238    # - sw
1239    parser_sw = subcommand.add_parser('sw',
1240                                      help='start virtual switch')
1241
1242    parser_sw.add_argument('--debug', action='store_true', default=False,
1243                           help='run as debug mode')
1244    parser_sw.add_argument('--logconf',
1245                           help='path to logging config')
1246    parser_sw.add_argument('--foreground', action='store_true', default=False,
1247                           help='run as foreground mode')
1248    parser_sw.add_argument('--ageout', type=int, default=300,
1249                           help='FDB ageout time (sec)')
1250
1251    parser_sw.add_argument('--path', default='/',
1252                           help='http(s) path to serve WebSocket')
1253    parser_sw.add_argument('--host', default='',
1254                           help='listen address to serve WebSocket')
1255    parser_sw.add_argument('--port', type=int,
1256                           help='listen port to serve WebSocket')
1257    parser_sw.add_argument('--htpasswd',
1258                           help='path to htpasswd file to auth WebSocket')
1259    parser_sw.add_argument('--sslkey',
1260                           help='path to SSL private key for WebSocket')
1261    parser_sw.add_argument('--sslcert',
1262                           help='path to SSL certificate for WebSocket')
1263
1264    parser_sw.add_argument('--ctlpath', default='/ctl',
1265                           help='http(s) path to serve control API')
1266    parser_sw.add_argument('--ctlhost', default='127.0.0.1',
1267                           help='listen address to serve control API')
1268    parser_sw.add_argument('--ctlport', type=int, default=7867,
1269                           help='listen port to serve control API')
1270    parser_sw.add_argument('--ctlhtpasswd',
1271                           help='path to htpasswd file to auth control API')
1272    parser_sw.add_argument('--ctlsslkey',
1273                           help='path to SSL private key for control API')
1274    parser_sw.add_argument('--ctlsslcert',
1275                           help='path to SSL certificate for control API')
1276
1277    parser_sw.add_argument('--noorigin', action='store_true', default=False,
1278                           help='do not check origin header')
1279
1280    # - ctl
1281    parser_ctl = subcommand.add_parser('ctl',
1282                                       help='control virtual switch')
1283    parser_ctl.add_argument('--ctlurl', default='http://127.0.0.1:7867/ctl',
1284                            help='URL to control API')
1285    parser_ctl.add_argument('--ctluser',
1286                            help='username to auth control API')
1287    parser_ctl.add_argument('--ctlpasswd',
1288                            help='password to auth control API')
1289    parser_ctl.add_argument('--ctlsslcert',
1290                            help='path to SSL certificate for control API')
1291    parser_ctl.add_argument(
1292        '--ctlinsecure', action='store_true', default=False,
1293        help='do not verify control API certificate')
1294
1295    control_method = parser_ctl.add_subparsers(dest='control_method')
1296
1297    # -- ctl addport
1298    parser_ctl_addport = control_method.add_parser('addport',
1299                                                   help='create and add port')
1300    iftype = parser_ctl_addport.add_subparsers(dest='iftype')
1301
1302    # --- ctl addport netdev
1303    parser_ctl_addport_netdev = iftype.add_parser(NetdevHandler.IFTYPE,
1304                                                  help='Network device')
1305    parser_ctl_addport_netdev.add_argument('target',
1306                                           help='device name to add interface')
1307
1308    # --- ctl addport tap
1309    parser_ctl_addport_tap = iftype.add_parser(TapHandler.IFTYPE,
1310                                               help='TAP device')
1311    parser_ctl_addport_tap.add_argument('target',
1312                                        help='device name to create interface')
1313
1314    # --- ctl addport client
1315    parser_ctl_addport_client = iftype.add_parser(ClientHandler.IFTYPE,
1316                                                  help='WebSocket client')
1317    parser_ctl_addport_client.add_argument('target',
1318                                           help='URL to connect WebSocket')
1319    parser_ctl_addport_client.add_argument('--user',
1320                                           help='username to auth WebSocket')
1321    parser_ctl_addport_client.add_argument('--passwd',
1322                                           help='password to auth WebSocket')
1323    parser_ctl_addport_client.add_argument('--cacerts',
1324                                           help='path to CA certificate')
1325    parser_ctl_addport_client.add_argument(
1326        '--insecure', action='store_true', default=False,
1327        help='do not verify server certificate')
1328
1329    # -- ctl setport
1330    parser_ctl_setport = control_method.add_parser('setport',
1331                                                   help='set port config')
1332    parser_ctl_setport.add_argument('port', type=int,
1333                                    help='port number to set config')
1334    parser_ctl_setport.add_argument('--shut', type=int, choices=(0, 1),
1335                                    help='set shutdown state')
1336
1337    # -- ctl delport
1338    parser_ctl_delport = control_method.add_parser('delport',
1339                                                   help='delete port')
1340    parser_ctl_delport.add_argument('port', type=int,
1341                                    help='port number to delete')
1342
1343    # -- ctl listport
1344    parser_ctl_listport = control_method.add_parser('listport',
1345                                                    help='show port list')
1346
1347    # -- ctl setif
1348    parser_ctl_setif = control_method.add_parser('setif',
1349                                                 help='set interface config')
1350    parser_ctl_setif.add_argument('port', type=int,
1351                                  help='port number to set config')
1352    parser_ctl_setif.add_argument('--address',
1353                                  help='IPv4 address to set interface')
1354    parser_ctl_setif.add_argument('--netmask',
1355                                  help='IPv4 netmask to set interface')
1356    parser_ctl_setif.add_argument('--mtu', type=int,
1357                                  help='MTU to set interface')
1358
1359    # -- ctl listif
1360    parser_ctl_listif = control_method.add_parser('listif',
1361                                                  help='show interface list')
1362
1363    # -- ctl listfdb
1364    parser_ctl_listfdb = control_method.add_parser('listfdb',
1365                                                   help='show FDB entries')
1366
1367    # -- go
1368    args = parser.parse_args()
1369
1370    try:
1371        globals()['_start_' + args.subcommand](args)
1372    except Exception as e:
1373        _print_error({
1374            'code':    0 - 32603,
1375            'message': 'Internal error',
1376            'data':    '%s: %s' % (type(e).__name__, str(e)),
1377        })
1378
1379
1380if __name__ == '__main__':
1381    _main()
Note: See TracBrowser for help on using the repository browser.