Skip to content
Snippets Groups Projects
Commit e3dbfe9c authored by Waheed Ahmed's avatar Waheed Ahmed
Browse files

Fix refund entitlement on audit course un-enroll.

LEARNER-6247
parent 2892ce6e
No related merge requests found
......@@ -57,13 +57,12 @@ def _process_revoke_and_unenroll_entitlement(course_entitlement, is_refund=False
IntegrityError if there is an issue that should reverse the database changes
"""
if course_entitlement.expired_at is None:
course_entitlement.expired_at = timezone.now()
course_entitlement.expire_entitlement()
log.info(
'Set expired_at to [%s] for course entitlement [%s]',
course_entitlement.expired_at,
course_entitlement.uuid
)
course_entitlement.save()
if course_entitlement.enrollment_course_run is not None:
course_id = course_entitlement.enrollment_course_run.course_id
......
......@@ -210,8 +210,7 @@ class CourseEntitlement(TimeStampedModel):
if not self.expired_at:
if (self.policy.get_days_until_expiration(self) < 0 or
(self.enrollment_course_run and not self.is_entitlement_regainable())):
self.expired_at = now()
self.save()
self.expire_entitlement()
def get_days_until_expiration(self):
"""
......@@ -269,6 +268,13 @@ class CourseEntitlement(TimeStampedModel):
self.enrollment_course_run = enrollment
self.save()
def expire_entitlement(self):
"""
Expire the entitlement.
"""
self.expired_at = now()
self.save()
@classmethod
def unexpired_entitlements_for_user(cls, user):
return cls.objects.filter(user=user, expired_at=None).select_related('user')
......@@ -412,11 +418,13 @@ class CourseEntitlement(TimeStampedModel):
"""
course_uuid = get_course_uuid_for_course(course_enrollment.course_id)
course_entitlement = cls.get_entitlement_if_active(course_enrollment.user, course_uuid)
if course_entitlement:
if course_entitlement and course_entitlement.enrollment_course_run == course_enrollment:
course_entitlement.set_enrollment(None)
if not skip_refund and course_entitlement.is_entitlement_refundable():
course_entitlement.refund()
course_entitlement.expire_entitlement()
def refund(self):
"""
Initiate refund process for the entitlement.
......
......@@ -300,3 +300,31 @@ class TestModels(TestCase):
expired_at_datetime = entitlement.expired_at_datetime
assert expired_at_datetime
assert entitlement.expired_at
@patch("entitlements.models.get_course_uuid_for_course")
@patch("entitlements.models.CourseEntitlement.refund")
def test_unenroll_entitlement_with_audit_course_enrollment(self, mock_refund, mock_get_course_uuid):
"""
Test that entitlement is not refunded if un-enroll is called on audit course un-enroll.
"""
self.enrollment.mode = CourseMode.AUDIT
self.enrollment.user = self.user
self.enrollment.save()
entitlement = CourseEntitlementFactory.create(user=self.user)
mock_get_course_uuid.return_value = entitlement.course_uuid
CourseEnrollment.unenroll(self.user, self.course.id)
assert not mock_refund.called
entitlement.refresh_from_db()
assert entitlement.expired_at is None
self.enrollment.mode = CourseMode.VERIFIED
self.enrollment.is_active = True
self.enrollment.save()
entitlement.enrollment_course_run = self.enrollment
entitlement.save()
CourseEnrollment.unenroll(self.user, self.course.id)
assert mock_refund.called
entitlement.refresh_from_db()
assert entitlement.expired_at < now()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment