tuxbot-bot/tuxbot/core/bot.py

376 lines
12 KiB
Python
Raw Normal View History

2020-06-05 00:29:14 +02:00
import asyncio
import datetime
import importlib
2020-06-04 19:16:51 +02:00
import logging
2020-10-22 00:00:48 +02:00
from collections import Counter
2020-06-05 00:29:14 +02:00
from typing import List, Union
2020-10-22 00:00:48 +02:00
import aiohttp
2020-06-04 19:16:51 +02:00
import discord
from discord.ext import commands
from rich import box
from rich.columns import Columns
from rich.console import Console
from rich.panel import Panel
from rich.progress import Progress, TextColumn, BarColumn
from rich.table import Table
from tortoise import Tortoise
from tuxbot import version_info
from tuxbot.core.utils.data_manager import (
logs_data_path,
data_path,
config_dir,
)
2020-10-19 22:17:19 +02:00
from .config import (
Config,
ConfigFile,
search_for,
AppConfig,
set_for_key,
)
2020-08-28 01:06:57 +02:00
from . import __version__, ExitCodes
from . import exceptions
2020-06-05 00:29:14 +02:00
from .utils.functions.extra import ContextPlus
from .utils.functions.prefix import get_prefixes
2020-06-04 19:16:51 +02:00
log = logging.getLogger("tuxbot")
console = Console()
2020-06-04 19:16:51 +02:00
packages: List[str] = [
"jishaku",
2020-11-09 01:18:55 +01:00
"tuxbot.cogs.Admin",
"tuxbot.cogs.Logs",
# "tuxbot.cogs.Dev",
2020-11-09 01:18:55 +01:00
"tuxbot.cogs.Utils",
"tuxbot.cogs.Polls",
"tuxbot.cogs.Custom",
"tuxbot.cogs.Network",
]
class Tux(commands.AutoShardedBot):
2020-06-05 00:29:14 +02:00
_loading: asyncio.Task
_progress = {
"main": Progress(
TextColumn("[bold blue]{task.fields[task_name]}", justify="right"),
BarColumn(),
),
"tasks": {},
}
2020-06-05 00:29:14 +02:00
def __init__(self, *args, cli_flags=None, **kwargs):
2020-06-03 19:41:30 +02:00
# by default, if the bot shutdown without any intervention,
# it's a crash
2020-06-05 00:29:14 +02:00
self.shutdown_code = ExitCodes.CRITICAL
self.cli_flags = cli_flags
self.instance_name = self.cli_flags.instance_name
self.last_exception = None
2020-06-05 00:29:14 +02:00
self.logs = logs_data_path(self.instance_name)
2020-06-03 19:41:30 +02:00
2020-10-22 00:00:48 +02:00
self.console = console
self.stats = {"commands": Counter(), "socket": Counter()}
self.config: Config = ConfigFile(
str(data_path(self.instance_name) / "config.yaml"), Config
).config
2020-06-04 19:16:51 +02:00
async def _prefixes(bot, message) -> List[str]:
prefixes = self.config.Core.prefixes
2020-06-05 00:29:14 +02:00
prefixes.extend(get_prefixes(self, message.guild))
2020-06-04 19:16:51 +02:00
if self.config.Core.mentionable:
2020-06-04 19:16:51 +02:00
return commands.when_mentioned_or(*prefixes)(bot, message)
return prefixes
if "command_prefix" not in kwargs:
kwargs["command_prefix"] = _prefixes
2020-06-03 19:41:30 +02:00
if "owner_ids" in kwargs:
kwargs["owner_ids"] = set(kwargs["owner_ids"])
else:
kwargs["owner_ids"] = self.config.Core.owners_id
2020-06-03 19:41:30 +02:00
message_cache_size = 100_000
kwargs["max_messages"] = message_cache_size
self.max_messages = message_cache_size
2020-06-03 19:41:30 +02:00
self.uptime = None
2020-06-05 00:29:14 +02:00
self._app_owners_fetched = False # to prevent abusive API calls
2020-11-09 01:27:19 +01:00
super().__init__(
*args, help_command=None, intents=discord.Intents.all(), **kwargs
)
2020-10-22 00:00:48 +02:00
self.session = aiohttp.ClientSession(loop=self.loop)
async def _is_blacklister(self, message: discord.Message) -> bool:
"""Check for blacklists."""
if message.author.bot:
return True
if (
search_for(self.config.Servers, message.guild.id, "blacklisted")
or search_for(
self.config.Channels, message.channel.id, "blacklisted"
)
or search_for(self.config.Users, message.author.id, "blacklisted")
):
return True
return False
2020-06-05 00:29:14 +02:00
async def load_packages(self):
if packages:
with Progress() as progress:
task = progress.add_task(
"Loading packages...", total=len(packages)
)
for package in packages:
try:
self.load_extension(package)
progress.console.print(f"{package} loaded")
except Exception as e:
log.exception(
"Failed to load package %s", package, exc_info=e
)
progress.console.print(
f"[red]Failed to load package {package} "
f"[i](see "
f"{str((self.logs / 'tuxbot.log').resolve())} "
f"for more details)[/i]"
)
progress.advance(task)
2020-06-06 18:51:47 +02:00
2020-06-04 19:16:51 +02:00
async def on_ready(self):
2020-06-05 00:29:14 +02:00
self.uptime = datetime.datetime.now()
2020-10-19 22:17:19 +02:00
app_config = ConfigFile(config_dir / "config.yaml", AppConfig).config
set_for_key(
app_config.Instances,
self.instance_name,
AppConfig.Instance,
active=True,
last_run=datetime.datetime.timestamp(self.uptime),
)
2020-10-21 00:26:40 +02:00
self._progress["main"].stop_task(self._progress["tasks"]["connecting"])
2020-10-21 00:09:47 +02:00
self._progress["main"].remove_task(
self._progress["tasks"]["connecting"]
)
2020-10-21 00:09:47 +02:00
self._progress["tasks"].pop("connecting")
console.clear()
console.print(
Panel(f"[bold blue]Tuxbot V{version_info.major}", style="blue"),
justify="center",
)
console.print()
columns = Columns(expand=True, align="center")
table = Table(style="dim", border_style="not dim", box=box.HEAVY_HEAD)
table.add_column(
"INFO",
)
table.add_row(str(self.user))
table.add_row(f"Prefixes: {', '.join(self.config.Core.prefixes)}")
table.add_row(f"Language: {self.config.Core.locale}")
table.add_row(f"Tuxbot Version: {__version__}")
table.add_row(f"Discord.py Version: {discord.__version__}")
table.add_row(f"Shards: {self.shard_count}")
table.add_row(f"Servers: {len(self.guilds)}")
table.add_row(f"Users: {len(self.users)}")
columns.add_renderable(table)
table = Table(style="dim", border_style="not dim", box=box.HEAVY_HEAD)
table.add_column(
"COGS",
)
2020-06-05 00:29:14 +02:00
for extension in packages:
if extension in self.extensions:
status = f"[green]:heavy_check_mark: {extension} "
else:
status = f"[red]:heavy_multiplication_x: {extension} "
2020-06-04 19:16:51 +02:00
table.add_row(status)
columns.add_renderable(table)
2020-06-04 19:16:51 +02:00
console.print(columns)
console.print()
2020-06-05 00:29:14 +02:00
async def is_owner(
self, user: Union[discord.User, discord.Member]
) -> bool:
2020-06-05 00:29:14 +02:00
"""Determines if the user is a bot owner.
Parameters
----------
user: Union[discord.User, discord.Member]
Returns
-------
bool
"""
if user.id in self.config.Core.owners_id:
2020-06-05 00:29:14 +02:00
return True
owner = False
if not self._app_owners_fetched:
app = await self.application_info()
if app.team:
ids = [m.id for m in app.team.members]
self.config.Core.owners_id = ids
2020-06-05 00:29:14 +02:00
owner = user.id in ids
self._app_owners_fetched = True
return owner
# pylint: disable=unused-argument
2020-06-05 00:29:14 +02:00
async def get_context(self, message: discord.Message, *, cls=None):
return await super().get_context(message, cls=ContextPlus)
async def process_commands(self, message: discord.Message):
ctx: ContextPlus = await self.get_context(message)
2020-06-05 00:29:14 +02:00
if ctx is None or not ctx.valid:
if user_aliases := search_for(
self.config.Users, message.author.id, "aliases"
):
for alias, command in user_aliases.items():
back_content = message.content
message.content = message.content.replace(
alias, command, 1
)
if (
ctx := await self.get_context(message)
) is None or not ctx.valid:
message.content = back_content
else:
break
2020-06-05 00:29:14 +02:00
self.dispatch("message_without_command", message)
if ctx is not None and ctx.valid:
2020-10-19 00:53:26 +02:00
if ctx.command in search_for(
self.config.Servers, message.guild.id, "disabled_command", []
):
raise exceptions.DisabledCommandByServerOwner
if ctx.command in self.config.Core.disabled_command:
raise exceptions.DisabledCommandByBotOwner
2020-06-05 00:29:14 +02:00
await self.invoke(ctx)
async def on_message(self, message: discord.Message):
await self.process_commands(message)
async def start(self, token, bot): # pylint: disable=arguments-differ
"""Connect to Discord and start all connections.
Todo: add postgresql connect here
"""
with self._progress.get("main") as progress:
task_id = self._progress.get("tasks")[
"connecting"
] = progress.add_task(
"connecting",
task_name="Connecting to PostgreSQL...",
start=False,
)
models = []
for extension, _ in self.extensions.items():
if extension == "jishaku":
continue
if importlib.import_module(extension).HAS_MODELS:
models.append(f"{extension}.models.__init__")
progress.update(task_id)
await Tortoise.init(
db_url="postgres://{}:{}@{}:{}/{}".format(
self.config.Core.Database.username,
self.config.Core.Database.password,
self.config.Core.Database.domain,
self.config.Core.Database.port,
self.config.Core.Database.db_name,
),
modules={"models": models},
)
await Tortoise.generate_schemas()
self._progress["main"].stop_task(self._progress["tasks"]["connecting"])
self._progress["main"].remove_task(
self._progress["tasks"]["connecting"]
)
self._progress["tasks"].pop("connecting")
with self._progress.get("main") as progress:
task_id = self._progress.get("tasks")[
"connecting"
] = progress.add_task(
"connecting", task_name="Connecting to Discord...", start=False
)
progress.update(task_id)
await super().start(token, bot=bot)
2020-06-05 00:29:14 +02:00
async def logout(self):
"""Disconnect from Discord and closes all actives connections.
Todo: add postgresql logout here
"""
2020-10-19 22:17:19 +02:00
app_config = ConfigFile(config_dir / "config.yaml", AppConfig).config
set_for_key(
app_config.Instances,
self.instance_name,
AppConfig.Instance,
active=False,
)
2020-10-21 00:26:40 +02:00
for task in self._progress["tasks"]:
2020-10-21 00:09:47 +02:00
self._progress["main"].log("Shutting down", task)
2020-10-21 00:26:40 +02:00
self._progress["main"].stop_task(self._progress["tasks"][task])
2020-10-21 00:09:47 +02:00
self._progress["main"].remove_task(
self._progress["tasks"]["connecting"]
)
2020-10-21 00:09:47 +02:00
self._progress["main"].stop()
pending = [
t for t in asyncio.all_tasks() if t is not asyncio.current_task()
]
for task in pending:
console.log("Canceling", task.get_name(), f"({task.get_coro()})")
task.cancel()
2020-09-02 00:08:06 +02:00
await asyncio.gather(*pending, return_exceptions=False)
2020-06-05 00:29:14 +02:00
await super().logout()
async def shutdown(self, *, restart: bool = False):
"""Gracefully quit.
Parameters
----------
restart:bool
If `True`, systemd or the launcher gonna see custom exit code
and reboot.
"""
if not restart:
self.shutdown_code = ExitCodes.SHUTDOWN
else:
self.shutdown_code = ExitCodes.RESTART
await self.logout()
2020-09-02 00:08:06 +02:00
sys_e = SystemExit()
sys_e.code = self.shutdown_code
raise sys_e