aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mediagoblin/decorators.py28
-rw-r--r--mediagoblin/tests/tools.py9
2 files changed, 18 insertions, 19 deletions
diff --git a/mediagoblin/decorators.py b/mediagoblin/decorators.py
index b2791083..0eb1361d 100644
--- a/mediagoblin/decorators.py
+++ b/mediagoblin/decorators.py
@@ -14,27 +14,19 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
+from functools import wraps
from webob import exc
-from mediagoblin.tools.response import redirect, render_404
from mediagoblin.db.util import ObjectId, InvalidId
-
-
-def _make_safe(decorator, original):
- """
- Copy the function data from the old function to the decorator.
- """
- decorator.__name__ = original.__name__
- decorator.__dict__ = original.__dict__
- decorator.__doc__ = original.__doc__
- return decorator
+from mediagoblin.tools.response import redirect, render_404
def require_active_login(controller):
"""
Require an active login from the user.
"""
+ @wraps(controller)
def new_controller_func(request, *args, **kwargs):
if request.user and \
request.user.get('status') == u'needs_email_verification':
@@ -49,13 +41,14 @@ def require_active_login(controller):
return controller(request, *args, **kwargs)
- return _make_safe(new_controller_func, controller)
+ return new_controller_func
def user_may_delete_media(controller):
"""
Require user ownership of the MediaEntry to delete.
"""
+ @wraps(controller)
def wrapper(request, *args, **kwargs):
uploader_id = request.db.MediaEntry.find_one(
{'_id': ObjectId(request.matchdict['media'])}).uploader
@@ -65,13 +58,14 @@ def user_may_delete_media(controller):
return controller(request, *args, **kwargs)
- return _make_safe(wrapper, controller)
+ return wrapper
def uses_pagination(controller):
"""
Check request GET 'page' key for wrong values
"""
+ @wraps(controller)
def wrapper(request, *args, **kwargs):
try:
page = int(request.GET.get('page', 1))
@@ -82,13 +76,14 @@ def uses_pagination(controller):
return controller(request, page=page, *args, **kwargs)
- return _make_safe(wrapper, controller)
+ return wrapper
def get_user_media_entry(controller):
"""
Pass in a MediaEntry based off of a url component
"""
+ @wraps(controller)
def wrapper(request, *args, **kwargs):
user = request.db.User.find_one(
{'username': request.matchdict['user']})
@@ -116,13 +111,14 @@ def get_user_media_entry(controller):
return controller(request, media=media, *args, **kwargs)
- return _make_safe(wrapper, controller)
+ return wrapper
def get_media_entry_by_id(controller):
"""
Pass in a MediaEntry based off of a url component
"""
+ @wraps(controller)
def wrapper(request, *args, **kwargs):
try:
media = request.db.MediaEntry.find_one(
@@ -137,4 +133,4 @@ def get_media_entry_by_id(controller):
return controller(request, media=media, *args, **kwargs)
- return _make_safe(wrapper, controller)
+ return wrapper
diff --git a/mediagoblin/tests/tools.py b/mediagoblin/tests/tools.py
index bf40ea8b..a40569e4 100644
--- a/mediagoblin/tests/tools.py
+++ b/mediagoblin/tests/tools.py
@@ -15,8 +15,11 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
+import os
import pkg_resources
-import os, shutil
+import shutil
+
+from functools import wraps
from paste.deploy import loadapp
from webtest import TestApp
@@ -24,7 +27,6 @@ from webtest import TestApp
from mediagoblin import mg_globals
from mediagoblin.tools import testing
from mediagoblin.init.config import read_mediagoblin_config
-from mediagoblin.decorators import _make_safe
from mediagoblin.db.open import setup_connection_and_db_from_config
from mediagoblin.db.sql.base import Session
from mediagoblin.meddleware import BaseMeddleware
@@ -159,12 +161,13 @@ def setup_fresh_app(func):
Cleans out test buckets and passes in a new, fresh test_app.
"""
+ @wraps(func)
def wrapper(*args, **kwargs):
test_app = get_test_app()
testing.clear_test_buckets()
return func(test_app, *args, **kwargs)
- return _make_safe(wrapper, func)
+ return wrapper
def install_fixtures_simple(db, fixtures):