#!/usr/bin/env python3
# encoding: utf-8
"""
bgp

Created by Thomas Mangin
Copyright (c) 2013-2017 Exa Networks. All rights reserved.
License: 3-clause BSD. (See the COPYRIGHT file)
"""

import os
import pwd
import sys
import time
import errno
import socket
import threading
import signal
import asyncore
import subprocess
from struct import unpack

SIGNAL = dict([(name, getattr(signal, name)) for name in dir(signal) if name.startswith('SIG')])


def flushed(*output):
    print(' '.join(str(_) for _ in output))
    sys.stdout.flush()


def bytestream(value):
    return ''.join(['%02X' % _ for _ in value])


def dump(value):
    def spaced(value):
        even = None
        for v in value:
            if even is False:
                yield ' '
            yield '%02X' % v
            even = not even

    return ''.join(spaced(value))


def cdr_to_length(cidr):
    if cidr > 24:
        return 4
    if cidr > 16:
        return 3
    if cidr > 8:
        return 2
    if cidr > 0:
        return 1
    return 0


class BGPHandler(asyncore.dispatcher_with_send):
    counter = 0

    keepalive = bytearray([0xFF,] * 16 + [0x0, 0x13, 0x4])

    _name = {
        b'\x01': 'OPEN',
        b'\x02': 'UPDATE',
        b'\x03': 'NOTIFICATION',
        b'\x04': 'KEEPALIVE',
    }

    def signal(self, myself, signal_name='SIGUSR1'):
        signal_number = SIGNAL.get(signal_name, '')
        if not signal_number:
            self.announce('invalid signal name in configuration : %s' % signal_name)
            self.announce('options are: %s' % ','.join(SIGNAL.keys()))
            sys.exit(1)

        conf_name = sys.argv[1].split('/')[-1].split('.')[0]

        processes = []

        for line in os.popen("/bin/ps x"):
            low = line.strip().lower()
            if not low:
                continue
            if 'python' not in low and 'pypy' not in low:
                continue

            cmdline = line.strip().split()[4:]
            pid = line.strip().split()[0]

            if len(cmdline) > 1 and not cmdline[1].endswith('/bgp.py'):
                continue

            if conf_name not in cmdline[-1]:
                continue

            if not cmdline[-1].endswith('.conf'):
                continue

            processes.append(pid)

        if len(processes) == 0:
            self.announce('no running process found, this should not happend, quitting')
            sys.exit(1)

        if len(processes) > 1:
            self.announce('more than one process running, this should not happend, quitting')
            sys.exit(1)

        try:
            self.announce('sending signal %s to ExaBGP (pid %s)\n' % (signal_name, processes[0]))
            os.kill(int(processes[0]), signal_number)
        except Exception as exc:
            self.announce('\n     failed: %s' % str(exc))

    def kind(self, header):
        return header[18]

    def isupdate(self, header):
        return header[18] == 2

    def isnotification(self, header):
        return header[18] == 4

    def name(self, header):
        return self._name.get(header[18], 'SOME WEIRD RFC PACKET')

    def routes(self, header, body):
        len_w = unpack('!H', body[0:2])[0]
        withdrawn = bytearray([_ for _ in body[2 : 2 + len_w]])
        len_a = unpack('!H', body[2 + len_w : 2 + len_w + 2])[0]
        announced = bytearray([_ for _ in body[2 + len_w + 2 + len_a :]])

        if not withdrawn and not announced:
            if len(body) == 4:
                yield 'eor:1:1'
            elif len(body) == 11:
                yield 'eor:%d:%d' % (body[-2], body[-1])
            else:  # undecoded MP route
                yield 'mp:'
            return

        while withdrawn:
            cdr, withdrawn = withdrawn[0], withdrawn[1:]
            size = cdr_to_length(cdr)
            r = [0, 0, 0, 0]
            for index in range(size):
                r[index], withdrawn = withdrawn[0], withdrawn[1:]
            yield 'withdraw:%s' % '.'.join(str(_) for _ in r) + '/' + str(cdr)

        while announced:
            cdr, announced = announced[0], announced[1:]
            size = cdr_to_length(cdr)
            r = [0, 0, 0, 0]
            for index in range(size):
                r[index], announced = announced[0], announced[1:]
            yield 'announce:%s' % '.'.join(str(_) for _ in r) + '/' + str(cdr)

    def notification(self, header, body):
        yield 'notification:%d,%d' % (body[0], body[1]), bytestream(body)

    def announce(self, *args):
        flushed('    ', self.ip, self.port, ' '.join(str(_) for _ in args) if len(args) > 1 else args[0])

    def check_signal(self):
        if self.messages and self.messages[0].startswith('signal:'):
            name = self.messages.pop(0).split(':')[-1]
            self.signal(os.getppid(), name)

    def setup(self, ip, port, messages, options):
        self.ip = ip
        self.port = port
        self.options = options
        self.handle_read = self.handle_open
        self.sequence = {}
        self.raw = False
        for rule in messages:
            sequence, announcement = rule.split(':', 1)
            if announcement.startswith('raw:'):
                self.raw = True
                announcement = ''.join(announcement[4:].replace(':', ''))
            self.sequence.setdefault(sequence, []).append(announcement)
        self.update_sequence()
        return self

    def update_sequence(self):
        if self.options['sink'] or self.options['echo']:
            self.messages = []
            return True
        keys = sorted(list(self.sequence))
        if keys:
            key = keys[0]
            self.messages = self.sequence[key]
            self.step = key
            del self.sequence[key]

            self.check_signal()
            # we had a list with only one signal
            if not self.messages:
                return self.update_sequence()
            return True
        return False

    def read_message(self):
        header = b''
        while len(header) != 19:
            try:
                left = 19 - len(header)
                header += self.recv(left)
                if left == 19 - len(header):  # ugly
                    # the TCP session is gone.
                    return None, None
            except socket.error as exc:
                if exc.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
                    continue
                raise exc

        length = unpack('!H', header[16:18])[0] - 19

        body = b''
        while len(body) != length:
            try:
                left = length - len(body)
                body += self.recv(left)
            except socket.error as exc:
                if exc.args[0] in (errno.EWOULDBLOCK, errno.EAGAIN):
                    continue
                raise exc

        return bytearray(header), bytearray(body)

    def handle_open(self):
        # reply with a IBGP response with the same capability (just changing routerID)
        header, body = self.read_message()
        routerid = bytearray([body[8] + 1 & 0xFF])
        o = header + body[:8] + routerid + body[9:]

        if self.options['send-unknown-capability']:
            # hack capability 66 into the message

            content = b'loremipsum'
            cap66 = bytearray([66, len(content)]) + content
            param = bytearray([2, len(cap66)]) + cap66
            o = o[:17] + bytearray([o[17] + len(param)]) + o[18:28] + bytearray([o[28] + len(param)]) + o[29:] + param

        self.send(o)
        self.send(self.keepalive)

        if self.options['send-default-route']:
            self.send(
                bytearray(
                    [0xFF,] * 16
                    + [0x00, 0x31]
                    + [0x02,]
                    + [0x00, 0x00]
                    + [0x00, 0x15]
                    + []
                    + [0x40, 0x01, 0x01, 0x00]
                    + []
                    + [0x40, 0x02, 0x00]
                    + []
                    + [0x40, 0x03, 0x04, 0x7F, 0x00, 0x00, 0x01]
                    + []
                    + [0x40, 0x05, 0x04, 0x00, 0x00, 0x00, 0x64]
                    + [0x20, 0x00, 0x00, 0x00, 0x00]
                )
            )
            self.announce('sending default-route\n')

        self.handle_read = self.handle_keepalive

    def handle_keepalive(self):
        header, body = self.read_message()

        if header is None:
            self.announce('connection closed')
            self.close()
            if self.options['send-notification']:
                self.announce('successful')
                sys.exit(0)
            return

        if self.raw:

            def parser(self, header, body):
                if body:
                    yield bytestream(header + body)

        else:
            parser = self._decoder.get(self.kind(header), None)

        if self.options['sink']:
            self.announce(
                'received %d: %s'
                % (
                    self.counter,
                    '%s:%s:%s:%s'
                    % (bytestream(header[:16]), bytestream(header[16:18]), bytestream(header[18:]), bytestream(body)),
                )
            )
            self.send(self.keepalive)
            return

        if self.options['echo']:
            self.announce(
                'received %d: %s'
                % (
                    self.counter,
                    '%s:%s:%s:%s'
                    % (bytestream(header[:16]), bytestream(header[16:18]), bytestream(header[18:]), bytestream(body)),
                )
            )
            self.send(header + body)
            self.announce(
                'sent     %d: %s'
                % (
                    self.counter,
                    '%s:%s:%s:%s'
                    % (bytestream(header[:16]), bytestream(header[16:18]), bytestream(header[18:]), bytestream(body)),
                )
            )
            return

        if parser:
            for announcement in parser(self, header, body):
                self.send(self.keepalive)
                if announcement.startswith('eor:'):  # skip EOR
                    self.announce('skipping eor', announcement)
                    continue

                if announcement.startswith('mp:'):  # skip unparsed MP
                    self.announce('skipping multiprotocol :', dump(body))
                    continue

                self.counter += 1

                if announcement in self.messages:
                    self.messages.remove(announcement)
                    if self.raw:
                        self.announce(
                            'received %d (%1s%s):' % (self.counter, self.options['letter'], self.step),
                            '%s:%s:%s:%s'
                            % (announcement[:32], announcement[32:36], announcement[36:38], announcement[38:]),
                        )
                    else:
                        self.announce(
                            'received %d (%1s%s):' % (self.counter, self.options['letter'], self.step), announcement
                        )
                    self.check_signal()
                else:
                    if self.raw:
                        self.announce(
                            'received %d (%1s%s):' % (self.counter, self.options['letter'], self.step),
                            '%s:%s:%s:%s'
                            % (
                                bytestream(header[:16]),
                                bytestream(header[16:18]),
                                bytestream(header[18:]),
                                bytestream(body),
                            ),
                        )
                    else:
                        self.announce('received %d     :' % self.counter, announcement)

                    if len(self.messages) > 1:
                        self.announce('expected one of the following :')
                        for message in self.messages:
                            if message.startswith('F' * 32):
                                self.announce(
                                    '                 %s:%s:%s:%s'
                                    % (message[:32], message[32:36], message[36:38], message[38:])
                                )
                            else:
                                self.announce('                 %s' % message)
                    elif self.messages:
                        message = self.messages[0].upper()
                        if message.startswith('F' * 32):
                            self.announce('expected       : %s:%s:%s:%s' % (message[:32], message[32:36], message[36:38], message[38:]))
                        else:
                            self.announce('expected       : %s' % message)
                    else:
                        # can happen when the thread is still running
                        self.announce('extra data')
                        sys.exit(1)

                    sys.exit(1)

                if not self.messages:
                    if self.options['single-shot']:
                        self.announce('successful (partial test)')
                        sys.exit(0)

                    if not self.update_sequence():
                        if self.options['exit']:
                            self.announce('successful')
                            sys.exit(0)
        else:
            self.send(self.keepalive)

        if self.options['send-notification']:
            notification = b'closing session because we can'
            self.send(
                bytearray([0xFF,] * 16 + [0x00, 19 + 2 + len(notification)] + [0x03] + [0x06] + [0x00]) + notification
            )

    _decoder = {
        2: routes,
        3: notification,
    }


