diff --git a/passbook/providers/saml/processors/request_parser.py b/passbook/providers/saml/processors/request_parser.py index b37d0ffdb..34ff8364b 100644 --- a/passbook/providers/saml/processors/request_parser.py +++ b/passbook/providers/saml/processors/request_parser.py @@ -14,6 +14,10 @@ from structlog import get_logger from passbook.providers.saml.exceptions import CannotHandleAssertion from passbook.providers.saml.models import SAMLProvider from passbook.providers.saml.utils.encoding import decode_base64_and_inflate +from passbook.sources.saml.processors.constants import ( + NS_SAML_PROTOCOL, + SAML_NAME_ID_FORMAT_EMAIL, +) LOGGER = get_logger() @@ -27,6 +31,8 @@ class AuthNRequest: relay_state: Optional[str] = None + name_id_policy: str = SAML_NAME_ID_FORMAT_EMAIL + class AuthNRequestParser: """AuthNRequest Parser""" @@ -37,7 +43,6 @@ class AuthNRequestParser: self.provider = provider def _parse_xml(self, decoded_xml: str, relay_state: Optional[str]) -> AuthNRequest: - root = ElementTree.fromstring(decoded_xml) request_acs_url = root.attrib["AssertionConsumerServiceURL"] @@ -51,6 +56,13 @@ class AuthNRequestParser: raise CannotHandleAssertion(msg) auth_n_request = AuthNRequest(id=root.attrib["ID"], relay_state=relay_state) + + # Check if AuthnRequest has a NameID Policy object + name_id_policies = root.findall(f"{{{NS_SAML_PROTOCOL}}}:NameIDPolicy") + if len(name_id_policies) > 0: + name_id_policy = name_id_policies[0] + auth_n_request.name_id_policy = name_id_policy.attrib["Format"] + return auth_n_request def parse(self, saml_request: str, relay_state: Optional[str]) -> AuthNRequest: