source: etherws/trunk/etherws.py @ 182

Revision 182, 16.8 KB checked in by atzm, 12 years ago (diff)
  • mix-in'ize basic auth handler
  • Property svn:keywords set to Id
RevLine 
[133]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
[141]4#              Ethernet over WebSocket tunneling server/client
[133]5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
[136]9#   - websocket-client-0.7.0
10#   - tornado-2.2.1
[133]11#
[140]12# todo:
[143]13#   - servant mode support (like typical p2p software)
[140]14#
[133]15# ===========================================================================
16# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
17# All rights reserved.
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are met:
21#
22# 1. Redistributions of source code must retain the above copyright notice,
23#    this list of conditions and the following disclaimer.
24# 2. Redistributions in binary form must reproduce the above copyright
25#    notice, this list of conditions and the following disclaimer in the
26#    documentation and/or other materials provided with the distribution.
27#
28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38# POSSIBILITY OF SUCH DAMAGE.
39# ===========================================================================
40#
41# $Id$
42
43import os
44import sys
[156]45import ssl
[160]46import time
[175]47import fcntl
[150]48import base64
49import hashlib
[151]50import getpass
[133]51import argparse
[165]52import traceback
[133]53
54import websocket
[160]55import tornado.web
[133]56import tornado.ioloop
[160]57import tornado.httpserver
[133]58
[182]59from tornado.websocket import WebSocketHandler
[166]60from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
[133]61
[166]62
[160]63class DebugMixIn(object):
[166]64    def dprintf(self, msg, func=lambda: ()):
[160]65        if self._debug:
66            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
[164]67            sys.stderr.write(prefix + (msg % func()))
[160]68
69
[164]70class EthernetFrame(object):
71    def __init__(self, data):
72        self.data = data
73
[176]74    @property
75    def dst_multicast(self):
76        return ord(self.data[0]) & 1
[164]77
78    @property
[176]79    def src_multicast(self):
80        return ord(self.data[6]) & 1
81
82    @property
[164]83    def dst_mac(self):
84        return self.data[:6]
85
86    @property
87    def src_mac(self):
88        return self.data[6:12]
89
90    @property
91    def tagged(self):
92        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
93
94    @property
95    def vid(self):
96        if self.tagged:
97            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
98        return -1
99
100
[166]101class FDB(DebugMixIn):
[167]102    def __init__(self, ageout, debug=False):
[164]103        self._ageout = ageout
104        self._debug = debug
[166]105        self._dict = {}
[164]106
107    def lookup(self, frame):
108        mac = frame.dst_mac
109        vid = frame.vid
110
[177]111        group = self._dict.get(vid)
[164]112        if not group:
113            return None
114
[177]115        entry = group.get(mac)
[164]116        if not entry:
117            return None
118
119        if time.time() - entry['time'] > self._ageout:
[166]120            del self._dict[vid][mac]
121            if not self._dict[vid]:
122                del self._dict[vid]
[164]123            self.dprintf('aged out: [%d] %s\n',
124                         lambda: (vid, mac.encode('hex')))
125            return None
126
127        return entry['port']
128
[166]129    def learn(self, port, frame):
130        mac = frame.src_mac
131        vid = frame.vid
132
133        if vid not in self._dict:
134            self._dict[vid] = {}
135
136        self._dict[vid][mac] = {'time': time.time(), 'port': port}
137        self.dprintf('learned: [%d] %s\n',
138                     lambda: (vid, mac.encode('hex')))
139
[164]140    def delete(self, port):
[166]141        for vid in self._dict.keys():
142            for mac in self._dict[vid].keys():
143                if self._dict[vid][mac]['port'] is port:
144                    del self._dict[vid][mac]
[164]145                    self.dprintf('deleted: [%d] %s\n',
146                                 lambda: (vid, mac.encode('hex')))
[166]147            if not self._dict[vid]:
148                del self._dict[vid]
[164]149
150
[166]151class SwitchingHub(DebugMixIn):
152    def __init__(self, fdb, debug=False):
153        self._fdb = fdb
[133]154        self._debug = debug
[166]155        self._ports = []
[133]156
[166]157    def register_port(self, port):
158        self._ports.append(port)
[133]159
[166]160    def unregister_port(self, port):
161        self._fdb.delete(port)
162        self._ports.remove(port)
[133]163
[166]164    def forward(self, src_port, frame):
165        try:
[176]166            if not frame.src_multicast:
[172]167                self._fdb.learn(src_port, frame)
[133]168
[176]169            if not frame.dst_multicast:
[166]170                dst_port = self._fdb.lookup(frame)
[164]171
[166]172                if dst_port:
173                    self._unicast(frame, dst_port)
174                    return
[133]175
[166]176            self._broadcast(frame, src_port)
[162]177
[166]178        except:  # ex. received invalid frame
179            traceback.print_exc()
[133]180
[166]181    def _unicast(self, frame, port):
182        port.write_message(frame.data, True)
183        self.dprintf('sent unicast: [%d] %s -> %s\n',
184                     lambda: (frame.vid,
185                              frame.src_mac.encode('hex'),
186                              frame.dst_mac.encode('hex')))
[164]187
[166]188    def _broadcast(self, frame, *except_ports):
189        ports = self._ports[:]
190        for port in except_ports:
191            ports.remove(port)
192        for port in ports:
193            port.write_message(frame.data, True)
[164]194        self.dprintf('sent broadcast: [%d] %s -> %s\n',
195                     lambda: (frame.vid,
196                              frame.src_mac.encode('hex'),
197                              frame.dst_mac.encode('hex')))
198
[166]199
[179]200class Htpasswd(object):
201    def __init__(self, path):
202        self._path = path
203        self._stat = None
204        self._data = {}
205
206    def auth(self, name, passwd):
207        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
208        return self._data.get(name) == passwd
209
210    def load(self):
211        old_stat = self._stat
212
213        with open(self._path) as fp:
214            fileno = fp.fileno()
215            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
216            self._stat = os.fstat(fileno)
217
218            unchanged = old_stat and \
219                        old_stat.st_ino == self._stat.st_ino and \
220                        old_stat.st_dev == self._stat.st_dev and \
221                        old_stat.st_mtime == self._stat.st_mtime
222
223            if not unchanged:
224                self._data = self._parse(fp)
225
226        return self
227
228    def _parse(self, fp):
229        data = {}
230        for line in fp:
231            line = line.strip()
232            if 0 <= line.find(':'):
233                name, passwd = line.split(':', 1)
234                if passwd.startswith('{SHA}'):
235                    data[name] = passwd[5:]
236        return data
237
238
[182]239class BasicAuthMixIn(object):
240    def _execute(self, transforms, *args, **kwargs):
241        def do_execute():
242            sp = super(BasicAuthMixIn, self)
243            return sp._execute(transforms, *args, **kwargs)
244
245        def auth_required():
246            self.stream.write(tornado.escape.utf8(
247                'HTTP/1.1 401 Authorization Required\r\n'
248                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
249            ))
250            self.stream.close()
251
252        try:
253            if not self._htpasswd:
254                return do_execute()
255
256            creds = self.request.headers.get('Authorization')
257
258            if not creds or not creds.startswith('Basic '):
259                return auth_required()
260
261            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
262
263            if self._htpasswd.load().auth(name, passwd):
264                return do_execute()
265        except:
266            traceback.print_exc()
267
268        return auth_required()
269
270
[166]271class TapHandler(DebugMixIn):
272    READ_SIZE = 65535
273
[178]274    def __init__(self, ioloop, switch, dev, debug=False):
275        self._ioloop = ioloop
[166]276        self._switch = switch
277        self._dev = dev
278        self._debug = debug
279        self._tap = None
280
281    @property
282    def closed(self):
283        return not self._tap
284
285    def open(self):
286        if not self.closed:
287            raise ValueError('already opened')
288        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
289        self._tap.up()
290        self._switch.register_port(self)
[178]291        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
[166]292
293    def close(self):
294        if self.closed:
295            raise ValueError('I/O operation on closed tap')
[178]296        self._ioloop.remove_handler(self.fileno())
[166]297        self._switch.unregister_port(self)
298        self._tap.close()
299        self._tap = None
300
301    def fileno(self):
302        if self.closed:
303            raise ValueError('I/O operation on closed tap')
304        return self._tap.fileno()
305
306    def write_message(self, message, binary=False):
307        if self.closed:
308            raise ValueError('I/O operation on closed tap')
309        self._tap.write(message)
310
[138]311    def __call__(self, fd, events):
[166]312        try:
313            self._switch.forward(self, EthernetFrame(self._read()))
314            return
315        except:
316            traceback.print_exc()
[178]317        self.close()
[166]318
319    def _read(self):
320        if self.closed:
321            raise ValueError('I/O operation on closed tap')
[162]322        buf = []
323        while True:
[166]324            buf.append(self._tap.read(self.READ_SIZE))
325            if len(buf[-1]) < self.READ_SIZE:
[162]326                break
[166]327        return ''.join(buf)
[162]328
329
[182]330class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
[179]331    def __init__(self, app, req, switch, htpasswd=None, debug=False):
[160]332        super(EtherWebSocketHandler, self).__init__(app, req)
[166]333        self._switch = switch
[179]334        self._htpasswd = htpasswd
[133]335        self._debug = debug
336
[179]337        if self._htpasswd:
338            self._htpasswd = Htpasswd(self._htpasswd)
339
[133]340    def open(self):
[166]341        self._switch.register_port(self)
[164]342        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
[133]343
344    def on_message(self, message):
[166]345        self._switch.forward(self, EthernetFrame(message))
[133]346
347    def on_close(self):
[166]348        self._switch.unregister_port(self)
[164]349        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
[133]350
351
[160]352class EtherWebSocketClient(DebugMixIn):
[181]353    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
[178]354        self._ioloop = ioloop
[166]355        self._switch = switch
[151]356        self._url = url
[181]357        self._ssl = ssl_
[160]358        self._debug = debug
[166]359        self._sock = None
[151]360        self._options = {}
361
[174]362        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
363            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
[151]364            auth = ['Authorization: Basic %s' % token]
365            self._options['header'] = auth
366
[160]367    @property
368    def closed(self):
369        return not self._sock
370
[151]371    def open(self):
[181]372        sslwrap = websocket._SSLSocketWrapper
373
[160]374        if not self.closed:
375            raise websocket.WebSocketException('already opened')
[151]376
[181]377        if self._ssl:
378            websocket._SSLSocketWrapper = self._ssl
379
380        try:
381            self._sock = websocket.WebSocket()
382            self._sock.connect(self._url, **self._options)
383            self._switch.register_port(self)
384            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
385            self.dprintf('connected: %s\n', lambda: self._url)
386        finally:
387            websocket._SSLSocketWrapper = sslwrap
388
[151]389    def close(self):
[160]390        if self.closed:
391            raise websocket.WebSocketException('already closed')
[178]392        self._ioloop.remove_handler(self.fileno())
[166]393        self._switch.unregister_port(self)
[151]394        self._sock.close()
395        self._sock = None
[164]396        self.dprintf('disconnected: %s\n', lambda: self._url)
[151]397
[165]398    def fileno(self):
399        if self.closed:
400            raise websocket.WebSocketException('closed socket')
401        return self._sock.io_sock.fileno()
402
[151]403    def write_message(self, message, binary=False):
[160]404        if self.closed:
405            raise websocket.WebSocketException('closed socket')
[151]406        if binary:
407            flag = websocket.ABNF.OPCODE_BINARY
[160]408        else:
409            flag = websocket.ABNF.OPCODE_TEXT
[151]410        self._sock.send(message, flag)
411
[165]412    def __call__(self, fd, events):
[151]413        try:
[165]414            data = self._sock.recv()
415            if data is not None:
[166]416                self._switch.forward(self, EthernetFrame(data))
[165]417                return
418        except:
419            traceback.print_exc()
[178]420        self.close()
[151]421
422
[134]423def daemonize(nochdir=False, noclose=False):
424    if os.fork() > 0:
425        sys.exit(0)
426
427    os.setsid()
428
429    if os.fork() > 0:
430        sys.exit(0)
431
432    if not nochdir:
433        os.chdir('/')
434
435    if not noclose:
436        os.umask(0)
437        sys.stdin.close()
438        sys.stdout.close()
439        sys.stderr.close()
440        os.close(0)
441        os.close(1)
442        os.close(2)
443        sys.stdin = open(os.devnull)
444        sys.stdout = open(os.devnull, 'a')
445        sys.stderr = open(os.devnull, 'a')
446
447
[160]448def realpath(ns, *keys):
449    for k in keys:
450        v = getattr(ns, k, None)
451        if v is not None:
452            v = os.path.realpath(v)
453            setattr(ns, k, v)
454            open(v).close()  # check readable
455    return ns
456
457
[180]458def ssl_wrapper(insecure, ca_certs):
459    args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': ca_certs}
460    if insecure:
461        args = {}
462    return lambda sock: ssl.wrap_socket(sock, **args)
463
464
[133]465def server_main(args):
[160]466    realpath(args, 'keyfile', 'certfile', 'htpasswd')
[143]467
[160]468    if args.keyfile and args.certfile:
469        ssl_options = {'keyfile': args.keyfile, 'certfile': args.certfile}
470    elif args.keyfile or args.certfile:
[143]471        raise ValueError('both keyfile and certfile are required')
[160]472    else:
[143]473        ssl_options = None
474
[160]475    if args.port is None:
[143]476        if ssl_options:
477            args.port = 443
478        else:
479            args.port = 80
[160]480    elif not (0 <= args.port <= 65535):
481        raise ValueError('invalid port: %s' % args.port)
[143]482
[167]483    if args.ageout <= 0:
484        raise ValueError('invalid ageout: %s' % args.ageout)
485
486    ioloop = tornado.ioloop.IOLoop.instance()
487    fdb = FDB(ageout=args.ageout, debug=args.debug)
[181]488    sw = SwitchingHub(fdb, debug=args.debug)
[167]489
[181]490    harg = {'switch': sw, 'htpasswd': args.htpasswd, 'debug': args.debug}
[179]491    serv = (args.path, EtherWebSocketHandler, harg)
492    app = tornado.web.Application([serv])
[143]493    server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_options)
[133]494    server.listen(args.port, address=args.address)
495
[174]496    for dev in args.device:
[181]497        tap = TapHandler(ioloop, sw, dev, debug=args.debug)
[167]498        tap.open()
[151]499
500    if not args.foreground:
501        daemonize()
502
[138]503    ioloop.start()
[133]504
505
506def client_main(args):
[160]507    realpath(args, 'cacerts')
508
[133]509    if args.debug:
510        websocket.enableTrace(True)
511
[174]512    if args.ageout <= 0:
513        raise ValueError('invalid ageout: %s' % args.ageout)
514
[160]515    if args.user and args.passwd is None:
516        args.passwd = getpass.getpass()
[143]517
[181]518    ssl_ = ssl_wrapper(args.insecure, args.cacerts)
[174]519    cred = {'user': args.user, 'passwd': args.passwd}
[167]520    ioloop = tornado.ioloop.IOLoop.instance()
521    fdb = FDB(ageout=args.ageout, debug=args.debug)
[181]522    sw = SwitchingHub(fdb, debug=args.debug)
[167]523
[174]524    for uri in args.uri:
[181]525        client = EtherWebSocketClient(ioloop, sw, uri, ssl_, cred, args.debug)
[168]526        client.open()
527
[174]528    for dev in args.device:
[181]529        tap = TapHandler(ioloop, sw, dev, debug=args.debug)
[167]530        tap.open()
531
[151]532    if not args.foreground:
533        daemonize()
534
[165]535    ioloop.start()
[133]536
[138]537
[133]538def main():
539    parser = argparse.ArgumentParser()
[167]540    parser.add_argument('--device', action='append', default=[])
541    parser.add_argument('--ageout', action='store', type=int, default=300)
[133]542    parser.add_argument('--foreground', action='store_true', default=False)
543    parser.add_argument('--debug', action='store_true', default=False)
544
545    subparsers = parser.add_subparsers(dest='subcommand')
546
[158]547    parser_s = subparsers.add_parser('server')
548    parser_s.add_argument('--address', action='store', default='')
549    parser_s.add_argument('--port', action='store', type=int)
550    parser_s.add_argument('--path', action='store', default='/')
551    parser_s.add_argument('--htpasswd', action='store')
552    parser_s.add_argument('--keyfile', action='store')
553    parser_s.add_argument('--certfile', action='store')
[133]554
[158]555    parser_c = subparsers.add_parser('client')
[168]556    parser_c.add_argument('--uri', action='append', default=[])
[158]557    parser_c.add_argument('--insecure', action='store_true', default=False)
558    parser_c.add_argument('--cacerts', action='store')
559    parser_c.add_argument('--user', action='store')
[160]560    parser_c.add_argument('--passwd', action='store')
[133]561
562    args = parser.parse_args()
563
564    if args.subcommand == 'server':
565        server_main(args)
566    elif args.subcommand == 'client':
567        client_main(args)
568
569
570if __name__ == '__main__':
571    main()
Note: See TracBrowser for help on using the repository browser.