# testing/assertsql.py # Copyright (C) 2005-2019 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php import collections import contextlib import re from .. import event from .. import util from ..engine import url from ..engine.default import DefaultDialect from ..engine.util import _distill_params from ..schema import _DDLCompiles class AssertRule(object): is_consumed = False errormessage = None consume_statement = True def process_statement(self, execute_observed): pass def no_more_statements(self): assert False, ( "All statements are complete, but pending " "assertion rules remain" ) class SQLMatchRule(AssertRule): pass class CursorSQL(SQLMatchRule): consume_statement = False def __init__(self, statement, params=None): self.statement = statement self.params = params def process_statement(self, execute_observed): stmt = execute_observed.statements[0] if self.statement != stmt.statement or ( self.params is not None and self.params != stmt.parameters ): self.errormessage = ( "Testing for exact SQL %s parameters %s received %s %s" % ( self.statement, self.params, stmt.statement, stmt.parameters, ) ) else: execute_observed.statements.pop(0) self.is_consumed = True if not execute_observed.statements: self.consume_statement = True class CompiledSQL(SQLMatchRule): def __init__(self, statement, params=None, dialect="default"): self.statement = statement self.params = params self.dialect = dialect def _compare_sql(self, execute_observed, received_statement): stmt = re.sub(r"[\n\t]", "", self.statement) return received_statement == stmt def _compile_dialect(self, execute_observed): if self.dialect == "default": return DefaultDialect() else: # ugh if self.dialect == "postgresql": params = {"implicit_returning": True} else: params = {} return url.URL(self.dialect).get_dialect()(**params) def _received_statement(self, execute_observed): """reconstruct the statement and params in terms of a target dialect, which for CompiledSQL is just DefaultDialect.""" context = execute_observed.context compare_dialect = self._compile_dialect(execute_observed) if isinstance(context.compiled.statement, _DDLCompiles): compiled = context.compiled.statement.compile( dialect=compare_dialect, schema_translate_map=context.execution_options.get( "schema_translate_map" ), ) else: compiled = context.compiled.statement.compile( dialect=compare_dialect, column_keys=context.compiled.column_keys, inline=context.compiled.inline, schema_translate_map=context.execution_options.get( "schema_translate_map" ), ) _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled)) parameters = execute_observed.parameters if not parameters: _received_parameters = [compiled.construct_params()] else: _received_parameters = [ compiled.construct_params(m) for m in parameters ] return _received_statement, _received_parameters def process_statement(self, execute_observed): context = execute_observed.context _received_statement, _received_parameters = self._received_statement( execute_observed ) params = self._all_params(context) equivalent = self._compare_sql(execute_observed, _received_statement) if equivalent: if params is not None: all_params = list(params) all_received = list(_received_parameters) while all_params and all_received: param = dict(all_params.pop(0)) for idx, received in enumerate(list(all_received)): # do a positive compare only for param_key in param: # a key in param did not match current # 'received' if ( param_key not in received or received[param_key] != param[param_key] ): break else: # all keys in param matched 'received'; # onto next param del all_received[idx] break else: # param did not match any entry # in all_received equivalent = False break if all_params or all_received: equivalent = False if equivalent: self.is_consumed = True self.errormessage = None else: self.errormessage = self._failure_message(params) % { "received_statement": _received_statement, "received_parameters": _received_parameters, } def _all_params(self, context): if self.params: if util.callable(self.params): params = self.params(context) else: params = self.params if not isinstance(params, list): params = [params] return params else: return None def _failure_message(self, expected_params): return ( "Testing for compiled statement %r partial params %s, " "received %%(received_statement)r with params " "%%(received_parameters)r" % ( self.statement.replace("%", "%%"), repr(expected_params).replace("%", "%%"), ) ) class RegexSQL(CompiledSQL): def __init__(self, regex, params=None): SQLMatchRule.__init__(self) self.regex = re.compile(regex) self.orig_regex = regex self.params = params self.dialect = "default" def _failure_message(self, expected_params): return ( "Testing for compiled statement ~%r partial params %s, " "received %%(received_statement)r with params " "%%(received_parameters)r" % ( self.orig_regex.replace("%", "%%"), repr(expected_params).replace("%", "%%"), ) ) def _compare_sql(self, execute_observed, received_statement): return bool(self.regex.match(received_statement)) class DialectSQL(CompiledSQL): def _compile_dialect(self, execute_observed): return execute_observed.context.dialect def _compare_no_space(self, real_stmt, received_stmt): stmt = re.sub(r"[\n\t]", "", real_stmt) return received_stmt == stmt def _received_statement(self, execute_observed): received_stmt, received_params = super( DialectSQL, self )._received_statement(execute_observed) # TODO: why do we need this part? for real_stmt in execute_observed.statements: if self._compare_no_space(real_stmt.statement, received_stmt): break else: raise AssertionError( "Can't locate compiled statement %r in list of " "statements actually invoked" % received_stmt ) return received_stmt, execute_observed.context.compiled_parameters def _compare_sql(self, execute_observed, received_statement): stmt = re.sub(r"[\n\t]", "", self.statement) # convert our comparison statement to have the # paramstyle of the received paramstyle = execute_observed.context.dialect.paramstyle if paramstyle == "pyformat": stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt) else: # positional params repl = None if paramstyle == "qmark": repl = "?" elif paramstyle == "format": repl = r"%s" elif paramstyle == "numeric": repl = None stmt = re.sub(r":([\w_]+)", repl, stmt) return received_statement == stmt class CountStatements(AssertRule): def __init__(self, count): self.count = count self._statement_count = 0 def process_statement(self, execute_observed): self._statement_count += 1 def no_more_statements(self): if self.count != self._statement_count: assert False, "desired statement count %d does not match %d" % ( self.count, self._statement_count, ) class AllOf(AssertRule): def __init__(self, *rules): self.rules = set(rules) def process_statement(self, execute_observed): for rule in list(self.rules): rule.errormessage = None rule.process_statement(execute_observed) if rule.is_consumed: self.rules.discard(rule) if not self.rules: self.is_consumed = True break elif not rule.errormessage: # rule is not done yet self.errormessage = None break else: self.errormessage = list(self.rules)[0].errormessage class EachOf(AssertRule): def __init__(self, *rules): self.rules = list(rules) def process_statement(self, execute_observed): while self.rules: rule = self.rules[0] rule.process_statement(execute_observed) if rule.is_consumed: self.rules.pop(0) elif rule.errormessage: self.errormessage = rule.errormessage if rule.consume_statement: break if not self.rules: self.is_consumed = True def no_more_statements(self): if self.rules and not self.rules[0].is_consumed: self.rules[0].no_more_statements() elif self.rules: super(EachOf, self).no_more_statements() class Or(AllOf): def process_statement(self, execute_observed): for rule in self.rules: rule.process_statement(execute_observed) if rule.is_consumed: self.is_consumed = True break else: self.errormessage = list(self.rules)[0].errormessage class SQLExecuteObserved(object): def __init__(self, context, clauseelement, multiparams, params): self.context = context self.clauseelement = clauseelement self.parameters = _distill_params(multiparams, params) self.statements = [] class SQLCursorExecuteObserved( collections.namedtuple( "SQLCursorExecuteObserved", ["statement", "parameters", "context", "executemany"], ) ): pass class SQLAsserter(object): def __init__(self): self.accumulated = [] def _close(self): self._final = self.accumulated del self.accumulated def assert_(self, *rules): rule = EachOf(*rules) observed = list(self._final) while observed: statement = observed.pop(0) rule.process_statement(statement) if rule.is_consumed: break elif rule.errormessage: assert False, rule.errormessage if observed: assert False, "Additional SQL statements remain" elif not rule.is_consumed: rule.no_more_statements() @contextlib.contextmanager def assert_engine(engine): asserter = SQLAsserter() orig = [] @event.listens_for(engine, "before_execute") def connection_execute(conn, clauseelement, multiparams, params): # grab the original statement + params before any cursor # execution orig[:] = clauseelement, multiparams, params @event.listens_for(engine, "after_cursor_execute") def cursor_execute( conn, cursor, statement, parameters, context, executemany ): if not context: return # then grab real cursor statements and associate them all # around a single context if ( asserter.accumulated and asserter.accumulated[-1].context is context ): obs = asserter.accumulated[-1] else: obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2]) asserter.accumulated.append(obs) obs.statements.append( SQLCursorExecuteObserved( statement, parameters, context, executemany ) ) try: yield asserter finally: event.remove(engine, "after_cursor_execute", cursor_execute) event.remove(engine, "before_execute", connection_execute) asserter._close()