From 86e79c59ed1051920e0f8fa2642079ac329202ce Mon Sep 17 00:00:00 2001
From: Michael Terry <mterry@edx.org>
Date: Thu, 5 Jul 2018 15:46:44 -0400
Subject: [PATCH] Paginate notify_credentials queries

Avoid blowing up on the giant queries that notify_credentials does by
paginating the queries before we resolve them.
---
 .../management/commands/notify_credentials.py | 54 +++++++++++++------
 .../commands/tests/test_notify_credentials.py | 16 +++++-
 2 files changed, 53 insertions(+), 17 deletions(-)

diff --git a/openedx/core/djangoapps/credentials/management/commands/notify_credentials.py b/openedx/core/djangoapps/credentials/management/commands/notify_credentials.py
index c71a71259b2..ebba0170d80 100644
--- a/openedx/core/djangoapps/credentials/management/commands/notify_credentials.py
+++ b/openedx/core/djangoapps/credentials/management/commands/notify_credentials.py
@@ -9,8 +9,9 @@ This management command will manually trigger the receivers we care about.
 (We don't want to trigger all receivers for these signals, since these are busy
 signals.)
 """
-from __future__ import print_function
+from __future__ import print_function, division
 import logging
+import math
 import time
 import sys
 
@@ -45,6 +46,25 @@ def parsetime(timestr):
     return dt
 
 
+def paged_query(queryset, delay, page_size):
+    """
+    A generator that iterates through a queryset but only resolves chunks of it at once, to avoid overwhelming memory
+    with a giant query. Also adds an optional delay between yields, to help with load.
+    """
+    count = queryset.count()
+    pages = int(math.ceil(count / page_size))
+
+    for page in range(pages):
+        page_start = page * page_size
+        page_end = page_start + page_size
+        subquery = queryset[page_start:page_end]
+
+        for i, item in enumerate(subquery, start=1):
+            if delay:
+                time.sleep(delay)
+            yield page_start + i, item
+
+
 class Command(BaseCommand):
     """
     Example usage:
@@ -101,6 +121,12 @@ class Command(BaseCommand):
             default=0,
             help="Number of seconds to sleep between processing certificates, so that we don't flood our queues.",
         )
+        parser.add_argument(
+            '--page-size',
+            type=int,
+            default=100,
+            help="Number of items to query at once.",
+        )
 
     def handle(self, *args, **options):
         log.info(
@@ -133,23 +159,21 @@ class Command(BaseCommand):
         grades = PersistentCourseGrade.objects.filter(**grade_filter_args).order_by('modified')
 
         if options['dry_run']:
-            self.print_dry_run(list(certs), list(grades))
+            self.print_dry_run(certs, grades)
         else:
-            self.send_notifications(certs, grades, delay=options['delay'])
+            self.send_notifications(certs, grades, delay=options['delay'], page_size=options['page_size'])
 
         log.info('notify_credentials finished')
 
-    def send_notifications(self, certs, grades, delay=0):
+    def send_notifications(self, certs, grades, delay=0, page_size=0):
         """ Run actual handler commands for the provided certs and grades. """
 
         # First, do certs
-        for i, cert in enumerate(certs, start=1):
+        for i, cert in paged_query(certs, delay, page_size):
             log.info(
                 "Handling credential changes %d for certificate %s",
                 i, certstr(cert),
             )
-            if delay:
-                time.sleep(delay)
 
             signal_args = {
                 'sender': None,
@@ -163,13 +187,11 @@ class Command(BaseCommand):
             handle_cert_change(**signal_args)
 
         # Then do grades
-        for i, grade in enumerate(grades, start=1):
+        for i, grade in paged_query(grades, delay, page_size):
             log.info(
                 "Handling grade changes %d for grade %s",
                 i, gradestr(grade),
             )
-            if delay:
-                time.sleep(delay)
 
             user = User.objects.get(id=grade.user_id)
             send_grade_if_interesting(user, grade.course_id, None, None, grade.letter_grade, grade.percent_grade)
@@ -201,14 +223,14 @@ class Command(BaseCommand):
         print("DRY-RUN: This command would have handled changes for...")
         ITEMS_TO_SHOW = 10
 
-        print(len(certs), "Certificates:")
+        print(certs.count(), "Certificates:")
         for cert in certs[:ITEMS_TO_SHOW]:
             print("   ", certstr(cert))
-        if len(certs) > ITEMS_TO_SHOW:
-            print("    (+ {} more)".format(len(certs) - ITEMS_TO_SHOW))
+        if certs.count() > ITEMS_TO_SHOW:
+            print("    (+ {} more)".format(certs.count() - ITEMS_TO_SHOW))
 
-        print(len(grades), "Grades:")
+        print(grades.count(), "Grades:")
         for grade in grades[:ITEMS_TO_SHOW]:
             print("   ", gradestr(grade))
-        if len(grades) > ITEMS_TO_SHOW:
-            print("    (+ {} more)".format(len(grades) - ITEMS_TO_SHOW))
+        if grades.count() > ITEMS_TO_SHOW:
+            print("    (+ {} more)".format(grades.count() - ITEMS_TO_SHOW))
diff --git a/openedx/core/djangoapps/credentials/management/commands/tests/test_notify_credentials.py b/openedx/core/djangoapps/credentials/management/commands/tests/test_notify_credentials.py
index 4eebc399652..05fea62a87b 100644
--- a/openedx/core/djangoapps/credentials/management/commands/tests/test_notify_credentials.py
+++ b/openedx/core/djangoapps/credentials/management/commands/tests/test_notify_credentials.py
@@ -8,7 +8,8 @@ import mock
 
 from django.core.management import call_command
 from django.core.management.base import CommandError
-from django.test import TestCase
+from django.db import connection, reset_queries
+from django.test import TestCase, override_settings
 from freezegun import freeze_time
 
 from lms.djangoapps.certificates.tests.factories import GeneratedCertificateFactory
@@ -107,3 +108,16 @@ class TestNotifyCredentials(TestCase):
         call_command(Command(), '--start-date', '2017-02-01', '--delay', '0.2')
         self.assertEqual(mock_time.sleep.call_count, 4)  # After each cert and each grade (2 each)
         self.assertEqual(mock_time.sleep.call_args[0][0], 0.2)
+
+    @override_settings(DEBUG=True)
+    def test_page_size(self):
+        call_command(Command(), '--start-date', '2017-01-01')
+        baseline = len(connection.queries)
+
+        reset_queries()
+        call_command(Command(), '--start-date', '2017-01-01', '--page-size=1')
+        self.assertEqual(len(connection.queries), baseline + 4)  # two extra page queries each for certs & grades
+
+        reset_queries()
+        call_command(Command(), '--start-date', '2017-01-01', '--page-size=2')
+        self.assertEqual(len(connection.queries), baseline + 2)  # one extra page query each for certs & grades
-- 
GitLab