From 73857f60398202b6deb299a03ad8fc9103e7a202 Mon Sep 17 00:00:00 2001
From: Ahsan Ulhaq <ahsan.haq@arbisoft.com>
Date: Mon, 23 Jul 2018 19:20:38 +0500
Subject: [PATCH] Added Management command for bulk unenrollment of users
 LEARNER-5852

---
 .../management/commands/bulk_unenroll.py      |  64 +++++++++++
 .../management/tests/test_bulk_unenroll.py    | 104 ++++++++++++++++++
 lms/djangoapps/shoppingcart/models.py         |   2 +-
 3 files changed, 169 insertions(+), 1 deletion(-)
 create mode 100644 common/djangoapps/student/management/commands/bulk_unenroll.py
 create mode 100644 common/djangoapps/student/management/tests/test_bulk_unenroll.py

diff --git a/common/djangoapps/student/management/commands/bulk_unenroll.py b/common/djangoapps/student/management/commands/bulk_unenroll.py
new file mode 100644
index 00000000000..6671b92bc7c
--- /dev/null
+++ b/common/djangoapps/student/management/commands/bulk_unenroll.py
@@ -0,0 +1,64 @@
+import logging
+
+import unicodecsv
+from django.core.exceptions import ObjectDoesNotExist
+from django.core.management.base import BaseCommand, CommandError
+from django.db.models import Q
+from opaque_keys import InvalidKeyError
+from opaque_keys.edx.keys import CourseKey
+
+from student.models import CourseEnrollment, User
+
+
+logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
+
+
+class Command(BaseCommand):
+
+    help = """"
+    Un-enroll bulk users from the courses.
+    It expect that the data will be provided in a csv file format with
+    first row being the header and columns will be as follows:
+    user_id, username, email, course_id, is_verified, verification_date
+    """
+
+    def add_arguments(self, parser):
+        parser.add_argument('-p', '--csv_path',
+                            metavar='csv_path',
+                            dest='csv_path',
+                            required=True,
+                            help='Path to CSV file.')
+
+    def handle(self, *args, **options):
+        csv_path = options['csv_path']
+        with open(csv_path) as csvfile:
+            reader = unicodecsv.DictReader(csvfile)
+            for row in reader:
+                username = row['username']
+                email = row['email']
+                course_key = row['course_id']
+                try:
+                    user = User.objects.get(Q(username=username) | Q(email=email))
+                except ObjectDoesNotExist:
+                    user = None
+                    msg = 'User with username {} or email {} does not exist'.format(username, email)
+                    logger.warning(msg)
+
+                try:
+                    course_id = CourseKey.from_string(course_key)
+                except InvalidKeyError:
+                    course_id = None
+                    msg = 'Invalid course id {course_id}, skipping un-enrollement for {username}, {email}'.format(**row)
+                    logger.warning(msg)
+
+                if user and course_id:
+                    enrollment = CourseEnrollment.get_enrollment(user, course_id)
+                    if not enrollment:
+                        msg = 'Enrollment for the user {} in course {} does not exist!'.format(username, course_key)
+                        logger.info(msg)
+                    else:
+                        try:
+                            CourseEnrollment.unenroll(user, course_id, skip_refund=True)
+                        except Exception as err:
+                            msg = 'Error un-enrolling User {} from course {}: '.format(username, course_key, err)
+                            logger.error(msg, exc_info=True)
diff --git a/common/djangoapps/student/management/tests/test_bulk_unenroll.py b/common/djangoapps/student/management/tests/test_bulk_unenroll.py
new file mode 100644
index 00000000000..a86352bf3ef
--- /dev/null
+++ b/common/djangoapps/student/management/tests/test_bulk_unenroll.py
@@ -0,0 +1,104 @@
+from tempfile import NamedTemporaryFile
+
+from django.core.management import call_command
+from xmodule.modulestore.tests.django_utils import SharedModuleStoreTestCase
+from xmodule.modulestore.tests.factories import CourseFactory
+from testfixtures import LogCapture
+
+from course_modes.tests.factories import CourseModeFactory
+from student.tests.factories import UserFactory
+from student.models import CourseEnrollment, User
+
+
+LOGGER_NAME = 'student.management.commands.bulk_unenroll'
+
+
+class BulkUnenrollTests(SharedModuleStoreTestCase):
+    """Test Bulk un-enroll command works fine for all test cases."""
+    def setUp(self):
+        super(BulkUnenrollTests, self).setUp()
+        self.course = CourseFactory.create()
+        self.audit_mode = CourseModeFactory.create(
+            course_id=self.course.id,
+            mode_slug='audit',
+            mode_display_name='Audit',
+        )
+
+        self.user_info = [
+            ('amy', 'amy@pond.com', 'password'),
+            ('rory', 'rory@theroman.com', 'password'),
+            ('river', 'river@song.com', 'password')
+        ]
+        self.enrollments = []
+        self.users = []
+
+        for username, email, password in self.user_info:
+            user = UserFactory.create(username=username, email=email, password=password)
+            self.users.append(user)
+            self.enrollments.append(CourseEnrollment.enroll(user, self.course.id, mode='audit'))
+
+    def _write_test_csv(self, csv, lines=None):
+        """Write a test csv file with the lines procided"""
+        csv.write("user_id,username,email,course_id\n")
+        csv.writelines(lines)
+        csv.seek(0)
+        return csv
+
+    def test_user_not_exist(self):
+        """Verify that warning user not exist is logged for non existing user."""
+        with NamedTemporaryFile() as csv:
+            csv = self._write_test_csv(csv, lines="111,test,test@example.com,course-v1:edX+DemoX+Demo_Course\n")
+
+            with LogCapture(LOGGER_NAME) as log:
+                call_command("bulk_unenroll", "--csv_path={}".format(csv.name))
+                log.check(
+                    (
+                        LOGGER_NAME,
+                        'WARNING',
+                        'User with username {} or email {} does not exist'.format('test', 'test@example.com')
+                    )
+                )
+
+    def test_invalid_course_key(self):
+        """Verify in case of invalid course key warning is logged."""
+        with NamedTemporaryFile() as csv:
+            csv = self._write_test_csv(csv, lines="111,amy,amy@pond.com,test_course\n")
+
+            with LogCapture(LOGGER_NAME) as log:
+                call_command("bulk_unenroll", "--csv_path={}".format(csv.name))
+                log.check(
+                    (
+                        LOGGER_NAME,
+                        'WARNING',
+                        'Invalid course id {}, skipping un-enrollement for {}, {}'.format(
+                            'test_course', 'amy', 'amy@pond.com')
+                    )
+                )
+
+    def test_user_not_enrolled(self):
+        """Verify in case of user not enrolled in course warning is logged."""
+        with NamedTemporaryFile() as csv:
+            csv = self._write_test_csv(csv, lines="111,amy,amy@pond.com,course-v1:edX+DemoX+Demo_Course\n")
+
+            with LogCapture(LOGGER_NAME) as log:
+                call_command("bulk_unenroll", "--csv_path={}".format(csv.name))
+                log.check(
+                    (
+                        LOGGER_NAME,
+                        'INFO',
+                        'Enrollment for the user {} in course {} does not exist!'.format(
+                            'amy', 'course-v1:edX+DemoX+Demo_Course')
+                    )
+                )
+
+    def test_bulk_un_enroll(self):
+        """Verify users are unenrolled using the command."""
+        lines = (str(enrollment.user.id) + "," + enrollment.user.username + "," +
+                 enrollment.user.email + "," + str(enrollment.course.id) + "\n"
+                 for enrollment in self.enrollments)
+        with NamedTemporaryFile() as csv:
+            csv = self._write_test_csv(csv, lines=lines)\
+
+            call_command("bulk_unenroll", "--csv_path={}".format(csv.name))
+            for enrollment in CourseEnrollment.objects.all():
+                self.assertEqual(enrollment.is_active, False)
diff --git a/lms/djangoapps/shoppingcart/models.py b/lms/djangoapps/shoppingcart/models.py
index c8afe931407..9f415937873 100644
--- a/lms/djangoapps/shoppingcart/models.py
+++ b/lms/djangoapps/shoppingcart/models.py
@@ -1882,7 +1882,7 @@ class CertificateItem(OrderItem):
         """
 
         # Only refund verified cert unenrollments that are within bounds of the expiration date
-        if (not course_enrollment.refundable()) or skip_refund:
+        if skip_refund or (not course_enrollment.refundable()):
             return
 
         target_certs = CertificateItem.objects.filter(course_id=course_enrollment.course_id, user_id=course_enrollment.user, status='purchased', mode='verified')
-- 
GitLab