aboutsummaryrefslogtreecommitdiffstats
path: root/mediagoblin/db/migration_tools.py
diff options
context:
space:
mode:
Diffstat (limited to 'mediagoblin/db/migration_tools.py')
-rw-r--r--mediagoblin/db/migration_tools.py115
1 files changed, 94 insertions, 21 deletions
diff --git a/mediagoblin/db/migration_tools.py b/mediagoblin/db/migration_tools.py
index e39070c3..f4273fa0 100644
--- a/mediagoblin/db/migration_tools.py
+++ b/mediagoblin/db/migration_tools.py
@@ -14,10 +14,24 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
+from __future__ import unicode_literals
+
+import logging
+import os
+import pkg_resources
+
+from alembic import command
+from alembic.config import Config
+from alembic.migration import MigrationContext
+
+from mediagoblin.db.base import Base
from mediagoblin.tools.common import simple_printer
from sqlalchemy import Table
from sqlalchemy.sql import select
+log = logging.getLogger(__name__)
+
+
class TableAlreadyExists(Exception):
pass
@@ -30,7 +44,7 @@ class MigrationManager(object):
to the latest migrations, etc.
"""
- def __init__(self, name, models, foundations, migration_registry, session,
+ def __init__(self, name, models, migration_registry, session,
printer=simple_printer):
"""
Args:
@@ -39,9 +53,8 @@ class MigrationManager(object):
- migration_registry: where we should find all migrations to
run
"""
- self.name = unicode(name)
+ self.name = name
self.models = models
- self.foundations = foundations
self.session = session
self.migration_registry = migration_registry
self._sorted_migrations = None
@@ -112,14 +125,14 @@ class MigrationManager(object):
def migrations_to_run(self):
"""
Get a list of migrations to run still, if any.
-
+
Note that this will fail if there's no migration record for
this class!
"""
assert self.database_current_migration is not None
db_current_migration = self.database_current_migration
-
+
return [
(migration_number, migration_func)
for migration_number, migration_func in self.sorted_migrations
@@ -142,18 +155,6 @@ class MigrationManager(object):
self.session.bind,
tables=[model.__table__ for model in self.models])
- def populate_table_foundations(self):
- """
- Create the table foundations (default rows) as layed out in FOUNDATIONS
- in mediagoblin.db.models
- """
- for Model, rows in self.foundations.items():
- self.printer(u' + Laying foundations for %s table\n' %
- (Model.__name__))
- for parameters in rows:
- new_row = Model(**parameters)
- self.session.add(new_row)
-
def create_new_migration_record(self):
"""
Create a new migration record for this migration set
@@ -184,7 +185,7 @@ class MigrationManager(object):
migration_number, migration_func.func_name))
return u'migrated'
-
+
def name_for_printing(self):
if self.name == u'__main__':
return u"main mediagoblin tables"
@@ -218,7 +219,6 @@ class MigrationManager(object):
# auto-set at latest migration number
self.create_new_migration_record()
self.printer(u"done.\n")
- self.populate_table_foundations()
self.set_current_migration()
return u'inited'
@@ -230,7 +230,7 @@ class MigrationManager(object):
for migration_number, migration_func in migrations_to_run:
self.printer(
u' + Running migration %s, "%s"... ' % (
- migration_number, migration_func.func_name))
+ migration_number, migration_func.__name__))
migration_func(self.session)
self.set_current_migration(migration_number)
self.printer('done.\n')
@@ -263,6 +263,8 @@ class RegisterMigration(object):
assert migration_number > 0, "Migration number must be > 0!"
assert migration_number not in migration_registry, \
"Duplicate migration numbers detected! That's not allowed!"
+ assert migration_number <= 44, ('Alembic should be used for '
+ 'new migrations')
self.migration_number = migration_number
self.migration_registry = migration_registry
@@ -295,7 +297,7 @@ def replace_table_hack(db, old_table, replacement_table):
-tion, for example, dropping a boolean column in sqlite is impossible w/o
this method
- :param old_table A ref to the old table, gotten through
+ :param old_table A ref to the old table, gotten through
inspect_table
:param replacement_table A ref to the new table, gotten through
@@ -319,3 +321,74 @@ def replace_table_hack(db, old_table, replacement_table):
replacement_table.rename(old_table_name)
db.commit()
+
+def model_iteration_hack(db, query):
+ """
+ This will return either the query you gave if it's postgres or in the case
+ of sqlite it will return a list with all the results. This is because in
+ migrations it seems sqlite can't deal with concurrent quries so if you're
+ iterating over models and doing a commit inside the loop, you will run into
+ an exception which says you've closed the connection on your iteration
+ query. This fixes it.
+
+ NB: This loads all of the query reuslts into memeory, there isn't a good
+ way around this, we're assuming sqlite users have small databases.
+ """
+ # If it's SQLite just return all the objects
+ if db.bind.url.drivername == "sqlite":
+ return [obj for obj in db.execute(query)]
+
+ # Postgres return the query as it knows how to deal with it.
+ return db.execute(query)
+
+
+def populate_table_foundations(session, foundations, name,
+ printer=simple_printer):
+ """
+ Create the table foundations (default rows) as layed out in FOUNDATIONS
+ in mediagoblin.db.models
+ """
+ printer(u'Laying foundations for %s:\n' % name)
+ for Model, rows in foundations.items():
+ printer(u' + Laying foundations for %s table\n' %
+ (Model.__name__))
+ for parameters in rows:
+ new_row = Model(**parameters)
+ session.add(new_row)
+
+ session.commit()
+
+
+def build_alembic_config(global_config, cmd_options, session):
+ """
+ Build up a config that the alembic tooling can use based on our
+ configuration. Initialize the database session appropriately
+ as well.
+ """
+ root_dir = os.path.abspath(os.path.dirname(os.path.dirname(
+ os.path.dirname(__file__))))
+ alembic_cfg_path = os.path.join(root_dir, 'alembic.ini')
+ cfg = Config(alembic_cfg_path,
+ cmd_opts=cmd_options)
+ cfg.attributes["session"] = session
+
+ version_locations = [
+ pkg_resources.resource_filename(
+ "mediagoblin.db", os.path.join("migrations", "versions")),
+ ]
+
+ cfg.set_main_option("sqlalchemy.url", str(session.get_bind().url))
+
+ for plugin in global_config.get("plugins", []):
+ plugin_migrations = pkg_resources.resource_filename(
+ plugin, "migrations")
+ is_migrations_dir = (os.path.exists(plugin_migrations) and
+ os.path.isdir(plugin_migrations))
+ if is_migrations_dir:
+ version_locations.append(plugin_migrations)
+
+ cfg.set_main_option(
+ "version_locations",
+ " ".join(version_locations))
+
+ return cfg