# -*- coding: utf-8 -*- """ The MIT License (MIT) Copyright (c) 2017 Rapptz Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ # These are just things that allow me to make tables for PostgreSQL easier # This isn't exactly good. It's just good enough for my uses. # Also shoddy migration support. import asyncio import datetime import decimal import inspect import json import logging import pydoc import uuid from collections import OrderedDict from pathlib import Path import asyncpg log = logging.getLogger(__name__) class SchemaError(Exception): pass class SQLType: python = None def to_dict(self): o = self.__dict__.copy() cls = self.__class__ o['__meta__'] = cls.__module__ + '.' + cls.__qualname__ return o @classmethod def from_dict(cls, data): meta = data.pop('__meta__') given = cls.__module__ + '.' + cls.__qualname__ if given != meta: cls = pydoc.locate(meta) if cls is None: raise RuntimeError('Could not locate "%s".' % meta) self = cls.__new__(cls) self.__dict__.update(data) return self def __eq__(self, other): return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ def __ne__(self, other): return not self.__eq__(other) def to_sql(self): raise NotImplementedError() def is_real_type(self): return True class Binary(SQLType): python = bytes def to_sql(self): return 'BYTEA' class Boolean(SQLType): python = bool def to_sql(self): return 'BOOLEAN' class Date(SQLType): python = datetime.date def to_sql(self): return 'DATE' class Datetime(SQLType): python = datetime.datetime def __init__(self, *, timezone=False): self.timezone = timezone def to_sql(self): if self.timezone: return 'TIMESTAMP WITH TIME ZONE' return 'TIMESTAMP' class Double(SQLType): python = float def to_sql(self): return 'REAL' class Float(SQLType): python = float def to_sql(self): return 'FLOAT' class Integer(SQLType): python = int def __init__(self, *, big=False, small=False, auto_increment=False): self.big = big self.small = small self.auto_increment = auto_increment if big and small: raise SchemaError( 'Integer column type cannot be both big and small.') def to_sql(self): if self.auto_increment: if self.big: return 'BIGSERIAL' if self.small: return 'SMALLSERIAL' return 'SERIAL' if self.big: return 'BIGINT' if self.small: return 'SMALLINT' return 'INTEGER' def is_real_type(self): return not self.auto_increment class Interval(SQLType): python = datetime.timedelta def __init__(self, field=None): if field: field = field.upper() if field not in ( 'YEAR', 'MONTH', 'DAY', 'HOUR', 'MINUTE', 'SECOND', 'YEAR TO MONTH', 'DAY TO HOUR', 'DAY TO MINUTE', 'DAY TO SECOND', 'HOUR TO MINUTE', 'HOUR TO SECOND', 'MINUTE TO SECOND'): raise SchemaError('invalid interval specified') self.field = field else: self.field = None def to_sql(self): if self.field: return 'INTERVAL ' + self.field return 'INTERVAL' class Numeric(SQLType): python = decimal.Decimal def __init__(self, *, precision=None, scale=None): if precision is not None: if precision < 0 or precision > 1000: raise SchemaError( 'precision must be greater than 0 and below 1000') if scale is None: scale = 0 self.precision = precision self.scale = scale def to_sql(self): if self.precision is not None: return 'NUMERIC({0.precision}, {0.scale})'.format(self) return 'NUMERIC' class String(SQLType): python = str def __init__(self, *, length=None, fixed=False): self.length = length self.fixed = fixed if fixed and length is None: raise SchemaError('Cannot have fixed string with no length') def to_sql(self): if self.length is None: return 'TEXT' if self.fixed: return 'CHAR({0.length})'.format(self) return 'VARCHAR({0.length})'.format(self) class Time(SQLType): python = datetime.time def __init__(self, *, timezone=False): self.timezone = timezone def to_sql(self): if self.timezone: return 'TIME WITH TIME ZONE' return 'TIME' class JSON(SQLType): python = None def to_sql(self): return 'JSONB' class ForeignKey(SQLType): def __init__(self, table, column, *, sql_type=None, on_delete='CASCADE', on_update='NO ACTION'): if not table or not isinstance(table, str): raise SchemaError('missing table to reference (must be string)') valid_actions = ( 'NO ACTION', 'RESTRICT', 'CASCADE', 'SET NULL', 'SET DEFAULT', ) on_delete = on_delete.upper() on_update = on_update.upper() if on_delete not in valid_actions: raise TypeError('on_delete must be one of %s.' % valid_actions) if on_update not in valid_actions: raise TypeError('on_update must be one of %s.' % valid_actions) self.table = table self.column = column self.on_update = on_update self.on_delete = on_delete if sql_type is None: sql_type = Integer if inspect.isclass(sql_type): sql_type = sql_type() if not isinstance(sql_type, SQLType): raise TypeError('Cannot have non-SQLType derived sql_type') if not sql_type.is_real_type(): raise SchemaError('sql_type must be a "real" type') self.sql_type = sql_type.to_sql() def is_real_type(self): return False def to_sql(self): fmt = '{0.sql_type} REFERENCES {0.table} ({0.column})' \ ' ON DELETE {0.on_delete} ON UPDATE {0.on_update}' return fmt.format(self) class Array(SQLType): python = list def __init__(self, sql_type): if inspect.isclass(sql_type): sql_type = sql_type() if not isinstance(sql_type, SQLType): raise TypeError('Cannot have non-SQLType derived sql_type') if not sql_type.is_real_type(): raise SchemaError('sql_type must be a "real" type') self.sql_type = sql_type.to_sql() def to_sql(self): return '{0.sql_type} ARRAY'.format(self) def is_real_type(self): # technically, it is a real type # however, it doesn't play very well with migrations # so we're going to pretend that it isn't return False class Column: __slots__ = ('column_type', 'index', 'primary_key', 'nullable', 'default', 'unique', 'name', 'index_name') def __init__(self, column_type, *, index=False, primary_key=False, nullable=True, unique=False, default=None, name=None): if inspect.isclass(column_type): column_type = column_type() if not isinstance(column_type, SQLType): raise TypeError('Cannot have a non-SQLType derived column_type') self.column_type = column_type self.index = index self.unique = unique self.primary_key = primary_key self.nullable = nullable self.default = default self.name = name self.index_name = None # to be filled later if sum(map(bool, (unique, primary_key, default is not None))) > 1: raise SchemaError( "'unique', 'primary_key', and 'default' are mutually exclusive.") @classmethod def from_dict(cls, data): index_name = data.pop('index_name', None) column_type = data.pop('column_type') column_type = SQLType.from_dict(column_type) self = cls(column_type=column_type, **data) self.index_name = index_name return self @property def _comparable_id(self): return '-'.join( '%s:%s' % (attr, getattr(self, attr)) for attr in self.__slots__) def _to_dict(self): d = { attr: getattr(self, attr) for attr in self.__slots__ } d['column_type'] = self.column_type.to_dict() return d def _qualifiers_dict(self): return {attr: getattr(self, attr) for attr in ('nullable', 'default')} def _is_rename(self, other): if self.name == other.name: return False return self.unique == other.unique and self.primary_key == other.primary_key def _create_table(self): builder = [] builder.append(self.name) builder.append(self.column_type.to_sql()) default = self.default if default is not None: builder.append('DEFAULT') if isinstance(default, str) and isinstance(self.column_type, String): builder.append("'%s'" % default) elif isinstance(default, bool): builder.append(str(default).upper()) else: builder.append("(%s)" % default) elif self.unique: builder.append('UNIQUE') if not self.nullable: builder.append('NOT NULL') return ' '.join(builder) class PrimaryKeyColumn(Column): """Shortcut for a SERIAL PRIMARY KEY column.""" def __init__(self): super().__init__(Integer(auto_increment=True), primary_key=True) class SchemaDiff: __slots__ = ('table', 'upgrade', 'downgrade') def __init__(self, table, upgrade, downgrade): self.table = table self.upgrade = upgrade self.downgrade = downgrade def to_dict(self): return {'upgrade': self.upgrade, 'downgrade': self.downgrade} def is_empty(self): return len(self.upgrade) == 0 and len(self.downgrade) == 0 def to_sql(self, *, downgrade=False): statements = [] base = 'ALTER TABLE %s ' % self.table.__tablename__ path = self.upgrade if not downgrade else self.downgrade for rename in path.get('rename_columns', []): fmt = '{0}RENAME COLUMN {1[before]} TO {1[after]};'.format(base, rename) statements.append(fmt) sub_statements = [] for dropped in path.get('remove_columns', []): fmt = 'DROP COLUMN {0[name]} RESTRICT'.format(dropped) sub_statements.append(fmt) for changed_types in path.get('changed_column_types', []): fmt = 'ALTER COLUMN {0[name]} SET DATA TYPE {0[type]}'.format( changed_types) using = changed_types.get('using') if using is not None: fmt = '%s USING %s' % (fmt, using) sub_statements.append(fmt) for constraints in path.get('changed_constraints', []): before, after = constraints['before'], constraints['after'] before_default, after_default = before.get('default'), after.get( 'default') if before_default is None and after_default is not None: fmt = 'ALTER COLUMN {0[name]} SET DEFAULT {1[default]}'.format( constraints, after) sub_statements.append(fmt) elif before_default is not None and after_default is None: fmt = 'ALTER COLUMN {0[name]} DROP DEFAULT'.format(constraints) sub_statements.append(fmt) before_nullable, after_nullable = before.get( 'nullable'), after.get('nullable') if not before_nullable and after_nullable: fmt = 'ALTER COLUMN {0[name]} DROP NOT NULL'.format( constraints) sub_statements.append(fmt) elif before_nullable and not after_nullable: fmt = 'ALTER COLUMN {0[name]} SET NOT NULL'.format(constraints) sub_statements.append(fmt) for added in path.get('add_columns', []): column = Column.from_dict(added) sub_statements.append('ADD COLUMN ' + column._create_table()) if sub_statements: statements.append(base + ', '.join(sub_statements) + ';') # handle the index creation bits for dropped in path.get('drop_index', []): statements.append( 'DROP INDEX IF EXISTS {0[index]};'.format(dropped)) for added in path.get('add_index', []): fmt = 'CREATE INDEX IF NOT EXISTS {0[index]} ON {1.__tablename__} ({0[name]});' statements.append(fmt.format(added, self.table)) return '\n'.join(statements) class MaybeAcquire: def __init__(self, connection, *, pool): self.connection = connection self.pool = pool self._cleanup = False async def __aenter__(self): if self.connection is None: self._cleanup = True self._connection = c = await self.pool.acquire() return c return self.connection async def __aexit__(self, *args): if self._cleanup: await self.pool.release(self._connection) class TableMeta(type): @classmethod def __prepare__(cls, name, bases, **kwargs): return OrderedDict() def __new__(cls, name, parents, dct, **kwargs): columns = [] try: table_name = kwargs['table_name'] except KeyError: table_name = name.lower() dct['__tablename__'] = table_name for elem, value in dct.items(): if isinstance(value, Column): if value.name is None: value.name = elem if value.index: value.index_name = '%s_%s_idx' % (table_name, value.name) columns.append(value) dct['columns'] = columns return super().__new__(cls, name, parents, dct) def __init__(self, name, parents, dct, **kwargs): super().__init__(name, parents, dct) class Table(metaclass=TableMeta): @classmethod async def create_pool(cls, uri, **kwargs): """Sets up and returns the PostgreSQL connection pool that is used. .. note:: This must be called at least once before doing anything with the tables. And must be called on the ``Table`` class. Parameters ----------- uri: str The PostgreSQL URI to connect to. \*\*kwargs The arguments to forward to asyncpg.create_pool. """ def _encode_jsonb(value): return json.dumps(value) def _decode_jsonb(value): return json.loads(value) old_init = kwargs.pop('init', None) async def init(con): await con.set_type_codec('jsonb', schema='pg_catalog', encoder=_encode_jsonb, decoder=_decode_jsonb, format='text') if old_init is not None: await old_init(con) cls._pool = pool = await asyncpg.create_pool(uri, init=init, **kwargs) return pool @classmethod def acquire_connection(cls, connection): return MaybeAcquire(connection, pool=cls._pool) @classmethod def write_migration(cls, *, directory='migrations'): """Writes the migration diff into the data file. Note ------ This doesn't actually commit/do the migration. To do so, use :meth:`migrate`. Returns -------- bool ``True`` if a migration was written, ``False`` otherwise. Raises ------- RuntimeError Could not find the migration data necessary. """ directory = Path(directory) / cls.__tablename__ p = directory.with_suffix('.json') if not p.exists(): raise RuntimeError('Could not find migration file.') current = directory.with_name('current-' + p.name) if not current.exists(): raise RuntimeError('Could not find current data file.') with current.open() as fp: current_table = cls.from_dict(json.load(fp)) diff = cls().diff(current_table) # the most common case, no difference if diff.is_empty(): return None # load the migration data with p.open('r', encoding='utf-8') as fp: data = json.load(fp) migrations = data['migrations'] # check if we should add it our_migrations = diff.to_dict() if len(migrations) == 0 or migrations[-1] != our_migrations: # we have a new migration, so add it migrations.append(our_migrations) temp_file = p.with_name('%s-%s.tmp' % (uuid.uuid4(), p.name)) with temp_file.open('w', encoding='utf-8') as tmp: json.dump(data, tmp, ensure_ascii=True, indent=4) temp_file.replace(p) return True return False @classmethod async def migrate(cls, *, directory='migrations', index=-1, downgrade=False, verbose=False, connection=None): """Actually run the latest migration pointed by the data file. Parameters ----------- directory: str The directory of where the migration data file resides. index: int The index of the migration array to use. downgrade: bool Whether to run an upgrade or a downgrade. verbose: bool Whether to output some information to stdout. connection: Optional[asyncpg.Connection] The connection to use, if not provided will acquire one from the internal pool. """ directory = Path(directory) / cls.__tablename__ p = directory.with_suffix('.json') if not p.exists(): raise RuntimeError('Could not find migration file.') with p.open('r', encoding='utf-8') as fp: data = json.load(fp) migrations = data['migrations'] try: migration = migrations[index] except IndexError: return False diff = SchemaDiff(cls, migration['upgrade'], migration['downgrade']) if diff.is_empty(): return False async with MaybeAcquire(connection, pool=cls._pool) as con: sql = diff.to_sql(downgrade=downgrade) if verbose: print(sql) await con.execute(sql) current = directory.with_name('current-' + p.name) with current.open('w', encoding='utf-8') as fp: json.dump(cls.to_dict(), fp, indent=4, ensure_ascii=True) @classmethod async def create(cls, *, directory='migrations', verbose=False, connection=None, run_migrations=True): """Creates the database and manages migrations, if any. Parameters ----------- directory: str The migrations directory. verbose: bool Whether to output some information to stdout. connection: Optional[asyncpg.Connection] The connection to use, if not provided will acquire one from the internal pool. run_migrations: bool Whether to run migrations at all. Returns -------- Optional[bool] ``True`` if the table was successfully created or ``False`` if the table was successfully migrated or ``None`` if no migration took place. """ directory = Path(directory) / cls.__tablename__ p = directory.with_suffix('.json') current = directory.with_name('current-' + p.name) table_data = cls.to_dict() if not p.exists(): p.parent.mkdir(parents=True, exist_ok=True) # we're creating this table for the first time, # it's an uncommon case so let's get it out of the way # first, try to actually create the table async with MaybeAcquire(connection, pool=cls._pool) as con: sql = cls.create_table(exists_ok=True) if verbose: print(sql) await con.execute(sql) # since that step passed, let's go ahead and make the migration with p.open('w', encoding='utf-8') as fp: data = {'table': table_data, 'migrations': []} json.dump(data, fp, indent=4, ensure_ascii=True) with current.open('w', encoding='utf-8') as fp: json.dump(table_data, fp, indent=4, ensure_ascii=True) return True if not run_migrations: return None with current.open() as fp: current_table = cls.from_dict(json.load(fp)) diff = cls().diff(current_table) # the most common case, no difference if diff.is_empty(): return None # execute the upgrade SQL async with MaybeAcquire(connection, pool=cls._pool) as con: sql = diff.to_sql() if verbose: print(sql) await con.execute(sql) # load the migration data with p.open('r', encoding='utf-8') as fp: data = json.load(fp) migrations = data['migrations'] # check if we should add it our_migrations = diff.to_dict() if len(migrations) == 0 or migrations[-1] != our_migrations: # we have a new migration, so add it migrations.append(our_migrations) temp_file = p.with_name('%s-%s.tmp' % (uuid.uuid4(), p.name)) with temp_file.open('w', encoding='utf-8') as tmp: json.dump(data, tmp, ensure_ascii=True, indent=4) temp_file.replace(p) # update our "current" data in the filesystem with current.open('w', encoding='utf-8') as fp: json.dump(table_data, fp, indent=4, ensure_ascii=True) return False @classmethod async def drop(cls, *, directory='migrations', verbose=False, connection=None): """Drops the database and migrations, if any. Parameters ----------- directory: str The migrations directory. verbose: bool Whether to output some information to stdout. connection: Optional[asyncpg.Connection] The connection to use, if not provided will acquire one from the internal pool. """ directory = Path(directory) / cls.__tablename__ p = directory.with_suffix('.json') current = directory.with_name('current-' + p.name) if not p.exists() or not current.exists(): raise RuntimeError('Could not find the appropriate data files.') try: p.unlink() except: raise RuntimeError('Could not delete migration file') try: current.unlink() except: raise RuntimeError('Could not delete current migration file') async with MaybeAcquire(connection, pool=cls._pool) as con: sql = 'DROP TABLE {0} CASCADE;'.format(cls.__tablename__) if verbose: print(sql) await con.execute(sql) @classmethod def create_table(cls, *, exists_ok=True): """Generates the CREATE TABLE stub.""" statements = [] builder = ['CREATE TABLE'] if exists_ok: builder.append('IF NOT EXISTS') builder.append(cls.__tablename__) column_creations = [] primary_keys = [] for col in cls.columns: column_creations.append(col._create_table()) if col.primary_key: primary_keys.append(col.name) column_creations.append('PRIMARY KEY (%s)' % ', '.join(primary_keys)) builder.append('(%s)' % ', '.join(column_creations)) statements.append(' '.join(builder) + ';') # handle the index creations for column in cls.columns: if column.index: fmt = 'CREATE INDEX IF NOT EXISTS {1.index_name} ON {0} ({1.name});'.format( cls.__tablename__, column) statements.append(fmt) return '\n'.join(statements) @classmethod async def insert(cls, connection=None, **kwargs): """Inserts an element to the table.""" # verify column names: verified = {} for column in cls.columns: try: value = kwargs[column.name] except KeyError: continue check = column.column_type.python if value is None and not column.nullable: raise TypeError( 'Cannot pass None to non-nullable column %s.' % column.name) elif not check or not isinstance(value, check): fmt = 'column {0.name} expected {1.__name__}, received {2.__class__.__name__}' raise TypeError(fmt.format(column, check, value)) verified[column.name] = value sql = 'INSERT INTO {0} ({1}) VALUES ({2});'.format(cls.__tablename__, ', '.join(verified), ', '.join( '$' + str(i) for i, _ in enumerate( verified, 1))) async with MaybeAcquire(connection, pool=cls._pool) as con: await con.execute(sql, *verified.values()) @classmethod def to_dict(cls): x = {} x['name'] = cls.__tablename__ x['__meta__'] = cls.__module__ + '.' + cls.__qualname__ # nb: columns is ordered due to the ordered dict usage # this is used to help detect renames x['columns'] = [a._to_dict() for a in cls.columns] return x @classmethod def from_dict(cls, data): meta = data['__meta__'] given = cls.__module__ + '.' + cls.__qualname__ if given != meta: cls = pydoc.locate(meta) if cls is None: raise RuntimeError('Could not locate "%s".' % meta) self = cls() self.__tablename__ = data['name'] self.columns = [Column.from_dict(a) for a in data['columns']] return self @classmethod def all_tables(cls): return cls.__subclasses__() def diff(self, before): """Outputs the upgrade and downgrade path in JSON. This isn't necessarily good, but it outputs it in a format that allows the user to manually make edits if something is wrong. The following JSON schema is used: Note that every major key takes a list of objects as noted below. Note that add_column and drop_column automatically create and drop indices as necessary. changed_column_types: name: str [The column name] type: str [The new column type] using: Optional[str] [The USING expression to use, if applicable] add_columns: column: object remove_columns: column: object rename_columns: before: str [The previous column name] after: str [The new column name] drop_index: name: str [The column name] index: str [The index name] add_index: name: str [The column name] index: str [The index name] changed_constraints: name: str [The column name] before: nullable: Optional[bool] default: Optional[str] after: nullable: Optional[bool] default: Optional[str] """ upgrade = {} downgrade = {} def check_index_diff(a, b): if a.index != b.index: # Let's assume we have {name: thing, index: True} # and we're going to { name: foo, index: False } # This is a 'dropped' column when we upgrade with a rename # care must be taken to use the old name when dropping # check if we're dropping the index if not a.index: # we could also be renaming so make sure to use the old index name upgrade.setdefault('drop_index', []).append( {'name': a.name, 'index': b.index_name}) # if we want to roll back, we need to re-add the old index to the old column name downgrade.setdefault('add_index', []).append( {'name': b.name, 'index': b.index_name}) else: # we're not dropping an index, instead we're adding one upgrade.setdefault('add_index', []).append( {'name': a.name, 'index': a.index_name}) downgrade.setdefault('drop_index', []).append( {'name': a.name, 'index': a.index_name}) def insert_column_diff(a, b): if a.column_type != b.column_type: if a.name == b.name and a.column_type.is_real_type() and b.column_type.is_real_type(): upgrade.setdefault('changed_column_types', []).append( {'name': a.name, 'type': a.column_type.to_sql()}) downgrade.setdefault('changed_column_types', []).append( {'name': a.name, 'type': b.column_type.to_sql()}) else: a_dict, b_dict = a._to_dict(), b._to_dict() upgrade.setdefault('add_columns', []).append(a_dict) upgrade.setdefault('remove_columns', []).append(b_dict) downgrade.setdefault('remove_columns', []).append(a_dict) downgrade.setdefault('add_columns', []).append(b_dict) check_index_diff(a, b) return elif a._is_rename(b): upgrade.setdefault('rename_columns', []).append( {'before': b.name, 'after': a.name}) downgrade.setdefault('rename_columns', []).append( {'before': a.name, 'after': b.name}) # technically, adding UNIQUE or PRIMARY KEY is rather simple and straight forward # however, since the inverse is a little bit more complicated (you have to remove # the index it maintains and you can't easily know what it is), it's not exactly # worth supporting any sort of change to the uniqueness/primary_key as it stands. # So.. just drop/add the column and call it a day. if a.unique != b.unique or a.primary_key != b.primary_key: a_dict, b_dict = a._to_dict(), b._to_dict() upgrade.setdefault('add_columns', []).append(a_dict) upgrade.setdefault('remove_columns', []).append(b_dict) downgrade.setdefault('remove_columns', []).append(a_dict) downgrade.setdefault('add_columns', []).append(b_dict) check_index_diff(a, b) return check_index_diff(a, b) b_qual, a_qual = b._qualifiers_dict(), a._qualifiers_dict() if a_qual != b_qual: upgrade.setdefault('changed_constraints', []).append( {'name': a.name, 'before': b_qual, 'after': a_qual}) downgrade.setdefault('changed_constraints', []).append( {'name': a.name, 'before': a_qual, 'after': b_qual}) if len(self.columns) == len(before.columns): # check if we have any changes at all for a, b in zip(self.columns, before.columns): if a._comparable_id == b._comparable_id: # no change continue insert_column_diff(a, b) elif len(self.columns) > len(before.columns): # check if we have more columns # typically when we add columns we add them at the end of # the table, this assumption makes this particularly bit easier. # Breaking this assumption will probably break this portion and thus # will require manual handling, sorry. for a, b in zip(self.columns, before.columns): if a._comparable_id == b._comparable_id: # no change continue insert_column_diff(a, b) new_columns = self.columns[len(before.columns):] add, remove = upgrade.setdefault('add_columns', []), downgrade.setdefault( 'remove_columns', []) for column in new_columns: as_dict = column._to_dict() add.append(as_dict) remove.append(as_dict) if column.index: upgrade.setdefault('add_index', []).append( {'name': column.name, 'index': column.index_name}) downgrade.setdefault('drop_index', []).append( {'name': column.name, 'index': column.index_name}) elif len(self.columns) < len(before.columns): # check if we have fewer columns # this one is a little bit more complicated # first we sort the columns by comparable IDs. sorted_before = sorted(before.columns, key=lambda c: c._comparable_id) sorted_after = sorted(self.columns, key=lambda c: c._comparable_id) # handle the column diffs: for a, b in zip(sorted_after, sorted_before): if a._comparable_id == b._comparable_id: continue insert_column_diff(a, b) # check which columns are 'left over' and remove them removed = [c._to_dict() for c in sorted_before[len(sorted_after):]] upgrade.setdefault('remove_columns', []).extend(removed) downgrade.setdefault('add_columns', []).extend(removed) return SchemaDiff(self, upgrade, downgrade) async def _table_creator(tables, *, verbose=True): for table in tables: try: await table.create(verbose=verbose) except: log.error('Failed to create table %s.', table.__tablename__) def create_tables(*tables, verbose=True, loop=None): if loop is None: loop = asyncio.get_event_loop() loop.create_task(_table_creator(tables, verbose=verbose))