#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
#              Ethernet over WebSocket tunneling server/client
#
# depends on:
#   - python-2.7.2
#   - python-pytun-0.2
#   - websocket-client-0.7.0
#   - tornado-2.3
#
# todo:
#   - servant mode support (like typical p2p software)
#
# ===========================================================================
# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
#    this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
# ===========================================================================
#
# $Id$

import os
import sys
import ssl
import time
import json
import fcntl
import base64
import hashlib
import getpass
import argparse
import traceback

import tornado
import websocket

from tornado.web import Application, RequestHandler
from tornado.websocket import WebSocketHandler
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop

from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI


class DebugMixIn(object):
    def dprintf(self, msg, func=lambda: ()):
        if self._debug:
            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
            sys.stderr.write(prefix + (msg % func()))


class EthernetFrame(object):
    def __init__(self, data):
        self.data = data

    @property
    def dst_multicast(self):
        return ord(self.data[0]) & 1

    @property
    def src_multicast(self):
        return ord(self.data[6]) & 1

    @property
    def dst_mac(self):
        return self.data[:6]

    @property
    def src_mac(self):
        return self.data[6:12]

    @property
    def tagged(self):
        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0

    @property
    def vid(self):
        if self.tagged:
            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
        return 0


class FDB(DebugMixIn):
    def __init__(self, ageout, debug=False):
        self._ageout = ageout
        self._debug = debug
        self._dict = {}

    def lookup(self, frame):
        mac = frame.dst_mac
        vid = frame.vid

        group = self._dict.get(vid)
        if not group:
            return None

        entry = group.get(mac)
        if not entry:
            return None

        if time.time() - entry['time'] > self._ageout:
            port = self._dict[vid][mac]['port']
            del self._dict[vid][mac]
            if not self._dict[vid]:
                del self._dict[vid]
            self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
                         lambda: (port.number, vid, mac.encode('hex')))
            return None

        return entry['port']

    def learn(self, port, frame):
        mac = frame.src_mac
        vid = frame.vid

        if vid not in self._dict:
            self._dict[vid] = {}

        self._dict[vid][mac] = {'time': time.time(), 'port': port}
        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
                     lambda: (port.number, vid, mac.encode('hex')))

    def delete(self, port):
        for vid in self._dict.keys():
            for mac in self._dict[vid].keys():
                if self._dict[vid][mac]['port'].number == port.number:
                    del self._dict[vid][mac]
                    self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
                                 lambda: (port.number, vid, mac.encode('hex')))
            if not self._dict[vid]:
                del self._dict[vid]


class SwitchPort(object):
    def __init__(self, number, interface):
        self.number = number
        self.interface = interface
        self.tx = 0
        self.rx = 0
        self.shut = False

    @staticmethod
    def cmp_by_number(x, y):
        return cmp(x.number, y.number)


class SwitchingHub(DebugMixIn):
    def __init__(self, fdb, debug=False):
        self._fdb = fdb
        self._debug = debug
        self._table = {}
        self._next = 1

    @property
    def portlist(self):
        return sorted(self._table.itervalues(), cmp=SwitchPort.cmp_by_number)

    def shut_port(self, portnum, flag=True):
        self._table[portnum].shut = flag

    def get_port(self, portnum):
        return self._table[portnum]

    def register_port(self, interface):
        interface._switch_portnum = self._next  # XXX
        self._table[self._next] = SwitchPort(self._next, interface)
        self._next += 1

    def unregister_port(self, interface):
        self._fdb.delete(self._table[interface._switch_portnum])
        del self._table[interface._switch_portnum]
        del interface._switch_portnum

    def send(self, dst_interfaces, frame):
        ports = sorted((self._table[i._switch_portnum] for i in dst_interfaces
                        if not self._table[i._switch_portnum].shut),
                       cmp=SwitchPort.cmp_by_number)

        for p in ports:
            p.interface.write_message(frame.data, True)
            p.tx += 1

        if ports:
            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
                         lambda: (','.join(str(p.number) for p in ports),
                                  frame.vid,
                                  frame.src_mac.encode('hex'),
                                  frame.dst_mac.encode('hex')))

    def receive(self, src_interface, frame):
        port = self._table[src_interface._switch_portnum]

        if not port.shut:
            port.rx += 1
            self._forward(port, frame)

    def _forward(self, src_port, frame):
        try:
            if not frame.src_multicast:
                self._fdb.learn(src_port, frame)

            if not frame.dst_multicast:
                dst_port = self._fdb.lookup(frame)

                if dst_port:
                    self.send([dst_port.interface], frame)
                    return

            ports = set(self._table.itervalues()) - set([src_port])
            self.send((p.interface for p in ports), frame)

        except:  # ex. received invalid frame
            traceback.print_exc()


