"""SAML Assertion generator""" from hashlib import sha256 from types import GeneratorType from django.http import HttpRequest from lxml import etree # nosec from lxml.etree import Element, SubElement # nosec from signxml import XMLSigner, XMLVerifier, strip_pem_header from structlog import get_logger from passbook.core.exceptions import PropertyMappingExpressionException from passbook.lib.utils.time import timedelta_from_string from passbook.providers.saml.models import SAMLPropertyMapping, SAMLProvider from passbook.providers.saml.processors.request_parser import AuthNRequest from passbook.providers.saml.utils import get_random_id from passbook.providers.saml.utils.time import get_time_string from passbook.sources.saml.exceptions import UnsupportedNameIDFormat from passbook.sources.saml.processors.constants import ( NS_MAP, NS_SAML_ASSERTION, NS_SAML_PROTOCOL, NS_SIGNATURE, SAML_NAME_ID_FORMAT_EMAIL, SAML_NAME_ID_FORMAT_PERSISTENT, SAML_NAME_ID_FORMAT_TRANSIENT, SAML_NAME_ID_FORMAT_X509, ) LOGGER = get_logger() class AssertionProcessor: """Generate a SAML Response from an AuthNRequest""" provider: SAMLProvider http_request: HttpRequest auth_n_request: AuthNRequest _issue_instant: str _assertion_id: str _valid_not_before: str _valid_not_on_or_after: str def __init__( self, provider: SAMLProvider, request: HttpRequest, auth_n_request: AuthNRequest ): self.provider = provider self.http_request = request self.auth_n_request = auth_n_request self._issue_instant = get_time_string() self._assertion_id = get_random_id() self._valid_not_before = get_time_string( timedelta_from_string(self.provider.assertion_valid_not_before) ) self._valid_not_on_or_after = get_time_string( timedelta_from_string(self.provider.assertion_valid_not_on_or_after) ) def get_attributes(self) -> Element: """Get AttributeStatement Element with Attributes from Property Mappings.""" # https://commons.lbl.gov/display/IDMgmt/Attribute+Definitions attribute_statement = Element(f"{{{NS_SAML_ASSERTION}}}AttributeStatement") for mapping in self.provider.property_mappings.all().select_subclasses(): if not isinstance(mapping, SAMLPropertyMapping): continue try: mapping: SAMLPropertyMapping value = mapping.evaluate( user=self.http_request.user, request=self.http_request, provider=self.provider, ) if value is None: continue attribute = Element(f"{{{NS_SAML_ASSERTION}}}Attribute") attribute.attrib["FriendlyName"] = mapping.friendly_name attribute.attrib["Name"] = mapping.saml_name if not isinstance(value, (list, GeneratorType)): value = [value] for value_item in value: attribute_value = SubElement( attribute, f"{{{NS_SAML_ASSERTION}}}AttributeValue" ) if not isinstance(value_item, str): value_item = str(value_item) attribute_value.text = value_item attribute_statement.append(attribute) except PropertyMappingExpressionException as exc: LOGGER.warning(exc) continue return attribute_statement def get_issuer(self) -> Element: """Get Issuer Element""" issuer = Element(f"{{{NS_SAML_ASSERTION}}}Issuer", nsmap=NS_MAP) issuer.text = self.provider.issuer return issuer def get_assertion_auth_n_statement(self) -> Element: """Generate AuthnStatement with AuthnContext and ContextClassRef Elements.""" auth_n_statement = Element(f"{{{NS_SAML_ASSERTION}}}AuthnStatement") auth_n_statement.attrib["AuthnInstant"] = self._valid_not_before auth_n_statement.attrib["SessionIndex"] = self._assertion_id auth_n_context = SubElement( auth_n_statement, f"{{{NS_SAML_ASSERTION}}}AuthnContext" ) auth_n_context_class_ref = SubElement( auth_n_context, f"{{{NS_SAML_ASSERTION}}}AuthnContextClassRef" ) auth_n_context_class_ref.text = ( "urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport" ) return auth_n_statement def get_assertion_conditions(self) -> Element: """Generate Conditions with AudienceRestriction and Audience Elements.""" conditions = Element(f"{{{NS_SAML_ASSERTION}}}Conditions") conditions.attrib["NotBefore"] = self._valid_not_before conditions.attrib["NotOnOrAfter"] = self._valid_not_on_or_after audience_restriction = SubElement( conditions, f"{{{NS_SAML_ASSERTION}}}AudienceRestriction" ) audience = SubElement(audience_restriction, f"{{{NS_SAML_ASSERTION}}}Audience") audience.text = self.provider.audience return conditions def get_name_id(self) -> Element: """Get NameID Element""" name_id = Element(f"{{{NS_SAML_ASSERTION}}}NameID") name_id.attrib["Format"] = self.auth_n_request.name_id_policy if name_id.attrib["Format"] == SAML_NAME_ID_FORMAT_EMAIL: name_id.text = self.http_request.user.email return name_id if name_id.attrib["Format"] == SAML_NAME_ID_FORMAT_PERSISTENT: name_id.text = self.http_request.user.username return name_id if name_id.attrib["Format"] == SAML_NAME_ID_FORMAT_X509: # This attribute is statically set by the LDAP source name_id.text = self.http_request.user.attributes.get( "distinguishedName", "" ) return name_id if name_id.attrib["Format"] == SAML_NAME_ID_FORMAT_TRANSIENT: # This attribute is statically set by the LDAP source session_key: str = self.http_request.user.session.session_key name_id.text = sha256(session_key.encode()).hexdigest() return name_id raise UnsupportedNameIDFormat( f"Assertion contains NameID with unsupported format {name_id.attrib['Format']}." ) def get_assertion_subject(self) -> Element: """Generate Subject Element with NameID and SubjectConfirmation Objects""" subject = Element(f"{{{NS_SAML_ASSERTION}}}Subject") subject.append(self.get_name_id()) subject_confirmation = SubElement( subject, f"{{{NS_SAML_ASSERTION}}}SubjectConfirmation" ) subject_confirmation.attrib["Method"] = "urn:oasis:names:tc:SAML:2.0:cm:bearer" subject_confirmation_data = SubElement( subject_confirmation, f"{{{NS_SAML_ASSERTION}}}SubjectConfirmationData" ) if self.auth_n_request.id: subject_confirmation_data.attrib["InResponseTo"] = self.auth_n_request.id subject_confirmation_data.attrib["NotOnOrAfter"] = self._valid_not_on_or_after subject_confirmation_data.attrib["Recipient"] = self.provider.acs_url return subject def get_assertion(self) -> Element: """Generate Main Assertion Element""" assertion = Element(f"{{{NS_SAML_ASSERTION}}}Assertion", nsmap=NS_MAP) assertion.attrib["Version"] = "2.0" assertion.attrib["ID"] = self._assertion_id assertion.attrib["IssueInstant"] = self._issue_instant assertion.append(self.get_issuer()) if self.provider.signing_kp: # We need a placeholder signature as SAML requires the signature to be between # Issuer and subject signature_placeholder = SubElement( assertion, f"{{{NS_SIGNATURE}}}Signature", nsmap=NS_MAP ) signature_placeholder.attrib["Id"] = "placeholder" assertion.append(self.get_assertion_subject()) assertion.append(self.get_assertion_conditions()) assertion.append(self.get_assertion_auth_n_statement()) assertion.append(self.get_attributes()) return assertion def get_response(self) -> Element: """Generate Root response element""" response = Element(f"{{{NS_SAML_PROTOCOL}}}Response", nsmap=NS_MAP) response.attrib["Version"] = "2.0" response.attrib["IssueInstant"] = self._issue_instant response.attrib["Destination"] = self.provider.acs_url response.attrib["ID"] = get_random_id() if self.auth_n_request.id: response.attrib["InResponseTo"] = self.auth_n_request.id response.append(self.get_issuer()) status = SubElement(response, f"{{{NS_SAML_PROTOCOL}}}Status") status_code = SubElement(status, f"{{{NS_SAML_PROTOCOL}}}StatusCode") status_code.attrib["Value"] = "urn:oasis:names:tc:SAML:2.0:status:Success" response.append(self.get_assertion()) return response def build_response(self) -> str: """Build string XML Response and sign if signing is enabled.""" root_response = self.get_response() if self.provider.signing_kp: signer = XMLSigner( c14n_algorithm="http://www.w3.org/2001/10/xml-exc-c14n#", signature_algorithm=self.provider.signature_algorithm, digest_algorithm=self.provider.digest_algorithm, ) x509_data = strip_pem_header( self.provider.signing_kp.certificate_data ).replace("\n", "") signed = signer.sign( root_response, key=self.provider.signing_kp.private_key, cert=[x509_data], reference_uri=self._assertion_id, ) XMLVerifier().verify(signed, x509_cert=x509_data) return etree.tostring(signed).decode("utf-8") # nosec return etree.tostring(root_response).decode("utf-8") # nosec