From b3521e0f8832ce4e7b2dafc5b88b7a28877481e0 Mon Sep 17 00:00:00 2001
From: "J. Cliff Dyer" <jcd@sdf.org>
Date: Tue, 10 Apr 2018 17:03:42 -0400
Subject: [PATCH] Add endpoint to get SAML providers for a user.

View is combined with user SSO views.

Includes a new version of the view that takes explicit "username" or "email".

OC-4285
---
 .../third_party_auth/api/tests/test_views.py  |  98 +++++++-
 .../djangoapps/third_party_auth/api/urls.py   |   3 +-
 .../djangoapps/third_party_auth/api/views.py  | 236 +++++++++++++++---
 lms/envs/aws.py                               |   7 +
 lms/envs/common.py                            |   4 +
 5 files changed, 298 insertions(+), 50 deletions(-)

diff --git a/common/djangoapps/third_party_auth/api/tests/test_views.py b/common/djangoapps/third_party_auth/api/tests/test_views.py
index 2fac4d2a5c7..cb5e15aef9b 100644
--- a/common/djangoapps/third_party_auth/api/tests/test_views.py
+++ b/common/djangoapps/third_party_auth/api/tests/test_views.py
@@ -4,15 +4,16 @@ Tests for the Third Party Auth REST API
 import unittest
 
 import ddt
-from django.urls import reverse
+import six
+from django.conf import settings
 from django.http import QueryDict
+from django.test.utils import override_settings
+from django.urls import reverse
 from mock import patch
 from provider.constants import CONFIDENTIAL
 from provider.oauth2.models import Client, AccessToken
 from openedx.core.lib.api.permissions import ApiKeyHeaderPermission
 from rest_framework.test import APITestCase
-from django.conf import settings
-from django.test.utils import override_settings
 from social_django.models import UserSocialAuth
 
 from student.tests.factories import UserFactory
@@ -29,6 +30,7 @@ ALICE_USERNAME = "alice"
 CARL_USERNAME = "carl"
 STAFF_USERNAME = "staff"
 ADMIN_USERNAME = "admin"
+NONEXISTENT_USERNAME = "nobody"
 # These users will be created and linked to third party accounts:
 LINKED_USERS = (ALICE_USERNAME, STAFF_USERNAME, ADMIN_USERNAME)
 PASSWORD = "edx"
@@ -62,9 +64,10 @@ class TpaAPITestCase(ThirdPartyAuthTestMixin, APITestCase):
             make_staff = (username == STAFF_USERNAME) or make_superuser
             user = UserFactory.create(
                 username=username,
+                email='{}@example.com'.format(username),
                 password=PASSWORD,
                 is_staff=make_staff,
-                is_superuser=make_superuser
+                is_superuser=make_superuser,
             )
             UserSocialAuth.objects.create(
                 user=user,
@@ -77,15 +80,13 @@ class TpaAPITestCase(ThirdPartyAuthTestMixin, APITestCase):
                 uid='{}:remote_{}'.format(testshib.slug, username),
             )
         # Create another user not linked to any providers:
-        UserFactory.create(username=CARL_USERNAME, password=PASSWORD)
+        UserFactory.create(username=CARL_USERNAME, email='{}@example.com'.format(CARL_USERNAME), password=PASSWORD)
 
 
-@override_settings(EDX_API_KEY=VALID_API_KEY)
 @ddt.ddt
-@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
-class UserViewAPITests(TpaAPITestCase):
+class UserViewsMixin(object):
     """
-    Test the Third Party Auth User REST API
+    Generic TestCase to exercise the v1 and v2 UserViews.
     """
 
     def expected_active(self, username):
@@ -124,7 +125,7 @@ class UserViewAPITests(TpaAPITestCase):
     @ddt.unpack
     def test_list_connected_providers(self, request_user, target_user, expect_result):
         self.client.login(username=request_user, password=PASSWORD)
-        url = reverse('third_party_auth_users_api', kwargs={'username': target_user})
+        url = self.make_url({'username': target_user})
 
         response = self.client.get(url)
         self.assertEqual(response.status_code, expect_result)
@@ -140,14 +141,87 @@ class UserViewAPITests(TpaAPITestCase):
         (None, ALICE_USERNAME, 403),
     )
     @ddt.unpack
