From 00226bf3c068924ad2ffbb8af252071e6afe2a78 Mon Sep 17 00:00:00 2001 From: Braden MacDonald <braden@opencraft.com> Date: Mon, 15 Jun 2015 01:54:38 -0700 Subject: [PATCH] Asynchronous metadata fetching using celery beat - PR 8518 --- .gitignore | 3 + common/djangoapps/third_party_auth/admin.py | 14 +- .../management/commands/saml.py | 139 ++-------------- common/djangoapps/third_party_auth/tasks.py | 157 ++++++++++++++++++ .../tests/specs/test_testshib.py | 15 +- .../third_party_auth/tests/test_views.py | 2 +- lms/envs/aws.py | 8 + pavelib/servers.py | 4 +- 8 files changed, 204 insertions(+), 138 deletions(-) create mode 100644 common/djangoapps/third_party_auth/tasks.py diff --git a/.gitignore b/.gitignore index 861ceec67f9..1c2845e1756 100644 --- a/.gitignore +++ b/.gitignore @@ -91,6 +91,9 @@ logs chromedriver.log ghostdriver.log +### Celery artifacts ### +celerybeat-schedule + ### Unknown artifacts database.sqlite courseware/static/js/mathjax/* diff --git a/common/djangoapps/third_party_auth/admin.py b/common/djangoapps/third_party_auth/admin.py index d36ca9dd412..8495ef3a2b2 100644 --- a/common/djangoapps/third_party_auth/admin.py +++ b/common/djangoapps/third_party_auth/admin.py @@ -7,6 +7,7 @@ from django.contrib import admin from config_models.admin import ConfigurationModelAdmin, KeyedConfigurationModelAdmin from .models import OAuth2ProviderConfig, SAMLProviderConfig, SAMLConfiguration, SAMLProviderData +from .tasks import fetch_saml_metadata admin.site.register(OAuth2ProviderConfig, KeyedConfigurationModelAdmin) @@ -29,6 +30,17 @@ class SAMLProviderConfigAdmin(KeyedConfigurationModelAdmin): has_data.short_description = u'Metadata Ready' has_data.boolean = True + def save_model(self, request, obj, form, change): + """ + Post save: Queue an asynchronous metadata fetch to update SAMLProviderData. + We only want to do this for manual edits done using the admin interface. + + Note: This only works if the celery worker and the app worker are using the + same 'configuration' cache. + """ + super(SAMLProviderConfigAdmin, self).save_model(request, obj, form, change) + fetch_saml_metadata.apply_async((), countdown=2) + admin.site.register(SAMLProviderConfig, SAMLProviderConfigAdmin) @@ -54,7 +66,7 @@ admin.site.register(SAMLConfiguration, SAMLConfigurationAdmin) class SAMLProviderDataAdmin(admin.ModelAdmin): - """ Django Admin class for SAMLProviderData """ + """ Django Admin class for SAMLProviderData (Read Only) """ list_display = ('entity_id', 'is_valid', 'fetched_at', 'expires_at', 'sso_url') readonly_fields = ('is_valid', ) diff --git a/common/djangoapps/third_party_auth/management/commands/saml.py b/common/djangoapps/third_party_auth/management/commands/saml.py index ca15bcaf4db..01918157ae0 100644 --- a/common/djangoapps/third_party_auth/management/commands/saml.py +++ b/common/djangoapps/third_party_auth/management/commands/saml.py @@ -2,20 +2,10 @@ """ Management commands for third_party_auth """ -import datetime -import dateutil.parser from django.core.management.base import BaseCommand, CommandError -from lxml import etree -import requests -from onelogin.saml2.utils import OneLogin_Saml2_Utils -from third_party_auth.models import SAMLConfiguration, SAMLProviderConfig, SAMLProviderData - -#pylint: disable=superfluous-parens,no-member - - -class MetadataParseError(Exception): - """ An error occurred while parsing the SAML metadata from an IdP """ - pass +import logging +from third_party_auth.models import SAMLConfiguration +from third_party_auth.tasks import fetch_saml_metadata class Command(BaseCommand): @@ -27,120 +17,21 @@ class Command(BaseCommand): raise CommandError("saml requires one argument: pull") if not SAMLConfiguration.is_enabled(): - self.stdout.write("Warning: SAML support is disabled via SAMLConfiguration.\n") + raise CommandError("SAML support is disabled via SAMLConfiguration.") subcommand = args[0] if subcommand == "pull": - self.cmd_pull() + log_handler = logging.StreamHandler(self.stdout) + log_handler.setLevel(logging.DEBUG) + log = logging.getLogger('third_party_auth.tasks') + log.propagate = False + log.addHandler(log_handler) + num_changed, num_failed, num_total = fetch_saml_metadata() + self.stdout.write( + "\nDone. Fetched {num_total} total. {num_changed} were updated and {num_failed} failed.\n".format( + num_changed=num_changed, num_failed=num_failed, num_total=num_total + ) + ) else: raise CommandError("Unknown argment: {}".format(subcommand)) - - @staticmethod - def tag_name(tag_name): - """ Get the namespaced-qualified name for an XML tag """ - return '{urn:oasis:names:tc:SAML:2.0:metadata}' + tag_name - - def cmd_pull(self): - """ Fetch the metadata for each provider and update the DB """ - # First make a list of all the metadata XML URLs: - url_map = {} - for idp_slug in SAMLProviderConfig.key_values('idp_slug', flat=True): - config = SAMLProviderConfig.current(idp_slug) - if not config.enabled: - continue - url = config.metadata_source - if url not in url_map: - url_map[url] = [] - if config.entity_id not in url_map[url]: - url_map[url].append(config.entity_id) - # Now fetch the metadata: - for url, entity_ids in url_map.items(): - try: - self.stdout.write("\n→ Fetching {}\n".format(url)) - if not url.lower().startswith('https'): - self.stdout.write("→ WARNING: This URL is not secure! It should use HTTPS.\n") - response = requests.get(url, verify=True) # May raise HTTPError or SSLError or ConnectionError - response.raise_for_status() # May raise an HTTPError - - try: - parser = etree.XMLParser(remove_comments=True) - xml = etree.fromstring(response.text, parser) - except etree.XMLSyntaxError: - raise - # TODO: Can use OneLogin_Saml2_Utils to validate signed XML if anyone is using that - - for entity_id in entity_ids: - self.stdout.write("→ Processing IdP with entityID {}\n".format(entity_id)) - public_key, sso_url, expires_at = self._parse_metadata_xml(xml, entity_id) - self._update_data(entity_id, public_key, sso_url, expires_at) - except Exception as err: # pylint: disable=broad-except - self.stderr.write(u"→ ERROR: {}\n\n".format(err.message)) - - @classmethod - def _parse_metadata_xml(cls, xml, entity_id): - """ - Given an XML document containing SAML 2.0 metadata, parse it and return a tuple of - (public_key, sso_url, expires_at) for the specified entityID. - - Raises MetadataParseError if anything is wrong. - """ - if xml.tag == cls.tag_name('EntityDescriptor'): - entity_desc = xml - else: - if xml.tag != cls.tag_name('EntitiesDescriptor'): - raise MetadataParseError("Expected root element to be <EntitiesDescriptor>, not {}".format(xml.tag)) - entity_desc = xml.find(".//{}[@entityID='{}']".format(cls.tag_name('EntityDescriptor'), entity_id)) - if not entity_desc: - raise MetadataParseError("Can't find EntityDescriptor for entityID {}".format(entity_id)) - - expires_at = None - if "validUntil" in xml.attrib: - expires_at = dateutil.parser.parse(xml.attrib["validUntil"]) - if "cacheDuration" in xml.attrib: - cache_expires = OneLogin_Saml2_Utils.parse_duration(xml.attrib["cacheDuration"]) - if expires_at is None or cache_expires < expires_at: - expires_at = cache_expires - - sso_desc = entity_desc.find(cls.tag_name("IDPSSODescriptor")) - if not sso_desc: - raise MetadataParseError("IDPSSODescriptor missing") - if 'urn:oasis:names:tc:SAML:2.0:protocol' not in sso_desc.get("protocolSupportEnumeration"): - raise MetadataParseError("This IdP does not support SAML 2.0") - - # Now we just need to get the public_key and sso_url - public_key = sso_desc.findtext("./{}//{}".format( - cls.tag_name("KeyDescriptor"), "{http://www.w3.org/2000/09/xmldsig#}X509Certificate" - )) - if not public_key: - raise MetadataParseError("Public Key missing. Expected an <X509Certificate>") - public_key = public_key.replace(" ", "") - binding_elements = sso_desc.iterfind("./{}".format(cls.tag_name("SingleSignOnService"))) - sso_bindings = {element.get('Binding'): element.get('Location') for element in binding_elements} - try: - # The only binding supported by python-saml and python-social-auth is HTTP-Redirect: - sso_url = sso_bindings['urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect'] - except KeyError: - raise MetadataParseError("Unable to find SSO URL with HTTP-Redirect binding.") - return public_key, sso_url, expires_at - - def _update_data(self, entity_id, public_key, sso_url, expires_at): - """ - Update/Create the SAMLProviderData for the given entity ID. - """ - data_obj = SAMLProviderData.current(entity_id) - fetched_at = datetime.datetime.now() - if data_obj and (data_obj.public_key == public_key and data_obj.sso_url == sso_url): - data_obj.expires_at = expires_at - data_obj.fetched_at = fetched_at - data_obj.save() - self.stdout.write("→ Updated existing SAMLProviderData. Nothing has changed.\n") - else: - SAMLProviderData.objects.create( - entity_id=entity_id, - fetched_at=fetched_at, - expires_at=expires_at, - sso_url=sso_url, - public_key=public_key, - ) - self.stdout.write("→ Created new record for SAMLProviderData\n") diff --git a/common/djangoapps/third_party_auth/tasks.py b/common/djangoapps/third_party_auth/tasks.py new file mode 100644 index 00000000000..7466e113afe --- /dev/null +++ b/common/djangoapps/third_party_auth/tasks.py @@ -0,0 +1,157 @@ +# -*- coding: utf-8 -*- +""" +Code to manage fetching and storing the metadata of IdPs. +""" +#pylint: disable=no-member +from celery.task import task # pylint: disable=import-error,no-name-in-module +import datetime +import dateutil.parser +import logging +from lxml import etree +import requests +from onelogin.saml2.utils import OneLogin_Saml2_Utils +from third_party_auth.models import SAMLConfiguration, SAMLProviderConfig, SAMLProviderData + +log = logging.getLogger(__name__) + +SAML_XML_NS = 'urn:oasis:names:tc:SAML:2.0:metadata' # The SAML Metadata XML namespace + + +class MetadataParseError(Exception): + """ An error occurred while parsing the SAML metadata from an IdP """ + pass + + +@task(name='third_party_auth.fetch_saml_metadata') +def fetch_saml_metadata(): + """ + Fetch and store/update the metadata of all IdPs + + This task should be run on a daily basis. + It's OK to run this whether or not SAML is enabled. + + Return value: + tuple(num_changed, num_failed, num_total) + num_changed: Number of providers that are either new or whose metadata has changed + num_failed: Number of providers that could not be updated + num_total: Total number of providers whose metadata was fetched + """ + if not SAMLConfiguration.is_enabled(): + return (0, 0, 0) # Nothing to do until SAML is enabled. + + num_changed, num_failed = 0, 0 + + # First make a list of all the metadata XML URLs: + url_map = {} + for idp_slug in SAMLProviderConfig.key_values('idp_slug', flat=True): + config = SAMLProviderConfig.current(idp_slug) + if not config.enabled: + continue + url = config.metadata_source + if url not in url_map: + url_map[url] = [] + if config.entity_id not in url_map[url]: + url_map[url].append(config.entity_id) + # Now fetch the metadata: + for url, entity_ids in url_map.items(): + try: + log.info("Fetching %s", url) + if not url.lower().startswith('https'): + log.warning("This SAML metadata URL is not secure! It should use HTTPS. (%s)", url) + response = requests.get(url, verify=True) # May raise HTTPError or SSLError or ConnectionError + response.raise_for_status() # May raise an HTTPError + + try: + parser = etree.XMLParser(remove_comments=True) + xml = etree.fromstring(response.text, parser) + except etree.XMLSyntaxError: + raise + # TODO: Can use OneLogin_Saml2_Utils to validate signed XML if anyone is using that + + for entity_id in entity_ids: + log.info(u"Processing IdP with entityID %s", entity_id) + public_key, sso_url, expires_at = _parse_metadata_xml(xml, entity_id) + changed = _update_data(entity_id, public_key, sso_url, expires_at) + if changed: + log.info(u"→ Created new record for SAMLProviderData") + num_changed += 1 + else: + log.info(u"→ Updated existing SAMLProviderData. Nothing has changed.") + except Exception as err: # pylint: disable=broad-except + log.exception(err.message) + num_failed += 1 + return (num_changed, num_failed, len(url_map)) + + +def _parse_metadata_xml(xml, entity_id): + """ + Given an XML document containing SAML 2.0 metadata, parse it and return a tuple of + (public_key, sso_url, expires_at) for the specified entityID. + + Raises MetadataParseError if anything is wrong. + """ + if xml.tag == etree.QName(SAML_XML_NS, 'EntityDescriptor'): + entity_desc = xml + else: + if xml.tag != etree.QName(SAML_XML_NS, 'EntitiesDescriptor'): + raise MetadataParseError("Expected root element to be <EntitiesDescriptor>, not {}".format(xml.tag)) + entity_desc = xml.find( + ".//{}[@entityID='{}']".format(etree.QName(SAML_XML_NS, 'EntityDescriptor'), entity_id) + ) + if not entity_desc: + raise MetadataParseError("Can't find EntityDescriptor for entityID {}".format(entity_id)) + + expires_at = None + if "validUntil" in xml.attrib: + expires_at = dateutil.parser.parse(xml.attrib["validUntil"]) + if "cacheDuration" in xml.attrib: + cache_expires = OneLogin_Saml2_Utils.parse_duration(xml.attrib["cacheDuration"]) + if expires_at is None or cache_expires < expires_at: + expires_at = cache_expires + + sso_desc = entity_desc.find(etree.QName(SAML_XML_NS, "IDPSSODescriptor")) + if not sso_desc: + raise MetadataParseError("IDPSSODescriptor missing") + if 'urn:oasis:names:tc:SAML:2.0:protocol' not in sso_desc.get("protocolSupportEnumeration"): + raise MetadataParseError("This IdP does not support SAML 2.0") + + # Now we just need to get the public_key and sso_url + public_key = sso_desc.findtext("./{}//{}".format( + etree.QName(SAML_XML_NS, "KeyDescriptor"), "{http://www.w3.org/2000/09/xmldsig#}X509Certificate" + )) + if not public_key: + raise MetadataParseError("Public Key missing. Expected an <X509Certificate>") + public_key = public_key.replace(" ", "") + binding_elements = sso_desc.iterfind("./{}".format(etree.QName(SAML_XML_NS, "SingleSignOnService"))) + sso_bindings = {element.get('Binding'): element.get('Location') for element in binding_elements} + try: + # The only binding supported by python-saml and python-social-auth is HTTP-Redirect: + sso_url = sso_bindings['urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect'] + except KeyError: + raise MetadataParseError("Unable to find SSO URL with HTTP-Redirect binding.") + return public_key, sso_url, expires_at + + +def _update_data(entity_id, public_key, sso_url, expires_at): + """ + Update/Create the SAMLProviderData for the given entity ID. + Return value: + False if nothing has changed and existing data's "fetched at" timestamp is just updated. + True if a new record was created. (Either this is a new provider or something changed.) + """ + data_obj = SAMLProviderData.current(entity_id) + fetched_at = datetime.datetime.now() + if data_obj and (data_obj.public_key == public_key and data_obj.sso_url == sso_url): + data_obj.expires_at = expires_at + data_obj.fetched_at = fetched_at + data_obj.save() + return False + else: + SAMLProviderData.objects.create( + entity_id=entity_id, + fetched_at=fetched_at, + expires_at=expires_at, + sso_url=sso_url, + public_key=public_key, + ) + return True diff --git a/common/djangoapps/third_party_auth/tests/specs/test_testshib.py b/common/djangoapps/third_party_auth/tests/specs/test_testshib.py index 689066113f0..be17cf74a88 100644 --- a/common/djangoapps/third_party_auth/tests/specs/test_testshib.py +++ b/common/djangoapps/third_party_auth/tests/specs/test_testshib.py @@ -1,12 +1,11 @@ """ Third_party_auth integration tests using a mock version of the TestShib provider """ -from django.core.management import call_command from django.core.urlresolvers import reverse import httpretty from mock import patch -import StringIO from student.tests.factories import UserFactory +from third_party_auth.tasks import fetch_saml_metadata from third_party_auth.tests import testutil import unittest @@ -209,15 +208,11 @@ class TestShibIntegrationTest(testutil.SAMLTestCase): self.configure_saml_provider(**kwargs) if fetch_metadata: - stdout = StringIO.StringIO() - stderr = StringIO.StringIO() self.assertTrue(httpretty.is_enabled()) - call_command('saml', 'pull', stdout=stdout, stderr=stderr) - stdout = stdout.getvalue().decode('utf-8') - stderr = stderr.getvalue().decode('utf-8') - self.assertEqual(stderr, '') - self.assertIn(u'Fetching {}'.format(TESTSHIB_METADATA_URL), stdout) - self.assertIn(u'Created new record for SAMLProviderData', stdout) + num_changed, num_failed, num_total = fetch_saml_metadata() + self.assertEqual(num_failed, 0) + self.assertEqual(num_changed, 1) + self.assertEqual(num_total, 1) def _fake_testshib_login_and_return(self): """ Mocked: the user logs in to TestShib and then gets redirected back """ diff --git a/common/djangoapps/third_party_auth/tests/test_views.py b/common/djangoapps/third_party_auth/tests/test_views.py index 11659011a52..8e88629801d 100644 --- a/common/djangoapps/third_party_auth/tests/test_views.py +++ b/common/djangoapps/third_party_auth/tests/test_views.py @@ -8,7 +8,7 @@ import unittest from .testutil import AUTH_FEATURE_ENABLED, SAMLTestCase # Define some XML namespaces: -SAML_XML_NS = 'urn:oasis:names:tc:SAML:2.0:metadata' +from third_party_auth.tasks import SAML_XML_NS XMLDSIG_XML_NS = 'http://www.w3.org/2000/09/xmldsig#' diff --git a/lms/envs/aws.py b/lms/envs/aws.py index 9a4ce09cc5e..8582aa7ef9e 100644 --- a/lms/envs/aws.py +++ b/lms/envs/aws.py @@ -16,6 +16,7 @@ Common traits: # and throws spurious errors. Therefore, we disable invalid-name checking. # pylint: disable=invalid-name +import datetime import json from .common import * @@ -107,6 +108,7 @@ CELERY_QUEUES = { if os.environ.get('QUEUE') == 'high_mem': CELERYD_MAX_TASKS_PER_CHILD = 1 +CELERYBEAT_SCHEDULE = {} # For scheduling tasks, entries can be added to this dict ########################## NON-SECURE ENV CONFIG ############################## # Things like server locations, ports, etc. @@ -552,6 +554,12 @@ if FEATURES.get('ENABLE_THIRD_PARTY_AUTH'): # third_party_auth config moved to ConfigurationModels. This is for data migration only: THIRD_PARTY_AUTH_OLD_CONFIG = AUTH_TOKENS.get('THIRD_PARTY_AUTH', None) + if ENV_TOKENS.get('THIRD_PARTY_AUTH_SAML_FETCH_PERIOD_HOURS', 24) is not None: + CELERYBEAT_SCHEDULE['refresh-saml-metadata'] = { + 'task': 'third_party_auth.fetch_saml_metadata', + 'schedule': datetime.timedelta(hours=ENV_TOKENS.get('THIRD_PARTY_AUTH_SAML_FETCH_PERIOD_HOURS', 24)), + } + ##### OAUTH2 Provider ############## if FEATURES.get('ENABLE_OAUTH2_PROVIDER'): OAUTH_OIDC_ISSUER = ENV_TOKENS['OAUTH_OIDC_ISSUER'] diff --git a/pavelib/servers.py b/pavelib/servers.py index fd613af5493..8076a0e46e0 100644 --- a/pavelib/servers.py +++ b/pavelib/servers.py @@ -109,7 +109,7 @@ def celery(options): Runs Celery workers. """ settings = getattr(options, 'settings', 'dev_with_worker') - run_process(django_cmd('lms', settings, 'celery', 'worker', '--loglevel=INFO', '--pythonpath=.')) + run_process(django_cmd('lms', settings, 'celery', 'worker', '--beat', '--loglevel=INFO', '--pythonpath=.')) @task @@ -142,7 +142,7 @@ def run_all_servers(options): run_multi_processes([ django_cmd('lms', settings_lms, 'runserver', '--traceback', '--pythonpath=.', "0.0.0.0:{}".format(DEFAULT_PORT['lms'])), django_cmd('studio', settings_cms, 'runserver', '--traceback', '--pythonpath=.', "0.0.0.0:{}".format(DEFAULT_PORT['studio'])), - django_cmd('lms', worker_settings, 'celery', 'worker', '--loglevel=INFO', '--pythonpath=.') + django_cmd('lms', worker_settings, 'celery', 'worker', '--beat', '--loglevel=INFO', '--pythonpath=.') ]) -- GitLab