update(commands:*|Network): speed optimisation

This commit is contained in:
Romain J 2021-04-22 18:11:55 +02:00
parent c6a5dc4ad6
commit c6c61a0886
8 changed files with 144 additions and 36 deletions

1
.gitignore vendored
View file

@ -46,3 +46,4 @@ build
data/settings/ data/settings/
dump.rdb

View file

@ -34,6 +34,7 @@
<w>outoutxyz</w> <w>outoutxyz</w>
<w>outouxyz</w> <w>outouxyz</w>
<w>pacman</w> <w>pacman</w>
<w>peeringdb</w>
<w>perso</w> <w>perso</w>
<w>postgre</w> <w>postgre</w>
<w>postgresql</w> <w>postgresql</w>

View file

@ -15,6 +15,7 @@ platforms = linux
packages = find_namespace: packages = find_namespace:
python_requires = >=3.9 python_requires = >=3.9
install_requires = install_requires =
asyncstdlib>=3.9.1
asyncpg>=0.21.0 asyncpg>=0.21.0
Babel>=2.8.0 Babel>=2.8.0
discord.py @ git+https://github.com/Rapptz/discord.py discord.py @ git+https://github.com/Rapptz/discord.py

View file

@ -1,4 +1,5 @@
from discord.ext import commands from discord.ext import commands
from discord.ext.commands import Context
def _(x): def _(x):
@ -6,7 +7,7 @@ def _(x):
class IPConverter(commands.Converter): class IPConverter(commands.Converter):
async def convert(self, ctx, argument): # skipcq: PYL-W0613 async def convert(self, ctx: Context, argument: str): # skipcq: PYL-W0613
argument = argument.replace("http://", "").replace("https://", "") argument = argument.replace("http://", "").replace("https://", "")
argument = argument.rstrip("/") argument = argument.rstrip("/")
@ -17,7 +18,7 @@ class IPConverter(commands.Converter):
class DomainConverter(commands.Converter): class DomainConverter(commands.Converter):
async def convert(self, ctx, argument): # skipcq: PYL-W0613 async def convert(self, ctx: Context, argument: str): # skipcq: PYL-W0613
if not argument.startswith("http"): if not argument.startswith("http"):
return f"http://{argument}" return f"http://{argument}"
@ -25,13 +26,18 @@ class DomainConverter(commands.Converter):
class QueryTypeConverter(commands.Converter): class QueryTypeConverter(commands.Converter):
async def convert(self, ctx, argument): # skipcq: PYL-W0613 async def convert(self, ctx: Context, argument: str): # skipcq: PYL-W0613
return argument.lower() return argument.lower()
class IPVersionConverter(commands.Converter): class IPVersionConverter(commands.Converter):
async def convert(self, ctx, argument): # skipcq: PYL-W0613 async def convert(self, ctx: Context, argument: str): # skipcq: PYL-W0613
if not argument: if not argument:
return argument return argument
return argument.replace("-", "").replace("ip", "").replace("v", "") return argument.replace("-", "").replace("ip", "").replace("v", "")
class ASConverter(commands.Converter):
async def convert(self, ctx: Context, argument: str): # skipcq: PYL-W0613
return argument.lower().lstrip("as")

View file

@ -23,3 +23,7 @@ class InvalidQueryType(NetworkException):
class VersionNotFound(NetworkException): class VersionNotFound(NetworkException):
pass pass
class InvalidAsn(NetworkException):
pass

View file

