breaking change !

update(database): change database ORM

todo: update Admin, Poll and User cogs
This commit is contained in:
Romain J 2020-01-15 22:56:54 +01:00
parent 96618fa502
commit be1e6d24e4
16 changed files with 111 additions and 149 deletions

51
bot.py
View file

@ -10,11 +10,11 @@ import discord
import git import git
from discord.ext import commands from discord.ext import commands
from utils import Config from utils.functions import Config
from utils import Database from utils.functions import Database
from utils import Texts from utils.functions import Texts
from utils import Version from utils.functions import Version
from utils import ContextPlus from utils.functions import ContextPlus
description = """ description = """
Je suis TuxBot, le bot qui vit de l'OpenSource ! ;) Je suis TuxBot, le bot qui vit de l'OpenSource ! ;)
@ -52,7 +52,7 @@ async def _prefix_callable(bot, message: discord.message) -> list:
class TuxBot(commands.AutoShardedBot): class TuxBot(commands.AutoShardedBot):
def __init__(self, database): def __init__(self,):
super().__init__(command_prefix=_prefix_callable, pm_help=None, super().__init__(command_prefix=_prefix_callable, pm_help=None,
help_command=None, description=description, help_command=None, description=description,
help_attrs=dict(hidden=True), help_attrs=dict(hidden=True),
@ -63,20 +63,20 @@ class TuxBot(commands.AutoShardedBot):
self.socket_stats = Counter() self.socket_stats = Counter()
self.command_stats = Counter() self.command_stats = Counter()
self.uptime: datetime = datetime.datetime.utcnow()
self._prev_events = deque(maxlen=10)
self.session = aiohttp.ClientSession(loop=self.loop)
self.database = database
self.config = Config('./configs/config.cfg') self.config = Config('./configs/config.cfg')
self.prefixes = Config('./configs/prefixes.cfg') self.prefixes = Config('./configs/prefixes.cfg')
self.blacklist = Config('./configs/blacklist.cfg') self.blacklist = Config('./configs/blacklist.cfg')
self.fallbacks = Config('./configs/fallbacks.cfg') self.fallbacks = Config('./configs/fallbacks.cfg')
self.cluster = self.fallbacks.find('True', key='This', first=True) self.cluster = self.fallbacks.find('True', key='This', first=True)
self.uptime: datetime = datetime.datetime.utcnow()
self._prev_events = deque(maxlen=10)
self.session = aiohttp.ClientSession(loop=self.loop)
self.database = Database(self.config)
self.version = Version(*version, pre_release='rc2', build=build) self.version = Version(*version, pre_release='rc2', build=build)
self.owner: discord.User = discord.User self.owner_ids = self.config.get('permissions', 'Owners').split(', ')
self.owners: List[discord.User] = [] self.owner_id = int(self.owner_ids[0])
for extension in l_extensions: for extension in l_extensions:
try: try:
@ -93,8 +93,7 @@ class TuxBot(commands.AutoShardedBot):
+ extension, exc_info=e) + extension, exc_info=e)
async def is_owner(self, user: discord.User) -> bool: async def is_owner(self, user: discord.User) -> bool:
return str(user.id) in self.config.get("permissions", "Owners").split( return str(user.id) in self.owner_ids
', ')
async def get_context(self, message, *, cls=None): async def get_context(self, message, *, cls=None):
return await super().get_context(message, cls=cls or ContextPlus) return await super().get_context(message, cls=cls or ContextPlus)
@ -114,6 +113,8 @@ class TuxBot(commands.AutoShardedBot):
"Sorry. This command is disabled and cannot be used." "Sorry. This command is disabled and cannot be used."
) )
) )
elif isinstance(error, commands.CommandOnCooldown):
await ctx.send(str(error))
async def process_commands(self, message: discord.message): async def process_commands(self, message: discord.message):
ctx: commands.Context = await self.get_context(message) ctx: commands.Context = await self.get_context(message)
@ -193,13 +194,13 @@ class TuxBot(commands.AutoShardedBot):
@contextlib.contextmanager @contextlib.contextmanager
def setup_logging(): def setup_logging():
logging.getLogger('discord').setLevel(logging.INFO)
logging.getLogger('discord.http').setLevel(logging.WARNING)
log = logging.getLogger()
log.setLevel(logging.INFO)
try: try:
logging.getLogger('discord').setLevel(logging.INFO)
logging.getLogger('discord.http').setLevel(logging.WARNING)
log = logging.getLogger()
log.setLevel(logging.INFO)
handler = logging.FileHandler(filename='logs/tuxbot.log', handler = logging.FileHandler(filename='logs/tuxbot.log',
encoding='utf-8', mode='w') encoding='utf-8', mode='w')
fmt = logging.Formatter('[{levelname:<7}] [{asctime}]' fmt = logging.Formatter('[{levelname:<7}] [{asctime}]'
@ -218,14 +219,12 @@ def setup_logging():
if __name__ == "__main__": if __name__ == "__main__":
log = logging.getLogger()
print(Texts().get('Starting...')) print(Texts().get('Starting...'))
bot = TuxBot(Database(Config("./configs/config.cfg"))) app = TuxBot()
try: try:
with setup_logging(): with setup_logging():
bot.run() app.run()
except KeyboardInterrupt: except KeyboardInterrupt:
bot.close() app.close()

View file

@ -9,7 +9,7 @@ from discord.ext import commands
from bot import TuxBot from bot import TuxBot
from utils import Texts from utils import Texts
from utils import WarnModel, LangModel from utils.models import WarnModel
from utils import commandExtra, groupExtra from utils import commandExtra, groupExtra
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View file

@ -150,13 +150,6 @@ class Useful(commands.Cog):
########################################################################### ###########################################################################
@commands.Cog.listener()
async def on_command_error(self, ctx, error):
if isinstance(error, commands.CommandOnCooldown):
await ctx.send(error)
###########################################################################
@commandExtra(name='getheaders', category='network') @commandExtra(name='getheaders', category='network')
async def _getheaders(self, ctx: commands.Context, addr: str): async def _getheaders(self, ctx: commands.Context, addr: str):
if (addr.startswith('http') or addr.startswith('ftp')) is not True: if (addr.startswith('http') or addr.startswith('ftp')) is not True:

5
configs/langs.json Normal file
View file

@ -0,0 +1,5 @@
{
"default": "fr",
"available": ["en", "fr"],
"280805240977227776": "fr"
}

View file

@ -1,7 +1,5 @@
from utils import Config import sqlalchemy
from utils.models import Base from utils.models import database, metadata
from utils import Database
from utils.models.lang import LangModel
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -9,20 +7,13 @@ parser.add_argument("-m", "--migrate", action="store_true")
parser.add_argument("-s", "--seed", action="store_true") parser.add_argument("-s", "--seed", action="store_true")
args = parser.parse_args() args = parser.parse_args()
database = Database(Config("./configs/config.cfg"))
if args.migrate: if args.migrate:
print("Migrate...") print("Migrate...")
Base.metadata.create_all(database.engine) engine = sqlalchemy.create_engine(str(database.url))
metadata.create_all(engine)
print("Done!") print("Done!")
if args.seed: if args.seed:
print('Seeding...') print('Seeding...')
default = LangModel(key="default", value="fr") # todo: add seeding
available = LangModel(key="available", value="fr,en")
database.session.add(default)
database.session.add(available)
database.session.commit()
print("Done!") print("Done!")

View file

@ -3,7 +3,8 @@ humanize
git+https://github.com/Rapptz/discord.py@master git+https://github.com/Rapptz/discord.py@master
jishaku jishaku
gitpython gitpython
sqlalchemy orm
asyncpg
psycopg2 psycopg2
configparser configparser
psutil psutil

View file

@ -1,5 +1,3 @@
from .models import *
from utils.functions.config import * from utils.functions.config import *
from utils.functions.lang import * from utils.functions.lang import *
from utils.functions.version import * from utils.functions.version import *

View file

@ -0,0 +1,6 @@
from .config import Config
from .database import Database
from .extra import *
from .lang import Texts
from .paginator import *
from .version import Version

View file

@ -1,7 +1,7 @@
from .config import Config from .config import Config
from sqlalchemy import create_engine import sqlalchemy
from sqlalchemy.orm import sessionmaker, session import databases
class Database: class Database:
@ -10,8 +10,7 @@ class Database:
postgresql = 'postgresql://{}:{}@{}/{}'.format( postgresql = 'postgresql://{}:{}@{}/{}'.format(
conf_postgresql.get("Username"), conf_postgresql.get("Password"), conf_postgresql.get("Username"), conf_postgresql.get("Password"),
conf_postgresql.get("Host"), conf_postgresql.get("DBName")) conf_postgresql.get("Host"), conf_postgresql.get("DBName"))
self.engine = create_engine(postgresql, echo=False)
Session = sessionmaker() self.database = databases.Database(postgresql)
Session.configure(bind=self.engine) self.metadata = sqlalchemy.MetaData()
self.session: session = Session() self.engine = sqlalchemy.create_engine(str(self.database.url))

View file

@ -1,5 +1,5 @@
from discord.ext import commands from discord.ext import commands
from utils import Config from utils.functions import Config
class CommandsPlus(commands.Command): class CommandsPlus(commands.Command):

View file

@ -1,8 +1,6 @@
import gettext import gettext
from .config import Config import json
from .database import Database
from utils.models.lang import LangModel
from discord.ext import commands from discord.ext import commands
@ -22,17 +20,10 @@ class Texts:
@staticmethod @staticmethod
def get_locale(ctx): def get_locale(ctx):
database = Database(Config("./configs/config.cfg")) with open('./configs/langs.json') as f:
data = json.load(f)
if ctx is not None: if ctx is not None:
current = database.session\ return data.get(str(ctx.guild.id), data['default'])
.query(LangModel.value)\ else:
.filter(LangModel.key == str(ctx.guild.id)) return data['default']
if current.count() > 0:
return current.one()[0]
default = database.session\
.query(LangModel.value)\
.filter(LangModel.key == 'default')\
.one()[0]
return default

View file

@ -1,7 +1,15 @@
from sqlalchemy.ext.declarative import declarative_base import databases
Base = declarative_base() import sqlalchemy
from utils.functions import Config
conf_postgresql = Config('./configs/config.cfg')["postgresql"]
postgresql = 'postgresql://{}:{}@{}/{}'.format(
conf_postgresql.get("Username"), conf_postgresql.get("Password"),
conf_postgresql.get("Host"), conf_postgresql.get("DBName"))
database = databases.Database(postgresql)
metadata = sqlalchemy.MetaData()
from .lang import LangModel
from .warn import WarnModel from .warn import WarnModel
from .poll import PollModel, ResponsesModel from .poll import PollModel, ResponsesModel
from .alias import AliasesModel from .alias import AliasesModel

View file

@ -1,28 +1,14 @@
from sqlalchemy import Column, String, BigInteger, Integer import orm
from . import database, metadata
from . import Base
class AliasesModel(Base): class AliasesModel(orm.Model):
__tablename__ = 'aliases' __tablename__ = 'aliases'
__database__ = database
__metadata__ = metadata
id = Column(Integer, primary_key=True) id = orm.Integer(primary_key=True)
user_id = Column(BigInteger) user_id = orm.String(max_length=18)
alias = Column(String) alias = orm.String(max_length=255)
command = Column(String) command = orm.String(max_length=255)
guild = Column(String) guild = orm.String(max_length=255)
def __repr__(self):
return "<AliasesModel(" \
"id='%s', " \
"user_id='%s', " \
"alias='%s', " \
"command='%s', " \
"guild='%s', " \
")>" % (
self.id,
self.user_id,
self.alias,
self.command,
self.guild
)

View file

@ -1,12 +0,0 @@
from . import Base
from sqlalchemy import Column, String
class LangModel(Base):
__tablename__ = 'langs'
key = Column(String, primary_key=True)
value = Column(String)
def __repr__(self):
return "<LangModel(key='%s', locale='%s')>" % (self.key, self.value)

View file

@ -1,27 +1,29 @@
from . import Base import orm
from sqlalchemy import Column, Integer, BigInteger, JSON, ForeignKey, Boolean from . import database, metadata
from sqlalchemy.orm import relationship
class PollModel(Base): class ResponsesModel(orm.Model):
__tablename__ = 'polls'
id = Column(Integer, primary_key=True, autoincrement=True)
channel_id = Column(BigInteger)
message_id = Column(BigInteger)
content = Column(JSON)
is_anonymous = Column(Boolean)
available_choices = Column(Integer)
choice = relationship("ResponsesModel")
class ResponsesModel(Base):
__tablename__ = 'responses' __tablename__ = 'responses'
__database__ = database
__metadata__ = metadata
id = Column(Integer, primary_key=True, autoincrement=True) id = orm.Integer(primary_key=True)
user = Column(BigInteger) user = orm.String(max_length=18)
poll_id = Column(Integer, ForeignKey('polls.id')) choice = orm.Integer()
choice = Column(Integer)
class PollModel(orm.Model):
__tablename__ = 'polls'
__database__ = database
__metadata__ = metadata
id = orm.Integer(primary_key=True)
channel_id = orm.String(max_length=18)
message_id = orm.String(max_length=18)
content = orm.JSON()
is_anonymous = orm.Boolean()
available_choices = orm.Integer()
choice = orm.ForeignKey(ResponsesModel)

View file

@ -1,19 +1,14 @@
import datetime import orm
from . import database, metadata
from . import Base
from sqlalchemy import Column, Integer, String, BIGINT, TIMESTAMP
class WarnModel(Base): class WarnModel(orm.Model):
__tablename__ = 'warns' __tablename__ = 'warns'
__database__ = database
__metadata__ = metadata
id = Column(Integer, primary_key=True) id = orm.Integer(primary_key=True)
server_id = Column(BIGINT) server_id = orm.String(max_length=18)
user_id = Column(BIGINT) user_id = orm.String(max_length=18)
reason = Column(String) reason = orm.String(max_length=255)
created_at = Column(TIMESTAMP, default=datetime.datetime.now()) created_at = orm.DateTime()
def __repr__(self):
return "<WarnModel(server_id='%s', user_id='%s', reason='%s', " \
"created_at='%s')>" \
% (self.server_id, self.user_id, self.reason, self.created_at)