source: etherws/trunk/etherws.py @ 174

Revision 174, 16.3 KB checked in by atzm, 12 years ago (diff)
  • dynamic htpasswd loading support
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#              Ethernet over WebSocket tunneling server/client
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.2.1
11#
12# todo:
13#   - servant mode support (like typical p2p software)
14#
15# ===========================================================================
16# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
17# All rights reserved.
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are met:
21#
22# 1. Redistributions of source code must retain the above copyright notice,
23#    this list of conditions and the following disclaimer.
24# 2. Redistributions in binary form must reproduce the above copyright
25#    notice, this list of conditions and the following disclaimer in the
26#    documentation and/or other materials provided with the distribution.
27#
28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38# POSSIBILITY OF SUCH DAMAGE.
39# ===========================================================================
40#
41# $Id$
42
43import os
44import sys
45import ssl
46import time
47import base64
48import hashlib
49import getpass
50import argparse
51import traceback
52
53import websocket
54import tornado.web
55import tornado.ioloop
56import tornado.websocket
57import tornado.httpserver
58
59from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
60
61
62class DebugMixIn(object):
63    def dprintf(self, msg, func=lambda: ()):
64        if self._debug:
65            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
66            sys.stderr.write(prefix + (msg % func()))
67
68
69class EthernetFrame(object):
70    def __init__(self, data):
71        self.data = data
72
73    @staticmethod
74    def multicast(mac):
75        return ord(mac[0]) & 1
76
77    @property
78    def dst_mac(self):
79        return self.data[:6]
80
81    @property
82    def src_mac(self):
83        return self.data[6:12]
84
85    @property
86    def tagged(self):
87        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
88
89    @property
90    def vid(self):
91        if self.tagged:
92            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
93        return -1
94
95
96class FDB(DebugMixIn):
97    def __init__(self, ageout, debug=False):
98        self._ageout = ageout
99        self._debug = debug
100        self._dict = {}
101
102    def lookup(self, frame):
103        mac = frame.dst_mac
104        vid = frame.vid
105
106        group = self._dict.get(vid, None)
107        if not group:
108            return None
109
110        entry = group.get(mac, None)
111        if not entry:
112            return None
113
114        if time.time() - entry['time'] > self._ageout:
115            del self._dict[vid][mac]
116            if not self._dict[vid]:
117                del self._dict[vid]
118            self.dprintf('aged out: [%d] %s\n',
119                         lambda: (vid, mac.encode('hex')))
120            return None
121
122        return entry['port']
123
124    def learn(self, port, frame):
125        mac = frame.src_mac
126        vid = frame.vid
127
128        if vid not in self._dict:
129            self._dict[vid] = {}
130
131        self._dict[vid][mac] = {'time': time.time(), 'port': port}
132        self.dprintf('learned: [%d] %s\n',
133                     lambda: (vid, mac.encode('hex')))
134
135    def delete(self, port):
136        for vid in self._dict.keys():
137            for mac in self._dict[vid].keys():
138                if self._dict[vid][mac]['port'] is port:
139                    del self._dict[vid][mac]
140                    self.dprintf('deleted: [%d] %s\n',
141                                 lambda: (vid, mac.encode('hex')))
142            if not self._dict[vid]:
143                del self._dict[vid]
144
145
146class SwitchingHub(DebugMixIn):
147    def __init__(self, fdb, debug=False):
148        self._fdb = fdb
149        self._debug = debug
150        self._ports = []
151
152    def register_port(self, port):
153        self._ports.append(port)
154
155    def unregister_port(self, port):
156        self._fdb.delete(port)
157        self._ports.remove(port)
158
159    def forward(self, src_port, frame):
160        try:
161            if not frame.multicast(frame.src_mac):
162                self._fdb.learn(src_port, frame)
163
164            if not frame.multicast(frame.dst_mac):
165                dst_port = self._fdb.lookup(frame)
166
167                if dst_port:
168                    self._unicast(frame, dst_port)
169                    return
170
171            self._broadcast(frame, src_port)
172
173        except:  # ex. received invalid frame
174            traceback.print_exc()
175
176    def _unicast(self, frame, port):
177        port.write_message(frame.data, True)
178        self.dprintf('sent unicast: [%d] %s -> %s\n',
179                     lambda: (frame.vid,
180                              frame.src_mac.encode('hex'),
181                              frame.dst_mac.encode('hex')))
182
183    def _broadcast(self, frame, *except_ports):
184        ports = self._ports[:]
185        for port in except_ports:
186            ports.remove(port)
187        for port in ports:
188            port.write_message(frame.data, True)
189        self.dprintf('sent broadcast: [%d] %s -> %s\n',
190                     lambda: (frame.vid,
191                              frame.src_mac.encode('hex'),
192                              frame.dst_mac.encode('hex')))
193
194
195class TapHandler(DebugMixIn):
196    READ_SIZE = 65535
197
198    def __init__(self, switch, dev, debug=False):
199        self._switch = switch
200        self._dev = dev
201        self._debug = debug
202        self._tap = None
203
204    @property
205    def closed(self):
206        return not self._tap
207
208    def open(self):
209        if not self.closed:
210            raise ValueError('already opened')
211        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
212        self._tap.up()
213        self._switch.register_port(self)
214
215    def close(self):
216        if self.closed:
217            raise ValueError('I/O operation on closed tap')
218        self._switch.unregister_port(self)
219        self._tap.close()
220        self._tap = None
221
222    def fileno(self):
223        if self.closed:
224            raise ValueError('I/O operation on closed tap')
225        return self._tap.fileno()
226
227    def write_message(self, message, binary=False):
228        if self.closed:
229            raise ValueError('I/O operation on closed tap')
230        self._tap.write(message)
231
232    def __call__(self, fd, events):
233        try:
234            self._switch.forward(self, EthernetFrame(self._read()))
235            return
236        except:
237            traceback.print_exc()
238        tornado.ioloop.IOLoop.instance().stop()  # XXX: should unregister fd
239
240    def _read(self):
241        if self.closed:
242            raise ValueError('I/O operation on closed tap')
243        buf = []
244        while True:
245            buf.append(self._tap.read(self.READ_SIZE))
246            if len(buf[-1]) < self.READ_SIZE:
247                break
248        return ''.join(buf)
249
250
251class EtherWebSocketHandler(tornado.websocket.WebSocketHandler, DebugMixIn):
252    def __init__(self, app, req, switch, debug=False):
253        super(EtherWebSocketHandler, self).__init__(app, req)
254        self._switch = switch
255        self._debug = debug
256
257    def open(self):
258        self._switch.register_port(self)
259        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
260
261    def on_message(self, message):
262        self._switch.forward(self, EthernetFrame(message))
263
264    def on_close(self):
265        self._switch.unregister_port(self)
266        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
267
268
269class EtherWebSocketClient(DebugMixIn):
270    def __init__(self, switch, url, cred=None, debug=False):
271        self._switch = switch
272        self._url = url
273        self._debug = debug
274        self._sock = None
275        self._options = {}
276
277        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
278            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
279            auth = ['Authorization: Basic %s' % token]
280            self._options['header'] = auth
281
282    @property
283    def closed(self):
284        return not self._sock
285
286    def open(self):
287        if not self.closed:
288            raise websocket.WebSocketException('already opened')
289        self._sock = websocket.WebSocket()
290        self._sock.connect(self._url, **self._options)
291        self._switch.register_port(self)
292        self.dprintf('connected: %s\n', lambda: self._url)
293
294    def close(self):
295        if self.closed:
296            raise websocket.WebSocketException('already closed')
297        self._switch.unregister_port(self)
298        self._sock.close()
299        self._sock = None
300        self.dprintf('disconnected: %s\n', lambda: self._url)
301
302    def fileno(self):
303        if self.closed:
304            raise websocket.WebSocketException('closed socket')
305        return self._sock.io_sock.fileno()
306
307    def write_message(self, message, binary=False):
308        if self.closed:
309            raise websocket.WebSocketException('closed socket')
310        if binary:
311            flag = websocket.ABNF.OPCODE_BINARY
312        else:
313            flag = websocket.ABNF.OPCODE_TEXT
314        self._sock.send(message, flag)
315
316    def __call__(self, fd, events):
317        try:
318            data = self._sock.recv()
319            if data is not None:
320                self._switch.forward(self, EthernetFrame(data))
321                return
322        except:
323            traceback.print_exc()
324        tornado.ioloop.IOLoop.instance().stop()  # XXX: should unregister fd
325
326
327class Htpasswd(object):
328    def __init__(self, path):
329        self._path = path
330        self._stat = None
331        self._data = {}
332
333    def auth(self, name, passwd):
334        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
335        return self._data.get(name) == passwd
336
337    def load(self):
338        old_stat = self._stat
339
340        with open(self._path) as fp:
341            self._stat = os.fstat(fp.fileno())
342
343            unchanged = old_stat and \
344                        old_stat.st_ino == self._stat.st_ino and \
345                        old_stat.st_dev == self._stat.st_dev and \
346                        old_stat.st_mtime == self._stat.st_mtime
347
348            if not unchanged:
349                self._data = self._parse(fp)
350
351    def _parse(self, fp):
352        data = {}
353        for line in fp:
354            line = line.strip()
355            if 0 <= line.find(':'):
356                name, passwd = line.split(':', 1)
357                if passwd.startswith('{SHA}'):
358                    data[name] = passwd[5:]
359        return data
360
361
362def wrap_basic_auth(handler_class, htpasswd_path):
363    if not htpasswd_path:
364        return handler_class
365
366    old_execute = handler_class._execute
367    htpasswd = Htpasswd(htpasswd_path)
368
369    def execute(self, transforms, *args, **kwargs):
370        def auth_required():
371            self.stream.write(tornado.escape.utf8(
372                'HTTP/1.1 401 Authorization Required\r\n'
373                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
374            ))
375            self.stream.close()
376
377        creds = self.request.headers.get('Authorization')
378
379        if not creds or not creds.startswith('Basic '):
380            return auth_required()
381
382        try:
383            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
384            htpasswd.load()
385
386            if not htpasswd.auth(name, passwd):
387                return auth_required()
388
389            return old_execute(self, transforms, *args, **kwargs)
390
391        except:
392            return auth_required()
393
394    handler_class._execute = execute
395    return handler_class
396
397
398def daemonize(nochdir=False, noclose=False):
399    if os.fork() > 0:
400        sys.exit(0)
401
402    os.setsid()
403
404    if os.fork() > 0:
405        sys.exit(0)
406
407    if not nochdir:
408        os.chdir('/')
409
410    if not noclose:
411        os.umask(0)
412        sys.stdin.close()
413        sys.stdout.close()
414        sys.stderr.close()
415        os.close(0)
416        os.close(1)
417        os.close(2)
418        sys.stdin = open(os.devnull)
419        sys.stdout = open(os.devnull, 'a')
420        sys.stderr = open(os.devnull, 'a')
421
422
423def realpath(ns, *keys):
424    for k in keys:
425        v = getattr(ns, k, None)
426        if v is not None:
427            v = os.path.realpath(v)
428            setattr(ns, k, v)
429            open(v).close()  # check readable
430    return ns
431
432
433def server_main(args):
434    realpath(args, 'keyfile', 'certfile', 'htpasswd')
435
436    if args.keyfile and args.certfile:
437        ssl_options = {'keyfile': args.keyfile, 'certfile': args.certfile}
438    elif args.keyfile or args.certfile:
439        raise ValueError('both keyfile and certfile are required')
440    else:
441        ssl_options = None
442
443    if args.port is None:
444        if ssl_options:
445            args.port = 443
446        else:
447            args.port = 80
448    elif not (0 <= args.port <= 65535):
449        raise ValueError('invalid port: %s' % args.port)
450
451    if args.ageout <= 0:
452        raise ValueError('invalid ageout: %s' % args.ageout)
453
454    ioloop = tornado.ioloop.IOLoop.instance()
455    fdb = FDB(ageout=args.ageout, debug=args.debug)
456    switch = SwitchingHub(fdb, debug=args.debug)
457
458    handler = wrap_basic_auth(EtherWebSocketHandler, args.htpasswd)
459    srv = (args.path, handler, {'switch': switch, 'debug': args.debug})
460    app = tornado.web.Application([srv])
461    server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_options)
462    server.listen(args.port, address=args.address)
463
464    for dev in args.device:
465        tap = TapHandler(switch, dev, debug=args.debug)
466        tap.open()
467        ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
468
469    if not args.foreground:
470        daemonize()
471
472    ioloop.start()
473
474
475def client_main(args):
476    realpath(args, 'cacerts')
477
478    if args.debug:
479        websocket.enableTrace(True)
480
481    if args.insecure:
482        websocket._SSLSocketWrapper = \
483            lambda s: ssl.wrap_socket(s)
484    else:
485        websocket._SSLSocketWrapper = \
486            lambda s: ssl.wrap_socket(s, cert_reqs=ssl.CERT_REQUIRED,
487                                      ca_certs=args.cacerts)
488
489    if args.ageout <= 0:
490        raise ValueError('invalid ageout: %s' % args.ageout)
491
492    if args.user and args.passwd is None:
493        args.passwd = getpass.getpass()
494
495    cred = {'user': args.user, 'passwd': args.passwd}
496    ioloop = tornado.ioloop.IOLoop.instance()
497    fdb = FDB(ageout=args.ageout, debug=args.debug)
498    switch = SwitchingHub(fdb, debug=args.debug)
499
500    for uri in args.uri:
501        client = EtherWebSocketClient(switch, uri, cred, args.debug)
502        client.open()
503        ioloop.add_handler(client.fileno(), client, ioloop.READ)
504
505    for dev in args.device:
506        tap = TapHandler(switch, dev, debug=args.debug)
507        tap.open()
508        ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
509
510    if not args.foreground:
511        daemonize()
512
513    ioloop.start()
514
515
516def main():
517    parser = argparse.ArgumentParser()
518    parser.add_argument('--device', action='append', default=[])
519    parser.add_argument('--ageout', action='store', type=int, default=300)
520    parser.add_argument('--foreground', action='store_true', default=False)
521    parser.add_argument('--debug', action='store_true', default=False)
522
523    subparsers = parser.add_subparsers(dest='subcommand')
524
525    parser_s = subparsers.add_parser('server')
526    parser_s.add_argument('--address', action='store', default='')
527    parser_s.add_argument('--port', action='store', type=int)
528    parser_s.add_argument('--path', action='store', default='/')
529    parser_s.add_argument('--htpasswd', action='store')
530    parser_s.add_argument('--keyfile', action='store')
531    parser_s.add_argument('--certfile', action='store')
532
533    parser_c = subparsers.add_parser('client')
534    parser_c.add_argument('--uri', action='append', default=[])
535    parser_c.add_argument('--insecure', action='store_true', default=False)
536    parser_c.add_argument('--cacerts', action='store')
537    parser_c.add_argument('--user', action='store')
538    parser_c.add_argument('--passwd', action='store')
539
540    args = parser.parse_args()
541
542    if args.subcommand == 'server':
543        server_main(args)
544    elif args.subcommand == 'client':
545        client_main(args)
546
547
548if __name__ == '__main__':
549    main()
Note: See TracBrowser for help on using the repository browser.