providers/saml: parse NameID Policy from AuthnRequest

This commit is contained in:
Jens Langhammer 2020-07-12 17:05:48 +02:00
parent 06f73512df
commit f8e5383ba2
1 changed files with 13 additions and 1 deletions

View File

@ -14,6 +14,10 @@ from structlog import get_logger
from passbook.providers.saml.exceptions import CannotHandleAssertion from passbook.providers.saml.exceptions import CannotHandleAssertion
from passbook.providers.saml.models import SAMLProvider from passbook.providers.saml.models import SAMLProvider
from passbook.providers.saml.utils.encoding import decode_base64_and_inflate from passbook.providers.saml.utils.encoding import decode_base64_and_inflate
from passbook.sources.saml.processors.constants import (
NS_SAML_PROTOCOL,
SAML_NAME_ID_FORMAT_EMAIL,
)
LOGGER = get_logger() LOGGER = get_logger()
@ -27,6 +31,8 @@ class AuthNRequest:
relay_state: Optional[str] = None relay_state: Optional[str] = None
name_id_policy: str = SAML_NAME_ID_FORMAT_EMAIL
class AuthNRequestParser: class AuthNRequestParser:
"""AuthNRequest Parser""" """AuthNRequest Parser"""
@ -37,7 +43,6 @@ class AuthNRequestParser:
self.provider = provider self.provider = provider
def _parse_xml(self, decoded_xml: str, relay_state: Optional[str]) -> AuthNRequest: def _parse_xml(self, decoded_xml: str, relay_state: Optional[str]) -> AuthNRequest:
root = ElementTree.fromstring(decoded_xml) root = ElementTree.fromstring(decoded_xml)
request_acs_url = root.attrib["AssertionConsumerServiceURL"] request_acs_url = root.attrib["AssertionConsumerServiceURL"]
@ -51,6 +56,13 @@ class AuthNRequestParser:
raise CannotHandleAssertion(msg) raise CannotHandleAssertion(msg)
auth_n_request = AuthNRequest(id=root.attrib["ID"], relay_state=relay_state) auth_n_request = AuthNRequest(id=root.attrib["ID"], relay_state=relay_state)
# Check if AuthnRequest has a NameID Policy object
name_id_policies = root.findall(f"{{{NS_SAML_PROTOCOL}}}:NameIDPolicy")
if len(name_id_policies) > 0:
name_id_policy = name_id_policies[0]
auth_n_request.name_id_policy = name_id_policy.attrib["Format"]
return auth_n_request return auth_n_request
def parse(self, saml_request: str, relay_state: Optional[str]) -> AuthNRequest: def parse(self, saml_request: str, relay_state: Optional[str]) -> AuthNRequest: