From b3521e0f8832ce4e7b2dafc5b88b7a28877481e0 Mon Sep 17 00:00:00 2001 From: "J. Cliff Dyer" <jcd@sdf.org> Date: Tue, 10 Apr 2018 17:03:42 -0400 Subject: [PATCH] Add endpoint to get SAML providers for a user. View is combined with user SSO views. Includes a new version of the view that takes explicit "username" or "email". OC-4285 --- .../third_party_auth/api/tests/test_views.py | 98 +++++++- .../djangoapps/third_party_auth/api/urls.py | 3 +- .../djangoapps/third_party_auth/api/views.py | 236 +++++++++++++++--- lms/envs/aws.py | 7 + lms/envs/common.py | 4 + 5 files changed, 298 insertions(+), 50 deletions(-) diff --git a/common/djangoapps/third_party_auth/api/tests/test_views.py b/common/djangoapps/third_party_auth/api/tests/test_views.py index 2fac4d2a5c7..cb5e15aef9b 100644 --- a/common/djangoapps/third_party_auth/api/tests/test_views.py +++ b/common/djangoapps/third_party_auth/api/tests/test_views.py @@ -4,15 +4,16 @@ Tests for the Third Party Auth REST API import unittest import ddt -from django.urls import reverse +import six +from django.conf import settings from django.http import QueryDict +from django.test.utils import override_settings +from django.urls import reverse from mock import patch from provider.constants import CONFIDENTIAL from provider.oauth2.models import Client, AccessToken from openedx.core.lib.api.permissions import ApiKeyHeaderPermission from rest_framework.test import APITestCase -from django.conf import settings -from django.test.utils import override_settings from social_django.models import UserSocialAuth from student.tests.factories import UserFactory @@ -29,6 +30,7 @@ ALICE_USERNAME = "alice" CARL_USERNAME = "carl" STAFF_USERNAME = "staff" ADMIN_USERNAME = "admin" +NONEXISTENT_USERNAME = "nobody" # These users will be created and linked to third party accounts: LINKED_USERS = (ALICE_USERNAME, STAFF_USERNAME, ADMIN_USERNAME) PASSWORD = "edx" @@ -62,9 +64,10 @@ class TpaAPITestCase(ThirdPartyAuthTestMixin, APITestCase): make_staff = (username == STAFF_USERNAME) or make_superuser user = UserFactory.create( username=username, + email='{}@example.com'.format(username), password=PASSWORD, is_staff=make_staff, - is_superuser=make_superuser + is_superuser=make_superuser, ) UserSocialAuth.objects.create( user=user, @@ -77,15 +80,13 @@ class TpaAPITestCase(ThirdPartyAuthTestMixin, APITestCase): uid='{}:remote_{}'.format(testshib.slug, username), ) # Create another user not linked to any providers: - UserFactory.create(username=CARL_USERNAME, password=PASSWORD) + UserFactory.create(username=CARL_USERNAME, email='{}@example.com'.format(CARL_USERNAME), password=PASSWORD) -@override_settings(EDX_API_KEY=VALID_API_KEY) @ddt.ddt -@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') -class UserViewAPITests(TpaAPITestCase): +class UserViewsMixin(object): """ - Test the Third Party Auth User REST API + Generic TestCase to exercise the v1 and v2 UserViews. """ def expected_active(self, username): @@ -124,7 +125,7 @@ class UserViewAPITests(TpaAPITestCase): @ddt.unpack def test_list_connected_providers(self, request_user, target_user, expect_result): self.client.login(username=request_user, password=PASSWORD) - url = reverse('third_party_auth_users_api', kwargs={'username': target_user}) + url = self.make_url({'username': target_user}) response = self.client.get(url) self.assertEqual(response.status_code, expect_result) @@ -140,14 +141,87 @@ class UserViewAPITests(TpaAPITestCase): (None, ALICE_USERNAME, 403), ) @ddt.unpack - def test_list_connected_providers__withapi_key(self, api_key, target_user, expect_result): - url = reverse('third_party_auth_users_api', kwargs={'username': target_user}) + def test_list_connected_providers_with_api_key(self, api_key, target_user, expect_result): + url = self.make_url({'username': target_user}) response = self.client.get(url, HTTP_X_EDX_API_KEY=api_key) self.assertEqual(response.status_code, expect_result) if expect_result == 200: self.assertIn("active", response.data) self.assertItemsEqual(response.data["active"], self.expected_active(target_user)) + @ddt.data( + (True, ALICE_USERNAME, 200, True), + (True, CARL_USERNAME, 200, False), + (False, ALICE_USERNAME, 200, True), + (False, CARL_USERNAME, 403, None), + ) + @ddt.unpack + def test_allow_unprivileged_response(self, allow_unprivileged, requesting_user, expect, include_remote_id): + self.client.login(username=requesting_user, password=PASSWORD) + with override_settings(ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY=allow_unprivileged): + url = self.make_url({'username': ALICE_USERNAME}) + response = self.client.get(url) + self.assertEqual(response.status_code, expect) + if response.status_code == 200: + self.assertGreater(len(response.data['active']), 0) + for provider_data in response.data['active']: + self.assertEqual(include_remote_id, 'remote_id' in provider_data) + + def test_allow_query_by_email(self): + self.client.login(username=ALICE_USERNAME, password=PASSWORD) + url = self.make_url({'email': '{}@example.com'.format(ALICE_USERNAME)}) + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + self.assertGreater(len(response.data['active']), 0) + + def test_throttling(self): + # Default throttle is 10/min. Make 11 requests to verify + throttling_user = UserFactory.create(password=PASSWORD) + self.client.login(username=throttling_user.username, password=PASSWORD) + url = self.make_url({'username': ALICE_USERNAME}) + with override_settings(ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY=True): + for _ in range(10): + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + +@override_settings(EDX_API_KEY=VALID_API_KEY) +@ddt.ddt +@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') +class UserViewAPITests(UserViewsMixin, TpaAPITestCase): + """ + Test the Third Party Auth User REST API + """ + + def make_url(self, identifier): + """ + Return the view URL, with the identifier provided + """ + return reverse( + 'third_party_auth_users_api', + kwargs={'username': identifier.values()[0]} + ) + + +@override_settings(EDX_API_KEY=VALID_API_KEY) +@ddt.ddt +@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') +class UserViewV2APITests(UserViewsMixin, TpaAPITestCase): + """ + Test the Third Party Auth User REST API + """ + + def make_url(self, identifier): + """ + Return the view URL, with the identifier provided + """ + return '?'.join([ + reverse('third_party_auth_users_api_v2'), + six.moves.urllib.parse.urlencode(identifier) + ]) + @override_settings(EDX_API_KEY=VALID_API_KEY) @ddt.ddt diff --git a/common/djangoapps/third_party_auth/api/urls.py b/common/djangoapps/third_party_auth/api/urls.py index dc1673dae2e..df6209ebade 100644 --- a/common/djangoapps/third_party_auth/api/urls.py +++ b/common/djangoapps/third_party_auth/api/urls.py @@ -3,7 +3,7 @@ from django.conf import settings from django.conf.urls import url -from .views import UserMappingView, UserView +from .views import UserMappingView, UserView, UserViewV2 PROVIDER_PATTERN = r'(?P<provider_id>[\w.+-]+)(?:\:(?P<idp_slug>[\w.+-]+))?' @@ -14,6 +14,7 @@ urlpatterns = [ UserView.as_view(), name='third_party_auth_users_api', ), + url(r'^v0/users/', UserViewV2.as_view(), name='third_party_auth_users_api_v2'), url( r'^v0/providers/{provider_pattern}/users$'.format(provider_pattern=PROVIDER_PATTERN), UserMappingView.as_view(), diff --git a/common/djangoapps/third_party_auth/api/views.py b/common/djangoapps/third_party_auth/api/views.py index fe590a46eb6..5004a653691 100644 --- a/common/djangoapps/third_party_auth/api/views.py +++ b/common/djangoapps/third_party_auth/api/views.py @@ -1,11 +1,15 @@ """ Third Party Auth REST API views """ + +from collections import namedtuple + +from django.conf import settings from django.contrib.auth.models import User from django.db.models import Q from django.http import Http404 from edx_rest_framework_extensions.auth.session.authentication import SessionAuthenticationAllowInactiveUser -from rest_framework import exceptions, status +from rest_framework import exceptions, status, throttling from rest_framework.generics import ListAPIView from rest_framework.response import Response from rest_framework.views import APIView @@ -20,13 +24,132 @@ from third_party_auth.api.permissions import ThirdPartyAuthProviderApiPermission from third_party_auth.provider import Registry -class UserView(APIView): +class ProviderBaseThrottle(throttling.UserRateThrottle): + """ + Base throttle for provider queries + """ + + def allow_request(self, request, view): + """ + Only throttle unprivileged requests. + """ + if view.is_unprivileged_query(request, view.get_identifier_for_requested_user(request)): + return super(ProviderBaseThrottle, self).allow_request(request, view) + return True + + +class ProviderBurstThrottle(ProviderBaseThrottle): + """ + Maximum number of provider requests in a quick burst. + """ + rate = settings.TPA_PROVIDER_BURST_THROTTLE # Default '10/min' + + +class ProviderSustainedThrottle(ProviderBaseThrottle): + """ + Maximum number of provider requests over time. + """ + rate = settings.TPA_PROVIDER_SUSTAINED_THROTTLE # Default '50/day' + + +class BaseUserView(APIView): + """ + Common core of UserView and UserViewV2 + """ + identifier = namedtuple('identifier', ['kind', 'value']) + identifier_kinds = ['email', 'username'] + + authentication_classes = ( + # Users may want to view/edit the providers used for authentication before they've + # activated their account, so we allow inactive users. + OAuth2AuthenticationAllowInactiveUser, + SessionAuthenticationAllowInactiveUser, + ) + throttle_classes = [ProviderSustainedThrottle, ProviderBurstThrottle] + + def do_get(self, request, identifier): + """ + Fulfill the request, now that the identifier has been specified. + """ + is_unprivileged = self.is_unprivileged_query(request, identifier) + + if is_unprivileged: + if not getattr(settings, 'ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY', False): + return Response(status=status.HTTP_403_FORBIDDEN) + try: + user = User.objects.get(**{identifier.kind: identifier.value}) + except User.DoesNotExist: + return Response(status=status.HTTP_404_NOT_FOUND) + + providers = pipeline.get_provider_user_states(user) + + active_providers = [ + self.get_provider_data(assoc, is_unprivileged) + for assoc in providers if assoc.has_account + ] + + # In the future this can be trivially modified to return the inactive/disconnected providers as well. + + return Response({ + "active": active_providers + }) + + def get_provider_data(self, assoc, is_unprivileged): + """ + Return the data for the specified provider. + + If the request is unprivileged, do not return the remote ID of the user. + """ + provider_data = { + "provider_id": assoc.provider.provider_id, + "name": assoc.provider.name, + } + if not is_unprivileged: + provider_data["remote_id"] = assoc.remote_id + return provider_data + + def is_unprivileged_query(self, request, identifier): + """ + Return True if a non-superuser requests information about another user. + + Params must be a dict that includes only one of 'username' or 'email' + """ + if identifier.kind not in self.identifier_kinds: + # This is already checked before we get here, so raise a 500 error + # if the check fails. + raise ValueError("Identifier kind {} not in {}".format(identifier.kind, self.identifier_kinds)) + + self_request = False + if identifier == self.identifier('username', request.user.username): + self_request = True + elif identifier.kind == 'email' and getattr(identifier, 'value', object()) == request.user.email: + # AnonymousUser does not have an email attribute, so fall back to + # something that will never compare equal to the provided email. + self_request = True + if self_request: + # We can always ask for our own provider + return False + # We are querying permissions for a user other than the current user. + if not request.user.is_superuser and not ApiKeyHeaderPermission().has_permission(request, self): + # The user does not have elevated permissions. + return True + return False + + +class UserView(BaseUserView): """ List the third party auth accounts linked to the specified user account. + [DEPRECATED] + + This view uses heuristics to guess whether the provided identifier is a + username or email address. Instead, use /api/third_party_auth/v0/users/ + and specify ?username=foo or ?email=foo@exmaple.com. + **Example Request** GET /api/third_party_auth/v0/users/{username} + GET /api/third_party_auth/v0/users/{email@example.com} **Response Values** @@ -45,18 +168,11 @@ class UserView(APIView): is what is used to link the user to their edX account during login. """ - authentication_classes = ( - # Users may want to view/edit the providers used for authentication before they've - # activated their account, so we allow inactive users. - OAuth2AuthenticationAllowInactiveUser, - SessionAuthenticationAllowInactiveUser, - ) def get(self, request, username): - """Create, read, or update enrollment information for a user. + """Read provider information for a user. - HTTP Endpoint for all CRUD operations for a user course enrollment. Allows creation, reading, and - updates of the current enrollment for a particular course. + Allows reading the list of providers for a specified user. Args: request (Request): The HTTP GET request @@ -66,34 +182,80 @@ class UserView(APIView): JSON serialized list of the providers linked to this user. """ - if request.user.username != username: - # We are querying permissions for a user other than the current user. - if not request.user.is_superuser and not ApiKeyHeaderPermission().has_permission(request, self): - # Return a 403 (Unauthorized) without validating 'username', so that we - # do not let users probe the existence of other user accounts. - return Response(status=status.HTTP_403_FORBIDDEN) + identifier = self.get_identifier_for_requested_user(request) + return self.do_get(request, identifier) - try: - user = User.objects.get(username=username) - except User.DoesNotExist: - return Response(status=status.HTTP_404_NOT_FOUND) + def get_identifier_for_requested_user(self, _request): + """ + Return an identifier namedtuple for the requested user. + """ + if u'@' in self.kwargs[u'username']: + id_kind = u'email' + else: + id_kind = u'username' + return self.identifier(id_kind, self.kwargs[u'username']) - providers = pipeline.get_provider_user_states(user) - active_providers = [ - { - "provider_id": assoc.provider.provider_id, - "name": assoc.provider.name, - "remote_id": assoc.remote_id, - } - for assoc in providers if assoc.has_account - ] +# TODO: When removing deprecated UserView, rename this view to UserView. +class UserViewV2(BaseUserView): + """ + List the third party auth accounts linked to the specified user account. - # In the future this can be trivially modified to return the inactive/disconnected providers as well. + **Example Request** - return Response({ - "active": active_providers - }) + GET /api/third_party_auth/v0/users/?username={username} + GET /api/third_party_auth/v0/users/?email={email@example.com} + + **Response Values** + + If the request for information about the user is successful, an HTTP 200 "OK" response + is returned. + + The HTTP 200 response has the following values. + + * active: A list of all the third party auth providers currently linked + to the given user's account. Each object in this list has the + following attributes: + + * provider_id: The unique identifier of this provider (string) + * name: The name of this provider (string) + * remote_id: The ID of the user according to the provider. This ID + is what is used to link the user to their edX account during + login. + """ + + def get(self, request): + """ + Read provider information for a user. + + Allows reading the list of providers for a specified user. + + Args: + request (Request): The HTTP GET request + + Request Parameters: + Must provide one of 'email' or 'username'. If both are provided, + the username will be ignored. + + Return: + JSON serialized list of the providers linked to this user. + + """ + identifier = self.get_identifier_for_requested_user(request) + return self.do_get(request, identifier) + + def get_identifier_for_requested_user(self, request): + """ + Return an identifier namedtuple for the requested user. + """ + identifier = None + for id_kind in self.identifier_kinds: + if id_kind in request.GET: + identifier = self.identifier(id_kind, request.GET[id_kind]) + break + if identifier is None: + raise exceptions.ValidationError(u"Must provide one of {}".format(self.identifier_kinds)) + return identifier class UserMappingView(ListAPIView): @@ -195,7 +357,7 @@ class UserMappingView(ListAPIView): # When using multi-IdP backend, we only retrieve the ones that are for current IdP. # test if the current provider has a slug uid = self.provider.get_social_auth_uid('uid') - if uid is not 'uid': + if uid != 'uid': # if yes, we add a filter for the slug on uid column query_set = query_set.filter(uid__startswith=uid[:-3]) @@ -207,13 +369,13 @@ class UserMappingView(ListAPIView): if usernames: usernames = ','.join(usernames) usernames = set(usernames.split(',')) if usernames else set() - if len(usernames): + if usernames: query = query | Q(user__username__in=usernames) if remote_ids: remote_ids = ','.join(remote_ids) remote_ids = set(remote_ids.split(',')) if remote_ids else set() - if len(remote_ids): + if remote_ids: query = query | Q(uid__in=[self.provider.get_social_auth_uid(remote_id) for remote_id in remote_ids]) return query_set.filter(query) diff --git a/lms/envs/aws.py b/lms/envs/aws.py index 3c8f62ab5c6..e9a0e463d83 100644 --- a/lms/envs/aws.py +++ b/lms/envs/aws.py @@ -707,6 +707,9 @@ if FEATURES.get('ENABLE_THIRD_PARTY_AUTH'): # dict with an arbitrary 'secret_key' and a 'url'. THIRD_PARTY_AUTH_CUSTOM_AUTH_FORMS = AUTH_TOKENS.get('THIRD_PARTY_AUTH_CUSTOM_AUTH_FORMS', {}) + # Whether to allow unprivileged users to discover SSO providers for arbitrary usernames. + ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY = ENV_TOKENS.get('ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY', False) + ##### OAUTH2 Provider ############## if FEATURES.get('ENABLE_OAUTH2_PROVIDER'): OAUTH_OIDC_ISSUER = ENV_TOKENS['OAUTH_OIDC_ISSUER'] @@ -722,6 +725,10 @@ if FEATURES.get('ENABLE_OAUTH2_PROVIDER'): OAUTH_ID_TOKEN_EXPIRATION = ENV_TOKENS.get('OAUTH_ID_TOKEN_EXPIRATION', OAUTH_ID_TOKEN_EXPIRATION) OAUTH_DELETE_EXPIRED = ENV_TOKENS.get('OAUTH_DELETE_EXPIRED', OAUTH_DELETE_EXPIRED) +##### THIRD_PARTY_AUTH ############# +TPA_PROVIDER_BURST_THROTTLE = ENV_TOKENS.get('TPA_PROVIDER_BURST_THROTTLE', TPA_PROVIDER_BURST_THROTTLE) +TPA_PROVIDER_SUSTAINED_THROTTLE = ENV_TOKENS.get('TPA_PROVIDER_SUSTAINED_THROTTLE', TPA_PROVIDER_SUSTAINED_THROTTLE) + ##### ADVANCED_SECURITY_CONFIG ##### ADVANCED_SECURITY_CONFIG = ENV_TOKENS.get('ADVANCED_SECURITY_CONFIG', {}) diff --git a/lms/envs/common.py b/lms/envs/common.py index 7df659fcb09..7fda273083e 100644 --- a/lms/envs/common.py +++ b/lms/envs/common.py @@ -521,6 +521,10 @@ OAUTH2_PROVIDER_APPLICATION_MODEL = 'oauth2_provider.Application' OAUTH_DELETE_EXPIRED = True OAUTH_ID_TOKEN_EXPIRATION = 60 * 60 +################################## THIRD_PARTY_AUTH CONFIGURATION ############################# +TPA_PROVIDER_BURST_THROTTLE = '10/min' +TPA_PROVIDER_SUSTAINED_THROTTLE = '50/hr' + ################################## TEMPLATE CONFIGURATION ##################################### # Mako templating import tempfile -- GitLab