diff options
Diffstat (limited to 'mediagoblin/db/migration_tools.py')
-rw-r--r-- | mediagoblin/db/migration_tools.py | 115 |
1 files changed, 94 insertions, 21 deletions
diff --git a/mediagoblin/db/migration_tools.py b/mediagoblin/db/migration_tools.py index e39070c3..f4273fa0 100644 --- a/mediagoblin/db/migration_tools.py +++ b/mediagoblin/db/migration_tools.py @@ -14,10 +14,24 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. +from __future__ import unicode_literals + +import logging +import os +import pkg_resources + +from alembic import command +from alembic.config import Config +from alembic.migration import MigrationContext + +from mediagoblin.db.base import Base from mediagoblin.tools.common import simple_printer from sqlalchemy import Table from sqlalchemy.sql import select +log = logging.getLogger(__name__) + + class TableAlreadyExists(Exception): pass @@ -30,7 +44,7 @@ class MigrationManager(object): to the latest migrations, etc. """ - def __init__(self, name, models, foundations, migration_registry, session, + def __init__(self, name, models, migration_registry, session, printer=simple_printer): """ Args: @@ -39,9 +53,8 @@ class MigrationManager(object): - migration_registry: where we should find all migrations to run """ - self.name = unicode(name) + self.name = name self.models = models - self.foundations = foundations self.session = session self.migration_registry = migration_registry self._sorted_migrations = None @@ -112,14 +125,14 @@ class MigrationManager(object): def migrations_to_run(self): """ Get a list of migrations to run still, if any. - + Note that this will fail if there's no migration record for this class! """ assert self.database_current_migration is not None db_current_migration = self.database_current_migration - + return [ (migration_number, migration_func) for migration_number, migration_func in self.sorted_migrations @@ -142,18 +155,6 @@ class MigrationManager(object): self.session.bind, tables=[model.__table__ for model in self.models]) - def populate_table_foundations(self): - """ - Create the table foundations (default rows) as layed out in FOUNDATIONS - in mediagoblin.db.models - """ - for Model, rows in self.foundations.items(): - self.printer(u' + Laying foundations for %s table\n' % - (Model.__name__)) - for parameters in rows: - new_row = Model(**parameters) - self.session.add(new_row) - def create_new_migration_record(self): """ Create a new migration record for this migration set @@ -184,7 +185,7 @@ class MigrationManager(object): migration_number, migration_func.func_name)) return u'migrated' - + def name_for_printing(self): if self.name == u'__main__': return u"main mediagoblin tables" @@ -218,7 +219,6 @@ class MigrationManager(object): # auto-set at latest migration number self.create_new_migration_record() self.printer(u"done.\n") - self.populate_table_foundations() self.set_current_migration() return u'inited' @@ -230,7 +230,7 @@ class MigrationManager(object): for migration_number, migration_func in migrations_to_run: self.printer( u' + Running migration %s, "%s"... ' % ( - migration_number, migration_func.func_name)) + migration_number, migration_func.__name__)) migration_func(self.session) self.set_current_migration(migration_number) self.printer('done.\n') @@ -263,6 +263,8 @@ class RegisterMigration(object): assert migration_number > 0, "Migration number must be > 0!" assert migration_number not in migration_registry, \ "Duplicate migration numbers detected! That's not allowed!" + assert migration_number <= 44, ('Alembic should be used for ' + 'new migrations') self.migration_number = migration_number self.migration_registry = migration_registry @@ -295,7 +297,7 @@ def replace_table_hack(db, old_table, replacement_table): -tion, for example, dropping a boolean column in sqlite is impossible w/o this method - :param old_table A ref to the old table, gotten through + :param old_table A ref to the old table, gotten through inspect_table :param replacement_table A ref to the new table, gotten through @@ -319,3 +321,74 @@ def replace_table_hack(db, old_table, replacement_table): replacement_table.rename(old_table_name) db.commit() + +def model_iteration_hack(db, query): + """ + This will return either the query you gave if it's postgres or in the case + of sqlite it will return a list with all the results. This is because in + migrations it seems sqlite can't deal with concurrent quries so if you're + iterating over models and doing a commit inside the loop, you will run into + an exception which says you've closed the connection on your iteration + query. This fixes it. + + NB: This loads all of the query reuslts into memeory, there isn't a good + way around this, we're assuming sqlite users have small databases. + """ + # If it's SQLite just return all the objects + if db.bind.url.drivername == "sqlite": + return [obj for obj in db.execute(query)] + + # Postgres return the query as it knows how to deal with it. + return db.execute(query) + + +def populate_table_foundations(session, foundations, name, + printer=simple_printer): + """ + Create the table foundations (default rows) as layed out in FOUNDATIONS + in mediagoblin.db.models + """ + printer(u'Laying foundations for %s:\n' % name) + for Model, rows in foundations.items(): + printer(u' + Laying foundations for %s table\n' % + (Model.__name__)) + for parameters in rows: + new_row = Model(**parameters) + session.add(new_row) + + session.commit() + + +def build_alembic_config(global_config, cmd_options, session): + """ + Build up a config that the alembic tooling can use based on our + configuration. Initialize the database session appropriately + as well. + """ + root_dir = os.path.abspath(os.path.dirname(os.path.dirname( + os.path.dirname(__file__)))) + alembic_cfg_path = os.path.join(root_dir, 'alembic.ini') + cfg = Config(alembic_cfg_path, + cmd_opts=cmd_options) + cfg.attributes["session"] = session + + version_locations = [ + pkg_resources.resource_filename( + "mediagoblin.db", os.path.join("migrations", "versions")), + ] + + cfg.set_main_option("sqlalchemy.url", str(session.get_bind().url)) + + for plugin in global_config.get("plugins", []): + plugin_migrations = pkg_resources.resource_filename( + plugin, "migrations") + is_migrations_dir = (os.path.exists(plugin_migrations) and + os.path.isdir(plugin_migrations)) + if is_migrations_dir: + version_locations.append(plugin_migrations) + + cfg.set_main_option( + "version_locations", + " ".join(version_locations)) + + return cfg |