source: etherws/trunk/etherws.py @ 175

Revision 175, 16.4 KB checked in by atzm, 12 years ago (diff)
  • acquire shared lock while loading htpasswd
  • Property svn:keywords set to Id
RevLine 
[133]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
[141]4#              Ethernet over WebSocket tunneling server/client
[133]5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
[136]9#   - websocket-client-0.7.0
10#   - tornado-2.2.1
[133]11#
[140]12# todo:
[143]13#   - servant mode support (like typical p2p software)
[140]14#
[133]15# ===========================================================================
16# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
17# All rights reserved.
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are met:
21#
22# 1. Redistributions of source code must retain the above copyright notice,
23#    this list of conditions and the following disclaimer.
24# 2. Redistributions in binary form must reproduce the above copyright
25#    notice, this list of conditions and the following disclaimer in the
26#    documentation and/or other materials provided with the distribution.
27#
28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38# POSSIBILITY OF SUCH DAMAGE.
39# ===========================================================================
40#
41# $Id$
42
43import os
44import sys
[156]45import ssl
[160]46import time
[175]47import fcntl
[150]48import base64
49import hashlib
[151]50import getpass
[133]51import argparse
[165]52import traceback
[133]53
54import websocket
[160]55import tornado.web
[133]56import tornado.ioloop
57import tornado.websocket
[160]58import tornado.httpserver
[133]59
[166]60from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
[133]61
[166]62
[160]63class DebugMixIn(object):
[166]64    def dprintf(self, msg, func=lambda: ()):
[160]65        if self._debug:
66            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
[164]67            sys.stderr.write(prefix + (msg % func()))
[160]68
69
[164]70class EthernetFrame(object):
71    def __init__(self, data):
72        self.data = data
73
[173]74    @staticmethod
75    def multicast(mac):
76        return ord(mac[0]) & 1
[164]77
78    @property
79    def dst_mac(self):
80        return self.data[:6]
81
82    @property
83    def src_mac(self):
84        return self.data[6:12]
85
86    @property
87    def tagged(self):
88        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
89
90    @property
91    def vid(self):
92        if self.tagged:
93            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
94        return -1
95
96
[166]97class FDB(DebugMixIn):
[167]98    def __init__(self, ageout, debug=False):
[164]99        self._ageout = ageout
100        self._debug = debug
[166]101        self._dict = {}
[164]102
103    def lookup(self, frame):
104        mac = frame.dst_mac
105        vid = frame.vid
106
[166]107        group = self._dict.get(vid, None)
[164]108        if not group:
109            return None
110
111        entry = group.get(mac, None)
112        if not entry:
113            return None
114
115        if time.time() - entry['time'] > self._ageout:
[166]116            del self._dict[vid][mac]
117            if not self._dict[vid]:
118                del self._dict[vid]
[164]119            self.dprintf('aged out: [%d] %s\n',
120                         lambda: (vid, mac.encode('hex')))
121            return None
122
123        return entry['port']
124
[166]125    def learn(self, port, frame):
126        mac = frame.src_mac
127        vid = frame.vid
128
129        if vid not in self._dict:
130            self._dict[vid] = {}
131
132        self._dict[vid][mac] = {'time': time.time(), 'port': port}
133        self.dprintf('learned: [%d] %s\n',
134                     lambda: (vid, mac.encode('hex')))
135
[164]136    def delete(self, port):
[166]137        for vid in self._dict.keys():
138            for mac in self._dict[vid].keys():
139                if self._dict[vid][mac]['port'] is port:
140                    del self._dict[vid][mac]
[164]141                    self.dprintf('deleted: [%d] %s\n',
142                                 lambda: (vid, mac.encode('hex')))
[166]143            if not self._dict[vid]:
144                del self._dict[vid]
[164]145
146
[166]147class SwitchingHub(DebugMixIn):
148    def __init__(self, fdb, debug=False):
149        self._fdb = fdb
[133]150        self._debug = debug
[166]151        self._ports = []
[133]152
[166]153    def register_port(self, port):
154        self._ports.append(port)
[133]155
[166]156    def unregister_port(self, port):
157        self._fdb.delete(port)
158        self._ports.remove(port)
[133]159
[166]160    def forward(self, src_port, frame):
161        try:
[173]162            if not frame.multicast(frame.src_mac):
[172]163                self._fdb.learn(src_port, frame)
[133]164
[173]165            if not frame.multicast(frame.dst_mac):
[166]166                dst_port = self._fdb.lookup(frame)
[164]167
[166]168                if dst_port:
169                    self._unicast(frame, dst_port)
170                    return
[133]171
[166]172            self._broadcast(frame, src_port)
[162]173
[166]174        except:  # ex. received invalid frame
175            traceback.print_exc()
[133]176
[166]177    def _unicast(self, frame, port):
178        port.write_message(frame.data, True)
179        self.dprintf('sent unicast: [%d] %s -> %s\n',
180                     lambda: (frame.vid,
181                              frame.src_mac.encode('hex'),
182                              frame.dst_mac.encode('hex')))
[164]183
[166]184    def _broadcast(self, frame, *except_ports):
185        ports = self._ports[:]
186        for port in except_ports:
187            ports.remove(port)
188        for port in ports:
189            port.write_message(frame.data, True)
[164]190        self.dprintf('sent broadcast: [%d] %s -> %s\n',
191                     lambda: (frame.vid,
192                              frame.src_mac.encode('hex'),
193                              frame.dst_mac.encode('hex')))
194
[166]195
196class TapHandler(DebugMixIn):
197    READ_SIZE = 65535
198
199    def __init__(self, switch, dev, debug=False):
200        self._switch = switch
201        self._dev = dev
202        self._debug = debug
203        self._tap = None
204
205    @property
206    def closed(self):
207        return not self._tap
208
209    def open(self):
210        if not self.closed:
211            raise ValueError('already opened')
212        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
213        self._tap.up()
214        self._switch.register_port(self)
215
216    def close(self):
217        if self.closed:
218            raise ValueError('I/O operation on closed tap')
219        self._switch.unregister_port(self)
220        self._tap.close()
221        self._tap = None
222
223    def fileno(self):
224        if self.closed:
225            raise ValueError('I/O operation on closed tap')
226        return self._tap.fileno()
227
228    def write_message(self, message, binary=False):
229        if self.closed:
230            raise ValueError('I/O operation on closed tap')
231        self._tap.write(message)
232
[138]233    def __call__(self, fd, events):
[166]234        try:
235            self._switch.forward(self, EthernetFrame(self._read()))
236            return
237        except:
238            traceback.print_exc()
[174]239        tornado.ioloop.IOLoop.instance().stop()  # XXX: should unregister fd
[166]240
241    def _read(self):
242        if self.closed:
243            raise ValueError('I/O operation on closed tap')
[162]244        buf = []
245        while True:
[166]246            buf.append(self._tap.read(self.READ_SIZE))
247            if len(buf[-1]) < self.READ_SIZE:
[162]248                break
[166]249        return ''.join(buf)
[162]250
251
[160]252class EtherWebSocketHandler(tornado.websocket.WebSocketHandler, DebugMixIn):
[166]253    def __init__(self, app, req, switch, debug=False):
[160]254        super(EtherWebSocketHandler, self).__init__(app, req)
[166]255        self._switch = switch
[133]256        self._debug = debug
257
258    def open(self):
[166]259        self._switch.register_port(self)
[164]260        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
[133]261
262    def on_message(self, message):
[166]263        self._switch.forward(self, EthernetFrame(message))
[133]264
265    def on_close(self):
[166]266        self._switch.unregister_port(self)
[164]267        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
[133]268
269
[160]270class EtherWebSocketClient(DebugMixIn):
[174]271    def __init__(self, switch, url, cred=None, debug=False):
[166]272        self._switch = switch
[151]273        self._url = url
[160]274        self._debug = debug
[166]275        self._sock = None
[151]276        self._options = {}
277
[174]278        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
279            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
[151]280            auth = ['Authorization: Basic %s' % token]
281            self._options['header'] = auth
282
[160]283    @property
284    def closed(self):
285        return not self._sock
286
[151]287    def open(self):
[160]288        if not self.closed:
289            raise websocket.WebSocketException('already opened')
[151]290        self._sock = websocket.WebSocket()
291        self._sock.connect(self._url, **self._options)
[166]292        self._switch.register_port(self)
[164]293        self.dprintf('connected: %s\n', lambda: self._url)
[151]294
295    def close(self):
[160]296        if self.closed:
297            raise websocket.WebSocketException('already closed')
[166]298        self._switch.unregister_port(self)
[151]299        self._sock.close()
300        self._sock = None
[164]301        self.dprintf('disconnected: %s\n', lambda: self._url)
[151]302
[165]303    def fileno(self):
304        if self.closed:
305            raise websocket.WebSocketException('closed socket')
306        return self._sock.io_sock.fileno()
307
[151]308    def write_message(self, message, binary=False):
[160]309        if self.closed:
310            raise websocket.WebSocketException('closed socket')
[151]311        if binary:
312            flag = websocket.ABNF.OPCODE_BINARY
[160]313        else:
314            flag = websocket.ABNF.OPCODE_TEXT
[151]315        self._sock.send(message, flag)
316
[165]317    def __call__(self, fd, events):
[151]318        try:
[165]319            data = self._sock.recv()
320            if data is not None:
[166]321                self._switch.forward(self, EthernetFrame(data))
[165]322                return
323        except:
324            traceback.print_exc()
[174]325        tornado.ioloop.IOLoop.instance().stop()  # XXX: should unregister fd
[151]326
327
[174]328class Htpasswd(object):
329    def __init__(self, path):
330        self._path = path
331        self._stat = None
332        self._data = {}
333
334    def auth(self, name, passwd):
335        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
336        return self._data.get(name) == passwd
337
338    def load(self):
339        old_stat = self._stat
340
341        with open(self._path) as fp:
[175]342            fileno = fp.fileno()
343            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
344            self._stat = os.fstat(fileno)
[174]345
346            unchanged = old_stat and \
347                        old_stat.st_ino == self._stat.st_ino and \
348                        old_stat.st_dev == self._stat.st_dev and \
349                        old_stat.st_mtime == self._stat.st_mtime
350
351            if not unchanged:
352                self._data = self._parse(fp)
353
354    def _parse(self, fp):
355        data = {}
356        for line in fp:
357            line = line.strip()
358            if 0 <= line.find(':'):
359                name, passwd = line.split(':', 1)
360                if passwd.startswith('{SHA}'):
361                    data[name] = passwd[5:]
362        return data
363
364
365def wrap_basic_auth(handler_class, htpasswd_path):
366    if not htpasswd_path:
367        return handler_class
368
369    old_execute = handler_class._execute
370    htpasswd = Htpasswd(htpasswd_path)
371
372    def execute(self, transforms, *args, **kwargs):
373        def auth_required():
374            self.stream.write(tornado.escape.utf8(
375                'HTTP/1.1 401 Authorization Required\r\n'
376                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
377            ))
378            self.stream.close()
379
380        creds = self.request.headers.get('Authorization')
381
382        if not creds or not creds.startswith('Basic '):
383            return auth_required()
384
385        try:
386            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
387            htpasswd.load()
388
389            if not htpasswd.auth(name, passwd):
390                return auth_required()
391
392            return old_execute(self, transforms, *args, **kwargs)
393
394        except:
395            return auth_required()
396
397    handler_class._execute = execute
398    return handler_class
399
400
[134]401def daemonize(nochdir=False, noclose=False):
402    if os.fork() > 0:
403        sys.exit(0)
404
405    os.setsid()
406
407    if os.fork() > 0:
408        sys.exit(0)
409
410    if not nochdir:
411        os.chdir('/')
412
413    if not noclose:
414        os.umask(0)
415        sys.stdin.close()
416        sys.stdout.close()
417        sys.stderr.close()
418        os.close(0)
419        os.close(1)
420        os.close(2)
421        sys.stdin = open(os.devnull)
422        sys.stdout = open(os.devnull, 'a')
423        sys.stderr = open(os.devnull, 'a')
424
425
[160]426def realpath(ns, *keys):
427    for k in keys:
428        v = getattr(ns, k, None)
429        if v is not None:
430            v = os.path.realpath(v)
431            setattr(ns, k, v)
432            open(v).close()  # check readable
433    return ns
434
435
[133]436def server_main(args):
[160]437    realpath(args, 'keyfile', 'certfile', 'htpasswd')
[143]438
[160]439    if args.keyfile and args.certfile:
440        ssl_options = {'keyfile': args.keyfile, 'certfile': args.certfile}
441    elif args.keyfile or args.certfile:
[143]442        raise ValueError('both keyfile and certfile are required')
[160]443    else:
[143]444        ssl_options = None
445
[160]446    if args.port is None:
[143]447        if ssl_options:
448            args.port = 443
449        else:
450            args.port = 80
[160]451    elif not (0 <= args.port <= 65535):
452        raise ValueError('invalid port: %s' % args.port)
[143]453
[167]454    if args.ageout <= 0:
455        raise ValueError('invalid ageout: %s' % args.ageout)
456
457    ioloop = tornado.ioloop.IOLoop.instance()
458    fdb = FDB(ageout=args.ageout, debug=args.debug)
459    switch = SwitchingHub(fdb, debug=args.debug)
460
[174]461    handler = wrap_basic_auth(EtherWebSocketHandler, args.htpasswd)
462    srv = (args.path, handler, {'switch': switch, 'debug': args.debug})
463    app = tornado.web.Application([srv])
[143]464    server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_options)
[133]465    server.listen(args.port, address=args.address)
466
[174]467    for dev in args.device:
468        tap = TapHandler(switch, dev, debug=args.debug)
[167]469        tap.open()
470        ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
[151]471
472    if not args.foreground:
473        daemonize()
474
[138]475    ioloop.start()
[133]476
477
478def client_main(args):
[160]479    realpath(args, 'cacerts')
480
[133]481    if args.debug:
482        websocket.enableTrace(True)
483
[174]484    if args.insecure:
[156]485        websocket._SSLSocketWrapper = \
[174]486            lambda s: ssl.wrap_socket(s)
487    else:
488        websocket._SSLSocketWrapper = \
[156]489            lambda s: ssl.wrap_socket(s, cert_reqs=ssl.CERT_REQUIRED,
490                                      ca_certs=args.cacerts)
491
[174]492    if args.ageout <= 0:
493        raise ValueError('invalid ageout: %s' % args.ageout)
494
[160]495    if args.user and args.passwd is None:
496        args.passwd = getpass.getpass()
[143]497
[174]498    cred = {'user': args.user, 'passwd': args.passwd}
[167]499    ioloop = tornado.ioloop.IOLoop.instance()
500    fdb = FDB(ageout=args.ageout, debug=args.debug)
[166]501    switch = SwitchingHub(fdb, debug=args.debug)
[167]502
[174]503    for uri in args.uri:
504        client = EtherWebSocketClient(switch, uri, cred, args.debug)
[168]505        client.open()
506        ioloop.add_handler(client.fileno(), client, ioloop.READ)
507
[174]508    for dev in args.device:
509        tap = TapHandler(switch, dev, debug=args.debug)
[167]510        tap.open()
511        ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
512
[151]513    if not args.foreground:
514        daemonize()
515
[165]516    ioloop.start()
[133]517
[138]518
[133]519def main():
520    parser = argparse.ArgumentParser()
[167]521    parser.add_argument('--device', action='append', default=[])
522    parser.add_argument('--ageout', action='store', type=int, default=300)
[133]523    parser.add_argument('--foreground', action='store_true', default=False)
524    parser.add_argument('--debug', action='store_true', default=False)
525
526    subparsers = parser.add_subparsers(dest='subcommand')
527
[158]528    parser_s = subparsers.add_parser('server')
529    parser_s.add_argument('--address', action='store', default='')
530    parser_s.add_argument('--port', action='store', type=int)
531    parser_s.add_argument('--path', action='store', default='/')
532    parser_s.add_argument('--htpasswd', action='store')
533    parser_s.add_argument('--keyfile', action='store')
534    parser_s.add_argument('--certfile', action='store')
[133]535
[158]536    parser_c = subparsers.add_parser('client')
[168]537    parser_c.add_argument('--uri', action='append', default=[])
[158]538    parser_c.add_argument('--insecure', action='store_true', default=False)
539    parser_c.add_argument('--cacerts', action='store')
540    parser_c.add_argument('--user', action='store')
[160]541    parser_c.add_argument('--passwd', action='store')
[133]542
543    args = parser.parse_args()
544
545    if args.subcommand == 'server':
546        server_main(args)
547    elif args.subcommand == 'client':
548        client_main(args)
549
550
551if __name__ == '__main__':
552    main()
Note: See TracBrowser for help on using the repository browser.