class Htpasswd(object):
    def __init__(self, path):
        self._path = path
        self._stat = None
        self._data = {}

    def auth(self, name, passwd):
        passwd = base64.b64encode(hashlib.sha1(passwd).digest())
        return self._data.get(name) == passwd

    def load(self):
        old_stat = self._stat

        with open(self._path) as fp:
            fileno = fp.fileno()
            fcntl.flock(fileno, fcntl.LOCK_SH | fcntl.LOCK_NB)
            self._stat = os.fstat(fileno)

            unchanged = old_stat and \
                        old_stat.st_ino == self._stat.st_ino and \
                        old_stat.st_dev == self._stat.st_dev and \
                        old_stat.st_mtime == self._stat.st_mtime

            if not unchanged:
                self._data = self._parse(fp)

        return self

    def _parse(self, fp):
        data = {}
        for line in fp:
            line = line.strip()
            if 0 <= line.find(':'):
                name, passwd = line.split(':', 1)
                if passwd.startswith('{SHA}'):
                    data[name] = passwd[5:]
        return data


class BasicAuthMixIn(object):
    def _execute(self, transforms, *args, **kwargs):
        def do_execute():
            sp = super(BasicAuthMixIn, self)
            return sp._execute(transforms, *args, **kwargs)

        def auth_required():
            stream = getattr(self, 'stream', self.request.connection.stream)
            stream.write(tornado.escape.utf8(
                'HTTP/1.1 401 Authorization Required\r\n'
                'WWW-Authenticate: Basic realm=etherws\r\n\r\n'
            ))
            stream.close()

        try:
            if not self._htpasswd:
                return do_execute()

            creds = self.request.headers.get('Authorization')

            if not creds or not creds.startswith('Basic '):
                return auth_required()

            name, passwd = base64.b64decode(creds[6:]).split(':', 1)

            if self._htpasswd.load().auth(name, passwd):
                return do_execute()
        except:
            traceback.print_exc()

        return auth_required()


class TapHandler(DebugMixIn):
    READ_SIZE = 65535

    def __init__(self, ioloop, switch, dev, debug=False):
        self._ioloop = ioloop
        self._switch = switch
        self._dev = dev
        self._debug = debug
        self._tap = None

    @property
    def closed(self):
        return not self._tap

    def get_type(self):
        return 'tap'

    def get_name(self):
        if self.closed:
            return self._dev
        return self._tap.name

    def open(self):
        if not self.closed:
            raise ValueError('already opened')
        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
        self._tap.up()
        self._switch.register_port(self)
        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)

    def close(self):
        if self.closed:
            raise ValueError('I/O operation on closed tap')
        self._ioloop.remove_handler(self.fileno())
        self._switch.unregister_port(self)
        self._tap.close()
        self._tap = None

    def fileno(self):
        if self.closed:
            raise ValueError('I/O operation on closed tap')
        return self._tap.fileno()

    def write_message(self, message, binary=False):
        if self.closed:
            raise ValueError('I/O operation on closed tap')
        self._tap.write(message)

    def __call__(self, fd, events):
        try:
            self._switch.receive(self, EthernetFrame(self._read()))
            return
        except:
            traceback.print_exc()
        self.close()

    def _read(self):
        if self.closed:
            raise ValueError('I/O operation on closed tap')
        buf = []
        while True:
            buf.append(self._tap.read(self.READ_SIZE))
            if len(buf[-1]) < self.READ_SIZE:
                break
        return ''.join(buf)


class EtherWebSocketHandler(DebugMixIn, BasicAuthMixIn, WebSocketHandler):
    def __init__(self, app, req, switch, htpasswd=None, debug=False):
        super(EtherWebSocketHandler, self).__init__(app, req)
        self._switch = switch
        self._htpasswd = htpasswd
        self._debug = debug

    def get_type(self):
        return 'server'

    def get_name(self):
        return self.request.remote_ip

    def open(self):
        self._switch.register_port(self)
        self.dprintf('connected: %s\n', lambda: self.request.remote_ip)

    def on_message(self, message):
        self._switch.receive(self, EthernetFrame(message))

    def on_close(self):
        self._switch.unregister_port(self)
        self.dprintf('disconnected: %s\n', lambda: self.request.remote_ip)


