diff --git a/passbook/core/policies.py b/passbook/core/policies.py index f26919cef..917b1a039 100644 --- a/passbook/core/policies.py +++ b/passbook/core/policies.py @@ -54,6 +54,8 @@ class PolicyEngine: def build(self): """Build task group""" + if not self._user: + raise ValueError("User not set.") signatures = [] kwargs = { '__password__': getattr(self._user, '__password__', None), @@ -74,6 +76,7 @@ class PolicyEngine: for policy_action, policy_result, policy_message in self._group.get(): passing = (policy_action == Policy.ACTION_ALLOW and policy_result) or \ (policy_action == Policy.ACTION_DENY and not policy_result) + LOGGER.debug('Action=%s, Result=%r => %r', policy_action, policy_result, passing) if policy_message: messages.append(policy_message) if not passing: diff --git a/passbook/saml_idp/views.py b/passbook/saml_idp/views.py index f5cf03cd1..147a77399 100644 --- a/passbook/saml_idp/views.py +++ b/passbook/saml_idp/views.py @@ -2,7 +2,7 @@ from logging import getLogger from django.contrib.auth import logout -from django.contrib.auth.mixins import LoginRequiredMixin +from django.contrib.auth.mixins import AccessMixin from django.core.exceptions import ValidationError from django.core.validators import URLValidator from django.http import HttpResponse, HttpResponseBadRequest @@ -46,7 +46,7 @@ def render_xml(request, template, ctx): return render(request, template, context=ctx, content_type="application/xml") -class ProviderMixin: +class AccessRequiredView(AccessMixin, View): """Mixin class for Views using a provider instance""" _provider = None @@ -59,8 +59,24 @@ class ProviderMixin: self._provider = get_object_or_404(SAMLProvider, pk=application.provider_id) return self._provider + def _has_access(self): + """Check if user has access to application""" + policy_engine = PolicyEngine(self.provider.application.policies.all()) + policy_engine.for_user(self.request.user).with_request(self.request).build() + return policy_engine.passing -class LoginBeginView(LoginRequiredMixin, View): + def dispatch(self, request, *args, **kwargs): + if not request.user.is_authenticated: + return self.handle_no_permission() + if not self._has_access(): + return render(request, 'login/denied.html', { + 'title': _("You don't have access to this application"), + 'is_login': True + }) + return super().dispatch(request, *args, **kwargs) + + +class LoginBeginView(AccessRequiredView): """Receives a SAML 2.0 AuthnRequest from a Service Provider and stores it in the session prior to enforcing login.""" @@ -83,7 +99,7 @@ class LoginBeginView(LoginRequiredMixin, View): })) -class RedirectToSPView(LoginRequiredMixin, View): +class RedirectToSPView(AccessRequiredView): """Return autosubmit form""" def get(self, request, acs_url, saml_response, relay_state): @@ -97,24 +113,13 @@ class RedirectToSPView(LoginRequiredMixin, View): }) - -class LoginProcessView(ProviderMixin, LoginRequiredMixin, View): +class LoginProcessView(AccessRequiredView): """Processor-based login continuation. Presents a SAML 2.0 Assertion for POSTing back to the Service Provider.""" - def _has_access(self): - """Check if user has access to application""" - policy_engine = PolicyEngine(self.provider.application.policies.all()) - policy_engine.for_user(self.request.user).with_request(self.request).build() - return policy_engine.passing - def get(self, request, application): """Handle get request, i.e. render form""" LOGGER.debug("Request: %s", request) - if not self._has_access(): - return render(request, 'login/denied.html', { - 'title': _("You don't have access to this application") - }) # Check if user has access if self.provider.application.skip_authorization: ctx = self.provider.processor.generate_response() @@ -138,10 +143,6 @@ class LoginProcessView(ProviderMixin, LoginRequiredMixin, View): def post(self, request, application): """Handle post request, return back to ACS""" LOGGER.debug("Request: %s", request) - if not self._has_access(): - return render(request, 'login/denied.html', { - 'title': _("You don't have access to this application") - }) # Check if user has access if request.POST.get('ACSUrl', None): # User accepted request @@ -162,7 +163,7 @@ class LoginProcessView(ProviderMixin, LoginRequiredMixin, View): LOGGER.debug(exc) -class LogoutView(CSRFExemptMixin, LoginRequiredMixin, View): +class LogoutView(CSRFExemptMixin, AccessRequiredView): """Allows a non-SAML 2.0 URL to log out the user and returns a standard logged-out page. (SalesForce and others use this method, though it's technically not SAML 2.0).""" @@ -183,7 +184,7 @@ class LogoutView(CSRFExemptMixin, LoginRequiredMixin, View): return render(request, 'saml/idp/logged_out.html') -class SLOLogout(CSRFExemptMixin, LoginRequiredMixin, View): +class SLOLogout(CSRFExemptMixin, AccessRequiredView): """Receives a SAML 2.0 LogoutRequest from a Service Provider, logs out the user and returns a standard logged-out page.""" @@ -199,7 +200,7 @@ class SLOLogout(CSRFExemptMixin, LoginRequiredMixin, View): return render(request, 'saml/idp/logged_out.html') -class DescriptorDownloadView(ProviderMixin, View): +class DescriptorDownloadView(AccessRequiredView): """Replies with the XML Metadata IDSSODescriptor.""" def get(self, request, application): @@ -223,10 +224,10 @@ class DescriptorDownloadView(ProviderMixin, View): return response -class InitiateLoginView(ProviderMixin, LoginRequiredMixin, View): +class InitiateLoginView(AccessRequiredView): """IdP-initiated Login""" - def dispatch(self, request, application): + def get(self, request, application): """Initiates an IdP-initiated link to a simple SP resource/target URL.""" self.provider.processor.init_deep_link(request, '') return _generate_response(request, self.provider)