95 lines
3.4 KiB
Python
95 lines
3.4 KiB
Python
|
"""SAML Service Provider Metadata Processor"""
|
||
|
from typing import Iterator, Optional
|
||
|
|
||
|
from defusedxml import ElementTree
|
||
|
from django.http import HttpRequest
|
||
|
from lxml.etree import Element, SubElement # nosec
|
||
|
from signxml.util import strip_pem_header
|
||
|
|
||
|
from passbook.sources.saml.models import SAMLSource
|
||
|
from passbook.sources.saml.processors.constants import (
|
||
|
NS_MAP,
|
||
|
NS_SAML_METADATA,
|
||
|
NS_SIGNATURE,
|
||
|
SAML_BINDING_POST,
|
||
|
SAML_NAME_ID_FORMAT_EMAIL,
|
||
|
SAML_NAME_ID_FORMAT_PRESISTENT,
|
||
|
SAML_NAME_ID_FORMAT_TRANSIENT,
|
||
|
SAML_NAME_ID_FORMAT_WINDOWS,
|
||
|
SAML_NAME_ID_FORMAT_X509,
|
||
|
)
|
||
|
|
||
|
|
||
|
class MetadataProcessor:
|
||
|
"""SAML Service Provider Metadata Processor"""
|
||
|
|
||
|
source: SAMLSource
|
||
|
http_request: HttpRequest
|
||
|
|
||
|
def __init__(self, source: SAMLSource, request: HttpRequest):
|
||
|
self.source = source
|
||
|
self.http_request = request
|
||
|
|
||
|
def get_signing_key_descriptor(self) -> Optional[Element]:
|
||
|
"""Get Singing KeyDescriptor, if enabled for the source"""
|
||
|
if self.source.signing_kp:
|
||
|
key_descriptor = Element(f"{{{NS_SAML_METADATA}}}KeyDescriptor")
|
||
|
key_descriptor.attrib["use"] = "signing"
|
||
|
key_info = SubElement(key_descriptor, f"{{{NS_SIGNATURE}}}KeyInfo")
|
||
|
x509_data = SubElement(key_info, f"{{{NS_SIGNATURE}}}X509Data")
|
||
|
x509_certificate = SubElement(
|
||
|
x509_data, f"{{{NS_SIGNATURE}}}X509Certificate"
|
||
|
)
|
||
|
x509_certificate.text = strip_pem_header(
|
||
|
self.source.signing_kp.certificate_data.replace("\r", "")
|
||
|
).replace("\n", "")
|
||
|
return key_descriptor
|
||
|
return None
|
||
|
|
||
|
def get_name_id_formats(self) -> Iterator[Element]:
|
||
|
"""Get compatible NameID Formats"""
|
||
|
formats = [
|
||
|
SAML_NAME_ID_FORMAT_EMAIL,
|
||
|
SAML_NAME_ID_FORMAT_PRESISTENT,
|
||
|
SAML_NAME_ID_FORMAT_X509,
|
||
|
SAML_NAME_ID_FORMAT_WINDOWS,
|
||
|
SAML_NAME_ID_FORMAT_TRANSIENT,
|
||
|
]
|
||
|
for name_id_format in formats:
|
||
|
element = Element(f"{{{NS_SAML_METADATA}}}NameIDFormat")
|
||
|
element.text = name_id_format
|
||
|
yield element
|
||
|
|
||
|
def build_entity_descriptor(self) -> str:
|
||
|
"""Build full EntityDescriptor"""
|
||
|
entity_descriptor = Element(
|
||
|
f"{{{NS_SAML_METADATA}}}EntityDescriptor", nsmap=NS_MAP
|
||
|
)
|
||
|
entity_descriptor.attrib["entityID"] = self.source.get_issuer(self.http_request)
|
||
|
|
||
|
sp_sso_descriptor = SubElement(
|
||
|
entity_descriptor, f"{{{NS_SAML_METADATA}}}SPSSODescriptor"
|
||
|
)
|
||
|
sp_sso_descriptor.attrib[
|
||
|
"protocolSupportEnumeration"
|
||
|
] = "urn:oasis:names:tc:SAML:2.0:protocol"
|
||
|
|
||
|
signing_descriptor = self.get_signing_key_descriptor()
|
||
|
if signing_descriptor:
|
||
|
sp_sso_descriptor.append(signing_descriptor)
|
||
|
|
||
|
for name_id_format in self.get_name_id_formats():
|
||
|
sp_sso_descriptor.append(name_id_format)
|
||
|
|
||
|
assertion_consumer_service = SubElement(
|
||
|
sp_sso_descriptor, f"{{{NS_SAML_METADATA}}}"
|
||
|
)
|
||
|
assertion_consumer_service.attrib["isDefault"] = True
|
||
|
assertion_consumer_service.attrib["index"] = 0
|
||
|
assertion_consumer_service.attrib["Binding"] = SAML_BINDING_POST
|
||
|
assertion_consumer_service.attrib["Location"] = self.source.build_full_url(
|
||
|
self.http_request
|
||
|
)
|
||
|
|
||
|
return ElementTree.tostring(entity_descriptor).decode()
|