refactor(database): migrate to sqlalchemy

This commit is contained in:
Romain J 2019-09-29 18:31:01 +02:00
parent 8f17085cf7
commit 29808d41d6
8 changed files with 188 additions and 180 deletions

8
bot.py
View file

@ -4,7 +4,7 @@ import sys
from collections import deque, Counter
import aiohttp
import asyncpg
import sqlalchemy
import discord
import git
from discord.ext import commands
@ -40,7 +40,7 @@ async def _prefix_callable(bot, message: discord.message) -> list:
class TuxBot(commands.AutoShardedBot):
def __init__(self, unload: list, db: asyncpg.pool.Pool):
def __init__(self, unload: list, engine: sqlalchemy.engine.Engine):
super().__init__(command_prefix=_prefix_callable, pm_help=None,
help_command=None, description=description,
help_attrs=dict(hidden=True),
@ -53,7 +53,7 @@ class TuxBot(commands.AutoShardedBot):
self.uptime: datetime = datetime.datetime.utcnow()
self.config = config
self.db = db
self.engine = engine
self._prev_events = deque(maxlen=10)
self.session = aiohttp.ClientSession(loop=self.loop)
@ -137,8 +137,6 @@ class TuxBot(commands.AutoShardedBot):
async def close(self):
await super().close()
await self.db.close()
await self.session.close()
def run(self):
super().run(config.token, reconnect=True)

View file

@ -3,12 +3,14 @@ import logging
from typing import Union
import asyncio
import discord
import humanize
from discord.ext import commands
from bot import TuxBot
from .utils.lang import Texts
from .utils.models.warn import Warn
log = logging.getLogger(__name__)
@ -63,12 +65,13 @@ class Admin(commands.Cog):
@commands.group(name='say', invoke_without_command=True)
async def _say(self, ctx: commands.Context, *, content: str):
try:
await ctx.message.delete()
except discord.errors.Forbidden:
pass
if ctx.invoked_subcommand is None:
try:
await ctx.message.delete()
except discord.errors.Forbidden:
pass
await ctx.send(content)
await ctx.send(content)
@_say.command(name='edit')
async def _say_edit(self, ctx: commands.Context, message_id: int, *,
@ -226,158 +229,142 @@ class Admin(commands.Cog):
async def get_warn(self, ctx: commands.Context,
member: discord.Member = False):
query = """
SELECT * FROM warns
WHERE created_at >= $1 AND server_id = $2
"""
query += """AND user_id = $3""" if member else ""
query += """ORDER BY created_at DESC"""
await ctx.trigger_typing()
week_ago = datetime.datetime.now() - datetime.timedelta(weeks=6)
async with self.bot.db.acquire() as con:
await ctx.trigger_typing()
args = [week_ago, ctx.guild.id]
if member:
args.append(member.id)
if member:
warns = self.bot.engine \
.query(Warn) \
.filter(Warn.user_id == member.id, Warn.created_at > week_ago,
Warn.server_id == ctx.guild.id) \
.order_by(Warn.created_at.desc())
else:
warns = self.bot.engine \
.query(Warn) \
.filter(Warn.created_at > week_ago,
Warn.server_id == ctx.guild.id) \
.order_by(Warn.created_at.desc())
warns_list = ''
warns = await con.fetch(query, *args)
warns_list = ''
for warn in warns:
row_id = warn.id
user_id = warn.user_id
user = await self.bot.fetch_user(user_id)
reason = warn.reason
ago = humanize.naturaldelta(
datetime.datetime.now() - warn.created_at
)
for warn in warns:
row_id = warn.get('id')
user_id = warn.get('user_id')
user = await self.bot.fetch_user(user_id)
reason = warn.get('reason')
ago = humanize.naturaldelta(
datetime.datetime.now() - warn.get('created_at')
)
warns_list += f"[{row_id}] **{user}**: `{reason}` " \
f"*({ago} ago)*\n"
warns_list += f"[{row_id}] **{user}**: `{reason}` *({ago} ago)*\n"
return warns_list, warns
async def add_warn(self, ctx: commands.Context, member: discord.Member,
reason):
now = datetime.datetime.now()
warn = Warn(server_id=ctx.guild.id, user_id=member.id, reason=reason,
created_at=now)
self.bot.engine.add(warn)
self.bot.engine.commit()
@commands.group(name='warn', aliases=['warns'])
async def _warn(self, ctx: commands.Context):
await ctx.trigger_typing()
if ctx.invoked_subcommand is None:
warns_list, warns = await self.get_warn(ctx)
e = discord.Embed(
title=f"{len(warns)} {Texts('admin').get('last warns')}: ",
title=f"{warns.count()} {Texts('admin').get('last warns')}: ",
description=warns_list
)
await ctx.send(embed=e)
async def add_warn(self, ctx: commands.Context, member: discord.Member,
reason):
query = """
INSERT INTO warns (server_id, user_id, reason, created_at)
VALUES ($1, $2, $3, $4)
"""
now = datetime.datetime.now()
await self.bot.db.execute(query, ctx.guild.id, member.id, reason, now)
@_warn.command(name='add', aliases=['new'])
async def _warn_new(self, ctx: commands.Context, member: discord.Member,
*, reason="N/A"):
member = await ctx.guild.fetch_member(member.id)
if not member:
return await ctx.send(
Texts('utils').get("Unable to find the user...")
)
query = """
SELECT user_id, reason, created_at FROM warns
WHERE created_at >= $1 AND server_id = $2 and user_id = $3
"""
week_ago = datetime.datetime.now() - datetime.timedelta(weeks=6)
def check(pld: discord.RawReactionActionEvent):
if pld.message_id != choice.id \
or pld.user_id != ctx.author.id:
return False
return pld.emoji.name in ('1⃣', '2⃣', '3⃣')
async with self.bot.db.acquire() as con:
await ctx.trigger_typing()
warns = await con.fetch(query, week_ago, ctx.guild.id, member.id)
warns_list, warns = await self.get_warn(ctx)
if len(warns) >= 2:
e = discord.Embed(
title=Texts('admin').get('More than 2 warns'),
description=f"{member.mention} "
+ Texts('admin').get('has more than 2 warns')
)
e.add_field(
name='__Actions__',
value=':one: kick\n'
':two: ban\n'
':three: ' + Texts('admin').get('ignore')
)
choice = await ctx.send(embed=e)
for reaction in ('1⃣', '2⃣', '3⃣'):
await choice.add_reaction(reaction)
try:
payload = await self.bot.wait_for(
'raw_reaction_add',
check=check,
timeout=50.0
)
except asyncio.TimeoutError:
return await ctx.send(
Texts('admin').get('Took too long. Aborting.')
)
finally:
await choice.delete()
if payload.emoji.name == '1⃣':
from jishaku.models import copy_context_with
alt_ctx = await copy_context_with(
ctx,
content=f"{ctx.prefix}"
f"kick "
f"{member} "
f"{Texts('admin').get('More than 2 warns')}"
)
return await alt_ctx.command.invoke(alt_ctx)
elif payload.emoji.name == '2⃣':
from jishaku.models import copy_context_with
alt_ctx = await copy_context_with(
ctx,
content=f"{ctx.prefix}"
f"ban "
f"{member} "
f"{Texts('admin').get('More than 2 warns')}"
)
return await alt_ctx.command.invoke(alt_ctx)
await self.add_warn(ctx, member, reason)
await ctx.send(
content=f"{member.mention} "
f"**{Texts('admin').get('got a warn')}**"
f"\n**{Texts('admin').get('Reason')}:** `{reason}`"
if reason != 'N/A' else ''
if warns.count() >= 3:
e = discord.Embed(
title=Texts('admin').get('More than 2 warns'),
description=f"{member.mention} "
+ Texts('admin').get('has more than 2 warns')
)
e.add_field(
name='__Actions__',
value=':one: kick\n'
':two: ban\n'
':three: ' + Texts('admin').get('ignore')
)
choice = await ctx.send(embed=e)
for reaction in ('1⃣', '2⃣', '3⃣'):
await choice.add_reaction(reaction)
try:
payload = await self.bot.wait_for(
'raw_reaction_add',
check=check,
timeout=50.0
)
except asyncio.TimeoutError:
return await ctx.send(
Texts('admin').get('Took too long. Aborting.')
)
finally:
await choice.delete()
if payload.emoji.name == '1⃣':
from jishaku.models import copy_context_with
alt_ctx = await copy_context_with(
ctx,
content=f"{ctx.prefix}"
f"kick "
f"{member} "
f"{Texts('admin').get('More than 2 warns')}"
)
return await alt_ctx.command.invoke(alt_ctx)
elif payload.emoji.name == '2⃣':
from jishaku.models import copy_context_with
alt_ctx = await copy_context_with(
ctx,
content=f"{ctx.prefix}"
f"ban "
f"{member} "
f"{Texts('admin').get('More than 2 warns')}"
)
return await alt_ctx.command.invoke(alt_ctx)
await self.add_warn(ctx, member, reason)
await ctx.send(
content=f"{member.mention} "
f"**{Texts('admin').get('got a warn')}**"
f"\n**{Texts('admin').get('Reason')}:** `{reason}`"
)
@_warn.command(name='remove', aliases=['revoke'])
async def _warn_remove(self, ctx: commands.Context, warn_id: int):
query = """
DELETE FROM warns
WHERE id = $1
"""
async with self.bot.db.acquire() as con:
await ctx.trigger_typing()
await con.fetch(query, warn_id)
warn = self.bot.engine.query(Warn).filter(Warn.id == warn_id).one()
self.bot.engine.delete(warn)
await ctx.send(f"{Texts('admin').get('Warn with id')} `{warn_id}`"
f" {Texts('admin').get('successfully removed')}")
@ -385,8 +372,9 @@ class Admin(commands.Cog):
@_warn.command(name='show', aliases=['list'])
async def _warn_show(self, ctx: commands.Context, member: discord.Member):
warns_list, warns = await self.get_warn(ctx, member)
e = discord.Embed(
title=f"{len(warns)} {Texts('admin').get('last warns')}: ",
title=f"{warns.count()} {Texts('admin').get('last warns')}: ",
description=warns_list
)
@ -394,27 +382,40 @@ class Admin(commands.Cog):
@_warn.command(name='edit', aliases=['change'])
async def _warn_edit(self, ctx: commands.Context, warn_id: int, *, reason):
query = """
UPDATE warns
SET reason = $2
WHERE id = $1
"""
async with self.bot.db.acquire() as con:
await ctx.trigger_typing()
await con.fetch(query, warn_id, reason)
warn = self.bot.engine.query(Warn).filter(Warn.id == warn_id).one()
warn.reason = reason
self.bot.engine.commit()
await ctx.send(f"{Texts('admin').get('Warn with id')} `{warn_id}`"
f" {Texts('admin').get('successfully edited')}")
"""---------------------------------------------------------------------"""
@commands.command(name='set-language', aliases=['set-lang'])
async def _set_language(self, ctx: commands.Context, lang):
"""
todo: set lang for guild
@commands.command(name='language', aliases=['lang', 'langue', 'langage'])
async def _language(self, ctx: commands.Context, locale):
query = """
SELECT locale
FROM lang
WHERE key = 'available'
"""
async with self.bot.engine.begin() as con:
await ctx.trigger_typing()
available = list(await con.fetchrow(query))
if str(locale) in available:
query = """
IF EXISTS(SELECT * FROM lang WHERE key = $1 )
then
UPDATE lang
SET locale = $2
WHERE key = $1
ELSE
INSERT INTO lang (key, locale)
VALUES ($1, $2)
"""
await con.fetch(query, str(ctx.guild.id), str(locale))
def setup(bot: TuxBot):
bot.add_cog(Admin(bot))

View file

@ -1,5 +1,4 @@
from .checks import *
from .config import *
from .db import *
from .lang import *
from .version import *
from .version import *

View file

@ -1,12 +0,0 @@
import logging
import asyncpg
log = logging.getLogger(__name__)
class Table:
@classmethod
async def create_pool(cls, uri, **kwargs) -> asyncpg.pool.Pool:
cls._pool = db = await asyncpg.create_pool(uri, **kwargs)
return db

View file

@ -5,12 +5,13 @@ import config
class Texts:
def __init__(self, base: str = 'base'):
self.locale = config.locale
self.texts = gettext.translation(base, localedir='extras/locales',
languages=[self.locale])
self.texts.install()
def __str__(self) -> str:
return self.texts
self.base = base
def get(self, text: str) -> str:
return self.texts.gettext(text)
texts = gettext.translation(self.base, localedir='extras/locales',
languages=[self.locale])
texts.install()
return texts.gettext(text)
def set(self, lang: str):
self.locale = lang

20
cogs/utils/models/warn.py Normal file
View file

@ -0,0 +1,20 @@
import datetime
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String, BIGINT, TIMESTAMP
Base = declarative_base()
class Warn(Base):
__tablename__ = 'warns'
id = Column(Integer, primary_key=True)
server_id = Column(BIGINT)
user_id = Column(BIGINT)
reason = Column(String)
created_at = Column(TIMESTAMP, default=datetime.datetime.now())
def __repr__(self):
return "<Warn(server_id='%s', user_id='%s', reason='%s', " \
"created_at='%s')>"\
% (self.server_id, self.user_id, self.reason, self.created_at)

View file

@ -1,4 +1,9 @@
import asyncio
try:
import config
from cogs.utils.lang import Texts
except ModuleNotFoundError:
import extras.first_run
import contextlib
import logging
import socket
@ -9,13 +14,8 @@ import git
import requests
from bot import TuxBot
from cogs.utils.db import Table
try:
import config
from cogs.utils.lang import Texts
except ModuleNotFoundError:
import extras.first_run
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
@contextlib.contextmanager
@ -39,36 +39,36 @@ def setup_logging():
yield
finally:
handlers = log.handlers[:]
for hdlr in handlers:
hdlr.close()
log.removeHandler(hdlr)
for handler in handlers:
handler.close()
log.removeHandler(handler)
def run_bot(unload: list = []):
loop = asyncio.get_event_loop()
log = logging.getLogger()
print(Texts().get('Starting...'))
try:
db = loop.run_until_complete(
Table.create_pool(config.postgresql, command_timeout=60)
)
engine = create_engine(config.postgresql)
Session = sessionmaker()
Session.configure(bind=engine)
except socket.gaierror:
click.echo(Texts().get("Could not set up PostgreSQL..."),
file=sys.stderr)
log.exception(Texts().get("Could not set up PostgreSQL..."))
return
bot = TuxBot(unload, db)
bot = TuxBot(unload, Session())
bot.run()
@click.command()
@click.option('-d', '--unload', multiple=True, type=str,
help=Texts().get("Launch without loading the <TEXT> module"))
@click.option('-u', '--update', help=Texts().get("Search for update"),
is_flag=True)
@click.option('-u', '--update', is_flag=True,
help=Texts().get("Search for update"))
def main(**kwargs):
if kwargs.get('update'):
_update()
@ -77,8 +77,8 @@ def main(**kwargs):
run_bot(kwargs.get('unload'))
@click.option('-d', '--update', help=Texts().get("Search for update"),
is_flag=True)
@click.option('-d', '--update', is_flag=True,
help=Texts().get("Search for update"))
def _update():
print(Texts().get("Checking for update..."))

View file

@ -4,6 +4,7 @@ jishaku
lxml
click
asyncpg>=0.12.0
sqlalchemy
gitpython
requests
psutil