aboutsummaryrefslogtreecommitdiffstats
path: root/mediagoblin/db/migration_tools.py
diff options
context:
space:
mode:
Diffstat (limited to 'mediagoblin/db/migration_tools.py')
-rw-r--r--mediagoblin/db/migration_tools.py58
1 files changed, 56 insertions, 2 deletions
diff --git a/mediagoblin/db/migration_tools.py b/mediagoblin/db/migration_tools.py
index e39070c3..fae98643 100644
--- a/mediagoblin/db/migration_tools.py
+++ b/mediagoblin/db/migration_tools.py
@@ -14,14 +14,68 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
+from __future__ import unicode_literals
+
+import logging
+import os
+
+from alembic import command
+from alembic.config import Config
+from alembic.migration import MigrationContext
+
+from mediagoblin.db.base import Base
from mediagoblin.tools.common import simple_printer
from sqlalchemy import Table
from sqlalchemy.sql import select
+log = logging.getLogger(__name__)
+
+
class TableAlreadyExists(Exception):
pass
+class AlembicMigrationManager(object):
+
+ def __init__(self, session):
+ root_dir = os.path.abspath(os.path.dirname(os.path.dirname(
+ os.path.dirname(__file__))))
+ alembic_cfg_path = os.path.join(root_dir, 'alembic.ini')
+ self.alembic_cfg = Config(alembic_cfg_path)
+ self.session = session
+
+ def get_current_revision(self):
+ context = MigrationContext.configure(self.session.bind)
+ return context.get_current_revision()
+
+ def upgrade(self, version):
+ return command.upgrade(self.alembic_cfg, version or 'head')
+
+ def downgrade(self, version):
+ if isinstance(version, int) or version is None or version.isdigit():
+ version = 'base'
+ return command.downgrade(self.alembic_cfg, version)
+
+ def stamp(self, revision):
+ return command.stamp(self.alembic_cfg, revision=revision)
+
+ def init_tables(self):
+ Base.metadata.create_all(self.session.bind)
+ # load the Alembic configuration and generate the
+ # version table, "stamping" it with the most recent rev:
+ # XXX: we need to find a better way to detect current installations
+ # using sqlalchemy-migrate because we don't have to create all table
+ # for them
+ command.stamp(self.alembic_cfg, 'head')
+
+ def init_or_migrate(self, version=None):
+ # XXX: we need to call this method when we ditch
+ # sqlalchemy-migrate entirely
+ # if self.get_current_revision() is None:
+ # self.init_tables()
+ self.upgrade(version)
+
+
class MigrationManager(object):
"""
Migration handling tool.
@@ -39,7 +93,7 @@ class MigrationManager(object):
- migration_registry: where we should find all migrations to
run
"""
- self.name = unicode(name)
+ self.name = name
self.models = models
self.foundations = foundations
self.session = session
@@ -230,7 +284,7 @@ class MigrationManager(object):
for migration_number, migration_func in migrations_to_run:
self.printer(
u' + Running migration %s, "%s"... ' % (
- migration_number, migration_func.func_name))
+ migration_number, migration_func.__name__))
migration_func(self.session)
self.set_current_migration(migration_number)
self.printer('done.\n')