providers/saml: big cleanup, simplify base processor

add New fields for
 - assertion_valid_not_before
 - assertion_valid_not_on_or_after
 - session_valid_not_on_or_after
allow flexible time durations for these fields
fall back to Provider's ACS if none is specified in AuthNRequest
This commit is contained in:
Jens Langhammer 2020-02-14 15:19:48 +01:00
parent 2be026dd44
commit e36d7928e4
19 changed files with 495 additions and 392 deletions

View File

@ -1,336 +0,0 @@
"""Basic SAML Processor"""
import time
import uuid
from defusedxml import ElementTree
from structlog import get_logger
from passbook.providers.saml import exceptions, utils, xml_render
MINUTES = 60
HOURS = 60 * MINUTES
def get_random_id():
"""Random hex id"""
# It is very important that these random IDs NOT start with a number.
random_id = "_" + uuid.uuid4().hex
return random_id
def get_time_string(delta=0):
"""Get Data formatted in SAML format"""
return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(time.time() + delta))
# Design note: I've tried to make this easy to sub-class and override
# just the bits you need to override. I've made use of object properties,
# so that your sub-classes have access to all information: use wisely.
# Formatting note: These methods are alphabetized.
# pylint: disable=too-many-instance-attributes
class Processor:
"""Base SAML 2.0 AuthnRequest to Response Processor.
Sub-classes should provide Service Provider-specific functionality."""
is_idp_initiated = False
_audience = ""
_assertion_params = None
_assertion_xml = None
_assertion_id = None
_django_request = None
_relay_state = None
_request = None
_request_id = None
_request_xml = None
_request_params = None
_response_id = None
_response_xml = None
_response_params = None
_saml_request = None
_saml_response = None
_session_index = None
_subject = None
_subject_format = "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent"
_system_params = {}
@property
def dotted_path(self):
"""Return a dotted path to this class"""
return "{module}.{class_name}".format(
module=self.__module__, class_name=self.__class__.__name__
)
def __init__(self, remote):
self.name = remote.name
self._remote = remote
self._logger = get_logger()
self._system_params["ISSUER"] = self._remote.issuer
self._logger.debug("processor configured")
def _build_assertion(self):
"""Builds _assertion_params."""
self._determine_assertion_id()
self._determine_audience()
self._determine_subject()
self._determine_session_index()
self._assertion_params = {
"ASSERTION_ID": self._assertion_id,
"ASSERTION_SIGNATURE": "", # it's unsigned
"AUDIENCE": self._audience,
"AUTH_INSTANT": get_time_string(),
"ISSUE_INSTANT": get_time_string(),
"NOT_BEFORE": get_time_string(-1 * HOURS), # TODO: Make these settings.
"NOT_ON_OR_AFTER": get_time_string(86400 * MINUTES),
"SESSION_INDEX": self._session_index,
"SESSION_NOT_ON_OR_AFTER": get_time_string(8 * HOURS),
"SP_NAME_QUALIFIER": self._audience,
"SUBJECT": self._subject,
"SUBJECT_FORMAT": self._subject_format,
}
self._assertion_params.update(self._system_params)
self._assertion_params.update(self._request_params)
def _build_response(self):
"""Builds _response_params."""
self._determine_response_id()
self._response_params = {
"ASSERTION": self._assertion_xml,
"ISSUE_INSTANT": get_time_string(),
"RESPONSE_ID": self._response_id,
"RESPONSE_SIGNATURE": "", # initially unsigned
}
self._response_params.update(self._system_params)
self._response_params.update(self._request_params)
def _decode_request(self):
"""Decodes _request_xml from _saml_request."""
self._request_xml = utils.decode_base64_and_inflate(self._saml_request).decode(
"utf-8"
)
self._logger.debug("SAML request decoded")
def _determine_assertion_id(self):
"""Determines the _assertion_id."""
self._assertion_id = get_random_id()
def _determine_audience(self):
"""Determines the _audience."""
self._audience = self._remote.audience
self._logger.info("determined audience")
def _determine_response_id(self):
"""Determines _response_id."""
self._response_id = get_random_id()
def _determine_session_index(self):
self._session_index = self._django_request.session.session_key
def _determine_subject(self):
"""Determines _subject and _subject_type for Assertion Subject."""
self._subject = self._django_request.user.email
def _encode_response(self):
"""Encodes _response_xml to _encoded_xml."""
self._saml_response = utils.nice64(str.encode(self._response_xml))
def _extract_saml_request(self):
"""Retrieves the _saml_request AuthnRequest from the _django_request."""
self._saml_request = self._django_request.session["SAMLRequest"]
self._relay_state = self._django_request.session["RelayState"]
def _format_assertion(self):
"""Formats _assertion_params as _assertion_xml."""
# https://commons.lbl.gov/display/IDMgmt/Attribute+Definitions
self._assertion_params["ATTRIBUTES"] = [
{
"FriendlyName": "eduPersonPrincipalName",
"Name": "urn:oid:1.3.6.1.4.1.5923.1.1.1.6",
"Value": self._django_request.user.email,
},
{
"FriendlyName": "cn",
"Name": "urn:oid:2.5.4.3",
"Value": self._django_request.user.name,
},
{
"FriendlyName": "mail",
"Name": "urn:oid:0.9.2342.19200300.100.1.3",
"Value": self._django_request.user.email,
},
{
"FriendlyName": "displayName",
"Name": "urn:oid:2.16.840.1.113730.3.1.241",
"Value": self._django_request.user.username,
},
{
"FriendlyName": "uid",
"Name": "urn:oid:0.9.2342.19200300.100.1.1",
"Value": self._django_request.user.pk,
},
]
from passbook.providers.saml.models import SAMLPropertyMapping
for mapping in self._remote.property_mappings.all().select_subclasses():
if isinstance(mapping, SAMLPropertyMapping):
mapping_payload = {
"Name": mapping.saml_name,
"ValueArray": [],
"FriendlyName": mapping.friendly_name,
}
for value in mapping.values:
mapping_payload["ValueArray"].append(
value.format(
user=self._django_request.user, request=self._django_request
)
)
self._assertion_params["ATTRIBUTES"].append(mapping_payload)
self._assertion_xml = xml_render.get_assertion_xml(
"saml/xml/assertions/generic.xml", self._assertion_params, signed=True
)
def _format_response(self):
"""Formats _response_params as _response_xml."""
assertion_id = self._assertion_params["ASSERTION_ID"]
self._response_xml = xml_render.get_response_xml(
self._response_params, saml_provider=self._remote, assertion_id=assertion_id
)
def _get_django_response_params(self):
"""Returns a dictionary of parameters for the response template."""
return {
"acs_url": self._request_params["ACS_URL"],
"saml_response": self._saml_response,
"relay_state": self._relay_state,
"autosubmit": self._remote.application.skip_authorization,
}
def _parse_request(self):
"""Parses various parameters from _request_xml into _request_params."""
# Minimal test to verify that it's not binarily encoded still:
if not str(self._request_xml.strip()).startswith("<"):
raise Exception(
"RequestXML is not valid XML; "
"it may need to be decoded or decompressed."
)
root = ElementTree.fromstring(self._request_xml)
params = {}
params["ACS_URL"] = root.attrib["AssertionConsumerServiceURL"]
params["REQUEST_ID"] = root.attrib["ID"]
params["DESTINATION"] = root.attrib.get("Destination", "")
params["PROVIDER_NAME"] = root.attrib.get("ProviderName", "")
self._request_params = params
def _reset(self, django_request, sp_config=None):
"""Initialize (and reset) object properties, so we don't risk carrying
over anything from the last authentication.
If provided, use sp_config throughout; otherwise, it will be set in
_validate_request(). """
self._assertion_params = sp_config
self._assertion_xml = sp_config
self._assertion_id = sp_config
self._django_request = django_request
self._relay_state = sp_config
self._request = sp_config
self._request_id = sp_config
self._request_xml = sp_config
self._request_params = sp_config
self._response_id = sp_config
self._response_xml = sp_config
self._response_params = sp_config
self._saml_request = sp_config
self._saml_response = sp_config
self._session_index = sp_config
self._subject = sp_config
self._subject_format = "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent"
self._system_params = {"ISSUER": self._remote.issuer}
def _validate_request(self):
"""
Validates the SAML request against the SP configuration of this
processor. Sub-classes should override this and raise a
`CannotHandleAssertion` exception if the validation fails.
Raises:
CannotHandleAssertion: if the ACS URL specified in the SAML request
doesn't match the one specified in the processor config.
"""
request_acs_url = self._request_params["ACS_URL"]
if self._remote.acs_url != request_acs_url:
msg = "couldn't find ACS url '{}' in SAML2IDP_REMOTES " "setting.".format(
request_acs_url
)
self._logger.info(msg)
raise exceptions.CannotHandleAssertion(msg)
def _validate_user(self):
"""Validates the User. Sub-classes should override this and
throw an CannotHandleAssertion Exception if the validation does not succeed."""
def can_handle(self, request):
"""Returns true if this processor can handle this request."""
self._reset(request)
# Read the request.
try:
self._extract_saml_request()
except Exception as exc:
msg = "can't find SAML request in user session: %s" % exc
self._logger.info(msg)
raise exceptions.CannotHandleAssertion(msg)
try:
self._decode_request()
except Exception as exc:
msg = "can't decode SAML request: %s" % exc
self._logger.info(msg)
raise exceptions.CannotHandleAssertion(msg)
try:
self._parse_request()
except Exception as exc:
msg = "can't parse SAML request: %s" % exc
self._logger.info(msg)
raise exceptions.CannotHandleAssertion(msg)
self._validate_request()
return True
def generate_response(self):
"""Processes request and returns template variables suitable for a response."""
# Build the assertion and response.
# Only call can_handle if SP initiated Request, otherwise we have no Request
if not self.is_idp_initiated:
self.can_handle(self._django_request)
self._validate_user()
self._build_assertion()
self._format_assertion()
self._build_response()
self._format_response()
self._encode_response()
# Return proper template params.
return self._get_django_response_params()
def init_deep_link(self, request, url):
"""Initialize this Processor to make an IdP-initiated call to the SP's
deep-linked URL."""
self._reset(request)
acs_url = self._remote.acs_url
# NOTE: The following request params are made up. Some are blank,
# because they comes over in the AuthnRequest, but we don't have an
# AuthnRequest in this case:
# - Destination: Should be this IdP's SSO endpoint URL. Not used in the response?
# - ProviderName: According to the spec, this is optional.
self._request_params = {
"ACS_URL": acs_url,
"DESTINATION": "",
"PROVIDER_NAME": "",
}
self._relay_state = url

