source: etherws/trunk/etherws.py @ 182

Revision 182, 16.8 KB checked in by atzm, 12 years ago (diff)
  • mix-in'ize basic auth handler
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#              Ethernet over WebSocket tunneling server/client
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.2.1
11#
12# todo:
13#   - servant mode support (like typical p2p software)
14#
15# ===========================================================================
16# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
17# All rights reserved.
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are met:
21#
22# 1. Redistributions of source code must retain the above copyright notice,
23#    this list of conditions and the following disclaimer.
24# 2. Redistributions in binary form must reproduce the above copyright
25#    notice, this list of conditions and the following disclaimer in the
26#    documentation and/or other materials provided with the distribution.
27#
28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38# POSSIBILITY OF SUCH DAMAGE.
39# ===========================================================================
40#
41# $Id$
42
43import os
44import sys
45import ssl
46import time
47import fcntl
48import base64
49import hashlib
50import getpass
51import argparse
52import traceback
53
54import websocket
55import tornado.web
56import tornado.ioloop
57import tornado.httpserver
58
59from tornado.websocket import WebSocketHandler
60from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
61
62
63class DebugMixIn(object):
64    def dprintf(self, msg, func=lambda: ()):
65        if self._debug:
66            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
67            sys.stderr.write(prefix + (msg % func()))
68
69
70class EthernetFrame(object):
71    def __init__(self, data):
72        self.data = data
73
74    @property
75    def dst_multicast(self):
76        return ord(self.data[0]) & 1
77
78    @property
79    def src_multicast(self):
80        return ord(self.data[6]) & 1
81
82    @property
83    def dst_mac(self):
84        return self.data[:6]
85
86    @property
87    def src_mac(self):
88        return self.data[6:12]
89
90    @property
91    def tagged(self):
92        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
93
94    @property
95    def vid(self):
96        if self.tagged:
97            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
98        return -1
99
100
101class FDB(DebugMixIn):
102    def __init__(self, ageout, debug=False):
103        self._ageout = ageout
104        self._debug = debug
105        self._dict = {}
106
107    def lookup(self, frame):
108        mac = frame.dst_mac
109        vid = frame.vid
110
111        group = self._dict.get(vid)
112        if not group:
113            return None
114
115        entry = group.get(mac)
116        if not entry:
117            return None
118
119        if time.time() - entry['time'] > self._ageout:
120            del self._dict[vid][mac]
121            if not self._dict[vid]:
122                del self._dict[vid]
123            self.dprintf('aged out: [%d] %s\n',
124                         lambda: (vid, mac.encode('hex')))
125            return None
126
127        return entry['port']
128
129    def learn(self, port, frame):
130        mac = frame.src_mac
131        vid = frame.vid
132
133        if vid not in self._dict:
134            self._dict[vid] = {}
135
136        self._dict[vid][mac] = {'time': time.time(), 'port': port}
137        self.dprintf('learned: [%d] %s\n',
138                     lambda: (vid, mac.encode('hex')))
139
140    def delete(self, port):
141        for vid in self._dict.keys():
142            for mac in self._dict[vid].keys():
143                if self._dict[vid][mac]['port'] is port:
144                    del self._dict[vid][mac]
145                    self.dprintf('deleted: [%d] %s\n',
146                                 lambda: (vid, mac.encode('hex')))
147            if not self._dict[vid]:
148                del self._dict[vid]
149
150
151class SwitchingHub(DebugMixIn):
152    def __init__(self, fdb, debug=False):
153        self._fdb = fdb
154        self._debug = debug
155        self._ports = []
156
157    def register_port(self, port):
158        self._ports.append(port)
159
160    def unregister_port(self, port):
161        self._fdb.delete(port)
162        self._ports.remove(port)
163
164    def forward(self, src_port, frame):
165        try:
166            if not frame.src_multicast:
167                self._fdb.learn(src_port, frame)
168
169            if not frame.dst_multicast:
170                dst_port = self._fdb.lookup(frame)
171
172                if dst_port:
173                    self._unicast(frame, dst_port)
174                    return
175
176            self._broadcast(frame, src_port)
177
178        except:  # ex. received invalid frame
179            traceback.print_exc()
180
181    def _unicast(self, frame, port):
182        port.write_message(frame.data, True)
183        self.dprintf('sent unicast: [%d] %s -> %s\n',
184                     lambda: (frame.vid,
185                              frame.src_mac.encode('hex'),
186                              frame.dst_mac.encode('hex')))
187
188    def _broadcast(self, frame, *except_ports):
189        ports = self._ports[:]
190        for port in except_ports:
191            ports.remove(port)
192        for port in ports:
193            port.write_message(frame.data, True)
194        self.dprintf('sent broadcast: [%d] %s -> %s\n',
195                     lambda: (frame.vid,
196                              frame.src_mac.encode('hex'),
197                              frame.dst_mac.encode('hex')))
198
199
200class Htpasswd(object):
201    def __init__(self, path):
202        self._path = path
203        self._stat = None
204        self._data = {}
205
206    def auth(self, name, passwd):
207        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
208        return self._data.get(name) == passwd
209
210    def load(self):
211        old_stat = self._stat
212
213        with open(self._path) as fp:
214            fileno = fp.fileno()
215            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
216            self._stat = os.fstat(fileno)
217
218            unchanged = old_stat and \
219                        old_stat.st_ino == self._stat.st_ino and \
220                        old_stat.st_dev == self._stat.st_dev and \
221                        old_stat.st_mtime == self._stat.st_mtime
222
223            if not unchanged:
224                self._data = self._parse(fp)
225
226        return self
227
228    def _parse(self, fp):
229        data = {}
230        for line in fp:
231            line = line.strip()
232            if 0 <= line.find(':'):
233                name, passwd = line.split(':', 1)
234                if passwd.startswith('{SHA}'):
235                    data[name] = passwd[5:]
236        return data
237
238
239class BasicAuthMixIn(object):
240    def _execute(self, transforms, *args, **kwargs):
241        def do_execute():
242            sp = super(BasicAuthMixIn, self)
243            return sp._execute(transforms, *args, **kwargs)
244
245        def auth_required():
246            self.stream.write(tornado.escape.utf8(
247                'HTTP/1.1 401 Authorization Required\r\n'
248                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
249            ))
250            self.stream.close()
251
252        try:
253            if not self._htpasswd:
254                return do_execute()
255
256            creds = self.request.headers.get('Authorization')
257
258            if not creds or not creds.startswith('Basic '):
259                return auth_required()
260
261            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
262
263            if self._htpasswd.load().auth(name, passwd):
264                return do_execute()
265        except:
266            traceback.print_exc()
267
268        return auth_required()
269
270
271class TapHandler(DebugMixIn):
272    READ_SIZE = 65535
273
274    def __init__(self, ioloop, switch, dev, debug=False):
275        self._ioloop = ioloop
276        self._switch = switch
277        self._dev = dev
278        self._debug = debug
279        self._tap = None
280
281    @property
282    def closed(self):
283        return not self._tap
284
285    def open(self):
286        if not self.closed:
287            raise ValueError('already opened')
288        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
289        self._tap.up()
290        self._switch.register_port(self)
291        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
292
293    def close(self):
294        if self.closed:
295            raise ValueError('I/O operation on closed tap')
296        self._ioloop.remove_handler(self.fileno())
297        self._switch.unregister_port(self)
298        self._tap.close()
299        self._tap = None
300
301    def fileno(self):
302        if self.closed:
303            raise ValueError('I/O operation on closed tap')
304        return self._tap.fileno()
305
306    def write_message(self, message, binary=False):
307        if self.closed:
308            raise ValueError('I/O operation on closed tap')
309        self._tap.write(message)
310
311    def __call__(self, fd, events):
312        try:
313            self._switch.forward(self, EthernetFrame(self._read()))
314            return
315        except:
316            traceback.print_exc()
317        self.close()
318
319    def _read(self):
320        if self.closed:
321            raise ValueError('I/O operation on closed tap')
322        buf = []
323        while True:
324            buf.append(self._tap.read(self.READ_SIZE))
325            if len(buf[-1]) < self.READ_SIZE:
326                break
327        return ''.join(buf)
328
329
330class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
331    def __init__(self, app, req, switch, htpasswd=None, debug=False):
332        super(EtherWebSocketHandler, self).__init__(app, req)
333        self._switch = switch
334        self._htpasswd = htpasswd
335        self._debug = debug
336
337        if self._htpasswd:
338            self._htpasswd = Htpasswd(self._htpasswd)
339
340    def open(self):
341        self._switch.register_port(self)
342        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
343
344    def on_message(self, message):
345        self._switch.forward(self, EthernetFrame(message))
346
347    def on_close(self):
348        self._switch.unregister_port(self)
349        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
350
351
352class EtherWebSocketClient(DebugMixIn):
353    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
354        self._ioloop = ioloop
355        self._switch = switch
356        self._url = url
357        self._ssl = ssl_
358        self._debug = debug
359        self._sock = None
360        self._options = {}
361
362        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
363            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
364            auth = ['Authorization: Basic %s' % token]
365            self._options['header'] = auth
366
367    @property
368    def closed(self):
369        return not self._sock
370
371    def open(self):
372        sslwrap = websocket._SSLSocketWrapper
373
374        if not self.closed:
375            raise websocket.WebSocketException('already opened')
376
377        if self._ssl:
378            websocket._SSLSocketWrapper = self._ssl
379
380        try:
381            self._sock = websocket.WebSocket()
382            self._sock.connect(self._url, **self._options)
383            self._switch.register_port(self)
384            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
385            self.dprintf('connected: %s\n', lambda: self._url)
386        finally:
387            websocket._SSLSocketWrapper = sslwrap
388
389    def close(self):
390        if self.closed:
391            raise websocket.WebSocketException('already closed')
392        self._ioloop.remove_handler(self.fileno())
393        self._switch.unregister_port(self)
394        self._sock.close()
395        self._sock = None
396        self.dprintf('disconnected: %s\n', lambda: self._url)
397
398    def fileno(self):
399        if self.closed:
400            raise websocket.WebSocketException('closed socket')
401        return self._sock.io_sock.fileno()
402
403    def write_message(self, message, binary=False):
404        if self.closed:
405            raise websocket.WebSocketException('closed socket')
406        if binary:
407            flag = websocket.ABNF.OPCODE_BINARY
408        else:
409            flag = websocket.ABNF.OPCODE_TEXT
410        self._sock.send(message, flag)
411
412    def __call__(self, fd, events):
413        try:
414            data = self._sock.recv()
415            if data is not None:
416                self._switch.forward(self, EthernetFrame(data))
417                return
418        except:
419            traceback.print_exc()
420        self.close()
421
422
423def daemonize(nochdir=False, noclose=False):
424    if os.fork() > 0:
425        sys.exit(0)
426
427    os.setsid()
428
429    if os.fork() > 0:
430        sys.exit(0)
431
432    if not nochdir:
433        os.chdir('/')
434
435    if not noclose:
436        os.umask(0)
437        sys.stdin.close()
438        sys.stdout.close()
439        sys.stderr.close()
440        os.close(0)
441        os.close(1)
442        os.close(2)
443        sys.stdin = open(os.devnull)
444        sys.stdout = open(os.devnull, 'a')
445        sys.stderr = open(os.devnull, 'a')
446
447
448def realpath(ns, *keys):
449    for k in keys:
450        v = getattr(ns, k, None)
451        if v is not None:
452            v = os.path.realpath(v)
453            setattr(ns, k, v)
454            open(v).close()  # check readable
455    return ns
456
457
458def ssl_wrapper(insecure, ca_certs):
459    args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': ca_certs}
460    if insecure:
461        args = {}
462    return lambda sock: ssl.wrap_socket(sock, **args)
463
464
465def server_main(args):
466    realpath(args, 'keyfile', 'certfile', 'htpasswd')
467
468    if args.keyfile and args.certfile:
469        ssl_options = {'keyfile': args.keyfile, 'certfile': args.certfile}
470    elif args.keyfile or args.certfile:
471        raise ValueError('both keyfile and certfile are required')
472    else:
473        ssl_options = None
474
475    if args.port is None:
476        if ssl_options:
477            args.port = 443
478        else:
479            args.port = 80
480    elif not (0 <= args.port <= 65535):
481        raise ValueError('invalid port: %s' % args.port)
482
483    if args.ageout <= 0:
484        raise ValueError('invalid ageout: %s' % args.ageout)
485
486    ioloop = tornado.ioloop.IOLoop.instance()
487    fdb = FDB(ageout=args.ageout, debug=args.debug)
488    sw = SwitchingHub(fdb, debug=args.debug)
489
490    harg = {'switch': sw, 'htpasswd': args.htpasswd, 'debug': args.debug}
491    serv = (args.path, EtherWebSocketHandler, harg)
492    app = tornado.web.Application([serv])
493    server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_options)
494    server.listen(args.port, address=args.address)
495
496    for dev in args.device:
497        tap = TapHandler(ioloop, sw, dev, debug=args.debug)
498        tap.open()
499
500    if not args.foreground:
501        daemonize()
502
503    ioloop.start()
504
505
506def client_main(args):
507    realpath(args, 'cacerts')
508
509    if args.debug:
510        websocket.enableTrace(True)
511
512    if args.ageout <= 0:
513        raise ValueError('invalid ageout: %s' % args.ageout)
514
515    if args.user and args.passwd is None:
516        args.passwd = getpass.getpass()
517
518    ssl_ = ssl_wrapper(args.insecure, args.cacerts)
519    cred = {'user': args.user, 'passwd': args.passwd}
520    ioloop = tornado.ioloop.IOLoop.instance()
521    fdb = FDB(ageout=args.ageout, debug=args.debug)
522    sw = SwitchingHub(fdb, debug=args.debug)
523
524    for uri in args.uri:
525        client = EtherWebSocketClient(ioloop, sw, uri, ssl_, cred, args.debug)
526        client.open()
527
528    for dev in args.device:
529        tap = TapHandler(ioloop, sw, dev, debug=args.debug)
530        tap.open()
531
532    if not args.foreground:
533        daemonize()
534
535    ioloop.start()
536
537
538def main():
539    parser = argparse.ArgumentParser()
540    parser.add_argument('--device', action='append', default=[])
541    parser.add_argument('--ageout', action='store', type=int, default=300)
542    parser.add_argument('--foreground', action='store_true', default=False)
543    parser.add_argument('--debug', action='store_true', default=False)
544
545    subparsers = parser.add_subparsers(dest='subcommand')
546
547    parser_s = subparsers.add_parser('server')
548    parser_s.add_argument('--address', action='store', default='')
549    parser_s.add_argument('--port', action='store', type=int)
550    parser_s.add_argument('--path', action='store', default='/')
551    parser_s.add_argument('--htpasswd', action='store')
552    parser_s.add_argument('--keyfile', action='store')
553    parser_s.add_argument('--certfile', action='store')
554
555    parser_c = subparsers.add_parser('client')
556    parser_c.add_argument('--uri', action='append', default=[])
557    parser_c.add_argument('--insecure', action='store_true', default=False)
558    parser_c.add_argument('--cacerts', action='store')
559    parser_c.add_argument('--user', action='store')
560    parser_c.add_argument('--passwd', action='store')
561
562    args = parser.parse_args()
563
564    if args.subcommand == 'server':
565        server_main(args)
566    elif args.subcommand == 'client':
567        client_main(args)
568
569
570if __name__ == '__main__':
571    main()
Note: See TracBrowser for help on using the repository browser.