@ -13,11 +13,15 @@ from ipinfo.exceptions import RequestQuotaExceededError
from ipwhois import Net from ipwhois import Net
from ipwhois.asn import IPASN from ipwhois.asn import IPASN
from aiocache import cached
from aiocache.serializers import PickleSerializer
from tuxbot.cogs.Network.functions.exceptions import ( from tuxbot.cogs.Network.functions.exceptions import (
VersionNotFound, VersionNotFound,
RFC18, RFC18,
InvalidIp, InvalidIp,
InvalidQueryType, InvalidQueryType,
InvalidAsn,
) )
@ -25,7 +29,8 @@ def _(x):
return x return x
def get_ip(ip: str, inet: str = "") -> str: @cached(ttl=15 * 60, serializer=PickleSerializer())
async def get_ip(loop, ip: str, inet: str = "") -> str:
_inet: socket.AddressFamily | int = 0 # pylint: disable=no-member _inet: socket.AddressFamily | int = 0 # pylint: disable=no-member
if inet == "6": if inet == "6":
@ -33,8 +38,9 @@ def get_ip(ip: str, inet: str = "") -> str:
elif inet == "4": elif inet == "4":
_inet = socket.AF_INET _inet = socket.AF_INET
def _get_ip(_ip: str):
try: try:
return socket.getaddrinfo(str(ip), None, _inet)[1][4][0] return socket.getaddrinfo(_ip, None, _inet)[1][4][0]
except socket.gaierror as e: except socket.gaierror as e:
raise VersionNotFound( raise VersionNotFound(
_( _(
@ -43,7 +49,10 @@ def get_ip(ip: str, inet: str = "") -> str:
) )
) from e ) from e
return await loop.run_in_executor(None, _get_ip, str(ip))
@cached(ttl=15 * 60, serializer=PickleSerializer())
async def get_hostname(loop, ip: str) -> str: async def get_hostname(loop, ip: str) -> str:
def _get_hostname(_ip: str): def _get_hostname(_ip: str):
try: try:
@ -62,6 +71,7 @@ async def get_hostname(loop, ip: str) -> str:
return "N/A" return "N/A"
@cached(ttl=15 * 60, serializer=PickleSerializer())
async def get_ipwhois_result(loop, ip_address: str) -> NoReturn | dict: async def get_ipwhois_result(loop, ip_address: str) -> NoReturn | dict:
def _get_ipwhois_result(_ip_address: str) -> NoReturn | dict: def _get_ipwhois_result(_ip_address: str) -> NoReturn | dict:
try: try:
@ -87,6 +97,7 @@ async def get_ipwhois_result(loop, ip_address: str) -> NoReturn | dict:
return {} return {}
@cached(ttl=15 * 60, serializer=PickleSerializer())
async def get_ipinfo_result(apikey: str, ip_address: str) -> dict: async def get_ipinfo_result(apikey: str, ip_address: str) -> dict:
try: try:
handler = ipinfo.getHandlerAsync( handler = ipinfo.getHandlerAsync(
@ -97,6 +108,7 @@ async def get_ipinfo_result(apikey: str, ip_address: str) -> dict:
return {} return {}
@cached(ttl=15 * 60, serializer=PickleSerializer())
async def get_crimeflare_result( async def get_crimeflare_result(
session: aiohttp.ClientSession, ip_address: str session: aiohttp.ClientSession, ip_address: str
) -> Optional[str]: ) -> Optional[str]:
@ -149,11 +161,13 @@ def merge_ipinfo_ipwhois(ipinfo_result: dict, ipwhois_result: dict) -> dict:
return output return output
@cached(ttl=15 * 60, serializer=PickleSerializer())
async def get_pydig_result( async def get_pydig_result(
domain: str, query_type: str, dnssec: str | bool loop, domain: str, query_type: str, dnssec: str | bool
) -> list: ) -> list:
additional_args = [] if dnssec is False else ["+dnssec"] additional_args = [] if dnssec is False else ["+dnssec"]
def _get_pydig_result(_domain: str) -> NoReturn | dict:
resolver = pydig.Resolver( resolver = pydig.Resolver(
nameservers=[ nameservers=[
"80.67.169.40", "80.67.169.40",
@ -162,7 +176,59 @@ async def get_pydig_result(
additional_args=additional_args, additional_args=additional_args,
) )
return resolver.query(domain, query_type) return resolver.query(_domain, query_type)
try:
return await asyncio.wait_for(
loop.run_in_executor(None, _get_pydig_result, str(domain)),
timeout=0.500,
)
except asyncio.exceptions.TimeoutError:
return []
@cached(ttl=15 * 60, serializer=PickleSerializer())
async def get_peeringdb_as_set_result(
session: aiohttp.ClientSession, asn: str
) -> Optional[dict]:
try:
async with session.get(
f"https://www.peeringdb.com/api/as_set/{asn}",
timeout=aiohttp.ClientTimeout(total=5),
) as s:
return await s.json()
except (
aiohttp.ClientError,
aiohttp.ContentTypeError,
asyncio.exceptions.TimeoutError,
):
pass
return None
@cached(ttl=15 * 60, serializer=PickleSerializer())
async def get_peeringdb_net_irr_as_set_result(
session: aiohttp.ClientSession, asn: str
) -> Optional[dict]:
try:
async with session.get(
f"https://www.peeringdb.com/api/net?irr_as_set={asn}",
timeout=aiohttp.ClientTimeout(total=10),
) as s:
json = await s.json()
for data in json:
if data["asn"] == int(asn):
return data
except (
aiohttp.ClientError,
aiohttp.ContentTypeError,
asyncio.exceptions.TimeoutError,
):
pass
return None
def check_ip_version_or_raise(version: str) -> bool | NoReturn: def check_ip_version_or_raise(version: str) -> bool | NoReturn:
@ -194,3 +260,10 @@ def check_query_type_or_raise(query_type: str) -> bool | NoReturn:
"Supported queries : A, AAAA, CNAME, NS, DS, DNSKEY, SOA, TXT, PTR, MX" "Supported queries : A, AAAA, CNAME, NS, DS, DNSKEY, SOA, TXT, PTR, MX"
) )
) )
def check_asn_or_raise(asn: str) -> bool | NoReturn:
if asn.isdigit() and int(asn) < 4_294_967_295:
return True
raise InvalidAsn(_("Invalid ASN provided"))

View file

@ -15,6 +15,7 @@ from tuxbot.cogs.Network.functions.converters import (
IPVersionConverter, IPVersionConverter,
DomainConverter, DomainConverter,
QueryTypeConverter, QueryTypeConverter,
ASConverter,
) )
from tuxbot.cogs.Network.functions.exceptions import ( from tuxbot.cogs.Network.functions.exceptions import (
RFC18, RFC18,
@ -22,6 +23,7 @@ from tuxbot.cogs.Network.functions.exceptions import (
VersionNotFound, VersionNotFound,
InvalidDomain, InvalidDomain,
InvalidQueryType, InvalidQueryType,
InvalidAsn,
) )
from tuxbot.core.bot import Tux from tuxbot.core.bot import Tux
from tuxbot.core.i18n import ( from tuxbot.core.i18n import (
@ -37,13 +39,16 @@ from .config import NetworkConfig
from .functions.utils import ( from .functions.utils import (
get_ip, get_ip,
get_hostname, get_hostname,
get_crimeflare_result,
get_ipinfo_result, get_ipinfo_result,
get_ipwhois_result, get_ipwhois_result,
merge_ipinfo_ipwhois,
get_pydig_result, get_pydig_result,
# get_peeringdb_as_set_result,
# get_peeringdb_net_irr_as_set_result,
merge_ipinfo_ipwhois,
check_query_type_or_raise, check_query_type_or_raise,
check_ip_version_or_raise, check_ip_version_or_raise,
get_crimeflare_result, check_asn_or_raise,
) )
log = logging.getLogger("tuxbot.cogs.Network") log = logging.getLogger("tuxbot.cogs.Network")
@ -68,6 +73,7 @@ class Network(commands.Cog):
InvalidDomain, InvalidDomain,
InvalidQueryType, InvalidQueryType,
VersionNotFound, VersionNotFound,
InvalidAsn,
), ),
): ):
await ctx.send(_(str(error), ctx, self.bot.config)) await ctx.send(_(str(error), ctx, self.bot.config))
@ -87,9 +93,8 @@ class Network(commands.Cog):
): ):
check_ip_version_or_raise(str(version)) check_ip_version_or_raise(str(version))
ip_address = await self.bot.loop.run_in_executor( ip_address = await get_ip(self.bot.loop, str(ip), str(version))
None, get_ip, str(ip), str(version)
)
ip_hostname = await get_hostname(self.bot.loop, str(ip_address)) ip_hostname = await get_hostname(self.bot.loop, str(ip_address))
ipinfo_result = await get_ipinfo_result( ipinfo_result = await get_ipinfo_result(
@ -222,7 +227,7 @@ class Network(commands.Cog):
check_query_type_or_raise(str(query_type)) check_query_type_or_raise(str(query_type))
pydig_result = await get_pydig_result( pydig_result = await get_pydig_result(
str(domain), str(query_type), dnssec self.bot.loop, str(domain), str(query_type), dnssec
) )
e = discord.Embed(title=f"DIG {domain} {query_type}", color=0x5858D7) e = discord.Embed(title=f"DIG {domain} {query_type}", color=0x5858D7)
@ -285,3 +290,26 @@ class Network(commands.Cog):
domain domain
) )
) )
@command_extra(
name="peeringdb", aliases=["peer", "peering"], deletable=True
)
async def _peeringdb(self, ctx: ContextPlus, asn: ASConverter):
check_asn_or_raise(str(asn))
return await ctx.send("Not implemented yet")
# peeringdb_as_set_result = await get_peeringdb_as_set_result(
# self.bot.session, str(asn)
# )
# peeringdb_net_irr_as_set_result = (
# await get_peeringdb_net_irr_as_set_result(
# self.bot.session, peeringdb_as_set_result["data"][0][asn]
# )
# )["data"]
#
# data = peeringdb_net_irr_as_set_result
#
# self.bot.console.log(data)
#
# await ctx.send("done")

View file

@ -90,8 +90,6 @@ class Tux(commands.AutoShardedBot):
self._app_owners_fetched = False # to prevent abusive API calls self._app_owners_fetched = False # to prevent abusive API calls
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self.before_invoke(self._typing)
super().__init__( super().__init__(
*args, *args,
intents=discord.Intents.all(), intents=discord.Intents.all(),
@ -121,10 +119,6 @@ class Tux(commands.AutoShardedBot):
return False return False
@staticmethod
async def _typing(ctx: ContextPlus) -> None:
await ctx.trigger_typing()
async def load_packages(self): async def load_packages(self):
if packages: if packages:
with Progress() as progress: with Progress() as progress: