Skip to content
Snippets Groups Projects
Unverified Commit 167d8f27 authored by Binod Pant's avatar Binod Pant Committed by GitHub
Browse files

ENT-3007 : round 2 API endpoints for samlproviderconfig and samlproviderdata (#24456)

* ENT-3007 auth/saml/v0/saml/providerdata and auth/saml/v0/saml/providerconfig endpoints

Move code to subfolder for samlproviderconfig

extra comma

undo accidental remove of import

GET works for a single config now

Use ModelViewSet to get all CRUD method. Test still fails

Add auth/saml/v0/providerdata endpoints

fixup reverse and test issue, remove leading caret

just triggering run, why is it failing in CI?

pycodelint fixes

Skip auth tests unless feature is on

Tests for post/put for samlproviderdata

move urls to their own folders

api tests for post samlprovierconfig

create 1 providerconfig test case

lint fixes

lint

lint

cleanup code local urls /samlproviderconfig works

note needed right now

Fix import errors

lint

unused import

wip: first attempt at rbac auth and jwt cookie in test

round 2 with enterprise uuid as url param for samlproviderconfig

improve tests, still dont pass

fix test by using system role, wip other test

fix create test

add get/post tests for providerdata

isort fixes

string lint fix

Cleanup based on feedback round1

move utils to tests package

Move util fn to openedx.feature area

lint

ENT-3007 : Round 2 of work on auth/saml/v0/providerconfig and auth/saml/v0/providerdata endpoints

* Fix test issue use string uuid for permission obj

* snake case changes provider_config

* snake case

* provider_data, tests and lint

* patch and delete tests for providerdata

* snake_case

* snake_case

* snake_case

* make patch test stronger

* 404 if invalid uuid for get param

* common util for validate uuid4

* unused import

* lint fixes for pycodestyle

* 400 when uuid is missing

* 400 instead of 404 for missing uuid

* spell fix

* update docstring for api usage

* docstring clarify
parent 1cba2a00
No related branches found
Tags release-2021-01-19-10.59
No related merge requests found
Showing
with 551 additions and 3 deletions
"""
Serializer for SAMLProviderConfig
"""
from rest_framework import serializers
from third_party_auth.models import SAMLProviderConfig
class SAMLProviderConfigSerializer(serializers.ModelSerializer):
class Meta:
model = SAMLProviderConfig
fields = '__all__'
"""
Tests for SAMLProviderConfig endpoints
"""
import unittest
import copy
from uuid import uuid4
from django.urls import reverse
from django.contrib.sites.models import Site
from django.contrib.auth.models import User
from django.utils.http import urlencode
from rest_framework import status
from rest_framework.test import APITestCase
from enterprise.models import EnterpriseCustomerIdentityProvider, EnterpriseCustomer
from enterprise.constants import ENTERPRISE_ADMIN_ROLE
from third_party_auth.tests.samlutils import set_jwt_cookie
from third_party_auth.models import SAMLProviderConfig
from third_party_auth.tests import testutil
SINGLE_PROVIDER_CONFIG = {
'entity_id': 'id',
'metadata_source': 'http://test.url',
'name': 'name-of-config',
'enabled': 'true',
'slug': 'test-slug'
}
SINGLE_PROVIDER_CONFIG_2 = copy.copy(SINGLE_PROVIDER_CONFIG)
SINGLE_PROVIDER_CONFIG_2['name'] = 'name-of-config-2'
SINGLE_PROVIDER_CONFIG_2['slug'] = 'test-slug-2'
ENTERPRISE_ID = str(uuid4())
@unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, testutil.AUTH_FEATURES_KEY + ' not enabled')
class SAMLProviderConfigTests(APITestCase):
"""
API Tests for SAMLProviderConfig REST endpoints
The skip annotation above exists because we currently cannot run this test in
the cms mode in CI builds, where the third_party_auth application is not loaded
"""
@classmethod
def setUpTestData(cls):
super(SAMLProviderConfigTests, cls).setUpTestData()
cls.user = User.objects.create_user(username='testuser', password='testpwd')
cls.site, _ = Site.objects.get_or_create(domain='example.com')
cls.enterprise_customer = EnterpriseCustomer.objects.create(
uuid=ENTERPRISE_ID,
name='test-ep',
slug='test-ep',
site=cls.site)
cls.samlproviderconfig, _ = SAMLProviderConfig.objects.get_or_create(
entity_id=SINGLE_PROVIDER_CONFIG['entity_id'],
metadata_source=SINGLE_PROVIDER_CONFIG['metadata_source']
)
cls.enterprisecustomeridp, _ = EnterpriseCustomerIdentityProvider.objects.get_or_create(
provider_id=cls.samlproviderconfig.id,
enterprise_customer_id=ENTERPRISE_ID
)
def setUp(self):
set_jwt_cookie(self.client, self.user, [(ENTERPRISE_ADMIN_ROLE, ENTERPRISE_ID)])
self.client.force_authenticate(user=self.user)
def test_get_one_config_by_enterprise_uuid_found(self):
"""
GET auth/saml/v0/provider_config/?enterprise_customer_uuid=id=id
"""
urlbase = reverse('saml_provider_config-list')
query_kwargs = {'enterprise_customer_uuid': ENTERPRISE_ID}
url = '{}?{}'.format(urlbase, urlencode(query_kwargs))
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
results = response.data['results']
self.assertEqual(len(results), 1)
self.assertEqual(results[0]['entity_id'], SINGLE_PROVIDER_CONFIG['entity_id'])
self.assertEqual(results[0]['metadata_source'], SINGLE_PROVIDER_CONFIG['metadata_source'])
self.assertEqual(SAMLProviderConfig.objects.count(), 1)
def test_get_one_config_by_enterprise_uuid_invalid_uuid(self):
"""
GET auth/saml/v0/provider_config/?enterprise_customer_uuid=invalidUUID
"""
urlbase = reverse('saml_provider_config-list')
query_kwargs = {'enterprise_customer_uuid': 'invalid_uuid'}
url = '{}?{}'.format(urlbase, urlencode(query_kwargs))
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
def test_get_one_config_by_enterprise_uuid_not_found(self):
"""
GET auth/saml/v0/provider_config/?enterprise_customer_uuid=id=id
"""
urlbase = reverse('saml_provider_config-list')
query_kwargs = {'enterprise_customer_uuid': 'abc-notfound'}
url = '{}?{}'.format(urlbase, urlencode(query_kwargs))
orig_count = SAMLProviderConfig.objects.count()
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(SAMLProviderConfig.objects.count(), orig_count)
def test_create_one_config(self):
"""
POST auth/saml/v0/provider_config/ -d data
"""
url = reverse('saml_provider_config-list')
data = copy.copy(SINGLE_PROVIDER_CONFIG_2)
data['enterprise_customer_uuid'] = ENTERPRISE_ID
orig_count = SAMLProviderConfig.objects.count()
response = self.client.post(url, data)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(SAMLProviderConfig.objects.count(), orig_count + 1)
provider_config = SAMLProviderConfig.objects.get(slug=SINGLE_PROVIDER_CONFIG_2['slug'])
self.assertEqual(provider_config.name, 'name-of-config-2')
def test_create_one_config_with_absent_enterprise_uuid(self):
"""
POST auth/saml/v0/provider_config/ -d data
"""
url = reverse('saml_provider_config-list')
data = copy.copy(SINGLE_PROVIDER_CONFIG_2)
orig_count = SAMLProviderConfig.objects.count()
response = self.client.post(url, data)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(SAMLProviderConfig.objects.count(), orig_count)
"""
Viewset for auth/saml/v0/providerconfig/
"""
from rest_framework import routers
from .views import SAMLProviderConfigViewSet
saml_provider_config_router = routers.DefaultRouter()
saml_provider_config_router.register(r'provider_config', SAMLProviderConfigViewSet, basename="saml_provider_config")
urlpatterns = saml_provider_config_router.urls
"""
Viewset for auth/saml/v0/samlproviderconfig
"""
from django.shortcuts import get_object_or_404
from edx_rbac.mixins import PermissionRequiredMixin
from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication
from rest_framework import permissions, viewsets
from rest_framework.authentication import SessionAuthentication
from rest_framework.exceptions import ParseError
from enterprise.models import EnterpriseCustomerIdentityProvider
from third_party_auth.utils import validate_uuid4_string
from ..models import SAMLProviderConfig
from .serializers import SAMLProviderConfigSerializer
class SAMLProviderMixin(object):
authentication_classes = [JwtAuthentication, SessionAuthentication]
permission_classes = [permissions.IsAuthenticated]
serializer_class = SAMLProviderConfigSerializer
class SAMLProviderConfigViewSet(PermissionRequiredMixin, SAMLProviderMixin, viewsets.ModelViewSet):
"""
A View to handle SAMLProviderConfig CRUD
Usage:
NOTE: Only the GET request requires a request parameter, otherwise pass the uuid as part
of the post body
GET /auth/saml/v0/provider_config/?enterprise-id=uuid
POST /auth/saml/v0/provider_config/ -d postData (must contain 'enterprise_customer_uuid')
DELETE /auth/saml/v0/provider_config/:pk -d postData (must contain 'enterprise_customer_uuid')
PATCH /auth/saml/v0/provider_config/:pk -d postData (must contain 'enterprise_customer_uuid')
permission_required refers to the Django permission name defined
in enterprise.rules.
The associated rule will allow edx-rbac to check if the EnterpriseCustomer
returned by the get_permission_object method here, can be
accessed by the user making this request (request.user)
Access is only allowed if the user has the system role
of 'ENTERPRISE_ADMIN' which is defined in enterprise.constants
"""
permission_required = 'enterprise.can_access_admin_dashboard'
def get_queryset(self):
"""
Find and return the matching providerconfig for the given enterprise uuid
if an association exists in EnterpriseCustomerIdentityProvider model
"""
if self.requested_enterprise_uuid is None:
raise ParseError('Required enterprise_customer_uuid is missing')
enterprise_customer_idp = get_object_or_404(
EnterpriseCustomerIdentityProvider,
enterprise_customer__uuid=self.requested_enterprise_uuid
)
return SAMLProviderConfig.objects.filter(pk=enterprise_customer_idp.provider_id)
@property
def requested_enterprise_uuid(self):
"""
The enterprise customer uuid from request params or post body
"""
if self.request.method == "POST":
uuid_str = self.request.POST.get('enterprise_customer_uuid')
if uuid_str is None:
raise ParseError('Required enterprise_customer_uuid is missing')
return uuid_str
else:
uuid_str = self.request.query_params.get('enterprise_customer_uuid')
if validate_uuid4_string(uuid_str) is False:
raise ParseError('Invalid UUID enterprise_customer_id')
return uuid_str
def get_permission_object(self):
"""
Retrieve an EnterpriseCustomer uuid to do auth against
Right now this is the same as from the request object
meaning that only users belonging to the same enterprise
can access these endpoints, we have to sort out the operator role use case
"""
return self.requested_enterprise_uuid
"""
Serializer for SAMLProviderData
"""
from rest_framework import serializers
from third_party_auth.models import SAMLProviderData
class SAMLProviderDataSerializer(serializers.ModelSerializer):
class Meta:
model = SAMLProviderData
fields = '__all__'
import unittest
import copy
import pytz
from uuid import uuid4
from datetime import datetime
from django.contrib.sites.models import Site
from django.contrib.auth.models import User
from django.urls import reverse
from django.utils.http import urlencode
from rest_framework import status
from rest_framework.test import APITestCase
from enterprise.models import EnterpriseCustomer, EnterpriseCustomerIdentityProvider
from enterprise.constants import ENTERPRISE_ADMIN_ROLE
from third_party_auth.tests import testutil
from third_party_auth.models import SAMLProviderData, SAMLProviderConfig
from third_party_auth.tests.samlutils import set_jwt_cookie
SINGLE_PROVIDER_CONFIG = {
'entity_id': 'http://entity-id-1',
'metadata_source': 'http://test.url',
'name': 'name-of-config',
'enabled': 'true',
'slug': 'test-slug'
}
# entity_id here matches that of the providerconfig, intentionally
# that allows this data entity to be found
SINGLE_PROVIDER_DATA = {
'entity_id': 'http://entity-id-1',
'sso_url': 'http://test.url',
'public_key': 'a-key0Aid98',
'fetched_at': datetime.now(pytz.UTC).replace(microsecond=0)
}
SINGLE_PROVIDER_DATA_2 = copy.copy(SINGLE_PROVIDER_DATA)
SINGLE_PROVIDER_DATA_2['entity_id'] = 'http://entity-id-2'
SINGLE_PROVIDER_DATA_2['sso_url'] = 'http://test2.url'
ENTERPRISE_ID = str(uuid4())
@unittest.skipUnless(testutil.AUTH_FEATURE_ENABLED, testutil.AUTH_FEATURES_KEY + ' not enabled')
class SAMLProviderDataTests(APITestCase):
"""
API Tests for SAMLProviderConfig REST endpoints
"""
@classmethod
def setUpTestData(cls):
super(SAMLProviderDataTests, cls).setUpTestData()
cls.user = User.objects.create_user(username='testuser', password='testpwd')
cls.site, _ = Site.objects.get_or_create(domain='example.com')
cls.enterprise_customer = EnterpriseCustomer.objects.create(
uuid=ENTERPRISE_ID,
name='test-ep',
slug='test-ep',
site=cls.site)
cls.saml_provider_config, _ = SAMLProviderConfig.objects.get_or_create(
entity_id=SINGLE_PROVIDER_CONFIG['entity_id'],
metadata_source=SINGLE_PROVIDER_CONFIG['metadata_source']
)
# the entity_id here must match that of the saml_provider_config
cls.saml_provider_data, _ = SAMLProviderData.objects.get_or_create(
entity_id=SINGLE_PROVIDER_DATA['entity_id'],
sso_url=SINGLE_PROVIDER_DATA['sso_url'],
fetched_at=SINGLE_PROVIDER_DATA['fetched_at']
)
cls.enterprise_customer_idp, _ = EnterpriseCustomerIdentityProvider.objects.get_or_create(
provider_id=cls.saml_provider_config.id,
enterprise_customer_id=ENTERPRISE_ID
)
def setUp(self):
# a cookie with roles: [{enterprise_admin_role: ent_id}] will be
# needed to rbac to authorize access for this view
set_jwt_cookie(self.client, self.user, [(ENTERPRISE_ADMIN_ROLE, ENTERPRISE_ID)])
self.client.force_authenticate(user=self.user)
def test_get_one_provider_data_success(self):
# GET auth/saml/v0/providerdata/?enterprise_customer_uuid=id
url_base = reverse('saml_provider_data-list')
query_kwargs = {'enterprise_customer_uuid': ENTERPRISE_ID}
url = '{}?{}'.format(url_base, urlencode(query_kwargs))
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
results = response.data['results']
self.assertEqual(len(results), 1)
self.assertEqual(results[0]['sso_url'], SINGLE_PROVIDER_DATA['sso_url'])
def test_create_one_provider_data_success(self):
# POST auth/saml/v0/providerdata/ -d data
url = reverse('saml_provider_data-list')
data = copy.copy(SINGLE_PROVIDER_DATA_2)
data['enterprise_customer_uuid'] = ENTERPRISE_ID
orig_count = SAMLProviderData.objects.count()
response = self.client.post(url, data)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(SAMLProviderData.objects.count(), orig_count + 1)
self.assertEqual(
SAMLProviderData.objects.get(entity_id=SINGLE_PROVIDER_DATA_2['entity_id']).sso_url,
SINGLE_PROVIDER_DATA_2['sso_url']
)
def test_create_one_data_with_absent_enterprise_uuid(self):
"""
POST auth/saml/v0/provider_data/ -d data
"""
url = reverse('saml_provider_data-list')
data = copy.copy(SINGLE_PROVIDER_DATA_2)
orig_count = SAMLProviderData.objects.count()
response = self.client.post(url, data)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(SAMLProviderData.objects.count(), orig_count)
def test_patch_one_provider_data(self):
# PATCH auth/saml/v0/providerdata/ -d data
url = reverse('saml_provider_data-detail', kwargs={'pk': self.saml_provider_data.id})
data = {
'sso_url': 'http://new.url'
}
data['enterprise_customer_uuid'] = ENTERPRISE_ID
orig_count = SAMLProviderData.objects.count()
response = self.client.patch(url, data)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(SAMLProviderData.objects.count(), orig_count)
# ensure only the sso_url was updated
fetched_provider_data = SAMLProviderData.objects.get(pk=self.saml_provider_data.id)
self.assertEqual(fetched_provider_data.sso_url, 'http://new.url')
self.assertEqual(fetched_provider_data.fetched_at, SINGLE_PROVIDER_DATA['fetched_at'])
self.assertEqual(fetched_provider_data.entity_id, SINGLE_PROVIDER_DATA['entity_id'])
def test_delete_one_provider_data(self):
# DELETE auth/saml/v0/providerdata/ -d data
url = reverse('saml_provider_data-detail', kwargs={'pk': self.saml_provider_data.id})
data = {}
data['enterprise_customer_uuid'] = ENTERPRISE_ID
orig_count = SAMLProviderData.objects.count()
response = self.client.delete(url, data)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
self.assertEqual(SAMLProviderData.objects.count(), orig_count - 1)
# ensure only the sso_url was updated
query_set_count = SAMLProviderData.objects.filter(pk=self.saml_provider_data.id).count()
self.assertEqual(query_set_count, 0)
"""
url mappings for auth/saml/v0/providerdata/
"""
from rest_framework import routers
from .views import SAMLProviderDataViewSet
saml_provider_data_router = routers.DefaultRouter()
saml_provider_data_router.register(r'provider_data', SAMLProviderDataViewSet, basename="saml_provider_data")
urlpatterns = saml_provider_data_router.urls
"""
Viewset for auth/saml/v0/samlproviderdata
"""
from django.shortcuts import get_object_or_404
from edx_rbac.mixins import PermissionRequiredMixin
from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication
from rest_framework import permissions, viewsets
from rest_framework.authentication import SessionAuthentication
from rest_framework.exceptions import ParseError
from enterprise.models import EnterpriseCustomerIdentityProvider
from third_party_auth.utils import validate_uuid4_string
from ..models import SAMLProviderConfig, SAMLProviderData
from .serializers import SAMLProviderDataSerializer
class SAMLProviderDataMixin(object):
authentication_classes = [JwtAuthentication, SessionAuthentication]
permission_classes = [permissions.IsAuthenticated]
serializer_class = SAMLProviderDataSerializer
class SAMLProviderDataViewSet(PermissionRequiredMixin, SAMLProviderDataMixin, viewsets.ModelViewSet):
"""
A View to handle SAMLProviderData CRUD.
Uses the edx-rbac mixin PermissionRequiredMixin to apply enterprise authorization
Usage:
NOTE: Only the GET request requires a request parameter, otherwise pass the uuid as part
of the post body
GET /auth/saml/v0/provider_data/?enterprise-id=uuid
POST /auth/saml/v0/provider_data/ -d postData (must contain 'enterprise_customer_uuid')
DELETE /auth/saml/v0/provider_data/:pk -d postData (must contain 'enterprise_customer_uuid')
PATCH /auth/saml/v0/provider_data/:pk -d postData (must contain 'enterprise_customer_uuid')
"""
permission_required = 'enterprise.can_access_admin_dashboard'
def get_queryset(self):
"""
Find and return the matching providerid for the given enterprise uuid
Note: There is no direct association between samlproviderdata and enterprisecustomer.
So we make that association in code via samlproviderdata > samlproviderconfig ( via entity_id )
then, we fetch enterprisecustomer via samlproviderconfig > enterprisecustomer ( via association table )
"""
if self.requested_enterprise_uuid is None:
raise ParseError('Required enterprise_customer_uuid is missing')
enterprise_customer_idp = get_object_or_404(
EnterpriseCustomerIdentityProvider,
enterprise_customer__uuid=self.requested_enterprise_uuid
)
saml_provider = SAMLProviderConfig.objects.get(pk=enterprise_customer_idp.provider_id)
return SAMLProviderData.objects.filter(entity_id=saml_provider.entity_id)
@property
def requested_enterprise_uuid(self):
"""
The enterprise customer uuid from request params or post body
"""
if self.request.method in ('POST', 'PATCH', 'DELETE'):
uuid_str = self.request.POST.get('enterprise_customer_uuid')
if uuid_str is None:
raise ParseError('Required enterprise_customer_uuid is missing')
return uuid_str
else:
uuid_str = self.request.query_params.get('enterprise_customer_uuid')
if validate_uuid4_string(uuid_str) is False:
raise ParseError('Invalid UUID enterprise_customer_id')
return uuid_str
def get_permission_object(self):
"""
Retrieve an EnterpriseCustomer to do auth against
"""
return self.requested_enterprise_uuid
"""
Utility functions for use in SAMLProviderConfig, SAMLProviderData tests
"""
from edx_rest_framework_extensions.auth.jwt.cookies import jwt_cookie_name
from edx_rest_framework_extensions.auth.jwt.tests.utils import generate_jwt_token, generate_unversioned_payload
def _jwt_token_from_role_context_pairs(user, role_context_pairs):
"""
Generates a new JWT token with roles assigned from pairs of (role name, context).
"""
roles = []
for role, context in role_context_pairs:
role_data = '{role}'.format(role=role)
if context is not None:
role_data += ':{context}'.format(context=context)
roles.append(role_data)
payload = generate_unversioned_payload(user)
payload.update({'roles': roles})
return generate_jwt_token(payload)
def set_jwt_cookie(client, user, role_context_pairs=None):
"""
Set jwt token in cookies
"""
jwt_token = _jwt_token_from_role_context_pairs(user, role_context_pairs or [])
client.cookies[jwt_cookie_name()] = jwt_token
"""Url configuration for the auth module."""
from django.conf.urls import include, url
from .views import (
......@@ -18,4 +17,6 @@ urlpatterns = [
url(r'^auth/login/(?P<backend>lti)/$', lti_login_and_complete_view),
url(r'^auth/idp_redirect/(?P<provider_slug>[\w-]+)', IdPRedirectView.as_view(), name="idp_redirect"),
url(r'^auth/', include('social_django.urls', namespace='social')),
url(r'^auth/saml/v0/', include('third_party_auth.samlproviderconfig.urls')),
url(r'^auth/saml/v0/', include('third_party_auth.samlproviderdata.urls'))
]
......@@ -2,7 +2,7 @@
Utility functions for third_party_auth
"""
from uuid import UUID
from django.contrib.auth.models import User
......@@ -28,3 +28,14 @@ def user_exists(details):
return User.objects.filter(**user_queryset_filter).exists()
return False
def validate_uuid4_string(uuid_string):
"""
Returns True if valid uuid4 string, or False
"""
try:
UUID(uuid_string, version=4)
except ValueError:
return False
return True
......@@ -10,7 +10,7 @@ from django.conf import settings
from django.urls import NoReverseMatch, reverse
from django.utils.translation import ugettext as _
from edx_django_utils.cache import TieredCache, get_cache_key
from enterprise.models import EnterpriseCustomerUser
from enterprise.models import EnterpriseCustomerUser, EnterpriseCustomer
from social_django.models import UserSocialAuth
import third_party_auth
......@@ -342,3 +342,7 @@ def get_provider_login_url(request, provider_id, redirect_url=None):
redirect_url=redirect_url if redirect_url else get_next_url_for_login_page(request)
)
return provider_login_url
def fetch_enterprise_customer_by_id(enterprise_uuid):
return EnterpriseCustomer.objects.get(uuid=enterprise_uuid)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment