diff --git a/common/djangoapps/entitlements/tests/test_tasks.py b/common/djangoapps/entitlements/tests/test_tasks.py index aad5c618bad6fb33ff4ff46f9d3f4f93ae2990b1..da324fdaeeb876b6a33c4c92ece54a13cc215f6b 100644 --- a/common/djangoapps/entitlements/tests/test_tasks.py +++ b/common/djangoapps/entitlements/tests/test_tasks.py @@ -27,23 +27,26 @@ def boom(): @skip_unless_lms -@mock.patch('entitlements.models.CourseEntitlement.expired_at_datetime', new_callable=mock.PropertyMock) class TestExpireOldEntitlementsTask(TestCase): """ Tests for the 'expire_old_entitlements' method. """ - def test_checks_expiration(self, mock_datetime): + def test_checks_expiration(self): """ Test that we actually do check expiration on each entitlement (happy path) """ make_entitlement() make_entitlement() - tasks.expire_old_entitlements.delay(1, 3).get() + with mock.patch( + 'entitlements.models.CourseEntitlement.expired_at_datetime', + new_callable=mock.PropertyMock + ) as mock_datetime: + tasks.expire_old_entitlements.delay(1, 3).get() self.assertEqual(mock_datetime.call_count, 2) - def test_only_unexpired(self, mock_datetime): + def test_only_unexpired(self): """ Verify that only unexpired entitlements are included """ @@ -51,21 +54,28 @@ class TestExpireOldEntitlementsTask(TestCase): make_entitlement(expired=True) make_entitlement() - # Run expiration - tasks.expire_old_entitlements.delay(1, 3).get() + with mock.patch( + 'entitlements.models.CourseEntitlement.expired_at_datetime', + new_callable=mock.PropertyMock + ) as mock_datetime: + tasks.expire_old_entitlements.delay(1, 3).get() # Make sure only the unexpired one gets used self.assertEqual(mock_datetime.call_count, 1) - def test_retry(self, mock_datetime): + def test_retry(self): """ Test that we retry when an exception occurs while checking old entitlements. """ - mock_datetime.side_effect = boom - make_entitlement() - task = tasks.expire_old_entitlements.delay(1, 2) + + with mock.patch( + 'entitlements.models.CourseEntitlement.expired_at_datetime', + new_callable=mock.PropertyMock, + side_effect=boom + ) as mock_datetime: + task = tasks.expire_old_entitlements.delay(1, 2) self.assertRaises(Exception, task.get) self.assertEqual(mock_datetime.call_count, tasks.MAX_RETRIES + 1)