refactor(database): migrate to sqlalchemy

This commit is contained in:
Romain J 2019-09-29 18:31:01 +02:00
parent 8f17085cf7
commit 29808d41d6
8 changed files with 188 additions and 180 deletions

8
bot.py
View file

@ -4,7 +4,7 @@ import sys
from collections import deque, Counter from collections import deque, Counter
import aiohttp import aiohttp
import asyncpg import sqlalchemy
import discord import discord
import git import git
from discord.ext import commands from discord.ext import commands
@ -40,7 +40,7 @@ async def _prefix_callable(bot, message: discord.message) -> list:
class TuxBot(commands.AutoShardedBot): class TuxBot(commands.AutoShardedBot):
def __init__(self, unload: list, db: asyncpg.pool.Pool): def __init__(self, unload: list, engine: sqlalchemy.engine.Engine):
super().__init__(command_prefix=_prefix_callable, pm_help=None, super().__init__(command_prefix=_prefix_callable, pm_help=None,
help_command=None, description=description, help_command=None, description=description,
help_attrs=dict(hidden=True), help_attrs=dict(hidden=True),
@ -53,7 +53,7 @@ class TuxBot(commands.AutoShardedBot):
self.uptime: datetime = datetime.datetime.utcnow() self.uptime: datetime = datetime.datetime.utcnow()
self.config = config self.config = config
self.db = db self.engine = engine
self._prev_events = deque(maxlen=10) self._prev_events = deque(maxlen=10)
self.session = aiohttp.ClientSession(loop=self.loop) self.session = aiohttp.ClientSession(loop=self.loop)
@ -137,8 +137,6 @@ class TuxBot(commands.AutoShardedBot):
async def close(self): async def close(self):
await super().close() await super().close()
await self.db.close()
await self.session.close()
def run(self): def run(self):
super().run(config.token, reconnect=True) super().run(config.token, reconnect=True)

View file

@ -3,12 +3,14 @@ import logging
from typing import Union from typing import Union
import asyncio import asyncio
import discord import discord
import humanize import humanize
from discord.ext import commands from discord.ext import commands
from bot import TuxBot from bot import TuxBot
from .utils.lang import Texts from .utils.lang import Texts
from .utils.models.warn import Warn
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -63,12 +65,13 @@ class Admin(commands.Cog):
@commands.group(name='say', invoke_without_command=True) @commands.group(name='say', invoke_without_command=True)
async def _say(self, ctx: commands.Context, *, content: str): async def _say(self, ctx: commands.Context, *, content: str):
try: if ctx.invoked_subcommand is None:
await ctx.message.delete() try:
except discord.errors.Forbidden: await ctx.message.delete()
pass except discord.errors.Forbidden:
pass
await ctx.send(content) await ctx.send(content)
@_say.command(name='edit') @_say.command(name='edit')
async def _say_edit(self, ctx: commands.Context, message_id: int, *, async def _say_edit(self, ctx: commands.Context, message_id: int, *,
@ -226,158 +229,142 @@ class Admin(commands.Cog):
async def get_warn(self, ctx: commands.Context, async def get_warn(self, ctx: commands.Context,
member: discord.Member = False): member: discord.Member = False):
query = """ await ctx.trigger_typing()
SELECT * FROM warns
WHERE created_at >= $1 AND server_id = $2
"""
query += """AND user_id = $3""" if member else ""
query += """ORDER BY created_at DESC"""
week_ago = datetime.datetime.now() - datetime.timedelta(weeks=6) week_ago = datetime.datetime.now() - datetime.timedelta(weeks=6)
async with self.bot.db.acquire() as con: if member:
await ctx.trigger_typing() warns = self.bot.engine \
args = [week_ago, ctx.guild.id] .query(Warn) \
if member: .filter(Warn.user_id == member.id, Warn.created_at > week_ago,
args.append(member.id) Warn.server_id == ctx.guild.id) \
.order_by(Warn.created_at.desc())
else:
warns = self.bot.engine \
.query(Warn) \
.filter(Warn.created_at > week_ago,
Warn.server_id == ctx.guild.id) \
.order_by(Warn.created_at.desc())
warns_list = ''
warns = await con.fetch(query, *args) for warn in warns:
warns_list = '' row_id = warn.id
user_id = warn.user_id
user = await self.bot.fetch_user(user_id)
reason = warn.reason
ago = humanize.naturaldelta(
datetime.datetime.now() - warn.created_at
)
for warn in warns: warns_list += f"[{row_id}] **{user}**: `{reason}` *({ago} ago)*\n"
row_id = warn.get('id')
user_id = warn.get('user_id')
user = await self.bot.fetch_user(user_id)
reason = warn.get('reason')
ago = humanize.naturaldelta(
datetime.datetime.now() - warn.get('created_at')
)
warns_list += f"[{row_id}] **{user}**: `{reason}` " \
f"*({ago} ago)*\n"
return warns_list, warns return warns_list, warns
async def add_warn(self, ctx: commands.Context, member: discord.Member,
reason):
now = datetime.datetime.now()
warn = Warn(server_id=ctx.guild.id, user_id=member.id, reason=reason,
created_at=now)
self.bot.engine.add(warn)
self.bot.engine.commit()
@commands.group(name='warn', aliases=['warns']) @commands.group(name='warn', aliases=['warns'])
async def _warn(self, ctx: commands.Context): async def _warn(self, ctx: commands.Context):
await ctx.trigger_typing()
if ctx.invoked_subcommand is None: if ctx.invoked_subcommand is None:
warns_list, warns = await self.get_warn(ctx) warns_list, warns = await self.get_warn(ctx)
e = discord.Embed( e = discord.Embed(
title=f"{len(warns)} {Texts('admin').get('last warns')}: ", title=f"{warns.count()} {Texts('admin').get('last warns')}: ",
description=warns_list description=warns_list
) )
await ctx.send(embed=e) await ctx.send(embed=e)
async def add_warn(self, ctx: commands.Context, member: discord.Member,
reason):
query = """
INSERT INTO warns (server_id, user_id, reason, created_at)
VALUES ($1, $2, $3, $4)
"""
now = datetime.datetime.now()
await self.bot.db.execute(query, ctx.guild.id, member.id, reason, now)
@_warn.command(name='add', aliases=['new']) @_warn.command(name='add', aliases=['new'])
async def _warn_new(self, ctx: commands.Context, member: discord.Member, async def _warn_new(self, ctx: commands.Context, member: discord.Member,
*, reason="N/A"): *, reason="N/A"):
member = await ctx.guild.fetch_member(member.id) member = await ctx.guild.fetch_member(member.id)
if not member: if not member:
return await ctx.send( return await ctx.send(
Texts('utils').get("Unable to find the user...") Texts('utils').get("Unable to find the user...")
) )
query = """
SELECT user_id, reason, created_at FROM warns
WHERE created_at >= $1 AND server_id = $2 and user_id = $3
"""
week_ago = datetime.datetime.now() - datetime.timedelta(weeks=6)
def check(pld: discord.RawReactionActionEvent): def check(pld: discord.RawReactionActionEvent):
if pld.message_id != choice.id \ if pld.message_id != choice.id \
or pld.user_id != ctx.author.id: or pld.user_id != ctx.author.id:
return False return False
return pld.emoji.name in ('1⃣', '2⃣', '3⃣') return pld.emoji.name in ('1⃣', '2⃣', '3⃣')
async with self.bot.db.acquire() as con: warns_list, warns = await self.get_warn(ctx)
await ctx.trigger_typing()
warns = await con.fetch(query, week_ago, ctx.guild.id, member.id)
if len(warns) >= 2: if warns.count() >= 3:
e = discord.Embed( e = discord.Embed(
title=Texts('admin').get('More than 2 warns'), title=Texts('admin').get('More than 2 warns'),
description=f"{member.mention} " description=f"{member.mention} "
+ Texts('admin').get('has more than 2 warns') + Texts('admin').get('has more than 2 warns')
)
e.add_field(
name='__Actions__',
value=':one: kick\n'
':two: ban\n'
':three: ' + Texts('admin').get('ignore')
)
choice = await ctx.send(embed=e)
for reaction in ('1⃣', '2⃣', '3⃣'):
await choice.add_reaction(reaction)
try:
payload = await self.bot.wait_for(
'raw_reaction_add',
check=check,
timeout=50.0
)
except asyncio.TimeoutError:
return await ctx.send(
Texts('admin').get('Took too long. Aborting.')
)
finally:
await choice.delete()
if payload.emoji.name == '1⃣':
from jishaku.models import copy_context_with
alt_ctx = await copy_context_with(
ctx,
content=f"{ctx.prefix}"
f"kick "
f"{member} "
f"{Texts('admin').get('More than 2 warns')}"
)
return await alt_ctx.command.invoke(alt_ctx)
elif payload.emoji.name == '2⃣':
from jishaku.models import copy_context_with
alt_ctx = await copy_context_with(
ctx,
content=f"{ctx.prefix}"
f"ban "
f"{member} "
f"{Texts('admin').get('More than 2 warns')}"
)
return await alt_ctx.command.invoke(alt_ctx)
await self.add_warn(ctx, member, reason)
await ctx.send(
content=f"{member.mention} "
f"**{Texts('admin').get('got a warn')}**"
f"\n**{Texts('admin').get('Reason')}:** `{reason}`"
if reason != 'N/A' else ''
) )
e.add_field(
name='__Actions__',
value=':one: kick\n'
':two: ban\n'
':three: ' + Texts('admin').get('ignore')
)
choice = await ctx.send(embed=e)
for reaction in ('1⃣', '2⃣', '3⃣'):
await choice.add_reaction(reaction)
try:
payload = await self.bot.wait_for(
'raw_reaction_add',
check=check,
timeout=50.0
)
except asyncio.TimeoutError:
return await ctx.send(
Texts('admin').get('Took too long. Aborting.')
)
finally:
await choice.delete()
if payload.emoji.name == '1⃣':
from jishaku.models import copy_context_with
alt_ctx = await copy_context_with(
ctx,
content=f"{ctx.prefix}"
f"kick "
f"{member} "
f"{Texts('admin').get('More than 2 warns')}"
)
return await alt_ctx.command.invoke(alt_ctx)
elif payload.emoji.name == '2⃣':
from jishaku.models import copy_context_with
alt_ctx = await copy_context_with(
ctx,
content=f"{ctx.prefix}"
f"ban "
f"{member} "
f"{Texts('admin').get('More than 2 warns')}"
)
return await alt_ctx.command.invoke(alt_ctx)
await self.add_warn(ctx, member, reason)
await ctx.send(
content=f"{member.mention} "
f"**{Texts('admin').get('got a warn')}**"
f"\n**{Texts('admin').get('Reason')}:** `{reason}`"
)
@_warn.command(name='remove', aliases=['revoke']) @_warn.command(name='remove', aliases=['revoke'])
async def _warn_remove(self, ctx: commands.Context, warn_id: int): async def _warn_remove(self, ctx: commands.Context, warn_id: int):
query = """ warn = self.bot.engine.query(Warn).filter(Warn.id == warn_id).one()
DELETE FROM warns self.bot.engine.delete(warn)
WHERE id = $1
"""
async with self.bot.db.acquire() as con:
await ctx.trigger_typing()
await con.fetch(query, warn_id)
await ctx.send(f"{Texts('admin').get('Warn with id')} `{warn_id}`" await ctx.send(f"{Texts('admin').get('Warn with id')} `{warn_id}`"
f" {Texts('admin').get('successfully removed')}") f" {Texts('admin').get('successfully removed')}")
@ -385,8 +372,9 @@ class Admin(commands.Cog):
@_warn.command(name='show', aliases=['list']) @_warn.command(name='show', aliases=['list'])
async def _warn_show(self, ctx: commands.Context, member: discord.Member): async def _warn_show(self, ctx: commands.Context, member: discord.Member):
warns_list, warns = await self.get_warn(ctx, member) warns_list, warns = await self.get_warn(ctx, member)
e = discord.Embed( e = discord.Embed(
title=f"{len(warns)} {Texts('admin').get('last warns')}: ", title=f"{warns.count()} {Texts('admin').get('last warns')}: ",
description=warns_list description=warns_list
) )
@ -394,27 +382,40 @@ class Admin(commands.Cog):
@_warn.command(name='edit', aliases=['change']) @_warn.command(name='edit', aliases=['change'])
async def _warn_edit(self, ctx: commands.Context, warn_id: int, *, reason): async def _warn_edit(self, ctx: commands.Context, warn_id: int, *, reason):
query = """ warn = self.bot.engine.query(Warn).filter(Warn.id == warn_id).one()
UPDATE warns warn.reason = reason
SET reason = $2 self.bot.engine.commit()
WHERE id = $1
"""
async with self.bot.db.acquire() as con:
await ctx.trigger_typing()
await con.fetch(query, warn_id, reason)
await ctx.send(f"{Texts('admin').get('Warn with id')} `{warn_id}`" await ctx.send(f"{Texts('admin').get('Warn with id')} `{warn_id}`"
f" {Texts('admin').get('successfully edited')}") f" {Texts('admin').get('successfully edited')}")
"""---------------------------------------------------------------------""" """---------------------------------------------------------------------"""
@commands.command(name='set-language', aliases=['set-lang']) @commands.command(name='language', aliases=['lang', 'langue', 'langage'])
async def _set_language(self, ctx: commands.Context, lang): async def _language(self, ctx: commands.Context, locale):
""" query = """
todo: set lang for guild SELECT locale
FROM lang
WHERE key = 'available'
""" """
async with self.bot.engine.begin() as con:
await ctx.trigger_typing()
available = list(await con.fetchrow(query))
if str(locale) in available:
query = """
IF EXISTS(SELECT * FROM lang WHERE key = $1 )
then
UPDATE lang
SET locale = $2
WHERE key = $1
ELSE
INSERT INTO lang (key, locale)
VALUES ($1, $2)
"""
await con.fetch(query, str(ctx.guild.id), str(locale))
def setup(bot: TuxBot): def setup(bot: TuxBot):
bot.add_cog(Admin(bot)) bot.add_cog(Admin(bot))

View file

@ -1,5 +1,4 @@
from .checks import * from .checks import *
from .config import * from .config import *
from .db import *
from .lang import * from .lang import *
from .version import * from .version import *

View file

@ -1,12 +0,0 @@
import logging
import asyncpg
log = logging.getLogger(__name__)
class Table:
@classmethod
async def create_pool(cls, uri, **kwargs) -> asyncpg.pool.Pool:
cls._pool = db = await asyncpg.create_pool(uri, **kwargs)
return db

View file

@ -5,12 +5,13 @@ import config
class Texts: class Texts:
def __init__(self, base: str = 'base'): def __init__(self, base: str = 'base'):
self.locale = config.locale self.locale = config.locale
self.texts = gettext.translation(base, localedir='extras/locales', self.base = base
languages=[self.locale])
self.texts.install()
def __str__(self) -> str:
return self.texts
def get(self, text: str) -> str: def get(self, text: str) -> str:
return self.texts.gettext(text) texts = gettext.translation(self.base, localedir='extras/locales',
languages=[self.locale])
texts.install()
return texts.gettext(text)
def set(self, lang: str):
self.locale = lang

20
cogs/utils/models/warn.py Normal file
View file

@ -0,0 +1,20 @@
import datetime
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String, BIGINT, TIMESTAMP
Base = declarative_base()
class Warn(Base):
__tablename__ = 'warns'
id = Column(Integer, primary_key=True)
server_id = Column(BIGINT)
user_id = Column(BIGINT)
reason = Column(String)
created_at = Column(TIMESTAMP, default=datetime.datetime.now())
def __repr__(self):
return "<Warn(server_id='%s', user_id='%s', reason='%s', " \
"created_at='%s')>"\
% (self.server_id, self.user_id, self.reason, self.created_at)

View file

@ -1,4 +1,9 @@
import asyncio try:
import config
from cogs.utils.lang import Texts
except ModuleNotFoundError:
import extras.first_run
import contextlib import contextlib
import logging import logging
import socket import socket
@ -9,13 +14,8 @@ import git
import requests import requests
from bot import TuxBot from bot import TuxBot
from cogs.utils.db import Table from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
try:
import config
from cogs.utils.lang import Texts
except ModuleNotFoundError:
import extras.first_run
@contextlib.contextmanager @contextlib.contextmanager
@ -39,36 +39,36 @@ def setup_logging():
yield yield
finally: finally:
handlers = log.handlers[:] handlers = log.handlers[:]
for hdlr in handlers: for handler in handlers:
hdlr.close() handler.close()
log.removeHandler(hdlr) log.removeHandler(handler)
def run_bot(unload: list = []): def run_bot(unload: list = []):
loop = asyncio.get_event_loop()
log = logging.getLogger() log = logging.getLogger()
print(Texts().get('Starting...')) print(Texts().get('Starting...'))
try: try:
db = loop.run_until_complete( engine = create_engine(config.postgresql)
Table.create_pool(config.postgresql, command_timeout=60)
) Session = sessionmaker()
Session.configure(bind=engine)
except socket.gaierror: except socket.gaierror:
click.echo(Texts().get("Could not set up PostgreSQL..."), click.echo(Texts().get("Could not set up PostgreSQL..."),
file=sys.stderr) file=sys.stderr)
log.exception(Texts().get("Could not set up PostgreSQL...")) log.exception(Texts().get("Could not set up PostgreSQL..."))
return return
bot = TuxBot(unload, db) bot = TuxBot(unload, Session())
bot.run() bot.run()
@click.command() @click.command()
@click.option('-d', '--unload', multiple=True, type=str, @click.option('-d', '--unload', multiple=True, type=str,
help=Texts().get("Launch without loading the <TEXT> module")) help=Texts().get("Launch without loading the <TEXT> module"))
@click.option('-u', '--update', help=Texts().get("Search for update"), @click.option('-u', '--update', is_flag=True,
is_flag=True) help=Texts().get("Search for update"))
def main(**kwargs): def main(**kwargs):
if kwargs.get('update'): if kwargs.get('update'):
_update() _update()
@ -77,8 +77,8 @@ def main(**kwargs):
run_bot(kwargs.get('unload')) run_bot(kwargs.get('unload'))
@click.option('-d', '--update', help=Texts().get("Search for update"), @click.option('-d', '--update', is_flag=True,
is_flag=True) help=Texts().get("Search for update"))
def _update(): def _update():
print(Texts().get("Checking for update...")) print(Texts().get("Checking for update..."))

View file

@ -4,6 +4,7 @@ jishaku
lxml lxml
click click
asyncpg>=0.12.0 asyncpg>=0.12.0
sqlalchemy
gitpython gitpython
requests requests
psutil psutil