source: etherws/trunk/etherws.py @ 187

Revision 187, 23.1 KB checked in by atzm, 12 years ago (diff)
  • refactoring
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.3
11#
12# ===========================================================================
13# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
42import ssl
43import time
44import json
45import fcntl
46import base64
47import hashlib
48import getpass
49import argparse
50import traceback
51
52import tornado
53import websocket
54
55from tornado.web import Application, RequestHandler
56from tornado.websocket import WebSocketHandler
57from tornado.httpserver import HTTPServer
58from tornado.ioloop import IOLoop
59
60from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
61
62
63class DebugMixIn(object):
64    def dprintf(self, msg, func=lambda: ()):
65        if self._debug:
66            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
67            sys.stderr.write(prefix + (msg % func()))
68
69
70class EthernetFrame(object):
71    def __init__(self, data):
72        self.data = data
73
74    @property
75    def dst_multicast(self):
76        return ord(self.data[0]) & 1
77
78    @property
79    def src_multicast(self):
80        return ord(self.data[6]) & 1
81
82    @property
83    def dst_mac(self):
84        return self.data[:6]
85
86    @property
87    def src_mac(self):
88        return self.data[6:12]
89
90    @property
91    def tagged(self):
92        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
93
94    @property
95    def vid(self):
96        if self.tagged:
97            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
98        return 0
99
100
101class FDB(DebugMixIn):
102    def __init__(self, ageout, debug=False):
103        self._ageout = ageout
104        self._debug = debug
105        self._dict = {}
106
107    def lookup(self, frame):
108        mac = frame.dst_mac
109        vid = frame.vid
110
111        group = self._dict.get(vid)
112        if not group:
113            return None
114
115        entry = group.get(mac)
116        if not entry:
117            return None
118
119        if time.time() - entry['time'] > self._ageout:
120            port = self._dict[vid][mac]['port']
121            del self._dict[vid][mac]
122            if not self._dict[vid]:
123                del self._dict[vid]
124            self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
125                         lambda: (port.number, vid, mac.encode('hex')))
126            return None
127
128        return entry['port']
129
130    def learn(self, port, frame):
131        mac = frame.src_mac
132        vid = frame.vid
133
134        if vid not in self._dict:
135            self._dict[vid] = {}
136
137        self._dict[vid][mac] = {'time': time.time(), 'port': port}
138        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
139                     lambda: (port.number, vid, mac.encode('hex')))
140
141    def delete(self, port):
142        for vid in self._dict.keys():
143            for mac in self._dict[vid].keys():
144                if self._dict[vid][mac]['port'].number == port.number:
145                    del self._dict[vid][mac]
146                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
147                                 lambda: (port.number, vid, mac.encode('hex')))
148            if not self._dict[vid]:
149                del self._dict[vid]
150
151
152class SwitchPort(object):
153    def __init__(self, number, interface):
154        self.number = number
155        self.interface = interface
156        self.tx = 0
157        self.rx = 0
158        self.shut = False
159
160    @staticmethod
161    def cmp_by_number(x, y):
162        return cmp(x.number, y.number)
163
164
165class SwitchingHub(DebugMixIn):
166    def __init__(self, fdb, debug=False):
167        self._fdb = fdb
168        self._debug = debug
169        self._table = {}
170        self._next = 1
171
172    @property
173    def portlist(self):
174        return sorted(self._table.itervalues(), cmp=SwitchPort.cmp_by_number)
175
176    def get_port(self, portnum):
177        return self._table[portnum]
178
179    def register_port(self, interface):
180        try:
181            self._set_privattr('portnum', interface, self._next)  # XXX
182            self._table[self._next] = SwitchPort(self._next, interface)
183            return self._next
184        finally:
185            self._next += 1
186
187    def unregister_port(self, interface):
188        portnum = self._get_privattr('portnum', interface)
189        self._del_privattr('portnum', interface)
190        self._fdb.delete(self._table[portnum])
191        del self._table[portnum]
192
193    def send(self, dst_interfaces, frame):
194        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
195        ports = (self._table[n] for n in portnums)
196        ports = (p for p in ports if not p.shut)
197        ports = sorted(ports, cmp=SwitchPort.cmp_by_number)
198
199        for p in ports:
200            p.interface.write_message(frame.data, True)
201            p.tx += 1
202
203        if ports:
204            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
205                         lambda: (','.join(str(p.number) for p in ports),
206                                  frame.vid,
207                                  frame.src_mac.encode('hex'),
208                                  frame.dst_mac.encode('hex')))
209
210    def receive(self, src_interface, frame):
211        port = self._table[self._get_privattr('portnum', src_interface)]
212
213        if not port.shut:
214            port.rx += 1
215            self._forward(port, frame)
216
217    def _forward(self, src_port, frame):
218        try:
219            if not frame.src_multicast:
220                self._fdb.learn(src_port, frame)
221
222            if not frame.dst_multicast:
223                dst_port = self._fdb.lookup(frame)
224
225                if dst_port:
226                    self.send([dst_port.interface], frame)
227                    return
228
229            ports = set(self.portlist) - set([src_port])
230            self.send((p.interface for p in ports), frame)
231
232        except:  # ex. received invalid frame
233            traceback.print_exc()
234
235    def _privattr(self, name):
236        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
237
238    def _set_privattr(self, name, obj, value):
239        return setattr(obj, self._privattr(name), value)
240
241    def _get_privattr(self, name, obj, defaults=None):
242        return getattr(obj, self._privattr(name), defaults)
243
244    def _del_privattr(self, name, obj):
245        return delattr(obj, self._privattr(name))
246
247
248class Htpasswd(object):
249    def __init__(self, path):
250        self._path = path
251        self._stat = None
252        self._data = {}
253
254    def auth(self, name, passwd):
255        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
256        return self._data.get(name) == passwd
257
258    def load(self):
259        old_stat = self._stat
260
261        with open(self._path) as fp:
262            fileno = fp.fileno()
263            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
264            self._stat = os.fstat(fileno)
265
266            unchanged = old_stat and \
267                        old_stat.st_ino == self._stat.st_ino and \
268                        old_stat.st_dev == self._stat.st_dev and \
269                        old_stat.st_mtime == self._stat.st_mtime
270
271            if not unchanged:
272                self._data = self._parse(fp)
273
274        return self
275
276    def _parse(self, fp):
277        data = {}
278        for line in fp:
279            line = line.strip()
280            if 0 <= line.find(':'):
281                name, passwd = line.split(':', 1)
282                if passwd.startswith('{SHA}'):
283                    data[name] = passwd[5:]
284        return data
285
286
287class BasicAuthMixIn(object):
288    def _execute(self, transforms, *args, **kwargs):
289        def do_execute():
290            sp = super(BasicAuthMixIn, self)
291            return sp._execute(transforms, *args, **kwargs)
292
293        def auth_required():
294            stream = getattr(self, 'stream', self.request.connection.stream)
295            stream.write(tornado.escape.utf8(
296                'HTTP/1.1 401 Authorization Required\r\n'
297                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
298            ))
299            stream.close()
300
301        try:
302            if not self._htpasswd:
303                return do_execute()
304
305            creds = self.request.headers.get('Authorization')
306
307            if not creds or not creds.startswith('Basic '):
308                return auth_required()
309
310            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
311
312            if self._htpasswd.load().auth(name, passwd):
313                return do_execute()
314        except:
315            traceback.print_exc()
316
317        return auth_required()
318
319
320class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
321    def __init__(self, app, req, switch, htpasswd=None, debug=False):
322        super(EtherWebSocketHandler, self).__init__(app, req)
323        self._switch = switch
324        self._htpasswd = htpasswd
325        self._debug = debug
326
327    @classmethod
328    def get_type(cls):
329        return 'server'
330
331    def get_target(self):
332        return self.request.remote_ip
333
334    def open(self):
335        try:
336            return self._switch.register_port(self)
337        finally:
338            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
339
340    def on_message(self, message):
341        self._switch.receive(self, EthernetFrame(message))
342
343    def on_close(self):
344        self._switch.unregister_port(self)
345        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
346
347
348class TapHandler(DebugMixIn):
349    READ_SIZE = 65535
350
351    def __init__(self, ioloop, switch, dev, debug=False):
352        self._ioloop = ioloop
353        self._switch = switch
354        self._dev = dev
355        self._debug = debug
356        self._tap = None
357
358    @classmethod
359    def get_type(cls):
360        return 'tap'
361
362    def get_target(self):
363        if self.closed:
364            return self._dev
365        return self._tap.name
366
367    @property
368    def closed(self):
369        return not self._tap
370
371    def open(self):
372        if not self.closed:
373            raise ValueError('already opened')
374        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
375        self._tap.up()
376        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
377        return self._switch.register_port(self)
378
379    def close(self):
380        if self.closed:
381            raise ValueError('I/O operation on closed tap')
382        self._switch.unregister_port(self)
383        self._ioloop.remove_handler(self.fileno())
384        self._tap.close()
385        self._tap = None
386
387    def fileno(self):
388        if self.closed:
389            raise ValueError('I/O operation on closed tap')
390        return self._tap.fileno()
391
392    def write_message(self, message, binary=False):
393        if self.closed:
394            raise ValueError('I/O operation on closed tap')
395        self._tap.write(message)
396
397    def __call__(self, fd, events):
398        try:
399            self._switch.receive(self, EthernetFrame(self._read()))
400            return
401        except:
402            traceback.print_exc()
403        self.close()
404
405    def _read(self):
406        if self.closed:
407            raise ValueError('I/O operation on closed tap')
408        buf = []
409        while True:
410            buf.append(self._tap.read(self.READ_SIZE))
411            if len(buf[-1]) < self.READ_SIZE:
412                break
413        return ''.join(buf)
414
415
416class EtherWebSocketClient(DebugMixIn):
417    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
418        self._ioloop = ioloop
419        self._switch = switch
420        self._url = url
421        self._ssl = ssl_
422        self._debug = debug
423        self._sock = None
424        self._options = {}
425
426        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
427            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
428            auth = ['Authorization: Basic %s' % token]
429            self._options['header'] = auth
430
431    @classmethod
432    def get_type(cls):
433        return 'client'
434
435    def get_target(self):
436        return self._url
437
438    @property
439    def closed(self):
440        return not self._sock
441
442    def open(self):
443        sslwrap = websocket._SSLSocketWrapper
444
445        if not self.closed:
446            raise websocket.WebSocketException('already opened')
447
448        if self._ssl:
449            websocket._SSLSocketWrapper = self._ssl
450
451        try:
452            self._sock = websocket.WebSocket()
453            self._sock.connect(self._url, **self._options)
454            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
455            return self._switch.register_port(self)
456        finally:
457            websocket._SSLSocketWrapper = sslwrap
458            self.dprintf('connected: %s\n', lambda: self._url)
459
460    def close(self):
461        if self.closed:
462            raise websocket.WebSocketException('already closed')
463        self._switch.unregister_port(self)
464        self._ioloop.remove_handler(self.fileno())
465        self._sock.close()
466        self._sock = None
467        self.dprintf('disconnected: %s\n', lambda: self._url)
468
469    def fileno(self):
470        if self.closed:
471            raise websocket.WebSocketException('closed socket')
472        return self._sock.io_sock.fileno()
473
474    def write_message(self, message, binary=False):
475        if self.closed:
476            raise websocket.WebSocketException('closed socket')
477        if binary:
478            flag = websocket.ABNF.OPCODE_BINARY
479        else:
480            flag = websocket.ABNF.OPCODE_TEXT
481        self._sock.send(message, flag)
482
483    def __call__(self, fd, events):
484        try:
485            data = self._sock.recv()
486            if data is not None:
487                self._switch.receive(self, EthernetFrame(data))
488                return
489        except:
490            traceback.print_exc()
491        self.close()
492
493
494class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
495    NAMESPACE = 'etherws.control'
496    INTERFACES = {
497        TapHandler.get_type():           TapHandler,
498        EtherWebSocketClient.get_type(): EtherWebSocketClient,
499    }
500
501    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
502        super(EtherWebSocketControlHandler, self).__init__(app, req)
503        self._ioloop = ioloop
504        self._switch = switch
505        self._htpasswd = htpasswd
506        self._debug = debug
507
508    def post(self):
509        id_ = None
510
511        try:
512            req = json.loads(self.request.body)
513            method = req['method']
514            params = req['params']
515            id_ = req.get('id')
516
517            if not method.startswith(self.NAMESPACE + '.'):
518                raise ValueError('invalid method: %s' % method)
519
520            if not isinstance(params, list):
521                raise ValueError('invalid params: %s' % params)
522
523            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
524            result = getattr(self, handler)(params)
525            self.finish({'result': result, 'error': None, 'id': id_})
526
527        except Exception as e:
528            traceback.print_exc()
529            self.finish({'result': None, 'error': str(e), 'id': id_})
530
531    def handle_listPort(self, params):
532        list_ = []
533        for port in self._switch.portlist:
534            list_.append(self._portstat(port))
535        return {'portlist': list_}
536
537    def handle_addPort(self, params):
538        list_ = []
539        for p in params:
540            type_ = p['type']
541            target = p['target']
542            options = getattr(self, '_optparse_' + type_)(p.get('options', {}))
543            klass = self.INTERFACES[type_]
544            interface = klass(self._ioloop, self._switch, target, **options)
545            portnum = interface.open()
546            list_.append(self._portstat(self._switch.get_port(portnum)))
547        return {'portlist': list_}
548
549    def handle_delPort(self, params):
550        list_ = []
551        for p in params:
552            port = self._switch.get_port(int(p['port']))
553            list_.append(self._portstat(port))
554            port.interface.close()
555        return {'portlist': list_}
556
557    def handle_shutPort(self, params):
558        list_ = []
559        for p in params:
560            port = self._switch.get_port(int(p['port']))
561            port.shut = bool(p['flag'])
562            list_.append(self._portstat(port))
563        return {'portlist': list_}
564
565    def _optparse_tap(self, opt):
566        return {'debug': self._debug}
567
568    def _optparse_client(self, opt):
569        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
570        if opt.get('insecure'):
571            args = {}
572        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
573        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
574        return {'ssl_': ssl_, 'cred': cred, 'debug': self._debug}
575
576    @staticmethod
577    def _portstat(port):
578        return {
579            'port':   port.number,
580            'type':   port.interface.get_type(),
581            'target': port.interface.get_target(),
582            'tx':     port.tx,
583            'rx':     port.rx,
584            'shut':   port.shut,
585        }
586
587
588def start_switch(args):
589    def daemonize(nochdir=False, noclose=False):
590        if os.fork() > 0:
591            sys.exit(0)
592
593        os.setsid()
594
595        if os.fork() > 0:
596            sys.exit(0)
597
598        if not nochdir:
599            os.chdir('/')
600
601        if not noclose:
602            os.umask(0)
603            sys.stdin.close()
604            sys.stdout.close()
605            sys.stderr.close()
606            os.close(0)
607            os.close(1)
608            os.close(2)
609            sys.stdin = open(os.devnull)
610            sys.stdout = open(os.devnull, 'a')
611            sys.stderr = open(os.devnull, 'a')
612
613    def checkabspath(ns, path):
614        val = getattr(ns, path, '')
615        if not val.startswith('/'):
616            raise ValueError('invalid %: %s' % (path, val))
617
618    def getsslopt(ns, key, cert):
619        kval = getattr(ns, key, None)
620        cval = getattr(ns, cert, None)
621        if kval and cval:
622            return {'keyfile': kval, 'certfile': cval}
623        elif kval or cval:
624            raise ValueError('both %s and %s are required' % (key, cert))
625        return None
626
627    def setrealpath(ns, *keys):
628        for k in keys:
629            v = getattr(ns, k, None)
630            if v is not None:
631                v = os.path.realpath(v)
632                open(v).close()  # check readable
633                setattr(ns, k, v)
634
635    def setport(ns, port, isssl):
636        val = getattr(ns, port, None)
637        if val is None:
638            if isssl:
639                return setattr(ns, port, 443)
640            return setattr(ns, port, 80)
641        if not (0 <= val <= 65535):
642            raise ValueError('invalid %s: %s' % (port, val))
643
644    def sethtpasswd(ns, htpasswd):
645        val = getattr(ns, htpasswd, None)
646        if val:
647            return setattr(ns, htpasswd, Htpasswd(val))
648
649    #if args.debug:
650    #    websocket.enableTrace(True)
651
652    if args.ageout <= 0:
653        raise ValueError('invalid ageout: %s' % args.ageout)
654
655    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
656    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
657
658    checkabspath(args, 'path')
659    checkabspath(args, 'ctlpath')
660
661    sslopt = getsslopt(args, 'sslkey', 'sslcert')
662    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
663
664    setport(args, 'port', sslopt)
665    setport(args, 'ctlport', ctlsslopt)
666
667    sethtpasswd(args, 'htpasswd')
668    sethtpasswd(args, 'ctlhtpasswd')
669
670    ioloop = IOLoop.instance()
671    fdb = FDB(ageout=args.ageout, debug=args.debug)
672    switch = SwitchingHub(fdb, debug=args.debug)
673
674    if args.port == args.ctlport and args.host == args.ctlhost:
675        if args.path == args.ctlpath:
676            raise ValueError('same path/ctlpath on same host')
677        if args.sslkey != args.ctlsslkey:
678            raise ValueError('differ sslkey/ctlsslkey on same host')
679        if args.sslcert != args.ctlsslcert:
680            raise ValueError('differ sslcert/ctlsslcert on same host')
681
682        app = Application([
683            (args.path, EtherWebSocketHandler, {
684                'switch':   switch,
685                'htpasswd': args.htpasswd,
686                'debug':    args.debug,
687            }),
688            (args.ctlpath, EtherWebSocketControlHandler, {
689                'ioloop':   ioloop,
690                'switch':   switch,
691                'htpasswd': args.ctlhtpasswd,
692                'debug':    args.debug,
693            }),
694        ])
695        server = HTTPServer(app, ssl_options=sslopt)
696        server.listen(args.port, address=args.host)
697
698    else:
699        app = Application([(args.path, EtherWebSocketHandler, {
700            'switch':   switch,
701            'htpasswd': args.htpasswd,
702            'debug':    args.debug,
703        })])
704        server = HTTPServer(app, ssl_options=sslopt)
705        server.listen(args.port, address=args.host)
706
707        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
708            'ioloop':   ioloop,
709            'switch':   switch,
710            'htpasswd': args.ctlhtpasswd,
711            'debug':    args.debug,
712        })])
713        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
714        ctlserver.listen(args.ctlport, address=args.ctlhost)
715
716    if not args.foreground:
717        daemonize()
718
719    ioloop.start()
720
721
722def main():
723    parser = argparse.ArgumentParser()
724    subparsers = parser.add_subparsers(dest='subcommand')
725    parser_s = subparsers.add_parser('switch')
726    parser_c = subparsers.add_parser('control')
727
728    parser_s.add_argument('--debug', action='store_true', default=False)
729    parser_s.add_argument('--foreground', action='store_true', default=False)
730    parser_s.add_argument('--ageout', action='store', type=int, default=300)
731
732    parser_s.add_argument('--path', action='store', default='/')
733    parser_s.add_argument('--host', action='store', default='')
734    parser_s.add_argument('--port', action='store', type=int)
735    parser_s.add_argument('--htpasswd', action='store')
736    parser_s.add_argument('--sslkey', action='store')
737    parser_s.add_argument('--sslcert', action='store')
738
739    parser_s.add_argument('--ctlpath', action='store', default='/ctl')
740    parser_s.add_argument('--ctlhost', action='store', default='')
741    parser_s.add_argument('--ctlport', action='store', type=int)
742    parser_s.add_argument('--ctlhtpasswd', action='store')
743    parser_s.add_argument('--ctlsslkey', action='store')
744    parser_s.add_argument('--ctlsslcert', action='store')
745
746    args = parser.parse_args()
747
748    globals()['start_' + args.subcommand](args)
749
750
751if __name__ == '__main__':
752    main()
Note: See TracBrowser for help on using the repository browser.