84 lines
2.9 KiB
Python
84 lines
2.9 KiB
Python
|
"""SAML Request Parse/builder"""
|
||
|
from typing import TYPE_CHECKING, Optional
|
||
|
|
||
|
from defusedxml import ElementTree
|
||
|
from signxml import XMLVerifier
|
||
|
|
||
|
from passbook.channels.out_samlv2.saml.constants import (
|
||
|
NS_SAML_ASSERTION,
|
||
|
NS_SAML_PROTOCOL,
|
||
|
SAML_ATTRIB_ACS_URL,
|
||
|
SAML_ATTRIB_DESTINATION,
|
||
|
SAML_ATTRIB_ID,
|
||
|
SAML_ATTRIB_ISSUE_INSTANT,
|
||
|
SAML_ATTRIB_PROTOCOL_BINDING,
|
||
|
)
|
||
|
from passbook.channels.out_samlv2.saml.utils import decode_base64_and_inflate
|
||
|
from passbook.crypto.models import CertificateKeyPair
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from xml.etree.ElementTree import Element # nosec
|
||
|
|
||
|
|
||
|
# pylint: disable=too-many-instance-attributes
|
||
|
class SAMLRequest:
|
||
|
"""SAML Request data class, parse raw base64-encoded data, checks signature and more"""
|
||
|
|
||
|
_root: "Element"
|
||
|
|
||
|
acs_url: str
|
||
|
destination: str
|
||
|
id: str
|
||
|
issue_instant: str
|
||
|
protocol_binding: str
|
||
|
|
||
|
issuer: str
|
||
|
|
||
|
is_signed: bool
|
||
|
_detached_signature: str
|
||
|
|
||
|
def __init__(self):
|
||
|
self.acs_url = ""
|
||
|
self.destination = ""
|
||
|
# pylint: disable=invalid-name
|
||
|
self.id = ""
|
||
|
self.issue_instant = ""
|
||
|
self.protocol_binding = ""
|
||
|
|
||
|
@staticmethod
|
||
|
def parse(raw: str, detached_signature: Optional[str] = None) -> "SAMLRequest":
|
||
|
"""Prase SAML request from raw string, which can be base64 encoded and deflated.
|
||
|
Optionally accepts a detached_signature, as from a REDIRECT request."""
|
||
|
decoded_xml = decode_base64_and_inflate(raw)
|
||
|
root = ElementTree.fromstring(decoded_xml)
|
||
|
req = SAMLRequest()
|
||
|
req._root = root # pylint: disable=protected-access
|
||
|
# Verify the root element's tag
|
||
|
_expected_tag = f"{{{NS_SAML_PROTOCOL}}}AuthnRequest"
|
||
|
if root.tag != _expected_tag:
|
||
|
raise ValueError(
|
||
|
f"Invalid root tag, got '{root.tag}', expected '{_expected_tag}."
|
||
|
)
|
||
|
req.acs_url = root.attrib[SAML_ATTRIB_ACS_URL]
|
||
|
req.destination = root.attrib[SAML_ATTRIB_DESTINATION]
|
||
|
req.id = root.attrib[SAML_ATTRIB_ID]
|
||
|
req.issue_instant = root.attrib[SAML_ATTRIB_ISSUE_INSTANT]
|
||
|
req.protocol_binding = root.attrib[SAML_ATTRIB_PROTOCOL_BINDING]
|
||
|
req.issuer = root.find(f"{{{NS_SAML_ASSERTION}}}Issuer").text
|
||
|
# Check if this Request is signed
|
||
|
if detached_signature:
|
||
|
# pylint: disable=protected-access
|
||
|
req._detached_signature = detached_signature
|
||
|
return req
|
||
|
|
||
|
def verify_signature(self, keypair: CertificateKeyPair):
|
||
|
"""Verify signature of SAML Request.
|
||
|
Raises `cryptography.exceptions.InvalidSignature` on validaton failure."""
|
||
|
verifier = XMLVerifier()
|
||
|
if self._detached_signature:
|
||
|
verifier.verify(
|
||
|
self._detached_signature, x509_cert=keypair.certificate_data
|
||
|
)
|
||
|
else:
|
||
|
verifier.verify(self._root, x509_cert=keypair.certificate_data)
|