"""SAML AuthNRequest Parser and dataclass""" from base64 import b64decode from dataclasses import dataclass from typing import Optional from urllib.parse import quote_plus from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import padding from defusedxml import ElementTree from signxml import XMLVerifier from structlog import get_logger from passbook.providers.saml.exceptions import CannotHandleAssertion from passbook.providers.saml.models import SAMLProvider from passbook.providers.saml.utils.encoding import decode_base64_and_inflate from passbook.sources.saml.processors.constants import ( NS_SAML_PROTOCOL, SAML_NAME_ID_FORMAT_EMAIL, ) LOGGER = get_logger() @dataclass class AuthNRequest: """AuthNRequest Dataclass""" # pylint: disable=invalid-name id: Optional[str] = None relay_state: Optional[str] = None name_id_policy: str = SAML_NAME_ID_FORMAT_EMAIL class AuthNRequestParser: """AuthNRequest Parser""" provider: SAMLProvider def __init__(self, provider: SAMLProvider): self.provider = provider def _parse_xml(self, decoded_xml: str, relay_state: Optional[str]) -> AuthNRequest: root = ElementTree.fromstring(decoded_xml) request_acs_url = root.attrib["AssertionConsumerServiceURL"] if self.provider.acs_url != request_acs_url: msg = ( f"ACS URL of {request_acs_url} doesn't match Provider " f"ACS URL of {self.provider.acs_url}." ) LOGGER.info(msg) raise CannotHandleAssertion(msg) auth_n_request = AuthNRequest(id=root.attrib["ID"], relay_state=relay_state) # Check if AuthnRequest has a NameID Policy object name_id_policies = root.findall(f"{{{NS_SAML_PROTOCOL}}}:NameIDPolicy") if len(name_id_policies) > 0: name_id_policy = name_id_policies[0] auth_n_request.name_id_policy = name_id_policy.attrib["Format"] return auth_n_request def parse(self, saml_request: str, relay_state: Optional[str]) -> AuthNRequest: """Validate and parse raw request with enveloped signautre.""" decoded_xml = decode_base64_and_inflate(saml_request) if self.provider.signing_kp: try: XMLVerifier().verify( decoded_xml, x509_cert=self.provider.signing_kp.certificate_data ) except InvalidSignature as exc: raise CannotHandleAssertion("Failed to verify signature") from exc return self._parse_xml(decoded_xml, relay_state) def parse_detached( self, saml_request: str, relay_state: Optional[str], signature: Optional[str] = None, sig_alg: Optional[str] = None, ) -> AuthNRequest: """Validate and parse raw request with detached signature""" decoded_xml = decode_base64_and_inflate(saml_request) if signature and sig_alg: # if sig_alg == "http://www.w3.org/2000/09/xmldsig#rsa-sha1": sig_hash = hashes.SHA1() # nosec querystring = f"SAMLRequest={quote_plus(saml_request)}&" if relay_state is not None: querystring += f"RelayState={quote_plus(relay_state)}&" querystring += f"SigAlg={sig_alg}" public_key = self.provider.signing_kp.private_key.public_key() try: public_key.verify( b64decode(signature), querystring.encode(), padding.PSS( mgf=padding.MGF1(sig_hash), salt_length=padding.PSS.MAX_LENGTH ), sig_hash, ) except InvalidSignature as exc: raise CannotHandleAssertion("Failed to verify signature") from exc return self._parse_xml(decoded_xml, relay_state) def idp_initiated(self) -> AuthNRequest: """Create IdP Initiated AuthNRequest""" return AuthNRequest()