Skip to content
Snippets Groups Projects
test_views.py 15 KiB
Newer Older
Gabe Mulley's avatar
Gabe Mulley committed
import unittest
Gabe Mulley's avatar
Gabe Mulley committed
from django.conf import settings
from django.core.handlers.wsgi import WSGIRequest
from django.core.urlresolvers import reverse
Gabe Mulley's avatar
Gabe Mulley committed
from django.test.utils import override_settings
from mock import patch
from rest_framework.test import APITestCase

from experiments.factories import ExperimentDataFactory, ExperimentKeyValueFactory
from experiments.models import ExperimentData, ExperimentKeyValue
from experiments.serializers import ExperimentDataSerializer
from student.tests.factories import UserFactory


Gabe Mulley's avatar
Gabe Mulley committed
CROSS_DOMAIN_REFERER = 'https://ecommerce.edx.org'


class ExperimentDataViewSetTests(APITestCase):
Gabe Mulley's avatar
Gabe Mulley committed

    def assert_data_created_for_user(self, user, method='post', status=201):
        url = reverse('api_experiments:v0:data-list')
        data = {
            'experiment_id': 1,
            'key': 'foo',
            'value': 'bar',
        }
        self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD)
        response = getattr(self.client, method)(url, data)
        self.assertEqual(response.status_code, status)

        # This will raise an exception if no data exists
        ExperimentData.objects.get(user=user)

        data['user'] = user.username
        self.assertDictContainsSubset(data, response.data)

    def test_list_permissions(self):
        """ Users should only be able to list their own data. """
        url = reverse('api_experiments:v0:data-list')
        user = UserFactory()

        response = self.client.get(url)
        self.assertEqual(response.status_code, 401)

        ExperimentDataFactory()
        datum = ExperimentDataFactory(user=user)
        self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD)

        response = self.client.get(url)
        self.assertEqual(response.status_code, 200)
        self.assertEqual(response.data['results'], ExperimentDataSerializer([datum], many=True).data)

    def test_list_filtering(self):
        """ Users should be able to filter by the experiment_id and key fields. """
        url = reverse('api_experiments:v0:data-list')
        user = UserFactory()
        self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD)

        experiment_id = 1
        ExperimentDataFactory()
        ExperimentDataFactory(user=user)
        data = ExperimentDataFactory.create_batch(3, user=user, experiment_id=experiment_id)

        qs = urllib.urlencode({'experiment_id': experiment_id})
        response = self.client.get('{url}?{qs}'.format(url=url, qs=qs))
        self.assertEqual(response.status_code, 200)
        self.assertEqual(response.data['results'], ExperimentDataSerializer(data, many=True).data)

        datum = data[0]
        qs = urllib.urlencode({'key': datum.key})
        response = self.client.get('{url}?{qs}'.format(url=url, qs=qs))
        self.assertEqual(response.status_code, 200)
        self.assertEqual(response.data['results'], ExperimentDataSerializer([datum], many=True).data)

        qs = urllib.urlencode({'experiment_id': experiment_id, 'key': datum.key})
        response = self.client.get('{url}?{qs}'.format(url=url, qs=qs))
        self.assertEqual(response.status_code, 200)
        self.assertEqual(response.data['results'], ExperimentDataSerializer([datum], many=True).data)

    def test_read_permissions(self):
        """ Users should only be allowed to read their own data. """
        user = UserFactory()
        datum = ExperimentDataFactory(user=user)
        url = reverse('api_experiments:v0:data-detail', kwargs={'pk': datum.id})

        response = self.client.get(url)
        self.assertEqual(response.status_code, 401)

        self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD)
        response = self.client.get(url)
        self.assertEqual(response.status_code, 200)

        other_user = UserFactory()
        self.client.login(username=other_user.username, password=UserFactory._DEFAULT_PASSWORD)
        response = self.client.get(url)
        self.assertEqual(response.status_code, 404)

    def test_create_permissions(self):
        """ Users should only be allowed to create data for themselves. """
        url = reverse('api_experiments:v0:data-list')

        # Authentication is required
        response = self.client.post(url, {})
        self.assertEqual(response.status_code, 401)

        user = UserFactory()
        data = {
            'experiment_id': 1,
            'key': 'foo',
            'value': 'bar',
        }
        self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD)

        # Users can create data for themselves
        response = self.client.post(url, data)
        self.assertEqual(response.status_code, 201)
        ExperimentData.objects.get(user=user)

        # A non-staff user cannot create data for another user
        other_user = UserFactory()
        data['user'] = other_user.username
        response = self.client.post(url, data)
        self.assertEqual(response.status_code, 403)
        self.assertFalse(ExperimentData.objects.filter(user=other_user).exists())

        # A staff user can create data for other users
        user.is_staff = True
        user.save()
        response = self.client.post(url, data)
        self.assertEqual(response.status_code, 201)
        ExperimentData.objects.get(user=other_user)

    def test_put_as_create(self):
        """ Users should be able to use PUT to create new data. """
        user = UserFactory()
        self.assert_data_created_for_user(user, 'put')

        # Subsequent requests should update the data
        self.assert_data_created_for_user(user, 'put', 200)

    def test_update_permissions(self):
        """ Users should only be allowed to update their own data. """
        user = UserFactory()
        other_user = UserFactory()
        datum = ExperimentDataFactory(user=user)
        url = reverse('api_experiments:v0:data-detail', kwargs={'pk': datum.id})
        data = {}

        response = self.client.patch(url, data)
        self.assertEqual(response.status_code, 401)

        self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD)
        response = self.client.patch(url, data)
        self.assertEqual(response.status_code, 200)

        self.client.login(username=other_user.username, password=UserFactory._DEFAULT_PASSWORD)
        response = self.client.patch(url, data)
        self.assertEqual(response.status_code, 404)

    def test_bulk_upsert_permissions(self):
        """ Only staff users can access the bulk upsert endpoint. """
        url = reverse('api_experiments:v0:data-bulk-upsert')
        data = []

        # Authentication is required
        response = self.client.put(url, data, format='json')
        self.assertEqual(response.status_code, 401)

        user = UserFactory()
        self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD)

        # No access to non-staff users
        response = self.client.put(url, data, format='json')
        self.assertEqual(response.status_code, 403)

        user.is_staff = True
        user.save()
        response = self.client.put(url, data, format='json')
        self.assertEqual(response.status_code, 200)

    def test_bulk_upsert(self):
        """ The endpoint should support creating/updating multiple ExperimentData objects with a single call. """
        url = reverse('api_experiments:v0:data-bulk-upsert')
        experiment_id = 1
        user = UserFactory(is_staff=True)
        other_user = UserFactory()
        self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD)

        data = [
            {
                'experiment_id': experiment_id,
                'key': 'foo',
                'value': 'bar',
                'user': user.username,
            },
            {
                'experiment_id': experiment_id,
                'key': 'foo',
                'value': 'bar',
                'user': other_user.username,
            },
        ]
        response = self.client.put(url, data, format='json')
        self.assertEqual(response.status_code, 200)
        kwargs = {
            'experiment_id': experiment_id,
            'key': 'foo',
            'value': 'bar',
        }
        ExperimentData.objects.get(user=user, **kwargs)
        ExperimentData.objects.get(user=other_user, **kwargs)

        # Subsequent calls should update the existing data rather than create more
        response = self.client.put(url, data, format='json')
        self.assertEqual(response.status_code, 200)
        ExperimentData.objects.get(user=user, **kwargs)
        ExperimentData.objects.get(user=other_user, **kwargs)
