providers/saml: fix string being passed to lxml

This commit is contained in:
Jens Langhammer 2020-12-30 22:03:01 +01:00
parent d0ee7908ab
commit a9e53cd52a

View file

@ -5,10 +5,12 @@ from typing import Optional
import xmlsec import xmlsec
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.x509 import load_pem_x509_certificate from cryptography.x509 import load_pem_x509_certificate
from defusedxml.lxml import fromstring
from lxml import etree # nosec from lxml import etree # nosec
from structlog import get_logger from structlog import get_logger
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
from authentik.flows.models import Flow, FlowDesignation
from authentik.providers.saml.models import SAMLBindings, SAMLProvider from authentik.providers.saml.models import SAMLBindings, SAMLProvider
from authentik.providers.saml.utils.encoding import PEM_FOOTER, PEM_HEADER from authentik.providers.saml.utils.encoding import PEM_FOOTER, PEM_HEADER
from authentik.sources.saml.processors.constants import ( from authentik.sources.saml.processors.constants import (
@ -23,6 +25,8 @@ LOGGER = get_logger()
def format_pem_certificate(unformatted_cert: str) -> str: def format_pem_certificate(unformatted_cert: str) -> str:
"""Format single, inline certificate into PEM Format""" """Format single, inline certificate into PEM Format"""
# Ensure that all linebreaks are gone
unformatted_cert = unformatted_cert.replace("\n", "")
chunks, chunk_size = len(unformatted_cert), 64 chunks, chunk_size = len(unformatted_cert), 64
lines = [PEM_HEADER] lines = [PEM_HEADER]
for i in range(0, chunks, chunk_size): for i in range(0, chunks, chunk_size):
@ -104,7 +108,7 @@ class ServiceProviderMetadataParser:
def parse(self, raw_xml: str) -> ServiceProviderMetadata: def parse(self, raw_xml: str) -> ServiceProviderMetadata:
"""Parse raw XML to ServiceProviderMetadata""" """Parse raw XML to ServiceProviderMetadata"""
root = etree.fromstring(raw_xml) # nosec root = fromstring(raw_xml.encode())
entity_id = root.attrib["entityID"] entity_id = root.attrib["entityID"]
sp_sso_descriptors = root.findall(f"{{{NS_SAML_METADATA}}}SPSSODescriptor") sp_sso_descriptors = root.findall(f"{{{NS_SAML_METADATA}}}SPSSODescriptor")