source: etherws/trunk/etherws.py @ 186

Revision 186, 22.6 KB checked in by atzm, 12 years ago (diff)
  • API seiri
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.3
11#
12# ===========================================================================
13# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
42import ssl
43import time
44import json
45import fcntl
46import base64
47import hashlib
48import getpass
49import argparse
50import traceback
51
52import tornado
53import websocket
54
55from tornado.web import Application, RequestHandler
56from tornado.websocket import WebSocketHandler
57from tornado.httpserver import HTTPServer
58from tornado.ioloop import IOLoop
59
60from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
61
62
63class DebugMixIn(object):
64    def dprintf(self, msg, func=lambda: ()):
65        if self._debug:
66            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
67            sys.stderr.write(prefix + (msg % func()))
68
69
70class EthernetFrame(object):
71    def __init__(self, data):
72        self.data = data
73
74    @property
75    def dst_multicast(self):
76        return ord(self.data[0]) & 1
77
78    @property
79    def src_multicast(self):
80        return ord(self.data[6]) & 1
81
82    @property
83    def dst_mac(self):
84        return self.data[:6]
85
86    @property
87    def src_mac(self):
88        return self.data[6:12]
89
90    @property
91    def tagged(self):
92        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
93
94    @property
95    def vid(self):
96        if self.tagged:
97            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
98        return 0
99
100
101class FDB(DebugMixIn):
102    def __init__(self, ageout, debug=False):
103        self._ageout = ageout
104        self._debug = debug
105        self._dict = {}
106
107    def lookup(self, frame):
108        mac = frame.dst_mac
109        vid = frame.vid
110
111        group = self._dict.get(vid)
112        if not group:
113            return None
114
115        entry = group.get(mac)
116        if not entry:
117            return None
118
119        if time.time() - entry['time'] > self._ageout:
120            port = self._dict[vid][mac]['port']
121            del self._dict[vid][mac]
122            if not self._dict[vid]:
123                del self._dict[vid]
124            self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
125                         lambda: (port.number, vid, mac.encode('hex')))
126            return None
127
128        return entry['port']
129
130    def learn(self, port, frame):
131        mac = frame.src_mac
132        vid = frame.vid
133
134        if vid not in self._dict:
135            self._dict[vid] = {}
136
137        self._dict[vid][mac] = {'time': time.time(), 'port': port}
138        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
139                     lambda: (port.number, vid, mac.encode('hex')))
140
141    def delete(self, port):
142        for vid in self._dict.keys():
143            for mac in self._dict[vid].keys():
144                if self._dict[vid][mac]['port'].number == port.number:
145                    del self._dict[vid][mac]
146                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
147                                 lambda: (port.number, vid, mac.encode('hex')))
148            if not self._dict[vid]:
149                del self._dict[vid]
150
151
152class SwitchPort(object):
153    def __init__(self, number, interface):
154        self.number = number
155        self.interface = interface
156        self.tx = 0
157        self.rx = 0
158        self.shut = False
159
160    @staticmethod
161    def cmp_by_number(x, y):
162        return cmp(x.number, y.number)
163
164
165class SwitchingHub(DebugMixIn):
166    def __init__(self, fdb, debug=False):
167        self._fdb = fdb
168        self._debug = debug
169        self._table = {}
170        self._next = 1
171
172    @property
173    def portlist(self):
174        return sorted(self._table.itervalues(), cmp=SwitchPort.cmp_by_number)
175
176    def get_port(self, portnum):
177        return self._table[portnum]
178
179    def register_port(self, interface):
180        try:
181            interface._switch_portnum = self._next  # XXX
182            self._table[self._next] = SwitchPort(self._next, interface)
183            return self._next
184        finally:
185            self._next += 1
186
187    def unregister_port(self, interface):
188        self._fdb.delete(self._table[interface._switch_portnum])
189        del self._table[interface._switch_portnum]
190        del interface._switch_portnum
191
192    def send(self, dst_interfaces, frame):
193        ports = sorted((self._table[i._switch_portnum] for i in dst_interfaces
194                        if not self._table[i._switch_portnum].shut),
195                       cmp=SwitchPort.cmp_by_number)
196
197        for p in ports:
198            p.interface.write_message(frame.data, True)
199            p.tx += 1
200
201        if ports:
202            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
203                         lambda: (','.join(str(p.number) for p in ports),
204                                  frame.vid,
205                                  frame.src_mac.encode('hex'),
206                                  frame.dst_mac.encode('hex')))
207
208    def receive(self, src_interface, frame):
209        port = self._table[src_interface._switch_portnum]
210
211        if not port.shut:
212            port.rx += 1
213            self._forward(port, frame)
214
215    def _forward(self, src_port, frame):
216        try:
217            if not frame.src_multicast:
218                self._fdb.learn(src_port, frame)
219
220            if not frame.dst_multicast:
221                dst_port = self._fdb.lookup(frame)
222
223                if dst_port:
224                    self.send([dst_port.interface], frame)
225                    return
226
227            ports = set(self._table.itervalues()) - set([src_port])
228            self.send((p.interface for p in ports), frame)
229
230        except:  # ex. received invalid frame
231            traceback.print_exc()
232
233
234class Htpasswd(object):
235    def __init__(self, path):
236        self._path = path
237        self._stat = None
238        self._data = {}
239
240    def auth(self, name, passwd):
241        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
242        return self._data.get(name) == passwd
243
244    def load(self):
245        old_stat = self._stat
246
247        with open(self._path) as fp:
248            fileno = fp.fileno()
249            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
250            self._stat = os.fstat(fileno)
251
252            unchanged = old_stat and \
253                        old_stat.st_ino == self._stat.st_ino and \
254                        old_stat.st_dev == self._stat.st_dev and \
255                        old_stat.st_mtime == self._stat.st_mtime
256
257            if not unchanged:
258                self._data = self._parse(fp)
259
260        return self
261
262    def _parse(self, fp):
263        data = {}
264        for line in fp:
265            line = line.strip()
266            if 0 <= line.find(':'):
267                name, passwd = line.split(':', 1)
268                if passwd.startswith('{SHA}'):
269                    data[name] = passwd[5:]
270        return data
271
272
273class BasicAuthMixIn(object):
274    def _execute(self, transforms, *args, **kwargs):
275        def do_execute():
276            sp = super(BasicAuthMixIn, self)
277            return sp._execute(transforms, *args, **kwargs)
278
279        def auth_required():
280            stream = getattr(self, 'stream', self.request.connection.stream)
281            stream.write(tornado.escape.utf8(
282                'HTTP/1.1 401 Authorization Required\r\n'
283                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
284            ))
285            stream.close()
286
287        try:
288            if not self._htpasswd:
289                return do_execute()
290
291            creds = self.request.headers.get('Authorization')
292
293            if not creds or not creds.startswith('Basic '):
294                return auth_required()
295
296            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
297
298            if self._htpasswd.load().auth(name, passwd):
299                return do_execute()
300        except:
301            traceback.print_exc()
302
303        return auth_required()
304
305
306class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
307    def __init__(self, app, req, switch, htpasswd=None, debug=False):
308        super(EtherWebSocketHandler, self).__init__(app, req)
309        self._switch = switch
310        self._htpasswd = htpasswd
311        self._debug = debug
312
313    @classmethod
314    def get_type(cls):
315        return 'server'
316
317    def get_target(self):
318        return self.request.remote_ip
319
320    def open(self):
321        try:
322            return self._switch.register_port(self)
323        finally:
324            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
325
326    def on_message(self, message):
327        self._switch.receive(self, EthernetFrame(message))
328
329    def on_close(self):
330        self._switch.unregister_port(self)
331        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
332
333
334class TapHandler(DebugMixIn):
335    READ_SIZE = 65535
336
337    def __init__(self, ioloop, switch, dev, debug=False):
338        self._ioloop = ioloop
339        self._switch = switch
340        self._dev = dev
341        self._debug = debug
342        self._tap = None
343
344    @classmethod
345    def get_type(cls):
346        return 'tap'
347
348    def get_target(self):
349        if self.closed:
350            return self._dev
351        return self._tap.name
352
353    @property
354    def closed(self):
355        return not self._tap
356
357    def open(self):
358        if not self.closed:
359            raise ValueError('already opened')
360        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
361        self._tap.up()
362        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
363        return self._switch.register_port(self)
364
365    def close(self):
366        if self.closed:
367            raise ValueError('I/O operation on closed tap')
368        self._switch.unregister_port(self)
369        self._ioloop.remove_handler(self.fileno())
370        self._tap.close()
371        self._tap = None
372
373    def fileno(self):
374        if self.closed:
375            raise ValueError('I/O operation on closed tap')
376        return self._tap.fileno()
377
378    def write_message(self, message, binary=False):
379        if self.closed:
380            raise ValueError('I/O operation on closed tap')
381        self._tap.write(message)
382
383    def __call__(self, fd, events):
384        try:
385            self._switch.receive(self, EthernetFrame(self._read()))
386            return
387        except:
388            traceback.print_exc()
389        self.close()
390
391    def _read(self):
392        if self.closed:
393            raise ValueError('I/O operation on closed tap')
394        buf = []
395        while True:
396            buf.append(self._tap.read(self.READ_SIZE))
397            if len(buf[-1]) < self.READ_SIZE:
398                break
399        return ''.join(buf)
400
401
402class EtherWebSocketClient(DebugMixIn):
403    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
404        self._ioloop = ioloop
405        self._switch = switch
406        self._url = url
407        self._ssl = ssl_
408        self._debug = debug
409        self._sock = None
410        self._options = {}
411
412        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
413            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
414            auth = ['Authorization: Basic %s' % token]
415            self._options['header'] = auth
416
417    @classmethod
418    def get_type(cls):
419        return 'client'
420
421    def get_target(self):
422        return self._url
423
424    @property
425    def closed(self):
426        return not self._sock
427
428    def open(self):
429        sslwrap = websocket._SSLSocketWrapper
430
431        if not self.closed:
432            raise websocket.WebSocketException('already opened')
433
434        if self._ssl:
435            websocket._SSLSocketWrapper = self._ssl
436
437        try:
438            self._sock = websocket.WebSocket()
439            self._sock.connect(self._url, **self._options)
440            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
441            return self._switch.register_port(self)
442        finally:
443            websocket._SSLSocketWrapper = sslwrap
444            self.dprintf('connected: %s\n', lambda: self._url)
445
446    def close(self):
447        if self.closed:
448            raise websocket.WebSocketException('already closed')
449        self._switch.unregister_port(self)
450        self._ioloop.remove_handler(self.fileno())
451        self._sock.close()
452        self._sock = None
453        self.dprintf('disconnected: %s\n', lambda: self._url)
454
455    def fileno(self):
456        if self.closed:
457            raise websocket.WebSocketException('closed socket')
458        return self._sock.io_sock.fileno()
459
460    def write_message(self, message, binary=False):
461        if self.closed:
462            raise websocket.WebSocketException('closed socket')
463        if binary:
464            flag = websocket.ABNF.OPCODE_BINARY
465        else:
466            flag = websocket.ABNF.OPCODE_TEXT
467        self._sock.send(message, flag)
468
469    def __call__(self, fd, events):
470        try:
471            data = self._sock.recv()
472            if data is not None:
473                self._switch.receive(self, EthernetFrame(data))
474                return
475        except:
476            traceback.print_exc()
477        self.close()
478
479
480class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
481    NAMESPACE = 'etherws.control'
482    INTERFACES = {
483        TapHandler.get_type():           TapHandler,
484        EtherWebSocketClient.get_type(): EtherWebSocketClient,
485    }
486
487    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
488        super(EtherWebSocketControlHandler, self).__init__(app, req)
489        self._ioloop = ioloop
490        self._switch = switch
491        self._htpasswd = htpasswd
492        self._debug = debug
493
494    def post(self):
495        id_ = None
496
497        try:
498            req = json.loads(self.request.body)
499            method = req['method']
500            params = req['params']
501            id_ = req.get('id')
502
503            if not method.startswith(self.NAMESPACE + '.'):
504                raise ValueError('invalid method: %s' % method)
505
506            if not isinstance(params, list):
507                raise ValueError('invalid params: %s' % params)
508
509            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
510            result = getattr(self, handler)(params)
511            self.finish({'result': result, 'error': None, 'id': id_})
512
513        except Exception as e:
514            traceback.print_exc()
515            self.finish({'result': None, 'error': str(e), 'id': id_})
516
517    def handle_listPort(self, params):
518        list_ = []
519        for port in self._switch.portlist:
520            list_.append(self._portstat(port))
521        return {'portlist': list_}
522
523    def handle_addPort(self, params):
524        list_ = []
525        for p in params:
526            type_ = p['type']
527            target = p['target']
528            options = getattr(self, '_optparse_' + type_)(p.get('options', {}))
529            klass = self.INTERFACES[type_]
530            interface = klass(self._ioloop, self._switch, target, **options)
531            portnum = interface.open()
532            list_.append(self._portstat(self._switch.get_port(portnum)))
533        return {'portlist': list_}
534
535    def handle_delPort(self, params):
536        list_ = []
537        for p in params:
538            port = self._switch.get_port(int(p['port']))
539            list_.append(self._portstat(port))
540            port.interface.close()
541        return {'portlist': list_}
542
543    def handle_shutPort(self, params):
544        list_ = []
545        for p in params:
546            port = self._switch.get_port(int(p['port']))
547            port.shut = bool(p['flag'])
548            list_.append(self._portstat(port))
549        return {'portlist': list_}
550
551    def _optparse_tap(self, opt):
552        return {'debug': self._debug}
553
554    def _optparse_client(self, opt):
555        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
556        if opt.get('insecure'):
557            args = {}
558        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
559        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
560        return {'ssl_': ssl_, 'cred': cred, 'debug': self._debug}
561
562    @staticmethod
563    def _portstat(port):
564        return {
565            'port':   port.number,
566            'type':   port.interface.get_type(),
567            'target': port.interface.get_target(),
568            'tx':     port.tx,
569            'rx':     port.rx,
570            'shut':   port.shut,
571        }
572
573
574def start_switch(args):
575    def daemonize(nochdir=False, noclose=False):
576        if os.fork() > 0:
577            sys.exit(0)
578
579        os.setsid()
580
581        if os.fork() > 0:
582            sys.exit(0)
583
584        if not nochdir:
585            os.chdir('/')
586
587        if not noclose:
588            os.umask(0)
589            sys.stdin.close()
590            sys.stdout.close()
591            sys.stderr.close()
592            os.close(0)
593            os.close(1)
594            os.close(2)
595            sys.stdin = open(os.devnull)
596            sys.stdout = open(os.devnull, 'a')
597            sys.stderr = open(os.devnull, 'a')
598
599    def checkabspath(ns, path):
600        val = getattr(ns, path, '')
601        if not val.startswith('/'):
602            raise ValueError('invalid %: %s' % (path, val))
603
604    def getsslopt(ns, key, cert):
605        kval = getattr(ns, key, None)
606        cval = getattr(ns, cert, None)
607        if kval and cval:
608            return {'keyfile': kval, 'certfile': cval}
609        elif kval or cval:
610            raise ValueError('both %s and %s are required' % (key, cert))
611        return None
612
613    def setrealpath(ns, *keys):
614        for k in keys:
615            v = getattr(ns, k, None)
616            if v is not None:
617                v = os.path.realpath(v)
618                open(v).close()  # check readable
619                setattr(ns, k, v)
620
621    def setport(ns, port, isssl):
622        val = getattr(ns, port, None)
623        if val is None:
624            if isssl:
625                return setattr(ns, port, 443)
626            return setattr(ns, port, 80)
627        if not (0 <= val <= 65535):
628            raise ValueError('invalid %s: %s' % (port, val))
629
630    def sethtpasswd(ns, htpasswd):
631        val = getattr(ns, htpasswd, None)
632        if val:
633            return setattr(ns, htpasswd, Htpasswd(val))
634
635    #if args.debug:
636    #    websocket.enableTrace(True)
637
638    if args.ageout <= 0:
639        raise ValueError('invalid ageout: %s' % args.ageout)
640
641    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
642    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
643
644    checkabspath(args, 'path')
645    checkabspath(args, 'ctlpath')
646
647    sslopt = getsslopt(args, 'sslkey', 'sslcert')
648    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
649
650    setport(args, 'port', sslopt)
651    setport(args, 'ctlport', ctlsslopt)
652
653    sethtpasswd(args, 'htpasswd')
654    sethtpasswd(args, 'ctlhtpasswd')
655
656    ioloop = IOLoop.instance()
657    fdb = FDB(ageout=args.ageout, debug=args.debug)
658    switch = SwitchingHub(fdb, debug=args.debug)
659
660    if args.port == args.ctlport and args.host == args.ctlhost:
661        if args.path == args.ctlpath:
662            raise ValueError('same path/ctlpath on same host')
663        if args.sslkey != args.ctlsslkey:
664            raise ValueError('differ sslkey/ctlsslkey on same host')
665        if args.sslcert != args.ctlsslcert:
666            raise ValueError('differ sslcert/ctlsslcert on same host')
667
668        app = Application([
669            (args.path, EtherWebSocketHandler, {
670                'switch':   switch,
671                'htpasswd': args.htpasswd,
672                'debug':    args.debug,
673            }),
674            (args.ctlpath, EtherWebSocketControlHandler, {
675                'ioloop':   ioloop,
676                'switch':   switch,
677                'htpasswd': args.ctlhtpasswd,
678                'debug':    args.debug,
679            }),
680        ])
681        server = HTTPServer(app, ssl_options=sslopt)
682        server.listen(args.port, address=args.host)
683
684    else:
685        app = Application([(args.path, EtherWebSocketHandler, {
686            'switch':   switch,
687            'htpasswd': args.htpasswd,
688            'debug':    args.debug,
689        })])
690        server = HTTPServer(app, ssl_options=sslopt)
691        server.listen(args.port, address=args.host)
692
693        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
694            'ioloop':   ioloop,
695            'switch':   switch,
696            'htpasswd': args.ctlhtpasswd,
697            'debug':    args.debug,
698        })])
699        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
700        ctlserver.listen(args.ctlport, address=args.ctlhost)
701
702    if not args.foreground:
703        daemonize()
704
705    ioloop.start()
706
707
708def main():
709    parser = argparse.ArgumentParser()
710    subparsers = parser.add_subparsers(dest='subcommand')
711    parser_s = subparsers.add_parser('switch')
712    parser_c = subparsers.add_parser('control')
713
714    parser_s.add_argument('--debug', action='store_true', default=False)
715    parser_s.add_argument('--foreground', action='store_true', default=False)
716    parser_s.add_argument('--ageout', action='store', type=int, default=300)
717
718    parser_s.add_argument('--path', action='store', default='/')
719    parser_s.add_argument('--host', action='store', default='')
720    parser_s.add_argument('--port', action='store', type=int)
721    parser_s.add_argument('--htpasswd', action='store')
722    parser_s.add_argument('--sslkey', action='store')
723    parser_s.add_argument('--sslcert', action='store')
724
725    parser_s.add_argument('--ctlpath', action='store', default='/ctl')
726    parser_s.add_argument('--ctlhost', action='store', default='')
727    parser_s.add_argument('--ctlport', action='store', type=int)
728    parser_s.add_argument('--ctlhtpasswd', action='store')
729    parser_s.add_argument('--ctlsslkey', action='store')
730    parser_s.add_argument('--ctlsslcert', action='store')
731
732    args = parser.parse_args()
733
734    globals()['start_' + args.subcommand](args)
735
736
737if __name__ == '__main__':
738    main()
Note: See TracBrowser for help on using the repository browser.