source: etherws/trunk/etherws.py @ 176

Revision 176, 16.4 KB checked in by atzm, 12 years ago (diff)
  • reduce string copy
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#              Ethernet over WebSocket tunneling server/client
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.2.1
11#
12# todo:
13#   - servant mode support (like typical p2p software)
14#
15# ===========================================================================
16# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
17# All rights reserved.
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are met:
21#
22# 1. Redistributions of source code must retain the above copyright notice,
23#    this list of conditions and the following disclaimer.
24# 2. Redistributions in binary form must reproduce the above copyright
25#    notice, this list of conditions and the following disclaimer in the
26#    documentation and/or other materials provided with the distribution.
27#
28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38# POSSIBILITY OF SUCH DAMAGE.
39# ===========================================================================
40#
41# $Id$
42
43import os
44import sys
45import ssl
46import time
47import fcntl
48import base64
49import hashlib
50import getpass
51import argparse
52import traceback
53
54import websocket
55import tornado.web
56import tornado.ioloop
57import tornado.websocket
58import tornado.httpserver
59
60from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
61
62
63class DebugMixIn(object):
64    def dprintf(self, msg, func=lambda: ()):
65        if self._debug:
66            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
67            sys.stderr.write(prefix + (msg % func()))
68
69
70class EthernetFrame(object):
71    def __init__(self, data):
72        self.data = data
73
74    @property
75    def dst_multicast(self):
76        return ord(self.data[0]) & 1
77
78    @property
79    def src_multicast(self):
80        return ord(self.data[6]) & 1
81
82    @property
83    def dst_mac(self):
84        return self.data[:6]
85
86    @property
87    def src_mac(self):
88        return self.data[6:12]
89
90    @property
91    def tagged(self):
92        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
93
94    @property
95    def vid(self):
96        if self.tagged:
97            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
98        return -1
99
100
101class FDB(DebugMixIn):
102    def __init__(self, ageout, debug=False):
103        self._ageout = ageout
104        self._debug = debug
105        self._dict = {}
106
107    def lookup(self, frame):
108        mac = frame.dst_mac
109        vid = frame.vid
110
111        group = self._dict.get(vid, None)
112        if not group:
113            return None
114
115        entry = group.get(mac, None)
116        if not entry:
117            return None
118
119        if time.time() - entry['time'] > self._ageout:
120            del self._dict[vid][mac]
121            if not self._dict[vid]:
122                del self._dict[vid]
123            self.dprintf('aged out: [%d] %s\n',
124                         lambda: (vid, mac.encode('hex')))
125            return None
126
127        return entry['port']
128
129    def learn(self, port, frame):
130        mac = frame.src_mac
131        vid = frame.vid
132
133        if vid not in self._dict:
134            self._dict[vid] = {}
135
136        self._dict[vid][mac] = {'time': time.time(), 'port': port}
137        self.dprintf('learned: [%d] %s\n',
138                     lambda: (vid, mac.encode('hex')))
139
140    def delete(self, port):
141        for vid in self._dict.keys():
142            for mac in self._dict[vid].keys():
143                if self._dict[vid][mac]['port'] is port:
144                    del self._dict[vid][mac]
145                    self.dprintf('deleted: [%d] %s\n',
146                                 lambda: (vid, mac.encode('hex')))
147            if not self._dict[vid]:
148                del self._dict[vid]
149
150
151class SwitchingHub(DebugMixIn):
152    def __init__(self, fdb, debug=False):
153        self._fdb = fdb
154        self._debug = debug
155        self._ports = []
156
157    def register_port(self, port):
158        self._ports.append(port)
159
160    def unregister_port(self, port):
161        self._fdb.delete(port)
162        self._ports.remove(port)
163
164    def forward(self, src_port, frame):
165        try:
166            if not frame.src_multicast:
167                self._fdb.learn(src_port, frame)
168
169            if not frame.dst_multicast:
170                dst_port = self._fdb.lookup(frame)
171
172                if dst_port:
173                    self._unicast(frame, dst_port)
174                    return
175
176            self._broadcast(frame, src_port)
177
178        except:  # ex. received invalid frame
179            traceback.print_exc()
180
181    def _unicast(self, frame, port):
182        port.write_message(frame.data, True)
183        self.dprintf('sent unicast: [%d] %s -> %s\n',
184                     lambda: (frame.vid,
185                              frame.src_mac.encode('hex'),
186                              frame.dst_mac.encode('hex')))
187
188    def _broadcast(self, frame, *except_ports):
189        ports = self._ports[:]
190        for port in except_ports:
191            ports.remove(port)
192        for port in ports:
193            port.write_message(frame.data, True)
194        self.dprintf('sent broadcast: [%d] %s -> %s\n',
195                     lambda: (frame.vid,
196                              frame.src_mac.encode('hex'),
197                              frame.dst_mac.encode('hex')))
198
199
200class TapHandler(DebugMixIn):
201    READ_SIZE = 65535
202
203    def __init__(self, switch, dev, debug=False):
204        self._switch = switch
205        self._dev = dev
206        self._debug = debug
207        self._tap = None
208
209    @property
210    def closed(self):
211        return not self._tap
212
213    def open(self):
214        if not self.closed:
215            raise ValueError('already opened')
216        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
217        self._tap.up()
218        self._switch.register_port(self)
219
220    def close(self):
221        if self.closed:
222            raise ValueError('I/O operation on closed tap')
223        self._switch.unregister_port(self)
224        self._tap.close()
225        self._tap = None
226
227    def fileno(self):
228        if self.closed:
229            raise ValueError('I/O operation on closed tap')
230        return self._tap.fileno()
231
232    def write_message(self, message, binary=False):
233        if self.closed:
234            raise ValueError('I/O operation on closed tap')
235        self._tap.write(message)
236
237    def __call__(self, fd, events):
238        try:
239            self._switch.forward(self, EthernetFrame(self._read()))
240            return
241        except:
242            traceback.print_exc()
243        tornado.ioloop.IOLoop.instance().stop()  # XXX: should unregister fd
244
245    def _read(self):
246        if self.closed:
247            raise ValueError('I/O operation on closed tap')
248        buf = []
249        while True:
250            buf.append(self._tap.read(self.READ_SIZE))
251            if len(buf[-1]) < self.READ_SIZE:
252                break
253        return ''.join(buf)
254
255
256class EtherWebSocketHandler(tornado.websocket.WebSocketHandler, DebugMixIn):
257    def __init__(self, app, req, switch, debug=False):
258        super(EtherWebSocketHandler, self).__init__(app, req)
259        self._switch = switch
260        self._debug = debug
261
262    def open(self):
263        self._switch.register_port(self)
264        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
265
266    def on_message(self, message):
267        self._switch.forward(self, EthernetFrame(message))
268
269    def on_close(self):
270        self._switch.unregister_port(self)
271        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
272
273
274class EtherWebSocketClient(DebugMixIn):
275    def __init__(self, switch, url, cred=None, debug=False):
276        self._switch = switch
277        self._url = url
278        self._debug = debug
279        self._sock = None
280        self._options = {}
281
282        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
283            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
284            auth = ['Authorization: Basic %s' % token]
285            self._options['header'] = auth
286
287    @property
288    def closed(self):
289        return not self._sock
290
291    def open(self):
292        if not self.closed:
293            raise websocket.WebSocketException('already opened')
294        self._sock = websocket.WebSocket()
295        self._sock.connect(self._url, **self._options)
296        self._switch.register_port(self)
297        self.dprintf('connected: %s\n', lambda: self._url)
298
299    def close(self):
300        if self.closed:
301            raise websocket.WebSocketException('already closed')
302        self._switch.unregister_port(self)
303        self._sock.close()
304        self._sock = None
305        self.dprintf('disconnected: %s\n', lambda: self._url)
306
307    def fileno(self):
308        if self.closed:
309            raise websocket.WebSocketException('closed socket')
310        return self._sock.io_sock.fileno()
311
312    def write_message(self, message, binary=False):
313        if self.closed:
314            raise websocket.WebSocketException('closed socket')
315        if binary:
316            flag = websocket.ABNF.OPCODE_BINARY
317        else:
318            flag = websocket.ABNF.OPCODE_TEXT
319        self._sock.send(message, flag)
320
321    def __call__(self, fd, events):
322        try:
323            data = self._sock.recv()
324            if data is not None:
325                self._switch.forward(self, EthernetFrame(data))
326                return
327        except:
328            traceback.print_exc()
329        tornado.ioloop.IOLoop.instance().stop()  # XXX: should unregister fd
330
331
332class Htpasswd(object):
333    def __init__(self, path):
334        self._path = path
335        self._stat = None
336        self._data = {}
337
338    def auth(self, name, passwd):
339        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
340        return self._data.get(name) == passwd
341
342    def load(self):
343        old_stat = self._stat
344
345        with open(self._path) as fp:
346            fileno = fp.fileno()
347            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
348            self._stat = os.fstat(fileno)
349
350            unchanged = old_stat and \
351                        old_stat.st_ino == self._stat.st_ino and \
352                        old_stat.st_dev == self._stat.st_dev and \
353                        old_stat.st_mtime == self._stat.st_mtime
354
355            if not unchanged:
356                self._data = self._parse(fp)
357
358    def _parse(self, fp):
359        data = {}
360        for line in fp:
361            line = line.strip()
362            if 0 <= line.find(':'):
363                name, passwd = line.split(':', 1)
364                if passwd.startswith('{SHA}'):
365                    data[name] = passwd[5:]
366        return data
367
368
369def wrap_basic_auth(handler_class, htpasswd_path):
370    if not htpasswd_path:
371        return handler_class
372
373    old_execute = handler_class._execute
374    htpasswd = Htpasswd(htpasswd_path)
375
376    def execute(self, transforms, *args, **kwargs):
377        def auth_required():
378            self.stream.write(tornado.escape.utf8(
379                'HTTP/1.1 401 Authorization Required\r\n'
380                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
381            ))
382            self.stream.close()
383
384        creds = self.request.headers.get('Authorization')
385
386        if not creds or not creds.startswith('Basic '):
387            return auth_required()
388
389        try:
390            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
391            htpasswd.load()
392
393            if not htpasswd.auth(name, passwd):
394                return auth_required()
395
396            return old_execute(self, transforms, *args, **kwargs)
397
398        except:
399            return auth_required()
400
401    handler_class._execute = execute
402    return handler_class
403
404
405def daemonize(nochdir=False, noclose=False):
406    if os.fork() > 0:
407        sys.exit(0)
408
409    os.setsid()
410
411    if os.fork() > 0:
412        sys.exit(0)
413
414    if not nochdir:
415        os.chdir('/')
416
417    if not noclose:
418        os.umask(0)
419        sys.stdin.close()
420        sys.stdout.close()
421        sys.stderr.close()
422        os.close(0)
423        os.close(1)
424        os.close(2)
425        sys.stdin = open(os.devnull)
426        sys.stdout = open(os.devnull, 'a')
427        sys.stderr = open(os.devnull, 'a')
428
429
430def realpath(ns, *keys):
431    for k in keys:
432        v = getattr(ns, k, None)
433        if v is not None:
434            v = os.path.realpath(v)
435            setattr(ns, k, v)
436            open(v).close()  # check readable
437    return ns
438
439
440def server_main(args):
441    realpath(args, 'keyfile', 'certfile', 'htpasswd')
442
443    if args.keyfile and args.certfile:
444        ssl_options = {'keyfile': args.keyfile, 'certfile': args.certfile}
445    elif args.keyfile or args.certfile:
446        raise ValueError('both keyfile and certfile are required')
447    else:
448        ssl_options = None
449
450    if args.port is None:
451        if ssl_options:
452            args.port = 443
453        else:
454            args.port = 80
455    elif not (0 <= args.port <= 65535):
456        raise ValueError('invalid port: %s' % args.port)
457
458    if args.ageout <= 0:
459        raise ValueError('invalid ageout: %s' % args.ageout)
460
461    ioloop = tornado.ioloop.IOLoop.instance()
462    fdb = FDB(ageout=args.ageout, debug=args.debug)
463    switch = SwitchingHub(fdb, debug=args.debug)
464
465    handler = wrap_basic_auth(EtherWebSocketHandler, args.htpasswd)
466    srv = (args.path, handler, {'switch': switch, 'debug': args.debug})
467    app = tornado.web.Application([srv])
468    server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_options)
469    server.listen(args.port, address=args.address)
470
471    for dev in args.device:
472        tap = TapHandler(switch, dev, debug=args.debug)
473        tap.open()
474        ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
475
476    if not args.foreground:
477        daemonize()
478
479    ioloop.start()
480
481
482def client_main(args):
483    realpath(args, 'cacerts')
484
485    if args.debug:
486        websocket.enableTrace(True)
487
488    if args.insecure:
489        websocket._SSLSocketWrapper = \
490            lambda s: ssl.wrap_socket(s)
491    else:
492        websocket._SSLSocketWrapper = \
493            lambda s: ssl.wrap_socket(s, cert_reqs=ssl.CERT_REQUIRED,
494                                      ca_certs=args.cacerts)
495
496    if args.ageout <= 0:
497        raise ValueError('invalid ageout: %s' % args.ageout)
498
499    if args.user and args.passwd is None:
500        args.passwd = getpass.getpass()
501
502    cred = {'user': args.user, 'passwd': args.passwd}
503    ioloop = tornado.ioloop.IOLoop.instance()
504    fdb = FDB(ageout=args.ageout, debug=args.debug)
505    switch = SwitchingHub(fdb, debug=args.debug)
506
507    for uri in args.uri:
508        client = EtherWebSocketClient(switch, uri, cred, args.debug)
509        client.open()
510        ioloop.add_handler(client.fileno(), client, ioloop.READ)
511
512    for dev in args.device:
513        tap = TapHandler(switch, dev, debug=args.debug)
514        tap.open()
515        ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
516
517    if not args.foreground:
518        daemonize()
519
520    ioloop.start()
521
522
523def main():
524    parser = argparse.ArgumentParser()
525    parser.add_argument('--device', action='append', default=[])
526    parser.add_argument('--ageout', action='store', type=int, default=300)
527    parser.add_argument('--foreground', action='store_true', default=False)
528    parser.add_argument('--debug', action='store_true', default=False)
529
530    subparsers = parser.add_subparsers(dest='subcommand')
531
532    parser_s = subparsers.add_parser('server')
533    parser_s.add_argument('--address', action='store', default='')
534    parser_s.add_argument('--port', action='store', type=int)
535    parser_s.add_argument('--path', action='store', default='/')
536    parser_s.add_argument('--htpasswd', action='store')
537    parser_s.add_argument('--keyfile', action='store')
538    parser_s.add_argument('--certfile', action='store')
539
540    parser_c = subparsers.add_parser('client')
541    parser_c.add_argument('--uri', action='append', default=[])
542    parser_c.add_argument('--insecure', action='store_true', default=False)
543    parser_c.add_argument('--cacerts', action='store')
544    parser_c.add_argument('--user', action='store')
545    parser_c.add_argument('--passwd', action='store')
546
547    args = parser.parse_args()
548
549    if args.subcommand == 'server':
550        server_main(args)
551    elif args.subcommand == 'client':
552        client_main(args)
553
554
555if __name__ == '__main__':
556    main()
Note: See TracBrowser for help on using the repository browser.