213 lines
6.6 KiB
Python
213 lines
6.6 KiB
Python
|
# Copyright © 2018–2019 Io Mintz <io@mintz.cc>
|
|||
|
|
|||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|||
|
# of this software and associated documentation files (the “Software”),
|
|||
|
# to deal in the Software without restriction, including without limitation the
|
|||
|
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
|||
|
# sell copies of the Software, and to permit persons to whom the Software is
|
|||
|
# furnished to do so, subject to the following conditions:
|
|||
|
|
|||
|
# The above copyright notice and this permission notice shall be included in
|
|||
|
# all copies or substantial portions of the Software.
|
|||
|
|
|||
|
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|||
|
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
|||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
|||
|
# THE SOFTWARE.
|
|||
|
|
|||
|
import ast
|
|||
|
import contextlib
|
|||
|
from collections import namedtuple
|
|||
|
|
|||
|
from .constants import *
|
|||
|
|
|||
|
def parse_ast(root_node, **kwargs): return ast.fix_missing_locations(Transformer(**kwargs).visit(root_node))
|
|||
|
|
|||
|
def remove_string_right(haystack, needle):
|
|||
|
left, needle, right = haystack.rpartition(needle)
|
|||
|
if not right:
|
|||
|
return left
|
|||
|
# needle not found
|
|||
|
return haystack
|
|||
|
|
|||
|
def remove_import_op(name): return remove_string_right(name, MARKER)
|
|||
|
def has_any_import_op(name): return MARKER in name
|
|||
|
def has_invalid_import_op(name):
|
|||
|
removed = remove_import_op(name)
|
|||
|
return MARKER in removed or not removed
|
|||
|
def has_valid_import_op(name): return name.endswith(MARKER) and remove_import_op(name)
|
|||
|
|
|||
|
SyntaxErrorContext = namedtuple('SyntaxErrorContext', 'filename lineno column line')
|
|||
|
|
|||
|
class Transformer(ast.NodeTransformer):
|
|||
|
def __init__(self, *, filename=None, source=None):
|
|||
|
self.filename = filename
|
|||
|
self.source_lines = source.splitlines() if source is not None else None
|
|||
|
|
|||
|
def visit_Attribute(self, node):
|
|||
|
"""
|
|||
|
convert Attribute nodes containing import expressions into Attribute nodes containing import calls
|
|||
|
"""
|
|||
|
self._ensure_only_valid_import_ops(node)
|
|||
|
|
|||
|
maybe_transformed = self._transform_attribute_attr(node)
|
|||
|
if maybe_transformed:
|
|||
|
return maybe_transformed
|
|||
|
else:
|
|||
|
transformed_lhs = self.visit(node.value)
|
|||
|
return ast.copy_location(
|
|||
|
ast.Attribute(
|
|||
|
value=transformed_lhs,
|
|||
|
ctx=node.ctx,
|
|||
|
attr=node.attr),
|
|||
|
node)
|
|||
|
|
|||
|
def visit_Name(self, node):
|
|||
|
"""convert solitary Names that have import expressions, such as "a!", into import calls"""
|
|||
|
self._ensure_only_valid_import_ops(node)
|
|||
|
|
|||
|
is_import = id = has_valid_import_op(node.id)
|
|||
|
if is_import:
|
|||
|
return ast.copy_location(self._import_call(id, node.ctx), node)
|
|||
|
return node
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def _import_call(attribute_source, ctx):
|
|||
|
return ast.Call(
|
|||
|
func=ast.Name(id=IMPORTER, ctx=ctx),
|
|||
|
args=[ast.Str(attribute_source)],
|
|||
|
keywords=[])
|
|||
|
|
|||
|
def _transform_attribute_attr(self, node):
|
|||
|
"""convert an Attribute node's left hand side into an import call"""
|
|||
|
|
|||
|
attr = is_import = has_valid_import_op(node.attr)
|
|||
|
|
|||
|
if not is_import:
|
|||
|
return None
|
|||
|
|
|||
|
node.attr = attr
|
|||
|
as_source = self.attribute_source(node)
|
|||
|
|
|||
|
return ast.copy_location(
|
|||
|
self._import_call(as_source, node.ctx),
|
|||
|
node)
|
|||
|
|
|||
|
def attribute_source(self, node: ast.Attribute, _seen_import_op=False):
|
|||
|
"""return a source-code representation of an Attribute node"""
|
|||
|
if self._has_valid_import_op(node):
|
|||
|
_seen_import_op = True
|
|||
|
|
|||
|
stripped = self._remove_import_op(node)
|
|||
|
if type(node) is ast.Name:
|
|||
|
if _seen_import_op:
|
|||
|
raise self._syntax_error('multiple import expressions not allowed', node) from None
|
|||
|
return stripped
|
|||
|
|
|||
|
lhs = self.attribute_source(node.value, _seen_import_op)
|
|||
|
rhs = stripped
|
|||
|
|
|||
|
return lhs + '.' + rhs
|
|||
|
|
|||
|
def visit_def_(self, node):
|
|||
|
if not has_any_import_op(node.name):
|
|||
|
# it's valid so far, just ensure that arguments and body are also visited
|
|||
|
return self.generic_visit(node)
|
|||
|
|
|||
|
if isinstance(node, ast.ClassDef):
|
|||
|
type_name = 'class'
|
|||
|
else:
|
|||
|
type_name = 'function'
|
|||
|
|
|||
|
raise self._syntax_error(
|
|||
|
f'"{IMPORT_OP}" not allowed in the name of a {type_name}',
|
|||
|
node
|
|||
|
) from None
|
|||
|
|
|||
|
visit_FunctionDef = visit_def_
|
|||
|
visit_ClassDef = visit_def_
|
|||
|
|
|||
|
def visit_arg(self, node):
|
|||
|
"""ensure foo(x! = 1) or def foo(x!) does not occur"""
|
|||
|
if node.arg is not None and has_any_import_op(node.arg):
|
|||
|
raise self._syntax_error(
|
|||
|
f'"{IMPORT_OP}" not allowed in function arguments',
|
|||
|
node
|
|||
|
) from None
|
|||
|
|
|||
|
# regular arguments may have import expr annotations as children
|
|||
|
return super().generic_visit(node)
|
|||
|
|
|||
|
def visit_keyword(self, node):
|
|||
|
self.visit_arg(node)
|
|||
|
# keyword arguments may have import expressions as children
|
|||
|
return super().generic_visit(node)
|
|||
|
|
|||
|
def visit_alias(self, node):
|
|||
|
# from x import y **as z**
|
|||
|
self._ensure_no_import_ops(node)
|
|||
|
return node
|
|||
|
|
|||
|
def visit_ImportFrom(self, node):
|
|||
|
self._ensure_no_import_ops(node)
|
|||
|
# ImportFrom nodes can have alias children that we also need to check
|
|||
|
return super().generic_visit(node)
|
|||
|
|
|||
|
def _ensure_only_valid_import_ops(self, node):
|
|||
|
if self._for_any_child_node_string(has_invalid_import_op, node):
|
|||
|
raise self._syntax_error(
|
|||
|
f'"{IMPORT_OP}" only allowed at end of attribute name',
|
|||
|
node
|
|||
|
) from None
|
|||
|
|
|||
|
def _ensure_no_import_ops(self, node):
|
|||
|
if self._for_any_child_node_string(has_any_import_op, node):
|
|||
|
raise self._syntax_error(
|
|||
|
'import expressions are only allowed in variables and attributes',
|
|||
|
node
|
|||
|
) from None
|
|||
|
|
|||
|
@classmethod
|
|||
|
def _for_any_child_node_string(cls, predicate, node):
|
|||
|
for child_node in ast.walk(node):
|
|||
|
if cls._for_any_node_string(predicate, node):
|
|||
|
return True
|
|||
|
|
|||
|
return False
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def _for_any_node_string(predicate, node):
|
|||
|
for field, value in ast.iter_fields(node):
|
|||
|
if isinstance(value, str) and predicate(value):
|
|||
|
return True
|
|||
|
|
|||
|
return False
|
|||
|
|
|||
|
def _call_on_name_or_attribute(func):
|
|||
|
def checker(node):
|
|||
|
if type(node) is ast.Attribute:
|
|||
|
to_check = node.attr
|
|||
|
elif type(node) is ast.Name:
|
|||
|
to_check = node.id
|
|||
|
return func(to_check)
|
|||
|
|
|||
|
return staticmethod(checker)
|
|||
|
|
|||
|
_has_valid_import_op = _call_on_name_or_attribute(has_valid_import_op)
|
|||
|
_remove_import_op = _call_on_name_or_attribute(remove_import_op)
|
|||
|
|
|||
|
del _call_on_name_or_attribute
|
|||
|
|
|||
|
def _syntax_error(self, message, node):
|
|||
|
lineno = getattr(node, 'lineno', None)
|
|||
|
column = getattr(node, 'col_offset', None)
|
|||
|
line = None
|
|||
|
if self.source_lines is not None and lineno:
|
|||
|
with contextlib.suppress(IndexError):
|
|||
|
line = self.source_lines[lineno-1]
|
|||
|
ctx = SyntaxErrorContext(filename=self.filename, lineno=lineno, column=column, line=line)
|
|||
|
return SyntaxError(message, ctx)
|