source: etherws/trunk/etherws.py @ 172

Revision 172, 15.8 KB checked in by atzm, 12 years ago (diff)
  • do not learn multicast src address (IEEE Std. 802.1D compatible)
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#              Ethernet over WebSocket tunneling server/client
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.2.1
11#
12# todo:
13#   - servant mode support (like typical p2p software)
14#
15# ===========================================================================
16# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
17# All rights reserved.
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are met:
21#
22# 1. Redistributions of source code must retain the above copyright notice,
23#    this list of conditions and the following disclaimer.
24# 2. Redistributions in binary form must reproduce the above copyright
25#    notice, this list of conditions and the following disclaimer in the
26#    documentation and/or other materials provided with the distribution.
27#
28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38# POSSIBILITY OF SUCH DAMAGE.
39# ===========================================================================
40#
41# $Id$
42
43import os
44import sys
45import ssl
46import time
47import base64
48import hashlib
49import getpass
50import argparse
51import traceback
52
53import websocket
54import tornado.web
55import tornado.ioloop
56import tornado.websocket
57import tornado.httpserver
58
59from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
60
61
62class DebugMixIn(object):
63    def dprintf(self, msg, func=lambda: ()):
64        if self._debug:
65            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
66            sys.stderr.write(prefix + (msg % func()))
67
68
69class EthernetFrame(object):
70    def __init__(self, data):
71        self.data = data
72
73    @property
74    def multicast(self):
75        return ord(self.data[0]) & 1
76
77    @property
78    def dst_mac(self):
79        return self.data[:6]
80
81    @property
82    def src_mac(self):
83        return self.data[6:12]
84
85    @property
86    def tagged(self):
87        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
88
89    @property
90    def vid(self):
91        if self.tagged:
92            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
93        return -1
94
95
96class FDB(DebugMixIn):
97    def __init__(self, ageout, debug=False):
98        self._ageout = ageout
99        self._debug = debug
100        self._dict = {}
101
102    def lookup(self, frame):
103        mac = frame.dst_mac
104        vid = frame.vid
105
106        group = self._dict.get(vid, None)
107        if not group:
108            return None
109
110        entry = group.get(mac, None)
111        if not entry:
112            return None
113
114        if time.time() - entry['time'] > self._ageout:
115            del self._dict[vid][mac]
116            if not self._dict[vid]:
117                del self._dict[vid]
118            self.dprintf('aged out: [%d] %s\n',
119                         lambda: (vid, mac.encode('hex')))
120            return None
121
122        return entry['port']
123
124    def learn(self, port, frame):
125        mac = frame.src_mac
126        vid = frame.vid
127
128        if vid not in self._dict:
129            self._dict[vid] = {}
130
131        self._dict[vid][mac] = {'time': time.time(), 'port': port}
132        self.dprintf('learned: [%d] %s\n',
133                     lambda: (vid, mac.encode('hex')))
134
135    def delete(self, port):
136        for vid in self._dict.keys():
137            for mac in self._dict[vid].keys():
138                if self._dict[vid][mac]['port'] is port:
139                    del self._dict[vid][mac]
140                    self.dprintf('deleted: [%d] %s\n',
141                                 lambda: (vid, mac.encode('hex')))
142            if not self._dict[vid]:
143                del self._dict[vid]
144
145
146class SwitchingHub(DebugMixIn):
147    def __init__(self, fdb, debug=False):
148        self._fdb = fdb
149        self._debug = debug
150        self._ports = []
151
152    def register_port(self, port):
153        self._ports.append(port)
154
155    def unregister_port(self, port):
156        self._fdb.delete(port)
157        self._ports.remove(port)
158
159    def forward(self, src_port, frame):
160        try:
161            if not frame.multicast:
162                self._fdb.learn(src_port, frame)
163
164                dst_port = self._fdb.lookup(frame)
165
166                if dst_port:
167                    self._unicast(frame, dst_port)
168                    return
169
170            self._broadcast(frame, src_port)
171
172        except:  # ex. received invalid frame
173            traceback.print_exc()
174
175    def _unicast(self, frame, port):
176        port.write_message(frame.data, True)
177        self.dprintf('sent unicast: [%d] %s -> %s\n',
178                     lambda: (frame.vid,
179                              frame.src_mac.encode('hex'),
180                              frame.dst_mac.encode('hex')))
181
182    def _broadcast(self, frame, *except_ports):
183        ports = self._ports[:]
184        for port in except_ports:
185            ports.remove(port)
186        for port in ports:
187            port.write_message(frame.data, True)
188        self.dprintf('sent broadcast: [%d] %s -> %s\n',
189                     lambda: (frame.vid,
190                              frame.src_mac.encode('hex'),
191                              frame.dst_mac.encode('hex')))
192
193
194class TapHandler(DebugMixIn):
195    READ_SIZE = 65535
196
197    def __init__(self, switch, dev, debug=False):
198        self._switch = switch
199        self._dev = dev
200        self._debug = debug
201        self._tap = None
202
203    @property
204    def closed(self):
205        return not self._tap
206
207    def open(self):
208        if not self.closed:
209            raise ValueError('already opened')
210        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
211        self._tap.up()
212        self._switch.register_port(self)
213
214    def close(self):
215        if self.closed:
216            raise ValueError('I/O operation on closed tap')
217        self._switch.unregister_port(self)
218        self._tap.close()
219        self._tap = None
220
221    def fileno(self):
222        if self.closed:
223            raise ValueError('I/O operation on closed tap')
224        return self._tap.fileno()
225
226    def write_message(self, message, binary=False):
227        if self.closed:
228            raise ValueError('I/O operation on closed tap')
229        self._tap.write(message)
230
231    def __call__(self, fd, events):
232        try:
233            self._switch.forward(self, EthernetFrame(self._read()))
234            return
235        except:
236            traceback.print_exc()
237        tornado.ioloop.IOLoop.instance().stop()
238
239    def _read(self):
240        if self.closed:
241            raise ValueError('I/O operation on closed tap')
242        buf = []
243        while True:
244            buf.append(self._tap.read(self.READ_SIZE))
245            if len(buf[-1]) < self.READ_SIZE:
246                break
247        return ''.join(buf)
248
249
250class EtherWebSocketHandler(tornado.websocket.WebSocketHandler, DebugMixIn):
251    def __init__(self, app, req, switch, debug=False):
252        super(EtherWebSocketHandler, self).__init__(app, req)
253        self._switch = switch
254        self._debug = debug
255
256    def open(self):
257        self._switch.register_port(self)
258        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
259
260    def on_message(self, message):
261        self._switch.forward(self, EthernetFrame(message))
262
263    def on_close(self):
264        self._switch.unregister_port(self)
265        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
266
267
268class EtherWebSocketClient(DebugMixIn):
269    def __init__(self, switch, url, user=None, passwd=None, debug=False):
270        self._switch = switch
271        self._url = url
272        self._debug = debug
273        self._sock = None
274        self._options = {}
275
276        if user and passwd:
277            token = base64.b64encode('%s:%s' % (user, passwd))
278            auth = ['Authorization: Basic %s' % token]
279            self._options['header'] = auth
280
281    @property
282    def closed(self):
283        return not self._sock
284
285    def open(self):
286        if not self.closed:
287            raise websocket.WebSocketException('already opened')
288        self._sock = websocket.WebSocket()
289        self._sock.connect(self._url, **self._options)
290        self._switch.register_port(self)
291        self.dprintf('connected: %s\n', lambda: self._url)
292
293    def close(self):
294        if self.closed:
295            raise websocket.WebSocketException('already closed')
296        self._switch.unregister_port(self)
297        self._sock.close()
298        self._sock = None
299        self.dprintf('disconnected: %s\n', lambda: self._url)
300
301    def fileno(self):
302        if self.closed:
303            raise websocket.WebSocketException('closed socket')
304        return self._sock.io_sock.fileno()
305
306    def write_message(self, message, binary=False):
307        if self.closed:
308            raise websocket.WebSocketException('closed socket')
309        if binary:
310            flag = websocket.ABNF.OPCODE_BINARY
311        else:
312            flag = websocket.ABNF.OPCODE_TEXT
313        self._sock.send(message, flag)
314
315    def __call__(self, fd, events):
316        try:
317            data = self._sock.recv()
318            if data is not None:
319                self._switch.forward(self, EthernetFrame(data))
320                return
321        except:
322            traceback.print_exc()
323        tornado.ioloop.IOLoop.instance().stop()
324
325
326def daemonize(nochdir=False, noclose=False):
327    if os.fork() > 0:
328        sys.exit(0)
329
330    os.setsid()
331
332    if os.fork() > 0:
333        sys.exit(0)
334
335    if not nochdir:
336        os.chdir('/')
337
338    if not noclose:
339        os.umask(0)
340        sys.stdin.close()
341        sys.stdout.close()
342        sys.stderr.close()
343        os.close(0)
344        os.close(1)
345        os.close(2)
346        sys.stdin = open(os.devnull)
347        sys.stdout = open(os.devnull, 'a')
348        sys.stderr = open(os.devnull, 'a')
349
350
351def realpath(ns, *keys):
352    for k in keys:
353        v = getattr(ns, k, None)
354        if v is not None:
355            v = os.path.realpath(v)
356            setattr(ns, k, v)
357            open(v).close()  # check readable
358    return ns
359
360
361def server_main(args):
362    def wrap_basic_auth(cls, users):
363        o_exec = cls._execute
364
365        if not users:
366            return cls
367
368        def execute(self, transforms, *args, **kwargs):
369            def auth_required():
370                self.stream.write(tornado.escape.utf8(
371                    'HTTP/1.1 401 Authorization Required\r\n'
372                    'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
373                ))
374                self.stream.close()
375
376            creds = self.request.headers.get('Authorization')
377
378            if not creds or not creds.startswith('Basic '):
379                return auth_required()
380
381            try:
382                name, passwd = base64.b64decode(creds[6:]).split(':', 1)
383                passwd = base64.b64encode(hashlib.sha1(passwd).digest())
384
385                if name not in users or users[name] != passwd:
386                    return auth_required()
387
388                return o_exec(self, transforms, *args, **kwargs)
389
390            except:
391                return auth_required()
392
393        cls._execute = execute
394        return cls
395
396    def load_htpasswd(path):
397        users = {}
398        try:
399            with open(path) as fp:
400                for line in fp:
401                    line = line.strip()
402                    if 0 <= line.find(':'):
403                        name, passwd = line.split(':', 1)
404                        if passwd.startswith('{SHA}'):
405                            users[name] = passwd[5:]
406            if not users:
407                raise ValueError('no valid users found')
408        except TypeError:
409            pass
410        return users
411
412    realpath(args, 'keyfile', 'certfile', 'htpasswd')
413
414    if args.keyfile and args.certfile:
415        ssl_options = {'keyfile': args.keyfile, 'certfile': args.certfile}
416    elif args.keyfile or args.certfile:
417        raise ValueError('both keyfile and certfile are required')
418    else:
419        ssl_options = None
420
421    if args.port is None:
422        if ssl_options:
423            args.port = 443
424        else:
425            args.port = 80
426    elif not (0 <= args.port <= 65535):
427        raise ValueError('invalid port: %s' % args.port)
428
429    if args.ageout <= 0:
430        raise ValueError('invalid ageout: %s' % args.ageout)
431
432    ioloop = tornado.ioloop.IOLoop.instance()
433    fdb = FDB(ageout=args.ageout, debug=args.debug)
434    switch = SwitchingHub(fdb, debug=args.debug)
435    taps = [TapHandler(switch, dev, debug=args.debug) for dev in args.device]
436
437    handler = wrap_basic_auth(EtherWebSocketHandler,
438                              load_htpasswd(args.htpasswd))
439    app = tornado.web.Application([
440        (args.path, handler, {'switch': switch, 'debug': args.debug}),
441    ])
442    server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_options)
443    server.listen(args.port, address=args.address)
444
445    for tap in taps:
446        tap.open()
447        ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
448
449    if not args.foreground:
450        daemonize()
451
452    ioloop.start()
453
454
455def client_main(args):
456    realpath(args, 'cacerts')
457
458    if args.debug:
459        websocket.enableTrace(True)
460
461    if not args.insecure:
462        websocket._SSLSocketWrapper = \
463            lambda s: ssl.wrap_socket(s, cert_reqs=ssl.CERT_REQUIRED,
464                                      ca_certs=args.cacerts)
465    else:
466        websocket._SSLSocketWrapper = \
467            lambda s: ssl.wrap_socket(s)
468
469    if args.user and args.passwd is None:
470        args.passwd = getpass.getpass()
471
472    if args.ageout <= 0:
473        raise ValueError('invalid ageout: %s' % args.ageout)
474
475    ioloop = tornado.ioloop.IOLoop.instance()
476    fdb = FDB(ageout=args.ageout, debug=args.debug)
477    switch = SwitchingHub(fdb, debug=args.debug)
478    taps = [TapHandler(switch, dev, debug=args.debug) for dev in args.device]
479
480    clients = [EtherWebSocketClient(switch, uri,
481                                    args.user, args.passwd, args.debug)
482               for uri in args.uri]
483
484    for client in clients:
485        client.open()
486        ioloop.add_handler(client.fileno(), client, ioloop.READ)
487
488    for tap in taps:
489        tap.open()
490        ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
491
492    if not args.foreground:
493        daemonize()
494
495    ioloop.start()
496
497
498def main():
499    parser = argparse.ArgumentParser()
500    parser.add_argument('--device', action='append', default=[])
501    parser.add_argument('--ageout', action='store', type=int, default=300)
502    parser.add_argument('--foreground', action='store_true', default=False)
503    parser.add_argument('--debug', action='store_true', default=False)
504
505    subparsers = parser.add_subparsers(dest='subcommand')
506
507    parser_s = subparsers.add_parser('server')
508    parser_s.add_argument('--address', action='store', default='')
509    parser_s.add_argument('--port', action='store', type=int)
510    parser_s.add_argument('--path', action='store', default='/')
511    parser_s.add_argument('--htpasswd', action='store')
512    parser_s.add_argument('--keyfile', action='store')
513    parser_s.add_argument('--certfile', action='store')
514
515    parser_c = subparsers.add_parser('client')
516    parser_c.add_argument('--uri', action='append', default=[])
517    parser_c.add_argument('--insecure', action='store_true', default=False)
518    parser_c.add_argument('--cacerts', action='store')
519    parser_c.add_argument('--user', action='store')
520    parser_c.add_argument('--passwd', action='store')
521
522    args = parser.parse_args()
523
524    if args.subcommand == 'server':
525        server_main(args)
526    elif args.subcommand == 'client':
527        client_main(args)
528
529
530if __name__ == '__main__':
531    main()
Note: See TracBrowser for help on using the repository browser.