class EtherWebSocketClient(DebugMixIn):
    def __init__(self, ioloop, switch, url, ssl_=None, cred=None, debug=False):
        self._ioloop = ioloop
        self._switch = switch
        self._url = url
        self._ssl = ssl_
        self._debug = debug
        self._sock = None
        self._options = {}

        if isinstance(cred, dict) and cred['user'] and cred['passwd']:
            token = base64.b64encode('%s:%s' % (cred['user'], cred['passwd']))
            auth = ['Authorization: Basic %s' % token]
            self._options['header'] = auth

    @property
    def closed(self):
        return not self._sock

    def get_type(self):
        return 'client'

    def get_name(self):
        return self._url

    def open(self):
        sslwrap = websocket._SSLSocketWrapper

        if not self.closed:
            raise websocket.WebSocketException('already opened')

        if self._ssl:
            websocket._SSLSocketWrapper = self._ssl

        try:
            self._sock = websocket.WebSocket()
            self._sock.connect(self._url, **self._options)
            self._switch.register_port(self)
            self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
            self.dprintf('connected: %s\n', lambda: self._url)
        finally:
            websocket._SSLSocketWrapper = sslwrap

    def close(self):
        if self.closed:
            raise websocket.WebSocketException('already closed')
        self._ioloop.remove_handler(self.fileno())
        self._switch.unregister_port(self)
        self._sock.close()
        self._sock = None
        self.dprintf('disconnected: %s\n', lambda: self._url)

    def fileno(self):
        if self.closed:
            raise websocket.WebSocketException('closed socket')
        return self._sock.io_sock.fileno()

    def write_message(self, message, binary=False):
        if self.closed:
            raise websocket.WebSocketException('closed socket')
        if binary:
            flag = websocket.ABNF.OPCODE_BINARY
        else:
            flag = websocket.ABNF.OPCODE_TEXT
        self._sock.send(message, flag)

    def __call__(self, fd, events):
        try:
            data = self._sock.recv()
            if data is not None:
                self._switch.receive(self, EthernetFrame(data))
                return
        except:
            traceback.print_exc()
        self.close()


class EtherWebSocketControlHandler(DebugMixIn, BasicAuthMixIn, RequestHandler):
    NAMESPACE = 'etherws.control'

    def __init__(self, app, req, ioloop, switch, htpasswd=None, debug=False):
        super(EtherWebSocketControlHandler, self).__init__(app, req)
        self._ioloop = ioloop
        self._switch = switch
        self._htpasswd = htpasswd
        self._debug = debug

    def post(self):
        id_ = None

        try:
            req = json.loads(self.request.body)
            method = req['method']
            params = req['params']
            id_ = req.get('id')

            if not method.startswith(self.NAMESPACE + '.'):
                raise ValueError('invalid method: %s' % method)

            if not isinstance(params, list):
                raise ValueError('invalid params: %s' % params)

            handler = 'handle_' + method[len(self.NAMESPACE) + 1:]
            result = getattr(self, handler)(params)
            self.finish({'result': result, 'error': None, 'id': id_})

        except Exception as e:
            traceback.print_exc()
            self.finish({'result': None, 'error': str(e), 'id': id_})

    def handle_listPort(self, params):
        list_ = []
        for port in self._switch.portlist:
            list_.append({
                'port': port.number,
                'type': port.interface.get_type(),
                'name': port.interface.get_name(),
                'tx':   port.tx,
                'rx':   port.rx,
                'shut': port.shut,
            })
        return {'portlist': list_}

    def handle_addPort(self, params):
        for p in params:
            getattr(self, '_openport_' + p['type'])(p)
        return self.handle_listPort(params)

    def handle_delPort(self, params):
        for p in params:
            self._switch.get_port(int(p['port'])).interface.close()
        return self.handle_listPort(params)

    def handle_shutPort(self, params):
        for p in params:
            self._switch.shut_port(int(p['port']), bool(p['flag']))
        return self.handle_listPort(params)

    def _openport_tap(self, p):
        dev = p['device']
        tap = TapHandler(self._ioloop, self._switch, dev, debug=self._debug)
        tap.open()

    def _openport_client(self, p):
        ssl_ = self._ssl_wrapper(p.get('insecure'), p.get('cacerts'))
        cred = {'user': p.get('user'), 'passwd': p.get('passwd')}
        url = p['url']
        client = EtherWebSocketClient(self._ioloop, self._switch,
                                      url, ssl_, cred, self._debug)
        client.open()

    @staticmethod
    def _ssl_wrapper(insecure, ca_certs):
        args = {'cert_reqs': ssl.CERT_REQUIRED, 'ca_certs': ca_certs}
        if insecure:
            args = {}
        return lambda sock: ssl.wrap_socket(sock, **args)


def daemonize(nochdir=False, noclose=False):
    if os.fork() > 0:
        sys.exit(0)

    os.setsid()

    if os.fork() > 0:
        sys.exit(0)

    if not nochdir:
        os.chdir('/')

    if not noclose:
        os.umask(0)
        sys.stdin.close()
        sys.stdout.close()
        sys.stderr.close()
        os.close(0)
        os.close(1)
        os.close(2)
        sys.stdin = open(os.devnull)
        sys.stdout = open(os.devnull, 'a')
        sys.stderr = open(os.devnull, 'a')