Gabe Mulley's avatar
Gabe Mulley committed
def cross_domain_config(func):
    """Decorator for configuring a cross-domain request. """
    feature_flag_decorator = patch.dict(settings.FEATURES, {
        'ENABLE_CORS_HEADERS': True,
        'ENABLE_CROSS_DOMAIN_CSRF_COOKIE': True
    })
    settings_decorator = override_settings(
        CORS_ORIGIN_WHITELIST=['ecommerce.edx.org'],
        CSRF_COOKIE_NAME="prod-edx-csrftoken",
        CROSS_DOMAIN_CSRF_COOKIE_NAME="prod-edx-csrftoken",
        CROSS_DOMAIN_CSRF_COOKIE_DOMAIN=".edx.org"
    )
    is_secure_decorator = patch.object(WSGIRequest, 'is_secure', return_value=True)

    return feature_flag_decorator(
        settings_decorator(
            is_secure_decorator(func)
        )
    )


@unittest.skipUnless(settings.ROOT_URLCONF == 'lms.urls', 'Test only valid in lms')
Gabe Mulley's avatar
Gabe Mulley committed
class ExperimentCrossDomainTests(APITestCase):
Gabe Mulley's avatar
Gabe Mulley committed
    """Tests for handling cross-domain requests"""
Gabe Mulley's avatar
Gabe Mulley committed

    def setUp(self):
        super(ExperimentCrossDomainTests, self).setUp()
        self.client = self.client_class(enforce_csrf_checks=True)

    @cross_domain_config
    def test_cross_domain_create(self, *args):  # pylint: disable=unused-argument
        user = UserFactory()
Gabe Mulley's avatar
Gabe Mulley committed
        self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD)  # pylint: disable=protected-access
Gabe Mulley's avatar
Gabe Mulley committed
        csrf_cookie = self._get_csrf_cookie()
        data = {
            'experiment_id': 1,
            'key': 'foo',
            'value': 'bar',
        }
        resp = self._cross_domain_post(csrf_cookie, data)

        # Expect that the request gets through successfully,
        # passing the CSRF checks (including the referer check).
        self.assertEqual(resp.status_code, 201)

    @cross_domain_config
    def test_cross_domain_invalid_csrf_header(self, *args):  # pylint: disable=unused-argument
        user = UserFactory()
