diff --git a/common/djangoapps/enrollment/tests/test_views.py b/common/djangoapps/enrollment/tests/test_views.py index 6d08f93137e5d61d115dcc53158e5bd9d1c99276..42a7aff93fe0f33565a407fa0217cf019be1d1ca 100644 --- a/common/djangoapps/enrollment/tests/test_views.py +++ b/common/djangoapps/enrollment/tests/test_views.py @@ -30,6 +30,7 @@ from enrollment.views import EnrollmentUserThrottle from openedx.core.djangoapps.content.course_overviews.models import CourseOverview from openedx.core.djangoapps.embargo.models import Country, CountryAccessRule, RestrictedCourse from openedx.core.djangoapps.embargo.test_utils import restrict_course +from openedx.core.djangoapps.course_groups import cohorts from openedx.core.djangoapps.user_api.models import ( RetirementState, UserRetirementStatus, @@ -67,6 +68,7 @@ class EnrollmentTestMixin(object): min_mongo_calls=0, max_mongo_calls=0, linked_enterprise_customer=None, + cohort=None, ): """ Enroll in the course and verify the response's status code. If the expected status is 200, also validates @@ -96,6 +98,9 @@ class EnrollmentTestMixin(object): if linked_enterprise_customer is not None: data['linked_enterprise_customer'] = linked_enterprise_customer + if cohort is not None: + data['cohort'] = cohort + extra = {} if as_server: extra['HTTP_X_EDX_API_KEY'] = self.API_KEY @@ -576,6 +581,28 @@ class EnrollmentTest(EnrollmentTestMixin, ModuleStoreTestCase, APITestCase, Ente except ImproperlyConfigured: self.fail("No throttle rate set for {}".format(user_scope)) + def test_create_enrollment_with_cohort(self): + """Enroll in the course, and also add to a cohort.""" + # Create a cohort + cohort_name = 'masters' + cohorts.set_course_cohorted(self.course.id, True) + cohorts.add_cohort(self.course.id, cohort_name, 'test') + # Create an enrollment + + self.assert_enrollment_status(cohort=cohort_name) + self.assertTrue(CourseEnrollment.is_enrolled(self.user, self.course.id)) + course_mode, is_active = CourseEnrollment.enrollment_mode_for_user(self.user, self.course.id) + self.assertTrue(is_active) + self.assertEqual(cohorts.get_cohort(self.user, self.course.id, assign=False).name, cohort_name) + + def test_create_enrollment_with_wrong_cohort(self): + """Enroll in the course, and also add to a cohort.""" + # Create a cohort + cohorts.set_course_cohorted(self.course.id, True) + cohorts.add_cohort(self.course.id, 'masters', 'test') + # Create an enrollment + self.assert_enrollment_status(cohort='missing', expected_status=status.HTTP_400_BAD_REQUEST) + def test_create_enrollment_with_mode(self): """With the right API key, create a new enrollment with a mode set other than the default.""" # Create a professional ed course mode. diff --git a/common/djangoapps/enrollment/views.py b/common/djangoapps/enrollment/views.py index e45820c1cf8bcbefc082bfd7e24c5af067ddb515..35ecb2f9947830d4271939246d9b3e73fb0f7b7d 100644 --- a/common/djangoapps/enrollment/views.py +++ b/common/djangoapps/enrollment/views.py @@ -6,6 +6,7 @@ consist primarily of authentication, request validation, and serialization. import logging from course_modes.models import CourseMode +from django.db import transaction from django.contrib.auth import get_user_model from django.core.exceptions import ObjectDoesNotExist from django.utils.decorators import method_decorator @@ -21,6 +22,7 @@ from openedx.core.djangoapps.embargo import api as embargo_api from openedx.core.djangoapps.user_api.accounts.permissions import CanRetireUser from openedx.core.djangoapps.user_api.models import UserRetirementStatus from openedx.core.djangoapps.user_api.preferences.api import update_email_opt_in +from openedx.core.djangoapps.course_groups.cohorts import add_user_to_cohort, get_cohort_by_name, CourseUserGroup from openedx.core.lib.api.authentication import ( OAuth2AuthenticationAllowInactiveUser, SessionAuthenticationAllowInactiveUser @@ -729,6 +731,10 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn): enrollment_attributes=enrollment_attributes ) + cohort_name = request.data.get('cohort') + if cohort_name is not None: + cohort = get_cohort_by_name(course_id, cohort_name) + add_user_to_cohort(cohort, user) email_opt_in = request.data.get('email_opt_in', None) if email_opt_in is not None: org = course_id.org @@ -767,6 +773,13 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn): ).format(username=username, course_id=course_id) } ) + except CourseUserGroup.DoesNotExist: + log.exception('Missing cohort [%s] in course run [%s]', cohort_name, course_id) + return Response( + status=status.HTTP_400_BAD_REQUEST, + data={ + "message": "An error occured while adding to cohort [%s]" % cohort_name + }) finally: # Assumes that the ecommerce service uses an API key to authenticate. if has_api_key_permissions: diff --git a/openedx/core/djangoapps/course_groups/cohorts.py b/openedx/core/djangoapps/course_groups/cohorts.py index 599f8161a4e354b1f592dcf4c1b797e94b09dfc6..876031c2b2d3c6276f1539ff80d6d2f9cbd08756 100644 --- a/openedx/core/djangoapps/course_groups/cohorts.py +++ b/openedx/core/djangoapps/course_groups/cohorts.py @@ -244,24 +244,19 @@ def get_cohort(user, course_key, assign=True, use_cached=False): # Otherwise assign the user a cohort. try: - with transaction.atomic(): - # If learner has been pre-registered in a cohort, get that cohort. Otherwise assign to a random cohort. - course_user_group = None - for assignment in UnregisteredLearnerCohortAssignments.objects.filter(email=user.email, course_id=course_key): - course_user_group = assignment.course_user_group - unregistered_learner = assignment - - if course_user_group: - unregistered_learner.delete() - else: - course_user_group = get_random_cohort(course_key) - - membership = CohortMembership.objects.create( - user=user, - course_user_group=course_user_group, - ) - - return cache.setdefault(cache_key, membership.course_user_group) + # If learner has been pre-registered in a cohort, get that cohort. Otherwise assign to a random cohort. + course_user_group = None + for assignment in UnregisteredLearnerCohortAssignments.objects.filter(email=user.email, course_id=course_key): + course_user_group = assignment.course_user_group + assignment.delete() + break + else: + course_user_group = get_random_cohort(course_key) + add_user_to_cohort(course_user_group, user) + return course_user_group + except ValueError: + # user already in cohort + return course_user_group except IntegrityError as integrity_error: # An IntegrityError is raised when multiple workers attempt to # create the same row in one of the cohort model entries: @@ -348,7 +343,7 @@ def get_cohort_names(course): return {cohort.id: cohort.name for cohort in get_course_cohorts(course)} -### Helpers for cohort management views +# Helpers for cohort management views def get_cohort_by_name(course_key, name): @@ -432,13 +427,13 @@ def remove_user_from_cohort(cohort, username_or_email): raise ValueError("User {} was not present in cohort {}".format(username_or_email, cohort)) -def add_user_to_cohort(cohort, username_or_email): +def add_user_to_cohort(cohort, username_or_email_or_user): """ Look up the given user, and if successful, add them to the specified cohort. Arguments: cohort: CourseUserGroup - username_or_email: string. Treated as email if has '@' + username_or_email_or_user: user or string. Treated as email if has '@' Returns: User object (or None if the email address is preassigned), @@ -453,36 +448,41 @@ def add_user_to_cohort(cohort, username_or_email): User.DoesNotExist if a user could not be found. """ try: - user = get_user_by_username_or_email(username_or_email) + if hasattr(username_or_email_or_user, 'email'): + user = username_or_email_or_user + else: + user = get_user_by_username_or_email(username_or_email_or_user) - membership = CohortMembership(course_user_group=cohort, user=user) - membership.save() # This will handle both cases, creation and updating, of a CohortMembership for this user. - COHORT_MEMBERSHIP_UPDATED.send(sender=None, user=user, course_key=membership.course_id) + membership, previous_cohort = CohortMembership.assign(cohort, user) tracker.emit( "edx.cohort.user_add_requested", { "user_id": user.id, "cohort_id": cohort.id, "cohort_name": cohort.name, - "previous_cohort_id": membership.previous_cohort_id, - "previous_cohort_name": membership.previous_cohort_name, + "previous_cohort_id": getattr(previous_cohort, 'id', None), + "previous_cohort_name": getattr(previous_cohort, 'name', None), } ) - return (user, membership.previous_cohort_name, False) + cache = RequestCache(COHORT_CACHE_NAMESPACE).data + cache_key = _cohort_cache_key(user.id, membership.course_id) + cache[cache_key] = membership.course_user_group + COHORT_MEMBERSHIP_UPDATED.send(sender=None, user=user, course_key=membership.course_id) + return user, getattr(previous_cohort, 'name', None), False except User.DoesNotExist as ex: # If username_or_email is an email address, store in database. try: - validate_email(username_or_email) + validate_email(username_or_email_or_user) try: assignment = UnregisteredLearnerCohortAssignments.objects.get( - email=username_or_email, course_id=cohort.course_id + email=username_or_email_or_user, course_id=cohort.course_id ) assignment.course_user_group = cohort assignment.save() except UnregisteredLearnerCohortAssignments.DoesNotExist: assignment = UnregisteredLearnerCohortAssignments.objects.create( - course_user_group=cohort, email=username_or_email, course_id=cohort.course_id + course_user_group=cohort, email=username_or_email_or_user, course_id=cohort.course_id ) tracker.emit( @@ -496,7 +496,7 @@ def add_user_to_cohort(cohort, username_or_email): return (None, None, True) except ValidationError as invalid: - if "@" in username_or_email: + if "@" in username_or_email_or_user: raise invalid else: raise ex diff --git a/openedx/core/djangoapps/course_groups/management/commands/post_cohort_membership_fix.py b/openedx/core/djangoapps/course_groups/management/commands/post_cohort_membership_fix.py deleted file mode 100644 index 29afef6a78c549d45943a8fc3cecc2c1fd2856d6..0000000000000000000000000000000000000000 --- a/openedx/core/djangoapps/course_groups/management/commands/post_cohort_membership_fix.py +++ /dev/null @@ -1,91 +0,0 @@ -""" -Intended to fix any inconsistencies that may arise during the rollout of the CohortMembership model. -Illustration: https://gist.github.com/efischer19/d62f8ee42b7fbfbc6c9a -""" -from django.core.management.base import BaseCommand -from django.db import IntegrityError - -from openedx.core.djangoapps.course_groups.models import CourseUserGroup, CohortMembership - - -class Command(BaseCommand): - """ - Repair any inconsistencies between CourseUserGroup and CohortMembership. To be run after migration 0006. - """ - help = ''' - Repairs any potential inconsistencies made in the window between running migrations 0005 and 0006, and deploying - the code changes to enforce use of CohortMembership that go with said migrations. - - commit: optional argument. If not provided, will dry-run and list number of operations that would be made. - ''' - - def add_arguments(self, parser): - """ - Add arguments to the command parser. - """ - parser.add_argument( - '--commit', - action='store_true', - dest='commit', - default=False, - help='Really commit the changes, otherwise, just dry run', - ) - - def handle(self, *args, **options): - """ - Execute the command. Since this is designed to fix any issues cause by running pre-CohortMembership code - with the database already migrated to post-CohortMembership state, we will use the pre-CohortMembership - table CourseUserGroup as the canonical source of truth. This way, changes made in the window are persisted. - """ - commit = options['commit'] - memberships_to_delete = 0 - memberships_to_add = 0 - - # Begin by removing any data in CohortMemberships that does not match CourseUserGroups data - for membership in CohortMembership.objects.all(): - try: - CourseUserGroup.objects.get( - group_type=CourseUserGroup.COHORT, - users__id=membership.user.id, - course_id=membership.course_id, - id=membership.course_user_group.id - ) - except CourseUserGroup.DoesNotExist: - memberships_to_delete += 1 - if commit: - membership.delete() - - # Now we can add any CourseUserGroup data that is missing a backing CohortMembership - for course_group in CourseUserGroup.objects.filter(group_type=CourseUserGroup.COHORT): - for user in course_group.users.all(): - try: - CohortMembership.objects.get( - user=user, - course_id=course_group.course_id, - course_user_group_id=course_group.id - ) - except CohortMembership.DoesNotExist: - memberships_to_add += 1 - if commit: - membership = CohortMembership( - course_user_group=course_group, - user=user, - course_id=course_group.course_id - ) - try: - membership.save() - except IntegrityError: # If the user is in multiple cohorts, we arbitrarily choose between them - # In this case, allow the pre-existing entry to be "correct" - course_group.users.remove(user) - user.course_groups.remove(course_group) - - print '{} CohortMemberships did not match the CourseUserGroup table and will be deleted'.format( - memberships_to_delete - ) - print '{} CourseUserGroup users do not have a CohortMembership; one will be added if it is valid'.format( - memberships_to_add - ) - if commit: - print 'Changes have been made and saved.' - else: - print 'Dry run, changes have not been saved. Run again with "commit" argument to save changes' diff --git a/openedx/core/djangoapps/course_groups/management/commands/tests/test_post_cohort_membership_fix.py b/openedx/core/djangoapps/course_groups/management/commands/tests/test_post_cohort_membership_fix.py deleted file mode 100644 index 8049bfba3425e8d8d2f996e982daa2258fe3b165..0000000000000000000000000000000000000000 --- a/openedx/core/djangoapps/course_groups/management/commands/tests/test_post_cohort_membership_fix.py +++ /dev/null @@ -1,86 +0,0 @@ -""" -Test for the post-migration fix commands that are included with this djangoapp -""" -from django.core.management import call_command -from django.test.client import RequestFactory - -from openedx.core.djangoapps.course_groups.views import cohort_handler -from openedx.core.djangoapps.course_groups.cohorts import get_cohort_by_name -from openedx.core.djangoapps.course_groups.tests.helpers import config_course_cohorts -from openedx.core.djangoapps.course_groups.models import CohortMembership -from student.tests.factories import UserFactory -from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase -from xmodule.modulestore.tests.factories import CourseFactory - - -class TestPostMigrationFix(ModuleStoreTestCase): - """ - Base class for testing post-migration fix commands - """ - shard = 2 - - def setUp(self): - """ - setup course, user and request for tests - """ - super(TestPostMigrationFix, self).setUp() - self.course1 = CourseFactory.create() - self.course2 = CourseFactory.create() - self.user1 = UserFactory(is_staff=True) - self.user2 = UserFactory(is_staff=True) - self.request = RequestFactory().get("dummy_url") - self.request.user = self.user1 - - def test_post_cohortmembership_fix(self): - """ - Test that changes made *after* migration, but *before* turning on new code are handled properly - """ - # First, we're going to simulate some problem states that can arise during this window - config_course_cohorts(self.course1, is_cohorted=True, auto_cohorts=["Course1AutoGroup1", "Course1AutoGroup2"]) - - # Get the cohorts from the courses, which will cause auto cohorts to be created - cohort_handler(self.request, unicode(self.course1.id)) - course_1_auto_cohort_1 = get_cohort_by_name(self.course1.id, "Course1AutoGroup1") - course_1_auto_cohort_2 = get_cohort_by_name(self.course1.id, "Course1AutoGroup2") - - # When migrations were first run, the users were assigned to CohortMemberships correctly - membership1 = CohortMembership( - course_id=course_1_auto_cohort_1.course_id, - user=self.user1, - course_user_group=course_1_auto_cohort_1 - ) - membership1.save() - membership2 = CohortMembership( - course_id=course_1_auto_cohort_1.course_id, - user=self.user2, - course_user_group=course_1_auto_cohort_1 - ) - membership2.save() - - # But before CohortMembership code was turned on, some changes were made: - course_1_auto_cohort_2.users.add(self.user1) # user1 is now in 2 cohorts in the same course! - course_1_auto_cohort_2.users.add(self.user2) - course_1_auto_cohort_1.users.remove(self.user2) # and user2 was moved, but no one told CohortMembership! - - # run the post-CohortMembership command, dry-run - call_command('post_cohort_membership_fix') - - # Verify nothing was changed in dry-run mode. - self.assertEqual(self.user1.course_groups.count(), 2) # CourseUserGroup has 2 entries for user1 - - self.assertEqual(CohortMembership.objects.get(user=self.user2).course_user_group.name, 'Course1AutoGroup1') - user2_cohorts = list(self.user2.course_groups.values_list('name', flat=True)) - self.assertEqual(user2_cohorts, ['Course1AutoGroup2']) # CourseUserGroup and CohortMembership disagree - - # run the post-CohortMembership command, and commit it - call_command('post_cohort_membership_fix', commit='commit') - - # verify that both databases agree about the (corrected) state of the memberships - self.assertEqual(self.user1.course_groups.count(), 1) - self.assertEqual(CohortMembership.objects.filter(user=self.user1).count(), 1) - - self.assertEqual(self.user2.course_groups.count(), 1) - self.assertEqual(CohortMembership.objects.filter(user=self.user2).count(), 1) - self.assertEqual(CohortMembership.objects.get(user=self.user2).course_user_group.name, 'Course1AutoGroup2') - user2_cohorts = list(self.user2.course_groups.values_list('name', flat=True)) - self.assertEqual(user2_cohorts, ['Course1AutoGroup2']) diff --git a/openedx/core/djangoapps/course_groups/models.py b/openedx/core/djangoapps/course_groups/models.py index c17adea74dde0ba6e463fdccc8048953bdf3f5a4..89cdf1986412921f2c227b05741c79c51342ecc9 100644 --- a/openedx/core/djangoapps/course_groups/models.py +++ b/openedx/core/djangoapps/course_groups/models.py @@ -10,7 +10,6 @@ from django.core.exceptions import ValidationError from django.db import models, transaction from django.db.models.signals import pre_delete from django.dispatch import receiver -from util.db import outer_atomic from opaque_keys.edx.django.models import CourseKeyField from openedx.core.djangolib.model_mixins import DeletableByUserValue @@ -74,10 +73,6 @@ class CohortMembership(models.Model): user = models.ForeignKey(User, on_delete=models.CASCADE) course_id = CourseKeyField(max_length=255) - previous_cohort = None - previous_cohort_name = None - previous_cohort_id = None - class Meta(object): unique_together = (('user', 'course_id'), ) @@ -92,52 +87,47 @@ class CohortMembership(models.Model): if self.course_user_group.course_id != self.course_id: raise ValidationError("Non-matching course_ids provided") - def save(self, *args, **kwargs): - self.full_clean(validate_unique=False) - - log.info("Saving CohortMembership for user '%s' in '%s'", self.user.id, self.course_id) - - # Avoid infinite recursion if creating from get_or_create() call below. - # This block also allows middleware to use CohortMembership.get_or_create without worrying about outer_atomic - if 'force_insert' in kwargs and kwargs['force_insert'] is True: - with transaction.atomic(): - self.course_user_group.users.add(self.user) - super(CohortMembership, self).save(*args, **kwargs) - return - - # This block will transactionally commit updates to CohortMembership and underlying course_user_groups. - # Note the use of outer_atomic, which guarantees that operations are committed to the database on block exit. - # If called from a view method, that method must be marked with @transaction.non_atomic_requests. - with outer_atomic(read_committed=True): - - saved_membership, created = CohortMembership.objects.select_for_update().get_or_create( - user__id=self.user.id, - course_id=self.course_id, + @classmethod + def assign(cls, cohort, user): + """ + Assign user to cohort, switching them to this cohort if they had previously been assigned to another + cohort + Returns CohortMembership, previous_cohort (if any) + """ + with transaction.atomic(): + membership, created = cls.objects.select_for_update().get_or_create( + user__id=user.id, + course_id=cohort.course_id, defaults={ - 'course_user_group': self.course_user_group, - 'user': self.user - } - ) + 'course_user_group': cohort, + 'user': user + }) - # If the membership was newly created, all the validation and course_user_group logic was settled - # with a call to self.save(force_insert=True), which gets handled above. if created: - return - - if saved_membership.course_user_group == self.course_user_group: + membership.course_user_group.users.add(user) + previous_cohort = None + elif membership.course_user_group == cohort: raise ValueError("User {user_name} already present in cohort {cohort_name}".format( - user_name=self.user.username, - cohort_name=self.course_user_group.name - )) - self.previous_cohort = saved_membership.course_user_group - self.previous_cohort_name = saved_membership.course_user_group.name - self.previous_cohort_id = saved_membership.course_user_group.id - self.previous_cohort.users.remove(self.user) + user_name=user.username, + cohort_name=cohort.name)) + else: + previous_cohort = membership.course_user_group + previous_cohort.users.remove(user) + + membership.course_user_group = cohort + membership.course_user_group.users.add(user) + membership.save() + return membership, previous_cohort + + def save(self, force_insert=False, force_update=False, using=None, update_fields=None): + self.full_clean(validate_unique=False) - saved_membership.course_user_group = self.course_user_group - self.course_user_group.users.add(self.user) + log.info("Saving CohortMembership for user '%s' in '%s'", self.user.id, self.course_id) - super(CohortMembership, saved_membership).save(update_fields=['course_user_group']) + return super(CohortMembership, self).save(force_insert=force_insert, + force_update=force_update, + using=using, + update_fields=update_fields) # Needs to exist outside class definition in order to use 'sender=CohortMembership' @@ -243,7 +233,7 @@ class UnregisteredLearnerCohortAssignments(DeletableByUserValue, models.Model): """ Tracks the assignment of an unregistered learner to a course's cohort. """ - #pylint: disable=model-missing-unicode + # pylint: disable=model-missing-unicode class Meta(object): unique_together = (('course_id', 'email'), ) diff --git a/openedx/core/djangoapps/course_groups/views.py b/openedx/core/djangoapps/course_groups/views.py index da9f65612c49c069df6bcf32c66fd92ae9a445bb..caf8ebf37aa9dcb68bf37ba69ed7c9c633d84255 100644 --- a/openedx/core/djangoapps/course_groups/views.py +++ b/openedx/core/djangoapps/course_groups/views.py @@ -10,7 +10,6 @@ from django.contrib.auth.models import User from django.core.exceptions import ValidationError from django.core.paginator import EmptyPage, Paginator from django.urls import reverse -from django.db import transaction from django.http import Http404, HttpResponseBadRequest from django.utils.translation import ugettext from django.views.decorators.csrf import ensure_csrf_cookie @@ -253,7 +252,6 @@ def users_in_cohort(request, course_key_string, cohort_id): 'users': user_info}) -@transaction.non_atomic_requests @ensure_csrf_cookie @require_POST def add_users_to_cohort(request, course_key_string, cohort_id):