class BGPServer(asyncore.dispatcher):
    def announce(self, *args):
        flushed('    ' + ' '.join(str(_) for _ in args) if len(args) > 1 else args[0])

    def __init__(self, host, options):
        asyncore.dispatcher.__init__(self)

        if ':' in host:
            self.create_socket(socket.AF_INET6, socket.SOCK_STREAM)
        else:
            self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
        self.set_reuse_addr()
        self.bind((host, options['port']))
        self.listen(5)

        self.messages = {}

        self.options = {
            'send-unknown-capability': False,  # add an unknown capability to the open message
            'send-default-route': False,  # send a default route to the peer
            'send-notification': False,  # send notification messages to the backend
            'signal-SIGUSR1': 0,  # send SIGUSR1 after X seconds
            'single-shot': False,  # we can not test signal on python 2.6
            'sink': False,  # just accept whatever is sent
            'echo': False,  # just accept whatever is sent
        }
        self.options.update(options)

        for message in options['messages']:
            if message.strip() == 'option:open:send-unknown-capability':
                self.options['send-unknown-capability'] = True
                continue
            if message.strip() == 'option:update:send-default-route':
                self.options['send-default-route'] = True
                continue
            if message.strip() == 'option:notification:send-notification':
                self.options['send-notification'] = True
                continue
            if message.strip().startswith('option:SIGUSR1:'):

                def notify(delay, myself):
                    time.sleep(delay)
                    self.signal(myself)
                    time.sleep(10)

                # Python 2.6 can not perform this test as it misses the function
                if 'check_output' in dir(subprocess):
                    # thread.start_new_thread(notify,(int(message.split(':')[-1]),os.getpid()))
                    threading.Thread(target=notify, args=(int(message.split(':')[-1]), os.getpid()))
                else:
                    self.options['single-shot'] = True
                continue

            if message[0].isalpha():
                index, content = message[:1].upper(), message[1:]
            else:
                index, content = 'A', message
            self.messages.setdefault(index, []).append(content)

    def handle_accept(self):
        messages = None
        for number in range(ord('A'), ord('Z') + 1):
            letter = chr(number)
            if letter in self.messages:
                messages = self.messages[letter]
                del self.messages[letter]
                break

        if self.options['sink']:
            flushed('\nsink mode - send us whatever, we can take it ! :p\n')
            messages = []
        elif self.options['echo']:
            flushed('\necho mode - send us whatever, we can parrot it ! :p\n')
            messages = []
        elif not messages:
            self.announce('we used all the test data available, can not handle this new connection')
            sys.exit(1)
        else:
            flushed('using :\n   ', '\n    '.join(messages), '\n\nconversation:\n')

        self.options['exit'] = not len(self.messages.keys())
        self.options['letter'] = letter

        pair = self.accept()
        if pair is not None:
            sock, addr = pair
            handler = BGPHandler(sock).setup(*addr[:2], messages=messages, options=self.options)


