Skip to content
Snippets Groups Projects
Commit 13d4091a authored by Nimisha Asthagiri's avatar Nimisha Asthagiri
Browse files

Fix overriding of token expiration in DOT (ARCH-246)

parent f9488a85
Branches
Tags
No related merge requests found
......@@ -7,7 +7,6 @@ from oauthlib.oauth2.rfc6749.errors import OAuth2Error
from oauthlib.oauth2.rfc6749.tokens import BearerToken
from oauth2_provider.models import AccessToken as dot_access_token
from oauth2_provider.models import RefreshToken as dot_refresh_token
from oauth2_provider.oauth2_backends import get_oauthlib_core
from oauth2_provider.settings import oauth2_settings as dot_settings
from provider.oauth2.models import AccessToken as dop_access_token
from provider.oauth2.models import RefreshToken as dop_refresh_token
......@@ -51,8 +50,8 @@ def refresh_dot_access_token(request, client_id, refresh_token, expires_in=None)
Create and return a new (persisted) access token, given a previously created
refresh_token, possibly returned from create_dot_access_token above.
"""
auth_core = get_oauthlib_core()
expires_in = _get_expires_in_value(expires_in)
auth_core = _get_oauthlib_core(expires_in)
_populate_refresh_token_request(request, client_id, refresh_token)
# Note: Unlike create_dot_access_token, we use the top-level auth library
......@@ -70,13 +69,7 @@ def _get_expires_in_value(expires_in):
"""
Returns the expires_in value to use for the token.
"""
# TODO (ARCH-246) Fix expiration configuration as this does not actually
# override the token's expiration. Rather, DOT's save_bearer_token method
# will always use dot_settings.ACCESS_TOKEN_EXPIRE_SECONDS.
if not expires_in:
seconds_in_a_day = 24 * 60 * 60
expires_in = settings.OAUTH_EXPIRE_PUBLIC_CLIENT_DAYS * seconds_in_a_day
return expires_in
return expires_in or dot_settings.ACCESS_TOKEN_EXPIRE_SECONDS
def _populate_create_access_token_request(request, user, client, scope=None):
......@@ -105,3 +98,13 @@ def _populate_refresh_token_request(request, client_id, refresh_token):
refresh_token=refresh_token,
grant_type='refresh_token',
)
def _get_oauthlib_core(expires_in):
"""
Based on oauth2_provider.oauth2_backends.get_oauthlib_core, but allows
passing in a value for token_expires_in.
"""
validator = dot_settings.OAUTH2_VALIDATOR_CLASS()
server = dot_settings.OAUTH2_SERVER_CLASS(validator, token_expires_in=expires_in)
return dot_settings.OAUTH2_BACKEND_CLASS(server)
......@@ -3,7 +3,7 @@ Classes that override default django-oauth-toolkit behavior
"""
from __future__ import unicode_literals
from datetime import datetime
from datetime import datetime, timedelta
from django.contrib.auth import authenticate, get_user_model
from django.db.models.signals import pre_save
......@@ -82,21 +82,9 @@ class EdxOAuth2Validator(OAuth2Validator):
super(EdxOAuth2Validator, self).save_bearer_token(token, request, *args, **kwargs)
if RestrictedApplication.should_expire_access_token(request.client):
# Since RestrictedApplications will override the DOT defined expiry, so that access_tokens
# are always expired, we need to re-read the token from the database and then calculate the
# expires_in (in seconds) from what we stored in the database. This value should be a negative
#value, meaning that it is already expired
access_token = AccessToken.objects.get(token=token['access_token'])
utc_now = datetime.utcnow().replace(tzinfo=utc)
expires_in = (access_token.expires - utc_now).total_seconds()
# assert that RestrictedApplications only issue expired tokens
# blow up processing if we see otherwise
assert expires_in < 0
token['expires_in'] = expires_in
is_restricted_client = self._update_token_expiry_if_restricted_client(token, request.client)
if not is_restricted_client:
self._update_token_expiry_if_overridden_in_request(token, request)
# Restore the original request attributes
request.grant_type = grant_type
......@@ -108,3 +96,43 @@ class EdxOAuth2Validator(OAuth2Validator):
"""
available_scopes = get_scopes_backend().get_available_scopes(application=client, request=request)
return set(scopes).issubset(set(available_scopes))
def _update_token_expiry_if_restricted_client(self, token, client):
"""
Update the token's expires_in value if the given client is a
RestrictedApplication and return whether the given client is restricted.
"""
# Since RestrictedApplications override the DOT defined expiry such that
# access_tokens are always expired, re-read the token from the database
# and calculate expires_in (in seconds) from the database value. This
# value should be a negative value, meaning that it is already expired.
if RestrictedApplication.should_expire_access_token(client):
access_token = AccessToken.objects.get(token=token['access_token'])
expires_in = (access_token.expires - _get_utc_now()).total_seconds()
assert expires_in < 0
token['expires_in'] = expires_in
return True
def _update_token_expiry_if_overridden_in_request(self, token, request):
"""
Update the token's expires_in value if the request specifies an
expiration value and update the expires value on the stored AccessToken
object.
This is needed since DOT's save_bearer_token method always uses
the dot_settings.ACCESS_TOKEN_EXPIRE_SECONDS value instead of applying
the requesting expiration value.
"""
expires_in = getattr(request, 'expires_in', None)
if expires_in:
access_token = AccessToken.objects.get(token=token['access_token'])
access_token.expires = _get_utc_now() + timedelta(seconds=expires_in)
access_token.save()
token['expires_in'] = expires_in
def _get_utc_now():
"""
Return current time in UTC.
"""
return datetime.utcnow().replace(tzinfo=utc)
......@@ -30,7 +30,6 @@ class TestOAuthDispatchAPI(TestCase):
redirect_uri=DUMMY_REDIRECT_URL,
client_id='public-client-id',
)
self.request = HttpRequest()
def _assert_stored_token(self, stored_token_value, expected_token_user, expected_client):
stored_access_token = AccessToken.objects.get(token=stored_token_value)
......@@ -39,7 +38,7 @@ class TestOAuthDispatchAPI(TestCase):
self.assertEqual(stored_access_token.application.user.id, expected_client.user.id)
def test_create_token_success(self):
token = api.create_dot_access_token(self.request, self.user, self.client)
token = api.create_dot_access_token(HttpRequest(), self.user, self.client)
self.assertTrue(token['access_token'])
self.assertTrue(token['refresh_token'])
self.assertDictContainsSubset(
......@@ -54,20 +53,18 @@ class TestOAuthDispatchAPI(TestCase):
def test_create_token_another_user(self):
another_user = UserFactory()
token = api.create_dot_access_token(self.request, another_user, self.client)
token = api.create_dot_access_token(HttpRequest(), another_user, self.client)
self._assert_stored_token(token['access_token'], another_user, self.client)
def test_create_token_overrides(self):
expires_in = 4800
token = api.create_dot_access_token(self.request, self.user, self.client, expires_in=expires_in, scope=2)
token = api.create_dot_access_token(HttpRequest(), self.user, self.client, expires_in=expires_in, scope=2)
self.assertDictContainsSubset({u'scope': u'profile'}, token)
with self.assertRaises(AssertionError): # TODO (ARCH-246) expiration override does not actually work
self.assertDictContainsSubset({u'expires_in': expires_in}, token)
self.assertDictContainsSubset({u'expires_in': EXPECTED_DEFAULT_EXPIRES_IN}, token)
self.assertDictContainsSubset({u'expires_in': expires_in}, token)
def test_refresh_token_success(self):
old_token = api.create_dot_access_token(self.request, self.user, self.client)
new_token = api.refresh_dot_access_token(self.request, self.client.client_id, old_token['refresh_token'])
old_token = api.create_dot_access_token(HttpRequest(), self.user, self.client)
new_token = api.refresh_dot_access_token(HttpRequest(), self.client.client_id, old_token['refresh_token'])
self.assertDictContainsSubset(
{
u'token_type': u'Bearer',
......@@ -87,17 +84,17 @@ class TestOAuthDispatchAPI(TestCase):
self._assert_stored_token(new_token['access_token'], self.user, self.client)
def test_refresh_token_invalid_client(self):
token = api.create_dot_access_token(self.request, self.user, self.client)
token = api.create_dot_access_token(HttpRequest(), self.user, self.client)
with self.assertRaises(api.OAuth2Error) as error:
api.refresh_dot_access_token(
self.request, 'invalid_client_id', token['refresh_token'],
HttpRequest(), 'invalid_client_id', token['refresh_token'],
)
self.assertIn('invalid_client', error.exception.description)
def test_refresh_token_invalid_token(self):
api.create_dot_access_token(self.request, self.user, self.client)
api.create_dot_access_token(HttpRequest(), self.user, self.client)
with self.assertRaises(api.OAuth2Error) as error:
api.refresh_dot_access_token(
self.request, self.client.client_id, 'invalid_refresh_token',
HttpRequest(), self.client.client_id, 'invalid_refresh_token',
)
self.assertIn('invalid_grant', error.exception.description)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment