diff --git a/passbook/core/migrations/0002_default_user.py b/passbook/core/migrations/0002_default_user.py
new file mode 100644
index 000000000..66e6a2d3e
--- /dev/null
+++ b/passbook/core/migrations/0002_default_user.py
@@ -0,0 +1,28 @@
+# Generated by Django 3.0.6 on 2020-05-23 16:40
+
+from django.apps.registry import Apps
+from django.db import migrations
+from django.db.backends.base.schema import BaseDatabaseSchemaEditor
+
+
+def create_default_user(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
+ # User = apps.get_model("passbook_core", "User")
+ from passbook.core.models import User
+
+ pbadmin = User.objects.create(
+ username="pbadmin", email="root@localhost", # password="pbadmin"
+ )
+ pbadmin.set_password("pbadmin") # nosec
+ pbadmin.is_superuser = True
+ pbadmin.save()
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ("passbook_core", "0001_initial"),
+ ]
+
+ operations = [
+ migrations.RunPython(create_default_user),
+ ]
diff --git a/passbook/core/signals.py b/passbook/core/signals.py
index 01299f90e..74b6b49f1 100644
--- a/passbook/core/signals.py
+++ b/passbook/core/signals.py
@@ -1,31 +1,7 @@
"""passbook core signals"""
-from django.core.cache import cache
from django.core.signals import Signal
-from django.db.models.signals import post_save
-from django.dispatch import receiver
-from structlog import get_logger
-
-LOGGER = get_logger()
user_signed_up = Signal(providing_args=["request", "user"])
invitation_created = Signal(providing_args=["request", "invitation"])
invitation_used = Signal(providing_args=["request", "invitation", "user"])
password_changed = Signal(providing_args=["user", "password"])
-
-
-@receiver(post_save)
-# pylint: disable=unused-argument
-def invalidate_policy_cache(sender, instance, **_):
- """Invalidate Policy cache when policy is updated"""
- from passbook.policies.models import Policy, PolicyBinding
- from passbook.policies.process import cache_key
-
- if isinstance(instance, Policy):
- LOGGER.debug("Invalidating policy cache", policy=instance)
- total = 0
- for binding in PolicyBinding.objects.filter(policy=instance):
- prefix = cache_key(binding) + "*"
- keys = cache.keys(prefix)
- total += len(keys)
- cache.delete_many(keys)
- LOGGER.debug("Deleted keys", len=total)
diff --git a/passbook/core/templates/user/settings.html b/passbook/core/templates/user/settings.html
index edbcc4928..5e752f935 100644
--- a/passbook/core/templates/user/settings.html
+++ b/passbook/core/templates/user/settings.html
@@ -17,7 +17,9 @@
diff --git a/passbook/core/views/user.py b/passbook/core/views/user.py
index bb575ee8e..4bc02b27f 100644
--- a/passbook/core/views/user.py
+++ b/passbook/core/views/user.py
@@ -1,4 +1,6 @@
"""passbook core user views"""
+from typing import Any, Dict
+
from django.contrib.auth.mixins import LoginRequiredMixin
from django.contrib.messages.views import SuccessMessageMixin
from django.urls import reverse_lazy
@@ -6,6 +8,7 @@ from django.utils.translation import gettext as _
from django.views.generic import UpdateView
from passbook.core.forms.users import UserDetailForm
+from passbook.flows.models import Flow, FlowDesignation
class UserSettingsView(SuccessMessageMixin, LoginRequiredMixin, UpdateView):
@@ -19,3 +22,11 @@ class UserSettingsView(SuccessMessageMixin, LoginRequiredMixin, UpdateView):
def get_object(self):
return self.request.user
+
+ def get_context_data(self, **kwargs: Dict[str, Any]) -> Dict[str, Any]:
+ kwargs = super().get_context_data(**kwargs)
+ unenrollment_flow = Flow.with_policy(
+ self.request, designation=FlowDesignation.UNRENOLLMENT
+ )
+ kwargs["unenrollment_enabled"] = bool(unenrollment_flow)
+ return kwargs
diff --git a/passbook/crypto/forms.py b/passbook/crypto/forms.py
index babf25919..79d5f7100 100644
--- a/passbook/crypto/forms.py
+++ b/passbook/crypto/forms.py
@@ -34,7 +34,6 @@ class CertificateKeyPairForm(forms.ModelForm):
password=None,
backend=default_backend(),
)
- load_pem_x509_certificate(key_data.encode("utf-8"), default_backend())
except ValueError:
raise forms.ValidationError("Unable to load private key.")
return key_data
diff --git a/passbook/crypto/migrations/0002_create_self_signed_kp.py b/passbook/crypto/migrations/0002_create_self_signed_kp.py
new file mode 100644
index 000000000..66239b816
--- /dev/null
+++ b/passbook/crypto/migrations/0002_create_self_signed_kp.py
@@ -0,0 +1,26 @@
+# Generated by Django 3.0.6 on 2020-05-23 23:07
+
+from django.db import migrations
+
+
+def create_self_signed(apps, schema_editor):
+ CertificateKeyPair = apps.get_model("passbook_crypto", "CertificateKeyPair")
+ db_alias = schema_editor.connection.alias
+ from passbook.crypto.builder import CertificateBuilder
+
+ builder = CertificateBuilder()
+ builder.build()
+ CertificateKeyPair.objects.using(db_alias).create(
+ name="passbook Self-signed Certificate",
+ certificate_data=builder.certificate,
+ key_data=builder.private_key,
+ )
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ("passbook_crypto", "0001_initial"),
+ ]
+
+ operations = [migrations.RunPython(create_self_signed)]
diff --git a/passbook/flows/models.py b/passbook/flows/models.py
index de0147c99..ffb194386 100644
--- a/passbook/flows/models.py
+++ b/passbook/flows/models.py
@@ -3,12 +3,16 @@ from typing import Optional
from uuid import uuid4
from django.db import models
+from django.http import HttpRequest
from django.utils.translation import gettext_lazy as _
from model_utils.managers import InheritanceManager
+from structlog import get_logger
from passbook.core.types import UIUserSettings
from passbook.policies.models import PolicyBindingModel
+LOGGER = get_logger()
+
class FlowDesignation(models.TextChoices):
"""Designation of what a Flow should be used for. At a later point, this
@@ -62,10 +66,29 @@ class Flow(PolicyBindingModel):
PolicyBindingModel, parent_link=True, on_delete=models.CASCADE, related_name="+"
)
- def related_flow(self, designation: str) -> Optional["Flow"]:
+ @staticmethod
+ def with_policy(request: HttpRequest, **flow_filter) -> Optional["Flow"]:
+ """Get a Flow by `**flow_filter` and check if the request from `request` can access it."""
+ from passbook.policies.engine import PolicyEngine
+
+ flows = Flow.objects.filter(**flow_filter)
+ for flow in flows:
+ engine = PolicyEngine(flow, request.user, request)
+ engine.build()
+ result = engine.result
+ if result.passing:
+ LOGGER.debug("with_policy: flow passing", flow=flow)
+ return flow
+ LOGGER.warning(
+ "with_policy: flow not passing", flow=flow, messages=result.messages
+ )
+ LOGGER.debug("with_policy: no flow found", filters=flow_filter)
+ return None
+
+ def related_flow(self, designation: str, request: HttpRequest) -> Optional["Flow"]:
"""Get a related flow with `designation`. Currently this only queries
Flows by `designation`, but will eventually use `self` for related lookups."""
- return Flow.objects.filter(designation=designation).first()
+ return Flow.with_policy(request, designation=designation)
def __str__(self) -> str:
return f"Flow {self.name} ({self.slug})"
diff --git a/passbook/flows/planner.py b/passbook/flows/planner.py
index e44e378ee..262279434 100644
--- a/passbook/flows/planner.py
+++ b/passbook/flows/planner.py
@@ -11,7 +11,6 @@ from passbook.core.models import User
from passbook.flows.exceptions import EmptyFlowException, FlowNonApplicableException
from passbook.flows.models import Flow, Stage
from passbook.policies.engine import PolicyEngine
-from passbook.policies.types import PolicyResult
LOGGER = get_logger()
@@ -52,22 +51,12 @@ class FlowPlanner:
self.use_cache = True
self.flow = flow
- def _check_flow_root_policies(self, request: HttpRequest) -> PolicyResult:
- engine = PolicyEngine(self.flow, request.user, request)
- engine.build()
- return engine.result
-
def plan(
self, request: HttpRequest, default_context: Optional[Dict[str, Any]] = None
) -> FlowPlan:
"""Check each of the flows' policies, check policies for each stage with PolicyBinding
and return ordered list"""
LOGGER.debug("f(plan): Starting planning process", flow=self.flow)
- # First off, check the flow's direct policy bindings
- # to make sure the user even has access to the flow
- root_result = self._check_flow_root_policies(request)
- if not root_result.passing:
- raise FlowNonApplicableException(*root_result.messages)
# Bit of a workaround here, if there is a pending user set in the default context
# we use that user for our cache key
# to make sure they don't get the generic response
@@ -75,6 +64,16 @@ class FlowPlanner:
user = default_context[PLAN_CONTEXT_PENDING_USER]
else:
user = request.user
+ # First off, check the flow's direct policy bindings
+ # to make sure the user even has access to the flow
+ engine = PolicyEngine(self.flow, user, request)
+ if default_context:
+ engine.request.context = default_context
+ engine.build()
+ result = engine.result
+ if not result.passing:
+ raise FlowNonApplicableException(result.messages)
+ # User is passing so far, check if we have a cached plan
cached_plan_key = cache_key(self.flow, user)
cached_plan = cache.get(cached_plan_key, None)
if cached_plan and self.use_cache:
@@ -82,6 +81,7 @@ class FlowPlanner:
"f(plan): Taking plan from cache", flow=self.flow, key=cached_plan_key
)
return cached_plan
+ LOGGER.debug("f(plan): building plan", flow=self.flow)
plan = self._build_plan(user, request, default_context)
cache.set(cache_key(self.flow, user), plan)
if not plan.stages:
diff --git a/passbook/flows/tests/test_planner.py b/passbook/flows/tests/test_planner.py
index df0dc5a7e..af6ae98fa 100644
--- a/passbook/flows/tests/test_planner.py
+++ b/passbook/flows/tests/test_planner.py
@@ -1,5 +1,5 @@
"""flow planner tests"""
-from unittest.mock import MagicMock, patch
+from unittest.mock import MagicMock, PropertyMock, patch
from django.core.cache import cache
from django.shortcuts import reverse
@@ -13,7 +13,7 @@ from passbook.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner, cache
from passbook.policies.types import PolicyResult
from passbook.stages.dummy.models import DummyStage
-POLICY_RESULT_MOCK = MagicMock(return_value=PolicyResult(False))
+POLICY_RESULT_MOCK = PropertyMock(return_value=PolicyResult(False))
TIME_NOW_MOCK = MagicMock(return_value=3)
@@ -40,8 +40,7 @@ class TestFlowPlanner(TestCase):
planner.plan(request)
@patch(
- "passbook.flows.planner.FlowPlanner._check_flow_root_policies",
- POLICY_RESULT_MOCK,
+ "passbook.policies.engine.PolicyEngine.result", POLICY_RESULT_MOCK,
)
def test_non_applicable_plan(self):
"""Test that empty plan raises exception"""
diff --git a/passbook/flows/tests/test_views.py b/passbook/flows/tests/test_views.py
index e6a2ad20c..cacbe2004 100644
--- a/passbook/flows/tests/test_views.py
+++ b/passbook/flows/tests/test_views.py
@@ -1,5 +1,5 @@
"""flow views tests"""
-from unittest.mock import MagicMock, patch
+from unittest.mock import MagicMock, PropertyMock, patch
from django.shortcuts import reverse
from django.test import Client, TestCase
@@ -12,7 +12,7 @@ from passbook.lib.config import CONFIG
from passbook.policies.types import PolicyResult
from passbook.stages.dummy.models import DummyStage
-POLICY_RESULT_MOCK = MagicMock(return_value=PolicyResult(False))
+POLICY_RESULT_MOCK = PropertyMock(return_value=PolicyResult(False))
class TestFlowExecutor(TestCase):
@@ -45,8 +45,7 @@ class TestFlowExecutor(TestCase):
self.assertEqual(cancel_mock.call_count, 1)
@patch(
- "passbook.flows.planner.FlowPlanner._check_flow_root_policies",
- POLICY_RESULT_MOCK,
+ "passbook.policies.engine.PolicyEngine.result", POLICY_RESULT_MOCK,
)
def test_invalid_non_applicable_flow(self):
"""Tests that a non-applicable flow returns the correct error message"""
diff --git a/passbook/flows/views.py b/passbook/flows/views.py
index f955106db..5c8c71f11 100644
--- a/passbook/flows/views.py
+++ b/passbook/flows/views.py
@@ -1,7 +1,7 @@
"""passbook multi-stage authentication engine"""
from typing import Any, Dict, Optional
-from django.http import HttpRequest, HttpResponse
+from django.http import Http404, HttpRequest, HttpResponse
from django.shortcuts import get_object_or_404, redirect, reverse
from django.utils.decorators import method_decorator
from django.views.decorators.clickjacking import xframe_options_sameorigin
@@ -164,7 +164,9 @@ class ToDefaultFlow(View):
designation: Optional[FlowDesignation] = None
def dispatch(self, request: HttpRequest) -> HttpResponse:
- flow = get_object_or_404(Flow, designation=self.designation)
+ flow = Flow.with_policy(request, designation=self.designation)
+ if not flow:
+ raise Http404
# If user already has a pending plan, clear it so we don't have to later.
if SESSION_KEY_PLAN in self.request.session:
plan: FlowPlan = self.request.session[SESSION_KEY_PLAN]
diff --git a/passbook/policies/apps.py b/passbook/policies/apps.py
index 5795355b6..946f84609 100644
--- a/passbook/policies/apps.py
+++ b/passbook/policies/apps.py
@@ -1,4 +1,6 @@
"""passbook policies app config"""
+from importlib import import_module
+
from django.apps import AppConfig
@@ -8,3 +10,7 @@ class PassbookPoliciesConfig(AppConfig):
name = "passbook.policies"
label = "passbook_policies"
verbose_name = "passbook Policies"
+
+ def ready(self):
+ """Load source_types from config file"""
+ import_module("passbook.policies.signals")
diff --git a/passbook/policies/engine.py b/passbook/policies/engine.py
index 143ad6473..5db4f8cfc 100644
--- a/passbook/policies/engine.py
+++ b/passbook/policies/engine.py
@@ -73,16 +73,20 @@ class PolicyEngine:
"""Build task group"""
for binding in self._iter_bindings():
self._check_policy_type(binding.policy)
- policy = binding.policy
- cached_policy = cache.get(cache_key(binding, self.request.user), None)
+ key = cache_key(binding, self.request)
+ cached_policy = cache.get(key, None)
if cached_policy and self.use_cache:
- LOGGER.debug("P_ENG: Taking result from cache", policy=policy)
+ LOGGER.debug(
+ "P_ENG: Taking result from cache",
+ policy=binding.policy,
+ cache_key=key,
+ )
self.__cached_policies.append(cached_policy)
continue
- LOGGER.debug("P_ENG: Evaluating policy", policy=policy)
+ LOGGER.debug("P_ENG: Evaluating policy", policy=binding.policy)
our_end, task_end = Pipe(False)
task = PolicyProcess(binding, self.request, task_end)
- LOGGER.debug("P_ENG: Starting Process", policy=policy)
+ LOGGER.debug("P_ENG: Starting Process", policy=binding.policy)
task.start()
self.__processes.append(
PolicyProcessInfo(process=task, connection=our_end, binding=binding)
@@ -103,7 +107,9 @@ class PolicyEngine:
x.result for x in self.__processes if x.result
]
for result in process_results + self.__cached_policies:
- LOGGER.debug("P_ENG: result", passing=result.passing)
+ LOGGER.debug(
+ "P_ENG: result", passing=result.passing, messages=result.messages
+ )
if result.messages:
messages += result.messages
if not result.passing:
diff --git a/passbook/policies/process.py b/passbook/policies/process.py
index 1fb906c9f..a187627a6 100644
--- a/passbook/policies/process.py
+++ b/passbook/policies/process.py
@@ -6,7 +6,6 @@ from typing import Optional
from django.core.cache import cache
from structlog import get_logger
-from passbook.core.models import User
from passbook.policies.exceptions import PolicyException
from passbook.policies.models import PolicyBinding
from passbook.policies.types import PolicyRequest, PolicyResult
@@ -14,11 +13,13 @@ from passbook.policies.types import PolicyRequest, PolicyResult
LOGGER = get_logger()
-def cache_key(binding: PolicyBinding, user: Optional[User] = None) -> str:
+def cache_key(binding: PolicyBinding, request: PolicyRequest) -> str:
"""Generate Cache key for policy"""
prefix = f"policy_{binding.policy_binding_uuid.hex}_{binding.policy.pk.hex}"
- if user:
- prefix += f"#{user.pk}"
+ if request.http_request:
+ prefix += f"_{request.http_request.session.session_key}"
+ if request.user:
+ prefix += f"#{request.user.pk}"
return prefix
@@ -65,7 +66,7 @@ class PolicyProcess(Process):
passing=policy_result.passing,
user=self.request.user,
)
- key = cache_key(self.binding, self.request.user)
+ key = cache_key(self.binding, self.request)
cache.set(key, policy_result)
LOGGER.debug("P_ENG(proc): Cached policy evaluation", key=key)
return policy_result
diff --git a/passbook/policies/signals.py b/passbook/policies/signals.py
new file mode 100644
index 000000000..82e0b3d94
--- /dev/null
+++ b/passbook/policies/signals.py
@@ -0,0 +1,26 @@
+"""passbook policy signals"""
+from django.core.cache import cache
+from django.db.models.signals import post_save
+from django.dispatch import receiver
+from structlog import get_logger
+
+LOGGER = get_logger()
+
+
+@receiver(post_save)
+# pylint: disable=unused-argument
+def invalidate_policy_cache(sender, instance, **_):
+ """Invalidate Policy cache when policy is updated"""
+ from passbook.policies.models import Policy, PolicyBinding
+
+ if isinstance(instance, Policy):
+ LOGGER.debug("Invalidating policy cache", policy=instance)
+ total = 0
+ for binding in PolicyBinding.objects.filter(policy=instance):
+ prefix = (
+ f"policy_{binding.policy_binding_uuid.hex}_{binding.policy.pk.hex}*"
+ )
+ keys = cache.keys(prefix)
+ total += len(keys)
+ cache.delete_many(keys)
+ LOGGER.debug("Deleted keys", len=total)
diff --git a/passbook/providers/saml/migrations/0002_default_saml_property_mappings.py b/passbook/providers/saml/migrations/0002_default_saml_property_mappings.py
new file mode 100644
index 000000000..72575b6d6
--- /dev/null
+++ b/passbook/providers/saml/migrations/0002_default_saml_property_mappings.py
@@ -0,0 +1,63 @@
+# Generated by Django 3.0.6 on 2020-05-23 19:32
+
+from django.db import migrations
+
+
+def create_default_property_mappings(apps, schema_editor):
+ """Create default SAML Property Mappings"""
+ SAMLPropertyMapping = apps.get_model(
+ "passbook_providers_saml", "SAMLPropertyMapping"
+ )
+ db_alias = schema_editor.connection.alias
+ defaults = [
+ {
+ "FriendlyName": "eduPersonPrincipalName",
+ "Name": "urn:oid:1.3.6.1.4.1.5923.1.1.1.6",
+ "Expression": "{{ user.email }}",
+ },
+ {
+ "FriendlyName": "cn",
+ "Name": "urn:oid:2.5.4.3",
+ "Expression": "{{ user.name }}",
+ },
+ {
+ "FriendlyName": "mail",
+ "Name": "urn:oid:0.9.2342.19200300.100.1.3",
+ "Expression": "{{ user.email }}",
+ },
+ {
+ "FriendlyName": "displayName",
+ "Name": "urn:oid:2.16.840.1.113730.3.1.241",
+ "Expression": "{{ user.username }}",
+ },
+ {
+ "FriendlyName": "uid",
+ "Name": "urn:oid:0.9.2342.19200300.100.1.1",
+ "Expression": "{{ user.pk }}",
+ },
+ {
+ "FriendlyName": "member-of",
+ "Name": "member-of",
+ "Expression": "[{% for group in user.groups.all() %}'{{ group.name }}',{% endfor %}]",
+ },
+ ]
+ for default in defaults:
+ SAMLPropertyMapping.objects.using(db_alias).get_or_create(
+ saml_name=default["Name"],
+ friendly_name=default["FriendlyName"],
+ expression=default["Expression"],
+ defaults={
+ "name": f"Autogenerated SAML Mapping: {default['FriendlyName']} -> {default['Expression']}"
+ },
+ )
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ("passbook_providers_saml", "0001_initial"),
+ ]
+
+ operations = [
+ migrations.RunPython(create_default_property_mappings),
+ ]
diff --git a/passbook/sources/ldap/api.py b/passbook/sources/ldap/api.py
index a51a5ce12..e5ad2677c 100644
--- a/passbook/sources/ldap/api.py
+++ b/passbook/sources/ldap/api.py
@@ -23,6 +23,7 @@ class LDAPSourceSerializer(ModelSerializer):
"group_object_filter",
"user_group_membership_field",
"object_uniqueness_field",
+ "sync_users",
"sync_groups",
"sync_parent_group",
"property_mappings",
diff --git a/passbook/sources/ldap/connector.py b/passbook/sources/ldap/connector.py
index 064a6a628..748c25e9e 100644
--- a/passbook/sources/ldap/connector.py
+++ b/passbook/sources/ldap/connector.py
@@ -16,26 +16,10 @@ LOGGER = get_logger()
class Connector:
"""Wrapper for ldap3 to easily manage user authentication and creation"""
- _server: ldap3.Server
- _connection = ldap3.Connection
_source: LDAPSource
def __init__(self, source: LDAPSource):
self._source = source
- self._server = ldap3.Server(source.server_uri) # Implement URI parsing
-
- def bind(self):
- """Bind using Source's Credentials"""
- self._connection = ldap3.Connection(
- self._server,
- raise_exceptions=True,
- user=self._source.bind_cn,
- password=self._source.bind_password,
- )
-
- self._connection.bind()
- if self._source.start_tls:
- self._connection.start_tls()
@staticmethod
def encode_pass(password: str) -> bytes:
@@ -45,19 +29,23 @@ class Connector:
@property
def base_dn_users(self) -> str:
"""Shortcut to get full base_dn for user lookups"""
- return ",".join([self._source.additional_user_dn, self._source.base_dn])
+ if self._source.additional_user_dn:
+ return f"{self._source.additional_user_dn},{self._source.base_dn}"
+ return self._source.base_dn
@property
def base_dn_groups(self) -> str:
"""Shortcut to get full base_dn for group lookups"""
- return ",".join([self._source.additional_group_dn, self._source.base_dn])
+ if self._source.additional_group_dn:
+ return f"{self._source.additional_group_dn},{self._source.base_dn}"
+ return self._source.base_dn
def sync_groups(self):
"""Iterate over all LDAP Groups and create passbook_core.Group instances"""
if not self._source.sync_groups:
- LOGGER.debug("Group syncing is disabled for this Source")
+ LOGGER.warning("Group syncing is disabled for this Source")
return
- groups = self._connection.extend.standard.paged_search(
+ groups = self._source.connection.extend.standard.paged_search(
search_base=self.base_dn_groups,
search_filter=self._source.group_object_filter,
search_scope=ldap3.SUBTREE,
@@ -87,7 +75,10 @@ class Connector:
def sync_users(self):
"""Iterate over all LDAP Users and create passbook_core.User instances"""
- users = self._connection.extend.standard.paged_search(
+ if not self._source.sync_users:
+ LOGGER.warning("User syncing is disabled for this Source")
+ return
+ users = self._source.connection.extend.standard.paged_search(
search_base=self.base_dn_users,
search_filter=self._source.user_object_filter,
search_scope=ldap3.SUBTREE,
@@ -101,9 +92,9 @@ class Connector:
LOGGER.warning("Cannot find uniqueness Field in attributes")
continue
try:
+ defaults = self._build_object_properties(attributes)
user, created = User.objects.update_or_create(
- attributes__ldap_uniq=uniq,
- defaults=self._build_object_properties(attributes),
+ attributes__ldap_uniq=uniq, defaults=defaults,
)
except IntegrityError as exc:
LOGGER.warning("Failed to create user", exc=exc)
@@ -123,7 +114,7 @@ class Connector:
def sync_membership(self):
"""Iterate over all Users and assign Groups using memberOf Field"""
- users = self._connection.extend.standard.paged_search(
+ users = self._source.connection.extend.standard.paged_search(
search_base=self.base_dn_users,
search_filter=self._source.user_object_filter,
search_scope=ldap3.SUBTREE,
@@ -220,7 +211,7 @@ class Connector:
LOGGER.debug("Attempting Binding as user", user=user)
try:
temp_connection = ldap3.Connection(
- self._server,
+ self._source.connection.server,
user=user.attributes.get("distinguishedName"),
password=password,
raise_exceptions=True,
diff --git a/passbook/sources/ldap/forms.py b/passbook/sources/ldap/forms.py
index 249ebd5af..48d71d48a 100644
--- a/passbook/sources/ldap/forms.py
+++ b/passbook/sources/ldap/forms.py
@@ -26,6 +26,7 @@ class LDAPSourceForm(forms.ModelForm):
"group_object_filter",
"user_group_membership_field",
"object_uniqueness_field",
+ "sync_users",
"sync_groups",
"sync_parent_group",
"property_mappings",
diff --git a/passbook/sources/ldap/migrations/0002_ldapsource_sync_users.py b/passbook/sources/ldap/migrations/0002_ldapsource_sync_users.py
new file mode 100644
index 000000000..27a0da2b3
--- /dev/null
+++ b/passbook/sources/ldap/migrations/0002_ldapsource_sync_users.py
@@ -0,0 +1,18 @@
+# Generated by Django 3.0.6 on 2020-05-23 19:17
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ("passbook_sources_ldap", "0001_initial"),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name="ldapsource",
+ name="sync_users",
+ field=models.BooleanField(default=True),
+ ),
+ ]
diff --git a/passbook/sources/ldap/migrations/0003_default_ldap_property_mappings.py b/passbook/sources/ldap/migrations/0003_default_ldap_property_mappings.py
new file mode 100644
index 000000000..318952211
--- /dev/null
+++ b/passbook/sources/ldap/migrations/0003_default_ldap_property_mappings.py
@@ -0,0 +1,35 @@
+# Generated by Django 3.0.6 on 2020-05-23 19:30
+
+from django.apps.registry import Apps
+from django.db import migrations
+
+
+def create_default_ad_property_mappings(apps: Apps, schema_editor):
+ LDAPPropertyMapping = apps.get_model("passbook_sources_ldap", "LDAPPropertyMapping")
+ mapping = {
+ "name": "{{ ldap.name }}",
+ "first_name": "{{ ldap.givenName }}",
+ "last_name": "{{ ldap.sn }}",
+ "username": "{{ ldap.sAMAccountName }}",
+ "email": "{{ ldap.mail }}",
+ }
+ db_alias = schema_editor.connection.alias
+ for object_field, expression in mapping.items():
+ LDAPPropertyMapping.objects.using(db_alias).get_or_create(
+ expression=expression,
+ object_field=object_field,
+ defaults={
+ "name": f"Autogenerated LDAP Mapping: {expression} -> {object_field}"
+ },
+ )
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ("passbook_sources_ldap", "0002_ldapsource_sync_users"),
+ ]
+
+ operations = [
+ migrations.RunPython(create_default_ad_property_mappings),
+ ]
diff --git a/passbook/sources/ldap/models.py b/passbook/sources/ldap/models.py
index 393cccfa2..34fa96e56 100644
--- a/passbook/sources/ldap/models.py
+++ b/passbook/sources/ldap/models.py
@@ -1,8 +1,10 @@
"""passbook LDAP Models"""
+from typing import Optional
from django.core.validators import URLValidator
from django.db import models
from django.utils.translation import gettext_lazy as _
+from ldap3 import Connection, Server
from passbook.core.models import Group, PropertyMapping, Source
@@ -22,10 +24,12 @@ class LDAPSource(Source):
additional_user_dn = models.TextField(
help_text=_("Prepended to Base DN for User-queries."),
verbose_name=_("Addition User DN"),
+ blank=True,
)
additional_group_dn = models.TextField(
help_text=_("Prepended to Base DN for Group-queries."),
verbose_name=_("Addition Group DN"),
+ blank=True,
)
user_object_filter = models.TextField(
@@ -43,6 +47,7 @@ class LDAPSource(Source):
default="objectSid", help_text=_("Field which contains a unique Identifier.")
)
+ sync_users = models.BooleanField(default=True)
sync_groups = models.BooleanField(default=True)
sync_parent_group = models.ForeignKey(
Group, blank=True, null=True, default=None, on_delete=models.SET_DEFAULT
@@ -50,6 +55,25 @@ class LDAPSource(Source):
form = "passbook.sources.ldap.forms.LDAPSourceForm"
+ _connection: Optional[Connection]
+
+ @property
+ def connection(self) -> Connection:
+ """Get a fully connected and bound LDAP Connection"""
+ if not self._connection:
+ server = Server(self.server_uri)
+ self._connection = Connection(
+ server,
+ raise_exceptions=True,
+ user=self.bind_cn,
+ password=self.bind_password,
+ )
+
+ self._connection.bind()
+ if self.start_tls:
+ self._connection.start_tls()
+ return self._connection
+
class Meta:
verbose_name = _("LDAP Source")
diff --git a/passbook/sources/ldap/tasks.py b/passbook/sources/ldap/tasks.py
index 581d27c7a..eeb1cb282 100644
--- a/passbook/sources/ldap/tasks.py
+++ b/passbook/sources/ldap/tasks.py
@@ -9,7 +9,6 @@ def sync_groups(source_pk: int):
"""Sync LDAP Groups on background worker"""
source = LDAPSource.objects.get(pk=source_pk)
connector = Connector(source)
- connector.bind()
connector.sync_groups()
@@ -18,7 +17,6 @@ def sync_users(source_pk: int):
"""Sync LDAP Users on background worker"""
source = LDAPSource.objects.get(pk=source_pk)
connector = Connector(source)
- connector.bind()
connector.sync_users()
@@ -27,7 +25,6 @@ def sync():
"""Sync all sources"""
for source in LDAPSource.objects.filter(enabled=True):
connector = Connector(source)
- connector.bind()
connector.sync_users()
connector.sync_groups()
connector.sync_membership()
diff --git a/passbook/sources/ldap/tests.py b/passbook/sources/ldap/tests.py
new file mode 100644
index 000000000..faa3f4177
--- /dev/null
+++ b/passbook/sources/ldap/tests.py
@@ -0,0 +1,75 @@
+"""LDAP Source tests"""
+from unittest.mock import PropertyMock, patch
+
+from django.test import TestCase
+from ldap3 import MOCK_SYNC, OFFLINE_AD_2012_R2, Connection, Server
+
+from passbook.core.models import User
+from passbook.sources.ldap.connector import Connector
+from passbook.sources.ldap.models import LDAPPropertyMapping, LDAPSource
+
+
+def _build_mock_connection() -> Connection:
+ """Create mock connection"""
+ server = Server("my_fake_server", get_info=OFFLINE_AD_2012_R2)
+ _pass = "foo" # noqa # nosec
+ connection = Connection(
+ server,
+ user="cn=my_user,ou=test,o=lab",
+ password=_pass,
+ client_strategy=MOCK_SYNC,
+ )
+ connection.strategy.add_entry(
+ "cn=user0,ou=test,o=lab",
+ {
+ "userPassword": "test0000",
+ "sAMAccountName": "user0_sn",
+ "revision": 0,
+ "objectSid": "unique-test0000",
+ "objectCategory": "Person",
+ },
+ )
+ connection.strategy.add_entry(
+ "cn=user1,ou=test,o=lab",
+ {
+ "userPassword": "test1111",
+ "sAMAccountName": "user1_sn",
+ "revision": 0,
+ "objectSid": "unique-test1111",
+ "objectCategory": "Person",
+ },
+ )
+ connection.strategy.add_entry(
+ "cn=user2,ou=test,o=lab",
+ {
+ "userPassword": "test2222",
+ "sAMAccountName": "user2_sn",
+ "revision": 0,
+ "objectSid": "unique-test2222",
+ "objectCategory": "Person",
+ },
+ )
+ connection.bind()
+ return connection
+
+
+LDAP_CONNECTION_PATCH = PropertyMock(return_value=_build_mock_connection())
+
+
+class LDAPSourceTests(TestCase):
+ """LDAP Source tests"""
+
+ def setUp(self):
+ self.source = LDAPSource.objects.create(
+ name="ldap", slug="ldap", base_dn="o=lab"
+ )
+ self.source.property_mappings.set(LDAPPropertyMapping.objects.all())
+ self.source.save()
+
+ @patch("passbook.sources.ldap.models.LDAPSource.connection", LDAP_CONNECTION_PATCH)
+ def test_sync_users(self):
+ """Test user sync"""
+ connector = Connector(self.source)
+ connector.sync_users()
+ user = User.objects.filter(username="user2_sn")
+ self.assertTrue(user.exists())
diff --git a/passbook/sources/oauth/clients.py b/passbook/sources/oauth/clients.py
index 06a9310fc..35c58d7ba 100644
--- a/passbook/sources/oauth/clients.py
+++ b/passbook/sources/oauth/clients.py
@@ -1,6 +1,6 @@
"""OAuth Clients"""
import json
-from typing import Dict, Optional
+from typing import TYPE_CHECKING, Any, Dict, Optional
from urllib.parse import parse_qs, urlencode
from django.http import HttpRequest
@@ -14,24 +14,29 @@ from structlog import get_logger
from passbook import __version__
LOGGER = get_logger()
+if TYPE_CHECKING:
+ from passbook.sources.oauth.models import OAuthSource
class BaseOAuthClient:
"""Base OAuth Client"""
session: Session
+ source: "OAuthSource"
- def __init__(self, source, token=""): # nosec
+ def __init__(self, source: "OAuthSource", token=""): # nosec
self.source = source
self.token = token
self.session = Session()
self.session.headers.update({"User-Agent": "passbook %s" % __version__})
- def get_access_token(self, request, callback=None):
+ def get_access_token(
+ self, request: HttpRequest, callback=None
+ ) -> Optional[Dict[str, Any]]:
"Fetch access token from callback request."
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
- def get_profile_info(self, token: Dict[str, str]):
+ def get_profile_info(self, token: Dict[str, str]) -> Optional[Dict[str, Any]]:
"Fetch user profile information."
try:
headers = {
@@ -45,7 +50,7 @@ class BaseOAuthClient:
LOGGER.warning("Unable to fetch user profile", exc=exc)
return None
else:
- return response.json() or response.text
+ return response.json()
def get_redirect_args(self, request, callback) -> Dict[str, str]:
"Get request parameters for redirect url."
diff --git a/passbook/sources/oauth/views/core.py b/passbook/sources/oauth/views/core.py
index 7e3249bc7..9166ad3b6 100644
--- a/passbook/sources/oauth/views/core.py
+++ b/passbook/sources/oauth/views/core.py
@@ -21,7 +21,7 @@ from passbook.flows.planner import (
)
from passbook.flows.views import SESSION_KEY_PLAN
from passbook.lib.utils.urls import redirect_with_qs
-from passbook.sources.oauth.clients import get_client
+from passbook.sources.oauth.clients import BaseOAuthClient, get_client
from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from passbook.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND
@@ -34,7 +34,7 @@ class OAuthClientMixin:
client_class: Optional[Callable] = None
- def get_client(self, source):
+ def get_client(self, source: OAuthSource) -> BaseOAuthClient:
"Get instance of the OAuth client for this source."
if self.client_class is not None:
# pylint: disable=not-callable
diff --git a/passbook/stages/identification/forms.py b/passbook/stages/identification/forms.py
index 04217f7f8..882ce0f03 100644
--- a/passbook/stages/identification/forms.py
+++ b/passbook/stages/identification/forms.py
@@ -16,7 +16,7 @@ class IdentificationStageForm(forms.ModelForm):
class Meta:
model = IdentificationStage
- fields = ["name", "user_fields", "template"]
+ fields = ["name", "user_fields", "template", "enrollment_flow", "recovery_flow"]
widgets = {
"name": forms.TextInput(),
}
diff --git a/passbook/stages/user_write/tests.py b/passbook/stages/user_write/tests.py
index 5bad06809..d37012207 100644
--- a/passbook/stages/user_write/tests.py
+++ b/passbook/stages/user_write/tests.py
@@ -72,6 +72,7 @@ class TestUserWriteStage(TestCase):
plan.context[PLAN_CONTEXT_PROMPT] = {
"username": "test-user-new",
"password": new_password,
+ "some-custom-attribute": "test",
}
session = self.client.session
session[SESSION_KEY_PLAN] = plan
@@ -88,6 +89,7 @@ class TestUserWriteStage(TestCase):
)
self.assertTrue(user_qs.exists())
self.assertTrue(user_qs.first().check_password(new_password))
+ self.assertEqual(user_qs.first().attributes["some-custom-attribute"], "test")
def test_without_data(self):
"""Test without data results in error"""
diff --git a/swagger.yaml b/swagger.yaml
index ac38448f0..31c87f5be 100755
--- a/swagger.yaml
+++ b/swagger.yaml
@@ -5606,8 +5606,6 @@ definitions:
- bind_cn
- bind_password
- base_dn
- - additional_user_dn
- - additional_group_dn
type: object
properties:
pk:
@@ -5654,12 +5652,10 @@ definitions:
title: Addition User DN
description: Prepended to Base DN for User-queries.
type: string
- minLength: 1
additional_group_dn:
title: Addition Group DN
description: Prepended to Base DN for Group-queries.
type: string
- minLength: 1
user_object_filter:
title: User object filter
description: Consider Objects matching this filter to be Users.
@@ -5680,6 +5676,9 @@ definitions:
description: Field which contains a unique Identifier.
type: string
minLength: 1
+ sync_users:
+ title: Sync users
+ type: boolean
sync_groups:
title: Sync groups
type: boolean