source: etherws/trunk/etherws.py @ 164

Revision 164, 13.9 KB checked in by atzm, 12 years ago (diff)
  • switching support
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#              Ethernet over WebSocket tunneling server/client
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.2.1
11#
12# todo:
13#   - servant mode support (like typical p2p software)
14#
15# ===========================================================================
16# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
17# All rights reserved.
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are met:
21#
22# 1. Redistributions of source code must retain the above copyright notice,
23#    this list of conditions and the following disclaimer.
24# 2. Redistributions in binary form must reproduce the above copyright
25#    notice, this list of conditions and the following disclaimer in the
26#    documentation and/or other materials provided with the distribution.
27#
28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38# POSSIBILITY OF SUCH DAMAGE.
39# ===========================================================================
40#
41# $Id$
42
43import os
44import sys
45import ssl
46import time
47import base64
48import hashlib
49import getpass
50import argparse
51import threading
52
53import pytun
54import websocket
55import tornado.web
56import tornado.ioloop
57import tornado.websocket
58import tornado.httpserver
59
60
61class DebugMixIn(object):
62    def dprintf(self, msg, func):
63        if self._debug:
64            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
65            sys.stderr.write(prefix + (msg % func()))
66
67
68class EthernetFrame(object):
69    def __init__(self, data):
70        self.data = data
71
72    @property
73    def multicast(self):
74        return ord(self.data[0]) & 1
75
76    @property
77    def dst_mac(self):
78        return self.data[:6]
79
80    @property
81    def src_mac(self):
82        return self.data[6:12]
83
84    @property
85    def tagged(self):
86        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
87
88    @property
89    def vid(self):
90        if self.tagged:
91            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
92        return -1
93
94
95class SwitchingTable(DebugMixIn):
96    def __init__(self, ageout=300, debug=False):
97        self._ageout = ageout
98        self._debug = debug
99        self._table = {}
100
101    def learn(self, frame, port):
102        mac = frame.src_mac
103        vid = frame.vid
104
105        if vid not in self._table:
106            self._table[vid] = {}
107
108        self._table[vid][mac] = {'time': time.time(), 'port': port}
109        self.dprintf('learned: [%d] %s\n',
110                     lambda: (vid, mac.encode('hex')))
111
112    def lookup(self, frame):
113        mac = frame.dst_mac
114        vid = frame.vid
115
116        group = self._table.get(vid, None)
117        if not group:
118            return None
119
120        entry = group.get(mac, None)
121        if not entry:
122            return None
123
124        if time.time() - entry['time'] > self._ageout:
125            del self._table[vid][mac]
126            if not self._table[vid]:
127                del self._table[vid]
128            self.dprintf('aged out: [%d] %s\n',
129                         lambda: (vid, mac.encode('hex')))
130            return None
131
132        return entry['port']
133
134    def delete(self, port):
135        for vid in self._table.keys():
136            for mac in self._table[vid].keys():
137                if self._table[vid][mac]['port'] is port:
138                    del self._table[vid][mac]
139                    self.dprintf('deleted: [%d] %s\n',
140                                 lambda: (vid, mac.encode('hex')))
141            if not self._table[vid]:
142                del self._table[vid]
143
144
145class TapHandler(DebugMixIn):
146    READ_SIZE = 65535
147
148    def __init__(self, dev, debug=False):
149        self._debug = debug
150        self._clients = []
151        self._table = SwitchingTable(debug=debug)
152        self._tablelock = threading.Lock()
153        self._tap = pytun.TunTapDevice(dev, pytun.IFF_TAP | pytun.IFF_NO_PI)
154        self._tap.up()
155        self._taplock = threading.Lock()
156        self.register_client(self)
157
158    def fileno(self):
159        return self._tap.fileno()
160
161    def register_client(self, client):
162        self._clients.append(client)
163
164    def unregister_client(self, client):
165        self._table.delete(client)
166        self._clients.remove(client)
167
168    # synchronized methods
169    def write_message(self, message, binary=False):
170        with self._taplock:
171            self._tap.write(message)
172
173    def write(self, caller, message):
174        frame = EthernetFrame(message)
175
176        with self._tablelock:
177            self._table.learn(frame, caller)
178
179        if not frame.multicast:
180            with self._tablelock:
181                dst = self._table.lookup(frame)
182
183            if dst:
184                dst.write_message(frame.data, True)
185                self.dprintf('sent unicast: [%d] %s -> %s\n',
186                             lambda: (frame.vid,
187                                      frame.src_mac.encode('hex'),
188                                      frame.dst_mac.encode('hex')))
189                return
190
191        clients = self._clients[:]
192        clients.remove(caller)
193
194        for c in clients:
195            c.write_message(frame.data, True)
196
197        self.dprintf('sent broadcast: [%d] %s -> %s\n',
198                     lambda: (frame.vid,
199                              frame.src_mac.encode('hex'),
200                              frame.dst_mac.encode('hex')))
201
202    def __call__(self, fd, events):
203        buf = []
204
205        while True:
206            with self._taplock:
207                data = self._tap.read(self.READ_SIZE)
208
209            if data:
210                buf.append(data)
211
212            if len(data) < self.READ_SIZE:
213                break
214
215        self.write(self, ''.join(buf))
216
217
218class EtherWebSocketHandler(tornado.websocket.WebSocketHandler, DebugMixIn):
219    def __init__(self, app, req, tap, debug=False):
220        super(EtherWebSocketHandler, self).__init__(app, req)
221        self._tap = tap
222        self._debug = debug
223
224    def open(self):
225        self._tap.register_client(self)
226        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
227
228    def on_message(self, message):
229        self._tap.write(self, message)
230
231    def on_close(self):
232        self._tap.unregister_client(self)
233        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
234
235
236class EtherWebSocketClient(DebugMixIn):
237    def __init__(self, tap, url, user=None, passwd=None, debug=False):
238        self._sock = None
239        self._tap = tap
240        self._url = url
241        self._debug = debug
242        self._options = {}
243
244        if user and passwd:
245            token = base64.b64encode('%s:%s' % (user, passwd))
246            auth = ['Authorization: Basic %s' % token]
247            self._options['header'] = auth
248
249    @property
250    def closed(self):
251        return not self._sock
252
253    def open(self):
254        if not self.closed:
255            raise websocket.WebSocketException('already opened')
256        self._sock = websocket.WebSocket()
257        self._sock.connect(self._url, **self._options)
258        self.dprintf('connected: %s\n', lambda: self._url)
259
260    def close(self):
261        if self.closed:
262            raise websocket.WebSocketException('already closed')
263        self._sock.close()
264        self._sock = None
265        self.dprintf('disconnected: %s\n', lambda: self._url)
266
267    def write_message(self, message, binary=False):
268        if self.closed:
269            raise websocket.WebSocketException('closed socket')
270        if binary:
271            flag = websocket.ABNF.OPCODE_BINARY
272        else:
273            flag = websocket.ABNF.OPCODE_TEXT
274        self._sock.send(message, flag)
275
276    def run_forever(self):
277        try:
278            if self.closed:
279                self.open()
280            while True:
281                data = self._sock.recv()
282                if data is None:
283                    break
284                self._tap.write(self, data)
285        finally:
286            self.close()
287
288
289def daemonize(nochdir=False, noclose=False):
290    if os.fork() > 0:
291        sys.exit(0)
292
293    os.setsid()
294
295    if os.fork() > 0:
296        sys.exit(0)
297
298    if not nochdir:
299        os.chdir('/')
300
301    if not noclose:
302        os.umask(0)
303        sys.stdin.close()
304        sys.stdout.close()
305        sys.stderr.close()
306        os.close(0)
307        os.close(1)
308        os.close(2)
309        sys.stdin = open(os.devnull)
310        sys.stdout = open(os.devnull, 'a')
311        sys.stderr = open(os.devnull, 'a')
312
313
314def realpath(ns, *keys):
315    for k in keys:
316        v = getattr(ns, k, None)
317        if v is not None:
318            v = os.path.realpath(v)
319            setattr(ns, k, v)
320            open(v).close()  # check readable
321    return ns
322
323
324def server_main(args):
325    def wrap_basic_auth(cls, users):
326        o_exec = cls._execute
327
328        if not users:
329            return cls
330
331        def execute(self, transforms, *args, **kwargs):
332            def auth_required():
333                self.stream.write(tornado.escape.utf8(
334                    'HTTP/1.1 401 Authorization Required\r\n'
335                    'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
336                ))
337                self.stream.close()
338
339            creds = self.request.headers.get('Authorization')
340
341            if not creds or not creds.startswith('Basic '):
342                return auth_required()
343
344            try:
345                name, passwd = base64.b64decode(creds[6:]).split(':', 1)
346                passwd = base64.b64encode(hashlib.sha1(passwd).digest())
347
348                if name not in users or users[name] != passwd:
349                    return auth_required()
350
351                return o_exec(self, transforms, *args, **kwargs)
352
353            except:
354                return auth_required()
355
356        cls._execute = execute
357        return cls
358
359    def load_htpasswd(path):
360        users = {}
361        try:
362            with open(path) as fp:
363                for line in fp:
364                    line = line.strip()
365                    if 0 <= line.find(':'):
366                        name, passwd = line.split(':', 1)
367                        if passwd.startswith('{SHA}'):
368                            users[name] = passwd[5:]
369            if not users:
370                raise ValueError('no valid users found')
371        except TypeError:
372            pass
373        return users
374
375    realpath(args, 'keyfile', 'certfile', 'htpasswd')
376
377    if args.keyfile and args.certfile:
378        ssl_options = {'keyfile': args.keyfile, 'certfile': args.certfile}
379    elif args.keyfile or args.certfile:
380        raise ValueError('both keyfile and certfile are required')
381    else:
382        ssl_options = None
383
384    if args.port is None:
385        if ssl_options:
386            args.port = 443
387        else:
388            args.port = 80
389    elif not (0 <= args.port <= 65535):
390        raise ValueError('invalid port: %s' % args.port)
391
392    handler = wrap_basic_auth(EtherWebSocketHandler,
393                              load_htpasswd(args.htpasswd))
394
395    tap = TapHandler(args.device, debug=args.debug)
396    app = tornado.web.Application([
397        (args.path, handler, {'tap': tap, 'debug': args.debug}),
398    ])
399    server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_options)
400    server.listen(args.port, address=args.address)
401
402    ioloop = tornado.ioloop.IOLoop.instance()
403    ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
404
405    if not args.foreground:
406        daemonize()
407
408    ioloop.start()
409
410
411def client_main(args):
412    realpath(args, 'cacerts')
413
414    if args.debug:
415        websocket.enableTrace(True)
416
417    if not args.insecure:
418        websocket._SSLSocketWrapper = \
419            lambda s: ssl.wrap_socket(s, cert_reqs=ssl.CERT_REQUIRED,
420                                      ca_certs=args.cacerts)
421
422    if args.user and args.passwd is None:
423        args.passwd = getpass.getpass()
424
425    tap = TapHandler(args.device, debug=args.debug)
426    client = EtherWebSocketClient(tap, args.uri,
427                                  args.user, args.passwd, args.debug)
428
429    tap.register_client(client)
430    client.open()
431
432    ioloop = tornado.ioloop.IOLoop.instance()
433    ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
434
435    t = threading.Thread(target=ioloop.start)
436    t.setDaemon(True)
437
438    if not args.foreground:
439        daemonize()
440
441    t.start()
442    client.run_forever()
443
444
445def main():
446    parser = argparse.ArgumentParser()
447    parser.add_argument('--device', action='store', default='ethws%d')
448    parser.add_argument('--foreground', action='store_true', default=False)
449    parser.add_argument('--debug', action='store_true', default=False)
450
451    subparsers = parser.add_subparsers(dest='subcommand')
452
453    parser_s = subparsers.add_parser('server')
454    parser_s.add_argument('--address', action='store', default='')
455    parser_s.add_argument('--port', action='store', type=int)
456    parser_s.add_argument('--path', action='store', default='/')
457    parser_s.add_argument('--htpasswd', action='store')
458    parser_s.add_argument('--keyfile', action='store')
459    parser_s.add_argument('--certfile', action='store')
460
461    parser_c = subparsers.add_parser('client')
462    parser_c.add_argument('--uri', action='store', required=True)
463    parser_c.add_argument('--insecure', action='store_true', default=False)
464    parser_c.add_argument('--cacerts', action='store')
465    parser_c.add_argument('--user', action='store')
466    parser_c.add_argument('--passwd', action='store')
467
468    args = parser.parse_args()
469
470    if args.subcommand == 'server':
471        server_main(args)
472    elif args.subcommand == 'client':
473        client_main(args)
474
475
476if __name__ == '__main__':
477    main()
Note: See TracBrowser for help on using the repository browser.