source: etherws/trunk/etherws.py @ 164

Revision 164, 13.9 KB checked in by atzm, 12 years ago (diff)
  • switching support
  • Property svn:keywords set to Id
RevLine 
[133]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
[141]4#              Ethernet over WebSocket tunneling server/client
[133]5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
[136]9#   - websocket-client-0.7.0
10#   - tornado-2.2.1
[133]11#
[140]12# todo:
[143]13#   - servant mode support (like typical p2p software)
[140]14#
[133]15# ===========================================================================
16# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
17# All rights reserved.
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are met:
21#
22# 1. Redistributions of source code must retain the above copyright notice,
23#    this list of conditions and the following disclaimer.
24# 2. Redistributions in binary form must reproduce the above copyright
25#    notice, this list of conditions and the following disclaimer in the
26#    documentation and/or other materials provided with the distribution.
27#
28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38# POSSIBILITY OF SUCH DAMAGE.
39# ===========================================================================
40#
41# $Id$
42
43import os
44import sys
[156]45import ssl
[160]46import time
[150]47import base64
48import hashlib
[151]49import getpass
[133]50import argparse
51import threading
52
53import pytun
54import websocket
[160]55import tornado.web
[133]56import tornado.ioloop
57import tornado.websocket
[160]58import tornado.httpserver
[133]59
60
[160]61class DebugMixIn(object):
[164]62    def dprintf(self, msg, func):
[160]63        if self._debug:
64            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
[164]65            sys.stderr.write(prefix + (msg % func()))
[160]66
67
[164]68class EthernetFrame(object):
69    def __init__(self, data):
70        self.data = data
71
72    @property
73    def multicast(self):
74        return ord(self.data[0]) & 1
75
76    @property
77    def dst_mac(self):
78        return self.data[:6]
79
80    @property
81    def src_mac(self):
82        return self.data[6:12]
83
84    @property
85    def tagged(self):
86        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
87
88    @property
89    def vid(self):
90        if self.tagged:
91            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
92        return -1
93
94
95class SwitchingTable(DebugMixIn):
96    def __init__(self, ageout=300, debug=False):
97        self._ageout = ageout
98        self._debug = debug
99        self._table = {}
100
101    def learn(self, frame, port):
102        mac = frame.src_mac
103        vid = frame.vid
104
105        if vid not in self._table:
106            self._table[vid] = {}
107
108        self._table[vid][mac] = {'time': time.time(), 'port': port}
109        self.dprintf('learned: [%d] %s\n',
110                     lambda: (vid, mac.encode('hex')))
111
112    def lookup(self, frame):
113        mac = frame.dst_mac
114        vid = frame.vid
115
116        group = self._table.get(vid, None)
117        if not group:
118            return None
119
120        entry = group.get(mac, None)
121        if not entry:
122            return None
123
124        if time.time() - entry['time'] > self._ageout:
125            del self._table[vid][mac]
126            if not self._table[vid]:
127                del self._table[vid]
128            self.dprintf('aged out: [%d] %s\n',
129                         lambda: (vid, mac.encode('hex')))
130            return None
131
132        return entry['port']
133
134    def delete(self, port):
135        for vid in self._table.keys():
136            for mac in self._table[vid].keys():
137                if self._table[vid][mac]['port'] is port:
138                    del self._table[vid][mac]
139                    self.dprintf('deleted: [%d] %s\n',
140                                 lambda: (vid, mac.encode('hex')))
141            if not self._table[vid]:
142                del self._table[vid]
143
144
[160]145class TapHandler(DebugMixIn):
[162]146    READ_SIZE = 65535
147
[133]148    def __init__(self, dev, debug=False):
149        self._debug = debug
[138]150        self._clients = []
[164]151        self._table = SwitchingTable(debug=debug)
152        self._tablelock = threading.Lock()
[133]153        self._tap = pytun.TunTapDevice(dev, pytun.IFF_TAP | pytun.IFF_NO_PI)
[138]154        self._tap.up()
[162]155        self._taplock = threading.Lock()
[164]156        self.register_client(self)
[133]157
[138]158    def fileno(self):
[164]159        return self._tap.fileno()
[133]160
161    def register_client(self, client):
[162]162        self._clients.append(client)
[133]163
164    def unregister_client(self, client):
[164]165        self._table.delete(client)
[162]166        self._clients.remove(client)
[133]167
[164]168    # synchronized methods
169    def write_message(self, message, binary=False):
170        with self._taplock:
171            self._tap.write(message)
172
[133]173    def write(self, caller, message):
[164]174        frame = EthernetFrame(message)
[133]175
[164]176        with self._tablelock:
177            self._table.learn(frame, caller)
[162]178
[164]179        if not frame.multicast:
180            with self._tablelock:
181                dst = self._table.lookup(frame)
[133]182
[164]183            if dst:
184                dst.write_message(frame.data, True)
185                self.dprintf('sent unicast: [%d] %s -> %s\n',
186                             lambda: (frame.vid,
187                                      frame.src_mac.encode('hex'),
188                                      frame.dst_mac.encode('hex')))
189                return
190
191        clients = self._clients[:]
192        clients.remove(caller)
193
[162]194        for c in clients:
[164]195            c.write_message(frame.data, True)
[133]196
[164]197        self.dprintf('sent broadcast: [%d] %s -> %s\n',
198                     lambda: (frame.vid,
199                              frame.src_mac.encode('hex'),
200                              frame.dst_mac.encode('hex')))
201
[138]202    def __call__(self, fd, events):
[162]203        buf = []
[133]204
[162]205        while True:
206            with self._taplock:
207                data = self._tap.read(self.READ_SIZE)
[135]208
[162]209            if data:
210                buf.append(data)
211
212            if len(data) < self.READ_SIZE:
213                break
214
215        self.write(self, ''.join(buf))
216
217
[160]218class EtherWebSocketHandler(tornado.websocket.WebSocketHandler, DebugMixIn):
[133]219    def __init__(self, app, req, tap, debug=False):
[160]220        super(EtherWebSocketHandler, self).__init__(app, req)
[133]221        self._tap = tap
222        self._debug = debug
223
224    def open(self):
225        self._tap.register_client(self)
[164]226        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
[133]227
228    def on_message(self, message):
[139]229        self._tap.write(self, message)
[133]230
231    def on_close(self):
232        self._tap.unregister_client(self)
[164]233        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
[133]234
235
[160]236class EtherWebSocketClient(DebugMixIn):
237    def __init__(self, tap, url, user=None, passwd=None, debug=False):
[151]238        self._sock = None
239        self._tap = tap
240        self._url = url
[160]241        self._debug = debug
[151]242        self._options = {}
243
244        if user and passwd:
245            token = base64.b64encode('%s:%s' % (user, passwd))
246            auth = ['Authorization: Basic %s' % token]
247            self._options['header'] = auth
248
[160]249    @property
250    def closed(self):
251        return not self._sock
252
[151]253    def open(self):
[160]254        if not self.closed:
255            raise websocket.WebSocketException('already opened')
[151]256        self._sock = websocket.WebSocket()
257        self._sock.connect(self._url, **self._options)
[164]258        self.dprintf('connected: %s\n', lambda: self._url)
[151]259
260    def close(self):
[160]261        if self.closed:
262            raise websocket.WebSocketException('already closed')
[151]263        self._sock.close()
264        self._sock = None
[164]265        self.dprintf('disconnected: %s\n', lambda: self._url)
[151]266
267    def write_message(self, message, binary=False):
[160]268        if self.closed:
269            raise websocket.WebSocketException('closed socket')
[151]270        if binary:
271            flag = websocket.ABNF.OPCODE_BINARY
[160]272        else:
273            flag = websocket.ABNF.OPCODE_TEXT
[151]274        self._sock.send(message, flag)
275
276    def run_forever(self):
277        try:
[160]278            if self.closed:
[151]279                self.open()
280            while True:
281                data = self._sock.recv()
282                if data is None:
283                    break
284                self._tap.write(self, data)
285        finally:
286            self.close()
287
288
[134]289def daemonize(nochdir=False, noclose=False):
290    if os.fork() > 0:
291        sys.exit(0)
292
293    os.setsid()
294
295    if os.fork() > 0:
296        sys.exit(0)
297
298    if not nochdir:
299        os.chdir('/')
300
301    if not noclose:
302        os.umask(0)
303        sys.stdin.close()
304        sys.stdout.close()
305        sys.stderr.close()
306        os.close(0)
307        os.close(1)
308        os.close(2)
309        sys.stdin = open(os.devnull)
310        sys.stdout = open(os.devnull, 'a')
311        sys.stderr = open(os.devnull, 'a')
312
313
[160]314def realpath(ns, *keys):
315    for k in keys:
316        v = getattr(ns, k, None)
317        if v is not None:
318            v = os.path.realpath(v)
319            setattr(ns, k, v)
320            open(v).close()  # check readable
321    return ns
322
323
[133]324def server_main(args):
[160]325    def wrap_basic_auth(cls, users):
326        o_exec = cls._execute
327
[150]328        if not users:
329            return cls
330
[160]331        def execute(self, transforms, *args, **kwargs):
[150]332            def auth_required():
333                self.stream.write(tornado.escape.utf8(
334                    'HTTP/1.1 401 Authorization Required\r\n'
335                    'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
336                ))
337                self.stream.close()
338
[160]339            creds = self.request.headers.get('Authorization')
[150]340
[160]341            if not creds or not creds.startswith('Basic '):
342                return auth_required()
[150]343
[160]344            try:
345                name, passwd = base64.b64decode(creds[6:]).split(':', 1)
[150]346                passwd = base64.b64encode(hashlib.sha1(passwd).digest())
347
348                if name not in users or users[name] != passwd:
349                    return auth_required()
350
[160]351                return o_exec(self, transforms, *args, **kwargs)
[150]352
353            except:
354                return auth_required()
355
[160]356        cls._execute = execute
[150]357        return cls
358
359    def load_htpasswd(path):
360        users = {}
361        try:
362            with open(path) as fp:
363                for line in fp:
364                    line = line.strip()
365                    if 0 <= line.find(':'):
[160]366                        name, passwd = line.split(':', 1)
[150]367                        if passwd.startswith('{SHA}'):
368                            users[name] = passwd[5:]
[156]369            if not users:
[160]370                raise ValueError('no valid users found')
[150]371        except TypeError:
372            pass
373        return users
374
[160]375    realpath(args, 'keyfile', 'certfile', 'htpasswd')
[143]376
[160]377    if args.keyfile and args.certfile:
378        ssl_options = {'keyfile': args.keyfile, 'certfile': args.certfile}
379    elif args.keyfile or args.certfile:
[143]380        raise ValueError('both keyfile and certfile are required')
[160]381    else:
[143]382        ssl_options = None
383
[160]384    if args.port is None:
[143]385        if ssl_options:
386            args.port = 443
387        else:
388            args.port = 80
[160]389    elif not (0 <= args.port <= 65535):
390        raise ValueError('invalid port: %s' % args.port)
[143]391
[160]392    handler = wrap_basic_auth(EtherWebSocketHandler,
393                              load_htpasswd(args.htpasswd))
394
[138]395    tap = TapHandler(args.device, debug=args.debug)
[133]396    app = tornado.web.Application([
[150]397        (args.path, handler, {'tap': tap, 'debug': args.debug}),
[133]398    ])
[143]399    server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_options)
[133]400    server.listen(args.port, address=args.address)
401
[138]402    ioloop = tornado.ioloop.IOLoop.instance()
403    ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
[151]404
405    if not args.foreground:
406        daemonize()
407
[138]408    ioloop.start()
[133]409
410
411def client_main(args):
[160]412    realpath(args, 'cacerts')
413
[133]414    if args.debug:
415        websocket.enableTrace(True)
416
[156]417    if not args.insecure:
418        websocket._SSLSocketWrapper = \
419            lambda s: ssl.wrap_socket(s, cert_reqs=ssl.CERT_REQUIRED,
420                                      ca_certs=args.cacerts)
421
[160]422    if args.user and args.passwd is None:
423        args.passwd = getpass.getpass()
[143]424
[138]425    tap = TapHandler(args.device, debug=args.debug)
[160]426    client = EtherWebSocketClient(tap, args.uri,
427                                  args.user, args.passwd, args.debug)
[151]428
[133]429    tap.register_client(client)
[151]430    client.open()
[133]431
[138]432    ioloop = tornado.ioloop.IOLoop.instance()
433    ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
[151]434
[160]435    t = threading.Thread(target=ioloop.start)
436    t.setDaemon(True)
437
[151]438    if not args.foreground:
439        daemonize()
440
441    t.start()
[160]442    client.run_forever()
[133]443
[138]444
[133]445def main():
446    parser = argparse.ArgumentParser()
447    parser.add_argument('--device', action='store', default='ethws%d')
448    parser.add_argument('--foreground', action='store_true', default=False)
449    parser.add_argument('--debug', action='store_true', default=False)
450
451    subparsers = parser.add_subparsers(dest='subcommand')
452
[158]453    parser_s = subparsers.add_parser('server')
454    parser_s.add_argument('--address', action='store', default='')
455    parser_s.add_argument('--port', action='store', type=int)
456    parser_s.add_argument('--path', action='store', default='/')
457    parser_s.add_argument('--htpasswd', action='store')
458    parser_s.add_argument('--keyfile', action='store')
459    parser_s.add_argument('--certfile', action='store')
[133]460
[158]461    parser_c = subparsers.add_parser('client')
462    parser_c.add_argument('--uri', action='store', required=True)
463    parser_c.add_argument('--insecure', action='store_true', default=False)
464    parser_c.add_argument('--cacerts', action='store')
465    parser_c.add_argument('--user', action='store')
[160]466    parser_c.add_argument('--passwd', action='store')
[133]467
468    args = parser.parse_args()
469
470    if args.subcommand == 'server':
471        server_main(args)
472    elif args.subcommand == 'client':
473        client_main(args)
474
475
476if __name__ == '__main__':
477    main()
Note: See TracBrowser for help on using the repository browser.