Skip to content
Snippets Groups Projects
Unverified Commit a020c897 authored by Dave St.Germain's avatar Dave St.Germain Committed by GitHub
Browse files

Merge pull request #18912 from edx/dcs/enroll-cohorts

Added the ability to assign learners to a cohort in the enrollment API.
parents 99eca7b2 864f59ed
No related merge requests found
......@@ -30,6 +30,7 @@ from enrollment.views import EnrollmentUserThrottle
from openedx.core.djangoapps.content.course_overviews.models import CourseOverview
from openedx.core.djangoapps.embargo.models import Country, CountryAccessRule, RestrictedCourse
from openedx.core.djangoapps.embargo.test_utils import restrict_course
from openedx.core.djangoapps.course_groups import cohorts
from openedx.core.djangoapps.user_api.models import (
RetirementState,
UserRetirementStatus,
......@@ -67,6 +68,7 @@ class EnrollmentTestMixin(object):
min_mongo_calls=0,
max_mongo_calls=0,
linked_enterprise_customer=None,
cohort=None,
):
"""
Enroll in the course and verify the response's status code. If the expected status is 200, also validates
......@@ -96,6 +98,9 @@ class EnrollmentTestMixin(object):
if linked_enterprise_customer is not None:
data['linked_enterprise_customer'] = linked_enterprise_customer
if cohort is not None:
data['cohort'] = cohort
extra = {}
if as_server:
extra['HTTP_X_EDX_API_KEY'] = self.API_KEY
......@@ -576,6 +581,28 @@ class EnrollmentTest(EnrollmentTestMixin, ModuleStoreTestCase, APITestCase, Ente
except ImproperlyConfigured:
self.fail("No throttle rate set for {}".format(user_scope))
def test_create_enrollment_with_cohort(self):
"""Enroll in the course, and also add to a cohort."""
# Create a cohort
cohort_name = 'masters'
cohorts.set_course_cohorted(self.course.id, True)
cohorts.add_cohort(self.course.id, cohort_name, 'test')
# Create an enrollment
self.assert_enrollment_status(cohort=cohort_name)
self.assertTrue(CourseEnrollment.is_enrolled(self.user, self.course.id))
course_mode, is_active = CourseEnrollment.enrollment_mode_for_user(self.user, self.course.id)
self.assertTrue(is_active)
self.assertEqual(cohorts.get_cohort(self.user, self.course.id, assign=False).name, cohort_name)
def test_create_enrollment_with_wrong_cohort(self):
"""Enroll in the course, and also add to a cohort."""
# Create a cohort
cohorts.set_course_cohorted(self.course.id, True)
cohorts.add_cohort(self.course.id, 'masters', 'test')
# Create an enrollment
self.assert_enrollment_status(cohort='missing', expected_status=status.HTTP_400_BAD_REQUEST)
def test_create_enrollment_with_mode(self):
"""With the right API key, create a new enrollment with a mode set other than the default."""
# Create a professional ed course mode.
......
......@@ -6,6 +6,7 @@ consist primarily of authentication, request validation, and serialization.
import logging
from course_modes.models import CourseMode
from django.db import transaction
from django.contrib.auth import get_user_model
from django.core.exceptions import ObjectDoesNotExist
from django.utils.decorators import method_decorator
......@@ -21,6 +22,7 @@ from openedx.core.djangoapps.embargo import api as embargo_api
from openedx.core.djangoapps.user_api.accounts.permissions import CanRetireUser
from openedx.core.djangoapps.user_api.models import UserRetirementStatus
from openedx.core.djangoapps.user_api.preferences.api import update_email_opt_in
from openedx.core.djangoapps.course_groups.cohorts import add_user_to_cohort, get_cohort_by_name, CourseUserGroup
from openedx.core.lib.api.authentication import (
OAuth2AuthenticationAllowInactiveUser,
SessionAuthenticationAllowInactiveUser
......@@ -729,6 +731,10 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn):
enrollment_attributes=enrollment_attributes
)
cohort_name = request.data.get('cohort')
if cohort_name is not None:
cohort = get_cohort_by_name(course_id, cohort_name)
add_user_to_cohort(cohort, user)
email_opt_in = request.data.get('email_opt_in', None)
if email_opt_in is not None:
org = course_id.org
......@@ -767,6 +773,13 @@ class EnrollmentListView(APIView, ApiKeyPermissionMixIn):
).format(username=username, course_id=course_id)
}
)
except CourseUserGroup.DoesNotExist:
log.exception('Missing cohort [%s] in course run [%s]', cohort_name, course_id)
return Response(
status=status.HTTP_400_BAD_REQUEST,
data={
"message": "An error occured while adding to cohort [%s]" % cohort_name
})
finally:
# Assumes that the ecommerce service uses an API key to authenticate.
if has_api_key_permissions:
......
......@@ -244,24 +244,19 @@ def get_cohort(user, course_key, assign=True, use_cached=False):
# Otherwise assign the user a cohort.
try:
with transaction.atomic():
# If learner has been pre-registered in a cohort, get that cohort. Otherwise assign to a random cohort.
course_user_group = None
for assignment in UnregisteredLearnerCohortAssignments.objects.filter(email=user.email, course_id=course_key):
course_user_group = assignment.course_user_group
unregistered_learner = assignment
if course_user_group:
unregistered_learner.delete()
else:
course_user_group = get_random_cohort(course_key)
membership = CohortMembership.objects.create(
user=user,
course_user_group=course_user_group,
)
return cache.setdefault(cache_key, membership.course_user_group)
# If learner has been pre-registered in a cohort, get that cohort. Otherwise assign to a random cohort.
course_user_group = None
for assignment in UnregisteredLearnerCohortAssignments.objects.filter(email=user.email, course_id=course_key):
course_user_group = assignment.course_user_group
assignment.delete()
break
else:
course_user_group = get_random_cohort(course_key)
add_user_to_cohort(course_user_group, user)
return course_user_group
except ValueError:
# user already in cohort
return course_user_group
except IntegrityError as integrity_error:
# An IntegrityError is raised when multiple workers attempt to
# create the same row in one of the cohort model entries:
......@@ -348,7 +343,7 @@ def get_cohort_names(course):
return {cohort.id: cohort.name for cohort in get_course_cohorts(course)}
### Helpers for cohort management views
# Helpers for cohort management views
def get_cohort_by_name(course_key, name):
......@@ -432,13 +427,13 @@ def remove_user_from_cohort(cohort, username_or_email):
raise ValueError("User {} was not present in cohort {}".format(username_or_email, cohort))
def add_user_to_cohort(cohort, username_or_email):
def add_user_to_cohort(cohort, username_or_email_or_user):
"""
Look up the given user, and if successful, add them to the specified cohort.
Arguments:
cohort: CourseUserGroup
username_or_email: string. Treated as email if has '@'
username_or_email_or_user: user or string. Treated as email if has '@'
Returns:
User object (or None if the email address is preassigned),
......@@ -453,36 +448,41 @@ def add_user_to_cohort(cohort, username_or_email):
User.DoesNotExist if a user could not be found.
"""
try:
user = get_user_by_username_or_email(username_or_email)
if hasattr(username_or_email_or_user, 'email'):
user = username_or_email_or_user
else:
user = get_user_by_username_or_email(username_or_email_or_user)
membership = CohortMembership(course_user_group=cohort, user=user)
membership.save() # This will handle both cases, creation and updating, of a CohortMembership for this user.
COHORT_MEMBERSHIP_UPDATED.send(sender=None, user=user, course_key=membership.course_id)
membership, previous_cohort = CohortMembership.assign(cohort, user)
tracker.emit(
"edx.cohort.user_add_requested",
{
"user_id": user.id,
"cohort_id": cohort.id,
"cohort_name": cohort.name,
"previous_cohort_id": membership.previous_cohort_id,
"previous_cohort_name": membership.previous_cohort_name,
"previous_cohort_id": getattr(previous_cohort, 'id', None),
"previous_cohort_name": getattr(previous_cohort, 'name', None),
}
)
return (user, membership.previous_cohort_name, False)
cache = RequestCache(COHORT_CACHE_NAMESPACE).data
cache_key = _cohort_cache_key(user.id, membership.course_id)
cache[cache_key] = membership.course_user_group
COHORT_MEMBERSHIP_UPDATED.send(sender=None, user=user, course_key=membership.course_id)
return user, getattr(previous_cohort, 'name', None), False
except User.DoesNotExist as ex:
# If username_or_email is an email address, store in database.
try:
validate_email(username_or_email)
validate_email(username_or_email_or_user)
try:
assignment = UnregisteredLearnerCohortAssignments.objects.get(
email=username_or_email, course_id=cohort.course_id
email=username_or_email_or_user, course_id=cohort.course_id
)
assignment.course_user_group = cohort
assignment.save()
except UnregisteredLearnerCohortAssignments.DoesNotExist:
assignment = UnregisteredLearnerCohortAssignments.objects.create(
course_user_group=cohort, email=username_or_email, course_id=cohort.course_id
course_user_group=cohort, email=username_or_email_or_user, course_id=cohort.course_id
)
tracker.emit(
......@@ -496,7 +496,7 @@ def add_user_to_cohort(cohort, username_or_email):
return (None, None, True)
except ValidationError as invalid:
if "@" in username_or_email:
if "@" in username_or_email_or_user:
raise invalid
else:
raise ex
......
"""
Intended to fix any inconsistencies that may arise during the rollout of the CohortMembership model.
Illustration: https://gist.github.com/efischer19/d62f8ee42b7fbfbc6c9a
"""
from django.core.management.base import BaseCommand
from django.db import IntegrityError
from openedx.core.djangoapps.course_groups.models import CourseUserGroup, CohortMembership
class Command(BaseCommand):
"""
Repair any inconsistencies between CourseUserGroup and CohortMembership. To be run after migration 0006.
"""
help = '''
Repairs any potential inconsistencies made in the window between running migrations 0005 and 0006, and deploying
the code changes to enforce use of CohortMembership that go with said migrations.
commit: optional argument. If not provided, will dry-run and list number of operations that would be made.
'''
def add_arguments(self, parser):
"""
Add arguments to the command parser.
"""
parser.add_argument(
'--commit',
action='store_true',
dest='commit',
default=False,
help='Really commit the changes, otherwise, just dry run',
)
def handle(self, *args, **options):
"""
Execute the command. Since this is designed to fix any issues cause by running pre-CohortMembership code
with the database already migrated to post-CohortMembership state, we will use the pre-CohortMembership
table CourseUserGroup as the canonical source of truth. This way, changes made in the window are persisted.
"""
commit = options['commit']
memberships_to_delete = 0
memberships_to_add = 0
# Begin by removing any data in CohortMemberships that does not match CourseUserGroups data
for membership in CohortMembership.objects.all():
try:
CourseUserGroup.objects.get(
group_type=CourseUserGroup.COHORT,
users__id=membership.user.id,
course_id=membership.course_id,
id=membership.course_user_group.id
)
except CourseUserGroup.DoesNotExist:
memberships_to_delete += 1
if commit:
membership.delete()
# Now we can add any CourseUserGroup data that is missing a backing CohortMembership
for course_group in CourseUserGroup.objects.filter(group_type=CourseUserGroup.COHORT):
for user in course_group.users.all():
try:
CohortMembership.objects.get(
user=user,
course_id=course_group.course_id,
course_user_group_id=course_group.id
)
except CohortMembership.DoesNotExist:
memberships_to_add += 1
if commit:
membership = CohortMembership(
course_user_group=course_group,
user=user,
course_id=course_group.course_id
)
try:
membership.save()
except IntegrityError: # If the user is in multiple cohorts, we arbitrarily choose between them
# In this case, allow the pre-existing entry to be "correct"
course_group.users.remove(user)
user.course_groups.remove(course_group)
print '{} CohortMemberships did not match the CourseUserGroup table and will be deleted'.format(
memberships_to_delete
)
print '{} CourseUserGroup users do not have a CohortMembership; one will be added if it is valid'.format(
memberships_to_add
)
if commit:
print 'Changes have been made and saved.'
else:
print 'Dry run, changes have not been saved. Run again with "commit" argument to save changes'
"""
Test for the post-migration fix commands that are included with this djangoapp
"""
from django.core.management import call_command
from django.test.client import RequestFactory
from openedx.core.djangoapps.course_groups.views import cohort_handler
from openedx.core.djangoapps.course_groups.cohorts import get_cohort_by_name
from openedx.core.djangoapps.course_groups.tests.helpers import config_course_cohorts
from openedx.core.djangoapps.course_groups.models import CohortMembership
from student.tests.factories import UserFactory
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
from xmodule.modulestore.tests.factories import CourseFactory
class TestPostMigrationFix(ModuleStoreTestCase):
"""
Base class for testing post-migration fix commands
"""
shard = 2
def setUp(self):
"""
setup course, user and request for tests
"""
super(TestPostMigrationFix, self).setUp()
self.course1 = CourseFactory.create()
self.course2 = CourseFactory.create()
self.user1 = UserFactory(is_staff=True)
self.user2 = UserFactory(is_staff=True)
self.request = RequestFactory().get("dummy_url")
self.request.user = self.user1
def test_post_cohortmembership_fix(self):
"""
Test that changes made *after* migration, but *before* turning on new code are handled properly
"""
# First, we're going to simulate some problem states that can arise during this window
config_course_cohorts(self.course1, is_cohorted=True, auto_cohorts=["Course1AutoGroup1", "Course1AutoGroup2"])
# Get the cohorts from the courses, which will cause auto cohorts to be created
cohort_handler(self.request, unicode(self.course1.id))
course_1_auto_cohort_1 = get_cohort_by_name(self.course1.id, "Course1AutoGroup1")
course_1_auto_cohort_2 = get_cohort_by_name(self.course1.id, "Course1AutoGroup2")
# When migrations were first run, the users were assigned to CohortMemberships correctly
membership1 = CohortMembership(
course_id=course_1_auto_cohort_1.course_id,
user=self.user1,
course_user_group=course_1_auto_cohort_1
)
membership1.save()
membership2 = CohortMembership(
course_id=course_1_auto_cohort_1.course_id,
user=self.user2,
course_user_group=course_1_auto_cohort_1
)
membership2.save()
# But before CohortMembership code was turned on, some changes were made:
course_1_auto_cohort_2.users.add(self.user1) # user1 is now in 2 cohorts in the same course!
course_1_auto_cohort_2.users.add(self.user2)
course_1_auto_cohort_1.users.remove(self.user2) # and user2 was moved, but no one told CohortMembership!
# run the post-CohortMembership command, dry-run
call_command('post_cohort_membership_fix')
# Verify nothing was changed in dry-run mode.
self.assertEqual(self.user1.course_groups.count(), 2) # CourseUserGroup has 2 entries for user1
self.assertEqual(CohortMembership.objects.get(user=self.user2).course_user_group.name, 'Course1AutoGroup1')
user2_cohorts = list(self.user2.course_groups.values_list('name', flat=True))
self.assertEqual(user2_cohorts, ['Course1AutoGroup2']) # CourseUserGroup and CohortMembership disagree
# run the post-CohortMembership command, and commit it
call_command('post_cohort_membership_fix', commit='commit')
# verify that both databases agree about the (corrected) state of the memberships
self.assertEqual(self.user1.course_groups.count(), 1)
self.assertEqual(CohortMembership.objects.filter(user=self.user1).count(), 1)
self.assertEqual(self.user2.course_groups.count(), 1)
self.assertEqual(CohortMembership.objects.filter(user=self.user2).count(), 1)
self.assertEqual(CohortMembership.objects.get(user=self.user2).course_user_group.name, 'Course1AutoGroup2')
user2_cohorts = list(self.user2.course_groups.values_list('name', flat=True))
self.assertEqual(user2_cohorts, ['Course1AutoGroup2'])
......@@ -10,7 +10,6 @@ from django.core.exceptions import ValidationError
from django.db import models, transaction
from django.db.models.signals import pre_delete
from django.dispatch import receiver
from util.db import outer_atomic
from opaque_keys.edx.django.models import CourseKeyField
from openedx.core.djangolib.model_mixins import DeletableByUserValue
......@@ -74,10 +73,6 @@ class CohortMembership(models.Model):
user = models.ForeignKey(User, on_delete=models.CASCADE)
course_id = CourseKeyField(max_length=255)
previous_cohort = None
previous_cohort_name = None
previous_cohort_id = None
class Meta(object):
unique_together = (('user', 'course_id'), )
......@@ -92,52 +87,47 @@ class CohortMembership(models.Model):
if self.course_user_group.course_id != self.course_id:
raise ValidationError("Non-matching course_ids provided")
def save(self, *args, **kwargs):
self.full_clean(validate_unique=False)
log.info("Saving CohortMembership for user '%s' in '%s'", self.user.id, self.course_id)
# Avoid infinite recursion if creating from get_or_create() call below.
# This block also allows middleware to use CohortMembership.get_or_create without worrying about outer_atomic
if 'force_insert' in kwargs and kwargs['force_insert'] is True:
with transaction.atomic():
self.course_user_group.users.add(self.user)
super(CohortMembership, self).save(*args, **kwargs)
return
# This block will transactionally commit updates to CohortMembership and underlying course_user_groups.
# Note the use of outer_atomic, which guarantees that operations are committed to the database on block exit.
# If called from a view method, that method must be marked with @transaction.non_atomic_requests.
with outer_atomic(read_committed=True):
saved_membership, created = CohortMembership.objects.select_for_update().get_or_create(
user__id=self.user.id,
course_id=self.course_id,
@classmethod
def assign(cls, cohort, user):
"""
Assign user to cohort, switching them to this cohort if they had previously been assigned to another
cohort
Returns CohortMembership, previous_cohort (if any)
"""
with transaction.atomic():
membership, created = cls.objects.select_for_update().get_or_create(
user__id=user.id,
course_id=cohort.course_id,
defaults={
'course_user_group': self.course_user_group,
'user': self.user
}
)
'course_user_group': cohort,
'user': user
})
# If the membership was newly created, all the validation and course_user_group logic was settled
# with a call to self.save(force_insert=True), which gets handled above.
if created:
return
if saved_membership.course_user_group == self.course_user_group:
membership.course_user_group.users.add(user)
previous_cohort = None
elif membership.course_user_group == cohort:
raise ValueError("User {user_name} already present in cohort {cohort_name}".format(
user_name=self.user.username,
cohort_name=self.course_user_group.name
))
self.previous_cohort = saved_membership.course_user_group
self.previous_cohort_name = saved_membership.course_user_group.name
self.previous_cohort_id = saved_membership.course_user_group.id
self.previous_cohort.users.remove(self.user)
user_name=user.username,
cohort_name=cohort.name))
else:
previous_cohort = membership.course_user_group
previous_cohort.users.remove(user)
membership.course_user_group = cohort
membership.course_user_group.users.add(user)
membership.save()
return membership, previous_cohort
def save(self, force_insert=False, force_update=False, using=None, update_fields=None):
self.full_clean(validate_unique=False)
saved_membership.course_user_group = self.course_user_group
self.course_user_group.users.add(self.user)
log.info("Saving CohortMembership for user '%s' in '%s'", self.user.id, self.course_id)
super(CohortMembership, saved_membership).save(update_fields=['course_user_group'])
return super(CohortMembership, self).save(force_insert=force_insert,
force_update=force_update,
using=using,
update_fields=update_fields)
# Needs to exist outside class definition in order to use 'sender=CohortMembership'
......@@ -243,7 +233,7 @@ class UnregisteredLearnerCohortAssignments(DeletableByUserValue, models.Model):
"""
Tracks the assignment of an unregistered learner to a course's cohort.
"""
#pylint: disable=model-missing-unicode
# pylint: disable=model-missing-unicode
class Meta(object):
unique_together = (('course_id', 'email'), )
......
......@@ -10,7 +10,6 @@ from django.contrib.auth.models import User
from django.core.exceptions import ValidationError
from django.core.paginator import EmptyPage, Paginator
from django.urls import reverse
from django.db import transaction
from django.http import Http404, HttpResponseBadRequest
from django.utils.translation import ugettext
from django.views.decorators.csrf import ensure_csrf_cookie
......@@ -253,7 +252,6 @@ def users_in_cohort(request, course_key_string, cohort_id):
'users': user_info})
@transaction.non_atomic_requests
@ensure_csrf_cookie
@require_POST
def add_users_to_cohort(request, course_key_string, cohort_id):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment