update(commands:*|Network): speed optimisation

This commit is contained in:
Romain J 2021-04-22 18:11:55 +02:00
parent c6a5dc4ad6
commit c6c61a0886
8 changed files with 144 additions and 36 deletions

1
.gitignore vendored
View file

@ -46,3 +46,4 @@ build
data/settings/
dump.rdb

View file

@ -34,6 +34,7 @@
<w>outoutxyz</w>
<w>outouxyz</w>
<w>pacman</w>
<w>peeringdb</w>
<w>perso</w>
<w>postgre</w>
<w>postgresql</w>

View file

@ -15,6 +15,7 @@ platforms = linux
packages = find_namespace:
python_requires = >=3.9
install_requires =
asyncstdlib>=3.9.1
asyncpg>=0.21.0
Babel>=2.8.0
discord.py @ git+https://github.com/Rapptz/discord.py

View file

@ -1,4 +1,5 @@
from discord.ext import commands
from discord.ext.commands import Context
def _(x):
@ -6,7 +7,7 @@ def _(x):
class IPConverter(commands.Converter):
async def convert(self, ctx, argument): # skipcq: PYL-W0613
async def convert(self, ctx: Context, argument: str): # skipcq: PYL-W0613
argument = argument.replace("http://", "").replace("https://", "")
argument = argument.rstrip("/")
@ -17,7 +18,7 @@ class IPConverter(commands.Converter):
class DomainConverter(commands.Converter):
async def convert(self, ctx, argument): # skipcq: PYL-W0613
async def convert(self, ctx: Context, argument: str): # skipcq: PYL-W0613
if not argument.startswith("http"):
return f"http://{argument}"
@ -25,13 +26,18 @@ class DomainConverter(commands.Converter):
class QueryTypeConverter(commands.Converter):
async def convert(self, ctx, argument): # skipcq: PYL-W0613
async def convert(self, ctx: Context, argument: str): # skipcq: PYL-W0613
return argument.lower()
class IPVersionConverter(commands.Converter):
async def convert(self, ctx, argument): # skipcq: PYL-W0613
async def convert(self, ctx: Context, argument: str): # skipcq: PYL-W0613
if not argument:
return argument
return argument.replace("-", "").replace("ip", "").replace("v", "")
class ASConverter(commands.Converter):
async def convert(self, ctx: Context, argument: str): # skipcq: PYL-W0613
return argument.lower().lstrip("as")

View file

@ -23,3 +23,7 @@ class InvalidQueryType(NetworkException):
class VersionNotFound(NetworkException):
pass
class InvalidAsn(NetworkException):
pass

View file

@ -13,11 +13,15 @@ from ipinfo.exceptions import RequestQuotaExceededError
from ipwhois import Net
from ipwhois.asn import IPASN
from aiocache import cached
from aiocache.serializers import PickleSerializer
from tuxbot.cogs.Network.functions.exceptions import (
VersionNotFound,
RFC18,
InvalidIp,
InvalidQueryType,
InvalidAsn,
)
@ -25,7 +29,8 @@ def _(x):
return x
def get_ip(ip: str, inet: str = "") -> str:
@cached(ttl=15 * 60, serializer=PickleSerializer())
async def get_ip(loop, ip: str, inet: str = "") -> str:
_inet: socket.AddressFamily | int = 0 # pylint: disable=no-member
if inet == "6":
@ -33,17 +38,21 @@ def get_ip(ip: str, inet: str = "") -> str:
elif inet == "4":
_inet = socket.AF_INET
try:
return socket.getaddrinfo(str(ip), None, _inet)[1][4][0]
except socket.gaierror as e:
raise VersionNotFound(
_(
"Unable to collect information on this in the given "
"version",
)
) from e
def _get_ip(_ip: str):
try:
return socket.getaddrinfo(_ip, None, _inet)[1][4][0]
except socket.gaierror as e:
raise VersionNotFound(
_(
"Unable to collect information on this in the given "
"version",
)
) from e
return await loop.run_in_executor(None, _get_ip, str(ip))
@cached(ttl=15 * 60, serializer=PickleSerializer())
async def get_hostname(loop, ip: str) -> str:
def _get_hostname(_ip: str):
try:
@ -62,6 +71,7 @@ async def get_hostname(loop, ip: str) -> str:
return "N/A"
@cached(ttl=15 * 60, serializer=PickleSerializer())
async def get_ipwhois_result(loop, ip_address: str) -> NoReturn | dict:
def _get_ipwhois_result(_ip_address: str) -> NoReturn | dict:
try:
@ -87,6 +97,7 @@ async def get_ipwhois_result(loop, ip_address: str) -> NoReturn | dict:
return {}
@cached(ttl=15 * 60, serializer=PickleSerializer())
async def get_ipinfo_result(apikey: str, ip_address: str) -> dict:
try:
handler = ipinfo.getHandlerAsync(
@ -97,6 +108,7 @@ async def get_ipinfo_result(apikey: str, ip_address: str) -> dict:
return {}
@cached(ttl=15 * 60, serializer=PickleSerializer())
async def get_crimeflare_result(
session: aiohttp.ClientSession, ip_address: str
) -> Optional[str]:
@ -149,20 +161,74 @@ def merge_ipinfo_ipwhois(ipinfo_result: dict, ipwhois_result: dict) -> dict:
return output
@cached(ttl=15 * 60, serializer=PickleSerializer())
async def get_pydig_result(
domain: str, query_type: str, dnssec: str | bool
loop, domain: str, query_type: str, dnssec: str | bool
) -> list:
additional_args = [] if dnssec is False else ["+dnssec"]
resolver = pydig.Resolver(
nameservers=[
"80.67.169.40",
"80.67.169.12",
],
additional_args=additional_args,
)
def _get_pydig_result(_domain: str) -> NoReturn | dict:
resolver = pydig.Resolver(
nameservers=[
"80.67.169.40",
"80.67.169.12",
],
additional_args=additional_args,
)
return resolver.query(domain, query_type)
return resolver.query(_domain, query_type)
try:
return await asyncio.wait_for(
loop.run_in_executor(None, _get_pydig_result, str(domain)),
timeout=0.500,
)
except asyncio.exceptions.TimeoutError:
return []
@cached(ttl=15 * 60, serializer=PickleSerializer())
async def get_peeringdb_as_set_result(
session: aiohttp.ClientSession, asn: str
) -> Optional[dict]:
try:
async with session.get(
f"https://www.peeringdb.com/api/as_set/{asn}",
timeout=aiohttp.ClientTimeout(total=5),
) as s:
return await s.json()
except (
aiohttp.ClientError,
aiohttp.ContentTypeError,
asyncio.exceptions.TimeoutError,
):
pass
return None
@cached(ttl=15 * 60, serializer=PickleSerializer())
async def get_peeringdb_net_irr_as_set_result(
session: aiohttp.ClientSession, asn: str
) -> Optional[dict]:
try:
async with session.get(
f"https://www.peeringdb.com/api/net?irr_as_set={asn}",
timeout=aiohttp.ClientTimeout(total=10),
) as s:
json = await s.json()
for data in json:
if data["asn"] == int(asn):
return data
except (
aiohttp.ClientError,
aiohttp.ContentTypeError,
asyncio.exceptions.TimeoutError,
):
pass
return None
def check_ip_version_or_raise(version: str) -> bool | NoReturn:
@ -194,3 +260,10 @@ def check_query_type_or_raise(query_type: str) -> bool | NoReturn:
"Supported queries : A, AAAA, CNAME, NS, DS, DNSKEY, SOA, TXT, PTR, MX"
)
)
def check_asn_or_raise(asn: str) -> bool | NoReturn:
if asn.isdigit() and int(asn) < 4_294_967_295:
return True
raise InvalidAsn(_("Invalid ASN provided"))

