From 4cb4be6afecb64ca14ece8e46a4a5cdaf93b4d81 Mon Sep 17 00:00:00 2001
From: Nizar Mahmoud <nizarmah@hotmail.com>
Date: Thu, 1 Apr 2021 19:56:46 +0300
Subject: [PATCH] feat: associates user by email for oauth when tpa is required

This change associates users signing in using oauth providers when tpa is required, verifying that only a single database user is associated with the email.

For more information as to why this was added in a separate pipeline, check edx-platform#25935.
---
 .../djangoapps/third_party_auth/pipeline.py   | 61 ++++++++++---------
 .../djangoapps/third_party_auth/settings.py   |  1 +
 .../third_party_auth/tests/specs/base.py      | 21 +++++++
 .../tests/specs/test_azuread.py               |  4 ++
 .../tests/specs/test_google.py                |  4 ++
 .../tests/specs/test_linkedin.py              |  4 ++
 .../tests/specs/test_twitter.py               |  4 ++
 .../third_party_auth/tests/test_utils.py      | 51 ++++++++++++++++
 common/djangoapps/third_party_auth/utils.py   | 34 +++++++++++
 9 files changed, 156 insertions(+), 28 deletions(-)

diff --git a/common/djangoapps/third_party_auth/pipeline.py b/common/djangoapps/third_party_auth/pipeline.py
index 12562828969..4c4df14cbb5 100644
--- a/common/djangoapps/third_party_auth/pipeline.py
+++ b/common/djangoapps/third_party_auth/pipeline.py
@@ -79,7 +79,6 @@ from django.urls import reverse
 from edx_django_utils.monitoring import set_custom_attribute
 from social_core.exceptions import AuthException
 from social_core.pipeline import partial
-from social_core.pipeline.social_auth import associate_by_email
 from social_core.utils import module_member, slugify
 
 from common.djangoapps import third_party_auth
@@ -90,9 +89,12 @@ from openedx.core.djangoapps.site_configuration import helpers as configuration_
 from openedx.core.djangoapps.user_api import accounts
 from openedx.core.djangoapps.user_api.accounts.utils import is_multiple_sso_accounts_association_to_saml_user_enabled
 from openedx.core.djangoapps.user_authn import cookies as user_authn_cookies
+from openedx.core.djangoapps.user_authn.toggles import is_require_third_party_auth_enabled
 from common.djangoapps.third_party_auth.utils import (
+    get_associated_user_by_email_response,
     get_user_from_email,
     is_enterprise_customer_user,
+    is_oauth_provider,
     is_saml_provider,
     user_exists,
 )
