From 77ed25ae34850012543d4b12f27226530c213e30 Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Tue, 3 Aug 2021 17:45:16 +0200 Subject: [PATCH] root: reformat to 100 line width Signed-off-by: Jens Langhammer --- authentik/admin/api/metrics.py | 7 +- authentik/admin/api/system.py | 4 +- authentik/admin/api/tasks.py | 5 +- authentik/admin/api/version.py | 4 +- authentik/admin/api/workers.py | 4 +- authentik/admin/tasks.py | 8 +-- authentik/admin/tests/test_api.py | 8 +-- authentik/api/decorators.py | 4 +- authentik/api/schema.py | 4 +- authentik/api/tests/test_auth.py | 8 +-- authentik/api/v2/urls.py | 26 ++------ authentik/core/api/applications.py | 8 +-- authentik/core/api/authenticated_sessions.py | 4 +- authentik/core/api/propertymappings.py | 10 +-- authentik/core/api/sources.py | 4 +- authentik/core/api/tokens.py | 8 +-- authentik/core/api/used_by.py | 4 +- authentik/core/api/users.py | 17 ++--- authentik/core/api/utils.py | 14 +--- authentik/core/migrations/0001_initial.py | 32 +++------ .../core/migrations/0003_default_user.py | 4 +- .../migrations/0006_auto_20200709_1608.py | 4 +- .../migrations/0007_auto_20200815_1841.py | 4 +- .../migrations/0009_group_is_superuser.py | 4 +- .../migrations/0014_auto_20201018_1158.py | 4 +- .../core/migrations/0015_application_icon.py | 4 +- .../migrations/0016_auto_20201202_2234.py | 4 +- .../migrations/0022_authenticatedsession.py | 8 +-- .../0025_alter_application_meta_icon.py | 4 +- authentik/core/models.py | 32 +++------ authentik/core/signals.py | 12 +--- authentik/core/sources/flow_manager.py | 30 ++------- authentik/core/tasks.py | 4 +- authentik/core/tests/test_applications_api.py | 12 +--- authentik/core/tests/test_impersonation.py | 4 +- authentik/core/tests/test_models.py | 4 +- authentik/core/tests/test_property_mapping.py | 4 +- .../core/tests/test_property_mapping_api.py | 4 +- authentik/core/tests/test_token_api.py | 7 +- authentik/core/views/impersonate.py | 9 +-- authentik/core/views/session.py | 4 +- authentik/crypto/api.py | 27 ++------ authentik/crypto/builder.py | 12 +--- authentik/crypto/models.py | 18 ++--- authentik/events/api/event.py | 4 +- authentik/events/middleware.py | 16 ++--- .../migrations/0003_auto_20200917_1155.py | 8 +-- .../0011_notification_rules_default_v1.py | 36 +++------- authentik/events/migrations/0014_expiry.py | 8 +-- .../events/migrations/0016_add_tenant.py | 4 +- authentik/events/models.py | 17 ++--- authentik/events/monitored_tasks.py | 11 +--- authentik/events/signals.py | 18 ++--- authentik/events/tasks.py | 28 ++------ authentik/events/tests/test_event.py | 4 +- authentik/events/tests/test_notifications.py | 12 +--- authentik/flows/api/flows.py | 29 ++------- .../flows/management/commands/benchmark.py | 4 +- .../flows/migrations/0008_default_flows.py | 20 ++---- .../flows/migrations/0009_source_flows.py | 16 ++--- .../flows/migrations/0010_provider_flows.py | 4 +- authentik/flows/migrations/0018_oob_flows.py | 12 +--- authentik/flows/models.py | 8 +-- authentik/flows/planner.py | 14 ++-- authentik/flows/signals.py | 4 +- authentik/flows/stage.py | 9 +-- authentik/flows/tests/test_planner.py | 12 +--- authentik/flows/tests/test_transfer.py | 4 +- authentik/flows/tests/test_views.py | 63 +++++------------- authentik/flows/tests/test_views_helper.py | 8 +-- authentik/flows/transfer/common.py | 4 +- authentik/flows/transfer/exporter.py | 18 ++--- authentik/flows/transfer/importer.py | 14 +--- authentik/flows/views.py | 39 +++-------- authentik/lib/config.py | 4 +- authentik/lib/models.py | 8 +-- authentik/lib/sentry.py | 6 +- authentik/lib/tests/test_evaluator.py | 4 +- authentik/lib/tests/test_http.py | 13 +--- authentik/lib/tests/test_sentry.py | 4 +- authentik/lib/utils/http.py | 11 +--- authentik/managed/tasks.py | 4 +- authentik/outposts/api/outposts.py | 7 +- authentik/outposts/api/service_connections.py | 10 +-- authentik/outposts/channels.py | 8 +-- authentik/outposts/controllers/docker.py | 22 ++----- .../outposts/controllers/k8s/deployment.py | 9 +-- authentik/outposts/controllers/k8s/secret.py | 13 +--- authentik/outposts/controllers/k8s/service.py | 9 +-- authentik/outposts/controllers/kubernetes.py | 4 +- .../migrations/0002_auto_20200826_1306.py | 4 +- .../0009_fix_missing_token_identifier.py | 4 +- .../migrations/0010_service_connection.py | 8 +-- .../migrations/0013_auto_20201203_2009.py | 4 +- .../migrations/0016_alter_outpost_type.py | 4 +- authentik/outposts/models.py | 15 ++--- authentik/outposts/signals.py | 10 +-- authentik/outposts/tasks.py | 16 ++--- authentik/outposts/tests/test_api.py | 4 +- authentik/policies/api/bindings.py | 6 +- authentik/policies/api/policies.py | 14 +--- authentik/policies/denied.py | 11 +--- authentik/policies/dummy/tests.py | 4 +- authentik/policies/engine.py | 27 ++------ authentik/policies/event_matcher/tests.py | 16 ++--- authentik/policies/expiry/models.py | 5 +- .../migrations/0002_auto_20200926_1156.py | 4 +- .../migrations/0003_auto_20201203_1223.py | 8 +-- authentik/policies/hibp/models.py | 8 +-- authentik/policies/models.py | 8 +-- authentik/policies/password/models.py | 8 +-- authentik/policies/reputation/api.py | 6 +- authentik/policies/reputation/signals.py | 5 +- authentik/policies/reputation/tasks.py | 13 +--- authentik/policies/reputation/tests.py | 12 +--- authentik/policies/signals.py | 4 +- authentik/policies/tests/test_bindings_api.py | 6 +- authentik/policies/tests/test_engine.py | 47 ++++---------- authentik/policies/tests/test_policies_api.py | 4 +- authentik/policies/tests/test_process.py | 16 ++--- authentik/policies/views.py | 4 +- authentik/providers/oauth2/api/provider.py | 9 +-- authentik/providers/oauth2/errors.py | 19 ++---- authentik/providers/oauth2/generators.py | 4 +- .../oauth2/migrations/0001_initial.py | 40 +++--------- .../oauth2/migrations/0011_managed.py | 4 +- authentik/providers/oauth2/models.py | 35 +++------- authentik/providers/oauth2/tests/test_api.py | 4 +- .../providers/oauth2/tests/test_authorize.py | 23 ++----- authentik/providers/oauth2/tests/test_jwks.py | 8 +-- .../providers/oauth2/tests/test_token.py | 55 ++++------------ .../providers/oauth2/tests/test_userinfo.py | 16 +---- authentik/providers/oauth2/tests/utils.py | 6 +- authentik/providers/oauth2/urls_github.py | 13 +--- authentik/providers/oauth2/utils.py | 4 +- authentik/providers/oauth2/views/authorize.py | 39 +++-------- .../providers/oauth2/views/introspection.py | 13 +--- authentik/providers/oauth2/views/jwks.py | 4 +- authentik/providers/oauth2/views/provider.py | 8 +-- authentik/providers/oauth2/views/token.py | 39 +++-------- authentik/providers/oauth2/views/userinfo.py | 20 ++---- authentik/providers/proxy/api.py | 4 +- .../proxy/controllers/k8s/ingress.py | 25 ++----- .../proxy/controllers/k8s/traefik.py | 5 +- .../providers/proxy/controllers/kubernetes.py | 4 +- .../proxy/migrations/0001_initial.py | 12 +--- .../0002_proxyprovider_cookie_secret.py | 4 +- .../migrations/0004_auto_20200913_1947.py | 12 +--- .../0011_proxyprovider_forward_auth_mode.py | 6 +- authentik/providers/proxy/models.py | 14 +--- authentik/providers/saml/api.py | 14 +--- .../providers/saml/migrations/0001_initial.py | 12 +--- .../migrations/0008_auto_20201112_1036.py | 4 +- .../providers/saml/migrations/0012_managed.py | 8 +-- authentik/providers/saml/models.py | 13 +--- .../providers/saml/processors/assertion.py | 20 ++---- .../providers/saml/processors/metadata.py | 12 +--- .../saml/processors/metadata_parser.py | 26 ++------ .../saml/processors/request_parser.py | 8 +-- .../saml/tests/test_auth_n_request.py | 17 ++--- .../providers/saml/tests/test_metadata.py | 16 ++--- authentik/providers/saml/tests/test_schema.py | 8 +-- authentik/providers/saml/views/flows.py | 12 +--- authentik/providers/saml/views/sso.py | 26 ++------ .../commands/create_recovery_key.py | 9 +-- authentik/recovery/tests.py | 8 +-- authentik/root/asgi.py | 8 +-- authentik/root/celery.py | 4 +- authentik/root/middleware.py | 8 +-- authentik/root/settings.py | 16 ++--- authentik/root/tests.py | 4 +- authentik/root/websocket.py | 4 +- authentik/sources/ldap/auth.py | 12 +--- .../sources/ldap/migrations/0001_initial.py | 6 +- .../migrations/0005_auto_20200913_1947.py | 6 +- .../sources/ldap/migrations/0008_managed.py | 4 +- ...0011_ldapsource_property_mappings_group.py | 4 +- authentik/sources/ldap/password.py | 8 +-- authentik/sources/ldap/signals.py | 4 +- authentik/sources/ldap/sync/base.py | 16 ++--- authentik/sources/ldap/sync/groups.py | 4 +- authentik/sources/ldap/sync/membership.py | 12 +--- authentik/sources/ldap/sync/users.py | 4 +- authentik/sources/ldap/tests/test_auth.py | 8 +-- authentik/sources/ldap/tests/test_password.py | 8 +-- authentik/sources/ldap/tests/test_sync.py | 8 +-- authentik/sources/oauth/api/source.py | 4 +- authentik/sources/oauth/clients/base.py | 4 +- .../sources/oauth/migrations/0001_initial.py | 4 +- authentik/sources/oauth/types/azure_ad.py | 4 +- authentik/sources/oauth/types/twitter.py | 3 +- authentik/sources/oauth/views/callback.py | 4 +- .../migrations/0002_auto_20210505_1717.py | 4 +- authentik/sources/plex/plex.py | 8 +-- authentik/sources/plex/tasks.py | 4 +- authentik/sources/plex/tests.py | 8 +-- ...0010_samlsource_pre_authentication_flow.py | 4 +- .../sources/saml/processors/constants.py | 8 +-- authentik/sources/saml/processors/metadata.py | 12 +--- authentik/sources/saml/processors/request.py | 18 ++--- authentik/sources/saml/processors/response.py | 13 +--- authentik/sources/saml/tasks.py | 16 ++--- authentik/sources/saml/tests/test_metadata.py | 16 ++--- authentik/sources/saml/views.py | 9 +-- .../migrations/0001_initial.py | 4 +- authentik/stages/authenticator_duo/models.py | 4 +- authentik/stages/authenticator_duo/stage.py | 12 +--- .../migrations/0005_default_setup_flow.py | 4 +- .../stages/authenticator_static/models.py | 8 +-- .../stages/authenticator_static/signals.py | 4 +- .../stages/authenticator_static/stage.py | 10 +-- .../migrations/0006_default_setup_flow.py | 4 +- authentik/stages/authenticator_totp/models.py | 4 +- .../migrations/0004_auto_20210301_0949.py | 4 +- .../stages/authenticator_validate/models.py | 8 +-- .../stages/authenticator_validate/stage.py | 27 ++------ .../stages/authenticator_validate/tests.py | 13 +--- .../stages/authenticator_webauthn/api.py | 5 +- .../migrations/0002_default_setup_flow.py | 4 +- .../0003_webauthndevice_confirmed.py | 4 +- .../stages/authenticator_webauthn/models.py | 8 +-- .../stages/authenticator_webauthn/stage.py | 26 ++------ authentik/stages/captcha/models.py | 8 +-- authentik/stages/captcha/stage.py | 4 +- authentik/stages/captcha/tests.py | 16 ++--- .../migrations/0002_auto_20200720_0941.py | 4 +- authentik/stages/consent/models.py | 9 +-- authentik/stages/consent/stage.py | 8 +-- authentik/stages/consent/tests.py | 24 ++----- authentik/stages/deny/tests.py | 8 +-- authentik/stages/dummy/tests.py | 4 +- authentik/stages/email/apps.py | 4 +- authentik/stages/email/models.py | 4 +- authentik/stages/email/stage.py | 4 +- authentik/stages/email/tasks.py | 4 +- authentik/stages/email/tests/test_sending.py | 32 +++------ authentik/stages/email/tests/test_stage.py | 40 +++--------- .../0009_identificationstage_sources.py | 4 +- authentik/stages/identification/models.py | 12 +--- authentik/stages/identification/stage.py | 21 ++---- authentik/stages/identification/tests.py | 36 +++------- authentik/stages/invitation/stage.py | 8 +-- authentik/stages/invitation/tests.py | 32 +++------ authentik/stages/password/stage.py | 16 ++--- authentik/stages/password/tests.py | 60 +++++------------ authentik/stages/prompt/models.py | 7 +- authentik/stages/prompt/stage.py | 16 ++--- authentik/stages/prompt/tests.py | 52 ++++----------- authentik/stages/user_delete/tests.py | 12 +--- authentik/stages/user_login/tests.py | 16 ++--- authentik/stages/user_logout/tests.py | 8 +-- authentik/stages/user_write/stage.py | 10 ++- authentik/stages/user_write/tests.py | 53 ++++----------- authentik/tenants/migrations/0001_initial.py | 4 +- authentik/tenants/models.py | 4 +- authentik/tenants/tests.py | 8 +-- lifecycle/migrate.py | 4 +- lifecycle/system_migrations/to_0_10.py | 4 +- pyproject.toml | 3 +- tests/e2e/test_flows_authenticators.py | 40 ++++-------- tests/e2e/test_flows_enroll.py | 38 ++++------- tests/e2e/test_flows_stage_setup.py | 12 ++-- tests/e2e/test_provider_ldap.py | 21 +----- tests/e2e/test_provider_oauth2_github.py | 45 ++++--------- tests/e2e/test_provider_oauth2_grafana.py | 65 +++++-------------- tests/e2e/test_provider_oauth2_oidc.py | 21 ++---- .../e2e/test_provider_oauth2_oidc_implicit.py | 21 ++---- tests/e2e/test_provider_proxy.py | 23 ++----- tests/e2e/test_provider_saml.py | 34 ++-------- tests/e2e/test_source_oauth.py | 49 ++++---------- tests/e2e/test_source_saml.py | 36 +++------- tests/e2e/utils.py | 46 ++++--------- 272 files changed, 825 insertions(+), 2590 deletions(-) diff --git a/authentik/admin/api/metrics.py b/authentik/admin/api/metrics.py index 88b5b007c..739fdba31 100644 --- a/authentik/admin/api/metrics.py +++ b/authentik/admin/api/metrics.py @@ -23,9 +23,7 @@ def get_events_per_1h(**filter_kwargs) -> list[dict[str, int]]: date_from = now() - timedelta(days=1) result = ( Event.objects.filter(created__gte=date_from, **filter_kwargs) - .annotate( - age=ExpressionWrapper(now() - F("created"), output_field=DurationField()) - ) + .annotate(age=ExpressionWrapper(now() - F("created"), output_field=DurationField())) .annotate(age_hours=ExtractHour("age")) .values("age_hours") .annotate(count=Count("pk")) @@ -37,8 +35,7 @@ def get_events_per_1h(**filter_kwargs) -> list[dict[str, int]]: for hour in range(0, -24, -1): results.append( { - "x_cord": time.mktime((_now + timedelta(hours=hour)).timetuple()) - * 1000, + "x_cord": time.mktime((_now + timedelta(hours=hour)).timetuple()) * 1000, "y_cord": data[hour * -1], } ) diff --git a/authentik/admin/api/system.py b/authentik/admin/api/system.py index f531825fe..eb0c65d12 100644 --- a/authentik/admin/api/system.py +++ b/authentik/admin/api/system.py @@ -61,9 +61,7 @@ class SystemSerializer(PassiveSerializer): return { "python_version": python_version, "gunicorn_version": ".".join(str(x) for x in gunicorn_version), - "environment": "kubernetes" - if SERVICE_HOST_ENV_NAME in os.environ - else "compose", + "environment": "kubernetes" if SERVICE_HOST_ENV_NAME in os.environ else "compose", "architecture": platform.machine(), "platform": platform.platform(), "uname": " ".join(platform.uname()), diff --git a/authentik/admin/api/tasks.py b/authentik/admin/api/tasks.py index 5ea6e3678..2cfc58efc 100644 --- a/authentik/admin/api/tasks.py +++ b/authentik/admin/api/tasks.py @@ -92,10 +92,7 @@ class TaskViewSet(ViewSet): task_func.delay(*task.task_call_args, **task.task_call_kwargs) messages.success( self.request, - _( - "Successfully re-scheduled Task %(name)s!" - % {"name": task.task_name} - ), + _("Successfully re-scheduled Task %(name)s!" % {"name": task.task_name}), ) return Response(status=204) except ImportError: # pragma: no cover diff --git a/authentik/admin/api/version.py b/authentik/admin/api/version.py index 49b07f9be..8ea95e094 100644 --- a/authentik/admin/api/version.py +++ b/authentik/admin/api/version.py @@ -41,9 +41,7 @@ class VersionSerializer(PassiveSerializer): def get_outdated(self, instance) -> bool: """Check if we're running the latest version""" - return parse(self.get_version_current(instance)) < parse( - self.get_version_latest(instance) - ) + return parse(self.get_version_current(instance)) < parse(self.get_version_latest(instance)) class VersionView(APIView): diff --git a/authentik/admin/api/workers.py b/authentik/admin/api/workers.py index ff9b7c5e2..c4fe3f5ba 100644 --- a/authentik/admin/api/workers.py +++ b/authentik/admin/api/workers.py @@ -17,9 +17,7 @@ class WorkerView(APIView): permission_classes = [IsAdminUser] - @extend_schema( - responses=inline_serializer("Workers", fields={"count": IntegerField()}) - ) + @extend_schema(responses=inline_serializer("Workers", fields={"count": IntegerField()})) def get(self, request: Request) -> Response: """Get currently connected worker count.""" count = len(CELERY_APP.control.ping(timeout=0.5)) diff --git a/authentik/admin/tasks.py b/authentik/admin/tasks.py index 26fe1bbe3..9b8d1bff4 100644 --- a/authentik/admin/tasks.py +++ b/authentik/admin/tasks.py @@ -37,18 +37,14 @@ def _set_prom_info(): def update_latest_version(self: MonitoredTask): """Update latest version info""" try: - response = get( - "https://api.github.com/repos/goauthentik/authentik/releases/latest" - ) + response = get("https://api.github.com/repos/goauthentik/authentik/releases/latest") response.raise_for_status() data = response.json() tag_name = data.get("tag_name") upstream_version = tag_name.split("/")[1] cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT) self.set_status( - TaskResult( - TaskResultStatus.SUCCESSFUL, ["Successfully updated latest Version"] - ) + TaskResult(TaskResultStatus.SUCCESSFUL, ["Successfully updated latest Version"]) ) _set_prom_info() # Check if upstream version is newer than what we're running, diff --git a/authentik/admin/tests/test_api.py b/authentik/admin/tests/test_api.py index dbcbd647b..b04e879c9 100644 --- a/authentik/admin/tests/test_api.py +++ b/authentik/admin/tests/test_api.py @@ -27,9 +27,7 @@ class TestAdminAPI(TestCase): response = self.client.get(reverse("authentik_api:admin_system_tasks-list")) self.assertEqual(response.status_code, 200) body = loads(response.content) - self.assertTrue( - any(task["task_name"] == "clean_expired_models" for task in body) - ) + self.assertTrue(any(task["task_name"] == "clean_expired_models" for task in body)) def test_tasks_single(self): """Test Task API (read single)""" @@ -45,9 +43,7 @@ class TestAdminAPI(TestCase): self.assertEqual(body["status"], TaskResultStatus.SUCCESSFUL.name) self.assertEqual(body["task_name"], "clean_expired_models") response = self.client.get( - reverse( - "authentik_api:admin_system_tasks-detail", kwargs={"pk": "qwerqwer"} - ) + reverse("authentik_api:admin_system_tasks-detail", kwargs={"pk": "qwerqwer"}) ) self.assertEqual(response.status_code, 404) diff --git a/authentik/api/decorators.py b/authentik/api/decorators.py index 00a53ed0f..2c253b234 100644 --- a/authentik/api/decorators.py +++ b/authentik/api/decorators.py @@ -7,9 +7,7 @@ from rest_framework.response import Response from rest_framework.viewsets import ModelViewSet -def permission_required( - perm: Optional[str] = None, other_perms: Optional[list[str]] = None -): +def permission_required(perm: Optional[str] = None, other_perms: Optional[list[str]] = None): """Check permissions for a single custom action""" def wrapper_outter(func: Callable): diff --git a/authentik/api/schema.py b/authentik/api/schema.py index 994ed5af7..976706c01 100644 --- a/authentik/api/schema.py +++ b/authentik/api/schema.py @@ -63,9 +63,7 @@ def postprocess_schema_responses(result, generator, **kwargs): # noqa: W0613 method["responses"].setdefault("400", validation_error.ref) method["responses"].setdefault("403", generic_error.ref) - result["components"] = generator.registry.build( - spectacular_settings.APPEND_COMPONENTS - ) + result["components"] = generator.registry.build(spectacular_settings.APPEND_COMPONENTS) # This is a workaround for authentik/stages/prompt/stage.py # since the serializer PromptChallengeResponse diff --git a/authentik/api/tests/test_auth.py b/authentik/api/tests/test_auth.py index 1f22059fc..b965487c9 100644 --- a/authentik/api/tests/test_auth.py +++ b/authentik/api/tests/test_auth.py @@ -16,17 +16,13 @@ class TestAPIAuth(TestCase): def test_valid_basic(self): """Test valid token""" - token = Token.objects.create( - intent=TokenIntents.INTENT_API, user=get_anonymous_user() - ) + token = Token.objects.create(intent=TokenIntents.INTENT_API, user=get_anonymous_user()) auth = b64encode(f":{token.key}".encode()).decode() self.assertEqual(bearer_auth(f"Basic {auth}".encode()), token.user) def test_valid_bearer(self): """Test valid token""" - token = Token.objects.create( - intent=TokenIntents.INTENT_API, user=get_anonymous_user() - ) + token = Token.objects.create(intent=TokenIntents.INTENT_API, user=get_anonymous_user()) self.assertEqual(bearer_auth(f"Bearer {token.key}".encode()), token.user) def test_invalid_type(self): diff --git a/authentik/api/v2/urls.py b/authentik/api/v2/urls.py index c1b20d572..64ef4de66 100644 --- a/authentik/api/v2/urls.py +++ b/authentik/api/v2/urls.py @@ -52,20 +52,12 @@ from authentik.policies.reputation.api import ( from authentik.providers.ldap.api import LDAPOutpostConfigViewSet, LDAPProviderViewSet from authentik.providers.oauth2.api.provider import OAuth2ProviderViewSet from authentik.providers.oauth2.api.scope import ScopeMappingViewSet -from authentik.providers.oauth2.api.tokens import ( - AuthorizationCodeViewSet, - RefreshTokenViewSet, -) -from authentik.providers.proxy.api import ( - ProxyOutpostConfigViewSet, - ProxyProviderViewSet, -) +from authentik.providers.oauth2.api.tokens import AuthorizationCodeViewSet, RefreshTokenViewSet +from authentik.providers.proxy.api import ProxyOutpostConfigViewSet, ProxyProviderViewSet from authentik.providers.saml.api import SAMLPropertyMappingViewSet, SAMLProviderViewSet from authentik.sources.ldap.api import LDAPPropertyMappingViewSet, LDAPSourceViewSet from authentik.sources.oauth.api.source import OAuthSourceViewSet -from authentik.sources.oauth.api.source_connection import ( - UserOAuthSourceConnectionViewSet, -) +from authentik.sources.oauth.api.source_connection import UserOAuthSourceConnectionViewSet from authentik.sources.plex.api import PlexSourceViewSet from authentik.sources.saml.api import SAMLSourceViewSet from authentik.stages.authenticator_duo.api import ( @@ -83,9 +75,7 @@ from authentik.stages.authenticator_totp.api import ( TOTPAdminDeviceViewSet, TOTPDeviceViewSet, ) -from authentik.stages.authenticator_validate.api import ( - AuthenticatorValidateStageViewSet, -) +from authentik.stages.authenticator_validate.api import AuthenticatorValidateStageViewSet from authentik.stages.authenticator_webauthn.api import ( AuthenticateWebAuthnStageViewSet, WebAuthnAdminDeviceViewSet, @@ -122,9 +112,7 @@ router.register("core/tenants", TenantViewSet) router.register("outposts/instances", OutpostViewSet) router.register("outposts/service_connections/all", ServiceConnectionViewSet) router.register("outposts/service_connections/docker", DockerServiceConnectionViewSet) -router.register( - "outposts/service_connections/kubernetes", KubernetesServiceConnectionViewSet -) +router.register("outposts/service_connections/kubernetes", KubernetesServiceConnectionViewSet) router.register("outposts/proxy", ProxyOutpostConfigViewSet) router.register("outposts/ldap", LDAPOutpostConfigViewSet) @@ -184,9 +172,7 @@ router.register( StaticAdminDeviceViewSet, basename="admin-staticdevice", ) -router.register( - "authenticators/admin/totp", TOTPAdminDeviceViewSet, basename="admin-totpdevice" -) +router.register("authenticators/admin/totp", TOTPAdminDeviceViewSet, basename="admin-totpdevice") router.register( "authenticators/admin/webauthn", WebAuthnAdminDeviceViewSet, diff --git a/authentik/core/api/applications.py b/authentik/core/api/applications.py index 4a640e1ff..7c9dced9d 100644 --- a/authentik/core/api/applications.py +++ b/authentik/core/api/applications.py @@ -147,9 +147,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): """Custom list method that checks Policy based access instead of guardian""" should_cache = request.GET.get("search", "") == "" - superuser_full_list = ( - str(request.GET.get("superuser_full_list", "false")).lower() == "true" - ) + superuser_full_list = str(request.GET.get("superuser_full_list", "false")).lower() == "true" if superuser_full_list and request.user.is_superuser: return super().list(request) @@ -240,9 +238,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): app.save() return Response({}) - @permission_required( - "authentik_core.view_application", ["authentik_events.view_event"] - ) + @permission_required("authentik_core.view_application", ["authentik_events.view_event"]) @extend_schema(responses={200: CoordinateSerializer(many=True)}) @action(detail=True, pagination_class=None, filter_backends=[]) # pylint: disable=unused-argument diff --git a/authentik/core/api/authenticated_sessions.py b/authentik/core/api/authenticated_sessions.py index 55989fa04..1745811a3 100644 --- a/authentik/core/api/authenticated_sessions.py +++ b/authentik/core/api/authenticated_sessions.py @@ -68,9 +68,7 @@ class AuthenticatedSessionSerializer(ModelSerializer): """Get parsed user agent""" return user_agent_parser.Parse(instance.last_user_agent) - def get_geo_ip( - self, instance: AuthenticatedSession - ) -> Optional[GeoIPDict]: # pragma: no cover + def get_geo_ip(self, instance: AuthenticatedSession) -> Optional[GeoIPDict]: # pragma: no cover """Get parsed user agent""" return GEOIP_READER.city_dict(instance.last_ip) diff --git a/authentik/core/api/propertymappings.py b/authentik/core/api/propertymappings.py index 2593dc089..16c9bbd55 100644 --- a/authentik/core/api/propertymappings.py +++ b/authentik/core/api/propertymappings.py @@ -15,11 +15,7 @@ from rest_framework.viewsets import GenericViewSet from authentik.api.decorators import permission_required from authentik.core.api.used_by import UsedByMixin -from authentik.core.api.utils import ( - MetaNameSerializer, - PassiveSerializer, - TypeCreateSerializer, -) +from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer from authentik.core.expression import PropertyMappingEvaluator from authentik.core.models import PropertyMapping from authentik.lib.utils.reflection import all_subclasses @@ -141,9 +137,7 @@ class PropertyMappingViewSet( self.request, **test_params.validated_data.get("context", {}), ) - response_data["result"] = dumps( - result, indent=(4 if format_result else None) - ) + response_data["result"] = dumps(result, indent=(4 if format_result else None)) except Exception as exc: # pylint: disable=broad-except response_data["result"] = str(exc) response_data["successful"] = False diff --git a/authentik/core/api/sources.py b/authentik/core/api/sources.py index 97f966cf9..73bf16b76 100644 --- a/authentik/core/api/sources.py +++ b/authentik/core/api/sources.py @@ -93,9 +93,7 @@ class SourceViewSet( @action(detail=False, pagination_class=None, filter_backends=[]) def user_settings(self, request: Request) -> Response: """Get all sources the user can configure""" - _all_sources: Iterable[Source] = Source.objects.filter( - enabled=True - ).select_subclasses() + _all_sources: Iterable[Source] = Source.objects.filter(enabled=True).select_subclasses() matching_sources: list[UserSettingSerializer] = [] for source in _all_sources: user_settings = source.ui_user_settings diff --git a/authentik/core/api/tokens.py b/authentik/core/api/tokens.py index e822c02c3..f707e037a 100644 --- a/authentik/core/api/tokens.py +++ b/authentik/core/api/tokens.py @@ -70,9 +70,7 @@ class TokenViewSet(UsedByMixin, ModelViewSet): serializer.save( user=self.request.user, intent=TokenIntents.INTENT_API, - expiring=self.request.user.attributes.get( - USER_ATTRIBUTE_TOKEN_EXPIRING, True - ), + expiring=self.request.user.attributes.get(USER_ATTRIBUTE_TOKEN_EXPIRING, True), ) @permission_required("authentik_core.view_token_key") @@ -89,7 +87,5 @@ class TokenViewSet(UsedByMixin, ModelViewSet): token: Token = self.get_object() if token.is_expired: raise Http404 - Event.new(EventAction.SECRET_VIEW, secret=token).from_http( # noqa # nosec - request - ) + Event.new(EventAction.SECRET_VIEW, secret=token).from_http(request) # noqa # nosec return Response(TokenViewSerializer({"key": token.key}).data) diff --git a/authentik/core/api/used_by.py b/authentik/core/api/used_by.py index b1143d989..b507653b4 100644 --- a/authentik/core/api/used_by.py +++ b/authentik/core/api/used_by.py @@ -79,9 +79,7 @@ class UsedByMixin: ).all(): # Only merge shadows on first object if first_object: - shadows += getattr( - manager.model._meta, "authentik_used_by_shadows", [] - ) + shadows += getattr(manager.model._meta, "authentik_used_by_shadows", []) first_object = False serializer = UsedBySerializer( data={ diff --git a/authentik/core/api/users.py b/authentik/core/api/users.py index 6593fb937..de6d0f9e8 100644 --- a/authentik/core/api/users.py +++ b/authentik/core/api/users.py @@ -26,10 +26,7 @@ from authentik.api.decorators import permission_required from authentik.core.api.groups import GroupSerializer from authentik.core.api.used_by import UsedByMixin from authentik.core.api.utils import LinkSerializer, PassiveSerializer, is_dict -from authentik.core.middleware import ( - SESSION_IMPERSONATE_ORIGINAL_USER, - SESSION_IMPERSONATE_USER, -) +from authentik.core.middleware import SESSION_IMPERSONATE_ORIGINAL_USER, SESSION_IMPERSONATE_USER from authentik.core.models import Token, TokenIntents, User from authentik.events.models import EventAction from authentik.tenants.models import Tenant @@ -87,17 +84,13 @@ class UserMetricsSerializer(PassiveSerializer): def get_logins_failed_per_1h(self, _): """Get failed logins per hour for the last 24 hours""" user = self.context["user"] - return get_events_per_1h( - action=EventAction.LOGIN_FAILED, context__username=user.username - ) + return get_events_per_1h(action=EventAction.LOGIN_FAILED, context__username=user.username) @extend_schema_field(CoordinateSerializer(many=True)) def get_authorizations_per_1h(self, _): """Get failed logins per hour for the last 24 hours""" user = self.context["user"] - return get_events_per_1h( - action=EventAction.AUTHORIZE_APPLICATION, user__pk=user.pk - ) + return get_events_per_1h(action=EventAction.AUTHORIZE_APPLICATION, user__pk=user.pk) class UsersFilter(FilterSet): @@ -154,9 +147,7 @@ class UserViewSet(UsedByMixin, ModelViewSet): # pylint: disable=invalid-name def me(self, request: Request) -> Response: """Get information about current user""" - serializer = SessionUserSerializer( - data={"user": UserSerializer(request.user).data} - ) + serializer = SessionUserSerializer(data={"user": UserSerializer(request.user).data}) if SESSION_IMPERSONATE_USER in request._request.session: serializer.initial_data["original"] = UserSerializer( request._request.session[SESSION_IMPERSONATE_ORIGINAL_USER] diff --git a/authentik/core/api/utils.py b/authentik/core/api/utils.py index 87aa87260..1ebd8eea0 100644 --- a/authentik/core/api/utils.py +++ b/authentik/core/api/utils.py @@ -3,20 +3,14 @@ from typing import Any from django.db.models import Model from rest_framework.fields import CharField, IntegerField -from rest_framework.serializers import ( - Serializer, - SerializerMethodField, - ValidationError, -) +from rest_framework.serializers import Serializer, SerializerMethodField, ValidationError def is_dict(value: Any): """Ensure a value is a dictionary, useful for JSONFields""" if isinstance(value, dict): return - raise ValidationError( - "Value must be a dictionary, and not have any duplicate keys." - ) + raise ValidationError("Value must be a dictionary, and not have any duplicate keys.") class PassiveSerializer(Serializer): @@ -25,9 +19,7 @@ class PassiveSerializer(Serializer): def create(self, validated_data: dict) -> Model: # pragma: no cover return Model() - def update( - self, instance: Model, validated_data: dict - ) -> Model: # pragma: no cover + def update(self, instance: Model, validated_data: dict) -> Model: # pragma: no cover return Model() class Meta: diff --git a/authentik/core/migrations/0001_initial.py b/authentik/core/migrations/0001_initial.py index e79bbbfd1..fb4e18297 100644 --- a/authentik/core/migrations/0001_initial.py +++ b/authentik/core/migrations/0001_initial.py @@ -38,9 +38,7 @@ class Migration(migrations.Migration): ("password", models.CharField(max_length=128, verbose_name="password")), ( "last_login", - models.DateTimeField( - blank=True, null=True, verbose_name="last login" - ), + models.DateTimeField(blank=True, null=True, verbose_name="last login"), ), ( "is_superuser", @@ -53,35 +51,25 @@ class Migration(migrations.Migration): ( "username", models.CharField( - error_messages={ - "unique": "A user with that username already exists." - }, + error_messages={"unique": "A user with that username already exists."}, help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.", max_length=150, unique=True, - validators=[ - django.contrib.auth.validators.UnicodeUsernameValidator() - ], + validators=[django.contrib.auth.validators.UnicodeUsernameValidator()], verbose_name="username", ), ), ( "first_name", - models.CharField( - blank=True, max_length=30, verbose_name="first name" - ), + models.CharField(blank=True, max_length=30, verbose_name="first name"), ), ( "last_name", - models.CharField( - blank=True, max_length=150, verbose_name="last name" - ), + models.CharField(blank=True, max_length=150, verbose_name="last name"), ), ( "email", - models.EmailField( - blank=True, max_length=254, verbose_name="email address" - ), + models.EmailField(blank=True, max_length=254, verbose_name="email address"), ), ( "is_staff", @@ -217,9 +205,7 @@ class Migration(migrations.Migration): ), ( "expires", - models.DateTimeField( - default=authentik.core.models.default_token_duration - ), + models.DateTimeField(default=authentik.core.models.default_token_duration), ), ("expiring", models.BooleanField(default=True)), ("description", models.TextField(blank=True, default="")), @@ -306,9 +292,7 @@ class Migration(migrations.Migration): ("name", models.TextField(help_text="Application's display Name.")), ( "slug", - models.SlugField( - help_text="Internal application name, used in URLs." - ), + models.SlugField(help_text="Internal application name, used in URLs."), ), ("skip_authorization", models.BooleanField(default=False)), ("meta_launch_url", models.URLField(blank=True, default="")), diff --git a/authentik/core/migrations/0003_default_user.py b/authentik/core/migrations/0003_default_user.py index a955ac60c..6d45f6e39 100644 --- a/authentik/core/migrations/0003_default_user.py +++ b/authentik/core/migrations/0003_default_user.py @@ -17,9 +17,7 @@ def create_default_user(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): username="akadmin", email="root@localhost", name="authentik Default Admin" ) if "TF_BUILD" in environ or "AK_ADMIN_PASS" in environ or settings.TEST: - akadmin.set_password( - environ.get("AK_ADMIN_PASS", "akadmin"), signal=False - ) # noqa # nosec + akadmin.set_password(environ.get("AK_ADMIN_PASS", "akadmin"), signal=False) # noqa # nosec else: akadmin.set_unusable_password() akadmin.save() diff --git a/authentik/core/migrations/0006_auto_20200709_1608.py b/authentik/core/migrations/0006_auto_20200709_1608.py index 2dec93721..7b5de5fcf 100644 --- a/authentik/core/migrations/0006_auto_20200709_1608.py +++ b/authentik/core/migrations/0006_auto_20200709_1608.py @@ -13,8 +13,6 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="source", name="slug", - field=models.SlugField( - help_text="Internal source name, used in URLs.", unique=True - ), + field=models.SlugField(help_text="Internal source name, used in URLs.", unique=True), ), ] diff --git a/authentik/core/migrations/0007_auto_20200815_1841.py b/authentik/core/migrations/0007_auto_20200815_1841.py index 51fe03d1e..a7f0de6e5 100644 --- a/authentik/core/migrations/0007_auto_20200815_1841.py +++ b/authentik/core/migrations/0007_auto_20200815_1841.py @@ -13,8 +13,6 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="user", name="first_name", - field=models.CharField( - blank=True, max_length=150, verbose_name="first name" - ), + field=models.CharField(blank=True, max_length=150, verbose_name="first name"), ), ] diff --git a/authentik/core/migrations/0009_group_is_superuser.py b/authentik/core/migrations/0009_group_is_superuser.py index 37133587e..9b4d43fdd 100644 --- a/authentik/core/migrations/0009_group_is_superuser.py +++ b/authentik/core/migrations/0009_group_is_superuser.py @@ -40,9 +40,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="user", name="pb_groups", - field=models.ManyToManyField( - related_name="users", to="authentik_core.Group" - ), + field=models.ManyToManyField(related_name="users", to="authentik_core.Group"), ), migrations.AddField( model_name="group", diff --git a/authentik/core/migrations/0014_auto_20201018_1158.py b/authentik/core/migrations/0014_auto_20201018_1158.py index 0f3f9dc96..750532913 100644 --- a/authentik/core/migrations/0014_auto_20201018_1158.py +++ b/authentik/core/migrations/0014_auto_20201018_1158.py @@ -42,9 +42,7 @@ class Migration(migrations.Migration): ), migrations.AddIndex( model_name="token", - index=models.Index( - fields=["identifier"], name="authentik_co_identif_1a34a8_idx" - ), + index=models.Index(fields=["identifier"], name="authentik_co_identif_1a34a8_idx"), ), migrations.RunPython(set_default_token_key), ] diff --git a/authentik/core/migrations/0015_application_icon.py b/authentik/core/migrations/0015_application_icon.py index 4ea6ac2c8..75c8c42bd 100644 --- a/authentik/core/migrations/0015_application_icon.py +++ b/authentik/core/migrations/0015_application_icon.py @@ -17,8 +17,6 @@ class Migration(migrations.Migration): migrations.AddField( model_name="application", name="meta_icon", - field=models.FileField( - blank=True, default="", upload_to="application-icons/" - ), + field=models.FileField(blank=True, default="", upload_to="application-icons/"), ), ] diff --git a/authentik/core/migrations/0016_auto_20201202_2234.py b/authentik/core/migrations/0016_auto_20201202_2234.py index e03ab30e0..a2b5cc0db 100644 --- a/authentik/core/migrations/0016_auto_20201202_2234.py +++ b/authentik/core/migrations/0016_auto_20201202_2234.py @@ -25,9 +25,7 @@ class Migration(migrations.Migration): ), migrations.AddIndex( model_name="token", - index=models.Index( - fields=["identifier"], name="authentik_c_identif_d9d032_idx" - ), + index=models.Index(fields=["identifier"], name="authentik_c_identif_d9d032_idx"), ), migrations.AddIndex( model_name="token", diff --git a/authentik/core/migrations/0022_authenticatedsession.py b/authentik/core/migrations/0022_authenticatedsession.py index d28c2ccb9..df859a1a2 100644 --- a/authentik/core/migrations/0022_authenticatedsession.py +++ b/authentik/core/migrations/0022_authenticatedsession.py @@ -32,16 +32,12 @@ class Migration(migrations.Migration): fields=[ ( "expires", - models.DateTimeField( - default=authentik.core.models.default_token_duration - ), + models.DateTimeField(default=authentik.core.models.default_token_duration), ), ("expiring", models.BooleanField(default=True)), ( "uuid", - models.UUIDField( - default=uuid.uuid4, primary_key=True, serialize=False - ), + models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False), ), ("session_key", models.CharField(max_length=40)), ("last_ip", models.TextField()), diff --git a/authentik/core/migrations/0025_alter_application_meta_icon.py b/authentik/core/migrations/0025_alter_application_meta_icon.py index cd4965b04..e612cf8ac 100644 --- a/authentik/core/migrations/0025_alter_application_meta_icon.py +++ b/authentik/core/migrations/0025_alter_application_meta_icon.py @@ -13,8 +13,6 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="application", name="meta_icon", - field=models.FileField( - default=None, null=True, upload_to="application-icons/" - ), + field=models.FileField(default=None, null=True, upload_to="application-icons/"), ), ] diff --git a/authentik/core/models.py b/authentik/core/models.py index 03eb5cf21..de6a99b19 100644 --- a/authentik/core/models.py +++ b/authentik/core/models.py @@ -154,9 +154,7 @@ class User(GuardianUserMixin, AbstractUser): ("s", "158"), ("r", "g"), ] - gravatar_url = ( - f"{GRAVATAR_URL}/avatar/{mail_hash}?{urlencode(parameters, doseq=True)}" - ) + gravatar_url = f"{GRAVATAR_URL}/avatar/{mail_hash}?{urlencode(parameters, doseq=True)}" return escape(gravatar_url) return mode % { "username": self.username, @@ -186,9 +184,7 @@ class Provider(SerializerModel): related_name="provider_authorization", ) - property_mappings = models.ManyToManyField( - "PropertyMapping", default=None, blank=True - ) + property_mappings = models.ManyToManyField("PropertyMapping", default=None, blank=True) objects = InheritanceManager() @@ -218,9 +214,7 @@ class Application(PolicyBindingModel): add custom fields and other properties""" name = models.TextField(help_text=_("Application's display Name.")) - slug = models.SlugField( - help_text=_("Internal application name, used in URLs."), unique=True - ) + slug = models.SlugField(help_text=_("Internal application name, used in URLs."), unique=True) provider = models.OneToOneField( "Provider", null=True, blank=True, default=None, on_delete=models.SET_DEFAULT ) @@ -244,9 +238,7 @@ class Application(PolicyBindingModel): it is returned as-is""" if not self.meta_icon: return None - if self.meta_icon.name.startswith("http") or self.meta_icon.name.startswith( - "/static" - ): + if self.meta_icon.name.startswith("http") or self.meta_icon.name.startswith("/static"): return self.meta_icon.name return self.meta_icon.url @@ -301,14 +293,10 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel): """Base Authentication source, i.e. an OAuth Provider, SAML Remote or LDAP Server""" name = models.TextField(help_text=_("Source's display Name.")) - slug = models.SlugField( - help_text=_("Internal source name, used in URLs."), unique=True - ) + slug = models.SlugField(help_text=_("Internal source name, used in URLs."), unique=True) enabled = models.BooleanField(default=True) - property_mappings = models.ManyToManyField( - "PropertyMapping", default=None, blank=True - ) + property_mappings = models.ManyToManyField("PropertyMapping", default=None, blank=True) authentication_flow = models.ForeignKey( Flow, @@ -481,9 +469,7 @@ class PropertyMapping(SerializerModel, ManagedModel): """Get serializer for this model""" raise NotImplementedError - def evaluate( - self, user: Optional[User], request: Optional[HttpRequest], **kwargs - ) -> Any: + def evaluate(self, user: Optional[User], request: Optional[HttpRequest], **kwargs) -> Any: """Evaluate `self.expression` using `**kwargs` as Context.""" from authentik.core.expression import PropertyMappingEvaluator @@ -522,9 +508,7 @@ class AuthenticatedSession(ExpiringModel): last_used = models.DateTimeField(auto_now=True) @staticmethod - def from_request( - request: HttpRequest, user: User - ) -> Optional["AuthenticatedSession"]: + def from_request(request: HttpRequest, user: User) -> Optional["AuthenticatedSession"]: """Create a new session from a http request""" if not hasattr(request, "session") or not request.session.session_key: return None diff --git a/authentik/core/signals.py b/authentik/core/signals.py index 497fbdc8b..c114ffa65 100644 --- a/authentik/core/signals.py +++ b/authentik/core/signals.py @@ -14,9 +14,7 @@ from prometheus_client import Gauge # Arguments: user: User, password: str password_changed = Signal() -GAUGE_MODELS = Gauge( - "authentik_models", "Count of various objects", ["model_name", "app"] -) +GAUGE_MODELS = Gauge("authentik_models", "Count of various objects", ["model_name", "app"]) if TYPE_CHECKING: from authentik.core.models import AuthenticatedSession, User @@ -60,15 +58,11 @@ def user_logged_out_session(sender, request: HttpRequest, user: "User", **_): """Delete AuthenticatedSession if it exists""" from authentik.core.models import AuthenticatedSession - AuthenticatedSession.objects.filter( - session_key=request.session.session_key - ).delete() + AuthenticatedSession.objects.filter(session_key=request.session.session_key).delete() @receiver(pre_delete) -def authenticated_session_delete( - sender: Type[Model], instance: "AuthenticatedSession", **_ -): +def authenticated_session_delete(sender: Type[Model], instance: "AuthenticatedSession", **_): """Delete session when authenticated session is deleted""" from authentik.core.models import AuthenticatedSession diff --git a/authentik/core/sources/flow_manager.py b/authentik/core/sources/flow_manager.py index 4e0866f75..7666da07f 100644 --- a/authentik/core/sources/flow_manager.py +++ b/authentik/core/sources/flow_manager.py @@ -11,16 +11,8 @@ from django.urls import reverse from django.utils.translation import gettext as _ from structlog.stdlib import get_logger -from authentik.core.models import ( - Source, - SourceUserMatchingModes, - User, - UserSourceConnection, -) -from authentik.core.sources.stage import ( - PLAN_CONTEXT_SOURCES_CONNECTION, - PostUserEnrollmentStage, -) +from authentik.core.models import Source, SourceUserMatchingModes, User, UserSourceConnection +from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostUserEnrollmentStage from authentik.events.models import Event, EventAction from authentik.flows.models import Flow, Stage, in_memory_stage from authentik.flows.planner import ( @@ -76,9 +68,7 @@ class SourceFlowManager: # pylint: disable=too-many-return-statements def get_action(self, **kwargs) -> tuple[Action, Optional[UserSourceConnection]]: """decide which action should be taken""" - new_connection = self.connection_type( - source=self.source, identifier=self.identifier - ) + new_connection = self.connection_type(source=self.source, identifier=self.identifier) # When request is authenticated, always link if self.request.user.is_authenticated: new_connection.user = self.request.user @@ -113,9 +103,7 @@ class SourceFlowManager: SourceUserMatchingModes.USERNAME_DENY, ]: if not self.enroll_info.get("username", None): - self._logger.warning( - "Refusing to use none username", source=self.source - ) + self._logger.warning("Refusing to use none username", source=self.source) return Action.DENY, None query = Q(username__exact=self.enroll_info.get("username", None)) self._logger.debug("trying to link with existing user", query=query) @@ -229,10 +217,7 @@ class SourceFlowManager: """Login user and redirect.""" messages.success( self.request, - _( - "Successfully authenticated with %(source)s!" - % {"source": self.source.name} - ), + _("Successfully authenticated with %(source)s!" % {"source": self.source.name}), ) flow_kwargs = {PLAN_CONTEXT_PENDING_USER: connection.user} return self._handle_login_flow(self.source.authentication_flow, **flow_kwargs) @@ -270,10 +255,7 @@ class SourceFlowManager: """User was not authenticated and previous request was not authenticated.""" messages.success( self.request, - _( - "Successfully authenticated with %(source)s!" - % {"source": self.source.name} - ), + _("Successfully authenticated with %(source)s!" % {"source": self.source.name}), ) # We run the Flow planner here so we can pass the Pending user in the context diff --git a/authentik/core/tasks.py b/authentik/core/tasks.py index 2e32f3d9e..591880502 100644 --- a/authentik/core/tasks.py +++ b/authentik/core/tasks.py @@ -27,9 +27,7 @@ def clean_expired_models(self: MonitoredTask): for cls in ExpiringModel.__subclasses__(): cls: ExpiringModel objects = ( - cls.objects.all() - .exclude(expiring=False) - .exclude(expiring=True, expires__gt=now()) + cls.objects.all().exclude(expiring=False).exclude(expiring=True, expires__gt=now()) ) for obj in objects: obj.expire_action() diff --git a/authentik/core/tests/test_applications_api.py b/authentik/core/tests/test_applications_api.py index 21c0edef6..81e8171ac 100644 --- a/authentik/core/tests/test_applications_api.py +++ b/authentik/core/tests/test_applications_api.py @@ -17,9 +17,7 @@ class TestApplicationsAPI(APITestCase): self.denied = Application.objects.create(name="denied", slug="denied") PolicyBinding.objects.create( target=self.denied, - policy=DummyPolicy.objects.create( - name="deny", result=False, wait_min=1, wait_max=2 - ), + policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2), order=0, ) @@ -33,9 +31,7 @@ class TestApplicationsAPI(APITestCase): ) ) self.assertEqual(response.status_code, 200) - self.assertJSONEqual( - force_str(response.content), {"messages": [], "passing": True} - ) + self.assertJSONEqual(force_str(response.content), {"messages": [], "passing": True}) response = self.client.get( reverse( "authentik_api:application-check-access", @@ -43,9 +39,7 @@ class TestApplicationsAPI(APITestCase): ) ) self.assertEqual(response.status_code, 200) - self.assertJSONEqual( - force_str(response.content), {"messages": ["dummy"], "passing": False} - ) + self.assertJSONEqual(force_str(response.content), {"messages": ["dummy"], "passing": False}) def test_list(self): """Test list operation without superuser_full_list""" diff --git a/authentik/core/tests/test_impersonation.py b/authentik/core/tests/test_impersonation.py index 36bf5ee71..164c902c8 100644 --- a/authentik/core/tests/test_impersonation.py +++ b/authentik/core/tests/test_impersonation.py @@ -46,9 +46,7 @@ class TestImpersonation(TestCase): self.client.force_login(self.other_user) self.client.get( - reverse( - "authentik_core:impersonate-init", kwargs={"user_id": self.akadmin.pk} - ) + reverse("authentik_core:impersonate-init", kwargs={"user_id": self.akadmin.pk}) ) response = self.client.get(reverse("authentik_api:user-me")) diff --git a/authentik/core/tests/test_models.py b/authentik/core/tests/test_models.py index 5d04d4bf2..53e412c86 100644 --- a/authentik/core/tests/test_models.py +++ b/authentik/core/tests/test_models.py @@ -22,9 +22,7 @@ class TestModels(TestCase): def test_token_expire_no_expire(self): """Test token expiring with "expiring" set""" - token = Token.objects.create( - expires=now(), user=get_anonymous_user(), expiring=False - ) + token = Token.objects.create(expires=now(), user=get_anonymous_user(), expiring=False) sleep(0.5) self.assertFalse(token.is_expired) diff --git a/authentik/core/tests/test_property_mapping.py b/authentik/core/tests/test_property_mapping.py index 75a3610b2..a7fce579c 100644 --- a/authentik/core/tests/test_property_mapping.py +++ b/authentik/core/tests/test_property_mapping.py @@ -16,9 +16,7 @@ class TestPropertyMappings(TestCase): def test_expression(self): """Test expression""" - mapping = PropertyMapping.objects.create( - name="test", expression="return 'test'" - ) + mapping = PropertyMapping.objects.create(name="test", expression="return 'test'") self.assertEqual(mapping.evaluate(None, None), "test") def test_expression_syntax(self): diff --git a/authentik/core/tests/test_property_mapping_api.py b/authentik/core/tests/test_property_mapping_api.py index 74da28f5e..bb3bcf2ca 100644 --- a/authentik/core/tests/test_property_mapping_api.py +++ b/authentik/core/tests/test_property_mapping_api.py @@ -23,9 +23,7 @@ class TestPropertyMappingAPI(APITestCase): def test_test_call(self): """Test PropertMappings's test endpoint""" response = self.client.post( - reverse( - "authentik_api:propertymapping-test", kwargs={"pk": self.mapping.pk} - ), + reverse("authentik_api:propertymapping-test", kwargs={"pk": self.mapping.pk}), data={ "user": self.user.pk, }, diff --git a/authentik/core/tests/test_token_api.py b/authentik/core/tests/test_token_api.py index 74295ab43..0e19a9a98 100644 --- a/authentik/core/tests/test_token_api.py +++ b/authentik/core/tests/test_token_api.py @@ -4,12 +4,7 @@ from django.utils.timezone import now from guardian.shortcuts import get_anonymous_user from rest_framework.test import APITestCase -from authentik.core.models import ( - USER_ATTRIBUTE_TOKEN_EXPIRING, - Token, - TokenIntents, - User, -) +from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents, User from authentik.core.tasks import clean_expired_models diff --git a/authentik/core/views/impersonate.py b/authentik/core/views/impersonate.py index 4b5523d44..ddeb4f940 100644 --- a/authentik/core/views/impersonate.py +++ b/authentik/core/views/impersonate.py @@ -5,10 +5,7 @@ from django.shortcuts import get_object_or_404, redirect from django.views import View from structlog.stdlib import get_logger -from authentik.core.middleware import ( - SESSION_IMPERSONATE_ORIGINAL_USER, - SESSION_IMPERSONATE_USER, -) +from authentik.core.middleware import SESSION_IMPERSONATE_ORIGINAL_USER, SESSION_IMPERSONATE_USER from authentik.core.models import User from authentik.events.models import Event, EventAction @@ -21,9 +18,7 @@ class ImpersonateInitView(View): def get(self, request: HttpRequest, user_id: int) -> HttpResponse: """Impersonation handler, checks permissions""" if not request.user.has_perm("impersonate"): - LOGGER.debug( - "User attempted to impersonate without permissions", user=request.user - ) + LOGGER.debug("User attempted to impersonate without permissions", user=request.user) return HttpResponse("Unauthorized", status=401) user_to_be = get_object_or_404(User, pk=user_id) diff --git a/authentik/core/views/session.py b/authentik/core/views/session.py index 864d9b8de..11d5ad940 100644 --- a/authentik/core/views/session.py +++ b/authentik/core/views/session.py @@ -14,9 +14,7 @@ class EndSessionView(TemplateView, PolicyAccessView): template_name = "if/end_session.html" def resolve_provider_application(self): - self.application = get_object_or_404( - Application, slug=self.kwargs["application_slug"] - ) + self.application = get_object_or_404(Application, slug=self.kwargs["application_slug"]) def get_context_data(self, **kwargs: Any) -> dict[str, Any]: context = super().get_context_data(**kwargs) diff --git a/authentik/crypto/api.py b/authentik/crypto/api.py index 0d5f97830..f787d2935 100644 --- a/authentik/crypto/api.py +++ b/authentik/crypto/api.py @@ -10,12 +10,7 @@ from django_filters.filters import BooleanFilter from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema from rest_framework.decorators import action -from rest_framework.fields import ( - CharField, - DateTimeField, - IntegerField, - SerializerMethodField, -) +from rest_framework.fields import CharField, DateTimeField, IntegerField, SerializerMethodField from rest_framework.request import Request from rest_framework.response import Response from rest_framework.serializers import ModelSerializer, ValidationError @@ -86,9 +81,7 @@ class CertificateKeyPairSerializer(ModelSerializer): backend=default_backend(), ) except (ValueError, TypeError): - raise ValidationError( - "Unable to load private key (possibly encrypted?)." - ) + raise ValidationError("Unable to load private key (possibly encrypted?).") return value class Meta: @@ -123,9 +116,7 @@ class CertificateGenerationSerializer(PassiveSerializer): """Certificate generation parameters""" common_name = CharField() - subject_alt_name = CharField( - required=False, allow_blank=True, label=_("Subject-alt name") - ) + subject_alt_name = CharField(required=False, allow_blank=True, label=_("Subject-alt name")) validity_days = IntegerField(initial=365) @@ -170,9 +161,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet): builder = CertificateBuilder() builder.common_name = data.validated_data["common_name"] builder.build( - subject_alt_names=data.validated_data.get("subject_alt_name", "").split( - "," - ), + subject_alt_names=data.validated_data.get("subject_alt_name", "").split(","), validity_days=int(data.validated_data["validity_days"]), ) instance = builder.save() @@ -208,9 +197,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet): "Content-Disposition" ] = f'attachment; filename="{certificate.name}_certificate.pem"' return response - return Response( - CertificateDataSerializer({"data": certificate.certificate_data}).data - ) + return Response(CertificateDataSerializer({"data": certificate.certificate_data}).data) @extend_schema( parameters=[ @@ -234,9 +221,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet): ).from_http(request) if "download" in request._request.GET: # Mime type from https://pki-tutorial.readthedocs.io/en/latest/mime.html - response = HttpResponse( - certificate.key_data, content_type="application/x-pem-file" - ) + response = HttpResponse(certificate.key_data, content_type="application/x-pem-file") response[ "Content-Disposition" ] = f'attachment; filename="{certificate.name}_private_key.pem"' diff --git a/authentik/crypto/builder.py b/authentik/crypto/builder.py index 9b6848981..07318aeeb 100644 --- a/authentik/crypto/builder.py +++ b/authentik/crypto/builder.py @@ -46,9 +46,7 @@ class CertificateBuilder: public_exponent=65537, key_size=2048, backend=default_backend() ) self.__public_key = self.__private_key.public_key() - alt_names: list[x509.GeneralName] = [ - x509.DNSName(x) for x in subject_alt_names or [] - ] + alt_names: list[x509.GeneralName] = [x509.DNSName(x) for x in subject_alt_names or []] self.__builder = ( x509.CertificateBuilder() .subject_name( @@ -59,9 +57,7 @@ class CertificateBuilder: self.common_name, ), x509.NameAttribute(NameOID.ORGANIZATION_NAME, "authentik"), - x509.NameAttribute( - NameOID.ORGANIZATIONAL_UNIT_NAME, "Self-signed" - ), + x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Self-signed"), ] ) ) @@ -77,9 +73,7 @@ class CertificateBuilder: ) .add_extension(x509.SubjectAlternativeName(alt_names), critical=True) .not_valid_before(datetime.datetime.today() - one_day) - .not_valid_after( - datetime.datetime.today() + datetime.timedelta(days=validity_days) - ) + .not_valid_after(datetime.datetime.today() + datetime.timedelta(days=validity_days)) .serial_number(int(uuid.uuid4())) .public_key(self.__public_key) ) diff --git a/authentik/crypto/models.py b/authentik/crypto/models.py index 37d875dc9..7e3610b07 100644 --- a/authentik/crypto/models.py +++ b/authentik/crypto/models.py @@ -57,9 +57,7 @@ class CertificateKeyPair(CreatedUpdatedModel): if not self._private_key and self._private_key != "": try: self._private_key = load_pem_private_key( - str.encode( - "\n".join([x.strip() for x in self.key_data.split("\n")]) - ), + str.encode("\n".join([x.strip() for x in self.key_data.split("\n")])), password=None, backend=default_backend(), ) @@ -70,25 +68,19 @@ class CertificateKeyPair(CreatedUpdatedModel): @property def fingerprint_sha256(self) -> str: """Get SHA256 Fingerprint of certificate_data""" - return hexlify(self.certificate.fingerprint(hashes.SHA256()), ":").decode( - "utf-8" - ) + return hexlify(self.certificate.fingerprint(hashes.SHA256()), ":").decode("utf-8") @property def fingerprint_sha1(self) -> str: """Get SHA1 Fingerprint of certificate_data""" - return hexlify( - self.certificate.fingerprint(hashes.SHA1()), ":" # nosec - ).decode("utf-8") + return hexlify(self.certificate.fingerprint(hashes.SHA1()), ":").decode("utf-8") # nosec @property def kid(self): """Get Key ID used for JWKS""" return "{0}".format( - md5(self.key_data.encode("utf-8")).hexdigest() # nosec - if self.key_data - else "" - ) + md5(self.key_data.encode("utf-8")).hexdigest() if self.key_data else "" + ) # nosec def __str__(self) -> str: return f"Certificate-Key Pair {self.name}" diff --git a/authentik/events/api/event.py b/authentik/events/api/event.py index 570559a42..767f58a41 100644 --- a/authentik/events/api/event.py +++ b/authentik/events/api/event.py @@ -143,7 +143,5 @@ class EventViewSet(ModelViewSet): """Get all actions""" data = [] for value, name in EventAction.choices: - data.append( - {"name": name, "description": "", "component": value, "model_name": ""} - ) + data.append({"name": name, "description": "", "component": value, "model_name": ""}) return Response(TypeCreateSerializer(data, many=True).data) diff --git a/authentik/events/middleware.py b/authentik/events/middleware.py index 4a8838a0f..9a3c3d0ae 100644 --- a/authentik/events/middleware.py +++ b/authentik/events/middleware.py @@ -29,12 +29,8 @@ class AuditMiddleware: def __call__(self, request: HttpRequest) -> HttpResponse: # Connect signal for automatic logging - if hasattr(request, "user") and getattr( - request.user, "is_authenticated", False - ): - post_save_handler = partial( - self.post_save_handler, user=request.user, request=request - ) + if hasattr(request, "user") and getattr(request.user, "is_authenticated", False): + post_save_handler = partial(self.post_save_handler, user=request.user, request=request) pre_delete_handler = partial( self.pre_delete_handler, user=request.user, request=request ) @@ -94,13 +90,9 @@ class AuditMiddleware: @staticmethod # pylint: disable=unused-argument - def pre_delete_handler( - user: User, request: HttpRequest, sender, instance: Model, **_ - ): + def pre_delete_handler(user: User, request: HttpRequest, sender, instance: Model, **_): """Signal handler for all object's pre_delete""" - if isinstance( - instance, (Event, Notification, UserObjectPermission) - ): # pragma: no cover + if isinstance(instance, (Event, Notification, UserObjectPermission)): # pragma: no cover return EventNewThread( diff --git a/authentik/events/migrations/0003_auto_20200917_1155.py b/authentik/events/migrations/0003_auto_20200917_1155.py index 26d9b9782..f33dabb08 100644 --- a/authentik/events/migrations/0003_auto_20200917_1155.py +++ b/authentik/events/migrations/0003_auto_20200917_1155.py @@ -14,9 +14,7 @@ def convert_user_to_json(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): event.delete() # Because event objects cannot be updated, we have to re-create them event.pk = None - event.user_json = ( - authentik.events.models.get_user(event.user) if event.user else {} - ) + event.user_json = authentik.events.models.get_user(event.user) if event.user else {} event._state.adding = True event.save() @@ -58,7 +56,5 @@ class Migration(migrations.Migration): model_name="event", name="user", ), - migrations.RenameField( - model_name="event", old_name="user_json", new_name="user" - ), + migrations.RenameField(model_name="event", old_name="user_json", new_name="user"), ] diff --git a/authentik/events/migrations/0011_notification_rules_default_v1.py b/authentik/events/migrations/0011_notification_rules_default_v1.py index 70289bbf8..efa9ea64c 100644 --- a/authentik/events/migrations/0011_notification_rules_default_v1.py +++ b/authentik/events/migrations/0011_notification_rules_default_v1.py @@ -11,16 +11,12 @@ def notify_configuration_error(apps: Apps, schema_editor: BaseDatabaseSchemaEdit db_alias = schema_editor.connection.alias Group = apps.get_model("authentik_core", "Group") PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") - EventMatcherPolicy = apps.get_model( - "authentik_policies_event_matcher", "EventMatcherPolicy" - ) + EventMatcherPolicy = apps.get_model("authentik_policies_event_matcher", "EventMatcherPolicy") NotificationRule = apps.get_model("authentik_events", "NotificationRule") NotificationTransport = apps.get_model("authentik_events", "NotificationTransport") admin_group = ( - Group.objects.using(db_alias) - .filter(name="authentik Admins", is_superuser=True) - .first() + Group.objects.using(db_alias).filter(name="authentik Admins", is_superuser=True).first() ) policy, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create( @@ -32,9 +28,7 @@ def notify_configuration_error(apps: Apps, schema_editor: BaseDatabaseSchemaEdit defaults={"group": admin_group, "severity": NotificationSeverity.ALERT}, ) trigger.transports.set( - NotificationTransport.objects.using(db_alias).filter( - name="default-email-transport" - ) + NotificationTransport.objects.using(db_alias).filter(name="default-email-transport") ) trigger.save() PolicyBinding.objects.using(db_alias).update_or_create( @@ -50,16 +44,12 @@ def notify_update(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): db_alias = schema_editor.connection.alias Group = apps.get_model("authentik_core", "Group") PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") - EventMatcherPolicy = apps.get_model( - "authentik_policies_event_matcher", "EventMatcherPolicy" - ) + EventMatcherPolicy = apps.get_model("authentik_policies_event_matcher", "EventMatcherPolicy") NotificationRule = apps.get_model("authentik_events", "NotificationRule") NotificationTransport = apps.get_model("authentik_events", "NotificationTransport") admin_group = ( - Group.objects.using(db_alias) - .filter(name="authentik Admins", is_superuser=True) - .first() + Group.objects.using(db_alias).filter(name="authentik Admins", is_superuser=True).first() ) policy, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create( @@ -71,9 +61,7 @@ def notify_update(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): defaults={"group": admin_group, "severity": NotificationSeverity.ALERT}, ) trigger.transports.set( - NotificationTransport.objects.using(db_alias).filter( - name="default-email-transport" - ) + NotificationTransport.objects.using(db_alias).filter(name="default-email-transport") ) trigger.save() PolicyBinding.objects.using(db_alias).update_or_create( @@ -89,16 +77,12 @@ def notify_exception(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): db_alias = schema_editor.connection.alias Group = apps.get_model("authentik_core", "Group") PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") - EventMatcherPolicy = apps.get_model( - "authentik_policies_event_matcher", "EventMatcherPolicy" - ) + EventMatcherPolicy = apps.get_model("authentik_policies_event_matcher", "EventMatcherPolicy") NotificationRule = apps.get_model("authentik_events", "NotificationRule") NotificationTransport = apps.get_model("authentik_events", "NotificationTransport") admin_group = ( - Group.objects.using(db_alias) - .filter(name="authentik Admins", is_superuser=True) - .first() + Group.objects.using(db_alias).filter(name="authentik Admins", is_superuser=True).first() ) policy_policy_exc, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create( @@ -114,9 +98,7 @@ def notify_exception(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): defaults={"group": admin_group, "severity": NotificationSeverity.ALERT}, ) trigger.transports.set( - NotificationTransport.objects.using(db_alias).filter( - name="default-email-transport" - ) + NotificationTransport.objects.using(db_alias).filter(name="default-email-transport") ) trigger.save() PolicyBinding.objects.using(db_alias).update_or_create( diff --git a/authentik/events/migrations/0014_expiry.py b/authentik/events/migrations/0014_expiry.py index ef3939165..b59fc1f70 100644 --- a/authentik/events/migrations/0014_expiry.py +++ b/authentik/events/migrations/0014_expiry.py @@ -38,9 +38,7 @@ def progress_bar( def print_progress_bar(iteration): """Progress Bar Printing Function""" - percent = ("{0:." + str(decimals) + "f}").format( - 100 * (iteration / float(total)) - ) + percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total))) filledLength = int(length * iteration // total) bar = fill * filledLength + "-" * (length - filledLength) print(f"\r{prefix} |{bar}| {percent}% {suffix}", end=print_end) @@ -78,9 +76,7 @@ class Migration(migrations.Migration): migrations.AddField( model_name="event", name="expires", - field=models.DateTimeField( - default=authentik.events.models.default_event_duration - ), + field=models.DateTimeField(default=authentik.events.models.default_event_duration), ), migrations.AddField( model_name="event", diff --git a/authentik/events/migrations/0016_add_tenant.py b/authentik/events/migrations/0016_add_tenant.py index f8c199b4e..853b0146f 100644 --- a/authentik/events/migrations/0016_add_tenant.py +++ b/authentik/events/migrations/0016_add_tenant.py @@ -15,9 +15,7 @@ class Migration(migrations.Migration): migrations.AddField( model_name="event", name="tenant", - field=models.JSONField( - blank=True, default=authentik.events.models.default_tenant - ), + field=models.JSONField(blank=True, default=authentik.events.models.default_tenant), ), migrations.AlterField( model_name="event", diff --git a/authentik/events/models.py b/authentik/events/models.py index 3c2aede09..3cc04d322 100644 --- a/authentik/events/models.py +++ b/authentik/events/models.py @@ -15,10 +15,7 @@ from requests import RequestException, post from structlog.stdlib import get_logger from authentik import __version__ -from authentik.core.middleware import ( - SESSION_IMPERSONATE_ORIGINAL_USER, - SESSION_IMPERSONATE_USER, -) +from authentik.core.middleware import SESSION_IMPERSONATE_ORIGINAL_USER, SESSION_IMPERSONATE_USER from authentik.core.models import ExpiringModel, Group, User from authentik.events.geo import GEOIP_READER from authentik.events.utils import cleanse_dict, get_user, model_to_dict, sanitize_dict @@ -159,9 +156,7 @@ class Event(ExpiringModel): if hasattr(request, "user"): original_user = None if hasattr(request, "session"): - original_user = request.session.get( - SESSION_IMPERSONATE_ORIGINAL_USER, None - ) + original_user = request.session.get(SESSION_IMPERSONATE_ORIGINAL_USER, None) self.user = get_user(request.user, original_user) if user: self.user = get_user(user) @@ -169,9 +164,7 @@ class Event(ExpiringModel): if hasattr(request, "session"): if SESSION_IMPERSONATE_ORIGINAL_USER in request.session: self.user = get_user(request.session[SESSION_IMPERSONATE_ORIGINAL_USER]) - self.user["on_behalf_of"] = get_user( - request.session[SESSION_IMPERSONATE_USER] - ) + self.user["on_behalf_of"] = get_user(request.session[SESSION_IMPERSONATE_USER]) # User 255.255.255.255 as fallback if IP cannot be determined self.client_ip = get_client_ip(request) # Apply GeoIP Data, when enabled @@ -414,9 +407,7 @@ class NotificationRule(PolicyBindingModel): severity = models.TextField( choices=NotificationSeverity.choices, default=NotificationSeverity.NOTICE, - help_text=_( - "Controls which severity level the created notifications will have." - ), + help_text=_("Controls which severity level the created notifications will have."), ) group = models.ForeignKey( Group, diff --git a/authentik/events/monitored_tasks.py b/authentik/events/monitored_tasks.py index 162fa73a4..b7575d8dd 100644 --- a/authentik/events/monitored_tasks.py +++ b/authentik/events/monitored_tasks.py @@ -135,9 +135,7 @@ class MonitoredTask(Task): self._result = result # pylint: disable=too-many-arguments - def after_return( - self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo - ): + def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo): if self._result: if not self._result.uid: self._result.uid = self._uid @@ -159,9 +157,7 @@ class MonitoredTask(Task): # pylint: disable=too-many-arguments def on_failure(self, exc, task_id, args, kwargs, einfo): if not self._result: - self._result = TaskResult( - status=TaskResultStatus.ERROR, messages=[str(exc)] - ) + self._result = TaskResult(status=TaskResultStatus.ERROR, messages=[str(exc)]) if not self._result.uid: self._result.uid = self._uid TaskInfo( @@ -179,8 +175,7 @@ class MonitoredTask(Task): Event.new( EventAction.SYSTEM_TASK_EXCEPTION, message=( - f"Task {self.__name__} encountered an error: " - "\n".join(self._result.messages) + f"Task {self.__name__} encountered an error: " "\n".join(self._result.messages) ), ).save() return super().on_failure(exc, task_id, args, kwargs, einfo=einfo) diff --git a/authentik/events/signals.py b/authentik/events/signals.py index c0f7124aa..7a4b5a032 100644 --- a/authentik/events/signals.py +++ b/authentik/events/signals.py @@ -2,11 +2,7 @@ from threading import Thread from typing import Any, Optional -from django.contrib.auth.signals import ( - user_logged_in, - user_logged_out, - user_login_failed, -) +from django.contrib.auth.signals import user_logged_in, user_logged_out, user_login_failed from django.db.models.signals import post_save from django.dispatch import receiver from django.http import HttpRequest @@ -30,9 +26,7 @@ class EventNewThread(Thread): kwargs: dict[str, Any] user: Optional[User] = None - def __init__( - self, action: str, request: HttpRequest, user: Optional[User] = None, **kwargs - ): + def __init__(self, action: str, request: HttpRequest, user: Optional[User] = None, **kwargs): super().__init__() self.action = action self.request = request @@ -68,9 +62,7 @@ def on_user_logged_out(sender, request: HttpRequest, user: User, **_): @receiver(user_write) # pylint: disable=unused-argument -def on_user_write( - sender, request: HttpRequest, user: User, data: dict[str, Any], **kwargs -): +def on_user_write(sender, request: HttpRequest, user: User, data: dict[str, Any], **kwargs): """Log User write""" thread = EventNewThread(EventAction.USER_WRITE, request, **data) thread.kwargs["created"] = kwargs.get("created", False) @@ -80,9 +72,7 @@ def on_user_write( @receiver(user_login_failed) # pylint: disable=unused-argument -def on_user_login_failed( - sender, credentials: dict[str, str], request: HttpRequest, **_ -): +def on_user_login_failed(sender, credentials: dict[str, str], request: HttpRequest, **_): """Failed Login""" thread = EventNewThread(EventAction.LOGIN_FAILED, request, **credentials) thread.run() diff --git a/authentik/events/tasks.py b/authentik/events/tasks.py index 8ab4f8662..d4791f427 100644 --- a/authentik/events/tasks.py +++ b/authentik/events/tasks.py @@ -22,9 +22,7 @@ LOGGER = get_logger() def event_notification_handler(event_uuid: str): """Start task for each trigger definition""" for trigger in NotificationRule.objects.all(): - event_trigger_handler.apply_async( - args=[event_uuid, trigger.name], queue="authentik_events" - ) + event_trigger_handler.apply_async(args=[event_uuid, trigger.name], queue="authentik_events") @CELERY_APP.task() @@ -43,17 +41,13 @@ def event_trigger_handler(event_uuid: str, trigger_name: str): if "policy_uuid" in event.context: policy_uuid = event.context["policy_uuid"] if PolicyBinding.objects.filter( - target__in=NotificationRule.objects.all().values_list( - "pbm_uuid", flat=True - ), + target__in=NotificationRule.objects.all().values_list("pbm_uuid", flat=True), policy=policy_uuid, ).exists(): # If policy that caused this event to be created is attached # to *any* NotificationRule, we return early. # This is the most effective way to prevent infinite loops. - LOGGER.debug( - "e(trigger): attempting to prevent infinite loop", trigger=trigger - ) + LOGGER.debug("e(trigger): attempting to prevent infinite loop", trigger=trigger) return if not trigger.group: @@ -62,9 +56,7 @@ def event_trigger_handler(event_uuid: str, trigger_name: str): LOGGER.debug("e(trigger): checking if trigger applies", trigger=trigger) try: - user = ( - User.objects.filter(pk=event.user.get("pk")).first() or get_anonymous_user() - ) + user = User.objects.filter(pk=event.user.get("pk")).first() or get_anonymous_user() except User.DoesNotExist: LOGGER.warning("e(trigger): failed to get user", trigger=trigger) return @@ -99,20 +91,14 @@ def event_trigger_handler(event_uuid: str, trigger_name: str): retry_backoff=True, base=MonitoredTask, ) -def notification_transport( - self: MonitoredTask, notification_pk: int, transport_pk: int -): +def notification_transport(self: MonitoredTask, notification_pk: int, transport_pk: int): """Send notification over specified transport""" self.save_on_success = False try: - notification: Notification = Notification.objects.filter( - pk=notification_pk - ).first() + notification: Notification = Notification.objects.filter(pk=notification_pk).first() if not notification: return - transport: NotificationTransport = NotificationTransport.objects.get( - pk=transport_pk - ) + transport: NotificationTransport = NotificationTransport.objects.get(pk=transport_pk) transport.send(notification) self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL)) except NotificationTransportError as exc: diff --git a/authentik/events/tests/test_event.py b/authentik/events/tests/test_event.py index 9389ad144..cfc6fb1c7 100644 --- a/authentik/events/tests/test_event.py +++ b/authentik/events/tests/test_event.py @@ -38,7 +38,5 @@ class TestEvents(TestCase): event = Event.new("unittest", model=temp_model) event.save() # We save to ensure nothing is un-saveable model_content_type = ContentType.objects.get_for_model(temp_model) - self.assertEqual( - event.context.get("model").get("app"), model_content_type.app_label - ) + self.assertEqual(event.context.get("model").get("app"), model_content_type.app_label) self.assertEqual(event.context.get("model").get("pk"), temp_model.pk.hex) diff --git a/authentik/events/tests/test_notifications.py b/authentik/events/tests/test_notifications.py index ae80c0ea1..848608a61 100644 --- a/authentik/events/tests/test_notifications.py +++ b/authentik/events/tests/test_notifications.py @@ -81,12 +81,8 @@ class TestEventsNotifications(TestCase): execute_mock = MagicMock() passes = MagicMock(side_effect=PolicyException) - with patch( - "authentik.policies.event_matcher.models.EventMatcherPolicy.passes", passes - ): - with patch( - "authentik.events.models.NotificationTransport.send", execute_mock - ): + with patch("authentik.policies.event_matcher.models.EventMatcherPolicy.passes", passes): + with patch("authentik.events.models.NotificationTransport.send", execute_mock): Event.new(EventAction.CUSTOM_PREFIX).save() self.assertEqual(passes.call_count, 1) @@ -96,9 +92,7 @@ class TestEventsNotifications(TestCase): self.group.users.add(user2) self.group.save() - transport = NotificationTransport.objects.create( - name="transport", send_once=True - ) + transport = NotificationTransport.objects.create(name="transport", send_once=True) NotificationRule.objects.filter(name__startswith="default").delete() trigger = NotificationRule.objects.create(name="trigger", group=self.group) trigger.transports.add(transport) diff --git a/authentik/flows/api/flows.py b/authentik/flows/api/flows.py index 21bedfcd6..d6ebd4e83 100644 --- a/authentik/flows/api/flows.py +++ b/authentik/flows/api/flows.py @@ -14,12 +14,7 @@ from rest_framework.fields import BooleanField, FileField, ReadOnlyField from rest_framework.parsers import MultiPartParser from rest_framework.request import Request from rest_framework.response import Response -from rest_framework.serializers import ( - CharField, - ModelSerializer, - Serializer, - SerializerMethodField, -) +from rest_framework.serializers import CharField, ModelSerializer, Serializer, SerializerMethodField from rest_framework.viewsets import ModelViewSet from structlog.stdlib import get_logger @@ -152,11 +147,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet): ], ) @extend_schema( - request={ - "multipart/form-data": inline_serializer( - "SetIcon", fields={"file": FileField()} - ) - }, + request={"multipart/form-data": inline_serializer("SetIcon", fields={"file": FileField()})}, responses={ 204: OpenApiResponse(description="Successfully imported flow"), 400: OpenApiResponse(description="Bad request"), @@ -221,9 +212,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet): .order_by("order") ): for p_index, policy_binding in enumerate( - get_objects_for_user( - request.user, "authentik_policies.view_policybinding" - ) + get_objects_for_user(request.user, "authentik_policies.view_policybinding") .filter(target=stage_binding) .exclude(policy__isnull=True) .order_by("order") @@ -256,20 +245,14 @@ class FlowViewSet(UsedByMixin, ModelViewSet): element: DiagramElement = body[index] if element.type == "condition": # Policy passes, link policy yes to next stage - footer.append( - f"{element.identifier}(yes, right)->{body[index + 1].identifier}" - ) + footer.append(f"{element.identifier}(yes, right)->{body[index + 1].identifier}") # Policy doesn't pass, go to stage after next stage no_element = body[index + 1] if no_element.type != "end": no_element = body[index + 2] - footer.append( - f"{element.identifier}(no, bottom)->{no_element.identifier}" - ) + footer.append(f"{element.identifier}(no, bottom)->{no_element.identifier}") elif element.type == "operation": - footer.append( - f"{element.identifier}(bottom)->{body[index + 1].identifier}" - ) + footer.append(f"{element.identifier}(bottom)->{body[index + 1].identifier}") diagram = "\n".join([str(x) for x in header + body + footer]) return Response({"diagram": diagram}) diff --git a/authentik/flows/management/commands/benchmark.py b/authentik/flows/management/commands/benchmark.py index 586c83e47..5b445357a 100644 --- a/authentik/flows/management/commands/benchmark.py +++ b/authentik/flows/management/commands/benchmark.py @@ -95,9 +95,7 @@ class Command(BaseCommand): # pragma: no cover """Output results human readable""" total_max: int = max([max(inner) for inner in values]) total_min: int = min([min(inner) for inner in values]) - total_avg = sum([sum(inner) for inner in values]) / sum( - [len(inner) for inner in values] - ) + total_avg = sum([sum(inner) for inner in values]) / sum([len(inner) for inner in values]) print(f"Version: {__version__}") print(f"Processes: {len(values)}") diff --git a/authentik/flows/migrations/0008_default_flows.py b/authentik/flows/migrations/0008_default_flows.py index 7f92644eb..8de070f80 100644 --- a/authentik/flows/migrations/0008_default_flows.py +++ b/authentik/flows/migrations/0008_default_flows.py @@ -9,21 +9,15 @@ from authentik.stages.identification.models import UserFields from authentik.stages.password import BACKEND_DJANGO, BACKEND_LDAP -def create_default_authentication_flow( - apps: Apps, schema_editor: BaseDatabaseSchemaEditor -): +def create_default_authentication_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): Flow = apps.get_model("authentik_flows", "Flow") FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") PasswordStage = apps.get_model("authentik_stages_password", "PasswordStage") UserLoginStage = apps.get_model("authentik_stages_user_login", "UserLoginStage") - IdentificationStage = apps.get_model( - "authentik_stages_identification", "IdentificationStage" - ) + IdentificationStage = apps.get_model("authentik_stages_identification", "IdentificationStage") db_alias = schema_editor.connection.alias - identification_stage, _ = IdentificationStage.objects.using( - db_alias - ).update_or_create( + identification_stage, _ = IdentificationStage.objects.using(db_alias).update_or_create( name="default-authentication-identification", defaults={ "user_fields": [UserFields.E_MAIL, UserFields.USERNAME], @@ -69,17 +63,13 @@ def create_default_authentication_flow( ) -def create_default_invalidation_flow( - apps: Apps, schema_editor: BaseDatabaseSchemaEditor -): +def create_default_invalidation_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): Flow = apps.get_model("authentik_flows", "Flow") FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") UserLogoutStage = apps.get_model("authentik_stages_user_logout", "UserLogoutStage") db_alias = schema_editor.connection.alias - UserLogoutStage.objects.using(db_alias).update_or_create( - name="default-invalidation-logout" - ) + UserLogoutStage.objects.using(db_alias).update_or_create(name="default-invalidation-logout") flow, _ = Flow.objects.using(db_alias).update_or_create( slug="default-invalidation-flow", diff --git a/authentik/flows/migrations/0009_source_flows.py b/authentik/flows/migrations/0009_source_flows.py index d262250f4..1662311f3 100644 --- a/authentik/flows/migrations/0009_source_flows.py +++ b/authentik/flows/migrations/0009_source_flows.py @@ -15,16 +15,12 @@ PROMPT_POLICY_EXPRESSION = """# Check if we've not been given a username by the return 'username' not in context.get('prompt_data', {})""" -def create_default_source_enrollment_flow( - apps: Apps, schema_editor: BaseDatabaseSchemaEditor -): +def create_default_source_enrollment_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): Flow = apps.get_model("authentik_flows", "Flow") FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") - ExpressionPolicy = apps.get_model( - "authentik_policies_expression", "ExpressionPolicy" - ) + ExpressionPolicy = apps.get_model("authentik_policies_expression", "ExpressionPolicy") PromptStage = apps.get_model("authentik_stages_prompt", "PromptStage") Prompt = apps.get_model("authentik_stages_prompt", "Prompt") @@ -99,16 +95,12 @@ def create_default_source_enrollment_flow( ) -def create_default_source_authentication_flow( - apps: Apps, schema_editor: BaseDatabaseSchemaEditor -): +def create_default_source_authentication_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): Flow = apps.get_model("authentik_flows", "Flow") FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") - ExpressionPolicy = apps.get_model( - "authentik_policies_expression", "ExpressionPolicy" - ) + ExpressionPolicy = apps.get_model("authentik_policies_expression", "ExpressionPolicy") UserLoginStage = apps.get_model("authentik_stages_user_login", "UserLoginStage") diff --git a/authentik/flows/migrations/0010_provider_flows.py b/authentik/flows/migrations/0010_provider_flows.py index b80e11b6f..45eafcb0d 100644 --- a/authentik/flows/migrations/0010_provider_flows.py +++ b/authentik/flows/migrations/0010_provider_flows.py @@ -7,9 +7,7 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor from authentik.flows.models import FlowDesignation -def create_default_provider_authorization_flow( - apps: Apps, schema_editor: BaseDatabaseSchemaEditor -): +def create_default_provider_authorization_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): Flow = apps.get_model("authentik_flows", "Flow") FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") diff --git a/authentik/flows/migrations/0018_oob_flows.py b/authentik/flows/migrations/0018_oob_flows.py index 9058dccd4..38c7c409d 100644 --- a/authentik/flows/migrations/0018_oob_flows.py +++ b/authentik/flows/migrations/0018_oob_flows.py @@ -32,9 +32,7 @@ def create_default_oobe_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor PromptStage = apps.get_model("authentik_stages_prompt", "PromptStage") Prompt = apps.get_model("authentik_stages_prompt", "Prompt") - ExpressionPolicy = apps.get_model( - "authentik_policies_expression", "ExpressionPolicy" - ) + ExpressionPolicy = apps.get_model("authentik_policies_expression", "ExpressionPolicy") db_alias = schema_editor.connection.alias @@ -52,9 +50,7 @@ def create_default_oobe_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor name="default-oobe-prefill-user", defaults={"expression": PREFILL_POLICY_EXPRESSION}, ) - password_usable_policy, _ = ExpressionPolicy.objects.using( - db_alias - ).update_or_create( + password_usable_policy, _ = ExpressionPolicy.objects.using(db_alias).update_or_create( name="default-oobe-password-usable", defaults={"expression": PW_USABLE_POLICY_EXPRESSION}, ) @@ -83,9 +79,7 @@ def create_default_oobe_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor prompt_stage, _ = PromptStage.objects.using(db_alias).update_or_create( name="default-oobe-password", ) - prompt_stage.fields.set( - [prompt_header, prompt_email, password_first, password_second] - ) + prompt_stage.fields.set([prompt_header, prompt_email, password_first, password_second]) prompt_stage.save() user_write, _ = UserWriteStage.objects.using(db_alias).update_or_create( diff --git a/authentik/flows/models.py b/authentik/flows/models.py index d4821ec27..944bc09b1 100644 --- a/authentik/flows/models.py +++ b/authentik/flows/models.py @@ -138,9 +138,7 @@ class Flow(SerializerModel, PolicyBindingModel): it is returned as-is""" if not self.background: return "/static/dist/assets/images/flow_background.jpg" - if self.background.name.startswith("http") or self.background.name.startswith( - "/static" - ): + if self.background.name.startswith("http") or self.background.name.startswith("/static"): return self.background.name return self.background.url @@ -165,9 +163,7 @@ class Flow(SerializerModel, PolicyBindingModel): if result.passing: LOGGER.debug("with_policy: flow passing", flow=flow) return flow - LOGGER.warning( - "with_policy: flow not passing", flow=flow, messages=result.messages - ) + LOGGER.warning("with_policy: flow not passing", flow=flow, messages=result.messages) LOGGER.debug("with_policy: no flow found", filters=flow_filter) return None diff --git a/authentik/flows/planner.py b/authentik/flows/planner.py index ed1373601..193421fe7 100644 --- a/authentik/flows/planner.py +++ b/authentik/flows/planner.py @@ -78,14 +78,10 @@ class FlowPlan: marker = self.markers[0] if marker.__class__ is not StageMarker: - LOGGER.debug( - "f(plan_inst): stage has marker", binding=binding, marker=marker - ) + LOGGER.debug("f(plan_inst): stage has marker", binding=binding, marker=marker) marked_stage = marker.process(self, binding, http_request) if not marked_stage: - LOGGER.debug( - "f(plan_inst): marker returned none, next stage", binding=binding - ) + LOGGER.debug("f(plan_inst): marker returned none, next stage", binding=binding) self.bindings.remove(binding) self.markers.remove(marker) if not self.has_stages: @@ -193,9 +189,9 @@ class FlowPlanner: if default_context: plan.context = default_context # Check Flow policies - for binding in FlowStageBinding.objects.filter( - target__pk=self.flow.pk - ).order_by("order"): + for binding in FlowStageBinding.objects.filter(target__pk=self.flow.pk).order_by( + "order" + ): binding: FlowStageBinding stage = binding.stage marker = StageMarker() diff --git a/authentik/flows/signals.py b/authentik/flows/signals.py index ac3a94024..c6773280d 100644 --- a/authentik/flows/signals.py +++ b/authentik/flows/signals.py @@ -26,9 +26,7 @@ def invalidate_flow_cache(sender, instance, **_): LOGGER.debug("Invalidating Flow cache", flow=instance, len=total) if isinstance(instance, FlowStageBinding): total = delete_cache_prefix(f"{cache_key(instance.target)}*") - LOGGER.debug( - "Invalidating Flow cache from FlowStageBinding", binding=instance, len=total - ) + LOGGER.debug("Invalidating Flow cache from FlowStageBinding", binding=instance, len=total) if isinstance(instance, Stage): total = 0 for binding in FlowStageBinding.objects.filter(stage=instance): diff --git a/authentik/flows/stage.py b/authentik/flows/stage.py index f8c4afe8a..be26a8104 100644 --- a/authentik/flows/stage.py +++ b/authentik/flows/stage.py @@ -42,14 +42,9 @@ class StageView(View): other things besides the form display. If no user is pending, returns request.user""" - if ( - PLAN_CONTEXT_PENDING_USER_IDENTIFIER in self.executor.plan.context - and for_display - ): + if PLAN_CONTEXT_PENDING_USER_IDENTIFIER in self.executor.plan.context and for_display: return User( - username=self.executor.plan.context.get( - PLAN_CONTEXT_PENDING_USER_IDENTIFIER - ), + username=self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER_IDENTIFIER), email="", ) if PLAN_CONTEXT_PENDING_USER in self.executor.plan.context: diff --git a/authentik/flows/tests/test_planner.py b/authentik/flows/tests/test_planner.py index 8e185da58..63c4c5193 100644 --- a/authentik/flows/tests/test_planner.py +++ b/authentik/flows/tests/test_planner.py @@ -89,14 +89,10 @@ class TestFlowPlanner(TestCase): planner = FlowPlanner(flow) planner.plan(request) - self.assertEqual( - CACHE_MOCK.set.call_count, 1 - ) # Ensure plan is written to cache + self.assertEqual(CACHE_MOCK.set.call_count, 1) # Ensure plan is written to cache planner = FlowPlanner(flow) planner.plan(request) - self.assertEqual( - CACHE_MOCK.set.call_count, 1 - ) # Ensure nothing is written to cache + self.assertEqual(CACHE_MOCK.set.call_count, 1) # Ensure nothing is written to cache self.assertEqual(CACHE_MOCK.get.call_count, 2) # Get is called twice def test_planner_default_context(self): @@ -176,9 +172,7 @@ class TestFlowPlanner(TestCase): request.session.save() # Here we patch the dummy policy to evaluate to true so the stage is included - with patch( - "authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE - ): + with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE): planner = FlowPlanner(flow) plan = planner.plan(request) diff --git a/authentik/flows/tests/test_transfer.py b/authentik/flows/tests/test_transfer.py index f268232eb..ca98920ea 100644 --- a/authentik/flows/tests/test_transfer.py +++ b/authentik/flows/tests/test_transfer.py @@ -76,9 +76,7 @@ class TestFlowTransfer(TransactionTestCase): PolicyBinding.objects.create(policy=flow_policy, target=flow, order=0) user_login = UserLoginStage.objects.create(name=stage_name) - fsb = FlowStageBinding.objects.create( - target=flow, stage=user_login, order=0 - ) + fsb = FlowStageBinding.objects.create(target=flow, stage=user_login, order=0) PolicyBinding.objects.create(policy=flow_policy, target=fsb, order=0) exporter = FlowExporter(flow) diff --git a/authentik/flows/tests/test_views.py b/authentik/flows/tests/test_views.py index 3ccf58f0d..c3bb838c2 100644 --- a/authentik/flows/tests/test_views.py +++ b/authentik/flows/tests/test_views.py @@ -11,12 +11,7 @@ from authentik.core.models import User from authentik.flows.challenge import ChallengeTypes from authentik.flows.exceptions import FlowNonApplicableException from authentik.flows.markers import ReevaluateMarker, StageMarker -from authentik.flows.models import ( - Flow, - FlowDesignation, - FlowStageBinding, - InvalidResponseAction, -) +from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding, InvalidResponseAction from authentik.flows.planner import FlowPlan, FlowPlanner from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView from authentik.flows.views import NEXT_ARG_NAME, SESSION_KEY_PLAN, FlowExecutorView @@ -61,9 +56,7 @@ class TestFlowExecutor(TestCase): ) stage = DummyStage.objects.create(name="dummy") binding = FlowStageBinding(target=flow, stage=stage, order=0) - plan = FlowPlan( - flow_pk=flow.pk.hex + "a", bindings=[binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=flow.pk.hex + "a", bindings=[binding], markers=[StageMarker()]) session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() @@ -163,9 +156,7 @@ class TestFlowExecutor(TestCase): target=flow, stage=DummyStage.objects.create(name="dummy2"), order=1 ) - exec_url = reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": flow.slug} - ) + exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) # First Request, start planning, renders form response = self.client.get(exec_url) self.assertEqual(response.status_code, 200) @@ -209,13 +200,9 @@ class TestFlowExecutor(TestCase): PolicyBinding.objects.create(policy=false_policy, target=binding2, order=0) # Here we patch the dummy policy to evaluate to true so the stage is included - with patch( - "authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE - ): + with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE): - exec_url = reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": flow.slug} - ) + exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) # First request, run the planner response = self.client.get(exec_url) self.assertEqual(response.status_code, 200) @@ -263,13 +250,9 @@ class TestFlowExecutor(TestCase): PolicyBinding.objects.create(policy=false_policy, target=binding2, order=0) # Here we patch the dummy policy to evaluate to true so the stage is included - with patch( - "authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE - ): + with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE): - exec_url = reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": flow.slug} - ) + exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) # First request, run the planner response = self.client.get(exec_url) @@ -334,13 +317,9 @@ class TestFlowExecutor(TestCase): PolicyBinding.objects.create(policy=true_policy, target=binding2, order=0) # Here we patch the dummy policy to evaluate to true so the stage is included - with patch( - "authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE - ): + with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE): - exec_url = reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": flow.slug} - ) + exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) # First request, run the planner response = self.client.get(exec_url) @@ -422,13 +401,9 @@ class TestFlowExecutor(TestCase): PolicyBinding.objects.create(policy=false_policy, target=binding3, order=0) # Here we patch the dummy policy to evaluate to true so the stage is included - with patch( - "authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE - ): + with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE): - exec_url = reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": flow.slug} - ) + exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) # First request, run the planner response = self.client.get(exec_url) self.assertEqual(response.status_code, 200) @@ -511,9 +486,7 @@ class TestFlowExecutor(TestCase): ) request.user = user planner = FlowPlanner(flow) - plan = planner.plan( - request, default_context={PLAN_CONTEXT_PENDING_USER_IDENTIFIER: ident} - ) + plan = planner.plan(request, default_context={PLAN_CONTEXT_PENDING_USER_IDENTIFIER: ident}) executor = FlowExecutorView() executor.plan = plan @@ -542,9 +515,7 @@ class TestFlowExecutor(TestCase): evaluate_on_plan=False, re_evaluate_policies=True, ) - PolicyBinding.objects.create( - policy=reputation_policy, target=deny_binding, order=0 - ) + PolicyBinding.objects.create(policy=reputation_policy, target=deny_binding, order=0) # Stage 1 is an identification stage ident_stage = IdentificationStage.objects.create( @@ -557,9 +528,7 @@ class TestFlowExecutor(TestCase): order=1, invalid_response_action=InvalidResponseAction.RESTART_WITH_CONTEXT, ) - exec_url = reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": flow.slug} - ) + exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) # First request, run the planner response = self.client.get(exec_url) self.assertEqual(response.status_code, 200) @@ -579,9 +548,7 @@ class TestFlowExecutor(TestCase): "user_fields": [UserFields.E_MAIL], }, ) - response = self.client.post( - exec_url, {"uid_field": "invalid-string"}, follow=True - ) + response = self.client.post(exec_url, {"uid_field": "invalid-string"}, follow=True) self.assertEqual(response.status_code, 200) self.assertJSONEqual( force_str(response.content), diff --git a/authentik/flows/tests/test_views_helper.py b/authentik/flows/tests/test_views_helper.py index 23bb5d5c5..5e04b5232 100644 --- a/authentik/flows/tests/test_views_helper.py +++ b/authentik/flows/tests/test_views_helper.py @@ -21,9 +21,7 @@ class TestHelperView(TestCase): response = self.client.get( reverse("authentik_flows:default-invalidation"), ) - expected_url = reverse( - "authentik_core:if-flow", kwargs={"flow_slug": flow.slug} - ) + expected_url = reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug}) self.assertEqual(response.status_code, 302) self.assertEqual(response.url, expected_url) @@ -40,8 +38,6 @@ class TestHelperView(TestCase): response = self.client.get( reverse("authentik_flows:default-invalidation"), ) - expected_url = reverse( - "authentik_core:if-flow", kwargs={"flow_slug": flow.slug} - ) + expected_url = reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug}) self.assertEqual(response.status_code, 302) self.assertEqual(response.url, expected_url) diff --git a/authentik/flows/transfer/common.py b/authentik/flows/transfer/common.py index f9d7d79fd..fb42282ad 100644 --- a/authentik/flows/transfer/common.py +++ b/authentik/flows/transfer/common.py @@ -44,9 +44,7 @@ class FlowBundleEntry: attrs: dict[str, Any] @staticmethod - def from_model( - model: SerializerModel, *extra_identifier_names: str - ) -> "FlowBundleEntry": + def from_model(model: SerializerModel, *extra_identifier_names: str) -> "FlowBundleEntry": """Convert a SerializerModel instance to a Bundle Entry""" identifiers = { "pk": model.pk, diff --git a/authentik/flows/transfer/exporter.py b/authentik/flows/transfer/exporter.py index d2b09b0f4..a39fdd5ed 100644 --- a/authentik/flows/transfer/exporter.py +++ b/authentik/flows/transfer/exporter.py @@ -6,11 +6,7 @@ from uuid import UUID from django.db.models import Q from authentik.flows.models import Flow, FlowStageBinding, Stage -from authentik.flows.transfer.common import ( - DataclassEncoder, - FlowBundle, - FlowBundleEntry, -) +from authentik.flows.transfer.common import DataclassEncoder, FlowBundle, FlowBundleEntry from authentik.policies.models import Policy, PolicyBinding from authentik.stages.prompt.models import PromptStage @@ -37,9 +33,7 @@ class FlowExporter: def walk_stages(self) -> Iterator[FlowBundleEntry]: """Convert all stages attached to self.flow into FlowBundleEntry objects""" - stages = ( - Stage.objects.filter(flow=self.flow).select_related().select_subclasses() - ) + stages = Stage.objects.filter(flow=self.flow).select_related().select_subclasses() for stage in stages: if isinstance(stage, PromptStage): pass @@ -56,9 +50,7 @@ class FlowExporter: a direct foreign key to a policy.""" # Special case for PromptStage as that has a direct M2M to policy, we have to ensure # all policies referenced in there we also include here - prompt_stages = PromptStage.objects.filter(flow=self.flow).values_list( - "pk", flat=True - ) + prompt_stages = PromptStage.objects.filter(flow=self.flow).values_list("pk", flat=True) query = Q(bindings__in=self.pbm_uuids) | Q(promptstage__in=prompt_stages) policies = Policy.objects.filter(query).select_related() for policy in policies: @@ -67,9 +59,7 @@ class FlowExporter: def walk_policy_bindings(self) -> Iterator[FlowBundleEntry]: """Walk over all policybindings relative to us. This is run at the end of the export, as we are sure all objects exist now.""" - bindings = PolicyBinding.objects.filter( - target__in=self.pbm_uuids - ).select_related() + bindings = PolicyBinding.objects.filter(target__in=self.pbm_uuids).select_related() for binding in bindings: yield FlowBundleEntry.from_model(binding, "policy", "target", "order") diff --git a/authentik/flows/transfer/importer.py b/authentik/flows/transfer/importer.py index c0e942a68..62b0382ba 100644 --- a/authentik/flows/transfer/importer.py +++ b/authentik/flows/transfer/importer.py @@ -16,11 +16,7 @@ from rest_framework.serializers import BaseSerializer, Serializer from structlog.stdlib import BoundLogger, get_logger from authentik.flows.models import Flow, FlowStageBinding, Stage -from authentik.flows.transfer.common import ( - EntryInvalidError, - FlowBundle, - FlowBundleEntry, -) +from authentik.flows.transfer.common import EntryInvalidError, FlowBundle, FlowBundleEntry from authentik.lib.models import SerializerModel from authentik.policies.models import Policy, PolicyBinding from authentik.stages.prompt.models import Prompt @@ -105,9 +101,7 @@ class FlowImporter: if isinstance(value, dict) and "pk" in value: del updated_identifiers[key] updated_identifiers[f"{key}"] = value["pk"] - existing_models = model.objects.filter( - self.__query_from_identifier(updated_identifiers) - ) + existing_models = model.objects.filter(self.__query_from_identifier(updated_identifiers)) serializer_kwargs = {} if existing_models.exists(): @@ -120,9 +114,7 @@ class FlowImporter: ) serializer_kwargs["instance"] = model_instance else: - self.logger.debug( - "initialise new instance", model=model, **updated_identifiers - ) + self.logger.debug("initialise new instance", model=model, **updated_identifiers) full_data = self.__update_pks_for_attrs(entry.attrs) full_data.update(updated_identifiers) serializer_kwargs["data"] = full_data diff --git a/authentik/flows/views.py b/authentik/flows/views.py index ddfb8fa44..4ff9f2d57 100644 --- a/authentik/flows/views.py +++ b/authentik/flows/views.py @@ -38,13 +38,7 @@ from authentik.flows.challenge import ( WithUserInfoChallenge, ) from authentik.flows.exceptions import EmptyFlowException, FlowNonApplicableException -from authentik.flows.models import ( - ConfigurableStage, - Flow, - FlowDesignation, - FlowStageBinding, - Stage, -) +from authentik.flows.models import ConfigurableStage, Flow, FlowDesignation, FlowStageBinding, Stage from authentik.flows.planner import ( PLAN_CONTEXT_PENDING_USER, PLAN_CONTEXT_REDIRECT, @@ -155,9 +149,7 @@ class FlowExecutorView(APIView): try: self.plan = self._initiate_plan() except FlowNonApplicableException as exc: - self._logger.warning( - "f(exec): Flow not applicable to current user", exc=exc - ) + self._logger.warning("f(exec): Flow not applicable to current user", exc=exc) return to_stage_response(self.request, self.handle_invalid_flow(exc)) except EmptyFlowException as exc: self._logger.warning("f(exec): Flow is empty", exc=exc) @@ -174,9 +166,7 @@ class FlowExecutorView(APIView): # in which case we just delete the plan and invalidate everything next_binding = self.plan.next(self.request) except Exception as exc: # pylint: disable=broad-except - self._logger.warning( - "f(exec): found incompatible flow plan, invalidating run", exc=exc - ) + self._logger.warning("f(exec): found incompatible flow plan, invalidating run", exc=exc) keys = cache.keys("flow_*") cache.delete_many(keys) return self.stage_invalid() @@ -314,9 +304,7 @@ class FlowExecutorView(APIView): self.request.session[SESSION_KEY_PLAN] = plan kwargs = self.kwargs kwargs.update({"flow_slug": self.flow.slug}) - return redirect_with_qs( - "authentik_api:flow-executor", self.request.GET, **kwargs - ) + return redirect_with_qs("authentik_api:flow-executor", self.request.GET, **kwargs) def _flow_done(self) -> HttpResponse: """User Successfully passed all stages""" @@ -350,9 +338,7 @@ class FlowExecutorView(APIView): ) kwargs = self.kwargs kwargs.update({"flow_slug": self.flow.slug}) - return redirect_with_qs( - "authentik_api:flow-executor", self.request.GET, **kwargs - ) + return redirect_with_qs("authentik_api:flow-executor", self.request.GET, **kwargs) # User passed all stages self._logger.debug( "f(exec): User passed all stages", @@ -408,18 +394,13 @@ class FlowErrorResponse(TemplateResponse): super().__init__(request=request, template="flows/error.html") self.error = error - def resolve_context( - self, context: Optional[dict[str, Any]] - ) -> Optional[dict[str, Any]]: + def resolve_context(self, context: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]: if not context: context = {} context["error"] = self.error if self._request.user and self._request.user.is_authenticated: - if ( - self._request.user.is_superuser - or self._request.user.group_attributes().get( - USER_ATTRIBUTE_DEBUG, False - ) + if self._request.user.is_superuser or self._request.user.group_attributes().get( + USER_ATTRIBUTE_DEBUG, False ): context["tb"] = "".join(format_tb(self.error.__traceback__)) return context @@ -464,9 +445,7 @@ class ToDefaultFlow(View): flow_slug=flow.slug, ) del self.request.session[SESSION_KEY_PLAN] - return redirect_with_qs( - "authentik_core:if-flow", request.GET, flow_slug=flow.slug - ) + return redirect_with_qs("authentik_core:if-flow", request.GET, flow_slug=flow.slug) def to_stage_response(request: HttpRequest, source: HttpResponse) -> HttpResponse: diff --git a/authentik/lib/config.py b/authentik/lib/config.py index 3673b1d8a..b5ef28aa4 100644 --- a/authentik/lib/config.py +++ b/authentik/lib/config.py @@ -115,9 +115,7 @@ class ConfigLoader: for key, value in os.environ.items(): if not key.startswith(ENV_PREFIX): continue - relative_key = ( - key.replace(f"{ENV_PREFIX}_", "", 1).replace("__", ".").lower() - ) + relative_key = key.replace(f"{ENV_PREFIX}_", "", 1).replace("__", ".").lower() # Recursively convert path from a.b.c into outer[a][b][c] current_obj = outer dot_parts = relative_key.split(".") diff --git a/authentik/lib/models.py b/authentik/lib/models.py index ca5589e55..202c692c3 100644 --- a/authentik/lib/models.py +++ b/authentik/lib/models.py @@ -37,15 +37,11 @@ class InheritanceAutoManager(InheritanceManager): return super().get_queryset().select_subclasses() -class InheritanceForwardManyToOneDescriptor( - models.fields.related.ForwardManyToOneDescriptor -): +class InheritanceForwardManyToOneDescriptor(models.fields.related.ForwardManyToOneDescriptor): """Forward ManyToOne Descriptor that selects subclass. Requires InheritanceAutoManager.""" def get_queryset(self, **hints): - return self.field.remote_field.model.objects.db_manager( - hints=hints - ).select_subclasses() + return self.field.remote_field.model.objects.db_manager(hints=hints).select_subclasses() class InheritanceForeignKey(models.ForeignKey): diff --git a/authentik/lib/sentry.py b/authentik/lib/sentry.py index 4746d426b..7ed780cb9 100644 --- a/authentik/lib/sentry.py +++ b/authentik/lib/sentry.py @@ -8,11 +8,7 @@ from botocore.exceptions import BotoCoreError from celery.exceptions import CeleryError from channels.middleware import BaseMiddleware from channels_redis.core import ChannelFull -from django.core.exceptions import ( - ImproperlyConfigured, - SuspiciousOperation, - ValidationError, -) +from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation, ValidationError from django.db import InternalError, OperationalError, ProgrammingError from django.http.response import Http404 from django_redis.exceptions import ConnectionInterrupted diff --git a/authentik/lib/tests/test_evaluator.py b/authentik/lib/tests/test_evaluator.py index 931224d46..eb61be90a 100644 --- a/authentik/lib/tests/test_evaluator.py +++ b/authentik/lib/tests/test_evaluator.py @@ -26,7 +26,5 @@ class TestEvaluator(TestCase): def test_is_group_member(self): """Test expr_is_group_member""" self.assertFalse( - BaseEvaluator.expr_is_group_member( - User.objects.get(username="akadmin"), name="test" - ) + BaseEvaluator.expr_is_group_member(User.objects.get(username="akadmin"), name="test") ) diff --git a/authentik/lib/tests/test_http.py b/authentik/lib/tests/test_http.py index 0c6399950..ac90cfdf8 100644 --- a/authentik/lib/tests/test_http.py +++ b/authentik/lib/tests/test_http.py @@ -1,17 +1,8 @@ """Test HTTP Helpers""" from django.test import RequestFactory, TestCase -from authentik.core.models import ( - USER_ATTRIBUTE_CAN_OVERRIDE_IP, - Token, - TokenIntents, - User, -) -from authentik.lib.utils.http import ( - OUTPOST_REMOTE_IP_HEADER, - OUTPOST_TOKEN_HEADER, - get_client_ip, -) +from authentik.core.models import USER_ATTRIBUTE_CAN_OVERRIDE_IP, Token, TokenIntents, User +from authentik.lib.utils.http import OUTPOST_REMOTE_IP_HEADER, OUTPOST_TOKEN_HEADER, get_client_ip class TestHTTP(TestCase): diff --git a/authentik/lib/tests/test_sentry.py b/authentik/lib/tests/test_sentry.py index ba899dabb..958ef5041 100644 --- a/authentik/lib/tests/test_sentry.py +++ b/authentik/lib/tests/test_sentry.py @@ -9,9 +9,7 @@ class TestSentry(TestCase): def test_error_not_sent(self): """Test SentryIgnoredError not sent""" - self.assertIsNone( - before_send({}, {"exc_info": (0, SentryIgnoredException(), 0)}) - ) + self.assertIsNone(before_send({}, {"exc_info": (0, SentryIgnoredException(), 0)})) def test_error_sent(self): """Test error sent""" diff --git a/authentik/lib/utils/http.py b/authentik/lib/utils/http.py index ac4b9b5dd..53d823d2c 100644 --- a/authentik/lib/utils/http.py +++ b/authentik/lib/utils/http.py @@ -29,16 +29,9 @@ def _get_outpost_override_ip(request: HttpRequest) -> Optional[str]: """Get the actual remote IP when set by an outpost. Only allowed when the request is authenticated, by a user with USER_ATTRIBUTE_CAN_OVERRIDE_IP set to outpost""" - from authentik.core.models import ( - USER_ATTRIBUTE_CAN_OVERRIDE_IP, - Token, - TokenIntents, - ) + from authentik.core.models import USER_ATTRIBUTE_CAN_OVERRIDE_IP, Token, TokenIntents - if ( - OUTPOST_REMOTE_IP_HEADER not in request.META - or OUTPOST_TOKEN_HEADER not in request.META - ): + if OUTPOST_REMOTE_IP_HEADER not in request.META or OUTPOST_TOKEN_HEADER not in request.META: return None fake_ip = request.META[OUTPOST_REMOTE_IP_HEADER] tokens = Token.filter_not_expired( diff --git a/authentik/managed/tasks.py b/authentik/managed/tasks.py index 0dbacb8c5..589ccf3c2 100644 --- a/authentik/managed/tasks.py +++ b/authentik/managed/tasks.py @@ -12,9 +12,7 @@ def managed_reconcile(self: MonitoredTask): try: ObjectManager().run() self.set_status( - TaskResult( - TaskResultStatus.SUCCESSFUL, ["Successfully updated managed models."] - ) + TaskResult(TaskResultStatus.SUCCESSFUL, ["Successfully updated managed models."]) ) except DatabaseError as exc: self.set_status(TaskResult(TaskResultStatus.WARNING, [str(exc)])) diff --git a/authentik/outposts/api/outposts.py b/authentik/outposts/api/outposts.py index 2e42265ba..a2a0ddefb 100644 --- a/authentik/outposts/api/outposts.py +++ b/authentik/outposts/api/outposts.py @@ -15,12 +15,7 @@ from authentik.core.api.used_by import UsedByMixin from authentik.core.api.utils import PassiveSerializer, is_dict from authentik.core.models import Provider from authentik.outposts.api.service_connections import ServiceConnectionSerializer -from authentik.outposts.models import ( - Outpost, - OutpostConfig, - OutpostType, - default_outpost_config, -) +from authentik.outposts.models import Outpost, OutpostConfig, OutpostType, default_outpost_config from authentik.providers.ldap.models import LDAPProvider from authentik.providers.proxy.models import ProxyProvider diff --git a/authentik/outposts/api/service_connections.py b/authentik/outposts/api/service_connections.py index 5e722aae3..8cf25b994 100644 --- a/authentik/outposts/api/service_connections.py +++ b/authentik/outposts/api/service_connections.py @@ -15,11 +15,7 @@ from rest_framework.serializers import ModelSerializer from rest_framework.viewsets import GenericViewSet, ModelViewSet from authentik.core.api.used_by import UsedByMixin -from authentik.core.api.utils import ( - MetaNameSerializer, - PassiveSerializer, - TypeCreateSerializer, -) +from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer from authentik.lib.utils.reflection import all_subclasses from authentik.outposts.models import ( DockerServiceConnection, @@ -129,9 +125,7 @@ class KubernetesServiceConnectionSerializer(ServiceConnectionSerializer): if kubeconfig == {}: if not self.initial_data["local"]: raise serializers.ValidationError( - _( - "You can only use an empty kubeconfig when connecting to a local cluster." - ) + _("You can only use an empty kubeconfig when connecting to a local cluster.") ) # Empty kubeconfig is valid return kubeconfig diff --git a/authentik/outposts/channels.py b/authentik/outposts/channels.py index 29ab35ab7..718640703 100644 --- a/authentik/outposts/channels.py +++ b/authentik/outposts/channels.py @@ -59,9 +59,7 @@ class OutpostConsumer(AuthJsonConsumer): def connect(self): super().connect() uuid = self.scope["url_route"]["kwargs"]["pk"] - outpost = get_objects_for_user( - self.user, "authentik_outposts.view_outpost" - ).filter(pk=uuid) + outpost = get_objects_for_user(self.user, "authentik_outposts.view_outpost").filter(pk=uuid) if not outpost.exists(): raise DenyConnection() self.accept() @@ -129,7 +127,5 @@ class OutpostConsumer(AuthJsonConsumer): def event_update(self, event): """Event handler which is called by post_save signals, Send update instruction""" self.send_json( - asdict( - WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE) - ) + asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE)) ) diff --git a/authentik/outposts/controllers/docker.py b/authentik/outposts/controllers/docker.py index c233073db..3412e9587 100644 --- a/authentik/outposts/controllers/docker.py +++ b/authentik/outposts/controllers/docker.py @@ -9,11 +9,7 @@ from yaml import safe_dump from authentik import __version__ from authentik.outposts.controllers.base import BaseController, ControllerException -from authentik.outposts.models import ( - DockerServiceConnection, - Outpost, - ServiceConnectionInvalid, -) +from authentik.outposts.models import DockerServiceConnection, Outpost, ServiceConnectionInvalid class DockerController(BaseController): @@ -37,9 +33,7 @@ class DockerController(BaseController): def _get_env(self) -> dict[str, str]: return { "AUTHENTIK_HOST": self.outpost.config.authentik_host.lower(), - "AUTHENTIK_INSECURE": str( - self.outpost.config.authentik_host_insecure - ).lower(), + "AUTHENTIK_INSECURE": str(self.outpost.config.authentik_host_insecure).lower(), "AUTHENTIK_TOKEN": self.outpost.token.key, } @@ -141,9 +135,7 @@ class DockerController(BaseController): .lower() != "unless-stopped" ): - self.logger.info( - "Container has mis-matched restart policy, re-creating..." - ) + self.logger.info("Container has mis-matched restart policy, re-creating...") self.down() return self.up() # Check that container is healthy @@ -157,9 +149,7 @@ class DockerController(BaseController): if has_been_created: # Since we've just created the container, give it some time to start. # If its still not up by then, restart it - self.logger.info( - "Container is unhealthy and new, giving it time to boot." - ) + self.logger.info("Container is unhealthy and new, giving it time to boot.") sleep(60) self.logger.info("Container is unhealthy, restarting...") container.restart() @@ -198,9 +188,7 @@ class DockerController(BaseController): "ports": ports, "environment": { "AUTHENTIK_HOST": self.outpost.config.authentik_host, - "AUTHENTIK_INSECURE": str( - self.outpost.config.authentik_host_insecure - ), + "AUTHENTIK_INSECURE": str(self.outpost.config.authentik_host_insecure), "AUTHENTIK_TOKEN": self.outpost.token.key, }, "labels": self._get_labels(), diff --git a/authentik/outposts/controllers/k8s/deployment.py b/authentik/outposts/controllers/k8s/deployment.py index 8c75118af..f9abacd2a 100644 --- a/authentik/outposts/controllers/k8s/deployment.py +++ b/authentik/outposts/controllers/k8s/deployment.py @@ -17,10 +17,7 @@ from kubernetes.client import ( ) from authentik.outposts.controllers.base import FIELD_MANAGER -from authentik.outposts.controllers.k8s.base import ( - KubernetesObjectReconciler, - NeedsUpdate, -) +from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler, NeedsUpdate from authentik.outposts.models import Outpost if TYPE_CHECKING: @@ -124,9 +121,7 @@ class DeploymentReconciler(KubernetesObjectReconciler[V1Deployment]): ) def delete(self, reference: V1Deployment): - return self.api.delete_namespaced_deployment( - reference.metadata.name, self.namespace - ) + return self.api.delete_namespaced_deployment(reference.metadata.name, self.namespace) def retrieve(self) -> V1Deployment: return self.api.read_namespaced_deployment(self.name, self.namespace) diff --git a/authentik/outposts/controllers/k8s/secret.py b/authentik/outposts/controllers/k8s/secret.py index 99ec34fea..f15047f6b 100644 --- a/authentik/outposts/controllers/k8s/secret.py +++ b/authentik/outposts/controllers/k8s/secret.py @@ -5,10 +5,7 @@ from typing import TYPE_CHECKING from kubernetes.client import CoreV1Api, V1Secret from authentik.outposts.controllers.base import FIELD_MANAGER -from authentik.outposts.controllers.k8s.base import ( - KubernetesObjectReconciler, - NeedsUpdate, -) +from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler, NeedsUpdate if TYPE_CHECKING: from authentik.outposts.controllers.kubernetes import KubernetesController @@ -38,9 +35,7 @@ class SecretReconciler(KubernetesObjectReconciler[V1Secret]): return V1Secret( metadata=meta, data={ - "authentik_host": b64string( - self.controller.outpost.config.authentik_host - ), + "authentik_host": b64string(self.controller.outpost.config.authentik_host), "authentik_host_insecure": b64string( str(self.controller.outpost.config.authentik_host_insecure) ), @@ -54,9 +49,7 @@ class SecretReconciler(KubernetesObjectReconciler[V1Secret]): ) def delete(self, reference: V1Secret): - return self.api.delete_namespaced_secret( - reference.metadata.name, self.namespace - ) + return self.api.delete_namespaced_secret(reference.metadata.name, self.namespace) def retrieve(self) -> V1Secret: return self.api.read_namespaced_secret(self.name, self.namespace) diff --git a/authentik/outposts/controllers/k8s/service.py b/authentik/outposts/controllers/k8s/service.py index b076de161..c2d7d015a 100644 --- a/authentik/outposts/controllers/k8s/service.py +++ b/authentik/outposts/controllers/k8s/service.py @@ -4,10 +4,7 @@ from typing import TYPE_CHECKING from kubernetes.client import CoreV1Api, V1Service, V1ServicePort, V1ServiceSpec from authentik.outposts.controllers.base import FIELD_MANAGER -from authentik.outposts.controllers.k8s.base import ( - KubernetesObjectReconciler, - NeedsUpdate, -) +from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler, NeedsUpdate from authentik.outposts.controllers.k8s.deployment import DeploymentReconciler if TYPE_CHECKING: @@ -58,9 +55,7 @@ class ServiceReconciler(KubernetesObjectReconciler[V1Service]): ) def delete(self, reference: V1Service): - return self.api.delete_namespaced_service( - reference.metadata.name, self.namespace - ) + return self.api.delete_namespaced_service(reference.metadata.name, self.namespace) def retrieve(self) -> V1Service: return self.api.read_namespaced_service(self.name, self.namespace) diff --git a/authentik/outposts/controllers/kubernetes.py b/authentik/outposts/controllers/kubernetes.py index 68363fabb..fb2aa689f 100644 --- a/authentik/outposts/controllers/kubernetes.py +++ b/authentik/outposts/controllers/kubernetes.py @@ -24,9 +24,7 @@ class KubernetesController(BaseController): client: ApiClient connection: KubernetesServiceConnection - def __init__( - self, outpost: Outpost, connection: KubernetesServiceConnection - ) -> None: + def __init__(self, outpost: Outpost, connection: KubernetesServiceConnection) -> None: super().__init__(outpost, connection) self.client = connection.client() self.reconcilers = { diff --git a/authentik/outposts/migrations/0002_auto_20200826_1306.py b/authentik/outposts/migrations/0002_auto_20200826_1306.py index 4a916f37d..2eb6f0d2e 100644 --- a/authentik/outposts/migrations/0002_auto_20200826_1306.py +++ b/authentik/outposts/migrations/0002_auto_20200826_1306.py @@ -15,9 +15,7 @@ class Migration(migrations.Migration): migrations.AddField( model_name="outpost", name="_config", - field=models.JSONField( - default=authentik.outposts.models.default_outpost_config - ), + field=models.JSONField(default=authentik.outposts.models.default_outpost_config), ), migrations.AddField( model_name="outpost", diff --git a/authentik/outposts/migrations/0009_fix_missing_token_identifier.py b/authentik/outposts/migrations/0009_fix_missing_token_identifier.py index c6a70abb6..2cc600574 100644 --- a/authentik/outposts/migrations/0009_fix_missing_token_identifier.py +++ b/authentik/outposts/migrations/0009_fix_missing_token_identifier.py @@ -10,9 +10,7 @@ def fix_missing_token_identifier(apps: Apps, schema_editor: BaseDatabaseSchemaEd Token = apps.get_model("authentik_core", "Token") from authentik.outposts.models import Outpost - for outpost in ( - Outpost.objects.using(schema_editor.connection.alias).all().only("pk") - ): + for outpost in Outpost.objects.using(schema_editor.connection.alias).all().only("pk"): user_identifier = outpost.user_identifier users = User.objects.filter(username=user_identifier) if not users.exists(): diff --git a/authentik/outposts/migrations/0010_service_connection.py b/authentik/outposts/migrations/0010_service_connection.py index 51a0c3283..b13433def 100644 --- a/authentik/outposts/migrations/0010_service_connection.py +++ b/authentik/outposts/migrations/0010_service_connection.py @@ -14,9 +14,7 @@ import authentik.lib.models def migrate_to_service_connection(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): db_alias = schema_editor.connection.alias Outpost = apps.get_model("authentik_outposts", "Outpost") - DockerServiceConnection = apps.get_model( - "authentik_outposts", "DockerServiceConnection" - ) + DockerServiceConnection = apps.get_model("authentik_outposts", "DockerServiceConnection") KubernetesServiceConnection = apps.get_model( "authentik_outposts", "KubernetesServiceConnection" ) @@ -25,9 +23,7 @@ def migrate_to_service_connection(apps: Apps, schema_editor: BaseDatabaseSchemaE k8s = KubernetesServiceConnection.objects.filter(local=True).first() try: - for outpost in ( - Outpost.objects.using(db_alias).all().exclude(deployment_type="custom") - ): + for outpost in Outpost.objects.using(db_alias).all().exclude(deployment_type="custom"): if outpost.deployment_type == "kubernetes": outpost.service_connection = k8s elif outpost.deployment_type == "docker": diff --git a/authentik/outposts/migrations/0013_auto_20201203_2009.py b/authentik/outposts/migrations/0013_auto_20201203_2009.py index 58c67dcf9..223fae32d 100644 --- a/authentik/outposts/migrations/0013_auto_20201203_2009.py +++ b/authentik/outposts/migrations/0013_auto_20201203_2009.py @@ -11,9 +11,7 @@ def remove_pb_prefix_users(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): Outpost = apps.get_model("authentik_outposts", "Outpost") for outpost in Outpost.objects.using(alias).all(): - matching = User.objects.using(alias).filter( - username=f"pb-outpost-{outpost.uuid.hex}" - ) + matching = User.objects.using(alias).filter(username=f"pb-outpost-{outpost.uuid.hex}") if matching.exists(): matching.delete() diff --git a/authentik/outposts/migrations/0016_alter_outpost_type.py b/authentik/outposts/migrations/0016_alter_outpost_type.py index 966b86997..9103e40a3 100644 --- a/authentik/outposts/migrations/0016_alter_outpost_type.py +++ b/authentik/outposts/migrations/0016_alter_outpost_type.py @@ -13,8 +13,6 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="outpost", name="type", - field=models.TextField( - choices=[("proxy", "Proxy"), ("ldap", "Ldap")], default="proxy" - ), + field=models.TextField(choices=[("proxy", "Proxy"), ("ldap", "Ldap")], default="proxy"), ), ] diff --git a/authentik/outposts/models.py b/authentik/outposts/models.py index 4b76fa1b0..c159cb633 100644 --- a/authentik/outposts/models.py +++ b/authentik/outposts/models.py @@ -64,9 +64,7 @@ class OutpostConfig: log_level: str = CONFIG.y("log_level") error_reporting_enabled: bool = CONFIG.y_bool("error_reporting.enabled") - error_reporting_environment: str = CONFIG.y( - "error_reporting.environment", "customer" - ) + error_reporting_environment: str = CONFIG.y("error_reporting.environment", "customer") object_naming_template: str = field(default="ak-outpost-%(name)s") kubernetes_replicas: int = field(default=1) @@ -264,9 +262,7 @@ class KubernetesServiceConnection(OutpostServiceConnection): client = self.client() api_instance = VersionApi(client) version: VersionInfo = api_instance.get_code() - return OutpostServiceConnectionState( - version=version.git_version, healthy=True - ) + return OutpostServiceConnectionState(version=version.git_version, healthy=True) except (OpenApiException, HTTPError, ServiceConnectionInvalid): return OutpostServiceConnectionState(version="", healthy=False) @@ -360,8 +356,7 @@ class Outpost(ManagedModel): if isinstance(model_or_perm, models.Model): model_or_perm: models.Model code_name = ( - f"{model_or_perm._meta.app_label}." - f"view_{model_or_perm._meta.model_name}" + f"{model_or_perm._meta.app_label}." f"view_{model_or_perm._meta.model_name}" ) assign_perm(code_name, user, model_or_perm) else: @@ -417,9 +412,7 @@ class Outpost(ManagedModel): self, "authentik_events.add_event", ] - for provider in ( - Provider.objects.filter(outpost=self).select_related().select_subclasses() - ): + for provider in Provider.objects.filter(outpost=self).select_related().select_subclasses(): if isinstance(provider, OutpostModel): objects.extend(provider.get_required_objects()) else: diff --git a/authentik/outposts/signals.py b/authentik/outposts/signals.py index 9a712613e..025f8d970 100644 --- a/authentik/outposts/signals.py +++ b/authentik/outposts/signals.py @@ -9,11 +9,7 @@ from authentik.core.models import Provider from authentik.crypto.models import CertificateKeyPair from authentik.lib.utils.reflection import class_to_path from authentik.outposts.models import Outpost, OutpostServiceConnection -from authentik.outposts.tasks import ( - CACHE_KEY_OUTPOST_DOWN, - outpost_controller, - outpost_post_save, -) +from authentik.outposts.tasks import CACHE_KEY_OUTPOST_DOWN, outpost_controller, outpost_post_save LOGGER = get_logger() UPDATE_TRIGGERING_MODELS = ( @@ -37,9 +33,7 @@ def pre_save_outpost(sender, instance: Outpost, **_): # Name changes the deployment name, need to recreate dirty += old_instance.name != instance.name # namespace requires re-create - dirty += ( - old_instance.config.kubernetes_namespace != instance.config.kubernetes_namespace - ) + dirty += old_instance.config.kubernetes_namespace != instance.config.kubernetes_namespace if bool(dirty): LOGGER.info("Outpost needs re-deployment due to changes", instance=instance) cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, old_instance) diff --git a/authentik/outposts/tasks.py b/authentik/outposts/tasks.py index da683f8fd..eb27d904f 100644 --- a/authentik/outposts/tasks.py +++ b/authentik/outposts/tasks.py @@ -62,9 +62,7 @@ def controller_for_outpost(outpost: Outpost) -> Optional[BaseController]: def outpost_service_connection_state(connection_pk: Any): """Update cached state of a service connection""" connection: OutpostServiceConnection = ( - OutpostServiceConnection.objects.filter(pk=connection_pk) - .select_subclasses() - .first() + OutpostServiceConnection.objects.filter(pk=connection_pk).select_subclasses().first() ) if not connection: return @@ -157,9 +155,7 @@ def outpost_post_save(model_class: str, model_pk: Any): outpost_controller.delay(instance.pk) if isinstance(instance, (OutpostModel, Outpost)): - LOGGER.debug( - "triggering outpost update from outpostmodel/outpost", instance=instance - ) + LOGGER.debug("triggering outpost update from outpostmodel/outpost", instance=instance) outpost_send_update(instance) if isinstance(instance, OutpostServiceConnection): @@ -208,9 +204,7 @@ def _outpost_single_update(outpost: Outpost, layer=None): layer = get_channel_layer() for state in OutpostState.for_outpost(outpost): for channel in state.channel_ids: - LOGGER.debug( - "sending update", channel=channel, instance=state.uid, outpost=outpost - ) + LOGGER.debug("sending update", channel=channel, instance=state.uid, outpost=outpost) async_to_sync(layer.send)(channel, {"type": "event.update"}) @@ -231,9 +225,7 @@ def outpost_local_connection(): if Path(kubeconfig_path).exists(): LOGGER.debug("Detected kubeconfig") kubeconfig_local_name = f"k8s-{gethostname()}" - if not KubernetesServiceConnection.objects.filter( - name=kubeconfig_local_name - ).exists(): + if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists(): LOGGER.debug("Creating kubeconfig Service Connection") with open(kubeconfig_path, "r") as _kubeconfig: KubernetesServiceConnection.objects.create( diff --git a/authentik/outposts/tests/test_api.py b/authentik/outposts/tests/test_api.py index 923b33f6f..483a4e50b 100644 --- a/authentik/outposts/tests/test_api.py +++ b/authentik/outposts/tests/test_api.py @@ -63,9 +63,7 @@ class TestOutpostServiceConnectionsAPI(APITestCase): provider = ProxyProvider.objects.create( name="test", authorization_flow=Flow.objects.first() ) - invalid = OutpostSerializer( - data={"name": "foo", "providers": [provider.pk], "config": {}} - ) + invalid = OutpostSerializer(data={"name": "foo", "providers": [provider.pk], "config": {}}) self.assertFalse(invalid.is_valid()) self.assertIn("config", invalid.errors) valid = OutpostSerializer( diff --git a/authentik/policies/api/bindings.py b/authentik/policies/api/bindings.py index a9d03fe61..844c9acde 100644 --- a/authentik/policies/api/bindings.py +++ b/authentik/policies/api/bindings.py @@ -2,11 +2,7 @@ from typing import OrderedDict from django.core.exceptions import ObjectDoesNotExist -from rest_framework.serializers import ( - ModelSerializer, - PrimaryKeyRelatedField, - ValidationError, -) +from rest_framework.serializers import ModelSerializer, PrimaryKeyRelatedField, ValidationError from rest_framework.viewsets import ModelViewSet from structlog.stdlib import get_logger diff --git a/authentik/policies/api/policies.py b/authentik/policies/api/policies.py index 6be8b36b6..3d3753802 100644 --- a/authentik/policies/api/policies.py +++ b/authentik/policies/api/policies.py @@ -15,11 +15,7 @@ from structlog.stdlib import get_logger from authentik.api.decorators import permission_required from authentik.core.api.applications import user_app_cache_key from authentik.core.api.used_by import UsedByMixin -from authentik.core.api.utils import ( - CacheSerializer, - MetaNameSerializer, - TypeCreateSerializer, -) +from authentik.core.api.utils import CacheSerializer, MetaNameSerializer, TypeCreateSerializer from authentik.lib.utils.reflection import all_subclasses from authentik.policies.api.exec import PolicyTestResultSerializer, PolicyTestSerializer from authentik.policies.models import Policy, PolicyBinding @@ -58,9 +54,7 @@ class PolicySerializer(ModelSerializer, MetaNameSerializer): # pyright: reportGeneralTypeIssues=false if instance.__class__ == Policy or not self._resolve_inheritance: return super().to_representation(instance) - return dict( - instance.serializer(instance=instance, resolve_inheritance=False).data - ) + return dict(instance.serializer(instance=instance, resolve_inheritance=False).data) class Meta: @@ -95,9 +89,7 @@ class PolicyViewSet( search_fields = ["name"] def get_queryset(self): # pragma: no cover - return Policy.objects.select_subclasses().prefetch_related( - "bindings", "promptstage_set" - ) + return Policy.objects.select_subclasses().prefetch_related("bindings", "promptstage_set") @extend_schema(responses={200: TypeCreateSerializer(many=True)}) @action(detail=False, pagination_class=None, filter_backends=[]) diff --git a/authentik/policies/denied.py b/authentik/policies/denied.py index 9ffc2a7b1..349a17ce5 100644 --- a/authentik/policies/denied.py +++ b/authentik/policies/denied.py @@ -23,9 +23,7 @@ class AccessDeniedResponse(TemplateResponse): super().__init__(request, template) self.title = _("Access denied") - def resolve_context( - self, context: Optional[dict[str, Any]] - ) -> Optional[dict[str, Any]]: + def resolve_context(self, context: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]: if not context: context = {} context["title"] = self.title @@ -35,11 +33,8 @@ class AccessDeniedResponse(TemplateResponse): # either superuser or has USER_ATTRIBUTE_DEBUG set if self.policy_result: if self._request.user and self._request.user.is_authenticated: - if ( - self._request.user.is_superuser - or self._request.user.group_attributes().get( - USER_ATTRIBUTE_DEBUG, False - ) + if self._request.user.is_superuser or self._request.user.group_attributes().get( + USER_ATTRIBUTE_DEBUG, False ): context["policy_result"] = self.policy_result return context diff --git a/authentik/policies/dummy/tests.py b/authentik/policies/dummy/tests.py index c433d0d39..364ae54e7 100644 --- a/authentik/policies/dummy/tests.py +++ b/authentik/policies/dummy/tests.py @@ -15,9 +15,7 @@ class TestDummyPolicy(TestCase): def test_policy(self): """test policy .passes""" - policy: DummyPolicy = DummyPolicy.objects.create( - name="dummy", wait_min=1, wait_max=2 - ) + policy: DummyPolicy = DummyPolicy.objects.create(name="dummy", wait_min=1, wait_max=2) result = policy.passes(self.request) self.assertFalse(result.passing) self.assertEqual(result.messages, ("dummy",)) diff --git a/authentik/policies/engine.py b/authentik/policies/engine.py index 9f69b514b..98e2a70a3 100644 --- a/authentik/policies/engine.py +++ b/authentik/policies/engine.py @@ -11,12 +11,7 @@ from sentry_sdk.tracing import Span from structlog.stdlib import BoundLogger, get_logger from authentik.core.models import User -from authentik.policies.models import ( - Policy, - PolicyBinding, - PolicyBindingModel, - PolicyEngineMode, -) +from authentik.policies.models import Policy, PolicyBinding, PolicyBindingModel, PolicyEngineMode from authentik.policies.process import PolicyProcess, cache_key from authentik.policies.types import PolicyRequest, PolicyResult from authentik.root.monitoring import UpdatingGauge @@ -42,9 +37,7 @@ class PolicyProcessInfo: result: Optional[PolicyResult] binding: PolicyBinding - def __init__( - self, process: PolicyProcess, connection: Connection, binding: PolicyBinding - ): + def __init__(self, process: PolicyProcess, connection: Connection, binding: PolicyBinding): self.process = process self.connection = connection self.binding = binding @@ -62,9 +55,7 @@ class PolicyEngine: # Allow objects with no policies attached to pass empty_result: bool - def __init__( - self, pbm: PolicyBindingModel, user: User, request: HttpRequest = None - ): + def __init__(self, pbm: PolicyBindingModel, user: User, request: HttpRequest = None): self.logger = get_logger().bind() self.mode = pbm.policy_engine_mode # For backwards compatibility, set empty_result to true @@ -123,15 +114,11 @@ class PolicyEngine: ) self.__cached_policies.append(cached_policy) continue - self.logger.debug( - "P_ENG: Evaluating policy", binding=binding, request=self.request - ) + self.logger.debug("P_ENG: Evaluating policy", binding=binding, request=self.request) our_end, task_end = Pipe(False) task = PolicyProcess(binding, self.request, task_end) task.daemon = False - self.logger.debug( - "P_ENG: Starting Process", binding=binding, request=self.request - ) + self.logger.debug("P_ENG: Starting Process", binding=binding, request=self.request) if not CURRENT_PROCESS._config.get("daemon"): task.run() else: @@ -151,9 +138,7 @@ class PolicyEngine: @property def result(self) -> PolicyResult: """Get policy-checking result""" - process_results: list[PolicyResult] = [ - x.result for x in self.__processes if x.result - ] + process_results: list[PolicyResult] = [x.result for x in self.__processes if x.result] all_results = list(process_results + self.__cached_policies) if len(all_results) < self.__expected_result_count: # pragma: no cover raise AssertionError("Got less results than polices") diff --git a/authentik/policies/event_matcher/tests.py b/authentik/policies/event_matcher/tests.py index 156dcf86b..2c6347430 100644 --- a/authentik/policies/event_matcher/tests.py +++ b/authentik/policies/event_matcher/tests.py @@ -15,9 +15,7 @@ class TestEventMatcherPolicy(TestCase): event = Event.new(EventAction.LOGIN) request = PolicyRequest(get_anonymous_user()) request.context["event"] = event - policy: EventMatcherPolicy = EventMatcherPolicy.objects.create( - action=EventAction.LOGIN - ) + policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(action=EventAction.LOGIN) response = policy.passes(request) self.assertTrue(response.passing) self.assertTupleEqual(response.messages, ("Action matched.",)) @@ -28,9 +26,7 @@ class TestEventMatcherPolicy(TestCase): event.client_ip = "1.2.3.4" request = PolicyRequest(get_anonymous_user()) request.context["event"] = event - policy: EventMatcherPolicy = EventMatcherPolicy.objects.create( - client_ip="1.2.3.4" - ) + policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(client_ip="1.2.3.4") response = policy.passes(request) self.assertTrue(response.passing) self.assertTupleEqual(response.messages, ("Client IP matched.",)) @@ -52,17 +48,13 @@ class TestEventMatcherPolicy(TestCase): event.client_ip = "1.2.3.4" request = PolicyRequest(get_anonymous_user()) request.context["event"] = event - policy: EventMatcherPolicy = EventMatcherPolicy.objects.create( - client_ip="1.2.3.5" - ) + policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(client_ip="1.2.3.5") response = policy.passes(request) self.assertFalse(response.passing) def test_invalid(self): """Test passing event""" request = PolicyRequest(get_anonymous_user()) - policy: EventMatcherPolicy = EventMatcherPolicy.objects.create( - client_ip="1.2.3.4" - ) + policy: EventMatcherPolicy = EventMatcherPolicy.objects.create(client_ip="1.2.3.4") response = policy.passes(request) self.assertFalse(response.passing) diff --git a/authentik/policies/expiry/models.py b/authentik/policies/expiry/models.py index dd76d91af..47b7de56b 100644 --- a/authentik/policies/expiry/models.py +++ b/authentik/policies/expiry/models.py @@ -42,10 +42,7 @@ class PasswordExpiryPolicy(Policy): request.user.set_unusable_password() request.user.save() message = _( - ( - "Password expired %(days)d days ago. " - "Please update your password." - ) + ("Password expired %(days)d days ago. " "Please update your password.") % {"days": days_since_expiry} ) return PolicyResult(False, message) diff --git a/authentik/policies/expression/migrations/0002_auto_20200926_1156.py b/authentik/policies/expression/migrations/0002_auto_20200926_1156.py index 0a9f1ddcc..59d1a0a70 100644 --- a/authentik/policies/expression/migrations/0002_auto_20200926_1156.py +++ b/authentik/policies/expression/migrations/0002_auto_20200926_1156.py @@ -6,9 +6,7 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor def remove_pb_flow_plan(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): - ExpressionPolicy = apps.get_model( - "authentik_policies_expression", "ExpressionPolicy" - ) + ExpressionPolicy = apps.get_model("authentik_policies_expression", "ExpressionPolicy") db_alias = schema_editor.connection.alias diff --git a/authentik/policies/expression/migrations/0003_auto_20201203_1223.py b/authentik/policies/expression/migrations/0003_auto_20201203_1223.py index f0a0c4086..f9f335681 100644 --- a/authentik/policies/expression/migrations/0003_auto_20201203_1223.py +++ b/authentik/policies/expression/migrations/0003_auto_20201203_1223.py @@ -6,18 +6,14 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor def replace_pb_prefix(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): - ExpressionPolicy = apps.get_model( - "authentik_policies_expression", "ExpressionPolicy" - ) + ExpressionPolicy = apps.get_model("authentik_policies_expression", "ExpressionPolicy") db_alias = schema_editor.connection.alias for policy in ExpressionPolicy.objects.using(db_alias).all(): # Because the previous migration had a broken replace, we have to replace here again policy.expression = policy.expression.replace("pb_flow_plan.", "context.") - policy.expression = policy.expression.replace( - "pb_is_sso_flow", "ak_is_sso_flow" - ) + policy.expression = policy.expression.replace("pb_is_sso_flow", "ak_is_sso_flow") policy.save() diff --git a/authentik/policies/hibp/models.py b/authentik/policies/hibp/models.py index 4b2032a1d..13b90a049 100644 --- a/authentik/policies/hibp/models.py +++ b/authentik/policies/hibp/models.py @@ -19,9 +19,7 @@ class HaveIBeenPwendPolicy(Policy): password_field = models.TextField( default="password", - help_text=_( - "Field key to check, field keys defined in Prompt stages are available." - ), + help_text=_("Field key to check, field keys defined in Prompt stages are available."), ) allowed_count = models.IntegerField(default=0) @@ -59,9 +57,7 @@ class HaveIBeenPwendPolicy(Policy): final_count = int(count) LOGGER.debug("got hibp result", count=final_count, hash=pw_hash[:5]) if final_count > self.allowed_count: - message = _( - "Password exists on %(count)d online lists." % {"count": final_count} - ) + message = _("Password exists on %(count)d online lists." % {"count": final_count}) return PolicyResult(False, message) return PolicyResult(True) diff --git a/authentik/policies/models.py b/authentik/policies/models.py index 380e6eaa6..3a4f30ffa 100644 --- a/authentik/policies/models.py +++ b/authentik/policies/models.py @@ -49,9 +49,7 @@ class PolicyBindingModel(models.Model): class PolicyBinding(SerializerModel): """Relationship between a Policy and a PolicyBindingModel.""" - policy_binding_uuid = models.UUIDField( - primary_key=True, editable=False, default=uuid4 - ) + policy_binding_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) enabled = models.BooleanField(default=True) @@ -81,9 +79,7 @@ class PolicyBinding(SerializerModel): blank=True, ) - target = InheritanceForeignKey( - PolicyBindingModel, on_delete=models.CASCADE, related_name="+" - ) + target = InheritanceForeignKey(PolicyBindingModel, on_delete=models.CASCADE, related_name="+") negate = models.BooleanField( default=False, help_text=_("Negates the outcome of the policy. Messages are unaffected."), diff --git a/authentik/policies/password/models.py b/authentik/policies/password/models.py index 286033eec..be3fea1c3 100644 --- a/authentik/policies/password/models.py +++ b/authentik/policies/password/models.py @@ -17,9 +17,7 @@ class PasswordPolicy(Policy): password_field = models.TextField( default="password", - help_text=_( - "Field key to check, field keys defined in Prompt stages are available." - ), + help_text=_("Field key to check, field keys defined in Prompt stages are available."), ) amount_uppercase = models.IntegerField(default=0) @@ -55,9 +53,7 @@ class PasswordPolicy(Policy): if self.amount_uppercase > 0: filter_regex.append(r"[A-Z]{%d,}" % self.amount_uppercase) if self.amount_symbols > 0: - filter_regex.append( - r"[%s]{%d,}" % (self.symbol_charset, self.amount_symbols) - ) + filter_regex.append(r"[%s]{%d,}" % (self.symbol_charset, self.amount_symbols)) full_regex = "|".join(filter_regex) LOGGER.debug("Built regex", regexp=full_regex) result = bool(re.compile(full_regex).match(password)) diff --git a/authentik/policies/reputation/api.py b/authentik/policies/reputation/api.py index e3bf747be..2d5f610c3 100644 --- a/authentik/policies/reputation/api.py +++ b/authentik/policies/reputation/api.py @@ -5,11 +5,7 @@ from rest_framework.viewsets import GenericViewSet, ModelViewSet from authentik.core.api.used_by import UsedByMixin from authentik.policies.api.policies import PolicySerializer -from authentik.policies.reputation.models import ( - IPReputation, - ReputationPolicy, - UserReputation, -) +from authentik.policies.reputation.models import IPReputation, ReputationPolicy, UserReputation class ReputationPolicySerializer(PolicySerializer): diff --git a/authentik/policies/reputation/signals.py b/authentik/policies/reputation/signals.py index c02670a9c..4f4267fc3 100644 --- a/authentik/policies/reputation/signals.py +++ b/authentik/policies/reputation/signals.py @@ -7,10 +7,7 @@ from structlog.stdlib import get_logger from authentik.lib.config import CONFIG from authentik.lib.utils.http import get_client_ip -from authentik.policies.reputation.models import ( - CACHE_KEY_IP_PREFIX, - CACHE_KEY_USER_PREFIX, -) +from authentik.policies.reputation.models import CACHE_KEY_IP_PREFIX, CACHE_KEY_USER_PREFIX from authentik.stages.identification.signals import identification_failed LOGGER = get_logger() diff --git a/authentik/policies/reputation/tasks.py b/authentik/policies/reputation/tasks.py index 4af1f39e6..503d3a739 100644 --- a/authentik/policies/reputation/tasks.py +++ b/authentik/policies/reputation/tasks.py @@ -4,10 +4,7 @@ from structlog.stdlib import get_logger from authentik.events.monitored_tasks import MonitoredTask, TaskResult, TaskResultStatus from authentik.policies.reputation.models import IPReputation, UserReputation -from authentik.policies.reputation.signals import ( - CACHE_KEY_IP_PREFIX, - CACHE_KEY_USER_PREFIX, -) +from authentik.policies.reputation.signals import CACHE_KEY_IP_PREFIX, CACHE_KEY_USER_PREFIX from authentik.root.celery import CELERY_APP LOGGER = get_logger() @@ -23,9 +20,7 @@ def save_ip_reputation(self: MonitoredTask): rep.score = score objects_to_update.append(rep) IPReputation.objects.bulk_update(objects_to_update, ["score"]) - self.set_status( - TaskResult(TaskResultStatus.SUCCESSFUL, ["Successfully updated IP Reputation"]) - ) + self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, ["Successfully updated IP Reputation"])) @CELERY_APP.task(bind=True, base=MonitoredTask) @@ -39,7 +34,5 @@ def save_user_reputation(self: MonitoredTask): objects_to_update.append(rep) UserReputation.objects.bulk_update(objects_to_update, ["score"]) self.set_status( - TaskResult( - TaskResultStatus.SUCCESSFUL, ["Successfully updated User Reputation"] - ) + TaskResult(TaskResultStatus.SUCCESSFUL, ["Successfully updated User Reputation"]) ) diff --git a/authentik/policies/reputation/tests.py b/authentik/policies/reputation/tests.py index 9e944c584..db648618d 100644 --- a/authentik/policies/reputation/tests.py +++ b/authentik/policies/reputation/tests.py @@ -33,9 +33,7 @@ class TestReputationPolicy(TestCase): def test_ip_reputation(self): """test IP reputation""" # Trigger negative reputation - authenticate( - self.request, username=self.test_username, password=self.test_username - ) + authenticate(self.request, username=self.test_username, password=self.test_username) # Test value in cache self.assertEqual(cache.get(CACHE_KEY_IP_PREFIX + self.test_ip), -1) # Save cache and check db values @@ -45,16 +43,12 @@ class TestReputationPolicy(TestCase): def test_user_reputation(self): """test User reputation""" # Trigger negative reputation - authenticate( - self.request, username=self.test_username, password=self.test_username - ) + authenticate(self.request, username=self.test_username, password=self.test_username) # Test value in cache self.assertEqual(cache.get(CACHE_KEY_USER_PREFIX + self.test_username), -1) # Save cache and check db values save_user_reputation.delay().get() - self.assertEqual( - UserReputation.objects.get(username=self.test_username).score, -1 - ) + self.assertEqual(UserReputation.objects.get(username=self.test_username).score, -1) def test_policy(self): """Test Policy""" diff --git a/authentik/policies/signals.py b/authentik/policies/signals.py index fabde968f..c9baf9624 100644 --- a/authentik/policies/signals.py +++ b/authentik/policies/signals.py @@ -18,9 +18,7 @@ def invalidate_policy_cache(sender, instance, **_): if isinstance(instance, Policy): total = 0 for binding in PolicyBinding.objects.filter(policy=instance): - prefix = ( - f"policy_{binding.policy_binding_uuid.hex}_{binding.policy.pk.hex}*" - ) + prefix = f"policy_{binding.policy_binding_uuid.hex}_{binding.policy.pk.hex}*" keys = cache.keys(prefix) total += len(keys) cache.delete_many(keys) diff --git a/authentik/policies/tests/test_bindings_api.py b/authentik/policies/tests/test_bindings_api.py index fc699391d..51be90073 100644 --- a/authentik/policies/tests/test_bindings_api.py +++ b/authentik/policies/tests/test_bindings_api.py @@ -37,11 +37,7 @@ class TestBindingsAPI(APITestCase): ) self.assertJSONEqual( response.content.decode(), - { - "non_field_errors": [ - "Only one of 'policy', 'group' or 'user' can be set." - ] - }, + {"non_field_errors": ["Only one of 'policy', 'group' or 'user' can be set."]}, ) def test_invalid_too_little(self): diff --git a/authentik/policies/tests/test_engine.py b/authentik/policies/tests/test_engine.py index 0223718ca..28b60ba8f 100644 --- a/authentik/policies/tests/test_engine.py +++ b/authentik/policies/tests/test_engine.py @@ -6,12 +6,7 @@ from authentik.core.models import User from authentik.policies.dummy.models import DummyPolicy from authentik.policies.engine import PolicyEngine from authentik.policies.expression.models import ExpressionPolicy -from authentik.policies.models import ( - Policy, - PolicyBinding, - PolicyBindingModel, - PolicyEngineMode, -) +from authentik.policies.models import Policy, PolicyBinding, PolicyBindingModel, PolicyEngineMode from authentik.policies.tests.test_process import clear_policy_cache @@ -21,16 +16,10 @@ class TestPolicyEngine(TestCase): def setUp(self): clear_policy_cache() self.user = User.objects.create_user(username="policyuser") - self.policy_false = DummyPolicy.objects.create( - result=False, wait_min=0, wait_max=1 - ) - self.policy_true = DummyPolicy.objects.create( - result=True, wait_min=0, wait_max=1 - ) + self.policy_false = DummyPolicy.objects.create(result=False, wait_min=0, wait_max=1) + self.policy_true = DummyPolicy.objects.create(result=True, wait_min=0, wait_max=1) self.policy_wrong_type = Policy.objects.create(name="wrong_type") - self.policy_raises = ExpressionPolicy.objects.create( - name="raises", expression="{{ 0/0 }}" - ) + self.policy_raises = ExpressionPolicy.objects.create(name="raises", expression="{{ 0/0 }}") def test_engine_empty(self): """Ensure empty policy list passes""" @@ -51,9 +40,7 @@ class TestPolicyEngine(TestCase): def test_engine_mode_all(self): """Ensure all policies passes with AND mode (false and true -> false)""" - pbm = PolicyBindingModel.objects.create( - policy_engine_mode=PolicyEngineMode.MODE_ALL - ) + pbm = PolicyBindingModel.objects.create(policy_engine_mode=PolicyEngineMode.MODE_ALL) PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0) PolicyBinding.objects.create(target=pbm, policy=self.policy_true, order=1) engine = PolicyEngine(pbm, self.user) @@ -69,9 +56,7 @@ class TestPolicyEngine(TestCase): def test_engine_mode_any(self): """Ensure all policies passes with OR mode (false and true -> true)""" - pbm = PolicyBindingModel.objects.create( - policy_engine_mode=PolicyEngineMode.MODE_ANY - ) + pbm = PolicyBindingModel.objects.create(policy_engine_mode=PolicyEngineMode.MODE_ANY) PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0) PolicyBinding.objects.create(target=pbm, policy=self.policy_true, order=1) engine = PolicyEngine(pbm, self.user) @@ -88,9 +73,7 @@ class TestPolicyEngine(TestCase): def test_engine_negate(self): """Test negate flag""" pbm = PolicyBindingModel.objects.create() - PolicyBinding.objects.create( - target=pbm, policy=self.policy_true, negate=True, order=0 - ) + PolicyBinding.objects.create(target=pbm, policy=self.policy_true, negate=True, order=0) engine = PolicyEngine(pbm, self.user) result = engine.build().result self.assertEqual(result.passing, False) @@ -116,18 +99,10 @@ class TestPolicyEngine(TestCase): def test_engine_cache(self): """Ensure empty policy list passes""" pbm = PolicyBindingModel.objects.create() - binding = PolicyBinding.objects.create( - target=pbm, policy=self.policy_false, order=0 - ) + binding = PolicyBinding.objects.create(target=pbm, policy=self.policy_false, order=0) engine = PolicyEngine(pbm, self.user) - self.assertEqual( - len(cache.keys(f"policy_{binding.policy_binding_uuid.hex}*")), 0 - ) + self.assertEqual(len(cache.keys(f"policy_{binding.policy_binding_uuid.hex}*")), 0) self.assertEqual(engine.build().passing, False) - self.assertEqual( - len(cache.keys(f"policy_{binding.policy_binding_uuid.hex}*")), 1 - ) + self.assertEqual(len(cache.keys(f"policy_{binding.policy_binding_uuid.hex}*")), 1) self.assertEqual(engine.build().passing, False) - self.assertEqual( - len(cache.keys(f"policy_{binding.policy_binding_uuid.hex}*")), 1 - ) + self.assertEqual(len(cache.keys(f"policy_{binding.policy_binding_uuid.hex}*")), 1) diff --git a/authentik/policies/tests/test_policies_api.py b/authentik/policies/tests/test_policies_api.py index fc20a0177..53ec110a6 100644 --- a/authentik/policies/tests/test_policies_api.py +++ b/authentik/policies/tests/test_policies_api.py @@ -23,9 +23,7 @@ class TestPoliciesAPI(APITestCase): "user": self.user.pk, }, ) - self.assertJSONEqual( - response.content.decode(), {"passing": True, "messages": ["dummy"]} - ) + self.assertJSONEqual(response.content.decode(), {"passing": True, "messages": ["dummy"]}) def test_types(self): """Test Policy's types endpoint""" diff --git a/authentik/policies/tests/test_process.py b/authentik/policies/tests/test_process.py index f1e913639..1c2fde4b9 100644 --- a/authentik/policies/tests/test_process.py +++ b/authentik/policies/tests/test_process.py @@ -112,9 +112,7 @@ class TestPolicyProcess(TestCase): def test_exception(self): """Test policy execution""" policy = Policy.objects.create(name="test-execution") - binding = PolicyBinding( - policy=policy, target=Application.objects.create(name="test") - ) + binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test")) request = PolicyRequest(self.user) response = PolicyProcess(binding, request, None).execute() @@ -129,9 +127,7 @@ class TestPolicyProcess(TestCase): wait_max=1, execution_logging=True, ) - binding = PolicyBinding( - policy=policy, target=Application.objects.create(name="test") - ) + binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test")) http_request = self.factory.get("/") http_request.user = self.user @@ -163,9 +159,7 @@ class TestPolicyProcess(TestCase): wait_max=1, execution_logging=True, ) - binding = PolicyBinding( - policy=policy, target=Application.objects.create(name="test") - ) + binding = PolicyBinding(policy=policy, target=Application.objects.create(name="test")) user = AnonymousUser() @@ -192,9 +186,7 @@ class TestPolicyProcess(TestCase): def test_raises(self): """Test policy that raises error""" - policy_raises = ExpressionPolicy.objects.create( - name="raises", expression="{{ 0/0 }}" - ) + policy_raises = ExpressionPolicy.objects.create(name="raises", expression="{{ 0/0 }}") binding = PolicyBinding( policy=policy_raises, target=Application.objects.create(name="test") ) diff --git a/authentik/policies/views.py b/authentik/policies/views.py index 6fbb7c139..138e79639 100644 --- a/authentik/policies/views.py +++ b/authentik/policies/views.py @@ -102,9 +102,7 @@ class PolicyAccessView(AccessMixin, View): def user_has_access(self, user: Optional[User] = None) -> PolicyResult: """Check if user has access to application.""" user = user or self.request.user - policy_engine = PolicyEngine( - self.application, user or self.request.user, self.request - ) + policy_engine = PolicyEngine(self.application, user or self.request.user, self.request) policy_engine.use_cache = False policy_engine.build() result = policy_engine.result diff --git a/authentik/providers/oauth2/api/provider.py b/authentik/providers/oauth2/api/provider.py index daa099bd1..e601b6661 100644 --- a/authentik/providers/oauth2/api/provider.py +++ b/authentik/providers/oauth2/api/provider.py @@ -22,13 +22,8 @@ class OAuth2ProviderSerializer(ProviderSerializer): def validate_jwt_alg(self, value): """Ensure that when RS256 is selected, a certificate-key-pair is selected""" - if ( - self.initial_data.get("rsa_key", None) is None - and value == JWTAlgorithms.RS256 - ): - raise ValidationError( - _("RS256 requires a Certificate-Key-Pair to be selected.") - ) + if self.initial_data.get("rsa_key", None) is None and value == JWTAlgorithms.RS256: + raise ValidationError(_("RS256 requires a Certificate-Key-Pair to be selected.")) return value class Meta: diff --git a/authentik/providers/oauth2/errors.py b/authentik/providers/oauth2/errors.py index 8e79b8f6e..3e24ff1d4 100644 --- a/authentik/providers/oauth2/errors.py +++ b/authentik/providers/oauth2/errors.py @@ -71,9 +71,7 @@ class ClientIdError(OAuth2Error): self.client_id = client_id def to_event(self, **kwargs) -> Event: - return super().to_event( - f"Invalid client identifier: {self.client_id}.", **kwargs - ) + return super().to_event(f"Invalid client identifier: {self.client_id}.", **kwargs) class UserAuthError(OAuth2Error): @@ -103,8 +101,7 @@ class AuthorizeError(OAuth2Error): "invalid_request": "The request is otherwise malformed", "unauthorized_client": "The client is not authorized to request an " "authorization code using this method", - "access_denied": "The resource owner or authorization server denied " - "the request", + "access_denied": "The resource owner or authorization server denied " "the request", "unsupported_response_type": "The authorization server does not " "support obtaining an authorization code " "using this method", @@ -118,17 +115,14 @@ class AuthorizeError(OAuth2Error): # http://openid.net/specs/openid-connect-core-1_0.html#AuthError "interaction_required": "The Authorization Server requires End-User " "interaction of some form to proceed", - "login_required": "The Authorization Server requires End-User " - "authentication", + "login_required": "The Authorization Server requires End-User " "authentication", "account_selection_required": "The End-User is required to select a " "session at the Authorization Server", "consent_required": "The Authorization Server requires End-User" "consent", "invalid_request_uri": "The request_uri in the Authorization Request " "returns an error or contains invalid data", - "invalid_request_object": "The request parameter contains an invalid " - "Request Object", - "request_not_supported": "The provider does not support use of the " - "request parameter", + "invalid_request_object": "The request parameter contains an invalid " "Request Object", + "request_not_supported": "The provider does not support use of the " "request parameter", "request_uri_not_supported": "The provider does not support use of the " "request_uri parameter", "registration_not_supported": "The provider does not support use of " @@ -210,8 +204,7 @@ class BearerTokenError(OAuth2Error): 401, ), "insufficient_scope": ( - "The request requires higher privileges than provided by " - "the access token", + "The request requires higher privileges than provided by " "the access token", 403, ), } diff --git a/authentik/providers/oauth2/generators.py b/authentik/providers/oauth2/generators.py index 57df54ea8..0df19d51d 100644 --- a/authentik/providers/oauth2/generators.py +++ b/authentik/providers/oauth2/generators.py @@ -12,6 +12,4 @@ def generate_client_id(): def generate_client_secret(): """Generate a suitable client secret""" rand = SystemRandom() - return "".join( - rand.choice(string.ascii_letters + string.digits) for x in range(128) - ) + return "".join(rand.choice(string.ascii_letters + string.digits) for x in range(128)) diff --git a/authentik/providers/oauth2/migrations/0001_initial.py b/authentik/providers/oauth2/migrations/0001_initial.py index 3140ebba6..9a531742d 100644 --- a/authentik/providers/oauth2/migrations/0001_initial.py +++ b/authentik/providers/oauth2/migrations/0001_initial.py @@ -22,12 +22,8 @@ class Migration(migrations.Migration): ] operations = [ - migrations.RunSQL( - "DROP TABLE IF EXISTS authentik_providers_oauth_oauth2provider CASCADE;" - ), - migrations.RunSQL( - "DROP TABLE IF EXISTS authentik_providers_oidc_openidprovider CASCADE;" - ), + migrations.RunSQL("DROP TABLE IF EXISTS authentik_providers_oauth_oauth2provider CASCADE;"), + migrations.RunSQL("DROP TABLE IF EXISTS authentik_providers_oidc_openidprovider CASCADE;"), migrations.CreateModel( name="OAuth2Provider", fields=[ @@ -136,9 +132,7 @@ class Migration(migrations.Migration): models.TextField( default="minutes=10", help_text="Tokens not valid on or after current time + this value (Format: hours=1;minutes=2;seconds=3).", - validators=[ - authentik.lib.utils.time.timedelta_string_validator - ], + validators=[authentik.lib.utils.time.timedelta_string_validator], ), ), ( @@ -202,23 +196,17 @@ class Migration(migrations.Migration): ), ( "expires", - models.DateTimeField( - default=authentik.core.models.default_token_duration - ), + models.DateTimeField(default=authentik.core.models.default_token_duration), ), ("expiring", models.BooleanField(default=True)), ("_scope", models.TextField(default="", verbose_name="Scopes")), ( "access_token", - models.CharField( - max_length=255, unique=True, verbose_name="Access Token" - ), + models.CharField(max_length=255, unique=True, verbose_name="Access Token"), ), ( "refresh_token", - models.CharField( - max_length=255, unique=True, verbose_name="Refresh Token" - ), + models.CharField(max_length=255, unique=True, verbose_name="Refresh Token"), ), ("_id_token", models.TextField(verbose_name="ID Token")), ( @@ -256,9 +244,7 @@ class Migration(migrations.Migration): ), ( "expires", - models.DateTimeField( - default=authentik.core.models.default_token_duration - ), + models.DateTimeField(default=authentik.core.models.default_token_duration), ), ("expiring", models.BooleanField(default=True)), ("_scope", models.TextField(default="", verbose_name="Scopes")), @@ -268,21 +254,15 @@ class Migration(migrations.Migration): ), ( "nonce", - models.CharField( - blank=True, default="", max_length=255, verbose_name="Nonce" - ), + models.CharField(blank=True, default="", max_length=255, verbose_name="Nonce"), ), ( "is_open_id", - models.BooleanField( - default=False, verbose_name="Is Authentication?" - ), + models.BooleanField(default=False, verbose_name="Is Authentication?"), ), ( "code_challenge", - models.CharField( - max_length=255, null=True, verbose_name="Code Challenge" - ), + models.CharField(max_length=255, null=True, verbose_name="Code Challenge"), ), ( "code_challenge_method", diff --git a/authentik/providers/oauth2/migrations/0011_managed.py b/authentik/providers/oauth2/migrations/0011_managed.py index 9920131a9..5b1764920 100644 --- a/authentik/providers/oauth2/migrations/0011_managed.py +++ b/authentik/providers/oauth2/migrations/0011_managed.py @@ -14,9 +14,7 @@ scope_uid_map = { def set_managed_flag(apps: Apps, schema_editor): ScopeMapping = apps.get_model("authentik_providers_oauth2", "ScopeMapping") db_alias = schema_editor.connection.alias - for mapping in ScopeMapping.objects.using(db_alias).filter( - name__startswith="Autogenerated " - ): + for mapping in ScopeMapping.objects.using(db_alias).filter(name__startswith="Autogenerated "): mapping.managed = scope_uid_map[mapping.scope_name] mapping.save() diff --git a/authentik/providers/oauth2/models.py b/authentik/providers/oauth2/models.py index 5ae621389..11b6eb960 100644 --- a/authentik/providers/oauth2/models.py +++ b/authentik/providers/oauth2/models.py @@ -25,10 +25,7 @@ from authentik.events.utils import get_user from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator from authentik.providers.oauth2.apps import AuthentikProviderOAuth2Config from authentik.providers.oauth2.constants import ACR_AUTHENTIK_DEFAULT -from authentik.providers.oauth2.generators import ( - generate_client_id, - generate_client_secret, -) +from authentik.providers.oauth2.generators import generate_client_id, generate_client_secret class ClientTypes(models.TextChoices): @@ -208,9 +205,7 @@ class OAuth2Provider(Provider): issuer_mode = models.TextField( choices=IssuerMode.choices, default=IssuerMode.PER_PROVIDER, - help_text=_( - ("Configure how the issuer field of the ID Token should be filled.") - ), + help_text=_(("Configure how the issuer field of the ID Token should be filled.")), ) rsa_key = models.ForeignKey( @@ -339,12 +334,8 @@ class AuthorizationCode(ExpiringModel, BaseGrantModel): code = models.CharField(max_length=255, unique=True, verbose_name=_("Code")) nonce = models.TextField(null=True, default=None, verbose_name=_("Nonce")) - is_open_id = models.BooleanField( - default=False, verbose_name=_("Is Authentication?") - ) - code_challenge = models.CharField( - max_length=255, null=True, verbose_name=_("Code Challenge") - ) + is_open_id = models.BooleanField(default=False, verbose_name=_("Is Authentication?")) + code_challenge = models.CharField(max_length=255, null=True, verbose_name=_("Code Challenge")) code_challenge_method = models.CharField( max_length=255, null=True, verbose_name=_("Code Challenge Method") ) @@ -354,9 +345,7 @@ class AuthorizationCode(ExpiringModel, BaseGrantModel): """https://openid.net/specs/openid-connect-core-1_0.html#IDToken""" hashed_code = sha256(self.code.encode("ascii")).hexdigest().encode("ascii") return ( - base64.urlsafe_b64encode( - binascii.unhexlify(hashed_code[: len(hashed_code) // 2]) - ) + base64.urlsafe_b64encode(binascii.unhexlify(hashed_code[: len(hashed_code) // 2])) .rstrip(b"=") .decode("ascii") ) @@ -407,9 +396,7 @@ class RefreshToken(ExpiringModel, BaseGrantModel): """OAuth2 Refresh Token""" access_token = models.TextField(verbose_name=_("Access Token")) - refresh_token = models.CharField( - max_length=255, unique=True, verbose_name=_("Refresh Token") - ) + refresh_token = models.CharField(max_length=255, unique=True, verbose_name=_("Refresh Token")) _id_token = models.TextField(verbose_name=_("ID Token")) class Meta: @@ -434,9 +421,7 @@ class RefreshToken(ExpiringModel, BaseGrantModel): @property def at_hash(self): """Get hashed access_token""" - hashed_access_token = ( - sha256(self.access_token.encode("ascii")).hexdigest().encode("ascii") - ) + hashed_access_token = sha256(self.access_token.encode("ascii")).hexdigest().encode("ascii") return ( base64.urlsafe_b64encode( binascii.unhexlify(hashed_access_token[: len(hashed_access_token) // 2]) @@ -477,9 +462,9 @@ class RefreshToken(ExpiringModel, BaseGrantModel): iat_time = now exp_time = int(dateformat.format(self.expires, "U")) # We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time - auth_events = Event.objects.filter( - action=EventAction.LOGIN, user=get_user(user) - ).order_by("-created") + auth_events = Event.objects.filter(action=EventAction.LOGIN, user=get_user(user)).order_by( + "-created" + ) # Fallback in case we can't find any login events auth_time = datetime.now() if auth_events.exists(): diff --git a/authentik/providers/oauth2/tests/test_api.py b/authentik/providers/oauth2/tests/test_api.py index 6ff1dbfb5..6830a09e2 100644 --- a/authentik/providers/oauth2/tests/test_api.py +++ b/authentik/providers/oauth2/tests/test_api.py @@ -24,9 +24,7 @@ class TestOAuth2ProviderAPI(APITestCase): data={ "name": "test", "jwt_alg": str(JWTAlgorithms.RS256), - "authorization_flow": Flow.objects.filter( - designation=FlowDesignation.AUTHORIZATION - ) + "authorization_flow": Flow.objects.filter(designation=FlowDesignation.AUTHORIZATION) .first() .pk, }, diff --git a/authentik/providers/oauth2/tests/test_authorize.py b/authentik/providers/oauth2/tests/test_authorize.py index f96533fc6..603549968 100644 --- a/authentik/providers/oauth2/tests/test_authorize.py +++ b/authentik/providers/oauth2/tests/test_authorize.py @@ -7,15 +7,8 @@ from authentik.core.models import Application, User from authentik.crypto.models import CertificateKeyPair from authentik.flows.challenge import ChallengeTypes from authentik.flows.models import Flow -from authentik.providers.oauth2.errors import ( - AuthorizeError, - ClientIdError, - RedirectUriError, -) -from authentik.providers.oauth2.generators import ( - generate_client_id, - generate_client_secret, -) +from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError +from authentik.providers.oauth2.generators import generate_client_id, generate_client_secret from authentik.providers.oauth2.models import ( AuthorizationCode, GrantTypes, @@ -42,9 +35,7 @@ class TestAuthorize(OAuthTestCase): def test_invalid_client_id(self): """Test invalid client ID""" with self.assertRaises(ClientIdError): - request = self.factory.get( - "/", data={"response_type": "code", "client_id": "invalid"} - ) + request = self.factory.get("/", data={"response_type": "code", "client_id": "invalid"}) OAuthAuthorizationParams.from_request(request) def test_request(self): @@ -76,9 +67,7 @@ class TestAuthorize(OAuthTestCase): redirect_uris="http://local.invalid", ) with self.assertRaises(RedirectUriError): - request = self.factory.get( - "/", data={"response_type": "code", "client_id": "test"} - ) + request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) OAuthAuthorizationParams.from_request(request) with self.assertRaises(RedirectUriError): request = self.factory.get( @@ -99,9 +88,7 @@ class TestAuthorize(OAuthTestCase): authorization_flow=Flow.objects.first(), ) with self.assertRaises(RedirectUriError): - request = self.factory.get( - "/", data={"response_type": "code", "client_id": "test"} - ) + request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) OAuthAuthorizationParams.from_request(request) request = self.factory.get( "/", diff --git a/authentik/providers/oauth2/tests/test_jwks.py b/authentik/providers/oauth2/tests/test_jwks.py index cede1014b..9c928a083 100644 --- a/authentik/providers/oauth2/tests/test_jwks.py +++ b/authentik/providers/oauth2/tests/test_jwks.py @@ -30,9 +30,7 @@ class TestJWKS(OAuthTestCase): ) app = Application.objects.create(name="test", slug="test", provider=provider) response = self.client.get( - reverse( - "authentik_providers_oauth2:jwks", kwargs={"application_slug": app.slug} - ) + reverse("authentik_providers_oauth2:jwks", kwargs={"application_slug": app.slug}) ) body = json.loads(force_str(response.content)) self.assertEqual(len(body["keys"]), 1) @@ -47,8 +45,6 @@ class TestJWKS(OAuthTestCase): ) app = Application.objects.create(name="test", slug="test", provider=provider) response = self.client.get( - reverse( - "authentik_providers_oauth2:jwks", kwargs={"application_slug": app.slug} - ) + reverse("authentik_providers_oauth2:jwks", kwargs={"application_slug": app.slug}) ) self.assertJSONEqual(force_str(response.content), {}) diff --git a/authentik/providers/oauth2/tests/test_token.py b/authentik/providers/oauth2/tests/test_token.py index 41a48981e..ed919f3bb 100644 --- a/authentik/providers/oauth2/tests/test_token.py +++ b/authentik/providers/oauth2/tests/test_token.py @@ -13,15 +13,8 @@ from authentik.providers.oauth2.constants import ( GRANT_TYPE_AUTHORIZATION_CODE, GRANT_TYPE_REFRESH_TOKEN, ) -from authentik.providers.oauth2.generators import ( - generate_client_id, - generate_client_secret, -) -from authentik.providers.oauth2.models import ( - AuthorizationCode, - OAuth2Provider, - RefreshToken, -) +from authentik.providers.oauth2.generators import generate_client_id, generate_client_secret +from authentik.providers.oauth2.models import AuthorizationCode, OAuth2Provider, RefreshToken from authentik.providers.oauth2.tests.utils import OAuthTestCase from authentik.providers.oauth2.views.token import TokenParams @@ -44,13 +37,9 @@ class TestToken(OAuthTestCase): redirect_uris="http://testserver", rsa_key=CertificateKeyPair.objects.first(), ) - header = b64encode( - f"{provider.client_id}:{provider.client_secret}".encode() - ).decode() + header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() user = User.objects.get(username="akadmin") - code = AuthorizationCode.objects.create( - code="foobar", provider=provider, user=user - ) + code = AuthorizationCode.objects.create(code="foobar", provider=provider, user=user) request = self.factory.post( "/", data={ @@ -60,9 +49,7 @@ class TestToken(OAuthTestCase): }, HTTP_AUTHORIZATION=f"Basic {header}", ) - params = TokenParams.parse( - request, provider, provider.client_id, provider.client_secret - ) + params = TokenParams.parse(request, provider, provider.client_id, provider.client_secret) self.assertEqual(params.provider, provider) def test_request_refresh_token(self): @@ -75,9 +62,7 @@ class TestToken(OAuthTestCase): redirect_uris="http://local.invalid", rsa_key=CertificateKeyPair.objects.first(), ) - header = b64encode( - f"{provider.client_id}:{provider.client_secret}".encode() - ).decode() + header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() user = User.objects.get(username="akadmin") token: RefreshToken = RefreshToken.objects.create( provider=provider, @@ -93,9 +78,7 @@ class TestToken(OAuthTestCase): }, HTTP_AUTHORIZATION=f"Basic {header}", ) - params = TokenParams.parse( - request, provider, provider.client_id, provider.client_secret - ) + params = TokenParams.parse(request, provider, provider.client_id, provider.client_secret) self.assertEqual(params.provider, provider) def test_auth_code_view(self): @@ -111,9 +94,7 @@ class TestToken(OAuthTestCase): # Needs to be assigned to an application for iss to be set self.app.provider = provider self.app.save() - header = b64encode( - f"{provider.client_id}:{provider.client_secret}".encode() - ).decode() + header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() user = User.objects.get(username="akadmin") code = AuthorizationCode.objects.create( code="foobar", provider=provider, user=user, is_open_id=True @@ -155,9 +136,7 @@ class TestToken(OAuthTestCase): # Needs to be assigned to an application for iss to be set self.app.provider = provider self.app.save() - header = b64encode( - f"{provider.client_id}:{provider.client_secret}".encode() - ).decode() + header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() user = User.objects.get(username="akadmin") token: RefreshToken = RefreshToken.objects.create( provider=provider, @@ -178,9 +157,7 @@ class TestToken(OAuthTestCase): RefreshToken.objects.filter(user=user).exclude(pk=token.pk).first() ) self.assertEqual(response["Access-Control-Allow-Credentials"], "true") - self.assertEqual( - response["Access-Control-Allow-Origin"], "http://local.invalid" - ) + self.assertEqual(response["Access-Control-Allow-Origin"], "http://local.invalid") self.assertJSONEqual( force_str(response.content), { @@ -205,9 +182,7 @@ class TestToken(OAuthTestCase): redirect_uris="http://local.invalid", rsa_key=CertificateKeyPair.objects.first(), ) - header = b64encode( - f"{provider.client_id}:{provider.client_secret}".encode() - ).decode() + header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() user = User.objects.get(username="akadmin") token: RefreshToken = RefreshToken.objects.create( provider=provider, @@ -255,9 +230,7 @@ class TestToken(OAuthTestCase): # Needs to be assigned to an application for iss to be set self.app.provider = provider self.app.save() - header = b64encode( - f"{provider.client_id}:{provider.client_secret}".encode() - ).decode() + header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode() user = User.objects.get(username="akadmin") token: RefreshToken = RefreshToken.objects.create( provider=provider, @@ -300,6 +273,4 @@ class TestToken(OAuthTestCase): HTTP_AUTHORIZATION=f"Basic {header}", ) self.assertEqual(response.status_code, 400) - self.assertTrue( - Event.objects.filter(action=EventAction.SUSPICIOUS_REQUEST).exists() - ) + self.assertTrue(Event.objects.filter(action=EventAction.SUSPICIOUS_REQUEST).exists()) diff --git a/authentik/providers/oauth2/tests/test_userinfo.py b/authentik/providers/oauth2/tests/test_userinfo.py index 462e748da..fbe66383c 100644 --- a/authentik/providers/oauth2/tests/test_userinfo.py +++ b/authentik/providers/oauth2/tests/test_userinfo.py @@ -10,16 +10,8 @@ from authentik.crypto.models import CertificateKeyPair from authentik.events.models import Event, EventAction from authentik.flows.models import Flow from authentik.managed.manager import ObjectManager -from authentik.providers.oauth2.generators import ( - generate_client_id, - generate_client_secret, -) -from authentik.providers.oauth2.models import ( - IDToken, - OAuth2Provider, - RefreshToken, - ScopeMapping, -) +from authentik.providers.oauth2.generators import generate_client_id, generate_client_secret +from authentik.providers.oauth2.models import IDToken, OAuth2Provider, RefreshToken, ScopeMapping from authentik.providers.oauth2.tests.utils import OAuthTestCase @@ -78,9 +70,7 @@ class TestUserinfo(OAuthTestCase): def test_userinfo_invalid_scope(self): """test user info with a broken scope""" - scope = ScopeMapping.objects.create( - name="test", scope_name="openid", expression="q" - ) + scope = ScopeMapping.objects.create(name="test", scope_name="openid", expression="q") self.provider.property_mappings.add(scope) res = self.client.get( diff --git a/authentik/providers/oauth2/tests/utils.py b/authentik/providers/oauth2/tests/utils.py index adf9c7e7b..db2bec6f5 100644 --- a/authentik/providers/oauth2/tests/utils.py +++ b/authentik/providers/oauth2/tests/utils.py @@ -2,11 +2,7 @@ from django.test import TestCase from jwt import decode -from authentik.providers.oauth2.models import ( - JWTAlgorithms, - OAuth2Provider, - RefreshToken, -) +from authentik.providers.oauth2.models import JWTAlgorithms, OAuth2Provider, RefreshToken class OAuthTestCase(TestCase): diff --git a/authentik/providers/oauth2/urls_github.py b/authentik/providers/oauth2/urls_github.py index 77dba0826..d33b66b66 100644 --- a/authentik/providers/oauth2/urls_github.py +++ b/authentik/providers/oauth2/urls_github.py @@ -2,10 +2,7 @@ from django.urls import include, path from django.views.decorators.csrf import csrf_exempt -from authentik.providers.oauth2.constants import ( - SCOPE_GITHUB_ORG_READ, - SCOPE_GITHUB_USER_EMAIL, -) +from authentik.providers.oauth2.constants import SCOPE_GITHUB_ORG_READ, SCOPE_GITHUB_USER_EMAIL from authentik.providers.oauth2.utils import protected_resource_view from authentik.providers.oauth2.views.authorize import AuthorizationFlowInitView from authentik.providers.oauth2.views.github import GitHubUserTeamsView, GitHubUserView @@ -24,17 +21,13 @@ github_urlpatterns = [ ), path( "user", - csrf_exempt( - protected_resource_view([SCOPE_GITHUB_USER_EMAIL])(GitHubUserView.as_view()) - ), + csrf_exempt(protected_resource_view([SCOPE_GITHUB_USER_EMAIL])(GitHubUserView.as_view())), name="github-user", ), path( "user/teams", csrf_exempt( - protected_resource_view([SCOPE_GITHUB_ORG_READ])( - GitHubUserTeamsView.as_view() - ) + protected_resource_view([SCOPE_GITHUB_ORG_READ])(GitHubUserTeamsView.as_view()) ), name="github-user-teams", ), diff --git a/authentik/providers/oauth2/utils.py b/authentik/providers/oauth2/utils.py index c97cf67d6..3baf0e94b 100644 --- a/authentik/providers/oauth2/utils.py +++ b/authentik/providers/oauth2/utils.py @@ -133,9 +133,7 @@ def protected_resource_view(scopes: list[str]): raise BearerTokenError("invalid_token") try: - token: RefreshToken = RefreshToken.objects.get( - access_token=access_token - ) + token: RefreshToken = RefreshToken.objects.get(access_token=access_token) except RefreshToken.DoesNotExist: LOGGER.debug("Token does not exist", access_token=access_token) raise BearerTokenError("invalid_token") diff --git a/authentik/providers/oauth2/views/authorize.py b/authentik/providers/oauth2/views/authorize.py index 5473297a5..226ac4e68 100644 --- a/authentik/providers/oauth2/views/authorize.py +++ b/authentik/providers/oauth2/views/authorize.py @@ -132,9 +132,7 @@ class OAuthAuthorizationParams: scope=query_dict.get("scope", "").split(), state=state, nonce=query_dict.get("nonce"), - prompt=ALLOWED_PROMPT_PARAMS.intersection( - set(query_dict.get("prompt", "").split()) - ), + prompt=ALLOWED_PROMPT_PARAMS.intersection(set(query_dict.get("prompt", "").split())), request=query_dict.get("request", None), max_age=int(max_age) if max_age else None, code_challenge=query_dict.get("code_challenge"), @@ -143,9 +141,7 @@ class OAuthAuthorizationParams: def __post_init__(self): try: - self.provider: OAuth2Provider = OAuth2Provider.objects.get( - client_id=self.client_id - ) + self.provider: OAuth2Provider = OAuth2Provider.objects.get(client_id=self.client_id) except OAuth2Provider.DoesNotExist: LOGGER.warning("Invalid client identifier", client_id=self.client_id) raise ClientIdError(client_id=self.client_id) @@ -182,13 +178,10 @@ class OAuthAuthorizationParams: """Ensure openid scope is set in Hybrid flows, or when requesting an id_token""" if SCOPE_OPENID not in self.scope and ( self.grant_type == GrantTypes.HYBRID - or self.response_type - in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN] + or self.response_type in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN] ): LOGGER.warning("Missing 'openid' scope.") - raise AuthorizeError( - self.redirect_uri, "invalid_scope", self.grant_type, self.state - ) + raise AuthorizeError(self.redirect_uri, "invalid_scope", self.grant_type, self.state) def check_nonce(self): """Nonce parameter validation.""" @@ -226,9 +219,7 @@ class OAuthAuthorizationParams: code.code_challenge = self.code_challenge code.code_challenge_method = self.code_challenge_method - code.expires_at = timezone.now() + timedelta_from_string( - self.provider.access_code_validity - ) + code.expires_at = timezone.now() + timedelta_from_string(self.provider.access_code_validity) code.scope = self.scope code.nonce = self.nonce code.is_open_id = SCOPE_OPENID in self.scope @@ -253,12 +244,8 @@ class OAuthFulfillmentStage(StageView): if PLAN_CONTEXT_PARAMS not in self.executor.plan.context: LOGGER.warning("Got to fulfillment stage with no pending context") return HttpResponseBadRequest() - self.params: OAuthAuthorizationParams = self.executor.plan.context.pop( - PLAN_CONTEXT_PARAMS - ) - application: Application = self.executor.plan.context.pop( - PLAN_CONTEXT_APPLICATION - ) + self.params: OAuthAuthorizationParams = self.executor.plan.context.pop(PLAN_CONTEXT_PARAMS) + application: Application = self.executor.plan.context.pop(PLAN_CONTEXT_APPLICATION) self.provider = get_object_or_404(OAuth2Provider, pk=application.provider_id) try: # At this point we don't need to check permissions anymore @@ -303,9 +290,7 @@ class OAuthFulfillmentStage(StageView): if self.params.grant_type == GrantTypes.AUTHORIZATION_CODE: query_params["code"] = code.code - query_params["state"] = [ - str(self.params.state) if self.params.state else "" - ] + query_params["state"] = [str(self.params.state) if self.params.state else ""] uri = uri._replace(query=urlencode(query_params, doseq=True)) return urlunsplit(uri) @@ -433,9 +418,7 @@ class AuthorizationFlowInitView(PolicyAccessView): if self.params.max_age: current_age: timedelta = ( timezone.now() - - Event.objects.filter( - action=EventAction.LOGIN, user=get_user(self.request.user) - ) + - Event.objects.filter(action=EventAction.LOGIN, user=get_user(self.request.user)) .latest("created") .created ) @@ -465,9 +448,7 @@ class AuthorizationFlowInitView(PolicyAccessView): # OAuth2 related params PLAN_CONTEXT_PARAMS: self.params, # Consent related params - PLAN_CONTEXT_CONSENT_HEADER: _( - "You're about to sign into %(application)s." - ) + PLAN_CONTEXT_CONSENT_HEADER: _("You're about to sign into %(application)s.") % {"application": self.application.name}, PLAN_CONTEXT_CONSENT_PERMISSIONS: scope_descriptions, }, diff --git a/authentik/providers/oauth2/views/introspection.py b/authentik/providers/oauth2/views/introspection.py index 7e36866f3..626449aaa 100644 --- a/authentik/providers/oauth2/views/introspection.py +++ b/authentik/providers/oauth2/views/introspection.py @@ -45,10 +45,7 @@ class TokenIntrospectionParams: client_id, client_secret = extract_client_auth(request) if client_id == client_secret == "": return False - if ( - client_id != self.provider.client_id - or client_secret != self.provider.client_secret - ): + if client_id != self.provider.client_id or client_secret != self.provider.client_secret: LOGGER.debug("(basic) Provider for basic auth does not exist") raise TokenIntrospectionError() return True @@ -58,9 +55,7 @@ class TokenIntrospectionParams: body_token = extract_access_token(request) if not body_token: return False - tokens = RefreshToken.objects.filter(access_token=body_token).select_related( - "provider" - ) + tokens = RefreshToken.objects.filter(access_token=body_token).select_related("provider") if not tokens.exists(): LOGGER.debug("(bearer) Token does not exist") raise TokenIntrospectionError() @@ -89,9 +84,7 @@ class TokenIntrospectionParams: raise TokenIntrospectionError() params = TokenIntrospectionParams(token=token) - if not any( - [params.authenticate_basic(request), params.authenticate_bearer(request)] - ): + if not any([params.authenticate_basic(request), params.authenticate_bearer(request)]): LOGGER.debug("Not authenticated") raise TokenIntrospectionError() return params diff --git a/authentik/providers/oauth2/views/jwks.py b/authentik/providers/oauth2/views/jwks.py index a1d0fe56c..cf65698e7 100644 --- a/authentik/providers/oauth2/views/jwks.py +++ b/authentik/providers/oauth2/views/jwks.py @@ -24,9 +24,7 @@ class JWKSView(View): def get(self, request: HttpRequest, application_slug: str) -> HttpResponse: """Show RSA Key data for Provider""" application = get_object_or_404(Application, slug=application_slug) - provider: OAuth2Provider = get_object_or_404( - OAuth2Provider, pk=application.provider_id - ) + provider: OAuth2Provider = get_object_or_404(OAuth2Provider, pk=application.provider_id) response_data = {} diff --git a/authentik/providers/oauth2/views/provider.py b/authentik/providers/oauth2/views/provider.py index 050e79d23..3706ac1ab 100644 --- a/authentik/providers/oauth2/views/provider.py +++ b/authentik/providers/oauth2/views/provider.py @@ -35,9 +35,7 @@ class ProviderInfoView(View): def get_info(self, provider: OAuth2Provider) -> dict[str, Any]: """Get dictionary for OpenID Connect information""" scopes = list( - ScopeMapping.objects.filter(provider=provider).values_list( - "scope_name", flat=True - ) + ScopeMapping.objects.filter(provider=provider).values_list("scope_name", flat=True) ) if SCOPE_OPENID not in scopes: scopes.append(SCOPE_OPENID) @@ -99,9 +97,7 @@ class ProviderInfoView(View): # pylint: disable=unused-argument def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: """OpenID-compliant Provider Info""" - return JsonResponse( - self.get_info(self.provider), json_dumps_params={"indent": 2} - ) + return JsonResponse(self.get_info(self.provider), json_dumps_params={"indent": 2}) def dispatch( self, request: HttpRequest, application_slug: str, *args: Any, **kwargs: Any diff --git a/authentik/providers/oauth2/views/token.py b/authentik/providers/oauth2/views/token.py index c97086f7a..bf40dc3e0 100644 --- a/authentik/providers/oauth2/views/token.py +++ b/authentik/providers/oauth2/views/token.py @@ -21,11 +21,7 @@ from authentik.providers.oauth2.models import ( OAuth2Provider, RefreshToken, ) -from authentik.providers.oauth2.utils import ( - TokenResponse, - cors_allow, - extract_client_auth, -) +from authentik.providers.oauth2.utils import TokenResponse, cors_allow, extract_client_auth LOGGER = get_logger() @@ -132,9 +128,7 @@ class TokenParams: "Provider has no allowed redirect_uri set, allowing all.", allow=self.redirect_uri.lower(), ) - elif self.redirect_uri.lower() not in [ - x.lower() for x in allowed_redirect_urls - ]: + elif self.redirect_uri.lower() not in [x.lower() for x in allowed_redirect_urls]: LOGGER.warning( "Invalid redirect uri", uri=self.redirect_uri, @@ -148,10 +142,7 @@ class TokenParams: LOGGER.warning("Code does not exist", code=raw_code) raise TokenError("invalid_grant") - if ( - self.authorization_code.provider != self.provider - or self.authorization_code.is_expired - ): + if self.authorization_code.provider != self.provider or self.authorization_code.is_expired: LOGGER.warning("Invalid code: invalid client or code has expired") raise TokenError("invalid_grant") @@ -159,9 +150,7 @@ class TokenParams: if self.code_verifier: if self.authorization_code.code_challenge_method == "S256": new_code_challenge = ( - urlsafe_b64encode( - sha256(self.code_verifier.encode("ascii")).digest() - ) + urlsafe_b64encode(sha256(self.code_verifier.encode("ascii")).digest()) .decode("utf-8") .replace("=", "") ) @@ -197,16 +186,12 @@ class TokenView(View): try: self.provider = OAuth2Provider.objects.get(client_id=client_id) except OAuth2Provider.DoesNotExist: - LOGGER.warning( - "OAuth2Provider does not exist", client_id=self.client_id - ) + LOGGER.warning("OAuth2Provider does not exist", client_id=self.client_id) raise TokenError("invalid_client") if not self.provider: raise ValueError - self.params = TokenParams.parse( - request, self.provider, client_id, client_secret - ) + self.params = TokenParams.parse(request, self.provider, client_id, client_secret) if self.params.grant_type == GRANT_TYPE_AUTHORIZATION_CODE: return TokenResponse(self.create_code_response()) @@ -247,9 +232,7 @@ class TokenView(View): "refresh_token": refresh_token.refresh_token, "token_type": "bearer", "expires_in": int( - timedelta_from_string( - self.params.provider.token_validity - ).total_seconds() + timedelta_from_string(self.params.provider.token_validity).total_seconds() ), "id_token": refresh_token.provider.encode(refresh_token.id_token.to_dict()), } @@ -257,9 +240,7 @@ class TokenView(View): def create_refresh_response(self) -> dict[str, Any]: """See https://tools.ietf.org/html/rfc6749#section-6""" - unauthorized_scopes = set(self.params.scope) - set( - self.params.refresh_token.scope - ) + unauthorized_scopes = set(self.params.scope) - set(self.params.refresh_token.scope) if unauthorized_scopes: raise TokenError("invalid_scope") @@ -291,9 +272,7 @@ class TokenView(View): "refresh_token": refresh_token.refresh_token, "token_type": "bearer", "expires_in": int( - timedelta_from_string( - refresh_token.provider.token_validity - ).total_seconds() + timedelta_from_string(refresh_token.provider.token_validity).total_seconds() ), "id_token": self.params.provider.encode(refresh_token.id_token.to_dict()), } diff --git a/authentik/providers/oauth2/views/userinfo.py b/authentik/providers/oauth2/views/userinfo.py index 119ec79c2..223a316a6 100644 --- a/authentik/providers/oauth2/views/userinfo.py +++ b/authentik/providers/oauth2/views/userinfo.py @@ -30,30 +30,20 @@ class UserInfoView(View): def get_scope_descriptions(self, scopes: list[str]) -> list[dict[str, str]]: """Get a list of all Scopes's descriptions""" scope_descriptions = [] - for scope in ScopeMapping.objects.filter(scope_name__in=scopes).order_by( - "scope_name" - ): + for scope in ScopeMapping.objects.filter(scope_name__in=scopes).order_by("scope_name"): if scope.description != "": - scope_descriptions.append( - {"id": scope.scope_name, "name": scope.description} - ) + scope_descriptions.append({"id": scope.scope_name, "name": scope.description}) # GitHub Compatibility Scopes are handeled differently, since they required custom paths # Hence they don't exist as Scope objects github_scope_map = { SCOPE_GITHUB_USER: ("GitHub Compatibility: Access your User Information"), - SCOPE_GITHUB_USER_READ: ( - "GitHub Compatibility: Access your User Information" - ), - SCOPE_GITHUB_USER_EMAIL: ( - "GitHub Compatibility: Access you Email addresses" - ), + SCOPE_GITHUB_USER_READ: ("GitHub Compatibility: Access your User Information"), + SCOPE_GITHUB_USER_EMAIL: ("GitHub Compatibility: Access you Email addresses"), SCOPE_GITHUB_ORG_READ: ("GitHub Compatibility: Access your Groups"), } for scope in scopes: if scope in github_scope_map: - scope_descriptions.append( - {"id": scope, "name": github_scope_map[scope]} - ) + scope_descriptions.append({"id": scope, "name": github_scope_map[scope]}) return scope_descriptions def get_claims(self, token: RefreshToken) -> dict[str, Any]: diff --git a/authentik/providers/proxy/api.py b/authentik/providers/proxy/api.py index afae83fa1..450fde22a 100644 --- a/authentik/providers/proxy/api.py +++ b/authentik/providers/proxy/api.py @@ -42,9 +42,7 @@ class ProxyProviderSerializer(ProviderSerializer): attrs.get("mode", ProxyMode.PROXY) == ProxyMode.PROXY and attrs.get("internal_host", "") == "" ): - raise ValidationError( - "Internal host cannot be empty when forward auth is disabled." - ) + raise ValidationError("Internal host cannot be empty when forward auth is disabled.") return attrs def create(self, validated_data): diff --git a/authentik/providers/proxy/controllers/k8s/ingress.py b/authentik/providers/proxy/controllers/k8s/ingress.py index 2cb2153bc..967a63ad6 100644 --- a/authentik/providers/proxy/controllers/k8s/ingress.py +++ b/authentik/providers/proxy/controllers/k8s/ingress.py @@ -11,15 +11,10 @@ from kubernetes.client import ( NetworkingV1beta1IngressSpec, NetworkingV1beta1IngressTLS, ) -from kubernetes.client.models.networking_v1beta1_ingress_rule import ( - NetworkingV1beta1IngressRule, -) +from kubernetes.client.models.networking_v1beta1_ingress_rule import NetworkingV1beta1IngressRule from authentik.outposts.controllers.base import FIELD_MANAGER -from authentik.outposts.controllers.k8s.base import ( - KubernetesObjectReconciler, - NeedsUpdate, -) +from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler, NeedsUpdate from authentik.providers.proxy.models import ProxyMode, ProxyProvider if TYPE_CHECKING: @@ -41,9 +36,7 @@ class IngressReconciler(KubernetesObjectReconciler[NetworkingV1beta1Ingress]): if reference.metadata.annotations[key] != value: raise NeedsUpdate() - def reconcile( - self, current: NetworkingV1beta1Ingress, reference: NetworkingV1beta1Ingress - ): + def reconcile(self, current: NetworkingV1beta1Ingress, reference: NetworkingV1beta1Ingress): super().reconcile(current, reference) self._check_annotations(reference) # Create a list of all expected host and tls hosts @@ -84,9 +77,7 @@ class IngressReconciler(KubernetesObjectReconciler[NetworkingV1beta1Ingress]): "nginx.ingress.kubernetes.io/proxy-buffers-number": "4", "nginx.ingress.kubernetes.io/proxy-buffer-size": "16k", } - annotations.update( - self.controller.outpost.config.kubernetes_ingress_annotations - ) + annotations.update(self.controller.outpost.config.kubernetes_ingress_annotations) return annotations def get_reference_object(self) -> NetworkingV1beta1Ingress: @@ -155,16 +146,12 @@ class IngressReconciler(KubernetesObjectReconciler[NetworkingV1beta1Ingress]): ) def delete(self, reference: NetworkingV1beta1Ingress): - return self.api.delete_namespaced_ingress( - reference.metadata.name, self.namespace - ) + return self.api.delete_namespaced_ingress(reference.metadata.name, self.namespace) def retrieve(self) -> NetworkingV1beta1Ingress: return self.api.read_namespaced_ingress(self.name, self.namespace) - def update( - self, current: NetworkingV1beta1Ingress, reference: NetworkingV1beta1Ingress - ): + def update(self, current: NetworkingV1beta1Ingress, reference: NetworkingV1beta1Ingress): return self.api.patch_namespaced_ingress( current.metadata.name, self.namespace, diff --git a/authentik/providers/proxy/controllers/k8s/traefik.py b/authentik/providers/proxy/controllers/k8s/traefik.py index 926630f9b..adc19ca91 100644 --- a/authentik/providers/proxy/controllers/k8s/traefik.py +++ b/authentik/providers/proxy/controllers/k8s/traefik.py @@ -6,10 +6,7 @@ from dacite import from_dict from kubernetes.client import ApiextensionsV1Api, CustomObjectsApi from authentik.outposts.controllers.base import FIELD_MANAGER -from authentik.outposts.controllers.k8s.base import ( - KubernetesObjectReconciler, - NeedsUpdate, -) +from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler, NeedsUpdate from authentik.providers.proxy.models import ProxyMode, ProxyProvider if TYPE_CHECKING: diff --git a/authentik/providers/proxy/controllers/kubernetes.py b/authentik/providers/proxy/controllers/kubernetes.py index 048d81f14..7c728ece2 100644 --- a/authentik/providers/proxy/controllers/kubernetes.py +++ b/authentik/providers/proxy/controllers/kubernetes.py @@ -3,9 +3,7 @@ from authentik.outposts.controllers.base import DeploymentPort from authentik.outposts.controllers.kubernetes import KubernetesController from authentik.outposts.models import KubernetesServiceConnection, Outpost from authentik.providers.proxy.controllers.k8s.ingress import IngressReconciler -from authentik.providers.proxy.controllers.k8s.traefik import ( - TraefikMiddlewareReconciler, -) +from authentik.providers.proxy.controllers.k8s.traefik import TraefikMiddlewareReconciler class ProxyKubernetesController(KubernetesController): diff --git a/authentik/providers/proxy/migrations/0001_initial.py b/authentik/providers/proxy/migrations/0001_initial.py index 873690d26..4b2a02136 100644 --- a/authentik/providers/proxy/migrations/0001_initial.py +++ b/authentik/providers/proxy/migrations/0001_initial.py @@ -31,21 +31,13 @@ class Migration(migrations.Migration): ( "internal_host", models.TextField( - validators=[ - django.core.validators.URLValidator( - schemes=("http", "https") - ) - ] + validators=[django.core.validators.URLValidator(schemes=("http", "https"))] ), ), ( "external_host", models.TextField( - validators=[ - django.core.validators.URLValidator( - schemes=("http", "https") - ) - ] + validators=[django.core.validators.URLValidator(schemes=("http", "https"))] ), ), ], diff --git a/authentik/providers/proxy/migrations/0002_proxyprovider_cookie_secret.py b/authentik/providers/proxy/migrations/0002_proxyprovider_cookie_secret.py index bcf25c252..c6b1ad4ee 100644 --- a/authentik/providers/proxy/migrations/0002_proxyprovider_cookie_secret.py +++ b/authentik/providers/proxy/migrations/0002_proxyprovider_cookie_secret.py @@ -15,8 +15,6 @@ class Migration(migrations.Migration): migrations.AddField( model_name="proxyprovider", name="cookie_secret", - field=models.TextField( - default=authentik.providers.proxy.models.get_cookie_secret - ), + field=models.TextField(default=authentik.providers.proxy.models.get_cookie_secret), ), ] diff --git a/authentik/providers/proxy/migrations/0004_auto_20200913_1947.py b/authentik/providers/proxy/migrations/0004_auto_20200913_1947.py index 34426eaf8..ab5ce8ba3 100644 --- a/authentik/providers/proxy/migrations/0004_auto_20200913_1947.py +++ b/authentik/providers/proxy/migrations/0004_auto_20200913_1947.py @@ -16,22 +16,14 @@ class Migration(migrations.Migration): model_name="proxyprovider", name="external_host", field=models.TextField( - validators=[ - authentik.lib.models.DomainlessURLValidator( - schemes=("http", "https") - ) - ] + validators=[authentik.lib.models.DomainlessURLValidator(schemes=("http", "https"))] ), ), migrations.AlterField( model_name="proxyprovider", name="internal_host", field=models.TextField( - validators=[ - authentik.lib.models.DomainlessURLValidator( - schemes=("http", "https") - ) - ] + validators=[authentik.lib.models.DomainlessURLValidator(schemes=("http", "https"))] ), ), ] diff --git a/authentik/providers/proxy/migrations/0011_proxyprovider_forward_auth_mode.py b/authentik/providers/proxy/migrations/0011_proxyprovider_forward_auth_mode.py index b9be9af13..45f728836 100644 --- a/authentik/providers/proxy/migrations/0011_proxyprovider_forward_auth_mode.py +++ b/authentik/providers/proxy/migrations/0011_proxyprovider_forward_auth_mode.py @@ -25,11 +25,7 @@ class Migration(migrations.Migration): name="internal_host", field=models.TextField( blank=True, - validators=[ - authentik.lib.models.DomainlessURLValidator( - schemes=("http", "https") - ) - ], + validators=[authentik.lib.models.DomainlessURLValidator(schemes=("http", "https"))], ), ), ] diff --git a/authentik/providers/proxy/models.py b/authentik/providers/proxy/models.py index c1fbe7ec9..69c00c468 100644 --- a/authentik/providers/proxy/models.py +++ b/authentik/providers/proxy/models.py @@ -28,9 +28,7 @@ SCOPE_AK_PROXY = "ak_proxy" def get_cookie_secret(): """Generate random 32-character string for cookie-secret""" - return "".join( - SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(32) - ) + return "".join(SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(32)) def _get_callback_url(uri: str) -> str: @@ -53,9 +51,7 @@ class ProxyProvider(OutpostModel, OAuth2Provider): validators=[DomainlessURLValidator(schemes=("http", "https"))], blank=True, ) - external_host = models.TextField( - validators=[DomainlessURLValidator(schemes=("http", "https"))] - ) + external_host = models.TextField(validators=[DomainlessURLValidator(schemes=("http", "https"))]) internal_host_ssl_validation = models.BooleanField( default=True, help_text=_("Validate SSL Certificates of upstream servers"), @@ -101,11 +97,7 @@ class ProxyProvider(OutpostModel, OAuth2Provider): basic_auth_password_attribute = models.TextField( blank=True, verbose_name=_("HTTP-Basic Password Key"), - help_text=_( - ( - "User/Group Attribute used for the password part of the HTTP-Basic Header." - ) - ), + help_text=_(("User/Group Attribute used for the password part of the HTTP-Basic Header.")), ) certificate = models.ForeignKey( diff --git a/authentik/providers/saml/api.py b/authentik/providers/saml/api.py index 510f2a159..7182fb591 100644 --- a/authentik/providers/saml/api.py +++ b/authentik/providers/saml/api.py @@ -9,12 +9,7 @@ from django.utils.translation import gettext_lazy as _ from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema from rest_framework.decorators import action -from rest_framework.fields import ( - CharField, - FileField, - ReadOnlyField, - SerializerMethodField, -) +from rest_framework.fields import CharField, FileField, ReadOnlyField, SerializerMethodField from rest_framework.parsers import MultiPartParser from rest_framework.permissions import AllowAny from rest_framework.relations import SlugRelatedField @@ -33,9 +28,7 @@ from authentik.core.models import Provider from authentik.flows.models import Flow, FlowDesignation from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider from authentik.providers.saml.processors.metadata import MetadataProcessor -from authentik.providers.saml.processors.metadata_parser import ( - ServiceProviderMetadataParser, -) +from authentik.providers.saml.processors.metadata_parser import ServiceProviderMetadataParser LOGGER = get_logger() @@ -48,8 +41,7 @@ class SAMLProviderSerializer(ProviderSerializer): def get_metadata_download_url(self, instance: SAMLProvider) -> str: """Get metadata download URL""" return ( - reverse("authentik_api:samlprovider-metadata", kwargs={"pk": instance.pk}) - + "?download" + reverse("authentik_api:samlprovider-metadata", kwargs={"pk": instance.pk}) + "?download" ) class Meta: diff --git a/authentik/providers/saml/migrations/0001_initial.py b/authentik/providers/saml/migrations/0001_initial.py index c877030fe..e78eef07d 100644 --- a/authentik/providers/saml/migrations/0001_initial.py +++ b/authentik/providers/saml/migrations/0001_initial.py @@ -66,9 +66,7 @@ class Migration(migrations.Migration): models.TextField( default="minutes=-5", help_text="Assertion valid not before current time + this value (Format: hours=-1;minutes=-2;seconds=-3).", - validators=[ - authentik.lib.utils.time.timedelta_string_validator - ], + validators=[authentik.lib.utils.time.timedelta_string_validator], ), ), ( @@ -76,9 +74,7 @@ class Migration(migrations.Migration): models.TextField( default="minutes=5", help_text="Assertion not valid on or after current time + this value (Format: hours=1;minutes=2;seconds=3).", - validators=[ - authentik.lib.utils.time.timedelta_string_validator - ], + validators=[authentik.lib.utils.time.timedelta_string_validator], ), ), ( @@ -86,9 +82,7 @@ class Migration(migrations.Migration): models.TextField( default="minutes=86400", help_text="Session not valid on or after current time + this value (Format: hours=1;minutes=2;seconds=3).", - validators=[ - authentik.lib.utils.time.timedelta_string_validator - ], + validators=[authentik.lib.utils.time.timedelta_string_validator], ), ), ( diff --git a/authentik/providers/saml/migrations/0008_auto_20201112_1036.py b/authentik/providers/saml/migrations/0008_auto_20201112_1036.py index 18dc55c89..a0dace536 100644 --- a/authentik/providers/saml/migrations/0008_auto_20201112_1036.py +++ b/authentik/providers/saml/migrations/0008_auto_20201112_1036.py @@ -27,9 +27,7 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="samlprovider", name="issuer", - field=models.TextField( - default="authentik", help_text="Also known as EntityID" - ), + field=models.TextField(default="authentik", help_text="Also known as EntityID"), ), migrations.AlterField( model_name="samlprovider", diff --git a/authentik/providers/saml/migrations/0012_managed.py b/authentik/providers/saml/migrations/0012_managed.py index c290b42c8..ae5754f25 100644 --- a/authentik/providers/saml/migrations/0012_managed.py +++ b/authentik/providers/saml/migrations/0012_managed.py @@ -27,13 +27,9 @@ saml_name_uid_map = { def add_managed_update(apps, schema_editor): """Create default SAML Property Mappings""" - SAMLPropertyMapping = apps.get_model( - "authentik_providers_saml", "SAMLPropertyMapping" - ) + SAMLPropertyMapping = apps.get_model("authentik_providers_saml", "SAMLPropertyMapping") db_alias = schema_editor.connection.alias - for pm in SAMLPropertyMapping.objects.using(db_alias).filter( - name__startswith="Autogenerated " - ): + for pm in SAMLPropertyMapping.objects.using(db_alias).filter(name__startswith="Autogenerated "): if pm.saml_name not in saml_name_map: continue new_name = saml_name_map[pm.saml_name] diff --git a/authentik/providers/saml/models.py b/authentik/providers/saml/models.py index b1d77dac9..b506bfa28 100644 --- a/authentik/providers/saml/models.py +++ b/authentik/providers/saml/models.py @@ -46,18 +46,13 @@ class SAMLProvider(Provider): ) ), ) - issuer = models.TextField( - help_text=_("Also known as EntityID"), default="authentik" - ) + issuer = models.TextField(help_text=_("Also known as EntityID"), default="authentik") sp_binding = models.TextField( choices=SAMLBindings.choices, default=SAMLBindings.REDIRECT, verbose_name=_("Service Provider Binding"), help_text=_( - ( - "This determines how authentik sends the " - "response back to the Service Provider." - ) + ("This determines how authentik sends the " "response back to the Service Provider.") ), ) @@ -150,9 +145,7 @@ class SAMLProvider(Provider): default=None, null=True, blank=True, - help_text=_( - "Keypair used to sign outgoing Responses going to the Service Provider." - ), + help_text=_("Keypair used to sign outgoing Responses going to the Service Provider."), on_delete=models.SET_NULL, verbose_name=_("Signing Keypair"), ) diff --git a/authentik/providers/saml/processors/assertion.py b/authentik/providers/saml/processors/assertion.py index 765947ce7..db6f46400 100644 --- a/authentik/providers/saml/processors/assertion.py +++ b/authentik/providers/saml/processors/assertion.py @@ -47,9 +47,7 @@ class AssertionProcessor: _valid_not_before: str _valid_not_on_or_after: str - def __init__( - self, provider: SAMLProvider, request: HttpRequest, auth_n_request: AuthNRequest - ): + def __init__(self, provider: SAMLProvider, request: HttpRequest, auth_n_request: AuthNRequest): self.provider = provider self.http_request = request self.auth_n_request = auth_n_request @@ -120,9 +118,7 @@ class AssertionProcessor: auth_n_statement.attrib["AuthnInstant"] = self._valid_not_before auth_n_statement.attrib["SessionIndex"] = self._assertion_id - auth_n_context = SubElement( - auth_n_statement, f"{{{NS_SAML_ASSERTION}}}AuthnContext" - ) + auth_n_context = SubElement(auth_n_statement, f"{{{NS_SAML_ASSERTION}}}AuthnContext") auth_n_context_class_ref = SubElement( auth_n_context, f"{{{NS_SAML_ASSERTION}}}AuthnContextClassRef" ) @@ -140,9 +136,7 @@ class AssertionProcessor: audience_restriction = SubElement( conditions, f"{{{NS_SAML_ASSERTION}}}AudienceRestriction" ) - audience = SubElement( - audience_restriction, f"{{{NS_SAML_ASSERTION}}}Audience" - ) + audience = SubElement(audience_restriction, f"{{{NS_SAML_ASSERTION}}}Audience") audience.text = self.provider.audience return conditions @@ -205,9 +199,7 @@ class AssertionProcessor: subject = Element(f"{{{NS_SAML_ASSERTION}}}Subject") subject.append(self.get_name_id()) - subject_confirmation = SubElement( - subject, f"{{{NS_SAML_ASSERTION}}}SubjectConfirmation" - ) + subject_confirmation = SubElement(subject, f"{{{NS_SAML_ASSERTION}}}SubjectConfirmation") subject_confirmation.attrib["Method"] = "urn:oasis:names:tc:SAML:2.0:cm:bearer" subject_confirmation_data = SubElement( @@ -274,9 +266,7 @@ class AssertionProcessor: ) assertion = root_response.xpath("//saml:Assertion", namespaces=NS_MAP)[0] xmlsec.tree.add_ids(assertion, ["ID"]) - signature_node = xmlsec.tree.find_node( - assertion, xmlsec.constants.NodeSignature - ) + signature_node = xmlsec.tree.find_node(assertion, xmlsec.constants.NodeSignature) ref = xmlsec.template.add_reference( signature_node, digest_algorithm_transform, diff --git a/authentik/providers/saml/processors/metadata.py b/authentik/providers/saml/processors/metadata.py index 32edb7d15..a296f46fc 100644 --- a/authentik/providers/saml/processors/metadata.py +++ b/authentik/providers/saml/processors/metadata.py @@ -100,13 +100,9 @@ class MetadataProcessor: digest_algorithm_transform = DIGEST_ALGORITHM_TRANSLATION_MAP.get( self.provider.digest_algorithm, xmlsec.constants.TransformSha1 ) - assertion = entity_descriptor.xpath("//md:EntityDescriptor", namespaces=NS_MAP)[ - 0 - ] + assertion = entity_descriptor.xpath("//md:EntityDescriptor", namespaces=NS_MAP)[0] xmlsec.tree.add_ids(assertion, ["ID"]) - signature_node = xmlsec.tree.find_node( - assertion, xmlsec.constants.NodeSignature - ) + signature_node = xmlsec.tree.find_node(assertion, xmlsec.constants.NodeSignature) ref = xmlsec.template.add_reference( signature_node, digest_algorithm_transform, @@ -133,9 +129,7 @@ class MetadataProcessor: def build_entity_descriptor(self) -> str: """Build full EntityDescriptor""" - entity_descriptor = Element( - f"{{{NS_SAML_METADATA}}}EntityDescriptor", nsmap=NS_MAP - ) + entity_descriptor = Element(f"{{{NS_SAML_METADATA}}}EntityDescriptor", nsmap=NS_MAP) entity_descriptor.attrib["ID"] = self.xml_id entity_descriptor.attrib["entityID"] = self.provider.issuer diff --git a/authentik/providers/saml/processors/metadata_parser.py b/authentik/providers/saml/processors/metadata_parser.py index 8708a6c06..cd6068d91 100644 --- a/authentik/providers/saml/processors/metadata_parser.py +++ b/authentik/providers/saml/processors/metadata_parser.py @@ -11,11 +11,7 @@ from structlog.stdlib import get_logger from authentik.crypto.models import CertificateKeyPair from authentik.flows.models import Flow -from authentik.providers.saml.models import ( - SAMLBindings, - SAMLPropertyMapping, - SAMLProvider, -) +from authentik.providers.saml.models import SAMLBindings, SAMLPropertyMapping, SAMLProvider from authentik.providers.saml.utils.encoding import PEM_FOOTER, PEM_HEADER from authentik.sources.saml.processors.constants import ( NS_MAP, @@ -68,14 +64,10 @@ class ServiceProviderMetadata: self.signing_keypair.save() provider.verification_kp = self.signing_keypair if self.assertion_signed: - provider.signing_kp = CertificateKeyPair.objects.exclude( - key_data__iexact="" - ).first() + provider.signing_kp = CertificateKeyPair.objects.exclude(key_data__iexact="").first() # Set all auto-generated Property-mappings as defaults # They should provide a sane default for most applications: - provider.property_mappings.set( - SAMLPropertyMapping.objects.exclude(managed__isnull=True) - ) + provider.property_mappings.set(SAMLPropertyMapping.objects.exclude(managed__isnull=True)) provider.save() return provider @@ -101,9 +93,7 @@ class ServiceProviderMetadataParser: def check_signature(self, root: etree.Element, keypair: CertificateKeyPair): """If Metadata is signed, check validity of signature""" xmlsec.tree.add_ids(root, ["ID"]) - signature_nodes = root.xpath( - "/md:EntityDescriptor/ds:Signature", namespaces=NS_MAP - ) + signature_nodes = root.xpath("/md:EntityDescriptor/ds:Signature", namespaces=NS_MAP) if len(signature_nodes) != 1: # No Signature return @@ -134,14 +124,10 @@ class ServiceProviderMetadataParser: # For now we'll only look at the first descriptor. # Even if multiple descriptors exist, we can only configure one descriptor = sp_sso_descriptors[0] - auth_n_request_signed = ( - descriptor.attrib["AuthnRequestsSigned"].lower() == "true" - ) + auth_n_request_signed = descriptor.attrib["AuthnRequestsSigned"].lower() == "true" assertion_signed = descriptor.attrib["WantAssertionsSigned"].lower() == "true" - acs_services = descriptor.findall( - f"{{{NS_SAML_METADATA}}}AssertionConsumerService" - ) + acs_services = descriptor.findall(f"{{{NS_SAML_METADATA}}}AssertionConsumerService") if len(acs_services) < 1: raise ValueError("No AssertionConsumerService found.") diff --git a/authentik/providers/saml/processors/request_parser.py b/authentik/providers/saml/processors/request_parser.py index 0746e6b5c..7fb51904e 100644 --- a/authentik/providers/saml/processors/request_parser.py +++ b/authentik/providers/saml/processors/request_parser.py @@ -81,9 +81,7 @@ class AuthNRequestParser: return auth_n_request - def parse( - self, saml_request: str, relay_state: Optional[str] = None - ) -> AuthNRequest: + def parse(self, saml_request: str, relay_state: Optional[str] = None) -> AuthNRequest: """Validate and parse raw request with enveloped signautre.""" try: decoded_xml = b64decode(saml_request.encode()) @@ -94,9 +92,7 @@ class AuthNRequestParser: root = etree.fromstring(decoded_xml) # nosec xmlsec.tree.add_ids(root, ["ID"]) - signature_nodes = root.xpath( - "/samlp:AuthnRequest/ds:Signature", namespaces=NS_MAP - ) + signature_nodes = root.xpath("/samlp:AuthnRequest/ds:Signature", namespaces=NS_MAP) # No signatures, no verifier configured -> decode xml directly if len(signature_nodes) < 1 and not verifier: return self._parse_xml(decoded_xml, relay_state) diff --git a/authentik/providers/saml/tests/test_auth_n_request.py b/authentik/providers/saml/tests/test_auth_n_request.py index 45e6c9007..ae8986956 100644 --- a/authentik/providers/saml/tests/test_auth_n_request.py +++ b/authentik/providers/saml/tests/test_auth_n_request.py @@ -21,10 +21,7 @@ from authentik.sources.saml.processors.constants import ( SAML_NAME_ID_FORMAT_EMAIL, SAML_NAME_ID_FORMAT_UNSPECIFIED, ) -from authentik.sources.saml.processors.request import ( - SESSION_REQUEST_ID, - RequestProcessor, -) +from authentik.sources.saml.processors.request import SESSION_REQUEST_ID, RequestProcessor from authentik.sources.saml.processors.response import ResponseProcessor POST_REQUEST = ( @@ -54,9 +51,7 @@ REDIRECT_SIGNATURE = ( "jVvPdh96AhBFj2HCuGZhP0CGotafTciu6YlsiwUpuBkIYgZmNWYa3FR9LS4Q==" ) REDIRECT_SIG_ALG = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256" -REDIRECT_RELAY_STATE = ( - "ss:mem:7a054b4af44f34f89dd2d973f383c250b6b076e7f06cfa8276008a6504eaf3c7" -) +REDIRECT_RELAY_STATE = "ss:mem:7a054b4af44f34f89dd2d973f383c250b6b076e7f06cfa8276008a6504eaf3c7" REDIRECT_CERT = """-----BEGIN CERTIFICATE----- MIIDCDCCAfCgAwIBAgIRAM5s+bhOHk4ChSpPkGSh0NswDQYJKoZIhvcNAQELBQAw KzEpMCcGA1UEAwwgcGFzc2Jvb2sgU2VsZi1zaWduZWQgQ2VydGlmaWNhdGUwHhcN @@ -97,9 +92,7 @@ class TestAuthNRequest(TestCase): self.source = SAMLSource.objects.create( slug="provider", issuer="authentik", - pre_authentication_flow=Flow.objects.get( - slug="default-source-pre-authentication" - ), + pre_authentication_flow=Flow.objects.get(slug="default-source-pre-authentication"), signing_kp=cert, ) self.factory = RequestFactory() @@ -283,9 +276,7 @@ class TestAuthNRequest(TestCase): request = request_proc.build_auth_n() # Create invalid PropertyMapping - scope = SAMLPropertyMapping.objects.create( - name="test", saml_name="test", expression="q" - ) + scope = SAMLPropertyMapping.objects.create(name="test", saml_name="test", expression="q") self.provider.property_mappings.add(scope) # To get an assertion we need a parsed request (parsed by provider) diff --git a/authentik/providers/saml/tests/test_metadata.py b/authentik/providers/saml/tests/test_metadata.py index 29dc74b36..e07931643 100644 --- a/authentik/providers/saml/tests/test_metadata.py +++ b/authentik/providers/saml/tests/test_metadata.py @@ -5,9 +5,7 @@ from django.test import TestCase from authentik.flows.models import Flow from authentik.providers.saml.models import SAMLBindings, SAMLPropertyMapping -from authentik.providers.saml.processors.metadata_parser import ( - ServiceProviderMetadataParser, -) +from authentik.providers.saml.processors.metadata_parser import ServiceProviderMetadataParser METADATA_SIMPLE = """ HttpResponse: application: Application = self.executor.plan.context[PLAN_CONTEXT_APPLICATION] - provider: SAMLProvider = get_object_or_404( - SAMLProvider, pk=application.provider_id - ) + provider: SAMLProvider = get_object_or_404(SAMLProvider, pk=application.provider_id) if SESSION_KEY_AUTH_N_REQUEST not in self.request.session: return self.executor.stage_invalid() - auth_n_request: AuthNRequest = self.request.session.pop( - SESSION_KEY_AUTH_N_REQUEST - ) + auth_n_request: AuthNRequest = self.request.session.pop(SESSION_KEY_AUTH_N_REQUEST) try: - response = AssertionProcessor( - provider, request, auth_n_request - ).build_response() + response = AssertionProcessor(provider, request, auth_n_request).build_response() except SAMLException as exc: Event.new( EventAction.CONFIGURATION_ERROR, diff --git a/authentik/providers/saml/views/sso.py b/authentik/providers/saml/views/sso.py index 732d582aa..664173637 100644 --- a/authentik/providers/saml/views/sso.py +++ b/authentik/providers/saml/views/sso.py @@ -12,11 +12,7 @@ from structlog.stdlib import get_logger from authentik.core.models import Application from authentik.events.models import Event, EventAction from authentik.flows.models import in_memory_stage -from authentik.flows.planner import ( - PLAN_CONTEXT_APPLICATION, - PLAN_CONTEXT_SSO, - FlowPlanner, -) +from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_SSO, FlowPlanner from authentik.flows.views import SESSION_KEY_PLAN from authentik.lib.utils.urls import redirect_with_qs from authentik.lib.views import bad_request_message @@ -45,9 +41,7 @@ class SAMLSSOView(PolicyAccessView): Calls get/post handler.""" def resolve_provider_application(self): - self.application = get_object_or_404( - Application, slug=self.kwargs["application_slug"] - ) + self.application = get_object_or_404(Application, slug=self.kwargs["application_slug"]) self.provider: SAMLProvider = get_object_or_404( SAMLProvider, pk=self.application.provider_id ) @@ -72,9 +66,7 @@ class SAMLSSOView(PolicyAccessView): { PLAN_CONTEXT_SSO: True, PLAN_CONTEXT_APPLICATION: self.application, - PLAN_CONTEXT_CONSENT_HEADER: _( - "You're about to sign into %(application)s." - ) + PLAN_CONTEXT_CONSENT_HEADER: _("You're about to sign into %(application)s.") % {"application": self.application.name}, PLAN_CONTEXT_CONSENT_PERMISSIONS: [], }, @@ -100,9 +92,7 @@ class SAMLSSOBindingRedirectView(SAMLSSOView): """Handle REDIRECT bindings""" if REQUEST_KEY_SAML_REQUEST not in self.request.GET: LOGGER.info("handle_saml_request: SAML payload missing") - return bad_request_message( - self.request, "The SAML request payload is missing." - ) + return bad_request_message(self.request, "The SAML request payload is missing.") try: auth_n_request = AuthNRequestParser(self.provider).parse_detached( @@ -132,9 +122,7 @@ class SAMLSSOBindingPOSTView(SAMLSSOView): """Handle POST bindings""" if REQUEST_KEY_SAML_REQUEST not in self.request.POST: LOGGER.info("check_saml_request: SAML payload missing") - return bad_request_message( - self.request, "The SAML request payload is missing." - ) + return bad_request_message(self.request, "The SAML request payload is missing.") try: auth_n_request = AuthNRequestParser(self.provider).parse( @@ -153,8 +141,6 @@ class SAMLSSOBindingInitView(SAMLSSOView): def check_saml_request(self) -> Optional[HttpRequest]: """Create SAML Response from scratch""" - LOGGER.debug( - "handle_saml_no_request: No SAML Request, using IdP-initiated flow." - ) + LOGGER.debug("handle_saml_no_request: No SAML Request, using IdP-initiated flow.") auth_n_request = AuthNRequestParser(self.provider).idp_initiated() self.request.session[SESSION_KEY_AUTH_N_REQUEST] = auth_n_request diff --git a/authentik/recovery/management/commands/create_recovery_key.py b/authentik/recovery/management/commands/create_recovery_key.py index aa502d6bd..63736a9e3 100644 --- a/authentik/recovery/management/commands/create_recovery_key.py +++ b/authentik/recovery/management/commands/create_recovery_key.py @@ -25,9 +25,7 @@ class Command(BaseCommand): action="store", help="How long the token is valid for (in years).", ) - parser.add_argument( - "user", action="store", help="Which user the Token gives access to." - ) + parser.add_argument("user", action="store", help="Which user the Token gives access to.") def get_url(self, token: Token) -> str: """Get full recovery link""" @@ -47,9 +45,6 @@ class Command(BaseCommand): identifier=f"ak-recovery-{user}-{_now}", ) self.stdout.write( - ( - f"Store this link safely, as it will allow" - f" anyone to access authentik as {user}." - ) + (f"Store this link safely, as it will allow" f" anyone to access authentik as {user}.") ) self.stdout.write(self.get_url(token)) diff --git a/authentik/recovery/tests.py b/authentik/recovery/tests.py index 4357a7c80..b1f7886c8 100644 --- a/authentik/recovery/tests.py +++ b/authentik/recovery/tests.py @@ -28,14 +28,10 @@ class TestRecovery(TestCase): out = StringIO() call_command("create_recovery_key", "1", self.user.username, stdout=out) token = Token.objects.get(intent=TokenIntents.INTENT_RECOVERY, user=self.user) - self.client.get( - reverse("authentik_recovery:use-token", kwargs={"key": token.key}) - ) + self.client.get(reverse("authentik_recovery:use-token", kwargs={"key": token.key})) self.assertEqual(int(self.client.session["_auth_user_id"]), token.user.pk) def test_recovery_view_invalid(self): """Test recovery view with invalid token""" - response = self.client.get( - reverse("authentik_recovery:use-token", kwargs={"key": "abc"}) - ) + response = self.client.get(reverse("authentik_recovery:use-token", kwargs={"key": "abc"})) self.assertEqual(response.status_code, 404) diff --git a/authentik/root/asgi.py b/authentik/root/asgi.py index df17143bb..4d55e795a 100644 --- a/authentik/root/asgi.py +++ b/authentik/root/asgi.py @@ -74,13 +74,9 @@ class ASGILogger: if message["type"] == "http.response.start": response_headers = dict(message["headers"]) nonlocal request_id - request_id = response_headers.get( - RESPONSE_HEADER_ID.encode(), b"" - ).decode() + request_id = response_headers.get(RESPONSE_HEADER_ID.encode(), b"").decode() - if message["type"] == "http.response.body" and not message.get( - "more_body", True - ): + if message["type"] == "http.response.body" and not message.get("more_body", True): runtime = int((time() - self.start) * 1000) self.log(scope, runtime, content_length, request_id=request_id) await send(message) diff --git a/authentik/root/celery.py b/authentik/root/celery.py index 20b515348..16f3065a0 100644 --- a/authentik/root/celery.py +++ b/authentik/root/celery.py @@ -26,9 +26,7 @@ def config_loggers(*args, **kwags): def after_task_publish_hook(sender=None, headers=None, body=None, **kwargs): """Log task_id after it was published""" info = headers if "task" in headers else body - LOGGER.debug( - "Task published", task_id=info.get("id", ""), task_name=info.get("task", "") - ) + LOGGER.debug("Task published", task_id=info.get("id", ""), task_name=info.get("task", "")) # pylint: disable=unused-argument diff --git a/authentik/root/middleware.py b/authentik/root/middleware.py index 5b7832130..bdba86421 100644 --- a/authentik/root/middleware.py +++ b/authentik/root/middleware.py @@ -4,9 +4,7 @@ import time from django.conf import settings from django.contrib.sessions.backends.base import UpdateError from django.contrib.sessions.exceptions import SessionInterrupted -from django.contrib.sessions.middleware import ( - SessionMiddleware as UpstreamSessionMiddleware, -) +from django.contrib.sessions.middleware import SessionMiddleware as UpstreamSessionMiddleware from django.http.request import HttpRequest from django.http.response import HttpResponse from django.utils.cache import patch_vary_headers @@ -31,9 +29,7 @@ class SessionMiddleware(UpstreamSessionMiddleware): return True return False - def process_response( - self, request: HttpRequest, response: HttpResponse - ) -> HttpResponse: + def process_response(self, request: HttpRequest, response: HttpResponse) -> HttpResponse: """ If request.session was modified, or if the configuration is to save the session every time, save the changes and set a session cookie or delete diff --git a/authentik/root/settings.py b/authentik/root/settings.py index 7602cd123..11f0dc90e 100644 --- a/authentik/root/settings.py +++ b/authentik/root/settings.py @@ -175,9 +175,7 @@ REST_FRAMEWORK = { "rest_framework.filters.OrderingFilter", "rest_framework.filters.SearchFilter", ], - "DEFAULT_PERMISSION_CLASSES": ( - "rest_framework.permissions.DjangoObjectPermissions", - ), + "DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.DjangoObjectPermissions",), "DEFAULT_AUTHENTICATION_CLASSES": ( "authentik.api.authentication.TokenAuthentication", "rest_framework.authentication.SessionAuthentication", @@ -398,9 +396,7 @@ if _ERROR_REPORTING: if build_hash == "": build_hash = "tagged" set_tag("authentik.build_hash", build_hash) - set_tag( - "authentik.env", "kubernetes" if "KUBERNETES_PORT" in os.environ else "compose" - ) + set_tag("authentik.env", "kubernetes" if "KUBERNETES_PORT" in os.environ else "compose") set_tag("authentik.component", "backend") j_print( "Error reporting is enabled", @@ -514,12 +510,8 @@ for _app in INSTALLED_APPS: app_settings = importlib.import_module("%s.settings" % _app) INSTALLED_APPS.extend(getattr(app_settings, "INSTALLED_APPS", [])) MIDDLEWARE.extend(getattr(app_settings, "MIDDLEWARE", [])) - AUTHENTICATION_BACKENDS.extend( - getattr(app_settings, "AUTHENTICATION_BACKENDS", []) - ) - CELERY_BEAT_SCHEDULE.update( - getattr(app_settings, "CELERY_BEAT_SCHEDULE", {}) - ) + AUTHENTICATION_BACKENDS.extend(getattr(app_settings, "AUTHENTICATION_BACKENDS", [])) + CELERY_BEAT_SCHEDULE.update(getattr(app_settings, "CELERY_BEAT_SCHEDULE", {})) for _attr in dir(app_settings): if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS: globals()[_attr] = getattr(app_settings, _attr) diff --git a/authentik/root/tests.py b/authentik/root/tests.py index 94ae0a5cf..9824a03c6 100644 --- a/authentik/root/tests.py +++ b/authentik/root/tests.py @@ -20,9 +20,7 @@ class TestRoot(TestCase): def test_monitoring_ok(self): """Test monitoring with credentials""" - creds = "Basic " + b64encode(f"monitor:{settings.SECRET_KEY}".encode()).decode( - "utf-8" - ) + creds = "Basic " + b64encode(f"monitor:{settings.SECRET_KEY}".encode()).decode("utf-8") auth_headers = {"HTTP_AUTHORIZATION": creds} response = self.client.get(reverse("metrics"), **auth_headers) self.assertEqual(response.status_code, 200) diff --git a/authentik/root/websocket.py b/authentik/root/websocket.py index 398b47f62..8a8bc1d86 100644 --- a/authentik/root/websocket.py +++ b/authentik/root/websocket.py @@ -8,7 +8,5 @@ from authentik.root.messages.consumer import MessageConsumer websocket_urlpatterns = [ path("ws/outpost//", SentryWSMiddleware(OutpostConsumer.as_asgi())), - path( - "ws/client/", AuthMiddlewareStack(SentryWSMiddleware(MessageConsumer.as_asgi())) - ), + path("ws/client/", AuthMiddlewareStack(SentryWSMiddleware(MessageConsumer.as_asgi()))), ] diff --git a/authentik/sources/ldap/auth.py b/authentik/sources/ldap/auth.py index f2dcdf839..9cf55d36d 100644 --- a/authentik/sources/ldap/auth.py +++ b/authentik/sources/ldap/auth.py @@ -27,9 +27,7 @@ class LDAPBackend(ModelBackend): return user return None - def auth_user( - self, source: LDAPSource, password: str, **filters: str - ) -> Optional[User]: + def auth_user(self, source: LDAPSource, password: str, **filters: str) -> Optional[User]: """Try to bind as either user_dn or mail with password. Returns True on success, otherwise False""" users = User.objects.filter(**filters) @@ -37,9 +35,7 @@ class LDAPBackend(ModelBackend): return None user: User = users.first() if LDAP_DISTINGUISHED_NAME not in user.attributes: - LOGGER.debug( - "User doesn't have DN set, assuming not LDAP imported.", user=user - ) + LOGGER.debug("User doesn't have DN set, assuming not LDAP imported.", user=user) return None # Either has unusable password, # or has a password, but couldn't be authenticated by ModelBackend. @@ -54,9 +50,7 @@ class LDAPBackend(ModelBackend): LOGGER.debug("Failed to bind, password invalid") return None - def auth_user_by_bind( - self, source: LDAPSource, user: User, password: str - ) -> Optional[User]: + def auth_user_by_bind(self, source: LDAPSource, user: User, password: str) -> Optional[User]: """Attempt authentication by binding to the LDAP server as `user`. This method should be avoided as its slow to do the bind.""" # Try to bind as new user diff --git a/authentik/sources/ldap/migrations/0001_initial.py b/authentik/sources/ldap/migrations/0001_initial.py index dd42e98d7..3599fd069 100644 --- a/authentik/sources/ldap/migrations/0001_initial.py +++ b/authentik/sources/ldap/migrations/0001_initial.py @@ -53,11 +53,7 @@ class Migration(migrations.Migration): ( "server_uri", models.TextField( - validators=[ - django.core.validators.URLValidator( - schemes=["ldap", "ldaps"] - ) - ], + validators=[django.core.validators.URLValidator(schemes=["ldap", "ldaps"])], verbose_name="Server URI", ), ), diff --git a/authentik/sources/ldap/migrations/0005_auto_20200913_1947.py b/authentik/sources/ldap/migrations/0005_auto_20200913_1947.py index c33be769c..81218f77e 100644 --- a/authentik/sources/ldap/migrations/0005_auto_20200913_1947.py +++ b/authentik/sources/ldap/migrations/0005_auto_20200913_1947.py @@ -16,11 +16,7 @@ class Migration(migrations.Migration): model_name="ldapsource", name="server_uri", field=models.TextField( - validators=[ - authentik.lib.models.DomainlessURLValidator( - schemes=["ldap", "ldaps"] - ) - ], + validators=[authentik.lib.models.DomainlessURLValidator(schemes=["ldap", "ldaps"])], verbose_name="Server URI", ), ), diff --git a/authentik/sources/ldap/migrations/0008_managed.py b/authentik/sources/ldap/migrations/0008_managed.py index 8264c05ff..107906e6b 100644 --- a/authentik/sources/ldap/migrations/0008_managed.py +++ b/authentik/sources/ldap/migrations/0008_managed.py @@ -5,9 +5,7 @@ from django.db import migrations def set_managed_flag(apps: Apps, schema_editor): - LDAPPropertyMapping = apps.get_model( - "authentik_sources_ldap", "LDAPPropertyMapping" - ) + LDAPPropertyMapping = apps.get_model("authentik_sources_ldap", "LDAPPropertyMapping") db_alias = schema_editor.connection.alias field_to_uid = { "name": "goauthentik.io/sources/ldap/default-name", diff --git a/authentik/sources/ldap/migrations/0011_ldapsource_property_mappings_group.py b/authentik/sources/ldap/migrations/0011_ldapsource_property_mappings_group.py index bd5edca8d..6a2c55c21 100644 --- a/authentik/sources/ldap/migrations/0011_ldapsource_property_mappings_group.py +++ b/authentik/sources/ldap/migrations/0011_ldapsource_property_mappings_group.py @@ -5,9 +5,7 @@ from django.db import migrations, models def set_default_group_mappings(apps: Apps, schema_editor): - LDAPPropertyMapping = apps.get_model( - "authentik_sources_ldap", "LDAPPropertyMapping" - ) + LDAPPropertyMapping = apps.get_model("authentik_sources_ldap", "LDAPPropertyMapping") LDAPSource = apps.get_model("authentik_sources_ldap", "LDAPSource") db_alias = schema_editor.connection.alias diff --git a/authentik/sources/ldap/password.py b/authentik/sources/ldap/password.py index 6a5d49601..0234fdfd2 100644 --- a/authentik/sources/ldap/password.py +++ b/authentik/sources/ldap/password.py @@ -118,9 +118,7 @@ class LDAPPasswordChanger: return False return True - def ad_password_complexity( - self, password: str, user: Optional[User] = None - ) -> bool: + def ad_password_complexity(self, password: str, user: Optional[User] = None) -> bool: """Check if password matches Active direcotry password policies https://docs.microsoft.com/en-us/windows/security/threat-protection/ @@ -160,7 +158,5 @@ class LDAPPasswordChanger: must=required, ) return False - LOGGER.debug( - "Password matched categories", has=matched_categories, must=required - ) + LOGGER.debug("Password matched categories", has=matched_categories, must=required) return True diff --git a/authentik/sources/ldap/signals.py b/authentik/sources/ldap/signals.py index 1ec5d5f2a..fc7edc74f 100644 --- a/authentik/sources/ldap/signals.py +++ b/authentik/sources/ldap/signals.py @@ -39,9 +39,7 @@ def ldap_password_validate(sender, password: str, plan_context: dict[str, Any], password, plan_context.get(PLAN_CONTEXT_PENDING_USER, None) ) if not passing: - raise ValidationError( - _("Password does not match Active Direcory Complexity.") - ) + raise ValidationError(_("Password does not match Active Direcory Complexity.")) @receiver(password_changed) diff --git a/authentik/sources/ldap/sync/base.py b/authentik/sources/ldap/sync/base.py index 6870c899b..66a05b9c7 100644 --- a/authentik/sources/ldap/sync/base.py +++ b/authentik/sources/ldap/sync/base.py @@ -50,9 +50,7 @@ class BaseLDAPSynchronizer: def build_user_properties(self, user_dn: str, **kwargs) -> dict[str, Any]: """Build attributes for User object based on property mappings.""" - return self._build_object_properties( - user_dn, self._source.property_mappings, **kwargs - ) + return self._build_object_properties(user_dn, self._source.property_mappings, **kwargs) def build_group_properties(self, group_dn: str, **kwargs) -> dict[str, Any]: """Build attributes for Group object based on property mappings.""" @@ -69,18 +67,14 @@ class BaseLDAPSynchronizer: continue mapping: LDAPPropertyMapping try: - value = mapping.evaluate( - user=None, request=None, ldap=kwargs, dn=object_dn - ) + value = mapping.evaluate(user=None, request=None, ldap=kwargs, dn=object_dn) if value is None: continue object_field = mapping.object_field if object_field.startswith("attributes."): # Because returning a list might desired, we can't # rely on self._flatten here. Instead, just save the result as-is - properties["attributes"][ - object_field.replace("attributes.", "") - ] = value + properties["attributes"][object_field.replace("attributes.", "")] = value else: properties[object_field] = self._flatten(value) except PropertyMappingExpressionException as exc: @@ -89,9 +83,7 @@ class BaseLDAPSynchronizer: message=f"Failed to evaluate property-mapping: {str(exc)}", mapping=mapping, ).save() - self._logger.warning( - "Mapping failed to evaluate", exc=exc, mapping=mapping - ) + self._logger.warning("Mapping failed to evaluate", exc=exc, mapping=mapping) continue if self._source.object_uniqueness_field in kwargs: properties["attributes"][LDAP_UNIQUENESS] = self._flatten( diff --git a/authentik/sources/ldap/sync/groups.py b/authentik/sources/ldap/sync/groups.py index 291dbb67a..55771fd55 100644 --- a/authentik/sources/ldap/sync/groups.py +++ b/authentik/sources/ldap/sync/groups.py @@ -26,9 +26,7 @@ class GroupLDAPSynchronizer(BaseLDAPSynchronizer): group_count = 0 for group in groups: attributes = group.get("attributes", {}) - group_dn = self._flatten( - self._flatten(group.get("entryDN", group.get("dn"))) - ) + group_dn = self._flatten(self._flatten(group.get("entryDN", group.get("dn")))) if self._source.object_uniqueness_field not in attributes: self._logger.warning( "Cannot find uniqueness Field in attributes", diff --git a/authentik/sources/ldap/sync/membership.py b/authentik/sources/ldap/sync/membership.py index 90849a9a2..1f623b651 100644 --- a/authentik/sources/ldap/sync/membership.py +++ b/authentik/sources/ldap/sync/membership.py @@ -34,9 +34,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer): ) membership_count = 0 for group in groups: - members = group.get("attributes", {}).get( - self._source.group_membership_field, [] - ) + members = group.get("attributes", {}).get(self._source.group_membership_field, []) ak_group = self.get_group(group) if not ak_group: continue @@ -60,9 +58,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer): def get_group(self, group_dict: dict[str, Any]) -> Optional[Group]: """Check if we fetched the group already, and if not cache it for later""" group_dn = group_dict.get("attributes", {}).get(LDAP_DISTINGUISHED_NAME, []) - group_uniq = group_dict.get("attributes", {}).get( - self._source.object_uniqueness_field, [] - ) + group_uniq = group_dict.get("attributes", {}).get(self._source.object_uniqueness_field, []) # group_uniq might be a single string or an array with (hopefully) a single string if isinstance(group_uniq, list): if len(group_uniq) < 1: @@ -73,9 +69,7 @@ class MembershipLDAPSynchronizer(BaseLDAPSynchronizer): return None group_uniq = group_uniq[0] if group_uniq not in self.group_cache: - groups = Group.objects.filter( - **{f"attributes__{LDAP_UNIQUENESS}": group_uniq} - ) + groups = Group.objects.filter(**{f"attributes__{LDAP_UNIQUENESS}": group_uniq}) if not groups.exists(): self._logger.warning( "Group does not exist in our DB yet, run sync_groups first.", diff --git a/authentik/sources/ldap/sync/users.py b/authentik/sources/ldap/sync/users.py index a7fd98186..6462c8483 100644 --- a/authentik/sources/ldap/sync/users.py +++ b/authentik/sources/ldap/sync/users.py @@ -61,9 +61,7 @@ class UserLDAPSynchronizer(BaseLDAPSynchronizer): dn=user_dn, ).save() else: - self._logger.debug( - "Synced User", user=ak_user.username, created=created - ) + self._logger.debug("Synced User", user=ak_user.username, created=created) user_count += 1 pwd_last_set: datetime = attributes.get("pwdLastSet", datetime.now()) pwd_last_set = pwd_last_set.replace(tzinfo=UTC) diff --git a/authentik/sources/ldap/tests/test_auth.py b/authentik/sources/ldap/tests/test_auth.py index 92eb7cb52..2c5c9893f 100644 --- a/authentik/sources/ldap/tests/test_auth.py +++ b/authentik/sources/ldap/tests/test_auth.py @@ -51,9 +51,7 @@ class LDAPSyncTests(TestCase): ): backend = LDAPBackend() self.assertEqual( - backend.authenticate( - None, username="user0_sn", password=LDAP_PASSWORD - ), + backend.authenticate(None, username="user0_sn", password=LDAP_PASSWORD), user, ) @@ -80,8 +78,6 @@ class LDAPSyncTests(TestCase): ): backend = LDAPBackend() self.assertEqual( - backend.authenticate( - None, username="user0_sn", password=LDAP_PASSWORD - ), + backend.authenticate(None, username="user0_sn", password=LDAP_PASSWORD), user, ) diff --git a/authentik/sources/ldap/tests/test_password.py b/authentik/sources/ldap/tests/test_password.py index 91a32abf2..fa308c6c7 100644 --- a/authentik/sources/ldap/tests/test_password.py +++ b/authentik/sources/ldap/tests/test_password.py @@ -46,9 +46,5 @@ class LDAPPasswordTests(TestCase): self.assertFalse(pwc.ad_password_complexity("test", user)) # 1 category self.assertFalse(pwc.ad_password_complexity("test1", user)) # 2 categories self.assertTrue(pwc.ad_password_complexity("test1!", user)) # 2 categories - self.assertFalse( - pwc.ad_password_complexity("erin!qewrqewr", user) - ) # displayName token - self.assertFalse( - pwc.ad_password_complexity("hagens!qewrqewr", user) - ) # displayName token + self.assertFalse(pwc.ad_password_complexity("erin!qewrqewr", user)) # displayName token + self.assertFalse(pwc.ad_password_complexity("hagens!qewrqewr", user)) # displayName token diff --git a/authentik/sources/ldap/tests/test_sync.py b/authentik/sources/ldap/tests/test_sync.py index 98496e85f..bd150dc02 100644 --- a/authentik/sources/ldap/tests/test_sync.py +++ b/authentik/sources/ldap/tests/test_sync.py @@ -101,9 +101,7 @@ class LDAPSyncTests(TestCase): ) ) self.source.property_mappings_group.set( - LDAPPropertyMapping.objects.filter( - managed="goauthentik.io/sources/ldap/default-name" - ) + LDAPPropertyMapping.objects.filter(managed="goauthentik.io/sources/ldap/default-name") ) self.source.save() connection = PropertyMock(return_value=mock_ad_connection(LDAP_PASSWORD)) @@ -126,9 +124,7 @@ class LDAPSyncTests(TestCase): ) ) self.source.property_mappings_group.set( - LDAPPropertyMapping.objects.filter( - managed="goauthentik.io/sources/ldap/openldap-cn" - ) + LDAPPropertyMapping.objects.filter(managed="goauthentik.io/sources/ldap/openldap-cn") ) self.source.save() connection = PropertyMock(return_value=mock_slapd_connection(LDAP_PASSWORD)) diff --git a/authentik/sources/oauth/api/source.py b/authentik/sources/oauth/api/source.py index f0af31661..dc17be781 100644 --- a/authentik/sources/oauth/api/source.py +++ b/authentik/sources/oauth/api/source.py @@ -58,9 +58,7 @@ class OAuthSourceSerializer(SourceSerializer): ]: if getattr(provider_type, url, None) is None: if url not in attrs: - raise ValidationError( - f"{url} is required for provider {provider_type.name}" - ) + raise ValidationError(f"{url} is required for provider {provider_type.name}") return attrs class Meta: diff --git a/authentik/sources/oauth/clients/base.py b/authentik/sources/oauth/clients/base.py index d13a02078..842872848 100644 --- a/authentik/sources/oauth/clients/base.py +++ b/authentik/sources/oauth/clients/base.py @@ -25,9 +25,7 @@ class BaseOAuthClient: callback: Optional[str] - def __init__( - self, source: OAuthSource, request: HttpRequest, callback: Optional[str] = None - ): + def __init__(self, source: OAuthSource, request: HttpRequest, callback: Optional[str] = None): self.source = source self.session = Session() self.request = request diff --git a/authentik/sources/oauth/migrations/0001_initial.py b/authentik/sources/oauth/migrations/0001_initial.py index b13defbe8..a312f1cec 100644 --- a/authentik/sources/oauth/migrations/0001_initial.py +++ b/authentik/sources/oauth/migrations/0001_initial.py @@ -30,9 +30,7 @@ class Migration(migrations.Migration): ("provider_type", models.CharField(max_length=255)), ( "request_token_url", - models.CharField( - blank=True, max_length=255, verbose_name="Request Token URL" - ), + models.CharField(blank=True, max_length=255, verbose_name="Request Token URL"), ), ( "authorization_url", diff --git a/authentik/sources/oauth/types/azure_ad.py b/authentik/sources/oauth/types/azure_ad.py index 879664d33..e6066de21 100644 --- a/authentik/sources/oauth/types/azure_ad.py +++ b/authentik/sources/oauth/types/azure_ad.py @@ -24,9 +24,7 @@ class AzureADClient(OAuth2Client): response = self.session.request( "get", profile_url, - headers={ - "Authorization": f"{token['token_type']} {token['access_token']}" - }, + headers={"Authorization": f"{token['token_type']} {token['access_token']}"}, ) LOGGER.debug(response.text) response.raise_for_status() diff --git a/authentik/sources/oauth/types/twitter.py b/authentik/sources/oauth/types/twitter.py index b4df3d607..91bd1d70d 100644 --- a/authentik/sources/oauth/types/twitter.py +++ b/authentik/sources/oauth/types/twitter.py @@ -31,6 +31,5 @@ class TwitterType(SourceType): authorization_url = "https://api.twitter.com/oauth/authenticate" access_token_url = "https://api.twitter.com/oauth/access_token" # nosec profile_url = ( - "https://api.twitter.com/1.1/account/" - "verify_credentials.json?include_email=true" + "https://api.twitter.com/1.1/account/" "verify_credentials.json?include_email=true" ) diff --git a/authentik/sources/oauth/views/callback.py b/authentik/sources/oauth/views/callback.py index 122f94804..a192382dd 100644 --- a/authentik/sources/oauth/views/callback.py +++ b/authentik/sources/oauth/views/callback.py @@ -34,9 +34,7 @@ class OAuthCallback(OAuthClientMixin, View): if not self.source.enabled: raise Http404(f"Source {slug} is not enabled.") - client = self.get_client( - self.source, callback=self.get_callback_url(self.source) - ) + client = self.get_client(self.source, callback=self.get_callback_url(self.source)) # Fetch access token token = client.get_access_token() if token is None: diff --git a/authentik/sources/plex/migrations/0002_auto_20210505_1717.py b/authentik/sources/plex/migrations/0002_auto_20210505_1717.py index 43869fd08..cf5977ead 100644 --- a/authentik/sources/plex/migrations/0002_auto_20210505_1717.py +++ b/authentik/sources/plex/migrations/0002_auto_20210505_1717.py @@ -24,9 +24,7 @@ class Migration(migrations.Migration): migrations.AddField( model_name="plexsource", name="plex_token", - field=models.TextField( - default="", help_text="Plex token used to check firends" - ), + field=models.TextField(default="", help_text="Plex token used to check firends"), ), migrations.AlterField( model_name="plexsource", diff --git a/authentik/sources/plex/plex.py b/authentik/sources/plex/plex.py index ecfb7a5a4..81cfe2f0e 100644 --- a/authentik/sources/plex/plex.py +++ b/authentik/sources/plex/plex.py @@ -92,9 +92,7 @@ class PlexAuth: if resource["provides"] != "server": continue if resource["clientIdentifier"] in self._source.allowed_servers: - LOGGER.info( - "Plex allowed access from server", name=resource["name"] - ) + LOGGER.info("Plex allowed access from server", name=resource["name"]) return True return False @@ -104,9 +102,7 @@ class PlexSourceFlowManager(SourceFlowManager): connection_type = PlexSourceConnection - def update_connection( - self, connection: PlexSourceConnection, **kwargs - ) -> PlexSourceConnection: + def update_connection(self, connection: PlexSourceConnection, **kwargs) -> PlexSourceConnection: """Set the access_token on the connection""" connection.plex_token = kwargs.get("plex_token") return connection diff --git a/authentik/sources/plex/tasks.py b/authentik/sources/plex/tasks.py index 705d53826..57ecd08f5 100644 --- a/authentik/sources/plex/tasks.py +++ b/authentik/sources/plex/tasks.py @@ -26,9 +26,7 @@ def check_plex_token(self: MonitoredTask, source_slug: int): auth = PlexAuth(source, source.plex_token) try: auth.get_user_info() - self.set_status( - TaskResult(TaskResultStatus.SUCCESSFUL, ["Plex token is valid."]) - ) + self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, ["Plex token is valid."])) except RequestException as exc: self.set_status( TaskResult( diff --git a/authentik/sources/plex/tests.py b/authentik/sources/plex/tests.py index 5ee57d5dc..21c049977 100644 --- a/authentik/sources/plex/tests.py +++ b/authentik/sources/plex/tests.py @@ -71,12 +71,8 @@ class TestPlexSource(TestCase): with Mocker() as mocker: mocker.get("https://plex.tv/api/v2/user", json=USER_INFO_RESPONSE) check_plex_token_all() - self.assertFalse( - Event.objects.filter(action=EventAction.CONFIGURATION_ERROR).exists() - ) + self.assertFalse(Event.objects.filter(action=EventAction.CONFIGURATION_ERROR).exists()) with Mocker() as mocker: mocker.get("https://plex.tv/api/v2/user", exc=RequestException()) check_plex_token_all() - self.assertTrue( - Event.objects.filter(action=EventAction.CONFIGURATION_ERROR).exists() - ) + self.assertTrue(Event.objects.filter(action=EventAction.CONFIGURATION_ERROR).exists()) diff --git a/authentik/sources/saml/migrations/0010_samlsource_pre_authentication_flow.py b/authentik/sources/saml/migrations/0010_samlsource_pre_authentication_flow.py index 5b1b8cde6..324fe35e0 100644 --- a/authentik/sources/saml/migrations/0010_samlsource_pre_authentication_flow.py +++ b/authentik/sources/saml/migrations/0010_samlsource_pre_authentication_flow.py @@ -8,9 +8,7 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor from authentik.flows.models import FlowDesignation -def create_default_pre_authentication_flow( - apps: Apps, schema_editor: BaseDatabaseSchemaEditor -): +def create_default_pre_authentication_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): Flow = apps.get_model("authentik_flows", "Flow") SAMLSource = apps.get_model("authentik_sources_saml", "samlsource") diff --git a/authentik/sources/saml/processors/constants.py b/authentik/sources/saml/processors/constants.py index 967a365cb..d7e4bcd07 100644 --- a/authentik/sources/saml/processors/constants.py +++ b/authentik/sources/saml/processors/constants.py @@ -15,13 +15,9 @@ NS_MAP = { SAML_NAME_ID_FORMAT_EMAIL = "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress" SAML_NAME_ID_FORMAT_PERSISTENT = "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent" -SAML_NAME_ID_FORMAT_UNSPECIFIED = ( - "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified" -) +SAML_NAME_ID_FORMAT_UNSPECIFIED = "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified" SAML_NAME_ID_FORMAT_X509 = "urn:oasis:names:tc:SAML:2.0:nameid-format:X509SubjectName" -SAML_NAME_ID_FORMAT_WINDOWS = ( - "urn:oasis:names:tc:SAML:2.0:nameid-format:WindowsDomainQualifiedName" -) +SAML_NAME_ID_FORMAT_WINDOWS = "urn:oasis:names:tc:SAML:2.0:nameid-format:WindowsDomainQualifiedName" SAML_NAME_ID_FORMAT_TRANSIENT = "urn:oasis:names:tc:SAML:2.0:nameid-format:transient" SAML_BINDING_POST = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" diff --git a/authentik/sources/saml/processors/metadata.py b/authentik/sources/saml/processors/metadata.py index a0cd61a35..93f135b39 100644 --- a/authentik/sources/saml/processors/metadata.py +++ b/authentik/sources/saml/processors/metadata.py @@ -36,9 +36,7 @@ class MetadataProcessor: key_descriptor.attrib["use"] = "signing" key_info = SubElement(key_descriptor, f"{{{NS_SIGNATURE}}}KeyInfo") x509_data = SubElement(key_info, f"{{{NS_SIGNATURE}}}X509Data") - x509_certificate = SubElement( - x509_data, f"{{{NS_SIGNATURE}}}X509Certificate" - ) + x509_certificate = SubElement(x509_data, f"{{{NS_SIGNATURE}}}X509Certificate") x509_certificate.text = strip_pem_header( self.source.signing_kp.certificate_data.replace("\r", "") ).replace("\n", "") @@ -61,14 +59,10 @@ class MetadataProcessor: def build_entity_descriptor(self) -> str: """Build full EntityDescriptor""" - entity_descriptor = Element( - f"{{{NS_SAML_METADATA}}}EntityDescriptor", nsmap=NS_MAP - ) + entity_descriptor = Element(f"{{{NS_SAML_METADATA}}}EntityDescriptor", nsmap=NS_MAP) entity_descriptor.attrib["entityID"] = self.source.get_issuer(self.http_request) - sp_sso_descriptor = SubElement( - entity_descriptor, f"{{{NS_SAML_METADATA}}}SPSSODescriptor" - ) + sp_sso_descriptor = SubElement(entity_descriptor, f"{{{NS_SAML_METADATA}}}SPSSODescriptor") sp_sso_descriptor.attrib[ "protocolSupportEnumeration" ] = "urn:oasis:names:tc:SAML:2.0:protocol" diff --git a/authentik/sources/saml/processors/request.py b/authentik/sources/saml/processors/request.py index e16ca2777..90072f8df 100644 --- a/authentik/sources/saml/processors/request.py +++ b/authentik/sources/saml/processors/request.py @@ -56,9 +56,9 @@ class RequestProcessor: def get_auth_n(self) -> Element: """Get full AuthnRequest""" auth_n_request = Element(f"{{{NS_SAML_PROTOCOL}}}AuthnRequest", nsmap=NS_MAP) - auth_n_request.attrib[ - "AssertionConsumerServiceURL" - ] = self.source.build_full_url(self.http_request) + auth_n_request.attrib["AssertionConsumerServiceURL"] = self.source.build_full_url( + self.http_request + ) auth_n_request.attrib["Destination"] = self.source.sso_url auth_n_request.attrib["ID"] = self.request_id auth_n_request.attrib["IssueInstant"] = self.issue_instant @@ -106,9 +106,7 @@ class RequestProcessor: self.source.digest_algorithm, xmlsec.constants.TransformSha1 ) - signature_node = xmlsec.tree.find_node( - auth_n_request, xmlsec.constants.NodeSignature - ) + signature_node = xmlsec.tree.find_node(auth_n_request, xmlsec.constants.NodeSignature) ref = xmlsec.template.add_reference( signature_node, @@ -129,9 +127,7 @@ class RequestProcessor: Signature. See https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf""" auth_n_request = self.get_auth_n() - saml_request = deflate_and_base64_encode( - etree.tostring(auth_n_request).decode() - ) + saml_request = deflate_and_base64_encode(etree.tostring(auth_n_request).decode()) response_dict = { "SAMLRequest": saml_request, @@ -162,9 +158,7 @@ class RequestProcessor: ) ctx.key = key - signature = ctx.sign_binary( - querystring.encode("utf-8"), sign_algorithm_transform - ) + signature = ctx.sign_binary(querystring.encode("utf-8"), sign_algorithm_transform) response_dict["Signature"] = b64encode(signature).decode() response_dict["SigAlg"] = self.source.signature_algorithm diff --git a/authentik/sources/saml/processors/response.py b/authentik/sources/saml/processors/response.py index 8ccf26232..0e41b44c8 100644 --- a/authentik/sources/saml/processors/response.py +++ b/authentik/sources/saml/processors/response.py @@ -110,10 +110,7 @@ class ResponseProcessor: seen_ids.append(self._root.attrib["ID"]) cache.set(CACHE_SEEN_REQUEST_ID % self._source.pk, seen_ids) return - if ( - SESSION_REQUEST_ID not in request.session - or "InResponseTo" not in self._root.attrib - ): + if SESSION_REQUEST_ID not in request.session or "InResponseTo" not in self._root.attrib: raise MismatchedRequestID( "Missing InResponseTo and IdP-initiated Logins are not allowed" ) @@ -129,9 +126,7 @@ class ResponseProcessor: name_id = self._get_name_id().text user: User = User.objects.create( username=name_id, - attributes={ - "saml": {"source": self._source.pk.hex, "delete_on_logout": True} - }, + attributes={"saml": {"source": self._source.pk.hex, "delete_on_logout": True}}, ) LOGGER.debug("Created temporary user for NameID Transient", username=name_id) user.set_unusable_password() @@ -214,9 +209,7 @@ class ResponseProcessor: **{PLAN_CONTEXT_PROMPT: delete_none_keys(name_id_filter)}, ) - def _flow_response( - self, request: HttpRequest, flow: Flow, **kwargs - ) -> HttpResponse: + def _flow_response(self, request: HttpRequest, flow: Flow, **kwargs) -> HttpResponse: kwargs[PLAN_CONTEXT_SSO] = True kwargs[PLAN_CONTEXT_SOURCE] = self._source request.session[SESSION_KEY_PLAN] = FlowPlanner(flow).plan(request, kwargs) diff --git a/authentik/sources/saml/tasks.py b/authentik/sources/saml/tasks.py index bac63ad81..91ec96b85 100644 --- a/authentik/sources/saml/tasks.py +++ b/authentik/sources/saml/tasks.py @@ -18,16 +18,10 @@ def clean_temporary_users(self: MonitoredTask): messages = [] deleted_users = 0 for user in User.objects.filter(attributes__saml__isnull=False): - sources = SAMLSource.objects.filter( - pk=user.attributes.get("saml", {}).get("source", "") - ) + sources = SAMLSource.objects.filter(pk=user.attributes.get("saml", {}).get("source", "")) if not sources.exists(): - LOGGER.warning( - "User has an invalid SAML Source and won't be deleted!", user=user - ) - messages.append( - f"User {user} has an invalid SAML Source and won't be deleted!" - ) + LOGGER.warning("User has an invalid SAML Source and won't be deleted!", user=user) + messages.append(f"User {user} has an invalid SAML Source and won't be deleted!") continue source = sources.first() source_delta = timedelta_from_string(source.temporary_user_delete_after) @@ -35,9 +29,7 @@ def clean_temporary_users(self: MonitoredTask): _now - user.last_login >= source_delta and not AuthenticatedSession.objects.filter(user=user).exists() ): - LOGGER.debug( - "User is expired and will be deleted.", user=user, delta=source_delta - ) + LOGGER.debug("User is expired and will be deleted.", user=user, delta=source_delta) user.delete() deleted_users += 1 messages.append(f"Successfully deleted {deleted_users} users.") diff --git a/authentik/sources/saml/tests/test_metadata.py b/authentik/sources/saml/tests/test_metadata.py index 95e1482db..672c20262 100644 --- a/authentik/sources/saml/tests/test_metadata.py +++ b/authentik/sources/saml/tests/test_metadata.py @@ -21,17 +21,13 @@ class TestMetadataProcessor(TestCase): slug="provider", issuer="authentik", signing_kp=CertificateKeyPair.objects.first(), - pre_authentication_flow=Flow.objects.get( - slug="default-source-pre-authentication" - ), + pre_authentication_flow=Flow.objects.get(slug="default-source-pre-authentication"), ) request = self.factory.get("/") xml = MetadataProcessor(source, request).build_entity_descriptor() metadata = etree.fromstring(xml) # nosec - schema = etree.XMLSchema( - etree.parse("xml/saml-schema-metadata-2.0.xsd") - ) # nosec + schema = etree.XMLSchema(etree.parse("xml/saml-schema-metadata-2.0.xsd")) # nosec self.assertTrue(schema.validate(metadata)) def test_metadata(self): @@ -40,9 +36,7 @@ class TestMetadataProcessor(TestCase): slug="provider", issuer="authentik", signing_kp=CertificateKeyPair.objects.first(), - pre_authentication_flow=Flow.objects.get( - slug="default-source-pre-authentication" - ), + pre_authentication_flow=Flow.objects.get(slug="default-source-pre-authentication"), ) request = self.factory.get("/") xml = MetadataProcessor(source, request).build_entity_descriptor() @@ -54,9 +48,7 @@ class TestMetadataProcessor(TestCase): source = SAMLSource.objects.create( slug="provider", issuer="authentik", - pre_authentication_flow=Flow.objects.get( - slug="default-source-pre-authentication" - ), + pre_authentication_flow=Flow.objects.get(slug="default-source-pre-authentication"), ) request = self.factory.get("/") xml = MetadataProcessor(source, request).build_entity_descriptor() diff --git a/authentik/sources/saml/views.py b/authentik/sources/saml/views.py index d85496d3f..ef304c0b2 100644 --- a/authentik/sources/saml/views.py +++ b/authentik/sources/saml/views.py @@ -27,10 +27,7 @@ from authentik.lib.utils.urls import redirect_with_qs from authentik.lib.views import bad_request_message from authentik.providers.saml.utils.encoding import nice64 from authentik.providers.saml.views.flows import AutosubmitChallenge -from authentik.sources.saml.exceptions import ( - MissingSAMLResponse, - UnsupportedNameIDFormat, -) +from authentik.sources.saml.exceptions import MissingSAMLResponse, UnsupportedNameIDFormat from authentik.sources.saml.models import SAMLBindingTypes, SAMLSource from authentik.sources.saml.processors.metadata import MetadataProcessor from authentik.sources.saml.processors.request import RequestProcessor @@ -69,9 +66,7 @@ class AutosubmitStageView(ChallengeStageView): class InitiateView(View): """Get the Form with SAML Request, which sends us to the IDP""" - def handle_login_flow( - self, source: SAMLSource, *stages_to_append, **kwargs - ) -> HttpResponse: + def handle_login_flow(self, source: SAMLSource, *stages_to_append, **kwargs) -> HttpResponse: """Prepare Authentication Plan, redirect user FlowExecutor""" # Ensure redirect is carried through when user was trying to # authorize application diff --git a/authentik/stages/authenticator_duo/migrations/0001_initial.py b/authentik/stages/authenticator_duo/migrations/0001_initial.py index 89fd69207..4585fe39e 100644 --- a/authentik/stages/authenticator_duo/migrations/0001_initial.py +++ b/authentik/stages/authenticator_duo/migrations/0001_initial.py @@ -70,9 +70,7 @@ class Migration(migrations.Migration): ), ( "confirmed", - models.BooleanField( - default=True, help_text="Is this device ready for use?" - ), + models.BooleanField(default=True, help_text="Is this device ready for use?"), ), ("duo_user_id", models.TextField()), ( diff --git a/authentik/stages/authenticator_duo/models.py b/authentik/stages/authenticator_duo/models.py index 7edd1bda5..4e9db0478 100644 --- a/authentik/stages/authenticator_duo/models.py +++ b/authentik/stages/authenticator_duo/models.py @@ -23,9 +23,7 @@ class AuthenticatorDuoStage(ConfigurableStage, Stage): @property def serializer(self) -> BaseSerializer: - from authentik.stages.authenticator_duo.api import ( - AuthenticatorDuoStageSerializer, - ) + from authentik.stages.authenticator_duo.api import AuthenticatorDuoStageSerializer return AuthenticatorDuoStageSerializer diff --git a/authentik/stages/authenticator_duo/stage.py b/authentik/stages/authenticator_duo/stage.py index 96c5908c6..d8ea49165 100644 --- a/authentik/stages/authenticator_duo/stage.py +++ b/authentik/stages/authenticator_duo/stage.py @@ -55,9 +55,7 @@ class AuthenticatorDuoStageView(ChallengeStageView): raise InvalidStageError(str(exc)) from exc user_id = enroll["user_id"] self.request.session[SESSION_KEY_DUO_USER_ID] = user_id - self.request.session[SESSION_KEY_DUO_ACTIVATION_CODE] = enroll[ - "activation_code" - ] + self.request.session[SESSION_KEY_DUO_ACTIVATION_CODE] = enroll["activation_code"] return AuthenticatorDuoChallenge( data={ "type": ChallengeTypes.NATIVE.value, @@ -86,11 +84,7 @@ class AuthenticatorDuoStageView(ChallengeStageView): self.request.session.pop(SESSION_KEY_DUO_USER_ID) self.request.session.pop(SESSION_KEY_DUO_ACTIVATION_CODE) if not existing_device: - DuoDevice.objects.create( - user=self.get_pending_user(), duo_user_id=user_id, stage=stage - ) + DuoDevice.objects.create(user=self.get_pending_user(), duo_user_id=user_id, stage=stage) else: - return self.executor.stage_invalid( - "Device with Credential ID already exists." - ) + return self.executor.stage_invalid("Device with Credential ID already exists.") return self.executor.stage_ok() diff --git a/authentik/stages/authenticator_static/migrations/0005_default_setup_flow.py b/authentik/stages/authenticator_static/migrations/0005_default_setup_flow.py index de787105e..7fc920678 100644 --- a/authentik/stages/authenticator_static/migrations/0005_default_setup_flow.py +++ b/authentik/stages/authenticator_static/migrations/0005_default_setup_flow.py @@ -34,9 +34,7 @@ def create_default_setup_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEdito target=flow, stage=stage, defaults={"order": 0} ) - for stage in AuthenticatorStaticStage.objects.using(db_alias).filter( - configure_flow=None - ): + for stage in AuthenticatorStaticStage.objects.using(db_alias).filter(configure_flow=None): stage.configure_flow = flow stage.save() diff --git a/authentik/stages/authenticator_static/models.py b/authentik/stages/authenticator_static/models.py index 125abb02f..5b26ca8aa 100644 --- a/authentik/stages/authenticator_static/models.py +++ b/authentik/stages/authenticator_static/models.py @@ -17,17 +17,13 @@ class AuthenticatorStaticStage(ConfigurableStage, Stage): @property def serializer(self) -> BaseSerializer: - from authentik.stages.authenticator_static.api import ( - AuthenticatorStaticStageSerializer, - ) + from authentik.stages.authenticator_static.api import AuthenticatorStaticStageSerializer return AuthenticatorStaticStageSerializer @property def type(self) -> Type[View]: - from authentik.stages.authenticator_static.stage import ( - AuthenticatorStaticStageView, - ) + from authentik.stages.authenticator_static.stage import AuthenticatorStaticStageView return AuthenticatorStaticStageView diff --git a/authentik/stages/authenticator_static/signals.py b/authentik/stages/authenticator_static/signals.py index 0b81c3663..81a75799e 100644 --- a/authentik/stages/authenticator_static/signals.py +++ b/authentik/stages/authenticator_static/signals.py @@ -11,8 +11,6 @@ from authentik.events.models import Event def pre_delete_event(sender, instance: StaticDevice, **_): """Create event before deleting Static Devices""" # Create event with email notification - event = Event.new( - "static_authenticator_disable", message="User disabled Static OTP Tokens." - ) + event = Event.new("static_authenticator_disable", message="User disabled Static OTP Tokens.") event.set_user(instance.user) event.save() diff --git a/authentik/stages/authenticator_static/stage.py b/authentik/stages/authenticator_static/stage.py index 5f40a10e6..92316bc34 100644 --- a/authentik/stages/authenticator_static/stage.py +++ b/authentik/stages/authenticator_static/stage.py @@ -4,11 +4,7 @@ from django_otp.plugins.otp_static.models import StaticDevice, StaticToken from rest_framework.fields import CharField, ListField from structlog.stdlib import get_logger -from authentik.flows.challenge import ( - ChallengeResponse, - ChallengeTypes, - WithUserInfoChallenge, -) +from authentik.flows.challenge import ChallengeResponse, ChallengeTypes, WithUserInfoChallenge from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER from authentik.flows.stage import ChallengeStageView from authentik.stages.authenticator_static.models import AuthenticatorStaticStage @@ -62,9 +58,7 @@ class AuthenticatorStaticStageView(ChallengeStageView): device = StaticDevice(user=user, confirmed=True) tokens = [] for _ in range(0, stage.token_count): - tokens.append( - StaticToken(device=device, token=StaticToken.random_token()) - ) + tokens.append(StaticToken(device=device, token=StaticToken.random_token())) self.request.session[SESSION_STATIC_DEVICE] = device self.request.session[SESSION_STATIC_TOKENS] = tokens return super().get(request, *args, **kwargs) diff --git a/authentik/stages/authenticator_totp/migrations/0006_default_setup_flow.py b/authentik/stages/authenticator_totp/migrations/0006_default_setup_flow.py index 85788c348..a9af0239f 100644 --- a/authentik/stages/authenticator_totp/migrations/0006_default_setup_flow.py +++ b/authentik/stages/authenticator_totp/migrations/0006_default_setup_flow.py @@ -35,9 +35,7 @@ def create_default_setup_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEdito target=flow, stage=stage, defaults={"order": 0} ) - for stage in AuthenticatorTOTPStage.objects.using(db_alias).filter( - configure_flow=None - ): + for stage in AuthenticatorTOTPStage.objects.using(db_alias).filter(configure_flow=None): stage.configure_flow = flow stage.save() diff --git a/authentik/stages/authenticator_totp/models.py b/authentik/stages/authenticator_totp/models.py index daf69d015..9b36fe303 100644 --- a/authentik/stages/authenticator_totp/models.py +++ b/authentik/stages/authenticator_totp/models.py @@ -24,9 +24,7 @@ class AuthenticatorTOTPStage(ConfigurableStage, Stage): @property def serializer(self) -> BaseSerializer: - from authentik.stages.authenticator_totp.api import ( - AuthenticatorTOTPStageSerializer, - ) + from authentik.stages.authenticator_totp.api import AuthenticatorTOTPStageSerializer return AuthenticatorTOTPStageSerializer diff --git a/authentik/stages/authenticator_validate/migrations/0004_auto_20210301_0949.py b/authentik/stages/authenticator_validate/migrations/0004_auto_20210301_0949.py index 4e5ca58e6..8d30e44ed 100644 --- a/authentik/stages/authenticator_validate/migrations/0004_auto_20210301_0949.py +++ b/authentik/stages/authenticator_validate/migrations/0004_auto_20210301_0949.py @@ -16,8 +16,6 @@ class Migration(migrations.Migration): migrations.AlterField( model_name="authenticatorvalidatestage", name="not_configured_action", - field=models.TextField( - choices=[("skip", "Skip"), ("deny", "Deny")], default="skip" - ), + field=models.TextField(choices=[("skip", "Skip"), ("deny", "Deny")], default="skip"), ), ] diff --git a/authentik/stages/authenticator_validate/models.py b/authentik/stages/authenticator_validate/models.py index a9276babd..197b0eda3 100644 --- a/authentik/stages/authenticator_validate/models.py +++ b/authentik/stages/authenticator_validate/models.py @@ -59,17 +59,13 @@ class AuthenticatorValidateStage(Stage): @property def serializer(self) -> BaseSerializer: - from authentik.stages.authenticator_validate.api import ( - AuthenticatorValidateStageSerializer, - ) + from authentik.stages.authenticator_validate.api import AuthenticatorValidateStageSerializer return AuthenticatorValidateStageSerializer @property def type(self) -> Type[View]: - from authentik.stages.authenticator_validate.stage import ( - AuthenticatorValidateStageView, - ) + from authentik.stages.authenticator_validate.stage import AuthenticatorValidateStageView return AuthenticatorValidateStageView diff --git a/authentik/stages/authenticator_validate/stage.py b/authentik/stages/authenticator_validate/stage.py index d8491522b..775bd41fd 100644 --- a/authentik/stages/authenticator_validate/stage.py +++ b/authentik/stages/authenticator_validate/stage.py @@ -5,11 +5,7 @@ from rest_framework.fields import CharField, IntegerField, JSONField, ListField from rest_framework.serializers import ValidationError from structlog.stdlib import get_logger -from authentik.flows.challenge import ( - ChallengeResponse, - ChallengeTypes, - WithUserInfoChallenge, -) +from authentik.flows.challenge import ChallengeResponse, ChallengeTypes, WithUserInfoChallenge from authentik.flows.models import NotConfiguredAction, Stage from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER from authentik.flows.stage import ChallengeStageView @@ -20,10 +16,7 @@ from authentik.stages.authenticator_validate.challenge import ( validate_challenge_duo, validate_challenge_webauthn, ) -from authentik.stages.authenticator_validate.models import ( - AuthenticatorValidateStage, - DeviceClasses, -) +from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses LOGGER = get_logger() @@ -46,18 +39,14 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse): component = CharField(default="ak-stage-authenticator-validate") def _challenge_allowed(self, classes: list): - device_challenges: list[dict] = self.stage.request.session.get( - "device_challenges" - ) + device_challenges: list[dict] = self.stage.request.session.get("device_challenges") if not any(x["device_class"] in classes for x in device_challenges): raise ValidationError("No compatible device class allowed") def validate_code(self, code: str) -> str: """Validate code-based response, raise error if code isn't allowed""" self._challenge_allowed([DeviceClasses.TOTP, DeviceClasses.STATIC]) - return validate_challenge_code( - code, self.stage.request, self.stage.get_pending_user() - ) + return validate_challenge_code(code, self.stage.request, self.stage.get_pending_user()) def validate_webauthn(self, webauthn: dict) -> dict: """Validate webauthn response, raise error if webauthn wasn't allowed @@ -70,9 +59,7 @@ class AuthenticatorValidationChallengeResponse(ChallengeResponse): def validate_duo(self, duo: int) -> int: """Initiate Duo authentication""" self._challenge_allowed([DeviceClasses.DUO]) - return validate_challenge_duo( - duo, self.stage.request, self.stage.get_pending_user() - ) + return validate_challenge_duo(duo, self.stage.request, self.stage.get_pending_user()) def validate(self, attrs: dict): # Checking if the given data is from a valid device class is done above @@ -162,8 +149,6 @@ class AuthenticatorValidateStageView(ChallengeStageView): ) # pylint: disable=unused-argument - def challenge_valid( - self, response: AuthenticatorValidationChallengeResponse - ) -> HttpResponse: + def challenge_valid(self, response: AuthenticatorValidationChallengeResponse) -> HttpResponse: # All validation is done by the serializer return self.executor.stage_ok() diff --git a/authentik/stages/authenticator_validate/tests.py b/authentik/stages/authenticator_validate/tests.py index e66ef266e..91b0dc8d3 100644 --- a/authentik/stages/authenticator_validate/tests.py +++ b/authentik/stages/authenticator_validate/tests.py @@ -13,14 +13,9 @@ from authentik.core.models import User from authentik.flows.challenge import ChallengeTypes from authentik.flows.models import Flow, FlowStageBinding, NotConfiguredAction from authentik.flows.tests.test_planner import dummy_get_response -from authentik.providers.oauth2.generators import ( - generate_client_id, - generate_client_secret, -) +from authentik.providers.oauth2.generators import generate_client_id, generate_client_secret from authentik.stages.authenticator_duo.models import AuthenticatorDuoStage, DuoDevice -from authentik.stages.authenticator_validate.api import ( - AuthenticatorValidateStageSerializer, -) +from authentik.stages.authenticator_validate.api import AuthenticatorValidateStageSerializer from authentik.stages.authenticator_validate.challenge import ( get_challenge_for_device, validate_challenge_code, @@ -95,9 +90,7 @@ class AuthenticatorValidateStageTests(TestCase): def test_device_challenge_totp(self): """Test device challenge""" request = self.request_factory.get("/") - totp_device = TOTPDevice.objects.create( - user=self.user, confirmed=True, digits=6 - ) + totp_device = TOTPDevice.objects.create(user=self.user, confirmed=True, digits=6) self.assertEqual(get_challenge_for_device(request, totp_device), {}) with self.assertRaises(ValidationError): validate_challenge_code("1234", request, self.user) diff --git a/authentik/stages/authenticator_webauthn/api.py b/authentik/stages/authenticator_webauthn/api.py index 9f95be3f5..3bccebca1 100644 --- a/authentik/stages/authenticator_webauthn/api.py +++ b/authentik/stages/authenticator_webauthn/api.py @@ -9,10 +9,7 @@ from rest_framework.viewsets import GenericViewSet, ModelViewSet, ReadOnlyModelV from authentik.api.authorization import OwnerFilter, OwnerPermissions from authentik.core.api.used_by import UsedByMixin from authentik.flows.api.stages import StageSerializer -from authentik.stages.authenticator_webauthn.models import ( - AuthenticateWebAuthnStage, - WebAuthnDevice, -) +from authentik.stages.authenticator_webauthn.models import AuthenticateWebAuthnStage, WebAuthnDevice class AuthenticateWebAuthnStageSerializer(StageSerializer): diff --git a/authentik/stages/authenticator_webauthn/migrations/0002_default_setup_flow.py b/authentik/stages/authenticator_webauthn/migrations/0002_default_setup_flow.py index 294850fc8..446393b68 100644 --- a/authentik/stages/authenticator_webauthn/migrations/0002_default_setup_flow.py +++ b/authentik/stages/authenticator_webauthn/migrations/0002_default_setup_flow.py @@ -34,9 +34,7 @@ def create_default_setup_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEdito target=flow, stage=stage, defaults={"order": 0} ) - for stage in AuthenticateWebAuthnStage.objects.using(db_alias).filter( - configure_flow=None - ): + for stage in AuthenticateWebAuthnStage.objects.using(db_alias).filter(configure_flow=None): stage.configure_flow = flow stage.save() diff --git a/authentik/stages/authenticator_webauthn/migrations/0003_webauthndevice_confirmed.py b/authentik/stages/authenticator_webauthn/migrations/0003_webauthndevice_confirmed.py index be8e9035d..0243fa2b8 100644 --- a/authentik/stages/authenticator_webauthn/migrations/0003_webauthndevice_confirmed.py +++ b/authentik/stages/authenticator_webauthn/migrations/0003_webauthndevice_confirmed.py @@ -13,8 +13,6 @@ class Migration(migrations.Migration): migrations.AddField( model_name="webauthndevice", name="confirmed", - field=models.BooleanField( - default=True, help_text="Is this device ready for use?" - ), + field=models.BooleanField(default=True, help_text="Is this device ready for use?"), ), ] diff --git a/authentik/stages/authenticator_webauthn/models.py b/authentik/stages/authenticator_webauthn/models.py index 8358653fc..2131fc5f2 100644 --- a/authentik/stages/authenticator_webauthn/models.py +++ b/authentik/stages/authenticator_webauthn/models.py @@ -18,17 +18,13 @@ class AuthenticateWebAuthnStage(ConfigurableStage, Stage): @property def serializer(self) -> BaseSerializer: - from authentik.stages.authenticator_webauthn.api import ( - AuthenticateWebAuthnStageSerializer, - ) + from authentik.stages.authenticator_webauthn.api import AuthenticateWebAuthnStageSerializer return AuthenticateWebAuthnStageSerializer @property def type(self) -> Type[View]: - from authentik.stages.authenticator_webauthn.stage import ( - AuthenticatorWebAuthnStageView, - ) + from authentik.stages.authenticator_webauthn.stage import AuthenticatorWebAuthnStageView return AuthenticatorWebAuthnStageView diff --git a/authentik/stages/authenticator_webauthn/stage.py b/authentik/stages/authenticator_webauthn/stage.py index 20489194c..1a383e9ca 100644 --- a/authentik/stages/authenticator_webauthn/stage.py +++ b/authentik/stages/authenticator_webauthn/stage.py @@ -22,17 +22,11 @@ from authentik.flows.challenge import ( from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER from authentik.flows.stage import ChallengeStageView from authentik.stages.authenticator_webauthn.models import WebAuthnDevice -from authentik.stages.authenticator_webauthn.utils import ( - generate_challenge, - get_origin, - get_rp_id, -) +from authentik.stages.authenticator_webauthn.utils import generate_challenge, get_origin, get_rp_id LOGGER = get_logger() -SESSION_KEY_WEBAUTHN_AUTHENTICATED = ( - "authentik_stages_authenticator_webauthn_authenticated" -) +SESSION_KEY_WEBAUTHN_AUTHENTICATED = "authentik_stages_authenticator_webauthn_authenticated" class AuthenticatorWebAuthnChallenge(WithUserInfoChallenge): @@ -89,9 +83,7 @@ class AuthenticatorWebAuthnChallengeResponse(ChallengeResponse): if credential_id_exists: raise ValidationError("Credential ID already exists.") - webauthn_credential.credential_id = str( - webauthn_credential.credential_id, "utf-8" - ) + webauthn_credential.credential_id = str(webauthn_credential.credential_id, "utf-8") webauthn_credential.public_key = str(webauthn_credential.public_key, "utf-8") return webauthn_credential @@ -145,12 +137,8 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView): return self.executor.stage_ok() return super().get(request, *args, **kwargs) - def get_response_instance( - self, data: QueryDict - ) -> AuthenticatorWebAuthnChallengeResponse: - response: AuthenticatorWebAuthnChallengeResponse = ( - super().get_response_instance(data) - ) + def get_response_instance(self, data: QueryDict) -> AuthenticatorWebAuthnChallengeResponse: + response: AuthenticatorWebAuthnChallengeResponse = super().get_response_instance(data) response.request = self.request response.user = self.get_pending_user() return response @@ -170,7 +158,5 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView): rp_id=get_rp_id(self.request), ) else: - return self.executor.stage_invalid( - "Device with Credential ID already exists." - ) + return self.executor.stage_invalid("Device with Credential ID already exists.") return self.executor.stage_ok() diff --git a/authentik/stages/captcha/models.py b/authentik/stages/captcha/models.py index 17d4a721e..ac15752f3 100644 --- a/authentik/stages/captcha/models.py +++ b/authentik/stages/captcha/models.py @@ -13,14 +13,10 @@ class CaptchaStage(Stage): """Verify the user is human using Google's reCaptcha.""" public_key = models.TextField( - help_text=_( - "Public key, acquired from https://www.google.com/recaptcha/intro/v3.html" - ) + help_text=_("Public key, acquired from https://www.google.com/recaptcha/intro/v3.html") ) private_key = models.TextField( - help_text=_( - "Private key, acquired from https://www.google.com/recaptcha/intro/v3.html" - ) + help_text=_("Private key, acquired from https://www.google.com/recaptcha/intro/v3.html") ) @property diff --git a/authentik/stages/captcha/stage.py b/authentik/stages/captcha/stage.py index 1bc0d8492..28c89b888 100644 --- a/authentik/stages/captcha/stage.py +++ b/authentik/stages/captcha/stage.py @@ -49,9 +49,7 @@ class CaptchaChallengeResponse(ChallengeResponse): response.raise_for_status() data = response.json() if not data.get("success", False): - raise ValidationError( - f"Failed to validate token: {data.get('error-codes', '')}" - ) + raise ValidationError(f"Failed to validate token: {data.get('error-codes', '')}") except RequestException as exc: raise ValidationError("Failed to validate token") from exc return token diff --git a/authentik/stages/captcha/tests.py b/authentik/stages/captcha/tests.py index 8c863e2ee..b9cebb8c1 100644 --- a/authentik/stages/captcha/tests.py +++ b/authentik/stages/captcha/tests.py @@ -21,9 +21,7 @@ class TestCaptchaStage(TestCase): def setUp(self): super().setUp() - self.user = User.objects.create_user( - username="unittest", email="test@beryju.org" - ) + self.user = User.objects.create_user(username="unittest", email="test@beryju.org") self.client = Client() self.flow = Flow.objects.create( @@ -36,22 +34,16 @@ class TestCaptchaStage(TestCase): public_key=RECAPTCHA_PUBLIC_KEY, private_key=RECAPTCHA_PRIVATE_KEY, ) - self.binding = FlowStageBinding.objects.create( - target=self.flow, stage=self.stage, order=2 - ) + self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=2) def test_valid(self): """Test valid captcha""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() response = self.client.post( - reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ), + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), {"token": "PASSED"}, ) self.assertEqual(response.status_code, 200) diff --git a/authentik/stages/consent/migrations/0002_auto_20200720_0941.py b/authentik/stages/consent/migrations/0002_auto_20200720_0941.py index 6521b4c0b..ff420a5ae 100644 --- a/authentik/stages/consent/migrations/0002_auto_20200720_0941.py +++ b/authentik/stages/consent/migrations/0002_auto_20200720_0941.py @@ -53,9 +53,7 @@ class Migration(migrations.Migration): ), ( "expires", - models.DateTimeField( - default=authentik.core.models.default_token_duration - ), + models.DateTimeField(default=authentik.core.models.default_token_duration), ), ("expiring", models.BooleanField(default=True)), ( diff --git a/authentik/stages/consent/models.py b/authentik/stages/consent/models.py index 9c2fa5433..0294a2aa4 100644 --- a/authentik/stages/consent/models.py +++ b/authentik/stages/consent/models.py @@ -22,18 +22,13 @@ class ConsentMode(models.TextChoices): class ConsentStage(Stage): """Prompt the user for confirmation.""" - mode = models.TextField( - choices=ConsentMode.choices, default=ConsentMode.ALWAYS_REQUIRE - ) + mode = models.TextField(choices=ConsentMode.choices, default=ConsentMode.ALWAYS_REQUIRE) consent_expire_in = models.TextField( validators=[timedelta_string_validator], default="weeks=4", verbose_name="Consent expires in", help_text=_( - ( - "Offset after which consent expires. " - "(Format: hours=1;minutes=2;seconds=3)." - ) + ("Offset after which consent expires. " "(Format: hours=1;minutes=2;seconds=3).") ), ) diff --git a/authentik/stages/consent/stage.py b/authentik/stages/consent/stage.py index 8227ebf91..bfeff6046 100644 --- a/authentik/stages/consent/stage.py +++ b/authentik/stages/consent/stage.py @@ -42,16 +42,12 @@ class ConsentStageView(ChallengeStageView): def get_challenge(self) -> Challenge: data = { "type": ChallengeTypes.NATIVE.value, - "permissions": self.executor.plan.context.get( - PLAN_CONTEXT_CONSENT_PERMISSIONS, [] - ), + "permissions": self.executor.plan.context.get(PLAN_CONTEXT_CONSENT_PERMISSIONS, []), } if PLAN_CONTEXT_CONSENT_TITLE in self.executor.plan.context: data["title"] = self.executor.plan.context[PLAN_CONTEXT_CONSENT_TITLE] if PLAN_CONTEXT_CONSENT_HEADER in self.executor.plan.context: - data["header_text"] = self.executor.plan.context[ - PLAN_CONTEXT_CONSENT_HEADER - ] + data["header_text"] = self.executor.plan.context[PLAN_CONTEXT_CONSENT_HEADER] challenge = ConsentChallenge(data=data) return challenge diff --git a/authentik/stages/consent/tests.py b/authentik/stages/consent/tests.py index 728b61649..c73b11c28 100644 --- a/authentik/stages/consent/tests.py +++ b/authentik/stages/consent/tests.py @@ -20,9 +20,7 @@ class TestConsentStage(TestCase): def setUp(self): super().setUp() - self.user = User.objects.create_user( - username="unittest", email="test@beryju.org" - ) + self.user = User.objects.create_user(username="unittest", email="test@beryju.org") self.application = Application.objects.create( name="test-application", slug="test-application", @@ -36,14 +34,10 @@ class TestConsentStage(TestCase): slug="test-consent", designation=FlowDesignation.AUTHENTICATION, ) - stage = ConsentStage.objects.create( - name="consent", mode=ConsentMode.ALWAYS_REQUIRE - ) + stage = ConsentStage.objects.create(name="consent", mode=ConsentMode.ALWAYS_REQUIRE) binding = FlowStageBinding.objects.create(target=flow, stage=stage, order=2) - plan = FlowPlan( - flow_pk=flow.pk.hex, bindings=[binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=flow.pk.hex, bindings=[binding], markers=[StageMarker()]) session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() @@ -96,9 +90,7 @@ class TestConsentStage(TestCase): }, ) self.assertTrue( - UserConsent.objects.filter( - user=self.user, application=self.application - ).exists() + UserConsent.objects.filter(user=self.user, application=self.application).exists() ) def test_expire(self): @@ -137,14 +129,10 @@ class TestConsentStage(TestCase): }, ) self.assertTrue( - UserConsent.objects.filter( - user=self.user, application=self.application - ).exists() + UserConsent.objects.filter(user=self.user, application=self.application).exists() ) sleep(1) clean_expired_models.delay().get() self.assertFalse( - UserConsent.objects.filter( - user=self.user, application=self.application - ).exists() + UserConsent.objects.filter(user=self.user, application=self.application).exists() ) diff --git a/authentik/stages/deny/tests.py b/authentik/stages/deny/tests.py index 9a15181ce..7713eaf62 100644 --- a/authentik/stages/deny/tests.py +++ b/authentik/stages/deny/tests.py @@ -26,15 +26,11 @@ class TestUserDenyStage(TestCase): designation=FlowDesignation.AUTHENTICATION, ) self.stage = DenyStage.objects.create(name="logout") - self.binding = FlowStageBinding.objects.create( - target=self.flow, stage=self.stage, order=2 - ) + self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=2) def test_valid_password(self): """Test with a valid pending user and backend""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() diff --git a/authentik/stages/dummy/tests.py b/authentik/stages/dummy/tests.py index 0b173feff..388366d9b 100644 --- a/authentik/stages/dummy/tests.py +++ b/authentik/stages/dummy/tests.py @@ -39,9 +39,7 @@ class TestDummyStage(TestCase): def test_post(self): """Test with valid email, check that URL redirects back to itself""" - url = reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ) + url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) response = self.client.post(url, {}) self.assertEqual(response.status_code, 200) self.assertJSONEqual( diff --git a/authentik/stages/email/apps.py b/authentik/stages/email/apps.py index 499cc600e..c3c1c5788 100644 --- a/authentik/stages/email/apps.py +++ b/authentik/stages/email/apps.py @@ -33,9 +33,7 @@ class AuthentikStageEmailConfig(AppConfig): try: get_template(stage.template) except TemplateDoesNotExist: - LOGGER.warning( - "Stage template does not exist, resetting", path=stage.template - ) + LOGGER.warning("Stage template does not exist, resetting", path=stage.template) Event.new( EventAction.CONFIGURATION_ERROR, stage=stage, diff --git a/authentik/stages/email/models.py b/authentik/stages/email/models.py index 28ea6af95..b9b7fe12d 100644 --- a/authentik/stages/email/models.py +++ b/authentik/stages/email/models.py @@ -42,9 +42,7 @@ def get_template_choices(): for template in template_dir.glob("**/*.html"): path = str(template) if not access(path, R_OK): - LOGGER.warning( - "Custom template file is not readable, check permissions", path=path - ) + LOGGER.warning("Custom template file is not readable, check permissions", path=path) continue rel_path = template.relative_to(template_dir) static_choices.append((str(rel_path), f"Custom Template: {rel_path}")) diff --git a/authentik/stages/email/stage.py b/authentik/stages/email/stage.py index 6c94e8ed7..fb1106477 100644 --- a/authentik/stages/email/stage.py +++ b/authentik/stages/email/stage.py @@ -99,9 +99,7 @@ class EmailStageView(ChallengeStageView): def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: # Check if the user came back from the email link to verify if QS_KEY_TOKEN in request.session.get(SESSION_KEY_GET, {}): - token = get_object_or_404( - Token, key=request.session[SESSION_KEY_GET][QS_KEY_TOKEN] - ) + token = get_object_or_404(Token, key=request.session[SESSION_KEY_GET][QS_KEY_TOKEN]) self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = token.user token.delete() messages.success(request, _("Successfully verified Email.")) diff --git a/authentik/stages/email/tasks.py b/authentik/stages/email/tasks.py index 59ebf6e63..d08da9683 100644 --- a/authentik/stages/email/tasks.py +++ b/authentik/stages/email/tasks.py @@ -45,9 +45,7 @@ def get_email_body(email: EmailMultiAlternatives) -> str: retry_backoff=True, base=MonitoredTask, ) -def send_mail( - self: MonitoredTask, message: dict[Any, Any], email_stage_pk: Optional[int] = None -): +def send_mail(self: MonitoredTask, message: dict[Any, Any], email_stage_pk: Optional[int] = None): """Send Email for Email Stage. Retries are scheduled automatically.""" self.save_on_success = False message_id = make_msgid(domain=DNS_NAME) diff --git a/authentik/stages/email/tests/test_sending.py b/authentik/stages/email/tests/test_sending.py index 5467999b4..8966e5a71 100644 --- a/authentik/stages/email/tests/test_sending.py +++ b/authentik/stages/email/tests/test_sending.py @@ -21,9 +21,7 @@ class TestEmailStageSending(TestCase): def setUp(self): super().setUp() - self.user = User.objects.create_user( - username="unittest", email="test@beryju.org" - ) + self.user = User.objects.create_user(username="unittest", email="test@beryju.org") self.client = Client() self.flow = Flow.objects.create( @@ -34,26 +32,18 @@ class TestEmailStageSending(TestCase): self.stage = EmailStage.objects.create( name="email", ) - self.binding = FlowStageBinding.objects.create( - target=self.flow, stage=self.stage, order=2 - ) + self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=2) def test_pending_user(self): """Test with pending user""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan.context[PLAN_CONTEXT_PENDING_USER] = self.user session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() - url = reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ) - with self.settings( - EMAIL_BACKEND="django.core.mail.backends.locmem.EmailBackend" - ): + url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) + with self.settings(EMAIL_BACKEND="django.core.mail.backends.locmem.EmailBackend"): response = self.client.post(url) self.assertEqual(response.status_code, 200) self.assertEqual(len(mail.outbox), 1) @@ -68,20 +58,14 @@ class TestEmailStageSending(TestCase): def test_send_error(self): """Test error during sending (sending will be retried)""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan.context[PLAN_CONTEXT_PENDING_USER] = self.user session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() - url = reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ) - with self.settings( - EMAIL_BACKEND="django.core.mail.backends.locmem.EmailBackend" - ): + url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) + with self.settings(EMAIL_BACKEND="django.core.mail.backends.locmem.EmailBackend"): with patch( "django.core.mail.backends.locmem.EmailBackend.send_messages", MagicMock(side_effect=[SMTPException, EmailBackend.send_messages]), diff --git a/authentik/stages/email/tests/test_stage.py b/authentik/stages/email/tests/test_stage.py index 541e21750..0c88c8f5a 100644 --- a/authentik/stages/email/tests/test_stage.py +++ b/authentik/stages/email/tests/test_stage.py @@ -22,9 +22,7 @@ class TestEmailStage(TestCase): def setUp(self): super().setUp() - self.user = User.objects.create_user( - username="unittest", email="test@beryju.org" - ) + self.user = User.objects.create_user(username="unittest", email="test@beryju.org") self.client = Client() self.flow = Flow.objects.create( @@ -35,57 +33,41 @@ class TestEmailStage(TestCase): self.stage = EmailStage.objects.create( name="email", ) - self.binding = FlowStageBinding.objects.create( - target=self.flow, stage=self.stage, order=2 - ) + self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=2) def test_rendering(self): """Test with pending user""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan.context[PLAN_CONTEXT_PENDING_USER] = self.user session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() - url = reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ) + url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) response = self.client.get(url) self.assertEqual(response.status_code, 200) def test_without_user(self): """Test without pending user""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() - url = reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ) + url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) response = self.client.get(url) self.assertEqual(response.status_code, 200) def test_pending_user(self): """Test with pending user""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan.context[PLAN_CONTEXT_PENDING_USER] = self.user session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() - url = reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ) - with self.settings( - EMAIL_BACKEND="django.core.mail.backends.locmem.EmailBackend" - ): + url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) + with self.settings(EMAIL_BACKEND="django.core.mail.backends.locmem.EmailBackend"): response = self.client.post(url) self.assertEqual(response.status_code, 200) self.assertEqual(len(mail.outbox), 1) @@ -103,9 +85,7 @@ class TestEmailStage(TestCase): """Test with token""" # Make sure token exists self.test_pending_user() - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() diff --git a/authentik/stages/identification/migrations/0009_identificationstage_sources.py b/authentik/stages/identification/migrations/0009_identificationstage_sources.py index b2231b441..39806fea7 100644 --- a/authentik/stages/identification/migrations/0009_identificationstage_sources.py +++ b/authentik/stages/identification/migrations/0009_identificationstage_sources.py @@ -8,9 +8,7 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor def assign_sources(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): db_alias = schema_editor.connection.alias - IdentificationStage = apps.get_model( - "authentik_stages_identification", "identificationstage" - ) + IdentificationStage = apps.get_model("authentik_stages_identification", "identificationstage") Source = apps.get_model("authentik_core", "source") sources = Source.objects.all() diff --git a/authentik/stages/identification/models.py b/authentik/stages/identification/models.py index f8fc7ba15..e12dca425 100644 --- a/authentik/stages/identification/models.py +++ b/authentik/stages/identification/models.py @@ -46,9 +46,7 @@ class IdentificationStage(Stage): ) case_insensitive_matching = models.BooleanField( default=True, - help_text=_( - "When enabled, user fields are matched regardless of their casing." - ), + help_text=_("When enabled, user fields are matched regardless of their casing."), ) show_matched_user = models.BooleanField( default=True, @@ -68,9 +66,7 @@ class IdentificationStage(Stage): blank=True, related_name="+", default=None, - help_text=_( - "Optional enrollment flow, which is linked at the bottom of the page." - ), + help_text=_("Optional enrollment flow, which is linked at the bottom of the page."), ) recovery_flow = models.ForeignKey( Flow, @@ -79,9 +75,7 @@ class IdentificationStage(Stage): blank=True, related_name="+", default=None, - help_text=_( - "Optional recovery flow, which is linked at the bottom of the page." - ), + help_text=_("Optional recovery flow, which is linked at the bottom of the page."), ) sources = models.ManyToManyField( diff --git a/authentik/stages/identification/stage.py b/authentik/stages/identification/stage.py index dad27f923..21654ce7e 100644 --- a/authentik/stages/identification/stage.py +++ b/authentik/stages/identification/stage.py @@ -17,10 +17,7 @@ from authentik.core.api.utils import PassiveSerializer from authentik.core.models import Application, Source, User from authentik.flows.challenge import Challenge, ChallengeResponse, ChallengeTypes from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER -from authentik.flows.stage import ( - PLAN_CONTEXT_PENDING_USER_IDENTIFIER, - ChallengeStageView, -) +from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, ChallengeStageView from authentik.flows.views import SESSION_KEY_APPLICATION_PRE, challenge_types from authentik.stages.identification.models import IdentificationStage from authentik.stages.identification.signals import identification_failed @@ -82,9 +79,7 @@ class IdentificationChallengeResponse(ChallengeResponse): if not pre_user: sleep(0.150) LOGGER.debug("invalid_login", identifier=uid_field) - identification_failed.send( - sender=self, request=self.stage.request, uid_field=uid_field - ) + identification_failed.send(sender=self, request=self.stage.request, uid_field=uid_field) # We set the pending_user even on failure so it's part of the context, even # when the input is invalid # This is so its part of the current flow plan, and on flow restart can be kept, and @@ -94,9 +89,7 @@ class IdentificationChallengeResponse(ChallengeResponse): email=uid_field, ) if not current_stage.show_matched_user: - self.stage.executor.plan.context[ - PLAN_CONTEXT_PENDING_USER_IDENTIFIER - ] = uid_field + self.stage.executor.plan.context[PLAN_CONTEXT_PENDING_USER_IDENTIFIER] = uid_field raise ValidationError("Failed to authenticate.") self.pre_user = pre_user if not current_stage.password_stage: @@ -177,9 +170,7 @@ class IdentificationStageView(ChallengeStageView): # Check all enabled source, add them if they have a UI Login button. ui_sources = [] sources: list[Source] = ( - current_stage.sources.filter(enabled=True) - .order_by("name") - .select_subclasses() + current_stage.sources.filter(enabled=True).order_by("name").select_subclasses() ) for source in sources: ui_login_button = source.ui_login_button @@ -190,9 +181,7 @@ class IdentificationStageView(ChallengeStageView): challenge.initial_data["sources"] = ui_sources return challenge - def challenge_valid( - self, response: IdentificationChallengeResponse - ) -> HttpResponse: + def challenge_valid(self, response: IdentificationChallengeResponse) -> HttpResponse: self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = response.pre_user current_stage: IdentificationStage = self.executor.current_stage if not current_stage.show_matched_user: diff --git a/authentik/stages/identification/tests.py b/authentik/stages/identification/tests.py index 158db23ab..c9d5e267f 100644 --- a/authentik/stages/identification/tests.py +++ b/authentik/stages/identification/tests.py @@ -54,9 +54,7 @@ class TestIdentificationStage(TestCase): def test_valid_with_email(self): """Test with valid email, check that URL redirects back to itself""" form_data = {"uid_field": self.user.email} - url = reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ) + url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) response = self.client.post(url, form_data) self.assertEqual(response.status_code, 200) self.assertJSONEqual( @@ -70,15 +68,11 @@ class TestIdentificationStage(TestCase): def test_valid_with_password(self): """Test with valid email and password in single step""" - pw_stage = PasswordStage.objects.create( - name="password", backends=[BACKEND_DJANGO] - ) + pw_stage = PasswordStage.objects.create(name="password", backends=[BACKEND_DJANGO]) self.stage.password_stage = pw_stage self.stage.save() form_data = {"uid_field": self.user.email, "password": self.password} - url = reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ) + url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) response = self.client.post(url, form_data) self.assertEqual(response.status_code, 200) self.assertJSONEqual( @@ -92,18 +86,14 @@ class TestIdentificationStage(TestCase): def test_invalid_with_password(self): """Test with valid email and invalid password in single step""" - pw_stage = PasswordStage.objects.create( - name="password", backends=[BACKEND_DJANGO] - ) + pw_stage = PasswordStage.objects.create(name="password", backends=[BACKEND_DJANGO]) self.stage.password_stage = pw_stage self.stage.save() form_data = { "uid_field": self.user.email, "password": self.password + "test", } - url = reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ) + url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) response = self.client.post(url, form_data) self.assertEqual(response.status_code, 200) self.assertJSONEqual( @@ -142,9 +132,7 @@ class TestIdentificationStage(TestCase): """Test invalid with username (user exists but stage only allows email)""" form_data = {"uid_field": self.user.username} response = self.client.post( - reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ), + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), form_data, ) self.assertEqual(response.status_code, 200) @@ -153,9 +141,7 @@ class TestIdentificationStage(TestCase): """Test with invalid email (user doesn't exist) -> Will return to login form""" form_data = {"uid_field": self.user.email + "test"} response = self.client.post( - reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ), + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), form_data, ) self.assertEqual(response.status_code, 200) @@ -177,9 +163,7 @@ class TestIdentificationStage(TestCase): ) response = self.client.get( - reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ), + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), ) self.assertEqual(response.status_code, 200) self.assertJSONEqual( @@ -228,9 +212,7 @@ class TestIdentificationStage(TestCase): order=0, ) response = self.client.get( - reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ), + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), ) self.assertEqual(response.status_code, 200) self.assertJSONEqual( diff --git a/authentik/stages/invitation/stage.py b/authentik/stages/invitation/stage.py index d48ee03a9..0042feb2a 100644 --- a/authentik/stages/invitation/stage.py +++ b/authentik/stages/invitation/stage.py @@ -27,9 +27,7 @@ class InvitationStageView(StageView): """Get token from saved get-arguments or prompt_data""" if INVITATION_TOKEN_KEY in self.request.session.get(SESSION_KEY_GET, {}): return self.request.session[SESSION_KEY_GET][INVITATION_TOKEN_KEY] - if INVITATION_TOKEN_KEY in self.executor.plan.context.get( - PLAN_CONTEXT_PROMPT, {} - ): + if INVITATION_TOKEN_KEY in self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {}): return self.executor.plan.context[PLAN_CONTEXT_PROMPT][INVITATION_TOKEN_KEY] return None @@ -48,9 +46,7 @@ class InvitationStageView(StageView): self.executor.plan.context[INVITATION] = invite context = {} - always_merger.merge( - context, self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {}) - ) + always_merger.merge(context, self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {})) always_merger.merge(context, invite.fixed_data) self.executor.plan.context[PLAN_CONTEXT_PROMPT] = context diff --git a/authentik/stages/invitation/tests.py b/authentik/stages/invitation/tests.py index c82e2b449..45d0c97f5 100644 --- a/authentik/stages/invitation/tests.py +++ b/authentik/stages/invitation/tests.py @@ -35,9 +35,7 @@ class TestUserLoginStage(TestCase): designation=FlowDesignation.AUTHENTICATION, ) self.stage = InvitationStage.objects.create(name="invitation") - self.binding = FlowStageBinding.objects.create( - target=self.flow, stage=self.stage, order=2 - ) + self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=2) @patch( "authentik.flows.views.to_stage_response", @@ -45,9 +43,7 @@ class TestUserLoginStage(TestCase): ) def test_without_invitation_fail(self): """Test without any invitation, continue_flow_without_invitation not set.""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan.context[PLAN_CONTEXT_PENDING_USER] = self.user plan.context[PLAN_CONTEXT_AUTHENTICATION_BACKEND] = BACKEND_DJANGO session = self.client.session @@ -76,9 +72,7 @@ class TestUserLoginStage(TestCase): """Test without any invitation, continue_flow_without_invitation is set.""" self.stage.continue_flow_without_invitation = True self.stage.save() - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan.context[PLAN_CONTEXT_PENDING_USER] = self.user plan.context[PLAN_CONTEXT_AUTHENTICATION_BACKEND] = BACKEND_DJANGO session = self.client.session @@ -104,22 +98,16 @@ class TestUserLoginStage(TestCase): def test_with_invitation_get(self): """Test with invitation, check data in session""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() data = {"foo": "bar"} - invite = Invitation.objects.create( - created_by=get_anonymous_user(), fixed_data=data - ) + invite = Invitation.objects.create(created_by=get_anonymous_user(), fixed_data=data) with patch("authentik.flows.views.FlowExecutorView.cancel", MagicMock()): - base_url = reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ) + base_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) args = urlencode({INVITATION_TOKEN_KEY: invite.pk.hex}) response = self.client.get(base_url + f"?query={args}") @@ -144,18 +132,14 @@ class TestUserLoginStage(TestCase): created_by=get_anonymous_user(), fixed_data=data, single_use=True ) - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan.context[PLAN_CONTEXT_PROMPT] = {INVITATION_TOKEN_KEY: invite.pk.hex} session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() with patch("authentik.flows.views.FlowExecutorView.cancel", MagicMock()): - base_url = reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ) + base_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) response = self.client.get(base_url, follow=True) session = self.client.session diff --git a/authentik/stages/password/stage.py b/authentik/stages/password/stage.py index a548cc7be..f5cf96c22 100644 --- a/authentik/stages/password/stage.py +++ b/authentik/stages/password/stage.py @@ -91,9 +91,7 @@ class PasswordStageView(ChallengeStageView): "authentik_core:if-flow", kwargs={"flow_slug": recovery_flow.first().slug}, ) - challenge.initial_data["recovery_url"] = self.request.build_absolute_uri( - recover_url - ) + challenge.initial_data["recovery_url"] = self.request.build_absolute_uri(recover_url) return challenge def challenge_invalid(self, response: PasswordChallengeResponse) -> HttpResponse: @@ -122,9 +120,7 @@ class PasswordStageView(ChallengeStageView): "username": pending_user.username, } try: - user = authenticate( - self.request, self.executor.current_stage.backends, **auth_kwargs - ) + user = authenticate(self.request, self.executor.current_stage.backends, **auth_kwargs) except PermissionDenied: del auth_kwargs["password"] # User was found, but permission was denied (i.e. user is not active) @@ -142,13 +138,9 @@ class PasswordStageView(ChallengeStageView): LOGGER.debug("Invalid credentials") # Manually inject error into form response._errors.setdefault("password", []) - response._errors["password"].append( - ErrorDetail(_("Invalid password"), "invalid") - ) + response._errors["password"].append(ErrorDetail(_("Invalid password"), "invalid")) return self.challenge_invalid(response) # User instance returned from authenticate() has .backend property set self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = user - self.executor.plan.context[ - PLAN_CONTEXT_AUTHENTICATION_BACKEND - ] = user.backend + self.executor.plan.context[PLAN_CONTEXT_AUTHENTICATION_BACKEND] = user.backend return self.executor.stage_ok() diff --git a/authentik/stages/password/tests.py b/authentik/stages/password/tests.py index 5a653451f..39a191098 100644 --- a/authentik/stages/password/tests.py +++ b/authentik/stages/password/tests.py @@ -36,12 +36,8 @@ class TestPasswordStage(TestCase): slug="test-password", designation=FlowDesignation.AUTHENTICATION, ) - self.stage = PasswordStage.objects.create( - name="password", backends=[BACKEND_DJANGO] - ) - self.binding = FlowStageBinding.objects.create( - target=self.flow, stage=self.stage, order=2 - ) + self.stage = PasswordStage.objects.create(name="password", backends=[BACKEND_DJANGO]) + self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=2) @patch( "authentik.flows.views.to_stage_response", @@ -49,17 +45,13 @@ class TestPasswordStage(TestCase): ) def test_without_user(self): """Test without user""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() response = self.client.post( - reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ), + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), # Still have to send the password so the form is valid {"password": self.password}, ) @@ -81,39 +73,29 @@ class TestPasswordStage(TestCase): def test_recovery_flow_link(self): """Test link to the default recovery flow""" - flow = Flow.objects.create( - designation=FlowDesignation.RECOVERY, slug="qewrqerqr" - ) + flow = Flow.objects.create(designation=FlowDesignation.RECOVERY, slug="qewrqerqr") - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() response = self.client.get( - reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ), + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), ) self.assertEqual(response.status_code, 200) self.assertIn(flow.slug, force_str(response.content)) def test_valid_password(self): """Test with a valid pending user and valid password""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan.context[PLAN_CONTEXT_PENDING_USER] = self.user session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() response = self.client.post( - reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ), + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), # Form data {"password": self.password}, ) @@ -130,18 +112,14 @@ class TestPasswordStage(TestCase): def test_invalid_password(self): """Test with a valid pending user and invalid password""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan.context[PLAN_CONTEXT_PENDING_USER] = self.user session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() response = self.client.post( - reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ), + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), # Form data {"password": self.password + "test"}, ) @@ -149,9 +127,7 @@ class TestPasswordStage(TestCase): def test_invalid_password_lockout(self): """Test with a valid pending user and invalid password (trigger logout counter)""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan.context[PLAN_CONTEXT_PENDING_USER] = self.user session = self.client.session session[SESSION_KEY_PLAN] = plan @@ -169,9 +145,7 @@ class TestPasswordStage(TestCase): self.assertEqual(response.status_code, 200) response = self.client.post( - reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ), + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), # Form data {"password": self.password + "test"}, ) @@ -190,18 +164,14 @@ class TestPasswordStage(TestCase): def test_permission_denied(self): """Test with a valid pending user and valid password. Backend is patched to return PermissionError""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan.context[PLAN_CONTEXT_PENDING_USER] = self.user session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() response = self.client.post( - reverse( - "authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug} - ), + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), # Form data {"password": self.password + "test"}, ) diff --git a/authentik/stages/prompt/models.py b/authentik/stages/prompt/models.py index 686931103..6f73d5d42 100644 --- a/authentik/stages/prompt/models.py +++ b/authentik/stages/prompt/models.py @@ -29,12 +29,7 @@ class FieldTypes(models.TextChoices): # Same as text, but has autocomplete for password managers USERNAME = ( "username", - _( - ( - "Username: Same as Text input, but checks for " - "and prevents duplicate usernames." - ) - ), + _(("Username: Same as Text input, but checks for " "and prevents duplicate usernames.")), ) EMAIL = "email", _("Email: Text field with Email type.") PASSWORD = ( diff --git a/authentik/stages/prompt/stage.py b/authentik/stages/prompt/stage.py index b9ab2d578..b5b2d978b 100644 --- a/authentik/stages/prompt/stage.py +++ b/authentik/stages/prompt/stage.py @@ -99,13 +99,9 @@ class PromptChallengeResponse(ChallengeResponse): attrs[static_hidden.field_key] = static_hidden.placeholder # Check if we have two password fields, and make sure they are the same - password_fields: QuerySet[Prompt] = self.stage.fields.filter( - type=FieldTypes.PASSWORD - ) + password_fields: QuerySet[Prompt] = self.stage.fields.filter(type=FieldTypes.PASSWORD) if password_fields.exists() and password_fields.count() == 2: - self._validate_password_fields( - *[field.field_key for field in password_fields] - ) + self._validate_password_fields(*[field.field_key for field in password_fields]) user = self.plan.context.get(PLAN_CONTEXT_PENDING_USER, get_anonymous_user()) engine = ListPolicyEngine(self.stage.validation_policies.all(), user) @@ -135,9 +131,7 @@ def password_single_validator_factory() -> Callable[[PromptChallenge, str], Any] def password_single_clean(self: PromptChallenge, value: str) -> Any: """Send password validation signals for e.g. LDAP Source""" - password_validate.send( - sender=self, password=value, plan_context=self.plan.context - ) + password_validate.send(sender=self, password=value, plan_context=self.plan.context) return value return password_single_clean @@ -146,9 +140,7 @@ def password_single_validator_factory() -> Callable[[PromptChallenge, str], Any] class ListPolicyEngine(PolicyEngine): """Slightly modified policy engine, which uses a list instead of a PolicyBindingModel""" - def __init__( - self, policies: list[Policy], user: User, request: HttpRequest = None - ) -> None: + def __init__(self, policies: list[Policy], user: User, request: HttpRequest = None) -> None: super().__init__(PolicyBindingModel(), user, request) self.__list = policies self.use_cache = False diff --git a/authentik/stages/prompt/tests.py b/authentik/stages/prompt/tests.py index cc33e2acd..644448ebd 100644 --- a/authentik/stages/prompt/tests.py +++ b/authentik/stages/prompt/tests.py @@ -110,15 +110,11 @@ class TestPromptStage(TestCase): static_prompt.field_key: static_prompt.placeholder, } - self.binding = FlowStageBinding.objects.create( - target=self.flow, stage=self.stage, order=2 - ) + self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=2) def test_render(self): """Test render of form, check if all prompts are rendered correctly""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() @@ -134,13 +130,9 @@ class TestPromptStage(TestCase): def test_valid_challenge_with_policy(self) -> PromptChallengeResponse: """Test challenge_response validation""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) expr = "return request.context['password_prompt'] == request.context['password2_prompt']" - expr_policy = ExpressionPolicy.objects.create( - name="validate-form", expression=expr - ) + expr_policy = ExpressionPolicy.objects.create(name="validate-form", expression=expr) self.stage.validation_policies.set([expr_policy]) self.stage.save() challenge_response = PromptChallengeResponse( @@ -151,13 +143,9 @@ class TestPromptStage(TestCase): def test_invalid_challenge(self) -> PromptChallengeResponse: """Test challenge_response validation""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) expr = "False" - expr_policy = ExpressionPolicy.objects.create( - name="validate-form", expression=expr - ) + expr_policy = ExpressionPolicy.objects.create(name="validate-form", expression=expr) self.stage.validation_policies.set([expr_policy]) self.stage.save() challenge_response = PromptChallengeResponse( @@ -168,9 +156,7 @@ class TestPromptStage(TestCase): def test_valid_challenge_request(self): """Test a request with valid challenge_response data""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() @@ -205,9 +191,7 @@ class TestPromptStage(TestCase): def test_invalid_password(self): """Test challenge_response validation""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) self.prompt_data["password2_prompt"] = "qwerqwerqr" challenge_response = PromptChallengeResponse( None, stage=self.stage, plan=plan, data=self.prompt_data @@ -215,18 +199,12 @@ class TestPromptStage(TestCase): self.assertEqual(challenge_response.is_valid(), False) self.assertEqual( challenge_response.errors, - { - "non_field_errors": [ - ErrorDetail(string="Passwords don't match.", code="invalid") - ] - }, + {"non_field_errors": [ErrorDetail(string="Passwords don't match.", code="invalid")]}, ) def test_invalid_username(self): """Test challenge_response validation""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) self.prompt_data["username_prompt"] = "akadmin" challenge_response = PromptChallengeResponse( None, stage=self.stage, plan=plan, data=self.prompt_data @@ -234,18 +212,12 @@ class TestPromptStage(TestCase): self.assertEqual(challenge_response.is_valid(), False) self.assertEqual( challenge_response.errors, - { - "username_prompt": [ - ErrorDetail(string="Username is already taken.", code="invalid") - ] - }, + {"username_prompt": [ErrorDetail(string="Username is already taken.", code="invalid")]}, ) def test_static_hidden_overwrite(self): """Test that static and hidden fields ignore any value sent to them""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) self.prompt_data["hidden_prompt"] = "foo" self.prompt_data["static_prompt"] = "foo" challenge_response = PromptChallengeResponse( diff --git a/authentik/stages/user_delete/tests.py b/authentik/stages/user_delete/tests.py index e1e357d61..544121036 100644 --- a/authentik/stages/user_delete/tests.py +++ b/authentik/stages/user_delete/tests.py @@ -30,9 +30,7 @@ class TestUserDeleteStage(TestCase): designation=FlowDesignation.AUTHENTICATION, ) self.stage = UserDeleteStage.objects.create(name="delete") - self.binding = FlowStageBinding.objects.create( - target=self.flow, stage=self.stage, order=2 - ) + self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=2) @patch( "authentik.flows.views.to_stage_response", @@ -40,9 +38,7 @@ class TestUserDeleteStage(TestCase): ) def test_no_user(self): """Test without user set""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() @@ -67,9 +63,7 @@ class TestUserDeleteStage(TestCase): def test_user_delete_get(self): """Test Form render""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan.context[PLAN_CONTEXT_PENDING_USER] = self.user session = self.client.session session[SESSION_KEY_PLAN] = plan diff --git a/authentik/stages/user_login/tests.py b/authentik/stages/user_login/tests.py index ebcb90569..cc31133b8 100644 --- a/authentik/stages/user_login/tests.py +++ b/authentik/stages/user_login/tests.py @@ -30,15 +30,11 @@ class TestUserLoginStage(TestCase): designation=FlowDesignation.AUTHENTICATION, ) self.stage = UserLoginStage.objects.create(name="login") - self.binding = FlowStageBinding.objects.create( - target=self.flow, stage=self.stage, order=2 - ) + self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=2) def test_valid_password(self): """Test with a valid pending user and backend""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan.context[PLAN_CONTEXT_PENDING_USER] = self.user session = self.client.session session[SESSION_KEY_PLAN] = plan @@ -62,9 +58,7 @@ class TestUserLoginStage(TestCase): """Test with expiry""" self.stage.session_duration = "seconds=2" self.stage.save() - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan.context[PLAN_CONTEXT_PENDING_USER] = self.user session = self.client.session session[SESSION_KEY_PLAN] = plan @@ -93,9 +87,7 @@ class TestUserLoginStage(TestCase): ) def test_without_user(self): """Test a plan without any pending user, resulting in a denied""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() diff --git a/authentik/stages/user_logout/tests.py b/authentik/stages/user_logout/tests.py index 2472fc7cc..2958e3f97 100644 --- a/authentik/stages/user_logout/tests.py +++ b/authentik/stages/user_logout/tests.py @@ -28,15 +28,11 @@ class TestUserLogoutStage(TestCase): designation=FlowDesignation.AUTHENTICATION, ) self.stage = UserLogoutStage.objects.create(name="logout") - self.binding = FlowStageBinding.objects.create( - target=self.flow, stage=self.stage, order=2 - ) + self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=2) def test_valid_password(self): """Test with a valid pending user and backend""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan.context[PLAN_CONTEXT_PENDING_USER] = self.user plan.context[PLAN_CONTEXT_AUTHENTICATION_BACKEND] = BACKEND_DJANGO session = self.client.session diff --git a/authentik/stages/user_write/stage.py b/authentik/stages/user_write/stage.py index 0e726e05f..ea86fe9a3 100644 --- a/authentik/stages/user_write/stage.py +++ b/authentik/stages/user_write/stage.py @@ -42,9 +42,9 @@ class UserWriteStageView(StageView): self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = User( is_active=not self.executor.current_stage.create_users_as_inactive ) - self.executor.plan.context[ - PLAN_CONTEXT_AUTHENTICATION_BACKEND - ] = class_to_path(ModelBackend) + self.executor.plan.context[PLAN_CONTEXT_AUTHENTICATION_BACKEND] = class_to_path( + ModelBackend + ) LOGGER.debug( "Created new user", flow_slug=self.executor.flow.slug, @@ -98,9 +98,7 @@ class UserWriteStageView(StageView): except IntegrityError as exc: LOGGER.warning("Failed to save user", exc=exc) return self.executor.stage_invalid() - user_write.send( - sender=self, request=request, user=user, data=data, created=user_created - ) + user_write.send(sender=self, request=request, user=user, data=data, created=user_created) # Check if the password has been updated, and update the session auth hash if should_update_seesion: update_session_auth_hash(self.request, user) diff --git a/authentik/stages/user_write/tests.py b/authentik/stages/user_write/tests.py index bd8fea12c..eb9ac2454 100644 --- a/authentik/stages/user_write/tests.py +++ b/authentik/stages/user_write/tests.py @@ -7,12 +7,7 @@ from django.test import Client, TestCase from django.urls import reverse from django.utils.encoding import force_str -from authentik.core.models import ( - USER_ATTRIBUTE_SOURCES, - Source, - User, - UserSourceConnection, -) +from authentik.core.models import USER_ATTRIBUTE_SOURCES, Source, User, UserSourceConnection from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION from authentik.flows.challenge import ChallengeTypes from authentik.flows.markers import StageMarker @@ -37,30 +32,23 @@ class TestUserWriteStage(TestCase): designation=FlowDesignation.AUTHENTICATION, ) self.stage = UserWriteStage.objects.create(name="write") - self.binding = FlowStageBinding.objects.create( - target=self.flow, stage=self.stage, order=2 - ) + self.binding = FlowStageBinding.objects.create(target=self.flow, stage=self.stage, order=2) self.source = Source.objects.create(name="fake_source") def test_user_create(self): """Test creation of user""" password = "".join( - SystemRandom().choice(string.ascii_uppercase + string.digits) - for _ in range(8) + SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(8) ) - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan.context[PLAN_CONTEXT_PROMPT] = { "username": "test-user", "name": "name", "email": "test@beryju.org", "password": password, } - plan.context[PLAN_CONTEXT_SOURCES_CONNECTION] = UserSourceConnection( - source=self.source - ) + plan.context[PLAN_CONTEXT_SOURCES_CONNECTION] = UserSourceConnection(source=self.source) session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() @@ -78,24 +66,17 @@ class TestUserWriteStage(TestCase): "type": ChallengeTypes.REDIRECT.value, }, ) - user_qs = User.objects.filter( - username=plan.context[PLAN_CONTEXT_PROMPT]["username"] - ) + user_qs = User.objects.filter(username=plan.context[PLAN_CONTEXT_PROMPT]["username"]) self.assertTrue(user_qs.exists()) self.assertTrue(user_qs.first().check_password(password)) - self.assertEqual( - user_qs.first().attributes, {USER_ATTRIBUTE_SOURCES: [self.source.name]} - ) + self.assertEqual(user_qs.first().attributes, {USER_ATTRIBUTE_SOURCES: [self.source.name]}) def test_user_update(self): """Test update of existing user""" new_password = "".join( - SystemRandom().choice(string.ascii_uppercase + string.digits) - for _ in range(8) - ) - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] + SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(8) ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan.context[PLAN_CONTEXT_PENDING_USER] = User.objects.create( username="unittest", email="test@beryju.org" ) @@ -122,9 +103,7 @@ class TestUserWriteStage(TestCase): "type": ChallengeTypes.REDIRECT.value, }, ) - user_qs = User.objects.filter( - username=plan.context[PLAN_CONTEXT_PROMPT]["username"] - ) + user_qs = User.objects.filter(username=plan.context[PLAN_CONTEXT_PROMPT]["username"]) self.assertTrue(user_qs.exists()) self.assertTrue(user_qs.first().check_password(new_password)) self.assertEqual(user_qs.first().attributes["some-custom-attribute"], "test") @@ -136,9 +115,7 @@ class TestUserWriteStage(TestCase): ) def test_without_data(self): """Test without data results in error""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) session = self.client.session session[SESSION_KEY_PLAN] = plan session.save() @@ -168,9 +145,7 @@ class TestUserWriteStage(TestCase): ) def test_blank_username(self): """Test with blank username results in error""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) session = self.client.session plan.context[PLAN_CONTEXT_PROMPT] = { "username": "", @@ -205,9 +180,7 @@ class TestUserWriteStage(TestCase): ) def test_duplicate_data(self): """Test with duplicate data, should trigger error""" - plan = FlowPlan( - flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()] - ) + plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) session = self.client.session plan.context[PLAN_CONTEXT_PROMPT] = { "username": "akadmin", diff --git a/authentik/tenants/migrations/0001_initial.py b/authentik/tenants/migrations/0001_initial.py index c2dee9a05..354290149 100644 --- a/authentik/tenants/migrations/0001_initial.py +++ b/authentik/tenants/migrations/0001_initial.py @@ -37,9 +37,7 @@ class Migration(migrations.Migration): ("branding_title", models.TextField(default="authentik")), ( "branding_logo", - models.TextField( - default="/static/dist/assets/icons/icon_left_brand.svg" - ), + models.TextField(default="/static/dist/assets/icons/icon_left_brand.svg"), ), ( "flow_authentication", diff --git a/authentik/tenants/models.py b/authentik/tenants/models.py index 7be1c2f84..539cbc8f8 100644 --- a/authentik/tenants/models.py +++ b/authentik/tenants/models.py @@ -24,9 +24,7 @@ class Tenant(models.Model): branding_title = models.TextField(default="authentik") - branding_logo = models.TextField( - default="/static/dist/assets/icons/icon_left_brand.svg" - ) + branding_logo = models.TextField(default="/static/dist/assets/icons/icon_left_brand.svg") branding_favicon = models.TextField(default="/static/dist/assets/icons/icon.png") flow_authentication = models.ForeignKey( diff --git a/authentik/tenants/tests.py b/authentik/tenants/tests.py index 383006e37..bc2d58b07 100644 --- a/authentik/tenants/tests.py +++ b/authentik/tenants/tests.py @@ -72,12 +72,8 @@ class TestTenants(TestCase): factory = RequestFactory() request = factory.get("/") request.tenant = tenant - event = Event.new( - action=EventAction.SYSTEM_EXCEPTION, message="test" - ).from_http(request) - self.assertEqual( - event.expires.day, (event.created + timedelta_from_string("weeks=3")).day - ) + event = Event.new(action=EventAction.SYSTEM_EXCEPTION, message="test").from_http(request) + self.assertEqual(event.expires.day, (event.created + timedelta_from_string("weeks=3")).day) self.assertEqual( event.expires.month, (event.created + timedelta_from_string("weeks=3")).month, diff --git a/lifecycle/migrate.py b/lifecycle/migrate.py index 6d22a1df4..b21f64fe8 100755 --- a/lifecycle/migrate.py +++ b/lifecycle/migrate.py @@ -47,9 +47,7 @@ if __name__ == "__main__": LOGGER.info("waiting to acquire database lock") curr.execute("SELECT pg_advisory_lock(%s)", (ADV_LOCK_UID,)) try: - for migration in ( - Path(__file__).parent.absolute().glob("system_migrations/*.py") - ): + for migration in Path(__file__).parent.absolute().glob("system_migrations/*.py"): spec = spec_from_file_location("lifecycle.system_migrations", migration) mod = module_from_spec(spec) # pyright: reportGeneralTypeIssues=false diff --git a/lifecycle/system_migrations/to_0_10.py b/lifecycle/system_migrations/to_0_10.py index 7ea1e5d10..77ff3f69c 100644 --- a/lifecycle/system_migrations/to_0_10.py +++ b/lifecycle/system_migrations/to_0_10.py @@ -46,8 +46,6 @@ class Migration(BaseMigration): self.system_crit("./manage.py migrate passbook_stages_prompt") self.system_crit("./manage.py migrate passbook_flows 0008_default_flows --fake") self.system_crit("./manage.py migrate passbook_flows 0009_source_flows --fake") - self.system_crit( - "./manage.py migrate passbook_flows 0010_provider_flows --fake" - ) + self.system_crit("./manage.py migrate passbook_flows 0010_provider_flows --fake") self.system_crit("./manage.py migrate passbook_flows") self.system_crit("./manage.py migrate passbook_stages_password --fake") diff --git a/pyproject.toml b/pyproject.toml index f3601d61d..63fe19cf8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,5 @@ [tool.black] +line-length = 100 target-version = ['py38'] exclude = 'node_modules' @@ -7,7 +8,7 @@ multi_line_output = 3 include_trailing_comma = true force_grid_wrap = 0 use_parentheses = true -line_length = 88 +line_length = 100 src_paths = ["authentik", "tests", "lifecycle"] force_to_top = "*" diff --git a/tests/e2e/test_flows_authenticators.py b/tests/e2e/test_flows_authenticators.py index 6b62cdd8f..fc6119233 100644 --- a/tests/e2e/test_flows_authenticators.py +++ b/tests/e2e/test_flows_authenticators.py @@ -47,19 +47,11 @@ class TestFlowsAuthenticator(SeleniumTestCase): totp = TOTP(device.bin_key, device.step, device.t0, device.digits, device.drift) flow_executor = self.get_shadow_root("ak-flow-executor") - validation_stage = self.get_shadow_root( - "ak-stage-authenticator-validate", flow_executor - ) - code_stage = self.get_shadow_root( - "ak-stage-authenticator-validate-code", validation_stage - ) + validation_stage = self.get_shadow_root("ak-stage-authenticator-validate", flow_executor) + code_stage = self.get_shadow_root("ak-stage-authenticator-validate-code", validation_stage) - code_stage.find_element(By.CSS_SELECTOR, "input[name=code]").send_keys( - totp.token() - ) - code_stage.find_element(By.CSS_SELECTOR, "input[name=code]").send_keys( - Keys.ENTER - ) + code_stage.find_element(By.CSS_SELECTOR, "input[name=code]").send_keys(totp.token()) + code_stage.find_element(By.CSS_SELECTOR, "input[name=code]").send_keys(Keys.ENTER) self.wait_for_url(self.if_admin_url("/library")) self.assert_user(USER()) @@ -89,12 +81,10 @@ class TestFlowsAuthenticator(SeleniumTestCase): totp_stage = self.get_shadow_root("ak-stage-authenticator-totp", flow_executor) wait = WebDriverWait(totp_stage, self.wait_timeout) - wait.until( - ec.presence_of_element_located((By.CSS_SELECTOR, "input[name=otp_uri]")) + wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "input[name=otp_uri]"))) + otp_uri = totp_stage.find_element(By.CSS_SELECTOR, "input[name=otp_uri]").get_attribute( + "value" ) - otp_uri = totp_stage.find_element( - By.CSS_SELECTOR, "input[name=otp_uri]" - ).get_attribute("value") # Parse the OTP URI, extract the secret and get the next token otp_args = urlparse(otp_uri) @@ -104,12 +94,8 @@ class TestFlowsAuthenticator(SeleniumTestCase): totp = TOTP(secret_key) - totp_stage.find_element(By.CSS_SELECTOR, "input[name=code]").send_keys( - totp.token() - ) - totp_stage.find_element(By.CSS_SELECTOR, "input[name=code]").send_keys( - Keys.ENTER - ) + totp_stage.find_element(By.CSS_SELECTOR, "input[name=code]").send_keys(totp.token()) + totp_stage.find_element(By.CSS_SELECTOR, "input[name=code]").send_keys(Keys.ENTER) sleep(3) self.assertTrue(TOTPDevice.objects.filter(user=USER(), confirmed=True).exists()) @@ -140,9 +126,7 @@ class TestFlowsAuthenticator(SeleniumTestCase): destination_url = self.driver.current_url flow_executor = self.get_shadow_root("ak-flow-executor") - authenticator_stage = self.get_shadow_root( - "ak-stage-authenticator-static", flow_executor - ) + authenticator_stage = self.get_shadow_root("ak-stage-authenticator-static", flow_executor) token = authenticator_stage.find_element( By.CSS_SELECTOR, ".ak-otp-tokens li:nth-child(1)" ).text @@ -152,8 +136,6 @@ class TestFlowsAuthenticator(SeleniumTestCase): self.wait_for_url(destination_url) sleep(1) - self.assertTrue( - StaticDevice.objects.filter(user=USER(), confirmed=True).exists() - ) + self.assertTrue(StaticDevice.objects.filter(user=USER(), confirmed=True).exists()) device = StaticDevice.objects.filter(user=USER(), confirmed=True).first() self.assertTrue(StaticToken.objects.filter(token=token, device=device).exists()) diff --git a/tests/e2e/test_flows_enroll.py b/tests/e2e/test_flows_enroll.py index 82a22d912..31f27dac7 100644 --- a/tests/e2e/test_flows_enroll.py +++ b/tests/e2e/test_flows_enroll.py @@ -186,9 +186,7 @@ class TestFlowsEnroll(SeleniumTestCase): self.driver.get("http://localhost:8025") # Click on first message - self.wait.until( - ec.presence_of_element_located((By.CLASS_NAME, "msglist-message")) - ) + self.wait.until(ec.presence_of_element_located((By.CLASS_NAME, "msglist-message"))) self.driver.find_element(By.CLASS_NAME, "msglist-message").click() self.driver.switch_to.frame(self.driver.find_element(By.CLASS_NAME, "tab-pane")) self.driver.find_element(By.ID, "confirm").click() @@ -197,9 +195,7 @@ class TestFlowsEnroll(SeleniumTestCase): sleep(2) # We're now logged in - wait = WebDriverWait( - self.get_shadow_root("ak-interface-admin"), self.wait_timeout - ) + wait = WebDriverWait(self.get_shadow_root("ak-interface-admin"), self.wait_timeout) wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "ak-sidebar"))) self.driver.get(self.if_admin_url("/user")) @@ -210,9 +206,7 @@ class TestFlowsEnroll(SeleniumTestCase): """Fill out initial stages""" # Identification stage, click enroll flow_executor = self.get_shadow_root("ak-flow-executor") - identification_stage = self.get_shadow_root( - "ak-stage-identification", flow_executor - ) + identification_stage = self.get_shadow_root("ak-stage-identification", flow_executor) wait = WebDriverWait(identification_stage, self.wait_timeout) wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "#enroll"))) @@ -223,18 +217,14 @@ class TestFlowsEnroll(SeleniumTestCase): prompt_stage = self.get_shadow_root("ak-stage-prompt", flow_executor) wait = WebDriverWait(prompt_stage, self.wait_timeout) - wait.until( - ec.presence_of_element_located((By.CSS_SELECTOR, "input[name=username]")) - ) - prompt_stage.find_element(By.CSS_SELECTOR, "input[name=username]").send_keys( - "foo" - ) + wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "input[name=username]"))) + prompt_stage.find_element(By.CSS_SELECTOR, "input[name=username]").send_keys("foo") prompt_stage.find_element(By.CSS_SELECTOR, "input[name=password]").send_keys( USER().username ) - prompt_stage.find_element( - By.CSS_SELECTOR, "input[name=password_repeat]" - ).send_keys(USER().username) + prompt_stage.find_element(By.CSS_SELECTOR, "input[name=password_repeat]").send_keys( + USER().username + ) prompt_stage.find_element(By.CSS_SELECTOR, ".pf-c-button").click() # Second prompt stage @@ -242,13 +232,7 @@ class TestFlowsEnroll(SeleniumTestCase): prompt_stage = self.get_shadow_root("ak-stage-prompt", flow_executor) wait = WebDriverWait(prompt_stage, self.wait_timeout) - wait.until( - ec.presence_of_element_located((By.CSS_SELECTOR, "input[name=name]")) - ) - prompt_stage.find_element(By.CSS_SELECTOR, "input[name=name]").send_keys( - "some name" - ) - prompt_stage.find_element(By.CSS_SELECTOR, "input[name=email]").send_keys( - "foo@bar.baz" - ) + wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "input[name=name]"))) + prompt_stage.find_element(By.CSS_SELECTOR, "input[name=name]").send_keys("some name") + prompt_stage.find_element(By.CSS_SELECTOR, "input[name=email]").send_keys("foo@bar.baz") prompt_stage.find_element(By.CSS_SELECTOR, ".pf-c-button").click() diff --git a/tests/e2e/test_flows_stage_setup.py b/tests/e2e/test_flows_stage_setup.py index 7e2ccab44..93b77733a 100644 --- a/tests/e2e/test_flows_stage_setup.py +++ b/tests/e2e/test_flows_stage_setup.py @@ -54,15 +54,13 @@ class TestFlowsStageSetup(SeleniumTestCase): flow_executor = self.get_shadow_root("ak-flow-executor") prompt_stage = self.get_shadow_root("ak-stage-prompt", flow_executor) - prompt_stage.find_element(By.CSS_SELECTOR, "input[name=password]").send_keys( + prompt_stage.find_element(By.CSS_SELECTOR, "input[name=password]").send_keys(new_password) + prompt_stage.find_element(By.CSS_SELECTOR, "input[name=password_repeat]").send_keys( new_password ) - prompt_stage.find_element( - By.CSS_SELECTOR, "input[name=password_repeat]" - ).send_keys(new_password) - prompt_stage.find_element( - By.CSS_SELECTOR, "input[name=password_repeat]" - ).send_keys(Keys.ENTER) + prompt_stage.find_element(By.CSS_SELECTOR, "input[name=password_repeat]").send_keys( + Keys.ENTER + ) self.wait_for_url(self.if_admin_url("/library")) # Because USER() is cached, we need to get the user manually here diff --git a/tests/e2e/test_provider_ldap.py b/tests/e2e/test_provider_ldap.py index dbba8dcf1..d3c66e9ef 100644 --- a/tests/e2e/test_provider_ldap.py +++ b/tests/e2e/test_provider_ldap.py @@ -6,14 +6,7 @@ from unittest.case import skipUnless from docker.client import DockerClient, from_env from docker.models.containers import Container from guardian.shortcuts import get_anonymous_user -from ldap3 import ( - ALL, - ALL_ATTRIBUTES, - ALL_OPERATIONAL_ATTRIBUTES, - SUBTREE, - Connection, - Server, -) +from ldap3 import ALL, ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES, SUBTREE, Connection, Server from ldap3.core.exceptions import LDAPInvalidCredentialsResult from authentik.core.models import Application, Group, User @@ -22,13 +15,7 @@ from authentik.flows.models import Flow from authentik.outposts.managed import MANAGED_OUTPOST from authentik.outposts.models import Outpost, OutpostType from authentik.providers.ldap.models import LDAPProvider -from tests.e2e.utils import ( - USER, - SeleniumTestCase, - apply_migration, - object_manager, - retry, -) +from tests.e2e.utils import USER, SeleniumTestCase, apply_migration, object_manager, retry @skipUnless(platform.startswith("linux"), "requires local docker") @@ -276,9 +263,7 @@ class TestProviderLDAP(SeleniumTestCase): ], "uidNumber": [str(2000 + USER().pk)], "gidNumber": [str(2000 + USER().pk)], - "memberOf": [ - "cn=authentik Admins,ou=groups,dc=ldap,dc=goauthentik,dc=io" - ], + "memberOf": ["cn=authentik Admins,ou=groups,dc=ldap,dc=goauthentik,dc=io"], "accountStatus": ["true"], "superuser": ["true"], "goauthentik.io/ldap/active": ["true"], diff --git a/tests/e2e/test_provider_oauth2_github.py b/tests/e2e/test_provider_oauth2_github.py index 17b7caaf5..bd0ef5d13 100644 --- a/tests/e2e/test_provider_oauth2_github.py +++ b/tests/e2e/test_provider_oauth2_github.py @@ -12,10 +12,7 @@ from authentik.core.models import Application from authentik.flows.models import Flow from authentik.policies.expression.models import ExpressionPolicy from authentik.policies.models import PolicyBinding -from authentik.providers.oauth2.generators import ( - generate_client_id, - generate_client_secret, -) +from authentik.providers.oauth2.generators import generate_client_id, generate_client_secret from authentik.providers.oauth2.models import ClientTypes, OAuth2Provider from tests.e2e.utils import USER, SeleniumTestCase, apply_migration, retry @@ -53,9 +50,7 @@ class TestProviderOAuth2Github(SeleniumTestCase): "GF_AUTH_GITHUB_TOKEN_URL": self.url( "authentik_providers_oauth2_github:github-access-token" ), - "GF_AUTH_GITHUB_API_URL": self.url( - "authentik_providers_oauth2_github:github-user" - ), + "GF_AUTH_GITHUB_API_URL": self.url("authentik_providers_oauth2_github:github-user"), "GF_LOG_LEVEL": "debug", }, } @@ -97,21 +92,15 @@ class TestProviderOAuth2Github(SeleniumTestCase): USER().username, ) self.assertEqual( - self.driver.find_element(By.CSS_SELECTOR, "input[name=name]").get_attribute( - "value" - ), + self.driver.find_element(By.CSS_SELECTOR, "input[name=name]").get_attribute("value"), USER().username, ) self.assertEqual( - self.driver.find_element( - By.CSS_SELECTOR, "input[name=email]" - ).get_attribute("value"), + self.driver.find_element(By.CSS_SELECTOR, "input[name=email]").get_attribute("value"), USER().email, ) self.assertEqual( - self.driver.find_element( - By.CSS_SELECTOR, "input[name=login]" - ).get_attribute("value"), + self.driver.find_element(By.CSS_SELECTOR, "input[name=login]").get_attribute("value"), USER().username, ) @@ -146,9 +135,7 @@ class TestProviderOAuth2Github(SeleniumTestCase): self.login() sleep(3) - self.wait.until( - ec.presence_of_element_located((By.CSS_SELECTOR, "ak-flow-executor")) - ) + self.wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "ak-flow-executor"))) flow_executor = self.get_shadow_root("ak-flow-executor") consent_stage = self.get_shadow_root("ak-stage-consent", flow_executor) @@ -159,9 +146,7 @@ class TestProviderOAuth2Github(SeleniumTestCase): ) self.assertEqual( "GitHub Compatibility: Access you Email addresses", - consent_stage.find_element( - By.CSS_SELECTOR, "[data-permission-code='user:email']" - ).text, + consent_stage.find_element(By.CSS_SELECTOR, "[data-permission-code='user:email']").text, ) consent_stage.find_element( By.CSS_SELECTOR, @@ -175,21 +160,15 @@ class TestProviderOAuth2Github(SeleniumTestCase): USER().username, ) self.assertEqual( - self.driver.find_element(By.CSS_SELECTOR, "input[name=name]").get_attribute( - "value" - ), + self.driver.find_element(By.CSS_SELECTOR, "input[name=name]").get_attribute("value"), USER().username, ) self.assertEqual( - self.driver.find_element( - By.CSS_SELECTOR, "input[name=email]" - ).get_attribute("value"), + self.driver.find_element(By.CSS_SELECTOR, "input[name=email]").get_attribute("value"), USER().email, ) self.assertEqual( - self.driver.find_element( - By.CSS_SELECTOR, "input[name=login]" - ).get_attribute("value"), + self.driver.find_element(By.CSS_SELECTOR, "input[name=login]").get_attribute("value"), USER().username, ) @@ -228,9 +207,7 @@ class TestProviderOAuth2Github(SeleniumTestCase): self.driver.find_element(By.CLASS_NAME, "btn-service--github").click() self.login() - self.wait.until( - ec.presence_of_element_located((By.CSS_SELECTOR, "header > h1")) - ) + self.wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "header > h1"))) self.assertEqual( self.driver.find_element(By.CSS_SELECTOR, "header > h1").text, "Permission denied", diff --git a/tests/e2e/test_provider_oauth2_grafana.py b/tests/e2e/test_provider_oauth2_grafana.py index a234a065e..289441b31 100644 --- a/tests/e2e/test_provider_oauth2_grafana.py +++ b/tests/e2e/test_provider_oauth2_grafana.py @@ -19,18 +19,9 @@ from authentik.providers.oauth2.constants import ( SCOPE_OPENID_EMAIL, SCOPE_OPENID_PROFILE, ) -from authentik.providers.oauth2.generators import ( - generate_client_id, - generate_client_secret, -) +from authentik.providers.oauth2.generators import generate_client_id, generate_client_secret from authentik.providers.oauth2.models import ClientTypes, OAuth2Provider, ScopeMapping -from tests.e2e.utils import ( - USER, - SeleniumTestCase, - apply_migration, - object_manager, - retry, -) +from tests.e2e.utils import USER, SeleniumTestCase, apply_migration, object_manager, retry LOGGER = get_logger() APPLICATION_SLUG = "grafana" @@ -64,12 +55,8 @@ class TestProviderOAuth2OAuth(SeleniumTestCase): "GF_AUTH_GENERIC_OAUTH_AUTH_URL": ( self.url("authentik_providers_oauth2:authorize") ), - "GF_AUTH_GENERIC_OAUTH_TOKEN_URL": ( - self.url("authentik_providers_oauth2:token") - ), - "GF_AUTH_GENERIC_OAUTH_API_URL": ( - self.url("authentik_providers_oauth2:userinfo") - ), + "GF_AUTH_GENERIC_OAUTH_TOKEN_URL": (self.url("authentik_providers_oauth2:token")), + "GF_AUTH_GENERIC_OAUTH_API_URL": (self.url("authentik_providers_oauth2:userinfo")), "GF_AUTH_SIGNOUT_REDIRECT_URL": ( self.url( "authentik_core:if-session-end", @@ -167,21 +154,15 @@ class TestProviderOAuth2OAuth(SeleniumTestCase): USER().name, ) self.assertEqual( - self.driver.find_element(By.CSS_SELECTOR, "input[name=name]").get_attribute( - "value" - ), + self.driver.find_element(By.CSS_SELECTOR, "input[name=name]").get_attribute("value"), USER().name, ) self.assertEqual( - self.driver.find_element( - By.CSS_SELECTOR, "input[name=email]" - ).get_attribute("value"), + self.driver.find_element(By.CSS_SELECTOR, "input[name=email]").get_attribute("value"), USER().email, ) self.assertEqual( - self.driver.find_element( - By.CSS_SELECTOR, "input[name=login]" - ).get_attribute("value"), + self.driver.find_element(By.CSS_SELECTOR, "input[name=login]").get_attribute("value"), USER().email, ) @@ -230,21 +211,15 @@ class TestProviderOAuth2OAuth(SeleniumTestCase): USER().name, ) self.assertEqual( - self.driver.find_element(By.CSS_SELECTOR, "input[name=name]").get_attribute( - "value" - ), + self.driver.find_element(By.CSS_SELECTOR, "input[name=name]").get_attribute("value"), USER().name, ) self.assertEqual( - self.driver.find_element( - By.CSS_SELECTOR, "input[name=email]" - ).get_attribute("value"), + self.driver.find_element(By.CSS_SELECTOR, "input[name=email]").get_attribute("value"), USER().email, ) self.assertEqual( - self.driver.find_element( - By.CSS_SELECTOR, "input[name=login]" - ).get_attribute("value"), + self.driver.find_element(By.CSS_SELECTOR, "input[name=login]").get_attribute("value"), USER().email, ) self.driver.get("http://localhost:3000/logout") @@ -295,9 +270,7 @@ class TestProviderOAuth2OAuth(SeleniumTestCase): self.driver.find_element(By.CLASS_NAME, "btn-service--oauth").click() self.login() - self.wait.until( - ec.presence_of_element_located((By.CSS_SELECTOR, "ak-flow-executor")) - ) + self.wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "ak-flow-executor"))) sleep(1) flow_executor = self.get_shadow_root("ak-flow-executor") @@ -320,21 +293,15 @@ class TestProviderOAuth2OAuth(SeleniumTestCase): USER().name, ) self.assertEqual( - self.driver.find_element(By.CSS_SELECTOR, "input[name=name]").get_attribute( - "value" - ), + self.driver.find_element(By.CSS_SELECTOR, "input[name=name]").get_attribute("value"), USER().name, ) self.assertEqual( - self.driver.find_element( - By.CSS_SELECTOR, "input[name=email]" - ).get_attribute("value"), + self.driver.find_element(By.CSS_SELECTOR, "input[name=email]").get_attribute("value"), USER().email, ) self.assertEqual( - self.driver.find_element( - By.CSS_SELECTOR, "input[name=login]" - ).get_attribute("value"), + self.driver.find_element(By.CSS_SELECTOR, "input[name=login]").get_attribute("value"), USER().email, ) @@ -380,9 +347,7 @@ class TestProviderOAuth2OAuth(SeleniumTestCase): self.driver.find_element(By.CLASS_NAME, "btn-service--oauth").click() self.login() - self.wait.until( - ec.presence_of_element_located((By.CSS_SELECTOR, "header > h1")) - ) + self.wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "header > h1"))) self.assertEqual( self.driver.find_element(By.CSS_SELECTOR, "header > h1").text, "Permission denied", diff --git a/tests/e2e/test_provider_oauth2_oidc.py b/tests/e2e/test_provider_oauth2_oidc.py index 0ddf828ba..1c8dac26a 100644 --- a/tests/e2e/test_provider_oauth2_oidc.py +++ b/tests/e2e/test_provider_oauth2_oidc.py @@ -21,18 +21,9 @@ from authentik.providers.oauth2.constants import ( SCOPE_OPENID_EMAIL, SCOPE_OPENID_PROFILE, ) -from authentik.providers.oauth2.generators import ( - generate_client_id, - generate_client_secret, -) +from authentik.providers.oauth2.generators import generate_client_id, generate_client_secret from authentik.providers.oauth2.models import ClientTypes, OAuth2Provider, ScopeMapping -from tests.e2e.utils import ( - USER, - SeleniumTestCase, - apply_migration, - object_manager, - retry, -) +from tests.e2e.utils import USER, SeleniumTestCase, apply_migration, object_manager, retry LOGGER = get_logger() @@ -206,9 +197,7 @@ class TestProviderOAuth2OIDC(SeleniumTestCase): self.driver.get("http://localhost:9009") self.login() - self.wait.until( - ec.presence_of_element_located((By.CSS_SELECTOR, "ak-flow-executor")) - ) + self.wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "ak-flow-executor"))) flow_executor = self.get_shadow_root("ak-flow-executor") consent_stage = self.get_shadow_root("ak-stage-consent", flow_executor) @@ -276,9 +265,7 @@ class TestProviderOAuth2OIDC(SeleniumTestCase): self.container = self.setup_client() self.driver.get("http://localhost:9009") self.login() - self.wait.until( - ec.presence_of_element_located((By.CSS_SELECTOR, "header > h1")) - ) + self.wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "header > h1"))) self.assertEqual( self.driver.find_element(By.CSS_SELECTOR, "header > h1").text, "Permission denied", diff --git a/tests/e2e/test_provider_oauth2_oidc_implicit.py b/tests/e2e/test_provider_oauth2_oidc_implicit.py index 2c555bddc..4251817b0 100644 --- a/tests/e2e/test_provider_oauth2_oidc_implicit.py +++ b/tests/e2e/test_provider_oauth2_oidc_implicit.py @@ -21,18 +21,9 @@ from authentik.providers.oauth2.constants import ( SCOPE_OPENID_EMAIL, SCOPE_OPENID_PROFILE, ) -from authentik.providers.oauth2.generators import ( - generate_client_id, - generate_client_secret, -) +from authentik.providers.oauth2.generators import generate_client_id, generate_client_secret from authentik.providers.oauth2.models import ClientTypes, OAuth2Provider, ScopeMapping -from tests.e2e.utils import ( - USER, - SeleniumTestCase, - apply_migration, - object_manager, - retry, -) +from tests.e2e.utils import USER, SeleniumTestCase, apply_migration, object_manager, retry LOGGER = get_logger() @@ -204,9 +195,7 @@ class TestProviderOAuth2OIDCImplicit(SeleniumTestCase): sleep(2) self.login() - self.wait.until( - ec.presence_of_element_located((By.CSS_SELECTOR, "ak-flow-executor")) - ) + self.wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "ak-flow-executor"))) flow_executor = self.get_shadow_root("ak-flow-executor") consent_stage = self.get_shadow_root("ak-stage-consent", flow_executor) @@ -271,9 +260,7 @@ class TestProviderOAuth2OIDCImplicit(SeleniumTestCase): self.driver.get("http://localhost:9009/implicit/") sleep(2) self.login() - self.wait.until( - ec.presence_of_element_located((By.CSS_SELECTOR, "header > h1")) - ) + self.wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "header > h1"))) self.assertEqual( self.driver.find_element(By.CSS_SELECTOR, "header > h1").text, "Permission denied", diff --git a/tests/e2e/test_provider_proxy.py b/tests/e2e/test_provider_proxy.py index 0242b3430..0a7559840 100644 --- a/tests/e2e/test_provider_proxy.py +++ b/tests/e2e/test_provider_proxy.py @@ -13,21 +13,10 @@ from selenium.webdriver.common.by import By from authentik import __version__ from authentik.core.models import Application from authentik.flows.models import Flow -from authentik.outposts.models import ( - DockerServiceConnection, - Outpost, - OutpostConfig, - OutpostType, -) +from authentik.outposts.models import DockerServiceConnection, Outpost, OutpostConfig, OutpostType from authentik.outposts.tasks import outpost_local_connection from authentik.providers.proxy.models import ProxyProvider -from tests.e2e.utils import ( - USER, - SeleniumTestCase, - apply_migration, - object_manager, - retry, -) +from tests.e2e.utils import USER, SeleniumTestCase, apply_migration, object_manager, retry @skipUnless(platform.startswith("linux"), "requires local docker") @@ -121,9 +110,7 @@ class TestProviderProxy(SeleniumTestCase): self.driver.get("http://localhost:4180/akprox/sign_out") sleep(2) - full_body_text = self.driver.find_element( - By.CSS_SELECTOR, ".pf-c-title.pf-m-3xl" - ).text + full_body_text = self.driver.find_element(By.CSS_SELECTOR, ".pf-c-title.pf-m-3xl").text self.assertIn("You've logged out of proxy.", full_body_text) @@ -159,9 +146,7 @@ class TestProviderProxyConnect(ChannelsLiveServerTestCase): name="proxy_outpost", type=OutpostType.PROXY, service_connection=service_connection, - _config=asdict( - OutpostConfig(authentik_host=self.live_server_url, log_level="debug") - ), + _config=asdict(OutpostConfig(authentik_host=self.live_server_url, log_level="debug")), ) outpost.providers.add(proxy) outpost.save() diff --git a/tests/e2e/test_provider_saml.py b/tests/e2e/test_provider_saml.py index f841b268d..afbb192e9 100644 --- a/tests/e2e/test_provider_saml.py +++ b/tests/e2e/test_provider_saml.py @@ -16,18 +16,8 @@ from authentik.crypto.models import CertificateKeyPair from authentik.flows.models import Flow from authentik.policies.expression.models import ExpressionPolicy from authentik.policies.models import PolicyBinding -from authentik.providers.saml.models import ( - SAMLBindings, - SAMLPropertyMapping, - SAMLProvider, -) -from tests.e2e.utils import ( - USER, - SeleniumTestCase, - apply_migration, - object_manager, - retry, -) +from authentik.providers.saml.models import SAMLBindings, SAMLPropertyMapping, SAMLProvider +from tests.e2e.utils import USER, SeleniumTestCase, apply_migration, object_manager, retry LOGGER = get_logger() @@ -126,9 +116,7 @@ class TestProviderSAML(SeleniumTestCase): [str(USER().pk)], ) self.assertEqual( - body["attr"][ - "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress" - ], + body["attr"]["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress"], [USER().email], ) self.assertEqual( @@ -169,9 +157,7 @@ class TestProviderSAML(SeleniumTestCase): self.driver.get("http://localhost:9009") self.login() - self.wait.until( - ec.presence_of_element_located((By.CSS_SELECTOR, "ak-flow-executor")) - ) + self.wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "ak-flow-executor"))) flow_executor = self.get_shadow_root("ak-flow-executor") consent_stage = self.get_shadow_root("ak-stage-consent", flow_executor) @@ -208,9 +194,7 @@ class TestProviderSAML(SeleniumTestCase): [str(USER().pk)], ) self.assertEqual( - body["attr"][ - "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress" - ], + body["attr"]["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress"], [USER().email], ) self.assertEqual( @@ -279,9 +263,7 @@ class TestProviderSAML(SeleniumTestCase): [str(USER().pk)], ) self.assertEqual( - body["attr"][ - "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress" - ], + body["attr"]["http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress"], [USER().email], ) self.assertEqual( @@ -326,9 +308,7 @@ class TestProviderSAML(SeleniumTestCase): self.driver.get("http://localhost:9009/") self.login() - self.wait.until( - ec.presence_of_element_located((By.CSS_SELECTOR, "header > h1")) - ) + self.wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "header > h1"))) self.assertEqual( self.driver.find_element(By.CSS_SELECTOR, "header > h1").text, "Permission denied", diff --git a/tests/e2e/test_source_oauth.py b/tests/e2e/test_source_oauth.py index c8a73dac5..e32b8853b 100644 --- a/tests/e2e/test_source_oauth.py +++ b/tests/e2e/test_source_oauth.py @@ -18,10 +18,7 @@ from yaml import safe_dump from authentik.core.models import User from authentik.flows.models import Flow -from authentik.providers.oauth2.generators import ( - generate_client_id, - generate_client_secret, -) +from authentik.providers.oauth2.generators import generate_client_id, generate_client_secret from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.types.manager import SourceType from authentik.sources.oauth.types.twitter import TwitterOAuthCallback @@ -145,9 +142,7 @@ class TestSourceOAuth2(SeleniumTestCase): self.driver.get(self.live_server_url) flow_executor = self.get_shadow_root("ak-flow-executor") - identification_stage = self.get_shadow_root( - "ak-stage-identification", flow_executor - ) + identification_stage = self.get_shadow_root("ak-stage-identification", flow_executor) wait = WebDriverWait(identification_stage, self.wait_timeout) wait.until( @@ -166,9 +161,7 @@ class TestSourceOAuth2(SeleniumTestCase): self.driver.find_element(By.ID, "password").send_keys(Keys.ENTER) # Wait until we're logged in - self.wait.until( - ec.presence_of_element_located((By.CSS_SELECTOR, "button[type=submit]")) - ) + self.wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "button[type=submit]"))) self.driver.find_element(By.CSS_SELECTOR, "button[type=submit]").click() # At this point we've been redirected back @@ -177,12 +170,8 @@ class TestSourceOAuth2(SeleniumTestCase): prompt_stage = self.get_shadow_root("ak-stage-prompt", flow_executor) prompt_stage.find_element(By.CSS_SELECTOR, "input[name=username]").click() - prompt_stage.find_element(By.CSS_SELECTOR, "input[name=username]").send_keys( - "foo" - ) - prompt_stage.find_element(By.CSS_SELECTOR, "input[name=username]").send_keys( - Keys.ENTER - ) + prompt_stage.find_element(By.CSS_SELECTOR, "input[name=username]").send_keys("foo") + prompt_stage.find_element(By.CSS_SELECTOR, "input[name=username]").send_keys(Keys.ENTER) # Wait until we've logged in self.wait_for_url(self.if_admin_url("/library")) @@ -205,9 +194,7 @@ class TestSourceOAuth2(SeleniumTestCase): self.driver.get(self.live_server_url) flow_executor = self.get_shadow_root("ak-flow-executor") - identification_stage = self.get_shadow_root( - "ak-stage-identification", flow_executor - ) + identification_stage = self.get_shadow_root("ak-stage-identification", flow_executor) wait = WebDriverWait(identification_stage, self.wait_timeout) wait.until( @@ -226,9 +213,7 @@ class TestSourceOAuth2(SeleniumTestCase): self.driver.find_element(By.ID, "password").send_keys(Keys.ENTER) # Wait until we're logged in - self.wait.until( - ec.presence_of_element_located((By.CSS_SELECTOR, "button[type=submit]")) - ) + self.wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "button[type=submit]"))) self.driver.find_element(By.CSS_SELECTOR, "button[type=submit]").click() @retry() @@ -245,9 +230,7 @@ class TestSourceOAuth2(SeleniumTestCase): self.driver.get(self.url("authentik_flows:default-invalidation")) sleep(1) flow_executor = self.get_shadow_root("ak-flow-executor") - identification_stage = self.get_shadow_root( - "ak-stage-identification", flow_executor - ) + identification_stage = self.get_shadow_root("ak-stage-identification", flow_executor) wait = WebDriverWait(identification_stage, self.wait_timeout) wait.until( @@ -266,9 +249,7 @@ class TestSourceOAuth2(SeleniumTestCase): self.driver.find_element(By.ID, "password").send_keys(Keys.ENTER) # Wait until we're logged in - self.wait.until( - ec.presence_of_element_located((By.CSS_SELECTOR, "button[type=submit]")) - ) + self.wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "button[type=submit]"))) self.driver.find_element(By.CSS_SELECTOR, "button[type=submit]").click() # Wait until we've logged in @@ -342,9 +323,7 @@ class TestSourceOAuth1(SeleniumTestCase): self.driver.get(self.live_server_url) flow_executor = self.get_shadow_root("ak-flow-executor") - identification_stage = self.get_shadow_root( - "ak-stage-identification", flow_executor - ) + identification_stage = self.get_shadow_root("ak-stage-identification", flow_executor) wait = WebDriverWait(identification_stage, self.wait_timeout) wait.until( @@ -363,9 +342,7 @@ class TestSourceOAuth1(SeleniumTestCase): sleep(2) # Wait until we're logged in - self.wait.until( - ec.presence_of_element_located((By.CSS_SELECTOR, "[name='confirm']")) - ) + self.wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "[name='confirm']"))) self.driver.find_element(By.CSS_SELECTOR, "[name='confirm']").click() # Wait until we've loaded the user info page @@ -374,6 +351,4 @@ class TestSourceOAuth1(SeleniumTestCase): self.wait_for_url(self.if_admin_url("/library")) self.driver.get(self.if_admin_url("/user")) - self.assert_user( - User(username="example-user", name="test name", email="foo@example.com") - ) + self.assert_user(User(username="example-user", name="test name", email="foo@example.com")) diff --git a/tests/e2e/test_source_saml.py b/tests/e2e/test_source_saml.py index 314b0272e..c561ca090 100644 --- a/tests/e2e/test_source_saml.py +++ b/tests/e2e/test_source_saml.py @@ -102,18 +102,14 @@ class TestSourceSAML(SeleniumTestCase): @apply_migration("authentik_flows", "0011_flow_title") @apply_migration("authentik_flows", "0009_source_flows") @apply_migration("authentik_crypto", "0002_create_self_signed_kp") - @apply_migration( - "authentik_sources_saml", "0010_samlsource_pre_authentication_flow" - ) + @apply_migration("authentik_sources_saml", "0010_samlsource_pre_authentication_flow") @object_manager def test_idp_redirect(self): """test SAML Source With redirect binding""" # Bootstrap all needed objects authentication_flow = Flow.objects.get(slug="default-source-authentication") enrollment_flow = Flow.objects.get(slug="default-source-enrollment") - pre_authentication_flow = Flow.objects.get( - slug="default-source-pre-authentication" - ) + pre_authentication_flow = Flow.objects.get(slug="default-source-pre-authentication") keypair = CertificateKeyPair.objects.create( name="test-idp-cert", certificate_data=IDP_CERT, @@ -138,9 +134,7 @@ class TestSourceSAML(SeleniumTestCase): self.driver.get(self.live_server_url) flow_executor = self.get_shadow_root("ak-flow-executor") - identification_stage = self.get_shadow_root( - "ak-stage-identification", flow_executor - ) + identification_stage = self.get_shadow_root("ak-stage-identification", flow_executor) wait = WebDriverWait(identification_stage, self.wait_timeout) wait.until( @@ -175,18 +169,14 @@ class TestSourceSAML(SeleniumTestCase): @apply_migration("authentik_flows", "0011_flow_title") @apply_migration("authentik_flows", "0009_source_flows") @apply_migration("authentik_crypto", "0002_create_self_signed_kp") - @apply_migration( - "authentik_sources_saml", "0010_samlsource_pre_authentication_flow" - ) + @apply_migration("authentik_sources_saml", "0010_samlsource_pre_authentication_flow") @object_manager def test_idp_post(self): """test SAML Source With post binding""" # Bootstrap all needed objects authentication_flow = Flow.objects.get(slug="default-source-authentication") enrollment_flow = Flow.objects.get(slug="default-source-enrollment") - pre_authentication_flow = Flow.objects.get( - slug="default-source-pre-authentication" - ) + pre_authentication_flow = Flow.objects.get(slug="default-source-pre-authentication") keypair = CertificateKeyPair.objects.create( name="test-idp-cert", certificate_data=IDP_CERT, @@ -211,9 +201,7 @@ class TestSourceSAML(SeleniumTestCase): self.driver.get(self.live_server_url) flow_executor = self.get_shadow_root("ak-flow-executor") - identification_stage = self.get_shadow_root( - "ak-stage-identification", flow_executor - ) + identification_stage = self.get_shadow_root("ak-stage-identification", flow_executor) wait = WebDriverWait(identification_stage, self.wait_timeout) wait.until( @@ -261,18 +249,14 @@ class TestSourceSAML(SeleniumTestCase): @apply_migration("authentik_flows", "0011_flow_title") @apply_migration("authentik_flows", "0009_source_flows") @apply_migration("authentik_crypto", "0002_create_self_signed_kp") - @apply_migration( - "authentik_sources_saml", "0010_samlsource_pre_authentication_flow" - ) + @apply_migration("authentik_sources_saml", "0010_samlsource_pre_authentication_flow") @object_manager def test_idp_post_auto(self): """test SAML Source With post binding (auto redirect)""" # Bootstrap all needed objects authentication_flow = Flow.objects.get(slug="default-source-authentication") enrollment_flow = Flow.objects.get(slug="default-source-enrollment") - pre_authentication_flow = Flow.objects.get( - slug="default-source-pre-authentication" - ) + pre_authentication_flow = Flow.objects.get(slug="default-source-pre-authentication") keypair = CertificateKeyPair.objects.create( name="test-idp-cert", certificate_data=IDP_CERT, @@ -297,9 +281,7 @@ class TestSourceSAML(SeleniumTestCase): self.driver.get(self.live_server_url) flow_executor = self.get_shadow_root("ak-flow-executor") - identification_stage = self.get_shadow_root( - "ak-stage-identification", flow_executor - ) + identification_stage = self.get_shadow_root("ak-stage-identification", flow_executor) wait = WebDriverWait(identification_stage, self.wait_timeout) wait.until( diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 5ace1eab0..51dec3e92 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -15,11 +15,7 @@ from django.urls import reverse from docker import DockerClient, from_env from docker.models.containers import Container from selenium import webdriver -from selenium.common.exceptions import ( - NoSuchElementException, - TimeoutException, - WebDriverException, -) +from selenium.common.exceptions import NoSuchElementException, TimeoutException, WebDriverException from selenium.webdriver.common.by import By from selenium.webdriver.common.desired_capabilities import DesiredCapabilities from selenium.webdriver.common.keys import Keys @@ -94,16 +90,12 @@ class SeleniumTestCase(StaticLiveServerTestCase): def tearDown(self): if "TF_BUILD" in environ: makedirs("selenium_screenshots/", exist_ok=True) - screenshot_file = ( - f"selenium_screenshots/{self.__class__.__name__}_{time()}.png" - ) + screenshot_file = f"selenium_screenshots/{self.__class__.__name__}_{time()}.png" self.driver.save_screenshot(screenshot_file) self.logger.warning("Saved screenshot", file=screenshot_file) self.logger.debug("--------browser logs") for line in self.driver.get_log("browser"): - self.logger.debug( - line["message"], source=line["source"], level=line["level"] - ) + self.logger.debug(line["message"], source=line["source"], level=line["level"]) self.logger.debug("--------end browser logs") if self.container: self.output_container_logs() @@ -126,43 +118,33 @@ class SeleniumTestCase(StaticLiveServerTestCase): """same as self.url() but show URL in shell""" return f"{self.live_server_url}/if/admin/#{view}" - def get_shadow_root( - self, selector: str, container: Optional[WebElement] = None - ) -> WebElement: + def get_shadow_root(self, selector: str, container: Optional[WebElement] = None) -> WebElement: """Get shadow root element's inner shadowRoot""" if not container: container = self.driver shadow_root = container.find_element(By.CSS_SELECTOR, selector) - element = self.driver.execute_script( - "return arguments[0].shadowRoot", shadow_root - ) + element = self.driver.execute_script("return arguments[0].shadowRoot", shadow_root) return element def login(self): """Do entire login flow and check user afterwards""" flow_executor = self.get_shadow_root("ak-flow-executor") - identification_stage = self.get_shadow_root( - "ak-stage-identification", flow_executor - ) + identification_stage = self.get_shadow_root("ak-stage-identification", flow_executor) - identification_stage.find_element( - By.CSS_SELECTOR, "input[name=uidField]" - ).click() - identification_stage.find_element( - By.CSS_SELECTOR, "input[name=uidField]" - ).send_keys(USER().username) - identification_stage.find_element( - By.CSS_SELECTOR, "input[name=uidField]" - ).send_keys(Keys.ENTER) + identification_stage.find_element(By.CSS_SELECTOR, "input[name=uidField]").click() + identification_stage.find_element(By.CSS_SELECTOR, "input[name=uidField]").send_keys( + USER().username + ) + identification_stage.find_element(By.CSS_SELECTOR, "input[name=uidField]").send_keys( + Keys.ENTER + ) flow_executor = self.get_shadow_root("ak-flow-executor") password_stage = self.get_shadow_root("ak-stage-password", flow_executor) password_stage.find_element(By.CSS_SELECTOR, "input[name=password]").send_keys( USER().username ) - password_stage.find_element(By.CSS_SELECTOR, "input[name=password]").send_keys( - Keys.ENTER - ) + password_stage.find_element(By.CSS_SELECTOR, "input[name=password]").send_keys(Keys.ENTER) sleep(1) def assert_user(self, expected_user: User):