From 13d4091a1afb04f4ec5f37a9a0d76a28870f698e Mon Sep 17 00:00:00 2001 From: Nimisha Asthagiri <nasthagiri@edx.org> Date: Mon, 8 Oct 2018 11:29:38 -0400 Subject: [PATCH] Fix overriding of token expiration in DOT (ARCH-246) --- openedx/core/djangoapps/oauth_dispatch/api.py | 21 ++++--- .../dot_overrides/validators.py | 60 ++++++++++++++----- .../oauth_dispatch/tests/test_api.py | 23 ++++--- 3 files changed, 66 insertions(+), 38 deletions(-) diff --git a/openedx/core/djangoapps/oauth_dispatch/api.py b/openedx/core/djangoapps/oauth_dispatch/api.py index e1603a5ba2a..0c56bf901c8 100644 --- a/openedx/core/djangoapps/oauth_dispatch/api.py +++ b/openedx/core/djangoapps/oauth_dispatch/api.py @@ -7,7 +7,6 @@ from oauthlib.oauth2.rfc6749.errors import OAuth2Error from oauthlib.oauth2.rfc6749.tokens import BearerToken from oauth2_provider.models import AccessToken as dot_access_token from oauth2_provider.models import RefreshToken as dot_refresh_token -from oauth2_provider.oauth2_backends import get_oauthlib_core from oauth2_provider.settings import oauth2_settings as dot_settings from provider.oauth2.models import AccessToken as dop_access_token from provider.oauth2.models import RefreshToken as dop_refresh_token @@ -51,8 +50,8 @@ def refresh_dot_access_token(request, client_id, refresh_token, expires_in=None) Create and return a new (persisted) access token, given a previously created refresh_token, possibly returned from create_dot_access_token above. """ - auth_core = get_oauthlib_core() expires_in = _get_expires_in_value(expires_in) + auth_core = _get_oauthlib_core(expires_in) _populate_refresh_token_request(request, client_id, refresh_token) # Note: Unlike create_dot_access_token, we use the top-level auth library @@ -70,13 +69,7 @@ def _get_expires_in_value(expires_in): """ Returns the expires_in value to use for the token. """ - # TODO (ARCH-246) Fix expiration configuration as this does not actually - # override the token's expiration. Rather, DOT's save_bearer_token method - # will always use dot_settings.ACCESS_TOKEN_EXPIRE_SECONDS. - if not expires_in: - seconds_in_a_day = 24 * 60 * 60 - expires_in = settings.OAUTH_EXPIRE_PUBLIC_CLIENT_DAYS * seconds_in_a_day - return expires_in + return expires_in or dot_settings.ACCESS_TOKEN_EXPIRE_SECONDS def _populate_create_access_token_request(request, user, client, scope=None): @@ -105,3 +98,13 @@ def _populate_refresh_token_request(request, client_id, refresh_token): refresh_token=refresh_token, grant_type='refresh_token', ) + + +def _get_oauthlib_core(expires_in): + """ + Based on oauth2_provider.oauth2_backends.get_oauthlib_core, but allows + passing in a value for token_expires_in. + """ + validator = dot_settings.OAUTH2_VALIDATOR_CLASS() + server = dot_settings.OAUTH2_SERVER_CLASS(validator, token_expires_in=expires_in) + return dot_settings.OAUTH2_BACKEND_CLASS(server) diff --git a/openedx/core/djangoapps/oauth_dispatch/dot_overrides/validators.py b/openedx/core/djangoapps/oauth_dispatch/dot_overrides/validators.py index bc265a035cd..5c9aa5ddb65 100644 --- a/openedx/core/djangoapps/oauth_dispatch/dot_overrides/validators.py +++ b/openedx/core/djangoapps/oauth_dispatch/dot_overrides/validators.py @@ -3,7 +3,7 @@ Classes that override default django-oauth-toolkit behavior """ from __future__ import unicode_literals -from datetime import datetime +from datetime import datetime, timedelta from django.contrib.auth import authenticate, get_user_model from django.db.models.signals import pre_save @@ -82,21 +82,9 @@ class EdxOAuth2Validator(OAuth2Validator): super(EdxOAuth2Validator, self).save_bearer_token(token, request, *args, **kwargs) - if RestrictedApplication.should_expire_access_token(request.client): - # Since RestrictedApplications will override the DOT defined expiry, so that access_tokens - # are always expired, we need to re-read the token from the database and then calculate the - # expires_in (in seconds) from what we stored in the database. This value should be a negative - #value, meaning that it is already expired - - access_token = AccessToken.objects.get(token=token['access_token']) - utc_now = datetime.utcnow().replace(tzinfo=utc) - expires_in = (access_token.expires - utc_now).total_seconds() - - # assert that RestrictedApplications only issue expired tokens - # blow up processing if we see otherwise - assert expires_in < 0 - - token['expires_in'] = expires_in + is_restricted_client = self._update_token_expiry_if_restricted_client(token, request.client) + if not is_restricted_client: + self._update_token_expiry_if_overridden_in_request(token, request) # Restore the original request attributes request.grant_type = grant_type @@ -108,3 +96,43 @@ class EdxOAuth2Validator(OAuth2Validator): """ available_scopes = get_scopes_backend().get_available_scopes(application=client, request=request) return set(scopes).issubset(set(available_scopes)) + + def _update_token_expiry_if_restricted_client(self, token, client): + """ + Update the token's expires_in value if the given client is a + RestrictedApplication and return whether the given client is restricted. + """ + # Since RestrictedApplications override the DOT defined expiry such that + # access_tokens are always expired, re-read the token from the database + # and calculate expires_in (in seconds) from the database value. This + # value should be a negative value, meaning that it is already expired. + if RestrictedApplication.should_expire_access_token(client): + access_token = AccessToken.objects.get(token=token['access_token']) + expires_in = (access_token.expires - _get_utc_now()).total_seconds() + assert expires_in < 0 + token['expires_in'] = expires_in + return True + + def _update_token_expiry_if_overridden_in_request(self, token, request): + """ + Update the token's expires_in value if the request specifies an + expiration value and update the expires value on the stored AccessToken + object. + + This is needed since DOT's save_bearer_token method always uses + the dot_settings.ACCESS_TOKEN_EXPIRE_SECONDS value instead of applying + the requesting expiration value. + """ + expires_in = getattr(request, 'expires_in', None) + if expires_in: + access_token = AccessToken.objects.get(token=token['access_token']) + access_token.expires = _get_utc_now() + timedelta(seconds=expires_in) + access_token.save() + token['expires_in'] = expires_in + + +def _get_utc_now(): + """ + Return current time in UTC. + """ + return datetime.utcnow().replace(tzinfo=utc) diff --git a/openedx/core/djangoapps/oauth_dispatch/tests/test_api.py b/openedx/core/djangoapps/oauth_dispatch/tests/test_api.py index 247bc5fa6fd..aba40a15c85 100644 --- a/openedx/core/djangoapps/oauth_dispatch/tests/test_api.py +++ b/openedx/core/djangoapps/oauth_dispatch/tests/test_api.py @@ -30,7 +30,6 @@ class TestOAuthDispatchAPI(TestCase): redirect_uri=DUMMY_REDIRECT_URL, client_id='public-client-id', ) - self.request = HttpRequest() def _assert_stored_token(self, stored_token_value, expected_token_user, expected_client): stored_access_token = AccessToken.objects.get(token=stored_token_value) @@ -39,7 +38,7 @@ class TestOAuthDispatchAPI(TestCase): self.assertEqual(stored_access_token.application.user.id, expected_client.user.id) def test_create_token_success(self): - token = api.create_dot_access_token(self.request, self.user, self.client) + token = api.create_dot_access_token(HttpRequest(), self.user, self.client) self.assertTrue(token['access_token']) self.assertTrue(token['refresh_token']) self.assertDictContainsSubset( @@ -54,20 +53,18 @@ class TestOAuthDispatchAPI(TestCase): def test_create_token_another_user(self): another_user = UserFactory() - token = api.create_dot_access_token(self.request, another_user, self.client) + token = api.create_dot_access_token(HttpRequest(), another_user, self.client) self._assert_stored_token(token['access_token'], another_user, self.client) def test_create_token_overrides(self): expires_in = 4800 - token = api.create_dot_access_token(self.request, self.user, self.client, expires_in=expires_in, scope=2) + token = api.create_dot_access_token(HttpRequest(), self.user, self.client, expires_in=expires_in, scope=2) self.assertDictContainsSubset({u'scope': u'profile'}, token) - with self.assertRaises(AssertionError): # TODO (ARCH-246) expiration override does not actually work - self.assertDictContainsSubset({u'expires_in': expires_in}, token) - self.assertDictContainsSubset({u'expires_in': EXPECTED_DEFAULT_EXPIRES_IN}, token) + self.assertDictContainsSubset({u'expires_in': expires_in}, token) def test_refresh_token_success(self): - old_token = api.create_dot_access_token(self.request, self.user, self.client) - new_token = api.refresh_dot_access_token(self.request, self.client.client_id, old_token['refresh_token']) + old_token = api.create_dot_access_token(HttpRequest(), self.user, self.client) + new_token = api.refresh_dot_access_token(HttpRequest(), self.client.client_id, old_token['refresh_token']) self.assertDictContainsSubset( { u'token_type': u'Bearer', @@ -87,17 +84,17 @@ class TestOAuthDispatchAPI(TestCase): self._assert_stored_token(new_token['access_token'], self.user, self.client) def test_refresh_token_invalid_client(self): - token = api.create_dot_access_token(self.request, self.user, self.client) + token = api.create_dot_access_token(HttpRequest(), self.user, self.client) with self.assertRaises(api.OAuth2Error) as error: api.refresh_dot_access_token( - self.request, 'invalid_client_id', token['refresh_token'], + HttpRequest(), 'invalid_client_id', token['refresh_token'], ) self.assertIn('invalid_client', error.exception.description) def test_refresh_token_invalid_token(self): - api.create_dot_access_token(self.request, self.user, self.client) + api.create_dot_access_token(HttpRequest(), self.user, self.client) with self.assertRaises(api.OAuth2Error) as error: api.refresh_dot_access_token( - self.request, self.client.client_id, 'invalid_refresh_token', + HttpRequest(), self.client.client_id, 'invalid_refresh_token', ) self.assertIn('invalid_grant', error.exception.description) -- GitLab