diff --git a/common/djangoapps/student/tests/tests.py b/common/djangoapps/student/tests/tests.py index cc25c1eb23a30a64c2a1ee2a26af3015a92d43a1..f7035db51cf7e06bdfe4e7ab611cb02b9cb7a012 100644 --- a/common/djangoapps/student/tests/tests.py +++ b/common/djangoapps/student/tests/tests.py @@ -11,6 +11,7 @@ import unittest from datetime import datetime, timedelta import pytz +from django.core.cache import cache from django.conf import settings from django.test import TestCase from django.test.utils import override_settings @@ -89,6 +90,23 @@ class ResetPasswordTests(TestCase): 'value': "('registration/password_reset_done.html', [])", }) + @patch('student.views.render_to_string', Mock(side_effect=mock_render_to_string, autospec=True)) + def test_password_reset_ratelimited(self): + """ Try (and fail) resetting password 30 times in a row on an non-existant email address """ + cache.clear() + + for i in xrange(30): + good_req = self.request_factory.post('/password_reset/', {'email': 'thisdoesnotexist@foo.com'}) + good_resp = password_reset(good_req) + self.assertEquals(good_resp.status_code, 200) + + # then the rate limiter should kick in and give a HttpForbidden response + bad_req = self.request_factory.post('/password_reset/', {'email': 'thisdoesnotexist@foo.com'}) + bad_resp = password_reset(bad_req) + self.assertEquals(bad_resp.status_code, 403) + + cache.clear() + @unittest.skipIf( settings.FEATURES.get('DISABLE_RESET_EMAIL_TEST', False), dedent(""" diff --git a/common/djangoapps/student/views.py b/common/djangoapps/student/views.py index fcd4508bf64839e0baf810d21665de41a99c3577..c685994cb8f25961008c76485c74675b311d5565 100644 --- a/common/djangoapps/student/views.py +++ b/common/djangoapps/student/views.py @@ -72,6 +72,7 @@ import track.views from dogapi import dog_stats_api from util.json_request import JsonResponse +from util.bad_request_rate_limiter import BadRequestRateLimiter from microsite_configuration.middleware import MicrositeConfiguration @@ -86,7 +87,6 @@ AUDIT_LOG = logging.getLogger("audit") Article = namedtuple('Article', 'title url author image deck publication publish_date') ReverifyInfo = namedtuple('ReverifyInfo', 'course_id course_name course_number date status display') # pylint: disable=C0103 - def csrf_token(context): """A csrf token that can be included in a form.""" csrf_token = context.get('csrf_token', '') @@ -1345,12 +1345,23 @@ def password_reset(request): if request.method != "POST": raise Http404 + # Add some rate limiting here by re-using the RateLimitMixin as a helper class + limiter = BadRequestRateLimiter() + if limiter.is_rated_limit_exceeded(request): + AUDIT_LOG.warning("Rate limit exceeded in password_reset") + return HttpResponseForbidden() + form = PasswordResetFormNoActive(request.POST) if form.is_valid(): form.save(use_https=request.is_secure(), from_email=settings.DEFAULT_FROM_EMAIL, request=request, domain_override=request.get_host()) + else: + # bad user? tick the rate limiter counter + AUDIT_LOG.info("Bad password_reset user passed in.") + limiter.tick_bad_request_counter(request) + return JsonResponse({ 'success': True, 'value': render_to_string('registration/password_reset_done.html', {}), diff --git a/common/djangoapps/util/bad_request_rate_limiter.py b/common/djangoapps/util/bad_request_rate_limiter.py new file mode 100644 index 0000000000000000000000000000000000000000..fe596cc84904ed167529f9405a1336c7e26bfa88 --- /dev/null +++ b/common/djangoapps/util/bad_request_rate_limiter.py @@ -0,0 +1,23 @@ +""" +A utility class which wraps the RateLimitMixin 3rd party class to do bad request counting +which can be used for rate limiting +""" +from ratelimitbackend.backends import RateLimitMixin + +class BadRequestRateLimiter(RateLimitMixin): + """ + Use the 3rd party RateLimitMixin to help do rate limiting on the Password Reset flows + """ + + def is_rated_limit_exceeded(self, request): + """ + Returns if the client has been rated limited + """ + counts = self.get_counters(request) + return sum(counts.values()) >= self.requests + + def tick_bad_request_counter(self, request): + """ + Ticks any counters used to compute when rate limt has been reached + """ + self.cache_incr(self.get_cache_key(request))