sources/oauth: fix resolution of sources' provider type
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
5e67f68f2b
commit
2b48ba4103
|
@ -1,6 +1,6 @@
|
|||
"""Source type manager"""
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
|
@ -9,9 +9,6 @@ from authentik.sources.oauth.views.redirect import OAuthRedirect
|
|||
|
||||
LOGGER = get_logger()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
|
||||
|
||||
class RequestKind(Enum):
|
||||
"""Enum of OAuth Request types"""
|
||||
|
@ -69,13 +66,13 @@ class SourceTypeManager:
|
|||
LOGGER.warning(
|
||||
"no matching type found, using default",
|
||||
wanted=type_name,
|
||||
have=[x.name for x in self.__sources],
|
||||
have=[x.slug for x in self.__sources],
|
||||
)
|
||||
return found_type
|
||||
|
||||
def find(self, source: "OAuthSource", kind: RequestKind) -> Callable:
|
||||
def find(self, type_name: str, kind: RequestKind) -> Callable:
|
||||
"""Find fitting Source Type"""
|
||||
found_type = self.find_type(source)
|
||||
found_type = self.find_type(type_name)
|
||||
if kind == RequestKind.CALLBACK:
|
||||
return found_type.callback_view
|
||||
if kind == RequestKind.REDIRECT:
|
||||
|
|
|
@ -21,6 +21,6 @@ class DispatcherView(View):
|
|||
if not slug:
|
||||
raise Http404
|
||||
source = get_object_or_404(OAuthSource, slug=slug)
|
||||
view = MANAGER.find(source, kind=RequestKind(self.kind))
|
||||
view = MANAGER.find(source.provider_type, kind=RequestKind(self.kind))
|
||||
LOGGER.debug("dispatching OAuth2 request to", view=view, kind=self.kind)
|
||||
return view.as_view()(*args, **kwargs)
|
||||
|
|
Reference in a new issue