root: reformat to 100 line width

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-08-03 17:45:16 +02:00
parent b87903a209
commit 77ed25ae34
272 changed files with 825 additions and 2590 deletions

View File

@ -23,9 +23,7 @@ def get_events_per_1h(**filter_kwargs) -> list[dict[str, int]]:
date_from = now() - timedelta(days=1) date_from = now() - timedelta(days=1)
result = ( result = (
Event.objects.filter(created__gte=date_from, **filter_kwargs) Event.objects.filter(created__gte=date_from, **filter_kwargs)
.annotate( .annotate(age=ExpressionWrapper(now() - F("created"), output_field=DurationField()))
age=ExpressionWrapper(now() - F("created"), output_field=DurationField())
)
.annotate(age_hours=ExtractHour("age")) .annotate(age_hours=ExtractHour("age"))
.values("age_hours") .values("age_hours")
.annotate(count=Count("pk")) .annotate(count=Count("pk"))
@ -37,8 +35,7 @@ def get_events_per_1h(**filter_kwargs) -> list[dict[str, int]]:
for hour in range(0, -24, -1): for hour in range(0, -24, -1):
results.append( results.append(
{ {
"x_cord": time.mktime((_now + timedelta(hours=hour)).timetuple()) "x_cord": time.mktime((_now + timedelta(hours=hour)).timetuple()) * 1000,
* 1000,
"y_cord": data[hour * -1], "y_cord": data[hour * -1],
} }
) )

View File

@ -61,9 +61,7 @@ class SystemSerializer(PassiveSerializer):
return { return {
"python_version": python_version, "python_version": python_version,
"gunicorn_version": ".".join(str(x) for x in gunicorn_version), "gunicorn_version": ".".join(str(x) for x in gunicorn_version),
"environment": "kubernetes" "environment": "kubernetes" if SERVICE_HOST_ENV_NAME in os.environ else "compose",
if SERVICE_HOST_ENV_NAME in os.environ
else "compose",
"architecture": platform.machine(), "architecture": platform.machine(),
"platform": platform.platform(), "platform": platform.platform(),
"uname": " ".join(platform.uname()), "uname": " ".join(platform.uname()),

View File

@ -92,10 +92,7 @@ class TaskViewSet(ViewSet):
task_func.delay(*task.task_call_args, **task.task_call_kwargs) task_func.delay(*task.task_call_args, **task.task_call_kwargs)
messages.success( messages.success(
self.request, self.request,
_( _("Successfully re-scheduled Task %(name)s!" % {"name": task.task_name}),
"Successfully re-scheduled Task %(name)s!"
% {"name": task.task_name}
),
) )
return Response(status=204) return Response(status=204)
except ImportError: # pragma: no cover except ImportError: # pragma: no cover

View File

@ -41,9 +41,7 @@ class VersionSerializer(PassiveSerializer):
def get_outdated(self, instance) -> bool: def get_outdated(self, instance) -> bool:
"""Check if we're running the latest version""" """Check if we're running the latest version"""
return parse(self.get_version_current(instance)) < parse( return parse(self.get_version_current(instance)) < parse(self.get_version_latest(instance))
self.get_version_latest(instance)
)
class VersionView(APIView): class VersionView(APIView):

View File

@ -17,9 +17,7 @@ class WorkerView(APIView):
permission_classes = [IsAdminUser] permission_classes = [IsAdminUser]
@extend_schema( @extend_schema(responses=inline_serializer("Workers", fields={"count": IntegerField()}))
responses=inline_serializer("Workers", fields={"count": IntegerField()})
)
def get(self, request: Request) -> Response: def get(self, request: Request) -> Response:
"""Get currently connected worker count.""" """Get currently connected worker count."""
count = len(CELERY_APP.control.ping(timeout=0.5)) count = len(CELERY_APP.control.ping(timeout=0.5))

View File

@ -37,18 +37,14 @@ def _set_prom_info():
def update_latest_version(self: MonitoredTask): def update_latest_version(self: MonitoredTask):
"""Update latest version info""" """Update latest version info"""
try: try:
response = get( response = get("https://api.github.com/repos/goauthentik/authentik/releases/latest")
"https://api.github.com/repos/goauthentik/authentik/releases/latest"
)
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
tag_name = data.get("tag_name") tag_name = data.get("tag_name")
upstream_version = tag_name.split("/")[1] upstream_version = tag_name.split("/")[1]
cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT) cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT)
self.set_status( self.set_status(
TaskResult( TaskResult(TaskResultStatus.SUCCESSFUL, ["Successfully updated latest Version"])
TaskResultStatus.SUCCESSFUL, ["Successfully updated latest Version"]
)
) )
_set_prom_info() _set_prom_info()
# Check if upstream version is newer than what we're running, # Check if upstream version is newer than what we're running,

View File

@ -27,9 +27,7 @@ class TestAdminAPI(TestCase):
response = self.client.get(reverse("authentik_api:admin_system_tasks-list")) response = self.client.get(reverse("authentik_api:admin_system_tasks-list"))
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
body = loads(response.content) body = loads(response.content)
self.assertTrue( self.assertTrue(any(task["task_name"] == "clean_expired_models" for task in body))
any(task["task_name"] == "clean_expired_models" for task in body)
)
def test_tasks_single(self): def test_tasks_single(self):
"""Test Task API (read single)""" """Test Task API (read single)"""
@ -45,9 +43,7 @@ class TestAdminAPI(TestCase):
self.assertEqual(body["status"], TaskResultStatus.SUCCESSFUL.name) self.assertEqual(body["status"], TaskResultStatus.SUCCESSFUL.name)
self.assertEqual(body["task_name"], "clean_expired_models") self.assertEqual(body["task_name"], "clean_expired_models")
response = self.client.get( response = self.client.get(
reverse( reverse("authentik_api:admin_system_tasks-detail", kwargs={"pk": "qwerqwer"})
"authentik_api:admin_system_tasks-detail", kwargs={"pk": "qwerqwer"}
)
) )
self.assertEqual(response.status_code, 404) self.assertEqual(response.status_code, 404)

View File

@ -7,9 +7,7 @@ from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
def permission_required( def permission_required(perm: Optional[str] = None, other_perms: Optional[list[str]] = None):
perm: Optional[str] = None, other_perms: Optional[list[str]] = None
):
"""Check permissions for a single custom action""" """Check permissions for a single custom action"""
def wrapper_outter(func: Callable): def wrapper_outter(func: Callable):

View File

@ -63,9 +63,7 @@ def postprocess_schema_responses(result, generator, **kwargs): # noqa: W0613
method["responses"].setdefault("400", validation_error.ref) method["responses"].setdefault("400", validation_error.ref)
method["responses"].setdefault("403", generic_error.ref) method["responses"].setdefault("403", generic_error.ref)
result["components"] = generator.registry.build( result["components"] = generator.registry.build(spectacular_settings.APPEND_COMPONENTS)
spectacular_settings.APPEND_COMPONENTS
)
# This is a workaround for authentik/stages/prompt/stage.py # This is a workaround for authentik/stages/prompt/stage.py
# since the serializer PromptChallengeResponse # since the serializer PromptChallengeResponse

View File

@ -16,17 +16,13 @@ class TestAPIAuth(TestCase):
def test_valid_basic(self): def test_valid_basic(self):
"""Test valid token""" """Test valid token"""
token = Token.objects.create( token = Token.objects.create(intent=TokenIntents.INTENT_API, user=get_anonymous_user())
intent=TokenIntents.INTENT_API, user=get_anonymous_user()
)
auth = b64encode(f":{token.key}".encode()).decode() auth = b64encode(f":{token.key}".encode()).decode()
self.assertEqual(bearer_auth(f"Basic {auth}".encode()), token.user) self.assertEqual(bearer_auth(f"Basic {auth}".encode()), token.user)
def test_valid_bearer(self): def test_valid_bearer(self):
"""Test valid token""" """Test valid token"""
token = Token.objects.create( token = Token.objects.create(intent=TokenIntents.INTENT_API, user=get_anonymous_user())
intent=TokenIntents.INTENT_API, user=get_anonymous_user()
)
self.assertEqual(bearer_auth(f"Bearer {token.key}".encode()), token.user) self.assertEqual(bearer_auth(f"Bearer {token.key}".encode()), token.user)
def test_invalid_type(self): def test_invalid_type(self):

View File

@ -52,20 +52,12 @@ from authentik.policies.reputation.api import (
from authentik.providers.ldap.api import LDAPOutpostConfigViewSet, LDAPProviderViewSet from authentik.providers.ldap.api import LDAPOutpostConfigViewSet, LDAPProviderViewSet
from authentik.providers.oauth2.api.provider import OAuth2ProviderViewSet from authentik.providers.oauth2.api.provider import OAuth2ProviderViewSet
from authentik.providers.oauth2.api.scope import ScopeMappingViewSet from authentik.providers.oauth2.api.scope import ScopeMappingViewSet
from authentik.providers.oauth2.api.tokens import ( from authentik.providers.oauth2.api.tokens import AuthorizationCodeViewSet, RefreshTokenViewSet
AuthorizationCodeViewSet, from authentik.providers.proxy.api import ProxyOutpostConfigViewSet, ProxyProviderViewSet
RefreshTokenViewSet,
)
from authentik.providers.proxy.api import (
ProxyOutpostConfigViewSet,
ProxyProviderViewSet,
)
from authentik.providers.saml.api import SAMLPropertyMappingViewSet, SAMLProviderViewSet from authentik.providers.saml.api import SAMLPropertyMappingViewSet, SAMLProviderViewSet
from authentik.sources.ldap.api import LDAPPropertyMappingViewSet, LDAPSourceViewSet from authentik.sources.ldap.api import LDAPPropertyMappingViewSet, LDAPSourceViewSet
from authentik.sources.oauth.api.source import OAuthSourceViewSet from authentik.sources.oauth.api.source import OAuthSourceViewSet
from authentik.sources.oauth.api.source_connection import ( from authentik.sources.oauth.api.source_connection import UserOAuthSourceConnectionViewSet
UserOAuthSourceConnectionViewSet,
)
from authentik.sources.plex.api import PlexSourceViewSet from authentik.sources.plex.api import PlexSourceViewSet
from authentik.sources.saml.api import SAMLSourceViewSet from authentik.sources.saml.api import SAMLSourceViewSet
from authentik.stages.authenticator_duo.api import ( from authentik.stages.authenticator_duo.api import (
@ -83,9 +75,7 @@ from authentik.stages.authenticator_totp.api import (
TOTPAdminDeviceViewSet, TOTPAdminDeviceViewSet,
TOTPDeviceViewSet, TOTPDeviceViewSet,
) )
from authentik.stages.authenticator_validate.api import ( from authentik.stages.authenticator_validate.api import AuthenticatorValidateStageViewSet
AuthenticatorValidateStageViewSet,
)
from authentik.stages.authenticator_webauthn.api import ( from authentik.stages.authenticator_webauthn.api import (
AuthenticateWebAuthnStageViewSet, AuthenticateWebAuthnStageViewSet,
WebAuthnAdminDeviceViewSet, WebAuthnAdminDeviceViewSet,
@ -122,9 +112,7 @@ router.register("core/tenants", TenantViewSet)
router.register("outposts/instances", OutpostViewSet) router.register("outposts/instances", OutpostViewSet)
router.register("outposts/service_connections/all", ServiceConnectionViewSet) router.register("outposts/service_connections/all", ServiceConnectionViewSet)
router.register("outposts/service_connections/docker", DockerServiceConnectionViewSet) router.register("outposts/service_connections/docker", DockerServiceConnectionViewSet)
router.register( router.register("outposts/service_connections/kubernetes", KubernetesServiceConnectionViewSet)
"outposts/service_connections/kubernetes", KubernetesServiceConnectionViewSet
)
router.register("outposts/proxy", ProxyOutpostConfigViewSet) router.register("outposts/proxy", ProxyOutpostConfigViewSet)
router.register("outposts/ldap", LDAPOutpostConfigViewSet) router.register("outposts/ldap", LDAPOutpostConfigViewSet)
@ -184,9 +172,7 @@ router.register(
StaticAdminDeviceViewSet, StaticAdminDeviceViewSet,
basename="admin-staticdevice", basename="admin-staticdevice",
) )
router.register( router.register("authenticators/admin/totp", TOTPAdminDeviceViewSet, basename="admin-totpdevice")
"authenticators/admin/totp", TOTPAdminDeviceViewSet, basename="admin-totpdevice"
)
router.register( router.register(
"authenticators/admin/webauthn", "authenticators/admin/webauthn",
WebAuthnAdminDeviceViewSet, WebAuthnAdminDeviceViewSet,

View File

@ -147,9 +147,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
"""Custom list method that checks Policy based access instead of guardian""" """Custom list method that checks Policy based access instead of guardian"""
should_cache = request.GET.get("search", "") == "" should_cache = request.GET.get("search", "") == ""
superuser_full_list = ( superuser_full_list = str(request.GET.get("superuser_full_list", "false")).lower() == "true"
str(request.GET.get("superuser_full_list", "false")).lower() == "true"
)
if superuser_full_list and request.user.is_superuser: if superuser_full_list and request.user.is_superuser:
return super().list(request) return super().list(request)
@ -240,9 +238,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet):
app.save() app.save()
return Response({}) return Response({})
@permission_required( @permission_required("authentik_core.view_application", ["authentik_events.view_event"])
"authentik_core.view_application", ["authentik_events.view_event"]
)
@extend_schema(responses={200: CoordinateSerializer(many=True)}) @extend_schema(responses={200: CoordinateSerializer(many=True)})
@action(detail=True, pagination_class=None, filter_backends=[]) @action(detail=True, pagination_class=None, filter_backends=[])
# pylint: disable=unused-argument # pylint: disable=unused-argument

View File

@ -68,9 +68,7 @@ class AuthenticatedSessionSerializer(ModelSerializer):
"""Get parsed user agent""" """Get parsed user agent"""
return user_agent_parser.Parse(instance.last_user_agent) return user_agent_parser.Parse(instance.last_user_agent)
def get_geo_ip( def get_geo_ip(self, instance: AuthenticatedSession) -> Optional[GeoIPDict]: # pragma: no cover
self, instance: AuthenticatedSession
) -> Optional[GeoIPDict]: # pragma: no cover
"""Get parsed user agent""" """Get parsed user agent"""
return GEOIP_READER.city_dict(instance.last_ip) return GEOIP_READER.city_dict(instance.last_ip)

View File

@ -15,11 +15,7 @@ from rest_framework.viewsets import GenericViewSet
from authentik.api.decorators import permission_required from authentik.api.decorators import permission_required
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import ( from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer
MetaNameSerializer,
PassiveSerializer,
TypeCreateSerializer,
)
from authentik.core.expression import PropertyMappingEvaluator from authentik.core.expression import PropertyMappingEvaluator
from authentik.core.models import PropertyMapping from authentik.core.models import PropertyMapping
from authentik.lib.utils.reflection import all_subclasses from authentik.lib.utils.reflection import all_subclasses
@ -141,9 +137,7 @@ class PropertyMappingViewSet(
self.request, self.request,
**test_params.validated_data.get("context", {}), **test_params.validated_data.get("context", {}),
) )
response_data["result"] = dumps( response_data["result"] = dumps(result, indent=(4 if format_result else None))
result, indent=(4 if format_result else None)
)
except Exception as exc: # pylint: disable=broad-except except Exception as exc: # pylint: disable=broad-except
response_data["result"] = str(exc) response_data["result"] = str(exc)
response_data["successful"] = False response_data["successful"] = False

View File

@ -93,9 +93,7 @@ class SourceViewSet(
@action(detail=False, pagination_class=None, filter_backends=[]) @action(detail=False, pagination_class=None, filter_backends=[])
def user_settings(self, request: Request) -> Response: def user_settings(self, request: Request) -> Response:
"""Get all sources the user can configure""" """Get all sources the user can configure"""
_all_sources: Iterable[Source] = Source.objects.filter( _all_sources: Iterable[Source] = Source.objects.filter(enabled=True).select_subclasses()
enabled=True
).select_subclasses()
matching_sources: list[UserSettingSerializer] = [] matching_sources: list[UserSettingSerializer] = []
for source in _all_sources: for source in _all_sources:
user_settings = source.ui_user_settings user_settings = source.ui_user_settings

View File

@ -70,9 +70,7 @@ class TokenViewSet(UsedByMixin, ModelViewSet):
serializer.save( serializer.save(
user=self.request.user, user=self.request.user,
intent=TokenIntents.INTENT_API, intent=TokenIntents.INTENT_API,
expiring=self.request.user.attributes.get( expiring=self.request.user.attributes.get(USER_ATTRIBUTE_TOKEN_EXPIRING, True),
USER_ATTRIBUTE_TOKEN_EXPIRING, True
),
) )
@permission_required("authentik_core.view_token_key") @permission_required("authentik_core.view_token_key")
@ -89,7 +87,5 @@ class TokenViewSet(UsedByMixin, ModelViewSet):
token: Token = self.get_object() token: Token = self.get_object()
if token.is_expired: if token.is_expired:
raise Http404 raise Http404
Event.new(EventAction.SECRET_VIEW, secret=token).from_http( # noqa # nosec Event.new(EventAction.SECRET_VIEW, secret=token).from_http(request) # noqa # nosec
request
)
return Response(TokenViewSerializer({"key": token.key}).data) return Response(TokenViewSerializer({"key": token.key}).data)

View File

@ -79,9 +79,7 @@ class UsedByMixin:
).all(): ).all():
# Only merge shadows on first object # Only merge shadows on first object
if first_object: if first_object:
shadows += getattr( shadows += getattr(manager.model._meta, "authentik_used_by_shadows", [])
manager.model._meta, "authentik_used_by_shadows", []
)
first_object = False first_object = False
serializer = UsedBySerializer( serializer = UsedBySerializer(
data={ data={

View File

@ -26,10 +26,7 @@ from authentik.api.decorators import permission_required
from authentik.core.api.groups import GroupSerializer from authentik.core.api.groups import GroupSerializer
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import LinkSerializer, PassiveSerializer, is_dict from authentik.core.api.utils import LinkSerializer, PassiveSerializer, is_dict
from authentik.core.middleware import ( from authentik.core.middleware import SESSION_IMPERSONATE_ORIGINAL_USER, SESSION_IMPERSONATE_USER
SESSION_IMPERSONATE_ORIGINAL_USER,
SESSION_IMPERSONATE_USER,
)
from authentik.core.models import Token, TokenIntents, User from authentik.core.models import Token, TokenIntents, User
from authentik.events.models import EventAction from authentik.events.models import EventAction
from authentik.tenants.models import Tenant from authentik.tenants.models import Tenant
@ -87,17 +84,13 @@ class UserMetricsSerializer(PassiveSerializer):
def get_logins_failed_per_1h(self, _): def get_logins_failed_per_1h(self, _):
"""Get failed logins per hour for the last 24 hours""" """Get failed logins per hour for the last 24 hours"""
user = self.context["user"] user = self.context["user"]
return get_events_per_1h( return get_events_per_1h(action=EventAction.LOGIN_FAILED, context__username=user.username)
action=EventAction.LOGIN_FAILED, context__username=user.username
)
@extend_schema_field(CoordinateSerializer(many=True)) @extend_schema_field(CoordinateSerializer(many=True))
def get_authorizations_per_1h(self, _): def get_authorizations_per_1h(self, _):
"""Get failed logins per hour for the last 24 hours""" """Get failed logins per hour for the last 24 hours"""
user = self.context["user"] user = self.context["user"]
return get_events_per_1h( return get_events_per_1h(action=EventAction.AUTHORIZE_APPLICATION, user__pk=user.pk)
action=EventAction.AUTHORIZE_APPLICATION, user__pk=user.pk
)
class UsersFilter(FilterSet): class UsersFilter(FilterSet):
@ -154,9 +147,7 @@ class UserViewSet(UsedByMixin, ModelViewSet):
# pylint: disable=invalid-name # pylint: disable=invalid-name
def me(self, request: Request) -> Response: def me(self, request: Request) -> Response:
"""Get information about current user""" """Get information about current user"""
serializer = SessionUserSerializer( serializer = SessionUserSerializer(data={"user": UserSerializer(request.user).data})
data={"user": UserSerializer(request.user).data}
)
if SESSION_IMPERSONATE_USER in request._request.session: if SESSION_IMPERSONATE_USER in request._request.session:
serializer.initial_data["original"] = UserSerializer( serializer.initial_data["original"] = UserSerializer(
request._request.session[SESSION_IMPERSONATE_ORIGINAL_USER] request._request.session[SESSION_IMPERSONATE_ORIGINAL_USER]

View File

@ -3,20 +3,14 @@ from typing import Any
from django.db.models import Model from django.db.models import Model
from rest_framework.fields import CharField, IntegerField from rest_framework.fields import CharField, IntegerField
from rest_framework.serializers import ( from rest_framework.serializers import Serializer, SerializerMethodField, ValidationError
Serializer,
SerializerMethodField,
ValidationError,
)
def is_dict(value: Any): def is_dict(value: Any):
"""Ensure a value is a dictionary, useful for JSONFields""" """Ensure a value is a dictionary, useful for JSONFields"""
if isinstance(value, dict): if isinstance(value, dict):
return return
raise ValidationError( raise ValidationError("Value must be a dictionary, and not have any duplicate keys.")
"Value must be a dictionary, and not have any duplicate keys."
)
class PassiveSerializer(Serializer): class PassiveSerializer(Serializer):
@ -25,9 +19,7 @@ class PassiveSerializer(Serializer):
def create(self, validated_data: dict) -> Model: # pragma: no cover def create(self, validated_data: dict) -> Model: # pragma: no cover
return Model() return Model()
def update( def update(self, instance: Model, validated_data: dict) -> Model: # pragma: no cover
self, instance: Model, validated_data: dict
) -> Model: # pragma: no cover
return Model() return Model()
class Meta: class Meta:

View File

@ -38,9 +38,7 @@ class Migration(migrations.Migration):
("password", models.CharField(max_length=128, verbose_name="password")), ("password", models.CharField(max_length=128, verbose_name="password")),
( (
"last_login", "last_login",
models.DateTimeField( models.DateTimeField(blank=True, null=True, verbose_name="last login"),
blank=True, null=True, verbose_name="last login"
),
), ),
( (
"is_superuser", "is_superuser",
@ -53,35 +51,25 @@ class Migration(migrations.Migration):
( (
"username", "username",
models.CharField( models.CharField(
error_messages={ error_messages={"unique": "A user with that username already exists."},
"unique": "A user with that username already exists."
},
help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.", help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.",
max_length=150, max_length=150,
unique=True, unique=True,
validators=[ validators=[django.contrib.auth.validators.UnicodeUsernameValidator()],
django.contrib.auth.validators.UnicodeUsernameValidator()
],
verbose_name="username", verbose_name="username",
), ),
), ),
( (
"first_name", "first_name",
models.CharField( models.CharField(blank=True, max_length=30, verbose_name="first name"),
blank=True, max_length=30, verbose_name="first name"
),
), ),
( (
"last_name", "last_name",
models.CharField( models.CharField(blank=True, max_length=150, verbose_name="last name"),
blank=True, max_length=150, verbose_name="last name"
),
), ),
( (
"email", "email",
models.EmailField( models.EmailField(blank=True, max_length=254, verbose_name="email address"),
blank=True, max_length=254, verbose_name="email address"
),
), ),
( (
"is_staff", "is_staff",
@ -217,9 +205,7 @@ class Migration(migrations.Migration):
), ),
( (
"expires", "expires",
models.DateTimeField( models.DateTimeField(default=authentik.core.models.default_token_duration),
default=authentik.core.models.default_token_duration
),
), ),
("expiring", models.BooleanField(default=True)), ("expiring", models.BooleanField(default=True)),
("description", models.TextField(blank=True, default="")), ("description", models.TextField(blank=True, default="")),
@ -306,9 +292,7 @@ class Migration(migrations.Migration):
("name", models.TextField(help_text="Application's display Name.")), ("name", models.TextField(help_text="Application's display Name.")),
( (
"slug", "slug",
models.SlugField( models.SlugField(help_text="Internal application name, used in URLs."),
help_text="Internal application name, used in URLs."
),
), ),
("skip_authorization", models.BooleanField(default=False)), ("skip_authorization", models.BooleanField(default=False)),
("meta_launch_url", models.URLField(blank=True, default="")), ("meta_launch_url", models.URLField(blank=True, default="")),

View File

@ -17,9 +17,7 @@ def create_default_user(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
username="akadmin", email="root@localhost", name="authentik Default Admin" username="akadmin", email="root@localhost", name="authentik Default Admin"
) )
if "TF_BUILD" in environ or "AK_ADMIN_PASS" in environ or settings.TEST: if "TF_BUILD" in environ or "AK_ADMIN_PASS" in environ or settings.TEST:
akadmin.set_password( akadmin.set_password(environ.get("AK_ADMIN_PASS", "akadmin"), signal=False) # noqa # nosec
environ.get("AK_ADMIN_PASS", "akadmin"), signal=False
) # noqa # nosec
else: else:
akadmin.set_unusable_password() akadmin.set_unusable_password()
akadmin.save() akadmin.save()

View File

@ -13,8 +13,6 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="source", model_name="source",
name="slug", name="slug",
field=models.SlugField( field=models.SlugField(help_text="Internal source name, used in URLs.", unique=True),
help_text="Internal source name, used in URLs.", unique=True
),
), ),
] ]

View File

@ -13,8 +13,6 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
name="first_name", name="first_name",
field=models.CharField( field=models.CharField(blank=True, max_length=150, verbose_name="first name"),
blank=True, max_length=150, verbose_name="first name"
),
), ),
] ]

View File

@ -40,9 +40,7 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="user", model_name="user",
name="pb_groups", name="pb_groups",
field=models.ManyToManyField( field=models.ManyToManyField(related_name="users", to="authentik_core.Group"),
related_name="users", to="authentik_core.Group"
),
), ),
migrations.AddField( migrations.AddField(
model_name="group", model_name="group",

View File

@ -42,9 +42,7 @@ class Migration(migrations.Migration):
), ),
migrations.AddIndex( migrations.AddIndex(
model_name="token", model_name="token",
index=models.Index( index=models.Index(fields=["identifier"], name="authentik_co_identif_1a34a8_idx"),
fields=["identifier"], name="authentik_co_identif_1a34a8_idx"
),
), ),
migrations.RunPython(set_default_token_key), migrations.RunPython(set_default_token_key),
] ]

