From 98b241d51b26cb61c2ad19c1b3296f526f7b06ca Mon Sep 17 00:00:00 2001 From: Romain J Date: Sun, 6 Oct 2019 23:49:00 +0200 Subject: [PATCH] refactor(command|sondage): continue rewrite of sondage known issues: more than 2 votes are not peristed --- bot.py | 7 +-- cogs/admin.py | 34 +++++++----- cogs/basics.py | 3 + cogs/poll.py | 112 ++++++++++++++++++++++++++------------ cogs/utils/database.py | 11 ++++ cogs/utils/lang.py | 14 ++--- cogs/utils/models/poll.py | 7 ++- launcher.py | 10 +--- requirements.txt | 3 +- 9 files changed, 129 insertions(+), 72 deletions(-) create mode 100644 cogs/utils/database.py diff --git a/bot.py b/bot.py index 5ca163b..d4cd75b 100755 --- a/bot.py +++ b/bot.py @@ -4,7 +4,6 @@ import sys from collections import deque, Counter import aiohttp -import sqlalchemy import discord import git from discord.ext import commands @@ -41,7 +40,7 @@ async def _prefix_callable(bot, message: discord.message) -> list: class TuxBot(commands.AutoShardedBot): - def __init__(self, unload: list, engine: sqlalchemy.engine.Engine): + def __init__(self, unload: list, database): super().__init__(command_prefix=_prefix_callable, pm_help=None, help_command=None, description=description, help_attrs=dict(hidden=True), @@ -54,14 +53,14 @@ class TuxBot(commands.AutoShardedBot): self.uptime: datetime = datetime.datetime.utcnow() self.config = config - self.engine = engine + self.database = database self._prev_events = deque(maxlen=10) self.session = aiohttp.ClientSession(loop=self.loop) self.prefixes = Config('prefixes.json') self.blacklist = Config('blacklist.json') - self.version = Version(10, 0, 0, pre_release='a20', build=build) + self.version = Version(10, 0, 0, pre_release='a21', build=build) for extension in l_extensions: if extension not in unload: diff --git a/cogs/admin.py b/cogs/admin.py index 7ad57f1..dcf82f6 100644 --- a/cogs/admin.py +++ b/cogs/admin.py @@ -243,13 +243,13 @@ class Admin(commands.Cog): week_ago = datetime.datetime.now() - datetime.timedelta(weeks=6) if member: - warns = self.bot.engine \ + warns = self.bot.database.session \ .query(Warn) \ .filter(Warn.user_id == member.id, Warn.created_at > week_ago, Warn.server_id == ctx.guild.id) \ .order_by(Warn.created_at.desc()) else: - warns = self.bot.engine \ + warns = self.bot.database.session \ .query(Warn) \ .filter(Warn.created_at > week_ago, Warn.server_id == ctx.guild.id) \ @@ -276,8 +276,8 @@ class Admin(commands.Cog): warn = Warn(server_id=ctx.guild.id, user_id=member.id, reason=reason, created_at=now) - self.bot.engine.add(warn) - self.bot.engine.commit() + self.bot.database.session.add(warn) + self.bot.database.session.commit() @commands.group(name='warn', aliases=['warns']) async def _warn(self, ctx: commands.Context): @@ -372,8 +372,12 @@ class Admin(commands.Cog): @_warn.command(name='remove', aliases=['revoke']) async def _warn_remove(self, ctx: commands.Context, warn_id: int): - warn = self.bot.engine.query(Warn).filter(Warn.id == warn_id).one() - self.bot.engine.delete(warn) + warn = self.bot.database.session\ + .query(Warn)\ + .filter(Warn.id == warn_id)\ + .one() + + self.bot.database.session.delete(warn) await ctx.send(f"{Texts('admin', ctx).get('Warn with id')} `{warn_id}`" f" {Texts('admin', ctx).get('successfully removed')}") @@ -391,9 +395,13 @@ class Admin(commands.Cog): @_warn.command(name='edit', aliases=['change']) async def _warn_edit(self, ctx: commands.Context, warn_id: int, *, reason): - warn = self.bot.engine.query(Warn).filter(Warn.id == warn_id).one() + warn = self.bot.database.session\ + .query(Warn)\ + .filter(Warn.id == warn_id)\ + .one() warn.reason = reason - self.bot.engine.commit() + + self.bot.database.session.commit() await ctx.send(f"{Texts('admin', ctx).get('Warn with id')} `{warn_id}`" f" {Texts('admin', ctx).get('successfully edited')}") @@ -402,7 +410,7 @@ class Admin(commands.Cog): @commands.command(name='language', aliases=['lang', 'langue', 'langage']) async def _language(self, ctx: commands.Context, locale: str): - available = self.bot.engine \ + available = self.bot.database.session \ .query(Lang.value) \ .filter(Lang.key == 'available') \ .one()[0] \ @@ -412,18 +420,18 @@ class Admin(commands.Cog): await ctx.send( Texts('admin', ctx).get('Unable to find this language')) else: - current = self.bot.engine \ + current = self.bot.database.session \ .query(Lang) \ .filter(Lang.key == str(ctx.guild.id)) if current.count() > 0: current = current.one() current.value = locale.lower() - self.bot.engine.commit() + self.bot.database.session.commit() else: new_row = Lang(key=str(ctx.guild.id), value=locale.lower()) - self.bot.engine.add(new_row) - self.bot.engine.commit() + self.bot.database.session.add(new_row) + self.bot.database.session.commit() await ctx.send( Texts('admin', ctx).get('Language changed successfully')) diff --git a/cogs/basics.py b/cogs/basics.py index 78cb5c3..f97dcbf 100644 --- a/cogs/basics.py +++ b/cogs/basics.py @@ -10,6 +10,7 @@ from discord.ext import commands from bot import TuxBot from .utils.lang import Texts +from tcp_latency import measure_latency class Basics(commands.Cog): @@ -33,10 +34,12 @@ class Basics(commands.Cog): latency = round(self.bot.latency * 1000, 2) typing = round((end - start) * 1000, 2) + discordapp = measure_latency(host='google.com', wait=0)[0] e = discord.Embed(title='Ping', color=discord.Color.teal()) e.add_field(name='Websocket', value=f'{latency}ms') e.add_field(name='Typing', value=f'{typing}ms') + e.add_field(name='discordapp.com', value=f'{discordapp}ms') await ctx.send(embed=e) """---------------------------------------------------------------------""" diff --git a/cogs/poll.py b/cogs/poll.py index 3006dc3..94a1434 100644 --- a/cogs/poll.py +++ b/cogs/poll.py @@ -1,3 +1,4 @@ +import json from typing import Union import discord @@ -17,12 +18,12 @@ class Polls(commands.Cog): def get_poll(self, pld) -> Union[bool, Poll]: if pld.user_id != self.bot.user.id: - poll = self.bot.engine \ + poll = self.bot.database.session \ .query(Poll) \ - .filter(Poll.message_id == pld.message_id) \ - .one_or_none() + .filter(Poll.message_id == pld.message_id) - if poll is not None: + if poll.count() != 0: + poll = poll.one() emotes = utils_emotes.get(len(poll.responses)) if pld.emoji.name in emotes: @@ -31,16 +32,52 @@ class Polls(commands.Cog): return False async def remove_reaction(self, pld): - channel: discord.TextChannel = self.bot.get_channel( - pld.channel_id - ) - message: discord.Message = await channel.fetch_message( - pld.message_id - ) + channel: discord.TextChannel = self.bot.get_channel(pld.channel_id) + message: discord.Message = await channel.fetch_message(pld.message_id) user: discord.User = await self.bot.fetch_user(pld.user_id) await message.remove_reaction(pld.emoji.name, user) + async def update_poll(self, poll_id: int): + poll = self.bot.database.session \ + .query(Poll) \ + .filter(Poll.id == poll_id) \ + .one() + channel: discord.TextChannel = self.bot.get_channel(poll.channel_id) + message: discord.Message = await channel.fetch_message(poll.message_id) + + content = json.loads(poll.content) \ + if isinstance(poll.content, str) \ + else poll.content + responses = json.loads(poll.responses) \ + if isinstance(poll.responses, str) \ + else poll.responses + + for i, field in enumerate(content.get('fields')): + responders = len(responses.get(str(i + 1))) + if responders <= 1: + field['value'] = f"**{responders}** vote" + else: + field['value'] = f"**{responders}** votes" + + e = discord.Embed(description=content.get('description')) + e.set_author( + name=content.get('author').get('name'), + icon_url=content.get('author').get('icon_url') + ) + for field in content.get('fields'): + e.add_field( + name=field.get('name'), + value=field.get('value'), + inline=True + ) + e.set_footer(text=content.get('footer').get('text')) + + await message.edit(embed=e) + + poll.content = json.dumps(content) + self.bot.database.session.commit() + @commands.Cog.listener() async def on_raw_reaction_add(self, pld: discord.RawReactionActionEvent): poll = self.get_poll(pld) @@ -50,33 +87,39 @@ class Polls(commands.Cog): await self.remove_reaction(pld) user_id = str(pld.user_id).encode() - responses = poll.responses choice = utils_emotes.get_index(pld.emoji.name) + 1 - responders = responses.get(str(choice)) + responses = json.loads(poll.responses) \ + if isinstance(poll.responses, str) \ + else poll.responses - if not responders: - print(responders, 'before0') + if not responses.get(str(choice)): + print(97) user_id_hash = bcrypt.hashpw(user_id, bcrypt.gensalt()) - responders.append(user_id_hash) - print(responders, 'after0') + responses \ + .get(str(choice)) \ + .append(user_id_hash.decode()) else: - for i, responder in enumerate(responders): + print(responses.get(str(choice))) + print(103) + for i, responder in enumerate(responses.get(str(choice))): + print(105) if bcrypt.checkpw(user_id, responder.encode()): - print(responders, 'before1') - responders.pop(i) - print(responders, 'after1') + print(107) + responses \ + .get(str(choice)) \ + .pop(i) else: - print(responders, 'before2') + print(112) user_id_hash = bcrypt.hashpw(user_id, bcrypt.gensalt()) - responders.append(user_id_hash) - print(responders, 'after2') + responses \ + .get(str(choice)) \ + .append(user_id_hash.decode()) + print(117) - poll.responses = responses - print(poll.responses) - self.bot.engine.commit() - - return 1 + poll.responses = json.dumps(responses) + self.bot.database.session.commit() + await self.update_poll(poll.id) """---------------------------------------------------------------------""" @@ -89,16 +132,16 @@ class Polls(commands.Cog): stmt = await ctx.send(Texts('poll', ctx).get('**Preparation...**')) poll_row = Poll() - self.bot.engine.add(poll_row) - self.bot.engine.flush() + self.bot.database.session.add(poll_row) + self.bot.database.session.flush() e = discord.Embed(description=f"**{question}**") e.set_author( name=ctx.author, - icon_url='https://cdn.pixabay.com/photo/2017/05/15/23/48/survey-2316468_960_720.png' + icon_url="https://cdn.gnous.eu/tuxbot/survey1.png" ) for i, response in enumerate(responses): - responses_row[str(i+1)] = [] + responses_row[str(i + 1)] = [] e.add_field( name=f"{emotes[i]} __{response.capitalize()}__", value="**0** vote" @@ -106,11 +149,12 @@ class Polls(commands.Cog): e.set_footer(text=f"ID: {poll_row.id}") poll_row.message_id = stmt.id - poll_row.poll = e.to_dict() + poll_row.channel_id = stmt.channel.id + poll_row.content = e.to_dict() poll_row.is_anonymous = anonymous poll_row.responses = responses_row - self.bot.engine.commit() + self.bot.database.session.commit() await stmt.edit(content='', embed=e) for emote in range(len(responses)): diff --git a/cogs/utils/database.py b/cogs/utils/database.py new file mode 100644 index 0000000..3e8d2bc --- /dev/null +++ b/cogs/utils/database.py @@ -0,0 +1,11 @@ +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker, session + + +class Database: + def __init__(self, config): + self.engine = create_engine(config.postgresql) + + Session = sessionmaker() + Session.configure(bind=self.engine) + self.session: session = Session() diff --git a/cogs/utils/lang.py b/cogs/utils/lang.py index 48b69ca..a357e85 100644 --- a/cogs/utils/lang.py +++ b/cogs/utils/lang.py @@ -1,9 +1,8 @@ import gettext import config +from cogs.utils.database import Database from .models.lang import Lang -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker from discord.ext import commands @@ -23,21 +22,16 @@ class Texts: @staticmethod def get_locale(ctx): - engine = create_engine(config.postgresql) - - Session = sessionmaker() - Session.configure(bind=engine) - - session = Session() + database = Database(config) if ctx is not None: - current = session\ + current = database.session\ .query(Lang.value)\ .filter(Lang.key == str(ctx.guild.id)) if current.count() > 0: return current.one()[0] - default = session\ + default = database.session\ .query(Lang.value)\ .filter(Lang.key == 'default')\ .one()[0] diff --git a/cogs/utils/models/poll.py b/cogs/utils/models/poll.py index aa30ece..ae6e4ce 100644 --- a/cogs/utils/models/poll.py +++ b/cogs/utils/models/poll.py @@ -8,13 +8,14 @@ class Poll(Base): __tablename__ = 'polls' id = Column(Integer, primary_key=True) + channel_id = Column(BigInteger) message_id = Column(BigInteger) - poll = Column(JSON) + content = Column(JSON) is_anonymous = Column(Boolean) responses = Column(JSON, nullable=True) def __repr__(self): - return "" % \ - (self.id, self.message_id, self.poll, + (self.id, self.channel_id, self.message_id, self.content, self.is_anonymous, self.responses) diff --git a/launcher.py b/launcher.py index ff9146a..59d5240 100644 --- a/launcher.py +++ b/launcher.py @@ -14,8 +14,7 @@ import git import requests from bot import TuxBot -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker +from cogs.utils.database import Database @contextlib.contextmanager @@ -50,17 +49,14 @@ def run_bot(unload: list = []): print(Texts().get('Starting...')) try: - engine = create_engine(config.postgresql) - - Session = sessionmaker() - Session.configure(bind=engine) + database = Database(config) except socket.gaierror: click.echo(Texts().get("Could not set up PostgreSQL..."), file=sys.stderr) log.exception(Texts().get("Could not set up PostgreSQL...")) return - bot = TuxBot(unload, Session()) + bot = TuxBot(unload, database) bot.run() diff --git a/requirements.txt b/requirements.txt index aeebb77..1fa6dca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,5 @@ sqlalchemy gitpython requests psutil -bcrypt \ No newline at end of file +bcrypt +tcp_latency \ No newline at end of file