-    def test_list_connected_providers__withapi_key(self, api_key, target_user, expect_result):
-        url = reverse('third_party_auth_users_api', kwargs={'username': target_user})
+    def test_list_connected_providers_with_api_key(self, api_key, target_user, expect_result):
+        url = self.make_url({'username': target_user})
         response = self.client.get(url, HTTP_X_EDX_API_KEY=api_key)
         self.assertEqual(response.status_code, expect_result)
         if expect_result == 200:
             self.assertIn("active", response.data)
             self.assertItemsEqual(response.data["active"], self.expected_active(target_user))
 
+    @ddt.data(
+        (True, ALICE_USERNAME, 200, True),
+        (True, CARL_USERNAME, 200, False),
+        (False, ALICE_USERNAME, 200, True),
+        (False, CARL_USERNAME, 403, None),
+    )
+    @ddt.unpack
+    def test_allow_unprivileged_response(self, allow_unprivileged, requesting_user, expect, include_remote_id):
+        self.client.login(username=requesting_user, password=PASSWORD)
+        with override_settings(ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY=allow_unprivileged):
+            url = self.make_url({'username': ALICE_USERNAME})
+            response = self.client.get(url)
+        self.assertEqual(response.status_code, expect)
+        if response.status_code == 200:
+            self.assertGreater(len(response.data['active']), 0)
+            for provider_data in response.data['active']:
+                self.assertEqual(include_remote_id, 'remote_id' in provider_data)
+
+    def test_allow_query_by_email(self):
+        self.client.login(username=ALICE_USERNAME, password=PASSWORD)
+        url = self.make_url({'email': '{}@example.com'.format(ALICE_USERNAME)})
+        response = self.client.get(url)
+        self.assertEqual(response.status_code, 200)
+        self.assertGreater(len(response.data['active']), 0)
+
+    def test_throttling(self):
+        # Default throttle is 10/min.  Make 11 requests to verify
+        throttling_user = UserFactory.create(password=PASSWORD)
+        self.client.login(username=throttling_user.username, password=PASSWORD)
+        url = self.make_url({'username': ALICE_USERNAME})
+        with override_settings(ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY=True):
+            for _ in range(10):
+                response = self.client.get(url)
+                self.assertEqual(response.status_code, 200)
+            response = self.client.get(url)
+            self.assertEqual(response.status_code, 200)
+
+
+@override_settings(EDX_API_KEY=VALID_API_KEY)
+@ddt.ddt
+@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
+class UserViewAPITests(UserViewsMixin, TpaAPITestCase):
+    """
+    Test the Third Party Auth User REST API
+    """
+
+    def make_url(self, identifier):
+        """
+        Return the view URL, with the identifier provided
+        """
+        return reverse(
+            'third_party_auth_users_api',
+            kwargs={'username': identifier.values()[0]}
+        )
+
+
+@override_settings(EDX_API_KEY=VALID_API_KEY)
+@ddt.ddt
+@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
+class UserViewV2APITests(UserViewsMixin, TpaAPITestCase):
+    """
+    Test the Third Party Auth User REST API
+    """
+
+    def make_url(self, identifier):
+        """
+        Return the view URL, with the identifier provided
+        """
+        return '?'.join([
+            reverse('third_party_auth_users_api_v2'),
+            six.moves.urllib.parse.urlencode(identifier)
+        ])
+
 
 @override_settings(EDX_API_KEY=VALID_API_KEY)
 @ddt.ddt
diff --git a/common/djangoapps/third_party_auth/api/urls.py b/common/djangoapps/third_party_auth/api/urls.py
index dc1673dae2e..df6209ebade 100644
--- a/common/djangoapps/third_party_auth/api/urls.py
+++ b/common/djangoapps/third_party_auth/api/urls.py
@@ -3,7 +3,7 @@
 from django.conf import settings
 from django.conf.urls import url
 
-from .views import UserMappingView, UserView
+from .views import UserMappingView, UserView, UserViewV2
 
 
 PROVIDER_PATTERN = r'(?P<provider_id>[\w.+-]+)(?:\:(?P<idp_slug>[\w.+-]+))?'
@@ -14,6 +14,7 @@ urlpatterns = [
         UserView.as_view(),
         name='third_party_auth_users_api',
     ),
+    url(r'^v0/users/', UserViewV2.as_view(), name='third_party_auth_users_api_v2'),
     url(
         r'^v0/providers/{provider_pattern}/users$'.format(provider_pattern=PROVIDER_PATTERN),
         UserMappingView.as_view(),
diff --git a/common/djangoapps/third_party_auth/api/views.py b/common/djangoapps/third_party_auth/api/views.py
index fe590a46eb6..5004a653691 100644
--- a/common/djangoapps/third_party_auth/api/views.py
+++ b/common/djangoapps/third_party_auth/api/views.py
@@ -1,11 +1,15 @@
 """
 Third Party Auth REST API views
 """
+
+from collections import namedtuple
+
+from django.conf import settings
 from django.contrib.auth.models import User
 from django.db.models import Q
 from django.http import Http404
 from edx_rest_framework_extensions.auth.session.authentication import SessionAuthenticationAllowInactiveUser
-from rest_framework import exceptions, status
+from rest_framework import exceptions, status, throttling
 from rest_framework.generics import ListAPIView
 from rest_framework.response import Response
 from rest_framework.views import APIView
@@ -20,13 +24,132 @@ from third_party_auth.api.permissions import ThirdPartyAuthProviderApiPermission
 from third_party_auth.provider import Registry
 
 
-class UserView(APIView):
+class ProviderBaseThrottle(throttling.UserRateThrottle):
+    """
+    Base throttle for provider queries
+    """
+
+    def allow_request(self, request, view):
+        """
+        Only throttle unprivileged requests.
+        """
+        if view.is_unprivileged_query(request, view.get_identifier_for_requested_user(request)):
+            return super(ProviderBaseThrottle, self).allow_request(request, view)
+        return True
+
+
+class ProviderBurstThrottle(ProviderBaseThrottle):
+    """
+    Maximum number of provider requests in a quick burst.
+    """
+    rate = settings.TPA_PROVIDER_BURST_THROTTLE  # Default '10/min'
+
+
+class ProviderSustainedThrottle(ProviderBaseThrottle):
+    """
+    Maximum number of provider requests over time.
+    """
+    rate = settings.TPA_PROVIDER_SUSTAINED_THROTTLE  # Default '50/day'
+
+
+class BaseUserView(APIView):
+    """
+    Common core of UserView and UserViewV2
+    """
+    identifier = namedtuple('identifier', ['kind', 'value'])
+    identifier_kinds = ['email', 'username']
+
+    authentication_classes = (
+        # Users may want to view/edit the providers used for authentication before they've
+        # activated their account, so we allow inactive users.
+        OAuth2AuthenticationAllowInactiveUser,
+        SessionAuthenticationAllowInactiveUser,
+    )
+    throttle_classes = [ProviderSustainedThrottle, ProviderBurstThrottle]
+
+    def do_get(self, request, identifier):
+        """
+        Fulfill the request, now that the identifier has been specified.
+        """
+        is_unprivileged = self.is_unprivileged_query(request, identifier)
+
+        if is_unprivileged:
+            if not getattr(settings, 'ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY', False):
+                return Response(status=status.HTTP_403_FORBIDDEN)
+        try:
+            user = User.objects.get(**{identifier.kind: identifier.value})
+        except User.DoesNotExist:
+            return Response(status=status.HTTP_404_NOT_FOUND)
+
+        providers = pipeline.get_provider_user_states(user)
+
+        active_providers = [
+            self.get_provider_data(assoc, is_unprivileged)
+            for assoc in providers if assoc.has_account
+        ]
+
+        # In the future this can be trivially modified to return the inactive/disconnected providers as well.
+
+        return Response({
+            "active": active_providers
+        })
+
+    def get_provider_data(self, assoc, is_unprivileged):
+        """
+        Return the data for the specified provider.
+
+        If the request is unprivileged, do not return the remote ID of the user.
+        """
+        provider_data = {
+            "provider_id": assoc.provider.provider_id,
+            "name": assoc.provider.name,
+        }
+        if not is_unprivileged:
+            provider_data["remote_id"] = assoc.remote_id
+        return provider_data
+
+    def is_unprivileged_query(self, request, identifier):
+        """
+        Return True if a non-superuser requests information about another user.
+
+        Params must be a dict that includes only one of 'username' or 'email'
+        """
+        if identifier.kind not in self.identifier_kinds:
+            # This is already checked before we get here, so raise a 500 error
+            # if the check fails.
+            raise ValueError("Identifier kind {} not in {}".format(identifier.kind, self.identifier_kinds))
+
+        self_request = False
+        if identifier == self.identifier('username', request.user.username):
+            self_request = True
+        elif identifier.kind == 'email' and getattr(identifier, 'value', object()) == request.user.email:
+            # AnonymousUser does not have an email attribute, so fall back to
+            # something that will never compare equal to the provided email.
+            self_request = True
+        if self_request:
+            # We can always ask for our own provider
+            return False
+        # We are querying permissions for a user other than the current user.
+        if not request.user.is_superuser and not ApiKeyHeaderPermission().has_permission(request, self):
+            # The user does not have elevated permissions.
+            return True
+        return False
+
+
+class UserView(BaseUserView):
     """
     List the third party auth accounts linked to the specified user account.
 
+    [DEPRECATED]
+
+    This view uses heuristics to guess whether the provided identifier is a
+    username or email address.  Instead, use /api/third_party_auth/v0/users/
+    and specify ?username=foo or ?email=foo@exmaple.com.
+
     **Example Request**
 
         GET /api/third_party_auth/v0/users/{username}
+        GET /api/third_party_auth/v0/users/{email@example.com}
 
     **Response Values**
 
@@ -45,18 +168,11 @@ class UserView(APIView):
               is what is used to link the user to their edX account during
               login.
     """
-    authentication_classes = (
-        # Users may want to view/edit the providers used for authentication before they've
-        # activated their account, so we allow inactive users.
-        OAuth2AuthenticationAllowInactiveUser,
-        SessionAuthenticationAllowInactiveUser,
-    )
 
     def get(self, request, username):
-        """Create, read, or update enrollment information for a user.
+        """Read provider information for a user.
 
-        HTTP Endpoint for all CRUD operations for a user course enrollment. Allows creation, reading, and
-        updates of the current enrollment for a particular course.
+        Allows reading the list of providers for a specified user.
 
         Args:
             request (Request): The HTTP GET request
@@ -66,34 +182,80 @@ class UserView(APIView):
             JSON serialized list of the providers linked to this user.
 
         """
-        if request.user.username != username:
-            # We are querying permissions for a user other than the current user.
-            if not request.user.is_superuser and not ApiKeyHeaderPermission().has_permission(request, self):
-                # Return a 403 (Unauthorized) without validating 'username', so that we
-                # do not let users probe the existence of other user accounts.
-                return Response(status=status.HTTP_403_FORBIDDEN)
+        identifier = self.get_identifier_for_requested_user(request)
+        return self.do_get(request, identifier)
 
-        try:
-            user = User.objects.get(username=username)
-        except User.DoesNotExist:
-            return Response(status=status.HTTP_404_NOT_FOUND)
+    def get_identifier_for_requested_user(self, _request):
+        """
+        Return an identifier namedtuple for the requested user.
+        """
+        if u'@' in self.kwargs[u'username']:
+            id_kind = u'email'
+        else:
+            id_kind = u'username'
+        return self.identifier(id_kind, self.kwargs[u'username'])
 
-        providers = pipeline.get_provider_user_states(user)
 
-        active_providers = [
-            {
-                "provider_id": assoc.provider.provider_id,
-                "name": assoc.provider.name,
-                "remote_id": assoc.remote_id,
-            }
-            for assoc in providers if assoc.has_account
-        ]
+# TODO: When removing deprecated UserView, rename this view to UserView.
+class UserViewV2(BaseUserView):
+    """
+    List the third party auth accounts linked to the specified user account.
 
-        # In the future this can be trivially modified to return the inactive/disconnected providers as well.
+    **Example Request**
 
-        return Response({
-            "active": active_providers
-        })
+        GET /api/third_party_auth/v0/users/?username={username}
+        GET /api/third_party_auth/v0/users/?email={email@example.com}
+
+    **Response Values**
+
+        If the request for information about the user is successful, an HTTP 200 "OK" response
+        is returned.
+
+        The HTTP 200 response has the following values.
+
+        * active: A list of all the third party auth providers currently linked
+          to the given user's account. Each object in this list has the
+          following attributes:
+
+            * provider_id: The unique identifier of this provider (string)
+            * name: The name of this provider (string)
+            * remote_id: The ID of the user according to the provider. This ID
+              is what is used to link the user to their edX account during
+              login.
+    """
+
+    def get(self, request):
+        """
+        Read provider information for a user.
+
+        Allows reading the list of providers for a specified user.
+
+        Args:
+            request (Request): The HTTP GET request
+
+        Request Parameters:
+            Must provide one of 'email' or 'username'.  If both are provided,
+            the username will be ignored.
+
+        Return:
+            JSON serialized list of the providers linked to this user.
+
+        """
+        identifier = self.get_identifier_for_requested_user(request)
+        return self.do_get(request, identifier)
+
+    def get_identifier_for_requested_user(self, request):
+        """
+        Return an identifier namedtuple for the requested user.
+        """
+        identifier = None
+        for id_kind in self.identifier_kinds:
+            if id_kind in request.GET:
+                identifier = self.identifier(id_kind, request.GET[id_kind])
+                break
+        if identifier is None:
+            raise exceptions.ValidationError(u"Must provide one of {}".format(self.identifier_kinds))
+        return identifier
 
 
 class UserMappingView(ListAPIView):
@@ -195,7 +357,7 @@ class UserMappingView(ListAPIView):
         # When using multi-IdP backend, we only retrieve the ones that are for current IdP.
         # test if the current provider has a slug
         uid = self.provider.get_social_auth_uid('uid')
-        if uid is not 'uid':
+        if uid != 'uid':
             # if yes, we add a filter for the slug on uid column
             query_set = query_set.filter(uid__startswith=uid[:-3])
 
@@ -207,13 +369,13 @@ class UserMappingView(ListAPIView):
         if usernames:
             usernames = ','.join(usernames)
             usernames = set(usernames.split(',')) if usernames else set()
-            if len(usernames):
+            if usernames:
                 query = query | Q(user__username__in=usernames)
 
         if remote_ids:
             remote_ids = ','.join(remote_ids)
             remote_ids = set(remote_ids.split(',')) if remote_ids else set()
-            if len(remote_ids):
+            if remote_ids:
                 query = query | Q(uid__in=[self.provider.get_social_auth_uid(remote_id) for remote_id in remote_ids])
 
         return query_set.filter(query)
diff --git a/lms/envs/aws.py b/lms/envs/aws.py
index 3c8f62ab5c6..e9a0e463d83 100644
--- a/lms/envs/aws.py
+++ b/lms/envs/aws.py
@@ -707,6 +707,9 @@ if FEATURES.get('ENABLE_THIRD_PARTY_AUTH'):
     # dict with an arbitrary 'secret_key' and a 'url'.
     THIRD_PARTY_AUTH_CUSTOM_AUTH_FORMS = AUTH_TOKENS.get('THIRD_PARTY_AUTH_CUSTOM_AUTH_FORMS', {})
 
+    # Whether to allow unprivileged users to discover SSO providers for arbitrary usernames.
+    ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY = ENV_TOKENS.get('ALLOW_UNPRIVILEGED_SSO_PROVIDER_QUERY', False)
+
 ##### OAUTH2 Provider ##############
 if FEATURES.get('ENABLE_OAUTH2_PROVIDER'):
     OAUTH_OIDC_ISSUER = ENV_TOKENS['OAUTH_OIDC_ISSUER']
@@ -722,6 +725,10 @@ if FEATURES.get('ENABLE_OAUTH2_PROVIDER'):
     OAUTH_ID_TOKEN_EXPIRATION = ENV_TOKENS.get('OAUTH_ID_TOKEN_EXPIRATION', OAUTH_ID_TOKEN_EXPIRATION)
     OAUTH_DELETE_EXPIRED = ENV_TOKENS.get('OAUTH_DELETE_EXPIRED', OAUTH_DELETE_EXPIRED)
 
+##### THIRD_PARTY_AUTH #############
+TPA_PROVIDER_BURST_THROTTLE = ENV_TOKENS.get('TPA_PROVIDER_BURST_THROTTLE', TPA_PROVIDER_BURST_THROTTLE)
+TPA_PROVIDER_SUSTAINED_THROTTLE = ENV_TOKENS.get('TPA_PROVIDER_SUSTAINED_THROTTLE', TPA_PROVIDER_SUSTAINED_THROTTLE)
+
 ##### ADVANCED_SECURITY_CONFIG #####
 ADVANCED_SECURITY_CONFIG = ENV_TOKENS.get('ADVANCED_SECURITY_CONFIG', {})
 
diff --git a/lms/envs/common.py b/lms/envs/common.py
index 7df659fcb09..7fda273083e 100644
--- a/lms/envs/common.py
+++ b/lms/envs/common.py
@@ -521,6 +521,10 @@ OAUTH2_PROVIDER_APPLICATION_MODEL = 'oauth2_provider.Application'
 OAUTH_DELETE_EXPIRED = True
 OAUTH_ID_TOKEN_EXPIRATION = 60 * 60
 
+################################## THIRD_PARTY_AUTH CONFIGURATION #############################
+TPA_PROVIDER_BURST_THROTTLE = '10/min'
+TPA_PROVIDER_SUSTAINED_THROTTLE = '50/hr'
+
 ################################## TEMPLATE CONFIGURATION #####################################
 # Mako templating
 import tempfile
-- 
GitLab