@@ -735,16 +737,30 @@ def associate_by_email_if_login_api(auth_entry, backend, details, user, current_
     if auth_entry == AUTH_ENTRY_LOGIN_API:
         # Temporary custom attribute to help ensure there is no usage.
         set_custom_attribute('deprecated_auth_entry_login_api', True)
-        association_response = associate_by_email(backend, details, user, *args, **kwargs)
-        if (
-            association_response and
-            association_response.get('user') and
-            association_response['user'].is_active
-        ):
-            # Only return the user matched by email if their email has been activated.
-            # Otherwise, an illegitimate user can create an account with another user's
-            # email address and the legitimate user would now login to the illegitimate
-            # account.
+
+        association_response, user_is_active = get_associated_user_by_email_response(
+            backend, details, user, *args, **kwargs)
+
+        if user_is_active:
+            return association_response
+
+
+@partial.partial
+def associate_by_email_if_oauth(auth_entry, backend, details, user, strategy, *args, **kwargs):
+    """
+    This pipeline step associates the current social auth with the user with the
+    same email address in the database.  It defers to the social library's associate_by_email
+    implementation, which verifies that only a single database user is associated with the email.
+
+    This association is done ONLY if the user entered the pipeline belongs to Oauth provider and
+    `ENABLE_REQUIRE_THIRD_PARTY_AUTH` is enabled.
+    """
+
+    if is_require_third_party_auth_enabled() and is_oauth_provider(backend.name, **kwargs):
+        association_response, user_is_active = get_associated_user_by_email_response(
+            backend, details, user, *args, **kwargs)
+
+        if user_is_active:
             return association_response
 
 
@@ -786,23 +802,10 @@ def associate_by_email_if_saml(auth_entry, backend, details, user, strategy, *ar
             if enterprise_customer_user:
                 # this is python social auth pipeline default method to automatically associate social accounts
                 # if the email already matches a user account.
-                association_response = associate_by_email(backend, details, user, *args, **kwargs)
-
-                if (
-                    association_response and
-                    association_response.get('user') and
-                    association_response['user'].is_active
-                ):
-                    # Only return the user matched by email if their email has been activated.
-                    # Otherwise, an illegitimate user can create an account with another user's
-                    # email address and the legitimate user would now login to the illegitimate
-                    # account.
-                    return association_response
-                elif (
-                    association_response and
-                    association_response.get('user') and
-                    not association_response['user'].is_active
-                ):
+                association_response, user_is_active = get_associated_user_by_email_response(
+                    backend, details, user, *args, **kwargs)
+
+                if not user_is_active:
                     logger.info(
                         '[Multiple_SSO_SAML_Accounts_Association_to_User] User association account is not'
                         ' active: User Email: {email}, User ID: {user_id}, Provider ID: {provider_id},'
@@ -815,6 +818,8 @@ def associate_by_email_if_saml(auth_entry, backend, details, user, strategy, *ar
                     )
                     return None
 
+                return association_response
+
         except Exception as ex:  # pylint: disable=broad-except
             logger.exception('[Multiple_SSO_SAML_Accounts_Association_to_User] Error in'
                              ' saml multiple accounts association: User ID: %s, User Email: %s:,'
diff --git a/common/djangoapps/third_party_auth/settings.py b/common/djangoapps/third_party_auth/settings.py
index dc787ba06da..2d714704c40 100644
--- a/common/djangoapps/third_party_auth/settings.py
+++ b/common/djangoapps/third_party_auth/settings.py
@@ -55,6 +55,7 @@ def apply_settings(django_settings):
         'social_core.pipeline.social_auth.social_user',
         'common.djangoapps.third_party_auth.pipeline.associate_by_email_if_login_api',
         'common.djangoapps.third_party_auth.pipeline.associate_by_email_if_saml',
+        'common.djangoapps.third_party_auth.pipeline.associate_by_email_if_oauth',
         'common.djangoapps.third_party_auth.pipeline.get_username',
         'common.djangoapps.third_party_auth.pipeline.set_pipeline_timeout',
         'common.djangoapps.third_party_auth.pipeline.ensure_user_information',
diff --git a/common/djangoapps/third_party_auth/tests/specs/base.py b/common/djangoapps/third_party_auth/tests/specs/base.py
index aeadd6d37dd..fd9bf3fa96d 100644
--- a/common/djangoapps/third_party_auth/tests/specs/base.py
+++ b/common/djangoapps/third_party_auth/tests/specs/base.py
@@ -789,6 +789,27 @@ class IntegrationTest(testutil.TestCase, test.TestCase, HelperMixin):
         post_request = self._get_login_post_request(strategy)
         self.assert_json_failure_response_is_missing_social_auth(login_user(post_request))
 
+    @django_utils.override_settings(ENABLE_REQUIRE_THIRD_PARTY_AUTH=True)
+    def test_signin_associates_user_if_oauth_provider_and_tpa_is_required(self):
+        """
+        Tests associate user by email with oauth provider and `ENABLE_REQUIRE_THIRD_PARTY_AUTH` enabled
+        """
+        username, email, password = self.get_username(), 'user@example.com', 'password'
+
+        _, strategy = self.get_request_and_strategy(
+            auth_entry=pipeline.AUTH_ENTRY_LOGIN, redirect_uri='social:complete')
+
+        user = self.create_user_models_for_existing_account(strategy, email, password, username, skip_social_auth=True)
+
+        with mock.patch(
+            'common.djangoapps.third_party_auth.pipeline.get_associated_user_by_email_response',
+            return_value=[{'user': user}, True],
+        ):
+            strategy.request.backend.auth_complete = mock.MagicMock(return_value=self.fake_auth_complete(strategy))
+
+            post_request = self._get_login_post_request(strategy)
+            self.assert_json_success_response_looks_correct(login_user(post_request), verify_redirect_url=True)
+
     def test_first_party_auth_trumps_third_party_auth_but_is_invalid_when_only_email_in_request(self):
         self.assert_first_party_auth_trumps_third_party_auth(email='user@example.com')
 
diff --git a/common/djangoapps/third_party_auth/tests/specs/test_azuread.py b/common/djangoapps/third_party_auth/tests/specs/test_azuread.py
index eeffc4afe76..283cbc521da 100644
--- a/common/djangoapps/third_party_auth/tests/specs/test_azuread.py
+++ b/common/djangoapps/third_party_auth/tests/specs/test_azuread.py
@@ -7,6 +7,10 @@ from common.djangoapps.third_party_auth.tests.specs import base
 class AzureADOauth2IntegrationTest(base.Oauth2IntegrationTest):  # lint-amnesty, pylint: disable=test-inherits-tests
     """Integration tests for Azure Active Directory / Microsoft Account provider."""
 
+    PROVIDER_NAME = "azure"
+    PROVIDER_BACKEND = "azure-oauth2"
+    PROVIDER_ID = "oa2-azure-oauth2"
+
     def setUp(self):
         super().setUp()
         self.provider = self.configure_azure_ad_provider(
diff --git a/common/djangoapps/third_party_auth/tests/specs/test_google.py b/common/djangoapps/third_party_auth/tests/specs/test_google.py
index c74b78d1c19..0592f50734a 100644
--- a/common/djangoapps/third_party_auth/tests/specs/test_google.py
+++ b/common/djangoapps/third_party_auth/tests/specs/test_google.py
@@ -19,6 +19,10 @@ from common.djangoapps.third_party_auth.tests.specs import base
 class GoogleOauth2IntegrationTest(base.Oauth2IntegrationTest):  # lint-amnesty, pylint: disable=test-inherits-tests
     """Integration tests for provider.GoogleOauth2."""
 
+    PROVIDER_NAME = "google"
+    PROVIDER_BACKEND = "google-oauth2"
+    PROVIDER_ID = "oa2-google-oauth2"
+
     def setUp(self):
         super().setUp()
         self.provider = self.configure_google_provider(
diff --git a/common/djangoapps/third_party_auth/tests/specs/test_linkedin.py b/common/djangoapps/third_party_auth/tests/specs/test_linkedin.py
index fd9ce1da5ff..f8184cd109c 100644
--- a/common/djangoapps/third_party_auth/tests/specs/test_linkedin.py
+++ b/common/djangoapps/third_party_auth/tests/specs/test_linkedin.py
@@ -16,6 +16,10 @@ def get_localized_name(name):
 class LinkedInOauth2IntegrationTest(base.Oauth2IntegrationTest):  # lint-amnesty, pylint: disable=test-inherits-tests
     """Integration tests for provider.LinkedInOauth2."""
 
+    PROVIDER_NAME = "linkedin"
+    PROVIDER_BACKEND = "linkedin-oauth2"
+    PROVIDER_ID = "oa2-linkedin-oauth2"
+
     def setUp(self):
         super().setUp()
         self.provider = self.configure_linkedin_provider(
diff --git a/common/djangoapps/third_party_auth/tests/specs/test_twitter.py b/common/djangoapps/third_party_auth/tests/specs/test_twitter.py
index e592a624198..d67373ccf67 100644
--- a/common/djangoapps/third_party_auth/tests/specs/test_twitter.py
+++ b/common/djangoapps/third_party_auth/tests/specs/test_twitter.py
@@ -10,6 +10,10 @@ from common.djangoapps.third_party_auth.tests.specs import base
 class TwitterIntegrationTest(base.Oauth2IntegrationTest):  # lint-amnesty, pylint: disable=test-inherits-tests
     """Integration tests for Twitter backend."""
 
+    PROVIDER_NAME = "twitter"
+    PROVIDER_BACKEND = "twitter-oauth2"
+    PROVIDER_ID = "oa2-twitter-oauth2"
+
     def setUp(self):
         super().setUp()
         self.provider = self.configure_twitter_provider(
diff --git a/common/djangoapps/third_party_auth/tests/test_utils.py b/common/djangoapps/third_party_auth/tests/test_utils.py
index 8af2909415c..dc13bc8a30d 100644
--- a/common/djangoapps/third_party_auth/tests/test_utils.py
+++ b/common/djangoapps/third_party_auth/tests/test_utils.py
@@ -4,14 +4,19 @@ Tests for third_party_auth utility functions.
 
 
 import unittest
+from unittest import mock
+from unittest.mock import MagicMock
 
+import ddt
 from django.conf import settings
 
 from common.djangoapps.student.tests.factories import UserFactory
 from common.djangoapps.third_party_auth.tests.testutil import TestCase
 from common.djangoapps.third_party_auth.utils import (
+    get_associated_user_by_email_response,
     get_user_from_email,
     is_enterprise_customer_user,
+    is_oauth_provider,
     user_exists,
     convert_saml_slug_provider_id,
 )
@@ -21,6 +26,7 @@ from openedx.features.enterprise_support.tests.factories import (
 )
 
 
+@ddt.ddt
 @unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
 class TestUtils(TestCase):
     """
@@ -77,3 +83,48 @@ class TestUtils(TestCase):
 
         assert is_enterprise_customer_user('the-provider', user)
         assert not is_enterprise_customer_user('the-provider', other_user)
+
+    @ddt.data(
+        ('saml-farkle', False),
+        ('oa2-fergus', True),
+        ('oa2-felicia', True),
+    )
+    @ddt.unpack
+    def test_is_oauth_provider(self, provider_id, oauth_provider):
+        """
+        Tests if the backend name is that of an auth provider or not
+        """
+        with mock.patch(
+            'common.djangoapps.third_party_auth.utils.provider.Registry.get_from_pipeline'
+        ) as get_from_pipeline:
+            get_from_pipeline.return_value.provider_id = provider_id
+
+            self.assertEqual(is_oauth_provider('backend_name'), oauth_provider)
+
+    @ddt.data(
+        (None, False),
+        (None, False),
+        ('The Muffin Man', True),
+        ('Gingerbread Man', False),
+    )
+    @ddt.unpack
+    def test_get_associated_user_by_email_response(self, user, user_is_active):
+        """
+        Tests if an association response is returned for a user
+        """
+        with mock.patch(
+            'common.djangoapps.third_party_auth.utils.associate_by_email',
+            side_effect=lambda _b, _d, u, *_a, **_k: {'user': u} if u else None,
+        ):
+            mock_user = MagicMock(return_value=user)
+            mock_user.is_active = user_is_active
+
+            association_response, user_is_active_resonse = get_associated_user_by_email_response(
+                backend=None, details=None, user=mock_user)
+
+            if association_response:
+                self.assertEqual(association_response['user'](), user)
+                self.assertEqual(user_is_active_resonse, user_is_active)
+            else:
+                self.assertIsNone(association_response)
+                self.assertFalse(user_is_active_resonse)
diff --git a/common/djangoapps/third_party_auth/utils.py b/common/djangoapps/third_party_auth/utils.py
index da0d55af12a..46da90794d1 100644
--- a/common/djangoapps/third_party_auth/utils.py
+++ b/common/djangoapps/third_party_auth/utils.py
@@ -5,6 +5,9 @@ Utility functions for third_party_auth
 from uuid import UUID
 from django.contrib.auth.models import User  # lint-amnesty, pylint: disable=imported-auth-user
 from enterprise.models import EnterpriseCustomerUser, EnterpriseCustomerIdentityProvider
+from social_core.pipeline.social_auth import associate_by_email
+
+from common.djangoapps.third_party_auth.models import OAuth2ProviderConfig
 from . import provider
 
 
@@ -92,3 +95,34 @@ def is_enterprise_customer_user(provider_id, user):
 
     return EnterpriseCustomerUser.objects.filter(enterprise_customer=enterprise_idp.enterprise_customer,
                                                  user_id=user.id).exists()
+
+
+def is_oauth_provider(backend_name, **kwargs):
+    """
+    Verify that the third party provider uses oauth
+    """
+    current_provider = provider.Registry.get_from_pipeline({'backend': backend_name, 'kwargs': kwargs})
+    if current_provider:
+        return current_provider.provider_id.startswith(OAuth2ProviderConfig.prefix)
+
+    return False
+
+
+def get_associated_user_by_email_response(backend, details, user, *args, **kwargs):
+    """
+    Gets the user associated by the `associate_by_email` social auth method
+    """
+
+    association_response = associate_by_email(backend, details, user, *args, **kwargs)
+
+    if (
+        association_response and
+        association_response.get('user')
+    ):
+        # Only return the user matched by email if their email has been activated.
+        # Otherwise, an illegitimate user can create an account with another user's
+        # email address and the legitimate user would now login to the illegitimate
+        # account.
+        return (association_response, association_response['user'].is_active)
+
+    return (None, False)
-- 
GitLab