source: etherws/trunk/etherws.py @ 178

Revision 178, 16.5 KB checked in by atzm, 12 years ago (diff)
  • fix error handling
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#              Ethernet over WebSocket tunneling server/client
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.2.1
11#
12# todo:
13#   - servant mode support (like typical p2p software)
14#
15# ===========================================================================
16# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
17# All rights reserved.
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are met:
21#
22# 1. Redistributions of source code must retain the above copyright notice,
23#    this list of conditions and the following disclaimer.
24# 2. Redistributions in binary form must reproduce the above copyright
25#    notice, this list of conditions and the following disclaimer in the
26#    documentation and/or other materials provided with the distribution.
27#
28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38# POSSIBILITY OF SUCH DAMAGE.
39# ===========================================================================
40#
41# $Id$
42
43import os
44import sys
45import ssl
46import time
47import fcntl
48import base64
49import hashlib
50import getpass
51import argparse
52import traceback
53
54import websocket
55import tornado.web
56import tornado.ioloop
57import tornado.websocket
58import tornado.httpserver
59
60from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
61
62
63class DebugMixIn(object):
64    def dprintf(self, msg, func=lambda: ()):
65        if self._debug:
66            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
67            sys.stderr.write(prefix + (msg % func()))
68
69
70class EthernetFrame(object):
71    def __init__(self, data):
72        self.data = data
73
74    @property
75    def dst_multicast(self):
76        return ord(self.data[0]) & 1
77
78    @property
79    def src_multicast(self):
80        return ord(self.data[6]) & 1
81
82    @property
83    def dst_mac(self):
84        return self.data[:6]
85
86    @property
87    def src_mac(self):
88        return self.data[6:12]
89
90    @property
91    def tagged(self):
92        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
93
94    @property
95    def vid(self):
96        if self.tagged:
97            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
98        return -1
99
100
101class FDB(DebugMixIn):
102    def __init__(self, ageout, debug=False):
103        self._ageout = ageout
104        self._debug = debug
105        self._dict = {}
106
107    def lookup(self, frame):
108        mac = frame.dst_mac
109        vid = frame.vid
110
111        group = self._dict.get(vid)
112        if not group:
113            return None
114
115        entry = group.get(mac)
116        if not entry:
117            return None
118
119        if time.time() - entry['time'] > self._ageout:
120            del self._dict[vid][mac]
121            if not self._dict[vid]:
122                del self._dict[vid]
123            self.dprintf('aged out: [%d] %s\n',
124                         lambda: (vid, mac.encode('hex')))
125            return None
126
127        return entry['port']
128
129    def learn(self, port, frame):
130        mac = frame.src_mac
131        vid = frame.vid
132
133        if vid not in self._dict:
134            self._dict[vid] = {}
135
136        self._dict[vid][mac] = {'time': time.time(), 'port': port}
137        self.dprintf('learned: [%d] %s\n',
138                     lambda: (vid, mac.encode('hex')))
139
140    def delete(self, port):
141        for vid in self._dict.keys():
142            for mac in self._dict[vid].keys():
143                if self._dict[vid][mac]['port'] is port:
144                    del self._dict[vid][mac]
145                    self.dprintf('deleted: [%d] %s\n',
146                                 lambda: (vid, mac.encode('hex')))
147            if not self._dict[vid]:
148                del self._dict[vid]
149
150
151class SwitchingHub(DebugMixIn):
152    def __init__(self, fdb, debug=False):
153        self._fdb = fdb
154        self._debug = debug
155        self._ports = []
156
157    def register_port(self, port):
158        self._ports.append(port)
159
160    def unregister_port(self, port):
161        self._fdb.delete(port)
162        self._ports.remove(port)
163
164    def forward(self, src_port, frame):
165        try:
166            if not frame.src_multicast:
167                self._fdb.learn(src_port, frame)
168
169            if not frame.dst_multicast:
170                dst_port = self._fdb.lookup(frame)
171
172                if dst_port:
173                    self._unicast(frame, dst_port)
174                    return
175
176            self._broadcast(frame, src_port)
177
178        except:  # ex. received invalid frame
179            traceback.print_exc()
180
181    def _unicast(self, frame, port):
182        port.write_message(frame.data, True)
183        self.dprintf('sent unicast: [%d] %s -> %s\n',
184                     lambda: (frame.vid,
185                              frame.src_mac.encode('hex'),
186                              frame.dst_mac.encode('hex')))
187
188    def _broadcast(self, frame, *except_ports):
189        ports = self._ports[:]
190        for port in except_ports:
191            ports.remove(port)
192        for port in ports:
193            port.write_message(frame.data, True)
194        self.dprintf('sent broadcast: [%d] %s -> %s\n',
195                     lambda: (frame.vid,
196                              frame.src_mac.encode('hex'),
197                              frame.dst_mac.encode('hex')))
198
199
200class TapHandler(DebugMixIn):
201    READ_SIZE = 65535
202
203    def __init__(self, ioloop, switch, dev, debug=False):
204        self._ioloop = ioloop
205        self._switch = switch
206        self._dev = dev
207        self._debug = debug
208        self._tap = None
209
210    @property
211    def closed(self):
212        return not self._tap
213
214    def open(self):
215        if not self.closed:
216            raise ValueError('already opened')
217        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
218        self._tap.up()
219        self._switch.register_port(self)
220        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
221
222    def close(self):
223        if self.closed:
224            raise ValueError('I/O operation on closed tap')
225        self._ioloop.remove_handler(self.fileno())
226        self._switch.unregister_port(self)
227        self._tap.close()
228        self._tap = None
229
230    def fileno(self):
231        if self.closed:
232            raise ValueError('I/O operation on closed tap')
233        return self._tap.fileno()
234
235    def write_message(self, message, binary=False):
236        if self.closed:
237            raise ValueError('I/O operation on closed tap')
238        self._tap.write(message)
239
240    def __call__(self, fd, events):
241        try:
242            self._switch.forward(self, EthernetFrame(self._read()))
243            return
244        except:
245            traceback.print_exc()
246        self.close()
247
248    def _read(self):
249        if self.closed:
250            raise ValueError('I/O operation on closed tap')
251        buf = []
252        while True:
253            buf.append(self._tap.read(self.READ_SIZE))
254            if len(buf[-1]) < self.READ_SIZE:
255                break
256        return ''.join(buf)
257
258
259class EtherWebSocketHandler(tornado.websocket.WebSocketHandler, DebugMixIn):
260    def __init__(self, app, req, switch, debug=False):
261        super(EtherWebSocketHandler, self).__init__(app, req)
262        self._switch = switch
263        self._debug = debug
264
265    def open(self):
266        self._switch.register_port(self)
267        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
268
269    def on_message(self, message):
270        self._switch.forward(self, EthernetFrame(message))
271
272    def on_close(self):
273        self._switch.unregister_port(self)
274        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
275
276
277class EtherWebSocketClient(DebugMixIn):
278    def __init__(self, ioloop, switch, url, cred=None, debug=False):
279        self._ioloop = ioloop
280        self._switch = switch
281        self._url = url
282        self._debug = debug
283        self._sock = None
284        self._options = {}
285
286        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
287            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
288            auth = ['Authorization: Basic %s' % token]
289            self._options['header'] = auth
290
291    @property
292    def closed(self):
293        return not self._sock
294
295    def open(self):
296        if not self.closed:
297            raise websocket.WebSocketException('already opened')
298        self._sock = websocket.WebSocket()
299        self._sock.connect(self._url, **self._options)
300        self._switch.register_port(self)
301        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
302        self.dprintf('connected: %s\n', lambda: self._url)
303
304    def close(self):
305        if self.closed:
306            raise websocket.WebSocketException('already closed')
307        self._ioloop.remove_handler(self.fileno())
308        self._switch.unregister_port(self)
309        self._sock.close()
310        self._sock = None
311        self.dprintf('disconnected: %s\n', lambda: self._url)
312
313    def fileno(self):
314        if self.closed:
315            raise websocket.WebSocketException('closed socket')
316        return self._sock.io_sock.fileno()
317
318    def write_message(self, message, binary=False):
319        if self.closed:
320            raise websocket.WebSocketException('closed socket')
321        if binary:
322            flag = websocket.ABNF.OPCODE_BINARY
323        else:
324            flag = websocket.ABNF.OPCODE_TEXT
325        self._sock.send(message, flag)
326
327    def __call__(self, fd, events):
328        try:
329            data = self._sock.recv()
330            if data is not None:
331                self._switch.forward(self, EthernetFrame(data))
332                return
333        except:
334            traceback.print_exc()
335        self.close()
336
337
338class Htpasswd(object):
339    def __init__(self, path):
340        self._path = path
341        self._stat = None
342        self._data = {}
343
344    def auth(self, name, passwd):
345        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
346        return self._data.get(name) == passwd
347
348    def load(self):
349        old_stat = self._stat
350
351        with open(self._path) as fp:
352            fileno = fp.fileno()
353            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
354            self._stat = os.fstat(fileno)
355
356            unchanged = old_stat and \
357                        old_stat.st_ino == self._stat.st_ino and \
358                        old_stat.st_dev == self._stat.st_dev and \
359                        old_stat.st_mtime == self._stat.st_mtime
360
361            if not unchanged:
362                self._data = self._parse(fp)
363
364    def _parse(self, fp):
365        data = {}
366        for line in fp:
367            line = line.strip()
368            if 0 <= line.find(':'):
369                name, passwd = line.split(':', 1)
370                if passwd.startswith('{SHA}'):
371                    data[name] = passwd[5:]
372        return data
373
374
375def wrap_basic_auth(handler_class, htpasswd_path):
376    if not htpasswd_path:
377        return handler_class
378
379    old_execute = handler_class._execute
380    htpasswd = Htpasswd(htpasswd_path)
381
382    def execute(self, transforms, *args, **kwargs):
383        def auth_required():
384            self.stream.write(tornado.escape.utf8(
385                'HTTP/1.1 401 Authorization Required\r\n'
386                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
387            ))
388            self.stream.close()
389
390        creds = self.request.headers.get('Authorization')
391
392        if not creds or not creds.startswith('Basic '):
393            return auth_required()
394
395        try:
396            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
397            htpasswd.load()
398
399            if not htpasswd.auth(name, passwd):
400                return auth_required()
401
402            return old_execute(self, transforms, *args, **kwargs)
403
404        except:
405            return auth_required()
406
407    handler_class._execute = execute
408    return handler_class
409
410
411def daemonize(nochdir=False, noclose=False):
412    if os.fork() > 0:
413        sys.exit(0)
414
415    os.setsid()
416
417    if os.fork() > 0:
418        sys.exit(0)
419
420    if not nochdir:
421        os.chdir('/')
422
423    if not noclose:
424        os.umask(0)
425        sys.stdin.close()
426        sys.stdout.close()
427        sys.stderr.close()
428        os.close(0)
429        os.close(1)
430        os.close(2)
431        sys.stdin = open(os.devnull)
432        sys.stdout = open(os.devnull, 'a')
433        sys.stderr = open(os.devnull, 'a')
434
435
436def realpath(ns, *keys):
437    for k in keys:
438        v = getattr(ns, k, None)
439        if v is not None:
440            v = os.path.realpath(v)
441            setattr(ns, k, v)
442            open(v).close()  # check readable
443    return ns
444
445
446def server_main(args):
447    realpath(args, 'keyfile', 'certfile', 'htpasswd')
448
449    if args.keyfile and args.certfile:
450        ssl_options = {'keyfile': args.keyfile, 'certfile': args.certfile}
451    elif args.keyfile or args.certfile:
452        raise ValueError('both keyfile and certfile are required')
453    else:
454        ssl_options = None
455
456    if args.port is None:
457        if ssl_options:
458            args.port = 443
459        else:
460            args.port = 80
461    elif not (0 <= args.port <= 65535):
462        raise ValueError('invalid port: %s' % args.port)
463
464    if args.ageout <= 0:
465        raise ValueError('invalid ageout: %s' % args.ageout)
466
467    ioloop = tornado.ioloop.IOLoop.instance()
468    fdb = FDB(ageout=args.ageout, debug=args.debug)
469    switch = SwitchingHub(fdb, debug=args.debug)
470
471    handler = wrap_basic_auth(EtherWebSocketHandler, args.htpasswd)
472    srv = (args.path, handler, {'switch': switch, 'debug': args.debug})
473    app = tornado.web.Application([srv])
474    server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_options)
475    server.listen(args.port, address=args.address)
476
477    for dev in args.device:
478        tap = TapHandler(ioloop, switch, dev, debug=args.debug)
479        tap.open()
480
481    if not args.foreground:
482        daemonize()
483
484    ioloop.start()
485
486
487def client_main(args):
488    realpath(args, 'cacerts')
489
490    if args.debug:
491        websocket.enableTrace(True)
492
493    if args.insecure:
494        websocket._SSLSocketWrapper = \
495            lambda s: ssl.wrap_socket(s)
496    else:
497        websocket._SSLSocketWrapper = \
498            lambda s: ssl.wrap_socket(s, cert_reqs=ssl.CERT_REQUIRED,
499                                      ca_certs=args.cacerts)
500
501    if args.ageout <= 0:
502        raise ValueError('invalid ageout: %s' % args.ageout)
503
504    if args.user and args.passwd is None:
505        args.passwd = getpass.getpass()
506
507    cred = {'user': args.user, 'passwd': args.passwd}
508    ioloop = tornado.ioloop.IOLoop.instance()
509    fdb = FDB(ageout=args.ageout, debug=args.debug)
510    switch = SwitchingHub(fdb, debug=args.debug)
511
512    for uri in args.uri:
513        client = EtherWebSocketClient(ioloop, switch, uri, cred, args.debug)
514        client.open()
515
516    for dev in args.device:
517        tap = TapHandler(ioloop, switch, dev, debug=args.debug)
518        tap.open()
519
520    if not args.foreground:
521        daemonize()
522
523    ioloop.start()
524
525
526def main():
527    parser = argparse.ArgumentParser()
528    parser.add_argument('--device', action='append', default=[])
529    parser.add_argument('--ageout', action='store', type=int, default=300)
530    parser.add_argument('--foreground', action='store_true', default=False)
531    parser.add_argument('--debug', action='store_true', default=False)
532
533    subparsers = parser.add_subparsers(dest='subcommand')
534
535    parser_s = subparsers.add_parser('server')
536    parser_s.add_argument('--address', action='store', default='')
537    parser_s.add_argument('--port', action='store', type=int)
538    parser_s.add_argument('--path', action='store', default='/')
539    parser_s.add_argument('--htpasswd', action='store')
540    parser_s.add_argument('--keyfile', action='store')
541    parser_s.add_argument('--certfile', action='store')
542
543    parser_c = subparsers.add_parser('client')
544    parser_c.add_argument('--uri', action='append', default=[])
545    parser_c.add_argument('--insecure', action='store_true', default=False)
546    parser_c.add_argument('--cacerts', action='store')
547    parser_c.add_argument('--user', action='store')
548    parser_c.add_argument('--passwd', action='store')
549
550    args = parser.parse_args()
551
552    if args.subcommand == 'server':
553        server_main(args)
554    elif args.subcommand == 'client':
555        client_main(args)
556
557
558if __name__ == '__main__':
559    main()
Note: See TracBrowser for help on using the repository browser.