source: etherws/trunk/etherws.py @ 173

Revision 173, 15.8 KB checked in by atzm, 12 years ago (diff)
  • fixed a trivial bug in r172
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#              Ethernet over WebSocket tunneling server/client
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.2.1
11#
12# todo:
13#   - servant mode support (like typical p2p software)
14#
15# ===========================================================================
16# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
17# All rights reserved.
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are met:
21#
22# 1. Redistributions of source code must retain the above copyright notice,
23#    this list of conditions and the following disclaimer.
24# 2. Redistributions in binary form must reproduce the above copyright
25#    notice, this list of conditions and the following disclaimer in the
26#    documentation and/or other materials provided with the distribution.
27#
28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38# POSSIBILITY OF SUCH DAMAGE.
39# ===========================================================================
40#
41# $Id$
42
43import os
44import sys
45import ssl
46import time
47import base64
48import hashlib
49import getpass
50import argparse
51import traceback
52
53import websocket
54import tornado.web
55import tornado.ioloop
56import tornado.websocket
57import tornado.httpserver
58
59from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
60
61
62class DebugMixIn(object):
63    def dprintf(self, msg, func=lambda: ()):
64        if self._debug:
65            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
66            sys.stderr.write(prefix + (msg % func()))
67
68
69class EthernetFrame(object):
70    def __init__(self, data):
71        self.data = data
72
73    @staticmethod
74    def multicast(mac):
75        return ord(mac[0]) & 1
76
77    @property
78    def dst_mac(self):
79        return self.data[:6]
80
81    @property
82    def src_mac(self):
83        return self.data[6:12]
84
85    @property
86    def tagged(self):
87        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
88
89    @property
90    def vid(self):
91        if self.tagged:
92            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
93        return -1
94
95
96class FDB(DebugMixIn):
97    def __init__(self, ageout, debug=False):
98        self._ageout = ageout
99        self._debug = debug
100        self._dict = {}
101
102    def lookup(self, frame):
103        mac = frame.dst_mac
104        vid = frame.vid
105
106        group = self._dict.get(vid, None)
107        if not group:
108            return None
109
110        entry = group.get(mac, None)
111        if not entry:
112            return None
113
114        if time.time() - entry['time'] > self._ageout:
115            del self._dict[vid][mac]
116            if not self._dict[vid]:
117                del self._dict[vid]
118            self.dprintf('aged out: [%d] %s\n',
119                         lambda: (vid, mac.encode('hex')))
120            return None
121
122        return entry['port']
123
124    def learn(self, port, frame):
125        mac = frame.src_mac
126        vid = frame.vid
127
128        if vid not in self._dict:
129            self._dict[vid] = {}
130
131        self._dict[vid][mac] = {'time': time.time(), 'port': port}
132        self.dprintf('learned: [%d] %s\n',
133                     lambda: (vid, mac.encode('hex')))
134
135    def delete(self, port):
136        for vid in self._dict.keys():
137            for mac in self._dict[vid].keys():
138                if self._dict[vid][mac]['port'] is port:
139                    del self._dict[vid][mac]
140                    self.dprintf('deleted: [%d] %s\n',
141                                 lambda: (vid, mac.encode('hex')))
142            if not self._dict[vid]:
143                del self._dict[vid]
144
145
146class SwitchingHub(DebugMixIn):
147    def __init__(self, fdb, debug=False):
148        self._fdb = fdb
149        self._debug = debug
150        self._ports = []
151
152    def register_port(self, port):
153        self._ports.append(port)
154
155    def unregister_port(self, port):
156        self._fdb.delete(port)
157        self._ports.remove(port)
158
159    def forward(self, src_port, frame):
160        try:
161            if not frame.multicast(frame.src_mac):
162                self._fdb.learn(src_port, frame)
163
164            if not frame.multicast(frame.dst_mac):
165                dst_port = self._fdb.lookup(frame)
166
167                if dst_port:
168                    self._unicast(frame, dst_port)
169                    return
170
171            self._broadcast(frame, src_port)
172
173        except:  # ex. received invalid frame
174            traceback.print_exc()
175
176    def _unicast(self, frame, port):
177        port.write_message(frame.data, True)
178        self.dprintf('sent unicast: [%d] %s -> %s\n',
179                     lambda: (frame.vid,
180                              frame.src_mac.encode('hex'),
181                              frame.dst_mac.encode('hex')))
182
183    def _broadcast(self, frame, *except_ports):
184        ports = self._ports[:]
185        for port in except_ports:
186            ports.remove(port)
187        for port in ports:
188            port.write_message(frame.data, True)
189        self.dprintf('sent broadcast: [%d] %s -> %s\n',
190                     lambda: (frame.vid,
191                              frame.src_mac.encode('hex'),
192                              frame.dst_mac.encode('hex')))
193
194
195class TapHandler(DebugMixIn):
196    READ_SIZE = 65535
197
198    def __init__(self, switch, dev, debug=False):
199        self._switch = switch
200        self._dev = dev
201        self._debug = debug
202        self._tap = None
203
204    @property
205    def closed(self):
206        return not self._tap
207
208    def open(self):
209        if not self.closed:
210            raise ValueError('already opened')
211        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
212        self._tap.up()
213        self._switch.register_port(self)
214
215    def close(self):
216        if self.closed:
217            raise ValueError('I/O operation on closed tap')
218        self._switch.unregister_port(self)
219        self._tap.close()
220        self._tap = None
221
222    def fileno(self):
223        if self.closed:
224            raise ValueError('I/O operation on closed tap')
225        return self._tap.fileno()
226
227    def write_message(self, message, binary=False):
228        if self.closed:
229            raise ValueError('I/O operation on closed tap')
230        self._tap.write(message)
231
232    def __call__(self, fd, events):
233        try:
234            self._switch.forward(self, EthernetFrame(self._read()))
235            return
236        except:
237            traceback.print_exc()
238        tornado.ioloop.IOLoop.instance().stop()
239
240    def _read(self):
241        if self.closed:
242            raise ValueError('I/O operation on closed tap')
243        buf = []
244        while True:
245            buf.append(self._tap.read(self.READ_SIZE))
246            if len(buf[-1]) < self.READ_SIZE:
247                break
248        return ''.join(buf)
249
250
251class EtherWebSocketHandler(tornado.websocket.WebSocketHandler, DebugMixIn):
252    def __init__(self, app, req, switch, debug=False):
253        super(EtherWebSocketHandler, self).__init__(app, req)
254        self._switch = switch
255        self._debug = debug
256
257    def open(self):
258        self._switch.register_port(self)
259        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
260
261    def on_message(self, message):
262        self._switch.forward(self, EthernetFrame(message))
263
264    def on_close(self):
265        self._switch.unregister_port(self)
266        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
267
268
269class EtherWebSocketClient(DebugMixIn):
270    def __init__(self, switch, url, user=None, passwd=None, debug=False):
271        self._switch = switch
272        self._url = url
273        self._debug = debug
274        self._sock = None
275        self._options = {}
276
277        if user and passwd:
278            token = base64.b64encode('%s:%s' % (user, passwd))
279            auth = ['Authorization: Basic %s' % token]
280            self._options['header'] = auth
281
282    @property
283    def closed(self):
284        return not self._sock
285
286    def open(self):
287        if not self.closed:
288            raise websocket.WebSocketException('already opened')
289        self._sock = websocket.WebSocket()
290        self._sock.connect(self._url, **self._options)
291        self._switch.register_port(self)
292        self.dprintf('connected: %s\n', lambda: self._url)
293
294    def close(self):
295        if self.closed:
296            raise websocket.WebSocketException('already closed')
297        self._switch.unregister_port(self)
298        self._sock.close()
299        self._sock = None
300        self.dprintf('disconnected: %s\n', lambda: self._url)
301
302    def fileno(self):
303        if self.closed:
304            raise websocket.WebSocketException('closed socket')
305        return self._sock.io_sock.fileno()
306
307    def write_message(self, message, binary=False):
308        if self.closed:
309            raise websocket.WebSocketException('closed socket')
310        if binary:
311            flag = websocket.ABNF.OPCODE_BINARY
312        else:
313            flag = websocket.ABNF.OPCODE_TEXT
314        self._sock.send(message, flag)
315
316    def __call__(self, fd, events):
317        try:
318            data = self._sock.recv()
319            if data is not None:
320                self._switch.forward(self, EthernetFrame(data))
321                return
322        except:
323            traceback.print_exc()
324        tornado.ioloop.IOLoop.instance().stop()
325
326
327def daemonize(nochdir=False, noclose=False):
328    if os.fork() > 0:
329        sys.exit(0)
330
331    os.setsid()
332
333    if os.fork() > 0:
334        sys.exit(0)
335
336    if not nochdir:
337        os.chdir('/')
338
339    if not noclose:
340        os.umask(0)
341        sys.stdin.close()
342        sys.stdout.close()
343        sys.stderr.close()
344        os.close(0)
345        os.close(1)
346        os.close(2)
347        sys.stdin = open(os.devnull)
348        sys.stdout = open(os.devnull, 'a')
349        sys.stderr = open(os.devnull, 'a')
350
351
352def realpath(ns, *keys):
353    for k in keys:
354        v = getattr(ns, k, None)
355        if v is not None:
356            v = os.path.realpath(v)
357            setattr(ns, k, v)
358            open(v).close()  # check readable
359    return ns
360
361
362def server_main(args):
363    def wrap_basic_auth(cls, users):
364        o_exec = cls._execute
365
366        if not users:
367            return cls
368
369        def execute(self, transforms, *args, **kwargs):
370            def auth_required():
371                self.stream.write(tornado.escape.utf8(
372                    'HTTP/1.1 401 Authorization Required\r\n'
373                    'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
374                ))
375                self.stream.close()
376
377            creds = self.request.headers.get('Authorization')
378
379            if not creds or not creds.startswith('Basic '):
380                return auth_required()
381
382            try:
383                name, passwd = base64.b64decode(creds[6:]).split(':', 1)
384                passwd = base64.b64encode(hashlib.sha1(passwd).digest())
385
386                if name not in users or users[name] != passwd:
387                    return auth_required()
388
389                return o_exec(self, transforms, *args, **kwargs)
390
391            except:
392                return auth_required()
393
394        cls._execute = execute
395        return cls
396
397    def load_htpasswd(path):
398        users = {}
399        try:
400            with open(path) as fp:
401                for line in fp:
402                    line = line.strip()
403                    if 0 <= line.find(':'):
404                        name, passwd = line.split(':', 1)
405                        if passwd.startswith('{SHA}'):
406                            users[name] = passwd[5:]
407            if not users:
408                raise ValueError('no valid users found')
409        except TypeError:
410            pass
411        return users
412
413    realpath(args, 'keyfile', 'certfile', 'htpasswd')
414
415    if args.keyfile and args.certfile:
416        ssl_options = {'keyfile': args.keyfile, 'certfile': args.certfile}
417    elif args.keyfile or args.certfile:
418        raise ValueError('both keyfile and certfile are required')
419    else:
420        ssl_options = None
421
422    if args.port is None:
423        if ssl_options:
424            args.port = 443
425        else:
426            args.port = 80
427    elif not (0 <= args.port <= 65535):
428        raise ValueError('invalid port: %s' % args.port)
429
430    if args.ageout <= 0:
431        raise ValueError('invalid ageout: %s' % args.ageout)
432
433    ioloop = tornado.ioloop.IOLoop.instance()
434    fdb = FDB(ageout=args.ageout, debug=args.debug)
435    switch = SwitchingHub(fdb, debug=args.debug)
436    taps = [TapHandler(switch, dev, debug=args.debug) for dev in args.device]
437
438    handler = wrap_basic_auth(EtherWebSocketHandler,
439                              load_htpasswd(args.htpasswd))
440    app = tornado.web.Application([
441        (args.path, handler, {'switch': switch, 'debug': args.debug}),
442    ])
443    server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_options)
444    server.listen(args.port, address=args.address)
445
446    for tap in taps:
447        tap.open()
448        ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
449
450    if not args.foreground:
451        daemonize()
452
453    ioloop.start()
454
455
456def client_main(args):
457    realpath(args, 'cacerts')
458
459    if args.debug:
460        websocket.enableTrace(True)
461
462    if not args.insecure:
463        websocket._SSLSocketWrapper = \
464            lambda s: ssl.wrap_socket(s, cert_reqs=ssl.CERT_REQUIRED,
465                                      ca_certs=args.cacerts)
466    else:
467        websocket._SSLSocketWrapper = \
468            lambda s: ssl.wrap_socket(s)
469
470    if args.user and args.passwd is None:
471        args.passwd = getpass.getpass()
472
473    if args.ageout <= 0:
474        raise ValueError('invalid ageout: %s' % args.ageout)
475
476    ioloop = tornado.ioloop.IOLoop.instance()
477    fdb = FDB(ageout=args.ageout, debug=args.debug)
478    switch = SwitchingHub(fdb, debug=args.debug)
479    taps = [TapHandler(switch, dev, debug=args.debug) for dev in args.device]
480
481    clients = [EtherWebSocketClient(switch, uri,
482                                    args.user, args.passwd, args.debug)
483               for uri in args.uri]
484
485    for client in clients:
486        client.open()
487        ioloop.add_handler(client.fileno(), client, ioloop.READ)
488
489    for tap in taps:
490        tap.open()
491        ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
492
493    if not args.foreground:
494        daemonize()
495
496    ioloop.start()
497
498
499def main():
500    parser = argparse.ArgumentParser()
501    parser.add_argument('--device', action='append', default=[])
502    parser.add_argument('--ageout', action='store', type=int, default=300)
503    parser.add_argument('--foreground', action='store_true', default=False)
504    parser.add_argument('--debug', action='store_true', default=False)
505
506    subparsers = parser.add_subparsers(dest='subcommand')
507
508    parser_s = subparsers.add_parser('server')
509    parser_s.add_argument('--address', action='store', default='')
510    parser_s.add_argument('--port', action='store', type=int)
511    parser_s.add_argument('--path', action='store', default='/')
512    parser_s.add_argument('--htpasswd', action='store')
513    parser_s.add_argument('--keyfile', action='store')
514    parser_s.add_argument('--certfile', action='store')
515
516    parser_c = subparsers.add_parser('client')
517    parser_c.add_argument('--uri', action='append', default=[])
518    parser_c.add_argument('--insecure', action='store_true', default=False)
519    parser_c.add_argument('--cacerts', action='store')
520    parser_c.add_argument('--user', action='store')
521    parser_c.add_argument('--passwd', action='store')
522
523    args = parser.parse_args()
524
525    if args.subcommand == 'server':
526        server_main(args)
527    elif args.subcommand == 'client':
528        client_main(args)
529
530
531if __name__ == '__main__':
532    main()
Note: See TracBrowser for help on using the repository browser.