Gabe Mulley's avatar
Gabe Mulley committed
        self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD)  # pylint: disable=protected-access
Gabe Mulley's avatar
Gabe Mulley committed
        self._get_csrf_cookie()
        data = {
            'experiment_id': 1,
            'key': 'foo',
            'value': 'bar',
        }
        resp = self._cross_domain_post('invalid_csrf_token', data)
        self.assertEqual(resp.status_code, 403)

    @cross_domain_config
    def test_cross_domain_not_in_whitelist(self, *args):  # pylint: disable=unused-argument
        user = UserFactory()
Gabe Mulley's avatar
Gabe Mulley committed
        self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD)  # pylint: disable=protected-access
Gabe Mulley's avatar
Gabe Mulley committed
        csrf_cookie = self._get_csrf_cookie()
        data = {
            'experiment_id': 1,
            'key': 'foo',
            'value': 'bar',
        }
        resp = self._cross_domain_post(csrf_cookie, data, referer='www.example.com')
        self.assertEqual(resp.status_code, 403)

    def _get_csrf_cookie(self):
        """Retrieve the cross-domain CSRF cookie. """
        url = reverse('courseenrollments')
        resp = self.client.get(url, HTTP_REFERER=CROSS_DOMAIN_REFERER)
        self.assertEqual(resp.status_code, 200)
Gabe Mulley's avatar
Gabe Mulley committed
        self.assertIn(settings.CSRF_COOKIE_NAME, resp.cookies)
        return resp.cookies[settings.CSRF_COOKIE_NAME].value
Gabe Mulley's avatar
Gabe Mulley committed

    def _cross_domain_post(self, csrf_token, data, referer=CROSS_DOMAIN_REFERER):
        """Perform a cross-domain POST request. """
        url = reverse('api_experiments:v0:data-list')
        kwargs = {
            'HTTP_REFERER': referer,
            settings.CSRF_HEADER_NAME: csrf_token,
        }
        return self.client.post(
            url,
            data,
            **kwargs
        )


class ExperimentKeyValueViewSetTests(APITestCase):
    def test_permissions(self):
        """ Staff access is required for write operations. """
        url = reverse('api_experiments:v0:key_value-list')

        response = self.client.get(url)
        self.assertEqual(response.status_code, 200)

        response = self.client.post(url, {})
        self.assertEqual(response.status_code, 401)

        instance = ExperimentKeyValueFactory()
        url = reverse('api_experiments:v0:key_value-detail', kwargs={'pk': instance.id})

        response = self.client.get(url)
        self.assertEqual(response.status_code, 200)

        user = UserFactory(is_staff=False)
        self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD)

        response = self.client.put(url, {})
        self.assertEqual(response.status_code, 403)

        response = self.client.patch(url, {})
        self.assertEqual(response.status_code, 403)

        response = self.client.delete(url)
        self.assertEqual(response.status_code, 403)

    def test_bulk_upsert_permissions(self):
        """ Non-staff users should not be allowed to access the endpoint.  """
        data = []
        url = reverse('api_experiments:v0:key_value-bulk-upsert')
        user = UserFactory(is_staff=False)

        # Authentication required
        response = self.client.put(url, data, format='json')
        self.assertEqual(response.status_code, 401)

        # Staff permission required
        self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD)
        response = self.client.put(url, data, format='json')
        self.assertEqual(response.status_code, 403)

    def test_bulk_upsert(self):
        """ The endpoint should support creating/updating multiple ExperimentData objects with a single call. """
        url = reverse('api_experiments:v0:key_value-bulk-upsert')
        experiment_id = 1
        user = UserFactory(is_staff=True)
        data = [
            {
                'experiment_id': experiment_id,
                'key': 'foo',
                'value': 'bar',
            },
            {
                'experiment_id': experiment_id,
                'key': 'foo1',
                'value': 'bar',
            },
        ]

        self.client.login(username=user.username, password=UserFactory._DEFAULT_PASSWORD)

        # New data should be created
        response = self.client.put(url, data, format='json')
        self.assertEqual(response.status_code, 200)
        kwargs = {
            'experiment_id': experiment_id,
            'value': 'bar',
        }
        ExperimentKeyValue.objects.get(key='foo', **kwargs)
        ExperimentKeyValue.objects.get(key='foo1', **kwargs)

        # Subsequent calls should update the existing data rather than create more
        response = self.client.put(url, data, format='json')
        self.assertEqual(response.status_code, 200)
        ExperimentKeyValue.objects.get(key='foo', **kwargs)
        ExperimentKeyValue.objects.get(key='foo1', **kwargs)