update(commands:iplocalise,getheaders|Network): speed optimisation

This commit is contained in:
Romain J 2021-04-22 00:16:37 +02:00
parent eca6e7b268
commit 561f56ca27
10 changed files with 198 additions and 161 deletions

View file

@ -1,5 +1,5 @@
import socket
from typing import Union, NoReturn, Optional
from typing import NoReturn, Optional
import asyncio
import re
@ -25,8 +25,8 @@ def _(x):
return x
async def get_ip(ip: str, inet: str = "") -> str:
_inet: Union[socket.AddressFamily, int] = 0 # pylint: disable=no-member
def get_ip(ip: str, inet: str = "") -> str:
_inet: socket.AddressFamily | int = 0 # pylint: disable=no-member
if inet == "6":
_inet = socket.AF_INET6
@ -44,14 +44,26 @@ async def get_ip(ip: str, inet: str = "") -> str:
) from e
async def get_hostname(ip: str) -> str:
async def get_hostname(loop, ip: str) -> str:
def _get_hostname(_ip: str):
try:
return socket.gethostbyaddr(ip)[0]
except socket.herror:
return "N/A"
try:
return await asyncio.wait_for(
loop.run_in_executor(None, _get_hostname, str(ip)),
timeout=0.200,
)
# assuming that if the hostname isn't retrieved in first .3sec,
# it doesn't exists
except asyncio.exceptions.TimeoutError:
return "N/A"
async def get_ipwhois_result(ip_address: str) -> Union[NoReturn, dict]:
async def get_ipwhois_result(loop, ip_address: str) -> NoReturn | dict:
def _get_ipwhois_result(_ip_address: str) -> NoReturn | dict:
try:
net = Net(ip_address)
obj = IPASN(net)
@ -66,6 +78,14 @@ async def get_ipwhois_result(ip_address: str) -> Union[NoReturn, dict]:
)
) from e
try:
return await asyncio.wait_for(
loop.run_in_executor(None, _get_ipwhois_result, str(ip_address)),
timeout=0.200,
)
except asyncio.exceptions.TimeoutError:
return {}
async def get_ipinfo_result(apikey: str, ip_address: str) -> dict:
try:
@ -130,7 +150,7 @@ def merge_ipinfo_ipwhois(ipinfo_result: dict, ipwhois_result: dict) -> dict:
async def get_pydig_result(
domain: str, query_type: str, dnssec: Union[str, bool]
domain: str, query_type: str, dnssec: str | bool
) -> list:
additional_args = [] if dnssec is False else ["+dnssec"]
@ -145,14 +165,14 @@ async def get_pydig_result(
return resolver.query(domain, query_type)
def check_ip_version_or_raise(version: str) -> Union[bool, NoReturn]:
def check_ip_version_or_raise(version: str) -> bool | NoReturn:
if version in ("4", "6", "None"):
return True
raise InvalidIp(_("Invalid ip version"))
def check_query_type_or_raise(query_type: str) -> Union[bool, NoReturn]:
def check_query_type_or_raise(query_type: str) -> bool | NoReturn:
query_types = (
"a",
"aaaa",

View file

@ -1,7 +1,7 @@
import asyncio
import logging
import time
from typing import Union, Optional
from typing import Optional
import aiohttp
import discord
@ -87,13 +87,15 @@ class Network(commands.Cog):
):
check_ip_version_or_raise(str(version))
ip_address = await get_ip(str(ip), str(version))
ip_hostname = await get_hostname(ip_address)
ip_address = await self.bot.loop.run_in_executor(
None, get_ip, str(ip), str(version)
)
ip_hostname = await get_hostname(self.bot.loop, str(ip_address))
ipinfo_result = await get_ipinfo_result(
self.__config.ipinfoKey, ip_address
)
ipwhois_result = await get_ipwhois_result(ip_address)
ipwhois_result = await get_ipwhois_result(self.bot.loop, ip_address)
merged_results = merge_ipinfo_ipwhois(ipinfo_result, ipwhois_result)
@ -185,8 +187,10 @@ class Network(commands.Cog):
headers = dict(s.headers.items())
headers.pop("Set-Cookie", headers)
fail = False
for key, value in headers.items():
output = await shorten(ctx.session, value, 50)
fail, output = await shorten(ctx.session, value, 50, fail)
if output["link"]:
value = _(
@ -209,7 +213,7 @@ class Network(commands.Cog):
ctx: ContextPlus,
domain: IPConverter,
query_type: QueryTypeConverter,
dnssec: Union[str, bool] = False,
dnssec: str | bool = False,
):
check_query_type_or_raise(str(query_type))

View file

@ -1,6 +1,6 @@
import json
import logging
from typing import Union, Dict
from typing import Dict
import discord
from discord.ext import commands
@ -90,7 +90,7 @@ class Polls(commands.Cog):
async def get_poll(
self, pld: discord.RawReactionActionEvent
) -> Union[bool, Poll]:
) -> bool | Poll:
if pld.user_id != self.bot.user.id:
poll = await Poll.get_or_none(message_id=pld.message_id)
@ -225,7 +225,7 @@ class Polls(commands.Cog):
async def get_suggest(
self, pld: discord.RawReactionActionEvent
) -> Union[bool, Suggest]:
) -> bool | Suggest:
if pld.user_id != self.bot.user.id:
suggest = await Suggest.get_or_none(message_id=pld.message_id)

View file

