212 lines
6.6 KiB
Python
212 lines
6.6 KiB
Python
|
from .. import config
|
||
|
from .. import fixtures
|
||
|
from ..assertions import eq_
|
||
|
from ..schema import Column
|
||
|
from ..schema import Table
|
||
|
from ... import ForeignKey
|
||
|
from ... import Integer
|
||
|
from ... import select
|
||
|
from ... import String
|
||
|
from ... import testing
|
||
|
|
||
|
|
||
|
class CTETest(fixtures.TablesTest):
|
||
|
__backend__ = True
|
||
|
__requires__ = ("ctes",)
|
||
|
|
||
|
run_inserts = "each"
|
||
|
run_deletes = "each"
|
||
|
|
||
|
@classmethod
|
||
|
def define_tables(cls, metadata):
|
||
|
Table(
|
||
|
"some_table",
|
||
|
metadata,
|
||
|
Column("id", Integer, primary_key=True),
|
||
|
Column("data", String(50)),
|
||
|
Column("parent_id", ForeignKey("some_table.id")),
|
||
|
)
|
||
|
|
||
|
Table(
|
||
|
"some_other_table",
|
||
|
metadata,
|
||
|
Column("id", Integer, primary_key=True),
|
||
|
Column("data", String(50)),
|
||
|
Column("parent_id", Integer),
|
||
|
)
|
||
|
|
||
|
@classmethod
|
||
|
def insert_data(cls):
|
||
|
config.db.execute(
|
||
|
cls.tables.some_table.insert(),
|
||
|
[
|
||
|
{"id": 1, "data": "d1", "parent_id": None},
|
||
|
{"id": 2, "data": "d2", "parent_id": 1},
|
||
|
{"id": 3, "data": "d3", "parent_id": 1},
|
||
|
{"id": 4, "data": "d4", "parent_id": 3},
|
||
|
{"id": 5, "data": "d5", "parent_id": 3},
|
||
|
],
|
||
|
)
|
||
|
|
||
|
def test_select_nonrecursive_round_trip(self):
|
||
|
some_table = self.tables.some_table
|
||
|
|
||
|
with config.db.connect() as conn:
|
||
|
cte = (
|
||
|
select([some_table])
|
||
|
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||
|
.cte("some_cte")
|
||
|
)
|
||
|
result = conn.execute(
|
||
|
select([cte.c.data]).where(cte.c.data.in_(["d4", "d5"]))
|
||
|
)
|
||
|
eq_(result.fetchall(), [("d4",)])
|
||
|
|
||
|
def test_select_recursive_round_trip(self):
|
||
|
some_table = self.tables.some_table
|
||
|
|
||
|
with config.db.connect() as conn:
|
||
|
cte = (
|
||
|
select([some_table])
|
||
|
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||
|
.cte("some_cte", recursive=True)
|
||
|
)
|
||
|
|
||
|
cte_alias = cte.alias("c1")
|
||
|
st1 = some_table.alias()
|
||
|
# note that SQL Server requires this to be UNION ALL,
|
||
|
# can't be UNION
|
||
|
cte = cte.union_all(
|
||
|
select([st1]).where(st1.c.id == cte_alias.c.parent_id)
|
||
|
)
|
||
|
result = conn.execute(
|
||
|
select([cte.c.data])
|
||
|
.where(cte.c.data != "d2")
|
||
|
.order_by(cte.c.data.desc())
|
||
|
)
|
||
|
eq_(
|
||
|
result.fetchall(),
|
||
|
[("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)],
|
||
|
)
|
||
|
|
||
|
def test_insert_from_select_round_trip(self):
|
||
|
some_table = self.tables.some_table
|
||
|
some_other_table = self.tables.some_other_table
|
||
|
|
||
|
with config.db.connect() as conn:
|
||
|
cte = (
|
||
|
select([some_table])
|
||
|
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||
|
.cte("some_cte")
|
||
|
)
|
||
|
conn.execute(
|
||
|
some_other_table.insert().from_select(
|
||
|
["id", "data", "parent_id"], select([cte])
|
||
|
)
|
||
|
)
|
||
|
eq_(
|
||
|
conn.execute(
|
||
|
select([some_other_table]).order_by(some_other_table.c.id)
|
||
|
).fetchall(),
|
||
|
[(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)],
|
||
|
)
|
||
|
|
||
|
@testing.requires.ctes_with_update_delete
|
||
|
@testing.requires.update_from
|
||
|
def test_update_from_round_trip(self):
|
||
|
some_table = self.tables.some_table
|
||
|
some_other_table = self.tables.some_other_table
|
||
|
|
||
|
with config.db.connect() as conn:
|
||
|
conn.execute(
|
||
|
some_other_table.insert().from_select(
|
||
|
["id", "data", "parent_id"], select([some_table])
|
||
|
)
|
||
|
)
|
||
|
|
||
|
cte = (
|
||
|
select([some_table])
|
||
|
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||
|
.cte("some_cte")
|
||
|
)
|
||
|
conn.execute(
|
||
|
some_other_table.update()
|
||
|
.values(parent_id=5)
|
||
|
.where(some_other_table.c.data == cte.c.data)
|
||
|
)
|
||
|
eq_(
|
||
|
conn.execute(
|
||
|
select([some_other_table]).order_by(some_other_table.c.id)
|
||
|
).fetchall(),
|
||
|
[
|
||
|
(1, "d1", None),
|
||
|
(2, "d2", 5),
|
||
|
(3, "d3", 5),
|
||
|
(4, "d4", 5),
|
||
|
(5, "d5", 3),
|
||
|
],
|
||
|
)
|
||
|
|
||
|
@testing.requires.ctes_with_update_delete
|
||
|
@testing.requires.delete_from
|
||
|
def test_delete_from_round_trip(self):
|
||
|
some_table = self.tables.some_table
|
||
|
some_other_table = self.tables.some_other_table
|
||
|
|
||
|
with config.db.connect() as conn:
|
||
|
conn.execute(
|
||
|
some_other_table.insert().from_select(
|
||
|
["id", "data", "parent_id"], select([some_table])
|
||
|
)
|
||
|
)
|
||
|
|
||
|
cte = (
|
||
|
select([some_table])
|
||
|
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||
|
.cte("some_cte")
|
||
|
)
|
||
|
conn.execute(
|
||
|
some_other_table.delete().where(
|
||
|
some_other_table.c.data == cte.c.data
|
||
|
)
|
||
|
)
|
||
|
eq_(
|
||
|
conn.execute(
|
||
|
select([some_other_table]).order_by(some_other_table.c.id)
|
||
|
).fetchall(),
|
||
|
[(1, "d1", None), (5, "d5", 3)],
|
||
|
)
|
||
|
|
||
|
@testing.requires.ctes_with_update_delete
|
||
|
def test_delete_scalar_subq_round_trip(self):
|
||
|
|
||
|
some_table = self.tables.some_table
|
||
|
some_other_table = self.tables.some_other_table
|
||
|
|
||
|
with config.db.connect() as conn:
|
||
|
conn.execute(
|
||
|
some_other_table.insert().from_select(
|
||
|
["id", "data", "parent_id"], select([some_table])
|
||
|
)
|
||
|
)
|
||
|
|
||
|
cte = (
|
||
|
select([some_table])
|
||
|
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||
|
.cte("some_cte")
|
||
|
)
|
||
|
conn.execute(
|
||
|
some_other_table.delete().where(
|
||
|
some_other_table.c.data
|
||
|
== select([cte.c.data]).where(
|
||
|
cte.c.id == some_other_table.c.id
|
||
|
)
|
||
|
)
|
||
|
)
|
||
|
eq_(
|
||
|
conn.execute(
|
||
|
select([some_other_table]).order_by(some_other_table.c.id)
|
||
|
).fetchall(),
|
||
|
[(1, "d1", None), (5, "d5", 3)],
|
||
|
)
|