diff options
Diffstat (limited to 'mediagoblin/db/migration_tools.py')
-rw-r--r-- | mediagoblin/db/migration_tools.py | 58 |
1 files changed, 56 insertions, 2 deletions
diff --git a/mediagoblin/db/migration_tools.py b/mediagoblin/db/migration_tools.py index e39070c3..fae98643 100644 --- a/mediagoblin/db/migration_tools.py +++ b/mediagoblin/db/migration_tools.py @@ -14,14 +14,68 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. +from __future__ import unicode_literals + +import logging +import os + +from alembic import command +from alembic.config import Config +from alembic.migration import MigrationContext + +from mediagoblin.db.base import Base from mediagoblin.tools.common import simple_printer from sqlalchemy import Table from sqlalchemy.sql import select +log = logging.getLogger(__name__) + + class TableAlreadyExists(Exception): pass +class AlembicMigrationManager(object): + + def __init__(self, session): + root_dir = os.path.abspath(os.path.dirname(os.path.dirname( + os.path.dirname(__file__)))) + alembic_cfg_path = os.path.join(root_dir, 'alembic.ini') + self.alembic_cfg = Config(alembic_cfg_path) + self.session = session + + def get_current_revision(self): + context = MigrationContext.configure(self.session.bind) + return context.get_current_revision() + + def upgrade(self, version): + return command.upgrade(self.alembic_cfg, version or 'head') + + def downgrade(self, version): + if isinstance(version, int) or version is None or version.isdigit(): + version = 'base' + return command.downgrade(self.alembic_cfg, version) + + def stamp(self, revision): + return command.stamp(self.alembic_cfg, revision=revision) + + def init_tables(self): + Base.metadata.create_all(self.session.bind) + # load the Alembic configuration and generate the + # version table, "stamping" it with the most recent rev: + # XXX: we need to find a better way to detect current installations + # using sqlalchemy-migrate because we don't have to create all table + # for them + command.stamp(self.alembic_cfg, 'head') + + def init_or_migrate(self, version=None): + # XXX: we need to call this method when we ditch + # sqlalchemy-migrate entirely + # if self.get_current_revision() is None: + # self.init_tables() + self.upgrade(version) + + class MigrationManager(object): """ Migration handling tool. @@ -39,7 +93,7 @@ class MigrationManager(object): - migration_registry: where we should find all migrations to run """ - self.name = unicode(name) + self.name = name self.models = models self.foundations = foundations self.session = session @@ -230,7 +284,7 @@ class MigrationManager(object): for migration_number, migration_func in migrations_to_run: self.printer( u' + Running migration %s, "%s"... ' % ( - migration_number, migration_func.func_name)) + migration_number, migration_func.__name__)) migration_func(self.session) self.set_current_migration(migration_number) self.printer('done.\n') |