View File

@ -10,7 +10,7 @@ from passbook.providers.saml.models import (
SAMLProvider, SAMLProvider,
get_provider_choices, get_provider_choices,
) )
from passbook.providers.saml.utils import CertificateBuilder from passbook.providers.saml.utils.cert import CertificateBuilder
class SAMLProviderForm(forms.ModelForm): class SAMLProviderForm(forms.ModelForm):
@ -32,12 +32,14 @@ class SAMLProviderForm(forms.ModelForm):
model = SAMLProvider model = SAMLProvider
fields = [ fields = [
"name", "name",
"property_mappings", "processor_path",
"acs_url", "acs_url",
"audience", "audience",
"processor_path",
"issuer", "issuer",
"assertion_valid_for", "assertion_valid_not_before",
"assertion_valid_not_on_or_after",
"session_valid_not_on_or_after",
"property_mappings",
"signing", "signing",
"signing_cert", "signing_cert",
"signing_key", "signing_key",
@ -50,6 +52,9 @@ class SAMLProviderForm(forms.ModelForm):
"name": forms.TextInput(), "name": forms.TextInput(),
"audience": forms.TextInput(), "audience": forms.TextInput(),
"issuer": forms.TextInput(), "issuer": forms.TextInput(),
"assertion_valid_not_before": forms.TextInput(),
"assertion_valid_not_on_or_after": forms.TextInput(),
"session_valid_not_on_or_after": forms.TextInput(),
"property_mappings": FilteredSelectMultiple(_("Property Mappings"), False), "property_mappings": FilteredSelectMultiple(_("Property Mappings"), False),
} }

