Skip to content
Snippets Groups Projects
Commit e85eb91e authored by Greg Price's avatar Greg Price
Browse files

Add an endpoint to exchange OAuth access tokens

This allows the holder of a third-party access token (e.g. from Google
or Facebook) to get a first-party access token for the edX account
linked to the given access token.
parent 38a9d33b
No related branches found
No related tags found
No related merge requests found
......@@ -26,7 +26,17 @@ from uuid import uuid4
# import settings from LMS for consistent behavior with CMS
# pylint: disable=unused-import
from lms.envs.test import (WIKI_ENABLED, PLATFORM_NAME, SITE_NAME, DEFAULT_FILE_STORAGE, MEDIA_ROOT, MEDIA_URL)
from lms.envs.test import (
WIKI_ENABLED,
PLATFORM_NAME,
SITE_NAME,
DEFAULT_FILE_STORAGE,
MEDIA_ROOT,
MEDIA_URL,
# This is practically unused but needed by the oauth2_provider package, which
# some tests in common/ rely on.
OAUTH_OIDC_ISSUER,
)
# mongo connection settings
MONGO_PORT_NUM = int(os.environ.get('EDXAPP_TEST_MONGO_PORT', '27017'))
......
"""
Forms to support third-party to first-party OAuth 2.0 access token exchange
"""
from django.contrib.auth.models import User
from django.forms import CharField
from oauth2_provider.constants import SCOPE_NAMES
import provider.constants
from provider.forms import OAuthForm, OAuthValidationError
from provider.oauth2.forms import ScopeChoiceField, ScopeMixin
from provider.oauth2.models import Client
from requests import HTTPError
from social.backends import oauth as social_oauth
from third_party_auth import pipeline
class AccessTokenExchangeForm(ScopeMixin, OAuthForm):
"""Form for access token exchange endpoint"""
access_token = CharField(required=False)
scope = ScopeChoiceField(choices=SCOPE_NAMES, required=False)
client_id = CharField(required=False)
def __init__(self, request, *args, **kwargs):
super(AccessTokenExchangeForm, self).__init__(*args, **kwargs)
self.request = request
def _require_oauth_field(self, field_name):
"""
Raise an appropriate OAuthValidationError error if the field is missing
"""
field_val = self.cleaned_data.get(field_name)
if not field_val:
raise OAuthValidationError(
{
"error": "invalid_request",
"error_description": "{} is required".format(field_name),
}
)
return field_val
def clean_access_token(self):
return self._require_oauth_field("access_token")
def clean_client_id(self):
return self._require_oauth_field("client_id")
def clean(self):
if self._errors:
return {}
backend = self.request.social_strategy.backend
if not isinstance(backend, social_oauth.BaseOAuth2):
raise OAuthValidationError(
{
"error": "invalid_request",
"error_description": "{} is not a supported provider".format(backend.name),
}
)
self.request.session[pipeline.AUTH_ENTRY_KEY] = pipeline.AUTH_ENTRY_API
client_id = self.cleaned_data["client_id"]
try:
client = Client.objects.get(client_id=client_id)
except Client.DoesNotExist:
raise OAuthValidationError(
{
"error": "invalid_client",
"error_description": "{} is not a valid client_id".format(client_id),
}
)
if client.client_type != provider.constants.PUBLIC:
raise OAuthValidationError(
{
# invalid_client isn't really the right code, but this mirrors
# https://github.com/edx/django-oauth2-provider/blob/edx/provider/oauth2/forms.py#L331
"error": "invalid_client",
"error_description": "{} is not a public client".format(client_id),
}
)
self.cleaned_data["client"] = client
user = None
try:
user = backend.do_auth(self.cleaned_data.get("access_token"))
except HTTPError:
pass
if user and isinstance(user, User):
self.cleaned_data["user"] = user
else:
# Ensure user does not re-enter the pipeline
self.request.social_strategy.clean_partial_pipeline()
raise OAuthValidationError(
{
"error": "invalid_grant",
"error_description": "access_token is not valid",
}
)
return self.cleaned_data
"""
A models.py is required to make this an app (until we move to Django 1.7)
"""
"""
Tests for OAuth token exchange forms
"""
import unittest
from django.conf import settings
from django.contrib.sessions.middleware import SessionMiddleware
from django.test import TestCase
from django.test.client import RequestFactory
import httpretty
from provider import scope
import social.apps.django_app.utils as social_utils
from oauth_exchange.forms import AccessTokenExchangeForm
from oauth_exchange.tests.utils import (
AccessTokenExchangeTestMixin,
AccessTokenExchangeMixinFacebook,
AccessTokenExchangeMixinGoogle
)
class AccessTokenExchangeFormTest(AccessTokenExchangeTestMixin):
"""
Mixin that defines test cases for AccessTokenExchangeForm
"""
def setUp(self):
super(AccessTokenExchangeFormTest, self).setUp()
self.request = RequestFactory().post("dummy_url")
SessionMiddleware().process_request(self.request)
self.request.social_strategy = social_utils.load_strategy(self.request, self.BACKEND)
def _assert_error(self, data, expected_error, expected_error_description):
form = AccessTokenExchangeForm(request=self.request, data=data)
self.assertEqual(
form.errors,
{"error": expected_error, "error_description": expected_error_description}
)
self.assertNotIn("partial_pipeline", self.request.session)
def _assert_success(self, data, expected_scopes):
form = AccessTokenExchangeForm(request=self.request, data=data)
self.assertTrue(form.is_valid())
self.assertEqual(form.cleaned_data["user"], self.user)
self.assertEqual(form.cleaned_data["client"], self.oauth_client)
self.assertEqual(scope.to_names(form.cleaned_data["scope"]), expected_scopes)
# This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
@httpretty.activate
class AccessTokenExchangeFormTestFacebook(
AccessTokenExchangeFormTest,
AccessTokenExchangeMixinFacebook,
TestCase
):
"""
Tests for AccessTokenExchangeForm used with Facebook
"""
pass
# This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
@httpretty.activate
class AccessTokenExchangeFormTestGoogle(
AccessTokenExchangeFormTest,
AccessTokenExchangeMixinGoogle,
TestCase
):
"""
Tests for AccessTokenExchangeForm used with Google
"""
pass
"""
Tests for OAuth token exchange views
"""
from datetime import timedelta
import json
import mock
import unittest
from django.conf import settings
from django.core.urlresolvers import reverse
from django.test import TestCase
import httpretty
import provider.constants
from provider import scope
from provider.oauth2.models import AccessToken
from oauth_exchange.tests.utils import (
AccessTokenExchangeTestMixin,
AccessTokenExchangeMixinFacebook,
AccessTokenExchangeMixinGoogle
)
class AccessTokenExchangeViewTest(AccessTokenExchangeTestMixin):
"""
Mixin that defines test cases for AccessTokenExchangeView
"""
def setUp(self):
super(AccessTokenExchangeViewTest, self).setUp()
self.url = reverse("exchange_access_token", kwargs={"backend": self.BACKEND})
def _assert_error(self, data, expected_error, expected_error_description):
response = self.client.post(self.url, data)
self.assertEqual(response.status_code, 400)
self.assertEqual(response["Content-Type"], "application/json")
self.assertEqual(
json.loads(response.content),
{"error": expected_error, "error_description": expected_error_description}
)
self.assertNotIn("partial_pipeline", self.client.session)
def _assert_success(self, data, expected_scopes):
response = self.client.post(self.url, data)
self.assertEqual(response.status_code, 200)
self.assertEqual(response["Content-Type"], "application/json")
content = json.loads(response.content)
self.assertEqual(set(content.keys()), {"access_token", "token_type", "expires_in", "scope"})
self.assertEqual(content["token_type"], "Bearer")
self.assertLessEqual(
timedelta(seconds=int(content["expires_in"])),
provider.constants.EXPIRE_DELTA_PUBLIC
)
self.assertEqual(content["scope"], " ".join(expected_scopes))
token = AccessToken.objects.get(token=content["access_token"])
self.assertEqual(token.user, self.user)
self.assertEqual(token.client, self.oauth_client)
self.assertEqual(scope.to_names(token.scope), expected_scopes)
def test_single_access_token(self):
def extract_token(response):
return json.loads(response.content)["access_token"]
self._setup_provider_response(success=True)
for single_access_token in [True, False]:
with mock.patch(
"oauth_exchange.views.constants.SINGLE_ACCESS_TOKEN",
single_access_token
):
first_response = self.client.post(self.url, self.data)
second_response = self.client.post(self.url, self.data)
self.assertEqual(
extract_token(first_response) == extract_token(second_response),
single_access_token
)
def test_get_method(self):
response = self.client.get(self.url, self.data)
self.assertEqual(response.status_code, 400)
self.assertEqual(
json.loads(response.content),
{
"error": "invalid_request",
"error_description": "Only POST requests allowed.",
}
)
def test_invalid_provider(self):
url = reverse("exchange_access_token", kwargs={"backend": "invalid"})
response = self.client.post(url, self.data)
self.assertEqual(response.status_code, 404)
# This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
@httpretty.activate
class AccessTokenExchangeViewTestFacebook(
AccessTokenExchangeViewTest,
AccessTokenExchangeMixinFacebook,
TestCase
):
"""
Tests for AccessTokenExchangeView used with Facebook
"""
pass
# This is necessary because cms does not implement third party auth
@unittest.skipUnless(settings.FEATURES.get("ENABLE_THIRD_PARTY_AUTH"), "third party auth not enabled")
@httpretty.activate
class AccessTokenExchangeViewTestGoogle(
AccessTokenExchangeViewTest,
AccessTokenExchangeMixinGoogle,
TestCase
):
"""
Tests for AccessTokenExchangeView used with Google
"""
pass
"""
Test utilities for OAuth access token exchange
"""
import json
import httpretty
import provider.constants
from provider.oauth2.models import Client
from social.apps.django_app.default.models import UserSocialAuth
from student.tests.factories import UserFactory
class AccessTokenExchangeTestMixin(object):
"""
A mixin to define test cases for access token exchange. The following
methods must be implemented by subclasses:
* _assert_error(data, expected_error, expected_error_description)
* _assert_success(data, expected_scopes)
"""
def setUp(self):
super(AccessTokenExchangeTestMixin, self).setUp()
self.client_id = "test_client_id"
self.oauth_client = Client.objects.create(
client_id=self.client_id,
client_type=provider.constants.PUBLIC
)
self.social_uid = "test_social_uid"
self.user = UserFactory()
UserSocialAuth.objects.create(user=self.user, provider=self.BACKEND, uid=self.social_uid)
self.access_token = "test_access_token"
# Initialize to minimal data
self.data = {
"access_token": self.access_token,
"client_id": self.client_id,
}
def _setup_provider_response(self, success):
"""
Register a mock response for the third party user information endpoint;
success indicates whether the response status code should be 200 or 400
"""
if success:
status = 200
body = json.dumps({self.UID_FIELD: self.social_uid})
else:
status = 400
body = json.dumps({})
httpretty.register_uri(
httpretty.GET,
self.USER_URL,
body=body,
status=status,
content_type="application/json"
)
def _assert_error(self, _data, _expected_error, _expected_error_description):
"""
Given request data, execute a test and check that the expected error
was returned (along with any other appropriate assertions).
"""
raise NotImplementedError()
def _assert_success(self, data, expected_scopes):
"""
Given request data, execute a test and check that the expected scopes
were returned (along with any other appropriate assertions).
"""
raise NotImplementedError()
def test_minimal(self):
self._setup_provider_response(success=True)
self._assert_success(self.data, expected_scopes=[])
def test_scopes(self):
self._setup_provider_response(success=True)
self.data["scope"] = "profile email"
self._assert_success(self.data, expected_scopes=["profile", "email"])
def test_missing_fields(self):
for field in ["access_token", "client_id"]:
data = dict(self.data)
del data[field]
self._assert_error(data, "invalid_request", "{} is required".format(field))
def test_invalid_client(self):
self.data["client_id"] = "nonexistent_client"
self._assert_error(
self.data,
"invalid_client",
"nonexistent_client is not a valid client_id"
)
def test_confidential_client(self):
self.oauth_client.client_type = provider.constants.CONFIDENTIAL
self.oauth_client.save()
self._assert_error(
self.data,
"invalid_client",
"test_client_id is not a public client"
)
def test_invalid_acess_token(self):
self._setup_provider_response(success=False)
self._assert_error(self.data, "invalid_grant", "access_token is not valid")
def test_no_linked_user(self):
UserSocialAuth.objects.all().delete()
self._setup_provider_response(success=True)
self._assert_error(self.data, "invalid_grant", "access_token is not valid")
class AccessTokenExchangeMixinFacebook(object):
"""Tests access token exchange with the Facebook backend"""
BACKEND = "facebook"
USER_URL = "https://graph.facebook.com/me"
# In facebook responses, the "id" field is used as the user's identifier
UID_FIELD = "id"
class AccessTokenExchangeMixinGoogle(object):
"""Tests access token exchange with the Google backend"""
BACKEND = "google-oauth2"
USER_URL = "https://www.googleapis.com/oauth2/v1/userinfo"
# In google-oauth2 responses, the "email" field is used as the user's identifier
UID_FIELD = "email"
"""
Views to support third-party to first-party OAuth 2.0 access token exchange
"""
from django.utils.decorators import method_decorator
from django.views.decorators.csrf import csrf_exempt
from provider import constants
from provider.oauth2.views import AccessTokenView as AccessTokenView
import social.apps.django_app.utils as social_utils
from oauth_exchange.forms import AccessTokenExchangeForm
class AccessTokenExchangeView(AccessTokenView):
"""View for access token exchange"""
@method_decorator(csrf_exempt)
@method_decorator(social_utils.strategy("social:complete"))
def dispatch(self, *args, **kwargs):
return super(AccessTokenExchangeView, self).dispatch(*args, **kwargs)
def get(self, request, _backend):
return super(AccessTokenExchangeView, self).get(request)
def post(self, request, _backend):
form = AccessTokenExchangeForm(request=request, data=request.POST)
if not form.is_valid():
return self.error_response(form.errors)
user = form.cleaned_data["user"]
scope = form.cleaned_data["scope"]
client = form.cleaned_data["client"]
if constants.SINGLE_ACCESS_TOKEN:
edx_access_token = self.get_access_token(request, user, scope, client)
else:
edx_access_token = self.create_access_token(request, user, scope, client)
return self.access_token_response(edx_access_token)
......@@ -1548,6 +1548,8 @@ INSTALLED_APPS = (
'provider.oauth2',
'oauth2_provider',
'oauth_exchange',
# For the wiki
'wiki', # The new django-wiki from benjaoming
'django_notify',
......
......@@ -5,6 +5,7 @@ from django.conf.urls.static import static
import django.contrib.auth.views
from microsite_configuration import microsite
import oauth_exchange.views
# Uncomment the next two lines to enable the admin:
if settings.DEBUG or settings.FEATURES.get('ENABLE_DJANGO_ADMIN_SITE'):
......@@ -585,6 +586,11 @@ if settings.FEATURES.get('AUTOMATIC_AUTH_FOR_TESTING'):
if settings.FEATURES.get('ENABLE_THIRD_PARTY_AUTH'):
urlpatterns += (
url(r'', include('third_party_auth.urls')),
url(
r'^oauth2/exchange_access_token/(?P<backend>[^/]+)/$',
oauth_exchange.views.AccessTokenExchangeView.as_view(),
name="exchange_access_token"
),
url(r'^login_oauth_token/(?P<backend>[^/]+)/$', 'student.views.login_oauth_token'),
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment