diff --git a/common/djangoapps/third_party_auth/management/commands/saml.py b/common/djangoapps/third_party_auth/management/commands/saml.py
index 01918157ae0acf6787c408e9395708dd443bc97d..dd2c6c104915bf24e2343093a134ec0d5837fd46 100644
--- a/common/djangoapps/third_party_auth/management/commands/saml.py
+++ b/common/djangoapps/third_party_auth/management/commands/saml.py
@@ -12,16 +12,14 @@ class Command(BaseCommand):
     """ manage.py commands to manage SAML/Shibboleth SSO """
     help = '''Configure/maintain/update SAML-based SSO'''
 
-    def handle(self, *args, **options):
-        if len(args) != 1:
-            raise CommandError("saml requires one argument: pull")
+    def add_arguments(self, parser):
+        parser.add_argument('--pull', action='store_true', help="Pull updated metadata from external IDPs")
 
+    def handle(self, *args, **options):
         if not SAMLConfiguration.is_enabled():
             raise CommandError("SAML support is disabled via SAMLConfiguration.")
 
-        subcommand = args[0]
-
-        if subcommand == "pull":
+        if options['pull']:
             log_handler = logging.StreamHandler(self.stdout)
             log_handler.setLevel(logging.DEBUG)
             log = logging.getLogger('third_party_auth.tasks')
diff --git a/common/test/acceptance/tests/lms/test_lms_problems.py b/common/test/acceptance/tests/lms/test_lms_problems.py
index aa6e08dec4703c40519fc7c890de030d2b1e3c2a..8a353ea7359d64e706f21b57c53495567b1540f4 100644
--- a/common/test/acceptance/tests/lms/test_lms_problems.py
+++ b/common/test/acceptance/tests/lms/test_lms_problems.py
@@ -10,6 +10,7 @@ from ..helpers import UniqueCourseTest
 from ...pages.studio.auto_auth import AutoAuthPage
 from ...pages.lms.courseware import CoursewarePage
 from ...pages.lms.problem import ProblemPage
+from ...pages.lms.login_and_register import CombinedLoginAndRegisterPage
 from ...fixtures.course import CourseFixture, XBlockFixtureDesc
 from ..helpers import EventsTestMixin
 
@@ -20,6 +21,7 @@ class ProblemsTest(UniqueCourseTest):
     """
     USERNAME = "joe_student"
     EMAIL = "joe@example.com"
+    PASSWORD = "keep it secret; keep it safe."
 
     def setUp(self):
         super(ProblemsTest, self).setUp()
@@ -42,8 +44,14 @@ class ProblemsTest(UniqueCourseTest):
         ).install()
 
         # Auto-auth register for the course.
-        AutoAuthPage(self.browser, username=self.USERNAME, email=self.EMAIL,
-                     course_id=self.course_id, staff=False).visit()
+        AutoAuthPage(
+            self.browser,
+            username=self.USERNAME,
+            email=self.EMAIL,
+            password=self.PASSWORD,
+            course_id=self.course_id,
+            staff=False
+        ).visit()
 
     def get_problem(self):
         """ Subclasses should override this to complete the fixture """
@@ -321,3 +329,85 @@ class ProblemPartialCredit(ProblemsTest):
         problem_page.click_check()
         problem_page.wait_for_status_icon()
         self.assertTrue(problem_page.simpleprob_is_partially_correct())
+
+
+class LogoutDuringAnswering(ProblemsTest):
+    """
+    Tests for the scenario where a user is logged out (their session expires
+    or is revoked) just before they click "check" on a problem.
+    """
+    def get_problem(self):
+        """
+        Create a problem.
+        """
+        xml = dedent("""
+            <problem>
+                <p>The answer is 1</p>
+                <numericalresponse answer="1">
+                    <formulaequationinput label="where are the songs of spring?" />
+                    <responseparam type="tolerance" default="0.01" />
+                </numericalresponse>
+            </problem>
+        """)
+        return XBlockFixtureDesc('problem', 'TEST PROBLEM', data=xml)
+
+    def log_user_out(self):
+        """
+        Log the user out by deleting their session cookie.
+        """
+        self.browser.delete_cookie('sessionid')
+
+    def test_logout_after_click_redirect(self):
+        """
+        1) User goes to a problem page.
+        2) User fills out an answer to the problem.
+        3) User is logged out because their session id is invalidated or removed.
+        4) User clicks "check", and sees a confirmation modal asking them to
+           re-authenticate, since they've just been logged out.
+        5) User clicks "ok".
+        6) User is redirected to the login page.
+        7) User logs in.
+        8) User is redirected back to the problem page they started out on.
+        9) User is able to submit an answer
+        """
+        self.courseware_page.visit()
+        problem_page = ProblemPage(self.browser)
+        self.assertEqual(problem_page.problem_name, 'TEST PROBLEM')
+        problem_page.fill_answer_numerical('1')
+
+        self.log_user_out()
+        with problem_page.handle_alert(confirm=True):
+            problem_page.click_check()
+
+        login_page = CombinedLoginAndRegisterPage(self.browser)
+        login_page.wait_for_page()
+
+        login_page.login(self.EMAIL, self.PASSWORD)
+
+        problem_page.wait_for_page()
+        self.assertEqual(problem_page.problem_name, 'TEST PROBLEM')
+
+        problem_page.fill_answer_numerical('1')
+        problem_page.click_check()
+        self.assertTrue(problem_page.simpleprob_is_correct())
+
+    def test_logout_cancel_no_redirect(self):
+        """
+        1) User goes to a problem page.
+        2) User fills out an answer to the problem.
+        3) User is logged out because their session id is invalidated or removed.
+        4) User clicks "check", and sees a confirmation modal asking them to
+           re-authenticate, since they've just been logged out.
+        5) User clicks "cancel".
+        6) User is not redirected to the login page.
+        """
+        self.courseware_page.visit()
+        problem_page = ProblemPage(self.browser)
+        self.assertEqual(problem_page.problem_name, 'TEST PROBLEM')
+        problem_page.fill_answer_numerical('1')
+        self.log_user_out()
+        with problem_page.handle_alert(confirm=False):
+            problem_page.click_check()
+
+        self.assertTrue(problem_page.is_browser_on_page())
+        self.assertEqual(problem_page.problem_name, 'TEST PROBLEM')
diff --git a/lms/static/js/ajax-error.js b/lms/static/js/ajax-error.js
index 460d2511fe6dc9d1c289d6364830695c2e2c771c..6b29d4c24bb1ed483a2be53b23d994e4f3f7af53 100644
--- a/lms/static/js/ajax-error.js
+++ b/lms/static/js/ajax-error.js
@@ -8,8 +8,8 @@ $(document).ajaxError(function (event, jXHR) {
         );
 
         if (window.confirm(message)) {
-            var currentLocation = window.location.href;
-            window.location.href = '/login?next=' + currentLocation;
+            var currentLocation = window.location.pathname;
+            window.location.href = '/login?next=' + encodeURIComponent(currentLocation);
         };
     }
 });