2020-06-05 00:29:14 +02:00
|
|
|
import asyncio
|
|
|
|
import datetime
|
2020-06-04 19:16:51 +02:00
|
|
|
import logging
|
2020-06-05 00:29:14 +02:00
|
|
|
from typing import List, Union
|
2020-06-03 01:07:43 +02:00
|
|
|
|
2020-06-04 19:16:51 +02:00
|
|
|
import discord
|
2020-06-03 01:07:43 +02:00
|
|
|
from discord.ext import commands
|
2020-08-28 23:05:04 +02:00
|
|
|
from rich import box
|
|
|
|
from rich.columns import Columns
|
|
|
|
from rich.console import Console
|
|
|
|
from rich.panel import Panel
|
|
|
|
from rich.progress import Progress, TextColumn, BarColumn
|
|
|
|
from rich.table import Table
|
2020-08-28 01:06:57 +02:00
|
|
|
from rich.traceback import install
|
2020-10-19 00:20:58 +02:00
|
|
|
|
2020-08-28 23:05:04 +02:00
|
|
|
from tuxbot import version_info
|
2020-08-28 01:06:57 +02:00
|
|
|
|
2020-10-19 00:53:26 +02:00
|
|
|
from .config import Config, ConfigFile, search_for
|
2020-10-19 00:20:58 +02:00
|
|
|
from .data_manager import logs_data_path, data_path
|
2020-06-03 01:07:43 +02:00
|
|
|
|
2020-08-28 01:06:57 +02:00
|
|
|
from . import __version__, ExitCodes
|
2020-10-19 00:20:58 +02:00
|
|
|
from . import exceptions
|
2020-06-05 00:29:14 +02:00
|
|
|
from .utils.functions.extra import ContextPlus
|
2020-10-19 00:20:58 +02:00
|
|
|
from .utils.functions.prefix import get_prefixes
|
2020-06-04 19:16:51 +02:00
|
|
|
|
|
|
|
log = logging.getLogger("tuxbot")
|
2020-08-28 23:05:04 +02:00
|
|
|
console = Console()
|
|
|
|
install(console=console)
|
2020-06-04 19:16:51 +02:00
|
|
|
|
2020-10-19 00:20:58 +02:00
|
|
|
packages: List[str] = ["jishaku", "tuxbot.cogs.admin"]
|
2020-06-03 01:07:43 +02:00
|
|
|
|
|
|
|
|
|
|
|
class Tux(commands.AutoShardedBot):
|
2020-06-05 00:29:14 +02:00
|
|
|
_loading: asyncio.Task
|
2020-08-28 23:05:04 +02:00
|
|
|
_progress = {
|
2020-10-19 00:20:58 +02:00
|
|
|
"main": Progress(
|
2020-08-28 23:05:04 +02:00
|
|
|
TextColumn("[bold blue]{task.fields[task_name]}", justify="right"),
|
2020-10-19 00:20:58 +02:00
|
|
|
BarColumn(),
|
2020-08-28 23:05:04 +02:00
|
|
|
),
|
2020-10-19 00:20:58 +02:00
|
|
|
"tasks": {},
|
2020-08-28 23:05:04 +02:00
|
|
|
}
|
2020-06-05 00:29:14 +02:00
|
|
|
|
|
|
|
def __init__(self, *args, cli_flags=None, **kwargs):
|
2020-06-03 19:41:30 +02:00
|
|
|
# by default, if the bot shutdown without any intervention,
|
|
|
|
# it's a crash
|
2020-06-05 00:29:14 +02:00
|
|
|
self.shutdown_code = ExitCodes.CRITICAL
|
2020-06-04 00:14:50 +02:00
|
|
|
self.cli_flags = cli_flags
|
|
|
|
self.instance_name = self.cli_flags.instance_name
|
|
|
|
self.last_exception = None
|
2020-06-05 00:29:14 +02:00
|
|
|
self.logs = logs_data_path(self.instance_name)
|
2020-06-03 19:41:30 +02:00
|
|
|
|
2020-10-19 00:20:58 +02:00
|
|
|
self.config: Config = ConfigFile(
|
|
|
|
str(data_path(self.instance_name) / "config.yaml"), Config
|
|
|
|
).config
|
2020-06-04 19:16:51 +02:00
|
|
|
|
|
|
|
async def _prefixes(bot, message) -> List[str]:
|
2020-10-19 00:20:58 +02:00
|
|
|
prefixes = self.config.Core.prefixes
|
2020-06-05 00:29:14 +02:00
|
|
|
|
2020-10-19 00:20:58 +02:00
|
|
|
prefixes.extend(get_prefixes(self, message.guild))
|
2020-06-04 19:16:51 +02:00
|
|
|
|
2020-10-19 00:20:58 +02:00
|
|
|
if self.config.Core.mentionable:
|
2020-06-04 19:16:51 +02:00
|
|
|
return commands.when_mentioned_or(*prefixes)(bot, message)
|
|
|
|
return prefixes
|
|
|
|
|
|
|
|
if "command_prefix" not in kwargs:
|
|
|
|
kwargs["command_prefix"] = _prefixes
|
2020-06-03 19:41:30 +02:00
|
|
|
|
|
|
|
if "owner_ids" in kwargs:
|
|
|
|
kwargs["owner_ids"] = set(kwargs["owner_ids"])
|
|
|
|
else:
|
2020-10-19 00:20:58 +02:00
|
|
|
kwargs["owner_ids"] = self.config.Core.owners_id
|
2020-06-03 19:41:30 +02:00
|
|
|
|
|
|
|
message_cache_size = 100_000
|
|
|
|
kwargs["max_messages"] = message_cache_size
|
2020-06-04 00:14:50 +02:00
|
|
|
self.max_messages = message_cache_size
|
2020-06-03 19:41:30 +02:00
|
|
|
|
2020-06-04 00:14:50 +02:00
|
|
|
self.uptime = None
|
2020-06-05 00:29:14 +02:00
|
|
|
self._app_owners_fetched = False # to prevent abusive API calls
|
2020-06-04 00:14:50 +02:00
|
|
|
|
2020-06-04 19:16:51 +02:00
|
|
|
super().__init__(*args, help_command=None, **kwargs)
|
2020-06-04 00:14:50 +02:00
|
|
|
|
2020-06-05 00:29:14 +02:00
|
|
|
async def load_packages(self):
|
|
|
|
if packages:
|
2020-08-28 23:05:04 +02:00
|
|
|
with Progress() as progress:
|
|
|
|
task = progress.add_task(
|
2020-10-19 00:20:58 +02:00
|
|
|
"Loading packages...", total=len(packages)
|
2020-08-28 23:05:04 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
for package in packages:
|
|
|
|
try:
|
|
|
|
self.load_extension(package)
|
|
|
|
progress.console.print(f"{package} loaded")
|
|
|
|
except Exception as e:
|
|
|
|
log.exception(
|
2020-10-19 00:20:58 +02:00
|
|
|
"Failed to load package %s", package, exc_info=e
|
2020-08-28 23:05:04 +02:00
|
|
|
)
|
|
|
|
progress.console.print(
|
|
|
|
f"[red]Failed to load package {package} "
|
|
|
|
f"[i](see "
|
|
|
|
f"{str((self.logs / 'tuxbot.log').resolve())} "
|
|
|
|
f"for more details)[/i]"
|
|
|
|
)
|
|
|
|
|
|
|
|
progress.advance(task)
|
2020-06-06 18:51:47 +02:00
|
|
|
|
2020-06-04 19:16:51 +02:00
|
|
|
async def on_ready(self):
|
2020-06-05 00:29:14 +02:00
|
|
|
self.uptime = datetime.datetime.now()
|
2020-08-28 23:05:04 +02:00
|
|
|
self._progress.get("main").stop_task(
|
|
|
|
self._progress.get("tasks")["connecting"]
|
|
|
|
)
|
|
|
|
self._progress.get("main").remove_task(
|
|
|
|
self._progress.get("tasks")["connecting"]
|
|
|
|
)
|
2020-09-02 00:08:06 +02:00
|
|
|
self._progress.get("tasks").pop("connecting")
|
2020-08-28 23:05:04 +02:00
|
|
|
console.clear()
|
|
|
|
|
|
|
|
console.print(
|
|
|
|
Panel(f"[bold blue]Tuxbot V{version_info.major}", style="blue"),
|
2020-10-19 00:20:58 +02:00
|
|
|
justify="center",
|
2020-08-28 23:05:04 +02:00
|
|
|
)
|
|
|
|
console.print()
|
|
|
|
|
|
|
|
columns = Columns(expand=True, padding=2, align="center")
|
|
|
|
|
2020-10-19 00:20:58 +02:00
|
|
|
table = Table(style="dim", border_style="not dim", box=box.HEAVY_HEAD)
|
2020-08-28 23:05:04 +02:00
|
|
|
table.add_column(
|
|
|
|
"INFO",
|
|
|
|
)
|
|
|
|
table.add_row(str(self.user))
|
2020-10-19 00:20:58 +02:00
|
|
|
table.add_row(f"Prefixes: {', '.join(self.config.Core.prefixes)}")
|
|
|
|
table.add_row(f"Language: {self.config.Core.locale}")
|
2020-08-28 23:05:04 +02:00
|
|
|
table.add_row(f"Tuxbot Version: {__version__}")
|
|
|
|
table.add_row(f"Discord.py Version: {discord.__version__}")
|
|
|
|
table.add_row(f"Shards: {self.shard_count}")
|
|
|
|
table.add_row(f"Servers: {len(self.guilds)}")
|
|
|
|
table.add_row(f"Users: {len(self.users)}")
|
|
|
|
columns.add_renderable(table)
|
|
|
|
|
2020-10-19 00:20:58 +02:00
|
|
|
table = Table(style="dim", border_style="not dim", box=box.HEAVY_HEAD)
|
2020-08-28 23:05:04 +02:00
|
|
|
table.add_column(
|
|
|
|
"COGS",
|
|
|
|
)
|
2020-06-05 00:29:14 +02:00
|
|
|
for extension in packages:
|
2020-08-28 23:05:04 +02:00
|
|
|
if extension in self.extensions:
|
2020-08-29 01:01:34 +02:00
|
|
|
status = f"[green]:heavy_check_mark: {extension}"
|
2020-08-28 23:05:04 +02:00
|
|
|
else:
|
2020-08-29 01:01:34 +02:00
|
|
|
status = f"[red]:heavy_multiplication_x: {extension}"
|
2020-06-04 19:16:51 +02:00
|
|
|
|
2020-08-28 23:05:04 +02:00
|
|
|
table.add_row(status)
|
|
|
|
columns.add_renderable(table)
|
2020-06-04 19:16:51 +02:00
|
|
|
|
2020-08-28 23:05:04 +02:00
|
|
|
console.print(columns)
|
|
|
|
console.print()
|
2020-06-05 00:29:14 +02:00
|
|
|
|
2020-10-19 00:20:58 +02:00
|
|
|
async def is_owner(
|
|
|
|
self, user: Union[discord.User, discord.Member]
|
|
|
|
) -> bool:
|
2020-06-05 00:29:14 +02:00
|
|
|
"""Determines if the user is a bot owner.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
user: Union[discord.User, discord.Member]
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
bool
|
|
|
|
"""
|
2020-10-19 00:20:58 +02:00
|
|
|
if user.id in self.config.Core.owners_id:
|
2020-06-05 00:29:14 +02:00
|
|
|
return True
|
|
|
|
|
|
|
|
owner = False
|
|
|
|
if not self._app_owners_fetched:
|
|
|
|
app = await self.application_info()
|
|
|
|
if app.team:
|
|
|
|
ids = [m.id for m in app.team.members]
|
2020-10-19 00:20:58 +02:00
|
|
|
self.config.Core.owners_id = ids
|
2020-06-05 00:29:14 +02:00
|
|
|
owner = user.id in ids
|
|
|
|
self._app_owners_fetched = True
|
|
|
|
|
|
|
|
return owner
|
|
|
|
|
2020-10-19 00:20:58 +02:00
|
|
|
# pylint: disable=unused-argument
|
2020-06-05 00:29:14 +02:00
|
|
|
async def get_context(self, message: discord.Message, *, cls=None):
|
|
|
|
return await super().get_context(message, cls=ContextPlus)
|
|
|
|
|
|
|
|
async def process_commands(self, message: discord.Message):
|
2020-10-19 00:20:58 +02:00
|
|
|
"""Check for blacklists."""
|
2020-06-05 00:29:14 +02:00
|
|
|
if message.author.bot:
|
|
|
|
return
|
|
|
|
|
2020-06-06 18:51:47 +02:00
|
|
|
if (
|
2020-10-19 00:53:26 +02:00
|
|
|
search_for(self.config.Servers, message.guild.id, "blacklisted")
|
|
|
|
or search_for(
|
|
|
|
self.config.Channels, message.channel.id, "blacklisted"
|
|
|
|
)
|
|
|
|
or search_for(self.config.Users, message.author.id, "blacklisted")
|
2020-06-06 18:51:47 +02:00
|
|
|
):
|
2020-06-05 00:29:14 +02:00
|
|
|
return
|
|
|
|
|
2020-10-19 00:20:58 +02:00
|
|
|
ctx: ContextPlus = await self.get_context(message)
|
2020-06-05 00:29:14 +02:00
|
|
|
|
|
|
|
if ctx is None or ctx.valid is False:
|
|
|
|
self.dispatch("message_without_command", message)
|
|
|
|
else:
|
2020-10-19 00:53:26 +02:00
|
|
|
if ctx.command in search_for(
|
|
|
|
self.config.Servers, message.guild.id, "disabled_command", []
|
2020-10-19 00:20:58 +02:00
|
|
|
):
|
|
|
|
raise exceptions.DisabledCommandByServerOwner
|
|
|
|
|
|
|
|
if ctx.command in self.config.Core.disabled_command:
|
|
|
|
raise exceptions.DisabledCommandByBotOwner
|
|
|
|
|
2020-06-05 00:29:14 +02:00
|
|
|
await self.invoke(ctx)
|
|
|
|
|
|
|
|
async def on_message(self, message: discord.Message):
|
|
|
|
await self.process_commands(message)
|
|
|
|
|
2020-10-19 00:20:58 +02:00
|
|
|
async def start(self, token, bot): # pylint: disable=arguments-differ
|
2020-08-28 23:05:04 +02:00
|
|
|
"""Connect to Discord and start all connections.
|
|
|
|
|
|
|
|
Todo: add postgresql connect here
|
|
|
|
"""
|
2020-10-19 00:20:58 +02:00
|
|
|
with self._progress.get("main") as progress:
|
|
|
|
task_id = self._progress.get("tasks")[
|
|
|
|
"connecting"
|
|
|
|
] = progress.add_task(
|
|
|
|
"connecting", task_name="Connecting to Discord...", start=False
|
2020-08-28 23:05:04 +02:00
|
|
|
)
|
2020-10-19 00:20:58 +02:00
|
|
|
progress.update(task_id)
|
2020-08-28 23:05:04 +02:00
|
|
|
await super().start(token, bot=bot)
|
|
|
|
|
2020-06-05 00:29:14 +02:00
|
|
|
async def logout(self):
|
|
|
|
"""Disconnect from Discord and closes all actives connections.
|
|
|
|
|
|
|
|
Todo: add postgresql logout here
|
|
|
|
"""
|
2020-08-28 23:05:04 +02:00
|
|
|
for task in self._progress.get("tasks").keys():
|
|
|
|
self._progress.get("main").log("Shutting down", task)
|
|
|
|
|
|
|
|
self._progress.get("main").stop_task(
|
|
|
|
self._progress.get("tasks")[task]
|
|
|
|
)
|
|
|
|
self._progress.get("main").remove_task(
|
|
|
|
self._progress.get("tasks")["connecting"]
|
|
|
|
)
|
|
|
|
self._progress.get("main").stop()
|
|
|
|
|
|
|
|
pending = [
|
2020-10-19 00:20:58 +02:00
|
|
|
t for t in asyncio.all_tasks() if t is not asyncio.current_task()
|
2020-08-28 23:05:04 +02:00
|
|
|
]
|
|
|
|
|
|
|
|
for task in pending:
|
|
|
|
console.log("Canceling", task.get_name(), f"({task.get_coro()})")
|
|
|
|
task.cancel()
|
2020-09-02 00:08:06 +02:00
|
|
|
await asyncio.gather(*pending, return_exceptions=False)
|
2020-08-28 23:05:04 +02:00
|
|
|
|
2020-06-05 00:29:14 +02:00
|
|
|
await super().logout()
|
|
|
|
|
|
|
|
async def shutdown(self, *, restart: bool = False):
|
|
|
|
"""Gracefully quit.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
restart:bool
|
|
|
|
If `True`, systemd or the launcher gonna see custom exit code
|
|
|
|
and reboot.
|
|
|
|
|
|
|
|
"""
|
|
|
|
if not restart:
|
|
|
|
self.shutdown_code = ExitCodes.SHUTDOWN
|
|
|
|
else:
|
|
|
|
self.shutdown_code = ExitCodes.RESTART
|
|
|
|
|
|
|
|
await self.logout()
|
2020-09-02 00:08:06 +02:00
|
|
|
|
|
|
|
sys_e = SystemExit()
|
|
|
|
sys_e.code = self.shutdown_code
|
|
|
|
|
|
|
|
raise sys_e
|