source: etherws/trunk/etherws.py @ 197

Revision 197, 26.1 KB checked in by atzm, 12 years ago (diff)
  • Property svn:keywords set to Id
RevLine 
[133]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
[186]4#                          Ethernet over WebSocket
[133]5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
[136]9#   - websocket-client-0.7.0
[183]10#   - tornado-2.3
[195]11#   - PyYAML-3.10
[133]12#
13# ===========================================================================
14# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
15# All rights reserved.
16#
17# Redistribution and use in source and binary forms, with or without
18# modification, are permitted provided that the following conditions are met:
19#
20# 1. Redistributions of source code must retain the above copyright notice,
21#    this list of conditions and the following disclaimer.
22# 2. Redistributions in binary form must reproduce the above copyright
23#    notice, this list of conditions and the following disclaimer in the
24#    documentation and/or other materials provided with the distribution.
25#
26# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
27# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
28# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
29# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
30# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
31# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
32# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
33# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
34# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
35# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
36# POSSIBILITY OF SUCH DAMAGE.
37# ===========================================================================
38#
39# $Id$
40
41import os
42import sys
[156]43import ssl
[160]44import time
[183]45import json
[194]46import yaml
[175]47import fcntl
[150]48import base64
[190]49import urllib2
[150]50import hashlib
[151]51import getpass
[133]52import argparse
[165]53import traceback
[133]54
[185]55import tornado
[133]56import websocket
57
[183]58from tornado.web import Application, RequestHandler
[182]59from tornado.websocket import WebSocketHandler
[183]60from tornado.httpserver import HTTPServer
61from tornado.ioloop import IOLoop
62
[166]63from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
[133]64
[166]65
[160]66class DebugMixIn(object):
[166]67    def dprintf(self, msg, func=lambda: ()):
[160]68        if self._debug:
69            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
[164]70            sys.stderr.write(prefix + (msg % func()))
[160]71
72
[164]73class EthernetFrame(object):
74    def __init__(self, data):
75        self.data = data
76
[176]77    @property
78    def dst_multicast(self):
79        return ord(self.data[0]) & 1
[164]80
81    @property
[176]82    def src_multicast(self):
83        return ord(self.data[6]) & 1
84
85    @property
[164]86    def dst_mac(self):
87        return self.data[:6]
88
89    @property
90    def src_mac(self):
91        return self.data[6:12]
92
93    @property
94    def tagged(self):
95        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
96
97    @property
98    def vid(self):
99        if self.tagged:
100            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
[183]101        return 0
[164]102
103
[166]104class FDB(DebugMixIn):
[196]105    class Entry(object):
[195]106        def __init__(self, port, ageout):
107            self.port = port
108            self._time = time.time()
109            self._ageout = ageout
110
111        @property
112        def age(self):
113            return time.time() - self._time
114
115        @property
116        def agedout(self):
117            return self.age > self._ageout
118
[167]119    def __init__(self, ageout, debug=False):
[164]120        self._ageout = ageout
121        self._debug = debug
[195]122        self._table = {}
[164]123
[195]124    def _set_entry(self, vid, mac, port):
125        if vid not in self._table:
126            self._table[vid] = {}
[196]127        self._table[vid][mac] = self.Entry(port, self._ageout)
[164]128
[195]129    def _del_entry(self, vid, mac):
130        if vid in self._table:
131            if mac in self._table[vid]:
132                del self._table[vid][mac]
133            if not self._table[vid]:
134                del self._table[vid]
[164]135
[195]136    def get_entry(self, vid, mac):
137        try:
138            entry = self._table[vid][mac]
139        except KeyError:
[164]140            return None
141
[195]142        if not entry.agedout:
143            return entry
[164]144
[195]145        self._del_entry(vid, mac)
146        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
147                     lambda: (entry.port.number, vid, mac.encode('hex')))
[164]148
[195]149    def get_vid_list(self):
150        return self._table.keys()
151
152    def get_mac_list(self, vid):
153        return self._table[vid].keys()
154
155    def lookup(self, frame):
156        mac = frame.dst_mac
157        vid = frame.vid
158        entry = self.get_entry(vid, mac)
159        return getattr(entry, 'port', None)
160
[166]161    def learn(self, port, frame):
162        mac = frame.src_mac
163        vid = frame.vid
[195]164        self._set_entry(vid, mac, port)
[183]165        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
166                     lambda: (port.number, vid, mac.encode('hex')))
[166]167
[164]168    def delete(self, port):
[195]169        for vid in self.get_vid_list():
170            for mac in self.get_mac_list(vid):
171                entry = self.get_entry(vid, mac)
172                if entry and entry.port.number == port.number:
173                    self._del_entry(vid, mac)
[183]174                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
175                                 lambda: (port.number, vid, mac.encode('hex')))
[164]176
177
[195]178class SwitchingHub(DebugMixIn):
179    class Port(object):
180        def __init__(self, number, interface):
181            self.number = number
182            self.interface = interface
183            self.tx = 0
184            self.rx = 0
185            self.shut = False
[183]186
[195]187        @staticmethod
188        def cmp_by_number(x, y):
189            return cmp(x.number, y.number)
[183]190
[166]191    def __init__(self, fdb, debug=False):
[197]192        self.fdb = fdb
[133]193        self._debug = debug
[183]194        self._table = {}
195        self._next = 1
[133]196
[183]197    @property
198    def portlist(self):
[195]199        return sorted(self._table.itervalues(), cmp=self.Port.cmp_by_number)
[133]200
[183]201    def get_port(self, portnum):
202        return self._table[portnum]
203
204    def register_port(self, interface):
[186]205        try:
[187]206            self._set_privattr('portnum', interface, self._next)  # XXX
[195]207            self._table[self._next] = self.Port(self._next, interface)
[186]208            return self._next
209        finally:
210            self._next += 1
[183]211
212    def unregister_port(self, interface):
[187]213        portnum = self._get_privattr('portnum', interface)
214        self._del_privattr('portnum', interface)
[197]215        self.fdb.delete(self._table[portnum])
[187]216        del self._table[portnum]
[183]217
218    def send(self, dst_interfaces, frame):
[187]219        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
220        ports = (self._table[n] for n in portnums)
221        ports = (p for p in ports if not p.shut)
[195]222        ports = sorted(ports, cmp=self.Port.cmp_by_number)
[183]223
224        for p in ports:
225            p.interface.write_message(frame.data, True)
226            p.tx += 1
227
228        if ports:
229            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
230                         lambda: (','.join(str(p.number) for p in ports),
231                                  frame.vid,
232                                  frame.src_mac.encode('hex'),
233                                  frame.dst_mac.encode('hex')))
234
235    def receive(self, src_interface, frame):
[187]236        port = self._table[self._get_privattr('portnum', src_interface)]
[183]237
238        if not port.shut:
239            port.rx += 1
240            self._forward(port, frame)
241
242    def _forward(self, src_port, frame):
[166]243        try:
[176]244            if not frame.src_multicast:
[197]245                self.fdb.learn(src_port, frame)
[133]246
[176]247            if not frame.dst_multicast:
[197]248                dst_port = self.fdb.lookup(frame)
[164]249
[166]250                if dst_port:
[183]251                    self.send([dst_port.interface], frame)
[166]252                    return
[133]253
[187]254            ports = set(self.portlist) - set([src_port])
[183]255            self.send((p.interface for p in ports), frame)
[162]256
[166]257        except:  # ex. received invalid frame
258            traceback.print_exc()
[133]259
[187]260    def _privattr(self, name):
261        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
[164]262
[187]263    def _set_privattr(self, name, obj, value):
264        return setattr(obj, self._privattr(name), value)
265
266    def _get_privattr(self, name, obj, defaults=None):
267        return getattr(obj, self._privattr(name), defaults)
268
269    def _del_privattr(self, name, obj):
270        return delattr(obj, self._privattr(name))
271
272
[179]273class Htpasswd(object):
274    def __init__(self, path):
275        self._path = path
276        self._stat = None
277        self._data = {}
278
279    def auth(self, name, passwd):
280        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
281        return self._data.get(name) == passwd
282
283    def load(self):
284        old_stat = self._stat
285
286        with open(self._path) as fp:
287            fileno = fp.fileno()
288            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
289            self._stat = os.fstat(fileno)
290
291            unchanged = old_stat and \
292                        old_stat.st_ino == self._stat.st_ino and \
293                        old_stat.st_dev == self._stat.st_dev and \
294                        old_stat.st_mtime == self._stat.st_mtime
295
296            if not unchanged:
297                self._data = self._parse(fp)
298
299        return self
300
301    def _parse(self, fp):
302        data = {}
303        for line in fp:
304            line = line.strip()
305            if 0 <= line.find(':'):
306                name, passwd = line.split(':', 1)
307                if passwd.startswith('{SHA}'):
308                    data[name] = passwd[5:]
309        return data
310
311
[182]312class BasicAuthMixIn(object):
313    def _execute(self, transforms, *args, **kwargs):
314        def do_execute():
315            sp = super(BasicAuthMixIn, self)
316            return sp._execute(transforms, *args, **kwargs)
317
318        def auth_required():
[185]319            stream = getattr(self, 'stream', self.request.connection.stream)
320            stream.write(tornado.escape.utf8(
[182]321                'HTTP/1.1 401 Authorization Required\r\n'
322                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
323            ))
[185]324            stream.close()
[182]325
326        try:
327            if not self._htpasswd:
328                return do_execute()
329
330            creds = self.request.headers.get('Authorization')
331
332            if not creds or not creds.startswith('Basic '):
333                return auth_required()
334
335            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
336
337            if self._htpasswd.load().auth(name, passwd):
338                return do_execute()
339        except:
340            traceback.print_exc()
341
342        return auth_required()
343
344
[186]345class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
[191]346    IFTYPE = 'server'
347
[186]348    def __init__(self, app, req, switch, htpasswd=None, debug=False):
349        super(EtherWebSocketHandler, self).__init__(app, req)
350        self._switch = switch
351        self._htpasswd = htpasswd
352        self._debug = debug
353
354    def get_target(self):
355        return self.request.remote_ip
356
357    def open(self):
358        try:
359            return self._switch.register_port(self)
360        finally:
361            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
362
363    def on_message(self, message):
364        self._switch.receive(self, EthernetFrame(message))
365
366    def on_close(self):
367        self._switch.unregister_port(self)
368        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
369
370
[166]371class TapHandler(DebugMixIn):
[191]372    IFTYPE = 'tap'
[166]373    READ_SIZE = 65535
374
[178]375    def __init__(self, ioloop, switch, dev, debug=False):
376        self._ioloop = ioloop
[166]377        self._switch = switch
378        self._dev = dev
379        self._debug = debug
380        self._tap = None
381
[186]382    def get_target(self):
[183]383        if self.closed:
384            return self._dev
385        return self._tap.name
386
[186]387    @property
388    def closed(self):
389        return not self._tap
390
[166]391    def open(self):
392        if not self.closed:
393            raise ValueError('already opened')
394        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
395        self._tap.up()
[178]396        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
[186]397        return self._switch.register_port(self)
[166]398
399    def close(self):
400        if self.closed:
401            raise ValueError('I/O operation on closed tap')
[186]402        self._switch.unregister_port(self)
[178]403        self._ioloop.remove_handler(self.fileno())
[166]404        self._tap.close()
405        self._tap = None
406
407    def fileno(self):
408        if self.closed:
409            raise ValueError('I/O operation on closed tap')
410        return self._tap.fileno()
411
412    def write_message(self, message, binary=False):
413        if self.closed:
414            raise ValueError('I/O operation on closed tap')
415        self._tap.write(message)
416
[138]417    def __call__(self, fd, events):
[166]418        try:
[183]419            self._switch.receive(self, EthernetFrame(self._read()))
[166]420            return
421        except:
422            traceback.print_exc()
[178]423        self.close()
[166]424
425    def _read(self):
426        if self.closed:
427            raise ValueError('I/O operation on closed tap')
[162]428        buf = []
429        while True:
[166]430            buf.append(self._tap.read(self.READ_SIZE))
431            if len(buf[-1]) < self.READ_SIZE:
[162]432                break
[166]433        return ''.join(buf)
[162]434
435
[160]436class EtherWebSocketClient(DebugMixIn):
[191]437    IFTYPE = 'client'
438
[181]439    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
[178]440        self._ioloop = ioloop
[166]441        self._switch = switch
[151]442        self._url = url
[181]443        self._ssl = ssl_
[160]444        self._debug = debug
[166]445        self._sock = None
[151]446        self._options = {}
447
[174]448        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
449            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
[151]450            auth = ['Authorization: Basic %s' % token]
451            self._options['header'] = auth
452
[186]453    def get_target(self):
454        return self._url
455
[160]456    @property
457    def closed(self):
458        return not self._sock
459
[151]460    def open(self):
[181]461        sslwrap = websocket._SSLSocketWrapper
462
[160]463        if not self.closed:
464            raise websocket.WebSocketException('already opened')
[151]465
[181]466        if self._ssl:
467            websocket._SSLSocketWrapper = self._ssl
468
469        try:
470            self._sock = websocket.WebSocket()
471            self._sock.connect(self._url, **self._options)
472            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
[186]473            return self._switch.register_port(self)
[181]474        finally:
475            websocket._SSLSocketWrapper = sslwrap
[186]476            self.dprintf('connected: %s\n', lambda: self._url)
[181]477
[151]478    def close(self):
[160]479        if self.closed:
480            raise websocket.WebSocketException('already closed')
[186]481        self._switch.unregister_port(self)
[178]482        self._ioloop.remove_handler(self.fileno())
[151]483        self._sock.close()
484        self._sock = None
[164]485        self.dprintf('disconnected: %s\n', lambda: self._url)
[151]486
[165]487    def fileno(self):
488        if self.closed:
489            raise websocket.WebSocketException('closed socket')
490        return self._sock.io_sock.fileno()
491
[151]492    def write_message(self, message, binary=False):
[160]493        if self.closed:
494            raise websocket.WebSocketException('closed socket')
[151]495        if binary:
496            flag = websocket.ABNF.OPCODE_BINARY
[160]497        else:
498            flag = websocket.ABNF.OPCODE_TEXT
[151]499        self._sock.send(message, flag)
500
[165]501    def __call__(self, fd, events):
[151]502        try:
[165]503            data = self._sock.recv()
504            if data is not None:
[183]505                self._switch.receive(self, EthernetFrame(data))
[165]506                return
507        except:
508            traceback.print_exc()
[178]509        self.close()
[151]510
511
[183]512class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
513    NAMESPACE = 'etherws.control'
[191]514    IFTYPES = {
515        TapHandler.IFTYPE:           TapHandler,
516        EtherWebSocketClient.IFTYPE: EtherWebSocketClient,
[186]517    }
[183]518
519    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
520        super(EtherWebSocketControlHandler, self).__init__(app, req)
521        self._ioloop = ioloop
522        self._switch = switch
523        self._htpasswd = htpasswd
524        self._debug = debug
525
526    def post(self):
527        id_ = None
528
529        try:
530            req = json.loads(self.request.body)
531            method = req['method']
532            params = req['params']
533            id_ = req.get('id')
534
535            if not method.startswith(self.NAMESPACE + '.'):
536                raise ValueError('invalid method: %s' % method)
537
538            if not isinstance(params, list):
539                raise ValueError('invalid params: %s' % params)
540
541            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
542            result = getattr(self, handler)(params)
543            self.finish({'result': result, 'error': None, 'id': id_})
544
545        except Exception as e:
546            traceback.print_exc()
[192]547            msg = '%s: %s' % (e.__class__.__name__, str(e))
548            self.finish({'result': None, 'error': {'message': msg}, 'id': id_})
[183]549
550    def handle_listPort(self, params):
[188]551        list_ = [self._portstat(p) for p in self._switch.portlist]
[183]552        return {'portlist': list_}
553
554    def handle_addPort(self, params):
[186]555        list_ = []
[183]556        for p in params:
[186]557            type_ = p['type']
558            target = p['target']
559            options = getattr(self, '_optparse_' + type_)(p.get('options', {}))
[191]560            klass = self.IFTYPES[type_]
[186]561            interface = klass(self._ioloop, self._switch, target, **options)
562            portnum = interface.open()
563            list_.append(self._portstat(self._switch.get_port(portnum)))
564        return {'portlist': list_}
[183]565
566    def handle_delPort(self, params):
[186]567        list_ = []
[183]568        for p in params:
[186]569            port = self._switch.get_port(int(p['port']))
570            list_.append(self._portstat(port))
571            port.interface.close()
572        return {'portlist': list_}
[183]573
574    def handle_shutPort(self, params):
[186]575        list_ = []
[183]576        for p in params:
[186]577            port = self._switch.get_port(int(p['port']))
[190]578            port.shut = bool(p['shut'])
[186]579            list_.append(self._portstat(port))
580        return {'portlist': list_}
[183]581
[186]582    def _optparse_tap(self, opt):
583        return {'debug': self._debug}
[183]584
[186]585    def _optparse_client(self, opt):
586        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
587        if opt.get('insecure'):
588            args = {}
589        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
590        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
591        return {'ssl_': ssl_, 'cred': cred, 'debug': self._debug}
[183]592
593    @staticmethod
[186]594    def _portstat(port):
595        return {
596            'port':   port.number,
[191]597            'type':   port.interface.IFTYPE,
[186]598            'target': port.interface.get_target(),
599            'tx':     port.tx,
600            'rx':     port.rx,
601            'shut':   port.shut,
602        }
[183]603
604
[190]605def start_sw(args):
[186]606    def daemonize(nochdir=False, noclose=False):
607        if os.fork() > 0:
608            sys.exit(0)
[134]609
[186]610        os.setsid()
[134]611
[186]612        if os.fork() > 0:
613            sys.exit(0)
[134]614
[186]615        if not nochdir:
616            os.chdir('/')
[134]617
[186]618        if not noclose:
619            os.umask(0)
620            sys.stdin.close()
621            sys.stdout.close()
622            sys.stderr.close()
623            os.close(0)
624            os.close(1)
625            os.close(2)
626            sys.stdin = open(os.devnull)
627            sys.stdout = open(os.devnull, 'a')
628            sys.stderr = open(os.devnull, 'a')
[134]629
[186]630    def checkabspath(ns, path):
[184]631        val = getattr(ns, path, '')
632        if not val.startswith('/'):
633            raise ValueError('invalid %: %s' % (path, val))
634
635    def getsslopt(ns, key, cert):
636        kval = getattr(ns, key, None)
637        cval = getattr(ns, cert, None)
638        if kval and cval:
639            return {'keyfile': kval, 'certfile': cval}
640        elif kval or cval:
641            raise ValueError('both %s and %s are required' % (key, cert))
642        return None
643
[186]644    def setrealpath(ns, *keys):
645        for k in keys:
646            v = getattr(ns, k, None)
647            if v is not None:
648                v = os.path.realpath(v)
649                open(v).close()  # check readable
650                setattr(ns, k, v)
651
[184]652    def setport(ns, port, isssl):
653        val = getattr(ns, port, None)
654        if val is None:
655            if isssl:
656                return setattr(ns, port, 443)
657            return setattr(ns, port, 80)
658        if not (0 <= val <= 65535):
659            raise ValueError('invalid %s: %s' % (port, val))
660
661    def sethtpasswd(ns, htpasswd):
662        val = getattr(ns, htpasswd, None)
663        if val:
664            return setattr(ns, htpasswd, Htpasswd(val))
665
[183]666    #if args.debug:
667    #    websocket.enableTrace(True)
668
669    if args.ageout <= 0:
670        raise ValueError('invalid ageout: %s' % args.ageout)
671
[186]672    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
673    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
[183]674
[186]675    checkabspath(args, 'path')
676    checkabspath(args, 'ctlpath')
[183]677
[184]678    sslopt = getsslopt(args, 'sslkey', 'sslcert')
679    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
[143]680
[184]681    setport(args, 'port', sslopt)
682    setport(args, 'ctlport', ctlsslopt)
[143]683
[184]684    sethtpasswd(args, 'htpasswd')
685    sethtpasswd(args, 'ctlhtpasswd')
[167]686
[183]687    ioloop = IOLoop.instance()
[167]688    fdb = FDB(ageout=args.ageout, debug=args.debug)
[183]689    switch = SwitchingHub(fdb, debug=args.debug)
[167]690
[184]691    if args.port == args.ctlport and args.host == args.ctlhost:
692        if args.path == args.ctlpath:
693            raise ValueError('same path/ctlpath on same host')
694        if args.sslkey != args.ctlsslkey:
[189]695            raise ValueError('different sslkey/ctlsslkey on same host')
[184]696        if args.sslcert != args.ctlsslcert:
[189]697            raise ValueError('different sslcert/ctlsslcert on same host')
[133]698
[184]699        app = Application([
700            (args.path, EtherWebSocketHandler, {
701                'switch':   switch,
702                'htpasswd': args.htpasswd,
703                'debug':    args.debug,
704            }),
705            (args.ctlpath, EtherWebSocketControlHandler, {
706                'ioloop':   ioloop,
707                'switch':   switch,
708                'htpasswd': args.ctlhtpasswd,
709                'debug':    args.debug,
710            }),
711        ])
712        server = HTTPServer(app, ssl_options=sslopt)
713        server.listen(args.port, address=args.host)
[151]714
[184]715    else:
716        app = Application([(args.path, EtherWebSocketHandler, {
717            'switch':   switch,
718            'htpasswd': args.htpasswd,
719            'debug':    args.debug,
720        })])
721        server = HTTPServer(app, ssl_options=sslopt)
722        server.listen(args.port, address=args.host)
723
724        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
725            'ioloop':   ioloop,
726            'switch':   switch,
727            'htpasswd': args.ctlhtpasswd,
728            'debug':    args.debug,
729        })])
730        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
731        ctlserver.listen(args.ctlport, address=args.ctlhost)
732
[151]733    if not args.foreground:
734        daemonize()
735
[138]736    ioloop.start()
[133]737
738
[190]739def start_ctl(args):
740    def request(args, method, params):
741        method = '.'.join([EtherWebSocketControlHandler.NAMESPACE, method])
742        data = json.dumps({'method': method, 'params': params})
743        req = urllib2.Request(args.ctlurl)
744        req.add_header('Content-type', 'application/json')
745        if args.ctluser:
746            if not args.ctlpasswd:
747                args.ctlpasswd = getpass.getpass()
748            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
749            req.add_header('Authorization', 'Basic %s' % token)
750        return json.loads(urllib2.urlopen(req, data).read())
751
752    def handle_ctl_addport(args):
[191]753        params = [{
754            'type':    args.type,
755            'target':  args.target,
756            'options': {
[193]757                'insecure': args.insecure,
[191]758                'cacerts':  args.cacerts,
759                'user':     args.user,
760                'passwd':   args.passwd,
761            }
762        }]
[192]763        return request(args, 'addPort', params)
[190]764
765    def handle_ctl_shutport(args):
766        if args.port <= 0:
767            raise ValueError('invalid port: %d' % args.port)
[192]768        params = [{'port': args.port, 'shut': args.no}]
769        return request(args, 'shutPort', params)
[190]770
771    def handle_ctl_delport(args):
772        if args.port <= 0:
773            raise ValueError('invalid port: %d' % args.port)
[192]774        params = [{'port': args.port}]
775        return request(args, 'delPort', params)
[190]776
777    def handle_ctl_listport(args):
[192]778        return request(args, 'listPort', [])
[190]779
[192]780    res = locals()['handle_ctl_' + args.control_method](args)
[190]781
[192]782    if res['error']:
783        print(res['error']['message'])
784    else:
785        print(yaml.safe_dump(res['result']['portlist']).strip())
[190]786
[192]787
[186]788def main():
789    parser = argparse.ArgumentParser()
[190]790    subcommand = parser.add_subparsers(dest='subcommand')
[186]791
[190]792    # -- sw command parser
793    parser_s = subcommand.add_parser('sw')
794
[186]795    parser_s.add_argument('--debug', action='store_true', default=False)
796    parser_s.add_argument('--foreground', action='store_true', default=False)
[190]797    parser_s.add_argument('--ageout', type=int, default=300)
[186]798
[190]799    parser_s.add_argument('--path', default='/')
800    parser_s.add_argument('--host', default='')
801    parser_s.add_argument('--port', type=int)
802    parser_s.add_argument('--htpasswd')
803    parser_s.add_argument('--sslkey')
804    parser_s.add_argument('--sslcert')
[186]805
[190]806    parser_s.add_argument('--ctlpath', default='/ctl')
807    parser_s.add_argument('--ctlhost', default='')
808    parser_s.add_argument('--ctlport', type=int)
809    parser_s.add_argument('--ctlhtpasswd')
810    parser_s.add_argument('--ctlsslkey')
811    parser_s.add_argument('--ctlsslcert')
[186]812
[190]813    # -- ctl command parser
814    parser_c = subcommand.add_parser('ctl')
815    parser_c.add_argument('--ctlurl', default='http://localhost/ctl')
816    parser_c.add_argument('--ctluser')
817    parser_c.add_argument('--ctlpasswd')
818
819    control_method = parser_c.add_subparsers(dest='control_method')
820
821    parser_c_ap = control_method.add_parser('addport')
[191]822    parser_c_ap.add_argument(
823        'type', choices=EtherWebSocketControlHandler.IFTYPES.keys())
[190]824    parser_c_ap.add_argument('target')
825    parser_c_ap.add_argument('--insecure', action='store_true', default=False)
826    parser_c_ap.add_argument('--cacerts')
827    parser_c_ap.add_argument('--user')
828    parser_c_ap.add_argument('--passwd')
829
830    parser_c_sp = control_method.add_parser('shutport')
831    parser_c_sp.add_argument('port', type=int)
832    parser_c_sp.add_argument('--no', action='store_false', default=True)
833
834    parser_c_dp = control_method.add_parser('delport')
835    parser_c_dp.add_argument('port', type=int)
836
837    parser_c_lp = control_method.add_parser('listport')
838
839    # -- go
[186]840    args = parser.parse_args()
841    globals()['start_' + args.subcommand](args)
842
843
[133]844if __name__ == '__main__':
845    main()
Note: See TracBrowser for help on using the repository browser.