source: etherws/trunk/etherws.py @ 165

Revision 165, 13.8 KB checked in by atzm, 12 years ago (diff)
  • kill multi-thread
  • Property svn:keywords set to Id
RevLine 
[133]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
[141]4#              Ethernet over WebSocket tunneling server/client
[133]5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
[136]9#   - websocket-client-0.7.0
10#   - tornado-2.2.1
[133]11#
[140]12# todo:
[143]13#   - servant mode support (like typical p2p software)
[140]14#
[133]15# ===========================================================================
16# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
17# All rights reserved.
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are met:
21#
22# 1. Redistributions of source code must retain the above copyright notice,
23#    this list of conditions and the following disclaimer.
24# 2. Redistributions in binary form must reproduce the above copyright
25#    notice, this list of conditions and the following disclaimer in the
26#    documentation and/or other materials provided with the distribution.
27#
28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38# POSSIBILITY OF SUCH DAMAGE.
39# ===========================================================================
40#
41# $Id$
42
43import os
44import sys
[156]45import ssl
[160]46import time
[150]47import base64
48import hashlib
[151]49import getpass
[133]50import argparse
[165]51import traceback
[133]52
53import pytun
54import websocket
[160]55import tornado.web
[133]56import tornado.ioloop
57import tornado.websocket
[160]58import tornado.httpserver
[133]59
60
[160]61class DebugMixIn(object):
[164]62    def dprintf(self, msg, func):
[160]63        if self._debug:
64            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
[164]65            sys.stderr.write(prefix + (msg % func()))
[160]66
67
[164]68class EthernetFrame(object):
69    def __init__(self, data):
70        self.data = data
71
72    @property
73    def multicast(self):
74        return ord(self.data[0]) & 1
75
76    @property
77    def dst_mac(self):
78        return self.data[:6]
79
80    @property
81    def src_mac(self):
82        return self.data[6:12]
83
84    @property
85    def tagged(self):
86        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
87
88    @property
89    def vid(self):
90        if self.tagged:
91            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
92        return -1
93
94
95class SwitchingTable(DebugMixIn):
96    def __init__(self, ageout=300, debug=False):
97        self._ageout = ageout
98        self._debug = debug
99        self._table = {}
100
101    def learn(self, frame, port):
102        mac = frame.src_mac
103        vid = frame.vid
104
105        if vid not in self._table:
106            self._table[vid] = {}
107
108        self._table[vid][mac] = {'time': time.time(), 'port': port}
109        self.dprintf('learned: [%d] %s\n',
110                     lambda: (vid, mac.encode('hex')))
111
112    def lookup(self, frame):
113        mac = frame.dst_mac
114        vid = frame.vid
115
116        group = self._table.get(vid, None)
117        if not group:
118            return None
119
120        entry = group.get(mac, None)
121        if not entry:
122            return None
123
124        if time.time() - entry['time'] > self._ageout:
125            del self._table[vid][mac]
126            if not self._table[vid]:
127                del self._table[vid]
128            self.dprintf('aged out: [%d] %s\n',
129                         lambda: (vid, mac.encode('hex')))
130            return None
131
132        return entry['port']
133
134    def delete(self, port):
135        for vid in self._table.keys():
136            for mac in self._table[vid].keys():
137                if self._table[vid][mac]['port'] is port:
138                    del self._table[vid][mac]
139                    self.dprintf('deleted: [%d] %s\n',
140                                 lambda: (vid, mac.encode('hex')))
141            if not self._table[vid]:
142                del self._table[vid]
143
144
[160]145class TapHandler(DebugMixIn):
[162]146    READ_SIZE = 65535
147
[133]148    def __init__(self, dev, debug=False):
149        self._debug = debug
[138]150        self._clients = []
[164]151        self._table = SwitchingTable(debug=debug)
[133]152        self._tap = pytun.TunTapDevice(dev, pytun.IFF_TAP | pytun.IFF_NO_PI)
[138]153        self._tap.up()
[164]154        self.register_client(self)
[133]155
[138]156    def fileno(self):
[164]157        return self._tap.fileno()
[133]158
159    def register_client(self, client):
[162]160        self._clients.append(client)
[133]161
162    def unregister_client(self, client):
[164]163        self._table.delete(client)
[162]164        self._clients.remove(client)
[133]165
[164]166    def write_message(self, message, binary=False):
[165]167        self._tap.write(message)
[164]168
[133]169    def write(self, caller, message):
[164]170        frame = EthernetFrame(message)
[133]171
[165]172        self._table.learn(frame, caller)
[162]173
[164]174        if not frame.multicast:
[165]175            dst = self._table.lookup(frame)
[133]176
[164]177            if dst:
178                dst.write_message(frame.data, True)
179                self.dprintf('sent unicast: [%d] %s -> %s\n',
180                             lambda: (frame.vid,
181                                      frame.src_mac.encode('hex'),
182                                      frame.dst_mac.encode('hex')))
183                return
184
185        clients = self._clients[:]
186        clients.remove(caller)
187
[162]188        for c in clients:
[164]189            c.write_message(frame.data, True)
[133]190
[164]191        self.dprintf('sent broadcast: [%d] %s -> %s\n',
192                     lambda: (frame.vid,
193                              frame.src_mac.encode('hex'),
194                              frame.dst_mac.encode('hex')))
195
[138]196    def __call__(self, fd, events):
[162]197        buf = []
[133]198
[162]199        while True:
[165]200            data = self._tap.read(self.READ_SIZE)
[135]201
[162]202            if data:
203                buf.append(data)
204
205            if len(data) < self.READ_SIZE:
206                break
207
208        self.write(self, ''.join(buf))
209
210
[160]211class EtherWebSocketHandler(tornado.websocket.WebSocketHandler, DebugMixIn):
[133]212    def __init__(self, app, req, tap, debug=False):
[160]213        super(EtherWebSocketHandler, self).__init__(app, req)
[133]214        self._tap = tap
215        self._debug = debug
216
217    def open(self):
218        self._tap.register_client(self)
[164]219        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
[133]220
221    def on_message(self, message):
[139]222        self._tap.write(self, message)
[133]223
224    def on_close(self):
225        self._tap.unregister_client(self)
[164]226        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
[133]227
228
[160]229class EtherWebSocketClient(DebugMixIn):
230    def __init__(self, tap, url, user=None, passwd=None, debug=False):
[151]231        self._sock = None
232        self._tap = tap
233        self._url = url
[160]234        self._debug = debug
[151]235        self._options = {}
236
237        if user and passwd:
238            token = base64.b64encode('%s:%s' % (user, passwd))
239            auth = ['Authorization: Basic %s' % token]
240            self._options['header'] = auth
241
[160]242    @property
243    def closed(self):
244        return not self._sock
245
[151]246    def open(self):
[160]247        if not self.closed:
248            raise websocket.WebSocketException('already opened')
[151]249        self._sock = websocket.WebSocket()
250        self._sock.connect(self._url, **self._options)
[164]251        self.dprintf('connected: %s\n', lambda: self._url)
[151]252
253    def close(self):
[160]254        if self.closed:
255            raise websocket.WebSocketException('already closed')
[151]256        self._sock.close()
257        self._sock = None
[164]258        self.dprintf('disconnected: %s\n', lambda: self._url)
[151]259
[165]260    def fileno(self):
261        if self.closed:
262            raise websocket.WebSocketException('closed socket')
263        return self._sock.io_sock.fileno()
264
[151]265    def write_message(self, message, binary=False):
[160]266        if self.closed:
267            raise websocket.WebSocketException('closed socket')
[151]268        if binary:
269            flag = websocket.ABNF.OPCODE_BINARY
[160]270        else:
271            flag = websocket.ABNF.OPCODE_TEXT
[151]272        self._sock.send(message, flag)
273
[165]274    def __call__(self, fd, events):
[151]275        try:
[165]276            data = self._sock.recv()
277            if data is not None:
[151]278                self._tap.write(self, data)
[165]279                return
280        except:
281            traceback.print_exc()
282        tornado.ioloop.IOLoop.instance().stop()
[151]283
284
[134]285def daemonize(nochdir=False, noclose=False):
286    if os.fork() > 0:
287        sys.exit(0)
288
289    os.setsid()
290
291    if os.fork() > 0:
292        sys.exit(0)
293
294    if not nochdir:
295        os.chdir('/')
296
297    if not noclose:
298        os.umask(0)
299        sys.stdin.close()
300        sys.stdout.close()
301        sys.stderr.close()
302        os.close(0)
303        os.close(1)
304        os.close(2)
305        sys.stdin = open(os.devnull)
306        sys.stdout = open(os.devnull, 'a')
307        sys.stderr = open(os.devnull, 'a')
308
309
[160]310def realpath(ns, *keys):
311    for k in keys:
312        v = getattr(ns, k, None)
313        if v is not None:
314            v = os.path.realpath(v)
315            setattr(ns, k, v)
316            open(v).close()  # check readable
317    return ns
318
319
[133]320def server_main(args):
[160]321    def wrap_basic_auth(cls, users):
322        o_exec = cls._execute
323
[150]324        if not users:
325            return cls
326
[160]327        def execute(self, transforms, *args, **kwargs):
[150]328            def auth_required():
329                self.stream.write(tornado.escape.utf8(
330                    'HTTP/1.1 401 Authorization Required\r\n'
331                    'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
332                ))
333                self.stream.close()
334
[160]335            creds = self.request.headers.get('Authorization')
[150]336
[160]337            if not creds or not creds.startswith('Basic '):
338                return auth_required()
[150]339
[160]340            try:
341                name, passwd = base64.b64decode(creds[6:]).split(':', 1)
[150]342                passwd = base64.b64encode(hashlib.sha1(passwd).digest())
343
344                if name not in users or users[name] != passwd:
345                    return auth_required()
346
[160]347                return o_exec(self, transforms, *args, **kwargs)
[150]348
349            except:
350                return auth_required()
351
[160]352        cls._execute = execute
[150]353        return cls
354
355    def load_htpasswd(path):
356        users = {}
357        try:
358            with open(path) as fp:
359                for line in fp:
360                    line = line.strip()
361                    if 0 <= line.find(':'):
[160]362                        name, passwd = line.split(':', 1)
[150]363                        if passwd.startswith('{SHA}'):
364                            users[name] = passwd[5:]
[156]365            if not users:
[160]366                raise ValueError('no valid users found')
[150]367        except TypeError:
368            pass
369        return users
370
[160]371    realpath(args, 'keyfile', 'certfile', 'htpasswd')
[143]372
[160]373    if args.keyfile and args.certfile:
374        ssl_options = {'keyfile': args.keyfile, 'certfile': args.certfile}
375    elif args.keyfile or args.certfile:
[143]376        raise ValueError('both keyfile and certfile are required')
[160]377    else:
[143]378        ssl_options = None
379
[160]380    if args.port is None:
[143]381        if ssl_options:
382            args.port = 443
383        else:
384            args.port = 80
[160]385    elif not (0 <= args.port <= 65535):
386        raise ValueError('invalid port: %s' % args.port)
[143]387
[160]388    handler = wrap_basic_auth(EtherWebSocketHandler,
389                              load_htpasswd(args.htpasswd))
390
[138]391    tap = TapHandler(args.device, debug=args.debug)
[133]392    app = tornado.web.Application([
[150]393        (args.path, handler, {'tap': tap, 'debug': args.debug}),
[133]394    ])
[143]395    server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_options)
[133]396    server.listen(args.port, address=args.address)
397
[138]398    ioloop = tornado.ioloop.IOLoop.instance()
399    ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
[151]400
401    if not args.foreground:
402        daemonize()
403
[138]404    ioloop.start()
[133]405
406
407def client_main(args):
[160]408    realpath(args, 'cacerts')
409
[133]410    if args.debug:
411        websocket.enableTrace(True)
412
[156]413    if not args.insecure:
414        websocket._SSLSocketWrapper = \
415            lambda s: ssl.wrap_socket(s, cert_reqs=ssl.CERT_REQUIRED,
416                                      ca_certs=args.cacerts)
417
[160]418    if args.user and args.passwd is None:
419        args.passwd = getpass.getpass()
[143]420
[138]421    tap = TapHandler(args.device, debug=args.debug)
[160]422    client = EtherWebSocketClient(tap, args.uri,
423                                  args.user, args.passwd, args.debug)
[151]424
[133]425    tap.register_client(client)
[151]426    client.open()
[133]427
[138]428    ioloop = tornado.ioloop.IOLoop.instance()
429    ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
[165]430    ioloop.add_handler(client.fileno(), client, ioloop.READ)
[151]431
432    if not args.foreground:
433        daemonize()
434
[165]435    ioloop.start()
[133]436
[138]437
[133]438def main():
439    parser = argparse.ArgumentParser()
440    parser.add_argument('--device', action='store', default='ethws%d')
441    parser.add_argument('--foreground', action='store_true', default=False)
442    parser.add_argument('--debug', action='store_true', default=False)
443
444    subparsers = parser.add_subparsers(dest='subcommand')
445
[158]446    parser_s = subparsers.add_parser('server')
447    parser_s.add_argument('--address', action='store', default='')
448    parser_s.add_argument('--port', action='store', type=int)
449    parser_s.add_argument('--path', action='store', default='/')
450    parser_s.add_argument('--htpasswd', action='store')
451    parser_s.add_argument('--keyfile', action='store')
452    parser_s.add_argument('--certfile', action='store')
[133]453
[158]454    parser_c = subparsers.add_parser('client')
455    parser_c.add_argument('--uri', action='store', required=True)
456    parser_c.add_argument('--insecure', action='store_true', default=False)
457    parser_c.add_argument('--cacerts', action='store')
458    parser_c.add_argument('--user', action='store')
[160]459    parser_c.add_argument('--passwd', action='store')
[133]460
461    args = parser.parse_args()
462
463    if args.subcommand == 'server':
464        server_main(args)
465    elif args.subcommand == 'client':
466        client_main(args)
467
468
469if __name__ == '__main__':
470    main()
Note: See TracBrowser for help on using the repository browser.