diff --git a/lms/djangoapps/experiments/tests/test_views.py b/lms/djangoapps/experiments/tests/test_views.py index b26f5f586cc778cf2a6322d08374bd557d69c6d3..b007cb1cea432d2400cfbdeef80dd16098f07730 100644 --- a/lms/djangoapps/experiments/tests/test_views.py +++ b/lms/djangoapps/experiments/tests/test_views.py @@ -1,6 +1,11 @@ import urllib +import unittest +from django.conf import settings +from django.core.handlers.wsgi import WSGIRequest from django.core.urlresolvers import reverse +from django.test.utils import override_settings +from mock import patch from rest_framework.test import APITestCase from experiments.factories import ExperimentDataFactory, ExperimentKeyValueFactory @@ -9,7 +14,11 @@ from experiments.serializers import ExperimentDataSerializer from student.tests.factories import UserFactory +CROSS_DOMAIN_REFERER = 'https://ecommerce.edx.org' + + class ExperimentDataViewSetTests(APITestCase): + def assert_data_created_for_user(self, user, method='post', status=201): url = reverse('api_experiments:v0:data-list') data = { @@ -210,6 +219,99 @@ class ExperimentDataViewSetTests(APITestCase): ExperimentData.objects.get(user=other_user, **kwargs) +def cross_domain_config(func): + """Decorator for configuring a cross-domain request. """ + feature_flag_decorator = patch.dict(settings.FEATURES, { + 'ENABLE_CORS_HEADERS': True, + 'ENABLE_CROSS_DOMAIN_CSRF_COOKIE': True + }) + settings_decorator = override_settings( + CORS_ORIGIN_WHITELIST=['ecommerce.edx.org'], + CSRF_COOKIE_NAME="prod-edx-csrftoken", + CROSS_DOMAIN_CSRF_COOKIE_NAME="prod-edx-csrftoken", + CROSS_DOMAIN_CSRF_COOKIE_DOMAIN=".edx.org" + ) + is_secure_decorator = patch.object(WSGIRequest, 'is_secure', return_value=True) + + return feature_flag_decorator( + settings_decorator( + is_secure_decorator(func) + ) + ) + + +@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms') +class ExperimentCrossDomainTests(APITestCase): + """Tests for handling cross-domain requests""" + + def setUp(self): + super(ExperimentCrossDomainTests, self).setUp() + self.client = self.client_class(enforce_csrf_checks=True) + + @cross_domain_config + def test_cross_domain_create(self, *args): # pylint: disable=unused-argument + user = UserFactory() + self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD) # pylint: disable=protected-access + csrf_cookie = self._get_csrf_cookie() + data = { + 'experiment_id': 1, + 'key': 'foo', + 'value': 'bar', + } + resp = self._cross_domain_post(csrf_cookie, data) + + # Expect that the request gets through successfully, + # passing the CSRF checks (including the referer check). + self.assertEqual(resp.status_code, 201) + + @cross_domain_config + def test_cross_domain_invalid_csrf_header(self, *args): # pylint: disable=unused-argument + user = UserFactory() + self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD) # pylint: disable=protected-access + self._get_csrf_cookie() + data = { + 'experiment_id': 1, + 'key': 'foo', + 'value': 'bar', + } + resp = self._cross_domain_post('invalid_csrf_token', data) + self.assertEqual(resp.status_code, 403) + + @cross_domain_config + def test_cross_domain_not_in_whitelist(self, *args): # pylint: disable=unused-argument + user = UserFactory() + self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD) # pylint: disable=protected-access + csrf_cookie = self._get_csrf_cookie() + data = { + 'experiment_id': 1, + 'key': 'foo', + 'value': 'bar', + } + resp = self._cross_domain_post(csrf_cookie, data, referer='www.example.com') + self.assertEqual(resp.status_code, 403) + + def _get_csrf_cookie(self): + """Retrieve the cross-domain CSRF cookie. """ + url = reverse('courseenrollments') + resp = self.client.get(url, HTTP_REFERER=CROSS_DOMAIN_REFERER) + self.assertEqual(resp.status_code, 200) + self.assertIn(settings.CSRF_COOKIE_NAME, resp.cookies) + return resp.cookies[settings.CSRF_COOKIE_NAME].value + + def _cross_domain_post(self, csrf_token, data, referer=CROSS_DOMAIN_REFERER): + """Perform a cross-domain POST request. """ + url = reverse('api_experiments:v0:data-list') + kwargs = { + 'HTTP_REFERER': referer, + settings.CSRF_HEADER_NAME: csrf_token, + } + return self.client.post( + url, + data, + **kwargs + ) + + class ExperimentKeyValueViewSetTests(APITestCase): def test_permissions(self): """ Staff access is required for write operations. """ diff --git a/lms/djangoapps/experiments/views.py b/lms/djangoapps/experiments/views.py index e26ce73d1caa4125e2ecc3fcea8b326e4def015d..a22f66b39eaae68163c6a95dfa77d14307f1d969 100644 --- a/lms/djangoapps/experiments/views.py +++ b/lms/djangoapps/experiments/views.py @@ -10,12 +10,18 @@ from experiments import filters, serializers from experiments.models import ExperimentData, ExperimentKeyValue from experiments.permissions import IsStaffOrOwner, IsStaffOrReadOnly from openedx.core.lib.api.authentication import SessionAuthenticationAllowInactiveUser +from openedx.core.djangoapps.cors_csrf.authentication import SessionAuthenticationCrossDomainCsrf User = get_user_model() # pylint: disable=invalid-name +class ExperimentCrossDomainSessionAuth(SessionAuthenticationAllowInactiveUser, SessionAuthenticationCrossDomainCsrf): + """Session authentication that allows inactive users and cross-domain requests. """ + pass + + class ExperimentDataViewSet(viewsets.ModelViewSet): - authentication_classes = (JwtAuthentication, SessionAuthenticationAllowInactiveUser,) + authentication_classes = (JwtAuthentication, ExperimentCrossDomainSessionAuth,) filter_backends = (DjangoFilterBackend,) filter_class = filters.ExperimentDataFilter permission_classes = (permissions.IsAuthenticated, IsStaffOrOwner,) @@ -83,7 +89,7 @@ class ExperimentDataViewSet(viewsets.ModelViewSet): class ExperimentKeyValueViewSet(viewsets.ModelViewSet): - authentication_classes = (JwtAuthentication, SessionAuthenticationAllowInactiveUser,) + authentication_classes = (JwtAuthentication, ExperimentCrossDomainSessionAuth,) filter_backends = (DjangoFilterBackend,) filter_class = filters.ExperimentKeyValueFilter permission_classes = (IsStaffOrReadOnly,)