import asyncio
import datetime
import importlib
import logging
from collections import Counter
from typing import List, Union

import aiohttp
import discord
from discord.ext import commands
from rich import box
from rich.columns import Columns
from rich.console import Console
from rich.panel import Panel
from rich.progress import Progress, TextColumn, BarColumn
from rich.table import Table
from tortoise import Tortoise

from tuxbot import version_info
from tuxbot.core.utils.data_manager import (
    logs_data_path,
    data_path,
    config_dir,
)
from .config import (
    Config,
    ConfigFile,
    search_for,
    AppConfig,
    set_for_key,
)
from . import __version__, ExitCodes
from . import exceptions
from .utils.functions.extra import ContextPlus
from .utils.functions.prefix import get_prefixes

log = logging.getLogger("tuxbot")
console = Console()

packages: List[str] = [
    "jishaku",
    "tuxbot.cogs.Admin",
    "tuxbot.cogs.Logs",
    # "tuxbot.cogs.Dev",
    "tuxbot.cogs.Utils",
    "tuxbot.cogs.Polls",
    "tuxbot.cogs.Custom",
]


class Tux(commands.AutoShardedBot):
    _loading: asyncio.Task
    _progress = {
        "main": Progress(
            TextColumn("[bold blue]{task.fields[task_name]}", justify="right"),
            BarColumn(),
        ),
        "tasks": {},
    }

    def __init__(self, *args, cli_flags=None, **kwargs):
        # by default, if the bot shutdown without any intervention,
        # it's a crash
        self.shutdown_code = ExitCodes.CRITICAL
        self.cli_flags = cli_flags
        self.instance_name = self.cli_flags.instance_name
        self.last_exception = None
        self.logs = logs_data_path(self.instance_name)

        self.console = console

        self.stats = {"commands": Counter(), "socket": Counter()}

        self.config: Config = ConfigFile(
            str(data_path(self.instance_name) / "config.yaml"), Config
        ).config

        async def _prefixes(bot, message) -> List[str]:
            prefixes = self.config.Core.prefixes

            prefixes.extend(get_prefixes(self, message.guild))

            if self.config.Core.mentionable:
                return commands.when_mentioned_or(*prefixes)(bot, message)
            return prefixes

        if "command_prefix" not in kwargs:
            kwargs["command_prefix"] = _prefixes

        if "owner_ids" in kwargs:
            kwargs["owner_ids"] = set(kwargs["owner_ids"])
        else:
            kwargs["owner_ids"] = self.config.Core.owners_id

        message_cache_size = 100_000
        kwargs["max_messages"] = message_cache_size
        self.max_messages = message_cache_size

        self.uptime = None
        self._app_owners_fetched = False  # to prevent abusive API calls

        super().__init__(
            *args, help_command=None, intents=discord.Intents.all(), **kwargs
        )
        self.session = aiohttp.ClientSession(loop=self.loop)

    async def _is_blacklister(self, message: discord.Message) -> bool:
        """Check for blacklists."""
        if message.author.bot:
            return True

        if (
            search_for(self.config.Servers, message.guild.id, "blacklisted")
            or search_for(
                self.config.Channels, message.channel.id, "blacklisted"
            )
            or search_for(self.config.Users, message.author.id, "blacklisted")
        ):
            return True

        return False

    async def load_packages(self):
        if packages:
            with Progress() as progress:
                task = progress.add_task(
                    "Loading packages...", total=len(packages)
                )

                for package in packages:
                    try:
                        self.load_extension(package)
                        progress.console.print(f"{package} loaded")
                    except Exception as e:
                        log.exception(
                            "Failed to load package %s", package, exc_info=e
                        )
                        progress.console.print(
                            f"[red]Failed to load package {package} "
                            f"[i](see "
                            f"{str((self.logs / 'tuxbot.log').resolve())} "
                            f"for more details)[/i]"
                        )

                    progress.advance(task)

    async def on_ready(self):
        self.uptime = datetime.datetime.now()
        app_config = ConfigFile(config_dir / "config.yaml", AppConfig).config
        set_for_key(
            app_config.Instances,
            self.instance_name,
            AppConfig.Instance,
            active=True,
            last_run=datetime.datetime.timestamp(self.uptime),
        )

        self._progress["main"].stop_task(self._progress["tasks"]["connecting"])
        self._progress["main"].remove_task(
            self._progress["tasks"]["connecting"]
        )
        self._progress["tasks"].pop("connecting")
        console.clear()

        console.print(
            Panel(f"[bold blue]Tuxbot V{version_info.major}", style="blue"),
            justify="center",
        )
        console.print()

        columns = Columns(expand=True, align="center")

        table = Table(style="dim", border_style="not dim", box=box.HEAVY_HEAD)
        table.add_column(
            "INFO",
        )
        table.add_row(str(self.user))
        table.add_row(f"Prefixes: {', '.join(self.config.Core.prefixes)}")
        table.add_row(f"Language: {self.config.Core.locale}")
        table.add_row(f"Tuxbot Version: {__version__}")
        table.add_row(f"Discord.py Version: {discord.__version__}")
        table.add_row(f"Shards: {self.shard_count}")
        table.add_row(f"Servers: {len(self.guilds)}")
        table.add_row(f"Users: {len(self.users)}")
        columns.add_renderable(table)

        table = Table(style="dim", border_style="not dim", box=box.HEAVY_HEAD)
        table.add_column(
            "COGS",
        )
        for extension in packages:
            if extension in self.extensions:
                status = f"[green]:heavy_check_mark: {extension} "
            else:
                status = f"[red]:heavy_multiplication_x: {extension} "

            table.add_row(status)
        columns.add_renderable(table)

        console.print(columns)
        console.print()

    async def is_owner(
        self, user: Union[discord.User, discord.Member]
    ) -> bool:
        """Determines if the user is a bot owner.

        Parameters
        ----------
        user: Union[discord.User, discord.Member]

        Returns
        -------
        bool
        """
        if user.id in self.config.Core.owners_id:
            return True

        owner = False
        if not self._app_owners_fetched:
            app = await self.application_info()
            if app.team:
                ids = [m.id for m in app.team.members]
                self.config.Core.owners_id = ids
                owner = user.id in ids
            self._app_owners_fetched = True

        return owner

    # pylint: disable=unused-argument
    async def get_context(self, message: discord.Message, *, cls=None):
        return await super().get_context(message, cls=ContextPlus)

    async def process_commands(self, message: discord.Message):
        ctx: ContextPlus = await self.get_context(message)

        if ctx is None or not ctx.valid:
            if user_aliases := search_for(
                self.config.Users, message.author.id, "aliases"
            ):
                for alias, command in user_aliases.items():
                    back_content = message.content
                    message.content = message.content.replace(
                        alias, command, 1
                    )

                    if (
                        ctx := await self.get_context(message)
                    ) is None or not ctx.valid:
                        message.content = back_content
                    else:
                        break

            self.dispatch("message_without_command", message)

        if ctx is not None and ctx.valid:
            if ctx.command in search_for(
                self.config.Servers, message.guild.id, "disabled_command", []
            ):
                raise exceptions.DisabledCommandByServerOwner

            if ctx.command in self.config.Core.disabled_command:
                raise exceptions.DisabledCommandByBotOwner

            await self.invoke(ctx)

    async def on_message(self, message: discord.Message):
        await self.process_commands(message)

    async def start(self, token, bot):  # pylint: disable=arguments-differ
        """Connect to Discord and start all connections.

        Todo: add postgresql connect here
        """
        with self._progress.get("main") as progress:
            task_id = self._progress.get("tasks")[
                "connecting"
            ] = progress.add_task(
                "connecting",
                task_name="Connecting to PostgreSQL...",
                start=False,
            )

            models = []

            for extension, _ in self.extensions.items():
                if extension == "jishaku":
                    continue

                if importlib.import_module(extension).HAS_MODELS:
                    models.append(f"{extension}.models.__init__")

            progress.update(task_id)
            await Tortoise.init(
                db_url="postgres://{}:{}@{}:{}/{}".format(
                    self.config.Core.Database.username,
                    self.config.Core.Database.password,
                    self.config.Core.Database.domain,
                    self.config.Core.Database.port,
                    self.config.Core.Database.db_name,
                ),
                modules={"models": models},
            )
            await Tortoise.generate_schemas()

        self._progress["main"].stop_task(self._progress["tasks"]["connecting"])
        self._progress["main"].remove_task(
            self._progress["tasks"]["connecting"]
        )
        self._progress["tasks"].pop("connecting")

        with self._progress.get("main") as progress:
            task_id = self._progress.get("tasks")[
                "connecting"
            ] = progress.add_task(
                "connecting", task_name="Connecting to Discord...", start=False
            )
            progress.update(task_id)
            await super().start(token, bot=bot)

    async def logout(self):
        """Disconnect from Discord and closes all actives connections.

        Todo: add postgresql logout here
        """
        app_config = ConfigFile(config_dir / "config.yaml", AppConfig).config
        set_for_key(
            app_config.Instances,
            self.instance_name,
            AppConfig.Instance,
            active=False,
        )

        for task in self._progress["tasks"]:
            self._progress["main"].log("Shutting down", task)

            self._progress["main"].stop_task(self._progress["tasks"][task])
            self._progress["main"].remove_task(
                self._progress["tasks"]["connecting"]
            )
        self._progress["main"].stop()

        pending = [
            t for t in asyncio.all_tasks() if t is not asyncio.current_task()
        ]

        for task in pending:
            console.log("Canceling", task.get_name(), f"({task.get_coro()})")
            task.cancel()
        await asyncio.gather(*pending, return_exceptions=False)

        await super().logout()

    async def shutdown(self, *, restart: bool = False):
        """Gracefully quit.

        Parameters
        ----------
        restart:bool
            If `True`, systemd or the launcher gonna see custom exit code
            and reboot.

        """
        if not restart:
            self.shutdown_code = ExitCodes.SHUTDOWN
        else:
            self.shutdown_code = ExitCodes.RESTART

        await self.logout()

        sys_e = SystemExit()
        sys_e.code = self.shutdown_code

        raise sys_e