diff --git a/bot.py b/bot.py
index 5ca163b..d4cd75b 100755
--- a/bot.py
+++ b/bot.py
@@ -4,7 +4,6 @@ import sys
from collections import deque, Counter
import aiohttp
-import sqlalchemy
import discord
import git
from discord.ext import commands
@@ -41,7 +40,7 @@ async def _prefix_callable(bot, message: discord.message) -> list:
class TuxBot(commands.AutoShardedBot):
- def __init__(self, unload: list, engine: sqlalchemy.engine.Engine):
+ def __init__(self, unload: list, database):
super().__init__(command_prefix=_prefix_callable, pm_help=None,
help_command=None, description=description,
help_attrs=dict(hidden=True),
@@ -54,14 +53,14 @@ class TuxBot(commands.AutoShardedBot):
self.uptime: datetime = datetime.datetime.utcnow()
self.config = config
- self.engine = engine
+ self.database = database
self._prev_events = deque(maxlen=10)
self.session = aiohttp.ClientSession(loop=self.loop)
self.prefixes = Config('prefixes.json')
self.blacklist = Config('blacklist.json')
- self.version = Version(10, 0, 0, pre_release='a20', build=build)
+ self.version = Version(10, 0, 0, pre_release='a21', build=build)
for extension in l_extensions:
if extension not in unload:
diff --git a/cogs/admin.py b/cogs/admin.py
index 7ad57f1..dcf82f6 100644
--- a/cogs/admin.py
+++ b/cogs/admin.py
@@ -243,13 +243,13 @@ class Admin(commands.Cog):
week_ago = datetime.datetime.now() - datetime.timedelta(weeks=6)
if member:
- warns = self.bot.engine \
+ warns = self.bot.database.session \
.query(Warn) \
.filter(Warn.user_id == member.id, Warn.created_at > week_ago,
Warn.server_id == ctx.guild.id) \
.order_by(Warn.created_at.desc())
else:
- warns = self.bot.engine \
+ warns = self.bot.database.session \
.query(Warn) \
.filter(Warn.created_at > week_ago,
Warn.server_id == ctx.guild.id) \
@@ -276,8 +276,8 @@ class Admin(commands.Cog):
warn = Warn(server_id=ctx.guild.id, user_id=member.id, reason=reason,
created_at=now)
- self.bot.engine.add(warn)
- self.bot.engine.commit()
+ self.bot.database.session.add(warn)
+ self.bot.database.session.commit()
@commands.group(name='warn', aliases=['warns'])
async def _warn(self, ctx: commands.Context):
@@ -372,8 +372,12 @@ class Admin(commands.Cog):
@_warn.command(name='remove', aliases=['revoke'])
async def _warn_remove(self, ctx: commands.Context, warn_id: int):
- warn = self.bot.engine.query(Warn).filter(Warn.id == warn_id).one()
- self.bot.engine.delete(warn)
+ warn = self.bot.database.session\
+ .query(Warn)\
+ .filter(Warn.id == warn_id)\
+ .one()
+
+ self.bot.database.session.delete(warn)
await ctx.send(f"{Texts('admin', ctx).get('Warn with id')} `{warn_id}`"
f" {Texts('admin', ctx).get('successfully removed')}")
@@ -391,9 +395,13 @@ class Admin(commands.Cog):
@_warn.command(name='edit', aliases=['change'])
async def _warn_edit(self, ctx: commands.Context, warn_id: int, *, reason):
- warn = self.bot.engine.query(Warn).filter(Warn.id == warn_id).one()
+ warn = self.bot.database.session\
+ .query(Warn)\
+ .filter(Warn.id == warn_id)\
+ .one()
warn.reason = reason
- self.bot.engine.commit()
+
+ self.bot.database.session.commit()
await ctx.send(f"{Texts('admin', ctx).get('Warn with id')} `{warn_id}`"
f" {Texts('admin', ctx).get('successfully edited')}")
@@ -402,7 +410,7 @@ class Admin(commands.Cog):
@commands.command(name='language', aliases=['lang', 'langue', 'langage'])
async def _language(self, ctx: commands.Context, locale: str):
- available = self.bot.engine \
+ available = self.bot.database.session \
.query(Lang.value) \
.filter(Lang.key == 'available') \
.one()[0] \
@@ -412,18 +420,18 @@ class Admin(commands.Cog):
await ctx.send(
Texts('admin', ctx).get('Unable to find this language'))
else:
- current = self.bot.engine \
+ current = self.bot.database.session \
.query(Lang) \
.filter(Lang.key == str(ctx.guild.id))
if current.count() > 0:
current = current.one()
current.value = locale.lower()
- self.bot.engine.commit()
+ self.bot.database.session.commit()
else:
new_row = Lang(key=str(ctx.guild.id), value=locale.lower())
- self.bot.engine.add(new_row)
- self.bot.engine.commit()
+ self.bot.database.session.add(new_row)
+ self.bot.database.session.commit()
await ctx.send(
Texts('admin', ctx).get('Language changed successfully'))
diff --git a/cogs/basics.py b/cogs/basics.py
index 78cb5c3..f97dcbf 100644
--- a/cogs/basics.py
+++ b/cogs/basics.py
@@ -10,6 +10,7 @@ from discord.ext import commands
from bot import TuxBot
from .utils.lang import Texts
+from tcp_latency import measure_latency
class Basics(commands.Cog):
@@ -33,10 +34,12 @@ class Basics(commands.Cog):
latency = round(self.bot.latency * 1000, 2)
typing = round((end - start) * 1000, 2)
+ discordapp = measure_latency(host='google.com', wait=0)[0]
e = discord.Embed(title='Ping', color=discord.Color.teal())
e.add_field(name='Websocket', value=f'{latency}ms')
e.add_field(name='Typing', value=f'{typing}ms')
+ e.add_field(name='discordapp.com', value=f'{discordapp}ms')
await ctx.send(embed=e)
"""---------------------------------------------------------------------"""
diff --git a/cogs/poll.py b/cogs/poll.py
index 3006dc3..94a1434 100644
--- a/cogs/poll.py
+++ b/cogs/poll.py
@@ -1,3 +1,4 @@
+import json
from typing import Union
import discord
@@ -17,12 +18,12 @@ class Polls(commands.Cog):
def get_poll(self, pld) -> Union[bool, Poll]:
if pld.user_id != self.bot.user.id:
- poll = self.bot.engine \
+ poll = self.bot.database.session \
.query(Poll) \
- .filter(Poll.message_id == pld.message_id) \
- .one_or_none()
+ .filter(Poll.message_id == pld.message_id)
- if poll is not None:
+ if poll.count() != 0:
+ poll = poll.one()
emotes = utils_emotes.get(len(poll.responses))
if pld.emoji.name in emotes:
@@ -31,16 +32,52 @@ class Polls(commands.Cog):
return False
async def remove_reaction(self, pld):
- channel: discord.TextChannel = self.bot.get_channel(
- pld.channel_id
- )
- message: discord.Message = await channel.fetch_message(
- pld.message_id
- )
+ channel: discord.TextChannel = self.bot.get_channel(pld.channel_id)
+ message: discord.Message = await channel.fetch_message(pld.message_id)
user: discord.User = await self.bot.fetch_user(pld.user_id)
await message.remove_reaction(pld.emoji.name, user)
+ async def update_poll(self, poll_id: int):
+ poll = self.bot.database.session \
+ .query(Poll) \
+ .filter(Poll.id == poll_id) \
+ .one()
+ channel: discord.TextChannel = self.bot.get_channel(poll.channel_id)
+ message: discord.Message = await channel.fetch_message(poll.message_id)
+
+ content = json.loads(poll.content) \
+ if isinstance(poll.content, str) \
+ else poll.content
+ responses = json.loads(poll.responses) \
+ if isinstance(poll.responses, str) \
+ else poll.responses
+
+ for i, field in enumerate(content.get('fields')):
+ responders = len(responses.get(str(i + 1)))
+ if responders <= 1:
+ field['value'] = f"**{responders}** vote"
+ else:
+ field['value'] = f"**{responders}** votes"
+
+ e = discord.Embed(description=content.get('description'))
+ e.set_author(
+ name=content.get('author').get('name'),
+ icon_url=content.get('author').get('icon_url')
+ )
+ for field in content.get('fields'):
+ e.add_field(
+ name=field.get('name'),
+ value=field.get('value'),
+ inline=True
+ )
+ e.set_footer(text=content.get('footer').get('text'))
+
+ await message.edit(embed=e)
+
+ poll.content = json.dumps(content)
+ self.bot.database.session.commit()
+
@commands.Cog.listener()
async def on_raw_reaction_add(self, pld: discord.RawReactionActionEvent):
poll = self.get_poll(pld)
@@ -50,33 +87,39 @@ class Polls(commands.Cog):
await self.remove_reaction(pld)
user_id = str(pld.user_id).encode()
- responses = poll.responses
choice = utils_emotes.get_index(pld.emoji.name) + 1
- responders = responses.get(str(choice))
+ responses = json.loads(poll.responses) \
+ if isinstance(poll.responses, str) \
+ else poll.responses
- if not responders:
- print(responders, 'before0')
+ if not responses.get(str(choice)):
+ print(97)
user_id_hash = bcrypt.hashpw(user_id, bcrypt.gensalt())
- responders.append(user_id_hash)
- print(responders, 'after0')
+ responses \
+ .get(str(choice)) \
+ .append(user_id_hash.decode())
else:
- for i, responder in enumerate(responders):
+ print(responses.get(str(choice)))
+ print(103)
+ for i, responder in enumerate(responses.get(str(choice))):
+ print(105)
if bcrypt.checkpw(user_id, responder.encode()):
- print(responders, 'before1')
- responders.pop(i)
- print(responders, 'after1')
+ print(107)
+ responses \
+ .get(str(choice)) \
+ .pop(i)
else:
- print(responders, 'before2')
+ print(112)
user_id_hash = bcrypt.hashpw(user_id, bcrypt.gensalt())
- responders.append(user_id_hash)
- print(responders, 'after2')
+ responses \
+ .get(str(choice)) \
+ .append(user_id_hash.decode())
+ print(117)
- poll.responses = responses
- print(poll.responses)
- self.bot.engine.commit()
-
- return 1
+ poll.responses = json.dumps(responses)
+ self.bot.database.session.commit()
+ await self.update_poll(poll.id)
"""---------------------------------------------------------------------"""
@@ -89,16 +132,16 @@ class Polls(commands.Cog):
stmt = await ctx.send(Texts('poll', ctx).get('**Preparation...**'))
poll_row = Poll()
- self.bot.engine.add(poll_row)
- self.bot.engine.flush()
+ self.bot.database.session.add(poll_row)
+ self.bot.database.session.flush()
e = discord.Embed(description=f"**{question}**")
e.set_author(
name=ctx.author,
- icon_url='https://cdn.pixabay.com/photo/2017/05/15/23/48/survey-2316468_960_720.png'
+ icon_url="https://cdn.gnous.eu/tuxbot/survey1.png"
)
for i, response in enumerate(responses):
- responses_row[str(i+1)] = []
+ responses_row[str(i + 1)] = []
e.add_field(
name=f"{emotes[i]} __{response.capitalize()}__",
value="**0** vote"
@@ -106,11 +149,12 @@ class Polls(commands.Cog):
e.set_footer(text=f"ID: {poll_row.id}")
poll_row.message_id = stmt.id
- poll_row.poll = e.to_dict()
+ poll_row.channel_id = stmt.channel.id
+ poll_row.content = e.to_dict()
poll_row.is_anonymous = anonymous
poll_row.responses = responses_row
- self.bot.engine.commit()
+ self.bot.database.session.commit()
await stmt.edit(content='', embed=e)
for emote in range(len(responses)):
diff --git a/cogs/utils/database.py b/cogs/utils/database.py
new file mode 100644
index 0000000..3e8d2bc
--- /dev/null
+++ b/cogs/utils/database.py
@@ -0,0 +1,11 @@
+from sqlalchemy import create_engine
+from sqlalchemy.orm import sessionmaker, session
+
+
+class Database:
+ def __init__(self, config):
+ self.engine = create_engine(config.postgresql)
+
+ Session = sessionmaker()
+ Session.configure(bind=self.engine)
+ self.session: session = Session()
diff --git a/cogs/utils/lang.py b/cogs/utils/lang.py
index 48b69ca..a357e85 100644
--- a/cogs/utils/lang.py
+++ b/cogs/utils/lang.py
@@ -1,9 +1,8 @@
import gettext
import config
+from cogs.utils.database import Database
from .models.lang import Lang
-from sqlalchemy import create_engine
-from sqlalchemy.orm import sessionmaker
from discord.ext import commands
@@ -23,21 +22,16 @@ class Texts:
@staticmethod
def get_locale(ctx):
- engine = create_engine(config.postgresql)
-
- Session = sessionmaker()
- Session.configure(bind=engine)
-
- session = Session()
+ database = Database(config)
if ctx is not None:
- current = session\
+ current = database.session\
.query(Lang.value)\
.filter(Lang.key == str(ctx.guild.id))
if current.count() > 0:
return current.one()[0]
- default = session\
+ default = database.session\
.query(Lang.value)\
.filter(Lang.key == 'default')\
.one()[0]
diff --git a/cogs/utils/models/poll.py b/cogs/utils/models/poll.py
index aa30ece..ae6e4ce 100644
--- a/cogs/utils/models/poll.py
+++ b/cogs/utils/models/poll.py
@@ -8,13 +8,14 @@ class Poll(Base):
__tablename__ = 'polls'
id = Column(Integer, primary_key=True)
+ channel_id = Column(BigInteger)
message_id = Column(BigInteger)
- poll = Column(JSON)
+ content = Column(JSON)
is_anonymous = Column(Boolean)
responses = Column(JSON, nullable=True)
def __repr__(self):
- return "<Poll(id='%s', message_id='%s', poll='%s', " \
+ return "<Poll(id='%s', channel_id='%s', message_id='%s', poll='%s', " \
"is_anonymous='%s', responses='%s')>" % \
- (self.id, self.message_id, self.poll,
+ (self.id, self.channel_id, self.message_id, self.content,
self.is_anonymous, self.responses)
diff --git a/launcher.py b/launcher.py
index ff9146a..59d5240 100644
--- a/launcher.py
+++ b/launcher.py
@@ -14,8 +14,7 @@ import git
import requests
from bot import TuxBot
-from sqlalchemy import create_engine
-from sqlalchemy.orm import sessionmaker
+from cogs.utils.database import Database
@contextlib.contextmanager
@@ -50,17 +49,14 @@ def run_bot(unload: list = []):
print(Texts().get('Starting...'))
try:
- engine = create_engine(config.postgresql)
-
- Session = sessionmaker()
- Session.configure(bind=engine)
+ database = Database(config)
except socket.gaierror:
click.echo(Texts().get("Could not set up PostgreSQL..."),
file=sys.stderr)
log.exception(Texts().get("Could not set up PostgreSQL..."))
return
- bot = TuxBot(unload, Session())
+ bot = TuxBot(unload, database)
bot.run()
diff --git a/requirements.txt b/requirements.txt
index aeebb77..1fa6dca 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -8,4 +8,5 @@ sqlalchemy
gitpython
requests
psutil
-bcrypt
\ No newline at end of file
+bcrypt
+tcp_latency
\ No newline at end of file