aboutsummaryrefslogtreecommitdiffstats
path: root/mediagoblin/db/open.py
diff options
context:
space:
mode:
Diffstat (limited to 'mediagoblin/db/open.py')
-rw-r--r--mediagoblin/db/open.py134
1 files changed, 111 insertions, 23 deletions
diff --git a/mediagoblin/db/open.py b/mediagoblin/db/open.py
index 4ff0945f..8f81c8d9 100644
--- a/mediagoblin/db/open.py
+++ b/mediagoblin/db/open.py
@@ -15,38 +15,117 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
-from sqlalchemy import create_engine, event
+from contextlib import contextmanager
import logging
-from mediagoblin.db.base import Base, Session
+import six
+from sqlalchemy import create_engine, event
+
from mediagoblin import mg_globals
+from mediagoblin.db.base import Base
_log = logging.getLogger(__name__)
+from mediagoblin.tools.transition import DISABLE_GLOBALS
+
+def set_models_as_attributes(obj):
+ """
+ Set all models as attributes on this object, for convenience
+
+ TODO: This should eventually be deprecated.
+ """
+ for k, v in six.iteritems(Base._decl_class_registry):
+ setattr(obj, k, v)
+
+
+if not DISABLE_GLOBALS:
+ from mediagoblin.db.base import Session
+
+ class DatabaseMaster(object):
+ def __init__(self, engine):
+ self.engine = engine
+
+ set_models_as_attributes(self)
+
+ def commit(self):
+ Session.commit()
-class DatabaseMaster(object):
- def __init__(self, engine):
- self.engine = engine
+ def save(self, obj):
+ Session.add(obj)
+ Session.flush()
- for k, v in Base._decl_class_registry.iteritems():
- setattr(self, k, v)
+ def check_session_clean(self):
+ for dummy in Session():
+ _log.warn("STRANGE: There are elements in the sql session. "
+ "Please report this and help us track this down.")
+ break
- def commit(self):
- Session.commit()
+ def reset_after_request(self):
+ Session.rollback()
+ Session.remove()
- def save(self, obj):
- Session.add(obj)
- Session.flush()
+ @property
+ def query(self):
+ return Session.query
- def check_session_clean(self):
- for dummy in Session():
- _log.warn("STRANGE: There are elements in the sql session. "
- "Please report this and help us track this down.")
- break
+else:
+ from sqlalchemy.orm import sessionmaker
+
+ class DatabaseManager(object):
+ """
+ Manage database connections.
+
+ The main method here is session_scope which can be used with a
+ "with" statement to get a session that is properly torn down
+ by the end of execution.
+ """
+ def __init__(self, engine):
+ self.engine = engine
+ self.Session = sessionmaker(bind=engine)
+ set_models_as_attributes(self)
+
+ @contextmanager
+ def session_scope(self):
+ """
+ This is a context manager, use like::
+
+ with dbmanager.session_scope() as request.db:
+ some_view(request)
+ """
+ session = self.Session()
+
+ #####################################
+ # Functions to emulate DatabaseMaster
+ #####################################
+ def save(obj):
+ session.add(obj)
+ session.flush()
+
+ def check_session_clean():
+ # Is this implemented right?
+ for dummy in session:
+ _log.warn("STRANGE: There are elements in the sql session. "
+ "Please report this and help us track this down.")
+ break
+
+ def reset_after_request():
+ session.rollback()
+ session.remove()
+
+ # now attach
+ session.save = save
+ session.check_session_clean = check_session_clean
+ session.reset_after_request = reset_after_request
+
+ set_models_as_attributes(session)
+ #####################################
+
+ try:
+ yield session
+ finally:
+ session.rollback()
+ session.close()
- def reset_after_request(self):
- Session.rollback()
- Session.remove()
def load_models(app_config):
@@ -75,9 +154,14 @@ def _sqlite_disable_fk_pragma_on_connect(dbapi_con, con_record):
dbapi_con.execute('pragma foreign_keys=off')
-def setup_connection_and_db_from_config(app_config, migrations=False):
+def setup_connection_and_db_from_config(app_config, migrations=False, app=None):
engine = create_engine(app_config['sql_engine'])
+ # @@: Maybe make a weak-ref so an engine can get garbage
+ # collected? Not that we expect to make a lot of MediaGoblinApp
+ # instances in a single process...
+ engine.app = app
+
# Enable foreign key checking for sqlite
if app_config['sql_engine'].startswith('sqlite://'):
if migrations:
@@ -88,9 +172,13 @@ def setup_connection_and_db_from_config(app_config, migrations=False):
# logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
- Session.configure(bind=engine)
+ if DISABLE_GLOBALS:
+ return DatabaseManager(engine)
+
+ else:
+ Session.configure(bind=engine)
- return DatabaseMaster(engine)
+ return DatabaseMaster(engine)
def check_db_migrations_current(db):