Skip to content
Snippets Groups Projects
Unverified Commit aa261ba9 authored by Gabe Mulley's avatar Gabe Mulley
Browse files

add tests

parent 5b89c2f5
Branches
Tags
No related merge requests found
import urllib
import unittest
from django.conf import settings
from django.core.handlers.wsgi import WSGIRequest
from django.core.urlresolvers import reverse
from django.test.utils import override_settings
from mock import patch
from rest_framework.test import APITestCase
from experiments.factories import ExperimentDataFactory, ExperimentKeyValueFactory
......@@ -9,7 +14,32 @@ from experiments.serializers import ExperimentDataSerializer
from student.tests.factories import UserFactory
CROSS_DOMAIN_REFERER = 'https://ecommerce.edx.org'
def cross_domain_config(func):
"""Decorator for configuring a cross-domain request. """
feature_flag_decorator = patch.dict(settings.FEATURES, {
'ENABLE_CORS_HEADERS': True,
'ENABLE_CROSS_DOMAIN_CSRF_COOKIE': True
})
settings_decorator = override_settings(
CORS_ORIGIN_WHITELIST=['ecommerce.edx.org'],
CSRF_COOKIE_NAME="prod-edx-csrftoken",
CROSS_DOMAIN_CSRF_COOKIE_NAME="prod-edx-csrftoken",
CROSS_DOMAIN_CSRF_COOKIE_DOMAIN=".edx.org"
)
is_secure_decorator = patch.object(WSGIRequest, 'is_secure', return_value=True)
return feature_flag_decorator(
settings_decorator(
is_secure_decorator(func)
)
)
class ExperimentDataViewSetTests(APITestCase):
def assert_data_created_for_user(self, user, method='post', status=201):
url = reverse('api_experiments:v0:data-list')
data = {
......@@ -210,6 +240,79 @@ class ExperimentDataViewSetTests(APITestCase):
ExperimentData.objects.get(user=other_user, **kwargs)
class ExperimentCrossDomainTests(APITestCase):
def setUp(self):
super(ExperimentCrossDomainTests, self).setUp()
self.client = self.client_class(enforce_csrf_checks=True)
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
@cross_domain_config
def test_cross_domain_create(self, *args): # pylint: disable=unused-argument
user = UserFactory()
self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD)
csrf_cookie = self._get_csrf_cookie()
data = {
'experiment_id': 1,
'key': 'foo',
'value': 'bar',
}
resp = self._cross_domain_post(csrf_cookie, data)
# Expect that the request gets through successfully,
# passing the CSRF checks (including the referer check).
self.assertEqual(resp.status_code, 201)
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
@cross_domain_config
def test_cross_domain_invalid_csrf_header(self, *args): # pylint: disable=unused-argument
user = UserFactory()
self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD)
self._get_csrf_cookie()
data = {
'experiment_id': 1,
'key': 'foo',
'value': 'bar',
}
resp = self._cross_domain_post('invalid_csrf_token', data)
self.assertEqual(resp.status_code, 403)
@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
@cross_domain_config
def test_cross_domain_not_in_whitelist(self, *args): # pylint: disable=unused-argument
user = UserFactory()
self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD)
csrf_cookie = self._get_csrf_cookie()
data = {
'experiment_id': 1,
'key': 'foo',
'value': 'bar',
}
resp = self._cross_domain_post(csrf_cookie, data, referer='www.example.com')
self.assertEqual(resp.status_code, 403)
def _get_csrf_cookie(self):
"""Retrieve the cross-domain CSRF cookie. """
url = reverse('courseenrollments')
resp = self.client.get(url, HTTP_REFERER=CROSS_DOMAIN_REFERER)
self.assertEqual(resp.status_code, 200)
self.assertIn(settings.CSRF_COOKIE_NAME, resp.cookies) # pylint: disable=no-member
return resp.cookies[settings.CSRF_COOKIE_NAME].value # pylint: disable=no-member
def _cross_domain_post(self, csrf_token, data, referer=CROSS_DOMAIN_REFERER):
"""Perform a cross-domain POST request. """
url = reverse('api_experiments:v0:data-list')
kwargs = {
'HTTP_REFERER': referer,
settings.CSRF_HEADER_NAME: csrf_token,
}
return self.client.post(
url,
data,
**kwargs
)
class ExperimentKeyValueViewSetTests(APITestCase):
def test_permissions(self):
""" Staff access is required for write operations. """
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment