#!/usr/bin/env python

import os
import pwd
import sys
import time
import errno
import socket
import thread
import signal
import asyncore
import subprocess
from struct import unpack

import psutil

SIGNAL = dict([(name,getattr(signal,name)) for name in dir(signal) if name.startswith('SIG')])

def bytestream (value):
	return ''.join(['%02X' % ord(_) for _ in value])

def dump (value):
	def spaced (value):
		even = None
		for v in value:
			if even is False:
				yield ' '
			yield '%02X' % ord(v)
			even = not even
	return ''.join(spaced(value))


class BGPHandler(asyncore.dispatcher_with_send):

	keepalive = chr(0xFF)*16 + chr(0x0) + chr(0x13) + chr(0x4)

	_name = {
		chr(1) : 'OPEN',
		chr(2) : 'UPDATE',
		chr(3) : 'NOTIFICATION',
		chr(4) : 'KEEPALIVE',
	}

	def signal (self,myself,signal_name='SIGUSR1'):
		signal_number = SIGNAL.get(signal_name,'')
		if not signal_number:
			self.announce('invalid signal name in configuration : %s' % signal_name)
			self.announce('options are: %s' % ','.join(SIGNAL.keys()))
			sys.exit(1)

		conf_name = sys.argv[1].split('/')[-1].split('.')[0]

		processes = []

		for pid in psutil.pids():
			try:
				process = psutil.Process(pid)

				process_name = process.name().lower()
				if 'python' not in process_name and 'pypy' not in process_name:
					continue

				cmdline = process.cmdline()
				if len(cmdline) < 2:
					continue

				if not cmdline[1].endswith('/bgp.py'):
					continue

				if conf_name not in cmdline[-1]:
					continue

				if not cmdline[-1].endswith('.conf'):
					continue

				processes.append(pid)

			except psutil.NoSuchProcess:
				pass
			except psutil.AccessDenied:
				pass

		if len(processes) == 0:
			self.announce('no running process found, this should not happend, quitting')
			sys.exit(1)

		if len(processes) > 1:
			self.announce('more than one process running, this should not happend, quitting')
			sys.exit(1)

		try:
			self.announce('sending signal %s to ExaBGP (pid %s)\n' % (signal_name,processes[0]))
			os.kill(int(processes[0]),signal_number)
		except Exception,exc:
			self.announce('\n     failed: %s' % str(exc))

	def kind (self, header):
		return header[18]

	def isupdate (self, header):
		return header[18] == chr(2)

	def isnotification (self, header):
		return header[18] == chr(4)

	def name (self, header):
		return self._name.get(header[18],'SOME WEIRD RFC PACKET')

	def routes (self, header, body):
		len_w = unpack('!H',body[0:2])[0]
		withdrawn = [ord(_) for _ in body[2:2+len_w]]
		len_a = unpack('!H',body[2+len_w:2+len_w+2])[0]
		announced = [ord(_) for _ in body[2+len_w + 2+len_a:]]

		if not withdrawn and not announced:
			if len(body) == 4:
				yield 'eor:1:1'
			elif len(body) == 11:
				yield 'eor:%d:%d' % (ord(body[-2]),ord(body[-1]))
			else:  # undecoded MP route
				yield 'mp:'

		while withdrawn:
			if len(withdrawn) > 5:
				yield ''
				break
			m = withdrawn.pop(0)
			r = [0,0,0,0]
			for index in range(4):
				if index*8 >= m: break
				r[index] = withdrawn.pop(0)
			yield 'withdraw:%s' % '.'.join(str(_) for _ in r) + '/' + str(m)

		while announced:
			if len(announced) > 5:
				yield ''
				break
			m = announced.pop(0)
			r = [0,0,0,0]
			for index in range(4):
				if index*8 >= m: break
				r[index] = announced.pop(0)
			yield 'announce:%s' % '.'.join(str(_) for _ in r) + '/' + str(m)

	def notification (self, header, body):
		yield 'notification:%d,%d' % (ord(body[0]),ord(body[1])), bytestream(body)

	def announce (self, *args):
		print '    ',self.ip,self.port,' '.join(str(_) for _ in args) if len(args) > 1 else args[0]

	def check_signal (self):
		if self.messages and self.messages[0].startswith('signal:'):
			name = self.messages.pop(0).split(':')[-1]
			self.signal(os.getppid(),name)

	def setup (self, ip, port, messages, options):
		self.ip = ip
		self.port = port
		self.options = options
		self.handle_read = self.handle_open
		self.sequence = {}
		self.raw = False
		for rule in messages:
			sequence,announcement = rule.split(':',1)
			if announcement.startswith('raw:'):
				self.raw = True
				announcement = ''.join(announcement[4:].replace(':',''))
			self.sequence.setdefault(sequence,[]).append(announcement)
		self.update_sequence()
		return self

	def update_sequence (self):
		keys = sorted(list(self.sequence))
		if keys:
			key = keys[0]
			self.messages = self.sequence[key]
			self.step = key
			del self.sequence[key]

			self.check_signal()
			# we had a list with only one signal
			if not self.messages:
				return self.update_sequence()
			return True
		return False

	def read_message (self):
		header = ''
		while len(header) != 19:
			try:
				left = 19-len(header)
				header += self.recv(left)
				if left == 19-len(header):  # ugly
					# the TCP session is gone.
					return None,None
			except socket.error,exc:
				if exc.args[0] in (errno.EWOULDBLOCK,errno.EAGAIN):
					continue
				raise exc

		length = unpack('!H',header[16:18])[0] - 19

		body = ''
		while len(body) != length:
			try:
				left = length-len(body)
				body += self.recv(left)
			except socket.error,exc:
				if exc.args[0] in (errno.EWOULDBLOCK,errno.EAGAIN):
					continue
				raise exc

		return header,body

	def handle_open (self):
		# reply with a IBGP response with the same capability (just changing routerID)
		header,body = self.read_message()
		routerid = chr((ord(body[8])+1) & 0xFF)
		o = header+body[:8]+routerid+body[9:]

		if self.options['send-unknown-capability']:
			# hack capability 66 into the message
			content = 'loremipsum'
			cap66 = chr(66) + chr(len(content)) + content
			param = chr(2) + chr(len(cap66)) + cap66
			o = o[:17] + chr(ord(o[17])+len(param)) + o[18:28] + \
				chr(ord(o[28])+len(param)) + o[29:] + param

		self.send(o)
		self.send(self.keepalive)

		if self.options['send-default-route']:
			self.send(
				chr(0xFF)*16+
				chr(0x00)+chr(0x31)+
				chr(0x02)+
				chr(0x00)+chr(0x00)+
				chr(0x00)+chr(0x15)+
					chr(0x40)+chr(0x01)+chr(0x01)+chr(0x00)+
					chr(0x40)+chr(0x02)+chr(0x00)+
					chr(0x40)+chr(0x03)+chr(0x04)+chr(0x7F)+chr(0x00)+chr(0x00)+chr(0x01)+
					chr(0x40)+chr(0x05)+chr(0x04)+chr(0x00)+chr(0x00)+chr(0x00)+chr(0x64)+
				chr(0x20)+chr(0x00)+chr(0x00)+chr(0x00)+chr(0x00)
			)

		self.handle_read = self.handle_keepalive

	def handle_keepalive (self):
		header,body = self.read_message()

		if header is None:
			self.announce('connection closed')
			self.close()
			if self.options['send-notification']:
				self.announce('successful')
				sys.exit(0)
			return

		if self.raw:
			def parser (self, header, body):
				if body:
					yield bytestream(header+body)
		else:
			parser = self._decoder.get(self.kind(header),None)

		if parser:
			for announcement in parser(self,header,body):
				if announcement.startswith('eor:'):  # skip EOR
					self.announce('skipping eor',announcement)
					continue

				if announcement.startswith('mp:'):  # skip unparsed MP
					self.announce('skipping multiprotocol :',dump(body))
					continue

				if announcement in self.messages:
					self.messages.remove(announcement)
					self.announce('received (%1s%s):' % (self.options['letter'],self.step),announcement)
					self.check_signal()
				else:
					if self.raw:
						self.announce('received (%1s%s):' % (self.options['letter'],self.step),'%s:%s:%s' % (bytestream(header[:16]),bytestream(header[16:]),bytestream(body)))
					else:
						self.announce('received     :',announcement)

					if len(self.messages) > 1:
						self.announce('expected one of the following :')
						for message in self.messages:
							if message.startswith('F'*32):
								self.announce('               %s:%s:%s' % (message[:32],message[32:38],message[38:]))
							else:
								self.announce('               %s' % message)
					elif self.messages:
						message = self.messages[0].upper()
						if message.startswith('F'*32):
							self.announce('expected     : %s:%s:%s' % (message[:32],message[32:38],message[38:]))
						else:
							self.announce('expected     : %s' % message)
					else:
						# can happen when the thread is still running
						self.announce('extra data')
						sys.exit(1)

					sys.exit(1)

				if not self.messages:
					if self.options['single-shot']:
						self.announce('successful (partial test)')
						sys.exit(0)

					if not self.update_sequence():
						if self.options['exit']:
							self.announce('successful')
							sys.exit(0)

		self.send(self.keepalive)

		if self.options['send-notification']:
			notification = 'closing session because we can'
			self.send(
				chr(0xFF)*16+
				chr(0x00)+ chr(19+2+len(notification))+
				chr(0x03)+
				chr(0x06)+
				chr(0x00)+
				notification
			)

	_decoder= {
		chr(2) : routes,
		chr(3) : notification,
	}

