source: etherws/trunk/etherws.py @ 179

Revision 179, 16.5 KB checked in by atzm, 12 years ago (diff)
  • handle authentication per instance, not class
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#              Ethernet over WebSocket tunneling server/client
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.2.1
11#
12# todo:
13#   - servant mode support (like typical p2p software)
14#
15# ===========================================================================
16# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
17# All rights reserved.
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are met:
21#
22# 1. Redistributions of source code must retain the above copyright notice,
23#    this list of conditions and the following disclaimer.
24# 2. Redistributions in binary form must reproduce the above copyright
25#    notice, this list of conditions and the following disclaimer in the
26#    documentation and/or other materials provided with the distribution.
27#
28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38# POSSIBILITY OF SUCH DAMAGE.
39# ===========================================================================
40#
41# $Id$
42
43import os
44import sys
45import ssl
46import time
47import fcntl
48import base64
49import hashlib
50import getpass
51import argparse
52import traceback
53
54import websocket
55import tornado.web
56import tornado.ioloop
57import tornado.websocket
58import tornado.httpserver
59
60from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
61
62
63class DebugMixIn(object):
64    def dprintf(self, msg, func=lambda: ()):
65        if self._debug:
66            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
67            sys.stderr.write(prefix + (msg % func()))
68
69
70class EthernetFrame(object):
71    def __init__(self, data):
72        self.data = data
73
74    @property
75    def dst_multicast(self):
76        return ord(self.data[0]) & 1
77
78    @property
79    def src_multicast(self):
80        return ord(self.data[6]) & 1
81
82    @property
83    def dst_mac(self):
84        return self.data[:6]
85
86    @property
87    def src_mac(self):
88        return self.data[6:12]
89
90    @property
91    def tagged(self):
92        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
93
94    @property
95    def vid(self):
96        if self.tagged:
97            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
98        return -1
99
100
101class FDB(DebugMixIn):
102    def __init__(self, ageout, debug=False):
103        self._ageout = ageout
104        self._debug = debug
105        self._dict = {}
106
107    def lookup(self, frame):
108        mac = frame.dst_mac
109        vid = frame.vid
110
111        group = self._dict.get(vid)
112        if not group:
113            return None
114
115        entry = group.get(mac)
116        if not entry:
117            return None
118
119        if time.time() - entry['time'] > self._ageout:
120            del self._dict[vid][mac]
121            if not self._dict[vid]:
122                del self._dict[vid]
123            self.dprintf('aged out: [%d] %s\n',
124                         lambda: (vid, mac.encode('hex')))
125            return None
126
127        return entry['port']
128
129    def learn(self, port, frame):
130        mac = frame.src_mac
131        vid = frame.vid
132
133        if vid not in self._dict:
134            self._dict[vid] = {}
135
136        self._dict[vid][mac] = {'time': time.time(), 'port': port}
137        self.dprintf('learned: [%d] %s\n',
138                     lambda: (vid, mac.encode('hex')))
139
140    def delete(self, port):
141        for vid in self._dict.keys():
142            for mac in self._dict[vid].keys():
143                if self._dict[vid][mac]['port'] is port:
144                    del self._dict[vid][mac]
145                    self.dprintf('deleted: [%d] %s\n',
146                                 lambda: (vid, mac.encode('hex')))
147            if not self._dict[vid]:
148                del self._dict[vid]
149
150
151class SwitchingHub(DebugMixIn):
152    def __init__(self, fdb, debug=False):
153        self._fdb = fdb
154        self._debug = debug
155        self._ports = []
156
157    def register_port(self, port):
158        self._ports.append(port)
159
160    def unregister_port(self, port):
161        self._fdb.delete(port)
162        self._ports.remove(port)
163
164    def forward(self, src_port, frame):
165        try:
166            if not frame.src_multicast:
167                self._fdb.learn(src_port, frame)
168
169            if not frame.dst_multicast:
170                dst_port = self._fdb.lookup(frame)
171
172                if dst_port:
173                    self._unicast(frame, dst_port)
174                    return
175
176            self._broadcast(frame, src_port)
177
178        except:  # ex. received invalid frame
179            traceback.print_exc()
180
181    def _unicast(self, frame, port):
182        port.write_message(frame.data, True)
183        self.dprintf('sent unicast: [%d] %s -> %s\n',
184                     lambda: (frame.vid,
185                              frame.src_mac.encode('hex'),
186                              frame.dst_mac.encode('hex')))
187
188    def _broadcast(self, frame, *except_ports):
189        ports = self._ports[:]
190        for port in except_ports:
191            ports.remove(port)
192        for port in ports:
193            port.write_message(frame.data, True)
194        self.dprintf('sent broadcast: [%d] %s -> %s\n',
195                     lambda: (frame.vid,
196                              frame.src_mac.encode('hex'),
197                              frame.dst_mac.encode('hex')))
198
199
200class Htpasswd(object):
201    def __init__(self, path):
202        self._path = path
203        self._stat = None
204        self._data = {}
205
206    def auth(self, name, passwd):
207        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
208        return self._data.get(name) == passwd
209
210    def load(self):
211        old_stat = self._stat
212
213        with open(self._path) as fp:
214            fileno = fp.fileno()
215            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
216            self._stat = os.fstat(fileno)
217
218            unchanged = old_stat and \
219                        old_stat.st_ino == self._stat.st_ino and \
220                        old_stat.st_dev == self._stat.st_dev and \
221                        old_stat.st_mtime == self._stat.st_mtime
222
223            if not unchanged:
224                self._data = self._parse(fp)
225
226        return self
227
228    def _parse(self, fp):
229        data = {}
230        for line in fp:
231            line = line.strip()
232            if 0 <= line.find(':'):
233                name, passwd = line.split(':', 1)
234                if passwd.startswith('{SHA}'):
235                    data[name] = passwd[5:]
236        return data
237
238
239class TapHandler(DebugMixIn):
240    READ_SIZE = 65535
241
242    def __init__(self, ioloop, switch, dev, debug=False):
243        self._ioloop = ioloop
244        self._switch = switch
245        self._dev = dev
246        self._debug = debug
247        self._tap = None
248
249    @property
250    def closed(self):
251        return not self._tap
252
253    def open(self):
254        if not self.closed:
255            raise ValueError('already opened')
256        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
257        self._tap.up()
258        self._switch.register_port(self)
259        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
260
261    def close(self):
262        if self.closed:
263            raise ValueError('I/O operation on closed tap')
264        self._ioloop.remove_handler(self.fileno())
265        self._switch.unregister_port(self)
266        self._tap.close()
267        self._tap = None
268
269    def fileno(self):
270        if self.closed:
271            raise ValueError('I/O operation on closed tap')
272        return self._tap.fileno()
273
274    def write_message(self, message, binary=False):
275        if self.closed:
276            raise ValueError('I/O operation on closed tap')
277        self._tap.write(message)
278
279    def __call__(self, fd, events):
280        try:
281            self._switch.forward(self, EthernetFrame(self._read()))
282            return
283        except:
284            traceback.print_exc()
285        self.close()
286
287    def _read(self):
288        if self.closed:
289            raise ValueError('I/O operation on closed tap')
290        buf = []
291        while True:
292            buf.append(self._tap.read(self.READ_SIZE))
293            if len(buf[-1]) < self.READ_SIZE:
294                break
295        return ''.join(buf)
296
297
298class EtherWebSocketHandler(tornado.websocket.WebSocketHandler, DebugMixIn):
299    def __init__(self, app, req, switch, htpasswd=None, debug=False):
300        super(EtherWebSocketHandler, self).__init__(app, req)
301        self._switch = switch
302        self._htpasswd = htpasswd
303        self._debug = debug
304
305        if self._htpasswd:
306            self._htpasswd = Htpasswd(self._htpasswd)
307
308    def open(self):
309        self._switch.register_port(self)
310        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
311
312    def on_message(self, message):
313        self._switch.forward(self, EthernetFrame(message))
314
315    def on_close(self):
316        self._switch.unregister_port(self)
317        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
318
319    def _execute(self, transforms, *args, **kwargs):
320        def do_execute():
321            sp = super(EtherWebSocketHandler, self)
322            return sp._execute(transforms, *args, **kwargs)
323
324        def auth_required():
325            self.stream.write(tornado.escape.utf8(
326                'HTTP/1.1 401 Authorization Required\r\n'
327                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
328            ))
329            self.stream.close()
330
331        try:
332            if not self._htpasswd:
333                return do_execute()
334
335            creds = self.request.headers.get('Authorization')
336
337            if not creds or not creds.startswith('Basic '):
338                return auth_required()
339
340            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
341
342            if self._htpasswd.load().auth(name, passwd):
343                return do_execute()
344        except:
345            traceback.print_exc()
346
347        return auth_required()
348
349
350class EtherWebSocketClient(DebugMixIn):
351    def __init__(self, ioloop, switch, url, cred=None, debug=False):
352        self._ioloop = ioloop
353        self._switch = switch
354        self._url = url
355        self._debug = debug
356        self._sock = None
357        self._options = {}
358
359        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
360            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
361            auth = ['Authorization: Basic %s' % token]
362            self._options['header'] = auth
363
364    @property
365    def closed(self):
366        return not self._sock
367
368    def open(self):
369        if not self.closed:
370            raise websocket.WebSocketException('already opened')
371        self._sock = websocket.WebSocket()
372        self._sock.connect(self._url, **self._options)
373        self._switch.register_port(self)
374        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
375        self.dprintf('connected: %s\n', lambda: self._url)
376
377    def close(self):
378        if self.closed:
379            raise websocket.WebSocketException('already closed')
380        self._ioloop.remove_handler(self.fileno())
381        self._switch.unregister_port(self)
382        self._sock.close()
383        self._sock = None
384        self.dprintf('disconnected: %s\n', lambda: self._url)
385
386    def fileno(self):
387        if self.closed:
388            raise websocket.WebSocketException('closed socket')
389        return self._sock.io_sock.fileno()
390
391    def write_message(self, message, binary=False):
392        if self.closed:
393            raise websocket.WebSocketException('closed socket')
394        if binary:
395            flag = websocket.ABNF.OPCODE_BINARY
396        else:
397            flag = websocket.ABNF.OPCODE_TEXT
398        self._sock.send(message, flag)
399
400    def __call__(self, fd, events):
401        try:
402            data = self._sock.recv()
403            if data is not None:
404                self._switch.forward(self, EthernetFrame(data))
405                return
406        except:
407            traceback.print_exc()
408        self.close()
409
410
411def daemonize(nochdir=False, noclose=False):
412    if os.fork() > 0:
413        sys.exit(0)
414
415    os.setsid()
416
417    if os.fork() > 0:
418        sys.exit(0)
419
420    if not nochdir:
421        os.chdir('/')
422
423    if not noclose:
424        os.umask(0)
425        sys.stdin.close()
426        sys.stdout.close()
427        sys.stderr.close()
428        os.close(0)
429        os.close(1)
430        os.close(2)
431        sys.stdin = open(os.devnull)
432        sys.stdout = open(os.devnull, 'a')
433        sys.stderr = open(os.devnull, 'a')
434
435
436def realpath(ns, *keys):
437    for k in keys:
438        v = getattr(ns, k, None)
439        if v is not None:
440            v = os.path.realpath(v)
441            setattr(ns, k, v)
442            open(v).close()  # check readable
443    return ns
444
445
446def server_main(args):
447    realpath(args, 'keyfile', 'certfile', 'htpasswd')
448
449    if args.keyfile and args.certfile:
450        ssl_options = {'keyfile': args.keyfile, 'certfile': args.certfile}
451    elif args.keyfile or args.certfile:
452        raise ValueError('both keyfile and certfile are required')
453    else:
454        ssl_options = None
455
456    if args.port is None:
457        if ssl_options:
458            args.port = 443
459        else:
460            args.port = 80
461    elif not (0 <= args.port <= 65535):
462        raise ValueError('invalid port: %s' % args.port)
463
464    if args.ageout <= 0:
465        raise ValueError('invalid ageout: %s' % args.ageout)
466
467    ioloop = tornado.ioloop.IOLoop.instance()
468    fdb = FDB(ageout=args.ageout, debug=args.debug)
469    switch = SwitchingHub(fdb, debug=args.debug)
470
471    harg = {'switch': switch, 'htpasswd': args.htpasswd, 'debug': args.debug}
472    serv = (args.path, EtherWebSocketHandler, harg)
473    app = tornado.web.Application([serv])
474    server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_options)
475    server.listen(args.port, address=args.address)
476
477    for dev in args.device:
478        tap = TapHandler(ioloop, switch, dev, debug=args.debug)
479        tap.open()
480
481    if not args.foreground:
482        daemonize()
483
484    ioloop.start()
485
486
487def client_main(args):
488    realpath(args, 'cacerts')
489
490    if args.debug:
491        websocket.enableTrace(True)
492
493    if args.insecure:
494        websocket._SSLSocketWrapper = \
495            lambda s: ssl.wrap_socket(s)
496    else:
497        websocket._SSLSocketWrapper = \
498            lambda s: ssl.wrap_socket(s, cert_reqs=ssl.CERT_REQUIRED,
499                                      ca_certs=args.cacerts)
500
501    if args.ageout <= 0:
502        raise ValueError('invalid ageout: %s' % args.ageout)
503
504    if args.user and args.passwd is None:
505        args.passwd = getpass.getpass()
506
507    cred = {'user': args.user, 'passwd': args.passwd}
508    ioloop = tornado.ioloop.IOLoop.instance()
509    fdb = FDB(ageout=args.ageout, debug=args.debug)
510    switch = SwitchingHub(fdb, debug=args.debug)
511
512    for uri in args.uri:
513        client = EtherWebSocketClient(ioloop, switch, uri, cred, args.debug)
514        client.open()
515
516    for dev in args.device:
517        tap = TapHandler(ioloop, switch, dev, debug=args.debug)
518        tap.open()
519
520    if not args.foreground:
521        daemonize()
522
523    ioloop.start()
524
525
526def main():
527    parser = argparse.ArgumentParser()
528    parser.add_argument('--device', action='append', default=[])
529    parser.add_argument('--ageout', action='store', type=int, default=300)
530    parser.add_argument('--foreground', action='store_true', default=False)
531    parser.add_argument('--debug', action='store_true', default=False)
532
533    subparsers = parser.add_subparsers(dest='subcommand')
534
535    parser_s = subparsers.add_parser('server')
536    parser_s.add_argument('--address', action='store', default='')
537    parser_s.add_argument('--port', action='store', type=int)
538    parser_s.add_argument('--path', action='store', default='/')
539    parser_s.add_argument('--htpasswd', action='store')
540    parser_s.add_argument('--keyfile', action='store')
541    parser_s.add_argument('--certfile', action='store')
542
543    parser_c = subparsers.add_parser('client')
544    parser_c.add_argument('--uri', action='append', default=[])
545    parser_c.add_argument('--insecure', action='store_true', default=False)
546    parser_c.add_argument('--cacerts', action='store')
547    parser_c.add_argument('--user', action='store')
548    parser_c.add_argument('--passwd', action='store')
549
550    args = parser.parse_args()
551
552    if args.subcommand == 'server':
553        server_main(args)
554    elif args.subcommand == 'client':
555        client_main(args)
556
557
558if __name__ == '__main__':
559    main()
Note: See TracBrowser for help on using the repository browser.