aboutsummaryrefslogtreecommitdiffstats
path: root/mediagoblin/plugins
diff options
context:
space:
mode:
Diffstat (limited to 'mediagoblin/plugins')
-rw-r--r--mediagoblin/plugins/api/__init__.py4
-rw-r--r--mediagoblin/plugins/httpapiauth/__init__.py3
-rw-r--r--mediagoblin/plugins/oauth/__init__.py2
-rw-r--r--mediagoblin/plugins/oauth/forms.py2
-rw-r--r--mediagoblin/plugins/oauth/migrations.py34
-rw-r--r--mediagoblin/plugins/oauth/models.py87
-rw-r--r--mediagoblin/plugins/oauth/tools.py73
-rw-r--r--mediagoblin/plugins/oauth/views.py150
-rw-r--r--mediagoblin/plugins/piwigo/__init__.py5
-rw-r--r--mediagoblin/plugins/piwigo/forms.py16
-rw-r--r--mediagoblin/plugins/piwigo/tools.py51
-rw-r--r--mediagoblin/plugins/piwigo/views.py68
12 files changed, 367 insertions, 128 deletions
diff --git a/mediagoblin/plugins/api/__init__.py b/mediagoblin/plugins/api/__init__.py
index d3fdf2ef..1eddd9e0 100644
--- a/mediagoblin/plugins/api/__init__.py
+++ b/mediagoblin/plugins/api/__init__.py
@@ -23,11 +23,11 @@ _log = logging.getLogger(__name__)
PLUGIN_DIR = os.path.dirname(__file__)
-config = pluginapi.get_config(__name__)
-
def setup_plugin():
_log.info('Setting up API...')
+ config = pluginapi.get_config(__name__)
+
_log.debug('API config: {0}'.format(config))
routes = [
diff --git a/mediagoblin/plugins/httpapiauth/__init__.py b/mediagoblin/plugins/httpapiauth/__init__.py
index 081b590e..99b6a4b0 100644
--- a/mediagoblin/plugins/httpapiauth/__init__.py
+++ b/mediagoblin/plugins/httpapiauth/__init__.py
@@ -15,9 +15,8 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging
-import base64
-from werkzeug.exceptions import BadRequest, Unauthorized
+from werkzeug.exceptions import Unauthorized
from mediagoblin.plugins.api.tools import Auth
diff --git a/mediagoblin/plugins/oauth/__init__.py b/mediagoblin/plugins/oauth/__init__.py
index 4714d95d..5762379d 100644
--- a/mediagoblin/plugins/oauth/__init__.py
+++ b/mediagoblin/plugins/oauth/__init__.py
@@ -34,7 +34,7 @@ def setup_plugin():
_log.debug('OAuth config: {0}'.format(config))
routes = [
- ('mediagoblin.plugins.oauth.authorize',
+ ('mediagoblin.plugins.oauth.authorize',
'/oauth/authorize',
'mediagoblin.plugins.oauth.views:authorize'),
('mediagoblin.plugins.oauth.authorize_client',
diff --git a/mediagoblin/plugins/oauth/forms.py b/mediagoblin/plugins/oauth/forms.py
index d0a4e9b8..5edd992a 100644
--- a/mediagoblin/plugins/oauth/forms.py
+++ b/mediagoblin/plugins/oauth/forms.py
@@ -19,7 +19,7 @@ import wtforms
from urlparse import urlparse
from mediagoblin.tools.extlib.wtf_html5 import URLField
-from mediagoblin.tools.translate import fake_ugettext_passthrough as _
+from mediagoblin.tools.translate import lazy_pass_to_ugettext as _
class AuthorizationForm(wtforms.Form):
diff --git a/mediagoblin/plugins/oauth/migrations.py b/mediagoblin/plugins/oauth/migrations.py
index 6aa0d7cb..d7b89da3 100644
--- a/mediagoblin/plugins/oauth/migrations.py
+++ b/mediagoblin/plugins/oauth/migrations.py
@@ -102,6 +102,21 @@ class OAuthCode_v0(declarative_base()):
client_id = Column(Integer, ForeignKey(OAuthClient_v0.id), nullable=False)
+class OAuthRefreshToken_v0(declarative_base()):
+ __tablename__ = 'oauth__refresh_tokens'
+
+ id = Column(Integer, primary_key=True)
+ created = Column(DateTime, nullable=False,
+ default=datetime.now)
+
+ token = Column(Unicode, index=True)
+
+ user_id = Column(Integer, ForeignKey(User.id), nullable=False)
+
+ # XXX: Is it OK to use OAuthClient_v0.id in this way?
+ client_id = Column(Integer, ForeignKey(OAuthClient_v0.id), nullable=False)
+
+
@RegisterMigration(1, MIGRATIONS)
def remove_and_replace_token_and_code(db):
metadata = MetaData(bind=db.bind)
@@ -122,3 +137,22 @@ def remove_and_replace_token_and_code(db):
OAuthCode_v0.__table__.create(db.bind)
db.commit()
+
+
+@RegisterMigration(2, MIGRATIONS)
+def remove_refresh_token_field(db):
+ metadata = MetaData(bind=db.bind)
+
+ token_table = Table('oauth__tokens', metadata, autoload=True,
+ autoload_with=db.bind)
+
+ refresh_token = token_table.columns['refresh_token']
+
+ refresh_token.drop()
+ db.commit()
+
+@RegisterMigration(3, MIGRATIONS)
+def create_refresh_token_table(db):
+ OAuthRefreshToken_v0.__table__.create(db.bind)
+
+ db.commit()
diff --git a/mediagoblin/plugins/oauth/models.py b/mediagoblin/plugins/oauth/models.py
index 695dad31..439424d3 100644
--- a/mediagoblin/plugins/oauth/models.py
+++ b/mediagoblin/plugins/oauth/models.py
@@ -14,17 +14,17 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
-import uuid
-import bcrypt
from datetime import datetime, timedelta
-from mediagoblin.db.base import Base
-from mediagoblin.db.models import User
from sqlalchemy import (
Column, Unicode, Integer, DateTime, ForeignKey, Enum)
-from sqlalchemy.orm import relationship
+from sqlalchemy.orm import relationship, backref
+from mediagoblin.db.base import Base
+from mediagoblin.db.models import User
+from mediagoblin.plugins.oauth.tools import generate_identifier, \
+ generate_secret, generate_token, generate_code, generate_refresh_token
# Don't remove this, I *think* it applies sqlalchemy-migrate functionality onto
# the models.
@@ -41,11 +41,14 @@ class OAuthClient(Base):
name = Column(Unicode)
description = Column(Unicode)
- identifier = Column(Unicode, unique=True, index=True)
- secret = Column(Unicode, index=True)
+ identifier = Column(Unicode, unique=True, index=True,
+ default=generate_identifier)
+ secret = Column(Unicode, index=True, default=generate_secret)
owner_id = Column(Integer, ForeignKey(User.id))
- owner = relationship(User, backref='registered_clients')
+ owner = relationship(
+ User,
+ backref=backref('registered_clients', cascade='all, delete-orphan'))
redirect_uri = Column(Unicode)
@@ -54,14 +57,8 @@ class OAuthClient(Base):
u'public',
name=u'oauth__client_type'))
- def generate_identifier(self):
- self.identifier = unicode(uuid.uuid4())
-
- def generate_secret(self):
- self.secret = unicode(
- bcrypt.hashpw(
- unicode(uuid.uuid4()),
- bcrypt.gensalt()))
+ def update_secret(self):
+ self.secret = generate_secret()
def __repr__(self):
return '<{0} {1}:{2} ({3})>'.format(
@@ -76,10 +73,15 @@ class OAuthUserClient(Base):
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey(User.id))
- user = relationship(User, backref='oauth_clients')
+ user = relationship(
+ User,
+ backref=backref('oauth_client_relations',
+ cascade='all, delete-orphan'))
client_id = Column(Integer, ForeignKey(OAuthClient.id))
- client = relationship(OAuthClient, backref='users')
+ client = relationship(
+ OAuthClient,
+ backref=backref('oauth_user_relations', cascade='all, delete-orphan'))
state = Column(Enum(
u'approved',
@@ -103,15 +105,18 @@ class OAuthToken(Base):
default=datetime.now)
expires = Column(DateTime, nullable=False,
default=lambda: datetime.now() + timedelta(days=30))
- token = Column(Unicode, index=True)
- refresh_token = Column(Unicode, index=True)
+ token = Column(Unicode, index=True, default=generate_token)
user_id = Column(Integer, ForeignKey(User.id), nullable=False,
index=True)
- user = relationship(User)
+ user = relationship(
+ User,
+ backref=backref('oauth_tokens', cascade='all, delete-orphan'))
client_id = Column(Integer, ForeignKey(OAuthClient.id), nullable=False)
- client = relationship(OAuthClient)
+ client = relationship(
+ OAuthClient,
+ backref=backref('oauth_tokens', cascade='all, delete-orphan'))
def __repr__(self):
return '<{0} #{1} expires {2} [{3}, {4}]>'.format(
@@ -121,6 +126,34 @@ class OAuthToken(Base):
self.user,
self.client)
+class OAuthRefreshToken(Base):
+ __tablename__ = 'oauth__refresh_tokens'
+
+ id = Column(Integer, primary_key=True)
+ created = Column(DateTime, nullable=False,
+ default=datetime.now)
+
+ token = Column(Unicode, index=True,
+ default=generate_refresh_token)
+
+ user_id = Column(Integer, ForeignKey(User.id), nullable=False)
+
+ user = relationship(User, backref=backref('oauth_refresh_tokens',
+ cascade='all, delete-orphan'))
+
+ client_id = Column(Integer, ForeignKey(OAuthClient.id), nullable=False)
+ client = relationship(OAuthClient,
+ backref=backref(
+ 'oauth_refresh_tokens',
+ cascade='all, delete-orphan'))
+
+ def __repr__(self):
+ return '<{0} #{1} [{3}, {4}]>'.format(
+ self.__class__.__name__,
+ self.id,
+ self.user,
+ self.client)
+
class OAuthCode(Base):
__tablename__ = 'oauth__codes'
@@ -130,14 +163,17 @@ class OAuthCode(Base):
default=datetime.now)
expires = Column(DateTime, nullable=False,
default=lambda: datetime.now() + timedelta(minutes=5))
- code = Column(Unicode, index=True)
+ code = Column(Unicode, index=True, default=generate_code)
user_id = Column(Integer, ForeignKey(User.id), nullable=False,
index=True)
- user = relationship(User)
+ user = relationship(User, backref=backref('oauth_codes',
+ cascade='all, delete-orphan'))
client_id = Column(Integer, ForeignKey(OAuthClient.id), nullable=False)
- client = relationship(OAuthClient)
+ client = relationship(OAuthClient, backref=backref(
+ 'oauth_codes',
+ cascade='all, delete-orphan'))
def __repr__(self):
return '<{0} #{1} expires {2} [{3}, {4}]>'.format(
@@ -150,6 +186,7 @@ class OAuthCode(Base):
MODELS = [
OAuthToken,
+ OAuthRefreshToken,
OAuthCode,
OAuthClient,
OAuthUserClient]
diff --git a/mediagoblin/plugins/oauth/tools.py b/mediagoblin/plugins/oauth/tools.py
index d21c8a5b..27ff32b4 100644
--- a/mediagoblin/plugins/oauth/tools.py
+++ b/mediagoblin/plugins/oauth/tools.py
@@ -1,3 +1,4 @@
+# -*- coding: utf-8 -*-
# GNU MediaGoblin -- federated, autonomous media hosting
# Copyright (C) 2011, 2012 MediaGoblin contributors. See AUTHORS.
#
@@ -14,13 +15,26 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
+import uuid
+
+from random import getrandbits
+
+from datetime import datetime
+
from functools import wraps
-from mediagoblin.plugins.oauth.models import OAuthClient
from mediagoblin.plugins.api.tools import json_response
def require_client_auth(controller):
+ '''
+ View decorator
+
+ - Requires the presence of ``?client_id``
+ '''
+ # Avoid circular import
+ from mediagoblin.plugins.oauth.models import OAuthClient
+
@wraps(controller)
def wrapper(request, *args, **kw):
if not request.GET.get('client_id'):
@@ -41,3 +55,60 @@ def require_client_auth(controller):
return controller(request, client)
return wrapper
+
+
+def create_token(client, user):
+ '''
+ Create an OAuthToken and an OAuthRefreshToken entry in the database
+
+ Returns the data structure expected by the OAuth clients.
+ '''
+ from mediagoblin.plugins.oauth.models import OAuthToken, OAuthRefreshToken
+
+ token = OAuthToken()
+ token.user = user
+ token.client = client
+ token.save()
+
+ refresh_token = OAuthRefreshToken()
+ refresh_token.user = user
+ refresh_token.client = client
+ refresh_token.save()
+
+ # expire time of token in full seconds
+ # timedelta.total_seconds is python >= 2.7 or we would use that
+ td = token.expires - datetime.now()
+ exp_in = 86400*td.days + td.seconds # just ignore µsec
+
+ return {'access_token': token.token, 'token_type': 'bearer',
+ 'refresh_token': refresh_token.token, 'expires_in': exp_in}
+
+
+def generate_identifier():
+ ''' Generates a ``uuid.uuid4()`` '''
+ return unicode(uuid.uuid4())
+
+
+def generate_token():
+ ''' Uses generate_identifier '''
+ return generate_identifier()
+
+
+def generate_refresh_token():
+ ''' Uses generate_identifier '''
+ return generate_identifier()
+
+
+def generate_code():
+ ''' Uses generate_identifier '''
+ return generate_identifier()
+
+
+def generate_secret():
+ '''
+ Generate a long string of pseudo-random characters
+ '''
+ # XXX: We might not want it to use bcrypt, since bcrypt takes its time to
+ # generate the result.
+ return unicode(getrandbits(192))
+
diff --git a/mediagoblin/plugins/oauth/views.py b/mediagoblin/plugins/oauth/views.py
index ea45c209..d6fd314f 100644
--- a/mediagoblin/plugins/oauth/views.py
+++ b/mediagoblin/plugins/oauth/views.py
@@ -16,21 +16,21 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging
-import json
from urllib import urlencode
-from uuid import uuid4
-from datetime import datetime
+
+from werkzeug.exceptions import BadRequest
from mediagoblin.tools.response import render_to_response, redirect
from mediagoblin.decorators import require_active_login
-from mediagoblin.messages import add_message, SUCCESS, ERROR
+from mediagoblin.messages import add_message, SUCCESS
from mediagoblin.tools.translate import pass_to_ugettext as _
-from mediagoblin.plugins.oauth.models import OAuthCode, OAuthToken, \
- OAuthClient, OAuthUserClient
+from mediagoblin.plugins.oauth.models import OAuthCode, OAuthClient, \
+ OAuthUserClient, OAuthRefreshToken
from mediagoblin.plugins.oauth.forms import ClientRegistrationForm, \
AuthorizationForm
-from mediagoblin.plugins.oauth.tools import require_client_auth
+from mediagoblin.plugins.oauth.tools import require_client_auth, \
+ create_token
from mediagoblin.plugins.api.tools import json_response
_log = logging.getLogger(__name__)
@@ -51,9 +51,6 @@ def register_client(request):
client.owner_id = request.user.id
client.redirect_uri = unicode(form.redirect_uri.data)
- client.generate_identifier()
- client.generate_secret()
-
client.save()
add_message(request, SUCCESS, _('The client {0} has been registered!')\
@@ -92,9 +89,9 @@ def authorize_client(request):
form.client_id.data).first()
if not client:
- _log.error('''No such client id as received from client authorization
- form.''')
- return BadRequest()
+ _log.error('No such client id as received from client authorization \
+form.')
+ raise BadRequest()
if form.validate():
relation = OAuthUserClient()
@@ -105,7 +102,7 @@ def authorize_client(request):
elif form.deny.data:
relation.state = u'rejected'
else:
- return BadRequest
+ raise BadRequest()
relation.save()
@@ -136,7 +133,7 @@ def authorize(request, client):
return json_response({
'status': 400,
'errors':
- [u'Public clients MUST have a redirect_uri pre-set']},
+ [u'Public clients should have a redirect_uri pre-set.']},
_disable_cors=True)
redirect_uri = client.redirect_uri
@@ -146,11 +143,10 @@ def authorize(request, client):
if not redirect_uri:
return json_response({
'status': 400,
- 'errors': [u'Can not find a redirect_uri for client: {0}'\
- .format(client.name)]}, _disable_cors=True)
+ 'errors': [u'No redirect_uri supplied!']},
+ _disable_cors=True)
code = OAuthCode()
- code.code = unicode(uuid4())
code.user = request.user
code.client = client
code.save()
@@ -180,59 +176,79 @@ def authorize(request, client):
def access_token(request):
+ '''
+ Access token endpoint provides access tokens to any clients that have the
+ right grants/credentials
+ '''
+
+ client = None
+ user = None
+
if request.GET.get('code'):
+ # Validate the code arg, then get the client object from the db.
code = OAuthCode.query.filter(OAuthCode.code ==
request.GET.get('code')).first()
- if code:
- if code.client.type == u'confidential':
- client_identifier = request.GET.get('client_id')
-
- if not client_identifier:
- return json_response({
- 'error': 'invalid_request',
- 'error_description':
- 'Missing client_id in request'})
-
- client_secret = request.GET.get('client_secret')
-
- if not client_secret:
- return json_response({
- 'error': 'invalid_request',
- 'error_description':
- 'Missing client_secret in request'})
-
- if not client_secret == code.client.secret or \
- not client_identifier == code.client.identifier:
- return json_response({
- 'error': 'invalid_client',
- 'error_description':
- 'The client_id or client_secret does not match the'
- ' code'})
-
- token = OAuthToken()
- token.token = unicode(uuid4())
- token.user = code.user
- token.client = code.client
- token.save()
-
- # expire time of token in full seconds
- # timedelta.total_seconds is python >= 2.7 or we would use that
- td = token.expires - datetime.now()
- exp_in = 86400*td.days + td.seconds # just ignore µsec
-
- access_token_data = {
- 'access_token': token.token,
- 'token_type': 'bearer',
- 'expires_in': exp_in}
- return json_response(access_token_data, _disable_cors=True)
- else:
+ if not code:
return json_response({
'error': 'invalid_request',
'error_description':
- 'Invalid code'})
- else:
- return json_response({
- 'error': 'invalid_request',
- 'error_descriptin':
- 'Missing `code` parameter in request'})
+ 'Invalid code.'})
+
+ client = code.client
+ user = code.user
+
+ elif request.args.get('refresh_token'):
+ # Validate a refresh token, then get the client object from the db.
+ refresh_token = OAuthRefreshToken.query.filter(
+ OAuthRefreshToken.token ==
+ request.args.get('refresh_token')).first()
+
+ if not refresh_token:
+ return json_response({
+ 'error': 'invalid_request',
+ 'error_description':
+ 'Invalid refresh token.'})
+
+ client = refresh_token.client
+ user = refresh_token.user
+
+ if client:
+ client_identifier = request.GET.get('client_id')
+
+ if not client_identifier:
+ return json_response({
+ 'error': 'invalid_request',
+ 'error_description':
+ 'Missing client_id in request.'})
+
+ if not client_identifier == client.identifier:
+ return json_response({
+ 'error': 'invalid_client',
+ 'error_description':
+ 'Mismatching client credentials.'})
+
+ if client.type == u'confidential':
+ client_secret = request.GET.get('client_secret')
+
+ if not client_secret:
+ return json_response({
+ 'error': 'invalid_request',
+ 'error_description':
+ 'Missing client_secret in request.'})
+
+ if not client_secret == client.secret:
+ return json_response({
+ 'error': 'invalid_client',
+ 'error_description':
+ 'Mismatching client credentials.'})
+
+
+ access_token_data = create_token(client, user)
+
+ return json_response(access_token_data, _disable_cors=True)
+
+ return json_response({
+ 'error': 'invalid_request',
+ 'error_description':
+ 'Missing `code` or `refresh_token` parameter in request.'})
diff --git a/mediagoblin/plugins/piwigo/__init__.py b/mediagoblin/plugins/piwigo/__init__.py
index 73326e9e..c4da708a 100644
--- a/mediagoblin/plugins/piwigo/__init__.py
+++ b/mediagoblin/plugins/piwigo/__init__.py
@@ -17,6 +17,8 @@
import logging
from mediagoblin.tools import pluginapi
+from mediagoblin.tools.session import SessionManager
+from .tools import PWGSession
_log = logging.getLogger(__name__)
@@ -32,6 +34,9 @@ def setup_plugin():
pluginapi.register_routes(routes)
+ PWGSession.session_manager = SessionManager("pwg_id", "plugins.piwigo")
+
+
hooks = {
'setup': setup_plugin
}
diff --git a/mediagoblin/plugins/piwigo/forms.py b/mediagoblin/plugins/piwigo/forms.py
index 5bb12e62..18cbd5c5 100644
--- a/mediagoblin/plugins/piwigo/forms.py
+++ b/mediagoblin/plugins/piwigo/forms.py
@@ -26,3 +26,19 @@ class AddSimpleForm(wtforms.Form):
# tags = wtforms.FieldList(wtforms.TextField())
category = wtforms.IntegerField()
level = wtforms.IntegerField()
+
+
+_md5_validator = wtforms.validators.Regexp(r"^[0-9a-fA-F]{32}$")
+
+
+class AddForm(wtforms.Form):
+ original_sum = wtforms.TextField(None,
+ [_md5_validator,
+ wtforms.validators.Required()])
+ thumbnail_sum = wtforms.TextField(None,
+ [wtforms.validators.Optional(False),
+ _md5_validator])
+ file_sum = wtforms.TextField(None, [_md5_validator])
+ name = wtforms.TextField()
+ date_creation = wtforms.TextField()
+ categories = wtforms.TextField()
diff --git a/mediagoblin/plugins/piwigo/tools.py b/mediagoblin/plugins/piwigo/tools.py
index 85d77310..400be615 100644
--- a/mediagoblin/plugins/piwigo/tools.py
+++ b/mediagoblin/plugins/piwigo/tools.py
@@ -16,9 +16,11 @@
import logging
+import six
import lxml.etree as ET
-from werkzeug.exceptions import MethodNotAllowed
+from werkzeug.exceptions import MethodNotAllowed, BadRequest
+from mediagoblin.tools.request import setup_user_in_request
from mediagoblin.tools.response import Response
@@ -43,7 +45,7 @@ class PwgNamedArray(list):
def _fill_element_dict(el, data, as_attr=()):
for k, v in data.iteritems():
if k in as_attr:
- if not isinstance(v, basestring):
+ if not isinstance(v, six.string_types):
v = str(v)
el.set(k, v)
else:
@@ -57,7 +59,7 @@ def _fill_element(el, data):
el.text = "1"
else:
el.text = "0"
- elif isinstance(data, basestring):
+ elif isinstance(data, six.string_types):
el.text = data
elif isinstance(data, int):
el.text = str(data)
@@ -105,3 +107,46 @@ class CmdTable(object):
_log.warn("Method %s only allowed for POST", cmd_name)
raise MethodNotAllowed()
return func
+
+
+def check_form(form):
+ if not form.validate():
+ _log.error("form validation failed for form %r", form)
+ for f in form:
+ if len(f.error):
+ _log.error("Errors for %s: %r", f.name, f.errors)
+ raise BadRequest()
+ dump = []
+ for f in form:
+ dump.append("%s=%r" % (f.name, f.data))
+ _log.debug("form: %s", " ".join(dump))
+
+
+class PWGSession(object):
+ session_manager = None
+
+ def __init__(self, request):
+ self.request = request
+ self.in_pwg_session = False
+
+ def __enter__(self):
+ # Backup old state
+ self.old_session = self.request.session
+ self.old_user = self.request.user
+ # Load piwigo session into state
+ self.request.session = self.session_manager.load_session_from_cookie(
+ self.request)
+ setup_user_in_request(self.request)
+ self.in_pwg_session = True
+ return self
+
+ def __exit__(self, *args):
+ # Restore state
+ self.request.session = self.old_session
+ self.request.user = self.old_user
+ self.in_pwg_session = False
+
+ def save_to_cookie(self, response):
+ assert self.in_pwg_session
+ self.session_manager.save_session_to_cookie(self.request.session,
+ self.request, response)
diff --git a/mediagoblin/plugins/piwigo/views.py b/mediagoblin/plugins/piwigo/views.py
index 3dee09cd..b59247ad 100644
--- a/mediagoblin/plugins/piwigo/views.py
+++ b/mediagoblin/plugins/piwigo/views.py
@@ -17,14 +17,15 @@
import logging
import re
-from werkzeug.exceptions import MethodNotAllowed, BadRequest
+from werkzeug.exceptions import MethodNotAllowed, BadRequest, NotImplemented
from werkzeug.wrappers import BaseResponse
-from mediagoblin import mg_globals
from mediagoblin.meddleware.csrf import csrf_exempt
-from mediagoblin.tools.response import render_404
-from .tools import CmdTable, PwgNamedArray, response_xml
-from .forms import AddSimpleForm
+from mediagoblin.submit.lib import check_file_field
+from mediagoblin.auth.lib import fake_login_attempt
+from .tools import CmdTable, PwgNamedArray, response_xml, check_form, \
+ PWGSession
+from .forms import AddSimpleForm, AddForm
_log = logging.getLogger(__name__)
@@ -34,13 +35,25 @@ _log = logging.getLogger(__name__)
def pwg_login(request):
username = request.form.get("username")
password = request.form.get("password")
- _log.info("Login for %r/%r...", username, password)
+ _log.debug("Login for %r/%r...", username, password)
+ user = request.db.User.query.filter_by(username=username).first()
+ if not user:
+ _log.info("User %r not found", username)
+ fake_login_attempt()
+ return False
+ if not user.check_login(password):
+ _log.warn("Wrong password for %r", username)
+ return False
+ _log.info("Logging %r in", username)
+ request.session["user_id"] = user.id
+ request.session.save()
return True
@CmdTable("pwg.session.logout")
def pwg_logout(request):
_log.info("Logout")
+ request.session.delete()
return True
@@ -51,7 +64,11 @@ def pwg_getversion(request):
@CmdTable("pwg.session.getStatus")
def pwg_session_getStatus(request):
- return {'username': "fake_user"}
+ if request.user:
+ username = request.user.username
+ else:
+ username = "guest"
+ return {'username': username}
@CmdTable("pwg.categories.getList")
@@ -92,6 +109,9 @@ def pwg_images_addSimple(request):
dump.append("%s=%r" % (f.name, f.data))
_log.info("addimple: %r %s %r", request.form, " ".join(dump), request.files)
+ if not check_file_field(request, 'image'):
+ raise BadRequest()
+
return {'image_id': 123456, 'url': ''}
@@ -130,17 +150,13 @@ def pwg_images_addChunk(request):
return True
-def possibly_add_cookie(request, response):
- # TODO: We should only add a *real* cookie, if
- # authenticated. And if there is no cookie already.
- if True:
- response.set_cookie(
- 'pwg_id',
- "some_fake_for_now",
- path=request.environ['SCRIPT_NAME'],
- domain=mg_globals.app_config.get('csrf_cookie_domain'),
- secure=(request.scheme.lower() == 'https'),
- httponly=True)
+@CmdTable("pwg.images.add", True)
+def pwg_images_add(request):
+ _log.info("add: %r", request.form)
+ form = AddForm(request.form)
+ check_form(form)
+
+ return {'image_id': 123456, 'url': ''}
@csrf_exempt
@@ -153,15 +169,15 @@ def ws_php(request):
if not func:
_log.warn("wsphp: Unhandled %s %r %r", request.method,
request.args, request.form)
- return render_404(request)
-
- result = func(request)
+ raise NotImplemented()
- if isinstance(result, BaseResponse):
- return result
+ with PWGSession(request) as session:
+ result = func(request)
- response = response_xml(result)
+ if isinstance(result, BaseResponse):
+ return result
- possibly_add_cookie(request, response)
+ response = response_xml(result)
+ session.save_to_cookie(response)
- return response
+ return response