source: etherirc/trunk/etherirc.py @ 223

Revision 223, 15.5 KB checked in by atzm, 12 years ago (diff)

initial commit

  • Property svn:keywords set to Id
Line 
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4#                         Ethernet over IRC client
5#
6# depends on:
7#   - python-2.7.2
8#   - python-pytun-0.2
9#   - tornado-2.3
10#
11# ===========================================================================
12# Copyright (c) 2012, Atzm WATANABE <atzm@atzm.org>
13# All rights reserved.
14#
15# Redistribution and use in source and binary forms, with or without
16# modification, are permitted provided that the following conditions are met:
17#
18# 1. Redistributions of source code must retain the above copyright notice,
19#    this list of conditions and the following disclaimer.
20# 2. Redistributions in binary form must reproduce the above copyright
21#    notice, this list of conditions and the following disclaimer in the
22#    documentation and/or other materials provided with the distribution.
23#
24# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
25# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
26# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
27# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
28# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
29# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
30# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
31# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
32# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
33# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34# POSSIBILITY OF SUCH DAMAGE.
35# ===========================================================================
36#
37# $Id$
38
39import os
40import sys
41import time
42import errno
43import socket
44import argparse
45import traceback
46
47import tornado.ioloop
48from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI
49
50
51class DebugMixIn(object):
52    def dprintf(self, msg, func=lambda: ()):
53        if self._debug:
54            prefix = '[%s] %s - ' % (time.asctime(), self.__class__.__name__)
55            sys.stderr.write(prefix + (msg % func()))
56
57
58class EthernetFrame(object):
59    def __init__(self, data):
60        self.data = data
61
62    @property
63    def dst_multicast(self):
64        return ord(self.data[0]) & 1
65
66    @property
67    def src_multicast(self):
68        return ord(self.data[6]) & 1
69
70    @property
71    def dst_mac(self):
72        return self.data[:6]
73
74    @property
75    def src_mac(self):
76        return self.data[6:12]
77
78    @property
79    def tagged(self):
80        return ord(self.data[12]) == 0x81 and ord(self.data[13]) == 0
81
82    @property
83    def vid(self):
84        if self.tagged:
85            return ((ord(self.data[14]) << 8) | ord(self.data[15])) & 0x0fff
86        return 0
87
88    @staticmethod
89    def format_mac(mac, sep=':'):
90        return sep.join(b.encode('hex') for b in mac)
91
92
93class FDB(DebugMixIn):
94    class Entry(object):
95        def __init__(self, port, ageout):
96            self.port = port
97            self._time = time.time()
98            self._ageout = ageout
99
100        @property
101        def age(self):
102            return time.time() - self._time
103
104        @property
105        def agedout(self):
106            return self.age > self._ageout
107
108    def __init__(self, ageout, debug=False):
109        self._ageout = ageout
110        self._debug = debug
111        self._table = {}
112
113    def _set_entry(self, vid, mac, port):
114        if vid not in self._table:
115            self._table[vid] = {}
116        self._table[vid][mac] = self.Entry(port, self._ageout)
117
118    def _del_entry(self, vid, mac):
119        if vid in self._table:
120            if mac in self._table[vid]:
121                del self._table[vid][mac]
122            if not self._table[vid]:
123                del self._table[vid]
124
125    def _get_entry(self, vid, mac):
126        try:
127            entry = self._table[vid][mac]
128        except KeyError:
129            return None
130
131        if not entry.agedout:
132            return entry
133
134        self._del_entry(vid, mac)
135        self.dprintf('aged out: port:%d; vid:%d; mac:%s\n',
136                     lambda: (entry.port.number, vid, mac.encode('hex')))
137
138    def each(self):
139        for vid in sorted(self._table.iterkeys()):
140            for mac in sorted(self._table[vid].iterkeys()):
141                entry = self._get_entry(vid, mac)
142                if entry:
143                    yield (vid, mac, entry)
144
145    def lookup(self, frame):
146        mac = frame.dst_mac
147        vid = frame.vid
148        entry = self._get_entry(vid, mac)
149        return getattr(entry, 'port', None)
150
151    def learn(self, port, frame):
152        mac = frame.src_mac
153        vid = frame.vid
154        self._set_entry(vid, mac, port)
155        self.dprintf('learned: port:%d; vid:%d; mac:%s\n',
156                     lambda: (port.number, vid, mac.encode('hex')))
157
158    def delete(self, port):
159        for vid, mac, entry in self.each():
160            if entry.port.number == port.number:
161                self._del_entry(vid, mac)
162                self.dprintf('deleted: port:%d; vid:%d; mac:%s\n',
163                             lambda: (port.number, vid, mac.encode('hex')))
164
165
166class SwitchingHub(DebugMixIn):
167    class Port(object):
168        def __init__(self, number, interface):
169            self.number = number
170            self.interface = interface
171            self.tx = 0
172            self.rx = 0
173            self.shut = False
174
175        @staticmethod
176        def cmp_by_number(x, y):
177            return cmp(x.number, y.number)
178
179    def __init__(self, fdb, debug=False):
180        self.fdb = fdb
181        self._debug = debug
182        self._table = {}
183        self._next = 1
184
185    @property
186    def portlist(self):
187        return sorted(self._table.itervalues(), cmp=self.Port.cmp_by_number)
188
189    def get_port(self, portnum):
190        return self._table[portnum]
191
192    def register_port(self, interface):
193        try:
194            self._set_privattr('portnum', interface, self._next)  # XXX
195            self._table[self._next] = self.Port(self._next, interface)
196            return self._next
197        finally:
198            self._next += 1
199
200    def unregister_port(self, interface):
201        portnum = self._get_privattr('portnum', interface)
202        self._del_privattr('portnum', interface)
203        self.fdb.delete(self._table[portnum])
204        del self._table[portnum]
205
206    def send(self, dst_interfaces, frame):
207        portnums = (self._get_privattr('portnum', i) for i in dst_interfaces)
208        ports = (self._table[n] for n in portnums)
209        ports = (p for p in ports if not p.shut)
210        ports = sorted(ports, cmp=self.Port.cmp_by_number)
211
212        for p in ports:
213            p.interface.write_message(frame.data)
214            p.tx += 1
215
216        if ports:
217            self.dprintf('sent: port:%s; vid:%d; %s -> %s\n',
218                         lambda: (','.join(str(p.number) for p in ports),
219                                  frame.vid,
220                                  frame.src_mac.encode('hex'),
221                                  frame.dst_mac.encode('hex')))
222
223    def receive(self, src_interface, frame):
224        port = self._table[self._get_privattr('portnum', src_interface)]
225
226        if not port.shut:
227            port.rx += 1
228            self._forward(port, frame)
229
230    def _forward(self, src_port, frame):
231        try:
232            if not frame.src_multicast:
233                self.fdb.learn(src_port, frame)
234
235            if not frame.dst_multicast:
236                dst_port = self.fdb.lookup(frame)
237
238                if dst_port:
239                    self.send([dst_port.interface], frame)
240                    return
241
242            ports = set(self.portlist) - set([src_port])
243            self.send((p.interface for p in ports), frame)
244
245        except:  # ex. received invalid frame
246            traceback.print_exc()
247
248    def _privattr(self, name):
249        return '_%s_%s_%s' % (self.__class__.__name__, id(self), name)
250
251    def _set_privattr(self, name, obj, value):
252        return setattr(obj, self._privattr(name), value)
253
254    def _get_privattr(self, name, obj, defaults=None):
255        return getattr(obj, self._privattr(name), defaults)
256
257    def _del_privattr(self, name, obj):
258        return delattr(obj, self._privattr(name))
259
260
261class TapHandler(DebugMixIn):
262    READ_SIZE = 65535
263
264    def __init__(self, ioloop, switch, dev, debug=False):
265        self._ioloop = ioloop
266        self._switch = switch
267        self._dev = dev
268        self._debug = debug
269        self._tap = None
270
271    @property
272    def target(self):
273        if self.closed:
274            return self._dev
275        return self._tap.name
276
277    @property
278    def closed(self):
279        return not self._tap
280
281    @property
282    def address(self):
283        if self.closed:
284            raise ValueError('I/O operation on closed tap')
285        try:
286            return self._tap.addr
287        except:
288            return ''
289
290    @property
291    def netmask(self):
292        if self.closed:
293            raise ValueError('I/O operation on closed tap')
294        try:
295            return self._tap.netmask
296        except:
297            return ''
298
299    @property
300    def mtu(self):
301        if self.closed:
302            raise ValueError('I/O operation on closed tap')
303        return self._tap.mtu
304
305    @address.setter
306    def address(self, address):
307        if self.closed:
308            raise ValueError('I/O operation on closed tap')
309        self._tap.addr = address
310
311    @netmask.setter
312    def netmask(self, netmask):
313        if self.closed:
314            raise ValueError('I/O operation on closed tap')
315        self._tap.netmask = netmask
316
317    @mtu.setter
318    def mtu(self, mtu):
319        if self.closed:
320            raise ValueError('I/O operation on closed tap')
321        self._tap.mtu = mtu
322
323    def open(self):
324        if not self.closed:
325            raise ValueError('Already opened')
326        self._tap = TunTapDevice(self._dev, IFF_TAP | IFF_NO_PI)
327        self._tap.up()
328        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
329        return self._switch.register_port(self)
330
331    def close(self):
332        if self.closed:
333            raise ValueError('I/O operation on closed tap')
334        self._switch.unregister_port(self)
335        self._ioloop.remove_handler(self.fileno())
336        self._tap.close()
337        self._tap = None
338
339    def fileno(self):
340        if self.closed:
341            raise ValueError('I/O operation on closed tap')
342        return self._tap.fileno()
343
344    def write_message(self, message, binary=False):
345        if self.closed:
346            raise ValueError('I/O operation on closed tap')
347        self._tap.write(message)
348
349    def __call__(self, fd, events):
350        try:
351            self._switch.receive(self, EthernetFrame(self._read()))
352            return
353        except:
354            traceback.print_exc()
355        self.close()
356
357    def _read(self):
358        if self.closed:
359            raise ValueError('I/O operation on closed tap')
360        buf = []
361        while True:
362            buf.append(self._tap.read(self.READ_SIZE))
363            if len(buf[-1]) < self.READ_SIZE:
364                break
365        return ''.join(buf)
366
367
368class EtherIRCHandler(DebugMixIn):
369    READ_SIZE = 65535
370
371    def __init__(self, ioloop, switch, server, nick, channel, debug=False):
372        self._ioloop = ioloop
373        self._switch = switch
374        self._server = server
375        self._nick = nick
376        self._channel = channel
377        self._debug = debug
378        self._sock = None
379        self._buffer = []
380
381    @property
382    def closed(self):
383        return not self._sock
384
385    def open(self):
386        self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
387        self._sock.connect(self._server)
388        self._sendraw('NICK %s' % self._nick)
389        self._sendraw('USER %s 0 * :%s' % (self._nick, self._nick))
390        self._sendraw('JOIN %s' % self._channel)
391
392        self._sock.setblocking(False)
393        self._ioloop.add_handler(self.fileno(), self, self._ioloop.READ)
394        return self._switch.register_port(self)
395
396    def close(self):
397        self._switch.unregister_port(self)
398        self._ioloop.remove_handler(self.fileno())
399        self._sock.close()
400        self._sock = None
401
402    def fileno(self):
403        return self._sock.fileno()
404
405    def write_message(self, message):
406        data = message.encode('hex')
407        self._sendraw('PRIVMSG %s %s' % (self._channel, data))
408
409    def _sendraw(self, data):
410        self._sock.send(data + '\r\n')
411
412    def __call__(self, fd, events):
413        close = False
414
415        try:
416            while True:
417                data = self._sock.recv(self.READ_SIZE)
418                if not data:
419                    close = True
420                    break
421                self._buffer.append(data)
422        except socket.error as e:
423            if e.errno != errno.EAGAIN:
424                raise e
425
426        lines = ''.join(self._buffer).split('\r\n')
427        rest = lines.pop(-1)
428
429        self._buffer = []
430        if rest:
431            self._buffer.append(rest)
432
433        for line in lines:
434            line = line.strip()
435            prefix = None
436            if line.startswith(':'):
437                prefix, cmd, params = line.split(' ', 2)
438            else:
439                cmd, params = line.split(' ', 1)
440            method = getattr(self, '_handle_%s' % cmd, self.__handle_default)
441            method(prefix, cmd, params)
442
443        if close:
444            self.close()
445            self.dprintf('connection closed\n')
446            self._ioloop.stop()  # XXX: shutdown process when connection closed
447
448    def __handle_default(self, prefix, cmd, params):
449        self.dprintf('UNKNOWN %s %s %s\n', lambda: (prefix, cmd, params))
450
451    def _handle_PING(self, prefix, cmd, params):
452        self.dprintf('%s %s %s\n', lambda: (prefix, cmd, params))
453        self._sendraw('PONG 0')
454
455    def _handle_PRIVMSG(self, prefix, cmd, params):
456        self.dprintf('%s %s %s\n', lambda: (prefix, cmd, params))
457        to, data = params.split(' ', 1)
458        try:
459            message = data[1:].decode('hex')
460            self._switch.receive(self, EthernetFrame(message))
461        except:
462            traceback.print_exc()
463
464
465def _main():
466    def daemonize(nochdir=False, noclose=False):
467        if os.fork() > 0:
468            sys.exit(0)
469
470        os.setsid()
471
472        if os.fork() > 0:
473            sys.exit(0)
474
475        if not nochdir:
476            os.chdir('/')
477
478        if not noclose:
479            os.umask(0)
480            sys.stdin.close()
481            sys.stdout.close()
482            sys.stderr.close()
483            os.close(0)
484            os.close(1)
485            os.close(2)
486            sys.stdin = open(os.devnull)
487            sys.stdout = open(os.devnull, 'a')
488            sys.stderr = open(os.devnull, 'a')
489
490    parser = argparse.ArgumentParser()
491    parser.add_argument('--device', action='append', default=[])
492    parser.add_argument('--ageout', action='store', type=int, default=300)
493    parser.add_argument('--foreground', action='store_true', default=False)
494    parser.add_argument('--debug', action='store_true', default=False)
495
496    parser.add_argument('host')
497    parser.add_argument('port', type=int, default=6667)
498    parser.add_argument('nick')
499    parser.add_argument('channel')
500
501    args = parser.parse_args()
502
503    if args.ageout <= 0:
504        raise ValueError('invalid ageout: %s' % args.ageout)
505
506    if not (0 < args.port < 65535):
507        raise ValueError('invalid port: %s' % args.port)
508
509    ioloop = tornado.ioloop.IOLoop.instance()
510    fdb = FDB(ageout=args.ageout, debug=args.debug)
511    switch = SwitchingHub(fdb, debug=args.debug)
512
513    client = EtherIRCHandler(ioloop, switch, (args.host, args.port),
514                             args.nick, args.channel, args.debug)
515    client.open()
516
517    for dev in args.device:
518        tap = TapHandler(ioloop, switch, dev, debug=args.debug)
519        tap.open()
520
521    if not args.foreground:
522        daemonize()
523
524    ioloop.start()
525
526
527if __name__ == '__main__':
528    _main()
Note: See TracBrowser for help on using the repository browser.