aboutsummaryrefslogtreecommitdiffstats
path: root/mediagoblin/db/open.py
diff options
context:
space:
mode:
Diffstat (limited to 'mediagoblin/db/open.py')
-rw-r--r--mediagoblin/db/open.py98
1 files changed, 85 insertions, 13 deletions
diff --git a/mediagoblin/db/open.py b/mediagoblin/db/open.py
index f4c38511..0b1679fb 100644
--- a/mediagoblin/db/open.py
+++ b/mediagoblin/db/open.py
@@ -14,16 +14,88 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
-try:
- from mediagoblin.db.sql_switch import use_sql
-except ImportError:
- use_sql = False
-
-if use_sql:
- from mediagoblin.db.sql.open import \
- setup_connection_and_db_from_config, check_db_migrations_current, \
- load_models
-else:
- from mediagoblin.db.mongo.open import \
- setup_connection_and_db_from_config, check_db_migrations_current, \
- load_models
+
+from sqlalchemy import create_engine, event
+import logging
+
+from mediagoblin.db.base import Base, Session
+from mediagoblin import mg_globals
+
+_log = logging.getLogger(__name__)
+
+
+class DatabaseMaster(object):
+ def __init__(self, engine):
+ self.engine = engine
+
+ for k, v in Base._decl_class_registry.iteritems():
+ setattr(self, k, v)
+
+ def commit(self):
+ Session.commit()
+
+ def save(self, obj):
+ Session.add(obj)
+ Session.flush()
+
+ def check_session_clean(self):
+ for dummy in Session():
+ _log.warn("STRANGE: There are elements in the sql session. "
+ "Please report this and help us track this down.")
+ break
+
+ def reset_after_request(self):
+ Session.rollback()
+ Session.remove()
+
+
+def load_models(app_config):
+ import mediagoblin.db.models
+
+ for media_type in app_config['media_types']:
+ _log.debug("Loading %s.models", media_type)
+ __import__(media_type + ".models")
+
+ for plugin in mg_globals.global_config.get('plugins', {}).keys():
+ _log.debug("Loading %s.models", plugin)
+ try:
+ __import__(plugin + ".models")
+ except ImportError as exc:
+ _log.debug("Could not load {0}.models: {1}".format(
+ plugin,
+ exc))
+
+
+def _sqlite_fk_pragma_on_connect(dbapi_con, con_record):
+ """Enable foreign key checking on each new sqlite connection"""
+ dbapi_con.execute('pragma foreign_keys=on')
+
+
+def _sqlite_disable_fk_pragma_on_connect(dbapi_con, con_record):
+ """
+ Disable foreign key checking on each new sqlite connection
+ (Good for migrations!)
+ """
+ dbapi_con.execute('pragma foreign_keys=off')
+
+
+def setup_connection_and_db_from_config(app_config, migrations=False):
+ engine = create_engine(app_config['sql_engine'])
+
+ # Enable foreign key checking for sqlite
+ if app_config['sql_engine'].startswith('sqlite://'):
+ if migrations:
+ event.listen(engine, 'connect',
+ _sqlite_disable_fk_pragma_on_connect)
+ else:
+ event.listen(engine, 'connect', _sqlite_fk_pragma_on_connect)
+
+ # logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
+
+ Session.configure(bind=engine)
+
+ return DatabaseMaster(engine)
+
+
+def check_db_migrations_current(db):
+ pass