source: etherws/trunk/etherws.py @ 185

Revision 185, 21.8 KB checked in by atzm, 12 years ago (diff)
  • fixed a bug, raise exception when controller uses authentication
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#              Ethernet over WebSocket tunneling server/client
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.3
11#
12# todo:
13#   - servant mode support (like typical p2p software)
14#
15# ===========================================================================
16# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
17# All rights reserved.
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are met:
21#
22# 1. Redistributions of source code must retain the above copyright notice,
23#    this list of conditions and the following disclaimer.
24# 2. Redistributions in binary form must reproduce the above copyright
25#    notice, this list of conditions and the following disclaimer in the
26#    documentation and/or other materials provided with the distribution.
27#
28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38# POSSIBILITY OF SUCH DAMAGE.
39# ===========================================================================
40#
41# $Id$
42
43import os
44import sys
45import ssl
46import time
47import json
48import fcntl
49import base64
50import hashlib
51import getpass
52import argparse
53import traceback
54
55import tornado
56import websocket
57
58from tornado.web import Application, RequestHandler
59from tornado.websocket import WebSocketHandler
60from tornado.httpserver import HTTPServer
61from tornado.ioloop import IOLoop
62
63from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
64
65
66class DebugMixIn(object):
67    def dprintf(self, msg, func=lambda: ()):
68        if self._debug:
69            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
70            sys.stderr.write(prefix + (msg % func()))
71
72
73class EthernetFrame(object):
74    def __init__(self, data):
75        self.data = data
76
77    @property
78    def dst_multicast(self):
79        return ord(self.data[0]) & 1
80
81    @property
82    def src_multicast(self):
83        return ord(self.data[6]) & 1
84
85    @property
86    def dst_mac(self):
87        return self.data[:6]
88
89    @property
90    def src_mac(self):
91        return self.data[6:12]
92
93    @property
94    def tagged(self):
95        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
96
97    @property
98    def vid(self):
99        if self.tagged:
100            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
101        return 0
102
103
104class FDB(DebugMixIn):
105    def __init__(self, ageout, debug=False):
106        self._ageout = ageout
107        self._debug = debug
108        self._dict = {}
109
110    def lookup(self, frame):
111        mac = frame.dst_mac
112        vid = frame.vid
113
114        group = self._dict.get(vid)
115        if not group:
116            return None
117
118        entry = group.get(mac)
119        if not entry:
120            return None
121
122        if time.time() - entry['time'] > self._ageout:
123            port = self._dict[vid][mac]['port']
124            del self._dict[vid][mac]
125            if not self._dict[vid]:
126                del self._dict[vid]
127            self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
128                         lambda: (port.number, vid, mac.encode('hex')))
129            return None
130
131        return entry['port']
132
133    def learn(self, port, frame):
134        mac = frame.src_mac
135        vid = frame.vid
136
137        if vid not in self._dict:
138            self._dict[vid] = {}
139
140        self._dict[vid][mac] = {'time': time.time(), 'port': port}
141        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
142                     lambda: (port.number, vid, mac.encode('hex')))
143
144    def delete(self, port):
145        for vid in self._dict.keys():
146            for mac in self._dict[vid].keys():
147                if self._dict[vid][mac]['port'].number == port.number:
148                    del self._dict[vid][mac]
149                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
150                                 lambda: (port.number, vid, mac.encode('hex')))
151            if not self._dict[vid]:
152                del self._dict[vid]
153
154
155class SwitchPort(object):
156    def __init__(self, number, interface):
157        self.number = number
158        self.interface = interface
159        self.tx = 0
160        self.rx = 0
161        self.shut = False
162
163    @staticmethod
164    def cmp_by_number(x, y):
165        return cmp(x.number, y.number)
166
167
168class SwitchingHub(DebugMixIn):
169    def __init__(self, fdb, debug=False):
170        self._fdb = fdb
171        self._debug = debug
172        self._table = {}
173        self._next = 1
174
175    @property
176    def portlist(self):
177        return sorted(self._table.itervalues(), cmp=SwitchPort.cmp_by_number)
178
179    def shut_port(self, portnum, flag=True):
180        self._table[portnum].shut = flag
181
182    def get_port(self, portnum):
183        return self._table[portnum]
184
185    def register_port(self, interface):
186        interface._switch_portnum = self._next  # XXX
187        self._table[self._next] = SwitchPort(self._next, interface)
188        self._next += 1
189
190    def unregister_port(self, interface):
191        self._fdb.delete(self._table[interface._switch_portnum])
192        del self._table[interface._switch_portnum]
193        del interface._switch_portnum
194
195    def send(self, dst_interfaces, frame):
196        ports = sorted((self._table[i._switch_portnum] for i in dst_interfaces
197                        if not self._table[i._switch_portnum].shut),
198                       cmp=SwitchPort.cmp_by_number)
199
200        for p in ports:
201            p.interface.write_message(frame.data, True)
202            p.tx += 1
203
204        if ports:
205            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
206                         lambda: (','.join(str(p.number) for p in ports),
207                                  frame.vid,
208                                  frame.src_mac.encode('hex'),
209                                  frame.dst_mac.encode('hex')))
210
211    def receive(self, src_interface, frame):
212        port = self._table[src_interface._switch_portnum]
213
214        if not port.shut:
215            port.rx += 1
216            self._forward(port, frame)
217
218    def _forward(self, src_port, frame):
219        try:
220            if not frame.src_multicast:
221                self._fdb.learn(src_port, frame)
222
223            if not frame.dst_multicast:
224                dst_port = self._fdb.lookup(frame)
225
226                if dst_port:
227                    self.send([dst_port.interface], frame)
228                    return
229
230            ports = set(self._table.itervalues()) - set([src_port])
231            self.send((p.interface for p in ports), frame)
232
233        except:  # ex. received invalid frame
234            traceback.print_exc()
235
236
237class Htpasswd(object):
238    def __init__(self, path):
239        self._path = path
240        self._stat = None
241        self._data = {}
242
243    def auth(self, name, passwd):
244        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
245        return self._data.get(name) == passwd
246
247    def load(self):
248        old_stat = self._stat
249
250        with open(self._path) as fp:
251            fileno = fp.fileno()
252            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
253            self._stat = os.fstat(fileno)
254
255            unchanged = old_stat and \
256                        old_stat.st_ino == self._stat.st_ino and \
257                        old_stat.st_dev == self._stat.st_dev and \
258                        old_stat.st_mtime == self._stat.st_mtime
259
260            if not unchanged:
261                self._data = self._parse(fp)
262
263        return self
264
265    def _parse(self, fp):
266        data = {}
267        for line in fp:
268            line = line.strip()
269            if 0 <= line.find(':'):
270                name, passwd = line.split(':', 1)
271                if passwd.startswith('{SHA}'):
272                    data[name] = passwd[5:]
273        return data
274
275
276class BasicAuthMixIn(object):
277    def _execute(self, transforms, *args, **kwargs):
278        def do_execute():
279            sp = super(BasicAuthMixIn, self)
280            return sp._execute(transforms, *args, **kwargs)
281
282        def auth_required():
283            stream = getattr(self, 'stream', self.request.connection.stream)
284            stream.write(tornado.escape.utf8(
285                'HTTP/1.1 401 Authorization Required\r\n'
286                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
287            ))
288            stream.close()
289
290        try:
291            if not self._htpasswd:
292                return do_execute()
293
294            creds = self.request.headers.get('Authorization')
295
296            if not creds or not creds.startswith('Basic '):
297                return auth_required()
298
299            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
300
301            if self._htpasswd.load().auth(name, passwd):
302                return do_execute()
303        except:
304            traceback.print_exc()
305
306        return auth_required()
307
308
309class TapHandler(DebugMixIn):
310    READ_SIZE = 65535
311
312    def __init__(self, ioloop, switch, dev, debug=False):
313        self._ioloop = ioloop
314        self._switch = switch
315        self._dev = dev
316        self._debug = debug
317        self._tap = None
318
319    @property
320    def closed(self):
321        return not self._tap
322
323    def get_type(self):
324        return 'tap'
325
326    def get_name(self):
327        if self.closed:
328            return self._dev
329        return self._tap.name
330
331    def open(self):
332        if not self.closed:
333            raise ValueError('already opened')
334        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
335        self._tap.up()
336        self._switch.register_port(self)
337        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
338
339    def close(self):
340        if self.closed:
341            raise ValueError('I/O operation on closed tap')
342        self._ioloop.remove_handler(self.fileno())
343        self._switch.unregister_port(self)
344        self._tap.close()
345        self._tap = None
346
347    def fileno(self):
348        if self.closed:
349            raise ValueError('I/O operation on closed tap')
350        return self._tap.fileno()
351
352    def write_message(self, message, binary=False):
353        if self.closed:
354            raise ValueError('I/O operation on closed tap')
355        self._tap.write(message)
356
357    def __call__(self, fd, events):
358        try:
359            self._switch.receive(self, EthernetFrame(self._read()))
360            return
361        except:
362            traceback.print_exc()
363        self.close()
364
365    def _read(self):
366        if self.closed:
367            raise ValueError('I/O operation on closed tap')
368        buf = []
369        while True:
370            buf.append(self._tap.read(self.READ_SIZE))
371            if len(buf[-1]) < self.READ_SIZE:
372                break
373        return ''.join(buf)
374
375
376class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
377    def __init__(self, app, req, switch, htpasswd=None, debug=False):
378        super(EtherWebSocketHandler, self).__init__(app, req)
379        self._switch = switch
380        self._htpasswd = htpasswd
381        self._debug = debug
382
383    def get_type(self):
384        return 'server'
385
386    def get_name(self):
387        return self.request.remote_ip
388
389    def open(self):
390        self._switch.register_port(self)
391        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
392
393    def on_message(self, message):
394        self._switch.receive(self, EthernetFrame(message))
395
396    def on_close(self):
397        self._switch.unregister_port(self)
398        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
399
400
401class EtherWebSocketClient(DebugMixIn):
402    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
403        self._ioloop = ioloop
404        self._switch = switch
405        self._url = url
406        self._ssl = ssl_
407        self._debug = debug
408        self._sock = None
409        self._options = {}
410
411        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
412            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
413            auth = ['Authorization: Basic %s' % token]
414            self._options['header'] = auth
415
416    @property
417    def closed(self):
418        return not self._sock
419
420    def get_type(self):
421        return 'client'
422
423    def get_name(self):
424        return self._url
425
426    def open(self):
427        sslwrap = websocket._SSLSocketWrapper
428
429        if not self.closed:
430            raise websocket.WebSocketException('already opened')
431
432        if self._ssl:
433            websocket._SSLSocketWrapper = self._ssl
434
435        try:
436            self._sock = websocket.WebSocket()
437            self._sock.connect(self._url, **self._options)
438            self._switch.register_port(self)
439            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
440            self.dprintf('connected: %s\n', lambda: self._url)
441        finally:
442            websocket._SSLSocketWrapper = sslwrap
443
444    def close(self):
445        if self.closed:
446            raise websocket.WebSocketException('already closed')
447        self._ioloop.remove_handler(self.fileno())
448        self._switch.unregister_port(self)
449        self._sock.close()
450        self._sock = None
451        self.dprintf('disconnected: %s\n', lambda: self._url)
452
453    def fileno(self):
454        if self.closed:
455            raise websocket.WebSocketException('closed socket')
456        return self._sock.io_sock.fileno()
457
458    def write_message(self, message, binary=False):
459        if self.closed:
460            raise websocket.WebSocketException('closed socket')
461        if binary:
462            flag = websocket.ABNF.OPCODE_BINARY
463        else:
464            flag = websocket.ABNF.OPCODE_TEXT
465        self._sock.send(message, flag)
466
467    def __call__(self, fd, events):
468        try:
469            data = self._sock.recv()
470            if data is not None:
471                self._switch.receive(self, EthernetFrame(data))
472                return
473        except:
474            traceback.print_exc()
475        self.close()
476
477
478class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
479    NAMESPACE = 'etherws.control'
480
481    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
482        super(EtherWebSocketControlHandler, self).__init__(app, req)
483        self._ioloop = ioloop
484        self._switch = switch
485        self._htpasswd = htpasswd
486        self._debug = debug
487
488    def post(self):
489        id_ = None
490
491        try:
492            req = json.loads(self.request.body)
493            method = req['method']
494            params = req['params']
495            id_ = req.get('id')
496
497            if not method.startswith(self.NAMESPACE + '.'):
498                raise ValueError('invalid method: %s' % method)
499
500            if not isinstance(params, list):
501                raise ValueError('invalid params: %s' % params)
502
503            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
504            result = getattr(self, handler)(params)
505            self.finish({'result': result, 'error': None, 'id': id_})
506
507        except Exception as e:
508            traceback.print_exc()
509            self.finish({'result': None, 'error': str(e), 'id': id_})
510
511    def handle_listPort(self, params):
512        list_ = []
513        for port in self._switch.portlist:
514            list_.append({
515                'port': port.number,
516                'type': port.interface.get_type(),
517                'name': port.interface.get_name(),
518                'tx':   port.tx,
519                'rx':   port.rx,
520                'shut': port.shut,
521            })
522        return {'portlist': list_}
523
524    def handle_addPort(self, params):
525        for p in params:
526            getattr(self, '_openport_' + p['type'])(p)
527        return self.handle_listPort(params)
528
529    def handle_delPort(self, params):
530        for p in params:
531            self._switch.get_port(int(p['port'])).interface.close()
532        return self.handle_listPort(params)
533
534    def handle_shutPort(self, params):
535        for p in params:
536            self._switch.shut_port(int(p['port']), bool(p['flag']))
537        return self.handle_listPort(params)
538
539    def _openport_tap(self, p):
540        dev = p['device']
541        tap = TapHandler(self._ioloop, self._switch, dev, debug=self._debug)
542        tap.open()
543
544    def _openport_client(self, p):
545        ssl_ = self._ssl_wrapper(p.get('insecure'), p.get('cacerts'))
546        cred = {'user': p.get('user'), 'passwd': p.get('passwd')}
547        url = p['url']
548        client = EtherWebSocketClient(self._ioloop, self._switch,
549                                      url, ssl_, cred, self._debug)
550        client.open()
551
552    @staticmethod
553    def _ssl_wrapper(insecure, ca_certs):
554        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': ca_certs}
555        if insecure:
556            args = {}
557        return lambda sock: ssl.wrap_socket(sock, **args)
558
559
560def daemonize(nochdir=False, noclose=False):
561    if os.fork() > 0:
562        sys.exit(0)
563
564    os.setsid()
565
566    if os.fork() > 0:
567        sys.exit(0)
568
569    if not nochdir:
570        os.chdir('/')
571
572    if not noclose:
573        os.umask(0)
574        sys.stdin.close()
575        sys.stdout.close()
576        sys.stderr.close()
577        os.close(0)
578        os.close(1)
579        os.close(2)
580        sys.stdin = open(os.devnull)
581        sys.stdout = open(os.devnull, 'a')
582        sys.stderr = open(os.devnull, 'a')
583
584
585def main():
586    def realpath(ns, *keys):
587        for k in keys:
588            v = getattr(ns, k, None)
589            if v is not None:
590                v = os.path.realpath(v)
591                open(v).close()  # check readable
592                setattr(ns, k, v)
593
594    def checkpath(ns, path):
595        val = getattr(ns, path, '')
596        if not val.startswith('/'):
597            raise ValueError('invalid %: %s' % (path, val))
598
599    def getsslopt(ns, key, cert):
600        kval = getattr(ns, key, None)
601        cval = getattr(ns, cert, None)
602        if kval and cval:
603            return {'keyfile': kval, 'certfile': cval}
604        elif kval or cval:
605            raise ValueError('both %s and %s are required' % (key, cert))
606        return None
607
608    def setport(ns, port, isssl):
609        val = getattr(ns, port, None)
610        if val is None:
611            if isssl:
612                return setattr(ns, port, 443)
613            return setattr(ns, port, 80)
614        if not (0 <= val <= 65535):
615            raise ValueError('invalid %s: %s' % (port, val))
616
617    def sethtpasswd(ns, htpasswd):
618        val = getattr(ns, htpasswd, None)
619        if val:
620            return setattr(ns, htpasswd, Htpasswd(val))
621
622    parser = argparse.ArgumentParser()
623
624    parser.add_argument('--debug', action='store_true', default=False)
625    parser.add_argument('--foreground', action='store_true', default=False)
626    parser.add_argument('--ageout', action='store', type=int, default=300)
627
628    parser.add_argument('--path', action='store', default='/')
629    parser.add_argument('--host', action='store', default='')
630    parser.add_argument('--port', action='store', type=int)
631    parser.add_argument('--htpasswd', action='store')
632    parser.add_argument('--sslkey', action='store')
633    parser.add_argument('--sslcert', action='store')
634
635    parser.add_argument('--ctlpath', action='store', default='/ctl')
636    parser.add_argument('--ctlhost', action='store', default='')
637    parser.add_argument('--ctlport', action='store', type=int)
638    parser.add_argument('--ctlhtpasswd', action='store')
639    parser.add_argument('--ctlsslkey', action='store')
640    parser.add_argument('--ctlsslcert', action='store')
641
642    args = parser.parse_args()
643
644    #if args.debug:
645    #    websocket.enableTrace(True)
646
647    if args.ageout <= 0:
648        raise ValueError('invalid ageout: %s' % args.ageout)
649
650    realpath(args, 'htpasswd', 'sslkey', 'sslcert')
651    realpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
652
653    checkpath(args, 'path')
654    checkpath(args, 'ctlpath')
655
656    sslopt = getsslopt(args, 'sslkey', 'sslcert')
657    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
658
659    setport(args, 'port', sslopt)
660    setport(args, 'ctlport', ctlsslopt)
661
662    sethtpasswd(args, 'htpasswd')
663    sethtpasswd(args, 'ctlhtpasswd')
664
665    ioloop = IOLoop.instance()
666    fdb = FDB(ageout=args.ageout, debug=args.debug)
667    switch = SwitchingHub(fdb, debug=args.debug)
668
669    if args.port == args.ctlport and args.host == args.ctlhost:
670        if args.path == args.ctlpath:
671            raise ValueError('same path/ctlpath on same host')
672        if args.sslkey != args.ctlsslkey:
673            raise ValueError('differ sslkey/ctlsslkey on same host')
674        if args.sslcert != args.ctlsslcert:
675            raise ValueError('differ sslcert/ctlsslcert on same host')
676
677        app = Application([
678            (args.path, EtherWebSocketHandler, {
679                'switch':   switch,
680                'htpasswd': args.htpasswd,
681                'debug':    args.debug,
682            }),
683            (args.ctlpath, EtherWebSocketControlHandler, {
684                'ioloop':   ioloop,
685                'switch':   switch,
686                'htpasswd': args.ctlhtpasswd,
687                'debug':    args.debug,
688            }),
689        ])
690        server = HTTPServer(app, ssl_options=sslopt)
691        server.listen(args.port, address=args.host)
692
693    else:
694        app = Application([(args.path, EtherWebSocketHandler, {
695            'switch':   switch,
696            'htpasswd': args.htpasswd,
697            'debug':    args.debug,
698        })])
699        server = HTTPServer(app, ssl_options=sslopt)
700        server.listen(args.port, address=args.host)
701
702        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
703            'ioloop':   ioloop,
704            'switch':   switch,
705            'htpasswd': args.ctlhtpasswd,
706            'debug':    args.debug,
707        })])
708        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
709        ctlserver.listen(args.ctlport, address=args.ctlhost)
710
711    if not args.foreground:
712        daemonize()
713
714    ioloop.start()
715
716
717if __name__ == '__main__':
718    main()
Note: See TracBrowser for help on using the repository browser.