source: etherws/trunk/etherws.py @ 175

Revision 175, 16.4 KB checked in by atzm, 12 years ago (diff)
  • acquire shared lock while loading htpasswd
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#              Ethernet over WebSocket tunneling server/client
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.2.1
11#
12# todo:
13#   - servant mode support (like typical p2p software)
14#
15# ===========================================================================
16# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
17# All rights reserved.
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are met:
21#
22# 1. Redistributions of source code must retain the above copyright notice,
23#    this list of conditions and the following disclaimer.
24# 2. Redistributions in binary form must reproduce the above copyright
25#    notice, this list of conditions and the following disclaimer in the
26#    documentation and/or other materials provided with the distribution.
27#
28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38# POSSIBILITY OF SUCH DAMAGE.
39# ===========================================================================
40#
41# $Id$
42
43import os
44import sys
45import ssl
46import time
47import fcntl
48import base64
49import hashlib
50import getpass
51import argparse
52import traceback
53
54import websocket
55import tornado.web
56import tornado.ioloop
57import tornado.websocket
58import tornado.httpserver
59
60from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
61
62
63class DebugMixIn(object):
64    def dprintf(self, msg, func=lambda: ()):
65        if self._debug:
66            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
67            sys.stderr.write(prefix + (msg % func()))
68
69
70class EthernetFrame(object):
71    def __init__(self, data):
72        self.data = data
73
74    @staticmethod
75    def multicast(mac):
76        return ord(mac[0]) & 1
77
78    @property
79    def dst_mac(self):
80        return self.data[:6]
81
82    @property
83    def src_mac(self):
84        return self.data[6:12]
85
86    @property
87    def tagged(self):
88        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
89
90    @property
91    def vid(self):
92        if self.tagged:
93            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
94        return -1
95
96
97class FDB(DebugMixIn):
98    def __init__(self, ageout, debug=False):
99        self._ageout = ageout
100        self._debug = debug
101        self._dict = {}
102
103    def lookup(self, frame):
104        mac = frame.dst_mac
105        vid = frame.vid
106
107        group = self._dict.get(vid, None)
108        if not group:
109            return None
110
111        entry = group.get(mac, None)
112        if not entry:
113            return None
114
115        if time.time() - entry['time'] > self._ageout:
116            del self._dict[vid][mac]
117            if not self._dict[vid]:
118                del self._dict[vid]
119            self.dprintf('aged out: [%d] %s\n',
120                         lambda: (vid, mac.encode('hex')))
121            return None
122
123        return entry['port']
124
125    def learn(self, port, frame):
126        mac = frame.src_mac
127        vid = frame.vid
128
129        if vid not in self._dict:
130            self._dict[vid] = {}
131
132        self._dict[vid][mac] = {'time': time.time(), 'port': port}
133        self.dprintf('learned: [%d] %s\n',
134                     lambda: (vid, mac.encode('hex')))
135
136    def delete(self, port):
137        for vid in self._dict.keys():
138            for mac in self._dict[vid].keys():
139                if self._dict[vid][mac]['port'] is port:
140                    del self._dict[vid][mac]
141                    self.dprintf('deleted: [%d] %s\n',
142                                 lambda: (vid, mac.encode('hex')))
143            if not self._dict[vid]:
144                del self._dict[vid]
145
146
147class SwitchingHub(DebugMixIn):
148    def __init__(self, fdb, debug=False):
149        self._fdb = fdb
150        self._debug = debug
151        self._ports = []
152
153    def register_port(self, port):
154        self._ports.append(port)
155
156    def unregister_port(self, port):
157        self._fdb.delete(port)
158        self._ports.remove(port)
159
160    def forward(self, src_port, frame):
161        try:
162            if not frame.multicast(frame.src_mac):
163                self._fdb.learn(src_port, frame)
164
165            if not frame.multicast(frame.dst_mac):
166                dst_port = self._fdb.lookup(frame)
167
168                if dst_port:
169                    self._unicast(frame, dst_port)
170                    return
171
172            self._broadcast(frame, src_port)
173
174        except:  # ex. received invalid frame
175            traceback.print_exc()
176
177    def _unicast(self, frame, port):
178        port.write_message(frame.data, True)
179        self.dprintf('sent unicast: [%d] %s -> %s\n',
180                     lambda: (frame.vid,
181                              frame.src_mac.encode('hex'),
182                              frame.dst_mac.encode('hex')))
183
184    def _broadcast(self, frame, *except_ports):
185        ports = self._ports[:]
186        for port in except_ports:
187            ports.remove(port)
188        for port in ports:
189            port.write_message(frame.data, True)
190        self.dprintf('sent broadcast: [%d] %s -> %s\n',
191                     lambda: (frame.vid,
192                              frame.src_mac.encode('hex'),
193                              frame.dst_mac.encode('hex')))
194
195
196class TapHandler(DebugMixIn):
197    READ_SIZE = 65535
198
199    def __init__(self, switch, dev, debug=False):
200        self._switch = switch
201        self._dev = dev
202        self._debug = debug
203        self._tap = None
204
205    @property
206    def closed(self):
207        return not self._tap
208
209    def open(self):
210        if not self.closed:
211            raise ValueError('already opened')
212        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
213        self._tap.up()
214        self._switch.register_port(self)
215
216    def close(self):
217        if self.closed:
218            raise ValueError('I/O operation on closed tap')
219        self._switch.unregister_port(self)
220        self._tap.close()
221        self._tap = None
222
223    def fileno(self):
224        if self.closed:
225            raise ValueError('I/O operation on closed tap')
226        return self._tap.fileno()
227
228    def write_message(self, message, binary=False):
229        if self.closed:
230            raise ValueError('I/O operation on closed tap')
231        self._tap.write(message)
232
233    def __call__(self, fd, events):
234        try:
235            self._switch.forward(self, EthernetFrame(self._read()))
236            return
237        except:
238            traceback.print_exc()
239        tornado.ioloop.IOLoop.instance().stop()  # XXX: should unregister fd
240
241    def _read(self):
242        if self.closed:
243            raise ValueError('I/O operation on closed tap')
244        buf = []
245        while True:
246            buf.append(self._tap.read(self.READ_SIZE))
247            if len(buf[-1]) < self.READ_SIZE:
248                break
249        return ''.join(buf)
250
251
252class EtherWebSocketHandler(tornado.websocket.WebSocketHandler, DebugMixIn):
253    def __init__(self, app, req, switch, debug=False):
254        super(EtherWebSocketHandler, self).__init__(app, req)
255        self._switch = switch
256        self._debug = debug
257
258    def open(self):
259        self._switch.register_port(self)
260        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
261
262    def on_message(self, message):
263        self._switch.forward(self, EthernetFrame(message))
264
265    def on_close(self):
266        self._switch.unregister_port(self)
267        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
268
269
270class EtherWebSocketClient(DebugMixIn):
271    def __init__(self, switch, url, cred=None, debug=False):
272        self._switch = switch
273        self._url = url
274        self._debug = debug
275        self._sock = None
276        self._options = {}
277
278        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
279            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
280            auth = ['Authorization: Basic %s' % token]
281            self._options['header'] = auth
282
283    @property
284    def closed(self):
285        return not self._sock
286
287    def open(self):
288        if not self.closed:
289            raise websocket.WebSocketException('already opened')
290        self._sock = websocket.WebSocket()
291        self._sock.connect(self._url, **self._options)
292        self._switch.register_port(self)
293        self.dprintf('connected: %s\n', lambda: self._url)
294
295    def close(self):
296        if self.closed:
297            raise websocket.WebSocketException('already closed')
298        self._switch.unregister_port(self)
299        self._sock.close()
300        self._sock = None
301        self.dprintf('disconnected: %s\n', lambda: self._url)
302
303    def fileno(self):
304        if self.closed:
305            raise websocket.WebSocketException('closed socket')
306        return self._sock.io_sock.fileno()
307
308    def write_message(self, message, binary=False):
309        if self.closed:
310            raise websocket.WebSocketException('closed socket')
311        if binary:
312            flag = websocket.ABNF.OPCODE_BINARY
313        else:
314            flag = websocket.ABNF.OPCODE_TEXT
315        self._sock.send(message, flag)
316
317    def __call__(self, fd, events):
318        try:
319            data = self._sock.recv()
320            if data is not None:
321                self._switch.forward(self, EthernetFrame(data))
322                return
323        except:
324            traceback.print_exc()
325        tornado.ioloop.IOLoop.instance().stop()  # XXX: should unregister fd
326
327
328class Htpasswd(object):
329    def __init__(self, path):
330        self._path = path
331        self._stat = None
332        self._data = {}
333
334    def auth(self, name, passwd):
335        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
336        return self._data.get(name) == passwd
337
338    def load(self):
339        old_stat = self._stat
340
341        with open(self._path) as fp:
342            fileno = fp.fileno()
343            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
344            self._stat = os.fstat(fileno)
345
346            unchanged = old_stat and \
347                        old_stat.st_ino == self._stat.st_ino and \
348                        old_stat.st_dev == self._stat.st_dev and \
349                        old_stat.st_mtime == self._stat.st_mtime
350
351            if not unchanged:
352                self._data = self._parse(fp)
353
354    def _parse(self, fp):
355        data = {}
356        for line in fp:
357            line = line.strip()
358            if 0 <= line.find(':'):
359                name, passwd = line.split(':', 1)
360                if passwd.startswith('{SHA}'):
361                    data[name] = passwd[5:]
362        return data
363
364
365def wrap_basic_auth(handler_class, htpasswd_path):
366    if not htpasswd_path:
367        return handler_class
368
369    old_execute = handler_class._execute
370    htpasswd = Htpasswd(htpasswd_path)
371
372    def execute(self, transforms, *args, **kwargs):
373        def auth_required():
374            self.stream.write(tornado.escape.utf8(
375                'HTTP/1.1 401 Authorization Required\r\n'
376                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
377            ))
378            self.stream.close()
379
380        creds = self.request.headers.get('Authorization')
381
382        if not creds or not creds.startswith('Basic '):
383            return auth_required()
384
385        try:
386            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
387            htpasswd.load()
388
389            if not htpasswd.auth(name, passwd):
390                return auth_required()
391
392            return old_execute(self, transforms, *args, **kwargs)
393
394        except:
395            return auth_required()
396
397    handler_class._execute = execute
398    return handler_class
399
400
401def daemonize(nochdir=False, noclose=False):
402    if os.fork() > 0:
403        sys.exit(0)
404
405    os.setsid()
406
407    if os.fork() > 0:
408        sys.exit(0)
409
410    if not nochdir:
411        os.chdir('/')
412
413    if not noclose:
414        os.umask(0)
415        sys.stdin.close()
416        sys.stdout.close()
417        sys.stderr.close()
418        os.close(0)
419        os.close(1)
420        os.close(2)
421        sys.stdin = open(os.devnull)
422        sys.stdout = open(os.devnull, 'a')
423        sys.stderr = open(os.devnull, 'a')
424
425
426def realpath(ns, *keys):
427    for k in keys:
428        v = getattr(ns, k, None)
429        if v is not None:
430            v = os.path.realpath(v)
431            setattr(ns, k, v)
432            open(v).close()  # check readable
433    return ns
434
435
436def server_main(args):
437    realpath(args, 'keyfile', 'certfile', 'htpasswd')
438
439    if args.keyfile and args.certfile:
440        ssl_options = {'keyfile': args.keyfile, 'certfile': args.certfile}
441    elif args.keyfile or args.certfile:
442        raise ValueError('both keyfile and certfile are required')
443    else:
444        ssl_options = None
445
446    if args.port is None:
447        if ssl_options:
448            args.port = 443
449        else:
450            args.port = 80
451    elif not (0 <= args.port <= 65535):
452        raise ValueError('invalid port: %s' % args.port)
453
454    if args.ageout <= 0:
455        raise ValueError('invalid ageout: %s' % args.ageout)
456
457    ioloop = tornado.ioloop.IOLoop.instance()
458    fdb = FDB(ageout=args.ageout, debug=args.debug)
459    switch = SwitchingHub(fdb, debug=args.debug)
460
461    handler = wrap_basic_auth(EtherWebSocketHandler, args.htpasswd)
462    srv = (args.path, handler, {'switch': switch, 'debug': args.debug})
463    app = tornado.web.Application([srv])
464    server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_options)
465    server.listen(args.port, address=args.address)
466
467    for dev in args.device:
468        tap = TapHandler(switch, dev, debug=args.debug)
469        tap.open()
470        ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
471
472    if not args.foreground:
473        daemonize()
474
475    ioloop.start()
476
477
478def client_main(args):
479    realpath(args, 'cacerts')
480
481    if args.debug:
482        websocket.enableTrace(True)
483
484    if args.insecure:
485        websocket._SSLSocketWrapper = \
486            lambda s: ssl.wrap_socket(s)
487    else:
488        websocket._SSLSocketWrapper = \
489            lambda s: ssl.wrap_socket(s, cert_reqs=ssl.CERT_REQUIRED,
490                                      ca_certs=args.cacerts)
491
492    if args.ageout <= 0:
493        raise ValueError('invalid ageout: %s' % args.ageout)
494
495    if args.user and args.passwd is None:
496        args.passwd = getpass.getpass()
497
498    cred = {'user': args.user, 'passwd': args.passwd}
499    ioloop = tornado.ioloop.IOLoop.instance()
500    fdb = FDB(ageout=args.ageout, debug=args.debug)
501    switch = SwitchingHub(fdb, debug=args.debug)
502
503    for uri in args.uri:
504        client = EtherWebSocketClient(switch, uri, cred, args.debug)
505        client.open()
506        ioloop.add_handler(client.fileno(), client, ioloop.READ)
507
508    for dev in args.device:
509        tap = TapHandler(switch, dev, debug=args.debug)
510        tap.open()
511        ioloop.add_handler(tap.fileno(), tap, ioloop.READ)
512
513    if not args.foreground:
514        daemonize()
515
516    ioloop.start()
517
518
519def main():
520    parser = argparse.ArgumentParser()
521    parser.add_argument('--device', action='append', default=[])
522    parser.add_argument('--ageout', action='store', type=int, default=300)
523    parser.add_argument('--foreground', action='store_true', default=False)
524    parser.add_argument('--debug', action='store_true', default=False)
525
526    subparsers = parser.add_subparsers(dest='subcommand')
527
528    parser_s = subparsers.add_parser('server')
529    parser_s.add_argument('--address', action='store', default='')
530    parser_s.add_argument('--port', action='store', type=int)
531    parser_s.add_argument('--path', action='store', default='/')
532    parser_s.add_argument('--htpasswd', action='store')
533    parser_s.add_argument('--keyfile', action='store')
534    parser_s.add_argument('--certfile', action='store')
535
536    parser_c = subparsers.add_parser('client')
537    parser_c.add_argument('--uri', action='append', default=[])
538    parser_c.add_argument('--insecure', action='store_true', default=False)
539    parser_c.add_argument('--cacerts', action='store')
540    parser_c.add_argument('--user', action='store')
541    parser_c.add_argument('--passwd', action='store')
542
543    args = parser.parse_args()
544
545    if args.subcommand == 'server':
546        server_main(args)
547    elif args.subcommand == 'client':
548        client_main(args)
549
550
551if __name__ == '__main__':
552    main()
Note: See TracBrowser for help on using the repository browser.