From 29808d41d6012f0b02c165d5160a287b26815c56 Mon Sep 17 00:00:00 2001 From: Romain J Date: Sun, 29 Sep 2019 18:31:01 +0200 Subject: [PATCH] refactor(database): migrate to sqlalchemy --- bot.py | 8 +- cogs/admin.py | 269 +++++++++++++++++++------------------- cogs/utils/__init__.py | 3 +- cogs/utils/db.py | 12 -- cogs/utils/lang.py | 15 ++- cogs/utils/models/warn.py | 20 +++ launcher.py | 40 +++--- requirements.txt | 1 + 8 files changed, 188 insertions(+), 180 deletions(-) delete mode 100755 cogs/utils/db.py create mode 100644 cogs/utils/models/warn.py diff --git a/bot.py b/bot.py index 120e46c..b1c06cd 100755 --- a/bot.py +++ b/bot.py @@ -4,7 +4,7 @@ import sys from collections import deque, Counter import aiohttp -import asyncpg +import sqlalchemy import discord import git from discord.ext import commands @@ -40,7 +40,7 @@ async def _prefix_callable(bot, message: discord.message) -> list: class TuxBot(commands.AutoShardedBot): - def __init__(self, unload: list, db: asyncpg.pool.Pool): + def __init__(self, unload: list, engine: sqlalchemy.engine.Engine): super().__init__(command_prefix=_prefix_callable, pm_help=None, help_command=None, description=description, help_attrs=dict(hidden=True), @@ -53,7 +53,7 @@ class TuxBot(commands.AutoShardedBot): self.uptime: datetime = datetime.datetime.utcnow() self.config = config - self.db = db + self.engine = engine self._prev_events = deque(maxlen=10) self.session = aiohttp.ClientSession(loop=self.loop) @@ -137,8 +137,6 @@ class TuxBot(commands.AutoShardedBot): async def close(self): await super().close() - await self.db.close() - await self.session.close() def run(self): super().run(config.token, reconnect=True) diff --git a/cogs/admin.py b/cogs/admin.py index 7986cdc..a211a9a 100644 --- a/cogs/admin.py +++ b/cogs/admin.py @@ -3,12 +3,14 @@ import logging from typing import Union import asyncio + import discord import humanize from discord.ext import commands from bot import TuxBot from .utils.lang import Texts +from .utils.models.warn import Warn log = logging.getLogger(__name__) @@ -63,12 +65,13 @@ class Admin(commands.Cog): @commands.group(name='say', invoke_without_command=True) async def _say(self, ctx: commands.Context, *, content: str): - try: - await ctx.message.delete() - except discord.errors.Forbidden: - pass + if ctx.invoked_subcommand is None: + try: + await ctx.message.delete() + except discord.errors.Forbidden: + pass - await ctx.send(content) + await ctx.send(content) @_say.command(name='edit') async def _say_edit(self, ctx: commands.Context, message_id: int, *, @@ -226,158 +229,142 @@ class Admin(commands.Cog): async def get_warn(self, ctx: commands.Context, member: discord.Member = False): - query = """ - SELECT * FROM warns - WHERE created_at >= $1 AND server_id = $2 - """ - query += """AND user_id = $3""" if member else "" - query += """ORDER BY created_at DESC""" + await ctx.trigger_typing() + week_ago = datetime.datetime.now() - datetime.timedelta(weeks=6) - async with self.bot.db.acquire() as con: - await ctx.trigger_typing() - args = [week_ago, ctx.guild.id] - if member: - args.append(member.id) + if member: + warns = self.bot.engine \ + .query(Warn) \ + .filter(Warn.user_id == member.id, Warn.created_at > week_ago, + Warn.server_id == ctx.guild.id) \ + .order_by(Warn.created_at.desc()) + else: + warns = self.bot.engine \ + .query(Warn) \ + .filter(Warn.created_at > week_ago, + Warn.server_id == ctx.guild.id) \ + .order_by(Warn.created_at.desc()) + warns_list = '' - warns = await con.fetch(query, *args) - warns_list = '' + for warn in warns: + row_id = warn.id + user_id = warn.user_id + user = await self.bot.fetch_user(user_id) + reason = warn.reason + ago = humanize.naturaldelta( + datetime.datetime.now() - warn.created_at + ) - for warn in warns: - row_id = warn.get('id') - user_id = warn.get('user_id') - user = await self.bot.fetch_user(user_id) - reason = warn.get('reason') - ago = humanize.naturaldelta( - datetime.datetime.now() - warn.get('created_at') - ) - - warns_list += f"[{row_id}] **{user}**: `{reason}` " \ - f"*({ago} ago)*\n" + warns_list += f"[{row_id}] **{user}**: `{reason}` *({ago} ago)*\n" return warns_list, warns + async def add_warn(self, ctx: commands.Context, member: discord.Member, + reason): + + now = datetime.datetime.now() + warn = Warn(server_id=ctx.guild.id, user_id=member.id, reason=reason, + created_at=now) + + self.bot.engine.add(warn) + self.bot.engine.commit() + @commands.group(name='warn', aliases=['warns']) async def _warn(self, ctx: commands.Context): + await ctx.trigger_typing() if ctx.invoked_subcommand is None: warns_list, warns = await self.get_warn(ctx) e = discord.Embed( - title=f"{len(warns)} {Texts('admin').get('last warns')}: ", + title=f"{warns.count()} {Texts('admin').get('last warns')}: ", description=warns_list ) await ctx.send(embed=e) - async def add_warn(self, ctx: commands.Context, member: discord.Member, - reason): - - query = """ - INSERT INTO warns (server_id, user_id, reason, created_at) - VALUES ($1, $2, $3, $4) - """ - - now = datetime.datetime.now() - await self.bot.db.execute(query, ctx.guild.id, member.id, reason, now) - @_warn.command(name='add', aliases=['new']) async def _warn_new(self, ctx: commands.Context, member: discord.Member, *, reason="N/A"): - member = await ctx.guild.fetch_member(member.id) if not member: return await ctx.send( Texts('utils').get("Unable to find the user...") ) - query = """ - SELECT user_id, reason, created_at FROM warns - WHERE created_at >= $1 AND server_id = $2 and user_id = $3 - """ - week_ago = datetime.datetime.now() - datetime.timedelta(weeks=6) - def check(pld: discord.RawReactionActionEvent): if pld.message_id != choice.id \ or pld.user_id != ctx.author.id: return False return pld.emoji.name in ('1⃣', '2⃣', '3⃣') - async with self.bot.db.acquire() as con: - await ctx.trigger_typing() - warns = await con.fetch(query, week_ago, ctx.guild.id, member.id) + warns_list, warns = await self.get_warn(ctx) - if len(warns) >= 2: - e = discord.Embed( - title=Texts('admin').get('More than 2 warns'), - description=f"{member.mention} " - + Texts('admin').get('has more than 2 warns') - ) - e.add_field( - name='__Actions__', - value=':one: kick\n' - ':two: ban\n' - ':three: ' + Texts('admin').get('ignore') - ) - - choice = await ctx.send(embed=e) - - for reaction in ('1⃣', '2⃣', '3⃣'): - await choice.add_reaction(reaction) - - try: - payload = await self.bot.wait_for( - 'raw_reaction_add', - check=check, - timeout=50.0 - ) - except asyncio.TimeoutError: - return await ctx.send( - Texts('admin').get('Took too long. Aborting.') - ) - finally: - await choice.delete() - - if payload.emoji.name == '1⃣': - from jishaku.models import copy_context_with - - alt_ctx = await copy_context_with( - ctx, - content=f"{ctx.prefix}" - f"kick " - f"{member} " - f"{Texts('admin').get('More than 2 warns')}" - ) - return await alt_ctx.command.invoke(alt_ctx) - - elif payload.emoji.name == '2⃣': - from jishaku.models import copy_context_with - - alt_ctx = await copy_context_with( - ctx, - content=f"{ctx.prefix}" - f"ban " - f"{member} " - f"{Texts('admin').get('More than 2 warns')}" - ) - return await alt_ctx.command.invoke(alt_ctx) - - await self.add_warn(ctx, member, reason) - await ctx.send( - content=f"{member.mention} " - f"**{Texts('admin').get('got a warn')}**" - f"\n**{Texts('admin').get('Reason')}:** `{reason}`" - if reason != 'N/A' else '' + if warns.count() >= 3: + e = discord.Embed( + title=Texts('admin').get('More than 2 warns'), + description=f"{member.mention} " + + Texts('admin').get('has more than 2 warns') ) + e.add_field( + name='__Actions__', + value=':one: kick\n' + ':two: ban\n' + ':three: ' + Texts('admin').get('ignore') + ) + + choice = await ctx.send(embed=e) + + for reaction in ('1⃣', '2⃣', '3⃣'): + await choice.add_reaction(reaction) + + try: + payload = await self.bot.wait_for( + 'raw_reaction_add', + check=check, + timeout=50.0 + ) + except asyncio.TimeoutError: + return await ctx.send( + Texts('admin').get('Took too long. Aborting.') + ) + finally: + await choice.delete() + + if payload.emoji.name == '1⃣': + from jishaku.models import copy_context_with + + alt_ctx = await copy_context_with( + ctx, + content=f"{ctx.prefix}" + f"kick " + f"{member} " + f"{Texts('admin').get('More than 2 warns')}" + ) + return await alt_ctx.command.invoke(alt_ctx) + + elif payload.emoji.name == '2⃣': + from jishaku.models import copy_context_with + + alt_ctx = await copy_context_with( + ctx, + content=f"{ctx.prefix}" + f"ban " + f"{member} " + f"{Texts('admin').get('More than 2 warns')}" + ) + return await alt_ctx.command.invoke(alt_ctx) + + await self.add_warn(ctx, member, reason) + await ctx.send( + content=f"{member.mention} " + f"**{Texts('admin').get('got a warn')}**" + f"\n**{Texts('admin').get('Reason')}:** `{reason}`" + ) @_warn.command(name='remove', aliases=['revoke']) async def _warn_remove(self, ctx: commands.Context, warn_id: int): - query = """ - DELETE FROM warns - WHERE id = $1 - """ - - async with self.bot.db.acquire() as con: - await ctx.trigger_typing() - await con.fetch(query, warn_id) + warn = self.bot.engine.query(Warn).filter(Warn.id == warn_id).one() + self.bot.engine.delete(warn) await ctx.send(f"{Texts('admin').get('Warn with id')} `{warn_id}`" f" {Texts('admin').get('successfully removed')}") @@ -385,8 +372,9 @@ class Admin(commands.Cog): @_warn.command(name='show', aliases=['list']) async def _warn_show(self, ctx: commands.Context, member: discord.Member): warns_list, warns = await self.get_warn(ctx, member) + e = discord.Embed( - title=f"{len(warns)} {Texts('admin').get('last warns')}: ", + title=f"{warns.count()} {Texts('admin').get('last warns')}: ", description=warns_list ) @@ -394,27 +382,40 @@ class Admin(commands.Cog): @_warn.command(name='edit', aliases=['change']) async def _warn_edit(self, ctx: commands.Context, warn_id: int, *, reason): - query = """ - UPDATE warns - SET reason = $2 - WHERE id = $1 - """ - - async with self.bot.db.acquire() as con: - await ctx.trigger_typing() - await con.fetch(query, warn_id, reason) + warn = self.bot.engine.query(Warn).filter(Warn.id == warn_id).one() + warn.reason = reason + self.bot.engine.commit() await ctx.send(f"{Texts('admin').get('Warn with id')} `{warn_id}`" f" {Texts('admin').get('successfully edited')}") """---------------------------------------------------------------------""" - @commands.command(name='set-language', aliases=['set-lang']) - async def _set_language(self, ctx: commands.Context, lang): - """ - todo: set lang for guild + @commands.command(name='language', aliases=['lang', 'langue', 'langage']) + async def _language(self, ctx: commands.Context, locale): + query = """ + SELECT locale + FROM lang + WHERE key = 'available' """ + async with self.bot.engine.begin() as con: + await ctx.trigger_typing() + available = list(await con.fetchrow(query)) + + if str(locale) in available: + query = """ + IF EXISTS(SELECT * FROM lang WHERE key = $1 ) + then + UPDATE lang + SET locale = $2 + WHERE key = $1 + ELSE + INSERT INTO lang (key, locale) + VALUES ($1, $2) + """ + await con.fetch(query, str(ctx.guild.id), str(locale)) + def setup(bot: TuxBot): bot.add_cog(Admin(bot)) diff --git a/cogs/utils/__init__.py b/cogs/utils/__init__.py index caaa513..9bc1674 100755 --- a/cogs/utils/__init__.py +++ b/cogs/utils/__init__.py @@ -1,5 +1,4 @@ from .checks import * from .config import * -from .db import * from .lang import * -from .version import * \ No newline at end of file +from .version import * diff --git a/cogs/utils/db.py b/cogs/utils/db.py deleted file mode 100755 index 554a7e6..0000000 --- a/cogs/utils/db.py +++ /dev/null @@ -1,12 +0,0 @@ -import logging - -import asyncpg - -log = logging.getLogger(__name__) - - -class Table: - @classmethod - async def create_pool(cls, uri, **kwargs) -> asyncpg.pool.Pool: - cls._pool = db = await asyncpg.create_pool(uri, **kwargs) - return db diff --git a/cogs/utils/lang.py b/cogs/utils/lang.py index cd00c01..2978056 100644 --- a/cogs/utils/lang.py +++ b/cogs/utils/lang.py @@ -5,12 +5,13 @@ import config class Texts: def __init__(self, base: str = 'base'): self.locale = config.locale - self.texts = gettext.translation(base, localedir='extras/locales', - languages=[self.locale]) - self.texts.install() - - def __str__(self) -> str: - return self.texts + self.base = base def get(self, text: str) -> str: - return self.texts.gettext(text) + texts = gettext.translation(self.base, localedir='extras/locales', + languages=[self.locale]) + texts.install() + return texts.gettext(text) + + def set(self, lang: str): + self.locale = lang diff --git a/cogs/utils/models/warn.py b/cogs/utils/models/warn.py new file mode 100644 index 0000000..7cf3ef5 --- /dev/null +++ b/cogs/utils/models/warn.py @@ -0,0 +1,20 @@ +import datetime + +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import Column, Integer, String, BIGINT, TIMESTAMP +Base = declarative_base() + + +class Warn(Base): + __tablename__ = 'warns' + + id = Column(Integer, primary_key=True) + server_id = Column(BIGINT) + user_id = Column(BIGINT) + reason = Column(String) + created_at = Column(TIMESTAMP, default=datetime.datetime.now()) + + def __repr__(self): + return ""\ + % (self.server_id, self.user_id, self.reason, self.created_at) diff --git a/launcher.py b/launcher.py index fb78fec..ff9146a 100644 --- a/launcher.py +++ b/launcher.py @@ -1,4 +1,9 @@ -import asyncio +try: + import config + from cogs.utils.lang import Texts +except ModuleNotFoundError: + import extras.first_run + import contextlib import logging import socket @@ -9,13 +14,8 @@ import git import requests from bot import TuxBot -from cogs.utils.db import Table - -try: - import config - from cogs.utils.lang import Texts -except ModuleNotFoundError: - import extras.first_run +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker @contextlib.contextmanager @@ -39,36 +39,36 @@ def setup_logging(): yield finally: handlers = log.handlers[:] - for hdlr in handlers: - hdlr.close() - log.removeHandler(hdlr) + for handler in handlers: + handler.close() + log.removeHandler(handler) def run_bot(unload: list = []): - loop = asyncio.get_event_loop() log = logging.getLogger() print(Texts().get('Starting...')) try: - db = loop.run_until_complete( - Table.create_pool(config.postgresql, command_timeout=60) - ) + engine = create_engine(config.postgresql) + + Session = sessionmaker() + Session.configure(bind=engine) except socket.gaierror: click.echo(Texts().get("Could not set up PostgreSQL..."), file=sys.stderr) log.exception(Texts().get("Could not set up PostgreSQL...")) return - bot = TuxBot(unload, db) + bot = TuxBot(unload, Session()) bot.run() @click.command() @click.option('-d', '--unload', multiple=True, type=str, help=Texts().get("Launch without loading the module")) -@click.option('-u', '--update', help=Texts().get("Search for update"), - is_flag=True) +@click.option('-u', '--update', is_flag=True, + help=Texts().get("Search for update")) def main(**kwargs): if kwargs.get('update'): _update() @@ -77,8 +77,8 @@ def main(**kwargs): run_bot(kwargs.get('unload')) -@click.option('-d', '--update', help=Texts().get("Search for update"), - is_flag=True) +@click.option('-d', '--update', is_flag=True, + help=Texts().get("Search for update")) def _update(): print(Texts().get("Checking for update...")) diff --git a/requirements.txt b/requirements.txt index 789361e..4b42423 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ jishaku lxml click asyncpg>=0.12.0 +sqlalchemy gitpython requests psutil \ No newline at end of file