class BGPServer (asyncore.dispatcher):
	def announce (self, *args):
		print '    ' + ' '.join(str(_) for _ in args) if len(args) > 1 else args[0]

	def __init__ (self, host, port, messages):
		asyncore.dispatcher.__init__(self)
		self.create_socket(socket.AF_INET,socket.SOCK_STREAM)
		self.set_reuse_addr()
		self.bind((host,port))
		self.listen(5)

		self.messages = {}

		self.options = {
			'send-unknown-capability': False,  # add an unknown capability to the open message
			'send-default-route': False,       # send a default route to the peer
			'send-notification': False,        # send notification messages to the backend
			'signal-SIGUSR1': 0,               # send SIGUSR1 after X seconds
			'single-shot': False,              # we can not test signal on python 2.6
		}

		for message in messages:
			if message.strip() == 'option:open:send-unknown-capability':
				self.options['send-unknown-capability'] = True
				continue
			if message.strip() == 'option:update:send-default-route':
				self.options['send-default-route'] = True
				continue
			if message.strip() == 'option:notification:send-notification':
				self.options['send-notification'] = True
				continue
			if message.strip().startswith('option:SIGUSR1:'):
				def notify (delay,myself):
					time.sleep(delay)
					self.signal(myself)
					time.sleep(10)
					thread.exit()

				# Python 2.6 can not perform this test as it misses the function
				if 'check_output' in dir(subprocess):
					thread.start_new_thread(notify,(int(message.split(':')[-1]),os.getpid()))
				else:
					self.options['single-shot'] = True
				continue

			if message[0].isalpha():
				index,content = message[:1].upper(), message[1:]
			else:
				index,content = 'A',message
			self.messages.setdefault(index,[]).append(content)

	def handle_accept (self):
		messages = None
		for number in range(ord('A'),ord('Z')+1):
			letter = chr(number)
			if letter in self.messages:
				messages = self.messages[letter]
				del self.messages[letter]
				break

		if not messages:
			self.announce('we used all the test data available, can not handle this new connection')
			sys.exit(1)
		else:
			print 'using :\n   ', '\n    '.join(messages),'\n\nconversation:\n'

		self.options['exit'] = not len(self.messages.keys())
		self.options['letter'] = letter

		pair = self.accept()
		if pair is not None:
			sock,addr = pair
			handler = BGPHandler(sock).setup(
				*addr,
				messages=messages,
				options=self.options
			)

def drop ():
	uid = os.getuid()
	gid = os.getgid()

	if uid and gid:
		return

	for name in ['nobody',]:
		try:
			user = pwd.getpwnam(name)
			nuid = int(user.pw_uid)
			ngid = int(user.pw_uid)
		except KeyError:
			pass

	if not gid:
		os.setgid(ngid)
	if not uid:
		os.setuid(nuid)

def main ():
	if len(sys.argv) <= 1:
		print 'a list of expected route announcement/withdrawl in the format <number>:announce:<ipv4-route> <number>:withdraw:<ipv4-route> <number>:raw:<exabgp hex dump : separated>'
		print 'for example:',sys.argv[0],'1:announce:10.0.0.0/8 1:announce:192.0.2.0/24 2:withdraw:10.0.0.0/8 '
		print 'routes with the same <number> can arrive in any order'
		sys.exit(1)

	try:
		with open(sys.argv[1]) as content:
			messages = [_.strip() for _ in content.readlines() if _.strip() and '#' not in _]
	except IOError:
		print 'could not open file', sys.argv[1]
		sys.exit(1)

	try:
		BGPServer('localhost',int(os.environ.get('exabgp.tcp.port','179')),messages)
		drop()
		asyncore.loop()
	except socket.error,exc:
		if exc.errno == errno.EACCES:
			print "failure: could not bind to port %s - most likely not run as root" % os.environ.get('exabgp.tcp.port','179')
		elif exc.errno == errno.EADDRINUSE:
			print "failure: could not bind to port %s - port already in use" % os.environ.get('exabgp.tcp.port','179')
		else:
			print "failure", str(exc)

if __name__ == '__main__':
	main()
