From 13d4091a1afb04f4ec5f37a9a0d76a28870f698e Mon Sep 17 00:00:00 2001
From: Nimisha Asthagiri <nasthagiri@edx.org>
Date: Mon, 8 Oct 2018 11:29:38 -0400
Subject: [PATCH] Fix overriding of token expiration in DOT (ARCH-246)

---
 openedx/core/djangoapps/oauth_dispatch/api.py | 21 ++++---
 .../dot_overrides/validators.py               | 60 ++++++++++++++-----
 .../oauth_dispatch/tests/test_api.py          | 23 ++++---
 3 files changed, 66 insertions(+), 38 deletions(-)

diff --git a/openedx/core/djangoapps/oauth_dispatch/api.py b/openedx/core/djangoapps/oauth_dispatch/api.py
index e1603a5ba2a..0c56bf901c8 100644
--- a/openedx/core/djangoapps/oauth_dispatch/api.py
+++ b/openedx/core/djangoapps/oauth_dispatch/api.py
@@ -7,7 +7,6 @@ from oauthlib.oauth2.rfc6749.errors import OAuth2Error
 from oauthlib.oauth2.rfc6749.tokens import BearerToken
 from oauth2_provider.models import AccessToken as dot_access_token
 from oauth2_provider.models import RefreshToken as dot_refresh_token
-from oauth2_provider.oauth2_backends import get_oauthlib_core
 from oauth2_provider.settings import oauth2_settings as dot_settings
 from provider.oauth2.models import AccessToken as dop_access_token
 from provider.oauth2.models import RefreshToken as dop_refresh_token
@@ -51,8 +50,8 @@ def refresh_dot_access_token(request, client_id, refresh_token, expires_in=None)
     Create and return a new (persisted) access token, given a previously created
     refresh_token, possibly returned from create_dot_access_token above.
     """
-    auth_core = get_oauthlib_core()
     expires_in = _get_expires_in_value(expires_in)
+    auth_core = _get_oauthlib_core(expires_in)
     _populate_refresh_token_request(request, client_id, refresh_token)
 
     # Note: Unlike create_dot_access_token, we use the top-level auth library
@@ -70,13 +69,7 @@ def _get_expires_in_value(expires_in):
     """
     Returns the expires_in value to use for the token.
     """
-    # TODO (ARCH-246) Fix expiration configuration as this does not actually
-    # override the token's expiration. Rather, DOT's save_bearer_token method
-    # will always use dot_settings.ACCESS_TOKEN_EXPIRE_SECONDS.
-    if not expires_in:
-        seconds_in_a_day = 24 * 60 * 60
-        expires_in = settings.OAUTH_EXPIRE_PUBLIC_CLIENT_DAYS * seconds_in_a_day
-    return expires_in
+    return expires_in or dot_settings.ACCESS_TOKEN_EXPIRE_SECONDS
 
 
 def _populate_create_access_token_request(request, user, client, scope=None):
@@ -105,3 +98,13 @@ def _populate_refresh_token_request(request, client_id, refresh_token):
         refresh_token=refresh_token,
         grant_type='refresh_token',
     )
+
+
+def _get_oauthlib_core(expires_in):
+    """
+    Based on oauth2_provider.oauth2_backends.get_oauthlib_core, but allows
+    passing in a value for token_expires_in.
+    """
+    validator = dot_settings.OAUTH2_VALIDATOR_CLASS()
+    server = dot_settings.OAUTH2_SERVER_CLASS(validator, token_expires_in=expires_in)
+    return dot_settings.OAUTH2_BACKEND_CLASS(server)
diff --git a/openedx/core/djangoapps/oauth_dispatch/dot_overrides/validators.py b/openedx/core/djangoapps/oauth_dispatch/dot_overrides/validators.py
index bc265a035cd..5c9aa5ddb65 100644
--- a/openedx/core/djangoapps/oauth_dispatch/dot_overrides/validators.py
+++ b/openedx/core/djangoapps/oauth_dispatch/dot_overrides/validators.py
@@ -3,7 +3,7 @@ Classes that override default django-oauth-toolkit behavior
 """
 from __future__ import unicode_literals
 
-from datetime import datetime
+from datetime import datetime, timedelta
 
 from django.contrib.auth import authenticate, get_user_model
 from django.db.models.signals import pre_save
@@ -82,21 +82,9 @@ class EdxOAuth2Validator(OAuth2Validator):
 
         super(EdxOAuth2Validator, self).save_bearer_token(token, request, *args, **kwargs)
 
-        if RestrictedApplication.should_expire_access_token(request.client):
-            # Since RestrictedApplications will override the DOT defined expiry, so that access_tokens
-            # are always expired, we need to re-read the token from the database and then calculate the
-            # expires_in (in seconds) from what we stored in the database. This value should be a negative
-            #value, meaning that it is already expired
-
-            access_token = AccessToken.objects.get(token=token['access_token'])
-            utc_now = datetime.utcnow().replace(tzinfo=utc)
-            expires_in = (access_token.expires - utc_now).total_seconds()
-
-            # assert that RestrictedApplications only issue expired tokens
-            # blow up processing if we see otherwise
-            assert expires_in < 0
-
-            token['expires_in'] = expires_in
+        is_restricted_client = self._update_token_expiry_if_restricted_client(token, request.client)
+        if not is_restricted_client:
+            self._update_token_expiry_if_overridden_in_request(token, request)
 
         # Restore the original request attributes
         request.grant_type = grant_type
