source: etherws/trunk/etherws.py @ 198

Revision 198, 26.8 KB checked in by atzm, 12 years ago (diff)
  • add ctl command "listfdb"
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.3
11#   - PyYAML-3.10
12#
13# ===========================================================================
14# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
15# All rights reserved.
16#
17# Redistribution and use in source and binary forms, with or without
18# modification, are permitted provided that the following conditions are met:
19#
20# 1. Redistributions of source code must retain the above copyright notice,
21#    this list of conditions and the following disclaimer.
22# 2. Redistributions in binary form must reproduce the above copyright
23#    notice, this list of conditions and the following disclaimer in the
24#    documentation and/or other materials provided with the distribution.
25#
26# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
27# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
28# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
29# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
30# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
31# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
32# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
33# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
34# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
35# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
36# POSSIBILITY OF SUCH DAMAGE.
37# ===========================================================================
38#
39# $Id$
40
41import os
42import sys
43import ssl
44import time
45import json
46import yaml
47import fcntl
48import base64
49import urllib2
50import hashlib
51import getpass
52import argparse
53import traceback
54
55import tornado
56import websocket
57
58from tornado.web import Application, RequestHandler
59from tornado.websocket import WebSocketHandler
60from tornado.httpserver import HTTPServer
61from tornado.ioloop import IOLoop
62
63from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
64
65
66class DebugMixIn(object):
67    def dprintf(self, msg, func=lambda: ()):
68        if self._debug:
69            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
70            sys.stderr.write(prefix + (msg % func()))
71
72
73class EthernetFrame(object):
74    def __init__(self, data):
75        self.data = data
76
77    @property
78    def dst_multicast(self):
79        return ord(self.data[0]) & 1
80
81    @property
82    def src_multicast(self):
83        return ord(self.data[6]) & 1
84
85    @property
86    def dst_mac(self):
87        return self.data[:6]
88
89    @property
90    def src_mac(self):
91        return self.data[6:12]
92
93    @property
94    def tagged(self):
95        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
96
97    @property
98    def vid(self):
99        if self.tagged:
100            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
101        return 0
102
103    @staticmethod
104    def format_mac(mac, sep=':'):
105        return sep.join(b.encode('hex') for b in mac)
106
107
108class FDB(DebugMixIn):
109    class Entry(object):
110        def __init__(self, port, ageout):
111            self.port = port
112            self._time = time.time()
113            self._ageout = ageout
114
115        @property
116        def age(self):
117            return time.time() - self._time
118
119        @property
120        def agedout(self):
121            return self.age > self._ageout
122
123    def __init__(self, ageout, debug=False):
124        self._ageout = ageout
125        self._debug = debug
126        self._table = {}
127
128    def _set_entry(self, vid, mac, port):
129        if vid not in self._table:
130            self._table[vid] = {}
131        self._table[vid][mac] = self.Entry(port, self._ageout)
132
133    def _del_entry(self, vid, mac):
134        if vid in self._table:
135            if mac in self._table[vid]:
136                del self._table[vid][mac]
137            if not self._table[vid]:
138                del self._table[vid]
139
140    def get_entry(self, vid, mac):
141        try:
142            entry = self._table[vid][mac]
143        except KeyError:
144            return None
145
146        if not entry.agedout:
147            return entry
148
149        self._del_entry(vid, mac)
150        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
151                     lambda: (entry.port.number, vid, mac.encode('hex')))
152
153    def get_vid_list(self):
154        return sorted(self._table.iterkeys())
155
156    def get_mac_list(self, vid):
157        return sorted(self._table[vid].iterkeys())
158
159    def lookup(self, frame):
160        mac = frame.dst_mac
161        vid = frame.vid
162        entry = self.get_entry(vid, mac)
163        return getattr(entry, 'port', None)
164
165    def learn(self, port, frame):
166        mac = frame.src_mac
167        vid = frame.vid
168        self._set_entry(vid, mac, port)
169        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
170                     lambda: (port.number, vid, mac.encode('hex')))
171
172    def delete(self, port):
173        for vid in self.get_vid_list():
174            for mac in self.get_mac_list(vid):
175                entry = self.get_entry(vid, mac)
176                if entry and entry.port.number == port.number:
177                    self._del_entry(vid, mac)
178                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
179                                 lambda: (port.number, vid, mac.encode('hex')))
180
181
182class SwitchingHub(DebugMixIn):
183    class Port(object):
184        def __init__(self, number, interface):
185            self.number = number
186            self.interface = interface
187            self.tx = 0
188            self.rx = 0
189            self.shut = False
190
191        @staticmethod
192        def cmp_by_number(x, y):
193            return cmp(x.number, y.number)
194
195    def __init__(self, fdb, debug=False):
196        self.fdb = fdb
197        self._debug = debug
198        self._table = {}
199        self._next = 1
200
201    @property
202    def portlist(self):
203        return sorted(self._table.itervalues(), cmp=self.Port.cmp_by_number)
204
205    def get_port(self, portnum):
206        return self._table[portnum]
207
208    def register_port(self, interface):
209        try:
210            self._set_privattr('portnum', interface, self._next)  # XXX
211            self._table[self._next] = self.Port(self._next, interface)
212            return self._next
213        finally:
214            self._next += 1
215
216    def unregister_port(self, interface):
217        portnum = self._get_privattr('portnum', interface)
218        self._del_privattr('portnum', interface)
219        self.fdb.delete(self._table[portnum])
220        del self._table[portnum]
221
222    def send(self, dst_interfaces, frame):
223        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
224        ports = (self._table[n] for n in portnums)
225        ports = (p for p in ports if not p.shut)
226        ports = sorted(ports, cmp=self.Port.cmp_by_number)
227
228        for p in ports:
229            p.interface.write_message(frame.data, True)
230            p.tx += 1
231
232        if ports:
233            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
234                         lambda: (','.join(str(p.number) for p in ports),
235                                  frame.vid,
236                                  frame.src_mac.encode('hex'),
237                                  frame.dst_mac.encode('hex')))
238
239    def receive(self, src_interface, frame):
240        port = self._table[self._get_privattr('portnum', src_interface)]
241
242        if not port.shut:
243            port.rx += 1
244            self._forward(port, frame)
245
246    def _forward(self, src_port, frame):
247        try:
248            if not frame.src_multicast:
249                self.fdb.learn(src_port, frame)
250
251            if not frame.dst_multicast:
252                dst_port = self.fdb.lookup(frame)
253
254                if dst_port:
255                    self.send([dst_port.interface], frame)
256                    return
257
258            ports = set(self.portlist) - set([src_port])
259            self.send((p.interface for p in ports), frame)
260
261        except:  # ex. received invalid frame
262            traceback.print_exc()
263
264    def _privattr(self, name):
265        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
266
267    def _set_privattr(self, name, obj, value):
268        return setattr(obj, self._privattr(name), value)
269
270    def _get_privattr(self, name, obj, defaults=None):
271        return getattr(obj, self._privattr(name), defaults)
272
273    def _del_privattr(self, name, obj):
274        return delattr(obj, self._privattr(name))
275
276
277class Htpasswd(object):
278    def __init__(self, path):
279        self._path = path
280        self._stat = None
281        self._data = {}
282
283    def auth(self, name, passwd):
284        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
285        return self._data.get(name) == passwd
286
287    def load(self):
288        old_stat = self._stat
289
290        with open(self._path) as fp:
291            fileno = fp.fileno()
292            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
293            self._stat = os.fstat(fileno)
294
295            unchanged = old_stat and \
296                        old_stat.st_ino == self._stat.st_ino and \
297                        old_stat.st_dev == self._stat.st_dev and \
298                        old_stat.st_mtime == self._stat.st_mtime
299
300            if not unchanged:
301                self._data = self._parse(fp)
302
303        return self
304
305    def _parse(self, fp):
306        data = {}
307        for line in fp:
308            line = line.strip()
309            if 0 <= line.find(':'):
310                name, passwd = line.split(':', 1)
311                if passwd.startswith('{SHA}'):
312                    data[name] = passwd[5:]
313        return data
314
315
316class BasicAuthMixIn(object):
317    def _execute(self, transforms, *args, **kwargs):
318        def do_execute():
319            sp = super(BasicAuthMixIn, self)
320            return sp._execute(transforms, *args, **kwargs)
321
322        def auth_required():
323            stream = getattr(self, 'stream', self.request.connection.stream)
324            stream.write(tornado.escape.utf8(
325                'HTTP/1.1 401 Authorization Required\r\n'
326                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
327            ))
328            stream.close()
329
330        try:
331            if not self._htpasswd:
332                return do_execute()
333
334            creds = self.request.headers.get('Authorization')
335
336            if not creds or not creds.startswith('Basic '):
337                return auth_required()
338
339            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
340
341            if self._htpasswd.load().auth(name, passwd):
342                return do_execute()
343        except:
344            traceback.print_exc()
345
346        return auth_required()
347
348
349class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
350    IFTYPE = 'server'
351
352    def __init__(self, app, req, switch, htpasswd=None, debug=False):
353        super(EtherWebSocketHandler, self).__init__(app, req)
354        self._switch = switch
355        self._htpasswd = htpasswd
356        self._debug = debug
357
358    def get_target(self):
359        return self.request.remote_ip
360
361    def open(self):
362        try:
363            return self._switch.register_port(self)
364        finally:
365            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
366
367    def on_message(self, message):
368        self._switch.receive(self, EthernetFrame(message))
369
370    def on_close(self):
371        self._switch.unregister_port(self)
372        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
373
374
375class TapHandler(DebugMixIn):
376    IFTYPE = 'tap'
377    READ_SIZE = 65535
378
379    def __init__(self, ioloop, switch, dev, debug=False):
380        self._ioloop = ioloop
381        self._switch = switch
382        self._dev = dev
383        self._debug = debug
384        self._tap = None
385
386    def get_target(self):
387        if self.closed:
388            return self._dev
389        return self._tap.name
390
391    @property
392    def closed(self):
393        return not self._tap
394
395    def open(self):
396        if not self.closed:
397            raise ValueError('already opened')
398        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
399        self._tap.up()
400        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
401        return self._switch.register_port(self)
402
403    def close(self):
404        if self.closed:
405            raise ValueError('I/O operation on closed tap')
406        self._switch.unregister_port(self)
407        self._ioloop.remove_handler(self.fileno())
408        self._tap.close()
409        self._tap = None
410
411    def fileno(self):
412        if self.closed:
413            raise ValueError('I/O operation on closed tap')
414        return self._tap.fileno()
415
416    def write_message(self, message, binary=False):
417        if self.closed:
418            raise ValueError('I/O operation on closed tap')
419        self._tap.write(message)
420
421    def __call__(self, fd, events):
422        try:
423            self._switch.receive(self, EthernetFrame(self._read()))
424            return
425        except:
426            traceback.print_exc()
427        self.close()
428
429    def _read(self):
430        if self.closed:
431            raise ValueError('I/O operation on closed tap')
432        buf = []
433        while True:
434            buf.append(self._tap.read(self.READ_SIZE))
435            if len(buf[-1]) < self.READ_SIZE:
436                break
437        return ''.join(buf)
438
439
440class EtherWebSocketClient(DebugMixIn):
441    IFTYPE = 'client'
442
443    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
444        self._ioloop = ioloop
445        self._switch = switch
446        self._url = url
447        self._ssl = ssl_
448        self._debug = debug
449        self._sock = None
450        self._options = {}
451
452        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
453            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
454            auth = ['Authorization: Basic %s' % token]
455            self._options['header'] = auth
456
457    def get_target(self):
458        return self._url
459
460    @property
461    def closed(self):
462        return not self._sock
463
464    def open(self):
465        sslwrap = websocket._SSLSocketWrapper
466
467        if not self.closed:
468            raise websocket.WebSocketException('already opened')
469
470        if self._ssl:
471            websocket._SSLSocketWrapper = self._ssl
472
473        try:
474            self._sock = websocket.WebSocket()
475            self._sock.connect(self._url, **self._options)
476            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
477            return self._switch.register_port(self)
478        finally:
479            websocket._SSLSocketWrapper = sslwrap
480            self.dprintf('connected: %s\n', lambda: self._url)
481
482    def close(self):
483        if self.closed:
484            raise websocket.WebSocketException('already closed')
485        self._switch.unregister_port(self)
486        self._ioloop.remove_handler(self.fileno())
487        self._sock.close()
488        self._sock = None
489        self.dprintf('disconnected: %s\n', lambda: self._url)
490
491    def fileno(self):
492        if self.closed:
493            raise websocket.WebSocketException('closed socket')
494        return self._sock.io_sock.fileno()
495
496    def write_message(self, message, binary=False):
497        if self.closed:
498            raise websocket.WebSocketException('closed socket')
499        if binary:
500            flag = websocket.ABNF.OPCODE_BINARY
501        else:
502            flag = websocket.ABNF.OPCODE_TEXT
503        self._sock.send(message, flag)
504
505    def __call__(self, fd, events):
506        try:
507            data = self._sock.recv()
508            if data is not None:
509                self._switch.receive(self, EthernetFrame(data))
510                return
511        except:
512            traceback.print_exc()
513        self.close()
514
515
516class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
517    NAMESPACE = 'etherws.control'
518    IFTYPES = {
519        TapHandler.IFTYPE:           TapHandler,
520        EtherWebSocketClient.IFTYPE: EtherWebSocketClient,
521    }
522
523    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
524        super(EtherWebSocketControlHandler, self).__init__(app, req)
525        self._ioloop = ioloop
526        self._switch = switch
527        self._htpasswd = htpasswd
528        self._debug = debug
529
530    def post(self):
531        id_ = None
532
533        try:
534            req = json.loads(self.request.body)
535            method = req['method']
536            params = req['params']
537            id_ = req.get('id')
538
539            if not method.startswith(self.NAMESPACE + '.'):
540                raise ValueError('invalid method: %s' % method)
541
542            if not isinstance(params, list):
543                raise ValueError('invalid params: %s' % params)
544
545            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
546            result = getattr(self, handler)(params)
547            self.finish({'result': result, 'error': None, 'id': id_})
548
549        except Exception as e:
550            traceback.print_exc()
551            msg = '%s: %s' % (e.__class__.__name__, str(e))
552            self.finish({'result': None, 'error': {'message': msg}, 'id': id_})
553
554    def handle_listFdb(self, params):
555        list_ = []
556        for vid in self._switch.fdb.get_vid_list():
557            for mac in self._switch.fdb.get_mac_list(vid):
558                entry = self._switch.fdb.get_entry(vid, mac)
559                if entry:
560                    mac = EthernetFrame.format_mac(mac)
561                    list_.append([vid, mac, entry.port.number, int(entry.age)])
562        return {'result': list_}
563
564    def handle_listPort(self, params):
565        list_ = [self._portstat(p) for p in self._switch.portlist]
566        return {'result': list_}
567
568    def handle_addPort(self, params):
569        list_ = []
570        for p in params:
571            type_ = p['type']
572            target = p['target']
573            options = getattr(self, '_optparse_' + type_)(p.get('options', {}))
574            klass = self.IFTYPES[type_]
575            interface = klass(self._ioloop, self._switch, target, **options)
576            portnum = interface.open()
577            list_.append(self._portstat(self._switch.get_port(portnum)))
578        return {'result': list_}
579
580    def handle_delPort(self, params):
581        list_ = []
582        for p in params:
583            port = self._switch.get_port(int(p['port']))
584            list_.append(self._portstat(port))
585            port.interface.close()
586        return {'result': list_}
587
588    def handle_shutPort(self, params):
589        list_ = []
590        for p in params:
591            port = self._switch.get_port(int(p['port']))
592            port.shut = bool(p['shut'])
593            list_.append(self._portstat(port))
594        return {'result': list_}
595
596    def _optparse_tap(self, opt):
597        return {'debug': self._debug}
598
599    def _optparse_client(self, opt):
600        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
601        if opt.get('insecure'):
602            args = {}
603        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
604        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
605        return {'ssl_': ssl_, 'cred': cred, 'debug': self._debug}
606
607    @staticmethod
608    def _portstat(port):
609        return {
610            'port':   port.number,
611            'type':   port.interface.IFTYPE,
612            'target': port.interface.get_target(),
613            'tx':     port.tx,
614            'rx':     port.rx,
615            'shut':   port.shut,
616        }
617
618
619def start_sw(args):
620    def daemonize(nochdir=False, noclose=False):
621        if os.fork() > 0:
622            sys.exit(0)
623
624        os.setsid()
625
626        if os.fork() > 0:
627            sys.exit(0)
628
629        if not nochdir:
630            os.chdir('/')
631
632        if not noclose:
633            os.umask(0)
634            sys.stdin.close()
635            sys.stdout.close()
636            sys.stderr.close()
637            os.close(0)
638            os.close(1)
639            os.close(2)
640            sys.stdin = open(os.devnull)
641            sys.stdout = open(os.devnull, 'a')
642            sys.stderr = open(os.devnull, 'a')
643
644    def checkabspath(ns, path):
645        val = getattr(ns, path, '')
646        if not val.startswith('/'):
647            raise ValueError('invalid %: %s' % (path, val))
648
649    def getsslopt(ns, key, cert):
650        kval = getattr(ns, key, None)
651        cval = getattr(ns, cert, None)
652        if kval and cval:
653            return {'keyfile': kval, 'certfile': cval}
654        elif kval or cval:
655            raise ValueError('both %s and %s are required' % (key, cert))
656        return None
657
658    def setrealpath(ns, *keys):
659        for k in keys:
660            v = getattr(ns, k, None)
661            if v is not None:
662                v = os.path.realpath(v)
663                open(v).close()  # check readable
664                setattr(ns, k, v)
665
666    def setport(ns, port, isssl):
667        val = getattr(ns, port, None)
668        if val is None:
669            if isssl:
670                return setattr(ns, port, 443)
671            return setattr(ns, port, 80)
672        if not (0 <= val <= 65535):
673            raise ValueError('invalid %s: %s' % (port, val))
674
675    def sethtpasswd(ns, htpasswd):
676        val = getattr(ns, htpasswd, None)
677        if val:
678            return setattr(ns, htpasswd, Htpasswd(val))
679
680    #if args.debug:
681    #    websocket.enableTrace(True)
682
683    if args.ageout <= 0:
684        raise ValueError('invalid ageout: %s' % args.ageout)
685
686    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
687    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
688
689    checkabspath(args, 'path')
690    checkabspath(args, 'ctlpath')
691
692    sslopt = getsslopt(args, 'sslkey', 'sslcert')
693    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
694
695    setport(args, 'port', sslopt)
696    setport(args, 'ctlport', ctlsslopt)
697
698    sethtpasswd(args, 'htpasswd')
699    sethtpasswd(args, 'ctlhtpasswd')
700
701    ioloop = IOLoop.instance()
702    fdb = FDB(ageout=args.ageout, debug=args.debug)
703    switch = SwitchingHub(fdb, debug=args.debug)
704
705    if args.port == args.ctlport and args.host == args.ctlhost:
706        if args.path == args.ctlpath:
707            raise ValueError('same path/ctlpath on same host')
708        if args.sslkey != args.ctlsslkey:
709            raise ValueError('different sslkey/ctlsslkey on same host')
710        if args.sslcert != args.ctlsslcert:
711            raise ValueError('different sslcert/ctlsslcert on same host')
712
713        app = Application([
714            (args.path, EtherWebSocketHandler, {
715                'switch':   switch,
716                'htpasswd': args.htpasswd,
717                'debug':    args.debug,
718            }),
719            (args.ctlpath, EtherWebSocketControlHandler, {
720                'ioloop':   ioloop,
721                'switch':   switch,
722                'htpasswd': args.ctlhtpasswd,
723                'debug':    args.debug,
724            }),
725        ])
726        server = HTTPServer(app, ssl_options=sslopt)
727        server.listen(args.port, address=args.host)
728
729    else:
730        app = Application([(args.path, EtherWebSocketHandler, {
731            'switch':   switch,
732            'htpasswd': args.htpasswd,
733            'debug':    args.debug,
734        })])
735        server = HTTPServer(app, ssl_options=sslopt)
736        server.listen(args.port, address=args.host)
737
738        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
739            'ioloop':   ioloop,
740            'switch':   switch,
741            'htpasswd': args.ctlhtpasswd,
742            'debug':    args.debug,
743        })])
744        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
745        ctlserver.listen(args.ctlport, address=args.ctlhost)
746
747    if not args.foreground:
748        daemonize()
749
750    ioloop.start()
751
752
753def start_ctl(args):
754    def request(args, method, params):
755        method = '.'.join([EtherWebSocketControlHandler.NAMESPACE, method])
756        data = json.dumps({'method': method, 'params': params})
757        req = urllib2.Request(args.ctlurl)
758        req.add_header('Content-type', 'application/json')
759        if args.ctluser:
760            if not args.ctlpasswd:
761                args.ctlpasswd = getpass.getpass()
762            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
763            req.add_header('Authorization', 'Basic %s' % token)
764        return json.loads(urllib2.urlopen(req, data).read())
765
766    def handle_ctl_addport(args):
767        params = [{
768            'type':    args.type,
769            'target':  args.target,
770            'options': {
771                'insecure': args.insecure,
772                'cacerts':  args.cacerts,
773                'user':     args.user,
774                'passwd':   args.passwd,
775            }
776        }]
777        return request(args, 'addPort', params)
778
779    def handle_ctl_shutport(args):
780        if args.port <= 0:
781            raise ValueError('invalid port: %d' % args.port)
782        params = [{'port': args.port, 'shut': args.no}]
783        return request(args, 'shutPort', params)
784
785    def handle_ctl_delport(args):
786        if args.port <= 0:
787            raise ValueError('invalid port: %d' % args.port)
788        params = [{'port': args.port}]
789        return request(args, 'delPort', params)
790
791    def handle_ctl_listport(args):
792        return request(args, 'listPort', [])
793
794    def handle_ctl_listfdb(args):
795        return request(args, 'listFdb', [])
796
797    res = locals()['handle_ctl_' + args.control_method](args)
798
799    if res['error']:
800        print(res['error']['message'])
801    else:
802        print(yaml.safe_dump(res['result']).strip())
803
804
805def main():
806    parser = argparse.ArgumentParser()
807    subcommand = parser.add_subparsers(dest='subcommand')
808
809    # -- sw command parser
810    parser_s = subcommand.add_parser('sw')
811
812    parser_s.add_argument('--debug', action='store_true', default=False)
813    parser_s.add_argument('--foreground', action='store_true', default=False)
814    parser_s.add_argument('--ageout', type=int, default=300)
815
816    parser_s.add_argument('--path', default='/')
817    parser_s.add_argument('--host', default='')
818    parser_s.add_argument('--port', type=int)
819    parser_s.add_argument('--htpasswd')
820    parser_s.add_argument('--sslkey')
821    parser_s.add_argument('--sslcert')
822
823    parser_s.add_argument('--ctlpath', default='/ctl')
824    parser_s.add_argument('--ctlhost', default='')
825    parser_s.add_argument('--ctlport', type=int)
826    parser_s.add_argument('--ctlhtpasswd')
827    parser_s.add_argument('--ctlsslkey')
828    parser_s.add_argument('--ctlsslcert')
829
830    # -- ctl command parser
831    parser_c = subcommand.add_parser('ctl')
832    parser_c.add_argument('--ctlurl', default='http://localhost/ctl')
833    parser_c.add_argument('--ctluser')
834    parser_c.add_argument('--ctlpasswd')
835
836    control_method = parser_c.add_subparsers(dest='control_method')
837
838    parser_c_ap = control_method.add_parser('addport')
839    parser_c_ap.add_argument(
840        'type', choices=EtherWebSocketControlHandler.IFTYPES.keys())
841    parser_c_ap.add_argument('target')
842    parser_c_ap.add_argument('--insecure', action='store_true', default=False)
843    parser_c_ap.add_argument('--cacerts')
844    parser_c_ap.add_argument('--user')
845    parser_c_ap.add_argument('--passwd')
846
847    parser_c_sp = control_method.add_parser('shutport')
848    parser_c_sp.add_argument('port', type=int)
849    parser_c_sp.add_argument('--no', action='store_false', default=True)
850
851    parser_c_dp = control_method.add_parser('delport')
852    parser_c_dp.add_argument('port', type=int)
853
854    parser_c_lp = control_method.add_parser('listport')
855
856    parser_c_lf = control_method.add_parser('listfdb')
857
858    # -- go
859    args = parser.parse_args()
860    globals()['start_' + args.subcommand](args)
861
862
863if __name__ == '__main__':
864    main()
Note: See TracBrowser for help on using the repository browser.