diff --git a/passbook/sources/saml/api.py b/passbook/sources/saml/api.py index b117e56df..d89014d4a 100644 --- a/passbook/sources/saml/api.py +++ b/passbook/sources/saml/api.py @@ -13,7 +13,7 @@ class SAMLSourceSerializer(ModelSerializer): model = SAMLSource fields = [ "pk", - "entity_id", + "issuer", "idp_url", "idp_logout_url", "auto_logout", diff --git a/passbook/sources/saml/exceptions.py b/passbook/sources/saml/exceptions.py new file mode 100644 index 000000000..7f91e2bad --- /dev/null +++ b/passbook/sources/saml/exceptions.py @@ -0,0 +1,9 @@ +"""passbook saml source exceptions""" + + +class MissingSAMLResponse(Exception): + """Exception raised when request does not contain SAML Response.""" + + +class UnsupportedNameIDFormat(Exception): + """Exception raised when SAML Response contains NameID Format not supported.""" diff --git a/passbook/sources/saml/forms.py b/passbook/sources/saml/forms.py index 14e350e78..ff6ed84ee 100644 --- a/passbook/sources/saml/forms.py +++ b/passbook/sources/saml/forms.py @@ -22,7 +22,7 @@ class SAMLSourceForm(forms.ModelForm): model = SAMLSource fields = SOURCE_FORM_FIELDS + [ - "entity_id", + "issuer", "idp_url", "idp_logout_url", "auto_logout", @@ -31,7 +31,7 @@ class SAMLSourceForm(forms.ModelForm): widgets = { "name": forms.TextInput(), "policies": FilteredSelectMultiple(_("policies"), False), - "entity_id": forms.TextInput(), + "issuer": forms.TextInput(), "idp_url": forms.TextInput(), "idp_logout_url": forms.TextInput(), } diff --git a/passbook/sources/saml/processors/__init__.py b/passbook/sources/saml/processors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/passbook/sources/saml/processors/base.py b/passbook/sources/saml/processors/base.py new file mode 100644 index 000000000..a094b32bf --- /dev/null +++ b/passbook/sources/saml/processors/base.py @@ -0,0 +1,85 @@ +"""passbook saml source processor""" +from typing import Optional +from xml.etree.ElementTree import Element + +from defusedxml import ElementTree +from django.http import HttpRequest +from signxml import XMLVerifier +from structlog import get_logger + +from passbook.core.models import User +from passbook.providers.saml.utils.encoding import decode_base64_and_inflate +from passbook.sources.saml.exceptions import ( + MissingSAMLResponse, + UnsupportedNameIDFormat, +) +from passbook.sources.saml.models import SAMLSource + +LOGGER = get_logger() + + +class Processor: + """SAML Response Processor""" + + _source: SAMLSource + + _root: Element + _root_xml: str + + def __init__(self, source: SAMLSource): + self._source = source + + def parse(self, request: HttpRequest): + """Check if `request` contains SAML Response data, parse and validate it.""" + # First off, check if we have any SAML Data at all. + raw_response = request.POST.get("SAMLResponse", None) + if not raw_response: + raise MissingSAMLResponse("Request does not contain 'SAMLResponse'") + # relay_state = request.POST.get('RelayState', None) + # Check if response is compressed, b64 decode it + self._root_xml = response = decode_base64_and_inflate(raw_response) + self._root = ElementTree.fromstring(self._root_xml) + # Verify signed XML + self._verify_signed() + + def _verify_signed(self): + """Verify SAML Response's Signature""" + verifier = XMLVerifier() + verifier.verify(self._root_xml, x509_cert=self._source.signing_cert) + + def _get_email(self) -> Optional[str]: + """ + Returns the email out of the response. + + At present, response must pass the email address as the Subject, eg.: + + + email@example.com + """ + assertion = self._root.find("{urn:oasis:names:tc:SAML:2.0:assertion}Assertion") + subject = assertion.find("{urn:oasis:names:tc:SAML:2.0:assertion}Subject") + name_id = subject.find("{urn:oasis:names:tc:SAML:2.0:assertion}NameID") + name_id_format = name_id.attrib["Format"] + if name_id_format != "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress": + raise UnsupportedNameIDFormat( + f"Assertion contains NameID with unsupported format {name_id_format}." + ) + return name_id.text + + def get_user(self) -> User: + """ + Gets info out of the response and locally logs in this user. + May create a local user account first. + Returns the user object that was created. + """ + email = self._get_email() + try: + user = User.objects.get(email=email) + except User.DoesNotExist: + user = User.objects.create_user(username=email, email=email) + # TODO: Property Mappings + user.set_unusable_password() + user.save() + return user diff --git a/passbook/sources/saml/utils.py b/passbook/sources/saml/utils.py index a9ae43bbf..139c556b4 100644 --- a/passbook/sources/saml/utils.py +++ b/passbook/sources/saml/utils.py @@ -2,7 +2,6 @@ from django.http import HttpRequest from django.shortcuts import reverse -from passbook.core.models import User from passbook.sources.saml.models import SAMLSource @@ -19,65 +18,3 @@ def build_full_url(view: str, request: HttpRequest, source: SAMLSource) -> str: return request.build_absolute_uri( reverse(f"passbook_sources_saml:{view}", kwargs={"source_slug": source.slug}) ) - - -def _get_email_from_response(root): - """ - Returns the email out of the response. - - At present, response must pass the email address as the Subject, eg.: - - - email@example.com - """ - assertion = root.find("{urn:oasis:names:tc:SAML:2.0:assertion}Assertion") - subject = assertion.find("{urn:oasis:names:tc:SAML:2.0:assertion}Subject") - name_id = subject.find("{urn:oasis:names:tc:SAML:2.0:assertion}NameID") - return name_id.text - - -def _get_attributes_from_response(root): - """ - Returns the SAML Attributes (if any) that are present in the response. - - NOTE: Technically, attribute values could be any XML structure. - But for now, just assume a single string value. - """ - flat_attributes = {} - assertion = root.find("{urn:oasis:names:tc:SAML:2.0:assertion}Assertion") - attributes = assertion.find( - "{urn:oasis:names:tc:SAML:2.0:assertion}AttributeStatement" - ) - for attribute in attributes.getchildren(): - name = attribute.attrib.get("Name") - children = attribute.getchildren() - if not children: - # Ignore empty-valued attributes. (I think these are not allowed.) - continue - if len(children) == 1: - # See NOTE: - flat_attributes[name] = children[0].text - else: - # It has multiple values. - for child in children: - # See NOTE: - flat_attributes.setdefault(name, []).append(child.text) - return flat_attributes - - -def _get_user_from_response(root): - """ - Gets info out of the response and locally logs in this user. - May create a local user account first. - Returns the user object that was created. - """ - email = _get_email_from_response(root) - try: - user = User.objects.get(email=email) - except User.DoesNotExist: - user = User.objects.create_user(username=email, email=email) - user.set_unusable_password() - user.save() - return user diff --git a/passbook/sources/saml/views.py b/passbook/sources/saml/views.py index 38915b177..9f6516afa 100644 --- a/passbook/sources/saml/views.py +++ b/passbook/sources/saml/views.py @@ -1,7 +1,4 @@ """saml sp views""" -import base64 - -from defusedxml import ElementTree from django.contrib.auth import login, logout from django.http import Http404, HttpRequest, HttpResponse from django.shortcuts import get_object_or_404, redirect, render, reverse @@ -10,15 +7,17 @@ from django.views import View from django.views.decorators.csrf import csrf_exempt from signxml.util import strip_pem_header +from passbook.lib.views import bad_request_message from passbook.providers.saml.utils import get_random_id, render_xml from passbook.providers.saml.utils.encoding import nice64 from passbook.providers.saml.utils.time import get_time_string -from passbook.sources.saml.models import SAMLSource -from passbook.sources.saml.utils import ( - _get_user_from_response, - build_full_url, - get_issuer, +from passbook.sources.saml.exceptions import ( + MissingSAMLResponse, + UnsupportedNameIDFormat, ) +from passbook.sources.saml.models import SAMLSource +from passbook.sources.saml.processors.base import Processor +from passbook.sources.saml.utils import build_full_url, get_issuer from passbook.sources.saml.xml_render import get_authnrequest_xml @@ -62,14 +61,18 @@ class ACSView(View): source: SAMLSource = get_object_or_404(SAMLSource, slug=source_slug) if not source.enabled: raise Http404 - # sso_session = request.POST.get('RelayState', None) - data = request.POST.get("SAMLResponse", None) - response = base64.b64decode(data) - root = ElementTree.fromstring(response) - user = _get_user_from_response(root) - # attributes = _get_attributes_from_response(root) - login(request, user, backend="django.contrib.auth.backends.ModelBackend") - return redirect(reverse("passbook_core:overview")) + processor = Processor(source) + try: + processor.parse(request) + except MissingSAMLResponse as exc: + return bad_request_message(request, str(exc)) + + try: + user = processor.get_user() + login(request, user, backend="django.contrib.auth.backends.ModelBackend") + return redirect(reverse("passbook_core:overview")) + except UnsupportedNameIDFormat as exc: + return bad_request_message(request, str(exc)) class SLOView(View):