diff --git a/bot.py b/bot.py index b8399bf..daeabba 100755 --- a/bot.py +++ b/bot.py @@ -1,11 +1,13 @@ #!/usr/bin/env python3 import asyncio +import functools import glob import importlib import yaml import os import re +import signal import sys import traceback import urllib.parse @@ -78,10 +80,10 @@ class Bot: await self.client.room_send(room.room_id, 'm.room.message', msg) def remove_callback(self, callback): - for cb_object in bot.client.event_callbacks: + for cb_object in self.client.event_callbacks: if cb_object.func == callback: self.logger.info("remove callback") - bot.client.event_callbacks.remove(cb_object) + self.client.event_callbacks.remove(cb_object) def get_room_by_id(self, room_id): return self.client.rooms[room_id] @@ -148,7 +150,7 @@ class Bot: if moduleobject is not None: if moduleobject.enabled: try: - await moduleobject.matrix_message(bot, room, event) + await moduleobject.matrix_message(self, room, event) except CommandRequiresAdmin: await self.send_text(room, f'Sorry, you need admin power level in this room to run that command.') except CommandRequiresOwner: @@ -219,7 +221,7 @@ class Bot: for modulename, moduleobject in self.modules.items(): if moduleobject.enabled: try: - await moduleobject.matrix_poll(bot, self.pollcount) + await moduleobject.matrix_poll(self, self.pollcount) except Exception: traceback.print_exc(file=sys.stderr) await asyncio.sleep(10) @@ -287,7 +289,7 @@ class Bot: for modulename, moduleobject in self.modules.items(): if moduleobject.enabled: try: - moduleobject.matrix_start(bot) + moduleobject.matrix_start(self) except Exception: traceback.print_exc(file=sys.stderr) @@ -295,7 +297,7 @@ class Bot: self.logger.info(f'Stopping {len(self.modules)} modules..') for modulename, moduleobject in self.modules.items(): try: - moduleobject.matrix_stop(bot) + moduleobject.matrix_stop(self) except Exception: traceback.print_exc(file=sys.stderr) @@ -353,15 +355,30 @@ class Bot: else: await self.client.client_session.close() + def handle_exit(self, signame, loop): + self.logger.info(f"Received signal {signame}") + if self.poll_task: + self.poll_task.cancel() + self.bot_task.cancel() + self.stop() + + +async def main(): + bot = Bot() + bot.init() + + loop = asyncio.get_running_loop() + + for signame in {'SIGINT', 'SIGTERM'}: + loop.add_signal_handler( + getattr(signal, signame), + functools.partial(bot.handle_exit, signame, loop)) + + await bot.run() + await bot.shutdown() + -bot = Bot() -bot.init() try: - asyncio.get_event_loop().run_until_complete(bot.run()) -except KeyboardInterrupt: - if bot.poll_task: - bot.poll_task.cancel() - bot.bot_task.cancel() - -bot.stop() -asyncio.get_event_loop().run_until_complete(bot.shutdown()) + asyncio.run(main()) +except Exception as e: + print(e)