refactor(command|sondage): continue rewrite of sondage

known issues: more than 2 votes are not peristed
This commit is contained in:
Romain J 2019-10-06 23:49:00 +02:00
parent 76e845e5be
commit 98b241d51b
9 changed files with 129 additions and 72 deletions

7
bot.py
View file

@ -4,7 +4,6 @@ import sys
from collections import deque, Counter from collections import deque, Counter
import aiohttp import aiohttp
import sqlalchemy
import discord import discord
import git import git
from discord.ext import commands from discord.ext import commands
@ -41,7 +40,7 @@ async def _prefix_callable(bot, message: discord.message) -> list:
class TuxBot(commands.AutoShardedBot): class TuxBot(commands.AutoShardedBot):
def __init__(self, unload: list, engine: sqlalchemy.engine.Engine): def __init__(self, unload: list, database):
super().__init__(command_prefix=_prefix_callable, pm_help=None, super().__init__(command_prefix=_prefix_callable, pm_help=None,
help_command=None, description=description, help_command=None, description=description,
help_attrs=dict(hidden=True), help_attrs=dict(hidden=True),
@ -54,14 +53,14 @@ class TuxBot(commands.AutoShardedBot):
self.uptime: datetime = datetime.datetime.utcnow() self.uptime: datetime = datetime.datetime.utcnow()
self.config = config self.config = config
self.engine = engine self.database = database
self._prev_events = deque(maxlen=10) self._prev_events = deque(maxlen=10)
self.session = aiohttp.ClientSession(loop=self.loop) self.session = aiohttp.ClientSession(loop=self.loop)
self.prefixes = Config('prefixes.json') self.prefixes = Config('prefixes.json')
self.blacklist = Config('blacklist.json') self.blacklist = Config('blacklist.json')
self.version = Version(10, 0, 0, pre_release='a20', build=build) self.version = Version(10, 0, 0, pre_release='a21', build=build)
for extension in l_extensions: for extension in l_extensions:
if extension not in unload: if extension not in unload:

View file

@ -243,13 +243,13 @@ class Admin(commands.Cog):
week_ago = datetime.datetime.now() - datetime.timedelta(weeks=6) week_ago = datetime.datetime.now() - datetime.timedelta(weeks=6)
if member: if member:
warns = self.bot.engine \ warns = self.bot.database.session \
.query(Warn) \ .query(Warn) \
.filter(Warn.user_id == member.id, Warn.created_at > week_ago, .filter(Warn.user_id == member.id, Warn.created_at > week_ago,
Warn.server_id == ctx.guild.id) \ Warn.server_id == ctx.guild.id) \
.order_by(Warn.created_at.desc()) .order_by(Warn.created_at.desc())
else: else:
warns = self.bot.engine \ warns = self.bot.database.session \
.query(Warn) \ .query(Warn) \
.filter(Warn.created_at > week_ago, .filter(Warn.created_at > week_ago,
Warn.server_id == ctx.guild.id) \ Warn.server_id == ctx.guild.id) \
@ -276,8 +276,8 @@ class Admin(commands.Cog):
warn = Warn(server_id=ctx.guild.id, user_id=member.id, reason=reason, warn = Warn(server_id=ctx.guild.id, user_id=member.id, reason=reason,
created_at=now) created_at=now)
self.bot.engine.add(warn) self.bot.database.session.add(warn)
self.bot.engine.commit() self.bot.database.session.commit()
@commands.group(name='warn', aliases=['warns']) @commands.group(name='warn', aliases=['warns'])
async def _warn(self, ctx: commands.Context): async def _warn(self, ctx: commands.Context):
@ -372,8 +372,12 @@ class Admin(commands.Cog):
@_warn.command(name='remove', aliases=['revoke']) @_warn.command(name='remove', aliases=['revoke'])
async def _warn_remove(self, ctx: commands.Context, warn_id: int): async def _warn_remove(self, ctx: commands.Context, warn_id: int):
warn = self.bot.engine.query(Warn).filter(Warn.id == warn_id).one() warn = self.bot.database.session\
self.bot.engine.delete(warn) .query(Warn)\
.filter(Warn.id == warn_id)\
.one()
self.bot.database.session.delete(warn)
await ctx.send(f"{Texts('admin', ctx).get('Warn with id')} `{warn_id}`" await ctx.send(f"{Texts('admin', ctx).get('Warn with id')} `{warn_id}`"
f" {Texts('admin', ctx).get('successfully removed')}") f" {Texts('admin', ctx).get('successfully removed')}")
@ -391,9 +395,13 @@ class Admin(commands.Cog):
@_warn.command(name='edit', aliases=['change']) @_warn.command(name='edit', aliases=['change'])
async def _warn_edit(self, ctx: commands.Context, warn_id: int, *, reason): async def _warn_edit(self, ctx: commands.Context, warn_id: int, *, reason):
warn = self.bot.engine.query(Warn).filter(Warn.id == warn_id).one() warn = self.bot.database.session\
.query(Warn)\
.filter(Warn.id == warn_id)\
.one()
warn.reason = reason warn.reason = reason
self.bot.engine.commit()
self.bot.database.session.commit()
await ctx.send(f"{Texts('admin', ctx).get('Warn with id')} `{warn_id}`" await ctx.send(f"{Texts('admin', ctx).get('Warn with id')} `{warn_id}`"
f" {Texts('admin', ctx).get('successfully edited')}") f" {Texts('admin', ctx).get('successfully edited')}")
@ -402,7 +410,7 @@ class Admin(commands.Cog):
@commands.command(name='language', aliases=['lang', 'langue', 'langage']) @commands.command(name='language', aliases=['lang', 'langue', 'langage'])
async def _language(self, ctx: commands.Context, locale: str): async def _language(self, ctx: commands.Context, locale: str):
available = self.bot.engine \ available = self.bot.database.session \
.query(Lang.value) \ .query(Lang.value) \
.filter(Lang.key == 'available') \ .filter(Lang.key == 'available') \
.one()[0] \ .one()[0] \
@ -412,18 +420,18 @@ class Admin(commands.Cog):
await ctx.send( await ctx.send(
Texts('admin', ctx).get('Unable to find this language')) Texts('admin', ctx).get('Unable to find this language'))
else: else:
current = self.bot.engine \ current = self.bot.database.session \
.query(Lang) \ .query(Lang) \
.filter(Lang.key == str(ctx.guild.id)) .filter(Lang.key == str(ctx.guild.id))
if current.count() > 0: if current.count() > 0:
current = current.one() current = current.one()
current.value = locale.lower() current.value = locale.lower()
self.bot.engine.commit() self.bot.database.session.commit()
else: else:
new_row = Lang(key=str(ctx.guild.id), value=locale.lower()) new_row = Lang(key=str(ctx.guild.id), value=locale.lower())
self.bot.engine.add(new_row) self.bot.database.session.add(new_row)
self.bot.engine.commit() self.bot.database.session.commit()
await ctx.send( await ctx.send(
Texts('admin', ctx).get('Language changed successfully')) Texts('admin', ctx).get('Language changed successfully'))

View file

@ -10,6 +10,7 @@ from discord.ext import commands
from bot import TuxBot from bot import TuxBot
from .utils.lang import Texts from .utils.lang import Texts
from tcp_latency import measure_latency
class Basics(commands.Cog): class Basics(commands.Cog):
@ -33,10 +34,12 @@ class Basics(commands.Cog):
latency = round(self.bot.latency * 1000, 2) latency = round(self.bot.latency * 1000, 2)
typing = round((end - start) * 1000, 2) typing = round((end - start) * 1000, 2)
discordapp = measure_latency(host='google.com', wait=0)[0]
e = discord.Embed(title='Ping', color=discord.Color.teal()) e = discord.Embed(title='Ping', color=discord.Color.teal())
e.add_field(name='Websocket', value=f'{latency}ms') e.add_field(name='Websocket', value=f'{latency}ms')
e.add_field(name='Typing', value=f'{typing}ms') e.add_field(name='Typing', value=f'{typing}ms')
e.add_field(name='discordapp.com', value=f'{discordapp}ms')
await ctx.send(embed=e) await ctx.send(embed=e)
"""---------------------------------------------------------------------""" """---------------------------------------------------------------------"""

View file

@ -1,3 +1,4 @@
import json
from typing import Union from typing import Union
import discord import discord
@ -17,12 +18,12 @@ class Polls(commands.Cog):
def get_poll(self, pld) -> Union[bool, Poll]: def get_poll(self, pld) -> Union[bool, Poll]:
if pld.user_id != self.bot.user.id: if pld.user_id != self.bot.user.id:
poll = self.bot.engine \ poll = self.bot.database.session \
.query(Poll) \ .query(Poll) \
.filter(Poll.message_id == pld.message_id) \ .filter(Poll.message_id == pld.message_id)
.one_or_none()
if poll is not None: if poll.count() != 0:
poll = poll.one()
emotes = utils_emotes.get(len(poll.responses)) emotes = utils_emotes.get(len(poll.responses))
if pld.emoji.name in emotes: if pld.emoji.name in emotes:
@ -31,16 +32,52 @@ class Polls(commands.Cog):
return False return False
async def remove_reaction(self, pld): async def remove_reaction(self, pld):
channel: discord.TextChannel = self.bot.get_channel( channel: discord.TextChannel = self.bot.get_channel(pld.channel_id)
pld.channel_id message: discord.Message = await channel.fetch_message(pld.message_id)
)
message: discord.Message = await channel.fetch_message(
pld.message_id
)
user: discord.User = await self.bot.fetch_user(pld.user_id) user: discord.User = await self.bot.fetch_user(pld.user_id)
await message.remove_reaction(pld.emoji.name, user) await message.remove_reaction(pld.emoji.name, user)
async def update_poll(self, poll_id: int):
poll = self.bot.database.session \
.query(Poll) \
.filter(Poll.id == poll_id) \
.one()
channel: discord.TextChannel = self.bot.get_channel(poll.channel_id)
message: discord.Message = await channel.fetch_message(poll.message_id)
content = json.loads(poll.content) \
if isinstance(poll.content, str) \
else poll.content
responses = json.loads(poll.responses) \
if isinstance(poll.responses, str) \
else poll.responses
for i, field in enumerate(content.get('fields')):
responders = len(responses.get(str(i + 1)))
if responders <= 1:
field['value'] = f"**{responders}** vote"
else:
field['value'] = f"**{responders}** votes"
e = discord.Embed(description=content.get('description'))
e.set_author(
name=content.get('author').get('name'),
icon_url=content.get('author').get('icon_url')
)
for field in content.get('fields'):
e.add_field(
name=field.get('name'),
value=field.get('value'),
inline=True
)
e.set_footer(text=content.get('footer').get('text'))
await message.edit(embed=e)
poll.content = json.dumps(content)
self.bot.database.session.commit()
@commands.Cog.listener() @commands.Cog.listener()
async def on_raw_reaction_add(self, pld: discord.RawReactionActionEvent): async def on_raw_reaction_add(self, pld: discord.RawReactionActionEvent):
poll = self.get_poll(pld) poll = self.get_poll(pld)
@ -50,33 +87,39 @@ class Polls(commands.Cog):
await self.remove_reaction(pld) await self.remove_reaction(pld)
user_id = str(pld.user_id).encode() user_id = str(pld.user_id).encode()
responses = poll.responses
choice = utils_emotes.get_index(pld.emoji.name) + 1 choice = utils_emotes.get_index(pld.emoji.name) + 1
responders = responses.get(str(choice)) responses = json.loads(poll.responses) \
if isinstance(poll.responses, str) \
else poll.responses
if not responders: if not responses.get(str(choice)):
print(responders, 'before0') print(97)
user_id_hash = bcrypt.hashpw(user_id, bcrypt.gensalt()) user_id_hash = bcrypt.hashpw(user_id, bcrypt.gensalt())
responders.append(user_id_hash) responses \
print(responders, 'after0') .get(str(choice)) \
.append(user_id_hash.decode())
else: else:
for i, responder in enumerate(responders): print(responses.get(str(choice)))
print(103)
for i, responder in enumerate(responses.get(str(choice))):
print(105)
if bcrypt.checkpw(user_id, responder.encode()): if bcrypt.checkpw(user_id, responder.encode()):
print(responders, 'before1') print(107)
responders.pop(i) responses \
print(responders, 'after1') .get(str(choice)) \
.pop(i)
else: else:
print(responders, 'before2') print(112)
user_id_hash = bcrypt.hashpw(user_id, bcrypt.gensalt()) user_id_hash = bcrypt.hashpw(user_id, bcrypt.gensalt())
responders.append(user_id_hash) responses \
print(responders, 'after2') .get(str(choice)) \
.append(user_id_hash.decode())
print(117)
poll.responses = responses poll.responses = json.dumps(responses)
print(poll.responses) self.bot.database.session.commit()
self.bot.engine.commit() await self.update_poll(poll.id)
return 1
"""---------------------------------------------------------------------""" """---------------------------------------------------------------------"""
@ -89,13 +132,13 @@ class Polls(commands.Cog):
stmt = await ctx.send(Texts('poll', ctx).get('**Preparation...**')) stmt = await ctx.send(Texts('poll', ctx).get('**Preparation...**'))
poll_row = Poll() poll_row = Poll()
self.bot.engine.add(poll_row) self.bot.database.session.add(poll_row)
self.bot.engine.flush() self.bot.database.session.flush()
e = discord.Embed(description=f"**{question}**") e = discord.Embed(description=f"**{question}**")
e.set_author( e.set_author(
name=ctx.author, name=ctx.author,
icon_url='https://cdn.pixabay.com/photo/2017/05/15/23/48/survey-2316468_960_720.png' icon_url="https://cdn.gnous.eu/tuxbot/survey1.png"
) )
for i, response in enumerate(responses): for i, response in enumerate(responses):
responses_row[str(i + 1)] = [] responses_row[str(i + 1)] = []
@ -106,11 +149,12 @@ class Polls(commands.Cog):
e.set_footer(text=f"ID: {poll_row.id}") e.set_footer(text=f"ID: {poll_row.id}")
poll_row.message_id = stmt.id poll_row.message_id = stmt.id
poll_row.poll = e.to_dict() poll_row.channel_id = stmt.channel.id
poll_row.content = e.to_dict()
poll_row.is_anonymous = anonymous poll_row.is_anonymous = anonymous
poll_row.responses = responses_row poll_row.responses = responses_row
self.bot.engine.commit() self.bot.database.session.commit()
await stmt.edit(content='', embed=e) await stmt.edit(content='', embed=e)
for emote in range(len(responses)): for emote in range(len(responses)):

11
cogs/utils/database.py Normal file
View file

@ -0,0 +1,11 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, session
class Database:
def __init__(self, config):
self.engine = create_engine(config.postgresql)
Session = sessionmaker()
Session.configure(bind=self.engine)
self.session: session = Session()

View file

@ -1,9 +1,8 @@
import gettext import gettext
import config import config
from cogs.utils.database import Database
from .models.lang import Lang from .models.lang import Lang
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from discord.ext import commands from discord.ext import commands
@ -23,21 +22,16 @@ class Texts:
@staticmethod @staticmethod
def get_locale(ctx): def get_locale(ctx):
engine = create_engine(config.postgresql) database = Database(config)
Session = sessionmaker()
Session.configure(bind=engine)
session = Session()
if ctx is not None: if ctx is not None:
current = session\ current = database.session\
.query(Lang.value)\ .query(Lang.value)\
.filter(Lang.key == str(ctx.guild.id)) .filter(Lang.key == str(ctx.guild.id))
if current.count() > 0: if current.count() > 0:
return current.one()[0] return current.one()[0]
default = session\ default = database.session\
.query(Lang.value)\ .query(Lang.value)\
.filter(Lang.key == 'default')\ .filter(Lang.key == 'default')\
.one()[0] .one()[0]

View file

@ -8,13 +8,14 @@ class Poll(Base):
__tablename__ = 'polls' __tablename__ = 'polls'
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
channel_id = Column(BigInteger)
message_id = Column(BigInteger) message_id = Column(BigInteger)
poll = Column(JSON) content = Column(JSON)
is_anonymous = Column(Boolean) is_anonymous = Column(Boolean)
responses = Column(JSON, nullable=True) responses = Column(JSON, nullable=True)
def __repr__(self): def __repr__(self):
return "<Poll(id='%s', message_id='%s', poll='%s', " \ return "<Poll(id='%s', channel_id='%s', message_id='%s', poll='%s', " \
"is_anonymous='%s', responses='%s')>" % \ "is_anonymous='%s', responses='%s')>" % \
(self.id, self.message_id, self.poll, (self.id, self.channel_id, self.message_id, self.content,
self.is_anonymous, self.responses) self.is_anonymous, self.responses)

View file

@ -14,8 +14,7 @@ import git
import requests import requests
from bot import TuxBot from bot import TuxBot
from sqlalchemy import create_engine from cogs.utils.database import Database
from sqlalchemy.orm import sessionmaker
@contextlib.contextmanager @contextlib.contextmanager
@ -50,17 +49,14 @@ def run_bot(unload: list = []):
print(Texts().get('Starting...')) print(Texts().get('Starting...'))
try: try:
engine = create_engine(config.postgresql) database = Database(config)
Session = sessionmaker()
Session.configure(bind=engine)
except socket.gaierror: except socket.gaierror:
click.echo(Texts().get("Could not set up PostgreSQL..."), click.echo(Texts().get("Could not set up PostgreSQL..."),
file=sys.stderr) file=sys.stderr)
log.exception(Texts().get("Could not set up PostgreSQL...")) log.exception(Texts().get("Could not set up PostgreSQL..."))
return return
bot = TuxBot(unload, Session()) bot = TuxBot(unload, database)
bot.run() bot.run()

View file

@ -9,3 +9,4 @@ gitpython
requests requests
psutil psutil
bcrypt bcrypt
tcp_latency