View file

@ -15,6 +15,7 @@ from tuxbot.cogs.Network.functions.converters import (
IPVersionConverter,
DomainConverter,
QueryTypeConverter,
ASConverter,
)
from tuxbot.cogs.Network.functions.exceptions import (
RFC18,
@ -22,6 +23,7 @@ from tuxbot.cogs.Network.functions.exceptions import (
VersionNotFound,
InvalidDomain,
InvalidQueryType,
InvalidAsn,
)
from tuxbot.core.bot import Tux
from tuxbot.core.i18n import (
@ -37,13 +39,16 @@ from .config import NetworkConfig
from .functions.utils import (
get_ip,
get_hostname,
get_crimeflare_result,
get_ipinfo_result,
get_ipwhois_result,
merge_ipinfo_ipwhois,
get_pydig_result,
# get_peeringdb_as_set_result,
# get_peeringdb_net_irr_as_set_result,
merge_ipinfo_ipwhois,
check_query_type_or_raise,
check_ip_version_or_raise,
get_crimeflare_result,
check_asn_or_raise,
)
log = logging.getLogger("tuxbot.cogs.Network")
@ -68,6 +73,7 @@ class Network(commands.Cog):
InvalidDomain,
InvalidQueryType,
VersionNotFound,
InvalidAsn,
),
):
await ctx.send(_(str(error), ctx, self.bot.config))
@ -87,9 +93,8 @@ class Network(commands.Cog):
):
check_ip_version_or_raise(str(version))
ip_address = await self.bot.loop.run_in_executor(
None, get_ip, str(ip), str(version)
)
ip_address = await get_ip(self.bot.loop, str(ip), str(version))
ip_hostname = await get_hostname(self.bot.loop, str(ip_address))
ipinfo_result = await get_ipinfo_result(
@ -222,7 +227,7 @@ class Network(commands.Cog):
check_query_type_or_raise(str(query_type))
pydig_result = await get_pydig_result(
str(domain), str(query_type), dnssec
self.bot.loop, str(domain), str(query_type), dnssec
)
e = discord.Embed(title=f"DIG {domain} {query_type}", color=0x5858D7)
@ -285,3 +290,26 @@ class Network(commands.Cog):
domain
)
)
@command_extra(
name="peeringdb", aliases=["peer", "peering"], deletable=True
)
async def _peeringdb(self, ctx: ContextPlus, asn: ASConverter):
check_asn_or_raise(str(asn))
return await ctx.send("Not implemented yet")
# peeringdb_as_set_result = await get_peeringdb_as_set_result(
# self.bot.session, str(asn)
# )
# peeringdb_net_irr_as_set_result = (
# await get_peeringdb_net_irr_as_set_result(
# self.bot.session, peeringdb_as_set_result["data"][0][asn]
# )
# )["data"]
#
# data = peeringdb_net_irr_as_set_result
#
# self.bot.console.log(data)
#
# await ctx.send("done")

View file

@ -90,8 +90,6 @@ class Tux(commands.AutoShardedBot):
self._app_owners_fetched = False # to prevent abusive API calls
self.loop = asyncio.get_event_loop()
self.before_invoke(self._typing)
super().__init__(
*args,
intents=discord.Intents.all(),
@ -121,10 +119,6 @@ class Tux(commands.AutoShardedBot):
return False
@staticmethod
async def _typing(ctx: ContextPlus) -> None:
await ctx.trigger_typing()
async def load_packages(self):
if packages:
with Progress() as progress: