source: etherws/trunk/etherws.py @ 197

Revision 197, 26.1 KB checked in by atzm, 12 years ago (diff)
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.3
11#   - PyYAML-3.10
12#
13# ===========================================================================
14# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
15# All rights reserved.
16#
17# Redistribution and use in source and binary forms, with or without
18# modification, are permitted provided that the following conditions are met:
19#
20# 1. Redistributions of source code must retain the above copyright notice,
21#    this list of conditions and the following disclaimer.
22# 2. Redistributions in binary form must reproduce the above copyright
23#    notice, this list of conditions and the following disclaimer in the
24#    documentation and/or other materials provided with the distribution.
25#
26# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
27# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
28# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
29# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
30# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
31# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
32# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
33# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
34# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
35# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
36# POSSIBILITY OF SUCH DAMAGE.
37# ===========================================================================
38#
39# $Id$
40
41import os
42import sys
43import ssl
44import time
45import json
46import yaml
47import fcntl
48import base64
49import urllib2
50import hashlib
51import getpass
52import argparse
53import traceback
54
55import tornado
56import websocket
57
58from tornado.web import Application, RequestHandler
59from tornado.websocket import WebSocketHandler
60from tornado.httpserver import HTTPServer
61from tornado.ioloop import IOLoop
62
63from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
64
65
66class DebugMixIn(object):
67    def dprintf(self, msg, func=lambda: ()):
68        if self._debug:
69            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
70            sys.stderr.write(prefix + (msg % func()))
71
72
73class EthernetFrame(object):
74    def __init__(self, data):
75        self.data = data
76
77    @property
78    def dst_multicast(self):
79        return ord(self.data[0]) & 1
80
81    @property
82    def src_multicast(self):
83        return ord(self.data[6]) & 1
84
85    @property
86    def dst_mac(self):
87        return self.data[:6]
88
89    @property
90    def src_mac(self):
91        return self.data[6:12]
92
93    @property
94    def tagged(self):
95        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
96
97    @property
98    def vid(self):
99        if self.tagged:
100            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
101        return 0
102
103
104class FDB(DebugMixIn):
105    class Entry(object):
106        def __init__(self, port, ageout):
107            self.port = port
108            self._time = time.time()
109            self._ageout = ageout
110
111        @property
112        def age(self):
113            return time.time() - self._time
114
115        @property
116        def agedout(self):
117            return self.age > self._ageout
118
119    def __init__(self, ageout, debug=False):
120        self._ageout = ageout
121        self._debug = debug
122        self._table = {}
123
124    def _set_entry(self, vid, mac, port):
125        if vid not in self._table:
126            self._table[vid] = {}
127        self._table[vid][mac] = self.Entry(port, self._ageout)
128
129    def _del_entry(self, vid, mac):
130        if vid in self._table:
131            if mac in self._table[vid]:
132                del self._table[vid][mac]
133            if not self._table[vid]:
134                del self._table[vid]
135
136    def get_entry(self, vid, mac):
137        try:
138            entry = self._table[vid][mac]
139        except KeyError:
140            return None
141
142        if not entry.agedout:
143            return entry
144
145        self._del_entry(vid, mac)
146        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
147                     lambda: (entry.port.number, vid, mac.encode('hex')))
148
149    def get_vid_list(self):
150        return self._table.keys()
151
152    def get_mac_list(self, vid):
153        return self._table[vid].keys()
154
155    def lookup(self, frame):
156        mac = frame.dst_mac
157        vid = frame.vid
158        entry = self.get_entry(vid, mac)
159        return getattr(entry, 'port', None)
160
161    def learn(self, port, frame):
162        mac = frame.src_mac
163        vid = frame.vid
164        self._set_entry(vid, mac, port)
165        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
166                     lambda: (port.number, vid, mac.encode('hex')))
167
168    def delete(self, port):
169        for vid in self.get_vid_list():
170            for mac in self.get_mac_list(vid):
171                entry = self.get_entry(vid, mac)
172                if entry and entry.port.number == port.number:
173                    self._del_entry(vid, mac)
174                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
175                                 lambda: (port.number, vid, mac.encode('hex')))
176
177
178class SwitchingHub(DebugMixIn):
179    class Port(object):
180        def __init__(self, number, interface):
181            self.number = number
182            self.interface = interface
183            self.tx = 0
184            self.rx = 0
185            self.shut = False
186
187        @staticmethod
188        def cmp_by_number(x, y):
189            return cmp(x.number, y.number)
190
191    def __init__(self, fdb, debug=False):
192        self.fdb = fdb
193        self._debug = debug
194        self._table = {}
195        self._next = 1
196
197    @property
198    def portlist(self):
199        return sorted(self._table.itervalues(), cmp=self.Port.cmp_by_number)
200
201    def get_port(self, portnum):
202        return self._table[portnum]
203
204    def register_port(self, interface):
205        try:
206            self._set_privattr('portnum', interface, self._next)  # XXX
207            self._table[self._next] = self.Port(self._next, interface)
208            return self._next
209        finally:
210            self._next += 1
211
212    def unregister_port(self, interface):
213        portnum = self._get_privattr('portnum', interface)
214        self._del_privattr('portnum', interface)
215        self.fdb.delete(self._table[portnum])
216        del self._table[portnum]
217
218    def send(self, dst_interfaces, frame):
219        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
220        ports = (self._table[n] for n in portnums)
221        ports = (p for p in ports if not p.shut)
222        ports = sorted(ports, cmp=self.Port.cmp_by_number)
223
224        for p in ports:
225            p.interface.write_message(frame.data, True)
226            p.tx += 1
227
228        if ports:
229            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
230                         lambda: (','.join(str(p.number) for p in ports),
231                                  frame.vid,
232                                  frame.src_mac.encode('hex'),
233                                  frame.dst_mac.encode('hex')))
234
235    def receive(self, src_interface, frame):
236        port = self._table[self._get_privattr('portnum', src_interface)]
237
238        if not port.shut:
239            port.rx += 1
240            self._forward(port, frame)
241
242    def _forward(self, src_port, frame):
243        try:
244            if not frame.src_multicast:
245                self.fdb.learn(src_port, frame)
246
247            if not frame.dst_multicast:
248                dst_port = self.fdb.lookup(frame)
249
250                if dst_port:
251                    self.send([dst_port.interface], frame)
252                    return
253
254            ports = set(self.portlist) - set([src_port])
255            self.send((p.interface for p in ports), frame)
256
257        except:  # ex. received invalid frame
258            traceback.print_exc()
259
260    def _privattr(self, name):
261        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
262
263    def _set_privattr(self, name, obj, value):
264        return setattr(obj, self._privattr(name), value)
265
266    def _get_privattr(self, name, obj, defaults=None):
267        return getattr(obj, self._privattr(name), defaults)
268
269    def _del_privattr(self, name, obj):
270        return delattr(obj, self._privattr(name))
271
272
273class Htpasswd(object):
274    def __init__(self, path):
275        self._path = path
276        self._stat = None
277        self._data = {}
278
279    def auth(self, name, passwd):
280        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
281        return self._data.get(name) == passwd
282
283    def load(self):
284        old_stat = self._stat
285
286        with open(self._path) as fp:
287            fileno = fp.fileno()
288            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
289            self._stat = os.fstat(fileno)
290
291            unchanged = old_stat and \
292                        old_stat.st_ino == self._stat.st_ino and \
293                        old_stat.st_dev == self._stat.st_dev and \
294                        old_stat.st_mtime == self._stat.st_mtime
295
296            if not unchanged:
297                self._data = self._parse(fp)
298
299        return self
300
301    def _parse(self, fp):
302        data = {}
303        for line in fp:
304            line = line.strip()
305            if 0 <= line.find(':'):
306                name, passwd = line.split(':', 1)
307                if passwd.startswith('{SHA}'):
308                    data[name] = passwd[5:]
309        return data
310
311
312class BasicAuthMixIn(object):
313    def _execute(self, transforms, *args, **kwargs):
314        def do_execute():
315            sp = super(BasicAuthMixIn, self)
316            return sp._execute(transforms, *args, **kwargs)
317
318        def auth_required():
319            stream = getattr(self, 'stream', self.request.connection.stream)
320            stream.write(tornado.escape.utf8(
321                'HTTP/1.1 401 Authorization Required\r\n'
322                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
323            ))
324            stream.close()
325
326        try:
327            if not self._htpasswd:
328                return do_execute()
329
330            creds = self.request.headers.get('Authorization')
331
332            if not creds or not creds.startswith('Basic '):
333                return auth_required()
334
335            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
336
337            if self._htpasswd.load().auth(name, passwd):
338                return do_execute()
339        except:
340            traceback.print_exc()
341
342        return auth_required()
343
344
345class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
346    IFTYPE = 'server'
347
348    def __init__(self, app, req, switch, htpasswd=None, debug=False):
349        super(EtherWebSocketHandler, self).__init__(app, req)
350        self._switch = switch
351        self._htpasswd = htpasswd
352        self._debug = debug
353
354    def get_target(self):
355        return self.request.remote_ip
356
357    def open(self):
358        try:
359            return self._switch.register_port(self)
360        finally:
361            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
362
363    def on_message(self, message):
364        self._switch.receive(self, EthernetFrame(message))
365
366    def on_close(self):
367        self._switch.unregister_port(self)
368        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
369
370
371class TapHandler(DebugMixIn):
372    IFTYPE = 'tap'
373    READ_SIZE = 65535
374
375    def __init__(self, ioloop, switch, dev, debug=False):
376        self._ioloop = ioloop
377        self._switch = switch
378        self._dev = dev
379        self._debug = debug
380        self._tap = None
381
382    def get_target(self):
383        if self.closed:
384            return self._dev
385        return self._tap.name
386
387    @property
388    def closed(self):
389        return not self._tap
390
391    def open(self):
392        if not self.closed:
393            raise ValueError('already opened')
394        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
395        self._tap.up()
396        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
397        return self._switch.register_port(self)
398
399    def close(self):
400        if self.closed:
401            raise ValueError('I/O operation on closed tap')
402        self._switch.unregister_port(self)
403        self._ioloop.remove_handler(self.fileno())
404        self._tap.close()
405        self._tap = None
406
407    def fileno(self):
408        if self.closed:
409            raise ValueError('I/O operation on closed tap')
410        return self._tap.fileno()
411
412    def write_message(self, message, binary=False):
413        if self.closed:
414            raise ValueError('I/O operation on closed tap')
415        self._tap.write(message)
416
417    def __call__(self, fd, events):
418        try:
419            self._switch.receive(self, EthernetFrame(self._read()))
420            return
421        except:
422            traceback.print_exc()
423        self.close()
424
425    def _read(self):
426        if self.closed:
427            raise ValueError('I/O operation on closed tap')
428        buf = []
429        while True:
430            buf.append(self._tap.read(self.READ_SIZE))
431            if len(buf[-1]) < self.READ_SIZE:
432                break
433        return ''.join(buf)
434
435
436class EtherWebSocketClient(DebugMixIn):
437    IFTYPE = 'client'
438
439    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
440        self._ioloop = ioloop
441        self._switch = switch
442        self._url = url
443        self._ssl = ssl_
444        self._debug = debug
445        self._sock = None
446        self._options = {}
447
448        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
449            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
450            auth = ['Authorization: Basic %s' % token]
451            self._options['header'] = auth
452
453    def get_target(self):
454        return self._url
455
456    @property
457    def closed(self):
458        return not self._sock
459
460    def open(self):
461        sslwrap = websocket._SSLSocketWrapper
462
463        if not self.closed:
464            raise websocket.WebSocketException('already opened')
465
466        if self._ssl:
467            websocket._SSLSocketWrapper = self._ssl
468
469        try:
470            self._sock = websocket.WebSocket()
471            self._sock.connect(self._url, **self._options)
472            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
473            return self._switch.register_port(self)
474        finally:
475            websocket._SSLSocketWrapper = sslwrap
476            self.dprintf('connected: %s\n', lambda: self._url)
477
478    def close(self):
479        if self.closed:
480            raise websocket.WebSocketException('already closed')
481        self._switch.unregister_port(self)
482        self._ioloop.remove_handler(self.fileno())
483        self._sock.close()
484        self._sock = None
485        self.dprintf('disconnected: %s\n', lambda: self._url)
486
487    def fileno(self):
488        if self.closed:
489            raise websocket.WebSocketException('closed socket')
490        return self._sock.io_sock.fileno()
491
492    def write_message(self, message, binary=False):
493        if self.closed:
494            raise websocket.WebSocketException('closed socket')
495        if binary:
496            flag = websocket.ABNF.OPCODE_BINARY
497        else:
498            flag = websocket.ABNF.OPCODE_TEXT
499        self._sock.send(message, flag)
500
501    def __call__(self, fd, events):
502        try:
503            data = self._sock.recv()
504            if data is not None:
505                self._switch.receive(self, EthernetFrame(data))
506                return
507        except:
508            traceback.print_exc()
509        self.close()
510
511
512class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
513    NAMESPACE = 'etherws.control'
514    IFTYPES = {
515        TapHandler.IFTYPE:           TapHandler,
516        EtherWebSocketClient.IFTYPE: EtherWebSocketClient,
517    }
518
519    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
520        super(EtherWebSocketControlHandler, self).__init__(app, req)
521        self._ioloop = ioloop
522        self._switch = switch
523        self._htpasswd = htpasswd
524        self._debug = debug
525
526    def post(self):
527        id_ = None
528
529        try:
530            req = json.loads(self.request.body)
531            method = req['method']
532            params = req['params']
533            id_ = req.get('id')
534
535            if not method.startswith(self.NAMESPACE + '.'):
536                raise ValueError('invalid method: %s' % method)
537
538            if not isinstance(params, list):
539                raise ValueError('invalid params: %s' % params)
540
541            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
542            result = getattr(self, handler)(params)
543            self.finish({'result': result, 'error': None, 'id': id_})
544
545        except Exception as e:
546            traceback.print_exc()
547            msg = '%s: %s' % (e.__class__.__name__, str(e))
548            self.finish({'result': None, 'error': {'message': msg}, 'id': id_})
549
550    def handle_listPort(self, params):
551        list_ = [self._portstat(p) for p in self._switch.portlist]
552        return {'portlist': list_}
553
554    def handle_addPort(self, params):
555        list_ = []
556        for p in params:
557            type_ = p['type']
558            target = p['target']
559            options = getattr(self, '_optparse_' + type_)(p.get('options', {}))
560            klass = self.IFTYPES[type_]
561            interface = klass(self._ioloop, self._switch, target, **options)
562            portnum = interface.open()
563            list_.append(self._portstat(self._switch.get_port(portnum)))
564        return {'portlist': list_}
565
566    def handle_delPort(self, params):
567        list_ = []
568        for p in params:
569            port = self._switch.get_port(int(p['port']))
570            list_.append(self._portstat(port))
571            port.interface.close()
572        return {'portlist': list_}
573
574    def handle_shutPort(self, params):
575        list_ = []
576        for p in params:
577            port = self._switch.get_port(int(p['port']))
578            port.shut = bool(p['shut'])
579            list_.append(self._portstat(port))
580        return {'portlist': list_}
581
582    def _optparse_tap(self, opt):
583        return {'debug': self._debug}
584
585    def _optparse_client(self, opt):
586        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
587        if opt.get('insecure'):
588            args = {}
589        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
590        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
591        return {'ssl_': ssl_, 'cred': cred, 'debug': self._debug}
592
593    @staticmethod
594    def _portstat(port):
595        return {
596            'port':   port.number,
597            'type':   port.interface.IFTYPE,
598            'target': port.interface.get_target(),
599            'tx':     port.tx,
600            'rx':     port.rx,
601            'shut':   port.shut,
602        }
603
604
605def start_sw(args):
606    def daemonize(nochdir=False, noclose=False):
607        if os.fork() > 0:
608            sys.exit(0)
609
610        os.setsid()
611
612        if os.fork() > 0:
613            sys.exit(0)
614
615        if not nochdir:
616            os.chdir('/')
617
618        if not noclose:
619            os.umask(0)
620            sys.stdin.close()
621            sys.stdout.close()
622            sys.stderr.close()
623            os.close(0)
624            os.close(1)
625            os.close(2)
626            sys.stdin = open(os.devnull)
627            sys.stdout = open(os.devnull, 'a')
628            sys.stderr = open(os.devnull, 'a')
629
630    def checkabspath(ns, path):
631        val = getattr(ns, path, '')
632        if not val.startswith('/'):
633            raise ValueError('invalid %: %s' % (path, val))
634
635    def getsslopt(ns, key, cert):
636        kval = getattr(ns, key, None)
637        cval = getattr(ns, cert, None)
638        if kval and cval:
639            return {'keyfile': kval, 'certfile': cval}
640        elif kval or cval:
641            raise ValueError('both %s and %s are required' % (key, cert))
642        return None
643
644    def setrealpath(ns, *keys):
645        for k in keys:
646            v = getattr(ns, k, None)
647            if v is not None:
648                v = os.path.realpath(v)
649                open(v).close()  # check readable
650                setattr(ns, k, v)
651
652    def setport(ns, port, isssl):
653        val = getattr(ns, port, None)
654        if val is None:
655            if isssl:
656                return setattr(ns, port, 443)
657            return setattr(ns, port, 80)
658        if not (0 <= val <= 65535):
659            raise ValueError('invalid %s: %s' % (port, val))
660
661    def sethtpasswd(ns, htpasswd):
662        val = getattr(ns, htpasswd, None)
663        if val:
664            return setattr(ns, htpasswd, Htpasswd(val))
665
666    #if args.debug:
667    #    websocket.enableTrace(True)
668
669    if args.ageout <= 0:
670        raise ValueError('invalid ageout: %s' % args.ageout)
671
672    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
673    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
674
675    checkabspath(args, 'path')
676    checkabspath(args, 'ctlpath')
677
678    sslopt = getsslopt(args, 'sslkey', 'sslcert')
679    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
680
681    setport(args, 'port', sslopt)
682    setport(args, 'ctlport', ctlsslopt)
683
684    sethtpasswd(args, 'htpasswd')
685    sethtpasswd(args, 'ctlhtpasswd')
686
687    ioloop = IOLoop.instance()
688    fdb = FDB(ageout=args.ageout, debug=args.debug)
689    switch = SwitchingHub(fdb, debug=args.debug)
690
691    if args.port == args.ctlport and args.host == args.ctlhost:
692        if args.path == args.ctlpath:
693            raise ValueError('same path/ctlpath on same host')
694        if args.sslkey != args.ctlsslkey:
695            raise ValueError('different sslkey/ctlsslkey on same host')
696        if args.sslcert != args.ctlsslcert:
697            raise ValueError('different sslcert/ctlsslcert on same host')
698
699        app = Application([
700            (args.path, EtherWebSocketHandler, {
701                'switch':   switch,
702                'htpasswd': args.htpasswd,
703                'debug':    args.debug,
704            }),
705            (args.ctlpath, EtherWebSocketControlHandler, {
706                'ioloop':   ioloop,
707                'switch':   switch,
708                'htpasswd': args.ctlhtpasswd,
709                'debug':    args.debug,
710            }),
711        ])
712        server = HTTPServer(app, ssl_options=sslopt)
713        server.listen(args.port, address=args.host)
714
715    else:
716        app = Application([(args.path, EtherWebSocketHandler, {
717            'switch':   switch,
718            'htpasswd': args.htpasswd,
719            'debug':    args.debug,
720        })])
721        server = HTTPServer(app, ssl_options=sslopt)
722        server.listen(args.port, address=args.host)
723
724        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
725            'ioloop':   ioloop,
726            'switch':   switch,
727            'htpasswd': args.ctlhtpasswd,
728            'debug':    args.debug,
729        })])
730        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
731        ctlserver.listen(args.ctlport, address=args.ctlhost)
732
733    if not args.foreground:
734        daemonize()
735
736    ioloop.start()
737
738
739def start_ctl(args):
740    def request(args, method, params):
741        method = '.'.join([EtherWebSocketControlHandler.NAMESPACE, method])
742        data = json.dumps({'method': method, 'params': params})
743        req = urllib2.Request(args.ctlurl)
744        req.add_header('Content-type', 'application/json')
745        if args.ctluser:
746            if not args.ctlpasswd:
747                args.ctlpasswd = getpass.getpass()
748            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
749            req.add_header('Authorization', 'Basic %s' % token)
750        return json.loads(urllib2.urlopen(req, data).read())
751
752    def handle_ctl_addport(args):
753        params = [{
754            'type':    args.type,
755            'target':  args.target,
756            'options': {
757                'insecure': args.insecure,
758                'cacerts':  args.cacerts,
759                'user':     args.user,
760                'passwd':   args.passwd,
761            }
762        }]
763        return request(args, 'addPort', params)
764
765    def handle_ctl_shutport(args):
766        if args.port <= 0:
767            raise ValueError('invalid port: %d' % args.port)
768        params = [{'port': args.port, 'shut': args.no}]
769        return request(args, 'shutPort', params)
770
771    def handle_ctl_delport(args):
772        if args.port <= 0:
773            raise ValueError('invalid port: %d' % args.port)
774        params = [{'port': args.port}]
775        return request(args, 'delPort', params)
776
777    def handle_ctl_listport(args):
778        return request(args, 'listPort', [])
779
780    res = locals()['handle_ctl_' + args.control_method](args)
781
782    if res['error']:
783        print(res['error']['message'])
784    else:
785        print(yaml.safe_dump(res['result']['portlist']).strip())
786
787
788def main():
789    parser = argparse.ArgumentParser()
790    subcommand = parser.add_subparsers(dest='subcommand')
791
792    # -- sw command parser
793    parser_s = subcommand.add_parser('sw')
794
795    parser_s.add_argument('--debug', action='store_true', default=False)
796    parser_s.add_argument('--foreground', action='store_true', default=False)
797    parser_s.add_argument('--ageout', type=int, default=300)
798
799    parser_s.add_argument('--path', default='/')
800    parser_s.add_argument('--host', default='')
801    parser_s.add_argument('--port', type=int)
802    parser_s.add_argument('--htpasswd')
803    parser_s.add_argument('--sslkey')
804    parser_s.add_argument('--sslcert')
805
806    parser_s.add_argument('--ctlpath', default='/ctl')
807    parser_s.add_argument('--ctlhost', default='')
808    parser_s.add_argument('--ctlport', type=int)
809    parser_s.add_argument('--ctlhtpasswd')
810    parser_s.add_argument('--ctlsslkey')
811    parser_s.add_argument('--ctlsslcert')
812
813    # -- ctl command parser
814    parser_c = subcommand.add_parser('ctl')
815    parser_c.add_argument('--ctlurl', default='http://localhost/ctl')
816    parser_c.add_argument('--ctluser')
817    parser_c.add_argument('--ctlpasswd')
818
819    control_method = parser_c.add_subparsers(dest='control_method')
820
821    parser_c_ap = control_method.add_parser('addport')
822    parser_c_ap.add_argument(
823        'type', choices=EtherWebSocketControlHandler.IFTYPES.keys())
824    parser_c_ap.add_argument('target')
825    parser_c_ap.add_argument('--insecure', action='store_true', default=False)
826    parser_c_ap.add_argument('--cacerts')
827    parser_c_ap.add_argument('--user')
828    parser_c_ap.add_argument('--passwd')
829
830    parser_c_sp = control_method.add_parser('shutport')
831    parser_c_sp.add_argument('port', type=int)
832    parser_c_sp.add_argument('--no', action='store_false', default=True)
833
834    parser_c_dp = control_method.add_parser('delport')
835    parser_c_dp.add_argument('port', type=int)
836
837    parser_c_lp = control_method.add_parser('listport')
838
839    # -- go
840    args = parser.parse_args()
841    globals()['start_' + args.subcommand](args)
842
843
844if __name__ == '__main__':
845    main()
Note: See TracBrowser for help on using the repository browser.