def drop():
    uid = os.getuid()
    gid = os.getgid()

    if uid and gid:
        return

    for name in [
        'nobody',
    ]:
        try:
            user = pwd.getpwnam(name)
            nuid = int(user.pw_uid)
            ngid = int(user.pw_uid)
        except KeyError:
            pass

    if not gid:
        os.setgid(ngid)
    if not uid:
        os.setuid(nuid)


def main():
    port = os.environ.get('exabgp.tcp.port', os.environ.get('exabgp_tcp_port', '179'))

    if not port.isdigit() and port > 0 and port <= 65535 or len(sys.argv) <= 1:
        flushed('--sink   accept any BGP messages and reply with a keepalive')
        flushed('--echo   accept any BGP messages send it back to the emiter')
        flushed('--port <port>   port to bind to')
        flushed(
            'a list of expected route announcement/withdrawl in the format <number>:announce:<ipv4-route> <number>:withdraw:<ipv4-route> <number>:raw:<exabgp hex dump : separated>'
        )
        flushed('for example:', sys.argv[0], '1:announce:10.0.0.0/8 1:announce:192.0.2.0/24 2:withdraw:10.0.0.0/8 ')
        flushed('routes with the same <number> can arrive in any order')
        sys.exit(1)

    options = {'sink': False, 'echo': False, 'port': int(port), 'messages': []}

    for arg in sys.argv[1:]:
        if arg == '--sink':
            messages = []
            options['sink'] = True
            continue

        if arg == '--echo':
            messages = []
            options['echo'] = True
            continue

        if arg == '--port':
            args = sys.argv[1:] + [
                '',
            ]
            port = args[args.index('--port') + 1]
            if port.isdigit() and int(port) > 0:
                options['port'] = int(port)
                continue
            print('invalid port %s' % port)
            sys.exit(1)

        if arg == str(options['port']):
            continue

        try:
            with open(sys.argv[1]) as content:
                options['messages'] = [_.strip() for _ in content.readlines() if _.strip() and '#' not in _]
        except IOError:
            flushed('could not open file', sys.argv[1])
            sys.exit(1)

    try:
        BGPServer('127.0.0.1', options)
        try:
            BGPServer('::1', options)
        except:
            # does not work on travis-ci
            pass
        drop()
        asyncore.loop()
    except socket.error as exc:
        if exc.errno == errno.EACCES:
            flushed('failure: could not bind to port %s - most likely not run as root' % port)
        elif exc.errno == errno.EADDRINUSE:
            flushed('failure: could not bind to port %s - port already in use' % port)
        else:
            flushed('failure', str(exc))


if __name__ == '__main__':
    main()
