diff --git a/common/djangoapps/entitlements/api/v1/tests/test_views.py b/common/djangoapps/entitlements/api/v1/tests/test_views.py index 44b5086b3316de7d72896d4a874549965c9d0524..03fc7dddff1cf32f9fbb1b9f5c02b98f396a61a0 100644 --- a/common/djangoapps/entitlements/api/v1/tests/test_views.py +++ b/common/djangoapps/entitlements/api/v1/tests/test_views.py @@ -7,6 +7,9 @@ from datetime import datetime, timedelta import pytz from django.conf import settings from django.core.urlresolvers import reverse + +from course_modes.models import CourseMode +from course_modes.tests.factories import CourseModeFactory from mock import patch from opaque_keys.edx.locator import CourseKey from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase @@ -34,6 +37,13 @@ class EntitlementViewSetTest(ModuleStoreTestCase): self.user = UserFactory(is_staff=True) self.client.login(username=self.user.username, password=TEST_PASSWORD) self.course = CourseFactory() + self.course_mode = CourseModeFactory( + course_id=self.course.id, + mode_slug=CourseMode.VERIFIED, + # This must be in the future to ensure it is returned by downstream code. + expiration_datetime=datetime.now(pytz.UTC) + timedelta(days=1) + ) + self.entitlements_list_url = reverse('entitlements_api:v1:entitlements-list') def _get_data_set(self, user, course_uuid): @@ -115,6 +125,37 @@ class EntitlementViewSetTest(ModuleStoreTestCase): ) assert results == CourseEntitlementSerializer(course_entitlement).data + @patch("entitlements.api.v1.views.get_course_runs_for_course") + def test_add_entitlement_and_upgrade_audit_enrollment(self, mock_get_course_runs): + """ + Verify that if an entitlement is added for a user, if the user has one upgradeable enrollment + that enrollment is upgraded to the mode of the entitlement and linked to the entitlement. + """ + course_uuid = uuid.uuid4() + entitlement_data = self._get_data_set(self.user, str(course_uuid)) + mock_get_course_runs.return_value = [{'key': str(self.course.id)}] + + # Add an audit course enrollment for user. + enrollment = CourseEnrollment.enroll(self.user, self.course.id, mode=CourseMode.AUDIT) + + response = self.client.post( + self.entitlements_list_url, + data=json.dumps(entitlement_data), + content_type='application/json', + ) + assert response.status_code == 201 + results = response.data + + course_entitlement = CourseEntitlement.objects.get( + user=self.user, + course_uuid=course_uuid + ) + # Assert that enrollment mode is now verified + enrollment_mode = CourseEnrollment.enrollment_mode_for_user(self.user, self.course.id)[0] + assert enrollment_mode == course_entitlement.mode + assert course_entitlement.enrollment_course_run == enrollment + assert results == CourseEntitlementSerializer(course_entitlement).data + def test_non_staff_get_select_entitlements(self): not_staff_user = UserFactory() self.client.login(username=not_staff_user.username, password=TEST_PASSWORD) diff --git a/common/djangoapps/entitlements/api/v1/views.py b/common/djangoapps/entitlements/api/v1/views.py index 0858fdfbeea5503dcb9049752545d1bba313430f..a5fdfe22a0fb91feb25e8a6767328af0bd75390c 100644 --- a/common/djangoapps/entitlements/api/v1/views.py +++ b/common/djangoapps/entitlements/api/v1/views.py @@ -53,6 +53,48 @@ class EntitlementViewSet(viewsets.ModelViewSet): # to Admin users return CourseEntitlement.objects.all().select_related('user').select_related('enrollment_course_run') + def create(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + self.perform_create(serializer) + + entitlement = serializer.instance + user = entitlement.user + + # find all course_runs within the course + course_runs = get_course_runs_for_course(entitlement.course_uuid) + + # check if the user has enrollments for any of the course_runs + user_run_enrollments = [ + CourseEnrollment.get_enrollment(user, CourseKey.from_string(course_run.get('key'))) + for course_run + in course_runs + if CourseEnrollment.get_enrollment(user, CourseKey.from_string(course_run.get('key'))) + ] + + # filter to just enrollments that can be upgraded. + upgradeable_enrollments = [ + enrollment + for enrollment + in user_run_enrollments + if enrollment.upgrade_deadline and enrollment.upgrade_deadline > timezone.now() + ] + + # if there is only one upgradeable enrollment, convert it from audit to the entitlement.mode + # if there is any ambiguity about which enrollment to upgrade + # (i.e. multiple upgradeable enrollments or no available upgradeable enrollment), dont enroll + if len(upgradeable_enrollments) == 1: + enrollment = upgradeable_enrollments[0] + log.info('Upgrading enrollment [%s] from audit to [%s] while adding entitlement for user [%s] for course [%s] ', enrollment, serializer.data.get('mode'), user.username, serializer.data.get('course_uuid')) + enrollment.update_enrollment(mode=entitlement.mode) + entitlement.set_enrollment(enrollment) + else: + log.info('No enrollment upgraded while adding entitlement for user [%s] for course [%s] ', user.username, serializer.data.get('course_uuid')) + + headers = self.get_success_headers(serializer.data) + # Note, the entitlement is re-serialized before getting added to the Response, so that the 'modified' date reflects changes that occur when upgrading enrollment. + return Response(CourseEntitlementSerializer(entitlement).data, status=status.HTTP_201_CREATED, headers=headers) + def retrieve(self, request, *args, **kwargs): """ Override the retrieve method to expire a record that is past the