118 lines
4.0 KiB
Python
118 lines
4.0 KiB
Python
"""SAML AuthNRequest Parser and dataclass"""
|
|
from base64 import b64decode
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
from urllib.parse import quote_plus
|
|
|
|
from cryptography.exceptions import InvalidSignature
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives.asymmetric import padding
|
|
from defusedxml import ElementTree
|
|
from signxml import XMLVerifier
|
|
from structlog import get_logger
|
|
|
|
from passbook.providers.saml.exceptions import CannotHandleAssertion
|
|
from passbook.providers.saml.models import SAMLProvider
|
|
from passbook.providers.saml.utils.encoding import decode_base64_and_inflate
|
|
from passbook.sources.saml.processors.constants import (
|
|
NS_SAML_PROTOCOL,
|
|
SAML_NAME_ID_FORMAT_EMAIL,
|
|
)
|
|
|
|
LOGGER = get_logger()
|
|
|
|
|
|
@dataclass
|
|
class AuthNRequest:
|
|
"""AuthNRequest Dataclass"""
|
|
|
|
# pylint: disable=invalid-name
|
|
id: Optional[str] = None
|
|
|
|
relay_state: Optional[str] = None
|
|
|
|
name_id_policy: str = SAML_NAME_ID_FORMAT_EMAIL
|
|
|
|
|
|
class AuthNRequestParser:
|
|
"""AuthNRequest Parser"""
|
|
|
|
provider: SAMLProvider
|
|
|
|
def __init__(self, provider: SAMLProvider):
|
|
self.provider = provider
|
|
|
|
def _parse_xml(self, decoded_xml: str, relay_state: Optional[str]) -> AuthNRequest:
|
|
root = ElementTree.fromstring(decoded_xml)
|
|
|
|
request_acs_url = root.attrib["AssertionConsumerServiceURL"]
|
|
|
|
if self.provider.acs_url.lower() != request_acs_url.lower():
|
|
msg = (
|
|
f"ACS URL of {request_acs_url} doesn't match Provider "
|
|
f"ACS URL of {self.provider.acs_url}."
|
|
)
|
|
LOGGER.info(msg)
|
|
raise CannotHandleAssertion(msg)
|
|
|
|
auth_n_request = AuthNRequest(id=root.attrib["ID"], relay_state=relay_state)
|
|
|
|
# Check if AuthnRequest has a NameID Policy object
|
|
name_id_policies = root.findall(f"{{{NS_SAML_PROTOCOL}}}:NameIDPolicy")
|
|
if len(name_id_policies) > 0:
|
|
name_id_policy = name_id_policies[0]
|
|
auth_n_request.name_id_policy = name_id_policy.attrib["Format"]
|
|
|
|
return auth_n_request
|
|
|
|
def parse(self, saml_request: str, relay_state: Optional[str]) -> AuthNRequest:
|
|
"""Validate and parse raw request with enveloped signautre."""
|
|
decoded_xml = decode_base64_and_inflate(saml_request)
|
|
|
|
if self.provider.signing_kp:
|
|
try:
|
|
XMLVerifier().verify(
|
|
decoded_xml, x509_cert=self.provider.signing_kp.certificate_data
|
|
)
|
|
except InvalidSignature as exc:
|
|
raise CannotHandleAssertion("Failed to verify signature") from exc
|
|
|
|
return self._parse_xml(decoded_xml, relay_state)
|
|
|
|
def parse_detached(
|
|
self,
|
|
saml_request: str,
|
|
relay_state: Optional[str],
|
|
signature: Optional[str] = None,
|
|
sig_alg: Optional[str] = None,
|
|
) -> AuthNRequest:
|
|
"""Validate and parse raw request with detached signature"""
|
|
decoded_xml = decode_base64_and_inflate(saml_request)
|
|
|
|
if signature and sig_alg:
|
|
# if sig_alg == "http://www.w3.org/2000/09/xmldsig#rsa-sha1":
|
|
sig_hash = hashes.SHA1() # nosec
|
|
|
|
querystring = f"SAMLRequest={quote_plus(saml_request)}&"
|
|
if relay_state is not None:
|
|
querystring += f"RelayState={quote_plus(relay_state)}&"
|
|
querystring += f"SigAlg={sig_alg}"
|
|
|
|
public_key = self.provider.signing_kp.private_key.public_key()
|
|
try:
|
|
public_key.verify(
|
|
b64decode(signature),
|
|
querystring.encode(),
|
|
padding.PSS(
|
|
mgf=padding.MGF1(sig_hash), salt_length=padding.PSS.MAX_LENGTH
|
|
),
|
|
sig_hash,
|
|
)
|
|
except InvalidSignature as exc:
|
|
raise CannotHandleAssertion("Failed to verify signature") from exc
|
|
return self._parse_xml(decoded_xml, relay_state)
|
|
|
|
def idp_initiated(self) -> AuthNRequest:
|
|
"""Create IdP Initiated AuthNRequest"""
|
|
return AuthNRequest()
|