def main():
    def realpath(ns, *keys):
        for k in keys:
            v = getattr(ns, k, None)
            if v is not None:
                v = os.path.realpath(v)
                open(v).close()  # check readable
                setattr(ns, k, v)

    def checkpath(ns, path):
        val = getattr(ns, path, '')
        if not val.startswith('/'):
            raise ValueError('invalid %: %s' % (path, val))

    def getsslopt(ns, key, cert):
        kval = getattr(ns, key, None)
        cval = getattr(ns, cert, None)
        if kval and cval:
            return {'keyfile': kval, 'certfile': cval}
        elif kval or cval:
            raise ValueError('both %s and %s are required' % (key, cert))
        return None

    def setport(ns, port, isssl):
        val = getattr(ns, port, None)
        if val is None:
            if isssl:
                return setattr(ns, port, 443)
            return setattr(ns, port, 80)
        if not (0 <= val <= 65535):
            raise ValueError('invalid %s: %s' % (port, val))

    def sethtpasswd(ns, htpasswd):
        val = getattr(ns, htpasswd, None)
        if val:
            return setattr(ns, htpasswd, Htpasswd(val))

    parser = argparse.ArgumentParser()

    parser.add_argument('--debug', action='store_true', default=False)
    parser.add_argument('--foreground', action='store_true', default=False)
    parser.add_argument('--ageout', action='store', type=int, default=300)

    parser.add_argument('--path', action='store', default='/')
    parser.add_argument('--host', action='store', default='')
    parser.add_argument('--port', action='store', type=int)
    parser.add_argument('--htpasswd', action='store')
    parser.add_argument('--sslkey', action='store')
    parser.add_argument('--sslcert', action='store')

    parser.add_argument('--ctlpath', action='store', default='/ctl')
    parser.add_argument('--ctlhost', action='store', default='')
    parser.add_argument('--ctlport', action='store', type=int)
    parser.add_argument('--ctlhtpasswd', action='store')
    parser.add_argument('--ctlsslkey', action='store')
    parser.add_argument('--ctlsslcert', action='store')

    args = parser.parse_args()

    #if args.debug:
    #    websocket.enableTrace(True)

    if args.ageout <= 0:
        raise ValueError('invalid ageout: %s' % args.ageout)

    realpath(args, 'htpasswd', 'sslkey', 'sslcert')
    realpath(args, 'ctlhtpasswd', 'ctlsslkey', 'ctlsslcert')

    checkpath(args, 'path')
    checkpath(args, 'ctlpath')

    sslopt = getsslopt(args, 'sslkey', 'sslcert')
    ctlsslopt = getsslopt(args, 'ctlsslkey', 'ctlsslcert')

    setport(args, 'port', sslopt)
    setport(args, 'ctlport', ctlsslopt)

    sethtpasswd(args, 'htpasswd')
    sethtpasswd(args, 'ctlhtpasswd')

    ioloop = IOLoop.instance()
    fdb = FDB(ageout=args.ageout, debug=args.debug)
    switch = SwitchingHub(fdb, debug=args.debug)

    if args.port == args.ctlport and args.host == args.ctlhost:
        if args.path == args.ctlpath:
            raise ValueError('same path/ctlpath on same host')
        if args.sslkey != args.ctlsslkey:
            raise ValueError('differ sslkey/ctlsslkey on same host')
        if args.sslcert != args.ctlsslcert:
            raise ValueError('differ sslcert/ctlsslcert on same host')

        app = Application([
            (args.path, EtherWebSocketHandler, {
                'switch':   switch,
                'htpasswd': args.htpasswd,
                'debug':    args.debug,
            }),
            (args.ctlpath, EtherWebSocketControlHandler, {
                'ioloop':   ioloop,
                'switch':   switch,
                'htpasswd': args.ctlhtpasswd,
                'debug':    args.debug,
            }),
        ])
        server = HTTPServer(app, ssl_options=sslopt)
        server.listen(args.port, address=args.host)

    else:
        app = Application([(args.path, EtherWebSocketHandler, {
            'switch':   switch,
            'htpasswd': args.htpasswd,
            'debug':    args.debug,
        })])
        server = HTTPServer(app, ssl_options=sslopt)
        server.listen(args.port, address=args.host)

        ctl = Application([(args.ctlpath, EtherWebSocketControlHandler, {
            'ioloop':   ioloop,
            'switch':   switch,
            'htpasswd': args.ctlhtpasswd,
            'debug':    args.debug,
        })])
        ctlserver = HTTPServer(ctl, ssl_options=ctlsslopt)
        ctlserver.listen(args.ctlport, address=args.ctlhost)

    if not args.foreground:
        daemonize()

    ioloop.start()


if __name__ == '__main__':
    main()