@@ -108,3 +96,43 @@ class EdxOAuth2Validator(OAuth2Validator):
         """
         available_scopes = get_scopes_backend().get_available_scopes(application=client, request=request)
         return set(scopes).issubset(set(available_scopes))
+
+    def _update_token_expiry_if_restricted_client(self, token, client):
+        """
+        Update the token's expires_in value if the given client is a
+        RestrictedApplication and return whether the given client is restricted.
+        """
+        # Since RestrictedApplications override the DOT defined expiry such that
+        # access_tokens are always expired, re-read the token from the database
+        # and calculate expires_in (in seconds) from the database value. This
+        # value should be a negative value, meaning that it is already expired.
+        if RestrictedApplication.should_expire_access_token(client):
+            access_token = AccessToken.objects.get(token=token['access_token'])
+            expires_in = (access_token.expires - _get_utc_now()).total_seconds()
+            assert expires_in < 0
+            token['expires_in'] = expires_in
+            return True
+
+    def _update_token_expiry_if_overridden_in_request(self, token, request):
+        """
+        Update the token's expires_in value if the request specifies an
+        expiration value and update the expires value on the stored AccessToken
+        object.
+
+        This is needed since DOT's save_bearer_token method always uses
+        the dot_settings.ACCESS_TOKEN_EXPIRE_SECONDS value instead of applying
+        the requesting expiration value.
+        """
+        expires_in = getattr(request, 'expires_in', None)
+        if expires_in:
+            access_token = AccessToken.objects.get(token=token['access_token'])
+            access_token.expires = _get_utc_now() + timedelta(seconds=expires_in)
+            access_token.save()
+            token['expires_in'] = expires_in
+
+
+def _get_utc_now():
+    """
+    Return current time in UTC.
+    """
+    return datetime.utcnow().replace(tzinfo=utc)
diff --git a/openedx/core/djangoapps/oauth_dispatch/tests/test_api.py b/openedx/core/djangoapps/oauth_dispatch/tests/test_api.py
index 247bc5fa6fd..aba40a15c85 100644
--- a/openedx/core/djangoapps/oauth_dispatch/tests/test_api.py
+++ b/openedx/core/djangoapps/oauth_dispatch/tests/test_api.py
@@ -30,7 +30,6 @@ class TestOAuthDispatchAPI(TestCase):
             redirect_uri=DUMMY_REDIRECT_URL,
             client_id='public-client-id',
         )
-        self.request = HttpRequest()
 
     def _assert_stored_token(self, stored_token_value, expected_token_user, expected_client):
         stored_access_token = AccessToken.objects.get(token=stored_token_value)
@@ -39,7 +38,7 @@ class TestOAuthDispatchAPI(TestCase):
         self.assertEqual(stored_access_token.application.user.id, expected_client.user.id)
 
     def test_create_token_success(self):
-        token = api.create_dot_access_token(self.request, self.user, self.client)
+        token = api.create_dot_access_token(HttpRequest(), self.user, self.client)
         self.assertTrue(token['access_token'])
         self.assertTrue(token['refresh_token'])
         self.assertDictContainsSubset(
@@ -54,20 +53,18 @@ class TestOAuthDispatchAPI(TestCase):
 
     def test_create_token_another_user(self):
         another_user = UserFactory()
-        token = api.create_dot_access_token(self.request, another_user, self.client)
+        token = api.create_dot_access_token(HttpRequest(), another_user, self.client)
         self._assert_stored_token(token['access_token'], another_user, self.client)
 
     def test_create_token_overrides(self):
         expires_in = 4800
-        token = api.create_dot_access_token(self.request, self.user, self.client, expires_in=expires_in, scope=2)
+        token = api.create_dot_access_token(HttpRequest(), self.user, self.client, expires_in=expires_in, scope=2)
         self.assertDictContainsSubset({u'scope': u'profile'}, token)
-        with self.assertRaises(AssertionError):  # TODO (ARCH-246) expiration override does not actually work
-            self.assertDictContainsSubset({u'expires_in': expires_in}, token)
-        self.assertDictContainsSubset({u'expires_in': EXPECTED_DEFAULT_EXPIRES_IN}, token)
+        self.assertDictContainsSubset({u'expires_in': expires_in}, token)
 
     def test_refresh_token_success(self):
-        old_token = api.create_dot_access_token(self.request, self.user, self.client)
-        new_token = api.refresh_dot_access_token(self.request, self.client.client_id, old_token['refresh_token'])
+        old_token = api.create_dot_access_token(HttpRequest(), self.user, self.client)
+        new_token = api.refresh_dot_access_token(HttpRequest(), self.client.client_id, old_token['refresh_token'])
         self.assertDictContainsSubset(
             {
                 u'token_type': u'Bearer',
@@ -87,17 +84,17 @@ class TestOAuthDispatchAPI(TestCase):
         self._assert_stored_token(new_token['access_token'], self.user, self.client)
 
     def test_refresh_token_invalid_client(self):
-        token = api.create_dot_access_token(self.request, self.user, self.client)
+        token = api.create_dot_access_token(HttpRequest(), self.user, self.client)
         with self.assertRaises(api.OAuth2Error) as error:
             api.refresh_dot_access_token(
-                self.request, 'invalid_client_id', token['refresh_token'],
+                HttpRequest(), 'invalid_client_id', token['refresh_token'],
             )
         self.assertIn('invalid_client', error.exception.description)
 
     def test_refresh_token_invalid_token(self):
-        api.create_dot_access_token(self.request, self.user, self.client)
+        api.create_dot_access_token(HttpRequest(), self.user, self.client)
         with self.assertRaises(api.OAuth2Error) as error:
             api.refresh_dot_access_token(
-                self.request, self.client.client_id, 'invalid_refresh_token',
+                HttpRequest(), self.client.client_id, 'invalid_refresh_token',
             )
         self.assertIn('invalid_grant', error.exception.description)
-- 
GitLab