View File

@ -0,0 +1,61 @@
# Generated by Django 2.2.9 on 2020-02-14 13:54
from django.db import migrations, models
import passbook.providers.saml.utils.time
def migrate_valid_for(apps, schema_editor):
"""Migrate from single number standing for minutes to 'minutes=3'"""
SAMLProvider = apps.get_model("passbook_providers_saml", "SAMLProvider")
db_alias = schema_editor.connection.alias
for provider in SAMLProvider.objects.using(db_alias).all():
provider.assertion_valid_not_on_or_after = (
f"minutes={provider.assertion_valid_for}"
)
provider.save()
class Migration(migrations.Migration):
dependencies = [
("passbook_providers_saml", "0001_initial"),
]
operations = [
migrations.AddField(
model_name="samlprovider",
name="assertion_valid_not_before",
field=models.TextField(
default="minutes=5",
help_text="Assertion valid not before current time - this value (Format: hours=1;minutes=2;seconds=3).",
validators=[
passbook.providers.saml.utils.time.timedelta_string_validator
],
),
),
migrations.AddField(
model_name="samlprovider",
name="assertion_valid_not_on_or_after",
field=models.TextField(
default="minutes=5",
help_text="Assertion not valid on or after current time + this value (Format: hours=1;minutes=2;seconds=3).",
validators=[
passbook.providers.saml.utils.time.timedelta_string_validator
],
),
),
migrations.RunPython(migrate_valid_for),
migrations.RemoveField(model_name="samlprovider", name="assertion_valid_for",),
migrations.AddField(
model_name="samlprovider",
name="session_valid_not_on_or_after",
field=models.TextField(
default="minutes=86400",
help_text="Session not valid on or after current time + this value (Format: hours=1;minutes=2;seconds=3).",
validators=[
passbook.providers.saml.utils.time.timedelta_string_validator
],
),
),
]

View File

@ -7,7 +7,8 @@ from structlog import get_logger
from passbook.core.models import PropertyMapping, Provider from passbook.core.models import PropertyMapping, Provider
from passbook.lib.utils.reflection import class_to_path, path_to_class from passbook.lib.utils.reflection import class_to_path, path_to_class
from passbook.providers.saml.base import Processor from passbook.providers.saml.processors.base import Processor
from passbook.providers.saml.utils.time import timedelta_string_validator
LOGGER = get_logger() LOGGER = get_logger()
@ -16,11 +17,44 @@ class SAMLProvider(Provider):
"""Model to save information about a Remote SAML Endpoint""" """Model to save information about a Remote SAML Endpoint"""
name = models.TextField() name = models.TextField()
processor_path = models.CharField(max_length=255, choices=[])
acs_url = models.URLField() acs_url = models.URLField()
audience = models.TextField(default="") audience = models.TextField(default="")
processor_path = models.CharField(max_length=255, choices=[])
issuer = models.TextField() issuer = models.TextField()
assertion_valid_for = models.IntegerField(default=86400)
assertion_valid_not_before = models.TextField(
default="minutes=5",
validators=[timedelta_string_validator],
help_text=_(
(
"Assertion valid not before current time - this value "
"(Format: hours=1;minutes=2;seconds=3)."
)
),
)
assertion_valid_not_on_or_after = models.TextField(
default="minutes=5",
validators=[timedelta_string_validator],
help_text=_(
(
"Assertion not valid on or after current time + this value "
"(Format: hours=1;minutes=2;seconds=3)."
)
),
)
session_valid_not_on_or_after = models.TextField(
default="minutes=86400",
validators=[timedelta_string_validator],
help_text=_(
(
"Session not valid on or after current time + this value "
"(Format: hours=1;minutes=2;seconds=3)."
)
),
)
signing = models.BooleanField(default=True) signing = models.BooleanField(default=True)
signing_cert = models.TextField() signing_cert = models.TextField()
signing_key = models.TextField() signing_key = models.TextField()
@ -44,7 +78,7 @@ class SAMLProvider(Provider):
return self._processor return self._processor
def __str__(self): def __str__(self):
return "SAML Provider %s" % self.name return f"SAML Provider {self.name}"
def link_download_metadata(self): def link_download_metadata(self):
"""Get link to download XML metadata for admin interface""" """Get link to download XML metadata for admin interface"""
@ -73,7 +107,7 @@ class SAMLPropertyMapping(PropertyMapping):
form = "passbook.providers.saml.forms.SAMLPropertyMappingForm" form = "passbook.providers.saml.forms.SAMLPropertyMappingForm"
def __str__(self): def __str__(self):
return "SAML Property Mapping %s" % self.saml_name return f"SAML Property Mapping {self.saml_name}"
class Meta: class Meta:

View File

@ -0,0 +1,252 @@
"""Basic SAML Processor"""
from typing import TYPE_CHECKING, Dict, List, Union
from defusedxml import ElementTree
from django.http import HttpRequest
from structlog import get_logger
from passbook.providers.saml.exceptions import CannotHandleAssertion
from passbook.providers.saml.utils import get_random_id
from passbook.providers.saml.utils.encoding import decode_base64_and_inflate, nice64
from passbook.providers.saml.utils.time import get_time_string, timedelta_from_string
from passbook.providers.saml.utils.xml_render import get_assertion_xml, get_response_xml
if TYPE_CHECKING:
from passbook.providers.saml.models import SAMLProvider
# pylint: disable=too-many-instance-attributes
class Processor:
"""Base SAML 2.0 AuthnRequest to Response Processor.
Sub-classes should provide Service Provider-specific functionality."""
is_idp_initiated = False
_remote: "SAMLProvider"
_http_request: HttpRequest
_assertion_xml: str
_response_xml: str
_saml_response: str
_relay_state: str
_saml_request: str
_assertion_params: Dict[str, Union[str, List[Dict[str, str]]]]
_request_params: Dict[str, str]
_system_params: Dict[str, str]
_response_params: Dict[str, str]
@property
def subject_format(self) -> str:
"""Get subject Format"""
return "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent"
def __init__(self, remote: "SAMLProvider"):
self.name = remote.name
self._remote = remote
self._logger = get_logger()
self._system_params = {
"ISSUER": self._remote.issuer,
}
def _build_assertion(self):
"""Builds _assertion_params."""
self._assertion_params = {
"ASSERTION_ID": get_random_id(),
"ASSERTION_SIGNATURE": "", # it's unsigned
"AUDIENCE": self._remote.audience,
"AUTH_INSTANT": get_time_string(),
"ISSUE_INSTANT": get_time_string(),
"NOT_BEFORE": get_time_string(
timedelta_from_string(self._remote.assertion_valid_not_before)
),
"NOT_ON_OR_AFTER": get_time_string(
timedelta_from_string(self._remote.assertion_valid_not_on_or_after)
),
"SESSION_INDEX": self._http_request.session.session_key,
"SESSION_NOT_ON_OR_AFTER": get_time_string(
timedelta_from_string(self._remote.session_valid_not_on_or_after)
),
"SP_NAME_QUALIFIER": self._remote.audience,
"SUBJECT": self._http_request.user.email,
"SUBJECT_FORMAT": self.subject_format,
}
self._assertion_params.update(self._system_params)
self._assertion_params.update(self._request_params)
def _build_response(self):
"""Builds _response_params."""
self._response_params = {
"ASSERTION": self._assertion_xml,
"ISSUE_INSTANT": get_time_string(),
"RESPONSE_ID": get_random_id(),
"RESPONSE_SIGNATURE": "", # initially unsigned
}
self._response_params.update(self._system_params)
self._response_params.update(self._request_params)
def _encode_response(self):
"""Encodes _response_xml to _encoded_xml."""
self._saml_response = nice64(str.encode(self._response_xml))
def _extract_saml_request(self):
"""Retrieves the _saml_request AuthnRequest from the _http_request."""
self._saml_request = self._http_request.session["SAMLRequest"]
self._relay_state = self._http_request.session["RelayState"]
def _format_assertion(self):
"""Formats _assertion_params as _assertion_xml."""
# https://commons.lbl.gov/display/IDMgmt/Attribute+Definitions
self._assertion_params["ATTRIBUTES"] = [
{
"FriendlyName": "eduPersonPrincipalName",
"Name": "urn:oid:1.3.6.1.4.1.5923.1.1.1.6",
"Value": self._http_request.user.email,
},
{
"FriendlyName": "cn",
"Name": "urn:oid:2.5.4.3",
"Value": self._http_request.user.name,
},
{
"FriendlyName": "mail",
"Name": "urn:oid:0.9.2342.19200300.100.1.3",
"Value": self._http_request.user.email,
},
{
"FriendlyName": "displayName",
"Name": "urn:oid:2.16.840.1.113730.3.1.241",
"Value": self._http_request.user.username,
},
{
"FriendlyName": "uid",
"Name": "urn:oid:0.9.2342.19200300.100.1.1",
"Value": self._http_request.user.pk,
},
]
from passbook.providers.saml.models import SAMLPropertyMapping
for mapping in self._remote.property_mappings.all().select_subclasses():
if isinstance(mapping, SAMLPropertyMapping):
mapping_payload = {
"Name": mapping.saml_name,
"ValueArray": [],
"FriendlyName": mapping.friendly_name,
}
for value in mapping.values:
mapping_payload["ValueArray"].append(
value.format(
user=self._http_request.user, request=self._http_request
)
)
self._assertion_params["ATTRIBUTES"].append(mapping_payload)
self._assertion_xml = get_assertion_xml(
"saml/xml/assertions/generic.xml", self._assertion_params, signed=True
)
def _format_response(self):
"""Formats _response_params as _response_xml."""
assertion_id = self._assertion_params["ASSERTION_ID"]
self._response_xml = get_response_xml(
self._response_params, saml_provider=self._remote, assertion_id=assertion_id
)
def _get_django_response_params(self) -> Dict[str, str]:
"""Returns a dictionary of parameters for the response template."""
return {
"acs_url": self._request_params["ACS_URL"],
"saml_response": self._saml_response,
"relay_state": self._relay_state,
"autosubmit": self._remote.application.skip_authorization,
}
def _decode_and_parse_request(self):
"""Parses various parameters from _request_xml into _request_params."""
decoded_xml = decode_base64_and_inflate(self._saml_request).decode("utf-8")
root = ElementTree.fromstring(decoded_xml)
params = {}
params["ACS_URL"] = root.attrib.get(
"AssertionConsumerServiceURL", self._remote.acs_url
)
params["REQUEST_ID"] = root.attrib["ID"]
params["DESTINATION"] = root.attrib.get("Destination", "")
params["PROVIDER_NAME"] = root.attrib.get("ProviderName", "")
self._request_params = params
def _validate_request(self):
"""
Validates the SAML request against the SP configuration of this
processor. Sub-classes should override this and raise a
`CannotHandleAssertion` exception if the validation fails.
Raises:
CannotHandleAssertion: if the ACS URL specified in the SAML request
doesn't match the one specified in the processor config.
"""
request_acs_url = self._request_params["ACS_URL"]
if self._remote.acs_url != request_acs_url:
msg = "couldn't find ACS url '{}' in SAML2IDP_REMOTES " "setting.".format(
request_acs_url
)
self._logger.info(msg)
raise CannotHandleAssertion(msg)
def _validate_user(self):
"""Validates the User. Sub-classes should override this and
throw an CannotHandleAssertion Exception if the validation does not succeed."""
def can_handle(self, request: HttpRequest) -> bool:
"""Returns true if this processor can handle this request."""
self._http_request = request
# Read the request.
try:
self._extract_saml_request()
except Exception as exc:
raise CannotHandleAssertion(
f"can't find SAML request in user session: {exc}"
) from exc
try:
self._decode_and_parse_request()
except Exception as exc:
raise CannotHandleAssertion(f"can't parse SAML request: {exc}") from exc
self._validate_request()
return True
def generate_response(self) -> Dict[str, str]:
"""Processes request and returns template variables suitable for a response."""
# Build the assertion and response.
# Only call can_handle if SP initiated Request, otherwise we have no Request
if not self.is_idp_initiated:
self.can_handle(self._http_request)
self._validate_user()
self._build_assertion()
self._format_assertion()
self._build_response()
self._format_response()
self._encode_response()
# Return proper template params.
return self._get_django_response_params()
def init_deep_link(self, request: HttpRequest, url: str):
"""Initialize this Processor to make an IdP-initiated call to the SP's
deep-linked URL."""
self._http_request = request
acs_url = self._remote.acs_url
# NOTE: The following request params are made up. Some are blank,
# because they comes over in the AuthnRequest, but we don't have an
# AuthnRequest in this case:
# - Destination: Should be this IdP's SSO endpoint URL. Not used in the response?
# - ProviderName: According to the spec, this is optional.
self._request_params = {
"ACS_URL": acs_url,
"DESTINATION": "",
"PROVIDER_NAME": "",
}
self._relay_state = url

View File

@ -1,7 +1,7 @@
"""Generic Processor""" """Generic Processor"""
from passbook.providers.saml.base import Processor from passbook.providers.saml.processors.base import Processor
class GenericProcessor(Processor): class GenericProcessor(Processor):
"""Generic Response Handler Processor for testing against django-saml2-sp.""" """Generic SAML2 Processor"""

View File

@ -1,16 +1,14 @@
"""Salesforce Processor""" """Salesforce Processor"""
from passbook.providers.saml.base import Processor from passbook.providers.saml.processors.generic import GenericProcessor
from passbook.providers.saml.xml_render import get_assertion_xml from passbook.providers.saml.utils.xml_render import get_assertion_xml
class SalesForceProcessor(Processor): class SalesForceProcessor(GenericProcessor):
"""SalesForce.com-specific SAML 2.0 AuthnRequest to Response Handler Processor.""" """SalesForce.com-specific SAML 2.0 AuthnRequest to Response Handler Processor."""
def _determine_audience(self):
self._audience = "IAMShowcase"
def _format_assertion(self): def _format_assertion(self):
super()._format_assertion()
self._assertion_xml = get_assertion_xml( self._assertion_xml = get_assertion_xml(
"saml/xml/assertions/salesforce.xml", self._assertion_params, signed=True "saml/xml/assertions/salesforce.xml", self._assertion_params, signed=True
) )

View File

@ -0,0 +1,30 @@
"""Test time utils"""
from datetime import timedelta
from django.core.exceptions import ValidationError
from django.test import TestCase
from passbook.providers.saml.utils.time import (
timedelta_from_string,
timedelta_string_validator,
)
class TestTimeUtils(TestCase):
"""Test time-utils"""
def test_valid(self):
"""Test valid expression"""
expr = "hours=3;minutes=1"
expected = timedelta(hours=3, minutes=1)
self.assertEqual(timedelta_from_string(expr), expected)
def test_invalid(self):
"""Test invalid expression"""
with self.assertRaises(ValueError):
timedelta_from_string("foo")
def test_validation(self):
"""Test Django model field validator"""
with self.assertRaises(ValidationError):
timedelta_string_validator("foo")

View File

@ -0,0 +1,18 @@
"""Small helper functions"""
import uuid
from django.http import HttpRequest, HttpResponse
from django.shortcuts import render
from django.template.context import Context
def render_xml(request: HttpRequest, template: str, ctx: Context) -> HttpResponse:
"""Render template with content_type application/xml"""
return render(request, template, context=ctx, content_type="application/xml")
def get_random_id() -> str:
"""Random hex id"""
# It is very important that these random IDs NOT start with a number.
random_id = "_" + uuid.uuid4().hex
return random_id

View File

@ -1,8 +1,6 @@
"""Wrappers to de/encode and de/inflate strings""" """Create self-signed certificates"""
import base64
import datetime import datetime
import uuid import uuid
import zlib
from cryptography import x509 from cryptography import x509
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
@ -11,24 +9,6 @@ from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID from cryptography.x509.oid import NameOID
def decode_base64_and_inflate(b64string):
"""Base64 decode and ZLib decompress b64string"""
decoded_data = base64.b64decode(b64string)
return zlib.decompress(decoded_data, -15)
def deflate_and_base64_encode(string_val):
"""Base64 and ZLib Compress b64string"""
zlibbed_str = zlib.compress(string_val)
compressed_string = zlibbed_str[2:-4]
return base64.b64encode(compressed_string)
def nice64(src):
""" Returns src base64-encoded and formatted nicely for our XML. """
return base64.b64encode(src).decode("utf-8").replace("\n", "")
class CertificateBuilder: class CertificateBuilder:
"""Build self-signed certificates""" """Build self-signed certificates"""

View File

@ -0,0 +1,21 @@
"""Wrappers to de/encode and de/inflate strings"""
import base64
import zlib
def decode_base64_and_inflate(b64string):
"""Base64 decode and ZLib decompress b64string"""
decoded_data = base64.b64decode(b64string)
return zlib.decompress(decoded_data, -15)
def deflate_and_base64_encode(string_val):
"""Base64 and ZLib Compress b64string"""
zlibbed_str = zlib.compress(string_val)
compressed_string = zlibbed_str[2:-4]
return base64.b64encode(compressed_string)
def nice64(src):
""" Returns src base64-encoded and formatted nicely for our XML. """
return base64.b64encode(src).decode("utf-8").replace("\n", "")

View File

@ -0,0 +1,45 @@
"""Time utilities"""
import datetime
from django.core.exceptions import ValidationError
from django.utils.translation import gettext_lazy as _
ALLOWED_KEYS = (
"days",
"seconds",
"microseconds",
"milliseconds",
"minutes",
"hours",
"weeks",
)
def timedelta_string_validator(value: str):
"""Validator for Django that checks if value can be parsed with `timedelta_from_string`"""
try:
timedelta_from_string(value)
except ValueError as exc:
raise ValidationError(
_("%(value)s is not in the correct format of 'hours=3;minutes=1'."),
params={"value": value},
) from exc
def timedelta_from_string(expr: str) -> datetime.timedelta:
"""Convert a string with the format of 'hours=1;minute=3;seconds=5' to a
`datetime.timedelta` Object with hours = 1, minutes = 3, seconds = 5"""
kwargs = {}
for duration_pair in expr.split(";"):
key, value = duration_pair.split("=")
if key.lower() not in ALLOWED_KEYS:
continue
kwargs[key.lower()] = float(value)
return datetime.timedelta(**kwargs)
def get_time_string(delta: datetime.timedelta = None) -> str:
"""Get Data formatted in SAML format"""
now = datetime.datetime.now()
final = now + delta
return final.strftime("%Y-%m-%dT%H:%M:%SZ")

View File

@ -6,7 +6,10 @@ from typing import TYPE_CHECKING
from structlog import get_logger from structlog import get_logger
from passbook.lib.utils.template import render_to_string from passbook.lib.utils.template import render_to_string
from passbook.providers.saml.xml_signing import get_signature_xml, sign_with_signxml from passbook.providers.saml.utils.xml_signing import (
get_signature_xml,
sign_with_signxml,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from passbook.providers.saml.models import SAMLProvider from passbook.providers.saml.models import SAMLProvider
@ -60,7 +63,6 @@ def get_assertion_xml(template, parameters, signed=False):
_get_attribute_statement(params) _get_attribute_statement(params)
unsigned = render_to_string(template, params) unsigned = render_to_string(template, params)
# LOGGER.debug('Unsigned: %s', unsigned)
if not signed: if not signed:
return unsigned return unsigned
@ -80,13 +82,11 @@ def get_response_xml(parameters, saml_provider: SAMLProvider, assertion_id=""):
raw_response = render_to_string("saml/xml/response.xml", params) raw_response = render_to_string("saml/xml/response.xml", params)
# LOGGER.debug('Unsigned: %s', unsigned)
if not saml_provider.signing: if not saml_provider.signing:
return raw_response return raw_response
signature_xml = get_signature_xml() signature_xml = get_signature_xml()
params["RESPONSE_SIGNATURE"] = signature_xml params["RESPONSE_SIGNATURE"] = signature_xml
# LOGGER.debug("Raw response: %s", raw_response)
signed = sign_with_signxml( signed = sign_with_signxml(
saml_provider.signing_key, saml_provider.signing_key,

View File

@ -39,18 +39,13 @@ def _generate_response(request, provider: SAMLProvider):
return render(request, "saml/idp/login.html", ctx) return render(request, "saml/idp/login.html", ctx)
def render_xml(request, template, ctx):
"""Render template with content_type application/xml"""
return render(request, template, context=ctx, content_type="application/xml")
class AccessRequiredView(AccessMixin, View): class AccessRequiredView(AccessMixin, View):
"""Mixin class for Views using a provider instance""" """Mixin class for Views using a provider instance"""
_provider = None _provider: SAMLProvider
@property @property
def provider(self): def provider(self) -> SAMLProvider:
"""Get provider instance""" """Get provider instance"""
if not self._provider: if not self._provider:
application = get_object_or_404( application = get_object_or_404(
@ -147,10 +142,10 @@ class LoginProcessView(AccessRequiredView):
relay_state=ctx["relay_state"], relay_state=ctx["relay_state"],
) )
try: try:
full_res = _generate_response(request, self.provider) return _generate_response(request, self.provider)
return full_res
except exceptions.CannotHandleAssertion as exc: except exceptions.CannotHandleAssertion as exc:
LOGGER.debug(exc) LOGGER.debug(exc)
return HttpResponseBadRequest()
# pylint: disable=unused-argument # pylint: disable=unused-argument
def post(self, request, application): def post(self, request, application):

View File

@ -5,7 +5,7 @@ from django.contrib.admin.widgets import FilteredSelectMultiple
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from passbook.admin.forms.source import SOURCE_FORM_FIELDS from passbook.admin.forms.source import SOURCE_FORM_FIELDS
from passbook.providers.saml.utils import CertificateBuilder from passbook.providers.saml.utils.cert import CertificateBuilder
from passbook.sources.saml.models import SAMLSource from passbook.sources.saml.models import SAMLSource

View File

@ -9,9 +9,9 @@ from django.utils.decorators import method_decorator
from django.views import View from django.views import View
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
from passbook.providers.saml.base import get_random_id, get_time_string from passbook.providers.saml.utils import get_random_id, render_xml
from passbook.providers.saml.utils import nice64 from passbook.providers.saml.utils.encoding import nice64
from passbook.providers.saml.views import render_xml from passbook.providers.saml.utils.time import get_time_string
from passbook.sources.saml.models import SAMLSource from passbook.sources.saml.models import SAMLSource
from passbook.sources.saml.utils import ( from passbook.sources.saml.utils import (
_get_user_from_response, _get_user_from_response,

View File

@ -2,7 +2,7 @@
from structlog import get_logger from structlog import get_logger
from passbook.lib.utils.template import render_to_string from passbook.lib.utils.template import render_to_string
from passbook.providers.saml.xml_signing import get_signature_xml from passbook.providers.saml.utils.xml_signing import get_signature_xml
LOGGER = get_logger() LOGGER = get_logger()