source: etherws/trunk/etherws.py @ 195

Revision 195, 26.2 KB checked in by atzm, 12 years ago (diff)
  • refactoring
  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                          Ethernet over WebSocket
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - websocket-client-0.7.0
10#   - tornado-2.3
11#   - PyYAML-3.10
12#
13# ===========================================================================
14# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
15# All rights reserved.
16#
17# Redistribution and use in source and binary forms, with or without
18# modification, are permitted provided that the following conditions are met:
19#
20# 1. Redistributions of source code must retain the above copyright notice,
21#    this list of conditions and the following disclaimer.
22# 2. Redistributions in binary form must reproduce the above copyright
23#    notice, this list of conditions and the following disclaimer in the
24#    documentation and/or other materials provided with the distribution.
25#
26# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
27# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
28# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
29# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
30# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
31# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
32# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
33# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
34# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
35# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
36# POSSIBILITY OF SUCH DAMAGE.
37# ===========================================================================
38#
39# $Id$
40
41import os
42import sys
43import ssl
44import time
45import json
46import yaml
47import fcntl
48import base64
49import urllib2
50import hashlib
51import getpass
52import argparse
53import traceback
54
55import tornado
56import websocket
57
58from tornado.web import Application, RequestHandler
59from tornado.websocket import WebSocketHandler
60from tornado.httpserver import HTTPServer
61from tornado.ioloop import IOLoop
62
63from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
64
65
66class DebugMixIn(object):
67    def dprintf(self, msg, func=lambda: ()):
68        if self._debug:
69            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
70            sys.stderr.write(prefix + (msg % func()))
71
72
73class EthernetFrame(object):
74    def __init__(self, data):
75        self.data = data
76
77    @property
78    def dst_multicast(self):
79        return ord(self.data[0]) & 1
80
81    @property
82    def src_multicast(self):
83        return ord(self.data[6]) & 1
84
85    @property
86    def dst_mac(self):
87        return self.data[:6]
88
89    @property
90    def src_mac(self):
91        return self.data[6:12]
92
93    @property
94    def tagged(self):
95        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
96
97    @property
98    def vid(self):
99        if self.tagged:
100            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
101        return 0
102
103
104class FDB(DebugMixIn):
105    class FDBEntry(object):
106        def __init__(self, port, ageout):
107            self.port = port
108            self._time = time.time()
109            self._ageout = ageout
110
111        @property
112        def age(self):
113            return time.time() - self._time
114
115        @property
116        def agedout(self):
117            return self.age > self._ageout
118
119    def __init__(self, ageout, debug=False):
120        self._ageout = ageout
121        self._debug = debug
122        self._table = {}
123
124    def _set_entry(self, vid, mac, port):
125        if vid not in self._table:
126            self._table[vid] = {}
127        self._table[vid][mac] = self.FDBEntry(port, self._ageout)
128
129    def _del_entry(self, vid, mac):
130        if vid in self._table:
131            if mac in self._table[vid]:
132                del self._table[vid][mac]
133            if not self._table[vid]:
134                del self._table[vid]
135
136    def get_entry(self, vid, mac):
137        try:
138            entry = self._table[vid][mac]
139        except KeyError:
140            return None
141
142        if not entry.agedout:
143            return entry
144
145        self._del_entry(vid, mac)
146        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
147                     lambda: (entry.port.number, vid, mac.encode('hex')))
148
149    def get_vid_list(self):
150        return self._table.keys()
151
152    def get_mac_list(self, vid):
153        return self._table[vid].keys()
154
155    def lookup(self, frame):
156        mac = frame.dst_mac
157        vid = frame.vid
158        entry = self.get_entry(vid, mac)
159        return getattr(entry, 'port', None)
160
161    def learn(self, port, frame):
162        mac = frame.src_mac
163        vid = frame.vid
164        self._set_entry(vid, mac, port)
165        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
166                     lambda: (port.number, vid, mac.encode('hex')))
167
168    def delete(self, port):
169        for vid in self.get_vid_list():
170            for mac in self.get_mac_list(vid):
171                entry = self.get_entry(vid, mac)
172                if entry and entry.port.number == port.number:
173                    self._del_entry(vid, mac)
174                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
175                                 lambda: (port.number, vid, mac.encode('hex')))
176
177
178class SwitchingHub(DebugMixIn):
179    class Port(object):
180        def __init__(self, number, interface):
181            self.number = number
182            self.interface = interface
183            self.tx = 0
184            self.rx = 0
185            self.shut = False
186
187        @staticmethod
188        def cmp_by_number(x, y):
189            return cmp(x.number, y.number)
190
191    def __init__(self, fdb, debug=False):
192        self._fdb = fdb
193        self._debug = debug
194        self._table = {}
195        self._next = 1
196
197    @property
198    def portlist(self):
199        return sorted(self._table.itervalues(), cmp=self.Port.cmp_by_number)
200
201    @property
202    def fdb(self):
203        return self._fdb
204
205    def get_port(self, portnum):
206        return self._table[portnum]
207
208    def register_port(self, interface):
209        try:
210            self._set_privattr('portnum', interface, self._next)  # XXX
211            self._table[self._next] = self.Port(self._next, interface)
212            return self._next
213        finally:
214            self._next += 1
215
216    def unregister_port(self, interface):
217        portnum = self._get_privattr('portnum', interface)
218        self._del_privattr('portnum', interface)
219        self._fdb.delete(self._table[portnum])
220        del self._table[portnum]
221
222    def send(self, dst_interfaces, frame):
223        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
224        ports = (self._table[n] for n in portnums)
225        ports = (p for p in ports if not p.shut)
226        ports = sorted(ports, cmp=self.Port.cmp_by_number)
227
228        for p in ports:
229            p.interface.write_message(frame.data, True)
230            p.tx += 1
231
232        if ports:
233            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
234                         lambda: (','.join(str(p.number) for p in ports),
235                                  frame.vid,
236                                  frame.src_mac.encode('hex'),
237                                  frame.dst_mac.encode('hex')))
238
239    def receive(self, src_interface, frame):
240        port = self._table[self._get_privattr('portnum', src_interface)]
241
242        if not port.shut:
243            port.rx += 1
244            self._forward(port, frame)
245
246    def _forward(self, src_port, frame):
247        try:
248            if not frame.src_multicast:
249                self._fdb.learn(src_port, frame)
250
251            if not frame.dst_multicast:
252                dst_port = self._fdb.lookup(frame)
253
254                if dst_port:
255                    self.send([dst_port.interface], frame)
256                    return
257
258            ports = set(self.portlist) - set([src_port])
259            self.send((p.interface for p in ports), frame)
260
261        except:  # ex. received invalid frame
262            traceback.print_exc()
263
264    def _privattr(self, name):
265        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
266
267    def _set_privattr(self, name, obj, value):
268        return setattr(obj, self._privattr(name), value)
269
270    def _get_privattr(self, name, obj, defaults=None):
271        return getattr(obj, self._privattr(name), defaults)
272
273    def _del_privattr(self, name, obj):
274        return delattr(obj, self._privattr(name))
275
276
277class Htpasswd(object):
278    def __init__(self, path):
279        self._path = path
280        self._stat = None
281        self._data = {}
282
283    def auth(self, name, passwd):
284        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
285        return self._data.get(name) == passwd
286
287    def load(self):
288        old_stat = self._stat
289
290        with open(self._path) as fp:
291            fileno = fp.fileno()
292            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
293            self._stat = os.fstat(fileno)
294
295            unchanged = old_stat and \
296                        old_stat.st_ino == self._stat.st_ino and \
297                        old_stat.st_dev == self._stat.st_dev and \
298                        old_stat.st_mtime == self._stat.st_mtime
299
300            if not unchanged:
301                self._data = self._parse(fp)
302
303        return self
304
305    def _parse(self, fp):
306        data = {}
307        for line in fp:
308            line = line.strip()
309            if 0 <= line.find(':'):
310                name, passwd = line.split(':', 1)
311                if passwd.startswith('{SHA}'):
312                    data[name] = passwd[5:]
313        return data
314
315
316class BasicAuthMixIn(object):
317    def _execute(self, transforms, *args, **kwargs):
318        def do_execute():
319            sp = super(BasicAuthMixIn, self)
320            return sp._execute(transforms, *args, **kwargs)
321
322        def auth_required():
323            stream = getattr(self, 'stream', self.request.connection.stream)
324            stream.write(tornado.escape.utf8(
325                'HTTP/1.1 401 Authorization Required\r\n'
326                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
327            ))
328            stream.close()
329
330        try:
331            if not self._htpasswd:
332                return do_execute()
333
334            creds = self.request.headers.get('Authorization')
335
336            if not creds or not creds.startswith('Basic '):
337                return auth_required()
338
339            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
340
341            if self._htpasswd.load().auth(name, passwd):
342                return do_execute()
343        except:
344            traceback.print_exc()
345
346        return auth_required()
347
348
349class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
350    IFTYPE = 'server'
351
352    def __init__(self, app, req, switch, htpasswd=None, debug=False):
353        super(EtherWebSocketHandler, self).__init__(app, req)
354        self._switch = switch
355        self._htpasswd = htpasswd
356        self._debug = debug
357
358    def get_target(self):
359        return self.request.remote_ip
360
361    def open(self):
362        try:
363            return self._switch.register_port(self)
364        finally:
365            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
366
367    def on_message(self, message):
368        self._switch.receive(self, EthernetFrame(message))
369
370    def on_close(self):
371        self._switch.unregister_port(self)
372        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
373
374
375class TapHandler(DebugMixIn):
376    IFTYPE = 'tap'
377    READ_SIZE = 65535
378
379    def __init__(self, ioloop, switch, dev, debug=False):
380        self._ioloop = ioloop
381        self._switch = switch
382        self._dev = dev
383        self._debug = debug
384        self._tap = None
385
386    def get_target(self):
387        if self.closed:
388            return self._dev
389        return self._tap.name
390
391    @property
392    def closed(self):
393        return not self._tap
394
395    def open(self):
396        if not self.closed:
397            raise ValueError('already opened')
398        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
399        self._tap.up()
400        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
401        return self._switch.register_port(self)
402
403    def close(self):
404        if self.closed:
405            raise ValueError('I/O operation on closed tap')
406        self._switch.unregister_port(self)
407        self._ioloop.remove_handler(self.fileno())
408        self._tap.close()
409        self._tap = None
410
411    def fileno(self):
412        if self.closed:
413            raise ValueError('I/O operation on closed tap')
414        return self._tap.fileno()
415
416    def write_message(self, message, binary=False):
417        if self.closed:
418            raise ValueError('I/O operation on closed tap')
419        self._tap.write(message)
420
421    def __call__(self, fd, events):
422        try:
423            self._switch.receive(self, EthernetFrame(self._read()))
424            return
425        except:
426            traceback.print_exc()
427        self.close()
428
429    def _read(self):
430        if self.closed:
431            raise ValueError('I/O operation on closed tap')
432        buf = []
433        while True:
434            buf.append(self._tap.read(self.READ_SIZE))
435            if len(buf[-1]) < self.READ_SIZE:
436                break
437        return ''.join(buf)
438
439
440class EtherWebSocketClient(DebugMixIn):
441    IFTYPE = 'client'
442
443    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
444        self._ioloop = ioloop
445        self._switch = switch
446        self._url = url
447        self._ssl = ssl_
448        self._debug = debug
449        self._sock = None
450        self._options = {}
451
452        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
453            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
454            auth = ['Authorization: Basic %s' % token]
455            self._options['header'] = auth
456
457    def get_target(self):
458        return self._url
459
460    @property
461    def closed(self):
462        return not self._sock
463
464    def open(self):
465        sslwrap = websocket._SSLSocketWrapper
466
467        if not self.closed:
468            raise websocket.WebSocketException('already opened')
469
470        if self._ssl:
471            websocket._SSLSocketWrapper = self._ssl
472
473        try:
474            self._sock = websocket.WebSocket()
475            self._sock.connect(self._url, **self._options)
476            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
477            return self._switch.register_port(self)
478        finally:
479            websocket._SSLSocketWrapper = sslwrap
480            self.dprintf('connected: %s\n', lambda: self._url)
481
482    def close(self):
483        if self.closed:
484            raise websocket.WebSocketException('already closed')
485        self._switch.unregister_port(self)
486        self._ioloop.remove_handler(self.fileno())
487        self._sock.close()
488        self._sock = None
489        self.dprintf('disconnected: %s\n', lambda: self._url)
490
491    def fileno(self):
492        if self.closed:
493            raise websocket.WebSocketException('closed socket')
494        return self._sock.io_sock.fileno()
495
496    def write_message(self, message, binary=False):
497        if self.closed:
498            raise websocket.WebSocketException('closed socket')
499        if binary:
500            flag = websocket.ABNF.OPCODE_BINARY
501        else:
502            flag = websocket.ABNF.OPCODE_TEXT
503        self._sock.send(message, flag)
504
505    def __call__(self, fd, events):
506        try:
507            data = self._sock.recv()
508            if data is not None:
509                self._switch.receive(self, EthernetFrame(data))
510                return
511        except:
512            traceback.print_exc()
513        self.close()
514
515
516class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
517    NAMESPACE = 'etherws.control'
518    IFTYPES = {
519        TapHandler.IFTYPE:           TapHandler,
520        EtherWebSocketClient.IFTYPE: EtherWebSocketClient,
521    }
522
523    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
524        super(EtherWebSocketControlHandler, self).__init__(app, req)
525        self._ioloop = ioloop
526        self._switch = switch
527        self._htpasswd = htpasswd
528        self._debug = debug
529
530    def post(self):
531        id_ = None
532
533        try:
534            req = json.loads(self.request.body)
535            method = req['method']
536            params = req['params']
537            id_ = req.get('id')
538
539            if not method.startswith(self.NAMESPACE + '.'):
540                raise ValueError('invalid method: %s' % method)
541
542            if not isinstance(params, list):
543                raise ValueError('invalid params: %s' % params)
544
545            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
546            result = getattr(self, handler)(params)
547            self.finish({'result': result, 'error': None, 'id': id_})
548
549        except Exception as e:
550            traceback.print_exc()
551            msg = '%s: %s' % (e.__class__.__name__, str(e))
552            self.finish({'result': None, 'error': {'message': msg}, 'id': id_})
553
554    def handle_listPort(self, params):
555        list_ = [self._portstat(p) for p in self._switch.portlist]
556        return {'portlist': list_}
557
558    def handle_addPort(self, params):
559        list_ = []
560        for p in params:
561            type_ = p['type']
562            target = p['target']
563            options = getattr(self, '_optparse_' + type_)(p.get('options', {}))
564            klass = self.IFTYPES[type_]
565            interface = klass(self._ioloop, self._switch, target, **options)
566            portnum = interface.open()
567            list_.append(self._portstat(self._switch.get_port(portnum)))
568        return {'portlist': list_}
569
570    def handle_delPort(self, params):
571        list_ = []
572        for p in params:
573            port = self._switch.get_port(int(p['port']))
574            list_.append(self._portstat(port))
575            port.interface.close()
576        return {'portlist': list_}
577
578    def handle_shutPort(self, params):
579        list_ = []
580        for p in params:
581            port = self._switch.get_port(int(p['port']))
582            port.shut = bool(p['shut'])
583            list_.append(self._portstat(port))
584        return {'portlist': list_}
585
586    def _optparse_tap(self, opt):
587        return {'debug': self._debug}
588
589    def _optparse_client(self, opt):
590        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
591        if opt.get('insecure'):
592            args = {}
593        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
594        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
595        return {'ssl_': ssl_, 'cred': cred, 'debug': self._debug}
596
597    @staticmethod
598    def _portstat(port):
599        return {
600            'port':   port.number,
601            'type':   port.interface.IFTYPE,
602            'target': port.interface.get_target(),
603            'tx':     port.tx,
604            'rx':     port.rx,
605            'shut':   port.shut,
606        }
607
608
609def start_sw(args):
610    def daemonize(nochdir=False, noclose=False):
611        if os.fork() > 0:
612            sys.exit(0)
613
614        os.setsid()
615
616        if os.fork() > 0:
617            sys.exit(0)
618
619        if not nochdir:
620            os.chdir('/')
621
622        if not noclose:
623            os.umask(0)
624            sys.stdin.close()
625            sys.stdout.close()
626            sys.stderr.close()
627            os.close(0)
628            os.close(1)
629            os.close(2)
630            sys.stdin = open(os.devnull)
631            sys.stdout = open(os.devnull, 'a')
632            sys.stderr = open(os.devnull, 'a')
633
634    def checkabspath(ns, path):
635        val = getattr(ns, path, '')
636        if not val.startswith('/'):
637            raise ValueError('invalid %: %s' % (path, val))
638
639    def getsslopt(ns, key, cert):
640        kval = getattr(ns, key, None)
641        cval = getattr(ns, cert, None)
642        if kval and cval:
643            return {'keyfile': kval, 'certfile': cval}
644        elif kval or cval:
645            raise ValueError('both %s and %s are required' % (key, cert))
646        return None
647
648    def setrealpath(ns, *keys):
649        for k in keys:
650            v = getattr(ns, k, None)
651            if v is not None:
652                v = os.path.realpath(v)
653                open(v).close()  # check readable
654                setattr(ns, k, v)
655
656    def setport(ns, port, isssl):
657        val = getattr(ns, port, None)
658        if val is None:
659            if isssl:
660                return setattr(ns, port, 443)
661            return setattr(ns, port, 80)
662        if not (0 <= val <= 65535):
663            raise ValueError('invalid %s: %s' % (port, val))
664
665    def sethtpasswd(ns, htpasswd):
666        val = getattr(ns, htpasswd, None)
667        if val:
668            return setattr(ns, htpasswd, Htpasswd(val))
669
670    #if args.debug:
671    #    websocket.enableTrace(True)
672
673    if args.ageout <= 0:
674        raise ValueError('invalid ageout: %s' % args.ageout)
675
676    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
677    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
678
679    checkabspath(args, 'path')
680    checkabspath(args, 'ctlpath')
681
682    sslopt = getsslopt(args, 'sslkey', 'sslcert')
683    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
684
685    setport(args, 'port', sslopt)
686    setport(args, 'ctlport', ctlsslopt)
687
688    sethtpasswd(args, 'htpasswd')
689    sethtpasswd(args, 'ctlhtpasswd')
690
691    ioloop = IOLoop.instance()
692    fdb = FDB(ageout=args.ageout, debug=args.debug)
693    switch = SwitchingHub(fdb, debug=args.debug)
694
695    if args.port == args.ctlport and args.host == args.ctlhost:
696        if args.path == args.ctlpath:
697            raise ValueError('same path/ctlpath on same host')
698        if args.sslkey != args.ctlsslkey:
699            raise ValueError('different sslkey/ctlsslkey on same host')
700        if args.sslcert != args.ctlsslcert:
701            raise ValueError('different sslcert/ctlsslcert on same host')
702
703        app = Application([
704            (args.path, EtherWebSocketHandler, {
705                'switch':   switch,
706                'htpasswd': args.htpasswd,
707                'debug':    args.debug,
708            }),
709            (args.ctlpath, EtherWebSocketControlHandler, {
710                'ioloop':   ioloop,
711                'switch':   switch,
712                'htpasswd': args.ctlhtpasswd,
713                'debug':    args.debug,
714            }),
715        ])
716        server = HTTPServer(app, ssl_options=sslopt)
717        server.listen(args.port, address=args.host)
718
719    else:
720        app = Application([(args.path, EtherWebSocketHandler, {
721            'switch':   switch,
722            'htpasswd': args.htpasswd,
723            'debug':    args.debug,
724        })])
725        server = HTTPServer(app, ssl_options=sslopt)
726        server.listen(args.port, address=args.host)
727
728        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
729            'ioloop':   ioloop,
730            'switch':   switch,
731            'htpasswd': args.ctlhtpasswd,
732            'debug':    args.debug,
733        })])
734        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
735        ctlserver.listen(args.ctlport, address=args.ctlhost)
736
737    if not args.foreground:
738        daemonize()
739
740    ioloop.start()
741
742
743def start_ctl(args):
744    def request(args, method, params):
745        method = '.'.join([EtherWebSocketControlHandler.NAMESPACE, method])
746        data = json.dumps({'method': method, 'params': params})
747        req = urllib2.Request(args.ctlurl)
748        req.add_header('Content-type', 'application/json')
749        if args.ctluser:
750            if not args.ctlpasswd:
751                args.ctlpasswd = getpass.getpass()
752            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
753            req.add_header('Authorization', 'Basic %s' % token)
754        return json.loads(urllib2.urlopen(req, data).read())
755
756    def handle_ctl_addport(args):
757        params = [{
758            'type':    args.type,
759            'target':  args.target,
760            'options': {
761                'insecure': args.insecure,
762                'cacerts':  args.cacerts,
763                'user':     args.user,
764                'passwd':   args.passwd,
765            }
766        }]
767        return request(args, 'addPort', params)
768
769    def handle_ctl_shutport(args):
770        if args.port <= 0:
771            raise ValueError('invalid port: %d' % args.port)
772        params = [{'port': args.port, 'shut': args.no}]
773        return request(args, 'shutPort', params)
774
775    def handle_ctl_delport(args):
776        if args.port <= 0:
777            raise ValueError('invalid port: %d' % args.port)
778        params = [{'port': args.port}]
779        return request(args, 'delPort', params)
780
781    def handle_ctl_listport(args):
782        return request(args, 'listPort', [])
783
784    res = locals()['handle_ctl_' + args.control_method](args)
785
786    if res['error']:
787        print(res['error']['message'])
788    else:
789        print(yaml.safe_dump(res['result']['portlist']).strip())
790
791
792def main():
793    parser = argparse.ArgumentParser()
794    subcommand = parser.add_subparsers(dest='subcommand')
795
796    # -- sw command parser
797    parser_s = subcommand.add_parser('sw')
798
799    parser_s.add_argument('--debug', action='store_true', default=False)
800    parser_s.add_argument('--foreground', action='store_true', default=False)
801    parser_s.add_argument('--ageout', type=int, default=300)
802
803    parser_s.add_argument('--path', default='/')
804    parser_s.add_argument('--host', default='')
805    parser_s.add_argument('--port', type=int)
806    parser_s.add_argument('--htpasswd')
807    parser_s.add_argument('--sslkey')
808    parser_s.add_argument('--sslcert')
809
810    parser_s.add_argument('--ctlpath', default='/ctl')
811    parser_s.add_argument('--ctlhost', default='')
812    parser_s.add_argument('--ctlport', type=int)
813    parser_s.add_argument('--ctlhtpasswd')
814    parser_s.add_argument('--ctlsslkey')
815    parser_s.add_argument('--ctlsslcert')
816
817    # -- ctl command parser
818    parser_c = subcommand.add_parser('ctl')
819    parser_c.add_argument('--ctlurl', default='http://localhost/ctl')
820    parser_c.add_argument('--ctluser')
821    parser_c.add_argument('--ctlpasswd')
822
823    control_method = parser_c.add_subparsers(dest='control_method')
824
825    parser_c_ap = control_method.add_parser('addport')
826    parser_c_ap.add_argument(
827        'type', choices=EtherWebSocketControlHandler.IFTYPES.keys())
828    parser_c_ap.add_argument('target')
829    parser_c_ap.add_argument('--insecure', action='store_true', default=False)
830    parser_c_ap.add_argument('--cacerts')
831    parser_c_ap.add_argument('--user')
832    parser_c_ap.add_argument('--passwd')
833
834    parser_c_sp = control_method.add_parser('shutport')
835    parser_c_sp.add_argument('port', type=int)
836    parser_c_sp.add_argument('--no', action='store_false', default=True)
837
838    parser_c_dp = control_method.add_parser('delport')
839    parser_c_dp.add_argument('port', type=int)
840
841    parser_c_lp = control_method.add_parser('listport')
842
843    # -- go
844    args = parser.parse_args()
845    globals()['start_' + args.subcommand](args)
846
847
848if __name__ == '__main__':
849    main()
Note: See TracBrowser for help on using the repository browser.