View File

@ -17,8 +17,6 @@ class Migration(migrations.Migration):
migrations.AddField( migrations.AddField(
model_name="application", model_name="application",
name="meta_icon", name="meta_icon",
field=models.FileField( field=models.FileField(blank=True, default="", upload_to="application-icons/"),
blank=True, default="", upload_to="application-icons/"
),
), ),
] ]

View File

@ -25,9 +25,7 @@ class Migration(migrations.Migration):
), ),
migrations.AddIndex( migrations.AddIndex(
model_name="token", model_name="token",
index=models.Index( index=models.Index(fields=["identifier"], name="authentik_c_identif_d9d032_idx"),
fields=["identifier"], name="authentik_c_identif_d9d032_idx"
),
), ),
migrations.AddIndex( migrations.AddIndex(
model_name="token", model_name="token",

View File

@ -32,16 +32,12 @@ class Migration(migrations.Migration):
fields=[ fields=[
( (
"expires", "expires",
models.DateTimeField( models.DateTimeField(default=authentik.core.models.default_token_duration),
default=authentik.core.models.default_token_duration
),
), ),
("expiring", models.BooleanField(default=True)), ("expiring", models.BooleanField(default=True)),
( (
"uuid", "uuid",
models.UUIDField( models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False),
default=uuid.uuid4, primary_key=True, serialize=False
),
), ),
("session_key", models.CharField(max_length=40)), ("session_key", models.CharField(max_length=40)),
("last_ip", models.TextField()), ("last_ip", models.TextField()),

View File

@ -13,8 +13,6 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="application", model_name="application",
name="meta_icon", name="meta_icon",
field=models.FileField( field=models.FileField(default=None, null=True, upload_to="application-icons/"),
default=None, null=True, upload_to="application-icons/"
),
), ),
] ]

View File

@ -154,9 +154,7 @@ class User(GuardianUserMixin, AbstractUser):
("s", "158"), ("s", "158"),
("r", "g"), ("r", "g"),
] ]
gravatar_url = ( gravatar_url = f"{GRAVATAR_URL}/avatar/{mail_hash}?{urlencode(parameters, doseq=True)}"
f"{GRAVATAR_URL}/avatar/{mail_hash}?{urlencode(parameters, doseq=True)}"
)
return escape(gravatar_url) return escape(gravatar_url)
return mode % { return mode % {
"username": self.username, "username": self.username,
@ -186,9 +184,7 @@ class Provider(SerializerModel):
related_name="provider_authorization", related_name="provider_authorization",
) )
property_mappings = models.ManyToManyField( property_mappings = models.ManyToManyField("PropertyMapping", default=None, blank=True)
"PropertyMapping", default=None, blank=True
)
objects = InheritanceManager() objects = InheritanceManager()
@ -218,9 +214,7 @@ class Application(PolicyBindingModel):
add custom fields and other properties""" add custom fields and other properties"""
name = models.TextField(help_text=_("Application's display Name.")) name = models.TextField(help_text=_("Application's display Name."))
slug = models.SlugField( slug = models.SlugField(help_text=_("Internal application name, used in URLs."), unique=True)
help_text=_("Internal application name, used in URLs."), unique=True
)
provider = models.OneToOneField( provider = models.OneToOneField(
"Provider", null=True, blank=True, default=None, on_delete=models.SET_DEFAULT "Provider", null=True, blank=True, default=None, on_delete=models.SET_DEFAULT
) )
@ -244,9 +238,7 @@ class Application(PolicyBindingModel):
it is returned as-is""" it is returned as-is"""
if not self.meta_icon: if not self.meta_icon:
return None return None
if self.meta_icon.name.startswith("http") or self.meta_icon.name.startswith( if self.meta_icon.name.startswith("http") or self.meta_icon.name.startswith("/static"):
"/static"
):
return self.meta_icon.name return self.meta_icon.name
return self.meta_icon.url return self.meta_icon.url
@ -301,14 +293,10 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
"""Base Authentication source, i.e. an OAuth Provider, SAML Remote or LDAP Server""" """Base Authentication source, i.e. an OAuth Provider, SAML Remote or LDAP Server"""
name = models.TextField(help_text=_("Source's display Name.")) name = models.TextField(help_text=_("Source's display Name."))
slug = models.SlugField( slug = models.SlugField(help_text=_("Internal source name, used in URLs."), unique=True)
help_text=_("Internal source name, used in URLs."), unique=True
)
enabled = models.BooleanField(default=True) enabled = models.BooleanField(default=True)
property_mappings = models.ManyToManyField( property_mappings = models.ManyToManyField("PropertyMapping", default=None, blank=True)
"PropertyMapping", default=None, blank=True
)
authentication_flow = models.ForeignKey( authentication_flow = models.ForeignKey(
Flow, Flow,
@ -481,9 +469,7 @@ class PropertyMapping(SerializerModel, ManagedModel):
"""Get serializer for this model""" """Get serializer for this model"""
raise NotImplementedError raise NotImplementedError
def evaluate( def evaluate(self, user: Optional[User], request: Optional[HttpRequest], **kwargs) -> Any:
self, user: Optional[User], request: Optional[HttpRequest], **kwargs
) -> Any:
"""Evaluate `self.expression` using `**kwargs` as Context.""" """Evaluate `self.expression` using `**kwargs` as Context."""
from authentik.core.expression import PropertyMappingEvaluator from authentik.core.expression import PropertyMappingEvaluator
@ -522,9 +508,7 @@ class AuthenticatedSession(ExpiringModel):
last_used = models.DateTimeField(auto_now=True) last_used = models.DateTimeField(auto_now=True)
@staticmethod @staticmethod
def from_request( def from_request(request: HttpRequest, user: User) -> Optional["AuthenticatedSession"]:
request: HttpRequest, user: User
) -> Optional["AuthenticatedSession"]:
"""Create a new session from a http request""" """Create a new session from a http request"""
if not hasattr(request, "session") or not request.session.session_key: if not hasattr(request, "session") or not request.session.session_key:
return None return None

View File

@ -14,9 +14,7 @@ from prometheus_client import Gauge
# Arguments: user: User, password: str # Arguments: user: User, password: str
password_changed = Signal() password_changed = Signal()
GAUGE_MODELS = Gauge( GAUGE_MODELS = Gauge("authentik_models", "Count of various objects", ["model_name", "app"])
"authentik_models", "Count of various objects", ["model_name", "app"]
)
if TYPE_CHECKING: if TYPE_CHECKING:
from authentik.core.models import AuthenticatedSession, User from authentik.core.models import AuthenticatedSession, User
@ -60,15 +58,11 @@ def user_logged_out_session(sender, request: HttpRequest, user: "User", **_):
"""Delete AuthenticatedSession if it exists""" """Delete AuthenticatedSession if it exists"""
from authentik.core.models import AuthenticatedSession from authentik.core.models import AuthenticatedSession
AuthenticatedSession.objects.filter( AuthenticatedSession.objects.filter(session_key=request.session.session_key).delete()
session_key=request.session.session_key
).delete()
@receiver(pre_delete) @receiver(pre_delete)
def authenticated_session_delete( def authenticated_session_delete(sender: Type[Model], instance: "AuthenticatedSession", **_):
sender: Type[Model], instance: "AuthenticatedSession", **_
):
"""Delete session when authenticated session is deleted""" """Delete session when authenticated session is deleted"""
from authentik.core.models import AuthenticatedSession from authentik.core.models import AuthenticatedSession

View File

@ -11,16 +11,8 @@ from django.urls import reverse
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.core.models import ( from authentik.core.models import Source, SourceUserMatchingModes, User, UserSourceConnection
Source, from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostUserEnrollmentStage
SourceUserMatchingModes,
User,
UserSourceConnection,
)
from authentik.core.sources.stage import (
PLAN_CONTEXT_SOURCES_CONNECTION,
PostUserEnrollmentStage,
)
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.flows.models import Flow, Stage, in_memory_stage from authentik.flows.models import Flow, Stage, in_memory_stage
from authentik.flows.planner import ( from authentik.flows.planner import (
@ -76,9 +68,7 @@ class SourceFlowManager:
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
def get_action(self, **kwargs) -> tuple[Action, Optional[UserSourceConnection]]: def get_action(self, **kwargs) -> tuple[Action, Optional[UserSourceConnection]]:
"""decide which action should be taken""" """decide which action should be taken"""
new_connection = self.connection_type( new_connection = self.connection_type(source=self.source, identifier=self.identifier)
source=self.source, identifier=self.identifier
)
# When request is authenticated, always link # When request is authenticated, always link
if self.request.user.is_authenticated: if self.request.user.is_authenticated:
new_connection.user = self.request.user new_connection.user = self.request.user
@ -113,9 +103,7 @@ class SourceFlowManager:
SourceUserMatchingModes.USERNAME_DENY, SourceUserMatchingModes.USERNAME_DENY,
]: ]:
if not self.enroll_info.get("username", None): if not self.enroll_info.get("username", None):
self._logger.warning( self._logger.warning("Refusing to use none username", source=self.source)
"Refusing to use none username", source=self.source
)
return Action.DENY, None return Action.DENY, None
query = Q(username__exact=self.enroll_info.get("username", None)) query = Q(username__exact=self.enroll_info.get("username", None))
self._logger.debug("trying to link with existing user", query=query) self._logger.debug("trying to link with existing user", query=query)
@ -229,10 +217,7 @@ class SourceFlowManager:
"""Login user and redirect.""" """Login user and redirect."""
messages.success( messages.success(
self.request, self.request,
_( _("Successfully authenticated with %(source)s!" % {"source": self.source.name}),
"Successfully authenticated with %(source)s!"
% {"source": self.source.name}
),
) )
flow_kwargs = {PLAN_CONTEXT_PENDING_USER: connection.user} flow_kwargs = {PLAN_CONTEXT_PENDING_USER: connection.user}
return self._handle_login_flow(self.source.authentication_flow, **flow_kwargs) return self._handle_login_flow(self.source.authentication_flow, **flow_kwargs)
@ -270,10 +255,7 @@ class SourceFlowManager:
"""User was not authenticated and previous request was not authenticated.""" """User was not authenticated and previous request was not authenticated."""
messages.success( messages.success(
self.request, self.request,
_( _("Successfully authenticated with %(source)s!" % {"source": self.source.name}),
"Successfully authenticated with %(source)s!"
% {"source": self.source.name}
),
) )
# We run the Flow planner here so we can pass the Pending user in the context # We run the Flow planner here so we can pass the Pending user in the context

View File

@ -27,9 +27,7 @@ def clean_expired_models(self: MonitoredTask):
for cls in ExpiringModel.__subclasses__(): for cls in ExpiringModel.__subclasses__():
cls: ExpiringModel cls: ExpiringModel
objects = ( objects = (
cls.objects.all() cls.objects.all().exclude(expiring=False).exclude(expiring=True, expires__gt=now())
.exclude(expiring=False)
.exclude(expiring=True, expires__gt=now())
) )
for obj in objects: for obj in objects:
obj.expire_action() obj.expire_action()

View File

@ -17,9 +17,7 @@ class TestApplicationsAPI(APITestCase):
self.denied = Application.objects.create(name="denied", slug="denied") self.denied = Application.objects.create(name="denied", slug="denied")
PolicyBinding.objects.create( PolicyBinding.objects.create(
target=self.denied, target=self.denied,
policy=DummyPolicy.objects.create( policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2),
name="deny", result=False, wait_min=1, wait_max=2
),
order=0, order=0,
) )
@ -33,9 +31,7 @@ class TestApplicationsAPI(APITestCase):
) )
) )
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertJSONEqual( self.assertJSONEqual(force_str(response.content), {"messages": [], "passing": True})
force_str(response.content), {"messages": [], "passing": True}
)
response = self.client.get( response = self.client.get(
reverse( reverse(
"authentik_api:application-check-access", "authentik_api:application-check-access",
@ -43,9 +39,7 @@ class TestApplicationsAPI(APITestCase):
) )
) )
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertJSONEqual( self.assertJSONEqual(force_str(response.content), {"messages": ["dummy"], "passing": False})
force_str(response.content), {"messages": ["dummy"], "passing": False}
)
def test_list(self): def test_list(self):
"""Test list operation without superuser_full_list""" """Test list operation without superuser_full_list"""

View File

@ -46,9 +46,7 @@ class TestImpersonation(TestCase):
self.client.force_login(self.other_user) self.client.force_login(self.other_user)
self.client.get( self.client.get(
reverse( reverse("authentik_core:impersonate-init", kwargs={"user_id": self.akadmin.pk})
"authentik_core:impersonate-init", kwargs={"user_id": self.akadmin.pk}
)
) )
response = self.client.get(reverse("authentik_api:user-me")) response = self.client.get(reverse("authentik_api:user-me"))

View File

@ -22,9 +22,7 @@ class TestModels(TestCase):
def test_token_expire_no_expire(self): def test_token_expire_no_expire(self):
"""Test token expiring with "expiring" set""" """Test token expiring with "expiring" set"""
token = Token.objects.create( token = Token.objects.create(expires=now(), user=get_anonymous_user(), expiring=False)
expires=now(), user=get_anonymous_user(), expiring=False
)
sleep(0.5) sleep(0.5)
self.assertFalse(token.is_expired) self.assertFalse(token.is_expired)

View File

@ -16,9 +16,7 @@ class TestPropertyMappings(TestCase):
def test_expression(self): def test_expression(self):
"""Test expression""" """Test expression"""
mapping = PropertyMapping.objects.create( mapping = PropertyMapping.objects.create(name="test", expression="return 'test'")
name="test", expression="return 'test'"
)
self.assertEqual(mapping.evaluate(None, None), "test") self.assertEqual(mapping.evaluate(None, None), "test")
def test_expression_syntax(self): def test_expression_syntax(self):

View File

@ -23,9 +23,7 @@ class TestPropertyMappingAPI(APITestCase):
def test_test_call(self): def test_test_call(self):
"""Test PropertMappings's test endpoint""" """Test PropertMappings's test endpoint"""
response = self.client.post( response = self.client.post(
reverse( reverse("authentik_api:propertymapping-test", kwargs={"pk": self.mapping.pk}),
"authentik_api:propertymapping-test", kwargs={"pk": self.mapping.pk}
),
data={ data={
"user": self.user.pk, "user": self.user.pk,
}, },

View File

@ -4,12 +4,7 @@ from django.utils.timezone import now
from guardian.shortcuts import get_anonymous_user from guardian.shortcuts import get_anonymous_user
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from authentik.core.models import ( from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents, User
USER_ATTRIBUTE_TOKEN_EXPIRING,
Token,
TokenIntents,
User,
)
from authentik.core.tasks import clean_expired_models from authentik.core.tasks import clean_expired_models

View File

@ -5,10 +5,7 @@ from django.shortcuts import get_object_or_404, redirect
from django.views import View from django.views import View
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.core.middleware import ( from authentik.core.middleware import SESSION_IMPERSONATE_ORIGINAL_USER, SESSION_IMPERSONATE_USER
SESSION_IMPERSONATE_ORIGINAL_USER,
SESSION_IMPERSONATE_USER,
)
from authentik.core.models import User from authentik.core.models import User
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
@ -21,9 +18,7 @@ class ImpersonateInitView(View):
def get(self, request: HttpRequest, user_id: int) -> HttpResponse: def get(self, request: HttpRequest, user_id: int) -> HttpResponse:
"""Impersonation handler, checks permissions""" """Impersonation handler, checks permissions"""
if not request.user.has_perm("impersonate"): if not request.user.has_perm("impersonate"):
LOGGER.debug( LOGGER.debug("User attempted to impersonate without permissions", user=request.user)
"User attempted to impersonate without permissions", user=request.user
)
return HttpResponse("Unauthorized", status=401) return HttpResponse("Unauthorized", status=401)
user_to_be = get_object_or_404(User, pk=user_id) user_to_be = get_object_or_404(User, pk=user_id)

View File

@ -14,9 +14,7 @@ class EndSessionView(TemplateView, PolicyAccessView):
template_name = "if/end_session.html" template_name = "if/end_session.html"
def resolve_provider_application(self): def resolve_provider_application(self):
self.application = get_object_or_404( self.application = get_object_or_404(Application, slug=self.kwargs["application_slug"])
Application, slug=self.kwargs["application_slug"]
)
def get_context_data(self, **kwargs: Any) -> dict[str, Any]: def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
context = super().get_context_data(**kwargs) context = super().get_context_data(**kwargs)

View File

@ -10,12 +10,7 @@ from django_filters.filters import BooleanFilter
from drf_spectacular.types import OpenApiTypes from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.fields import ( from rest_framework.fields import CharField, DateTimeField, IntegerField, SerializerMethodField
CharField,
DateTimeField,
IntegerField,
SerializerMethodField,
)
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.serializers import ModelSerializer, ValidationError from rest_framework.serializers import ModelSerializer, ValidationError
@ -86,9 +81,7 @@ class CertificateKeyPairSerializer(ModelSerializer):
backend=default_backend(), backend=default_backend(),
) )
except (ValueError, TypeError): except (ValueError, TypeError):
raise ValidationError( raise ValidationError("Unable to load private key (possibly encrypted?).")
"Unable to load private key (possibly encrypted?)."
)
return value return value
class Meta: class Meta:
@ -123,9 +116,7 @@ class CertificateGenerationSerializer(PassiveSerializer):
"""Certificate generation parameters""" """Certificate generation parameters"""
common_name = CharField() common_name = CharField()
subject_alt_name = CharField( subject_alt_name = CharField(required=False, allow_blank=True, label=_("Subject-alt name"))
required=False, allow_blank=True, label=_("Subject-alt name")
)
validity_days = IntegerField(initial=365) validity_days = IntegerField(initial=365)
@ -170,9 +161,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
builder = CertificateBuilder() builder = CertificateBuilder()
builder.common_name = data.validated_data["common_name"] builder.common_name = data.validated_data["common_name"]
builder.build( builder.build(
subject_alt_names=data.validated_data.get("subject_alt_name", "").split( subject_alt_names=data.validated_data.get("subject_alt_name", "").split(","),
","
),
validity_days=int(data.validated_data["validity_days"]), validity_days=int(data.validated_data["validity_days"]),
) )
instance = builder.save() instance = builder.save()
@ -208,9 +197,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
"Content-Disposition" "Content-Disposition"
] = f'attachment; filename="{certificate.name}_certificate.pem"' ] = f'attachment; filename="{certificate.name}_certificate.pem"'
return response return response
return Response( return Response(CertificateDataSerializer({"data": certificate.certificate_data}).data)
CertificateDataSerializer({"data": certificate.certificate_data}).data
)
@extend_schema( @extend_schema(
parameters=[ parameters=[
@ -234,9 +221,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
).from_http(request) ).from_http(request)
if "download" in request._request.GET: if "download" in request._request.GET:
# Mime type from https://pki-tutorial.readthedocs.io/en/latest/mime.html # Mime type from https://pki-tutorial.readthedocs.io/en/latest/mime.html
response = HttpResponse( response = HttpResponse(certificate.key_data, content_type="application/x-pem-file")
certificate.key_data, content_type="application/x-pem-file"
)
response[ response[
"Content-Disposition" "Content-Disposition"
] = f'attachment; filename="{certificate.name}_private_key.pem"' ] = f'attachment; filename="{certificate.name}_private_key.pem"'

View File

@ -46,9 +46,7 @@ class CertificateBuilder:
public_exponent=65537, key_size=2048, backend=default_backend() public_exponent=65537, key_size=2048, backend=default_backend()
) )
self.__public_key = self.__private_key.public_key() self.__public_key = self.__private_key.public_key()
alt_names: list[x509.GeneralName] = [ alt_names: list[x509.GeneralName] = [x509.DNSName(x) for x in subject_alt_names or []]
x509.DNSName(x) for x in subject_alt_names or []
]
self.__builder = ( self.__builder = (
x509.CertificateBuilder() x509.CertificateBuilder()
.subject_name( .subject_name(
@ -59,9 +57,7 @@ class CertificateBuilder:
self.common_name, self.common_name,
), ),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "authentik"), x509.NameAttribute(NameOID.ORGANIZATION_NAME, "authentik"),
x509.NameAttribute( x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Self-signed"),
NameOID.ORGANIZATIONAL_UNIT_NAME, "Self-signed"
),
] ]
) )
) )
@ -77,9 +73,7 @@ class CertificateBuilder:
) )
.add_extension(x509.SubjectAlternativeName(alt_names), critical=True) .add_extension(x509.SubjectAlternativeName(alt_names), critical=True)
.not_valid_before(datetime.datetime.today() - one_day) .not_valid_before(datetime.datetime.today() - one_day)
.not_valid_after( .not_valid_after(datetime.datetime.today() + datetime.timedelta(days=validity_days))
datetime.datetime.today() + datetime.timedelta(days=validity_days)
)
.serial_number(int(uuid.uuid4())) .serial_number(int(uuid.uuid4()))
.public_key(self.__public_key) .public_key(self.__public_key)
) )

View File

@ -57,9 +57,7 @@ class CertificateKeyPair(CreatedUpdatedModel):
if not self._private_key and self._private_key != "": if not self._private_key and self._private_key != "":
try: try:
self._private_key = load_pem_private_key( self._private_key = load_pem_private_key(
str.encode( str.encode("\n".join([x.strip() for x in self.key_data.split("\n")])),
"\n".join([x.strip() for x in self.key_data.split("\n")])
),
password=None, password=None,
backend=default_backend(), backend=default_backend(),
) )
@ -70,25 +68,19 @@ class CertificateKeyPair(CreatedUpdatedModel):
@property @property
def fingerprint_sha256(self) -> str: def fingerprint_sha256(self) -> str:
"""Get SHA256 Fingerprint of certificate_data""" """Get SHA256 Fingerprint of certificate_data"""
return hexlify(self.certificate.fingerprint(hashes.SHA256()), ":").decode( return hexlify(self.certificate.fingerprint(hashes.SHA256()), ":").decode("utf-8")
"utf-8"
)
@property @property
def fingerprint_sha1(self) -> str: def fingerprint_sha1(self) -> str:
"""Get SHA1 Fingerprint of certificate_data""" """Get SHA1 Fingerprint of certificate_data"""
return hexlify( return hexlify(self.certificate.fingerprint(hashes.SHA1()), ":").decode("utf-8") # nosec
self.certificate.fingerprint(hashes.SHA1()), ":" # nosec
).decode("utf-8")
@property @property
def kid(self): def kid(self):
"""Get Key ID used for JWKS""" """Get Key ID used for JWKS"""
return "{0}".format( return "{0}".format(
md5(self.key_data.encode("utf-8")).hexdigest() # nosec md5(self.key_data.encode("utf-8")).hexdigest() if self.key_data else ""
if self.key_data ) # nosec
else ""
)
def __str__(self) -> str: def __str__(self) -> str:
return f"Certificate-Key Pair {self.name}" return f"Certificate-Key Pair {self.name}"

