import asyncio import datetime import logging import sys from typing import List, Union import discord from discord.ext import commands from rich import box from rich.columns import Columns from rich.console import Console from rich.panel import Panel from rich.progress import Progress, TextColumn, BarColumn from rich.table import Table from rich.traceback import install from tuxbot import version_info from . import Config from .data_manager import logs_data_path from . import __version__, ExitCodes from .utils.functions.extra import ContextPlus log = logging.getLogger("tuxbot") console = Console() install(console=console) packages: List[str] = ["jishaku", "tuxbot.cogs.warnings", "tuxbot.cogs.admin"] class Tux(commands.AutoShardedBot): _loading: asyncio.Task _progress = { 'main': Progress( TextColumn("[bold blue]{task.fields[task_name]}", justify="right"), BarColumn() ), 'tasks': {} } def __init__(self, *args, cli_flags=None, **kwargs): # by default, if the bot shutdown without any intervention, # it's a crash self.shutdown_code = ExitCodes.CRITICAL self.cli_flags = cli_flags self.instance_name = self.cli_flags.instance_name self.last_exception = None self.logs = logs_data_path(self.instance_name) self.config = Config(self.instance_name) async def _prefixes(bot, message) -> List[str]: prefixes = self.config("core").get("prefixes") prefixes.extend(self.config.get_prefixes(message.guild)) if self.config("core").get("mentionable"): return commands.when_mentioned_or(*prefixes)(bot, message) return prefixes if "command_prefix" not in kwargs: kwargs["command_prefix"] = _prefixes if "owner_ids" in kwargs: kwargs["owner_ids"] = set(kwargs["owner_ids"]) else: kwargs["owner_ids"] = self.config.owners_id() message_cache_size = 100_000 kwargs["max_messages"] = message_cache_size self.max_messages = message_cache_size self.uptime = None self._app_owners_fetched = False # to prevent abusive API calls super().__init__(*args, help_command=None, **kwargs) async def load_packages(self): if packages: with Progress() as progress: task = progress.add_task( "Loading packages...", total=len(packages) ) for package in packages: try: self.load_extension(package) progress.console.print(f"{package} loaded") except Exception as e: log.exception( f"Failed to load package {package}", exc_info=e ) progress.console.print( f"[red]Failed to load package {package} " f"[i](see " f"{str((self.logs / 'tuxbot.log').resolve())} " f"for more details)[/i]" ) progress.advance(task) async def on_ready(self): self.uptime = datetime.datetime.now() self._progress.get("main").stop_task( self._progress.get("tasks")["connecting"] ) self._progress.get("main").remove_task( self._progress.get("tasks")["connecting"] ) console.clear() console.print( Panel(f"[bold blue]Tuxbot V{version_info.major}", style="blue"), justify="center" ) console.print() columns = Columns(expand=True, padding=2, align="center") table = Table( style="dim", border_style="not dim", box=box.HEAVY_HEAD ) table.add_column( "INFO", ) table.add_row(str(self.user)) table.add_row(f"Prefixes: {', '.join(self.config('core').get('prefixes'))}") table.add_row(f"Language: {self.config('core').get('locale')}") table.add_row(f"Tuxbot Version: {__version__}") table.add_row(f"Discord.py Version: {discord.__version__}") table.add_row(f"Shards: {self.shard_count}") table.add_row(f"Servers: {len(self.guilds)}") table.add_row(f"Users: {len(self.users)}") columns.add_renderable(table) table = Table( style="dim", border_style="not dim", box=box.HEAVY_HEAD ) table.add_column( "COGS", ) for extension in packages: if extension in self.extensions: status = f"[green]:heavy_check_mark: {extension}" else: status = f"[red]:heavy_multiplication_x: {extension}" table.add_row(status) columns.add_renderable(table) console.print(columns) console.print() async def is_owner(self, user: Union[discord.User, discord.Member]) -> bool: """Determines if the user is a bot owner. Parameters ---------- user: Union[discord.User, discord.Member] Returns ------- bool """ if user.id in self.config.owners_id(): return True owner = False if not self._app_owners_fetched: app = await self.application_info() if app.team: ids = [m.id for m in app.team.members] await self.config.update("core", "owners_id", ids) owner = user.id in ids self._app_owners_fetched = True return owner async def get_context(self, message: discord.Message, *, cls=None): return await super().get_context(message, cls=ContextPlus) async def process_commands(self, message: discord.Message): """Check for blacklists. """ if message.author.bot: return if ( message.guild.id in self.config.get_blacklist("guild") or message.channel.id in self.config.get_blacklist("channel") or message.author.id in self.config.get_blacklist("user") ): return ctx = await self.get_context(message) if ctx is None or ctx.valid is False: self.dispatch("message_without_command", message) else: await self.invoke(ctx) async def on_message(self, message: discord.Message): await self.process_commands(message) async def start(self, token, bot): """Connect to Discord and start all connections. Todo: add postgresql connect here """ with self._progress.get("main") as pg: task_id = self._progress.get("tasks")["connecting"] = pg.add_task( "connecting", task_name="Connecting to Discord...", start=False ) pg.update(task_id) await super().start(token, bot=bot) async def logout(self): """Disconnect from Discord and closes all actives connections. Todo: add postgresql logout here """ for task in self._progress.get("tasks").keys(): self._progress.get("main").log("Shutting down", task) self._progress.get("main").stop_task( self._progress.get("tasks")[task] ) self._progress.get("main").remove_task( self._progress.get("tasks")["connecting"] ) self._progress.get("main").stop() pending = [ t for t in asyncio.all_tasks() if t is not asyncio.current_task() ] for task in pending: console.log("Canceling", task.get_name(), f"({task.get_coro()})") task.cancel() await asyncio.gather(*pending, return_exceptions=True) await super().logout() async def shutdown(self, *, restart: bool = False): """Gracefully quit. Parameters ---------- restart:bool If `True`, systemd or the launcher gonna see custom exit code and reboot. """ if not restart: self.shutdown_code = ExitCodes.SHUTDOWN else: self.shutdown_code = ExitCodes.RESTART await self.logout() sys.exit(self.shutdown_code)