Changeset 183


Ignore:
Timestamp:
07/29/12 04:11:13 (12 years ago)
Author:
atzm
Message:
  • global change: enables remote control
Location:
etherws/trunk
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • etherws/trunk/etherws.py

    r182 r183  
    88#   - python-pytun-0.2 
    99#   - websocket-client-0.7.0 
    10 #   - tornado-2.2.1 
     10#   - tornado-2.3 
    1111# 
    1212# todo: 
     
    4545import ssl 
    4646import time 
     47import json 
    4748import fcntl 
    4849import base64 
     
    5354 
    5455import websocket 
    55 import tornado.web 
    56 import tornado.ioloop 
    57 import tornado.httpserver 
    58  
     56 
     57from tornado.web import Application, RequestHandler 
    5958from tornado.websocket import WebSocketHandler 
     59from tornado.httpserver import HTTPServer 
     60from tornado.ioloop import IOLoop 
     61 
    6062from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI 
    6163 
     
    9698        if self.tagged: 
    9799            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff 
    98         return -1 
     100        return 0 
    99101 
    100102 
     
    118120 
    119121        if time.time() - entry['time'] > self._ageout: 
     122            port = self._dict[vid][mac]['port'] 
    120123            del self._dict[vid][mac] 
    121124            if not self._dict[vid]: 
    122125                del self._dict[vid] 
    123             self.dprintf('aged out: [%d] %s\n', 
    124                          lambda: (vid, mac.encode('hex'))) 
     126            self.dprintf('aged out: port:%d; vid:%d; mac:%s\n', 
     127                         lambda: (port.number, vid, mac.encode('hex'))) 
    125128            return None 
    126129 
     
    135138 
    136139        self._dict[vid][mac] = {'time': time.time(), 'port': port} 
    137         self.dprintf('learned: [%d] %s\n', 
    138                      lambda: (vid, mac.encode('hex'))) 
     140        self.dprintf('learned: port:%d; vid:%d; mac:%s\n', 
     141                     lambda: (port.number, vid, mac.encode('hex'))) 
    139142 
    140143    def delete(self, port): 
    141144        for vid in self._dict.keys(): 
    142145            for mac in self._dict[vid].keys(): 
    143                 if self._dict[vid][mac]['port'] is port: 
     146                if self._dict[vid][mac]['port'].number == port.number: 
    144147                    del self._dict[vid][mac] 
    145                     self.dprintf('deleted: [%d] %s\n', 
    146                                  lambda: (vid, mac.encode('hex'))) 
     148                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n', 
     149                                 lambda: (port.number, vid, mac.encode('hex'))) 
    147150            if not self._dict[vid]: 
    148151                del self._dict[vid] 
     152 
     153 
     154class SwitchPort(object): 
     155    def __init__(self, number, interface): 
     156        self.number = number 
     157        self.interface = interface 
     158        self.tx = 0 
     159        self.rx = 0 
     160        self.shut = False 
     161 
     162    @staticmethod 
     163    def cmp_by_number(x, y): 
     164        return cmp(x.number, y.number) 
    149165 
    150166 
     
    153169        self._fdb = fdb 
    154170        self._debug = debug 
    155         self._ports = [] 
    156  
    157     def register_port(self, port): 
    158         self._ports.append(port) 
    159  
    160     def unregister_port(self, port): 
    161         self._fdb.delete(port) 
    162         self._ports.remove(port) 
    163  
    164     def forward(self, src_port, frame): 
     171        self._table = {} 
     172        self._next = 1 
     173 
     174    @property 
     175    def portlist(self): 
     176        return sorted(self._table.itervalues(), cmp=SwitchPort.cmp_by_number) 
     177 
     178    def shut_port(self, portnum, flag=True): 
     179        self._table[portnum].shut = flag 
     180 
     181    def get_port(self, portnum): 
     182        return self._table[portnum] 
     183 
     184    def register_port(self, interface): 
     185        interface._switch_portnum = self._next  # XXX 
     186        self._table[self._next] = SwitchPort(self._next, interface) 
     187        self._next += 1 
     188 
     189    def unregister_port(self, interface): 
     190        self._fdb.delete(self._table[interface._switch_portnum]) 
     191        del self._table[interface._switch_portnum] 
     192        del interface._switch_portnum 
     193 
     194    def send(self, dst_interfaces, frame): 
     195        ports = sorted((self._table[i._switch_portnum] for i in dst_interfaces 
     196                        if not self._table[i._switch_portnum].shut), 
     197                       cmp=SwitchPort.cmp_by_number) 
     198 
     199        for p in ports: 
     200            p.interface.write_message(frame.data, True) 
     201            p.tx += 1 
     202 
     203        if ports: 
     204            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n', 
     205                         lambda: (','.join(str(p.number) for p in ports), 
     206                                  frame.vid, 
     207                                  frame.src_mac.encode('hex'), 
     208                                  frame.dst_mac.encode('hex'))) 
     209 
     210    def receive(self, src_interface, frame): 
     211        port = self._table[src_interface._switch_portnum] 
     212 
     213        if not port.shut: 
     214            port.rx += 1 
     215            self._forward(port, frame) 
     216 
     217    def _forward(self, src_port, frame): 
    165218        try: 
    166219            if not frame.src_multicast: 
     
    171224 
    172225                if dst_port: 
    173                     self._unicast(frame, dst_port) 
     226                    self.send([dst_port.interface], frame) 
    174227                    return 
    175228 
    176             self._broadcast(frame, src_port) 
     229            ports = set(self._table.itervalues()) - set([src_port]) 
     230            self.send((p.interface for p in ports), frame) 
    177231 
    178232        except:  # ex. received invalid frame 
    179233            traceback.print_exc() 
    180  
    181     def _unicast(self, frame, port): 
    182         port.write_message(frame.data, True) 
    183         self.dprintf('sent unicast: [%d] %s -> %s\n', 
    184                      lambda: (frame.vid, 
    185                               frame.src_mac.encode('hex'), 
    186                               frame.dst_mac.encode('hex'))) 
    187  
    188     def _broadcast(self, frame, *except_ports): 
    189         ports = self._ports[:] 
    190         for port in except_ports: 
    191             ports.remove(port) 
    192         for port in ports: 
    193             port.write_message(frame.data, True) 
    194         self.dprintf('sent broadcast: [%d] %s -> %s\n', 
    195                      lambda: (frame.vid, 
    196                               frame.src_mac.encode('hex'), 
    197                               frame.dst_mac.encode('hex'))) 
    198234 
    199235 
     
    283319        return not self._tap 
    284320 
     321    def get_type(self): 
     322        return 'tap' 
     323 
     324    def get_name(self): 
     325        if self.closed: 
     326            return self._dev 
     327        return self._tap.name 
     328 
    285329    def open(self): 
    286330        if not self.closed: 
     
    311355    def __call__(self, fd, events): 
    312356        try: 
    313             self._switch.forward(self, EthernetFrame(self._read())) 
     357            self._switch.receive(self, EthernetFrame(self._read())) 
    314358            return 
    315359        except: 
     
    338382            self._htpasswd = Htpasswd(self._htpasswd) 
    339383 
     384    def get_type(self): 
     385        return 'server' 
     386 
     387    def get_name(self): 
     388        return self.request.remote_ip 
     389 
    340390    def open(self): 
    341391        self._switch.register_port(self) 
     
    343393 
    344394    def on_message(self, message): 
    345         self._switch.forward(self, EthernetFrame(message)) 
     395        self._switch.receive(self, EthernetFrame(message)) 
    346396 
    347397    def on_close(self): 
     
    369419        return not self._sock 
    370420 
     421    def get_type(self): 
     422        return 'client' 
     423 
     424    def get_name(self): 
     425        return self._url 
     426 
    371427    def open(self): 
    372428        sslwrap = websocket._SSLSocketWrapper 
     
    414470            data = self._sock.recv() 
    415471            if data is not None: 
    416                 self._switch.forward(self, EthernetFrame(data)) 
     472                self._switch.receive(self, EthernetFrame(data)) 
    417473                return 
    418474        except: 
    419475            traceback.print_exc() 
    420476        self.close() 
     477 
     478 
     479class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler): 
     480    NAMESPACE = 'etherws.control' 
     481 
     482    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False): 
     483        super(EtherWebSocketControlHandler, self).__init__(app, req) 
     484        self._ioloop = ioloop 
     485        self._switch = switch 
     486        self._htpasswd = htpasswd 
     487        self._debug = debug 
     488 
     489    def post(self): 
     490        id_ = None 
     491 
     492        try: 
     493            req = json.loads(self.request.body) 
     494            method = req['method'] 
     495            params = req['params'] 
     496            id_ = req.get('id') 
     497 
     498            if not method.startswith(self.NAMESPACE + '.'): 
     499                raise ValueError('invalid method: %s' % method) 
     500 
     501            if not isinstance(params, list): 
     502                raise ValueError('invalid params: %s' % params) 
     503 
     504            handler = 'handle_' + method[len(self.NAMESPACE) + 1:] 
     505            result = getattr(self, handler)(params) 
     506            self.finish({'result': result, 'error': None, 'id': id_}) 
     507 
     508        except Exception as e: 
     509            traceback.print_exc() 
     510            self.finish({'result': None, 'error': str(e), 'id': id_}) 
     511 
     512    def handle_listPort(self, params): 
     513        list_ = [] 
     514        for port in self._switch.portlist: 
     515            list_.append({ 
     516                'port': port.number, 
     517                'type': port.interface.get_type(), 
     518                'name': port.interface.get_name(), 
     519                'tx':   port.tx, 
     520                'rx':   port.rx, 
     521                'shut': port.shut, 
     522            }) 
     523        return {'portlist': list_} 
     524 
     525    def handle_addPort(self, params): 
     526        for p in params: 
     527            getattr(self, '_openport_' + p['type'])(p) 
     528        return self.handle_listPort(params) 
     529 
     530    def handle_delPort(self, params): 
     531        for p in params: 
     532            self._switch.get_port(int(p['port'])).interface.close() 
     533        return self.handle_listPort(params) 
     534 
     535    def handle_shutPort(self, params): 
     536        for p in params: 
     537            self._switch.shut_port(int(p['port']), bool(p['flag'])) 
     538        return self.handle_listPort(params) 
     539 
     540    def _openport_tap(self, p): 
     541        dev = p['device'] 
     542        tap = TapHandler(self._ioloop, self._switch, dev, debug=self._debug) 
     543        tap.open() 
     544 
     545    def _openport_client(self, p): 
     546        ssl_ = self._ssl_wrapper(p.get('insecure'), p.get('cacerts')) 
     547        cred = {'user': p.get('user'), 'passwd': p.get('passwd')} 
     548        url = p['url'] 
     549        client = EtherWebSocketClient(self._ioloop, self._switch, 
     550                                      url, ssl_, cred, self._debug) 
     551        client.open() 
     552 
     553    @staticmethod 
     554    def _ssl_wrapper(insecure, ca_certs): 
     555        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': ca_certs} 
     556        if insecure: 
     557            args = {} 
     558        return lambda sock: ssl.wrap_socket(sock, **args) 
    421559 
    422560 
     
    446584 
    447585 
    448 def realpath(ns, *keys): 
    449     for k in keys: 
    450         v = getattr(ns, k, None) 
    451         if v is not None: 
    452             v = os.path.realpath(v) 
    453             setattr(ns, k, v) 
    454             open(v).close()  # check readable 
    455     return ns 
    456  
    457  
    458 def ssl_wrapper(insecure, ca_certs): 
    459     args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': ca_certs} 
    460     if insecure: 
    461         args = {} 
    462     return lambda sock: ssl.wrap_socket(sock, **args) 
    463  
    464  
    465 def server_main(args): 
    466     realpath(args, 'keyfile', 'certfile', 'htpasswd') 
    467  
    468     if args.keyfile and args.certfile: 
    469         ssl_options = {'keyfile': args.keyfile, 'certfile': args.certfile} 
    470     elif args.keyfile or args.certfile: 
    471         raise ValueError('both keyfile and certfile are required') 
     586def main(): 
     587    def realpath(ns, *keys): 
     588        for k in keys: 
     589            v = getattr(ns, k, None) 
     590            if v is not None: 
     591                v = os.path.realpath(v) 
     592                open(v).close()  # check readable 
     593                setattr(ns, k, v) 
     594        return ns 
     595 
     596    parser = argparse.ArgumentParser() 
     597 
     598    parser.add_argument('--debug', action='store_true', default=False) 
     599    parser.add_argument('--foreground', action='store_true', default=False) 
     600    parser.add_argument('--ageout', action='store', type=int, default=300) 
     601 
     602    parser.add_argument('--path', action='store', default='/') 
     603    parser.add_argument('--address', action='store', default='') 
     604    parser.add_argument('--port', action='store', type=int) 
     605    parser.add_argument('--htpasswd', action='store') 
     606    parser.add_argument('--sslkey', action='store') 
     607    parser.add_argument('--sslcert', action='store') 
     608 
     609    parser.add_argument('--ctlpath', action='store', default='/ctl') 
     610    parser.add_argument('--ctladdress', action='store', default='127.0.0.1') 
     611    parser.add_argument('--ctlport', action='store', type=int, default=7867) 
     612 
     613    args = realpath(parser.parse_args(), 'htpasswd', 'sslkey', 'sslcert') 
     614 
     615    #if args.debug: 
     616    #    websocket.enableTrace(True) 
     617 
     618    if args.ageout <= 0: 
     619        raise ValueError('invalid ageout: %s' % args.ageout) 
     620 
     621    if not args.path.startswith('/'): 
     622        raise ValueError('invalid path: %s' % args.path) 
     623 
     624    if not args.ctlpath.startswith('/'): 
     625        raise ValueError('invalid ctlpath: %s' % args.ctlpath) 
     626 
     627    if args.sslkey and args.sslcert: 
     628        sslopt = {'keyfile': args.sslkey, 'certfile': args.sslcert} 
     629    elif args.sslkey or args.sslcert: 
     630        raise ValueError('both sslkey and sslcert are required') 
    472631    else: 
    473         ssl_options = None 
     632        sslopt = None 
    474633 
    475634    if args.port is None: 
    476         if ssl_options: 
     635        if sslopt: 
    477636            args.port = 443 
    478637        else: 
     
    481640        raise ValueError('invalid port: %s' % args.port) 
    482641 
    483     if args.ageout <= 0: 
    484         raise ValueError('invalid ageout: %s' % args.ageout) 
    485  
    486     ioloop = tornado.ioloop.IOLoop.instance() 
     642    if not (0 <= args.ctlport <= 65535): 
     643        raise ValueError('invalid ctlport: %s' % args.ctlport) 
     644 
     645    if args.htpasswd: 
     646        args.htpasswd = Htpasswd(args.htpasswd) 
     647 
     648    ioloop = IOLoop.instance() 
    487649    fdb = FDB(ageout=args.ageout, debug=args.debug) 
    488     sw = SwitchingHub(fdb, debug=args.debug) 
    489  
    490     harg = {'switch': sw, 'htpasswd': args.htpasswd, 'debug': args.debug} 
    491     serv = (args.path, EtherWebSocketHandler, harg) 
    492     app = tornado.web.Application([serv]) 
    493     server = tornado.httpserver.HTTPServer(app, ssl_options=ssl_options) 
     650    switch = SwitchingHub(fdb, debug=args.debug) 
     651 
     652    app = Application([(args.path, EtherWebSocketHandler, { 
     653        'switch':   switch, 
     654        'htpasswd': args.htpasswd, 
     655        'debug':    args.debug, 
     656    })]) 
     657    server = HTTPServer(app, ssl_options=sslopt) 
    494658    server.listen(args.port, address=args.address) 
    495659 
    496     for dev in args.device: 
    497         tap = TapHandler(ioloop, sw, dev, debug=args.debug) 
    498         tap.open() 
     660    ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, { 
     661        'ioloop':   ioloop, 
     662        'switch':   switch, 
     663        'htpasswd': None, 
     664        'debug':    args.debug, 
     665    })]) 
     666    ctlserver = HTTPServer(ctl) 
     667    ctlserver.listen(args.ctlport, address=args.ctladdress) 
    499668 
    500669    if not args.foreground: 
     
    504673 
    505674 
    506 def client_main(args): 
    507     realpath(args, 'cacerts') 
    508  
    509     if args.debug: 
    510         websocket.enableTrace(True) 
    511  
    512     if args.ageout <= 0: 
    513         raise ValueError('invalid ageout: %s' % args.ageout) 
    514  
    515     if args.user and args.passwd is None: 
    516         args.passwd = getpass.getpass() 
    517  
    518     ssl_ = ssl_wrapper(args.insecure, args.cacerts) 
    519     cred = {'user': args.user, 'passwd': args.passwd} 
    520     ioloop = tornado.ioloop.IOLoop.instance() 
    521     fdb = FDB(ageout=args.ageout, debug=args.debug) 
    522     sw = SwitchingHub(fdb, debug=args.debug) 
    523  
    524     for uri in args.uri: 
    525         client = EtherWebSocketClient(ioloop, sw, uri, ssl_, cred, args.debug) 
    526         client.open() 
    527  
    528     for dev in args.device: 
    529         tap = TapHandler(ioloop, sw, dev, debug=args.debug) 
    530         tap.open() 
    531  
    532     if not args.foreground: 
    533         daemonize() 
    534  
    535     ioloop.start() 
    536  
    537  
    538 def main(): 
    539     parser = argparse.ArgumentParser() 
    540     parser.add_argument('--device', action='append', default=[]) 
    541     parser.add_argument('--ageout', action='store', type=int, default=300) 
    542     parser.add_argument('--foreground', action='store_true', default=False) 
    543     parser.add_argument('--debug', action='store_true', default=False) 
    544  
    545     subparsers = parser.add_subparsers(dest='subcommand') 
    546  
    547     parser_s = subparsers.add_parser('server') 
    548     parser_s.add_argument('--address', action='store', default='') 
    549     parser_s.add_argument('--port', action='store', type=int) 
    550     parser_s.add_argument('--path', action='store', default='/') 
    551     parser_s.add_argument('--htpasswd', action='store') 
    552     parser_s.add_argument('--keyfile', action='store') 
    553     parser_s.add_argument('--certfile', action='store') 
    554  
    555     parser_c = subparsers.add_parser('client') 
    556     parser_c.add_argument('--uri', action='append', default=[]) 
    557     parser_c.add_argument('--insecure', action='store_true', default=False) 
    558     parser_c.add_argument('--cacerts', action='store') 
    559     parser_c.add_argument('--user', action='store') 
    560     parser_c.add_argument('--passwd', action='store') 
    561  
    562     args = parser.parse_args() 
    563  
    564     if args.subcommand == 'server': 
    565         server_main(args) 
    566     elif args.subcommand == 'client': 
    567         client_main(args) 
    568  
    569  
    570675if __name__ == '__main__': 
    571676    main() 
  • etherws/trunk/setup.py

    r170 r183  
    5050        'python-pytun>=0.2', 
    5151        'websocket-client>=0.7.0', 
    52         'tornado>=2.2.1', 
     52        'tornado>=2.3', 
    5353    ], 
    5454    classifiers=[ 
Note: See TracChangeset for help on using the changeset viewer.