diff options
Diffstat (limited to 'mediagoblin/db/open.py')
-rw-r--r-- | mediagoblin/db/open.py | 134 |
1 files changed, 111 insertions, 23 deletions
diff --git a/mediagoblin/db/open.py b/mediagoblin/db/open.py index 4ff0945f..8f81c8d9 100644 --- a/mediagoblin/db/open.py +++ b/mediagoblin/db/open.py @@ -15,38 +15,117 @@ # along with this program. If not, see <http://www.gnu.org/licenses/>. -from sqlalchemy import create_engine, event +from contextlib import contextmanager import logging -from mediagoblin.db.base import Base, Session +import six +from sqlalchemy import create_engine, event + from mediagoblin import mg_globals +from mediagoblin.db.base import Base _log = logging.getLogger(__name__) +from mediagoblin.tools.transition import DISABLE_GLOBALS + +def set_models_as_attributes(obj): + """ + Set all models as attributes on this object, for convenience + + TODO: This should eventually be deprecated. + """ + for k, v in six.iteritems(Base._decl_class_registry): + setattr(obj, k, v) + + +if not DISABLE_GLOBALS: + from mediagoblin.db.base import Session + + class DatabaseMaster(object): + def __init__(self, engine): + self.engine = engine + + set_models_as_attributes(self) + + def commit(self): + Session.commit() -class DatabaseMaster(object): - def __init__(self, engine): - self.engine = engine + def save(self, obj): + Session.add(obj) + Session.flush() - for k, v in Base._decl_class_registry.iteritems(): - setattr(self, k, v) + def check_session_clean(self): + for dummy in Session(): + _log.warn("STRANGE: There are elements in the sql session. " + "Please report this and help us track this down.") + break - def commit(self): - Session.commit() + def reset_after_request(self): + Session.rollback() + Session.remove() - def save(self, obj): - Session.add(obj) - Session.flush() + @property + def query(self): + return Session.query - def check_session_clean(self): - for dummy in Session(): - _log.warn("STRANGE: There are elements in the sql session. " - "Please report this and help us track this down.") - break +else: + from sqlalchemy.orm import sessionmaker + + class DatabaseManager(object): + """ + Manage database connections. + + The main method here is session_scope which can be used with a + "with" statement to get a session that is properly torn down + by the end of execution. + """ + def __init__(self, engine): + self.engine = engine + self.Session = sessionmaker(bind=engine) + set_models_as_attributes(self) + + @contextmanager + def session_scope(self): + """ + This is a context manager, use like:: + + with dbmanager.session_scope() as request.db: + some_view(request) + """ + session = self.Session() + + ##################################### + # Functions to emulate DatabaseMaster + ##################################### + def save(obj): + session.add(obj) + session.flush() + + def check_session_clean(): + # Is this implemented right? + for dummy in session: + _log.warn("STRANGE: There are elements in the sql session. " + "Please report this and help us track this down.") + break + + def reset_after_request(): + session.rollback() + session.remove() + + # now attach + session.save = save + session.check_session_clean = check_session_clean + session.reset_after_request = reset_after_request + + set_models_as_attributes(session) + ##################################### + + try: + yield session + finally: + session.rollback() + session.close() - def reset_after_request(self): - Session.rollback() - Session.remove() def load_models(app_config): @@ -75,9 +154,14 @@ def _sqlite_disable_fk_pragma_on_connect(dbapi_con, con_record): dbapi_con.execute('pragma foreign_keys=off') -def setup_connection_and_db_from_config(app_config, migrations=False): +def setup_connection_and_db_from_config(app_config, migrations=False, app=None): engine = create_engine(app_config['sql_engine']) + # @@: Maybe make a weak-ref so an engine can get garbage + # collected? Not that we expect to make a lot of MediaGoblinApp + # instances in a single process... + engine.app = app + # Enable foreign key checking for sqlite if app_config['sql_engine'].startswith('sqlite://'): if migrations: @@ -88,9 +172,13 @@ def setup_connection_and_db_from_config(app_config, migrations=False): # logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) - Session.configure(bind=engine) + if DISABLE_GLOBALS: + return DatabaseManager(engine) + + else: + Session.configure(bind=engine) - return DatabaseMaster(engine) + return DatabaseMaster(engine) def check_db_migrations_current(db): |