diff --git a/common/lib/xmodule/xmodule/tests/test_error_module.py b/common/lib/xmodule/xmodule/tests/test_error_module.py
index ce2f68754c70bd8e58e5acc60c3cee494e48d56e..341bca4d05c5ba43d6bceb01f3555ed445157af6 100644
--- a/common/lib/xmodule/xmodule/tests/test_error_module.py
+++ b/common/lib/xmodule/xmodule/tests/test_error_module.py
@@ -3,10 +3,13 @@ Tests for ErrorModule and NonStaffErrorModule
 """
 import unittest
 from xmodule.tests import get_test_system
-import xmodule.error_module as error_module
+from xmodule.error_module import ErrorDescriptor, ErrorModule, NonStaffErrorDescriptor
 from xmodule.modulestore import Location
-from xmodule.x_module import XModuleDescriptor
-from mock import MagicMock
+from xmodule.x_module import XModuleDescriptor, XModule
+from mock import MagicMock, Mock, patch
+from xblock.runtime import Runtime, UsageStore
+from xblock.field_data import FieldData
+from xblock.fields import ScopeIds
 
 
 class SetupTestErrorModules():
@@ -27,9 +30,9 @@ class TestErrorModule(unittest.TestCase, SetupTestErrorModules):
         SetupTestErrorModules.setUp(self)
 
     def test_error_module_xml_rendering(self):
-        descriptor = error_module.ErrorDescriptor.from_xml(
+        descriptor = ErrorDescriptor.from_xml(
             self.valid_xml, self.system, self.org, self.course, self.error_msg)
-        self.assertIsInstance(descriptor, error_module.ErrorDescriptor)
+        self.assertIsInstance(descriptor, ErrorDescriptor)
         descriptor.xmodule_runtime = self.system
         context_repr = self.system.render(descriptor, 'student_view').content
         self.assertIn(self.error_msg, context_repr)
@@ -41,9 +44,9 @@ class TestErrorModule(unittest.TestCase, SetupTestErrorModules):
                                location=self.location,
                                _field_data=self.valid_xml)
 
-        error_descriptor = error_module.ErrorDescriptor.from_descriptor(
+        error_descriptor = ErrorDescriptor.from_descriptor(
             descriptor, self.error_msg)
-        self.assertIsInstance(error_descriptor, error_module.ErrorDescriptor)
+        self.assertIsInstance(error_descriptor, ErrorDescriptor)
         error_descriptor.xmodule_runtime = self.system
         context_repr = self.system.render(error_descriptor, 'student_view').content
         self.assertIn(self.error_msg, context_repr)
@@ -58,12 +61,12 @@ class TestNonStaffErrorModule(unittest.TestCase, SetupTestErrorModules):
         SetupTestErrorModules.setUp(self)
 
     def test_non_staff_error_module_create(self):
-        descriptor = error_module.NonStaffErrorDescriptor.from_xml(
+        descriptor = NonStaffErrorDescriptor.from_xml(
             self.valid_xml, self.system, self.org, self.course)
-        self.assertIsInstance(descriptor, error_module.NonStaffErrorDescriptor)
+        self.assertIsInstance(descriptor, NonStaffErrorDescriptor)
 
     def test_from_xml_render(self):
-        descriptor = error_module.NonStaffErrorDescriptor.from_xml(
+        descriptor = NonStaffErrorDescriptor.from_xml(
             self.valid_xml, self.system, self.org, self.course)
         descriptor.xmodule_runtime = self.system
         context_repr = self.system.render(descriptor, 'student_view').content
@@ -76,10 +79,66 @@ class TestNonStaffErrorModule(unittest.TestCase, SetupTestErrorModules):
                                location=self.location,
                                _field_data=self.valid_xml)
 
-        error_descriptor = error_module.NonStaffErrorDescriptor.from_descriptor(
+        error_descriptor = NonStaffErrorDescriptor.from_descriptor(
             descriptor, self.error_msg)
-        self.assertIsInstance(error_descriptor, error_module.ErrorDescriptor)
+        self.assertIsInstance(error_descriptor, ErrorDescriptor)
         error_descriptor.xmodule_runtime = self.system
         context_repr = self.system.render(error_descriptor, 'student_view').content
         self.assertNotIn(self.error_msg, context_repr)
         self.assertNotIn(str(descriptor), context_repr)
+
+
+class BrokenModule(XModule):
+    def __init__(self, *args, **kwargs):
+        super(BrokenModule, self).__init__(*args, **kwargs)
+        raise Exception("This is a broken xmodule")
+
+
+class BrokenDescriptor(XModuleDescriptor):
+    module_class = BrokenModule
+
+
+class TestException(Exception):
+    """An exception type to use to verify raises in tests"""
+    pass
+
+
+class TestErrorModuleConstruction(unittest.TestCase):
+    """
+    Test that error module construction happens correctly
+    """
+
+    def setUp(self):
+        field_data = Mock(spec=FieldData)
+        self.descriptor = BrokenDescriptor(
+            Runtime(Mock(spec=UsageStore), field_data),
+            field_data,
+            ScopeIds(None, None, None, 'i4x://org/course/broken/name')
+        )
+        self.descriptor.xmodule_runtime = Runtime(Mock(spec=UsageStore), field_data)
+        self.descriptor.xmodule_runtime.error_descriptor_class = ErrorDescriptor
+        self.descriptor.xmodule_runtime.xmodule_instance = None
+
+    def test_broken_module(self):
+        """
+        Test that when an XModule throws an error during __init__, we
+        get an ErrorModule back from XModuleDescriptor._xmodule
+        """
+        module = self.descriptor._xmodule
+        self.assertIsInstance(module, ErrorModule)
+
+    @patch.object(ErrorDescriptor, '__init__', Mock(side_effect=TestException))
+    def test_broken_error_descriptor(self):
+        """
+        Test that a broken error descriptor doesn't cause an infinite loop
+        """
+        with self.assertRaises(TestException):
+            module = self.descriptor._xmodule
+
+    @patch.object(ErrorModule, '__init__', Mock(side_effect=TestException))
+    def test_broken_error_module(self):
+        """
+        Test that a broken error module doesn't cause an infinite loop
+        """
+        with self.assertRaises(TestException):
+            module = self.descriptor._xmodule
diff --git a/common/lib/xmodule/xmodule/x_module.py b/common/lib/xmodule/xmodule/x_module.py
index cd5a5ef3d9018a392ef59fd701eeaea1d4b195ea..d39095586c238df8255645e9ecb5bf44ea63aa31 100644
--- a/common/lib/xmodule/xmodule/x_module.py
+++ b/common/lib/xmodule/xmodule/x_module.py
@@ -771,6 +771,10 @@ class XModuleDescriptor(XModuleMixin, HTMLSnippet, ResourceTemplates, XBlock):
                 )
                 self.xmodule_runtime.xmodule_instance.save()
             except Exception:  # pylint: disable=broad-except
+                # xmodule_instance is set by the XModule.__init__. If we had an error after that,
+                # we need to clean it out so that we can set up the ErrorModule instead
+                self.xmodule_runtime.xmodule_instance = None
+
                 if isinstance(self, self.xmodule_runtime.error_descriptor_class):
                     log.exception('Error creating an ErrorModule from an ErrorDescriptor')
                     raise