source: etherws/trunk/etherws.py @ 189

Revision 189, 23.0 KB checked in by atzm, 12 years ago (diff)
  • fix error messages
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.3
11#
12# ===========================================================================
13# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
42import ssl
43import time
44import json
45import fcntl
46import base64
47import hashlib
48import getpass
49import argparse
50import traceback
51
52import tornado
53import websocket
54
55from tornado.web import Application, RequestHandler
56from tornado.websocket import WebSocketHandler
57from tornado.httpserver import HTTPServer
58from tornado.ioloop import IOLoop
59
60from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
61
62
63class DebugMixIn(object):
64    def dprintf(self, msg, func=lambda: ()):
65        if self._debug:
66            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
67            sys.stderr.write(prefix + (msg % func()))
68
69
70class EthernetFrame(object):
71    def __init__(self, data):
72        self.data = data
73
74    @property
75    def dst_multicast(self):
76        return ord(self.data[0]) & 1
77
78    @property
79    def src_multicast(self):
80        return ord(self.data[6]) & 1
81
82    @property
83    def dst_mac(self):
84        return self.data[:6]
85
86    @property
87    def src_mac(self):
88        return self.data[6:12]
89
90    @property
91    def tagged(self):
92        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
93
94    @property
95    def vid(self):
96        if self.tagged:
97            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
98        return 0
99
100
101class FDB(DebugMixIn):
102    def __init__(self, ageout, debug=False):
103        self._ageout = ageout
104        self._debug = debug
105        self._dict = {}
106
107    def lookup(self, frame):
108        mac = frame.dst_mac
109        vid = frame.vid
110
111        group = self._dict.get(vid)
112        if not group:
113            return None
114
115        entry = group.get(mac)
116        if not entry:
117            return None
118
119        if time.time() - entry['time'] > self._ageout:
120            port = self._dict[vid][mac]['port']
121            del self._dict[vid][mac]
122            if not self._dict[vid]:
123                del self._dict[vid]
124            self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
125                         lambda: (port.number, vid, mac.encode('hex')))
126            return None
127
128        return entry['port']
129
130    def learn(self, port, frame):
131        mac = frame.src_mac
132        vid = frame.vid
133
134        if vid not in self._dict:
135            self._dict[vid] = {}
136
137        self._dict[vid][mac] = {'time': time.time(), 'port': port}
138        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
139                     lambda: (port.number, vid, mac.encode('hex')))
140
141    def delete(self, port):
142        for vid in self._dict.keys():
143            for mac in self._dict[vid].keys():
144                if self._dict[vid][mac]['port'].number == port.number:
145                    del self._dict[vid][mac]
146                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
147                                 lambda: (port.number, vid, mac.encode('hex')))
148            if not self._dict[vid]:
149                del self._dict[vid]
150
151
152class SwitchPort(object):
153    def __init__(self, number, interface):
154        self.number = number
155        self.interface = interface
156        self.tx = 0
157        self.rx = 0
158        self.shut = False
159
160    @staticmethod
161    def cmp_by_number(x, y):
162        return cmp(x.number, y.number)
163
164
165class SwitchingHub(DebugMixIn):
166    def __init__(self, fdb, debug=False):
167        self._fdb = fdb
168        self._debug = debug
169        self._table = {}
170        self._next = 1
171
172    @property
173    def portlist(self):
174        return sorted(self._table.itervalues(), cmp=SwitchPort.cmp_by_number)
175
176    def get_port(self, portnum):
177        return self._table[portnum]
178
179    def register_port(self, interface):
180        try:
181            self._set_privattr('portnum', interface, self._next)  # XXX
182            self._table[self._next] = SwitchPort(self._next, interface)
183            return self._next
184        finally:
185            self._next += 1
186
187    def unregister_port(self, interface):
188        portnum = self._get_privattr('portnum', interface)
189        self._del_privattr('portnum', interface)
190        self._fdb.delete(self._table[portnum])
191        del self._table[portnum]
192
193    def send(self, dst_interfaces, frame):
194        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
195        ports = (self._table[n] for n in portnums)
196        ports = (p for p in ports if not p.shut)
197        ports = sorted(ports, cmp=SwitchPort.cmp_by_number)
198
199        for p in ports:
200            p.interface.write_message(frame.data, True)
201            p.tx += 1
202
203        if ports:
204            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
205                         lambda: (','.join(str(p.number) for p in ports),
206                                  frame.vid,
207                                  frame.src_mac.encode('hex'),
208                                  frame.dst_mac.encode('hex')))
209
210    def receive(self, src_interface, frame):
211        port = self._table[self._get_privattr('portnum', src_interface)]
212
213        if not port.shut:
214            port.rx += 1
215            self._forward(port, frame)
216
217    def _forward(self, src_port, frame):
218        try:
219            if not frame.src_multicast:
220                self._fdb.learn(src_port, frame)
221
222            if not frame.dst_multicast:
223                dst_port = self._fdb.lookup(frame)
224
225                if dst_port:
226                    self.send([dst_port.interface], frame)
227                    return
228
229            ports = set(self.portlist) - set([src_port])
230            self.send((p.interface for p in ports), frame)
231
232        except:  # ex. received invalid frame
233            traceback.print_exc()
234
235    def _privattr(self, name):
236        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
237
238    def _set_privattr(self, name, obj, value):
239        return setattr(obj, self._privattr(name), value)
240
241    def _get_privattr(self, name, obj, defaults=None):
242        return getattr(obj, self._privattr(name), defaults)
243
244    def _del_privattr(self, name, obj):
245        return delattr(obj, self._privattr(name))
246
247
248class Htpasswd(object):
249    def __init__(self, path):
250        self._path = path
251        self._stat = None
252        self._data = {}
253
254    def auth(self, name, passwd):
255        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
256        return self._data.get(name) == passwd
257
258    def load(self):
259        old_stat = self._stat
260
261        with open(self._path) as fp:
262            fileno = fp.fileno()
263            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
264            self._stat = os.fstat(fileno)
265
266            unchanged = old_stat and \
267                        old_stat.st_ino == self._stat.st_ino and \
268                        old_stat.st_dev == self._stat.st_dev and \
269                        old_stat.st_mtime == self._stat.st_mtime
270
271            if not unchanged:
272                self._data = self._parse(fp)
273
274        return self
275
276    def _parse(self, fp):
277        data = {}
278        for line in fp:
279            line = line.strip()
280            if 0 <= line.find(':'):
281                name, passwd = line.split(':', 1)
282                if passwd.startswith('{SHA}'):
283                    data[name] = passwd[5:]
284        return data
285
286
287class BasicAuthMixIn(object):
288    def _execute(self, transforms, *args, **kwargs):
289        def do_execute():
290            sp = super(BasicAuthMixIn, self)
291            return sp._execute(transforms, *args, **kwargs)
292
293        def auth_required():
294            stream = getattr(self, 'stream', self.request.connection.stream)
295            stream.write(tornado.escape.utf8(
296                'HTTP/1.1 401 Authorization Required\r\n'
297                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
298            ))
299            stream.close()
300
301        try:
302            if not self._htpasswd:
303                return do_execute()
304
305            creds = self.request.headers.get('Authorization')
306
307            if not creds or not creds.startswith('Basic '):
308                return auth_required()
309
310            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
311
312            if self._htpasswd.load().auth(name, passwd):
313                return do_execute()
314        except:
315            traceback.print_exc()
316
317        return auth_required()
318
319
320class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
321    def __init__(self, app, req, switch, htpasswd=None, debug=False):
322        super(EtherWebSocketHandler, self).__init__(app, req)
323        self._switch = switch
324        self._htpasswd = htpasswd
325        self._debug = debug
326
327    @classmethod
328    def get_type(cls):
329        return 'server'
330
331    def get_target(self):
332        return self.request.remote_ip
333
334    def open(self):
335        try:
336            return self._switch.register_port(self)
337        finally:
338            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
339
340    def on_message(self, message):
341        self._switch.receive(self, EthernetFrame(message))
342
343    def on_close(self):
344        self._switch.unregister_port(self)
345        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
346
347
348class TapHandler(DebugMixIn):
349    READ_SIZE = 65535
350
351    def __init__(self, ioloop, switch, dev, debug=False):
352        self._ioloop = ioloop
353        self._switch = switch
354        self._dev = dev
355        self._debug = debug
356        self._tap = None
357
358    @classmethod
359    def get_type(cls):
360        return 'tap'
361
362    def get_target(self):
363        if self.closed:
364            return self._dev
365        return self._tap.name
366
367    @property
368    def closed(self):
369        return not self._tap
370
371    def open(self):
372        if not self.closed:
373            raise ValueError('already opened')
374        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
375        self._tap.up()
376        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
377        return self._switch.register_port(self)
378
379    def close(self):
380        if self.closed:
381            raise ValueError('I/O operation on closed tap')
382        self._switch.unregister_port(self)
383        self._ioloop.remove_handler(self.fileno())
384        self._tap.close()
385        self._tap = None
386
387    def fileno(self):
388        if self.closed:
389            raise ValueError('I/O operation on closed tap')
390        return self._tap.fileno()
391
392    def write_message(self, message, binary=False):
393        if self.closed:
394            raise ValueError('I/O operation on closed tap')
395        self._tap.write(message)
396
397    def __call__(self, fd, events):
398        try:
399            self._switch.receive(self, EthernetFrame(self._read()))
400            return
401        except:
402            traceback.print_exc()
403        self.close()
404
405    def _read(self):
406        if self.closed:
407            raise ValueError('I/O operation on closed tap')
408        buf = []
409        while True:
410            buf.append(self._tap.read(self.READ_SIZE))
411            if len(buf[-1]) < self.READ_SIZE:
412                break
413        return ''.join(buf)
414
415
416class EtherWebSocketClient(DebugMixIn):
417    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
418        self._ioloop = ioloop
419        self._switch = switch
420        self._url = url
421        self._ssl = ssl_
422        self._debug = debug
423        self._sock = None
424        self._options = {}
425
426        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
427            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
428            auth = ['Authorization: Basic %s' % token]
429            self._options['header'] = auth
430
431    @classmethod
432    def get_type(cls):
433        return 'client'
434
435    def get_target(self):
436        return self._url
437
438    @property
439    def closed(self):
440        return not self._sock
441
442    def open(self):
443        sslwrap = websocket._SSLSocketWrapper
444
445        if not self.closed:
446            raise websocket.WebSocketException('already opened')
447
448        if self._ssl:
449            websocket._SSLSocketWrapper = self._ssl
450
451        try:
452            self._sock = websocket.WebSocket()
453            self._sock.connect(self._url, **self._options)
454            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
455            return self._switch.register_port(self)
456        finally:
457            websocket._SSLSocketWrapper = sslwrap
458            self.dprintf('connected: %s\n', lambda: self._url)
459
460    def close(self):
461        if self.closed:
462            raise websocket.WebSocketException('already closed')
463        self._switch.unregister_port(self)
464        self._ioloop.remove_handler(self.fileno())
465        self._sock.close()
466        self._sock = None
467        self.dprintf('disconnected: %s\n', lambda: self._url)
468
469    def fileno(self):
470        if self.closed:
471            raise websocket.WebSocketException('closed socket')
472        return self._sock.io_sock.fileno()
473
474    def write_message(self, message, binary=False):
475        if self.closed:
476            raise websocket.WebSocketException('closed socket')
477        if binary:
478            flag = websocket.ABNF.OPCODE_BINARY
479        else:
480            flag = websocket.ABNF.OPCODE_TEXT
481        self._sock.send(message, flag)
482
483    def __call__(self, fd, events):
484        try:
485            data = self._sock.recv()
486            if data is not None:
487                self._switch.receive(self, EthernetFrame(data))
488                return
489        except:
490            traceback.print_exc()
491        self.close()
492
493
494class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
495    NAMESPACE = 'etherws.control'
496    INTERFACES = {
497        TapHandler.get_type():           TapHandler,
498        EtherWebSocketClient.get_type(): EtherWebSocketClient,
499    }
500
501    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
502        super(EtherWebSocketControlHandler, self).__init__(app, req)
503        self._ioloop = ioloop
504        self._switch = switch
505        self._htpasswd = htpasswd
506        self._debug = debug
507
508    def post(self):
509        id_ = None
510
511        try:
512            req = json.loads(self.request.body)
513            method = req['method']
514            params = req['params']
515            id_ = req.get('id')
516
517            if not method.startswith(self.NAMESPACE + '.'):
518                raise ValueError('invalid method: %s' % method)
519
520            if not isinstance(params, list):
521                raise ValueError('invalid params: %s' % params)
522
523            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
524            result = getattr(self, handler)(params)
525            self.finish({'result': result, 'error': None, 'id': id_})
526
527        except Exception as e:
528            traceback.print_exc()
529            self.finish({'result': None, 'error': str(e), 'id': id_})
530
531    def handle_listPort(self, params):
532        list_ = [self._portstat(p) for p in self._switch.portlist]
533        return {'portlist': list_}
534
535    def handle_addPort(self, params):
536        list_ = []
537        for p in params:
538            type_ = p['type']
539            target = p['target']
540            options = getattr(self, '_optparse_' + type_)(p.get('options', {}))
541            klass = self.INTERFACES[type_]
542            interface = klass(self._ioloop, self._switch, target, **options)
543            portnum = interface.open()
544            list_.append(self._portstat(self._switch.get_port(portnum)))
545        return {'portlist': list_}
546
547    def handle_delPort(self, params):
548        list_ = []
549        for p in params:
550            port = self._switch.get_port(int(p['port']))
551            list_.append(self._portstat(port))
552            port.interface.close()
553        return {'portlist': list_}
554
555    def handle_shutPort(self, params):
556        list_ = []
557        for p in params:
558            port = self._switch.get_port(int(p['port']))
559            port.shut = bool(p['flag'])
560            list_.append(self._portstat(port))
561        return {'portlist': list_}
562
563    def _optparse_tap(self, opt):
564        return {'debug': self._debug}
565
566    def _optparse_client(self, opt):
567        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
568        if opt.get('insecure'):
569            args = {}
570        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
571        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
572        return {'ssl_': ssl_, 'cred': cred, 'debug': self._debug}
573
574    @staticmethod
575    def _portstat(port):
576        return {
577            'port':   port.number,
578            'type':   port.interface.get_type(),
579            'target': port.interface.get_target(),
580            'tx':     port.tx,
581            'rx':     port.rx,
582            'shut':   port.shut,
583        }
584
585
586def start_switch(args):
587    def daemonize(nochdir=False, noclose=False):
588        if os.fork() > 0:
589            sys.exit(0)
590
591        os.setsid()
592
593        if os.fork() > 0:
594            sys.exit(0)
595
596        if not nochdir:
597            os.chdir('/')
598
599        if not noclose:
600            os.umask(0)
601            sys.stdin.close()
602            sys.stdout.close()
603            sys.stderr.close()
604            os.close(0)
605            os.close(1)
606            os.close(2)
607            sys.stdin = open(os.devnull)
608            sys.stdout = open(os.devnull, 'a')
609            sys.stderr = open(os.devnull, 'a')
610
611    def checkabspath(ns, path):
612        val = getattr(ns, path, '')
613        if not val.startswith('/'):
614            raise ValueError('invalid %: %s' % (path, val))
615
616    def getsslopt(ns, key, cert):
617        kval = getattr(ns, key, None)
618        cval = getattr(ns, cert, None)
619        if kval and cval:
620            return {'keyfile': kval, 'certfile': cval}
621        elif kval or cval:
622            raise ValueError('both %s and %s are required' % (key, cert))
623        return None
624
625    def setrealpath(ns, *keys):
626        for k in keys:
627            v = getattr(ns, k, None)
628            if v is not None:
629                v = os.path.realpath(v)
630                open(v).close()  # check readable
631                setattr(ns, k, v)
632
633    def setport(ns, port, isssl):
634        val = getattr(ns, port, None)
635        if val is None:
636            if isssl:
637                return setattr(ns, port, 443)
638            return setattr(ns, port, 80)
639        if not (0 <= val <= 65535):
640            raise ValueError('invalid %s: %s' % (port, val))
641
642    def sethtpasswd(ns, htpasswd):
643        val = getattr(ns, htpasswd, None)
644        if val:
645            return setattr(ns, htpasswd, Htpasswd(val))
646
647    #if args.debug:
648    #    websocket.enableTrace(True)
649
650    if args.ageout <= 0:
651        raise ValueError('invalid ageout: %s' % args.ageout)
652
653    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
654    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
655
656    checkabspath(args, 'path')
657    checkabspath(args, 'ctlpath')
658
659    sslopt = getsslopt(args, 'sslkey', 'sslcert')
660    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
661
662    setport(args, 'port', sslopt)
663    setport(args, 'ctlport', ctlsslopt)
664
665    sethtpasswd(args, 'htpasswd')
666    sethtpasswd(args, 'ctlhtpasswd')
667
668    ioloop = IOLoop.instance()
669    fdb = FDB(ageout=args.ageout, debug=args.debug)
670    switch = SwitchingHub(fdb, debug=args.debug)
671
672    if args.port == args.ctlport and args.host == args.ctlhost:
673        if args.path == args.ctlpath:
674            raise ValueError('same path/ctlpath on same host')
675        if args.sslkey != args.ctlsslkey:
676            raise ValueError('different sslkey/ctlsslkey on same host')
677        if args.sslcert != args.ctlsslcert:
678            raise ValueError('different sslcert/ctlsslcert on same host')
679
680        app = Application([
681            (args.path, EtherWebSocketHandler, {
682                'switch':   switch,
683                'htpasswd': args.htpasswd,
684                'debug':    args.debug,
685            }),
686            (args.ctlpath, EtherWebSocketControlHandler, {
687                'ioloop':   ioloop,
688                'switch':   switch,
689                'htpasswd': args.ctlhtpasswd,
690                'debug':    args.debug,
691            }),
692        ])
693        server = HTTPServer(app, ssl_options=sslopt)
694        server.listen(args.port, address=args.host)
695
696    else:
697        app = Application([(args.path, EtherWebSocketHandler, {
698            'switch':   switch,
699            'htpasswd': args.htpasswd,
700            'debug':    args.debug,
701        })])
702        server = HTTPServer(app, ssl_options=sslopt)
703        server.listen(args.port, address=args.host)
704
705        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
706            'ioloop':   ioloop,
707            'switch':   switch,
708            'htpasswd': args.ctlhtpasswd,
709            'debug':    args.debug,
710        })])
711        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
712        ctlserver.listen(args.ctlport, address=args.ctlhost)
713
714    if not args.foreground:
715        daemonize()
716
717    ioloop.start()
718
719
720def main():
721    parser = argparse.ArgumentParser()
722    subparsers = parser.add_subparsers(dest='subcommand')
723    parser_s = subparsers.add_parser('switch')
724    parser_c = subparsers.add_parser('control')
725
726    parser_s.add_argument('--debug', action='store_true', default=False)
727    parser_s.add_argument('--foreground', action='store_true', default=False)
728    parser_s.add_argument('--ageout', action='store', type=int, default=300)
729
730    parser_s.add_argument('--path', action='store', default='/')
731    parser_s.add_argument('--host', action='store', default='')
732    parser_s.add_argument('--port', action='store', type=int)
733    parser_s.add_argument('--htpasswd', action='store')
734    parser_s.add_argument('--sslkey', action='store')
735    parser_s.add_argument('--sslcert', action='store')
736
737    parser_s.add_argument('--ctlpath', action='store', default='/ctl')
738    parser_s.add_argument('--ctlhost', action='store', default='')
739    parser_s.add_argument('--ctlport', action='store', type=int)
740    parser_s.add_argument('--ctlhtpasswd', action='store')
741    parser_s.add_argument('--ctlsslkey', action='store')
742    parser_s.add_argument('--ctlsslcert', action='store')
743
744    args = parser.parse_args()
745
746    globals()['start_' + args.subcommand](args)
747
748
749if __name__ == '__main__':
750    main()
Note: See TracBrowser for help on using the repository browser.