source: etherws/trunk/etherws.py @ 165

Revision 165, 13.8 KB checked in by atzm, 12 years ago (diff)
  • kill multi-thread
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#              Ethernet over WebSocket tunneling server/client
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.2.1
11#
12# todo:
13#   - servant mode support (like typical p2p software)
14#
15# ===========================================================================
16# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
17# All rights reserved.
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are met:
21#
22# 1. Redistributions of source code must retain the above copyright notice,
23#    this list of conditions and the following disclaimer.
24# 2. Redistributions in binary form must reproduce the above copyright
25#    notice, this list of conditions and the following disclaimer in the
26#    documentation and/or other materials provided with the distribution.
27#
28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38# POSSIBILITY OF SUCH DAMAGE.
39# ===========================================================================
40#
41# $Id$
42
43import os
44import sys
45import ssl
46import time
47import base64
48import hashlib
49import getpass
50import argparse
51import traceback
52
53import pytun
54import websocket
55import tornado.web
56import tornado.ioloop
57import tornado.websocket
58import tornado.httpserver
59
60
61class DebugMixIn(object):
62    def dprintf(self, msg, func):
63        if self._debug:
64            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
65            sys.stderr.write(prefix + (msg % func()))
66
67
68class EthernetFrame(object):
69    def __init__(self, data):
70        self.data = data
71
72    @property
73    def multicast(self):
74        return ord(self.data[0]) & 1
75
76    @property
77    def dst_mac(self):
78        return self.data[:6]
79
80    @property
81    def src_mac(self):
82        return self.data[6:12]
83
84    @property
85    def tagged(self):
86        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
87
88    @property
89    def vid(self):
90        if self.tagged:
91            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
92        return -1
93
94
95class SwitchingTable(DebugMixIn):
96    def __init__(self, ageout=300, debug=False):
97        self._ageout = ageout
98        self._debug = debug
99        self._table = {}
100
101    def learn(self, frame, port):
102        mac = frame.src_mac
103        vid = frame.vid
104
105        if vid not in self._table:
106            self._table[vid] = {}
107
108        self._table[vid][mac] = {'time': time.time(), 'port': port}
109        self.dprintf('learned: [%d] %s\n',
110                     lambda: (vid, mac.encode('hex')))
111
112    def lookup(self, frame):
113        mac = frame.dst_mac
114        vid = frame.vid
115
116        group = self._table.get(vid, None)
117        if not group:
118            return None
119
120        entry = group.get(mac, None)
121        if not entry:
122            return None
123
124        if time.time() - entry['time'] > self._ageout:
125            del self._table[vid][mac]
126            if not self._table[vid]:
127                del self._table[vid]
128            self.dprintf('aged out: [%d] %s\n',
129                         lambda: (vid, mac.encode('hex')))
130            return None
131
132        return entry['port']
133
134    def delete(self, port):
135        for vid in self._table.keys():
136            for mac in self._table[vid].keys():
137                if self._table[vid][mac]['port'] is port:
138                    del self._table[vid][mac]
139                    self.dprintf('deleted: [%d] %s\n',
140                                 lambda: (vid, mac.encode('hex')))
141            if not self._table[vid]:
142                del self._table[vid]
143
144
145class TapHandler(DebugMixIn):
146    READ_SIZE = 65535
147
148    def __init__(self, dev, debug=False):
149        self._debug = debug
150        self._clients = []
151        self._table = SwitchingTable(debug=debug)
152        self._tap = pytun.TunTapDevice(dev, pytun.IFF_TAP | pytun.IFF_NO_PI)
153        self._tap.up()
154        self.register_client(self)
155
156    def fileno(self):
157        return self._tap.fileno()
158
159    def register_client(self, client):
160        self._clients.append(client)
161
162    def unregister_client(self, client):
163        self._table.delete(client)
164        self._clients.remove(client)
165
166    def write_message(self, message, binary=False):
167        self._tap.write(message)
168
169    def write(self, caller, message):
170        frame = EthernetFrame(message)
171
172        self._table.learn(frame, caller)
173
174        if not frame.multicast:
175            dst = self._table.lookup(frame)
176
177            if dst:
178                dst.write_message(frame.data, True)
179                self.dprintf('sent unicast: [%d] %s -> %s\n',
180                             lambda: (frame.vid,
181                                      frame.src_mac.encode('hex'),
182                                      frame.dst_mac.encode('hex')))
183                return
184
185        clients = self._clients[:]
186        clients.remove(caller)
187
188        for c in clients:
189            c.write_message(frame.data, True)
190
191        self.dprintf('sent broadcast: [%d] %s -> %s\n',
192                     lambda: (frame.vid,
193                              frame.src_mac.encode('hex'),
194                              frame.dst_mac.encode('hex')))
195
196    def __call__(self, fd, events):
197        buf = []
198
199        while True:
200            data = self._tap.read(self.READ_SIZE)
201
202            if data:
203                buf.append(data)
204
205            if len(data) < self.READ_SIZE:
206                break
207
208        self.write(self, ''.join(buf))
209
210
211class EtherWebSocketHandler(tornado.websocket.WebSocketHandler, DebugMixIn):
212    def __init__(self, app, req, tap, debug=False):
213        super(EtherWebSocketHandler, self).__init__(app, req)
214        self._tap = tap
215        self._debug = debug
216
217    def open(self):
218        self._tap.register_client(self)
219        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
220
221    def on_message(self, message):
222        self._tap.write(self, message)
223
224    def on_close(self):
225        self._tap.unregister_client(self)
226        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
227
228
229class EtherWebSocketClient(DebugMixIn):
230    def __init__(self, tap, url, user=None, passwd=None, debug=False):
231        self._sock = None
232        self._tap = tap
233        self._url = url
234        self._debug = debug
235        self._options = {}
236
237        if user and passwd:
238            token = base64.b64encode('%s:%s' % (user, passwd))
239            auth = ['Authorization: Basic %s' % token]
240            self._options['header'] = auth
241
242    @property
243    def closed(self):
244        return not self._sock
245
246    def open(self):
247        if not self.closed:
248            raise websocket.WebSocketException('already opened')
249        self._sock = websocket.WebSocket()
250        self._sock.connect(self._url, **self._options)
251        self.dprintf('connected: %s\n', lambda: self._url)
252
253    def close(self):
254        if self.closed:
255            raise websocket.WebSocketException('already closed')
256        self._sock.close()
257        self._sock = None
258        self.dprintf('disconnected: %s\n', lambda: self._url)
259
260    def fileno(self):
261        if self.closed:
262            raise websocket.WebSocketException('closed socket')
263        return self._sock.io_sock.fileno()
264
265    def write_message(self, message, binary=False):
266        if self.closed:
267            raise websocket.WebSocketException('closed socket')
268        if binary:
269            flag = websocket.ABNF.OPCODE_BINARY
270        else:
271            flag = websocket.ABNF.OPCODE_TEXT
272        self._sock.send(message, flag)
273
274    def __call__(self, fd, events):
275        try:
276            data = self._sock.recv()
277            if data is not None:
278                self._tap.write(self, data)
279                return
280        except:
281            traceback.print_exc()
282        tornado.ioloop.IOLoop.instance().stop()
283
284
285def daemonize(nochdir=False, noclose=False):
286    if os.fork() > 0:
287        sys.exit(0)
288
289    os.setsid()
290
291    if os.fork() > 0:
292        sys.exit(0)
293
294    if not nochdir:
295        os.chdir('/')
296
297    if not noclose:
298        os.umask(0)
299        sys.stdin.close()
300        sys.stdout.close()
301        sys.stderr.close()
302        os.close(0)
303        os.close(1)
304        os.close(2)
305        sys.stdin = open(os.devnull)
306        sys.stdout = open(os.devnull, 'a')
307        sys.stderr = open(os.devnull, 'a')
308
309
310def realpath(ns, *keys):
311    for k in keys:
312        v = getattr(ns, k, None)
313        if v is not None:
314            v = os.path.realpath(v)
315            setattr(ns, k, v)
316            open(v).close()  # check readable
317    return ns
318
319
320def server_main(args):
321    def wrap_basic_auth(cls, users):
322        o_exec = cls._execute
323
324        if not users:
325            return cls
326
327        def execute(self, transforms, *args, **kwargs):
328            def auth_required():
329                self.stream.write(tornado.escape.utf8(
330                    'HTTP/1.1 401 Authorization Required\r\n'
331                    'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
332                ))
333                self.stream.close()
334
335            creds = self.request.headers.get('Authorization')
336
337            if not creds or not creds.startswith('Basic '):
338                return auth_required()
339
340            try:
341                name, passwd = base64.b64decode(creds[6:]).split(':', 1)
342                passwd = base64.b64encode(hashlib.sha1(passwd).digest())
343
344                if name not in users or users[name] != passwd:
345                    return auth_required()
346
347                return o_exec(self, transforms, *args, **kwargs)
348
349            except:
350                return auth_required()
351
352        cls._execute = execute
353        return cls
354
355    def load_htpasswd(path):
356        users = {}
357        try:
358            with open(path) as fp:
359                for line in fp:
360                    line = line.strip()
361                    if 0 <= line.find(':'):
362                        name, passwd = line.split(':', 1)
363                        if passwd.startswith('{SHA}'):
364                            users[name] = passwd[5:]
365            if not users:
366                raise ValueError('no valid users found')
367        except TypeError:
368            pass
369        return users
370
371    realpath(args, 'keyfile', 'certfile', 'htpasswd')
372
373    if args.keyfile and args.certfile:
374        ssl_options = {'keyfile': args.keyfile, 'certfile': args.certfile}
375    elif args.keyfile or args.certfile:
376        raise ValueError('both keyfile and certfile are required')
377    else:
378        ssl_options = None
379
380    if args.port is None:
381        if ssl_options:
382            args.port = 443
383        else:
384            args.port = 80
385    elif not (0 <= args.port <= 65535):
386        raise ValueError('invalid port: %s' % args.port)
387
388    handler = wrap_basic_auth(EtherWebSocketHandler,
389                              load_htpasswd(args.htpasswd))
390
391    tap = TapHandler(args.device, debug=args.debug)
392    app = tornado.web.Application([
393        (args.path, handler, {'tap': tap, 'debug': args.debug}),
394    ])
395    server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_options)
396    server.listen(args.port, address=args.address)
397
398    ioloop = tornado.ioloop.IOLoop.instance()
399    ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
400
401    if not args.foreground:
402        daemonize()
403
404    ioloop.start()
405
406
407def client_main(args):
408    realpath(args, 'cacerts')
409
410    if args.debug:
411        websocket.enableTrace(True)
412
413    if not args.insecure:
414        websocket._SSLSocketWrapper = \
415            lambda s: ssl.wrap_socket(s, cert_reqs=ssl.CERT_REQUIRED,
416                                      ca_certs=args.cacerts)
417
418    if args.user and args.passwd is None:
419        args.passwd = getpass.getpass()
420
421    tap = TapHandler(args.device, debug=args.debug)
422    client = EtherWebSocketClient(tap, args.uri,
423                                  args.user, args.passwd, args.debug)
424
425    tap.register_client(client)
426    client.open()
427
428    ioloop = tornado.ioloop.IOLoop.instance()
429    ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
430    ioloop.add_handler(client.fileno(), client, ioloop.READ)
431
432    if not args.foreground:
433        daemonize()
434
435    ioloop.start()
436
437
438def main():
439    parser = argparse.ArgumentParser()
440    parser.add_argument('--device', action='store', default='ethws%d')
441    parser.add_argument('--foreground', action='store_true', default=False)
442    parser.add_argument('--debug', action='store_true', default=False)
443
444    subparsers = parser.add_subparsers(dest='subcommand')
445
446    parser_s = subparsers.add_parser('server')
447    parser_s.add_argument('--address', action='store', default='')
448    parser_s.add_argument('--port', action='store', type=int)
449    parser_s.add_argument('--path', action='store', default='/')
450    parser_s.add_argument('--htpasswd', action='store')
451    parser_s.add_argument('--keyfile', action='store')
452    parser_s.add_argument('--certfile', action='store')
453
454    parser_c = subparsers.add_parser('client')
455    parser_c.add_argument('--uri', action='store', required=True)
456    parser_c.add_argument('--insecure', action='store_true', default=False)
457    parser_c.add_argument('--cacerts', action='store')
458    parser_c.add_argument('--user', action='store')
459    parser_c.add_argument('--passwd', action='store')
460
461    args = parser.parse_args()
462
463    if args.subcommand == 'server':
464        server_main(args)
465    elif args.subcommand == 'client':
466        client_main(args)
467
468
469if __name__ == '__main__':
470    main()
Note: See TracBrowser for help on using the repository browser.