source: etherws/trunk/etherws.py @ 174

Revision 174, 16.3 KB checked in by atzm, 12 years ago (diff)
  • dynamic htpasswd loading support
  • Property svn:keywords set to Id
RevLine 
[133]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
[141]4#              Ethernet over WebSocket tunneling server/client
[133]5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
[136]9#   - websocket-client-0.7.0
10#   - tornado-2.2.1
[133]11#
[140]12# todo:
[143]13#   - servant mode support (like typical p2p software)
[140]14#
[133]15# ===========================================================================
16# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
17# All rights reserved.
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are met:
21#
22# 1. Redistributions of source code must retain the above copyright notice,
23#    this list of conditions and the following disclaimer.
24# 2. Redistributions in binary form must reproduce the above copyright
25#    notice, this list of conditions and the following disclaimer in the
26#    documentation and/or other materials provided with the distribution.
27#
28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38# POSSIBILITY OF SUCH DAMAGE.
39# ===========================================================================
40#
41# $Id$
42
43import os
44import sys
[156]45import ssl
[160]46import time
[150]47import base64
48import hashlib
[151]49import getpass
[133]50import argparse
[165]51import traceback
[133]52
53import websocket
[160]54import tornado.web
[133]55import tornado.ioloop
56import tornado.websocket
[160]57import tornado.httpserver
[133]58
[166]59from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
[133]60
[166]61
[160]62class DebugMixIn(object):
[166]63    def dprintf(self, msg, func=lambda: ()):
[160]64        if self._debug:
65            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
[164]66            sys.stderr.write(prefix + (msg % func()))
[160]67
68
[164]69class EthernetFrame(object):
70    def __init__(self, data):
71        self.data = data
72
[173]73    @staticmethod
74    def multicast(mac):
75        return ord(mac[0]) & 1
[164]76
77    @property
78    def dst_mac(self):
79        return self.data[:6]
80
81    @property
82    def src_mac(self):
83        return self.data[6:12]
84
85    @property
86    def tagged(self):
87        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
88
89    @property
90    def vid(self):
91        if self.tagged:
92            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
93        return -1
94
95
[166]96class FDB(DebugMixIn):
[167]97    def __init__(self, ageout, debug=False):
[164]98        self._ageout = ageout
99        self._debug = debug
[166]100        self._dict = {}
[164]101
102    def lookup(self, frame):
103        mac = frame.dst_mac
104        vid = frame.vid
105
[166]106        group = self._dict.get(vid, None)
[164]107        if not group:
108            return None
109
110        entry = group.get(mac, None)
111        if not entry:
112            return None
113
114        if time.time() - entry['time'] > self._ageout:
[166]115            del self._dict[vid][mac]
116            if not self._dict[vid]:
117                del self._dict[vid]
[164]118            self.dprintf('aged out: [%d] %s\n',
119                         lambda: (vid, mac.encode('hex')))
120            return None
121
122        return entry['port']
123
[166]124    def learn(self, port, frame):
125        mac = frame.src_mac
126        vid = frame.vid
127
128        if vid not in self._dict:
129            self._dict[vid] = {}
130
131        self._dict[vid][mac] = {'time': time.time(), 'port': port}
132        self.dprintf('learned: [%d] %s\n',
133                     lambda: (vid, mac.encode('hex')))
134
[164]135    def delete(self, port):
[166]136        for vid in self._dict.keys():
137            for mac in self._dict[vid].keys():
138                if self._dict[vid][mac]['port'] is port:
139                    del self._dict[vid][mac]
[164]140                    self.dprintf('deleted: [%d] %s\n',
141                                 lambda: (vid, mac.encode('hex')))
[166]142            if not self._dict[vid]:
143                del self._dict[vid]
[164]144
145
[166]146class SwitchingHub(DebugMixIn):
147    def __init__(self, fdb, debug=False):
148        self._fdb = fdb
[133]149        self._debug = debug
[166]150        self._ports = []
[133]151
[166]152    def register_port(self, port):
153        self._ports.append(port)
[133]154
[166]155    def unregister_port(self, port):
156        self._fdb.delete(port)
157        self._ports.remove(port)
[133]158
[166]159    def forward(self, src_port, frame):
160        try:
[173]161            if not frame.multicast(frame.src_mac):
[172]162                self._fdb.learn(src_port, frame)
[133]163
[173]164            if not frame.multicast(frame.dst_mac):
[166]165                dst_port = self._fdb.lookup(frame)
[164]166
[166]167                if dst_port:
168                    self._unicast(frame, dst_port)
169                    return
[133]170
[166]171            self._broadcast(frame, src_port)
[162]172
[166]173        except:  # ex. received invalid frame
174            traceback.print_exc()
[133]175
[166]176    def _unicast(self, frame, port):
177        port.write_message(frame.data, True)
178        self.dprintf('sent unicast: [%d] %s -> %s\n',
179                     lambda: (frame.vid,
180                              frame.src_mac.encode('hex'),
181                              frame.dst_mac.encode('hex')))
[164]182
[166]183    def _broadcast(self, frame, *except_ports):
184        ports = self._ports[:]
185        for port in except_ports:
186            ports.remove(port)
187        for port in ports:
188            port.write_message(frame.data, True)
[164]189        self.dprintf('sent broadcast: [%d] %s -> %s\n',
190                     lambda: (frame.vid,
191                              frame.src_mac.encode('hex'),
192                              frame.dst_mac.encode('hex')))
193
[166]194
195class TapHandler(DebugMixIn):
196    READ_SIZE = 65535
197
198    def __init__(self, switch, dev, debug=False):
199        self._switch = switch
200        self._dev = dev
201        self._debug = debug
202        self._tap = None
203
204    @property
205    def closed(self):
206        return not self._tap
207
208    def open(self):
209        if not self.closed:
210            raise ValueError('already opened')
211        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
212        self._tap.up()
213        self._switch.register_port(self)
214
215    def close(self):
216        if self.closed:
217            raise ValueError('I/O operation on closed tap')
218        self._switch.unregister_port(self)
219        self._tap.close()
220        self._tap = None
221
222    def fileno(self):
223        if self.closed:
224            raise ValueError('I/O operation on closed tap')
225        return self._tap.fileno()
226
227    def write_message(self, message, binary=False):
228        if self.closed:
229            raise ValueError('I/O operation on closed tap')
230        self._tap.write(message)
231
[138]232    def __call__(self, fd, events):
[166]233        try:
234            self._switch.forward(self, EthernetFrame(self._read()))
235            return
236        except:
237            traceback.print_exc()
[174]238        tornado.ioloop.IOLoop.instance().stop()  # XXX: should unregister fd
[166]239
240    def _read(self):
241        if self.closed:
242            raise ValueError('I/O operation on closed tap')
[162]243        buf = []
244        while True:
[166]245            buf.append(self._tap.read(self.READ_SIZE))
246            if len(buf[-1]) < self.READ_SIZE:
[162]247                break
[166]248        return ''.join(buf)
[162]249
250
[160]251class EtherWebSocketHandler(tornado.websocket.WebSocketHandler, DebugMixIn):
[166]252    def __init__(self, app, req, switch, debug=False):
[160]253        super(EtherWebSocketHandler, self).__init__(app, req)
[166]254        self._switch = switch
[133]255        self._debug = debug
256
257    def open(self):
[166]258        self._switch.register_port(self)
[164]259        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
[133]260
261    def on_message(self, message):
[166]262        self._switch.forward(self, EthernetFrame(message))
[133]263
264    def on_close(self):
[166]265        self._switch.unregister_port(self)
[164]266        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
[133]267
268
[160]269class EtherWebSocketClient(DebugMixIn):
[174]270    def __init__(self, switch, url, cred=None, debug=False):
[166]271        self._switch = switch
[151]272        self._url = url
[160]273        self._debug = debug
[166]274        self._sock = None
[151]275        self._options = {}
276
[174]277        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
278            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
[151]279            auth = ['Authorization: Basic %s' % token]
280            self._options['header'] = auth
281
[160]282    @property
283    def closed(self):
284        return not self._sock
285
[151]286    def open(self):
[160]287        if not self.closed:
288            raise websocket.WebSocketException('already opened')
[151]289        self._sock = websocket.WebSocket()
290        self._sock.connect(self._url, **self._options)
[166]291        self._switch.register_port(self)
[164]292        self.dprintf('connected: %s\n', lambda: self._url)
[151]293
294    def close(self):
[160]295        if self.closed:
296            raise websocket.WebSocketException('already closed')
[166]297        self._switch.unregister_port(self)
[151]298        self._sock.close()
299        self._sock = None
[164]300        self.dprintf('disconnected: %s\n', lambda: self._url)
[151]301
[165]302    def fileno(self):
303        if self.closed:
304            raise websocket.WebSocketException('closed socket')
305        return self._sock.io_sock.fileno()
306
[151]307    def write_message(self, message, binary=False):
[160]308        if self.closed:
309            raise websocket.WebSocketException('closed socket')
[151]310        if binary:
311            flag = websocket.ABNF.OPCODE_BINARY
[160]312        else:
313            flag = websocket.ABNF.OPCODE_TEXT
[151]314        self._sock.send(message, flag)
315
[165]316    def __call__(self, fd, events):
[151]317        try:
[165]318            data = self._sock.recv()
319            if data is not None:
[166]320                self._switch.forward(self, EthernetFrame(data))
[165]321                return
322        except:
323            traceback.print_exc()
[174]324        tornado.ioloop.IOLoop.instance().stop()  # XXX: should unregister fd
[151]325
326
[174]327class Htpasswd(object):
328    def __init__(self, path):
329        self._path = path
330        self._stat = None
331        self._data = {}
332
333    def auth(self, name, passwd):
334        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
335        return self._data.get(name) == passwd
336
337    def load(self):
338        old_stat = self._stat
339
340        with open(self._path) as fp:
341            self._stat = os.fstat(fp.fileno())
342
343            unchanged = old_stat and \
344                        old_stat.st_ino == self._stat.st_ino and \
345                        old_stat.st_dev == self._stat.st_dev and \
346                        old_stat.st_mtime == self._stat.st_mtime
347
348            if not unchanged:
349                self._data = self._parse(fp)
350
351    def _parse(self, fp):
352        data = {}
353        for line in fp:
354            line = line.strip()
355            if 0 <= line.find(':'):
356                name, passwd = line.split(':', 1)
357                if passwd.startswith('{SHA}'):
358                    data[name] = passwd[5:]
359        return data
360
361
362def wrap_basic_auth(handler_class, htpasswd_path):
363    if not htpasswd_path:
364        return handler_class
365
366    old_execute = handler_class._execute
367    htpasswd = Htpasswd(htpasswd_path)
368
369    def execute(self, transforms, *args, **kwargs):
370        def auth_required():
371            self.stream.write(tornado.escape.utf8(
372                'HTTP/1.1 401 Authorization Required\r\n'
373                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
374            ))
375            self.stream.close()
376
377        creds = self.request.headers.get('Authorization')
378
379        if not creds or not creds.startswith('Basic '):
380            return auth_required()
381
382        try:
383            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
384            htpasswd.load()
385
386            if not htpasswd.auth(name, passwd):
387                return auth_required()
388
389            return old_execute(self, transforms, *args, **kwargs)
390
391        except:
392            return auth_required()
393
394    handler_class._execute = execute
395    return handler_class
396
397
[134]398def daemonize(nochdir=False, noclose=False):
399    if os.fork() > 0:
400        sys.exit(0)
401
402    os.setsid()
403
404    if os.fork() > 0:
405        sys.exit(0)
406
407    if not nochdir:
408        os.chdir('/')
409
410    if not noclose:
411        os.umask(0)
412        sys.stdin.close()
413        sys.stdout.close()
414        sys.stderr.close()
415        os.close(0)
416        os.close(1)
417        os.close(2)
418        sys.stdin = open(os.devnull)
419        sys.stdout = open(os.devnull, 'a')
420        sys.stderr = open(os.devnull, 'a')
421
422
[160]423def realpath(ns, *keys):
424    for k in keys:
425        v = getattr(ns, k, None)
426        if v is not None:
427            v = os.path.realpath(v)
428            setattr(ns, k, v)
429            open(v).close()  # check readable
430    return ns
431
432
[133]433def server_main(args):
[160]434    realpath(args, 'keyfile', 'certfile', 'htpasswd')
[143]435
[160]436    if args.keyfile and args.certfile:
437        ssl_options = {'keyfile': args.keyfile, 'certfile': args.certfile}
438    elif args.keyfile or args.certfile:
[143]439        raise ValueError('both keyfile and certfile are required')
[160]440    else:
[143]441        ssl_options = None
442
[160]443    if args.port is None:
[143]444        if ssl_options:
445            args.port = 443
446        else:
447            args.port = 80
[160]448    elif not (0 <= args.port <= 65535):
449        raise ValueError('invalid port: %s' % args.port)
[143]450
[167]451    if args.ageout <= 0:
452        raise ValueError('invalid ageout: %s' % args.ageout)
453
454    ioloop = tornado.ioloop.IOLoop.instance()
455    fdb = FDB(ageout=args.ageout, debug=args.debug)
456    switch = SwitchingHub(fdb, debug=args.debug)
457
[174]458    handler = wrap_basic_auth(EtherWebSocketHandler, args.htpasswd)
459    srv = (args.path, handler, {'switch': switch, 'debug': args.debug})
460    app = tornado.web.Application([srv])
[143]461    server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_options)
[133]462    server.listen(args.port, address=args.address)
463
[174]464    for dev in args.device:
465        tap = TapHandler(switch, dev, debug=args.debug)
[167]466        tap.open()
467        ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
[151]468
469    if not args.foreground:
470        daemonize()
471
[138]472    ioloop.start()
[133]473
474
475def client_main(args):
[160]476    realpath(args, 'cacerts')
477
[133]478    if args.debug:
479        websocket.enableTrace(True)
480
[174]481    if args.insecure:
[156]482        websocket._SSLSocketWrapper = \
[174]483            lambda s: ssl.wrap_socket(s)
484    else:
485        websocket._SSLSocketWrapper = \
[156]486            lambda s: ssl.wrap_socket(s, cert_reqs=ssl.CERT_REQUIRED,
487                                      ca_certs=args.cacerts)
488
[174]489    if args.ageout <= 0:
490        raise ValueError('invalid ageout: %s' % args.ageout)
491
[160]492    if args.user and args.passwd is None:
493        args.passwd = getpass.getpass()
[143]494
[174]495    cred = {'user': args.user, 'passwd': args.passwd}
[167]496    ioloop = tornado.ioloop.IOLoop.instance()
497    fdb = FDB(ageout=args.ageout, debug=args.debug)
[166]498    switch = SwitchingHub(fdb, debug=args.debug)
[167]499
[174]500    for uri in args.uri:
501        client = EtherWebSocketClient(switch, uri, cred, args.debug)
[168]502        client.open()
503        ioloop.add_handler(client.fileno(), client, ioloop.READ)
504
[174]505    for dev in args.device:
506        tap = TapHandler(switch, dev, debug=args.debug)
[167]507        tap.open()
508        ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
509
[151]510    if not args.foreground:
511        daemonize()
512
[165]513    ioloop.start()
[133]514
[138]515
[133]516def main():
517    parser = argparse.ArgumentParser()
[167]518    parser.add_argument('--device', action='append', default=[])
519    parser.add_argument('--ageout', action='store', type=int, default=300)
[133]520    parser.add_argument('--foreground', action='store_true', default=False)
521    parser.add_argument('--debug', action='store_true', default=False)
522
523    subparsers = parser.add_subparsers(dest='subcommand')
524
[158]525    parser_s = subparsers.add_parser('server')
526    parser_s.add_argument('--address', action='store', default='')
527    parser_s.add_argument('--port', action='store', type=int)
528    parser_s.add_argument('--path', action='store', default='/')
529    parser_s.add_argument('--htpasswd', action='store')
530    parser_s.add_argument('--keyfile', action='store')
531    parser_s.add_argument('--certfile', action='store')
[133]532
[158]533    parser_c = subparsers.add_parser('client')
[168]534    parser_c.add_argument('--uri', action='append', default=[])
[158]535    parser_c.add_argument('--insecure', action='store_true', default=False)
536    parser_c.add_argument('--cacerts', action='store')
537    parser_c.add_argument('--user', action='store')
[160]538    parser_c.add_argument('--passwd', action='store')
[133]539
540    args = parser.parse_args()
541
542    if args.subcommand == 'server':
543        server_main(args)
544    elif args.subcommand == 'client':
545        client_main(args)
546
547
548if __name__ == '__main__':
549    main()
Note: See TracBrowser for help on using the repository browser.