source: etherws/trunk/etherws.py @ 183

Revision 183, 20.2 KB checked in by atzm, 12 years ago (diff)
  • global change: enables remote control
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#              Ethernet over WebSocket tunneling server/client
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.3
11#
12# todo:
13#   - servant mode support (like typical p2p software)
14#
15# ===========================================================================
16# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
17# All rights reserved.
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are met:
21#
22# 1. Redistributions of source code must retain the above copyright notice,
23#    this list of conditions and the following disclaimer.
24# 2. Redistributions in binary form must reproduce the above copyright
25#    notice, this list of conditions and the following disclaimer in the
26#    documentation and/or other materials provided with the distribution.
27#
28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38# POSSIBILITY OF SUCH DAMAGE.
39# ===========================================================================
40#
41# $Id$
42
43import os
44import sys
45import ssl
46import time
47import json
48import fcntl
49import base64
50import hashlib
51import getpass
52import argparse
53import traceback
54
55import websocket
56
57from tornado.web import Application, RequestHandler
58from tornado.websocket import WebSocketHandler
59from tornado.httpserver import HTTPServer
60from tornado.ioloop import IOLoop
61
62from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
63
64
65class DebugMixIn(object):
66    def dprintf(self, msg, func=lambda: ()):
67        if self._debug:
68            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
69            sys.stderr.write(prefix + (msg % func()))
70
71
72class EthernetFrame(object):
73    def __init__(self, data):
74        self.data = data
75
76    @property
77    def dst_multicast(self):
78        return ord(self.data[0]) & 1
79
80    @property
81    def src_multicast(self):
82        return ord(self.data[6]) & 1
83
84    @property
85    def dst_mac(self):
86        return self.data[:6]
87
88    @property
89    def src_mac(self):
90        return self.data[6:12]
91
92    @property
93    def tagged(self):
94        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
95
96    @property
97    def vid(self):
98        if self.tagged:
99            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
100        return 0
101
102
103class FDB(DebugMixIn):
104    def __init__(self, ageout, debug=False):
105        self._ageout = ageout
106        self._debug = debug
107        self._dict = {}
108
109    def lookup(self, frame):
110        mac = frame.dst_mac
111        vid = frame.vid
112
113        group = self._dict.get(vid)
114        if not group:
115            return None
116
117        entry = group.get(mac)
118        if not entry:
119            return None
120
121        if time.time() - entry['time'] > self._ageout:
122            port = self._dict[vid][mac]['port']
123            del self._dict[vid][mac]
124            if not self._dict[vid]:
125                del self._dict[vid]
126            self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
127                         lambda: (port.number, vid, mac.encode('hex')))
128            return None
129
130        return entry['port']
131
132    def learn(self, port, frame):
133        mac = frame.src_mac
134        vid = frame.vid
135
136        if vid not in self._dict:
137            self._dict[vid] = {}
138
139        self._dict[vid][mac] = {'time': time.time(), 'port': port}
140        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
141                     lambda: (port.number, vid, mac.encode('hex')))
142
143    def delete(self, port):
144        for vid in self._dict.keys():
145            for mac in self._dict[vid].keys():
146                if self._dict[vid][mac]['port'].number == port.number:
147                    del self._dict[vid][mac]
148                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
149                                 lambda: (port.number, vid, mac.encode('hex')))
150            if not self._dict[vid]:
151                del self._dict[vid]
152
153
154class SwitchPort(object):
155    def __init__(self, number, interface):
156        self.number = number
157        self.interface = interface
158        self.tx = 0
159        self.rx = 0
160        self.shut = False
161
162    @staticmethod
163    def cmp_by_number(x, y):
164        return cmp(x.number, y.number)
165
166
167class SwitchingHub(DebugMixIn):
168    def __init__(self, fdb, debug=False):
169        self._fdb = fdb
170        self._debug = debug
171        self._table = {}
172        self._next = 1
173
174    @property
175    def portlist(self):
176        return sorted(self._table.itervalues(), cmp=SwitchPort.cmp_by_number)
177
178    def shut_port(self, portnum, flag=True):
179        self._table[portnum].shut = flag
180
181    def get_port(self, portnum):
182        return self._table[portnum]
183
184    def register_port(self, interface):
185        interface._switch_portnum = self._next  # XXX
186        self._table[self._next] = SwitchPort(self._next, interface)
187        self._next += 1
188
189    def unregister_port(self, interface):
190        self._fdb.delete(self._table[interface._switch_portnum])
191        del self._table[interface._switch_portnum]
192        del interface._switch_portnum
193
194    def send(self, dst_interfaces, frame):
195        ports = sorted((self._table[i._switch_portnum] for i in dst_interfaces
196                        if not self._table[i._switch_portnum].shut),
197                       cmp=SwitchPort.cmp_by_number)
198
199        for p in ports:
200            p.interface.write_message(frame.data, True)
201            p.tx += 1
202
203        if ports:
204            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
205                         lambda: (','.join(str(p.number) for p in ports),
206                                  frame.vid,
207                                  frame.src_mac.encode('hex'),
208                                  frame.dst_mac.encode('hex')))
209
210    def receive(self, src_interface, frame):
211        port = self._table[src_interface._switch_portnum]
212
213        if not port.shut:
214            port.rx += 1
215            self._forward(port, frame)
216
217    def _forward(self, src_port, frame):
218        try:
219            if not frame.src_multicast:
220                self._fdb.learn(src_port, frame)
221
222            if not frame.dst_multicast:
223                dst_port = self._fdb.lookup(frame)
224
225                if dst_port:
226                    self.send([dst_port.interface], frame)
227                    return
228
229            ports = set(self._table.itervalues()) - set([src_port])
230            self.send((p.interface for p in ports), frame)
231
232        except:  # ex. received invalid frame
233            traceback.print_exc()
234
235
236class Htpasswd(object):
237    def __init__(self, path):
238        self._path = path
239        self._stat = None
240        self._data = {}
241
242    def auth(self, name, passwd):
243        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
244        return self._data.get(name) == passwd
245
246    def load(self):
247        old_stat = self._stat
248
249        with open(self._path) as fp:
250            fileno = fp.fileno()
251            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
252            self._stat = os.fstat(fileno)
253
254            unchanged = old_stat and \
255                        old_stat.st_ino == self._stat.st_ino and \
256                        old_stat.st_dev == self._stat.st_dev and \
257                        old_stat.st_mtime == self._stat.st_mtime
258
259            if not unchanged:
260                self._data = self._parse(fp)
261
262        return self
263
264    def _parse(self, fp):
265        data = {}
266        for line in fp:
267            line = line.strip()
268            if 0 <= line.find(':'):
269                name, passwd = line.split(':', 1)
270                if passwd.startswith('{SHA}'):
271                    data[name] = passwd[5:]
272        return data
273
274
275class BasicAuthMixIn(object):
276    def _execute(self, transforms, *args, **kwargs):
277        def do_execute():
278            sp = super(BasicAuthMixIn, self)
279            return sp._execute(transforms, *args, **kwargs)
280
281        def auth_required():
282            self.stream.write(tornado.escape.utf8(
283                'HTTP/1.1 401 Authorization Required\r\n'
284                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
285            ))
286            self.stream.close()
287
288        try:
289            if not self._htpasswd:
290                return do_execute()
291
292            creds = self.request.headers.get('Authorization')
293
294            if not creds or not creds.startswith('Basic '):
295                return auth_required()
296
297            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
298
299            if self._htpasswd.load().auth(name, passwd):
300                return do_execute()
301        except:
302            traceback.print_exc()
303
304        return auth_required()
305
306
307class TapHandler(DebugMixIn):
308    READ_SIZE = 65535
309
310    def __init__(self, ioloop, switch, dev, debug=False):
311        self._ioloop = ioloop
312        self._switch = switch
313        self._dev = dev
314        self._debug = debug
315        self._tap = None
316
317    @property
318    def closed(self):
319        return not self._tap
320
321    def get_type(self):
322        return 'tap'
323
324    def get_name(self):
325        if self.closed:
326            return self._dev
327        return self._tap.name
328
329    def open(self):
330        if not self.closed:
331            raise ValueError('already opened')
332        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
333        self._tap.up()
334        self._switch.register_port(self)
335        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
336
337    def close(self):
338        if self.closed:
339            raise ValueError('I/O operation on closed tap')
340        self._ioloop.remove_handler(self.fileno())
341        self._switch.unregister_port(self)
342        self._tap.close()
343        self._tap = None
344
345    def fileno(self):
346        if self.closed:
347            raise ValueError('I/O operation on closed tap')
348        return self._tap.fileno()
349
350    def write_message(self, message, binary=False):
351        if self.closed:
352            raise ValueError('I/O operation on closed tap')
353        self._tap.write(message)
354
355    def __call__(self, fd, events):
356        try:
357            self._switch.receive(self, EthernetFrame(self._read()))
358            return
359        except:
360            traceback.print_exc()
361        self.close()
362
363    def _read(self):
364        if self.closed:
365            raise ValueError('I/O operation on closed tap')
366        buf = []
367        while True:
368            buf.append(self._tap.read(self.READ_SIZE))
369            if len(buf[-1]) < self.READ_SIZE:
370                break
371        return ''.join(buf)
372
373
374class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
375    def __init__(self, app, req, switch, htpasswd=None, debug=False):
376        super(EtherWebSocketHandler, self).__init__(app, req)
377        self._switch = switch
378        self._htpasswd = htpasswd
379        self._debug = debug
380
381        if self._htpasswd:
382            self._htpasswd = Htpasswd(self._htpasswd)
383
384    def get_type(self):
385        return 'server'
386
387    def get_name(self):
388        return self.request.remote_ip
389
390    def open(self):
391        self._switch.register_port(self)
392        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
393
394    def on_message(self, message):
395        self._switch.receive(self, EthernetFrame(message))
396
397    def on_close(self):
398        self._switch.unregister_port(self)
399        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
400
401
402class EtherWebSocketClient(DebugMixIn):
403    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
404        self._ioloop = ioloop
405        self._switch = switch
406        self._url = url
407        self._ssl = ssl_
408        self._debug = debug
409        self._sock = None
410        self._options = {}
411
412        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
413            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
414            auth = ['Authorization: Basic %s' % token]
415            self._options['header'] = auth
416
417    @property
418    def closed(self):
419        return not self._sock
420
421    def get_type(self):
422        return 'client'
423
424    def get_name(self):
425        return self._url
426
427    def open(self):
428        sslwrap = websocket._SSLSocketWrapper
429
430        if not self.closed:
431            raise websocket.WebSocketException('already opened')
432
433        if self._ssl:
434            websocket._SSLSocketWrapper = self._ssl
435
436        try:
437            self._sock = websocket.WebSocket()
438            self._sock.connect(self._url, **self._options)
439            self._switch.register_port(self)
440            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
441            self.dprintf('connected: %s\n', lambda: self._url)
442        finally:
443            websocket._SSLSocketWrapper = sslwrap
444
445    def close(self):
446        if self.closed:
447            raise websocket.WebSocketException('already closed')
448        self._ioloop.remove_handler(self.fileno())
449        self._switch.unregister_port(self)
450        self._sock.close()
451        self._sock = None
452        self.dprintf('disconnected: %s\n', lambda: self._url)
453
454    def fileno(self):
455        if self.closed:
456            raise websocket.WebSocketException('closed socket')
457        return self._sock.io_sock.fileno()
458
459    def write_message(self, message, binary=False):
460        if self.closed:
461            raise websocket.WebSocketException('closed socket')
462        if binary:
463            flag = websocket.ABNF.OPCODE_BINARY
464        else:
465            flag = websocket.ABNF.OPCODE_TEXT
466        self._sock.send(message, flag)
467
468    def __call__(self, fd, events):
469        try:
470            data = self._sock.recv()
471            if data is not None:
472                self._switch.receive(self, EthernetFrame(data))
473                return
474        except:
475            traceback.print_exc()
476        self.close()
477
478
479class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
480    NAMESPACE = 'etherws.control'
481
482    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
483        super(EtherWebSocketControlHandler, self).__init__(app, req)
484        self._ioloop = ioloop
485        self._switch = switch
486        self._htpasswd = htpasswd
487        self._debug = debug
488
489    def post(self):
490        id_ = None
491
492        try:
493            req = json.loads(self.request.body)
494            method = req['method']
495            params = req['params']
496            id_ = req.get('id')
497
498            if not method.startswith(self.NAMESPACE + '.'):
499                raise ValueError('invalid method: %s' % method)
500
501            if not isinstance(params, list):
502                raise ValueError('invalid params: %s' % params)
503
504            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
505            result = getattr(self, handler)(params)
506            self.finish({'result': result, 'error': None, 'id': id_})
507
508        except Exception as e:
509            traceback.print_exc()
510            self.finish({'result': None, 'error': str(e), 'id': id_})
511
512    def handle_listPort(self, params):
513        list_ = []
514        for port in self._switch.portlist:
515            list_.append({
516                'port': port.number,
517                'type': port.interface.get_type(),
518                'name': port.interface.get_name(),
519                'tx':   port.tx,
520                'rx':   port.rx,
521                'shut': port.shut,
522            })
523        return {'portlist': list_}
524
525    def handle_addPort(self, params):
526        for p in params:
527            getattr(self, '_openport_' + p['type'])(p)
528        return self.handle_listPort(params)
529
530    def handle_delPort(self, params):
531        for p in params:
532            self._switch.get_port(int(p['port'])).interface.close()
533        return self.handle_listPort(params)
534
535    def handle_shutPort(self, params):
536        for p in params:
537            self._switch.shut_port(int(p['port']), bool(p['flag']))
538        return self.handle_listPort(params)
539
540    def _openport_tap(self, p):
541        dev = p['device']
542        tap = TapHandler(self._ioloop, self._switch, dev, debug=self._debug)
543        tap.open()
544
545    def _openport_client(self, p):
546        ssl_ = self._ssl_wrapper(p.get('insecure'), p.get('cacerts'))
547        cred = {'user': p.get('user'), 'passwd': p.get('passwd')}
548        url = p['url']
549        client = EtherWebSocketClient(self._ioloop, self._switch,
550                                      url, ssl_, cred, self._debug)
551        client.open()
552
553    @staticmethod
554    def _ssl_wrapper(insecure, ca_certs):
555        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': ca_certs}
556        if insecure:
557            args = {}
558        return lambda sock: ssl.wrap_socket(sock, **args)
559
560
561def daemonize(nochdir=False, noclose=False):
562    if os.fork() > 0:
563        sys.exit(0)
564
565    os.setsid()
566
567    if os.fork() > 0:
568        sys.exit(0)
569
570    if not nochdir:
571        os.chdir('/')
572
573    if not noclose:
574        os.umask(0)
575        sys.stdin.close()
576        sys.stdout.close()
577        sys.stderr.close()
578        os.close(0)
579        os.close(1)
580        os.close(2)
581        sys.stdin = open(os.devnull)
582        sys.stdout = open(os.devnull, 'a')
583        sys.stderr = open(os.devnull, 'a')
584
585
586def main():
587    def realpath(ns, *keys):
588        for k in keys:
589            v = getattr(ns, k, None)
590            if v is not None:
591                v = os.path.realpath(v)
592                open(v).close()  # check readable
593                setattr(ns, k, v)
594        return ns
595
596    parser = argparse.ArgumentParser()
597
598    parser.add_argument('--debug', action='store_true', default=False)
599    parser.add_argument('--foreground', action='store_true', default=False)
600    parser.add_argument('--ageout', action='store', type=int, default=300)
601
602    parser.add_argument('--path', action='store', default='/')
603    parser.add_argument('--address', action='store', default='')
604    parser.add_argument('--port', action='store', type=int)
605    parser.add_argument('--htpasswd', action='store')
606    parser.add_argument('--sslkey', action='store')
607    parser.add_argument('--sslcert', action='store')
608
609    parser.add_argument('--ctlpath', action='store', default='/ctl')
610    parser.add_argument('--ctladdress', action='store', default='127.0.0.1')
611    parser.add_argument('--ctlport', action='store', type=int, default=7867)
612
613    args = realpath(parser.parse_args(), 'htpasswd', 'sslkey', 'sslcert')
614
615    #if args.debug:
616    #    websocket.enableTrace(True)
617
618    if args.ageout <= 0:
619        raise ValueError('invalid ageout: %s' % args.ageout)
620
621    if not args.path.startswith('/'):
622        raise ValueError('invalid path: %s' % args.path)
623
624    if not args.ctlpath.startswith('/'):
625        raise ValueError('invalid ctlpath: %s' % args.ctlpath)
626
627    if args.sslkey and args.sslcert:
628        sslopt = {'keyfile': args.sslkey, 'certfile': args.sslcert}
629    elif args.sslkey or args.sslcert:
630        raise ValueError('both sslkey and sslcert are required')
631    else:
632        sslopt = None
633
634    if args.port is None:
635        if sslopt:
636            args.port = 443
637        else:
638            args.port = 80
639    elif not (0 <= args.port <= 65535):
640        raise ValueError('invalid port: %s' % args.port)
641
642    if not (0 <= args.ctlport <= 65535):
643        raise ValueError('invalid ctlport: %s' % args.ctlport)
644
645    if args.htpasswd:
646        args.htpasswd = Htpasswd(args.htpasswd)
647
648    ioloop = IOLoop.instance()
649    fdb = FDB(ageout=args.ageout, debug=args.debug)
650    switch = SwitchingHub(fdb, debug=args.debug)
651
652    app = Application([(args.path, EtherWebSocketHandler, {
653        'switch':   switch,
654        'htpasswd': args.htpasswd,
655        'debug':    args.debug,
656    })])
657    server = HTTPServer(app, ssl_options=sslopt)
658    server.listen(args.port, address=args.address)
659
660    ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
661        'ioloop':   ioloop,
662        'switch':   switch,
663        'htpasswd': None,
664        'debug':    args.debug,
665    })])
666    ctlserver = HTTPServer(ctl)
667    ctlserver.listen(args.ctlport, address=args.ctladdress)
668
669    if not args.foreground:
670        daemonize()
671
672    ioloop.start()
673
674
675if __name__ == '__main__':
676    main()
Note: See TracBrowser for help on using the repository browser.