diff --git a/common/djangoapps/entitlements/api/v1/tests/test_views.py b/common/djangoapps/entitlements/api/v1/tests/test_views.py index a74ac51ecbdaef5b6c8442eb173b334cb68d9898..6b535ac642bb20d588f5872014d23704152ffdfc 100644 --- a/common/djangoapps/entitlements/api/v1/tests/test_views.py +++ b/common/djangoapps/entitlements/api/v1/tests/test_views.py @@ -1,4 +1,5 @@ import json +import logging import unittest import uuid from datetime import datetime, timedelta @@ -6,11 +7,16 @@ from datetime import datetime, timedelta import pytz from django.conf import settings from django.core.urlresolvers import reverse - -from student.tests.factories import (TEST_PASSWORD, CourseEnrollmentFactory, UserFactory) +from mock import patch +from opaque_keys.edx.locator import CourseKey from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase from xmodule.modulestore.tests.factories import CourseFactory +from student.models import CourseEnrollment +from student.tests.factories import (TEST_PASSWORD, CourseEnrollmentFactory, UserFactory) + +log = logging.getLogger(__name__) + # Entitlements is not in CMS' INSTALLED_APPS so these imports will error during test collection if settings.ROOT_URLCONF == 'lms.urls': from entitlements.tests.factories import CourseEntitlementFactory @@ -81,7 +87,7 @@ class EntitlementViewSetTest(ModuleStoreTestCase): not_staff_user = UserFactory() self.client.login(username=not_staff_user.username, password=TEST_PASSWORD) - course_entitlement = CourseEntitlementFactory() + course_entitlement = CourseEntitlementFactory.create() url = reverse(self.ENTITLEMENTS_DETAILS_PATH, args=[str(course_entitlement.uuid)]) response = self.client.delete( @@ -122,7 +128,7 @@ class EntitlementViewSetTest(ModuleStoreTestCase): results = response.data.get('results', []) # pylint: disable=no-member assert results == CourseEntitlementSerializer([entitlement], many=True).data - def test_staff_not_get_all_entitlements(self): + def test_staff_get_only_staff_entitlements(self): CourseEntitlementFactory.create_batch(2) entitlement = CourseEntitlementFactory.create(user=self.user) @@ -189,7 +195,7 @@ class EntitlementViewSetTest(ModuleStoreTestCase): assert results == CourseEntitlementSerializer([entitlement_user2], many=True).data def test_get_entitlement_by_uuid(self): - entitlement = CourseEntitlementFactory() + entitlement = CourseEntitlementFactory.create() CourseEntitlementFactory.create_batch(2) url = reverse(self.ENTITLEMENTS_DETAILS_PATH, args=[str(entitlement.uuid)]) @@ -253,3 +259,173 @@ class EntitlementViewSetTest(ModuleStoreTestCase): course_entitlement.refresh_from_db() assert course_entitlement.expired_at is not None assert course_entitlement.enrollment_course_run is None + + +@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') +class EntitlementEnrollmentViewSetTest(ModuleStoreTestCase): + """ + Tests for the EntitlementEnrollmentViewSets + """ + ENTITLEMENTS_ENROLLMENT_NAMESPACE = 'entitlements_api:v1:enrollments' + + def setUp(self): + super(EntitlementEnrollmentViewSetTest, self).setUp() + self.user = UserFactory() + self.client.login(username=self.user.username, password=TEST_PASSWORD) + self.course = CourseFactory.create(org='edX', number='DemoX', display_name='Demo_Course') + self.course2 = CourseFactory.create(org='edX', number='DemoX2', display_name='Demo_Course 2') + + self.return_values = [ + {'key': str(self.course.id)}, + {'key': str(self.course2.id)} + ] + + @patch("entitlements.api.v1.views.get_course_runs_for_course") + def test_user_can_enroll(self, mock_get_course_runs): + course_entitlement = CourseEntitlementFactory.create(user=self.user) + mock_get_course_runs.return_value = self.return_values + url = reverse( + self.ENTITLEMENTS_ENROLLMENT_NAMESPACE, + args=[str(course_entitlement.uuid)] + ) + assert course_entitlement.enrollment_course_run is None + + data = { + 'course_run_id': str(self.course.id) + } + response = self.client.post( + url, + data=json.dumps(data), + content_type='application/json', + ) + course_entitlement.refresh_from_db() + + assert response.status_code == 201 + assert CourseEnrollment.is_enrolled(self.user, self.course.id) + assert course_entitlement.enrollment_course_run is not None + + @patch("entitlements.api.v1.views.get_course_runs_for_course") + def test_user_can_unenroll(self, mock_get_course_runs): + course_entitlement = CourseEntitlementFactory.create(user=self.user) + mock_get_course_runs.return_value = self.return_values + + url = reverse( + self.ENTITLEMENTS_ENROLLMENT_NAMESPACE, + args=[str(course_entitlement.uuid)] + ) + assert course_entitlement.enrollment_course_run is None + + data = { + 'course_run_id': str(self.course.id) + } + response = self.client.post( + url, + data=json.dumps(data), + content_type='application/json', + ) + course_entitlement.refresh_from_db() + + assert response.status_code == 201 + assert CourseEnrollment.is_enrolled(self.user, self.course.id) + + response = self.client.delete( + url, + content_type='application/json', + ) + assert response.status_code == 204 + + course_entitlement.refresh_from_db() + assert not CourseEnrollment.is_enrolled(self.user, self.course.id) + assert course_entitlement.enrollment_course_run is None + + @patch("entitlements.api.v1.views.get_course_runs_for_course") + def test_user_can_switch(self, mock_get_course_runs): + mock_get_course_runs.return_value = self.return_values + course_entitlement = CourseEntitlementFactory.create(user=self.user) + + url = reverse( + self.ENTITLEMENTS_ENROLLMENT_NAMESPACE, + args=[str(course_entitlement.uuid)] + ) + assert course_entitlement.enrollment_course_run is None + + data = { + 'course_run_id': str(self.course.id) + } + response = self.client.post( + url, + data=json.dumps(data), + content_type='application/json', + ) + course_entitlement.refresh_from_db() + + assert response.status_code == 201 + assert CourseEnrollment.is_enrolled(self.user, self.course.id) + + data = { + 'course_run_id': str(self.course2.id) + } + response = self.client.post( + url, + data=json.dumps(data), + content_type='application/json', + ) + assert response.status_code == 201 + + course_entitlement.refresh_from_db() + assert CourseEnrollment.is_enrolled(self.user, self.course2.id) + assert course_entitlement.enrollment_course_run is not None + + @patch("entitlements.api.v1.views.get_course_runs_for_course") + def test_user_already_enrolled(self, mock_get_course_runs): + course_entitlement = CourseEntitlementFactory.create(user=self.user) + mock_get_course_runs.return_value = self.return_values + + url = reverse( + self.ENTITLEMENTS_ENROLLMENT_NAMESPACE, + args=[str(course_entitlement.uuid)] + ) + + CourseEnrollment.enroll(self.user, self.course.id, mode=course_entitlement.mode) + data = { + 'course_run_id': str(self.course.id) + } + response = self.client.post( + url, + data=json.dumps(data), + content_type='application/json', + ) + course_entitlement.refresh_from_db() + + assert response.status_code == 201 + assert CourseEnrollment.is_enrolled(self.user, self.course.id) + + course_entitlement.refresh_from_db() + assert CourseEnrollment.is_enrolled(self.user, self.course.id) + assert course_entitlement.enrollment_course_run is not None + + @patch("entitlements.api.v1.views.get_course_runs_for_course") + def test_user_cannot_enroll_in_unknown_course_run_id(self, mock_get_course_runs): + fake_course_str = str(self.course.id) + 'fake' + fake_course_key = CourseKey.from_string(fake_course_str) + course_entitlement = CourseEntitlementFactory.create(user=self.user) + mock_get_course_runs.return_value = self.return_values + + url = reverse( + self.ENTITLEMENTS_ENROLLMENT_NAMESPACE, + args=[str(course_entitlement.uuid)] + ) + + data = { + 'course_run_id': str(fake_course_key) + } + response = self.client.post( + url, + data=json.dumps(data), + content_type='application/json', + ) + + expected_message = 'The Course Run ID is not a match for this Course Entitlement.' + assert response.status_code == 400 + assert response.data['message'] == expected_message # pylint: disable=no-member + assert not CourseEnrollment.is_enrolled(self.user, fake_course_key) diff --git a/common/djangoapps/entitlements/api/v1/urls.py b/common/djangoapps/entitlements/api/v1/urls.py index a8a81e0de9a68d815e8f7a690fc527a73c9cd819..f716d39b9c639c1a94bd8e1d81a87aa1bfd6716b 100644 --- a/common/djangoapps/entitlements/api/v1/urls.py +++ b/common/djangoapps/entitlements/api/v1/urls.py @@ -1,11 +1,22 @@ from django.conf.urls import url, include from rest_framework.routers import DefaultRouter -from .views import EntitlementViewSet +from .views import EntitlementViewSet, EntitlementEnrollmentViewSet router = DefaultRouter() router.register(r'entitlements', EntitlementViewSet, base_name='entitlements') +ENROLLMENTS_VIEW = EntitlementEnrollmentViewSet.as_view({ + 'post': 'create', + 'delete': 'destroy', +}) + + urlpatterns = [ url(r'', include(router.urls)), + url( + r'entitlements/(?P<uuid>{regex})/enrollments$'.format(regex=EntitlementViewSet.ENTITLEMENT_UUID4_REGEX), + ENROLLMENTS_VIEW, + name='enrollments' + ) ] diff --git a/common/djangoapps/entitlements/api/v1/views.py b/common/djangoapps/entitlements/api/v1/views.py index 68f3a1eb3a2c5ed3e6658a9d37c1b7d5af51bb2e..13aa941a05d40b0f924f11582a7696e74ce50503 100644 --- a/common/djangoapps/entitlements/api/v1/views.py +++ b/common/djangoapps/entitlements/api/v1/views.py @@ -4,23 +4,30 @@ from django.db import transaction from django.utils import timezone from django_filters.rest_framework import DjangoFilterBackend from edx_rest_framework_extensions.authentication import JwtAuthentication -from rest_framework import permissions, viewsets +from opaque_keys import InvalidKeyError +from opaque_keys.edx.keys import CourseKey +from rest_framework import permissions, viewsets, status +from rest_framework.authentication import SessionAuthentication from rest_framework.response import Response from entitlements.api.v1.filters import CourseEntitlementFilter from entitlements.api.v1.permissions import IsAdminOrAuthenticatedReadOnly from entitlements.api.v1.serializers import CourseEntitlementSerializer from entitlements.models import CourseEntitlement +from openedx.core.djangoapps.catalog.utils import get_course_runs_for_course from openedx.core.djangoapps.cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf from student.models import CourseEnrollment +from student.models import CourseEnrollmentException, AlreadyEnrolledError log = logging.getLogger(__name__) class EntitlementViewSet(viewsets.ModelViewSet): + ENTITLEMENT_UUID4_REGEX = '[0-9a-f]{8}-[0-9a-f]{4}-[1-5][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}' + authentication_classes = (JwtAuthentication, SessionAuthenticationCrossDomainCsrf,) permission_classes = (permissions.IsAuthenticated, IsAdminOrAuthenticatedReadOnly,) - lookup_value_regex = '[0-9a-f-]+' + lookup_value_regex = ENTITLEMENT_UUID4_REGEX lookup_field = 'uuid' serializer_class = CourseEntitlementSerializer filter_backends = (DjangoFilterBackend,) @@ -102,3 +109,169 @@ class EntitlementViewSet(viewsets.ModelViewSet): ) if save_model: instance.save() + + +class EntitlementEnrollmentViewSet(viewsets.GenericViewSet): + """ + Endpoint in the Entitlement API to handle the Enrollment of a User's Entitlement. + This API will handle + - Enroll + - Unenroll + - Switch Enrollment + """ + authentication_classes = (JwtAuthentication, SessionAuthentication,) + permission_classes = (permissions.IsAuthenticated,) + queryset = CourseEntitlement.objects.all() + + def _verify_course_run_for_entitlement(self, entitlement, course_run_id): + """ + Verifies that a Course run is a child of the Course assigned to the entitlement. + """ + course_runs = get_course_runs_for_course(entitlement.course_uuid) + for run in course_runs: + if course_run_id == run.get('key', ''): + return True + return False + + def _enroll_entitlement(self, entitlement, course_run_key, user): + """ + Internal method to handle the details of enrolling a User in a Course Run. + + Returns a response object is there is an error or exception, None otherwise + """ + try: + enrollment = CourseEnrollment.enroll( + user=user, + course_key=course_run_key, + mode=entitlement.mode, + check_access=True + ) + except AlreadyEnrolledError: + enrollment = CourseEnrollment.get_enrollment(user, course_run_key) + if enrollment.mode == entitlement.mode: + CourseEntitlement.set_enrollment(entitlement, enrollment) + # Else the User is already enrolled in another Mode and we should + # not do anything else related to Entitlements. + except CourseEnrollmentException: + message = ( + 'Course Entitlement Enroll for {username} failed for course: {course_id}, ' + 'mode: {mode}, and entitlement: {entitlement}' + ).format( + username=user.username, + course_id=course_run_key, + mode=entitlement.mode, + entitlement=entitlement.uuid + ) + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={'message': message} + ) + + CourseEntitlement.set_enrollment(entitlement, enrollment) + return None + + def _unenroll_entitlement(self, entitlement, course_run_key, user): + """ + Internal method to handle the details of Unenrolling a User in a Course Run. + """ + CourseEnrollment.unenroll(user, course_run_key, skip_refund=True) + CourseEntitlement.set_enrollment(entitlement, None) + + def create(self, request, uuid): + """ + On POST this method will be called and will handle enrolling a user in the + provided course_run_id from the data. This is called on a specific entitlement + UUID so the course_run_id has to correspond to the Course that is assigned to + the Entitlement. + + When this API is called for a user who is already enrolled in a run that User + will be unenrolled from their current run and enrolled in the new run if it is + available. + """ + course_run_id = request.data.get('course_run_id', None) + + if not course_run_id: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data='The Course Run ID was not provided.' + ) + + # Verify that the user has an Entitlement for the provided Course UUID. + try: + entitlement = CourseEntitlement.objects.get(uuid=uuid, user=request.user, expired_at=None) + except CourseEntitlement.DoesNotExist: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data='The Entitlement for this UUID does not exist or is Expired.' + ) + + # Verify the course run ID is of the same type as the Course entitlement. + course_run_valid = self._verify_course_run_for_entitlement(entitlement, course_run_id) + if not course_run_valid: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={ + 'message': 'The Course Run ID is not a match for this Course Entitlement.' + } + ) + + # Determine if this is a Switch session or a simple enroll and handle both. + try: + course_run_string = CourseKey.from_string(course_run_id) + except InvalidKeyError: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={ + 'message': 'Invalid {course_id}'.format(course_id=course_run_id) + } + ) + if entitlement.enrollment_course_run is None: + response = self._enroll_entitlement( + entitlement=entitlement, + course_run_key=course_run_string, + user=request.user + ) + if response: + return response + elif entitlement.enrollment_course_run.course_id != course_run_id: + self._unenroll_entitlement( + entitlement=entitlement, + course_run_key=entitlement.enrollment_course_run.course_id, + user=request.user + ) + response = self._enroll_entitlement( + entitlement=entitlement, + course_run_key=course_run_string, + user=request.user + ) + if response: + return response + + return Response( + status=status.HTTP_201_CREATED, + data={ + 'course_run_id': course_run_id, + } + ) + + def destroy(self, request, uuid): + """ + On DELETE call to this API we will unenroll the course enrollment for the provided uuid + """ + try: + entitlement = CourseEntitlement.objects.get(uuid=uuid, user=request.user, expired_at=None) + except CourseEntitlement.DoesNotExist: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data='The Entitlement for this UUID does not exist or is Expired.' + ) + + if entitlement.enrollment_course_run is None: + return Response(status=status.HTTP_204_NO_CONTENT) + + self._unenroll_entitlement( + entitlement=entitlement, + course_run_key=entitlement.enrollment_course_run.course_id, + user=request.user + ) + return Response(status=status.HTTP_204_NO_CONTENT) diff --git a/common/djangoapps/entitlements/migrations/0004_auto_20171206_1729.py b/common/djangoapps/entitlements/migrations/0004_auto_20171206_1729.py new file mode 100644 index 0000000000000000000000000000000000000000..ae3f9881394762abef979ac5b9862e6914df1a1d --- /dev/null +++ b/common/djangoapps/entitlements/migrations/0004_auto_20171206_1729.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +from django.db import migrations, models +import uuid + + +class Migration(migrations.Migration): + + dependencies = [ + ('entitlements', '0003_auto_20171205_1431'), + ] + + operations = [ + migrations.AlterField( + model_name='courseentitlement', + name='uuid', + field=models.UUIDField(default=uuid.uuid4, unique=True, editable=False), + ), + ] diff --git a/common/djangoapps/entitlements/models.py b/common/djangoapps/entitlements/models.py index a272fe1724213f43dc36574e150cc084900cd97b..404d3a550bacc3921cb88c29fafedce241fe9909 100644 --- a/common/djangoapps/entitlements/models.py +++ b/common/djangoapps/entitlements/models.py @@ -125,7 +125,7 @@ class CourseEntitlement(TimeStampedModel): """ user = models.ForeignKey(settings.AUTH_USER_MODEL) - uuid = models.UUIDField(default=uuid_tools.uuid4, editable=False) + uuid = models.UUIDField(default=uuid_tools.uuid4, editable=False, unique=True) course_uuid = models.UUIDField(help_text='UUID for the Course, not the Course Run') expired_at = models.DateTimeField( null=True, @@ -212,3 +212,10 @@ class CourseEntitlement(TimeStampedModel): Returns a boolean as to whether or not the entitlement can be redeemed based on the entitlement's policy """ return self.policy.is_entitlement_redeemable(self) + + @classmethod + def set_enrollment(cls, entitlement, enrollment): + """ + Fulfills an entitlement by specifying a session. + """ + cls.objects.filter(id=entitlement.id).update(enrollment_course_run=enrollment) diff --git a/openedx/core/djangoapps/catalog/utils.py b/openedx/core/djangoapps/catalog/utils.py index c71dca2a7b9240d942ae32c6c72363c7d3199a40..4f06845900ef7bf4ee53064a5764ec9d81986a1d 100644 --- a/openedx/core/djangoapps/catalog/utils.py +++ b/openedx/core/djangoapps/catalog/utils.py @@ -235,7 +235,6 @@ def get_course_runs_for_course(course_uuid): cache_key=cache_key if catalog_integration.is_cache_enabled else None, long_term_cache=True ) - return data.get('course_runs', []) else: return []