flows: remove need for post() wrapper by using dispatch (#6765)

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L 2023-09-05 22:15:03 +02:00 committed by GitHub
parent 7cbce1bb3d
commit e373bae189
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 15 additions and 48 deletions

View File

@ -48,7 +48,7 @@ class Action(Enum):
class MessageStage(StageView): class MessageStage(StageView):
"""Show a pre-configured message after the flow is done""" """Show a pre-configured message after the flow is done"""
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Show a pre-configured message after the flow is done""" """Show a pre-configured message after the flow is done"""
message = getattr(self.executor.current_stage, "message", "") message = getattr(self.executor.current_stage, "message", "")
level = getattr(self.executor.current_stage, "level", messages.SUCCESS) level = getattr(self.executor.current_stage, "level", messages.SUCCESS)
@ -59,10 +59,6 @@ class MessageStage(StageView):
) )
return self.executor.stage_ok() return self.executor.stage_ok()
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)
class SourceFlowManager: class SourceFlowManager:
"""Help sources decide what they should do after authorization. Based on source settings and """Help sources decide what they should do after authorization. Based on source settings and

View File

@ -13,7 +13,7 @@ class PostUserEnrollmentStage(StageView):
"""Dynamically injected stage which saves the Connection after """Dynamically injected stage which saves the Connection after
the user has been enrolled.""" the user has been enrolled."""
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Stage used after the user has been enrolled""" """Stage used after the user has been enrolled"""
connection: UserSourceConnection = self.executor.plan.context[ connection: UserSourceConnection = self.executor.plan.context[
PLAN_CONTEXT_SOURCES_CONNECTION PLAN_CONTEXT_SOURCES_CONNECTION
@ -27,7 +27,3 @@ class PostUserEnrollmentStage(StageView):
source=connection.source, source=connection.source,
).from_http(self.request) ).from_http(self.request)
return self.executor.stage_ok() return self.executor.stage_ok()
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)

View File

@ -21,8 +21,9 @@ def view_tester_factory(view_class: type[StageView]) -> Callable:
def tester(self: TestViews): def tester(self: TestViews):
model_class = view_class(self.exec) model_class = view_class(self.exec)
self.assertIsNotNone(model_class.post) if not hasattr(model_class, "dispatch"):
self.assertIsNotNone(model_class.get) self.assertIsNotNone(model_class.post)
self.assertIsNotNone(model_class.get)
return tester return tester

View File

@ -295,7 +295,7 @@ class FlowExecutorView(APIView):
span.set_data("Method", "GET") span.set_data("Method", "GET")
span.set_data("authentik Stage", self.current_stage_view) span.set_data("authentik Stage", self.current_stage_view)
span.set_data("authentik Flow", self.flow.slug) span.set_data("authentik Flow", self.flow.slug)
stage_response = self.current_stage_view.get(request, *args, **kwargs) stage_response = self.current_stage_view.dispatch(request)
return to_stage_response(request, stage_response) return to_stage_response(request, stage_response)
except Exception as exc: # pylint: disable=broad-except except Exception as exc: # pylint: disable=broad-except
return self.handle_exception(exc) return self.handle_exception(exc)
@ -339,7 +339,7 @@ class FlowExecutorView(APIView):
span.set_data("Method", "POST") span.set_data("Method", "POST")
span.set_data("authentik Stage", self.current_stage_view) span.set_data("authentik Stage", self.current_stage_view)
span.set_data("authentik Flow", self.flow.slug) span.set_data("authentik Flow", self.flow.slug)
stage_response = self.current_stage_view.post(request, *args, **kwargs) stage_response = self.current_stage_view.dispatch(request)
return to_stage_response(request, stage_response) return to_stage_response(request, stage_response)
except Exception as exc: # pylint: disable=broad-except except Exception as exc: # pylint: disable=broad-except
return self.handle_exception(exc) return self.handle_exception(exc)

View File

@ -7,10 +7,6 @@ from authentik.flows.stage import StageView
class DenyStageView(StageView): class DenyStageView(StageView):
"""Cancels the current flow""" """Cancels the current flow"""
def get(self, request: HttpRequest) -> HttpResponse: def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Cancels the current flow""" """Cancels the current flow"""
return self.executor.stage_invalid() return self.executor.stage_invalid()
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)

View File

@ -21,10 +21,6 @@ INVITATION = "invitation"
class InvitationStageView(StageView): class InvitationStageView(StageView):
"""Finalise Authentication flow by logging the user in""" """Finalise Authentication flow by logging the user in"""
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)
def get_token(self) -> Optional[str]: def get_token(self) -> Optional[str]:
"""Get token from saved get-arguments or prompt_data""" """Get token from saved get-arguments or prompt_data"""
# Check for ?token= and ?itoken= # Check for ?token= and ?itoken=
@ -55,7 +51,7 @@ class InvitationStageView(StageView):
return None return None
return invite return invite
def get(self, request: HttpRequest) -> HttpResponse: def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Apply data to the current flow based on a URL""" """Apply data to the current flow based on a URL"""
stage: InvitationStage = self.executor.current_stage stage: InvitationStage = self.executor.current_stage

View File

@ -11,11 +11,7 @@ from authentik.flows.stage import StageView
class UserDeleteStageView(StageView): class UserDeleteStageView(StageView):
"""Finalise unenrollment flow by deleting the user object.""" """Finalise unenrollment flow by deleting the user object."""
def post(self, request: HttpRequest) -> HttpResponse: def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)
def get(self, request: HttpRequest) -> HttpResponse:
"""Delete currently pending user""" """Delete currently pending user"""
user = self.get_pending_user() user = self.get_pending_user()
if not user.is_authenticated: if not user.is_authenticated:

View File

@ -41,17 +41,11 @@ class UserLoginStageView(ChallengeStageView):
} }
) )
def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests""" """Check for remember_me, and do login"""
stage: UserLoginStage = self.executor.current_stage stage: UserLoginStage = self.executor.current_stage
if timedelta_from_string(stage.remember_me_offset).total_seconds() > 0: if timedelta_from_string(stage.remember_me_offset).total_seconds() > 0:
return super().post(request, *args, **kwargs) return super().dispatch(request)
return self.do_login(request)
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
stage: UserLoginStage = self.executor.current_stage
if timedelta_from_string(stage.remember_me_offset).total_seconds() > 0:
return super().get(request, *args, **kwargs)
return self.do_login(request) return self.do_login(request)
def challenge_valid(self, response: UserLoginChallengeResponse) -> HttpResponse: def challenge_valid(self, response: UserLoginChallengeResponse) -> HttpResponse:

View File

@ -8,7 +8,7 @@ from authentik.flows.stage import StageView
class UserLogoutStageView(StageView): class UserLogoutStageView(StageView):
"""Finalise Authentication flow by logging the user in""" """Finalise Authentication flow by logging the user in"""
def get(self, request: HttpRequest) -> HttpResponse: def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Remove the user from the current session""" """Remove the user from the current session"""
self.logger.debug( self.logger.debug(
"Logged out", "Logged out",
@ -17,7 +17,3 @@ class UserLogoutStageView(StageView):
) )
logout(self.request) logout(self.request)
return self.executor.stage_ok() return self.executor.stage_ok()
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)

View File

@ -51,10 +51,6 @@ class UserWriteStageView(StageView):
attrs = attrs.get(comp) attrs = attrs.get(comp)
attrs[parts[-1]] = value attrs[parts[-1]] = value
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)
def ensure_user(self) -> tuple[Optional[User], bool]: def ensure_user(self) -> tuple[Optional[User], bool]:
"""Ensure a user exists""" """Ensure a user exists"""
user_created = False user_created = False
@ -127,7 +123,7 @@ class UserWriteStageView(StageView):
if connection.source.name not in user.attributes[USER_ATTRIBUTE_SOURCES]: if connection.source.name not in user.attributes[USER_ATTRIBUTE_SOURCES]:
user.attributes[USER_ATTRIBUTE_SOURCES].append(connection.source.name) user.attributes[USER_ATTRIBUTE_SOURCES].append(connection.source.name)
def get(self, request: HttpRequest) -> HttpResponse: def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Save data in the current flow to the currently pending user. If no user is pending, """Save data in the current flow to the currently pending user. If no user is pending,
a new user is created.""" a new user is created."""
if PLAN_CONTEXT_PROMPT not in self.executor.plan.context: if PLAN_CONTEXT_PROMPT not in self.executor.plan.context: