refactor(command|sondage): continue rewrite of sondage

known issues: more than 2 votes are not peristed
This commit is contained in:
Romain J 2019-10-06 23:49:00 +02:00
parent 76e845e5be
commit 98b241d51b
9 changed files with 129 additions and 72 deletions

7
bot.py
View file

@ -4,7 +4,6 @@ import sys
from collections import deque, Counter
import aiohttp
import sqlalchemy
import discord
import git
from discord.ext import commands
@ -41,7 +40,7 @@ async def _prefix_callable(bot, message: discord.message) -> list:
class TuxBot(commands.AutoShardedBot):
def __init__(self, unload: list, engine: sqlalchemy.engine.Engine):
def __init__(self, unload: list, database):
super().__init__(command_prefix=_prefix_callable, pm_help=None,
help_command=None, description=description,
help_attrs=dict(hidden=True),
@ -54,14 +53,14 @@ class TuxBot(commands.AutoShardedBot):
self.uptime: datetime = datetime.datetime.utcnow()
self.config = config
self.engine = engine
self.database = database
self._prev_events = deque(maxlen=10)
self.session = aiohttp.ClientSession(loop=self.loop)
self.prefixes = Config('prefixes.json')
self.blacklist = Config('blacklist.json')
self.version = Version(10, 0, 0, pre_release='a20', build=build)
self.version = Version(10, 0, 0, pre_release='a21', build=build)
for extension in l_extensions:
if extension not in unload:

View file

@ -243,13 +243,13 @@ class Admin(commands.Cog):
week_ago = datetime.datetime.now() - datetime.timedelta(weeks=6)
if member:
warns = self.bot.engine \
warns = self.bot.database.session \
.query(Warn) \
.filter(Warn.user_id == member.id, Warn.created_at > week_ago,
Warn.server_id == ctx.guild.id) \
.order_by(Warn.created_at.desc())
else:
warns = self.bot.engine \
warns = self.bot.database.session \
.query(Warn) \
.filter(Warn.created_at > week_ago,
Warn.server_id == ctx.guild.id) \
@ -276,8 +276,8 @@ class Admin(commands.Cog):
warn = Warn(server_id=ctx.guild.id, user_id=member.id, reason=reason,
created_at=now)
self.bot.engine.add(warn)
self.bot.engine.commit()
self.bot.database.session.add(warn)
self.bot.database.session.commit()
@commands.group(name='warn', aliases=['warns'])
async def _warn(self, ctx: commands.Context):
@ -372,8 +372,12 @@ class Admin(commands.Cog):
@_warn.command(name='remove', aliases=['revoke'])
async def _warn_remove(self, ctx: commands.Context, warn_id: int):
warn = self.bot.engine.query(Warn).filter(Warn.id == warn_id).one()
self.bot.engine.delete(warn)
warn = self.bot.database.session\
.query(Warn)\
.filter(Warn.id == warn_id)\
.one()
self.bot.database.session.delete(warn)
await ctx.send(f"{Texts('admin', ctx).get('Warn with id')} `{warn_id}`"
f" {Texts('admin', ctx).get('successfully removed')}")
@ -391,9 +395,13 @@ class Admin(commands.Cog):
@_warn.command(name='edit', aliases=['change'])
async def _warn_edit(self, ctx: commands.Context, warn_id: int, *, reason):
warn = self.bot.engine.query(Warn).filter(Warn.id == warn_id).one()
warn = self.bot.database.session\
.query(Warn)\
.filter(Warn.id == warn_id)\
.one()
warn.reason = reason
self.bot.engine.commit()
self.bot.database.session.commit()
await ctx.send(f"{Texts('admin', ctx).get('Warn with id')} `{warn_id}`"
f" {Texts('admin', ctx).get('successfully edited')}")
@ -402,7 +410,7 @@ class Admin(commands.Cog):
@commands.command(name='language', aliases=['lang', 'langue', 'langage'])
async def _language(self, ctx: commands.Context, locale: str):
available = self.bot.engine \
available = self.bot.database.session \
.query(Lang.value) \
.filter(Lang.key == 'available') \
.one()[0] \
@ -412,18 +420,18 @@ class Admin(commands.Cog):
await ctx.send(
Texts('admin', ctx).get('Unable to find this language'))
else:
current = self.bot.engine \
current = self.bot.database.session \
.query(Lang) \
.filter(Lang.key == str(ctx.guild.id))
if current.count() > 0:
current = current.one()
current.value = locale.lower()
self.bot.engine.commit()
self.bot.database.session.commit()
else:
new_row = Lang(key=str(ctx.guild.id), value=locale.lower())
self.bot.engine.add(new_row)
self.bot.engine.commit()
self.bot.database.session.add(new_row)
self.bot.database.session.commit()
await ctx.send(
Texts('admin', ctx).get('Language changed successfully'))

View file

@ -10,6 +10,7 @@ from discord.ext import commands
from bot import TuxBot
from .utils.lang import Texts
from tcp_latency import measure_latency
class Basics(commands.Cog):
@ -33,10 +34,12 @@ class Basics(commands.Cog):
latency = round(self.bot.latency * 1000, 2)
typing = round((end - start) * 1000, 2)
discordapp = measure_latency(host='google.com', wait=0)[0]
e = discord.Embed(title='Ping', color=discord.Color.teal())
e.add_field(name='Websocket', value=f'{latency}ms')
e.add_field(name='Typing', value=f'{typing}ms')
e.add_field(name='discordapp.com', value=f'{discordapp}ms')
await ctx.send(embed=e)
"""---------------------------------------------------------------------"""

View file

@ -1,3 +1,4 @@
import json
from typing import Union
import discord
@ -17,12 +18,12 @@ class Polls(commands.Cog):
def get_poll(self, pld) -> Union[bool, Poll]:
if pld.user_id != self.bot.user.id:
poll = self.bot.engine \
poll = self.bot.database.session \
.query(Poll) \
.filter(Poll.message_id == pld.message_id) \
.one_or_none()
.filter(Poll.message_id == pld.message_id)
if poll is not None:
if poll.count() != 0:
poll = poll.one()
emotes = utils_emotes.get(len(poll.responses))
if pld.emoji.name in emotes:
@ -31,16 +32,52 @@ class Polls(commands.Cog):
return False
async def remove_reaction(self, pld):
channel: discord.TextChannel = self.bot.get_channel(
pld.channel_id
)
message: discord.Message = await channel.fetch_message(
pld.message_id
)
channel: discord.TextChannel = self.bot.get_channel(pld.channel_id)
message: discord.Message = await channel.fetch_message(pld.message_id)
user: discord.User = await self.bot.fetch_user(pld.user_id)
await message.remove_reaction(pld.emoji.name, user)
async def update_poll(self, poll_id: int):
poll = self.bot.database.session \
.query(Poll) \
.filter(Poll.id == poll_id) \
.one()
channel: discord.TextChannel = self.bot.get_channel(poll.channel_id)
message: discord.Message = await channel.fetch_message(poll.message_id)
content = json.loads(poll.content) \
if isinstance(poll.content, str) \
else poll.content
responses = json.loads(poll.responses) \
if isinstance(poll.responses, str) \
else poll.responses
for i, field in enumerate(content.get('fields')):
responders = len(responses.get(str(i + 1)))
if responders <= 1:
field['value'] = f"**{responders}** vote"
else:
field['value'] = f"**{responders}** votes"
e = discord.Embed(description=content.get('description'))
e.set_author(
name=content.get('author').get('name'),
icon_url=content.get('author').get('icon_url')
)
for field in content.get('fields'):
e.add_field(
name=field.get('name'),
value=field.get('value'),
inline=True
)
e.set_footer(text=content.get('footer').get('text'))
await message.edit(embed=e)
poll.content = json.dumps(content)
self.bot.database.session.commit()
@commands.Cog.listener()
async def on_raw_reaction_add(self, pld: discord.RawReactionActionEvent):
poll = self.get_poll(pld)
@ -50,33 +87,39 @@ class Polls(commands.Cog):
await self.remove_reaction(pld)
user_id = str(pld.user_id).encode()
responses = poll.responses
choice = utils_emotes.get_index(pld.emoji.name) + 1
responders = responses.get(str(choice))
responses = json.loads(poll.responses) \
if isinstance(poll.responses, str) \
else poll.responses
if not responders:
print(responders, 'before0')
if not responses.get(str(choice)):
print(97)
user_id_hash = bcrypt.hashpw(user_id, bcrypt.gensalt())
responders.append(user_id_hash)
print(responders, 'after0')
responses \
.get(str(choice)) \
.append(user_id_hash.decode())
else:
for i, responder in enumerate(responders):
print(responses.get(str(choice)))
print(103)
for i, responder in enumerate(responses.get(str(choice))):
print(105)
if bcrypt.checkpw(user_id, responder.encode()):
print(responders, 'before1')
responders.pop(i)
print(responders, 'after1')
print(107)
responses \
.get(str(choice)) \
.pop(i)
else:
print(responders, 'before2')
print(112)
user_id_hash = bcrypt.hashpw(user_id, bcrypt.gensalt())
responders.append(user_id_hash)
print(responders, 'after2')
responses \
.get(str(choice)) \
.append(user_id_hash.decode())
print(117)
poll.responses = responses
print(poll.responses)
self.bot.engine.commit()
return 1
poll.responses = json.dumps(responses)
self.bot.database.session.commit()
await self.update_poll(poll.id)
"""---------------------------------------------------------------------"""
@ -89,16 +132,16 @@ class Polls(commands.Cog):
stmt = await ctx.send(Texts('poll', ctx).get('**Preparation...**'))
poll_row = Poll()
self.bot.engine.add(poll_row)
self.bot.engine.flush()
self.bot.database.session.add(poll_row)
self.bot.database.session.flush()
e = discord.Embed(description=f"**{question}**")
e.set_author(
name=ctx.author,
icon_url='https://cdn.pixabay.com/photo/2017/05/15/23/48/survey-2316468_960_720.png'
icon_url="https://cdn.gnous.eu/tuxbot/survey1.png"
)
for i, response in enumerate(responses):
responses_row[str(i+1)] = []
responses_row[str(i + 1)] = []
e.add_field(
name=f"{emotes[i]} __{response.capitalize()}__",
value="**0** vote"
@ -106,11 +149,12 @@ class Polls(commands.Cog):
e.set_footer(text=f"ID: {poll_row.id}")
poll_row.message_id = stmt.id
poll_row.poll = e.to_dict()
poll_row.channel_id = stmt.channel.id
poll_row.content = e.to_dict()
poll_row.is_anonymous = anonymous
poll_row.responses = responses_row
self.bot.engine.commit()
self.bot.database.session.commit()
await stmt.edit(content='', embed=e)
for emote in range(len(responses)):

11
cogs/utils/database.py Normal file
View file

@ -0,0 +1,11 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, session
class Database:
def __init__(self, config):
self.engine = create_engine(config.postgresql)
Session = sessionmaker()
Session.configure(bind=self.engine)
self.session: session = Session()

View file

@ -1,9 +1,8 @@
import gettext
import config
from cogs.utils.database import Database
from .models.lang import Lang
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from discord.ext import commands
@ -23,21 +22,16 @@ class Texts:
@staticmethod
def get_locale(ctx):
engine = create_engine(config.postgresql)
Session = sessionmaker()
Session.configure(bind=engine)
session = Session()
database = Database(config)
if ctx is not None:
current = session\
current = database.session\
.query(Lang.value)\
.filter(Lang.key == str(ctx.guild.id))
if current.count() > 0:
return current.one()[0]
default = session\
default = database.session\
.query(Lang.value)\
.filter(Lang.key == 'default')\
.one()[0]

View file

@ -8,13 +8,14 @@ class Poll(Base):
__tablename__ = 'polls'
id = Column(Integer, primary_key=True)
channel_id = Column(BigInteger)
message_id = Column(BigInteger)
poll = Column(JSON)
content = Column(JSON)
is_anonymous = Column(Boolean)
responses = Column(JSON, nullable=True)
def __repr__(self):
return "<Poll(id='%s', message_id='%s', poll='%s', " \
return "<Poll(id='%s', channel_id='%s', message_id='%s', poll='%s', " \
"is_anonymous='%s', responses='%s')>" % \
(self.id, self.message_id, self.poll,
(self.id, self.channel_id, self.message_id, self.content,
self.is_anonymous, self.responses)

View file

@ -14,8 +14,7 @@ import git
import requests
from bot import TuxBot
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from cogs.utils.database import Database
@contextlib.contextmanager
@ -50,17 +49,14 @@ def run_bot(unload: list = []):
print(Texts().get('Starting...'))
try:
engine = create_engine(config.postgresql)
Session = sessionmaker()
Session.configure(bind=engine)
database = Database(config)
except socket.gaierror:
click.echo(Texts().get("Could not set up PostgreSQL..."),
file=sys.stderr)
log.exception(Texts().get("Could not set up PostgreSQL..."))
return
bot = TuxBot(unload, Session())
bot = TuxBot(unload, database)
bot.run()

View file

@ -9,3 +9,4 @@ gitpython
requests
psutil
bcrypt
tcp_latency