Skip to content
Snippets Groups Projects
Commit b3521e0f authored by J. Cliff Dyer's avatar J. Cliff Dyer
Browse files

Add endpoint to get SAML providers for a user.

View is combined with user SSO views.

Includes a new version of the view that takes explicit "username" or "email".

OC-4285
parent 75a739e2
Branches
Tags
No related merge requests found
......@@ -4,15 +4,16 @@ Tests for the Third Party Auth REST API
import unittest
import ddt
from django.urls import reverse
import six
from django.conf import settings
from django.http import QueryDict
from django.test.utils import override_settings
from django.urls import reverse
from mock import patch
from provider.constants import CONFIDENTIAL
from provider.oauth2.models import Client, AccessToken
from openedx.core.lib.api.permissions import ApiKeyHeaderPermission
from rest_framework.test import APITestCase
from django.conf import settings
from django.test.utils import override_settings
from social_django.models import UserSocialAuth
from student.tests.factories import UserFactory
......@@ -29,6 +30,7 @@ ALICE_USERNAME = "alice"
CARL_USERNAME = "carl"
STAFF_USERNAME = "staff"
ADMIN_USERNAME = "admin"
NONEXISTENT_USERNAME = "nobody"
# These users will be created and linked to third party accounts:
LINKED_USERS = (ALICE_USERNAME, STAFF_USERNAME, ADMIN_USERNAME)
PASSWORD = "edx"
......@@ -62,9 +64,10 @@ class TpaAPITestCase(ThirdPartyAuthTestMixin, APITestCase):
make_staff = (username == STAFF_USERNAME) or make_superuser
user = UserFactory.create(
username=username,
email='{}@example.com'.format(username),
password=PASSWORD,
is_staff=make_staff,
is_superuser=make_superuser
is_superuser=make_superuser,
)
UserSocialAuth.objects.create(
user=user,
......@@ -77,15 +80,13 @@ class TpaAPITestCase(ThirdPartyAuthTestMixin, APITestCase):
uid='{}:remote_{}'.format(testshib.slug, username),
)
# Create another user not linked to any providers:
UserFactory.create(username=CARL_USERNAME, password=PASSWORD)
UserFactory.create(username=CARL_USERNAME, email='{}@example.com'.format(CARL_USERNAME), password=PASSWORD)
@override_settings(EDX_API_KEY=VALID_API_KEY)
@ddt.ddt
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class UserViewAPITests(TpaAPITestCase):
class UserViewsMixin(object):
"""
Test the Third Party Auth User REST API
Generic TestCase to exercise the v1 and v2 UserViews.
"""
def expected_active(self, username):
......@@ -124,7 +125,7 @@ class UserViewAPITests(TpaAPITestCase):
@ddt.unpack
def test_list_connected_providers(self, request_user, target_user, expect_result):
self.client.login(username=request_user, password=PASSWORD)
url = reverse('third_party_auth_users_api', kwargs={'username': target_user})
url = self.make_url({'username': target_user})
response = self.client.get(url)
self.assertEqual(response.status_code, expect_result)
......@@ -140,14 +141,87 @@ class UserViewAPITests(TpaAPITestCase):
(None, ALICE_USERNAME, 403),
)
@ddt.unpack
def test_list_connected_providers__withapi_key(self, api_key, target_user, expect_result):
url = reverse('third_party_auth_users_api', kwargs={'username': target_user})
def test_list_connected_providers_with_api_key(self, api_key, target_user, expect_result):
url = self.make_url({'username': target_user})
response = self.client.get(url, HTTP_X_EDX_API_KEY=api_key)
self.assertEqual(response.status_code, expect_result)
if expect_result == 200:
self.assertIn("active", response.data)
self.assertItemsEqual(response.data["active"], self.expected_active(target_user))
@ddt.data(
(True, ALICE_USERNAME, 200, True),
(True, CARL_USERNAME, 200, False),
(False, ALICE_USERNAME, 200, True),
(False, CARL_USERNAME, 403, None),
)
@ddt.unpack
def test_allow_unprivileged_response(self, allow_unprivileged, requesting_user, expect, include_remote_id):
self.client.login(username=requesting_user, password=PASSWORD)
with override_settings(ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY=allow_unprivileged):
url = self.make_url({'username': ALICE_USERNAME})
response = self.client.get(url)
self.assertEqual(response.status_code, expect)
if response.status_code == 200:
self.assertGreater(len(response.data['active']), 0)
for provider_data in response.data['active']:
self.assertEqual(include_remote_id, 'remote_id' in provider_data)
def test_allow_query_by_email(self):
self.client.login(username=ALICE_USERNAME, password=PASSWORD)
url = self.make_url({'email': '{}@example.com'.format(ALICE_USERNAME)})
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
self.assertGreater(len(response.data['active']), 0)
def test_throttling(self):
# Default throttle is 10/min. Make 11 requests to verify
throttling_user = UserFactory.create(password=PASSWORD)
self.client.login(username=throttling_user.username, password=PASSWORD)
url = self.make_url({'username': ALICE_USERNAME})
with override_settings(ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY=True):
for _ in range(10):
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
response = self.client.get(url)
self.assertEqual(response.status_code, 200)
@override_settings(EDX_API_KEY=VALID_API_KEY)
@ddt.ddt
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class UserViewAPITests(UserViewsMixin, TpaAPITestCase):
"""
Test the Third Party Auth User REST API
"""
def make_url(self, identifier):
"""
Return the view URL, with the identifier provided
"""
return reverse(
'third_party_auth_users_api',
kwargs={'username': identifier.values()[0]}
)
@override_settings(EDX_API_KEY=VALID_API_KEY)
@ddt.ddt
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
class UserViewV2APITests(UserViewsMixin, TpaAPITestCase):
"""
Test the Third Party Auth User REST API
"""
def make_url(self, identifier):
"""
Return the view URL, with the identifier provided
"""
return '?'.join([
reverse('third_party_auth_users_api_v2'),
six.moves.urllib.parse.urlencode(identifier)
])
@override_settings(EDX_API_KEY=VALID_API_KEY)
@ddt.ddt
......
......@@ -3,7 +3,7 @@
from django.conf import settings
from django.conf.urls import url
from .views import UserMappingView, UserView
from .views import UserMappingView, UserView, UserViewV2
PROVIDER_PATTERN = r'(?P<provider_id>[\w.+-]+)(?:\:(?P<idp_slug>[\w.+-]+))?'
......@@ -14,6 +14,7 @@ urlpatterns = [
UserView.as_view(),
name='third_party_auth_users_api',
),
url(r'^v0/users/', UserViewV2.as_view(), name='third_party_auth_users_api_v2'),
url(
r'^v0/providers/{provider_pattern}/users$'.format(provider_pattern=PROVIDER_PATTERN),
UserMappingView.as_view(),
......
"""
Third Party Auth REST API views
"""
from collections import namedtuple
from django.conf import settings
from django.contrib.auth.models import User
from django.db.models import Q
from django.http import Http404
from edx_rest_framework_extensions.auth.session.authentication import SessionAuthenticationAllowInactiveUser
from rest_framework import exceptions, status
from rest_framework import exceptions, status, throttling
from rest_framework.generics import ListAPIView
from rest_framework.response import Response
from rest_framework.views import APIView
......@@ -20,13 +24,132 @@ from third_party_auth.api.permissions import ThirdPartyAuthProviderApiPermission
from third_party_auth.provider import Registry
class UserView(APIView):
class ProviderBaseThrottle(throttling.UserRateThrottle):
"""
Base throttle for provider queries
"""
def allow_request(self, request, view):
"""
Only throttle unprivileged requests.
"""
if view.is_unprivileged_query(request, view.get_identifier_for_requested_user(request)):
return super(ProviderBaseThrottle, self).allow_request(request, view)
return True
class ProviderBurstThrottle(ProviderBaseThrottle):
"""
Maximum number of provider requests in a quick burst.
"""
rate = settings.TPA_PROVIDER_BURST_THROTTLE # Default '10/min'
class ProviderSustainedThrottle(ProviderBaseThrottle):
"""
Maximum number of provider requests over time.
"""
rate = settings.TPA_PROVIDER_SUSTAINED_THROTTLE # Default '50/day'
class BaseUserView(APIView):
"""
Common core of UserView and UserViewV2
"""
identifier = namedtuple('identifier', ['kind', 'value'])
identifier_kinds = ['email', 'username']
authentication_classes = (
# Users may want to view/edit the providers used for authentication before they've
# activated their account, so we allow inactive users.
OAuth2AuthenticationAllowInactiveUser,
SessionAuthenticationAllowInactiveUser,
)
throttle_classes = [ProviderSustainedThrottle, ProviderBurstThrottle]
def do_get(self, request, identifier):
"""
Fulfill the request, now that the identifier has been specified.
"""
is_unprivileged = self.is_unprivileged_query(request, identifier)
if is_unprivileged:
if not getattr(settings, 'ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY', False):
return Response(status=status.HTTP_403_FORBIDDEN)
try:
user = User.objects.get(**{identifier.kind: identifier.value})
except User.DoesNotExist:
return Response(status=status.HTTP_404_NOT_FOUND)
providers = pipeline.get_provider_user_states(user)
active_providers = [
self.get_provider_data(assoc, is_unprivileged)
for assoc in providers if assoc.has_account
]
# In the future this can be trivially modified to return the inactive/disconnected providers as well.
return Response({
"active": active_providers
})
def get_provider_data(self, assoc, is_unprivileged):
"""
Return the data for the specified provider.
If the request is unprivileged, do not return the remote ID of the user.
"""
provider_data = {
"provider_id": assoc.provider.provider_id,
"name": assoc.provider.name,
}
if not is_unprivileged:
provider_data["remote_id"] = assoc.remote_id
return provider_data
def is_unprivileged_query(self, request, identifier):
"""
Return True if a non-superuser requests information about another user.
Params must be a dict that includes only one of 'username' or 'email'
"""
if identifier.kind not in self.identifier_kinds:
# This is already checked before we get here, so raise a 500 error
# if the check fails.
raise ValueError("Identifier kind {} not in {}".format(identifier.kind, self.identifier_kinds))
self_request = False
if identifier == self.identifier('username', request.user.username):
self_request = True
elif identifier.kind == 'email' and getattr(identifier, 'value', object()) == request.user.email:
# AnonymousUser does not have an email attribute, so fall back to
# something that will never compare equal to the provided email.
self_request = True
if self_request:
# We can always ask for our own provider
return False
# We are querying permissions for a user other than the current user.
if not request.user.is_superuser and not ApiKeyHeaderPermission().has_permission(request, self):
# The user does not have elevated permissions.
return True
return False
class UserView(BaseUserView):
"""
List the third party auth accounts linked to the specified user account.
[DEPRECATED]
This view uses heuristics to guess whether the provided identifier is a
username or email address. Instead, use /api/third_party_auth/v0/users/
and specify ?username=foo or ?email=foo@exmaple.com.
**Example Request**
GET /api/third_party_auth/v0/users/{username}
GET /api/third_party_auth/v0/users/{email@example.com}
**Response Values**
......@@ -45,18 +168,11 @@ class UserView(APIView):
is what is used to link the user to their edX account during
login.
"""
authentication_classes = (
# Users may want to view/edit the providers used for authentication before they've
# activated their account, so we allow inactive users.
OAuth2AuthenticationAllowInactiveUser,
SessionAuthenticationAllowInactiveUser,
)
def get(self, request, username):
"""Create, read, or update enrollment information for a user.
"""Read provider information for a user.
HTTP Endpoint for all CRUD operations for a user course enrollment. Allows creation, reading, and
updates of the current enrollment for a particular course.
Allows reading the list of providers for a specified user.
Args:
request (Request): The HTTP GET request
......@@ -66,34 +182,80 @@ class UserView(APIView):
JSON serialized list of the providers linked to this user.
"""
if request.user.username != username:
# We are querying permissions for a user other than the current user.
if not request.user.is_superuser and not ApiKeyHeaderPermission().has_permission(request, self):
# Return a 403 (Unauthorized) without validating 'username', so that we
# do not let users probe the existence of other user accounts.
return Response(status=status.HTTP_403_FORBIDDEN)
identifier = self.get_identifier_for_requested_user(request)
return self.do_get(request, identifier)
try:
user = User.objects.get(username=username)
except User.DoesNotExist:
return Response(status=status.HTTP_404_NOT_FOUND)
def get_identifier_for_requested_user(self, _request):
"""
Return an identifier namedtuple for the requested user.
"""
if u'@' in self.kwargs[u'username']:
id_kind = u'email'
else:
id_kind = u'username'
return self.identifier(id_kind, self.kwargs[u'username'])
providers = pipeline.get_provider_user_states(user)
active_providers = [
{
"provider_id": assoc.provider.provider_id,
"name": assoc.provider.name,
"remote_id": assoc.remote_id,
}
for assoc in providers if assoc.has_account
]
# TODO: When removing deprecated UserView, rename this view to UserView.
class UserViewV2(BaseUserView):
"""
List the third party auth accounts linked to the specified user account.
# In the future this can be trivially modified to return the inactive/disconnected providers as well.
**Example Request**
return Response({
"active": active_providers
})
GET /api/third_party_auth/v0/users/?username={username}
GET /api/third_party_auth/v0/users/?email={email@example.com}
**Response Values**
If the request for information about the user is successful, an HTTP 200 "OK" response
is returned.
The HTTP 200 response has the following values.
* active: A list of all the third party auth providers currently linked
to the given user's account. Each object in this list has the
following attributes:
* provider_id: The unique identifier of this provider (string)
* name: The name of this provider (string)
* remote_id: The ID of the user according to the provider. This ID
is what is used to link the user to their edX account during
login.
"""
def get(self, request):
"""
Read provider information for a user.
Allows reading the list of providers for a specified user.
Args:
request (Request): The HTTP GET request
Request Parameters:
Must provide one of 'email' or 'username'. If both are provided,
the username will be ignored.
Return:
JSON serialized list of the providers linked to this user.
"""
identifier = self.get_identifier_for_requested_user(request)
return self.do_get(request, identifier)
def get_identifier_for_requested_user(self, request):
"""
Return an identifier namedtuple for the requested user.
"""
identifier = None
for id_kind in self.identifier_kinds:
if id_kind in request.GET:
identifier = self.identifier(id_kind, request.GET[id_kind])
break
if identifier is None:
raise exceptions.ValidationError(u"Must provide one of {}".format(self.identifier_kinds))
return identifier
class UserMappingView(ListAPIView):
......@@ -195,7 +357,7 @@ class UserMappingView(ListAPIView):
# When using multi-IdP backend, we only retrieve the ones that are for current IdP.
# test if the current provider has a slug
uid = self.provider.get_social_auth_uid('uid')
if uid is not 'uid':
if uid != 'uid':
# if yes, we add a filter for the slug on uid column
query_set = query_set.filter(uid__startswith=uid[:-3])
......@@ -207,13 +369,13 @@ class UserMappingView(ListAPIView):
if usernames:
usernames = ','.join(usernames)
usernames = set(usernames.split(',')) if usernames else set()
if len(usernames):
if usernames:
query = query | Q(user__username__in=usernames)
if remote_ids:
remote_ids = ','.join(remote_ids)
remote_ids = set(remote_ids.split(',')) if remote_ids else set()
if len(remote_ids):
if remote_ids:
query = query | Q(uid__in=[self.provider.get_social_auth_uid(remote_id) for remote_id in remote_ids])
return query_set.filter(query)
......
......@@ -707,6 +707,9 @@ if FEATURES.get('ENABLE_THIRD_PARTY_AUTH'):
# dict with an arbitrary 'secret_key' and a 'url'.
THIRD_PARTY_AUTH_CUSTOM_AUTH_FORMS = AUTH_TOKENS.get('THIRD_PARTY_AUTH_CUSTOM_AUTH_FORMS', {})
# Whether to allow unprivileged users to discover SSO providers for arbitrary usernames.
ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY = ENV_TOKENS.get('ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY', False)
##### OAUTH2 Provider ##############
if FEATURES.get('ENABLE_OAUTH2_PROVIDER'):
OAUTH_OIDC_ISSUER = ENV_TOKENS['OAUTH_OIDC_ISSUER']
......@@ -722,6 +725,10 @@ if FEATURES.get('ENABLE_OAUTH2_PROVIDER'):
OAUTH_ID_TOKEN_EXPIRATION = ENV_TOKENS.get('OAUTH_ID_TOKEN_EXPIRATION', OAUTH_ID_TOKEN_EXPIRATION)
OAUTH_DELETE_EXPIRED = ENV_TOKENS.get('OAUTH_DELETE_EXPIRED', OAUTH_DELETE_EXPIRED)
##### THIRD_PARTY_AUTH #############
TPA_PROVIDER_BURST_THROTTLE = ENV_TOKENS.get('TPA_PROVIDER_BURST_THROTTLE', TPA_PROVIDER_BURST_THROTTLE)
TPA_PROVIDER_SUSTAINED_THROTTLE = ENV_TOKENS.get('TPA_PROVIDER_SUSTAINED_THROTTLE', TPA_PROVIDER_SUSTAINED_THROTTLE)
##### ADVANCED_SECURITY_CONFIG #####
ADVANCED_SECURITY_CONFIG = ENV_TOKENS.get('ADVANCED_SECURITY_CONFIG', {})
......
......@@ -521,6 +521,10 @@ OAUTH2_PROVIDER_APPLICATION_MODEL = 'oauth2_provider.Application'
OAUTH_DELETE_EXPIRED = True
OAUTH_ID_TOKEN_EXPIRATION = 60 * 60
################################## THIRD_PARTY_AUTH CONFIGURATION #############################
TPA_PROVIDER_BURST_THROTTLE = '10/min'
TPA_PROVIDER_SUSTAINED_THROTTLE = '50/hr'
################################## TEMPLATE CONFIGURATION #####################################
# Mako templating
import tempfile
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment