*: migrate ui_* properties to functions to allow context being passed
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
4f05dcec89
commit
e4841d54a1
|
@ -104,14 +104,14 @@ class SourceViewSet(
|
|||
)
|
||||
matching_sources: list[UserSettingSerializer] = []
|
||||
for source in _all_sources:
|
||||
user_settings = source.ui_user_settings
|
||||
user_settings = source.ui_user_settings()
|
||||
if not user_settings:
|
||||
continue
|
||||
policy_engine = PolicyEngine(source, request.user, request)
|
||||
policy_engine.build()
|
||||
if not policy_engine.passing:
|
||||
continue
|
||||
source_settings = source.ui_user_settings
|
||||
source_settings = source.ui_user_settings()
|
||||
source_settings.initial_data["object_uid"] = source.slug
|
||||
if not source_settings.is_valid():
|
||||
LOGGER.warning(source_settings.errors)
|
||||
|
|
|
@ -359,13 +359,11 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
|
|||
"""Return component used to edit this object"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def ui_login_button(self) -> Optional[UILoginButton]:
|
||||
def ui_login_button(self, request: HttpRequest) -> Optional[UILoginButton]:
|
||||
"""If source uses a http-based flow, return UI Information about the login
|
||||
button. If source doesn't use http-based flow, return None."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
||||
"""Entrypoint to integrate with User settings. Can either return None if no
|
||||
user settings are available, or UserSettingSerializer."""
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from time import sleep
|
||||
from typing import Callable, Type
|
||||
|
||||
from django.test import TestCase
|
||||
from django.test import RequestFactory, TestCase
|
||||
from django.utils.timezone import now
|
||||
from guardian.shortcuts import get_anonymous_user
|
||||
|
||||
|
@ -30,6 +30,9 @@ class TestModels(TestCase):
|
|||
def source_tester_factory(test_model: Type[Stage]) -> Callable:
|
||||
"""Test source"""
|
||||
|
||||
factory = RequestFactory()
|
||||
request = factory.get("/")
|
||||
|
||||
def tester(self: TestModels):
|
||||
model_class = None
|
||||
if test_model._meta.abstract:
|
||||
|
@ -38,8 +41,8 @@ def source_tester_factory(test_model: Type[Stage]) -> Callable:
|
|||
model_class = test_model()
|
||||
model_class.slug = "test"
|
||||
self.assertIsNotNone(model_class.component)
|
||||
_ = model_class.ui_login_button
|
||||
_ = model_class.ui_user_settings
|
||||
_ = model_class.ui_login_button(request)
|
||||
_ = model_class.ui_user_settings()
|
||||
|
||||
return tester
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ class StageViewSet(
|
|||
stages += list(configurable_stage.objects.all().order_by("name"))
|
||||
matching_stages: list[dict] = []
|
||||
for stage in stages:
|
||||
user_settings = stage.ui_user_settings
|
||||
user_settings = stage.ui_user_settings()
|
||||
if not user_settings:
|
||||
continue
|
||||
user_settings.initial_data["object_uid"] = str(stage.pk)
|
||||
|
|
|
@ -75,7 +75,6 @@ class Stage(SerializerModel):
|
|||
"""Return component used to edit this object"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
||||
"""Entrypoint to integrate with User settings. Can either return None if no
|
||||
user settings are available, or a challenge."""
|
||||
|
|
|
@ -32,7 +32,7 @@ class TestFlowsAPI(APITestCase):
|
|||
|
||||
def test_models(self):
|
||||
"""Test that ui_user_settings returns none"""
|
||||
self.assertIsNone(Stage().ui_user_settings)
|
||||
self.assertIsNone(Stage().ui_user_settings())
|
||||
|
||||
def test_api_serializer(self):
|
||||
"""Test that stage serializer returns the correct type"""
|
||||
|
|
|
@ -23,7 +23,7 @@ def model_tester_factory(test_model: Type[Stage]) -> Callable:
|
|||
model_class = test_model()
|
||||
self.assertTrue(issubclass(model_class.type, StageView))
|
||||
self.assertIsNotNone(test_model.component)
|
||||
_ = model_class.ui_user_settings
|
||||
_ = model_class.ui_user_settings()
|
||||
|
||||
return tester
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from typing import TYPE_CHECKING, Optional, Type
|
||||
|
||||
from django.db import models
|
||||
from django.http.request import HttpRequest
|
||||
from django.urls import reverse
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.serializers import Serializer
|
||||
|
@ -63,11 +64,15 @@ class OAuthSource(Source):
|
|||
|
||||
return OAuthSourceSerializer
|
||||
|
||||
@property
|
||||
def ui_login_button(self) -> UILoginButton:
|
||||
return self.type().ui_login_button()
|
||||
def ui_login_button(self, request: HttpRequest) -> UILoginButton:
|
||||
provider_type = self.type
|
||||
provider = provider_type()
|
||||
return UILoginButton(
|
||||
name=self.name,
|
||||
icon_url=provider.icon_url(),
|
||||
challenge=provider.login_challenge(self, request),
|
||||
)
|
||||
|
||||
@property
|
||||
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
||||
return UserSettingSerializer(
|
||||
data={
|
||||
|
|
|
@ -2,12 +2,13 @@
|
|||
from enum import Enum
|
||||
from typing import Callable, Optional, Type
|
||||
|
||||
from django.http.request import HttpRequest
|
||||
from django.templatetags.static import static
|
||||
from django.urls.base import reverse
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.types import UILoginButton
|
||||
from authentik.flows.challenge import ChallengeTypes, RedirectChallenge
|
||||
from authentik.flows.challenge import Challenge, ChallengeTypes, RedirectChallenge
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||
|
||||
|
@ -40,20 +41,17 @@ class SourceType:
|
|||
"""Get Icon URL for login"""
|
||||
return static(f"authentik/sources/{self.slug}.svg")
|
||||
|
||||
def ui_login_button(self) -> UILoginButton:
|
||||
# pylint: disable=unused-argument
|
||||
def login_challenge(self, source: OAuthSource, request: HttpRequest) -> Challenge:
|
||||
"""Allow types to return custom challenges"""
|
||||
return UILoginButton(
|
||||
challenge=RedirectChallenge(
|
||||
instance={
|
||||
"type": ChallengeTypes.REDIRECT.value,
|
||||
"to": reverse(
|
||||
"authentik_sources_oauth:oauth-client-login",
|
||||
kwargs={"source_slug": self.slug},
|
||||
),
|
||||
}
|
||||
),
|
||||
icon_url=self.icon_url(),
|
||||
name=self.name,
|
||||
return RedirectChallenge(
|
||||
instance={
|
||||
"type": ChallengeTypes.REDIRECT.value,
|
||||
"to": reverse(
|
||||
"authentik_sources_oauth:oauth-client-login",
|
||||
kwargs={"source_slug": self.slug},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from typing import Optional
|
|||
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.db import models
|
||||
from django.http.request import HttpRequest
|
||||
from django.templatetags.static import static
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.fields import CharField
|
||||
|
@ -62,8 +63,7 @@ class PlexSource(Source):
|
|||
|
||||
return PlexSourceSerializer
|
||||
|
||||
@property
|
||||
def ui_login_button(self) -> UILoginButton:
|
||||
def ui_login_button(self, request: HttpRequest) -> UILoginButton:
|
||||
return UILoginButton(
|
||||
challenge=PlexAuthenticationChallenge(
|
||||
{
|
||||
|
@ -77,7 +77,6 @@ class PlexSource(Source):
|
|||
name=self.name,
|
||||
)
|
||||
|
||||
@property
|
||||
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
||||
return UserSettingSerializer(
|
||||
data={
|
||||
|
|
|
@ -167,8 +167,7 @@ class SAMLSource(Source):
|
|||
reverse(f"authentik_sources_saml:{view}", kwargs={"source_slug": self.slug})
|
||||
)
|
||||
|
||||
@property
|
||||
def ui_login_button(self) -> UILoginButton:
|
||||
def ui_login_button(self, request: HttpRequest) -> UILoginButton:
|
||||
return UILoginButton(
|
||||
challenge=RedirectChallenge(
|
||||
instance={
|
||||
|
|
|
@ -48,7 +48,6 @@ class AuthenticatorDuoStage(ConfigurableStage, Stage):
|
|||
def component(self) -> str:
|
||||
return "ak-stage-authenticator-duo-form"
|
||||
|
||||
@property
|
||||
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
||||
return UserSettingSerializer(
|
||||
data={
|
||||
|
|
|
@ -141,7 +141,6 @@ class AuthenticatorSMSStage(ConfigurableStage, Stage):
|
|||
def component(self) -> str:
|
||||
return "ak-stage-authenticator-sms-form"
|
||||
|
||||
@property
|
||||
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
||||
return UserSettingSerializer(
|
||||
data={
|
||||
|
|
|
@ -31,7 +31,6 @@ class AuthenticatorStaticStage(ConfigurableStage, Stage):
|
|||
def component(self) -> str:
|
||||
return "ak-stage-authenticator-static-form"
|
||||
|
||||
@property
|
||||
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
||||
return UserSettingSerializer(
|
||||
data={
|
||||
|
|
|
@ -38,7 +38,6 @@ class AuthenticatorTOTPStage(ConfigurableStage, Stage):
|
|||
def component(self) -> str:
|
||||
return "ak-stage-authenticator-totp-form"
|
||||
|
||||
@property
|
||||
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
||||
return UserSettingSerializer(
|
||||
data={
|
||||
|
|
|
@ -34,7 +34,6 @@ class AuthenticateWebAuthnStage(ConfigurableStage, Stage):
|
|||
def component(self) -> str:
|
||||
return "ak-stage-authenticator-webauthn-form"
|
||||
|
||||
@property
|
||||
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
||||
return UserSettingSerializer(
|
||||
data={
|
||||
|
|
|
@ -191,7 +191,7 @@ class IdentificationStageView(ChallengeStageView):
|
|||
current_stage.sources.filter(enabled=True).order_by("name").select_subclasses()
|
||||
)
|
||||
for source in sources:
|
||||
ui_login_button = source.ui_login_button
|
||||
ui_login_button = source.ui_login_button(self.request)
|
||||
if ui_login_button:
|
||||
button = asdict(ui_login_button)
|
||||
button["challenge"] = ui_login_button.challenge.data
|
||||
|
|
|
@ -63,7 +63,6 @@ class PasswordStage(ConfigurableStage, Stage):
|
|||
def component(self) -> str:
|
||||
return "ak-stage-password-form"
|
||||
|
||||
@property
|
||||
def ui_user_settings(self) -> Optional[UserSettingSerializer]:
|
||||
if not self.configure_flow:
|
||||
return None
|
||||
|
|
Reference in a new issue