source: etherws/trunk/etherws.py @ 183

Revision 183, 20.2 KB checked in by atzm, 12 years ago (diff)
  • global change: enables remote control
  • Property svn:keywords set to Id
RevLine 
[133]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
[141]4#              Ethernet over WebSocket tunneling server/client
[133]5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
[136]9#   - websocket-client-0.7.0
[183]10#   - tornado-2.3
[133]11#
[140]12# todo:
[143]13#   - servant mode support (like typical p2p software)
[140]14#
[133]15# ===========================================================================
16# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
17# All rights reserved.
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are met:
21#
22# 1. Redistributions of source code must retain the above copyright notice,
23#    this list of conditions and the following disclaimer.
24# 2. Redistributions in binary form must reproduce the above copyright
25#    notice, this list of conditions and the following disclaimer in the
26#    documentation and/or other materials provided with the distribution.
27#
28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38# POSSIBILITY OF SUCH DAMAGE.
39# ===========================================================================
40#
41# $Id$
42
43import os
44import sys
[156]45import ssl
[160]46import time
[183]47import json
[175]48import fcntl
[150]49import base64
50import hashlib
[151]51import getpass
[133]52import argparse
[165]53import traceback
[133]54
55import websocket
56
[183]57from tornado.web import Application, RequestHandler
[182]58from tornado.websocket import WebSocketHandler
[183]59from tornado.httpserver import HTTPServer
60from tornado.ioloop import IOLoop
61
[166]62from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
[133]63
[166]64
[160]65class DebugMixIn(object):
[166]66    def dprintf(self, msg, func=lambda: ()):
[160]67        if self._debug:
68            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
[164]69            sys.stderr.write(prefix + (msg % func()))
[160]70
71
[164]72class EthernetFrame(object):
73    def __init__(self, data):
74        self.data = data
75
[176]76    @property
77    def dst_multicast(self):
78        return ord(self.data[0]) & 1
[164]79
80    @property
[176]81    def src_multicast(self):
82        return ord(self.data[6]) & 1
83
84    @property
[164]85    def dst_mac(self):
86        return self.data[:6]
87
88    @property
89    def src_mac(self):
90        return self.data[6:12]
91
92    @property
93    def tagged(self):
94        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
95
96    @property
97    def vid(self):
98        if self.tagged:
99            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
[183]100        return 0
[164]101
102
[166]103class FDB(DebugMixIn):
[167]104    def __init__(self, ageout, debug=False):
[164]105        self._ageout = ageout
106        self._debug = debug
[166]107        self._dict = {}
[164]108
109    def lookup(self, frame):
110        mac = frame.dst_mac
111        vid = frame.vid
112
[177]113        group = self._dict.get(vid)
[164]114        if not group:
115            return None
116
[177]117        entry = group.get(mac)
[164]118        if not entry:
119            return None
120
121        if time.time() - entry['time'] > self._ageout:
[183]122            port = self._dict[vid][mac]['port']
[166]123            del self._dict[vid][mac]
124            if not self._dict[vid]:
125                del self._dict[vid]
[183]126            self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
127                         lambda: (port.number, vid, mac.encode('hex')))
[164]128            return None
129
130        return entry['port']
131
[166]132    def learn(self, port, frame):
133        mac = frame.src_mac
134        vid = frame.vid
135
136        if vid not in self._dict:
137            self._dict[vid] = {}
138
139        self._dict[vid][mac] = {'time': time.time(), 'port': port}
[183]140        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
141                     lambda: (port.number, vid, mac.encode('hex')))
[166]142
[164]143    def delete(self, port):
[166]144        for vid in self._dict.keys():
145            for mac in self._dict[vid].keys():
[183]146                if self._dict[vid][mac]['port'].number == port.number:
[166]147                    del self._dict[vid][mac]
[183]148                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
149                                 lambda: (port.number, vid, mac.encode('hex')))
[166]150            if not self._dict[vid]:
151                del self._dict[vid]
[164]152
153
[183]154class SwitchPort(object):
155    def __init__(self, number, interface):
156        self.number = number
157        self.interface = interface
158        self.tx = 0
159        self.rx = 0
160        self.shut = False
161
162    @staticmethod
163    def cmp_by_number(x, y):
164        return cmp(x.number, y.number)
165
166
[166]167class SwitchingHub(DebugMixIn):
168    def __init__(self, fdb, debug=False):
169        self._fdb = fdb
[133]170        self._debug = debug
[183]171        self._table = {}
172        self._next = 1
[133]173
[183]174    @property
175    def portlist(self):
176        return sorted(self._table.itervalues(), cmp=SwitchPort.cmp_by_number)
[133]177
[183]178    def shut_port(self, portnum, flag=True):
179        self._table[portnum].shut = flag
[133]180
[183]181    def get_port(self, portnum):
182        return self._table[portnum]
183
184    def register_port(self, interface):
185        interface._switch_portnum = self._next  # XXX
186        self._table[self._next] = SwitchPort(self._next, interface)
187        self._next += 1
188
189    def unregister_port(self, interface):
190        self._fdb.delete(self._table[interface._switch_portnum])
191        del self._table[interface._switch_portnum]
192        del interface._switch_portnum
193
194    def send(self, dst_interfaces, frame):
195        ports = sorted((self._table[i._switch_portnum] for i in dst_interfaces
196                        if not self._table[i._switch_portnum].shut),
197                       cmp=SwitchPort.cmp_by_number)
198
199        for p in ports:
200            p.interface.write_message(frame.data, True)
201            p.tx += 1
202
203        if ports:
204            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
205                         lambda: (','.join(str(p.number) for p in ports),
206                                  frame.vid,
207                                  frame.src_mac.encode('hex'),
208                                  frame.dst_mac.encode('hex')))
209
210    def receive(self, src_interface, frame):
211        port = self._table[src_interface._switch_portnum]
212
213        if not port.shut:
214            port.rx += 1
215            self._forward(port, frame)
216
217    def _forward(self, src_port, frame):
[166]218        try:
[176]219            if not frame.src_multicast:
[172]220                self._fdb.learn(src_port, frame)
[133]221
[176]222            if not frame.dst_multicast:
[166]223                dst_port = self._fdb.lookup(frame)
[164]224
[166]225                if dst_port:
[183]226                    self.send([dst_port.interface], frame)
[166]227                    return
[133]228
[183]229            ports = set(self._table.itervalues()) - set([src_port])
230            self.send((p.interface for p in ports), frame)
[162]231
[166]232        except:  # ex. received invalid frame
233            traceback.print_exc()
[133]234
[164]235
[179]236class Htpasswd(object):
237    def __init__(self, path):
238        self._path = path
239        self._stat = None
240        self._data = {}
241
242    def auth(self, name, passwd):
243        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
244        return self._data.get(name) == passwd
245
246    def load(self):
247        old_stat = self._stat
248
249        with open(self._path) as fp:
250            fileno = fp.fileno()
251            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
252            self._stat = os.fstat(fileno)
253
254            unchanged = old_stat and \
255                        old_stat.st_ino == self._stat.st_ino and \
256                        old_stat.st_dev == self._stat.st_dev and \
257                        old_stat.st_mtime == self._stat.st_mtime
258
259            if not unchanged:
260                self._data = self._parse(fp)
261
262        return self
263
264    def _parse(self, fp):
265        data = {}
266        for line in fp:
267            line = line.strip()
268            if 0 <= line.find(':'):
269                name, passwd = line.split(':', 1)
270                if passwd.startswith('{SHA}'):
271                    data[name] = passwd[5:]
272        return data
273
274
[182]275class BasicAuthMixIn(object):
276    def _execute(self, transforms, *args, **kwargs):
277        def do_execute():
278            sp = super(BasicAuthMixIn, self)
279            return sp._execute(transforms, *args, **kwargs)
280
281        def auth_required():
282            self.stream.write(tornado.escape.utf8(
283                'HTTP/1.1 401 Authorization Required\r\n'
284                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
285            ))
286            self.stream.close()
287
288        try:
289            if not self._htpasswd:
290                return do_execute()
291
292            creds = self.request.headers.get('Authorization')
293
294            if not creds or not creds.startswith('Basic '):
295                return auth_required()
296
297            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
298
299            if self._htpasswd.load().auth(name, passwd):
300                return do_execute()
301        except:
302            traceback.print_exc()
303
304        return auth_required()
305
306
[166]307class TapHandler(DebugMixIn):
308    READ_SIZE = 65535
309
[178]310    def __init__(self, ioloop, switch, dev, debug=False):
311        self._ioloop = ioloop
[166]312        self._switch = switch
313        self._dev = dev
314        self._debug = debug
315        self._tap = None
316
317    @property
318    def closed(self):
319        return not self._tap
320
[183]321    def get_type(self):
322        return 'tap'
323
324    def get_name(self):
325        if self.closed:
326            return self._dev
327        return self._tap.name
328
[166]329    def open(self):
330        if not self.closed:
331            raise ValueError('already opened')
332        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
333        self._tap.up()
334        self._switch.register_port(self)
[178]335        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
[166]336
337    def close(self):
338        if self.closed:
339            raise ValueError('I/O operation on closed tap')
[178]340        self._ioloop.remove_handler(self.fileno())
[166]341        self._switch.unregister_port(self)
342        self._tap.close()
343        self._tap = None
344
345    def fileno(self):
346        if self.closed:
347            raise ValueError('I/O operation on closed tap')
348        return self._tap.fileno()
349
350    def write_message(self, message, binary=False):
351        if self.closed:
352            raise ValueError('I/O operation on closed tap')
353        self._tap.write(message)
354
[138]355    def __call__(self, fd, events):
[166]356        try:
[183]357            self._switch.receive(self, EthernetFrame(self._read()))
[166]358            return
359        except:
360            traceback.print_exc()
[178]361        self.close()
[166]362
363    def _read(self):
364        if self.closed:
365            raise ValueError('I/O operation on closed tap')
[162]366        buf = []
367        while True:
[166]368            buf.append(self._tap.read(self.READ_SIZE))
369            if len(buf[-1]) < self.READ_SIZE:
[162]370                break
[166]371        return ''.join(buf)
[162]372
373
[182]374class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
[179]375    def __init__(self, app, req, switch, htpasswd=None, debug=False):
[160]376        super(EtherWebSocketHandler, self).__init__(app, req)
[166]377        self._switch = switch
[179]378        self._htpasswd = htpasswd
[133]379        self._debug = debug
380
[179]381        if self._htpasswd:
382            self._htpasswd = Htpasswd(self._htpasswd)
383
[183]384    def get_type(self):
385        return 'server'
386
387    def get_name(self):
388        return self.request.remote_ip
389
[133]390    def open(self):
[166]391        self._switch.register_port(self)
[164]392        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
[133]393
394    def on_message(self, message):
[183]395        self._switch.receive(self, EthernetFrame(message))
[133]396
397    def on_close(self):
[166]398        self._switch.unregister_port(self)
[164]399        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
[133]400
401
[160]402class EtherWebSocketClient(DebugMixIn):
[181]403    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
[178]404        self._ioloop = ioloop
[166]405        self._switch = switch
[151]406        self._url = url
[181]407        self._ssl = ssl_
[160]408        self._debug = debug
[166]409        self._sock = None
[151]410        self._options = {}
411
[174]412        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
413            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
[151]414            auth = ['Authorization: Basic %s' % token]
415            self._options['header'] = auth
416
[160]417    @property
418    def closed(self):
419        return not self._sock
420
[183]421    def get_type(self):
422        return 'client'
423
424    def get_name(self):
425        return self._url
426
[151]427    def open(self):
[181]428        sslwrap = websocket._SSLSocketWrapper
429
[160]430        if not self.closed:
431            raise websocket.WebSocketException('already opened')
[151]432
[181]433        if self._ssl:
434            websocket._SSLSocketWrapper = self._ssl
435
436        try:
437            self._sock = websocket.WebSocket()
438            self._sock.connect(self._url, **self._options)
439            self._switch.register_port(self)
440            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
441            self.dprintf('connected: %s\n', lambda: self._url)
442        finally:
443            websocket._SSLSocketWrapper = sslwrap
444
[151]445    def close(self):
[160]446        if self.closed:
447            raise websocket.WebSocketException('already closed')
[178]448        self._ioloop.remove_handler(self.fileno())
[166]449        self._switch.unregister_port(self)
[151]450        self._sock.close()
451        self._sock = None
[164]452        self.dprintf('disconnected: %s\n', lambda: self._url)
[151]453
[165]454    def fileno(self):
455        if self.closed:
456            raise websocket.WebSocketException('closed socket')
457        return self._sock.io_sock.fileno()
458
[151]459    def write_message(self, message, binary=False):
[160]460        if self.closed:
461            raise websocket.WebSocketException('closed socket')
[151]462        if binary:
463            flag = websocket.ABNF.OPCODE_BINARY
[160]464        else:
465            flag = websocket.ABNF.OPCODE_TEXT
[151]466        self._sock.send(message, flag)
467
[165]468    def __call__(self, fd, events):
[151]469        try:
[165]470            data = self._sock.recv()
471            if data is not None:
[183]472                self._switch.receive(self, EthernetFrame(data))
[165]473                return
474        except:
475            traceback.print_exc()
[178]476        self.close()
[151]477
478
[183]479class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
480    NAMESPACE = 'etherws.control'
481
482    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
483        super(EtherWebSocketControlHandler, self).__init__(app, req)
484        self._ioloop = ioloop
485        self._switch = switch
486        self._htpasswd = htpasswd
487        self._debug = debug
488
489    def post(self):
490        id_ = None
491
492        try:
493            req = json.loads(self.request.body)
494            method = req['method']
495            params = req['params']
496            id_ = req.get('id')
497
498            if not method.startswith(self.NAMESPACE + '.'):
499                raise ValueError('invalid method: %s' % method)
500
501            if not isinstance(params, list):
502                raise ValueError('invalid params: %s' % params)
503
504            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
505            result = getattr(self, handler)(params)
506            self.finish({'result': result, 'error': None, 'id': id_})
507
508        except Exception as e:
509            traceback.print_exc()
510            self.finish({'result': None, 'error': str(e), 'id': id_})
511
512    def handle_listPort(self, params):
513        list_ = []
514        for port in self._switch.portlist:
515            list_.append({
516                'port': port.number,
517                'type': port.interface.get_type(),
518                'name': port.interface.get_name(),
519                'tx':   port.tx,
520                'rx':   port.rx,
521                'shut': port.shut,
522            })
523        return {'portlist': list_}
524
525    def handle_addPort(self, params):
526        for p in params:
527            getattr(self, '_openport_' + p['type'])(p)
528        return self.handle_listPort(params)
529
530    def handle_delPort(self, params):
531        for p in params:
532            self._switch.get_port(int(p['port'])).interface.close()
533        return self.handle_listPort(params)
534
535    def handle_shutPort(self, params):
536        for p in params:
537            self._switch.shut_port(int(p['port']), bool(p['flag']))
538        return self.handle_listPort(params)
539
540    def _openport_tap(self, p):
541        dev = p['device']
542        tap = TapHandler(self._ioloop, self._switch, dev, debug=self._debug)
543        tap.open()
544
545    def _openport_client(self, p):
546        ssl_ = self._ssl_wrapper(p.get('insecure'), p.get('cacerts'))
547        cred = {'user': p.get('user'), 'passwd': p.get('passwd')}
548        url = p['url']
549        client = EtherWebSocketClient(self._ioloop, self._switch,
550                                      url, ssl_, cred, self._debug)
551        client.open()
552
553    @staticmethod
554    def _ssl_wrapper(insecure, ca_certs):
555        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': ca_certs}
556        if insecure:
557            args = {}
558        return lambda sock: ssl.wrap_socket(sock, **args)
559
560
[134]561def daemonize(nochdir=False, noclose=False):
562    if os.fork() > 0:
563        sys.exit(0)
564
565    os.setsid()
566
567    if os.fork() > 0:
568        sys.exit(0)
569
570    if not nochdir:
571        os.chdir('/')
572
573    if not noclose:
574        os.umask(0)
575        sys.stdin.close()
576        sys.stdout.close()
577        sys.stderr.close()
578        os.close(0)
579        os.close(1)
580        os.close(2)
581        sys.stdin = open(os.devnull)
582        sys.stdout = open(os.devnull, 'a')
583        sys.stderr = open(os.devnull, 'a')
584
585
[183]586def main():
587    def realpath(ns, *keys):
588        for k in keys:
589            v = getattr(ns, k, None)
590            if v is not None:
591                v = os.path.realpath(v)
592                open(v).close()  # check readable
593                setattr(ns, k, v)
594        return ns
[160]595
[183]596    parser = argparse.ArgumentParser()
[160]597
[183]598    parser.add_argument('--debug', action='store_true', default=False)
599    parser.add_argument('--foreground', action='store_true', default=False)
600    parser.add_argument('--ageout', action='store', type=int, default=300)
[180]601
[183]602    parser.add_argument('--path', action='store', default='/')
603    parser.add_argument('--address', action='store', default='')
604    parser.add_argument('--port', action='store', type=int)
605    parser.add_argument('--htpasswd', action='store')
606    parser.add_argument('--sslkey', action='store')
607    parser.add_argument('--sslcert', action='store')
[180]608
[183]609    parser.add_argument('--ctlpath', action='store', default='/ctl')
610    parser.add_argument('--ctladdress', action='store', default='127.0.0.1')
611    parser.add_argument('--ctlport', action='store', type=int, default=7867)
[143]612
[183]613    args = realpath(parser.parse_args(), 'htpasswd', 'sslkey', 'sslcert')
614
615    #if args.debug:
616    #    websocket.enableTrace(True)
617
618    if args.ageout <= 0:
619        raise ValueError('invalid ageout: %s' % args.ageout)
620
621    if not args.path.startswith('/'):
622        raise ValueError('invalid path: %s' % args.path)
623
624    if not args.ctlpath.startswith('/'):
625        raise ValueError('invalid ctlpath: %s' % args.ctlpath)
626
627    if args.sslkey and args.sslcert:
628        sslopt = {'keyfile': args.sslkey, 'certfile': args.sslcert}
629    elif args.sslkey or args.sslcert:
630        raise ValueError('both sslkey and sslcert are required')
[160]631    else:
[183]632        sslopt = None
[143]633
[160]634    if args.port is None:
[183]635        if sslopt:
[143]636            args.port = 443
637        else:
638            args.port = 80
[160]639    elif not (0 <= args.port <= 65535):
640        raise ValueError('invalid port: %s' % args.port)
[143]641
[183]642    if not (0 <= args.ctlport <= 65535):
643        raise ValueError('invalid ctlport: %s' % args.ctlport)
[167]644
[183]645    if args.htpasswd:
646        args.htpasswd = Htpasswd(args.htpasswd)
647
648    ioloop = IOLoop.instance()
[167]649    fdb = FDB(ageout=args.ageout, debug=args.debug)
[183]650    switch = SwitchingHub(fdb, debug=args.debug)
[167]651
[183]652    app = Application([(args.path, EtherWebSocketHandler, {
653        'switch':   switch,
654        'htpasswd': args.htpasswd,
655        'debug':    args.debug,
656    })])
657    server = HTTPServer(app, ssl_options=sslopt)
[133]658    server.listen(args.port, address=args.address)
659
[183]660    ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
661        'ioloop':   ioloop,
662        'switch':   switch,
663        'htpasswd': None,
664        'debug':    args.debug,
665    })])
666    ctlserver = HTTPServer(ctl)
667    ctlserver.listen(args.ctlport, address=args.ctladdress)
[151]668
669    if not args.foreground:
670        daemonize()
671
[138]672    ioloop.start()
[133]673
674
675if __name__ == '__main__':
676    main()
Note: See TracBrowser for help on using the repository browser.