tuxbot-bot/venv/lib/python3.7/site-packages/import_expression/_parser.py
2019-12-16 18:12:10 +01:00

212 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright © 20182019 Io Mintz <io@mintz.cc>
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the “Software”),
# to deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
import ast
import contextlib
from collections import namedtuple
from .constants import *
def parse_ast(root_node, **kwargs): return ast.fix_missing_locations(Transformer(**kwargs).visit(root_node))
def remove_string_right(haystack, needle):
left, needle, right = haystack.rpartition(needle)
if not right:
return left
# needle not found
return haystack
def remove_import_op(name): return remove_string_right(name, MARKER)
def has_any_import_op(name): return MARKER in name
def has_invalid_import_op(name):
removed = remove_import_op(name)
return MARKER in removed or not removed
def has_valid_import_op(name): return name.endswith(MARKER) and remove_import_op(name)
SyntaxErrorContext = namedtuple('SyntaxErrorContext', 'filename lineno column line')
class Transformer(ast.NodeTransformer):
def __init__(self, *, filename=None, source=None):
self.filename = filename
self.source_lines = source.splitlines() if source is not None else None
def visit_Attribute(self, node):
"""
convert Attribute nodes containing import expressions into Attribute nodes containing import calls
"""
self._ensure_only_valid_import_ops(node)
maybe_transformed = self._transform_attribute_attr(node)
if maybe_transformed:
return maybe_transformed
else:
transformed_lhs = self.visit(node.value)
return ast.copy_location(
ast.Attribute(
value=transformed_lhs,
ctx=node.ctx,
attr=node.attr),
node)
def visit_Name(self, node):
"""convert solitary Names that have import expressions, such as "a!", into import calls"""
self._ensure_only_valid_import_ops(node)
is_import = id = has_valid_import_op(node.id)
if is_import:
return ast.copy_location(self._import_call(id, node.ctx), node)
return node
@staticmethod
def _import_call(attribute_source, ctx):
return ast.Call(
func=ast.Name(id=IMPORTER, ctx=ctx),
args=[ast.Str(attribute_source)],
keywords=[])
def _transform_attribute_attr(self, node):
"""convert an Attribute node's left hand side into an import call"""
attr = is_import = has_valid_import_op(node.attr)
if not is_import:
return None
node.attr = attr
as_source = self.attribute_source(node)
return ast.copy_location(
self._import_call(as_source, node.ctx),
node)
def attribute_source(self, node: ast.Attribute, _seen_import_op=False):
"""return a source-code representation of an Attribute node"""
if self._has_valid_import_op(node):
_seen_import_op = True
stripped = self._remove_import_op(node)
if type(node) is ast.Name:
if _seen_import_op:
raise self._syntax_error('multiple import expressions not allowed', node) from None
return stripped
lhs = self.attribute_source(node.value, _seen_import_op)
rhs = stripped
return lhs + '.' + rhs
def visit_def_(self, node):
if not has_any_import_op(node.name):
# it's valid so far, just ensure that arguments and body are also visited
return self.generic_visit(node)
if isinstance(node, ast.ClassDef):
type_name = 'class'
else:
type_name = 'function'
raise self._syntax_error(
f'"{IMPORT_OP}" not allowed in the name of a {type_name}',
node
) from None
visit_FunctionDef = visit_def_
visit_ClassDef = visit_def_
def visit_arg(self, node):
"""ensure foo(x! = 1) or def foo(x!) does not occur"""
if node.arg is not None and has_any_import_op(node.arg):
raise self._syntax_error(
f'"{IMPORT_OP}" not allowed in function arguments',
node
) from None
# regular arguments may have import expr annotations as children
return super().generic_visit(node)
def visit_keyword(self, node):
self.visit_arg(node)
# keyword arguments may have import expressions as children
return super().generic_visit(node)
def visit_alias(self, node):
# from x import y **as z**
self._ensure_no_import_ops(node)
return node
def visit_ImportFrom(self, node):
self._ensure_no_import_ops(node)
# ImportFrom nodes can have alias children that we also need to check
return super().generic_visit(node)
def _ensure_only_valid_import_ops(self, node):
if self._for_any_child_node_string(has_invalid_import_op, node):
raise self._syntax_error(
f'"{IMPORT_OP}" only allowed at end of attribute name',
node
) from None
def _ensure_no_import_ops(self, node):
if self._for_any_child_node_string(has_any_import_op, node):
raise self._syntax_error(
'import expressions are only allowed in variables and attributes',
node
) from None
@classmethod
def _for_any_child_node_string(cls, predicate, node):
for child_node in ast.walk(node):
if cls._for_any_node_string(predicate, node):
return True
return False
@staticmethod
def _for_any_node_string(predicate, node):
for field, value in ast.iter_fields(node):
if isinstance(value, str) and predicate(value):
return True
return False
def _call_on_name_or_attribute(func):
def checker(node):
if type(node) is ast.Attribute:
to_check = node.attr
elif type(node) is ast.Name:
to_check = node.id
return func(to_check)
return staticmethod(checker)
_has_valid_import_op = _call_on_name_or_attribute(has_valid_import_op)
_remove_import_op = _call_on_name_or_attribute(remove_import_op)
del _call_on_name_or_attribute
def _syntax_error(self, message, node):
lineno = getattr(node, 'lineno', None)
column = getattr(node, 'col_offset', None)
line = None
if self.source_lines is not None and lineno:
with contextlib.suppress(IndexError):
line = self.source_lines[lineno-1]
ctx = SyntaxErrorContext(filename=self.filename, lineno=lineno, column=column, line=line)
return SyntaxError(message, ctx)