source: etherws/trunk/etherws.py @ 184

Revision 184, 21.7 KB checked in by atzm, 12 years ago (diff)
  • enables controller options
  • Property svn:keywords set to Id
RevLine 
[133]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
[141]4#              Ethernet over WebSocket tunneling server/client
[133]5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
[136]9#   - websocket-client-0.7.0
[183]10#   - tornado-2.3
[133]11#
[140]12# todo:
[143]13#   - servant mode support (like typical p2p software)
[140]14#
[133]15# ===========================================================================
16# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
17# All rights reserved.
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are met:
21#
22# 1. Redistributions of source code must retain the above copyright notice,
23#    this list of conditions and the following disclaimer.
24# 2. Redistributions in binary form must reproduce the above copyright
25#    notice, this list of conditions and the following disclaimer in the
26#    documentation and/or other materials provided with the distribution.
27#
28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38# POSSIBILITY OF SUCH DAMAGE.
39# ===========================================================================
40#
41# $Id$
42
43import os
44import sys
[156]45import ssl
[160]46import time
[183]47import json
[175]48import fcntl
[150]49import base64
50import hashlib
[151]51import getpass
[133]52import argparse
[165]53import traceback
[133]54
55import websocket
56
[183]57from tornado.web import Application, RequestHandler
[182]58from tornado.websocket import WebSocketHandler
[183]59from tornado.httpserver import HTTPServer
60from tornado.ioloop import IOLoop
61
[166]62from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
[133]63
[166]64
[160]65class DebugMixIn(object):
[166]66    def dprintf(self, msg, func=lambda: ()):
[160]67        if self._debug:
68            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
[164]69            sys.stderr.write(prefix + (msg % func()))
[160]70
71
[164]72class EthernetFrame(object):
73    def __init__(self, data):
74        self.data = data
75
[176]76    @property
77    def dst_multicast(self):
78        return ord(self.data[0]) & 1
[164]79
80    @property
[176]81    def src_multicast(self):
82        return ord(self.data[6]) & 1
83
84    @property
[164]85    def dst_mac(self):
86        return self.data[:6]
87
88    @property
89    def src_mac(self):
90        return self.data[6:12]
91
92    @property
93    def tagged(self):
94        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
95
96    @property
97    def vid(self):
98        if self.tagged:
99            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
[183]100        return 0
[164]101
102
[166]103class FDB(DebugMixIn):
[167]104    def __init__(self, ageout, debug=False):
[164]105        self._ageout = ageout
106        self._debug = debug
[166]107        self._dict = {}
[164]108
109    def lookup(self, frame):
110        mac = frame.dst_mac
111        vid = frame.vid
112
[177]113        group = self._dict.get(vid)
[164]114        if not group:
115            return None
116
[177]117        entry = group.get(mac)
[164]118        if not entry:
119            return None
120
121        if time.time() - entry['time'] > self._ageout:
[183]122            port = self._dict[vid][mac]['port']
[166]123            del self._dict[vid][mac]
124            if not self._dict[vid]:
125                del self._dict[vid]
[183]126            self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
127                         lambda: (port.number, vid, mac.encode('hex')))
[164]128            return None
129
130        return entry['port']
131
[166]132    def learn(self, port, frame):
133        mac = frame.src_mac
134        vid = frame.vid
135
136        if vid not in self._dict:
137            self._dict[vid] = {}
138
139        self._dict[vid][mac] = {'time': time.time(), 'port': port}
[183]140        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
141                     lambda: (port.number, vid, mac.encode('hex')))
[166]142
[164]143    def delete(self, port):
[166]144        for vid in self._dict.keys():
145            for mac in self._dict[vid].keys():
[183]146                if self._dict[vid][mac]['port'].number == port.number:
[166]147                    del self._dict[vid][mac]
[183]148                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
149                                 lambda: (port.number, vid, mac.encode('hex')))
[166]150            if not self._dict[vid]:
151                del self._dict[vid]
[164]152
153
[183]154class SwitchPort(object):
155    def __init__(self, number, interface):
156        self.number = number
157        self.interface = interface
158        self.tx = 0
159        self.rx = 0
160        self.shut = False
161
162    @staticmethod
163    def cmp_by_number(x, y):
164        return cmp(x.number, y.number)
165
166
[166]167class SwitchingHub(DebugMixIn):
168    def __init__(self, fdb, debug=False):
169        self._fdb = fdb
[133]170        self._debug = debug
[183]171        self._table = {}
172        self._next = 1
[133]173
[183]174    @property
175    def portlist(self):
176        return sorted(self._table.itervalues(), cmp=SwitchPort.cmp_by_number)
[133]177
[183]178    def shut_port(self, portnum, flag=True):
179        self._table[portnum].shut = flag
[133]180
[183]181    def get_port(self, portnum):
182        return self._table[portnum]
183
184    def register_port(self, interface):
185        interface._switch_portnum = self._next  # XXX
186        self._table[self._next] = SwitchPort(self._next, interface)
187        self._next += 1
188
189    def unregister_port(self, interface):
190        self._fdb.delete(self._table[interface._switch_portnum])
191        del self._table[interface._switch_portnum]
192        del interface._switch_portnum
193
194    def send(self, dst_interfaces, frame):
195        ports = sorted((self._table[i._switch_portnum] for i in dst_interfaces
196                        if not self._table[i._switch_portnum].shut),
197                       cmp=SwitchPort.cmp_by_number)
198
199        for p in ports:
200            p.interface.write_message(frame.data, True)
201            p.tx += 1
202
203        if ports:
204            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
205                         lambda: (','.join(str(p.number) for p in ports),
206                                  frame.vid,
207                                  frame.src_mac.encode('hex'),
208                                  frame.dst_mac.encode('hex')))
209
210    def receive(self, src_interface, frame):
211        port = self._table[src_interface._switch_portnum]
212
213        if not port.shut:
214            port.rx += 1
215            self._forward(port, frame)
216
217    def _forward(self, src_port, frame):
[166]218        try:
[176]219            if not frame.src_multicast:
[172]220                self._fdb.learn(src_port, frame)
[133]221
[176]222            if not frame.dst_multicast:
[166]223                dst_port = self._fdb.lookup(frame)
[164]224
[166]225                if dst_port:
[183]226                    self.send([dst_port.interface], frame)
[166]227                    return
[133]228
[183]229            ports = set(self._table.itervalues()) - set([src_port])
230            self.send((p.interface for p in ports), frame)
[162]231
[166]232        except:  # ex. received invalid frame
233            traceback.print_exc()
[133]234
[164]235
[179]236class Htpasswd(object):
237    def __init__(self, path):
238        self._path = path
239        self._stat = None
240        self._data = {}
241
242    def auth(self, name, passwd):
243        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
244        return self._data.get(name) == passwd
245
246    def load(self):
247        old_stat = self._stat
248
249        with open(self._path) as fp:
250            fileno = fp.fileno()
251            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
252            self._stat = os.fstat(fileno)
253
254            unchanged = old_stat and \
255                        old_stat.st_ino == self._stat.st_ino and \
256                        old_stat.st_dev == self._stat.st_dev and \
257                        old_stat.st_mtime == self._stat.st_mtime
258
259            if not unchanged:
260                self._data = self._parse(fp)
261
262        return self
263
264    def _parse(self, fp):
265        data = {}
266        for line in fp:
267            line = line.strip()
268            if 0 <= line.find(':'):
269                name, passwd = line.split(':', 1)
270                if passwd.startswith('{SHA}'):
271                    data[name] = passwd[5:]
272        return data
273
274
[182]275class BasicAuthMixIn(object):
276    def _execute(self, transforms, *args, **kwargs):
277        def do_execute():
278            sp = super(BasicAuthMixIn, self)
279            return sp._execute(transforms, *args, **kwargs)
280
281        def auth_required():
282            self.stream.write(tornado.escape.utf8(
283                'HTTP/1.1 401 Authorization Required\r\n'
284                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
285            ))
286            self.stream.close()
287
288        try:
289            if not self._htpasswd:
290                return do_execute()
291
292            creds = self.request.headers.get('Authorization')
293
294            if not creds or not creds.startswith('Basic '):
295                return auth_required()
296
297            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
298
299            if self._htpasswd.load().auth(name, passwd):
300                return do_execute()
301        except:
302            traceback.print_exc()
303
304        return auth_required()
305
306
[166]307class TapHandler(DebugMixIn):
308    READ_SIZE = 65535
309
[178]310    def __init__(self, ioloop, switch, dev, debug=False):
311        self._ioloop = ioloop
[166]312        self._switch = switch
313        self._dev = dev
314        self._debug = debug
315        self._tap = None
316
317    @property
318    def closed(self):
319        return not self._tap
320
[183]321    def get_type(self):
322        return 'tap'
323
324    def get_name(self):
325        if self.closed:
326            return self._dev
327        return self._tap.name
328
[166]329    def open(self):
330        if not self.closed:
331            raise ValueError('already opened')
332        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
333        self._tap.up()
334        self._switch.register_port(self)
[178]335        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
[166]336
337    def close(self):
338        if self.closed:
339            raise ValueError('I/O operation on closed tap')
[178]340        self._ioloop.remove_handler(self.fileno())
[166]341        self._switch.unregister_port(self)
342        self._tap.close()
343        self._tap = None
344
345    def fileno(self):
346        if self.closed:
347            raise ValueError('I/O operation on closed tap')
348        return self._tap.fileno()
349
350    def write_message(self, message, binary=False):
351        if self.closed:
352            raise ValueError('I/O operation on closed tap')
353        self._tap.write(message)
354
[138]355    def __call__(self, fd, events):
[166]356        try:
[183]357            self._switch.receive(self, EthernetFrame(self._read()))
[166]358            return
359        except:
360            traceback.print_exc()
[178]361        self.close()
[166]362
363    def _read(self):
364        if self.closed:
365            raise ValueError('I/O operation on closed tap')
[162]366        buf = []
367        while True:
[166]368            buf.append(self._tap.read(self.READ_SIZE))
369            if len(buf[-1]) < self.READ_SIZE:
[162]370                break
[166]371        return ''.join(buf)
[162]372
373
[182]374class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
[179]375    def __init__(self, app, req, switch, htpasswd=None, debug=False):
[160]376        super(EtherWebSocketHandler, self).__init__(app, req)
[166]377        self._switch = switch
[179]378        self._htpasswd = htpasswd
[133]379        self._debug = debug
380
[183]381    def get_type(self):
382        return 'server'
383
384    def get_name(self):
385        return self.request.remote_ip
386
[133]387    def open(self):
[166]388        self._switch.register_port(self)
[164]389        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
[133]390
391    def on_message(self, message):
[183]392        self._switch.receive(self, EthernetFrame(message))
[133]393
394    def on_close(self):
[166]395        self._switch.unregister_port(self)
[164]396        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
[133]397
398
[160]399class EtherWebSocketClient(DebugMixIn):
[181]400    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
[178]401        self._ioloop = ioloop
[166]402        self._switch = switch
[151]403        self._url = url
[181]404        self._ssl = ssl_
[160]405        self._debug = debug
[166]406        self._sock = None
[151]407        self._options = {}
408
[174]409        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
410            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
[151]411            auth = ['Authorization: Basic %s' % token]
412            self._options['header'] = auth
413
[160]414    @property
415    def closed(self):
416        return not self._sock
417
[183]418    def get_type(self):
419        return 'client'
420
421    def get_name(self):
422        return self._url
423
[151]424    def open(self):
[181]425        sslwrap = websocket._SSLSocketWrapper
426
[160]427        if not self.closed:
428            raise websocket.WebSocketException('already opened')
[151]429
[181]430        if self._ssl:
431            websocket._SSLSocketWrapper = self._ssl
432
433        try:
434            self._sock = websocket.WebSocket()
435            self._sock.connect(self._url, **self._options)
436            self._switch.register_port(self)
437            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
438            self.dprintf('connected: %s\n', lambda: self._url)
439        finally:
440            websocket._SSLSocketWrapper = sslwrap
441
[151]442    def close(self):
[160]443        if self.closed:
444            raise websocket.WebSocketException('already closed')
[178]445        self._ioloop.remove_handler(self.fileno())
[166]446        self._switch.unregister_port(self)
[151]447        self._sock.close()
448        self._sock = None
[164]449        self.dprintf('disconnected: %s\n', lambda: self._url)
[151]450
[165]451    def fileno(self):
452        if self.closed:
453            raise websocket.WebSocketException('closed socket')
454        return self._sock.io_sock.fileno()
455
[151]456    def write_message(self, message, binary=False):
[160]457        if self.closed:
458            raise websocket.WebSocketException('closed socket')
[151]459        if binary:
460            flag = websocket.ABNF.OPCODE_BINARY
[160]461        else:
462            flag = websocket.ABNF.OPCODE_TEXT
[151]463        self._sock.send(message, flag)
464
[165]465    def __call__(self, fd, events):
[151]466        try:
[165]467            data = self._sock.recv()
468            if data is not None:
[183]469                self._switch.receive(self, EthernetFrame(data))
[165]470                return
471        except:
472            traceback.print_exc()
[178]473        self.close()
[151]474
475
[183]476class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
477    NAMESPACE = 'etherws.control'
478
479    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
480        super(EtherWebSocketControlHandler, self).__init__(app, req)
481        self._ioloop = ioloop
482        self._switch = switch
483        self._htpasswd = htpasswd
484        self._debug = debug
485
486    def post(self):
487        id_ = None
488
489        try:
490            req = json.loads(self.request.body)
491            method = req['method']
492            params = req['params']
493            id_ = req.get('id')
494
495            if not method.startswith(self.NAMESPACE + '.'):
496                raise ValueError('invalid method: %s' % method)
497
498            if not isinstance(params, list):
499                raise ValueError('invalid params: %s' % params)
500
501            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
502            result = getattr(self, handler)(params)
503            self.finish({'result': result, 'error': None, 'id': id_})
504
505        except Exception as e:
506            traceback.print_exc()
507            self.finish({'result': None, 'error': str(e), 'id': id_})
508
509    def handle_listPort(self, params):
510        list_ = []
511        for port in self._switch.portlist:
512            list_.append({
513                'port': port.number,
514                'type': port.interface.get_type(),
515                'name': port.interface.get_name(),
516                'tx':   port.tx,
517                'rx':   port.rx,
518                'shut': port.shut,
519            })
520        return {'portlist': list_}
521
522    def handle_addPort(self, params):
523        for p in params:
524            getattr(self, '_openport_' + p['type'])(p)
525        return self.handle_listPort(params)
526
527    def handle_delPort(self, params):
528        for p in params:
529            self._switch.get_port(int(p['port'])).interface.close()
530        return self.handle_listPort(params)
531
532    def handle_shutPort(self, params):
533        for p in params:
534            self._switch.shut_port(int(p['port']), bool(p['flag']))
535        return self.handle_listPort(params)
536
537    def _openport_tap(self, p):
538        dev = p['device']
539        tap = TapHandler(self._ioloop, self._switch, dev, debug=self._debug)
540        tap.open()
541
542    def _openport_client(self, p):
543        ssl_ = self._ssl_wrapper(p.get('insecure'), p.get('cacerts'))
544        cred = {'user': p.get('user'), 'passwd': p.get('passwd')}
545        url = p['url']
546        client = EtherWebSocketClient(self._ioloop, self._switch,
547                                      url, ssl_, cred, self._debug)
548        client.open()
549
550    @staticmethod
551    def _ssl_wrapper(insecure, ca_certs):
552        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': ca_certs}
553        if insecure:
554            args = {}
555        return lambda sock: ssl.wrap_socket(sock, **args)
556
557
[134]558def daemonize(nochdir=False, noclose=False):
559    if os.fork() > 0:
560        sys.exit(0)
561
562    os.setsid()
563
564    if os.fork() > 0:
565        sys.exit(0)
566
567    if not nochdir:
568        os.chdir('/')
569
570    if not noclose:
571        os.umask(0)
572        sys.stdin.close()
573        sys.stdout.close()
574        sys.stderr.close()
575        os.close(0)
576        os.close(1)
577        os.close(2)
578        sys.stdin = open(os.devnull)
579        sys.stdout = open(os.devnull, 'a')
580        sys.stderr = open(os.devnull, 'a')
581
582
[183]583def main():
584    def realpath(ns, *keys):
585        for k in keys:
586            v = getattr(ns, k, None)
587            if v is not None:
588                v = os.path.realpath(v)
589                open(v).close()  # check readable
590                setattr(ns, k, v)
[160]591
[184]592    def checkpath(ns, path):
593        val = getattr(ns, path, '')
594        if not val.startswith('/'):
595            raise ValueError('invalid %: %s' % (path, val))
596
597    def getsslopt(ns, key, cert):
598        kval = getattr(ns, key, None)
599        cval = getattr(ns, cert, None)
600        if kval and cval:
601            return {'keyfile': kval, 'certfile': cval}
602        elif kval or cval:
603            raise ValueError('both %s and %s are required' % (key, cert))
604        return None
605
606    def setport(ns, port, isssl):
607        val = getattr(ns, port, None)
608        if val is None:
609            if isssl:
610                return setattr(ns, port, 443)
611            return setattr(ns, port, 80)
612        if not (0 <= val <= 65535):
613            raise ValueError('invalid %s: %s' % (port, val))
614
615    def sethtpasswd(ns, htpasswd):
616        val = getattr(ns, htpasswd, None)
617        if val:
618            return setattr(ns, htpasswd, Htpasswd(val))
619
[183]620    parser = argparse.ArgumentParser()
[160]621
[183]622    parser.add_argument('--debug', action='store_true', default=False)
623    parser.add_argument('--foreground', action='store_true', default=False)
624    parser.add_argument('--ageout', action='store', type=int, default=300)
[180]625
[183]626    parser.add_argument('--path', action='store', default='/')
[184]627    parser.add_argument('--host', action='store', default='')
[183]628    parser.add_argument('--port', action='store', type=int)
629    parser.add_argument('--htpasswd', action='store')
630    parser.add_argument('--sslkey', action='store')
631    parser.add_argument('--sslcert', action='store')
[180]632
[183]633    parser.add_argument('--ctlpath', action='store', default='/ctl')
[184]634    parser.add_argument('--ctlhost', action='store', default='')
635    parser.add_argument('--ctlport', action='store', type=int)
636    parser.add_argument('--ctlhtpasswd', action='store')
637    parser.add_argument('--ctlsslkey', action='store')
638    parser.add_argument('--ctlsslcert', action='store')
[143]639
[184]640    args = parser.parse_args()
[183]641
642    #if args.debug:
643    #    websocket.enableTrace(True)
644
645    if args.ageout <= 0:
646        raise ValueError('invalid ageout: %s' % args.ageout)
647
[184]648    realpath(args, 'htpasswd', 'sslkey', 'sslcert')
649    realpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
[183]650
[184]651    checkpath(args, 'path')
652    checkpath(args, 'ctlpath')
[183]653
[184]654    sslopt = getsslopt(args, 'sslkey', 'sslcert')
655    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
[143]656
[184]657    setport(args, 'port', sslopt)
658    setport(args, 'ctlport', ctlsslopt)
[143]659
[184]660    sethtpasswd(args, 'htpasswd')
661    sethtpasswd(args, 'ctlhtpasswd')
[167]662
[183]663    ioloop = IOLoop.instance()
[167]664    fdb = FDB(ageout=args.ageout, debug=args.debug)
[183]665    switch = SwitchingHub(fdb, debug=args.debug)
[167]666
[184]667    if args.port == args.ctlport and args.host == args.ctlhost:
668        if args.path == args.ctlpath:
669            raise ValueError('same path/ctlpath on same host')
670        if args.sslkey != args.ctlsslkey:
671            raise ValueError('differ sslkey/ctlsslkey on same host')
672        if args.sslcert != args.ctlsslcert:
673            raise ValueError('differ sslcert/ctlsslcert on same host')
[133]674
[184]675        app = Application([
676            (args.path, EtherWebSocketHandler, {
677                'switch':   switch,
678                'htpasswd': args.htpasswd,
679                'debug':    args.debug,
680            }),
681            (args.ctlpath, EtherWebSocketControlHandler, {
682                'ioloop':   ioloop,
683                'switch':   switch,
684                'htpasswd': args.ctlhtpasswd,
685                'debug':    args.debug,
686            }),
687        ])
688        server = HTTPServer(app, ssl_options=sslopt)
689        server.listen(args.port, address=args.host)
[151]690
[184]691    else:
692        app = Application([(args.path, EtherWebSocketHandler, {
693            'switch':   switch,
694            'htpasswd': args.htpasswd,
695            'debug':    args.debug,
696        })])
697        server = HTTPServer(app, ssl_options=sslopt)
698        server.listen(args.port, address=args.host)
699
700        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
701            'ioloop':   ioloop,
702            'switch':   switch,
703            'htpasswd': args.ctlhtpasswd,
704            'debug':    args.debug,
705        })])
706        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
707        ctlserver.listen(args.ctlport, address=args.ctlhost)
708
[151]709    if not args.foreground:
710        daemonize()
711
[138]712    ioloop.start()
[133]713
714
715if __name__ == '__main__':
716    main()
Note: See TracBrowser for help on using the repository browser.