From dcd74e6dd072009c8289fe2a7ac9152de6f6def3 Mon Sep 17 00:00:00 2001
From: Calen Pennington <calen.pennington@gmail.com>
Date: Thu, 28 Jun 2012 16:27:46 -0400
Subject: [PATCH] Make abtests work, using the new abtest xml format

---
 common/lib/xmodule/abtest_module.py | 90 +++++++++++++++++++----------
 common/lib/xmodule/exceptions.py    |  2 +
 common/lib/xmodule/setup.py         |  1 +
 common/lib/xmodule/x_module.py      |  1 +
 4 files changed, 62 insertions(+), 32 deletions(-)
 create mode 100644 common/lib/xmodule/exceptions.py

diff --git a/common/lib/xmodule/abtest_module.py b/common/lib/xmodule/abtest_module.py
index e14117eb08c..3bd268184a3 100644
--- a/common/lib/xmodule/abtest_module.py
+++ b/common/lib/xmodule/abtest_module.py
@@ -2,11 +2,10 @@ import json
 import random
 from lxml import etree
 
-from x_module import XModule, XModuleDescriptor
-
-
-class ModuleDescriptor(XModuleDescriptor):
-    pass
+from xmodule.x_module import XModule
+from xmodule.raw_module import RawDescriptor
+from xmodule.xml_module import XmlDescriptor
+from xmodule.exceptions import InvalidDefinitionError
 
 
 def group_from_value(groups, v):
@@ -25,7 +24,7 @@ def group_from_value(groups, v):
     return g
 
 
-class Module(XModule):
+class ABTestModule(XModule):
     """
     Implements an A/B test with an aribtrary number of competing groups
 
@@ -37,20 +36,14 @@ class Module(XModule):
     </abtest>
     """
 
-    def __init__(self, system, xml, item_id, instance_state=None, shared_state=None):
-        XModule.__init__(self, system, xml, item_id, instance_state, shared_state)
-        self.xmltree = etree.fromstring(xml)
+    def __init__(self, system, location, definition, instance_state=None, shared_state=None, **kwargs):
+        XModule.__init__(self, system, location, definition, instance_state, shared_state, **kwargs)
 
-        target_groups = self.xmltree.findall('group')
+        target_groups = self.definition['data'].keys()
         if shared_state is None:
-            target_values = [
-                (elem.get('name'), float(elem.get('portion')))
-                for elem in target_groups
-            ]
-            default_value = 1 - sum(val for (_, val) in target_values)
 
             self.group = group_from_value(
-                target_values + [(None, default_value)],
+                self.definition['data']['group_portions'],
                 random.uniform(0, 1)
             )
         else:
@@ -69,24 +62,57 @@ class Module(XModule):
                 self.group = shared_state['group']
 
     def get_shared_state(self):
+        print self.group
         return json.dumps({'group': self.group})
 
-    def _xml_children(self):
-        group = None
-        if self.group is None:
-            group = self.xmltree.find('default')
-        else:
-            for candidate_group in self.xmltree.find('group'):
-                if self.group == candidate_group.get('name'):
-                    group = candidate_group
-                    break
+    def displayable_items(self):
+        return [self.system.get_module(child)
+                for child
+                in self.definition['data']['group_content'][self.group]]
+
+
+class ABTestDescriptor(RawDescriptor, XmlDescriptor):
+    module_class = ABTestModule
+
+    def __init__(self, system, definition=None, **kwargs):
+        kwargs['shared_state_key'] = definition['data']['experiment']
+        RawDescriptor.__init__(self, system, definition, **kwargs)
+
+    @classmethod
+    def definition_from_xml(cls, xml_object, system):
+        experiment = xml_object.get('experiment')
+
+        if experiment is None:
+            raise InvalidDefinitionError("ABTests must specify an experiment. Not found in:\n{xml}".format(xml=etree.tostring(xml_object, pretty_print=True)))
+
+        definition = {
+            'data': {
+                'experiment': experiment,
+                'group_portions': [],
+                'group_content': {None: []},
+            },
+            'children': []}
+        for group in xml_object:
+            if group.tag == 'default':
+                name = None
+            else:
+                name = group.get('name')
+                definition['data']['group_portions'].append(
+                    (name, float(group.get('portion', 0)))
+                )
+
+            child_content_urls = [
+                system.process_xml(etree.tostring(child)).url
+                for child in group
+            ]
+
+            definition['data']['group_content'][name] = child_content_urls
+            definition['children'].extend(child_content_urls)
 
-        if group is None:
-            return []
-        return list(group)
+        default_portion = 1 - sum(portion for (name, portion) in definition['data']['group_portions'])
+        if default_portion < 0:
+            raise InvalidDefinitionError("ABTest portions must add up to less than or equal to 1")
 
-    def get_children(self):
-        return [self.module_from_xml(child) for child in self._xml_children()]
+        definition['data']['group_portions'].append((None, default_portion))
 
-    def get_html(self):
-        return '\n'.join(child.get_html() for child in self.get_children())
+        return definition
diff --git a/common/lib/xmodule/exceptions.py b/common/lib/xmodule/exceptions.py
new file mode 100644
index 00000000000..9a9258d6008
--- /dev/null
+++ b/common/lib/xmodule/exceptions.py
@@ -0,0 +1,2 @@
+class InvalidDefinitionError(Exception):
+    pass
diff --git a/common/lib/xmodule/setup.py b/common/lib/xmodule/setup.py
index 93eddc5c7cf..e45e6654c23 100644
--- a/common/lib/xmodule/setup.py
+++ b/common/lib/xmodule/setup.py
@@ -13,6 +13,7 @@ setup(
     # for a description of entry_points
     entry_points={
         'xmodule.v1': [
+            "abtest = xmodule.abtest_module:ABTestDescriptor",
             "book = xmodule.translation_module:TranslateCustomTagDescriptor",
             "chapter = xmodule.seq_module:SequenceDescriptor",
             "course = xmodule.seq_module:SequenceDescriptor",
diff --git a/common/lib/xmodule/x_module.py b/common/lib/xmodule/x_module.py
index 8ee3df38ffd..d8559c9bb7a 100644
--- a/common/lib/xmodule/x_module.py
+++ b/common/lib/xmodule/x_module.py
@@ -295,6 +295,7 @@ class XModuleDescriptor(Plugin):
         self.display_name = kwargs.get('display_name')
         self.format = kwargs.get('format')
         self.graded = kwargs.get('graded', False)
+        self.shared_state_key = kwargs.get('shared_state_key')
 
         # For now, we represent goals as a list of strings, but this
         # is one of the things that we are going to be iterating on heavily
-- 
GitLab