source: etherws/trunk/etherws.py @ 184

Revision 184, 21.7 KB checked in by atzm, 12 years ago (diff)
  • enables controller options
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#              Ethernet over WebSocket tunneling server/client
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.3
11#
12# todo:
13#   - servant mode support (like typical p2p software)
14#
15# ===========================================================================
16# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
17# All rights reserved.
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are met:
21#
22# 1. Redistributions of source code must retain the above copyright notice,
23#    this list of conditions and the following disclaimer.
24# 2. Redistributions in binary form must reproduce the above copyright
25#    notice, this list of conditions and the following disclaimer in the
26#    documentation and/or other materials provided with the distribution.
27#
28# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38# POSSIBILITY OF SUCH DAMAGE.
39# ===========================================================================
40#
41# $Id$
42
43import os
44import sys
45import ssl
46import time
47import json
48import fcntl
49import base64
50import hashlib
51import getpass
52import argparse
53import traceback
54
55import websocket
56
57from tornado.web import Application, RequestHandler
58from tornado.websocket import WebSocketHandler
59from tornado.httpserver import HTTPServer
60from tornado.ioloop import IOLoop
61
62from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
63
64
65class DebugMixIn(object):
66    def dprintf(self, msg, func=lambda: ()):
67        if self._debug:
68            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
69            sys.stderr.write(prefix + (msg % func()))
70
71
72class EthernetFrame(object):
73    def __init__(self, data):
74        self.data = data
75
76    @property
77    def dst_multicast(self):
78        return ord(self.data[0]) & 1
79
80    @property
81    def src_multicast(self):
82        return ord(self.data[6]) & 1
83
84    @property
85    def dst_mac(self):
86        return self.data[:6]
87
88    @property
89    def src_mac(self):
90        return self.data[6:12]
91
92    @property
93    def tagged(self):
94        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
95
96    @property
97    def vid(self):
98        if self.tagged:
99            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
100        return 0
101
102
103class FDB(DebugMixIn):
104    def __init__(self, ageout, debug=False):
105        self._ageout = ageout
106        self._debug = debug
107        self._dict = {}
108
109    def lookup(self, frame):
110        mac = frame.dst_mac
111        vid = frame.vid
112
113        group = self._dict.get(vid)
114        if not group:
115            return None
116
117        entry = group.get(mac)
118        if not entry:
119            return None
120
121        if time.time() - entry['time'] > self._ageout:
122            port = self._dict[vid][mac]['port']
123            del self._dict[vid][mac]
124            if not self._dict[vid]:
125                del self._dict[vid]
126            self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
127                         lambda: (port.number, vid, mac.encode('hex')))
128            return None
129
130        return entry['port']
131
132    def learn(self, port, frame):
133        mac = frame.src_mac
134        vid = frame.vid
135
136        if vid not in self._dict:
137            self._dict[vid] = {}
138
139        self._dict[vid][mac] = {'time': time.time(), 'port': port}
140        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
141                     lambda: (port.number, vid, mac.encode('hex')))
142
143    def delete(self, port):
144        for vid in self._dict.keys():
145            for mac in self._dict[vid].keys():
146                if self._dict[vid][mac]['port'].number == port.number:
147                    del self._dict[vid][mac]
148                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
149                                 lambda: (port.number, vid, mac.encode('hex')))
150            if not self._dict[vid]:
151                del self._dict[vid]
152
153
154class SwitchPort(object):
155    def __init__(self, number, interface):
156        self.number = number
157        self.interface = interface
158        self.tx = 0
159        self.rx = 0
160        self.shut = False
161
162    @staticmethod
163    def cmp_by_number(x, y):
164        return cmp(x.number, y.number)
165
166
167class SwitchingHub(DebugMixIn):
168    def __init__(self, fdb, debug=False):
169        self._fdb = fdb
170        self._debug = debug
171        self._table = {}
172        self._next = 1
173
174    @property
175    def portlist(self):
176        return sorted(self._table.itervalues(), cmp=SwitchPort.cmp_by_number)
177
178    def shut_port(self, portnum, flag=True):
179        self._table[portnum].shut = flag
180
181    def get_port(self, portnum):
182        return self._table[portnum]
183
184    def register_port(self, interface):
185        interface._switch_portnum = self._next  # XXX
186        self._table[self._next] = SwitchPort(self._next, interface)
187        self._next += 1
188
189    def unregister_port(self, interface):
190        self._fdb.delete(self._table[interface._switch_portnum])
191        del self._table[interface._switch_portnum]
192        del interface._switch_portnum
193
194    def send(self, dst_interfaces, frame):
195        ports = sorted((self._table[i._switch_portnum] for i in dst_interfaces
196                        if not self._table[i._switch_portnum].shut),
197                       cmp=SwitchPort.cmp_by_number)
198
199        for p in ports:
200            p.interface.write_message(frame.data, True)
201            p.tx += 1
202
203        if ports:
204            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
205                         lambda: (','.join(str(p.number) for p in ports),
206                                  frame.vid,
207                                  frame.src_mac.encode('hex'),
208                                  frame.dst_mac.encode('hex')))
209
210    def receive(self, src_interface, frame):
211        port = self._table[src_interface._switch_portnum]
212
213        if not port.shut:
214            port.rx += 1
215            self._forward(port, frame)
216
217    def _forward(self, src_port, frame):
218        try:
219            if not frame.src_multicast:
220                self._fdb.learn(src_port, frame)
221
222            if not frame.dst_multicast:
223                dst_port = self._fdb.lookup(frame)
224
225                if dst_port:
226                    self.send([dst_port.interface], frame)
227                    return
228
229            ports = set(self._table.itervalues()) - set([src_port])
230            self.send((p.interface for p in ports), frame)
231
232        except:  # ex. received invalid frame
233            traceback.print_exc()
234
235
236class Htpasswd(object):
237    def __init__(self, path):
238        self._path = path
239        self._stat = None
240        self._data = {}
241
242    def auth(self, name, passwd):
243        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
244        return self._data.get(name) == passwd
245
246    def load(self):
247        old_stat = self._stat
248
249        with open(self._path) as fp:
250            fileno = fp.fileno()
251            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
252            self._stat = os.fstat(fileno)
253
254            unchanged = old_stat and \
255                        old_stat.st_ino == self._stat.st_ino and \
256                        old_stat.st_dev == self._stat.st_dev and \
257                        old_stat.st_mtime == self._stat.st_mtime
258
259            if not unchanged:
260                self._data = self._parse(fp)
261
262        return self
263
264    def _parse(self, fp):
265        data = {}
266        for line in fp:
267            line = line.strip()
268            if 0 <= line.find(':'):
269                name, passwd = line.split(':', 1)
270                if passwd.startswith('{SHA}'):
271                    data[name] = passwd[5:]
272        return data
273
274
275class BasicAuthMixIn(object):
276    def _execute(self, transforms, *args, **kwargs):
277        def do_execute():
278            sp = super(BasicAuthMixIn, self)
279            return sp._execute(transforms, *args, **kwargs)
280
281        def auth_required():
282            self.stream.write(tornado.escape.utf8(
283                'HTTP/1.1 401 Authorization Required\r\n'
284                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
285            ))
286            self.stream.close()
287
288        try:
289            if not self._htpasswd:
290                return do_execute()
291
292            creds = self.request.headers.get('Authorization')
293
294            if not creds or not creds.startswith('Basic '):
295                return auth_required()
296
297            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
298
299            if self._htpasswd.load().auth(name, passwd):
300                return do_execute()
301        except:
302            traceback.print_exc()
303
304        return auth_required()
305
306
307class TapHandler(DebugMixIn):
308    READ_SIZE = 65535
309
310    def __init__(self, ioloop, switch, dev, debug=False):
311        self._ioloop = ioloop
312        self._switch = switch
313        self._dev = dev
314        self._debug = debug
315        self._tap = None
316
317    @property
318    def closed(self):
319        return not self._tap
320
321    def get_type(self):
322        return 'tap'
323
324    def get_name(self):
325        if self.closed:
326            return self._dev
327        return self._tap.name
328
329    def open(self):
330        if not self.closed:
331            raise ValueError('already opened')
332        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
333        self._tap.up()
334        self._switch.register_port(self)
335        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
336
337    def close(self):
338        if self.closed:
339            raise ValueError('I/O operation on closed tap')
340        self._ioloop.remove_handler(self.fileno())
341        self._switch.unregister_port(self)
342        self._tap.close()
343        self._tap = None
344
345    def fileno(self):
346        if self.closed:
347            raise ValueError('I/O operation on closed tap')
348        return self._tap.fileno()
349
350    def write_message(self, message, binary=False):
351        if self.closed:
352            raise ValueError('I/O operation on closed tap')
353        self._tap.write(message)
354
355    def __call__(self, fd, events):
356        try:
357            self._switch.receive(self, EthernetFrame(self._read()))
358            return
359        except:
360            traceback.print_exc()
361        self.close()
362
363    def _read(self):
364        if self.closed:
365            raise ValueError('I/O operation on closed tap')
366        buf = []
367        while True:
368            buf.append(self._tap.read(self.READ_SIZE))
369            if len(buf[-1]) < self.READ_SIZE:
370                break
371        return ''.join(buf)
372
373
374class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
375    def __init__(self, app, req, switch, htpasswd=None, debug=False):
376        super(EtherWebSocketHandler, self).__init__(app, req)
377        self._switch = switch
378        self._htpasswd = htpasswd
379        self._debug = debug
380
381    def get_type(self):
382        return 'server'
383
384    def get_name(self):
385        return self.request.remote_ip
386
387    def open(self):
388        self._switch.register_port(self)
389        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
390
391    def on_message(self, message):
392        self._switch.receive(self, EthernetFrame(message))
393
394    def on_close(self):
395        self._switch.unregister_port(self)
396        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
397
398
399class EtherWebSocketClient(DebugMixIn):
400    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
401        self._ioloop = ioloop
402        self._switch = switch
403        self._url = url
404        self._ssl = ssl_
405        self._debug = debug
406        self._sock = None
407        self._options = {}
408
409        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
410            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
411            auth = ['Authorization: Basic %s' % token]
412            self._options['header'] = auth
413
414    @property
415    def closed(self):
416        return not self._sock
417
418    def get_type(self):
419        return 'client'
420
421    def get_name(self):
422        return self._url
423
424    def open(self):
425        sslwrap = websocket._SSLSocketWrapper
426
427        if not self.closed:
428            raise websocket.WebSocketException('already opened')
429
430        if self._ssl:
431            websocket._SSLSocketWrapper = self._ssl
432
433        try:
434            self._sock = websocket.WebSocket()
435            self._sock.connect(self._url, **self._options)
436            self._switch.register_port(self)
437            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
438            self.dprintf('connected: %s\n', lambda: self._url)
439        finally:
440            websocket._SSLSocketWrapper = sslwrap
441
442    def close(self):
443        if self.closed:
444            raise websocket.WebSocketException('already closed')
445        self._ioloop.remove_handler(self.fileno())
446        self._switch.unregister_port(self)
447        self._sock.close()
448        self._sock = None
449        self.dprintf('disconnected: %s\n', lambda: self._url)
450
451    def fileno(self):
452        if self.closed:
453            raise websocket.WebSocketException('closed socket')
454        return self._sock.io_sock.fileno()
455
456    def write_message(self, message, binary=False):
457        if self.closed:
458            raise websocket.WebSocketException('closed socket')
459        if binary:
460            flag = websocket.ABNF.OPCODE_BINARY
461        else:
462            flag = websocket.ABNF.OPCODE_TEXT
463        self._sock.send(message, flag)
464
465    def __call__(self, fd, events):
466        try:
467            data = self._sock.recv()
468            if data is not None:
469                self._switch.receive(self, EthernetFrame(data))
470                return
471        except:
472            traceback.print_exc()
473        self.close()
474
475
476class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
477    NAMESPACE = 'etherws.control'
478
479    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
480        super(EtherWebSocketControlHandler, self).__init__(app, req)
481        self._ioloop = ioloop
482        self._switch = switch
483        self._htpasswd = htpasswd
484        self._debug = debug
485
486    def post(self):
487        id_ = None
488
489        try:
490            req = json.loads(self.request.body)
491            method = req['method']
492            params = req['params']
493            id_ = req.get('id')
494
495            if not method.startswith(self.NAMESPACE + '.'):
496                raise ValueError('invalid method: %s' % method)
497
498            if not isinstance(params, list):
499                raise ValueError('invalid params: %s' % params)
500
501            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
502            result = getattr(self, handler)(params)
503            self.finish({'result': result, 'error': None, 'id': id_})
504
505        except Exception as e:
506            traceback.print_exc()
507            self.finish({'result': None, 'error': str(e), 'id': id_})
508
509    def handle_listPort(self, params):
510        list_ = []
511        for port in self._switch.portlist:
512            list_.append({
513                'port': port.number,
514                'type': port.interface.get_type(),
515                'name': port.interface.get_name(),
516                'tx':   port.tx,
517                'rx':   port.rx,
518                'shut': port.shut,
519            })
520        return {'portlist': list_}
521
522    def handle_addPort(self, params):
523        for p in params:
524            getattr(self, '_openport_' + p['type'])(p)
525        return self.handle_listPort(params)
526
527    def handle_delPort(self, params):
528        for p in params:
529            self._switch.get_port(int(p['port'])).interface.close()
530        return self.handle_listPort(params)
531
532    def handle_shutPort(self, params):
533        for p in params:
534            self._switch.shut_port(int(p['port']), bool(p['flag']))
535        return self.handle_listPort(params)
536
537    def _openport_tap(self, p):
538        dev = p['device']
539        tap = TapHandler(self._ioloop, self._switch, dev, debug=self._debug)
540        tap.open()
541
542    def _openport_client(self, p):
543        ssl_ = self._ssl_wrapper(p.get('insecure'), p.get('cacerts'))
544        cred = {'user': p.get('user'), 'passwd': p.get('passwd')}
545        url = p['url']
546        client = EtherWebSocketClient(self._ioloop, self._switch,
547                                      url, ssl_, cred, self._debug)
548        client.open()
549
550    @staticmethod
551    def _ssl_wrapper(insecure, ca_certs):
552        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': ca_certs}
553        if insecure:
554            args = {}
555        return lambda sock: ssl.wrap_socket(sock, **args)
556
557
558def daemonize(nochdir=False, noclose=False):
559    if os.fork() > 0:
560        sys.exit(0)
561
562    os.setsid()
563
564    if os.fork() > 0:
565        sys.exit(0)
566
567    if not nochdir:
568        os.chdir('/')
569
570    if not noclose:
571        os.umask(0)
572        sys.stdin.close()
573        sys.stdout.close()
574        sys.stderr.close()
575        os.close(0)
576        os.close(1)
577        os.close(2)
578        sys.stdin = open(os.devnull)
579        sys.stdout = open(os.devnull, 'a')
580        sys.stderr = open(os.devnull, 'a')
581
582
583def main():
584    def realpath(ns, *keys):
585        for k in keys:
586            v = getattr(ns, k, None)
587            if v is not None:
588                v = os.path.realpath(v)
589                open(v).close()  # check readable
590                setattr(ns, k, v)
591
592    def checkpath(ns, path):
593        val = getattr(ns, path, '')
594        if not val.startswith('/'):
595            raise ValueError('invalid %: %s' % (path, val))
596
597    def getsslopt(ns, key, cert):
598        kval = getattr(ns, key, None)
599        cval = getattr(ns, cert, None)
600        if kval and cval:
601            return {'keyfile': kval, 'certfile': cval}
602        elif kval or cval:
603            raise ValueError('both %s and %s are required' % (key, cert))
604        return None
605
606    def setport(ns, port, isssl):
607        val = getattr(ns, port, None)
608        if val is None:
609            if isssl:
610                return setattr(ns, port, 443)
611            return setattr(ns, port, 80)
612        if not (0 <= val <= 65535):
613            raise ValueError('invalid %s: %s' % (port, val))
614
615    def sethtpasswd(ns, htpasswd):
616        val = getattr(ns, htpasswd, None)
617        if val:
618            return setattr(ns, htpasswd, Htpasswd(val))
619
620    parser = argparse.ArgumentParser()
621
622    parser.add_argument('--debug', action='store_true', default=False)
623    parser.add_argument('--foreground', action='store_true', default=False)
624    parser.add_argument('--ageout', action='store', type=int, default=300)
625
626    parser.add_argument('--path', action='store', default='/')
627    parser.add_argument('--host', action='store', default='')
628    parser.add_argument('--port', action='store', type=int)
629    parser.add_argument('--htpasswd', action='store')
630    parser.add_argument('--sslkey', action='store')
631    parser.add_argument('--sslcert', action='store')
632
633    parser.add_argument('--ctlpath', action='store', default='/ctl')
634    parser.add_argument('--ctlhost', action='store', default='')
635    parser.add_argument('--ctlport', action='store', type=int)
636    parser.add_argument('--ctlhtpasswd', action='store')
637    parser.add_argument('--ctlsslkey', action='store')
638    parser.add_argument('--ctlsslcert', action='store')
639
640    args = parser.parse_args()
641
642    #if args.debug:
643    #    websocket.enableTrace(True)
644
645    if args.ageout <= 0:
646        raise ValueError('invalid ageout: %s' % args.ageout)
647
648    realpath(args, 'htpasswd', 'sslkey', 'sslcert')
649    realpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
650
651    checkpath(args, 'path')
652    checkpath(args, 'ctlpath')
653
654    sslopt = getsslopt(args, 'sslkey', 'sslcert')
655    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
656
657    setport(args, 'port', sslopt)
658    setport(args, 'ctlport', ctlsslopt)
659
660    sethtpasswd(args, 'htpasswd')
661    sethtpasswd(args, 'ctlhtpasswd')
662
663    ioloop = IOLoop.instance()
664    fdb = FDB(ageout=args.ageout, debug=args.debug)
665    switch = SwitchingHub(fdb, debug=args.debug)
666
667    if args.port == args.ctlport and args.host == args.ctlhost:
668        if args.path == args.ctlpath:
669            raise ValueError('same path/ctlpath on same host')
670        if args.sslkey != args.ctlsslkey:
671            raise ValueError('differ sslkey/ctlsslkey on same host')
672        if args.sslcert != args.ctlsslcert:
673            raise ValueError('differ sslcert/ctlsslcert on same host')
674
675        app = Application([
676            (args.path, EtherWebSocketHandler, {
677                'switch':   switch,
678                'htpasswd': args.htpasswd,
679                'debug':    args.debug,
680            }),
681            (args.ctlpath, EtherWebSocketControlHandler, {
682                'ioloop':   ioloop,
683                'switch':   switch,
684                'htpasswd': args.ctlhtpasswd,
685                'debug':    args.debug,
686            }),
687        ])
688        server = HTTPServer(app, ssl_options=sslopt)
689        server.listen(args.port, address=args.host)
690
691    else:
692        app = Application([(args.path, EtherWebSocketHandler, {
693            'switch':   switch,
694            'htpasswd': args.htpasswd,
695            'debug':    args.debug,
696        })])
697        server = HTTPServer(app, ssl_options=sslopt)
698        server.listen(args.port, address=args.host)
699
700        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
701            'ioloop':   ioloop,
702            'switch':   switch,
703            'htpasswd': args.ctlhtpasswd,
704            'debug':    args.debug,
705        })])
706        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
707        ctlserver.listen(args.ctlport, address=args.ctlhost)
708
709    if not args.foreground:
710        daemonize()
711
712    ioloop.start()
713
714
715if __name__ == '__main__':
716    main()
Note: See TracBrowser for help on using the repository browser.