source: etherws/trunk/etherws.py @ 194

Revision 194, 25.5 KB checked in by atzm, 12 years ago (diff)
  • depends on PyYAML
  • Property svn:keywords set to Id
RevLine 
[133]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
[186]4#                          Ethernet over WebSocket
[133]5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
[136]9#   - websocket-client-0.7.0
[183]10#   - tornado-2.3
[133]11#
12# ===========================================================================
13# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
14# All rights reserved.
15#
16# Redistribution and use in source and binary forms, with or without
17# modification, are permitted provided that the following conditions are met:
18#
19# 1. Redistributions of source code must retain the above copyright notice,
20#    this list of conditions and the following disclaimer.
21# 2. Redistributions in binary form must reproduce the above copyright
22#    notice, this list of conditions and the following disclaimer in the
23#    documentation and/or other materials provided with the distribution.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36# ===========================================================================
37#
38# $Id$
39
40import os
41import sys
[156]42import ssl
[160]43import time
[183]44import json
[194]45import yaml
[175]46import fcntl
[150]47import base64
[190]48import urllib2
[150]49import hashlib
[151]50import getpass
[133]51import argparse
[165]52import traceback
[133]53
[185]54import tornado
[133]55import websocket
56
[183]57from tornado.web import Application, RequestHandler
[182]58from tornado.websocket import WebSocketHandler
[183]59from tornado.httpserver import HTTPServer
60from tornado.ioloop import IOLoop
61
[166]62from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
[133]63
[166]64
[160]65class DebugMixIn(object):
[166]66    def dprintf(self, msg, func=lambda: ()):
[160]67        if self._debug:
68            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
[164]69            sys.stderr.write(prefix + (msg % func()))
[160]70
71
[164]72class EthernetFrame(object):
73    def __init__(self, data):
74        self.data = data
75
[176]76    @property
77    def dst_multicast(self):
78        return ord(self.data[0]) & 1
[164]79
80    @property
[176]81    def src_multicast(self):
82        return ord(self.data[6]) & 1
83
84    @property
[164]85    def dst_mac(self):
86        return self.data[:6]
87
88    @property
89    def src_mac(self):
90        return self.data[6:12]
91
92    @property
93    def tagged(self):
94        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
95
96    @property
97    def vid(self):
98        if self.tagged:
99            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
[183]100        return 0
[164]101
102
[166]103class FDB(DebugMixIn):
[167]104    def __init__(self, ageout, debug=False):
[164]105        self._ageout = ageout
106        self._debug = debug
[166]107        self._dict = {}
[164]108
109    def lookup(self, frame):
110        mac = frame.dst_mac
111        vid = frame.vid
112
[177]113        group = self._dict.get(vid)
[164]114        if not group:
115            return None
116
[177]117        entry = group.get(mac)
[164]118        if not entry:
119            return None
120
121        if time.time() - entry['time'] > self._ageout:
[183]122            port = self._dict[vid][mac]['port']
[166]123            del self._dict[vid][mac]
124            if not self._dict[vid]:
125                del self._dict[vid]
[183]126            self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
127                         lambda: (port.number, vid, mac.encode('hex')))
[164]128            return None
129
130        return entry['port']
131
[166]132    def learn(self, port, frame):
133        mac = frame.src_mac
134        vid = frame.vid
135
136        if vid not in self._dict:
137            self._dict[vid] = {}
138
139        self._dict[vid][mac] = {'time': time.time(), 'port': port}
[183]140        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
141                     lambda: (port.number, vid, mac.encode('hex')))
[166]142
[164]143    def delete(self, port):
[166]144        for vid in self._dict.keys():
145            for mac in self._dict[vid].keys():
[183]146                if self._dict[vid][mac]['port'].number == port.number:
[166]147                    del self._dict[vid][mac]
[183]148                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
149                                 lambda: (port.number, vid, mac.encode('hex')))
[166]150            if not self._dict[vid]:
151                del self._dict[vid]
[164]152
153
[183]154class SwitchPort(object):
155    def __init__(self, number, interface):
156        self.number = number
157        self.interface = interface
158        self.tx = 0
159        self.rx = 0
160        self.shut = False
161
162    @staticmethod
163    def cmp_by_number(x, y):
164        return cmp(x.number, y.number)
165
166
[166]167class SwitchingHub(DebugMixIn):
168    def __init__(self, fdb, debug=False):
169        self._fdb = fdb
[133]170        self._debug = debug
[183]171        self._table = {}
172        self._next = 1
[133]173
[183]174    @property
175    def portlist(self):
176        return sorted(self._table.itervalues(), cmp=SwitchPort.cmp_by_number)
[133]177
[183]178    def get_port(self, portnum):
179        return self._table[portnum]
180
181    def register_port(self, interface):
[186]182        try:
[187]183            self._set_privattr('portnum', interface, self._next)  # XXX
[186]184            self._table[self._next] = SwitchPort(self._next, interface)
185            return self._next
186        finally:
187            self._next += 1
[183]188
189    def unregister_port(self, interface):
[187]190        portnum = self._get_privattr('portnum', interface)
191        self._del_privattr('portnum', interface)
192        self._fdb.delete(self._table[portnum])
193        del self._table[portnum]
[183]194
195    def send(self, dst_interfaces, frame):
[187]196        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
197        ports = (self._table[n] for n in portnums)
198        ports = (p for p in ports if not p.shut)
199        ports = sorted(ports, cmp=SwitchPort.cmp_by_number)
[183]200
201        for p in ports:
202            p.interface.write_message(frame.data, True)
203            p.tx += 1
204
205        if ports:
206            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
207                         lambda: (','.join(str(p.number) for p in ports),
208                                  frame.vid,
209                                  frame.src_mac.encode('hex'),
210                                  frame.dst_mac.encode('hex')))
211
212    def receive(self, src_interface, frame):
[187]213        port = self._table[self._get_privattr('portnum', src_interface)]
[183]214
215        if not port.shut:
216            port.rx += 1
217            self._forward(port, frame)
218
219    def _forward(self, src_port, frame):
[166]220        try:
[176]221            if not frame.src_multicast:
[172]222                self._fdb.learn(src_port, frame)
[133]223
[176]224            if not frame.dst_multicast:
[166]225                dst_port = self._fdb.lookup(frame)
[164]226
[166]227                if dst_port:
[183]228                    self.send([dst_port.interface], frame)
[166]229                    return
[133]230
[187]231            ports = set(self.portlist) - set([src_port])
[183]232            self.send((p.interface for p in ports), frame)
[162]233
[166]234        except:  # ex. received invalid frame
235            traceback.print_exc()
[133]236
[187]237    def _privattr(self, name):
238        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
[164]239
[187]240    def _set_privattr(self, name, obj, value):
241        return setattr(obj, self._privattr(name), value)
242
243    def _get_privattr(self, name, obj, defaults=None):
244        return getattr(obj, self._privattr(name), defaults)
245
246    def _del_privattr(self, name, obj):
247        return delattr(obj, self._privattr(name))
248
249
[179]250class Htpasswd(object):
251    def __init__(self, path):
252        self._path = path
253        self._stat = None
254        self._data = {}
255
256    def auth(self, name, passwd):
257        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
258        return self._data.get(name) == passwd
259
260    def load(self):
261        old_stat = self._stat
262
263        with open(self._path) as fp:
264            fileno = fp.fileno()
265            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
266            self._stat = os.fstat(fileno)
267
268            unchanged = old_stat and \
269                        old_stat.st_ino == self._stat.st_ino and \
270                        old_stat.st_dev == self._stat.st_dev and \
271                        old_stat.st_mtime == self._stat.st_mtime
272
273            if not unchanged:
274                self._data = self._parse(fp)
275
276        return self
277
278    def _parse(self, fp):
279        data = {}
280        for line in fp:
281            line = line.strip()
282            if 0 <= line.find(':'):
283                name, passwd = line.split(':', 1)
284                if passwd.startswith('{SHA}'):
285                    data[name] = passwd[5:]
286        return data
287
288
[182]289class BasicAuthMixIn(object):
290    def _execute(self, transforms, *args, **kwargs):
291        def do_execute():
292            sp = super(BasicAuthMixIn, self)
293            return sp._execute(transforms, *args, **kwargs)
294
295        def auth_required():
[185]296            stream = getattr(self, 'stream', self.request.connection.stream)
297            stream.write(tornado.escape.utf8(
[182]298                'HTTP/1.1 401 Authorization Required\r\n'
299                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
300            ))
[185]301            stream.close()
[182]302
303        try:
304            if not self._htpasswd:
305                return do_execute()
306
307            creds = self.request.headers.get('Authorization')
308
309            if not creds or not creds.startswith('Basic '):
310                return auth_required()
311
312            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
313
314            if self._htpasswd.load().auth(name, passwd):
315                return do_execute()
316        except:
317            traceback.print_exc()
318
319        return auth_required()
320
321
[186]322class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
[191]323    IFTYPE = 'server'
324
[186]325    def __init__(self, app, req, switch, htpasswd=None, debug=False):
326        super(EtherWebSocketHandler, self).__init__(app, req)
327        self._switch = switch
328        self._htpasswd = htpasswd
329        self._debug = debug
330
331    def get_target(self):
332        return self.request.remote_ip
333
334    def open(self):
335        try:
336            return self._switch.register_port(self)
337        finally:
338            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
339
340    def on_message(self, message):
341        self._switch.receive(self, EthernetFrame(message))
342
343    def on_close(self):
344        self._switch.unregister_port(self)
345        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
346
347
[166]348class TapHandler(DebugMixIn):
[191]349    IFTYPE = 'tap'
[166]350    READ_SIZE = 65535
351
[178]352    def __init__(self, ioloop, switch, dev, debug=False):
353        self._ioloop = ioloop
[166]354        self._switch = switch
355        self._dev = dev
356        self._debug = debug
357        self._tap = None
358
[186]359    def get_target(self):
[183]360        if self.closed:
361            return self._dev
362        return self._tap.name
363
[186]364    @property
365    def closed(self):
366        return not self._tap
367
[166]368    def open(self):
369        if not self.closed:
370            raise ValueError('already opened')
371        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
372        self._tap.up()
[178]373        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
[186]374        return self._switch.register_port(self)
[166]375
376    def close(self):
377        if self.closed:
378            raise ValueError('I/O operation on closed tap')
[186]379        self._switch.unregister_port(self)
[178]380        self._ioloop.remove_handler(self.fileno())
[166]381        self._tap.close()
382        self._tap = None
383
384    def fileno(self):
385        if self.closed:
386            raise ValueError('I/O operation on closed tap')
387        return self._tap.fileno()
388
389    def write_message(self, message, binary=False):
390        if self.closed:
391            raise ValueError('I/O operation on closed tap')
392        self._tap.write(message)
393
[138]394    def __call__(self, fd, events):
[166]395        try:
[183]396            self._switch.receive(self, EthernetFrame(self._read()))
[166]397            return
398        except:
399            traceback.print_exc()
[178]400        self.close()
[166]401
402    def _read(self):
403        if self.closed:
404            raise ValueError('I/O operation on closed tap')
[162]405        buf = []
406        while True:
[166]407            buf.append(self._tap.read(self.READ_SIZE))
408            if len(buf[-1]) < self.READ_SIZE:
[162]409                break
[166]410        return ''.join(buf)
[162]411
412
[160]413class EtherWebSocketClient(DebugMixIn):
[191]414    IFTYPE = 'client'
415
[181]416    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
[178]417        self._ioloop = ioloop
[166]418        self._switch = switch
[151]419        self._url = url
[181]420        self._ssl = ssl_
[160]421        self._debug = debug
[166]422        self._sock = None
[151]423        self._options = {}
424
[174]425        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
426            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
[151]427            auth = ['Authorization: Basic %s' % token]
428            self._options['header'] = auth
429
[186]430    def get_target(self):
431        return self._url
432
[160]433    @property
434    def closed(self):
435        return not self._sock
436
[151]437    def open(self):
[181]438        sslwrap = websocket._SSLSocketWrapper
439
[160]440        if not self.closed:
441            raise websocket.WebSocketException('already opened')
[151]442
[181]443        if self._ssl:
444            websocket._SSLSocketWrapper = self._ssl
445
446        try:
447            self._sock = websocket.WebSocket()
448            self._sock.connect(self._url, **self._options)
449            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
[186]450            return self._switch.register_port(self)
[181]451        finally:
452            websocket._SSLSocketWrapper = sslwrap
[186]453            self.dprintf('connected: %s\n', lambda: self._url)
[181]454
[151]455    def close(self):
[160]456        if self.closed:
457            raise websocket.WebSocketException('already closed')
[186]458        self._switch.unregister_port(self)
[178]459        self._ioloop.remove_handler(self.fileno())
[151]460        self._sock.close()
461        self._sock = None
[164]462        self.dprintf('disconnected: %s\n', lambda: self._url)
[151]463
[165]464    def fileno(self):
465        if self.closed:
466            raise websocket.WebSocketException('closed socket')
467        return self._sock.io_sock.fileno()
468
[151]469    def write_message(self, message, binary=False):
[160]470        if self.closed:
471            raise websocket.WebSocketException('closed socket')
[151]472        if binary:
473            flag = websocket.ABNF.OPCODE_BINARY
[160]474        else:
475            flag = websocket.ABNF.OPCODE_TEXT
[151]476        self._sock.send(message, flag)
477
[165]478    def __call__(self, fd, events):
[151]479        try:
[165]480            data = self._sock.recv()
481            if data is not None:
[183]482                self._switch.receive(self, EthernetFrame(data))
[165]483                return
484        except:
485            traceback.print_exc()
[178]486        self.close()
[151]487
488
[183]489class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
490    NAMESPACE = 'etherws.control'
[191]491    IFTYPES = {
492        TapHandler.IFTYPE:           TapHandler,
493        EtherWebSocketClient.IFTYPE: EtherWebSocketClient,
[186]494    }
[183]495
496    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
497        super(EtherWebSocketControlHandler, self).__init__(app, req)
498        self._ioloop = ioloop
499        self._switch = switch
500        self._htpasswd = htpasswd
501        self._debug = debug
502
503    def post(self):
504        id_ = None
505
506        try:
507            req = json.loads(self.request.body)
508            method = req['method']
509            params = req['params']
510            id_ = req.get('id')
511
512            if not method.startswith(self.NAMESPACE + '.'):
513                raise ValueError('invalid method: %s' % method)
514
515            if not isinstance(params, list):
516                raise ValueError('invalid params: %s' % params)
517
518            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
519            result = getattr(self, handler)(params)
520            self.finish({'result': result, 'error': None, 'id': id_})
521
522        except Exception as e:
523            traceback.print_exc()
[192]524            msg = '%s: %s' % (e.__class__.__name__, str(e))
525            self.finish({'result': None, 'error': {'message': msg}, 'id': id_})
[183]526
527    def handle_listPort(self, params):
[188]528        list_ = [self._portstat(p) for p in self._switch.portlist]
[183]529        return {'portlist': list_}
530
531    def handle_addPort(self, params):
[186]532        list_ = []
[183]533        for p in params:
[186]534            type_ = p['type']
535            target = p['target']
536            options = getattr(self, '_optparse_' + type_)(p.get('options', {}))
[191]537            klass = self.IFTYPES[type_]
[186]538            interface = klass(self._ioloop, self._switch, target, **options)
539            portnum = interface.open()
540            list_.append(self._portstat(self._switch.get_port(portnum)))
541        return {'portlist': list_}
[183]542
543    def handle_delPort(self, params):
[186]544        list_ = []
[183]545        for p in params:
[186]546            port = self._switch.get_port(int(p['port']))
547            list_.append(self._portstat(port))
548            port.interface.close()
549        return {'portlist': list_}
[183]550
551    def handle_shutPort(self, params):
[186]552        list_ = []
[183]553        for p in params:
[186]554            port = self._switch.get_port(int(p['port']))
[190]555            port.shut = bool(p['shut'])
[186]556            list_.append(self._portstat(port))
557        return {'portlist': list_}
[183]558
[186]559    def _optparse_tap(self, opt):
560        return {'debug': self._debug}
[183]561
[186]562    def _optparse_client(self, opt):
563        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
564        if opt.get('insecure'):
565            args = {}
566        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
567        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
568        return {'ssl_': ssl_, 'cred': cred, 'debug': self._debug}
[183]569
570    @staticmethod
[186]571    def _portstat(port):
572        return {
573            'port':   port.number,
[191]574            'type':   port.interface.IFTYPE,
[186]575            'target': port.interface.get_target(),
576            'tx':     port.tx,
577            'rx':     port.rx,
578            'shut':   port.shut,
579        }
[183]580
581
[190]582def start_sw(args):
[186]583    def daemonize(nochdir=False, noclose=False):
584        if os.fork() > 0:
585            sys.exit(0)
[134]586
[186]587        os.setsid()
[134]588
[186]589        if os.fork() > 0:
590            sys.exit(0)
[134]591
[186]592        if not nochdir:
593            os.chdir('/')
[134]594
[186]595        if not noclose:
596            os.umask(0)
597            sys.stdin.close()
598            sys.stdout.close()
599            sys.stderr.close()
600            os.close(0)
601            os.close(1)
602            os.close(2)
603            sys.stdin = open(os.devnull)
604            sys.stdout = open(os.devnull, 'a')
605            sys.stderr = open(os.devnull, 'a')
[134]606
[186]607    def checkabspath(ns, path):
[184]608        val = getattr(ns, path, '')
609        if not val.startswith('/'):
610            raise ValueError('invalid %: %s' % (path, val))
611
612    def getsslopt(ns, key, cert):
613        kval = getattr(ns, key, None)
614        cval = getattr(ns, cert, None)
615        if kval and cval:
616            return {'keyfile': kval, 'certfile': cval}
617        elif kval or cval:
618            raise ValueError('both %s and %s are required' % (key, cert))
619        return None
620
[186]621    def setrealpath(ns, *keys):
622        for k in keys:
623            v = getattr(ns, k, None)
624            if v is not None:
625                v = os.path.realpath(v)
626                open(v).close()  # check readable
627                setattr(ns, k, v)
628
[184]629    def setport(ns, port, isssl):
630        val = getattr(ns, port, None)
631        if val is None:
632            if isssl:
633                return setattr(ns, port, 443)
634            return setattr(ns, port, 80)
635        if not (0 <= val <= 65535):
636            raise ValueError('invalid %s: %s' % (port, val))
637
638    def sethtpasswd(ns, htpasswd):
639        val = getattr(ns, htpasswd, None)
640        if val:
641            return setattr(ns, htpasswd, Htpasswd(val))
642
[183]643    #if args.debug:
644    #    websocket.enableTrace(True)
645
646    if args.ageout <= 0:
647        raise ValueError('invalid ageout: %s' % args.ageout)
648
[186]649    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
650    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
[183]651
[186]652    checkabspath(args, 'path')
653    checkabspath(args, 'ctlpath')
[183]654
[184]655    sslopt = getsslopt(args, 'sslkey', 'sslcert')
656    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
[143]657
[184]658    setport(args, 'port', sslopt)
659    setport(args, 'ctlport', ctlsslopt)
[143]660
[184]661    sethtpasswd(args, 'htpasswd')
662    sethtpasswd(args, 'ctlhtpasswd')
[167]663
[183]664    ioloop = IOLoop.instance()
[167]665    fdb = FDB(ageout=args.ageout, debug=args.debug)
[183]666    switch = SwitchingHub(fdb, debug=args.debug)
[167]667
[184]668    if args.port == args.ctlport and args.host == args.ctlhost:
669        if args.path == args.ctlpath:
670            raise ValueError('same path/ctlpath on same host')
671        if args.sslkey != args.ctlsslkey:
[189]672            raise ValueError('different sslkey/ctlsslkey on same host')
[184]673        if args.sslcert != args.ctlsslcert:
[189]674            raise ValueError('different sslcert/ctlsslcert on same host')
[133]675
[184]676        app = Application([
677            (args.path, EtherWebSocketHandler, {
678                'switch':   switch,
679                'htpasswd': args.htpasswd,
680                'debug':    args.debug,
681            }),
682            (args.ctlpath, EtherWebSocketControlHandler, {
683                'ioloop':   ioloop,
684                'switch':   switch,
685                'htpasswd': args.ctlhtpasswd,
686                'debug':    args.debug,
687            }),
688        ])
689        server = HTTPServer(app, ssl_options=sslopt)
690        server.listen(args.port, address=args.host)
[151]691
[184]692    else:
693        app = Application([(args.path, EtherWebSocketHandler, {
694            'switch':   switch,
695            'htpasswd': args.htpasswd,
696            'debug':    args.debug,
697        })])
698        server = HTTPServer(app, ssl_options=sslopt)
699        server.listen(args.port, address=args.host)
700
701        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
702            'ioloop':   ioloop,
703            'switch':   switch,
704            'htpasswd': args.ctlhtpasswd,
705            'debug':    args.debug,
706        })])
707        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
708        ctlserver.listen(args.ctlport, address=args.ctlhost)
709
[151]710    if not args.foreground:
711        daemonize()
712
[138]713    ioloop.start()
[133]714
715
[190]716def start_ctl(args):
717    def request(args, method, params):
718        method = '.'.join([EtherWebSocketControlHandler.NAMESPACE, method])
719        data = json.dumps({'method': method, 'params': params})
720        req = urllib2.Request(args.ctlurl)
721        req.add_header('Content-type', 'application/json')
722        if args.ctluser:
723            if not args.ctlpasswd:
724                args.ctlpasswd = getpass.getpass()
725            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
726            req.add_header('Authorization', 'Basic %s' % token)
727        return json.loads(urllib2.urlopen(req, data).read())
728
729    def handle_ctl_addport(args):
[191]730        params = [{
731            'type':    args.type,
732            'target':  args.target,
733            'options': {
[193]734                'insecure': args.insecure,
[191]735                'cacerts':  args.cacerts,
736                'user':     args.user,
737                'passwd':   args.passwd,
738            }
739        }]
[192]740        return request(args, 'addPort', params)
[190]741
742    def handle_ctl_shutport(args):
743        if args.port <= 0:
744            raise ValueError('invalid port: %d' % args.port)
[192]745        params = [{'port': args.port, 'shut': args.no}]
746        return request(args, 'shutPort', params)
[190]747
748    def handle_ctl_delport(args):
749        if args.port <= 0:
750            raise ValueError('invalid port: %d' % args.port)
[192]751        params = [{'port': args.port}]
752        return request(args, 'delPort', params)
[190]753
754    def handle_ctl_listport(args):
[192]755        return request(args, 'listPort', [])
[190]756
[192]757    res = locals()['handle_ctl_' + args.control_method](args)
[190]758
[192]759    if res['error']:
760        print(res['error']['message'])
761    else:
762        print(yaml.safe_dump(res['result']['portlist']).strip())
[190]763
[192]764
[186]765def main():
766    parser = argparse.ArgumentParser()
[190]767    subcommand = parser.add_subparsers(dest='subcommand')
[186]768
[190]769    # -- sw command parser
770    parser_s = subcommand.add_parser('sw')
771
[186]772    parser_s.add_argument('--debug', action='store_true', default=False)
773    parser_s.add_argument('--foreground', action='store_true', default=False)
[190]774    parser_s.add_argument('--ageout', type=int, default=300)
[186]775
[190]776    parser_s.add_argument('--path', default='/')
777    parser_s.add_argument('--host', default='')
778    parser_s.add_argument('--port', type=int)
779    parser_s.add_argument('--htpasswd')
780    parser_s.add_argument('--sslkey')
781    parser_s.add_argument('--sslcert')
[186]782
[190]783    parser_s.add_argument('--ctlpath', default='/ctl')
784    parser_s.add_argument('--ctlhost', default='')
785    parser_s.add_argument('--ctlport', type=int)
786    parser_s.add_argument('--ctlhtpasswd')
787    parser_s.add_argument('--ctlsslkey')
788    parser_s.add_argument('--ctlsslcert')
[186]789
[190]790    # -- ctl command parser
791    parser_c = subcommand.add_parser('ctl')
792    parser_c.add_argument('--ctlurl', default='http://localhost/ctl')
793    parser_c.add_argument('--ctluser')
794    parser_c.add_argument('--ctlpasswd')
795
796    control_method = parser_c.add_subparsers(dest='control_method')
797
798    parser_c_ap = control_method.add_parser('addport')
[191]799    parser_c_ap.add_argument(
800        'type', choices=EtherWebSocketControlHandler.IFTYPES.keys())
[190]801    parser_c_ap.add_argument('target')
802    parser_c_ap.add_argument('--insecure', action='store_true', default=False)
803    parser_c_ap.add_argument('--cacerts')
804    parser_c_ap.add_argument('--user')
805    parser_c_ap.add_argument('--passwd')
806
807    parser_c_sp = control_method.add_parser('shutport')
808    parser_c_sp.add_argument('port', type=int)
809    parser_c_sp.add_argument('--no', action='store_false', default=True)
810
811    parser_c_dp = control_method.add_parser('delport')
812    parser_c_dp.add_argument('port', type=int)
813
814    parser_c_lp = control_method.add_parser('listport')
815
816    # -- go
[186]817    args = parser.parse_args()
818    globals()['start_' + args.subcommand](args)
819
820
[133]821if __name__ == '__main__':
822    main()
Note: See TracBrowser for help on using the repository browser.