source: etherws/trunk/etherws.py @ 181

Revision 181, 16.7 KB checked in by atzm, 12 years ago (diff)
  • select client ssl handler per instance, not class
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#              Ethernet over WebSocket tunneling server/client
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.2.1
11#
12# todo:
13#   - servant mode support (like typical p2p software)
14#
15# ===========================================================================
16# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
17# All rights reserved.
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are met:
21#
22# 1. Redistributions of source code must retain the above copyright notice,
23#    this list of conditions and the following disclaimer.
24# 2. Redistributions in binary form must reproduce the above copyright
25#    notice, this list of conditions and the following disclaimer in the
26#    documentation and/or other materials provided with the distribution.
27#
28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38# POSSIBILITY OF SUCH DAMAGE.
39# ===========================================================================
40#
41# $Id$
42
43import os
44import sys
45import ssl
46import time
47import fcntl
48import base64
49import hashlib
50import getpass
51import argparse
52import traceback
53
54import websocket
55import tornado.web
56import tornado.ioloop
57import tornado.websocket
58import tornado.httpserver
59
60from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
61
62
63class DebugMixIn(object):
64    def dprintf(self, msg, func=lambda: ()):
65        if self._debug:
66            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
67            sys.stderr.write(prefix + (msg % func()))
68
69
70class EthernetFrame(object):
71    def __init__(self, data):
72        self.data = data
73
74    @property
75    def dst_multicast(self):
76        return ord(self.data[0]) & 1
77
78    @property
79    def src_multicast(self):
80        return ord(self.data[6]) & 1
81
82    @property
83    def dst_mac(self):
84        return self.data[:6]
85
86    @property
87    def src_mac(self):
88        return self.data[6:12]
89
90    @property
91    def tagged(self):
92        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
93
94    @property
95    def vid(self):
96        if self.tagged:
97            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
98        return -1
99
100
101class FDB(DebugMixIn):
102    def __init__(self, ageout, debug=False):
103        self._ageout = ageout
104        self._debug = debug
105        self._dict = {}
106
107    def lookup(self, frame):
108        mac = frame.dst_mac
109        vid = frame.vid
110
111        group = self._dict.get(vid)
112        if not group:
113            return None
114
115        entry = group.get(mac)
116        if not entry:
117            return None
118
119        if time.time() - entry['time'] > self._ageout:
120            del self._dict[vid][mac]
121            if not self._dict[vid]:
122                del self._dict[vid]
123            self.dprintf('aged out: [%d] %s\n',
124                         lambda: (vid, mac.encode('hex')))
125            return None
126
127        return entry['port']
128
129    def learn(self, port, frame):
130        mac = frame.src_mac
131        vid = frame.vid
132
133        if vid not in self._dict:
134            self._dict[vid] = {}
135
136        self._dict[vid][mac] = {'time': time.time(), 'port': port}
137        self.dprintf('learned: [%d] %s\n',
138                     lambda: (vid, mac.encode('hex')))
139
140    def delete(self, port):
141        for vid in self._dict.keys():
142            for mac in self._dict[vid].keys():
143                if self._dict[vid][mac]['port'] is port:
144                    del self._dict[vid][mac]
145                    self.dprintf('deleted: [%d] %s\n',
146                                 lambda: (vid, mac.encode('hex')))
147            if not self._dict[vid]:
148                del self._dict[vid]
149
150
151class SwitchingHub(DebugMixIn):
152    def __init__(self, fdb, debug=False):
153        self._fdb = fdb
154        self._debug = debug
155        self._ports = []
156
157    def register_port(self, port):
158        self._ports.append(port)
159
160    def unregister_port(self, port):
161        self._fdb.delete(port)
162        self._ports.remove(port)
163
164    def forward(self, src_port, frame):
165        try:
166            if not frame.src_multicast:
167                self._fdb.learn(src_port, frame)
168
169            if not frame.dst_multicast:
170                dst_port = self._fdb.lookup(frame)
171
172                if dst_port:
173                    self._unicast(frame, dst_port)
174                    return
175
176            self._broadcast(frame, src_port)
177
178        except:  # ex. received invalid frame
179            traceback.print_exc()
180
181    def _unicast(self, frame, port):
182        port.write_message(frame.data, True)
183        self.dprintf('sent unicast: [%d] %s -> %s\n',
184                     lambda: (frame.vid,
185                              frame.src_mac.encode('hex'),
186                              frame.dst_mac.encode('hex')))
187
188    def _broadcast(self, frame, *except_ports):
189        ports = self._ports[:]
190        for port in except_ports:
191            ports.remove(port)
192        for port in ports:
193            port.write_message(frame.data, True)
194        self.dprintf('sent broadcast: [%d] %s -> %s\n',
195                     lambda: (frame.vid,
196                              frame.src_mac.encode('hex'),
197                              frame.dst_mac.encode('hex')))
198
199
200class Htpasswd(object):
201    def __init__(self, path):
202        self._path = path
203        self._stat = None
204        self._data = {}
205
206    def auth(self, name, passwd):
207        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
208        return self._data.get(name) == passwd
209
210    def load(self):
211        old_stat = self._stat
212
213        with open(self._path) as fp:
214            fileno = fp.fileno()
215            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
216            self._stat = os.fstat(fileno)
217
218            unchanged = old_stat and \
219                        old_stat.st_ino == self._stat.st_ino and \
220                        old_stat.st_dev == self._stat.st_dev and \
221                        old_stat.st_mtime == self._stat.st_mtime
222
223            if not unchanged:
224                self._data = self._parse(fp)
225
226        return self
227
228    def _parse(self, fp):
229        data = {}
230        for line in fp:
231            line = line.strip()
232            if 0 <= line.find(':'):
233                name, passwd = line.split(':', 1)
234                if passwd.startswith('{SHA}'):
235                    data[name] = passwd[5:]
236        return data
237
238
239class TapHandler(DebugMixIn):
240    READ_SIZE = 65535
241
242    def __init__(self, ioloop, switch, dev, debug=False):
243        self._ioloop = ioloop
244        self._switch = switch
245        self._dev = dev
246        self._debug = debug
247        self._tap = None
248
249    @property
250    def closed(self):
251        return not self._tap
252
253    def open(self):
254        if not self.closed:
255            raise ValueError('already opened')
256        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
257        self._tap.up()
258        self._switch.register_port(self)
259        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
260
261    def close(self):
262        if self.closed:
263            raise ValueError('I/O operation on closed tap')
264        self._ioloop.remove_handler(self.fileno())
265        self._switch.unregister_port(self)
266        self._tap.close()
267        self._tap = None
268
269    def fileno(self):
270        if self.closed:
271            raise ValueError('I/O operation on closed tap')
272        return self._tap.fileno()
273
274    def write_message(self, message, binary=False):
275        if self.closed:
276            raise ValueError('I/O operation on closed tap')
277        self._tap.write(message)
278
279    def __call__(self, fd, events):
280        try:
281            self._switch.forward(self, EthernetFrame(self._read()))
282            return
283        except:
284            traceback.print_exc()
285        self.close()
286
287    def _read(self):
288        if self.closed:
289            raise ValueError('I/O operation on closed tap')
290        buf = []
291        while True:
292            buf.append(self._tap.read(self.READ_SIZE))
293            if len(buf[-1]) < self.READ_SIZE:
294                break
295        return ''.join(buf)
296
297
298class EtherWebSocketHandler(tornado.websocket.WebSocketHandler, DebugMixIn):
299    def __init__(self, app, req, switch, htpasswd=None, debug=False):
300        super(EtherWebSocketHandler, self).__init__(app, req)
301        self._switch = switch
302        self._htpasswd = htpasswd
303        self._debug = debug
304
305        if self._htpasswd:
306            self._htpasswd = Htpasswd(self._htpasswd)
307
308    def open(self):
309        self._switch.register_port(self)
310        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
311
312    def on_message(self, message):
313        self._switch.forward(self, EthernetFrame(message))
314
315    def on_close(self):
316        self._switch.unregister_port(self)
317        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
318
319    def _execute(self, transforms, *args, **kwargs):
320        def do_execute():
321            sp = super(EtherWebSocketHandler, self)
322            return sp._execute(transforms, *args, **kwargs)
323
324        def auth_required():
325            self.stream.write(tornado.escape.utf8(
326                'HTTP/1.1 401 Authorization Required\r\n'
327                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
328            ))
329            self.stream.close()
330
331        try:
332            if not self._htpasswd:
333                return do_execute()
334
335            creds = self.request.headers.get('Authorization')
336
337            if not creds or not creds.startswith('Basic '):
338                return auth_required()
339
340            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
341
342            if self._htpasswd.load().auth(name, passwd):
343                return do_execute()
344        except:
345            traceback.print_exc()
346
347        return auth_required()
348
349
350class EtherWebSocketClient(DebugMixIn):
351    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
352        self._ioloop = ioloop
353        self._switch = switch
354        self._url = url
355        self._ssl = ssl_
356        self._debug = debug
357        self._sock = None
358        self._options = {}
359
360        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
361            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
362            auth = ['Authorization: Basic %s' % token]
363            self._options['header'] = auth
364
365    @property
366    def closed(self):
367        return not self._sock
368
369    def open(self):
370        sslwrap = websocket._SSLSocketWrapper
371
372        if not self.closed:
373            raise websocket.WebSocketException('already opened')
374
375        if self._ssl:
376            websocket._SSLSocketWrapper = self._ssl
377
378        try:
379            self._sock = websocket.WebSocket()
380            self._sock.connect(self._url, **self._options)
381            self._switch.register_port(self)
382            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
383            self.dprintf('connected: %s\n', lambda: self._url)
384        finally:
385            websocket._SSLSocketWrapper = sslwrap
386
387    def close(self):
388        if self.closed:
389            raise websocket.WebSocketException('already closed')
390        self._ioloop.remove_handler(self.fileno())
391        self._switch.unregister_port(self)
392        self._sock.close()
393        self._sock = None
394        self.dprintf('disconnected: %s\n', lambda: self._url)
395
396    def fileno(self):
397        if self.closed:
398            raise websocket.WebSocketException('closed socket')
399        return self._sock.io_sock.fileno()
400
401    def write_message(self, message, binary=False):
402        if self.closed:
403            raise websocket.WebSocketException('closed socket')
404        if binary:
405            flag = websocket.ABNF.OPCODE_BINARY
406        else:
407            flag = websocket.ABNF.OPCODE_TEXT
408        self._sock.send(message, flag)
409
410    def __call__(self, fd, events):
411        try:
412            data = self._sock.recv()
413            if data is not None:
414                self._switch.forward(self, EthernetFrame(data))
415                return
416        except:
417            traceback.print_exc()
418        self.close()
419
420
421def daemonize(nochdir=False, noclose=False):
422    if os.fork() > 0:
423        sys.exit(0)
424
425    os.setsid()
426
427    if os.fork() > 0:
428        sys.exit(0)
429
430    if not nochdir:
431        os.chdir('/')
432
433    if not noclose:
434        os.umask(0)
435        sys.stdin.close()
436        sys.stdout.close()
437        sys.stderr.close()
438        os.close(0)
439        os.close(1)
440        os.close(2)
441        sys.stdin = open(os.devnull)
442        sys.stdout = open(os.devnull, 'a')
443        sys.stderr = open(os.devnull, 'a')
444
445
446def realpath(ns, *keys):
447    for k in keys:
448        v = getattr(ns, k, None)
449        if v is not None:
450            v = os.path.realpath(v)
451            setattr(ns, k, v)
452            open(v).close()  # check readable
453    return ns
454
455
456def ssl_wrapper(insecure, ca_certs):
457    args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': ca_certs}
458    if insecure:
459        args = {}
460    return lambda sock: ssl.wrap_socket(sock, **args)
461
462
463def server_main(args):
464    realpath(args, 'keyfile', 'certfile', 'htpasswd')
465
466    if args.keyfile and args.certfile:
467        ssl_options = {'keyfile': args.keyfile, 'certfile': args.certfile}
468    elif args.keyfile or args.certfile:
469        raise ValueError('both keyfile and certfile are required')
470    else:
471        ssl_options = None
472
473    if args.port is None:
474        if ssl_options:
475            args.port = 443
476        else:
477            args.port = 80
478    elif not (0 <= args.port <= 65535):
479        raise ValueError('invalid port: %s' % args.port)
480
481    if args.ageout <= 0:
482        raise ValueError('invalid ageout: %s' % args.ageout)
483
484    ioloop = tornado.ioloop.IOLoop.instance()
485    fdb = FDB(ageout=args.ageout, debug=args.debug)
486    sw = SwitchingHub(fdb, debug=args.debug)
487
488    harg = {'switch': sw, 'htpasswd': args.htpasswd, 'debug': args.debug}
489    serv = (args.path, EtherWebSocketHandler, harg)
490    app = tornado.web.Application([serv])
491    server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_options)
492    server.listen(args.port, address=args.address)
493
494    for dev in args.device:
495        tap = TapHandler(ioloop, sw, dev, debug=args.debug)
496        tap.open()
497
498    if not args.foreground:
499        daemonize()
500
501    ioloop.start()
502
503
504def client_main(args):
505    realpath(args, 'cacerts')
506
507    if args.debug:
508        websocket.enableTrace(True)
509
510    if args.ageout <= 0:
511        raise ValueError('invalid ageout: %s' % args.ageout)
512
513    if args.user and args.passwd is None:
514        args.passwd = getpass.getpass()
515
516    ssl_ = ssl_wrapper(args.insecure, args.cacerts)
517    cred = {'user': args.user, 'passwd': args.passwd}
518    ioloop = tornado.ioloop.IOLoop.instance()
519    fdb = FDB(ageout=args.ageout, debug=args.debug)
520    sw = SwitchingHub(fdb, debug=args.debug)
521
522    for uri in args.uri:
523        client = EtherWebSocketClient(ioloop, sw, uri, ssl_, cred, args.debug)
524        client.open()
525
526    for dev in args.device:
527        tap = TapHandler(ioloop, sw, dev, debug=args.debug)
528        tap.open()
529
530    if not args.foreground:
531        daemonize()
532
533    ioloop.start()
534
535
536def main():
537    parser = argparse.ArgumentParser()
538    parser.add_argument('--device', action='append', default=[])
539    parser.add_argument('--ageout', action='store', type=int, default=300)
540    parser.add_argument('--foreground', action='store_true', default=False)
541    parser.add_argument('--debug', action='store_true', default=False)
542
543    subparsers = parser.add_subparsers(dest='subcommand')
544
545    parser_s = subparsers.add_parser('server')
546    parser_s.add_argument('--address', action='store', default='')
547    parser_s.add_argument('--port', action='store', type=int)
548    parser_s.add_argument('--path', action='store', default='/')
549    parser_s.add_argument('--htpasswd', action='store')
550    parser_s.add_argument('--keyfile', action='store')
551    parser_s.add_argument('--certfile', action='store')
552
553    parser_c = subparsers.add_parser('client')
554    parser_c.add_argument('--uri', action='append', default=[])
555    parser_c.add_argument('--insecure', action='store_true', default=False)
556    parser_c.add_argument('--cacerts', action='store')
557    parser_c.add_argument('--user', action='store')
558    parser_c.add_argument('--passwd', action='store')
559
560    args = parser.parse_args()
561
562    if args.subcommand == 'server':
563        server_main(args)
564    elif args.subcommand == 'client':
565        client_main(args)
566
567
568if __name__ == '__main__':
569    main()
Note: See TracBrowser for help on using the repository browser.