# Copyright © 2018–2019 Io Mintz # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the “Software”), # to deal in the Software without restriction, including without limitation the # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or # sell copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS # OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. import ast import contextlib from collections import namedtuple from .constants import * def parse_ast(root_node, **kwargs): return ast.fix_missing_locations(Transformer(**kwargs).visit(root_node)) def remove_string_right(haystack, needle): left, needle, right = haystack.rpartition(needle) if not right: return left # needle not found return haystack def remove_import_op(name): return remove_string_right(name, MARKER) def has_any_import_op(name): return MARKER in name def has_invalid_import_op(name): removed = remove_import_op(name) return MARKER in removed or not removed def has_valid_import_op(name): return name.endswith(MARKER) and remove_import_op(name) SyntaxErrorContext = namedtuple('SyntaxErrorContext', 'filename lineno column line') class Transformer(ast.NodeTransformer): def __init__(self, *, filename=None, source=None): self.filename = filename self.source_lines = source.splitlines() if source is not None else None def visit_Attribute(self, node): """ convert Attribute nodes containing import expressions into Attribute nodes containing import calls """ self._ensure_only_valid_import_ops(node) maybe_transformed = self._transform_attribute_attr(node) if maybe_transformed: return maybe_transformed else: transformed_lhs = self.visit(node.value) return ast.copy_location( ast.Attribute( value=transformed_lhs, ctx=node.ctx, attr=node.attr), node) def visit_Name(self, node): """convert solitary Names that have import expressions, such as "a!", into import calls""" self._ensure_only_valid_import_ops(node) is_import = id = has_valid_import_op(node.id) if is_import: return ast.copy_location(self._import_call(id, node.ctx), node) return node @staticmethod def _import_call(attribute_source, ctx): return ast.Call( func=ast.Name(id=IMPORTER, ctx=ctx), args=[ast.Str(attribute_source)], keywords=[]) def _transform_attribute_attr(self, node): """convert an Attribute node's left hand side into an import call""" attr = is_import = has_valid_import_op(node.attr) if not is_import: return None node.attr = attr as_source = self.attribute_source(node) return ast.copy_location( self._import_call(as_source, node.ctx), node) def attribute_source(self, node: ast.Attribute, _seen_import_op=False): """return a source-code representation of an Attribute node""" if self._has_valid_import_op(node): _seen_import_op = True stripped = self._remove_import_op(node) if type(node) is ast.Name: if _seen_import_op: raise self._syntax_error('multiple import expressions not allowed', node) from None return stripped lhs = self.attribute_source(node.value, _seen_import_op) rhs = stripped return lhs + '.' + rhs def visit_def_(self, node): if not has_any_import_op(node.name): # it's valid so far, just ensure that arguments and body are also visited return self.generic_visit(node) if isinstance(node, ast.ClassDef): type_name = 'class' else: type_name = 'function' raise self._syntax_error( f'"{IMPORT_OP}" not allowed in the name of a {type_name}', node ) from None visit_FunctionDef = visit_def_ visit_ClassDef = visit_def_ def visit_arg(self, node): """ensure foo(x! = 1) or def foo(x!) does not occur""" if node.arg is not None and has_any_import_op(node.arg): raise self._syntax_error( f'"{IMPORT_OP}" not allowed in function arguments', node ) from None # regular arguments may have import expr annotations as children return super().generic_visit(node) def visit_keyword(self, node): self.visit_arg(node) # keyword arguments may have import expressions as children return super().generic_visit(node) def visit_alias(self, node): # from x import y **as z** self._ensure_no_import_ops(node) return node def visit_ImportFrom(self, node): self._ensure_no_import_ops(node) # ImportFrom nodes can have alias children that we also need to check return super().generic_visit(node) def _ensure_only_valid_import_ops(self, node): if self._for_any_child_node_string(has_invalid_import_op, node): raise self._syntax_error( f'"{IMPORT_OP}" only allowed at end of attribute name', node ) from None def _ensure_no_import_ops(self, node): if self._for_any_child_node_string(has_any_import_op, node): raise self._syntax_error( 'import expressions are only allowed in variables and attributes', node ) from None @classmethod def _for_any_child_node_string(cls, predicate, node): for child_node in ast.walk(node): if cls._for_any_node_string(predicate, node): return True return False @staticmethod def _for_any_node_string(predicate, node): for field, value in ast.iter_fields(node): if isinstance(value, str) and predicate(value): return True return False def _call_on_name_or_attribute(func): def checker(node): if type(node) is ast.Attribute: to_check = node.attr elif type(node) is ast.Name: to_check = node.id return func(to_check) return staticmethod(checker) _has_valid_import_op = _call_on_name_or_attribute(has_valid_import_op) _remove_import_op = _call_on_name_or_attribute(remove_import_op) del _call_on_name_or_attribute def _syntax_error(self, message, node): lineno = getattr(node, 'lineno', None) column = getattr(node, 'col_offset', None) line = None if self.source_lines is not None and lineno: with contextlib.suppress(IndexError): line = self.source_lines[lineno-1] ctx = SyntaxErrorContext(filename=self.filename, lineno=lineno, column=column, line=line) return SyntaxError(message, ctx)