source: etherws/trunk/etherws.py @ 278

Revision 278, 43.5 KB checked in by atzm, 9 years ago (diff)
  • fix comment
  • Property svn:keywords set to Id
RevLine 
[133]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
[186]4#                          Ethernet over WebSocket
[133]5#
6# depends on:
[265]7#   - python-2.7.5
8#   - python-pytun-2.1
[278]9#   - websocket-client-0.14.0
[265]10#   - tornado-2.4
[133]11#
12# ===========================================================================
[276]13# Copyright (c) 2012-2015, Atzm WATANABE <atzm@atzm.org>
[133]14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
[156]42import ssl
[160]43import time
[183]44import json
[175]45import fcntl
[150]46import base64
[212]47import socket
[256]48import struct
[190]49import urllib2
[150]50import hashlib
[151]51import getpass
[266]52import operator
[133]53import argparse
[165]54import traceback
[133]55
[185]56import tornado
[133]57import websocket
58
[183]59from tornado.web import Application, RequestHandler
[182]60from tornado.websocket import WebSocketHandler
[183]61from tornado.httpserver import HTTPServer
62from tornado.ioloop import IOLoop
63
[166]64from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
[133]65
[166]66
[160]67class DebugMixIn(object):
[166]68    def dprintf(self, msg, func=lambda: ()):
[160]69        if self._debug:
70            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
[164]71            sys.stderr.write(prefix + (msg % func()))
[160]72
73
[164]74class EthernetFrame(object):
75    def __init__(self, data):
76        self.data = data
77
[176]78    @property
79    def dst_multicast(self):
80        return ord(self.data[0]) & 1
[164]81
82    @property
[176]83    def src_multicast(self):
84        return ord(self.data[6]) & 1
85
86    @property
[164]87    def dst_mac(self):
88        return self.data[:6]
89
90    @property
91    def src_mac(self):
92        return self.data[6:12]
93
94    @property
95    def tagged(self):
96        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
97
98    @property
99    def vid(self):
100        if self.tagged:
101            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
[183]102        return 0
[164]103
[198]104    @staticmethod
105    def format_mac(mac, sep=':'):
106        return sep.join(b.encode('hex') for b in mac)
[164]107
[198]108
[166]109class FDB(DebugMixIn):
[196]110    class Entry(object):
[195]111        def __init__(self, port, ageout):
112            self.port = port
113            self._time = time.time()
114            self._ageout = ageout
115
116        @property
117        def age(self):
118            return time.time() - self._time
119
120        @property
121        def agedout(self):
122            return self.age > self._ageout
123
[253]124    def __init__(self, ageout, debug):
[164]125        self._ageout = ageout
126        self._debug = debug
[195]127        self._table = {}
[164]128
[195]129    def _set_entry(self, vid, mac, port):
130        if vid not in self._table:
131            self._table[vid] = {}
[196]132        self._table[vid][mac] = self.Entry(port, self._ageout)
[164]133
[195]134    def _del_entry(self, vid, mac):
135        if vid in self._table:
136            if mac in self._table[vid]:
137                del self._table[vid][mac]
138            if not self._table[vid]:
139                del self._table[vid]
[164]140
[207]141    def _get_entry(self, vid, mac):
[195]142        try:
143            entry = self._table[vid][mac]
144        except KeyError:
[164]145            return None
146
[195]147        if not entry.agedout:
148            return entry
[164]149
[195]150        self._del_entry(vid, mac)
151        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
152                     lambda: (entry.port.number, vid, mac.encode('hex')))
[164]153
[208]154    def each(self):
[207]155        for vid in sorted(self._table.iterkeys()):
156            for mac in sorted(self._table[vid].iterkeys()):
157                entry = self._get_entry(vid, mac)
158                if entry:
159                    yield (vid, mac, entry)
[195]160
161    def lookup(self, frame):
162        mac = frame.dst_mac
163        vid = frame.vid
[207]164        entry = self._get_entry(vid, mac)
[195]165        return getattr(entry, 'port', None)
166
[166]167    def learn(self, port, frame):
168        mac = frame.src_mac
169        vid = frame.vid
[195]170        self._set_entry(vid, mac, port)
[183]171        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
172                     lambda: (port.number, vid, mac.encode('hex')))
[166]173
[164]174    def delete(self, port):
[208]175        for vid, mac, entry in self.each():
[207]176            if entry.port.number == port.number:
177                self._del_entry(vid, mac)
178                self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
179                             lambda: (port.number, vid, mac.encode('hex')))
[164]180
181
[195]182class SwitchingHub(DebugMixIn):
183    class Port(object):
184        def __init__(self, number, interface):
185            self.number = number
186            self.interface = interface
187            self.tx = 0
188            self.rx = 0
189            self.shut = False
[183]190
[253]191    def __init__(self, fdb, debug):
[197]192        self.fdb = fdb
[133]193        self._debug = debug
[183]194        self._table = {}
195        self._next = 1
[133]196
[183]197    @property
198    def portlist(self):
[266]199        return sorted(self._table.itervalues(),
200                      key=operator.attrgetter('number'))
[133]201
[183]202    def get_port(self, portnum):
203        return self._table[portnum]
204
205    def register_port(self, interface):
[186]206        try:
[187]207            self._set_privattr('portnum', interface, self._next)  # XXX
[195]208            self._table[self._next] = self.Port(self._next, interface)
[186]209            return self._next
210        finally:
211            self._next += 1
[183]212
213    def unregister_port(self, interface):
[187]214        portnum = self._get_privattr('portnum', interface)
215        self._del_privattr('portnum', interface)
[197]216        self.fdb.delete(self._table[portnum])
[187]217        del self._table[portnum]
[183]218
219    def send(self, dst_interfaces, frame):
[187]220        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
221        ports = (self._table[n] for n in portnums)
222        ports = (p for p in ports if not p.shut)
[266]223        ports = sorted(ports, key=operator.attrgetter('number'))
[183]224
225        for p in ports:
226            p.interface.write_message(frame.data, True)
227            p.tx += 1
228
229        if ports:
230            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
231                         lambda: (','.join(str(p.number) for p in ports),
232                                  frame.vid,
233                                  frame.src_mac.encode('hex'),
234                                  frame.dst_mac.encode('hex')))
235
236    def receive(self, src_interface, frame):
[187]237        port = self._table[self._get_privattr('portnum', src_interface)]
[183]238
239        if not port.shut:
240            port.rx += 1
241            self._forward(port, frame)
242
243    def _forward(self, src_port, frame):
[166]244        try:
[176]245            if not frame.src_multicast:
[197]246                self.fdb.learn(src_port, frame)
[133]247
[176]248            if not frame.dst_multicast:
[197]249                dst_port = self.fdb.lookup(frame)
[164]250
[166]251                if dst_port:
[183]252                    self.send([dst_port.interface], frame)
[166]253                    return
[133]254
[187]255            ports = set(self.portlist) - set([src_port])
[183]256            self.send((p.interface for p in ports), frame)
[162]257
[166]258        except:  # ex. received invalid frame
259            traceback.print_exc()
[133]260
[187]261    def _privattr(self, name):
262        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
[164]263
[187]264    def _set_privattr(self, name, obj, value):
265        return setattr(obj, self._privattr(name), value)
266
267    def _get_privattr(self, name, obj, defaults=None):
268        return getattr(obj, self._privattr(name), defaults)
269
270    def _del_privattr(self, name, obj):
271        return delattr(obj, self._privattr(name))
272
273
[256]274class NetworkInterface(object):
275    SIOCGIFADDR = 0x8915  # from <linux/sockios.h>
276    SIOCSIFADDR = 0x8916
277    SIOCGIFNETMASK = 0x891b
278    SIOCSIFNETMASK = 0x891c
279    SIOCGIFMTU = 0x8921
280    SIOCSIFMTU = 0x8922
281
282    def __init__(self, ifname):
283        self._ifname = struct.pack('16s', str(ifname)[:15])
284
285    def _ioctl(self, req, data):
286        try:
287            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
288            return fcntl.ioctl(sock.fileno(), req, data)
289        finally:
290            sock.close()
291
292    @property
293    def address(self):
294        try:
295            data = struct.pack('16s18x', self._ifname)
296            ret = self._ioctl(self.SIOCGIFADDR, data)
297            return socket.inet_ntoa(ret[20:24])
298        except IOError:
299            return ''
300
301    @property
302    def netmask(self):
303        try:
304            data = struct.pack('16s18x', self._ifname)
305            ret = self._ioctl(self.SIOCGIFNETMASK, data)
306            return socket.inet_ntoa(ret[20:24])
307        except IOError:
308            return ''
309
310    @property
311    def mtu(self):
312        try:
313            data = struct.pack('16s18x', self._ifname)
314            ret = self._ioctl(self.SIOCGIFMTU, data)
315            return struct.unpack('i', ret[16:20])[0]
316        except IOError:
317            return 0
318
319    @address.setter
320    def address(self, addr):
321        data = struct.pack('16si4s10x', self._ifname, socket.AF_INET,
322                           socket.inet_aton(addr))
323        self._ioctl(self.SIOCSIFADDR, data)
324
325    @netmask.setter
326    def netmask(self, addr):
327        data = struct.pack('16si4s10x', self._ifname, socket.AF_INET,
328                           socket.inet_aton(addr))
329        self._ioctl(self.SIOCSIFNETMASK, data)
330
331    @mtu.setter
332    def mtu(self, mtu):
333        data = struct.pack('16si12x', self._ifname, mtu)
334        self._ioctl(self.SIOCSIFMTU, data)
335
336
[179]337class Htpasswd(object):
338    def __init__(self, path):
339        self._path = path
340        self._stat = None
341        self._data = {}
342
343    def auth(self, name, passwd):
344        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
345        return self._data.get(name) == passwd
346
347    def load(self):
348        old_stat = self._stat
349
350        with open(self._path) as fp:
351            fileno = fp.fileno()
352            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
353            self._stat = os.fstat(fileno)
354
355            unchanged = old_stat and \
[267]356                old_stat.st_ino == self._stat.st_ino and \
357                old_stat.st_dev == self._stat.st_dev and \
358                old_stat.st_mtime == self._stat.st_mtime
[179]359
360            if not unchanged:
361                self._data = self._parse(fp)
362
363        return self
364
365    def _parse(self, fp):
366        data = {}
367        for line in fp:
368            line = line.strip()
369            if 0 <= line.find(':'):
370                name, passwd = line.split(':', 1)
371                if passwd.startswith('{SHA}'):
372                    data[name] = passwd[5:]
373        return data
374
375
[182]376class BasicAuthMixIn(object):
377    def _execute(self, transforms, *args, **kwargs):
378        def do_execute():
379            sp = super(BasicAuthMixIn, self)
380            return sp._execute(transforms, *args, **kwargs)
381
382        def auth_required():
[185]383            stream = getattr(self, 'stream', self.request.connection.stream)
384            stream.write(tornado.escape.utf8(
[182]385                'HTTP/1.1 401 Authorization Required\r\n'
386                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
387            ))
[185]388            stream.close()
[182]389
390        try:
391            if not self._htpasswd:
392                return do_execute()
393
394            creds = self.request.headers.get('Authorization')
395
396            if not creds or not creds.startswith('Basic '):
397                return auth_required()
398
399            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
400
401            if self._htpasswd.load().auth(name, passwd):
402                return do_execute()
403        except:
404            traceback.print_exc()
405
406        return auth_required()
407
408
[253]409class ServerHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
[191]410    IFTYPE = 'server'
[253]411    IFOP_ALLOWED = False
[191]412
[253]413    def __init__(self, app, req, switch, htpasswd, debug):
414        super(ServerHandler, self).__init__(app, req)
[186]415        self._switch = switch
416        self._htpasswd = htpasswd
417        self._debug = debug
418
[203]419    @property
420    def target(self):
[276]421        conn = self.request.connection
422        if hasattr(conn, 'address'):
423            return ':'.join(str(e) for e in conn.address)
424        if hasattr(conn, 'context'):
425            return ':'.join(str(e) for e in conn.context.address)
426        return str(conn)
[186]427
428    def open(self):
429        try:
430            return self._switch.register_port(self)
431        finally:
432            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
433
434    def on_message(self, message):
435        self._switch.receive(self, EthernetFrame(message))
436
437    def on_close(self):
438        self._switch.unregister_port(self)
439        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
440
441
[253]442class BaseClientHandler(DebugMixIn):
443    IFTYPE = 'baseclient'
444    IFOP_ALLOWED = False
[251]445
[253]446    def __init__(self, ioloop, switch, target, debug, *args, **kwargs):
[251]447        self._ioloop = ioloop
448        self._switch = switch
[253]449        self._target = target
[251]450        self._debug = debug
[253]451        self._args = args
452        self._kwargs = kwargs
453        self._device = None
[251]454
455    @property
[253]456    def address(self):
457        raise NotImplementedError('unsupported')
458
459    @property
460    def netmask(self):
461        raise NotImplementedError('unsupported')
462
463    @property
464    def mtu(self):
465        raise NotImplementedError('unsupported')
466
467    @address.setter
468    def address(self, address):
469        raise NotImplementedError('unsupported')
470
471    @netmask.setter
472    def netmask(self, netmask):
473        raise NotImplementedError('unsupported')
474
475    @mtu.setter
476    def mtu(self, mtu):
477        raise NotImplementedError('unsupported')
478
479    def open(self):
480        raise NotImplementedError('unsupported')
481
482    def write_message(self, message, binary=False):
483        raise NotImplementedError('unsupported')
484
485    def read(self):
486        raise NotImplementedError('unsupported')
487
488    @property
[251]489    def target(self):
[253]490        return self._target
[251]491
492    @property
[253]493    def device(self):
494        return self._device
495
496    @property
[251]497    def closed(self):
[253]498        return not self.device
[251]499
[253]500    def close(self):
501        if self.closed:
502            raise ValueError('I/O operation on closed %s' % self.IFTYPE)
503        self.leave_switch()
504        self.unregister_device()
505        self.dprintf('disconnected: %s\n', lambda: self.target)
506
507    def register_device(self, device):
508        self._device = device
509
510    def unregister_device(self):
511        self._device.close()
512        self._device = None
513
514    def fileno(self):
515        if self.closed:
516            raise ValueError('I/O operation on closed %s' % self.IFTYPE)
517        return self.device.fileno()
518
519    def __call__(self, fd, events):
520        try:
521            data = self.read()
522            if data is not None:
523                self._switch.receive(self, EthernetFrame(data))
524                return
525        except:
526            traceback.print_exc()
527        self.close()
528
529    def join_switch(self):
530        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
531        return self._switch.register_port(self)
532
533    def leave_switch(self):
534        self._switch.unregister_port(self)
535        self._ioloop.remove_handler(self.fileno())
536
537
538class NetdevHandler(BaseClientHandler):
539    IFTYPE = 'netdev'
540    IFOP_ALLOWED = True
541    ETH_P_ALL = 0x0003  # from <linux/if_ether.h>
542
[252]543    @property
544    def address(self):
545        if self.closed:
546            raise ValueError('I/O operation on closed netdev')
[256]547        return NetworkInterface(self.target).address
[252]548
549    @property
550    def netmask(self):
551        if self.closed:
552            raise ValueError('I/O operation on closed netdev')
[256]553        return NetworkInterface(self.target).netmask
[252]554
555    @property
556    def mtu(self):
557        if self.closed:
558            raise ValueError('I/O operation on closed netdev')
[256]559        return NetworkInterface(self.target).mtu
[252]560
561    @address.setter
562    def address(self, address):
563        if self.closed:
564            raise ValueError('I/O operation on closed netdev')
[256]565        NetworkInterface(self.target).address = address
[252]566
567    @netmask.setter
568    def netmask(self, netmask):
569        if self.closed:
570            raise ValueError('I/O operation on closed netdev')
[256]571        NetworkInterface(self.target).netmask = netmask
[252]572
573    @mtu.setter
574    def mtu(self, mtu):
575        if self.closed:
576            raise ValueError('I/O operation on closed netdev')
[256]577        NetworkInterface(self.target).mtu = mtu
[252]578
[251]579    def open(self):
580        if not self.closed:
581            raise ValueError('Already opened')
[253]582        self.register_device(socket.socket(
583            socket.PF_PACKET, socket.SOCK_RAW, socket.htons(self.ETH_P_ALL)))
584        self.device.bind((self.target, self.ETH_P_ALL))
[254]585        self.dprintf('connected: %s\n', lambda: self.target)
[253]586        return self.join_switch()
[251]587
588    def write_message(self, message, binary=False):
589        if self.closed:
590            raise ValueError('I/O operation on closed netdev')
[253]591        self.device.sendall(message)
[251]592
[253]593    def read(self):
[251]594        if self.closed:
595            raise ValueError('I/O operation on closed netdev')
596        buf = []
597        while True:
[253]598            buf.append(self.device.recv(65535))
599            if len(buf[-1]) < 65535:
[251]600                break
601        return ''.join(buf)
602
603
[253]604class TapHandler(BaseClientHandler):
[191]605    IFTYPE = 'tap'
[253]606    IFOP_ALLOWED = True
[166]607
[203]608    @property
[212]609    def address(self):
610        if self.closed:
611            raise ValueError('I/O operation on closed tap')
612        try:
[253]613            return self.device.addr
[212]614        except:
615            return ''
616
617    @property
618    def netmask(self):
619        if self.closed:
620            raise ValueError('I/O operation on closed tap')
621        try:
[253]622            return self.device.netmask
[212]623        except:
624            return ''
625
626    @property
627    def mtu(self):
628        if self.closed:
629            raise ValueError('I/O operation on closed tap')
[253]630        return self.device.mtu
[212]631
632    @address.setter
633    def address(self, address):
634        if self.closed:
635            raise ValueError('I/O operation on closed tap')
[253]636        self.device.addr = address
[212]637
638    @netmask.setter
639    def netmask(self, netmask):
640        if self.closed:
641            raise ValueError('I/O operation on closed tap')
[253]642        self.device.netmask = netmask
[212]643
644    @mtu.setter
645    def mtu(self, mtu):
646        if self.closed:
647            raise ValueError('I/O operation on closed tap')
[253]648        self.device.mtu = mtu
[212]649
[253]650    @property
651    def target(self):
652        if self.closed:
653            return self._target
654        return self.device.name
655
[166]656    def open(self):
657        if not self.closed:
[202]658            raise ValueError('Already opened')
[253]659        self.register_device(TunTapDevice(self.target, IFF_TAP | IFF_NO_PI))
660        self.device.up()
[254]661        self.dprintf('connected: %s\n', lambda: self.target)
[253]662        return self.join_switch()
[166]663
664    def write_message(self, message, binary=False):
665        if self.closed:
666            raise ValueError('I/O operation on closed tap')
[253]667        self.device.write(message)
[166]668
[253]669    def read(self):
[166]670        if self.closed:
671            raise ValueError('I/O operation on closed tap')
[162]672        buf = []
673        while True:
[253]674            buf.append(self.device.read(65535))
675            if len(buf[-1]) < 65535:
[162]676                break
[166]677        return ''.join(buf)
[162]678
679
[253]680class ClientHandler(BaseClientHandler):
[191]681    IFTYPE = 'client'
[253]682    IFOP_ALLOWED = False
[191]683
[253]684    def __init__(self, *args, **kwargs):
685        super(ClientHandler, self).__init__(*args, **kwargs)
686
[265]687        self._sslopt = kwargs.get('sslopt', {})
[277]688        self._options = self._getproxy()
[151]689
[253]690        cred = kwargs.get('cred', None)
691
[174]692        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
693            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
[151]694            auth = ['Authorization: Basic %s' % token]
695            self._options['header'] = auth
696
697    def open(self):
[160]698        if not self.closed:
[202]699            raise websocket.WebSocketException('Already opened')
[151]700
[251]701        # XXX: may be blocked
[265]702        self.register_device(websocket.WebSocket(sslopt=self._sslopt))
703        self.device.connect(self.target, **self._options)
704        self.dprintf('connected: %s\n', lambda: self.target)
705        return self.join_switch()
[181]706
[151]707    def write_message(self, message, binary=False):
[160]708        if self.closed:
[202]709            raise websocket.WebSocketException('Closed socket')
[151]710        if binary:
711            flag = websocket.ABNF.OPCODE_BINARY
[160]712        else:
713            flag = websocket.ABNF.OPCODE_TEXT
[253]714        self.device.send(message, flag)
[151]715
[253]716    def read(self):
717        if self.closed:
718            raise websocket.WebSocketException('Closed socket')
719        return self.device.recv()
[151]720
[277]721    def _getproxy(self):
722        scheme, remain = urllib2.splittype(self.target)
[151]723
[277]724        proxy = urllib2.getproxies().get(scheme.replace('ws', 'http'))
725        if not proxy:
726            return {}
727
728        host = urllib2.splitport(urllib2.splithost(remain)[0])[0]
729        if urllib2.proxy_bypass(host):
730            return {}
731
732        hostport = urllib2.splithost(urllib2.splittype(proxy)[1])[0]
733        host, port = urllib2.splitport(hostport)
734        port = int(port) if port else 0
735        return {'http_proxy_host': host, 'http_proxy_port': port}
736
737
[253]738class ControlServerHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
[183]739    NAMESPACE = 'etherws.control'
[191]740    IFTYPES = {
[253]741        NetdevHandler.IFTYPE: NetdevHandler,
742        TapHandler.IFTYPE:    TapHandler,
743        ClientHandler.IFTYPE: ClientHandler,
[186]744    }
[183]745
[253]746    def __init__(self, app, req, ioloop, switch, htpasswd, debug):
747        super(ControlServerHandler, self).__init__(app, req)
[183]748        self._ioloop = ioloop
749        self._switch = switch
750        self._htpasswd = htpasswd
751        self._debug = debug
752
753    def post(self):
[202]754        try:
755            request = json.loads(self.request.body)
756        except Exception as e:
757            return self._jsonrpc_response(error={
758                'code':    0 - 32700,
759                'message': 'Parse error',
760                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
761            })
[183]762
763        try:
[202]764            id_ = request.get('id')
765            params = request.get('params')
766            version = request['jsonrpc']
767            method = request['method']
768            if version != '2.0':
769                raise ValueError('Invalid JSON-RPC version: %s' % version)
770        except Exception as e:
771            return self._jsonrpc_response(id_=id_, error={
772                'code':    0 - 32600,
773                'message': 'Invalid Request',
774                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
775            })
[183]776
[202]777        try:
[183]778            if not method.startswith(self.NAMESPACE + '.'):
[202]779                raise ValueError('Invalid method namespace: %s' % method)
[183]780            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
[202]781            handler = getattr(self, handler)
782        except Exception as e:
783            return self._jsonrpc_response(id_=id_, error={
784                'code':    0 - 32601,
785                'message': 'Method not found',
786                'data':    '%s: %s' % (e.__class__.__name__, str(e)),
787            })
[183]788
[202]789        try:
790            return self._jsonrpc_response(id_=id_, result=handler(params))
[183]791        except Exception as e:
792            traceback.print_exc()
[202]793            return self._jsonrpc_response(id_=id_, error={
794                'code':    0 - 32602,
795                'message': 'Invalid params',
796                'data':     '%s: %s' % (e.__class__.__name__, str(e)),
797            })
[183]798
[198]799    def handle_listFdb(self, params):
800        list_ = []
[208]801        for vid, mac, entry in self._switch.fdb.each():
[207]802            list_.append({
803                'vid':  vid,
804                'mac':  EthernetFrame.format_mac(mac),
805                'port': entry.port.number,
806                'age':  int(entry.age),
807            })
[199]808        return {'entries': list_}
[198]809
[183]810    def handle_listPort(self, params):
[202]811        return {'entries': [self._portstat(p) for p in self._switch.portlist]}
[183]812
813    def handle_addPort(self, params):
[202]814        type_ = params['type']
815        target = params['target']
[253]816        opt = getattr(self, '_optparse_' + type_)(params.get('options', {}))
[202]817        cls = self.IFTYPES[type_]
[253]818        interface = cls(self._ioloop, self._switch, target, self._debug, **opt)
[202]819        portnum = interface.open()
820        return {'entries': [self._portstat(self._switch.get_port(portnum))]}
[183]821
[211]822    def handle_setPort(self, params):
[202]823        port = self._switch.get_port(int(params['port']))
[211]824        shut = params.get('shut')
825        if shut is not None:
826            port.shut = bool(shut)
[202]827        return {'entries': [self._portstat(port)]}
[183]828
[211]829    def handle_delPort(self, params):
[202]830        port = self._switch.get_port(int(params['port']))
[211]831        port.interface.close()
[202]832        return {'entries': [self._portstat(port)]}
[183]833
[212]834    def handle_setInterface(self, params):
835        portnum = int(params['port'])
836        port = self._switch.get_port(portnum)
837        address = params.get('address')
838        netmask = params.get('netmask')
839        mtu = params.get('mtu')
[253]840        if not port.interface.IFOP_ALLOWED:
[212]841            raise ValueError('Port %d has unsupported interface: %s' %
842                             (portnum, port.interface.IFTYPE))
843        if address is not None:
844            port.interface.address = address
845        if netmask is not None:
846            port.interface.netmask = netmask
847        if mtu is not None:
848            port.interface.mtu = mtu
849        return {'entries': [self._ifstat(port)]}
850
851    def handle_listInterface(self, params):
852        return {'entries': [self._ifstat(p) for p in self._switch.portlist
[253]853                            if p.interface.IFOP_ALLOWED]}
[212]854
[251]855    def _optparse_netdev(self, opt):
[253]856        return {}
[251]857
[186]858    def _optparse_tap(self, opt):
[253]859        return {}
[183]860
[186]861    def _optparse_client(self, opt):
862        if opt.get('insecure'):
[273]863            sslopt = {'cert_reqs': ssl.CERT_NONE}
[265]864        else:
865            sslopt = {'cert_reqs': ssl.CERT_REQUIRED,
866                      'ca_certs':  opt.get('cacerts')}
[186]867        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
[265]868        return {'sslopt': sslopt, 'cred': cred}
[183]869
[202]870    def _jsonrpc_response(self, id_=None, result=None, error=None):
871        res = {'jsonrpc': '2.0', 'id': id_}
872        if result:
873            res['result'] = result
874        if error:
875            res['error'] = error
876        self.finish(res)
877
[183]878    @staticmethod
[186]879    def _portstat(port):
880        return {
881            'port':   port.number,
[191]882            'type':   port.interface.IFTYPE,
[203]883            'target': port.interface.target,
[186]884            'tx':     port.tx,
885            'rx':     port.rx,
886            'shut':   port.shut,
887        }
[183]888
[212]889    @staticmethod
890    def _ifstat(port):
891        return {
892            'port':    port.number,
893            'type':    port.interface.IFTYPE,
894            'target':  port.interface.target,
895            'address': port.interface.address,
896            'netmask': port.interface.netmask,
897            'mtu':     port.interface.mtu,
898        }
[183]899
[212]900
[206]901def _print_error(error):
[205]902    print(%s (%s)' % (error['message'], error['code']))
903    print('    %s' % error['data'])
904
905
[206]906def _start_sw(args):
[186]907    def daemonize(nochdir=False, noclose=False):
908        if os.fork() > 0:
909            sys.exit(0)
[134]910
[186]911        os.setsid()
[134]912
[186]913        if os.fork() > 0:
914            sys.exit(0)
[134]915
[186]916        if not nochdir:
917            os.chdir('/')
[134]918
[186]919        if not noclose:
920            os.umask(0)
921            sys.stdin.close()
922            sys.stdout.close()
923            sys.stderr.close()
924            os.close(0)
925            os.close(1)
926            os.close(2)
927            sys.stdin = open(os.devnull)
928            sys.stdout = open(os.devnull, 'a')
929            sys.stderr = open(os.devnull, 'a')
[134]930
[186]931    def checkabspath(ns, path):
[184]932        val = getattr(ns, path, '')
933        if not val.startswith('/'):
[202]934            raise ValueError('Invalid %: %s' % (path, val))
[184]935
936    def getsslopt(ns, key, cert):
937        kval = getattr(ns, key, None)
938        cval = getattr(ns, cert, None)
939        if kval and cval:
940            return {'keyfile': kval, 'certfile': cval}
941        elif kval or cval:
[202]942            raise ValueError('Both %s and %s are required' % (key, cert))
[184]943        return None
944
[186]945    def setrealpath(ns, *keys):
946        for k in keys:
947            v = getattr(ns, k, None)
948            if v is not None:
949                v = os.path.realpath(v)
950                open(v).close()  # check readable
951                setattr(ns, k, v)
952
[184]953    def setport(ns, port, isssl):
954        val = getattr(ns, port, None)
955        if val is None:
956            if isssl:
957                return setattr(ns, port, 443)
958            return setattr(ns, port, 80)
959        if not (0 <= val <= 65535):
[202]960            raise ValueError('Invalid %s: %s' % (port, val))
[184]961
962    def sethtpasswd(ns, htpasswd):
963        val = getattr(ns, htpasswd, None)
964        if val:
965            return setattr(ns, htpasswd, Htpasswd(val))
966
[272]967    # if args.debug:
968    #     websocket.enableTrace(True)
[183]969
970    if args.ageout <= 0:
[202]971        raise ValueError('Invalid ageout: %s' % args.ageout)
[183]972
[186]973    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
974    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
[183]975
[186]976    checkabspath(args, 'path')
977    checkabspath(args, 'ctlpath')
[183]978
[184]979    sslopt = getsslopt(args, 'sslkey', 'sslcert')
980    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
[143]981
[184]982    setport(args, 'port', sslopt)
983    setport(args, 'ctlport', ctlsslopt)
[143]984
[184]985    sethtpasswd(args, 'htpasswd')
986    sethtpasswd(args, 'ctlhtpasswd')
[167]987
[183]988    ioloop = IOLoop.instance()
[253]989    fdb = FDB(args.ageout, args.debug)
990    switch = SwitchingHub(fdb, args.debug)
[167]991
[184]992    if args.port == args.ctlport and args.host == args.ctlhost:
993        if args.path == args.ctlpath:
[202]994            raise ValueError('Same path/ctlpath on same host')
[184]995        if args.sslkey != args.ctlsslkey:
[202]996            raise ValueError('Different sslkey/ctlsslkey on same host')
[184]997        if args.sslcert != args.ctlsslcert:
[202]998            raise ValueError('Different sslcert/ctlsslcert on same host')
[133]999
[184]1000        app = Application([
[253]1001            (args.path, ServerHandler, {
[184]1002                'switch':   switch,
1003                'htpasswd': args.htpasswd,
1004                'debug':    args.debug,
1005            }),
[253]1006            (args.ctlpath, ControlServerHandler, {
[184]1007                'ioloop':   ioloop,
1008                'switch':   switch,
1009                'htpasswd': args.ctlhtpasswd,
1010                'debug':    args.debug,
1011            }),
1012        ])
1013        server = HTTPServer(app, ssl_options=sslopt)
1014        server.listen(args.port, address=args.host)
[151]1015
[184]1016    else:
[253]1017        app = Application([(args.path, ServerHandler, {
[184]1018            'switch':   switch,
1019            'htpasswd': args.htpasswd,
1020            'debug':    args.debug,
1021        })])
1022        server = HTTPServer(app, ssl_options=sslopt)
1023        server.listen(args.port, address=args.host)
1024
[253]1025        ctl = Application([(args.ctlpath, ControlServerHandler, {
[184]1026            'ioloop':   ioloop,
1027            'switch':   switch,
1028            'htpasswd': args.ctlhtpasswd,
1029            'debug':    args.debug,
1030        })])
1031        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
1032        ctlserver.listen(args.ctlport, address=args.ctlhost)
1033
[151]1034    if not args.foreground:
1035        daemonize()
1036
[138]1037    ioloop.start()
[133]1038
1039
[206]1040def _start_ctl(args):
[272]1041    def have_ssl_cert_verification():
1042        return 'context' in urllib2.urlopen.__code__.co_varnames
1043
[202]1044    def request(args, method, params=None, id_=0):
[190]1045        req = urllib2.Request(args.ctlurl)
1046        req.add_header('Content-type', 'application/json')
1047        if args.ctluser:
1048            if not args.ctlpasswd:
[209]1049                args.ctlpasswd = getpass.getpass('Control Password: ')
[190]1050            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
1051            req.add_header('Authorization', 'Basic %s' % token)
[253]1052        method = '.'.join([ControlServerHandler.NAMESPACE, method])
[202]1053        data = {'jsonrpc': '2.0', 'method': method, 'id': id_}
1054        if params is not None:
1055            data['params'] = params
[272]1056        if have_ssl_cert_verification():
1057            ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH,
1058                                             cafile=args.ctlsslcert)
1059            if args.ctlinsecure:
1060                ctx.check_hostname = False
1061                ctx.verify_mode = ssl.CERT_NONE
1062            fp = urllib2.urlopen(req, json.dumps(data), context=ctx)
1063        elif args.ctlsslcert:
1064            raise EnvironmentError('do not support certificate verification')
1065        else:
1066            fp = urllib2.urlopen(req, json.dumps(data))
1067        return json.loads(fp.read())
[190]1068
[255]1069    def print_table(rows):
1070        cols = zip(*rows)
1071        maxlen = [0] * len(cols)
1072        for i in xrange(len(cols)):
1073            maxlen[i] = max(len(str(c)) for c in cols[i])
1074        fmt = '  '.join(['%%-%ds' % maxlen[i] for i in xrange(len(cols))])
1075        fmt = '  ' + fmt
1076        for row in rows:
1077            print(fmt % tuple(row))
1078
[199]1079    def print_portlist(result):
[255]1080        rows = [['Port', 'Type', 'State', 'RX', 'TX', 'Target']]
[199]1081        for r in result:
[255]1082            rows.append([r['port'], r['type'], 'shut' if r['shut'] else 'up',
1083                         r['rx'], r['tx'], r['target']])
1084        print_table(rows)
[199]1085
[212]1086    def print_iflist(result):
[255]1087        rows = [['Port', 'Type', 'Address', 'Netmask', 'MTU', 'Target']]
[212]1088        for r in result:
[255]1089            rows.append([r['port'], r['type'], r['address'],
1090                         r['netmask'], r['mtu'], r['target']])
1091        print_table(rows)
[212]1092
[190]1093    def handle_ctl_addport(args):
[205]1094        opts = {
1095            'user':     getattr(args, 'user', None),
1096            'passwd':   getattr(args, 'passwd', None),
1097            'cacerts':  getattr(args, 'cacerts', None),
1098            'insecure': getattr(args, 'insecure', None),
1099        }
[253]1100        if args.iftype == ClientHandler.IFTYPE:
[205]1101            if not args.target.startswith('ws://') and \
1102               not args.target.startswith('wss://'):
1103                raise ValueError('Invalid target URL scheme: %s' % args.target)
[210]1104            if not opts['user'] and opts['passwd']:
1105                raise ValueError('Authentication required but username empty')
1106            if opts['user'] and not opts['passwd']:
1107                opts['passwd'] = getpass.getpass('Client Password: ')
[202]1108        result = request(args, 'addPort', {
[204]1109            'type':    args.iftype,
[191]1110            'target':  args.target,
[205]1111            'options': opts,
[202]1112        })
1113        if 'error' in result:
[206]1114            _print_error(result['error'])
[199]1115        else:
1116            print_portlist(result['result']['entries'])
[190]1117
[211]1118    def handle_ctl_setport(args):
[190]1119        if args.port <= 0:
[202]1120            raise ValueError('Invalid port: %d' % args.port)
[211]1121        req = {'port': args.port}
1122        shut = getattr(args, 'shut', None)
1123        if shut is not None:
1124            req['shut'] = bool(shut)
1125        result = request(args, 'setPort', req)
[202]1126        if 'error' in result:
[206]1127            _print_error(result['error'])
[199]1128        else:
1129            print_portlist(result['result']['entries'])
[190]1130
1131    def handle_ctl_delport(args):
1132        if args.port <= 0:
[202]1133            raise ValueError('Invalid port: %d' % args.port)
1134        result = request(args, 'delPort', {'port': args.port})
1135        if 'error' in result:
[206]1136            _print_error(result['error'])
[199]1137        else:
1138            print_portlist(result['result']['entries'])
[190]1139
1140    def handle_ctl_listport(args):
[202]1141        result = request(args, 'listPort')
1142        if 'error' in result:
[206]1143            _print_error(result['error'])
[199]1144        else:
1145            print_portlist(result['result']['entries'])
[190]1146
[212]1147    def handle_ctl_setif(args):
1148        if args.port <= 0:
1149            raise ValueError('Invalid port: %d' % args.port)
1150        req = {'port': args.port}
1151        address = getattr(args, 'address', None)
1152        netmask = getattr(args, 'netmask', None)
1153        mtu = getattr(args, 'mtu', None)
1154        if address is not None:
1155            if address:
1156                socket.inet_aton(address)  # validate
1157            req['address'] = address
1158        if netmask is not None:
1159            if netmask:
1160                socket.inet_aton(netmask)  # validate
1161            req['netmask'] = netmask
1162        if mtu is not None:
1163            if mtu < 576:
1164                raise ValueError('Invalid MTU: %d' % mtu)
1165            req['mtu'] = mtu
1166        result = request(args, 'setInterface', req)
1167        if 'error' in result:
1168            _print_error(result['error'])
1169        else:
1170            print_iflist(result['result']['entries'])
1171
1172    def handle_ctl_listif(args):
1173        result = request(args, 'listInterface')
1174        if 'error' in result:
1175            _print_error(result['error'])
1176        else:
1177            print_iflist(result['result']['entries'])
1178
[198]1179    def handle_ctl_listfdb(args):
[202]1180        result = request(args, 'listFdb')
1181        if 'error' in result:
[206]1182            return _print_error(result['error'])
[255]1183        rows = [['Port', 'VLAN', 'MAC', 'Age']]
1184        for r in result['result']['entries']:
1185            rows.append([r['port'], r['vid'], r['mac'], r['age']])
1186        print_table(rows)
1187
[199]1188    locals()['handle_ctl_' + args.control_method](args)
[190]1189
1190
[206]1191def _main():
[186]1192    parser = argparse.ArgumentParser()
[190]1193    subcommand = parser.add_subparsers(dest='subcommand')
[186]1194
[204]1195    # - sw
[215]1196    parser_sw = subcommand.add_parser('sw',
1197                                      help='start virtual switch')
[190]1198
[215]1199    parser_sw.add_argument('--debug', action='store_true', default=False,
1200                           help='run as debug mode')
1201    parser_sw.add_argument('--foreground', action='store_true', default=False,
1202                           help='run as foreground mode')
1203    parser_sw.add_argument('--ageout', type=int, default=300,
1204                           help='FDB ageout time (sec)')
[186]1205
[215]1206    parser_sw.add_argument('--path', default='/',
1207                           help='http(s) path to serve WebSocket')
1208    parser_sw.add_argument('--host', default='',
1209                           help='listen address to serve WebSocket')
1210    parser_sw.add_argument('--port', type=int,
1211                           help='listen port to serve WebSocket')
1212    parser_sw.add_argument('--htpasswd',
1213                           help='path to htpasswd file to auth WebSocket')
1214    parser_sw.add_argument('--sslkey',
1215                           help='path to SSL private key for WebSocket')
1216    parser_sw.add_argument('--sslcert',
1217                           help='path to SSL certificate for WebSocket')
[186]1218
[215]1219    parser_sw.add_argument('--ctlpath', default='/ctl',
1220                           help='http(s) path to serve control API')
1221    parser_sw.add_argument('--ctlhost', default='127.0.0.1',
1222                           help='listen address to serve control API')
1223    parser_sw.add_argument('--ctlport', type=int, default=7867,
1224                           help='listen port to serve control API')
1225    parser_sw.add_argument('--ctlhtpasswd',
1226                           help='path to htpasswd file to auth control API')
1227    parser_sw.add_argument('--ctlsslkey',
1228                           help='path to SSL private key for control API')
1229    parser_sw.add_argument('--ctlsslcert',
1230                           help='path to SSL certificate for control API')
[186]1231
[204]1232    # - ctl
[215]1233    parser_ctl = subcommand.add_parser('ctl',
1234                                       help='control virtual switch')
1235    parser_ctl.add_argument('--ctlurl', default='http://127.0.0.1:7867/ctl',
1236                            help='URL to control API')
1237    parser_ctl.add_argument('--ctluser',
1238                            help='username to auth control API')
1239    parser_ctl.add_argument('--ctlpasswd',
1240                            help='password to auth control API')
[272]1241    parser_ctl.add_argument('--ctlsslcert',
1242                            help='path to SSL certificate for control API')
1243    parser_ctl.add_argument(
1244        '--ctlinsecure', action='store_true', default=False,
1245        help='do not verify control API certificate')
[190]1246
[204]1247    control_method = parser_ctl.add_subparsers(dest='control_method')
[190]1248
[204]1249    # -- ctl addport
[215]1250    parser_ctl_addport = control_method.add_parser('addport',
1251                                                   help='create and add port')
[204]1252    iftype = parser_ctl_addport.add_subparsers(dest='iftype')
[190]1253
[251]1254    # --- ctl addport netdev
1255    parser_ctl_addport_netdev = iftype.add_parser(NetdevHandler.IFTYPE,
[257]1256                                                  help='Network device')
[251]1257    parser_ctl_addport_netdev.add_argument('target',
1258                                           help='device name to add interface')
1259
[204]1260    # --- ctl addport tap
[215]1261    parser_ctl_addport_tap = iftype.add_parser(TapHandler.IFTYPE,
1262                                               help='TAP device')
1263    parser_ctl_addport_tap.add_argument('target',
1264                                        help='device name to create interface')
[190]1265
[204]1266    # --- ctl addport client
[253]1267    parser_ctl_addport_client = iftype.add_parser(ClientHandler.IFTYPE,
[215]1268                                                  help='WebSocket client')
1269    parser_ctl_addport_client.add_argument('target',
1270                                           help='URL to connect WebSocket')
1271    parser_ctl_addport_client.add_argument('--user',
1272                                           help='username to auth WebSocket')
1273    parser_ctl_addport_client.add_argument('--passwd',
1274                                           help='password to auth WebSocket')
1275    parser_ctl_addport_client.add_argument('--cacerts',
1276                                           help='path to CA certificate')
[204]1277    parser_ctl_addport_client.add_argument(
[215]1278        '--insecure', action='store_true', default=False,
1279        help='do not verify server certificate')
[190]1280
[211]1281    # -- ctl setport
[215]1282    parser_ctl_setport = control_method.add_parser('setport',
1283                                                   help='set port config')
1284    parser_ctl_setport.add_argument('port', type=int,
1285                                    help='port number to set config')
1286    parser_ctl_setport.add_argument('--shut', type=int, choices=(0, 1),
1287                                    help='set shutdown state')
[190]1288
[204]1289    # -- ctl delport
[215]1290    parser_ctl_delport = control_method.add_parser('delport',
1291                                                   help='delete port')
1292    parser_ctl_delport.add_argument('port', type=int,
1293                                    help='port number to delete')
[198]1294
[204]1295    # -- ctl listport
[215]1296    parser_ctl_listport = control_method.add_parser('listport',
1297                                                    help='show port list')
[204]1298
[212]1299    # -- ctl setif
[215]1300    parser_ctl_setif = control_method.add_parser('setif',
1301                                                 help='set interface config')
1302    parser_ctl_setif.add_argument('port', type=int,
1303                                  help='port number to set config')
1304    parser_ctl_setif.add_argument('--address',
1305                                  help='IPv4 address to set interface')
1306    parser_ctl_setif.add_argument('--netmask',
1307                                  help='IPv4 netmask to set interface')
1308    parser_ctl_setif.add_argument('--mtu', type=int,
1309                                  help='MTU to set interface')
[212]1310
1311    # -- ctl listif
[215]1312    parser_ctl_listif = control_method.add_parser('listif',
1313                                                  help='show interface list')
[212]1314
[204]1315    # -- ctl listfdb
[215]1316    parser_ctl_listfdb = control_method.add_parser('listfdb',
1317                                                   help='show FDB entries')
[204]1318
[190]1319    # -- go
[186]1320    args = parser.parse_args()
1321
[205]1322    try:
[206]1323        globals()['_start_' + args.subcommand](args)
[205]1324    except Exception as e:
[206]1325        _print_error({
[205]1326            'code':    0 - 32603,
1327            'message': 'Internal error',
1328            'data':    '%s: %s' % (e.__class__.__name__, str(e)),
1329        })
[186]1330
[205]1331
[133]1332if __name__ == '__main__':
[206]1333    _main()
Note: See TracBrowser for help on using the repository browser.