diff --git a/common/djangoapps/entitlements/tests/test_tasks.py b/common/djangoapps/entitlements/tests/test_tasks.py
index aad5c618bad6fb33ff4ff46f9d3f4f93ae2990b1..da324fdaeeb876b6a33c4c92ece54a13cc215f6b 100644
--- a/common/djangoapps/entitlements/tests/test_tasks.py
+++ b/common/djangoapps/entitlements/tests/test_tasks.py
@@ -27,23 +27,26 @@ def boom():
 
 
 @skip_unless_lms
-@mock.patch('entitlements.models.CourseEntitlement.expired_at_datetime', new_callable=mock.PropertyMock)
 class TestExpireOldEntitlementsTask(TestCase):
     """
     Tests for the 'expire_old_entitlements' method.
     """
-    def test_checks_expiration(self, mock_datetime):
+    def test_checks_expiration(self):
         """
         Test that we actually do check expiration on each entitlement (happy path)
         """
         make_entitlement()
         make_entitlement()
 
-        tasks.expire_old_entitlements.delay(1, 3).get()
+        with mock.patch(
+            'entitlements.models.CourseEntitlement.expired_at_datetime',
+            new_callable=mock.PropertyMock
+        ) as mock_datetime:
+            tasks.expire_old_entitlements.delay(1, 3).get()
 
         self.assertEqual(mock_datetime.call_count, 2)
 
-    def test_only_unexpired(self, mock_datetime):
+    def test_only_unexpired(self):
         """
         Verify that only unexpired entitlements are included
         """
@@ -51,21 +54,28 @@ class TestExpireOldEntitlementsTask(TestCase):
         make_entitlement(expired=True)
         make_entitlement()
 
-        # Run expiration
-        tasks.expire_old_entitlements.delay(1, 3).get()
+        with mock.patch(
+            'entitlements.models.CourseEntitlement.expired_at_datetime',
+            new_callable=mock.PropertyMock
+        ) as mock_datetime:
+            tasks.expire_old_entitlements.delay(1, 3).get()
 
         # Make sure only the unexpired one gets used
         self.assertEqual(mock_datetime.call_count, 1)
 
-    def test_retry(self, mock_datetime):
+    def test_retry(self):
         """
         Test that we retry when an exception occurs while checking old
         entitlements.
         """
-        mock_datetime.side_effect = boom
-
         make_entitlement()
-        task = tasks.expire_old_entitlements.delay(1, 2)
+
+        with mock.patch(
+            'entitlements.models.CourseEntitlement.expired_at_datetime',
+            new_callable=mock.PropertyMock,
+            side_effect=boom
+        ) as mock_datetime:
+            task = tasks.expire_old_entitlements.delay(1, 2)
 
         self.assertRaises(Exception, task.get)
         self.assertEqual(mock_datetime.call_count, tasks.MAX_RETRIES + 1)