update(commands:iplocalise,getheaders|Network): speed optimisation

This commit is contained in:
Romain J 2021-04-22 00:16:37 +02:00
parent eca6e7b268
commit 561f56ca27
10 changed files with 198 additions and 161 deletions

View file

@ -1,5 +1,5 @@
import socket import socket
from typing import Union, NoReturn, Optional from typing import NoReturn, Optional
import asyncio import asyncio
import re import re
@ -25,8 +25,8 @@ def _(x):
return x return x
async def get_ip(ip: str, inet: str = "") -> str: def get_ip(ip: str, inet: str = "") -> str:
_inet: Union[socket.AddressFamily, int] = 0 # pylint: disable=no-member _inet: socket.AddressFamily | int = 0 # pylint: disable=no-member
if inet == "6": if inet == "6":
_inet = socket.AF_INET6 _inet = socket.AF_INET6
@ -44,27 +44,47 @@ async def get_ip(ip: str, inet: str = "") -> str:
) from e ) from e
async def get_hostname(ip: str) -> str: async def get_hostname(loop, ip: str) -> str:
def _get_hostname(_ip: str):
try:
return socket.gethostbyaddr(ip)[0]
except socket.herror:
return "N/A"
try: try:
return socket.gethostbyaddr(ip)[0] return await asyncio.wait_for(
except socket.herror: loop.run_in_executor(None, _get_hostname, str(ip)),
timeout=0.200,
)
# assuming that if the hostname isn't retrieved in first .3sec,
# it doesn't exists
except asyncio.exceptions.TimeoutError:
return "N/A" return "N/A"
async def get_ipwhois_result(ip_address: str) -> Union[NoReturn, dict]: async def get_ipwhois_result(loop, ip_address: str) -> NoReturn | dict:
def _get_ipwhois_result(_ip_address: str) -> NoReturn | dict:
try:
net = Net(ip_address)
obj = IPASN(net)
return obj.lookup()
except ipwhois.exceptions.ASNRegistryError:
return {}
except ipwhois.exceptions.IPDefinedError as e:
raise RFC18(
_(
"IP address {ip_address} is already defined as Private-Use"
" Networks via RFC 1918."
)
) from e
try: try:
net = Net(ip_address) return await asyncio.wait_for(
obj = IPASN(net) loop.run_in_executor(None, _get_ipwhois_result, str(ip_address)),
return obj.lookup() timeout=0.200,
except ipwhois.exceptions.ASNRegistryError: )
except asyncio.exceptions.TimeoutError:
return {} return {}
except ipwhois.exceptions.IPDefinedError as e:
raise RFC18(
_(
"IP address {ip_address} is already defined as Private-Use"
" Networks via RFC 1918."
)
) from e
async def get_ipinfo_result(apikey: str, ip_address: str) -> dict: async def get_ipinfo_result(apikey: str, ip_address: str) -> dict:
@ -130,7 +150,7 @@ def merge_ipinfo_ipwhois(ipinfo_result: dict, ipwhois_result: dict) -> dict:
async def get_pydig_result( async def get_pydig_result(
domain: str, query_type: str, dnssec: Union[str, bool] domain: str, query_type: str, dnssec: str | bool
) -> list: ) -> list:
additional_args = [] if dnssec is False else ["+dnssec"] additional_args = [] if dnssec is False else ["+dnssec"]
@ -145,14 +165,14 @@ async def get_pydig_result(
return resolver.query(domain, query_type) return resolver.query(domain, query_type)
def check_ip_version_or_raise(version: str) -> Union[bool, NoReturn]: def check_ip_version_or_raise(version: str) -> bool | NoReturn:
if version in ("4", "6", "None"): if version in ("4", "6", "None"):
return True return True
raise InvalidIp(_("Invalid ip version")) raise InvalidIp(_("Invalid ip version"))
def check_query_type_or_raise(query_type: str) -> Union[bool, NoReturn]: def check_query_type_or_raise(query_type: str) -> bool | NoReturn:
query_types = ( query_types = (
"a", "a",
"aaaa", "aaaa",

View file

@ -1,7 +1,7 @@
import asyncio import asyncio
import logging import logging
import time import time
from typing import Union, Optional from typing import Optional
import aiohttp import aiohttp
import discord import discord
@ -87,13 +87,15 @@ class Network(commands.Cog):
): ):
check_ip_version_or_raise(str(version)) check_ip_version_or_raise(str(version))
ip_address = await get_ip(str(ip), str(version)) ip_address = await self.bot.loop.run_in_executor(
ip_hostname = await get_hostname(ip_address) None, get_ip, str(ip), str(version)
)
ip_hostname = await get_hostname(self.bot.loop, str(ip_address))
ipinfo_result = await get_ipinfo_result( ipinfo_result = await get_ipinfo_result(
self.__config.ipinfoKey, ip_address self.__config.ipinfoKey, ip_address
) )
ipwhois_result = await get_ipwhois_result(ip_address) ipwhois_result = await get_ipwhois_result(self.bot.loop, ip_address)
merged_results = merge_ipinfo_ipwhois(ipinfo_result, ipwhois_result) merged_results = merge_ipinfo_ipwhois(ipinfo_result, ipwhois_result)
@ -185,8 +187,10 @@ class Network(commands.Cog):
headers = dict(s.headers.items()) headers = dict(s.headers.items())
headers.pop("Set-Cookie", headers) headers.pop("Set-Cookie", headers)
fail = False
for key, value in headers.items(): for key, value in headers.items():
output = await shorten(ctx.session, value, 50) fail, output = await shorten(ctx.session, value, 50, fail)
if output["link"]: if output["link"]:
value = _( value = _(
@ -209,7 +213,7 @@ class Network(commands.Cog):
ctx: ContextPlus, ctx: ContextPlus,
domain: IPConverter, domain: IPConverter,
query_type: QueryTypeConverter, query_type: QueryTypeConverter,
dnssec: Union[str, bool] = False, dnssec: str | bool = False,
): ):
check_query_type_or_raise(str(query_type)) check_query_type_or_raise(str(query_type))

View file

@ -1,6 +1,6 @@
import json import json
import logging import logging
from typing import Union, Dict from typing import Dict
import discord import discord
from discord.ext import commands from discord.ext import commands
@ -90,7 +90,7 @@ class Polls(commands.Cog):
async def get_poll( async def get_poll(
self, pld: discord.RawReactionActionEvent self, pld: discord.RawReactionActionEvent
) -> Union[bool, Poll]: ) -> bool | Poll:
if pld.user_id != self.bot.user.id: if pld.user_id != self.bot.user.id:
poll = await Poll.get_or_none(message_id=pld.message_id) poll = await Poll.get_or_none(message_id=pld.message_id)
@ -225,7 +225,7 @@ class Polls(commands.Cog):
async def get_suggest( async def get_suggest(
self, pld: discord.RawReactionActionEvent self, pld: discord.RawReactionActionEvent
) -> Union[bool, Suggest]: ) -> bool | Suggest:
if pld.user_id != self.bot.user.id: if pld.user_id != self.bot.user.id:
suggest = await Suggest.get_or_none(message_id=pld.message_id) suggest = await Suggest.get_or_none(message_id=pld.message_id)

View file

@ -29,110 +29,108 @@ class Utils(commands.Cog):
@command_extra(name="info", aliases=["about"]) @command_extra(name="info", aliases=["about"])
async def _info(self, ctx: ContextPlus): async def _info(self, ctx: ContextPlus):
proc = psutil.Process()
infos = fetch_info() infos = fetch_info()
with proc.oneshot(): mem = psutil.Process().memory_full_info()
mem = proc.memory_full_info() cpu = psutil.cpu_percent()
cpu = proc.cpu_percent() / psutil.cpu_count()
e = discord.Embed( e = discord.Embed(
title=_("Information about TuxBot", ctx, self.bot.config), title=_("Information about TuxBot", ctx, self.bot.config),
color=0x89C4F9, color=0x89C4F9,
) )
e.add_field( e.add_field(
name=_( name=_(
"__:busts_in_silhouette: Development__", "__:busts_in_silhouette: Development__",
ctx, ctx,
self.bot.config, self.bot.config,
), ),
value="**Romain#5117:** [git](https://git.gnous.eu/Romain)\n" value="**Romain#5117:** [git](https://git.gnous.eu/Romain)\n"
"**Outout#4039:** [git](https://git.gnous.eu/mael)\n", "**Outout#4039:** [git](https://git.gnous.eu/mael)\n",
inline=True, inline=True,
) )
e.add_field( e.add_field(
name="__<:python:596577462335307777> Python__", name="__<:python:596577462335307777> Python__",
value=f"**python** `{platform.python_version()}`\n" value=f"**python** `{platform.python_version()}`\n"
f"**discord.py** `{discord.__version__}`", f"**discord.py** `{discord.__version__}`",
inline=True, inline=True,
) )
e.add_field( e.add_field(
name="__:gear: Usage__", name="__:gear: Usage__",
value=_( value=_(
"**{}** physical memory\n" "**{}** physical memory\n"
"**{}** virtual memory\n" "**{}** virtual memory\n"
"**{:.2f}**% CPU", "**{:.2f}**% CPU",
ctx, ctx,
self.bot.config, self.bot.config,
).format( ).format(
humanize.naturalsize(mem.rss), humanize.naturalsize(mem.rss),
humanize.naturalsize(mem.vms), humanize.naturalsize(mem.vms),
cpu, cpu,
), ),
inline=True, inline=True,
) )
e.add_field( e.add_field(
name=_("__Servers count__", ctx, self.bot.config), name=_("__Servers count__", ctx, self.bot.config),
value=str(len(self.bot.guilds)), value=str(len(self.bot.guilds)),
inline=True, inline=True,
) )
e.add_field( e.add_field(
name=_("__Channels count__", ctx, self.bot.config), name=_("__Channels count__", ctx, self.bot.config),
value=str(len(list(self.bot.get_all_channels()))), value=str(len(list(self.bot.get_all_channels()))),
inline=True, inline=True,
) )
e.add_field( e.add_field(
name=_("__Members count__", ctx, self.bot.config), name=_("__Members count__", ctx, self.bot.config),
value=str(len(list(self.bot.get_all_members()))), value=str(len(list(self.bot.get_all_members()))),
inline=True, inline=True,
) )
e.add_field( e.add_field(
name=_("__:file_folder: Files__", ctx, self.bot.config), name=_("__:file_folder: Files__", ctx, self.bot.config),
value=f"{infos.get('file_amount')} " value=f"{infos.get('file_amount')} "
f"*({infos.get('python_file_amount')}" f"*({infos.get('python_file_amount')}"
f" <:python:596577462335307777>)*", f" <:python:596577462335307777>)*",
inline=True, inline=True,
) )
e.add_field( e.add_field(
name=_("__¶ Lines__", ctx, self.bot.config), name=_("__¶ Lines__", ctx, self.bot.config),
value=f"{infos.get('total_lines')} " value=f"{infos.get('total_lines')} "
f"*({infos.get('total_python_class')} " f"*({infos.get('total_python_class')} "
+ _("class", ctx, self.bot.config) + _("class", ctx, self.bot.config)
+ "," + ","
f" {infos.get('total_python_functions')} " f" {infos.get('total_python_functions')} "
+ _("functions", ctx, self.bot.config) + _("functions", ctx, self.bot.config)
+ "," + ","
f" {infos.get('total_python_coroutines')} " f" {infos.get('total_python_coroutines')} "
+ _("coroutines", ctx, self.bot.config) + _("coroutines", ctx, self.bot.config)
+ "," + ","
f" {infos.get('total_python_comments')} " f" {infos.get('total_python_comments')} "
+ _("comments", ctx, self.bot.config) + _("comments", ctx, self.bot.config)
+ ")*", + ")*",
inline=True, inline=True,
) )
e.add_field( e.add_field(
name=_("__Latest changes__", ctx, self.bot.config), name=_("__Latest changes__", ctx, self.bot.config),
value=version_info.info, value=version_info.info,
inline=False, inline=False,
) )
e.add_field( e.add_field(
name=_("__:link: Links__", ctx, self.bot.config), name=_("__:link: Links__", ctx, self.bot.config),
value="[tuxbot.gnous.eu](https://tuxbot.gnous.eu/) " value="[tuxbot.gnous.eu](https://tuxbot.gnous.eu/) "
"| [gnous.eu](https://gnous.eu/) " "| [gnous.eu](https://gnous.eu/) "
"| [git](https://git.gnous.eu/gnouseu/tuxbot-bot) " "| [git](https://git.gnous.eu/gnouseu/tuxbot-bot) "
"| [status](https://status.gnous.eu/check/154250) " "| [status](https://status.gnous.eu/check/154250) "
+ _("| [Invite]", ctx, self.bot.config) + _("| [Invite]", ctx, self.bot.config)
+ "(https://discordapp.com/oauth2/authorize?client_id=" + "(https://discordapp.com/oauth2/authorize?client_id="
"301062143942590465&scope=bot&permissions=268749888)", "301062143942590465&scope=bot&permissions=268749888)",
inline=False, inline=False,
) )
e.set_footer(text=f"version: {__version__} • prefix: {ctx.prefix}") e.set_footer(text=f"version: {__version__} • prefix: {ctx.prefix}")
await ctx.send(embed=e) await ctx.send(embed=e)

View file

@ -3,7 +3,7 @@ import datetime
import importlib import importlib
import logging import logging
from collections import Counter from collections import Counter
from typing import List, Union, Tuple from typing import List, Tuple
import aiohttp import aiohttp
import discord import discord
@ -94,7 +94,6 @@ class Tux(commands.AutoShardedBot):
super().__init__( super().__init__(
*args, *args,
# help_command=None,
intents=discord.Intents.all(), intents=discord.Intents.all(),
loop=self.loop, loop=self.loop,
**kwargs, **kwargs,
@ -204,13 +203,13 @@ class Tux(commands.AutoShardedBot):
self.console.print() self.console.print()
async def is_owner( async def is_owner(
self, user: Union[discord.User, discord.Member, discord.Object] self, user: discord.User | discord.Member | discord.Object
) -> bool: ) -> bool:
"""Determines if the user is a bot owner. """Determines if the user is a bot owner.
Parameters Parameters
---------- ----------
user: Union[discord.User, discord.Member] user: discord.User | discord.Member
Returns Returns
------- -------

View file

@ -1,7 +1,7 @@
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Union, Dict, NoReturn, Any, Tuple from typing import Dict, NoReturn, Any, Tuple
from babel.messages.pofile import read_po from babel.messages.pofile import read_po
@ -19,7 +19,7 @@ available_locales: Dict[str, Tuple] = {
} }
def find_locale(locale: str) -> Union[str, NoReturn]: def find_locale(locale: str) -> str | NoReturn:
"""We suppose `locale` is in `_available_locales.values()`""" """We suppose `locale` is in `_available_locales.values()`"""
for key, val in available_locales.items(): for key, val in available_locales.items():
@ -46,9 +46,7 @@ def get_locale_name(locale: str) -> str:
class Translator: class Translator:
"""Class to load texts at init.""" """Class to load texts at init."""
def __init__( def __init__(self, name: str, file_location: Path | os.PathLike | str):
self, name: str, file_location: Union[Path, os.PathLike, str]
):
"""Initializes the Translator object. """Initializes the Translator object.
Parameters Parameters

View file

@ -0,0 +1,18 @@
import time
class TimeSpent:
def __init__(self, *breakpoints):
self.breakpoints: tuple = breakpoints
self.times: list = [time.perf_counter()]
def update(self) -> None:
self.times.append(time.perf_counter())
def display(self) -> str:
output = ""
for i, value in enumerate(self.breakpoints):
output += f'\n{value}: {f"{(self.times[i + 1] - self.times[i]) * 1000:.2f}ms" if i + 1 < len(self.times) else "..."}'
return output

View file

@ -1,18 +1,18 @@
from typing import List, Union from typing import List, Optional
import discord import discord
from tuxbot.core.config import search_for from tuxbot.core.config import search_for
def get_prefixes(tux, guild: Union[discord.Guild, None]) -> List[str]: def get_prefixes(tux, guild: Optional[discord.Guild]) -> List[str]:
"""Get custom prefixes for one guild. """Get custom prefixes for one guild.
Parameters Parameters
---------- ----------
tux:Tux tux:Tux
The bot instance. The bot instance.
guild:Union[discord.Guild, None] guild:Optional[discord.Guild]
The required guild prefixes. The required guild prefixes.
Returns Returns
------- -------

View file

@ -27,24 +27,28 @@ def typing(func):
return wrapped return wrapped
async def shorten(session, text: str, length: int) -> dict: async def shorten(
session, text: str, length: int, fail: bool = False
) -> tuple[bool, dict]:
output: Dict[str, str] = {"text": text[:length], "link": ""} output: Dict[str, str] = {"text": text[:length], "link": ""}
if len(text) > length: if len(text) > length:
output["text"] += "[...]" output["text"] += "[...]"
try:
async with session.post(
"https://paste.ramle.be/documents",
data=text.encode(),
timeout=aiohttp.ClientTimeout(total=2),
) as r:
output[
"link"
] = f"https://paste.ramle.be/{(await r.json())['key']}"
except (aiohttp.ClientError, asyncio.exceptions.TimeoutError):
pass
return output if not fail:
try:
async with session.post(
"https://paste.ramle.be/documents",
data=text.encode(),
timeout=aiohttp.ClientTimeout(total=0.300),
) as r:
output[
"link"
] = f"https://paste.ramle.be/{(await r.json())['key']}"
except (aiohttp.ClientError, asyncio.exceptions.TimeoutError):
fail = True
return fail, output
def replace_in_dict(value: dict, search: str, replace: str) -> dict: def replace_in_dict(value: dict, search: str, replace: str) -> dict:

View file

@ -7,7 +7,7 @@ import sys
import json import json
from argparse import Namespace from argparse import Namespace
from pathlib import Path from pathlib import Path
from typing import Union, List from typing import List
from urllib import request from urllib import request
from rich.prompt import Prompt, IntPrompt from rich.prompt import Prompt, IntPrompt
@ -121,7 +121,7 @@ def get_ip() -> str:
def get_multiple( def get_multiple(
question: str, confirmation: str, value_type: type question: str, confirmation: str, value_type: type
) -> List[Union[str, int]]: ) -> List[str | int]:
"""Give possibility to user to fill multiple value. """Give possibility to user to fill multiple value.
Parameters Parameters
@ -135,12 +135,10 @@ def get_multiple(
Returns Returns
------- -------
List[Union[str, int]] List[str | int]
List containing user filled values. List containing user filled values.
""" """
prompt: Union[IntPrompt, Prompt] = ( prompt: IntPrompt | Prompt = IntPrompt() if value_type is int else Prompt()
IntPrompt() if value_type is int else Prompt()
)
user_input = prompt.ask(question, console=console) user_input = prompt.ask(question, console=console)
@ -168,14 +166,12 @@ def get_multiple(
return values return values
def get_extra(question: str, value_type: type) -> Union[str, int]: def get_extra(question: str, value_type: type) -> str | int:
prompt: Union[IntPrompt, Prompt] = ( prompt: IntPrompt | Prompt = IntPrompt() if value_type is int else Prompt()
IntPrompt() if value_type is int else Prompt()
)
return prompt.ask(question, console=console) return prompt.ask(question, console=console)
def additional_config(cogs: Union[str, list] = "**"): def additional_config(cogs: str | list = "**"):
"""Asking for additional configs in cogs. """Asking for additional configs in cogs.
Returns Returns