source: etherws/trunk/etherws.py @ 198

Revision 198, 26.8 KB checked in by atzm, 12 years ago (diff)
  • add ctl command "listfdb"
  • Property svn:keywords set to Id
RevLine 
[133]1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
[186]4#                          Ethernet over WebSocket
[133]5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
[136]9#   - websocket-client-0.7.0
[183]10#   - tornado-2.3
[195]11#   - PyYAML-3.10
[133]12#
13# ===========================================================================
14# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
15# All rights reserved.
16#
17# Redistribution and use in source and binary forms, with or without
18# modification, are permitted provided that the following conditions are met:
19#
20# 1. Redistributions of source code must retain the above copyright notice,
21#    this list of conditions and the following disclaimer.
22# 2. Redistributions in binary form must reproduce the above copyright
23#    notice, this list of conditions and the following disclaimer in the
24#    documentation and/or other materials provided with the distribution.
25#
26# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
27# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
28# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
29# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
30# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
31# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
32# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
33# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
34# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
35# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
36# POSSIBILITY OF SUCH DAMAGE.
37# ===========================================================================
38#
39# $Id$
40
41import os
42import sys
[156]43import ssl
[160]44import time
[183]45import json
[194]46import yaml
[175]47import fcntl
[150]48import base64
[190]49import urllib2
[150]50import hashlib
[151]51import getpass
[133]52import argparse
[165]53import traceback
[133]54
[185]55import tornado
[133]56import websocket
57
[183]58from tornado.web import Application, RequestHandler
[182]59from tornado.websocket import WebSocketHandler
[183]60from tornado.httpserver import HTTPServer
61from tornado.ioloop import IOLoop
62
[166]63from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
[133]64
[166]65
[160]66class DebugMixIn(object):
[166]67    def dprintf(self, msg, func=lambda: ()):
[160]68        if self._debug:
69            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
[164]70            sys.stderr.write(prefix + (msg % func()))
[160]71
72
[164]73class EthernetFrame(object):
74    def __init__(self, data):
75        self.data = data
76
[176]77    @property
78    def dst_multicast(self):
79        return ord(self.data[0]) & 1
[164]80
81    @property
[176]82    def src_multicast(self):
83        return ord(self.data[6]) & 1
84
85    @property
[164]86    def dst_mac(self):
87        return self.data[:6]
88
89    @property
90    def src_mac(self):
91        return self.data[6:12]
92
93    @property
94    def tagged(self):
95        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
96
97    @property
98    def vid(self):
99        if self.tagged:
100            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
[183]101        return 0
[164]102
[198]103    @staticmethod
104    def format_mac(mac, sep=':'):
105        return sep.join(b.encode('hex') for b in mac)
[164]106
[198]107
[166]108class FDB(DebugMixIn):
[196]109    class Entry(object):
[195]110        def __init__(self, port, ageout):
111            self.port = port
112            self._time = time.time()
113            self._ageout = ageout
114
115        @property
116        def age(self):
117            return time.time() - self._time
118
119        @property
120        def agedout(self):
121            return self.age > self._ageout
122
[167]123    def __init__(self, ageout, debug=False):
[164]124        self._ageout = ageout
125        self._debug = debug
[195]126        self._table = {}
[164]127
[195]128    def _set_entry(self, vid, mac, port):
129        if vid not in self._table:
130            self._table[vid] = {}
[196]131        self._table[vid][mac] = self.Entry(port, self._ageout)
[164]132
[195]133    def _del_entry(self, vid, mac):
134        if vid in self._table:
135            if mac in self._table[vid]:
136                del self._table[vid][mac]
137            if not self._table[vid]:
138                del self._table[vid]
[164]139
[195]140    def get_entry(self, vid, mac):
141        try:
142            entry = self._table[vid][mac]
143        except KeyError:
[164]144            return None
145
[195]146        if not entry.agedout:
147            return entry
[164]148
[195]149        self._del_entry(vid, mac)
150        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
151                     lambda: (entry.port.number, vid, mac.encode('hex')))
[164]152
[195]153    def get_vid_list(self):
[198]154        return sorted(self._table.iterkeys())
[195]155
156    def get_mac_list(self, vid):
[198]157        return sorted(self._table[vid].iterkeys())
[195]158
159    def lookup(self, frame):
160        mac = frame.dst_mac
161        vid = frame.vid
162        entry = self.get_entry(vid, mac)
163        return getattr(entry, 'port', None)
164
[166]165    def learn(self, port, frame):
166        mac = frame.src_mac
167        vid = frame.vid
[195]168        self._set_entry(vid, mac, port)
[183]169        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
170                     lambda: (port.number, vid, mac.encode('hex')))
[166]171
[164]172    def delete(self, port):
[195]173        for vid in self.get_vid_list():
174            for mac in self.get_mac_list(vid):
175                entry = self.get_entry(vid, mac)
176                if entry and entry.port.number == port.number:
177                    self._del_entry(vid, mac)
[183]178                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
179                                 lambda: (port.number, vid, mac.encode('hex')))
[164]180
181
[195]182class SwitchingHub(DebugMixIn):
183    class Port(object):
184        def __init__(self, number, interface):
185            self.number = number
186            self.interface = interface
187            self.tx = 0
188            self.rx = 0
189            self.shut = False
[183]190
[195]191        @staticmethod
192        def cmp_by_number(x, y):
193            return cmp(x.number, y.number)
[183]194
[166]195    def __init__(self, fdb, debug=False):
[197]196        self.fdb = fdb
[133]197        self._debug = debug
[183]198        self._table = {}
199        self._next = 1
[133]200
[183]201    @property
202    def portlist(self):
[195]203        return sorted(self._table.itervalues(), cmp=self.Port.cmp_by_number)
[133]204
[183]205    def get_port(self, portnum):
206        return self._table[portnum]
207
208    def register_port(self, interface):
[186]209        try:
[187]210            self._set_privattr('portnum', interface, self._next)  # XXX
[195]211            self._table[self._next] = self.Port(self._next, interface)
[186]212            return self._next
213        finally:
214            self._next += 1
[183]215
216    def unregister_port(self, interface):
[187]217        portnum = self._get_privattr('portnum', interface)
218        self._del_privattr('portnum', interface)
[197]219        self.fdb.delete(self._table[portnum])
[187]220        del self._table[portnum]
[183]221
222    def send(self, dst_interfaces, frame):
[187]223        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
224        ports = (self._table[n] for n in portnums)
225        ports = (p for p in ports if not p.shut)
[195]226        ports = sorted(ports, cmp=self.Port.cmp_by_number)
[183]227
228        for p in ports:
229            p.interface.write_message(frame.data, True)
230            p.tx += 1
231
232        if ports:
233            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
234                         lambda: (','.join(str(p.number) for p in ports),
235                                  frame.vid,
236                                  frame.src_mac.encode('hex'),
237                                  frame.dst_mac.encode('hex')))
238
239    def receive(self, src_interface, frame):
[187]240        port = self._table[self._get_privattr('portnum', src_interface)]
[183]241
242        if not port.shut:
243            port.rx += 1
244            self._forward(port, frame)
245
246    def _forward(self, src_port, frame):
[166]247        try:
[176]248            if not frame.src_multicast:
[197]249                self.fdb.learn(src_port, frame)
[133]250
[176]251            if not frame.dst_multicast:
[197]252                dst_port = self.fdb.lookup(frame)
[164]253
[166]254                if dst_port:
[183]255                    self.send([dst_port.interface], frame)
[166]256                    return
[133]257
[187]258            ports = set(self.portlist) - set([src_port])
[183]259            self.send((p.interface for p in ports), frame)
[162]260
[166]261        except:  # ex. received invalid frame
262            traceback.print_exc()
[133]263
[187]264    def _privattr(self, name):
265        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
[164]266
[187]267    def _set_privattr(self, name, obj, value):
268        return setattr(obj, self._privattr(name), value)
269
270    def _get_privattr(self, name, obj, defaults=None):
271        return getattr(obj, self._privattr(name), defaults)
272
273    def _del_privattr(self, name, obj):
274        return delattr(obj, self._privattr(name))
275
276
[179]277class Htpasswd(object):
278    def __init__(self, path):
279        self._path = path
280        self._stat = None
281        self._data = {}
282
283    def auth(self, name, passwd):
284        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
285        return self._data.get(name) == passwd
286
287    def load(self):
288        old_stat = self._stat
289
290        with open(self._path) as fp:
291            fileno = fp.fileno()
292            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
293            self._stat = os.fstat(fileno)
294
295            unchanged = old_stat and \
296                        old_stat.st_ino == self._stat.st_ino and \
297                        old_stat.st_dev == self._stat.st_dev and \
298                        old_stat.st_mtime == self._stat.st_mtime
299
300            if not unchanged:
301                self._data = self._parse(fp)
302
303        return self
304
305    def _parse(self, fp):
306        data = {}
307        for line in fp:
308            line = line.strip()
309            if 0 <= line.find(':'):
310                name, passwd = line.split(':', 1)
311                if passwd.startswith('{SHA}'):
312                    data[name] = passwd[5:]
313        return data
314
315
[182]316class BasicAuthMixIn(object):
317    def _execute(self, transforms, *args, **kwargs):
318        def do_execute():
319            sp = super(BasicAuthMixIn, self)
320            return sp._execute(transforms, *args, **kwargs)
321
322        def auth_required():
[185]323            stream = getattr(self, 'stream', self.request.connection.stream)
324            stream.write(tornado.escape.utf8(
[182]325                'HTTP/1.1 401 Authorization Required\r\n'
326                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
327            ))
[185]328            stream.close()
[182]329
330        try:
331            if not self._htpasswd:
332                return do_execute()
333
334            creds = self.request.headers.get('Authorization')
335
336            if not creds or not creds.startswith('Basic '):
337                return auth_required()
338
339            name, passwd = base64.b64decode(creds[6:]).split(':', 1)
340
341            if self._htpasswd.load().auth(name, passwd):
342                return do_execute()
343        except:
344            traceback.print_exc()
345
346        return auth_required()
347
348
[186]349class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
[191]350    IFTYPE = 'server'
351
[186]352    def __init__(self, app, req, switch, htpasswd=None, debug=False):
353        super(EtherWebSocketHandler, self).__init__(app, req)
354        self._switch = switch
355        self._htpasswd = htpasswd
356        self._debug = debug
357
358    def get_target(self):
359        return self.request.remote_ip
360
361    def open(self):
362        try:
363            return self._switch.register_port(self)
364        finally:
365            self.dprintf('connected: %s\n', lambda: self.request.remote_ip)
366
367    def on_message(self, message):
368        self._switch.receive(self, EthernetFrame(message))
369
370    def on_close(self):
371        self._switch.unregister_port(self)
372        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)
373
374
[166]375class TapHandler(DebugMixIn):
[191]376    IFTYPE = 'tap'
[166]377    READ_SIZE = 65535
378
[178]379    def __init__(self, ioloop, switch, dev, debug=False):
380        self._ioloop = ioloop
[166]381        self._switch = switch
382        self._dev = dev
383        self._debug = debug
384        self._tap = None
385
[186]386    def get_target(self):
[183]387        if self.closed:
388            return self._dev
389        return self._tap.name
390
[186]391    @property
392    def closed(self):
393        return not self._tap
394
[166]395    def open(self):
396        if not self.closed:
397            raise ValueError('already opened')
398        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
399        self._tap.up()
[178]400        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
[186]401        return self._switch.register_port(self)
[166]402
403    def close(self):
404        if self.closed:
405            raise ValueError('I/O operation on closed tap')
[186]406        self._switch.unregister_port(self)
[178]407        self._ioloop.remove_handler(self.fileno())
[166]408        self._tap.close()
409        self._tap = None
410
411    def fileno(self):
412        if self.closed:
413            raise ValueError('I/O operation on closed tap')
414        return self._tap.fileno()
415
416    def write_message(self, message, binary=False):
417        if self.closed:
418            raise ValueError('I/O operation on closed tap')
419        self._tap.write(message)
420
[138]421    def __call__(self, fd, events):
[166]422        try:
[183]423            self._switch.receive(self, EthernetFrame(self._read()))
[166]424            return
425        except:
426            traceback.print_exc()
[178]427        self.close()
[166]428
429    def _read(self):
430        if self.closed:
431            raise ValueError('I/O operation on closed tap')
[162]432        buf = []
433        while True:
[166]434            buf.append(self._tap.read(self.READ_SIZE))
435            if len(buf[-1]) < self.READ_SIZE:
[162]436                break
[166]437        return ''.join(buf)
[162]438
439
[160]440class EtherWebSocketClient(DebugMixIn):
[191]441    IFTYPE = 'client'
442
[181]443    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
[178]444        self._ioloop = ioloop
[166]445        self._switch = switch
[151]446        self._url = url
[181]447        self._ssl = ssl_
[160]448        self._debug = debug
[166]449        self._sock = None
[151]450        self._options = {}
451
[174]452        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
453            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
[151]454            auth = ['Authorization: Basic %s' % token]
455            self._options['header'] = auth
456
[186]457    def get_target(self):
458        return self._url
459
[160]460    @property
461    def closed(self):
462        return not self._sock
463
[151]464    def open(self):
[181]465        sslwrap = websocket._SSLSocketWrapper
466
[160]467        if not self.closed:
468            raise websocket.WebSocketException('already opened')
[151]469
[181]470        if self._ssl:
471            websocket._SSLSocketWrapper = self._ssl
472
473        try:
474            self._sock = websocket.WebSocket()
475            self._sock.connect(self._url, **self._options)
476            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
[186]477            return self._switch.register_port(self)
[181]478        finally:
479            websocket._SSLSocketWrapper = sslwrap
[186]480            self.dprintf('connected: %s\n', lambda: self._url)
[181]481
[151]482    def close(self):
[160]483        if self.closed:
484            raise websocket.WebSocketException('already closed')
[186]485        self._switch.unregister_port(self)
[178]486        self._ioloop.remove_handler(self.fileno())
[151]487        self._sock.close()
488        self._sock = None
[164]489        self.dprintf('disconnected: %s\n', lambda: self._url)
[151]490
[165]491    def fileno(self):
492        if self.closed:
493            raise websocket.WebSocketException('closed socket')
494        return self._sock.io_sock.fileno()
495
[151]496    def write_message(self, message, binary=False):
[160]497        if self.closed:
498            raise websocket.WebSocketException('closed socket')
[151]499        if binary:
500            flag = websocket.ABNF.OPCODE_BINARY
[160]501        else:
502            flag = websocket.ABNF.OPCODE_TEXT
[151]503        self._sock.send(message, flag)
504
[165]505    def __call__(self, fd, events):
[151]506        try:
[165]507            data = self._sock.recv()
508            if data is not None:
[183]509                self._switch.receive(self, EthernetFrame(data))
[165]510                return
511        except:
512            traceback.print_exc()
[178]513        self.close()
[151]514
515
[183]516class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
517    NAMESPACE = 'etherws.control'
[191]518    IFTYPES = {
519        TapHandler.IFTYPE:           TapHandler,
520        EtherWebSocketClient.IFTYPE: EtherWebSocketClient,
[186]521    }
[183]522
523    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
524        super(EtherWebSocketControlHandler, self).__init__(app, req)
525        self._ioloop = ioloop
526        self._switch = switch
527        self._htpasswd = htpasswd
528        self._debug = debug
529
530    def post(self):
531        id_ = None
532
533        try:
534            req = json.loads(self.request.body)
535            method = req['method']
536            params = req['params']
537            id_ = req.get('id')
538
539            if not method.startswith(self.NAMESPACE + '.'):
540                raise ValueError('invalid method: %s' % method)
541
542            if not isinstance(params, list):
543                raise ValueError('invalid params: %s' % params)
544
545            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
546            result = getattr(self, handler)(params)
547            self.finish({'result': result, 'error': None, 'id': id_})
548
549        except Exception as e:
550            traceback.print_exc()
[192]551            msg = '%s: %s' % (e.__class__.__name__, str(e))
552            self.finish({'result': None, 'error': {'message': msg}, 'id': id_})
[183]553
[198]554    def handle_listFdb(self, params):
555        list_ = []
556        for vid in self._switch.fdb.get_vid_list():
557            for mac in self._switch.fdb.get_mac_list(vid):
558                entry = self._switch.fdb.get_entry(vid, mac)
559                if entry:
560                    mac = EthernetFrame.format_mac(mac)
561                    list_.append([vid, mac, entry.port.number, int(entry.age)])
562        return {'result': list_}
563
[183]564    def handle_listPort(self, params):
[188]565        list_ = [self._portstat(p) for p in self._switch.portlist]
[198]566        return {'result': list_}
[183]567
568    def handle_addPort(self, params):
[186]569        list_ = []
[183]570        for p in params:
[186]571            type_ = p['type']
572            target = p['target']
573            options = getattr(self, '_optparse_' + type_)(p.get('options', {}))
[191]574            klass = self.IFTYPES[type_]
[186]575            interface = klass(self._ioloop, self._switch, target, **options)
576            portnum = interface.open()
577            list_.append(self._portstat(self._switch.get_port(portnum)))
[198]578        return {'result': list_}
[183]579
580    def handle_delPort(self, params):
[186]581        list_ = []
[183]582        for p in params:
[186]583            port = self._switch.get_port(int(p['port']))
584            list_.append(self._portstat(port))
585            port.interface.close()
[198]586        return {'result': list_}
[183]587
588    def handle_shutPort(self, params):
[186]589        list_ = []
[183]590        for p in params:
[186]591            port = self._switch.get_port(int(p['port']))
[190]592            port.shut = bool(p['shut'])
[186]593            list_.append(self._portstat(port))
[198]594        return {'result': list_}
[183]595
[186]596    def _optparse_tap(self, opt):
597        return {'debug': self._debug}
[183]598
[186]599    def _optparse_client(self, opt):
600        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': opt.get('cacerts')}
601        if opt.get('insecure'):
602            args = {}
603        ssl_ = lambda sock: ssl.wrap_socket(sock, **args)
604        cred = {'user': opt.get('user'), 'passwd': opt.get('passwd')}
605        return {'ssl_': ssl_, 'cred': cred, 'debug': self._debug}
[183]606
607    @staticmethod
[186]608    def _portstat(port):
609        return {
610            'port':   port.number,
[191]611            'type':   port.interface.IFTYPE,
[186]612            'target': port.interface.get_target(),
613            'tx':     port.tx,
614            'rx':     port.rx,
615            'shut':   port.shut,
616        }
[183]617
618
[190]619def start_sw(args):
[186]620    def daemonize(nochdir=False, noclose=False):
621        if os.fork() > 0:
622            sys.exit(0)
[134]623
[186]624        os.setsid()
[134]625
[186]626        if os.fork() > 0:
627            sys.exit(0)
[134]628
[186]629        if not nochdir:
630            os.chdir('/')
[134]631
[186]632        if not noclose:
633            os.umask(0)
634            sys.stdin.close()
635            sys.stdout.close()
636            sys.stderr.close()
637            os.close(0)
638            os.close(1)
639            os.close(2)
640            sys.stdin = open(os.devnull)
641            sys.stdout = open(os.devnull, 'a')
642            sys.stderr = open(os.devnull, 'a')
[134]643
[186]644    def checkabspath(ns, path):
[184]645        val = getattr(ns, path, '')
646        if not val.startswith('/'):
647            raise ValueError('invalid %: %s' % (path, val))
648
649    def getsslopt(ns, key, cert):
650        kval = getattr(ns, key, None)
651        cval = getattr(ns, cert, None)
652        if kval and cval:
653            return {'keyfile': kval, 'certfile': cval}
654        elif kval or cval:
655            raise ValueError('both %s and %s are required' % (key, cert))
656        return None
657
[186]658    def setrealpath(ns, *keys):
659        for k in keys:
660            v = getattr(ns, k, None)
661            if v is not None:
662                v = os.path.realpath(v)
663                open(v).close()  # check readable
664                setattr(ns, k, v)
665
[184]666    def setport(ns, port, isssl):
667        val = getattr(ns, port, None)
668        if val is None:
669            if isssl:
670                return setattr(ns, port, 443)
671            return setattr(ns, port, 80)
672        if not (0 <= val <= 65535):
673            raise ValueError('invalid %s: %s' % (port, val))
674
675    def sethtpasswd(ns, htpasswd):
676        val = getattr(ns, htpasswd, None)
677        if val:
678            return setattr(ns, htpasswd, Htpasswd(val))
679
[183]680    #if args.debug:
681    #    websocket.enableTrace(True)
682
683    if args.ageout <= 0:
684        raise ValueError('invalid ageout: %s' % args.ageout)
685
[186]686    setrealpath(args, 'htpasswd', 'sslkey', 'sslcert')
687    setrealpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')
[183]688
[186]689    checkabspath(args, 'path')
690    checkabspath(args, 'ctlpath')
[183]691
[184]692    sslopt = getsslopt(args, 'sslkey', 'sslcert')
693    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')
[143]694
[184]695    setport(args, 'port', sslopt)
696    setport(args, 'ctlport', ctlsslopt)
[143]697
[184]698    sethtpasswd(args, 'htpasswd')
699    sethtpasswd(args, 'ctlhtpasswd')
[167]700
[183]701    ioloop = IOLoop.instance()
[167]702    fdb = FDB(ageout=args.ageout, debug=args.debug)
[183]703    switch = SwitchingHub(fdb, debug=args.debug)
[167]704
[184]705    if args.port == args.ctlport and args.host == args.ctlhost:
706        if args.path == args.ctlpath:
707            raise ValueError('same path/ctlpath on same host')
708        if args.sslkey != args.ctlsslkey:
[189]709            raise ValueError('different sslkey/ctlsslkey on same host')
[184]710        if args.sslcert != args.ctlsslcert:
[189]711            raise ValueError('different sslcert/ctlsslcert on same host')
[133]712
[184]713        app = Application([
714            (args.path, EtherWebSocketHandler, {
715                'switch':   switch,
716                'htpasswd': args.htpasswd,
717                'debug':    args.debug,
718            }),
719            (args.ctlpath, EtherWebSocketControlHandler, {
720                'ioloop':   ioloop,
721                'switch':   switch,
722                'htpasswd': args.ctlhtpasswd,
723                'debug':    args.debug,
724            }),
725        ])
726        server = HTTPServer(app, ssl_options=sslopt)
727        server.listen(args.port, address=args.host)
[151]728
[184]729    else:
730        app = Application([(args.path, EtherWebSocketHandler, {
731            'switch':   switch,
732            'htpasswd': args.htpasswd,
733            'debug':    args.debug,
734        })])
735        server = HTTPServer(app, ssl_options=sslopt)
736        server.listen(args.port, address=args.host)
737
738        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
739            'ioloop':   ioloop,
740            'switch':   switch,
741            'htpasswd': args.ctlhtpasswd,
742            'debug':    args.debug,
743        })])
744        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
745        ctlserver.listen(args.ctlport, address=args.ctlhost)
746
[151]747    if not args.foreground:
748        daemonize()
749
[138]750    ioloop.start()
[133]751
752
[190]753def start_ctl(args):
754    def request(args, method, params):
755        method = '.'.join([EtherWebSocketControlHandler.NAMESPACE, method])
756        data = json.dumps({'method': method, 'params': params})
757        req = urllib2.Request(args.ctlurl)
758        req.add_header('Content-type', 'application/json')
759        if args.ctluser:
760            if not args.ctlpasswd:
761                args.ctlpasswd = getpass.getpass()
762            token = base64.b64encode('%s:%s' % (args.ctluser, args.ctlpasswd))
763            req.add_header('Authorization', 'Basic %s' % token)
764        return json.loads(urllib2.urlopen(req, data).read())
765
766    def handle_ctl_addport(args):
[191]767        params = [{
768            'type':    args.type,
769            'target':  args.target,
770            'options': {
[193]771                'insecure': args.insecure,
[191]772                'cacerts':  args.cacerts,
773                'user':     args.user,
774                'passwd':   args.passwd,
775            }
776        }]
[192]777        return request(args, 'addPort', params)
[190]778
779    def handle_ctl_shutport(args):
780        if args.port <= 0:
781            raise ValueError('invalid port: %d' % args.port)
[192]782        params = [{'port': args.port, 'shut': args.no}]
783        return request(args, 'shutPort', params)
[190]784
785    def handle_ctl_delport(args):
786        if args.port <= 0:
787            raise ValueError('invalid port: %d' % args.port)
[192]788        params = [{'port': args.port}]
789        return request(args, 'delPort', params)
[190]790
791    def handle_ctl_listport(args):
[192]792        return request(args, 'listPort', [])
[190]793
[198]794    def handle_ctl_listfdb(args):
795        return request(args, 'listFdb', [])
796
[192]797    res = locals()['handle_ctl_' + args.control_method](args)
[190]798
[192]799    if res['error']:
800        print(res['error']['message'])
801    else:
[198]802        print(yaml.safe_dump(res['result']).strip())
[190]803
[192]804
[186]805def main():
806    parser = argparse.ArgumentParser()
[190]807    subcommand = parser.add_subparsers(dest='subcommand')
[186]808
[190]809    # -- sw command parser
810    parser_s = subcommand.add_parser('sw')
811
[186]812    parser_s.add_argument('--debug', action='store_true', default=False)
813    parser_s.add_argument('--foreground', action='store_true', default=False)
[190]814    parser_s.add_argument('--ageout', type=int, default=300)
[186]815
[190]816    parser_s.add_argument('--path', default='/')
817    parser_s.add_argument('--host', default='')
818    parser_s.add_argument('--port', type=int)
819    parser_s.add_argument('--htpasswd')
820    parser_s.add_argument('--sslkey')
821    parser_s.add_argument('--sslcert')
[186]822
[190]823    parser_s.add_argument('--ctlpath', default='/ctl')
824    parser_s.add_argument('--ctlhost', default='')
825    parser_s.add_argument('--ctlport', type=int)
826    parser_s.add_argument('--ctlhtpasswd')
827    parser_s.add_argument('--ctlsslkey')
828    parser_s.add_argument('--ctlsslcert')
[186]829
[190]830    # -- ctl command parser
831    parser_c = subcommand.add_parser('ctl')
832    parser_c.add_argument('--ctlurl', default='http://localhost/ctl')
833    parser_c.add_argument('--ctluser')
834    parser_c.add_argument('--ctlpasswd')
835
836    control_method = parser_c.add_subparsers(dest='control_method')
837
838    parser_c_ap = control_method.add_parser('addport')
[191]839    parser_c_ap.add_argument(
840        'type', choices=EtherWebSocketControlHandler.IFTYPES.keys())
[190]841    parser_c_ap.add_argument('target')
842    parser_c_ap.add_argument('--insecure', action='store_true', default=False)
843    parser_c_ap.add_argument('--cacerts')
844    parser_c_ap.add_argument('--user')
845    parser_c_ap.add_argument('--passwd')
846
847    parser_c_sp = control_method.add_parser('shutport')
848    parser_c_sp.add_argument('port', type=int)
849    parser_c_sp.add_argument('--no', action='store_false', default=True)
850
851    parser_c_dp = control_method.add_parser('delport')
852    parser_c_dp.add_argument('port', type=int)
853
854    parser_c_lp = control_method.add_parser('listport')
855
[198]856    parser_c_lf = control_method.add_parser('listfdb')
857
[190]858    # -- go
[186]859    args = parser.parse_args()
860    globals()['start_' + args.subcommand](args)
861
862
[133]863if __name__ == '__main__':
864    main()
Note: See TracBrowser for help on using the repository browser.