View File

@ -143,7 +143,5 @@ class EventViewSet(ModelViewSet):
"""Get all actions""" """Get all actions"""
data = [] data = []
for value, name in EventAction.choices: for value, name in EventAction.choices:
data.append( data.append({"name": name, "description": "", "component": value, "model_name": ""})
{"name": name, "description": "", "component": value, "model_name": ""}
)
return Response(TypeCreateSerializer(data, many=True).data) return Response(TypeCreateSerializer(data, many=True).data)

View File

@ -29,12 +29,8 @@ class AuditMiddleware:
def __call__(self, request: HttpRequest) -> HttpResponse: def __call__(self, request: HttpRequest) -> HttpResponse:
# Connect signal for automatic logging # Connect signal for automatic logging
if hasattr(request, "user") and getattr( if hasattr(request, "user") and getattr(request.user, "is_authenticated", False):
request.user, "is_authenticated", False post_save_handler = partial(self.post_save_handler, user=request.user, request=request)
):
post_save_handler = partial(
self.post_save_handler, user=request.user, request=request
)
pre_delete_handler = partial( pre_delete_handler = partial(
self.pre_delete_handler, user=request.user, request=request self.pre_delete_handler, user=request.user, request=request
) )
@ -94,13 +90,9 @@ class AuditMiddleware:
@staticmethod @staticmethod
# pylint: disable=unused-argument # pylint: disable=unused-argument
def pre_delete_handler( def pre_delete_handler(user: User, request: HttpRequest, sender, instance: Model, **_):
user: User, request: HttpRequest, sender, instance: Model, **_
):
"""Signal handler for all object's pre_delete""" """Signal handler for all object's pre_delete"""
if isinstance( if isinstance(instance, (Event, Notification, UserObjectPermission)): # pragma: no cover
instance, (Event, Notification, UserObjectPermission)
): # pragma: no cover
return return
EventNewThread( EventNewThread(

View File

@ -14,9 +14,7 @@ def convert_user_to_json(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
event.delete() event.delete()
# Because event objects cannot be updated, we have to re-create them # Because event objects cannot be updated, we have to re-create them
event.pk = None event.pk = None
event.user_json = ( event.user_json = authentik.events.models.get_user(event.user) if event.user else {}
authentik.events.models.get_user(event.user) if event.user else {}
)
event._state.adding = True event._state.adding = True
event.save() event.save()
@ -58,7 +56,5 @@ class Migration(migrations.Migration):
model_name="event", model_name="event",
name="user", name="user",
), ),
migrations.RenameField( migrations.RenameField(model_name="event", old_name="user_json", new_name="user"),
model_name="event", old_name="user_json", new_name="user"
),
] ]

View File

@ -11,16 +11,12 @@ def notify_configuration_error(apps: Apps, schema_editor: BaseDatabaseSchemaEdit
db_alias = schema_editor.connection.alias db_alias = schema_editor.connection.alias
Group = apps.get_model("authentik_core", "Group") Group = apps.get_model("authentik_core", "Group")
PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding")
EventMatcherPolicy = apps.get_model( EventMatcherPolicy = apps.get_model("authentik_policies_event_matcher", "EventMatcherPolicy")
"authentik_policies_event_matcher", "EventMatcherPolicy"
)
NotificationRule = apps.get_model("authentik_events", "NotificationRule") NotificationRule = apps.get_model("authentik_events", "NotificationRule")
NotificationTransport = apps.get_model("authentik_events", "NotificationTransport") NotificationTransport = apps.get_model("authentik_events", "NotificationTransport")
admin_group = ( admin_group = (
Group.objects.using(db_alias) Group.objects.using(db_alias).filter(name="authentik Admins", is_superuser=True).first()
.filter(name="authentik Admins", is_superuser=True)
.first()
) )
policy, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create( policy, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create(
@ -32,9 +28,7 @@ def notify_configuration_error(apps: Apps, schema_editor: BaseDatabaseSchemaEdit
defaults={"group": admin_group, "severity": NotificationSeverity.ALERT}, defaults={"group": admin_group, "severity": NotificationSeverity.ALERT},
) )
trigger.transports.set( trigger.transports.set(
NotificationTransport.objects.using(db_alias).filter( NotificationTransport.objects.using(db_alias).filter(name="default-email-transport")
name="default-email-transport"
)
) )
trigger.save() trigger.save()
PolicyBinding.objects.using(db_alias).update_or_create( PolicyBinding.objects.using(db_alias).update_or_create(
@ -50,16 +44,12 @@ def notify_update(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
db_alias = schema_editor.connection.alias db_alias = schema_editor.connection.alias
Group = apps.get_model("authentik_core", "Group") Group = apps.get_model("authentik_core", "Group")
PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding")
EventMatcherPolicy = apps.get_model( EventMatcherPolicy = apps.get_model("authentik_policies_event_matcher", "EventMatcherPolicy")
"authentik_policies_event_matcher", "EventMatcherPolicy"
)
NotificationRule = apps.get_model("authentik_events", "NotificationRule") NotificationRule = apps.get_model("authentik_events", "NotificationRule")
NotificationTransport = apps.get_model("authentik_events", "NotificationTransport") NotificationTransport = apps.get_model("authentik_events", "NotificationTransport")
admin_group = ( admin_group = (
Group.objects.using(db_alias) Group.objects.using(db_alias).filter(name="authentik Admins", is_superuser=True).first()
.filter(name="authentik Admins", is_superuser=True)
.first()
) )
policy, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create( policy, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create(
@ -71,9 +61,7 @@ def notify_update(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
defaults={"group": admin_group, "severity": NotificationSeverity.ALERT}, defaults={"group": admin_group, "severity": NotificationSeverity.ALERT},
) )
trigger.transports.set( trigger.transports.set(
NotificationTransport.objects.using(db_alias).filter( NotificationTransport.objects.using(db_alias).filter(name="default-email-transport")
name="default-email-transport"
)
) )
trigger.save() trigger.save()
PolicyBinding.objects.using(db_alias).update_or_create( PolicyBinding.objects.using(db_alias).update_or_create(
@ -89,16 +77,12 @@ def notify_exception(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
db_alias = schema_editor.connection.alias db_alias = schema_editor.connection.alias
Group = apps.get_model("authentik_core", "Group") Group = apps.get_model("authentik_core", "Group")
PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding")
EventMatcherPolicy = apps.get_model( EventMatcherPolicy = apps.get_model("authentik_policies_event_matcher", "EventMatcherPolicy")
"authentik_policies_event_matcher", "EventMatcherPolicy"
)
NotificationRule = apps.get_model("authentik_events", "NotificationRule") NotificationRule = apps.get_model("authentik_events", "NotificationRule")
NotificationTransport = apps.get_model("authentik_events", "NotificationTransport") NotificationTransport = apps.get_model("authentik_events", "NotificationTransport")
admin_group = ( admin_group = (
Group.objects.using(db_alias) Group.objects.using(db_alias).filter(name="authentik Admins", is_superuser=True).first()
.filter(name="authentik Admins", is_superuser=True)
.first()
) )
policy_policy_exc, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create( policy_policy_exc, _ = EventMatcherPolicy.objects.using(db_alias).update_or_create(
@ -114,9 +98,7 @@ def notify_exception(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
defaults={"group": admin_group, "severity": NotificationSeverity.ALERT}, defaults={"group": admin_group, "severity": NotificationSeverity.ALERT},
) )
trigger.transports.set( trigger.transports.set(
NotificationTransport.objects.using(db_alias).filter( NotificationTransport.objects.using(db_alias).filter(name="default-email-transport")
name="default-email-transport"
)
) )
trigger.save() trigger.save()
PolicyBinding.objects.using(db_alias).update_or_create( PolicyBinding.objects.using(db_alias).update_or_create(

View File

@ -38,9 +38,7 @@ def progress_bar(
def print_progress_bar(iteration): def print_progress_bar(iteration):
"""Progress Bar Printing Function""" """Progress Bar Printing Function"""
percent = ("{0:." + str(decimals) + "f}").format( percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
100 * (iteration / float(total))
)
filledLength = int(length * iteration // total) filledLength = int(length * iteration // total)
bar = fill * filledLength + "-" * (length - filledLength) bar = fill * filledLength + "-" * (length - filledLength)
print(f"\r{prefix} |{bar}| {percent}% {suffix}", end=print_end) print(f"\r{prefix} |{bar}| {percent}% {suffix}", end=print_end)
@ -78,9 +76,7 @@ class Migration(migrations.Migration):
migrations.AddField( migrations.AddField(
model_name="event", model_name="event",
name="expires", name="expires",
field=models.DateTimeField( field=models.DateTimeField(default=authentik.events.models.default_event_duration),
default=authentik.events.models.default_event_duration
),
), ),
migrations.AddField( migrations.AddField(
model_name="event", model_name="event",

View File

@ -15,9 +15,7 @@ class Migration(migrations.Migration):
migrations.AddField( migrations.AddField(
model_name="event", model_name="event",
name="tenant", name="tenant",
field=models.JSONField( field=models.JSONField(blank=True, default=authentik.events.models.default_tenant),
blank=True, default=authentik.events.models.default_tenant
),
), ),
migrations.AlterField( migrations.AlterField(
model_name="event", model_name="event",

View File

@ -15,10 +15,7 @@ from requests import RequestException, post
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik import __version__ from authentik import __version__
from authentik.core.middleware import ( from authentik.core.middleware import SESSION_IMPERSONATE_ORIGINAL_USER, SESSION_IMPERSONATE_USER
SESSION_IMPERSONATE_ORIGINAL_USER,
SESSION_IMPERSONATE_USER,
)
from authentik.core.models import ExpiringModel, Group, User from authentik.core.models import ExpiringModel, Group, User
from authentik.events.geo import GEOIP_READER from authentik.events.geo import GEOIP_READER
from authentik.events.utils import cleanse_dict, get_user, model_to_dict, sanitize_dict from authentik.events.utils import cleanse_dict, get_user, model_to_dict, sanitize_dict
@ -159,9 +156,7 @@ class Event(ExpiringModel):
if hasattr(request, "user"): if hasattr(request, "user"):
original_user = None original_user = None
if hasattr(request, "session"): if hasattr(request, "session"):
original_user = request.session.get( original_user = request.session.get(SESSION_IMPERSONATE_ORIGINAL_USER, None)
SESSION_IMPERSONATE_ORIGINAL_USER, None
)
self.user = get_user(request.user, original_user) self.user = get_user(request.user, original_user)
if user: if user:
self.user = get_user(user) self.user = get_user(user)
@ -169,9 +164,7 @@ class Event(ExpiringModel):
if hasattr(request, "session"): if hasattr(request, "session"):
if SESSION_IMPERSONATE_ORIGINAL_USER in request.session: if SESSION_IMPERSONATE_ORIGINAL_USER in request.session:
self.user = get_user(request.session[SESSION_IMPERSONATE_ORIGINAL_USER]) self.user = get_user(request.session[SESSION_IMPERSONATE_ORIGINAL_USER])
self.user["on_behalf_of"] = get_user( self.user["on_behalf_of"] = get_user(request.session[SESSION_IMPERSONATE_USER])
request.session[SESSION_IMPERSONATE_USER]
)
# User 255.255.255.255 as fallback if IP cannot be determined # User 255.255.255.255 as fallback if IP cannot be determined
self.client_ip = get_client_ip(request) self.client_ip = get_client_ip(request)
# Apply GeoIP Data, when enabled # Apply GeoIP Data, when enabled
@ -414,9 +407,7 @@ class NotificationRule(PolicyBindingModel):
severity = models.TextField( severity = models.TextField(
choices=NotificationSeverity.choices, choices=NotificationSeverity.choices,
default=NotificationSeverity.NOTICE, default=NotificationSeverity.NOTICE,
help_text=_( help_text=_("Controls which severity level the created notifications will have."),
"Controls which severity level the created notifications will have."
),
) )
group = models.ForeignKey( group = models.ForeignKey(
Group, Group,

View File

@ -135,9 +135,7 @@ class MonitoredTask(Task):
self._result = result self._result = result
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
def after_return( def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo):
self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo
):
if self._result: if self._result:
if not self._result.uid: if not self._result.uid:
self._result.uid = self._uid self._result.uid = self._uid
@ -159,9 +157,7 @@ class MonitoredTask(Task):
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
def on_failure(self, exc, task_id, args, kwargs, einfo): def on_failure(self, exc, task_id, args, kwargs, einfo):
if not self._result: if not self._result:
self._result = TaskResult( self._result = TaskResult(status=TaskResultStatus.ERROR, messages=[str(exc)])
status=TaskResultStatus.ERROR, messages=[str(exc)]
)
if not self._result.uid: if not self._result.uid:
self._result.uid = self._uid self._result.uid = self._uid
TaskInfo( TaskInfo(
@ -179,8 +175,7 @@ class MonitoredTask(Task):
Event.new( Event.new(
EventAction.SYSTEM_TASK_EXCEPTION, EventAction.SYSTEM_TASK_EXCEPTION,
message=( message=(
f"Task {self.__name__} encountered an error: " f"Task {self.__name__} encountered an error: " "\n".join(self._result.messages)
"\n".join(self._result.messages)
), ),
).save() ).save()
return super().on_failure(exc, task_id, args, kwargs, einfo=einfo) return super().on_failure(exc, task_id, args, kwargs, einfo=einfo)

View File

@ -2,11 +2,7 @@
from threading import Thread from threading import Thread
from typing import Any, Optional from typing import Any, Optional
from django.contrib.auth.signals import ( from django.contrib.auth.signals import user_logged_in, user_logged_out, user_login_failed
user_logged_in,
user_logged_out,
user_login_failed,
)
from django.db.models.signals import post_save from django.db.models.signals import post_save
from django.dispatch import receiver from django.dispatch import receiver
from django.http import HttpRequest from django.http import HttpRequest
@ -30,9 +26,7 @@ class EventNewThread(Thread):
kwargs: dict[str, Any] kwargs: dict[str, Any]
user: Optional[User] = None user: Optional[User] = None
def __init__( def __init__(self, action: str, request: HttpRequest, user: Optional[User] = None, **kwargs):
self, action: str, request: HttpRequest, user: Optional[User] = None, **kwargs
):
super().__init__() super().__init__()
self.action = action self.action = action
self.request = request self.request = request
@ -68,9 +62,7 @@ def on_user_logged_out(sender, request: HttpRequest, user: User, **_):
@receiver(user_write) @receiver(user_write)
# pylint: disable=unused-argument # pylint: disable=unused-argument
def on_user_write( def on_user_write(sender, request: HttpRequest, user: User, data: dict[str, Any], **kwargs):
sender, request: HttpRequest, user: User, data: dict[str, Any], **kwargs
):
"""Log User write""" """Log User write"""
thread = EventNewThread(EventAction.USER_WRITE, request, **data) thread = EventNewThread(EventAction.USER_WRITE, request, **data)
thread.kwargs["created"] = kwargs.get("created", False) thread.kwargs["created"] = kwargs.get("created", False)
@ -80,9 +72,7 @@ def on_user_write(
@receiver(user_login_failed) @receiver(user_login_failed)
# pylint: disable=unused-argument # pylint: disable=unused-argument
def on_user_login_failed( def on_user_login_failed(sender, credentials: dict[str, str], request: HttpRequest, **_):
sender, credentials: dict[str, str], request: HttpRequest, **_
):
"""Failed Login""" """Failed Login"""
thread = EventNewThread(EventAction.LOGIN_FAILED, request, **credentials) thread = EventNewThread(EventAction.LOGIN_FAILED, request, **credentials)
thread.run() thread.run()

View File

@ -22,9 +22,7 @@ LOGGER = get_logger()
def event_notification_handler(event_uuid: str): def event_notification_handler(event_uuid: str):
"""Start task for each trigger definition""" """Start task for each trigger definition"""
for trigger in NotificationRule.objects.all(): for trigger in NotificationRule.objects.all():
event_trigger_handler.apply_async( event_trigger_handler.apply_async(args=[event_uuid, trigger.name], queue="authentik_events")
args=[event_uuid, trigger.name], queue="authentik_events"
)
@CELERY_APP.task() @CELERY_APP.task()
@ -43,17 +41,13 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
if "policy_uuid" in event.context: if "policy_uuid" in event.context:
policy_uuid = event.context["policy_uuid"] policy_uuid = event.context["policy_uuid"]
if PolicyBinding.objects.filter( if PolicyBinding.objects.filter(
target__in=NotificationRule.objects.all().values_list( target__in=NotificationRule.objects.all().values_list("pbm_uuid", flat=True),
"pbm_uuid", flat=True
),
policy=policy_uuid, policy=policy_uuid,
).exists(): ).exists():
# If policy that caused this event to be created is attached # If policy that caused this event to be created is attached
# to *any* NotificationRule, we return early. # to *any* NotificationRule, we return early.
# This is the most effective way to prevent infinite loops. # This is the most effective way to prevent infinite loops.
LOGGER.debug( LOGGER.debug("e(trigger): attempting to prevent infinite loop", trigger=trigger)
"e(trigger): attempting to prevent infinite loop", trigger=trigger
)
return return
if not trigger.group: if not trigger.group:
@ -62,9 +56,7 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
LOGGER.debug("e(trigger): checking if trigger applies", trigger=trigger) LOGGER.debug("e(trigger): checking if trigger applies", trigger=trigger)
try: try:
user = ( user = User.objects.filter(pk=event.user.get("pk")).first() or get_anonymous_user()
User.objects.filter(pk=event.user.get("pk")).first() or get_anonymous_user()
)
except User.DoesNotExist: except User.DoesNotExist:
LOGGER.warning("e(trigger): failed to get user", trigger=trigger) LOGGER.warning("e(trigger): failed to get user", trigger=trigger)
return return
@ -99,20 +91,14 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
retry_backoff=True, retry_backoff=True,
base=MonitoredTask, base=MonitoredTask,
) )
def notification_transport( def notification_transport(self: MonitoredTask, notification_pk: int, transport_pk: int):
self: MonitoredTask, notification_pk: int, transport_pk: int
):
"""Send notification over specified transport""" """Send notification over specified transport"""
self.save_on_success = False self.save_on_success = False
try: try:
notification: Notification = Notification.objects.filter( notification: Notification = Notification.objects.filter(pk=notification_pk).first()
pk=notification_pk
).first()
if not notification: if not notification:
return return
transport: NotificationTransport = NotificationTransport.objects.get( transport: NotificationTransport = NotificationTransport.objects.get(pk=transport_pk)
pk=transport_pk
)
transport.send(notification) transport.send(notification)
self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL)) self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL))
except NotificationTransportError as exc: except NotificationTransportError as exc:

View File

@ -38,7 +38,5 @@ class TestEvents(TestCase):
event = Event.new("unittest", model=temp_model) event = Event.new("unittest", model=temp_model)
event.save() # We save to ensure nothing is un-saveable event.save() # We save to ensure nothing is un-saveable
model_content_type = ContentType.objects.get_for_model(temp_model) model_content_type = ContentType.objects.get_for_model(temp_model)
self.assertEqual( self.assertEqual(event.context.get("model").get("app"), model_content_type.app_label)
event.context.get("model").get("app"), model_content_type.app_label
)
self.assertEqual(event.context.get("model").get("pk"), temp_model.pk.hex) self.assertEqual(event.context.get("model").get("pk"), temp_model.pk.hex)

View File

@ -81,12 +81,8 @@ class TestEventsNotifications(TestCase):
execute_mock = MagicMock() execute_mock = MagicMock()
passes = MagicMock(side_effect=PolicyException) passes = MagicMock(side_effect=PolicyException)
with patch( with patch("authentik.policies.event_matcher.models.EventMatcherPolicy.passes", passes):
"authentik.policies.event_matcher.models.EventMatcherPolicy.passes", passes with patch("authentik.events.models.NotificationTransport.send", execute_mock):
):
with patch(
"authentik.events.models.NotificationTransport.send", execute_mock
):
Event.new(EventAction.CUSTOM_PREFIX).save() Event.new(EventAction.CUSTOM_PREFIX).save()
self.assertEqual(passes.call_count, 1) self.assertEqual(passes.call_count, 1)
@ -96,9 +92,7 @@ class TestEventsNotifications(TestCase):
self.group.users.add(user2) self.group.users.add(user2)
self.group.save() self.group.save()
transport = NotificationTransport.objects.create( transport = NotificationTransport.objects.create(name="transport", send_once=True)
name="transport", send_once=True
)
NotificationRule.objects.filter(name__startswith="default").delete() NotificationRule.objects.filter(name__startswith="default").delete()
trigger = NotificationRule.objects.create(name="trigger", group=self.group) trigger = NotificationRule.objects.create(name="trigger", group=self.group)
trigger.transports.add(transport) trigger.transports.add(transport)

View File

@ -14,12 +14,7 @@ from rest_framework.fields import BooleanField, FileField, ReadOnlyField
from rest_framework.parsers import MultiPartParser from rest_framework.parsers import MultiPartParser
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.serializers import ( from rest_framework.serializers import CharField, ModelSerializer, Serializer, SerializerMethodField
CharField,
ModelSerializer,
Serializer,
SerializerMethodField,
)
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -152,11 +147,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet):
], ],
) )
@extend_schema( @extend_schema(
request={ request={"multipart/form-data": inline_serializer("SetIcon", fields={"file": FileField()})},
"multipart/form-data": inline_serializer(
"SetIcon", fields={"file": FileField()}
)
},
responses={ responses={
204: OpenApiResponse(description="Successfully imported flow"), 204: OpenApiResponse(description="Successfully imported flow"),
400: OpenApiResponse(description="Bad request"), 400: OpenApiResponse(description="Bad request"),
@ -221,9 +212,7 @@ class FlowViewSet(UsedByMixin, ModelViewSet):
.order_by("order") .order_by("order")
): ):
for p_index, policy_binding in enumerate( for p_index, policy_binding in enumerate(
get_objects_for_user( get_objects_for_user(request.user, "authentik_policies.view_policybinding")
request.user, "authentik_policies.view_policybinding"
)
.filter(target=stage_binding) .filter(target=stage_binding)
.exclude(policy__isnull=True) .exclude(policy__isnull=True)
.order_by("order") .order_by("order")
@ -256,20 +245,14 @@ class FlowViewSet(UsedByMixin, ModelViewSet):
element: DiagramElement = body[index] element: DiagramElement = body[index]
if element.type == "condition": if element.type == "condition":
# Policy passes, link policy yes to next stage # Policy passes, link policy yes to next stage
footer.append( footer.append(f"{element.identifier}(yes, right)->{body[index + 1].identifier}")
f"{element.identifier}(yes, right)->{body[index + 1].identifier}"
)
# Policy doesn't pass, go to stage after next stage # Policy doesn't pass, go to stage after next stage
no_element = body[index + 1] no_element = body[index + 1]
if no_element.type != "end": if no_element.type != "end":
no_element = body[index + 2] no_element = body[index + 2]
footer.append( footer.append(f"{element.identifier}(no, bottom)->{no_element.identifier}")
f"{element.identifier}(no, bottom)->{no_element.identifier}"
)
elif element.type == "operation": elif element.type == "operation":
footer.append( footer.append(f"{element.identifier}(bottom)->{body[index + 1].identifier}")
f"{element.identifier}(bottom)->{body[index + 1].identifier}"
)
diagram = "\n".join([str(x) for x in header + body + footer]) diagram = "\n".join([str(x) for x in header + body + footer])
return Response({"diagram": diagram}) return Response({"diagram": diagram})

View File

@ -95,9 +95,7 @@ class Command(BaseCommand): # pragma: no cover
"""Output results human readable""" """Output results human readable"""
total_max: int = max([max(inner) for inner in values]) total_max: int = max([max(inner) for inner in values])
total_min: int = min([min(inner) for inner in values]) total_min: int = min([min(inner) for inner in values])
total_avg = sum([sum(inner) for inner in values]) / sum( total_avg = sum([sum(inner) for inner in values]) / sum([len(inner) for inner in values])
[len(inner) for inner in values]
)
print(f"Version: {__version__}") print(f"Version: {__version__}")
print(f"Processes: {len(values)}") print(f"Processes: {len(values)}")

View File

@ -9,21 +9,15 @@ from authentik.stages.identification.models import UserFields
from authentik.stages.password import BACKEND_DJANGO, BACKEND_LDAP from authentik.stages.password import BACKEND_DJANGO, BACKEND_LDAP
def create_default_authentication_flow( def create_default_authentication_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
apps: Apps, schema_editor: BaseDatabaseSchemaEditor
):
Flow = apps.get_model("authentik_flows", "Flow") Flow = apps.get_model("authentik_flows", "Flow")
FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding")
PasswordStage = apps.get_model("authentik_stages_password", "PasswordStage") PasswordStage = apps.get_model("authentik_stages_password", "PasswordStage")
UserLoginStage = apps.get_model("authentik_stages_user_login", "UserLoginStage") UserLoginStage = apps.get_model("authentik_stages_user_login", "UserLoginStage")
IdentificationStage = apps.get_model( IdentificationStage = apps.get_model("authentik_stages_identification", "IdentificationStage")
"authentik_stages_identification", "IdentificationStage"
)
db_alias = schema_editor.connection.alias db_alias = schema_editor.connection.alias
identification_stage, _ = IdentificationStage.objects.using( identification_stage, _ = IdentificationStage.objects.using(db_alias).update_or_create(
db_alias
).update_or_create(
name="default-authentication-identification", name="default-authentication-identification",
defaults={ defaults={
"user_fields": [UserFields.E_MAIL, UserFields.USERNAME], "user_fields": [UserFields.E_MAIL, UserFields.USERNAME],
@ -69,17 +63,13 @@ def create_default_authentication_flow(
) )
def create_default_invalidation_flow( def create_default_invalidation_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
apps: Apps, schema_editor: BaseDatabaseSchemaEditor
):
Flow = apps.get_model("authentik_flows", "Flow") Flow = apps.get_model("authentik_flows", "Flow")
FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding")
UserLogoutStage = apps.get_model("authentik_stages_user_logout", "UserLogoutStage") UserLogoutStage = apps.get_model("authentik_stages_user_logout", "UserLogoutStage")
db_alias = schema_editor.connection.alias db_alias = schema_editor.connection.alias
UserLogoutStage.objects.using(db_alias).update_or_create( UserLogoutStage.objects.using(db_alias).update_or_create(name="default-invalidation-logout")
name="default-invalidation-logout"
)
flow, _ = Flow.objects.using(db_alias).update_or_create( flow, _ = Flow.objects.using(db_alias).update_or_create(
slug="default-invalidation-flow", slug="default-invalidation-flow",

View File

@ -15,16 +15,12 @@ PROMPT_POLICY_EXPRESSION = """# Check if we've not been given a username by the
return 'username' not in context.get('prompt_data', {})""" return 'username' not in context.get('prompt_data', {})"""
def create_default_source_enrollment_flow( def create_default_source_enrollment_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
apps: Apps, schema_editor: BaseDatabaseSchemaEditor
):
Flow = apps.get_model("authentik_flows", "Flow") Flow = apps.get_model("authentik_flows", "Flow")
FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding")
PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding")
ExpressionPolicy = apps.get_model( ExpressionPolicy = apps.get_model("authentik_policies_expression", "ExpressionPolicy")
"authentik_policies_expression", "ExpressionPolicy"
)
PromptStage = apps.get_model("authentik_stages_prompt", "PromptStage") PromptStage = apps.get_model("authentik_stages_prompt", "PromptStage")
Prompt = apps.get_model("authentik_stages_prompt", "Prompt") Prompt = apps.get_model("authentik_stages_prompt", "Prompt")
@ -99,16 +95,12 @@ def create_default_source_enrollment_flow(
) )
def create_default_source_authentication_flow( def create_default_source_authentication_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
apps: Apps, schema_editor: BaseDatabaseSchemaEditor
):
Flow = apps.get_model("authentik_flows", "Flow") Flow = apps.get_model("authentik_flows", "Flow")
FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding")
PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding") PolicyBinding = apps.get_model("authentik_policies", "PolicyBinding")
ExpressionPolicy = apps.get_model( ExpressionPolicy = apps.get_model("authentik_policies_expression", "ExpressionPolicy")
"authentik_policies_expression", "ExpressionPolicy"
)
UserLoginStage = apps.get_model("authentik_stages_user_login", "UserLoginStage") UserLoginStage = apps.get_model("authentik_stages_user_login", "UserLoginStage")

View File

@ -7,9 +7,7 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from authentik.flows.models import FlowDesignation from authentik.flows.models import FlowDesignation
def create_default_provider_authorization_flow( def create_default_provider_authorization_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
apps: Apps, schema_editor: BaseDatabaseSchemaEditor
):
Flow = apps.get_model("authentik_flows", "Flow") Flow = apps.get_model("authentik_flows", "Flow")
FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding") FlowStageBinding = apps.get_model("authentik_flows", "FlowStageBinding")

View File

@ -32,9 +32,7 @@ def create_default_oobe_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor
PromptStage = apps.get_model("authentik_stages_prompt", "PromptStage") PromptStage = apps.get_model("authentik_stages_prompt", "PromptStage")
Prompt = apps.get_model("authentik_stages_prompt", "Prompt") Prompt = apps.get_model("authentik_stages_prompt", "Prompt")
ExpressionPolicy = apps.get_model( ExpressionPolicy = apps.get_model("authentik_policies_expression", "ExpressionPolicy")
"authentik_policies_expression", "ExpressionPolicy"
)
db_alias = schema_editor.connection.alias db_alias = schema_editor.connection.alias
@ -52,9 +50,7 @@ def create_default_oobe_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor
name="default-oobe-prefill-user", name="default-oobe-prefill-user",
defaults={"expression": PREFILL_POLICY_EXPRESSION}, defaults={"expression": PREFILL_POLICY_EXPRESSION},
) )
password_usable_policy, _ = ExpressionPolicy.objects.using( password_usable_policy, _ = ExpressionPolicy.objects.using(db_alias).update_or_create(
db_alias
).update_or_create(
name="default-oobe-password-usable", name="default-oobe-password-usable",
defaults={"expression": PW_USABLE_POLICY_EXPRESSION}, defaults={"expression": PW_USABLE_POLICY_EXPRESSION},
) )
@ -83,9 +79,7 @@ def create_default_oobe_flow(apps: Apps, schema_editor: BaseDatabaseSchemaEditor
prompt_stage, _ = PromptStage.objects.using(db_alias).update_or_create( prompt_stage, _ = PromptStage.objects.using(db_alias).update_or_create(
name="default-oobe-password", name="default-oobe-password",
) )
prompt_stage.fields.set( prompt_stage.fields.set([prompt_header, prompt_email, password_first, password_second])
[prompt_header, prompt_email, password_first, password_second]
)
prompt_stage.save() prompt_stage.save()
user_write, _ = UserWriteStage.objects.using(db_alias).update_or_create( user_write, _ = UserWriteStage.objects.using(db_alias).update_or_create(

View File

@ -138,9 +138,7 @@ class Flow(SerializerModel, PolicyBindingModel):
it is returned as-is""" it is returned as-is"""
if not self.background: if not self.background:
return "/static/dist/assets/images/flow_background.jpg" return "/static/dist/assets/images/flow_background.jpg"
if self.background.name.startswith("http") or self.background.name.startswith( if self.background.name.startswith("http") or self.background.name.startswith("/static"):
"/static"
):
return self.background.name return self.background.name
return self.background.url return self.background.url
@ -165,9 +163,7 @@ class Flow(SerializerModel, PolicyBindingModel):
if result.passing: if result.passing:
LOGGER.debug("with_policy: flow passing", flow=flow) LOGGER.debug("with_policy: flow passing", flow=flow)
return flow return flow
LOGGER.warning( LOGGER.warning("with_policy: flow not passing", flow=flow, messages=result.messages)
"with_policy: flow not passing", flow=flow, messages=result.messages
)
LOGGER.debug("with_policy: no flow found", filters=flow_filter) LOGGER.debug("with_policy: no flow found", filters=flow_filter)
return None return None

View File

@ -78,14 +78,10 @@ class FlowPlan:
marker = self.markers[0] marker = self.markers[0]
if marker.__class__ is not StageMarker: if marker.__class__ is not StageMarker:
LOGGER.debug( LOGGER.debug("f(plan_inst): stage has marker", binding=binding, marker=marker)
"f(plan_inst): stage has marker", binding=binding, marker=marker
)
marked_stage = marker.process(self, binding, http_request) marked_stage = marker.process(self, binding, http_request)
if not marked_stage: if not marked_stage:
LOGGER.debug( LOGGER.debug("f(plan_inst): marker returned none, next stage", binding=binding)
"f(plan_inst): marker returned none, next stage", binding=binding
)
self.bindings.remove(binding) self.bindings.remove(binding)
self.markers.remove(marker) self.markers.remove(marker)
if not self.has_stages: if not self.has_stages:
@ -193,9 +189,9 @@ class FlowPlanner:
if default_context: if default_context:
plan.context = default_context plan.context = default_context
# Check Flow policies # Check Flow policies
for binding in FlowStageBinding.objects.filter( for binding in FlowStageBinding.objects.filter(target__pk=self.flow.pk).order_by(
target__pk=self.flow.pk "order"
).order_by("order"): ):
binding: FlowStageBinding binding: FlowStageBinding
stage = binding.stage stage = binding.stage
marker = StageMarker() marker = StageMarker()

View File

@ -26,9 +26,7 @@ def invalidate_flow_cache(sender, instance, **_):
LOGGER.debug("Invalidating Flow cache", flow=instance, len=total) LOGGER.debug("Invalidating Flow cache", flow=instance, len=total)
if isinstance(instance, FlowStageBinding): if isinstance(instance, FlowStageBinding):
total = delete_cache_prefix(f"{cache_key(instance.target)}*") total = delete_cache_prefix(f"{cache_key(instance.target)}*")
LOGGER.debug( LOGGER.debug("Invalidating Flow cache from FlowStageBinding", binding=instance, len=total)
"Invalidating Flow cache from FlowStageBinding", binding=instance, len=total
)
if isinstance(instance, Stage): if isinstance(instance, Stage):
total = 0 total = 0
for binding in FlowStageBinding.objects.filter(stage=instance): for binding in FlowStageBinding.objects.filter(stage=instance):

View File

@ -42,14 +42,9 @@ class StageView(View):
other things besides the form display. other things besides the form display.
If no user is pending, returns request.user""" If no user is pending, returns request.user"""
if ( if PLAN_CONTEXT_PENDING_USER_IDENTIFIER in self.executor.plan.context and for_display:
PLAN_CONTEXT_PENDING_USER_IDENTIFIER in self.executor.plan.context
and for_display
):
return User( return User(
username=self.executor.plan.context.get( username=self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER_IDENTIFIER),
PLAN_CONTEXT_PENDING_USER_IDENTIFIER
),
email="", email="",
) )
if PLAN_CONTEXT_PENDING_USER in self.executor.plan.context: if PLAN_CONTEXT_PENDING_USER in self.executor.plan.context:

View File

@ -89,14 +89,10 @@ class TestFlowPlanner(TestCase):
planner = FlowPlanner(flow) planner = FlowPlanner(flow)
planner.plan(request) planner.plan(request)
self.assertEqual( self.assertEqual(CACHE_MOCK.set.call_count, 1) # Ensure plan is written to cache
CACHE_MOCK.set.call_count, 1
) # Ensure plan is written to cache
planner = FlowPlanner(flow) planner = FlowPlanner(flow)
planner.plan(request) planner.plan(request)
self.assertEqual( self.assertEqual(CACHE_MOCK.set.call_count, 1) # Ensure nothing is written to cache
CACHE_MOCK.set.call_count, 1
) # Ensure nothing is written to cache
self.assertEqual(CACHE_MOCK.get.call_count, 2) # Get is called twice self.assertEqual(CACHE_MOCK.get.call_count, 2) # Get is called twice
def test_planner_default_context(self): def test_planner_default_context(self):
@ -176,9 +172,7 @@ class TestFlowPlanner(TestCase):
request.session.save() request.session.save()
# Here we patch the dummy policy to evaluate to true so the stage is included # Here we patch the dummy policy to evaluate to true so the stage is included
with patch( with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE):
"authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE
):
planner = FlowPlanner(flow) planner = FlowPlanner(flow)
plan = planner.plan(request) plan = planner.plan(request)

View File

@ -76,9 +76,7 @@ class TestFlowTransfer(TransactionTestCase):
PolicyBinding.objects.create(policy=flow_policy, target=flow, order=0) PolicyBinding.objects.create(policy=flow_policy, target=flow, order=0)
user_login = UserLoginStage.objects.create(name=stage_name) user_login = UserLoginStage.objects.create(name=stage_name)
fsb = FlowStageBinding.objects.create( fsb = FlowStageBinding.objects.create(target=flow, stage=user_login, order=0)
target=flow, stage=user_login, order=0
)
PolicyBinding.objects.create(policy=flow_policy, target=fsb, order=0) PolicyBinding.objects.create(policy=flow_policy, target=fsb, order=0)
exporter = FlowExporter(flow) exporter = FlowExporter(flow)

View File

@ -11,12 +11,7 @@ from authentik.core.models import User
from authentik.flows.challenge import ChallengeTypes from authentik.flows.challenge import ChallengeTypes
from authentik.flows.exceptions import FlowNonApplicableException from authentik.flows.exceptions import FlowNonApplicableException
from authentik.flows.markers import ReevaluateMarker, StageMarker from authentik.flows.markers import ReevaluateMarker, StageMarker
from authentik.flows.models import ( from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding, InvalidResponseAction
Flow,
FlowDesignation,
FlowStageBinding,
InvalidResponseAction,
)
from authentik.flows.planner import FlowPlan, FlowPlanner from authentik.flows.planner import FlowPlan, FlowPlanner
from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView
from authentik.flows.views import NEXT_ARG_NAME, SESSION_KEY_PLAN, FlowExecutorView from authentik.flows.views import NEXT_ARG_NAME, SESSION_KEY_PLAN, FlowExecutorView
@ -61,9 +56,7 @@ class TestFlowExecutor(TestCase):
) )
stage = DummyStage.objects.create(name="dummy") stage = DummyStage.objects.create(name="dummy")
binding = FlowStageBinding(target=flow, stage=stage, order=0) binding = FlowStageBinding(target=flow, stage=stage, order=0)
plan = FlowPlan( plan = FlowPlan(flow_pk=flow.pk.hex + "a", bindings=[binding], markers=[StageMarker()])
flow_pk=flow.pk.hex + "a", bindings=[binding], markers=[StageMarker()]
)
session = self.client.session session = self.client.session
session[SESSION_KEY_PLAN] = plan session[SESSION_KEY_PLAN] = plan
session.save() session.save()
@ -163,9 +156,7 @@ class TestFlowExecutor(TestCase):
target=flow, stage=DummyStage.objects.create(name="dummy2"), order=1 target=flow, stage=DummyStage.objects.create(name="dummy2"), order=1
) )
exec_url = reverse( exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
"authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}
)
# First Request, start planning, renders form # First Request, start planning, renders form
response = self.client.get(exec_url) response = self.client.get(exec_url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
@ -209,13 +200,9 @@ class TestFlowExecutor(TestCase):
PolicyBinding.objects.create(policy=false_policy, target=binding2, order=0) PolicyBinding.objects.create(policy=false_policy, target=binding2, order=0)
# Here we patch the dummy policy to evaluate to true so the stage is included # Here we patch the dummy policy to evaluate to true so the stage is included
with patch( with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE):
"authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE
):
exec_url = reverse( exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
"authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}
)
# First request, run the planner # First request, run the planner
response = self.client.get(exec_url) response = self.client.get(exec_url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
@ -263,13 +250,9 @@ class TestFlowExecutor(TestCase):
PolicyBinding.objects.create(policy=false_policy, target=binding2, order=0) PolicyBinding.objects.create(policy=false_policy, target=binding2, order=0)
# Here we patch the dummy policy to evaluate to true so the stage is included # Here we patch the dummy policy to evaluate to true so the stage is included
with patch( with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE):
"authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE
):
exec_url = reverse( exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
"authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}
)
# First request, run the planner # First request, run the planner
response = self.client.get(exec_url) response = self.client.get(exec_url)
@ -334,13 +317,9 @@ class TestFlowExecutor(TestCase):
PolicyBinding.objects.create(policy=true_policy, target=binding2, order=0) PolicyBinding.objects.create(policy=true_policy, target=binding2, order=0)
# Here we patch the dummy policy to evaluate to true so the stage is included # Here we patch the dummy policy to evaluate to true so the stage is included
with patch( with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE):
"authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE
):
exec_url = reverse( exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
"authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}
)
# First request, run the planner # First request, run the planner
response = self.client.get(exec_url) response = self.client.get(exec_url)
@ -422,13 +401,9 @@ class TestFlowExecutor(TestCase):
PolicyBinding.objects.create(policy=false_policy, target=binding3, order=0) PolicyBinding.objects.create(policy=false_policy, target=binding3, order=0)
# Here we patch the dummy policy to evaluate to true so the stage is included # Here we patch the dummy policy to evaluate to true so the stage is included
with patch( with patch("authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE):
"authentik.policies.dummy.models.DummyPolicy.passes", POLICY_RETURN_TRUE
):
exec_url = reverse( exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
"authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}
)
# First request, run the planner # First request, run the planner
response = self.client.get(exec_url) response = self.client.get(exec_url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
@ -511,9 +486,7 @@ class TestFlowExecutor(TestCase):
) )
request.user = user request.user = user
planner = FlowPlanner(flow) planner = FlowPlanner(flow)
plan = planner.plan( plan = planner.plan(request, default_context={PLAN_CONTEXT_PENDING_USER_IDENTIFIER: ident})
request, default_context={PLAN_CONTEXT_PENDING_USER_IDENTIFIER: ident}
)
executor = FlowExecutorView() executor = FlowExecutorView()
executor.plan = plan executor.plan = plan
@ -542,9 +515,7 @@ class TestFlowExecutor(TestCase):
evaluate_on_plan=False, evaluate_on_plan=False,
re_evaluate_policies=True, re_evaluate_policies=True,
) )
PolicyBinding.objects.create( PolicyBinding.objects.create(policy=reputation_policy, target=deny_binding, order=0)
policy=reputation_policy, target=deny_binding, order=0
)
# Stage 1 is an identification stage # Stage 1 is an identification stage
ident_stage = IdentificationStage.objects.create( ident_stage = IdentificationStage.objects.create(
@ -557,9 +528,7 @@ class TestFlowExecutor(TestCase):
order=1, order=1,
invalid_response_action=InvalidResponseAction.RESTART_WITH_CONTEXT, invalid_response_action=InvalidResponseAction.RESTART_WITH_CONTEXT,
) )
exec_url = reverse( exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
"authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}
)
# First request, run the planner # First request, run the planner
response = self.client.get(exec_url) response = self.client.get(exec_url)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
@ -579,9 +548,7 @@ class TestFlowExecutor(TestCase):
"user_fields": [UserFields.E_MAIL], "user_fields": [UserFields.E_MAIL],
}, },
) )
response = self.client.post( response = self.client.post(exec_url, {"uid_field": "invalid-string"}, follow=True)
exec_url, {"uid_field": "invalid-string"}, follow=True
)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertJSONEqual( self.assertJSONEqual(
force_str(response.content), force_str(response.content),

View File

@ -21,9 +21,7 @@ class TestHelperView(TestCase):
response = self.client.get( response = self.client.get(
reverse("authentik_flows:default-invalidation"), reverse("authentik_flows:default-invalidation"),
) )
expected_url = reverse( expected_url = reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug})
"authentik_core:if-flow", kwargs={"flow_slug": flow.slug}
)
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, expected_url) self.assertEqual(response.url, expected_url)
@ -40,8 +38,6 @@ class TestHelperView(TestCase):
response = self.client.get( response = self.client.get(
reverse("authentik_flows:default-invalidation"), reverse("authentik_flows:default-invalidation"),
) )
expected_url = reverse( expected_url = reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug})
"authentik_core:if-flow", kwargs={"flow_slug": flow.slug}
)
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, expected_url) self.assertEqual(response.url, expected_url)

View File

@ -44,9 +44,7 @@ class FlowBundleEntry:
attrs: dict[str, Any] attrs: dict[str, Any]
@staticmethod @staticmethod
def from_model( def from_model(model: SerializerModel, *extra_identifier_names: str) -> "FlowBundleEntry":
model: SerializerModel, *extra_identifier_names: str
) -> "FlowBundleEntry":
"""Convert a SerializerModel instance to a Bundle Entry""" """Convert a SerializerModel instance to a Bundle Entry"""
identifiers = { identifiers = {
"pk": model.pk, "pk": model.pk,

View File

@ -6,11 +6,7 @@ from uuid import UUID
from django.db.models import Q from django.db.models import Q
from authentik.flows.models import Flow, FlowStageBinding, Stage from authentik.flows.models import Flow, FlowStageBinding, Stage
from authentik.flows.transfer.common import ( from authentik.flows.transfer.common import DataclassEncoder, FlowBundle, FlowBundleEntry
DataclassEncoder,
FlowBundle,
FlowBundleEntry,
)
from authentik.policies.models import Policy, PolicyBinding from authentik.policies.models import Policy, PolicyBinding
from authentik.stages.prompt.models import PromptStage from authentik.stages.prompt.models import PromptStage
@ -37,9 +33,7 @@ class FlowExporter:
def walk_stages(self) -> Iterator[FlowBundleEntry]: def walk_stages(self) -> Iterator[FlowBundleEntry]:
"""Convert all stages attached to self.flow into FlowBundleEntry objects""" """Convert all stages attached to self.flow into FlowBundleEntry objects"""
stages = ( stages = Stage.objects.filter(flow=self.flow).select_related().select_subclasses()
Stage.objects.filter(flow=self.flow).select_related().select_subclasses()
)
for stage in stages: for stage in stages:
if isinstance(stage, PromptStage): if isinstance(stage, PromptStage):
pass pass
@ -56,9 +50,7 @@ class FlowExporter:
a direct foreign key to a policy.""" a direct foreign key to a policy."""
# Special case for PromptStage as that has a direct M2M to policy, we have to ensure # Special case for PromptStage as that has a direct M2M to policy, we have to ensure
# all policies referenced in there we also include here # all policies referenced in there we also include here
prompt_stages = PromptStage.objects.filter(flow=self.flow).values_list( prompt_stages = PromptStage.objects.filter(flow=self.flow).values_list("pk", flat=True)
"pk", flat=True
)
query = Q(bindings__in=self.pbm_uuids) | Q(promptstage__in=prompt_stages) query = Q(bindings__in=self.pbm_uuids) | Q(promptstage__in=prompt_stages)
policies = Policy.objects.filter(query).select_related() policies = Policy.objects.filter(query).select_related()
for policy in policies: for policy in policies:
@ -67,9 +59,7 @@ class FlowExporter:
def walk_policy_bindings(self) -> Iterator[FlowBundleEntry]: def walk_policy_bindings(self) -> Iterator[FlowBundleEntry]:
"""Walk over all policybindings relative to us. This is run at the end of the export, as """Walk over all policybindings relative to us. This is run at the end of the export, as
we are sure all objects exist now.""" we are sure all objects exist now."""
bindings = PolicyBinding.objects.filter( bindings = PolicyBinding.objects.filter(target__in=self.pbm_uuids).select_related()
target__in=self.pbm_uuids
).select_related()
for binding in bindings: for binding in bindings:
yield FlowBundleEntry.from_model(binding, "policy", "target", "order") yield FlowBundleEntry.from_model(binding, "policy", "target", "order")

View File

@ -16,11 +16,7 @@ from rest_framework.serializers import BaseSerializer, Serializer
from structlog.stdlib import BoundLogger, get_logger from structlog.stdlib import BoundLogger, get_logger
from authentik.flows.models import Flow, FlowStageBinding, Stage from authentik.flows.models import Flow, FlowStageBinding, Stage
from authentik.flows.transfer.common import ( from authentik.flows.transfer.common import EntryInvalidError, FlowBundle, FlowBundleEntry
EntryInvalidError,
FlowBundle,
FlowBundleEntry,
)
from authentik.lib.models import SerializerModel from authentik.lib.models import SerializerModel
from authentik.policies.models import Policy, PolicyBinding from authentik.policies.models import Policy, PolicyBinding
from authentik.stages.prompt.models import Prompt from authentik.stages.prompt.models import Prompt
@ -105,9 +101,7 @@ class FlowImporter:
if isinstance(value, dict) and "pk" in value: if isinstance(value, dict) and "pk" in value:
del updated_identifiers[key] del updated_identifiers[key]
updated_identifiers[f"{key}"] = value["pk"] updated_identifiers[f"{key}"] = value["pk"]
existing_models = model.objects.filter( existing_models = model.objects.filter(self.__query_from_identifier(updated_identifiers))
self.__query_from_identifier(updated_identifiers)
)
serializer_kwargs = {} serializer_kwargs = {}
if existing_models.exists(): if existing_models.exists():
@ -120,9 +114,7 @@ class FlowImporter:
) )
serializer_kwargs["instance"] = model_instance serializer_kwargs["instance"] = model_instance
else: else:
self.logger.debug( self.logger.debug("initialise new instance", model=model, **updated_identifiers)
"initialise new instance", model=model, **updated_identifiers
)
full_data = self.__update_pks_for_attrs(entry.attrs) full_data = self.__update_pks_for_attrs(entry.attrs)
full_data.update(updated_identifiers) full_data.update(updated_identifiers)
serializer_kwargs["data"] = full_data serializer_kwargs["data"] = full_data

View File

@ -38,13 +38,7 @@ from authentik.flows.challenge import (
WithUserInfoChallenge, WithUserInfoChallenge,
) )
from authentik.flows.exceptions import EmptyFlowException, FlowNonApplicableException from authentik.flows.exceptions import EmptyFlowException, FlowNonApplicableException
from authentik.flows.models import ( from authentik.flows.models import ConfigurableStage, Flow, FlowDesignation, FlowStageBinding, Stage
ConfigurableStage,
Flow,
FlowDesignation,
FlowStageBinding,
Stage,
)
from authentik.flows.planner import ( from authentik.flows.planner import (
PLAN_CONTEXT_PENDING_USER, PLAN_CONTEXT_PENDING_USER,
PLAN_CONTEXT_REDIRECT, PLAN_CONTEXT_REDIRECT,
@ -155,9 +149,7 @@ class FlowExecutorView(APIView):
try: try:
self.plan = self._initiate_plan() self.plan = self._initiate_plan()
except FlowNonApplicableException as exc: except FlowNonApplicableException as exc:
self._logger.warning( self._logger.warning("f(exec): Flow not applicable to current user", exc=exc)
"f(exec): Flow not applicable to current user", exc=exc
)
return to_stage_response(self.request, self.handle_invalid_flow(exc)) return to_stage_response(self.request, self.handle_invalid_flow(exc))
except EmptyFlowException as exc: except EmptyFlowException as exc:
self._logger.warning("f(exec): Flow is empty", exc=exc) self._logger.warning("f(exec): Flow is empty", exc=exc)
@ -174,9 +166,7 @@ class FlowExecutorView(APIView):
# in which case we just delete the plan and invalidate everything # in which case we just delete the plan and invalidate everything
next_binding = self.plan.next(self.request) next_binding = self.plan.next(self.request)
except Exception as exc: # pylint: disable=broad-except except Exception as exc: # pylint: disable=broad-except
self._logger.warning( self._logger.warning("f(exec): found incompatible flow plan, invalidating run", exc=exc)
"f(exec): found incompatible flow plan, invalidating run", exc=exc
)
keys = cache.keys("flow_*") keys = cache.keys("flow_*")
cache.delete_many(keys) cache.delete_many(keys)
return self.stage_invalid() return self.stage_invalid()
@ -314,9 +304,7 @@ class FlowExecutorView(APIView):
self.request.session[SESSION_KEY_PLAN] = plan self.request.session[SESSION_KEY_PLAN] = plan
kwargs = self.kwargs kwargs = self.kwargs
kwargs.update({"flow_slug": self.flow.slug}) kwargs.update({"flow_slug": self.flow.slug})
return redirect_with_qs( return redirect_with_qs("authentik_api:flow-executor", self.request.GET, **kwargs)
"authentik_api:flow-executor", self.request.GET, **kwargs
)
def _flow_done(self) -> HttpResponse: def _flow_done(self) -> HttpResponse:
"""User Successfully passed all stages""" """User Successfully passed all stages"""
@ -350,9 +338,7 @@ class FlowExecutorView(APIView):
) )
kwargs = self.kwargs kwargs = self.kwargs
kwargs.update({"flow_slug": self.flow.slug}) kwargs.update({"flow_slug": self.flow.slug})
return redirect_with_qs( return redirect_with_qs("authentik_api:flow-executor", self.request.GET, **kwargs)
"authentik_api:flow-executor", self.request.GET, **kwargs
)
# User passed all stages # User passed all stages
self._logger.debug( self._logger.debug(
"f(exec): User passed all stages", "f(exec): User passed all stages",
@ -408,18 +394,13 @@ class FlowErrorResponse(TemplateResponse):
super().__init__(request=request, template="flows/error.html") super().__init__(request=request, template="flows/error.html")
self.error = error self.error = error
def resolve_context( def resolve_context(self, context: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
self, context: Optional[dict[str, Any]]
) -> Optional[dict[str, Any]]:
if not context: if not context:
context = {} context = {}
context["error"] = self.error context["error"] = self.error
if self._request.user and self._request.user.is_authenticated: if self._request.user and self._request.user.is_authenticated:
if ( if self._request.user.is_superuser or self._request.user.group_attributes().get(
self._request.user.is_superuser
or self._request.user.group_attributes().get(
USER_ATTRIBUTE_DEBUG, False USER_ATTRIBUTE_DEBUG, False
)
): ):
context["tb"] = "".join(format_tb(self.error.__traceback__)) context["tb"] = "".join(format_tb(self.error.__traceback__))
return context return context
@ -464,9 +445,7 @@ class ToDefaultFlow(View):
flow_slug=flow.slug, flow_slug=flow.slug,
) )
del self.request.session[SESSION_KEY_PLAN] del self.request.session[SESSION_KEY_PLAN]
return redirect_with_qs( return redirect_with_qs("authentik_core:if-flow", request.GET, flow_slug=flow.slug)
"authentik_core:if-flow", request.GET, flow_slug=flow.slug
)
def to_stage_response(request: HttpRequest, source: HttpResponse) -> HttpResponse: def to_stage_response(request: HttpRequest, source: HttpResponse) -> HttpResponse:

View File

@ -115,9 +115,7 @@ class ConfigLoader:
for key, value in os.environ.items(): for key, value in os.environ.items():
if not key.startswith(ENV_PREFIX): if not key.startswith(ENV_PREFIX):
continue continue
relative_key = ( relative_key = key.replace(f"{ENV_PREFIX}_", "", 1).replace("__", ".").lower()
key.replace(f"{ENV_PREFIX}_", "", 1).replace("__", ".").lower()
)
# Recursively convert path from a.b.c into outer[a][b][c] # Recursively convert path from a.b.c into outer[a][b][c]
current_obj = outer current_obj = outer
dot_parts = relative_key.split(".") dot_parts = relative_key.split(".")

View File

@ -37,15 +37,11 @@ class InheritanceAutoManager(InheritanceManager):
return super().get_queryset().select_subclasses() return super().get_queryset().select_subclasses()
class InheritanceForwardManyToOneDescriptor( class InheritanceForwardManyToOneDescriptor(models.fields.related.ForwardManyToOneDescriptor):
models.fields.related.ForwardManyToOneDescriptor
):
"""Forward ManyToOne Descriptor that selects subclass. Requires InheritanceAutoManager.""" """Forward ManyToOne Descriptor that selects subclass. Requires InheritanceAutoManager."""
def get_queryset(self, **hints): def get_queryset(self, **hints):
return self.field.remote_field.model.objects.db_manager( return self.field.remote_field.model.objects.db_manager(hints=hints).select_subclasses()
hints=hints
).select_subclasses()
class InheritanceForeignKey(models.ForeignKey): class InheritanceForeignKey(models.ForeignKey):

View File

@ -8,11 +8,7 @@ from botocore.exceptions import BotoCoreError
from celery.exceptions import CeleryError from celery.exceptions import CeleryError
from channels.middleware import BaseMiddleware from channels.middleware import BaseMiddleware
from channels_redis.core import ChannelFull from channels_redis.core import ChannelFull
from django.core.exceptions import ( from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation, ValidationError
ImproperlyConfigured,
SuspiciousOperation,
ValidationError,
)
from django.db import InternalError, OperationalError, ProgrammingError from django.db import InternalError, OperationalError, ProgrammingError
from django.http.response import Http404 from django.http.response import Http404
from django_redis.exceptions import ConnectionInterrupted from django_redis.exceptions import ConnectionInterrupted

View File

@ -26,7 +26,5 @@ class TestEvaluator(TestCase):
def test_is_group_member(self): def test_is_group_member(self):
"""Test expr_is_group_member""" """Test expr_is_group_member"""
self.assertFalse( self.assertFalse(
BaseEvaluator.expr_is_group_member( BaseEvaluator.expr_is_group_member(User.objects.get(username="akadmin"), name="test")
User.objects.get(username="akadmin"), name="test"
)
) )

View File

@ -1,17 +1,8 @@
"""Test HTTP Helpers""" """Test HTTP Helpers"""
from django.test import RequestFactory, TestCase from django.test import RequestFactory, TestCase
from authentik.core.models import ( from authentik.core.models import USER_ATTRIBUTE_CAN_OVERRIDE_IP, Token, TokenIntents, User
USER_ATTRIBUTE_CAN_OVERRIDE_IP, from authentik.lib.utils.http import OUTPOST_REMOTE_IP_HEADER, OUTPOST_TOKEN_HEADER, get_client_ip
Token,
TokenIntents,
User,
)
from authentik.lib.utils.http import (
OUTPOST_REMOTE_IP_HEADER,
OUTPOST_TOKEN_HEADER,
get_client_ip,
)
class TestHTTP(TestCase): class TestHTTP(TestCase):

View File

@ -9,9 +9,7 @@ class TestSentry(TestCase):
def test_error_not_sent(self): def test_error_not_sent(self):
"""Test SentryIgnoredError not sent""" """Test SentryIgnoredError not sent"""
self.assertIsNone( self.assertIsNone(before_send({}, {"exc_info": (0, SentryIgnoredException(), 0)}))
before_send({}, {"exc_info": (0, SentryIgnoredException(), 0)})
)
def test_error_sent(self): def test_error_sent(self):
"""Test error sent""" """Test error sent"""

View File

@ -29,16 +29,9 @@ def _get_outpost_override_ip(request: HttpRequest) -> Optional[str]:
"""Get the actual remote IP when set by an outpost. Only """Get the actual remote IP when set by an outpost. Only
allowed when the request is authenticated, by a user with USER_ATTRIBUTE_CAN_OVERRIDE_IP set allowed when the request is authenticated, by a user with USER_ATTRIBUTE_CAN_OVERRIDE_IP set
to outpost""" to outpost"""
from authentik.core.models import ( from authentik.core.models import USER_ATTRIBUTE_CAN_OVERRIDE_IP, Token, TokenIntents
USER_ATTRIBUTE_CAN_OVERRIDE_IP,
Token,
TokenIntents,
)
if ( if OUTPOST_REMOTE_IP_HEADER not in request.META or OUTPOST_TOKEN_HEADER not in request.META:
OUTPOST_REMOTE_IP_HEADER not in request.META
or OUTPOST_TOKEN_HEADER not in request.META
):
return None return None
fake_ip = request.META[OUTPOST_REMOTE_IP_HEADER] fake_ip = request.META[OUTPOST_REMOTE_IP_HEADER]
tokens = Token.filter_not_expired( tokens = Token.filter_not_expired(

View File

@ -12,9 +12,7 @@ def managed_reconcile(self: MonitoredTask):
try: try:
ObjectManager().run() ObjectManager().run()
self.set_status( self.set_status(
TaskResult( TaskResult(TaskResultStatus.SUCCESSFUL, ["Successfully updated managed models."])
TaskResultStatus.SUCCESSFUL, ["Successfully updated managed models."]
)
) )
except DatabaseError as exc: except DatabaseError as exc:
self.set_status(TaskResult(TaskResultStatus.WARNING, [str(exc)])) self.set_status(TaskResult(TaskResultStatus.WARNING, [str(exc)]))

View File

@ -15,12 +15,7 @@ from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import PassiveSerializer, is_dict from authentik.core.api.utils import PassiveSerializer, is_dict
from authentik.core.models import Provider from authentik.core.models import Provider
from authentik.outposts.api.service_connections import ServiceConnectionSerializer from authentik.outposts.api.service_connections import ServiceConnectionSerializer
from authentik.outposts.models import ( from authentik.outposts.models import Outpost, OutpostConfig, OutpostType, default_outpost_config
Outpost,
OutpostConfig,
OutpostType,
default_outpost_config,
)
from authentik.providers.ldap.models import LDAPProvider from authentik.providers.ldap.models import LDAPProvider
from authentik.providers.proxy.models import ProxyProvider from authentik.providers.proxy.models import ProxyProvider

View File

@ -15,11 +15,7 @@ from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import GenericViewSet, ModelViewSet from rest_framework.viewsets import GenericViewSet, ModelViewSet
from authentik.core.api.used_by import UsedByMixin from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import ( from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer
MetaNameSerializer,
PassiveSerializer,
TypeCreateSerializer,
)
from authentik.lib.utils.reflection import all_subclasses from authentik.lib.utils.reflection import all_subclasses
from authentik.outposts.models import ( from authentik.outposts.models import (
DockerServiceConnection, DockerServiceConnection,
@ -129,9 +125,7 @@ class KubernetesServiceConnectionSerializer(ServiceConnectionSerializer):
if kubeconfig == {}: if kubeconfig == {}:
if not self.initial_data["local"]: if not self.initial_data["local"]:
raise serializers.ValidationError( raise serializers.ValidationError(
_( _("You can only use an empty kubeconfig when connecting to a local cluster.")
"You can only use an empty kubeconfig when connecting to a local cluster."
)
) )
# Empty kubeconfig is valid # Empty kubeconfig is valid
return kubeconfig return kubeconfig

View File

@ -59,9 +59,7 @@ class OutpostConsumer(AuthJsonConsumer):
def connect(self): def connect(self):
super().connect() super().connect()
uuid = self.scope["url_route"]["kwargs"]["pk"] uuid = self.scope["url_route"]["kwargs"]["pk"]
outpost = get_objects_for_user( outpost = get_objects_for_user(self.user, "authentik_outposts.view_outpost").filter(pk=uuid)
self.user, "authentik_outposts.view_outpost"
).filter(pk=uuid)
if not outpost.exists(): if not outpost.exists():
raise DenyConnection() raise DenyConnection()
self.accept() self.accept()
@ -129,7 +127,5 @@ class OutpostConsumer(AuthJsonConsumer):
def event_update(self, event): def event_update(self, event):
"""Event handler which is called by post_save signals, Send update instruction""" """Event handler which is called by post_save signals, Send update instruction"""
self.send_json( self.send_json(
asdict( asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE))
WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE)
)
) )

View File

@ -9,11 +9,7 @@ from yaml import safe_dump
from authentik import __version__ from authentik import __version__
from authentik.outposts.controllers.base import BaseController, ControllerException from authentik.outposts.controllers.base import BaseController, ControllerException
from authentik.outposts.models import ( from authentik.outposts.models import DockerServiceConnection, Outpost, ServiceConnectionInvalid
DockerServiceConnection,
Outpost,
ServiceConnectionInvalid,
)
class DockerController(BaseController): class DockerController(BaseController):
@ -37,9 +33,7 @@ class DockerController(BaseController):
def _get_env(self) -> dict[str, str]: def _get_env(self) -> dict[str, str]:
return { return {
"AUTHENTIK_HOST": self.outpost.config.authentik_host.lower(), "AUTHENTIK_HOST": self.outpost.config.authentik_host.lower(),
"AUTHENTIK_INSECURE": str( "AUTHENTIK_INSECURE": str(self.outpost.config.authentik_host_insecure).lower(),
self.outpost.config.authentik_host_insecure
).lower(),
"AUTHENTIK_TOKEN": self.outpost.token.key, "AUTHENTIK_TOKEN": self.outpost.token.key,
} }
@ -141,9 +135,7 @@ class DockerController(BaseController):
.lower() .lower()
!= "unless-stopped" != "unless-stopped"
): ):
self.logger.info( self.logger.info("Container has mis-matched restart policy, re-creating...")
"Container has mis-matched restart policy, re-creating..."
)
self.down() self.down()
return self.up() return self.up()
# Check that container is healthy # Check that container is healthy
@ -157,9 +149,7 @@ class DockerController(BaseController):
if has_been_created: if has_been_created:
# Since we've just created the container, give it some time to start. # Since we've just created the container, give it some time to start.
# If its still not up by then, restart it # If its still not up by then, restart it
self.logger.info( self.logger.info("Container is unhealthy and new, giving it time to boot.")
"Container is unhealthy and new, giving it time to boot."
)
sleep(60) sleep(60)
self.logger.info("Container is unhealthy, restarting...") self.logger.info("Container is unhealthy, restarting...")
container.restart() container.restart()
@ -198,9 +188,7 @@ class DockerController(BaseController):
"ports": ports, "ports": ports,
"environment": { "environment": {
"AUTHENTIK_HOST": self.outpost.config.authentik_host, "AUTHENTIK_HOST": self.outpost.config.authentik_host,
"AUTHENTIK_INSECURE": str( "AUTHENTIK_INSECURE": str(self.outpost.config.authentik_host_insecure),
self.outpost.config.authentik_host_insecure
),
"AUTHENTIK_TOKEN": self.outpost.token.key, "AUTHENTIK_TOKEN": self.outpost.token.key,
}, },
"labels": self._get_labels(), "labels": self._get_labels(),

View File

@ -17,10 +17,7 @@ from kubernetes.client import (
) )
from authentik.outposts.controllers.base import FIELD_MANAGER from authentik.outposts.controllers.base import FIELD_MANAGER
from authentik.outposts.controllers.k8s.base import ( from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler, NeedsUpdate
KubernetesObjectReconciler,
NeedsUpdate,
)
from authentik.outposts.models import Outpost from authentik.outposts.models import Outpost
if TYPE_CHECKING: if TYPE_CHECKING:
@ -124,9 +121,7 @@ class DeploymentReconciler(KubernetesObjectReconciler[V1Deployment]):
) )
def delete(self, reference: V1Deployment): def delete(self, reference: V1Deployment):
return self.api.delete_namespaced_deployment( return self.api.delete_namespaced_deployment(reference.metadata.name, self.namespace)
reference.metadata.name, self.namespace
)
def retrieve(self) -> V1Deployment: def retrieve(self) -> V1Deployment:
return self.api.read_namespaced_deployment(self.name, self.namespace) return self.api.read_namespaced_deployment(self.name, self.namespace)

View File

@ -5,10 +5,7 @@ from typing import TYPE_CHECKING
from kubernetes.client import CoreV1Api, V1Secret from kubernetes.client import CoreV1Api, V1Secret
from authentik.outposts.controllers.base import FIELD_MANAGER from authentik.outposts.controllers.base import FIELD_MANAGER
from authentik.outposts.controllers.k8s.base import ( from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler, NeedsUpdate
KubernetesObjectReconciler,
NeedsUpdate,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from authentik.outposts.controllers.kubernetes import KubernetesController from authentik.outposts.controllers.kubernetes import KubernetesController
@ -38,9 +35,7 @@ class SecretReconciler(KubernetesObjectReconciler[V1Secret]):
return V1Secret( return V1Secret(
metadata=meta, metadata=meta,
data={ data={
"authentik_host": b64string( "authentik_host": b64string(self.controller.outpost.config.authentik_host),
self.controller.outpost.config.authentik_host
),
"authentik_host_insecure": b64string( "authentik_host_insecure": b64string(
str(self.controller.outpost.config.authentik_host_insecure) str(self.controller.outpost.config.authentik_host_insecure)
), ),
@ -54,9 +49,7 @@ class SecretReconciler(KubernetesObjectReconciler[V1Secret]):
) )
def delete(self, reference: V1Secret): def delete(self, reference: V1Secret):
return self.api.delete_namespaced_secret( return self.api.delete_namespaced_secret(reference.metadata.name, self.namespace)
reference.metadata.name, self.namespace
)
def retrieve(self) -> V1Secret: def retrieve(self) -> V1Secret:
return self.api.read_namespaced_secret(self.name, self.namespace) return self.api.read_namespaced_secret(self.name, self.namespace)

View File

@ -4,10 +4,7 @@ from typing import TYPE_CHECKING
from kubernetes.client import CoreV1Api, V1Service, V1ServicePort, V1ServiceSpec from kubernetes.client import CoreV1Api, V1Service, V1ServicePort, V1ServiceSpec
from authentik.outposts.controllers.base import FIELD_MANAGER from authentik.outposts.controllers.base import FIELD_MANAGER
from authentik.outposts.controllers.k8s.base import ( from authentik.outposts.controllers.k8s.base import KubernetesObjectReconciler, NeedsUpdate
KubernetesObjectReconciler,
NeedsUpdate,
)
from authentik.outposts.controllers.k8s.deployment import DeploymentReconciler from authentik.outposts.controllers.k8s.deployment import DeploymentReconciler
if TYPE_CHECKING: if TYPE_CHECKING:
@ -58,9 +55,7 @@ class ServiceReconciler(KubernetesObjectReconciler[V1Service]):
) )
def delete(self, reference: V1Service): def delete(self, reference: V1Service):
return self.api.delete_namespaced_service( return self.api.delete_namespaced_service(reference.metadata.name, self.namespace)
reference.metadata.name, self.namespace
)
def retrieve(self) -> V1Service: def retrieve(self) -> V1Service:
return self.api.read_namespaced_service(self.name, self.namespace) return self.api.read_namespaced_service(self.name, self.namespace)

View File

@ -24,9 +24,7 @@ class KubernetesController(BaseController):
client: ApiClient client: ApiClient
connection: KubernetesServiceConnection connection: KubernetesServiceConnection
def __init__( def __init__(self, outpost: Outpost, connection: KubernetesServiceConnection) -> None:
self, outpost: Outpost, connection: KubernetesServiceConnection
) -> None:
super().__init__(outpost, connection) super().__init__(outpost, connection)
self.client = connection.client() self.client = connection.client()
self.reconcilers = { self.reconcilers = {

View File

@ -15,9 +15,7 @@ class Migration(migrations.Migration):
migrations.AddField( migrations.AddField(
model_name="outpost", model_name="outpost",
name="_config", name="_config",
field=models.JSONField( field=models.JSONField(default=authentik.outposts.models.default_outpost_config),
default=authentik.outposts.models.default_outpost_config
),
), ),
migrations.AddField( migrations.AddField(
model_name="outpost", model_name="outpost",

View File

@ -10,9 +10,7 @@ def fix_missing_token_identifier(apps: Apps, schema_editor: BaseDatabaseSchemaEd
Token = apps.get_model("authentik_core", "Token") Token = apps.get_model("authentik_core", "Token")
from authentik.outposts.models import Outpost from authentik.outposts.models import Outpost
for outpost in ( for outpost in Outpost.objects.using(schema_editor.connection.alias).all().only("pk"):
Outpost.objects.using(schema_editor.connection.alias).all().only("pk")
):
user_identifier = outpost.user_identifier user_identifier = outpost.user_identifier
users = User.objects.filter(username=user_identifier) users = User.objects.filter(username=user_identifier)
if not users.exists(): if not users.exists():

View File

@ -14,9 +14,7 @@ import authentik.lib.models
def migrate_to_service_connection(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): def migrate_to_service_connection(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
db_alias = schema_editor.connection.alias db_alias = schema_editor.connection.alias
Outpost = apps.get_model("authentik_outposts", "Outpost") Outpost = apps.get_model("authentik_outposts", "Outpost")
DockerServiceConnection = apps.get_model( DockerServiceConnection = apps.get_model("authentik_outposts", "DockerServiceConnection")
"authentik_outposts", "DockerServiceConnection"
)
KubernetesServiceConnection = apps.get_model( KubernetesServiceConnection = apps.get_model(
"authentik_outposts", "KubernetesServiceConnection" "authentik_outposts", "KubernetesServiceConnection"
) )
@ -25,9 +23,7 @@ def migrate_to_service_connection(apps: Apps, schema_editor: BaseDatabaseSchemaE
k8s = KubernetesServiceConnection.objects.filter(local=True).first() k8s = KubernetesServiceConnection.objects.filter(local=True).first()
try: try:
for outpost in ( for outpost in Outpost.objects.using(db_alias).all().exclude(deployment_type="custom"):
Outpost.objects.using(db_alias).all().exclude(deployment_type="custom")
):
if outpost.deployment_type == "kubernetes": if outpost.deployment_type == "kubernetes":
outpost.service_connection = k8s outpost.service_connection = k8s
elif outpost.deployment_type == "docker": elif outpost.deployment_type == "docker":

View File

@ -11,9 +11,7 @@ def remove_pb_prefix_users(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
Outpost = apps.get_model("authentik_outposts", "Outpost") Outpost = apps.get_model("authentik_outposts", "Outpost")
for outpost in Outpost.objects.using(alias).all(): for outpost in Outpost.objects.using(alias).all():
matching = User.objects.using(alias).filter( matching = User.objects.using(alias).filter(username=f"pb-outpost-{outpost.uuid.hex}")
username=f"pb-outpost-{outpost.uuid.hex}"
)
if matching.exists(): if matching.exists():
matching.delete() matching.delete()

View File

@ -13,8 +13,6 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name="outpost", model_name="outpost",
name="type", name="type",
field=models.TextField( field=models.TextField(choices=[("proxy", "Proxy"), ("ldap", "Ldap")], default="proxy"),
choices=[("proxy", "Proxy"), ("ldap", "Ldap")], default="proxy"
),
), ),
] ]

View File

@ -64,9 +64,7 @@ class OutpostConfig:
log_level: str = CONFIG.y("log_level") log_level: str = CONFIG.y("log_level")
error_reporting_enabled: bool = CONFIG.y_bool("error_reporting.enabled") error_reporting_enabled: bool = CONFIG.y_bool("error_reporting.enabled")
error_reporting_environment: str = CONFIG.y( error_reporting_environment: str = CONFIG.y("error_reporting.environment", "customer")
"error_reporting.environment", "customer"
)
object_naming_template: str = field(default="ak-outpost-%(name)s") object_naming_template: str = field(default="ak-outpost-%(name)s")
kubernetes_replicas: int = field(default=1) kubernetes_replicas: int = field(default=1)
@ -264,9 +262,7 @@ class KubernetesServiceConnection(OutpostServiceConnection):
client = self.client() client = self.client()
api_instance = VersionApi(client) api_instance = VersionApi(client)
version: VersionInfo = api_instance.get_code() version: VersionInfo = api_instance.get_code()
return OutpostServiceConnectionState( return OutpostServiceConnectionState(version=version.git_version, healthy=True)
version=version.git_version, healthy=True
)
except (OpenApiException, HTTPError, ServiceConnectionInvalid): except (OpenApiException, HTTPError, ServiceConnectionInvalid):
return OutpostServiceConnectionState(version="", healthy=False) return OutpostServiceConnectionState(version="", healthy=False)
@ -360,8 +356,7 @@ class Outpost(ManagedModel):
if isinstance(model_or_perm, models.Model): if isinstance(model_or_perm, models.Model):
model_or_perm: models.Model model_or_perm: models.Model
code_name = ( code_name = (
f"{model_or_perm._meta.app_label}." f"{model_or_perm._meta.app_label}." f"view_{model_or_perm._meta.model_name}"
f"view_{model_or_perm._meta.model_name}"
) )
assign_perm(code_name, user, model_or_perm) assign_perm(code_name, user, model_or_perm)
else: else:
@ -417,9 +412,7 @@ class Outpost(ManagedModel):
self, self,
"authentik_events.add_event", "authentik_events.add_event",
] ]
for provider in ( for provider in Provider.objects.filter(outpost=self).select_related().select_subclasses():
Provider.objects.filter(outpost=self).select_related().select_subclasses()
):
if isinstance(provider, OutpostModel): if isinstance(provider, OutpostModel):
objects.extend(provider.get_required_objects()) objects.extend(provider.get_required_objects())
else: else:

View File

@ -9,11 +9,7 @@ from authentik.core.models import Provider
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
from authentik.lib.utils.reflection import class_to_path from authentik.lib.utils.reflection import class_to_path
from authentik.outposts.models import Outpost, OutpostServiceConnection from authentik.outposts.models import Outpost, OutpostServiceConnection
from authentik.outposts.tasks import ( from authentik.outposts.tasks import CACHE_KEY_OUTPOST_DOWN, outpost_controller, outpost_post_save
CACHE_KEY_OUTPOST_DOWN,
outpost_controller,
outpost_post_save,
)
LOGGER = get_logger() LOGGER = get_logger()
UPDATE_TRIGGERING_MODELS = ( UPDATE_TRIGGERING_MODELS = (
@ -37,9 +33,7 @@ def pre_save_outpost(sender, instance: Outpost, **_):
# Name changes the deployment name, need to recreate # Name changes the deployment name, need to recreate
dirty += old_instance.name != instance.name dirty += old_instance.name != instance.name
# namespace requires re-create # namespace requires re-create
dirty += ( dirty += old_instance.config.kubernetes_namespace != instance.config.kubernetes_namespace
old_instance.config.kubernetes_namespace != instance.config.kubernetes_namespace
)
if bool(dirty): if bool(dirty):
LOGGER.info("Outpost needs re-deployment due to changes", instance=instance) LOGGER.info("Outpost needs re-deployment due to changes", instance=instance)
cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, old_instance) cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, old_instance)

View File

@ -62,9 +62,7 @@ def controller_for_outpost(outpost: Outpost) -> Optional[BaseController]:
def outpost_service_connection_state(connection_pk: Any): def outpost_service_connection_state(connection_pk: Any):
"""Update cached state of a service connection""" """Update cached state of a service connection"""
connection: OutpostServiceConnection = ( connection: OutpostServiceConnection = (
OutpostServiceConnection.objects.filter(pk=connection_pk) OutpostServiceConnection.objects.filter(pk=connection_pk).select_subclasses().first()
.select_subclasses()
.first()
) )
if not connection: if not connection:
return return
@ -157,9 +155,7 @@ def outpost_post_save(model_class: str, model_pk: Any):
outpost_controller.delay(instance.pk) outpost_controller.delay(instance.pk)
if isinstance(instance, (OutpostModel, Outpost)): if isinstance(instance, (OutpostModel, Outpost)):
LOGGER.debug( LOGGER.debug("triggering outpost update from outpostmodel/outpost", instance=instance)
"triggering outpost update from outpostmodel/outpost", instance=instance
)
outpost_send_update(instance) outpost_send_update(instance)
if isinstance(instance, OutpostServiceConnection): if isinstance(instance, OutpostServiceConnection):
@ -208,9 +204,7 @@ def _outpost_single_update(outpost: Outpost, layer=None):
layer = get_channel_layer() layer = get_channel_layer()
for state in OutpostState.for_outpost(outpost): for state in OutpostState.for_outpost(outpost):
for channel in state.channel_ids: for channel in state.channel_ids:
LOGGER.debug( LOGGER.debug("sending update", channel=channel, instance=state.uid, outpost=outpost)
"sending update", channel=channel, instance=state.uid, outpost=outpost
)
async_to_sync(layer.send)(channel, {"type": "event.update"}) async_to_sync(layer.send)(channel, {"type": "event.update"})
@ -231,9 +225,7 @@ def outpost_local_connection():
if Path(kubeconfig_path).exists(): if Path(kubeconfig_path).exists():
LOGGER.debug("Detected kubeconfig") LOGGER.debug("Detected kubeconfig")
kubeconfig_local_name = f"k8s-{gethostname()}" kubeconfig_local_name = f"k8s-{gethostname()}"
if not KubernetesServiceConnection.objects.filter( if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists():
name=kubeconfig_local_name
).exists():
LOGGER.debug("Creating kubeconfig Service Connection") LOGGER.debug("Creating kubeconfig Service Connection")
with open(kubeconfig_path, "r") as _kubeconfig: with open(kubeconfig_path, "r") as _kubeconfig:
KubernetesServiceConnection.objects.create( KubernetesServiceConnection.objects.create(

View File

@ -63,9 +63,7 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
provider = ProxyProvider.objects.create( provider = ProxyProvider.objects.create(
name="test", authorization_flow=Flow.objects.first() name="test", authorization_flow=Flow.objects.first()
) )
invalid = OutpostSerializer( invalid = OutpostSerializer(data={"name": "foo", "providers": [provider.pk], "config": {}})
data={"name": "foo", "providers": [provider.pk], "config": {}}
)
self.assertFalse(invalid.is_valid()) self.assertFalse(invalid.is_valid())
self.assertIn("config", invalid.errors) self.assertIn("config", invalid.errors)
valid = OutpostSerializer( valid = OutpostSerializer(

View File

@ -2,11 +2,7 @@
from typing import OrderedDict from typing import OrderedDict
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from rest_framework.serializers import ( from rest_framework.serializers import ModelSerializer, PrimaryKeyRelatedField, ValidationError
ModelSerializer,
PrimaryKeyRelatedField,
ValidationError,
)
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
from structlog.stdlib import get_logger from structlog.stdlib import get_logger

Some files were not shown because too many files have changed in this diff Show More