@ -29,12 +29,10 @@ class Utils(commands.Cog):
@command_extra(name="info", aliases=["about"])
async def _info(self, ctx: ContextPlus):
proc = psutil.Process()
infos = fetch_info()
with proc.oneshot():
mem = proc.memory_full_info()
cpu = proc.cpu_percent() / psutil.cpu_count()
mem = psutil.Process().memory_full_info()
cpu = psutil.cpu_percent()
e = discord.Embed(
title=_("Information about TuxBot", ctx, self.bot.config),

View file

@ -3,7 +3,7 @@ import datetime
import importlib
import logging
from collections import Counter
from typing import List, Union, Tuple
from typing import List, Tuple
import aiohttp
import discord
@ -94,7 +94,6 @@ class Tux(commands.AutoShardedBot):
super().__init__(
*args,
# help_command=None,
intents=discord.Intents.all(),
loop=self.loop,
**kwargs,
@ -204,13 +203,13 @@ class Tux(commands.AutoShardedBot):
self.console.print()
async def is_owner(
self, user: Union[discord.User, discord.Member, discord.Object]
self, user: discord.User | discord.Member | discord.Object
) -> bool:
"""Determines if the user is a bot owner.
Parameters
----------
user: Union[discord.User, discord.Member]
user: discord.User | discord.Member
Returns
-------

View file

@ -1,7 +1,7 @@
import logging
import os
from pathlib import Path
from typing import Union, Dict, NoReturn, Any, Tuple
from typing import Dict, NoReturn, Any, Tuple
from babel.messages.pofile import read_po
@ -19,7 +19,7 @@ available_locales: Dict[str, Tuple] = {
}
def find_locale(locale: str) -> Union[str, NoReturn]:
def find_locale(locale: str) -> str | NoReturn:
"""We suppose `locale` is in `_available_locales.values()`"""
for key, val in available_locales.items():
@ -46,9 +46,7 @@ def get_locale_name(locale: str) -> str:
class Translator:
"""Class to load texts at init."""
def __init__(
self, name: str, file_location: Union[Path, os.PathLike, str]
):
def __init__(self, name: str, file_location: Path | os.PathLike | str):
"""Initializes the Translator object.
Parameters

View file

@ -0,0 +1,18 @@
import time
class TimeSpent:
def __init__(self, *breakpoints):
self.breakpoints: tuple = breakpoints
self.times: list = [time.perf_counter()]
def update(self) -> None:
self.times.append(time.perf_counter())
def display(self) -> str:
output = ""
for i, value in enumerate(self.breakpoints):
output += f'\n{value}: {f"{(self.times[i + 1] - self.times[i]) * 1000:.2f}ms" if i + 1 < len(self.times) else "..."}'
return output

View file

@ -1,18 +1,18 @@
from typing import List, Union
from typing import List, Optional
import discord
from tuxbot.core.config import search_for
def get_prefixes(tux, guild: Union[discord.Guild, None]) -> List[str]:
def get_prefixes(tux, guild: Optional[discord.Guild]) -> List[str]:
"""Get custom prefixes for one guild.
Parameters
----------
tux:Tux
The bot instance.
guild:Union[discord.Guild, None]
guild:Optional[discord.Guild]
The required guild prefixes.
Returns
-------

View file

@ -27,24 +27,28 @@ def typing(func):
return wrapped
async def shorten(session, text: str, length: int) -> dict:
async def shorten(
session, text: str, length: int, fail: bool = False
) -> tuple[bool, dict]:
output: Dict[str, str] = {"text": text[:length], "link": ""}
if len(text) > length:
output["text"] += "[...]"
if not fail:
try:
async with session.post(
"https://paste.ramle.be/documents",
data=text.encode(),
timeout=aiohttp.ClientTimeout(total=2),
timeout=aiohttp.ClientTimeout(total=0.300),
) as r:
output[
"link"
] = f"https://paste.ramle.be/{(await r.json())['key']}"
except (aiohttp.ClientError, asyncio.exceptions.TimeoutError):
pass
fail = True
return output
return fail, output
def replace_in_dict(value: dict, search: str, replace: str) -> dict:

View file

@ -7,7 +7,7 @@ import sys
import json
from argparse import Namespace
from pathlib import Path
from typing import Union, List
from typing import List
from urllib import request
from rich.prompt import Prompt, IntPrompt
@ -121,7 +121,7 @@ def get_ip() -> str:
def get_multiple(
question: str, confirmation: str, value_type: type
) -> List[Union[str, int]]:
) -> List[str | int]:
"""Give possibility to user to fill multiple value.
Parameters
@ -135,12 +135,10 @@ def get_multiple(
Returns
-------
List[Union[str, int]]
List[str | int]
List containing user filled values.
"""
prompt: Union[IntPrompt, Prompt] = (
IntPrompt() if value_type is int else Prompt()
)
prompt: IntPrompt | Prompt = IntPrompt() if value_type is int else Prompt()
user_input = prompt.ask(question, console=console)
@ -168,14 +166,12 @@ def get_multiple(
return values
def get_extra(question: str, value_type: type) -> Union[str, int]:
prompt: Union[IntPrompt, Prompt] = (
IntPrompt() if value_type is int else Prompt()
)
def get_extra(question: str, value_type: type) -> str | int:
prompt: IntPrompt | Prompt = IntPrompt() if value_type is int else Prompt()
return prompt.ask(question, console=console)
def additional_config(cogs: Union[str, list] = "**"):
def additional_config(cogs: str | list = "**"):
"""Asking for additional configs in cogs.
Returns