refactor(database): migrate to sqlalchemy
This commit is contained in:
parent
8f17085cf7
commit
29808d41d6
8 changed files with 188 additions and 180 deletions
8
bot.py
8
bot.py
|
@ -4,7 +4,7 @@ import sys
|
||||||
from collections import deque, Counter
|
from collections import deque, Counter
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import asyncpg
|
import sqlalchemy
|
||||||
import discord
|
import discord
|
||||||
import git
|
import git
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
@ -40,7 +40,7 @@ async def _prefix_callable(bot, message: discord.message) -> list:
|
||||||
|
|
||||||
class TuxBot(commands.AutoShardedBot):
|
class TuxBot(commands.AutoShardedBot):
|
||||||
|
|
||||||
def __init__(self, unload: list, db: asyncpg.pool.Pool):
|
def __init__(self, unload: list, engine: sqlalchemy.engine.Engine):
|
||||||
super().__init__(command_prefix=_prefix_callable, pm_help=None,
|
super().__init__(command_prefix=_prefix_callable, pm_help=None,
|
||||||
help_command=None, description=description,
|
help_command=None, description=description,
|
||||||
help_attrs=dict(hidden=True),
|
help_attrs=dict(hidden=True),
|
||||||
|
@ -53,7 +53,7 @@ class TuxBot(commands.AutoShardedBot):
|
||||||
|
|
||||||
self.uptime: datetime = datetime.datetime.utcnow()
|
self.uptime: datetime = datetime.datetime.utcnow()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.db = db
|
self.engine = engine
|
||||||
self._prev_events = deque(maxlen=10)
|
self._prev_events = deque(maxlen=10)
|
||||||
self.session = aiohttp.ClientSession(loop=self.loop)
|
self.session = aiohttp.ClientSession(loop=self.loop)
|
||||||
|
|
||||||
|
@ -137,8 +137,6 @@ class TuxBot(commands.AutoShardedBot):
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
await super().close()
|
await super().close()
|
||||||
await self.db.close()
|
|
||||||
await self.session.close()
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
super().run(config.token, reconnect=True)
|
super().run(config.token, reconnect=True)
|
||||||
|
|
127
cogs/admin.py
127
cogs/admin.py
|
@ -3,12 +3,14 @@ import logging
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
import humanize
|
import humanize
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
from bot import TuxBot
|
from bot import TuxBot
|
||||||
from .utils.lang import Texts
|
from .utils.lang import Texts
|
||||||
|
from .utils.models.warn import Warn
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -63,6 +65,7 @@ class Admin(commands.Cog):
|
||||||
|
|
||||||
@commands.group(name='say', invoke_without_command=True)
|
@commands.group(name='say', invoke_without_command=True)
|
||||||
async def _say(self, ctx: commands.Context, *, content: str):
|
async def _say(self, ctx: commands.Context, *, content: str):
|
||||||
|
if ctx.invoked_subcommand is None:
|
||||||
try:
|
try:
|
||||||
await ctx.message.delete()
|
await ctx.message.delete()
|
||||||
except discord.errors.Forbidden:
|
except discord.errors.Forbidden:
|
||||||
|
@ -226,86 +229,77 @@ class Admin(commands.Cog):
|
||||||
|
|
||||||
async def get_warn(self, ctx: commands.Context,
|
async def get_warn(self, ctx: commands.Context,
|
||||||
member: discord.Member = False):
|
member: discord.Member = False):
|
||||||
query = """
|
await ctx.trigger_typing()
|
||||||
SELECT * FROM warns
|
|
||||||
WHERE created_at >= $1 AND server_id = $2
|
|
||||||
"""
|
|
||||||
query += """AND user_id = $3""" if member else ""
|
|
||||||
query += """ORDER BY created_at DESC"""
|
|
||||||
week_ago = datetime.datetime.now() - datetime.timedelta(weeks=6)
|
week_ago = datetime.datetime.now() - datetime.timedelta(weeks=6)
|
||||||
|
|
||||||
async with self.bot.db.acquire() as con:
|
|
||||||
await ctx.trigger_typing()
|
|
||||||
args = [week_ago, ctx.guild.id]
|
|
||||||
if member:
|
if member:
|
||||||
args.append(member.id)
|
warns = self.bot.engine \
|
||||||
|
.query(Warn) \
|
||||||
warns = await con.fetch(query, *args)
|
.filter(Warn.user_id == member.id, Warn.created_at > week_ago,
|
||||||
|
Warn.server_id == ctx.guild.id) \
|
||||||
|
.order_by(Warn.created_at.desc())
|
||||||
|
else:
|
||||||
|
warns = self.bot.engine \
|
||||||
|
.query(Warn) \
|
||||||
|
.filter(Warn.created_at > week_ago,
|
||||||
|
Warn.server_id == ctx.guild.id) \
|
||||||
|
.order_by(Warn.created_at.desc())
|
||||||
warns_list = ''
|
warns_list = ''
|
||||||
|
|
||||||
for warn in warns:
|
for warn in warns:
|
||||||
row_id = warn.get('id')
|
row_id = warn.id
|
||||||
user_id = warn.get('user_id')
|
user_id = warn.user_id
|
||||||
user = await self.bot.fetch_user(user_id)
|
user = await self.bot.fetch_user(user_id)
|
||||||
reason = warn.get('reason')
|
reason = warn.reason
|
||||||
ago = humanize.naturaldelta(
|
ago = humanize.naturaldelta(
|
||||||
datetime.datetime.now() - warn.get('created_at')
|
datetime.datetime.now() - warn.created_at
|
||||||
)
|
)
|
||||||
|
|
||||||
warns_list += f"[{row_id}] **{user}**: `{reason}` " \
|
warns_list += f"[{row_id}] **{user}**: `{reason}` *({ago} ago)*\n"
|
||||||
f"*({ago} ago)*\n"
|
|
||||||
|
|
||||||
return warns_list, warns
|
return warns_list, warns
|
||||||
|
|
||||||
|
async def add_warn(self, ctx: commands.Context, member: discord.Member,
|
||||||
|
reason):
|
||||||
|
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
warn = Warn(server_id=ctx.guild.id, user_id=member.id, reason=reason,
|
||||||
|
created_at=now)
|
||||||
|
|
||||||
|
self.bot.engine.add(warn)
|
||||||
|
self.bot.engine.commit()
|
||||||
|
|
||||||
@commands.group(name='warn', aliases=['warns'])
|
@commands.group(name='warn', aliases=['warns'])
|
||||||
async def _warn(self, ctx: commands.Context):
|
async def _warn(self, ctx: commands.Context):
|
||||||
|
await ctx.trigger_typing()
|
||||||
if ctx.invoked_subcommand is None:
|
if ctx.invoked_subcommand is None:
|
||||||
warns_list, warns = await self.get_warn(ctx)
|
warns_list, warns = await self.get_warn(ctx)
|
||||||
e = discord.Embed(
|
e = discord.Embed(
|
||||||
title=f"{len(warns)} {Texts('admin').get('last warns')}: ",
|
title=f"{warns.count()} {Texts('admin').get('last warns')}: ",
|
||||||
description=warns_list
|
description=warns_list
|
||||||
)
|
)
|
||||||
|
|
||||||
await ctx.send(embed=e)
|
await ctx.send(embed=e)
|
||||||
|
|
||||||
async def add_warn(self, ctx: commands.Context, member: discord.Member,
|
|
||||||
reason):
|
|
||||||
|
|
||||||
query = """
|
|
||||||
INSERT INTO warns (server_id, user_id, reason, created_at)
|
|
||||||
VALUES ($1, $2, $3, $4)
|
|
||||||
"""
|
|
||||||
|
|
||||||
now = datetime.datetime.now()
|
|
||||||
await self.bot.db.execute(query, ctx.guild.id, member.id, reason, now)
|
|
||||||
|
|
||||||
@_warn.command(name='add', aliases=['new'])
|
@_warn.command(name='add', aliases=['new'])
|
||||||
async def _warn_new(self, ctx: commands.Context, member: discord.Member,
|
async def _warn_new(self, ctx: commands.Context, member: discord.Member,
|
||||||
*, reason="N/A"):
|
*, reason="N/A"):
|
||||||
|
|
||||||
member = await ctx.guild.fetch_member(member.id)
|
member = await ctx.guild.fetch_member(member.id)
|
||||||
if not member:
|
if not member:
|
||||||
return await ctx.send(
|
return await ctx.send(
|
||||||
Texts('utils').get("Unable to find the user...")
|
Texts('utils').get("Unable to find the user...")
|
||||||
)
|
)
|
||||||
|
|
||||||
query = """
|
|
||||||
SELECT user_id, reason, created_at FROM warns
|
|
||||||
WHERE created_at >= $1 AND server_id = $2 and user_id = $3
|
|
||||||
"""
|
|
||||||
week_ago = datetime.datetime.now() - datetime.timedelta(weeks=6)
|
|
||||||
|
|
||||||
def check(pld: discord.RawReactionActionEvent):
|
def check(pld: discord.RawReactionActionEvent):
|
||||||
if pld.message_id != choice.id \
|
if pld.message_id != choice.id \
|
||||||
or pld.user_id != ctx.author.id:
|
or pld.user_id != ctx.author.id:
|
||||||
return False
|
return False
|
||||||
return pld.emoji.name in ('1⃣', '2⃣', '3⃣')
|
return pld.emoji.name in ('1⃣', '2⃣', '3⃣')
|
||||||
|
|
||||||
async with self.bot.db.acquire() as con:
|
warns_list, warns = await self.get_warn(ctx)
|
||||||
await ctx.trigger_typing()
|
|
||||||
warns = await con.fetch(query, week_ago, ctx.guild.id, member.id)
|
|
||||||
|
|
||||||
if len(warns) >= 2:
|
if warns.count() >= 3:
|
||||||
e = discord.Embed(
|
e = discord.Embed(
|
||||||
title=Texts('admin').get('More than 2 warns'),
|
title=Texts('admin').get('More than 2 warns'),
|
||||||
description=f"{member.mention} "
|
description=f"{member.mention} "
|
||||||
|
@ -365,19 +359,12 @@ class Admin(commands.Cog):
|
||||||
content=f"{member.mention} "
|
content=f"{member.mention} "
|
||||||
f"**{Texts('admin').get('got a warn')}**"
|
f"**{Texts('admin').get('got a warn')}**"
|
||||||
f"\n**{Texts('admin').get('Reason')}:** `{reason}`"
|
f"\n**{Texts('admin').get('Reason')}:** `{reason}`"
|
||||||
if reason != 'N/A' else ''
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@_warn.command(name='remove', aliases=['revoke'])
|
@_warn.command(name='remove', aliases=['revoke'])
|
||||||
async def _warn_remove(self, ctx: commands.Context, warn_id: int):
|
async def _warn_remove(self, ctx: commands.Context, warn_id: int):
|
||||||
query = """
|
warn = self.bot.engine.query(Warn).filter(Warn.id == warn_id).one()
|
||||||
DELETE FROM warns
|
self.bot.engine.delete(warn)
|
||||||
WHERE id = $1
|
|
||||||
"""
|
|
||||||
|
|
||||||
async with self.bot.db.acquire() as con:
|
|
||||||
await ctx.trigger_typing()
|
|
||||||
await con.fetch(query, warn_id)
|
|
||||||
|
|
||||||
await ctx.send(f"{Texts('admin').get('Warn with id')} `{warn_id}`"
|
await ctx.send(f"{Texts('admin').get('Warn with id')} `{warn_id}`"
|
||||||
f" {Texts('admin').get('successfully removed')}")
|
f" {Texts('admin').get('successfully removed')}")
|
||||||
|
@ -385,8 +372,9 @@ class Admin(commands.Cog):
|
||||||
@_warn.command(name='show', aliases=['list'])
|
@_warn.command(name='show', aliases=['list'])
|
||||||
async def _warn_show(self, ctx: commands.Context, member: discord.Member):
|
async def _warn_show(self, ctx: commands.Context, member: discord.Member):
|
||||||
warns_list, warns = await self.get_warn(ctx, member)
|
warns_list, warns = await self.get_warn(ctx, member)
|
||||||
|
|
||||||
e = discord.Embed(
|
e = discord.Embed(
|
||||||
title=f"{len(warns)} {Texts('admin').get('last warns')}: ",
|
title=f"{warns.count()} {Texts('admin').get('last warns')}: ",
|
||||||
description=warns_list
|
description=warns_list
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -394,26 +382,39 @@ class Admin(commands.Cog):
|
||||||
|
|
||||||
@_warn.command(name='edit', aliases=['change'])
|
@_warn.command(name='edit', aliases=['change'])
|
||||||
async def _warn_edit(self, ctx: commands.Context, warn_id: int, *, reason):
|
async def _warn_edit(self, ctx: commands.Context, warn_id: int, *, reason):
|
||||||
query = """
|
warn = self.bot.engine.query(Warn).filter(Warn.id == warn_id).one()
|
||||||
UPDATE warns
|
warn.reason = reason
|
||||||
SET reason = $2
|
self.bot.engine.commit()
|
||||||
WHERE id = $1
|
|
||||||
"""
|
|
||||||
|
|
||||||
async with self.bot.db.acquire() as con:
|
|
||||||
await ctx.trigger_typing()
|
|
||||||
await con.fetch(query, warn_id, reason)
|
|
||||||
|
|
||||||
await ctx.send(f"{Texts('admin').get('Warn with id')} `{warn_id}`"
|
await ctx.send(f"{Texts('admin').get('Warn with id')} `{warn_id}`"
|
||||||
f" {Texts('admin').get('successfully edited')}")
|
f" {Texts('admin').get('successfully edited')}")
|
||||||
|
|
||||||
"""---------------------------------------------------------------------"""
|
"""---------------------------------------------------------------------"""
|
||||||
|
|
||||||
@commands.command(name='set-language', aliases=['set-lang'])
|
@commands.command(name='language', aliases=['lang', 'langue', 'langage'])
|
||||||
async def _set_language(self, ctx: commands.Context, lang):
|
async def _language(self, ctx: commands.Context, locale):
|
||||||
|
query = """
|
||||||
|
SELECT locale
|
||||||
|
FROM lang
|
||||||
|
WHERE key = 'available'
|
||||||
"""
|
"""
|
||||||
todo: set lang for guild
|
|
||||||
|
async with self.bot.engine.begin() as con:
|
||||||
|
await ctx.trigger_typing()
|
||||||
|
available = list(await con.fetchrow(query))
|
||||||
|
|
||||||
|
if str(locale) in available:
|
||||||
|
query = """
|
||||||
|
IF EXISTS(SELECT * FROM lang WHERE key = $1 )
|
||||||
|
then
|
||||||
|
UPDATE lang
|
||||||
|
SET locale = $2
|
||||||
|
WHERE key = $1
|
||||||
|
ELSE
|
||||||
|
INSERT INTO lang (key, locale)
|
||||||
|
VALUES ($1, $2)
|
||||||
"""
|
"""
|
||||||
|
await con.fetch(query, str(ctx.guild.id), str(locale))
|
||||||
|
|
||||||
|
|
||||||
def setup(bot: TuxBot):
|
def setup(bot: TuxBot):
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from .checks import *
|
from .checks import *
|
||||||
from .config import *
|
from .config import *
|
||||||
from .db import *
|
|
||||||
from .lang import *
|
from .lang import *
|
||||||
from .version import *
|
from .version import *
|
|
@ -1,12 +0,0 @@
|
||||||
import logging
|
|
||||||
|
|
||||||
import asyncpg
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Table:
|
|
||||||
@classmethod
|
|
||||||
async def create_pool(cls, uri, **kwargs) -> asyncpg.pool.Pool:
|
|
||||||
cls._pool = db = await asyncpg.create_pool(uri, **kwargs)
|
|
||||||
return db
|
|
|
@ -5,12 +5,13 @@ import config
|
||||||
class Texts:
|
class Texts:
|
||||||
def __init__(self, base: str = 'base'):
|
def __init__(self, base: str = 'base'):
|
||||||
self.locale = config.locale
|
self.locale = config.locale
|
||||||
self.texts = gettext.translation(base, localedir='extras/locales',
|
self.base = base
|
||||||
languages=[self.locale])
|
|
||||||
self.texts.install()
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return self.texts
|
|
||||||
|
|
||||||
def get(self, text: str) -> str:
|
def get(self, text: str) -> str:
|
||||||
return self.texts.gettext(text)
|
texts = gettext.translation(self.base, localedir='extras/locales',
|
||||||
|
languages=[self.locale])
|
||||||
|
texts.install()
|
||||||
|
return texts.gettext(text)
|
||||||
|
|
||||||
|
def set(self, lang: str):
|
||||||
|
self.locale = lang
|
||||||
|
|
20
cogs/utils/models/warn.py
Normal file
20
cogs/utils/models/warn.py
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy import Column, Integer, String, BIGINT, TIMESTAMP
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
class Warn(Base):
|
||||||
|
__tablename__ = 'warns'
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
server_id = Column(BIGINT)
|
||||||
|
user_id = Column(BIGINT)
|
||||||
|
reason = Column(String)
|
||||||
|
created_at = Column(TIMESTAMP, default=datetime.datetime.now())
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<Warn(server_id='%s', user_id='%s', reason='%s', " \
|
||||||
|
"created_at='%s')>"\
|
||||||
|
% (self.server_id, self.user_id, self.reason, self.created_at)
|
40
launcher.py
40
launcher.py
|
@ -1,4 +1,9 @@
|
||||||
import asyncio
|
try:
|
||||||
|
import config
|
||||||
|
from cogs.utils.lang import Texts
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
import extras.first_run
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
|
@ -9,13 +14,8 @@ import git
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from bot import TuxBot
|
from bot import TuxBot
|
||||||
from cogs.utils.db import Table
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
try:
|
|
||||||
import config
|
|
||||||
from cogs.utils.lang import Texts
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
import extras.first_run
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
|
@ -39,36 +39,36 @@ def setup_logging():
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
handlers = log.handlers[:]
|
handlers = log.handlers[:]
|
||||||
for hdlr in handlers:
|
for handler in handlers:
|
||||||
hdlr.close()
|
handler.close()
|
||||||
log.removeHandler(hdlr)
|
log.removeHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
def run_bot(unload: list = []):
|
def run_bot(unload: list = []):
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
log = logging.getLogger()
|
log = logging.getLogger()
|
||||||
|
|
||||||
print(Texts().get('Starting...'))
|
print(Texts().get('Starting...'))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db = loop.run_until_complete(
|
engine = create_engine(config.postgresql)
|
||||||
Table.create_pool(config.postgresql, command_timeout=60)
|
|
||||||
)
|
Session = sessionmaker()
|
||||||
|
Session.configure(bind=engine)
|
||||||
except socket.gaierror:
|
except socket.gaierror:
|
||||||
click.echo(Texts().get("Could not set up PostgreSQL..."),
|
click.echo(Texts().get("Could not set up PostgreSQL..."),
|
||||||
file=sys.stderr)
|
file=sys.stderr)
|
||||||
log.exception(Texts().get("Could not set up PostgreSQL..."))
|
log.exception(Texts().get("Could not set up PostgreSQL..."))
|
||||||
return
|
return
|
||||||
|
|
||||||
bot = TuxBot(unload, db)
|
bot = TuxBot(unload, Session())
|
||||||
bot.run()
|
bot.run()
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option('-d', '--unload', multiple=True, type=str,
|
@click.option('-d', '--unload', multiple=True, type=str,
|
||||||
help=Texts().get("Launch without loading the <TEXT> module"))
|
help=Texts().get("Launch without loading the <TEXT> module"))
|
||||||
@click.option('-u', '--update', help=Texts().get("Search for update"),
|
@click.option('-u', '--update', is_flag=True,
|
||||||
is_flag=True)
|
help=Texts().get("Search for update"))
|
||||||
def main(**kwargs):
|
def main(**kwargs):
|
||||||
if kwargs.get('update'):
|
if kwargs.get('update'):
|
||||||
_update()
|
_update()
|
||||||
|
@ -77,8 +77,8 @@ def main(**kwargs):
|
||||||
run_bot(kwargs.get('unload'))
|
run_bot(kwargs.get('unload'))
|
||||||
|
|
||||||
|
|
||||||
@click.option('-d', '--update', help=Texts().get("Search for update"),
|
@click.option('-d', '--update', is_flag=True,
|
||||||
is_flag=True)
|
help=Texts().get("Search for update"))
|
||||||
def _update():
|
def _update():
|
||||||
print(Texts().get("Checking for update..."))
|
print(Texts().get("Checking for update..."))
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ jishaku
|
||||||
lxml
|
lxml
|
||||||
click
|
click
|
||||||
asyncpg>=0.12.0
|
asyncpg>=0.12.0
|
||||||
|
sqlalchemy
|
||||||
gitpython
|
gitpython
|
||||||
requests
|
requests
|
||||||
psutil
|
psutil
|
Loading…
Reference in a new issue