sources/oauth: revamp types system, move default URLs to type

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-04-02 14:59:58 +02:00
parent 83fc22005c
commit 1daba5db87
13 changed files with 257 additions and 266 deletions

View File

@ -2,10 +2,11 @@
from django.urls.base import reverse_lazy from django.urls.base import reverse_lazy
from drf_yasg.utils import swagger_auto_schema from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.fields import CharField, SerializerMethodField from rest_framework.fields import BooleanField, CharField, SerializerMethodField
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
from drf_yasg.utils import swagger_serializer_method
from authentik.core.api.sources import SourceSerializer from authentik.core.api.sources import SourceSerializer
from authentik.core.api.utils import PassiveSerializer from authentik.core.api.utils import PassiveSerializer
@ -13,6 +14,18 @@ from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.manager import MANAGER from authentik.sources.oauth.types.manager import MANAGER
class SourceTypeSerializer(PassiveSerializer):
"""Serializer for SourceType"""
name = CharField(required=True)
slug = CharField(required=True)
urls_customizable = BooleanField()
request_token_url = CharField(read_only=True, allow_null=True)
authorization_url = CharField(read_only=True, allow_null=True)
access_token_url = CharField(read_only=True, allow_null=True)
profile_url = CharField(read_only=True, allow_null=True)
class OAuthSourceSerializer(SourceSerializer): class OAuthSourceSerializer(SourceSerializer):
"""OAuth Source Serializer""" """OAuth Source Serializer"""
@ -28,6 +41,13 @@ class OAuthSourceSerializer(SourceSerializer):
return relative_url return relative_url
return self.context["request"].build_absolute_uri(relative_url) return self.context["request"].build_absolute_uri(relative_url)
type = SerializerMethodField()
@swagger_serializer_method(serializer_or_field=SourceTypeSerializer)
def get_type(self, instace: OAuthSource) -> SourceTypeSerializer:
"""Get source's type configuration"""
return SourceTypeSerializer(instace.type).data
class Meta: class Meta:
model = OAuthSource model = OAuthSource
fields = SourceSerializer.Meta.fields + [ fields = SourceSerializer.Meta.fields + [
@ -39,17 +59,11 @@ class OAuthSourceSerializer(SourceSerializer):
"consumer_key", "consumer_key",
"consumer_secret", "consumer_secret",
"callback_url", "callback_url",
"type",
] ]
extra_kwargs = {"consumer_secret": {"write_only": True}} extra_kwargs = {"consumer_secret": {"write_only": True}}
class OAuthSourceProviderType(PassiveSerializer):
"""OAuth Provider"""
name = CharField(required=True)
value = CharField(required=True)
class OAuthSourceViewSet(ModelViewSet): class OAuthSourceViewSet(ModelViewSet):
"""Source Viewset""" """Source Viewset"""
@ -57,16 +71,11 @@ class OAuthSourceViewSet(ModelViewSet):
serializer_class = OAuthSourceSerializer serializer_class = OAuthSourceSerializer
lookup_field = "slug" lookup_field = "slug"
@swagger_auto_schema(responses={200: OAuthSourceProviderType(many=True)}) @swagger_auto_schema(responses={200: SourceTypeSerializer(many=True)})
@action(detail=False, pagination_class=None, filter_backends=[]) @action(detail=False, pagination_class=None, filter_backends=[])
def provider_types(self, request: Request) -> Response: def source_types(self, request: Request) -> Response:
"""Get all creatable source types""" """Get all creatable source types"""
data = [] data = []
for key, value in MANAGER.get_name_tuple(): for source_type in MANAGER.get():
data.append( data.append(SourceTypeSerializer(source_type).data)
{ return Response(data)
"name": value,
"value": key,
}
)
return Response(OAuthSourceProviderType(data, many=True).data)

View File

@ -1,138 +0,0 @@
"""authentik oauth_client forms"""
from django import forms
from authentik.flows.models import Flow, FlowDesignation
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.manager import MANAGER
class OAuthSourceForm(forms.ModelForm):
"""OAuthSource Form"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fields["authentication_flow"].queryset = Flow.objects.filter(
designation=FlowDesignation.AUTHENTICATION
)
self.fields["authentication_flow"].required = True
self.fields["enrollment_flow"].queryset = Flow.objects.filter(
designation=FlowDesignation.ENROLLMENT
)
self.fields["enrollment_flow"].required = True
if hasattr(self.Meta, "overrides"):
for overide_field, overide_value in getattr(self.Meta, "overrides").items():
self.fields[overide_field].initial = overide_value
self.fields[overide_field].widget.attrs["readonly"] = "readonly"
class Meta:
model = OAuthSource
fields = [
"name",
"slug",
"enabled",
"policy_engine_mode",
"authentication_flow",
"enrollment_flow",
"provider_type",
"request_token_url",
"authorization_url",
"access_token_url",
"profile_url",
"consumer_key",
"consumer_secret",
]
widgets = {
"name": forms.TextInput(),
"consumer_key": forms.TextInput(),
"consumer_secret": forms.TextInput(),
"provider_type": forms.Select(choices=MANAGER.get_name_tuple()),
}
class GitHubOAuthSourceForm(OAuthSourceForm):
"""OAuth Source form with pre-determined URL for GitHub"""
class Meta(OAuthSourceForm.Meta):
overrides = {
"provider_type": "github",
"request_token_url": "",
"authorization_url": "https://github.com/login/oauth/authorize",
"access_token_url": "https://github.com/login/oauth/access_token",
"profile_url": "https://api.github.com/user",
}
class TwitterOAuthSourceForm(OAuthSourceForm):
"""OAuth Source form with pre-determined URL for Twitter"""
class Meta(OAuthSourceForm.Meta):
overrides = {
"provider_type": "twitter",
"request_token_url": "https://api.twitter.com/oauth/request_token",
"authorization_url": "https://api.twitter.com/oauth/authenticate",
"access_token_url": "https://api.twitter.com/oauth/access_token",
"profile_url": (
"https://api.twitter.com/1.1/account/"
"verify_credentials.json?include_email=true"
),
}
class FacebookOAuthSourceForm(OAuthSourceForm):
"""OAuth Source form with pre-determined URL for Facebook"""
class Meta(OAuthSourceForm.Meta):
overrides = {
"provider_type": "facebook",
"request_token_url": "",
"authorization_url": "https://www.facebook.com/v7.0/dialog/oauth",
"access_token_url": "https://graph.facebook.com/v7.0/oauth/access_token",
"profile_url": "https://graph.facebook.com/v7.0/me?fields=id,name,email",
}
class DiscordOAuthSourceForm(OAuthSourceForm):
"""OAuth Source form with pre-determined URL for Discord"""
class Meta(OAuthSourceForm.Meta):
overrides = {
"provider_type": "discord",
"request_token_url": "",
"authorization_url": "https://discord.com/api/oauth2/authorize",
"access_token_url": "https://discord.com/api/oauth2/token",
"profile_url": "https://discord.com/api/users/@me",
}
class GoogleOAuthSourceForm(OAuthSourceForm):
"""OAuth Source form with pre-determined URL for Google"""
class Meta(OAuthSourceForm.Meta):
overrides = {
"provider_type": "google",
"request_token_url": "",
"authorization_url": "https://accounts.google.com/o/oauth2/auth",
"access_token_url": "https://accounts.google.com/o/oauth2/token",
"profile_url": "https://www.googleapis.com/oauth2/v1/userinfo",
}
class AzureADOAuthSourceForm(OAuthSourceForm):
"""OAuth Source form with pre-determined URL for AzureAD"""
class Meta(OAuthSourceForm.Meta):
overrides = {
"provider_type": "azure-ad",
"request_token_url": "",
"authorization_url": "https://login.microsoftonline.com/common/oauth2/authorize",
"access_token_url": "https://login.microsoftonline.com/common/oauth2/token",
"profile_url": "https://graph.windows.net/myorganization/me?api-version=1.6",
}

View File

@ -1,8 +1,7 @@
"""OAuth Client models""" """OAuth Client models"""
from typing import Optional, Type from typing import TYPE_CHECKING, Optional, Type
from django.db import models from django.db import models
from django.forms import ModelForm
from django.templatetags.static import static from django.templatetags.static import static
from django.urls import reverse from django.urls import reverse
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -11,6 +10,9 @@ from rest_framework.serializers import Serializer
from authentik.core.models import Source, UserSourceConnection from authentik.core.models import Source, UserSourceConnection
from authentik.core.types import UILoginButton, UserSettingSerializer from authentik.core.types import UILoginButton, UserSettingSerializer
if TYPE_CHECKING:
from authentik.sources.oauth.types.manager import SourceType
class OAuthSource(Source): class OAuthSource(Source):
"""Login using a Generic OAuth provider.""" """Login using a Generic OAuth provider."""
@ -43,10 +45,15 @@ class OAuthSource(Source):
consumer_secret = models.TextField() consumer_secret = models.TextField()
@property @property
def form(self) -> Type[ModelForm]: def type(self) -> "SourceType":
from authentik.sources.oauth.forms import OAuthSourceForm """Return the provider instance for this source"""
from authentik.sources.oauth.types.manager import MANAGER
return OAuthSourceForm return MANAGER.find_type(self)
@property
def component(self) -> str:
return "ak-source-oauth-form"
@property @property
def serializer(self) -> Type[Serializer]: def serializer(self) -> Type[Serializer]:
@ -86,12 +93,6 @@ class OAuthSource(Source):
class GitHubOAuthSource(OAuthSource): class GitHubOAuthSource(OAuthSource):
"""Social Login using GitHub.com or a GitHub-Enterprise Instance.""" """Social Login using GitHub.com or a GitHub-Enterprise Instance."""
@property
def form(self) -> Type[ModelForm]:
from authentik.sources.oauth.forms import GitHubOAuthSourceForm
return GitHubOAuthSourceForm
class Meta: class Meta:
abstract = True abstract = True
@ -102,12 +103,6 @@ class GitHubOAuthSource(OAuthSource):
class TwitterOAuthSource(OAuthSource): class TwitterOAuthSource(OAuthSource):
"""Social Login using Twitter.com""" """Social Login using Twitter.com"""
@property
def form(self) -> Type[ModelForm]:
from authentik.sources.oauth.forms import TwitterOAuthSourceForm
return TwitterOAuthSourceForm
class Meta: class Meta:
abstract = True abstract = True
@ -118,12 +113,6 @@ class TwitterOAuthSource(OAuthSource):
class FacebookOAuthSource(OAuthSource): class FacebookOAuthSource(OAuthSource):
"""Social Login using Facebook.com.""" """Social Login using Facebook.com."""
@property
def form(self) -> Type[ModelForm]:
from authentik.sources.oauth.forms import FacebookOAuthSourceForm
return FacebookOAuthSourceForm
class Meta: class Meta:
abstract = True abstract = True
@ -134,12 +123,6 @@ class FacebookOAuthSource(OAuthSource):
class DiscordOAuthSource(OAuthSource): class DiscordOAuthSource(OAuthSource):
"""Social Login using Discord.""" """Social Login using Discord."""
@property
def form(self) -> Type[ModelForm]:
from authentik.sources.oauth.forms import DiscordOAuthSourceForm
return DiscordOAuthSourceForm
class Meta: class Meta:
abstract = True abstract = True
@ -150,12 +133,6 @@ class DiscordOAuthSource(OAuthSource):
class GoogleOAuthSource(OAuthSource): class GoogleOAuthSource(OAuthSource):
"""Social Login using Google or Gsuite.""" """Social Login using Google or Gsuite."""
@property
def form(self) -> Type[ModelForm]:
from authentik.sources.oauth.forms import GoogleOAuthSourceForm
return GoogleOAuthSourceForm
class Meta: class Meta:
abstract = True abstract = True
@ -166,12 +143,6 @@ class GoogleOAuthSource(OAuthSource):
class AzureADOAuthSource(OAuthSource): class AzureADOAuthSource(OAuthSource):
"""Social Login using Azure AD.""" """Social Login using Azure AD."""
@property
def form(self) -> Type[ModelForm]:
from authentik.sources.oauth.forms import AzureADOAuthSourceForm
return AzureADOAuthSourceForm
class Meta: class Meta:
abstract = True abstract = True
@ -182,12 +153,6 @@ class AzureADOAuthSource(OAuthSource):
class OpenIDOAuthSource(OAuthSource): class OpenIDOAuthSource(OAuthSource):
"""Login using a Generic OpenID-Connect compliant provider.""" """Login using a Generic OpenID-Connect compliant provider."""
@property
def form(self) -> Type[ModelForm]:
from authentik.sources.oauth.forms import OAuthSourceForm
return OAuthSourceForm
class Meta: class Meta:
abstract = True abstract = True

View File

@ -3,11 +3,10 @@ from typing import Any
from uuid import UUID from uuid import UUID
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, RequestKind from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
@MANAGER.source(kind=RequestKind.CALLBACK, name="Azure AD")
class AzureADOAuthCallback(OAuthCallback): class AzureADOAuthCallback(OAuthCallback):
"""AzureAD OAuth2 Callback""" """AzureAD OAuth2 Callback"""
@ -26,3 +25,18 @@ class AzureADOAuthCallback(OAuthCallback):
"email": mail, "email": mail,
"name": info.get("displayName"), "name": info.get("displayName"),
} }
@MANAGER.type()
class AzureADType(SourceType):
"""Azure AD Type definition"""
callback_view = AzureADOAuthCallback
name = "Azure AD"
slug = "azure-ad"
urls_customizable = True
authorization_url = "https://login.microsoftonline.com/common/oauth2/authorize"
access_token_url = "https://login.microsoftonline.com/common/oauth2/token" # nosec
profile_url = "https://graph.windows.net/myorganization/me?api-version=1.6"

View File

@ -2,12 +2,11 @@
from typing import Any from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, RequestKind from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@MANAGER.source(kind=RequestKind.REDIRECT, name="Discord")
class DiscordOAuthRedirect(OAuthRedirect): class DiscordOAuthRedirect(OAuthRedirect):
"""Discord OAuth2 Redirect""" """Discord OAuth2 Redirect"""
@ -17,7 +16,6 @@ class DiscordOAuthRedirect(OAuthRedirect):
} }
@MANAGER.source(kind=RequestKind.CALLBACK, name="Discord")
class DiscordOAuth2Callback(OAuthCallback): class DiscordOAuth2Callback(OAuthCallback):
"""Discord OAuth2 Callback""" """Discord OAuth2 Callback"""
@ -32,3 +30,17 @@ class DiscordOAuth2Callback(OAuthCallback):
"email": info.get("email", None), "email": info.get("email", None),
"name": info.get("username"), "name": info.get("username"),
} }
@MANAGER.type()
class DiscordType(SourceType):
"""Discord Type definition"""
callback_view = DiscordOAuth2Callback
redirect_view = DiscordOAuthRedirect
name = "Discord"
slug = "discord"
authorization_url = "https://discord.com/api/oauth2/authorize"
access_token_url = "https://discord.com/api/oauth2/token" # nosec
profile_url = "https://discord.com/api/users/@me"

View File

@ -5,12 +5,11 @@ from facebook import GraphAPI
from authentik.sources.oauth.clients.oauth2 import OAuth2Client from authentik.sources.oauth.clients.oauth2 import OAuth2Client
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, RequestKind from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@MANAGER.source(kind=RequestKind.REDIRECT, name="Facebook")
class FacebookOAuthRedirect(OAuthRedirect): class FacebookOAuthRedirect(OAuthRedirect):
"""Facebook OAuth2 Redirect""" """Facebook OAuth2 Redirect"""
@ -28,7 +27,6 @@ class FacebookOAuth2Client(OAuth2Client):
return api.get_object("me", fields="id,name,email") return api.get_object("me", fields="id,name,email")
@MANAGER.source(kind=RequestKind.CALLBACK, name="Facebook")
class FacebookOAuth2Callback(OAuthCallback): class FacebookOAuth2Callback(OAuthCallback):
"""Facebook OAuth2 Callback""" """Facebook OAuth2 Callback"""
@ -45,3 +43,17 @@ class FacebookOAuth2Callback(OAuthCallback):
"email": info.get("email"), "email": info.get("email"),
"name": info.get("name"), "name": info.get("name"),
} }
@MANAGER.type()
class FacebookType(SourceType):
"""Facebook Type definition"""
callback_view = FacebookOAuth2Callback
redirect_view = FacebookOAuthRedirect
name = "Facebook"
slug = "facebook"
authorization_url = "https://www.facebook.com/v7.0/dialog/oauth"
access_token_url = "https://graph.facebook.com/v7.0/oauth/access_token" # nosec
profile_url = "https://graph.facebook.com/v7.0/me?fields=id,name,email"

View File

@ -2,11 +2,10 @@
from typing import Any from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, RequestKind from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
@MANAGER.source(kind=RequestKind.CALLBACK, name="GitHub")
class GitHubOAuth2Callback(OAuthCallback): class GitHubOAuth2Callback(OAuthCallback):
"""GitHub OAuth2 Callback""" """GitHub OAuth2 Callback"""
@ -21,3 +20,18 @@ class GitHubOAuth2Callback(OAuthCallback):
"email": info.get("email"), "email": info.get("email"),
"name": info.get("name"), "name": info.get("name"),
} }
@MANAGER.type()
class GitHubType(SourceType):
"""GitHub Type definition"""
callback_view = GitHubOAuth2Callback
name = "GitHub"
slug = "github"
urls_customizable = True
authorization_url = "https://github.com/login/oauth/authorize"
access_token_url = "https://github.com/login/oauth/access_token" # nosec
profile_url = "https://api.github.com/user"

View File

@ -2,12 +2,11 @@
from typing import Any from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, RequestKind from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@MANAGER.source(kind=RequestKind.REDIRECT, name="Google")
class GoogleOAuthRedirect(OAuthRedirect): class GoogleOAuthRedirect(OAuthRedirect):
"""Google OAuth2 Redirect""" """Google OAuth2 Redirect"""
@ -17,7 +16,6 @@ class GoogleOAuthRedirect(OAuthRedirect):
} }
@MANAGER.source(kind=RequestKind.CALLBACK, name="Google")
class GoogleOAuth2Callback(OAuthCallback): class GoogleOAuth2Callback(OAuthCallback):
"""Google OAuth2 Callback""" """Google OAuth2 Callback"""
@ -32,3 +30,17 @@ class GoogleOAuth2Callback(OAuthCallback):
"email": info.get("email"), "email": info.get("email"),
"name": info.get("name"), "name": info.get("name"),
} }
@MANAGER.type()
class GoogleType(SourceType):
"""Google Type definition"""
callback_view = GoogleOAuth2Callback
redirect_view = GoogleOAuthRedirect
name = "Google"
slug = "google"
authorization_url = "https://accounts.google.com/o/oauth2/auth"
access_token_url = "https://accounts.google.com/o/oauth2/token" # nosec
profile_url = "https://www.googleapis.com/oauth2/v1/userinfo"

View File

@ -1,16 +1,17 @@
"""Source type manager""" """Source type manager"""
from enum import Enum from enum import Enum
from typing import Callable from typing import TYPE_CHECKING, Callable, Optional
from django.utils.text import slugify
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
LOGGER = get_logger() LOGGER = get_logger()
if TYPE_CHECKING:
from authentik.sources.oauth.models import OAuthSource
class RequestKind(Enum): class RequestKind(Enum):
"""Enum of OAuth Request types""" """Enum of OAuth Request types"""
@ -19,46 +20,67 @@ class RequestKind(Enum):
REDIRECT = "redirect" REDIRECT = "redirect"
class SourceType:
"""Source type, allows overriding of urls and views per type"""
callback_view = OAuthCallback
redirect_view = OAuthRedirect
name: str
slug: str
urls_customizable = False
request_token_url: Optional[str] = None
authorization_url: Optional[str] = None
access_token_url: Optional[str] = None
profile_url: Optional[str] = None
class SourceTypeManager: class SourceTypeManager:
"""Manager to hold all Source types.""" """Manager to hold all Source types."""
__source_types: dict[RequestKind, dict[str, Callable]] = {} __sources: list[SourceType] = []
__names: list[str] = []
def source(self, kind: RequestKind, name: str): def type(self):
"""Class decorator to register classes inline.""" """Class decorator to register classes inline."""
def inner_wrapper(cls): def inner_wrapper(cls):
if kind.value not in self.__source_types: self.__sources.append(cls)
self.__source_types[kind.value] = {}
self.__source_types[kind.value][slugify(name)] = cls
self.__names.append(name)
return cls return cls
return inner_wrapper return inner_wrapper
def get(self):
"""Get a list of all source types"""
return self.__sources
def get_name_tuple(self): def get_name_tuple(self):
"""Get list of tuples of all registered names""" """Get list of tuples of all registered names"""
return [(slugify(x), x) for x in set(self.__names)] return [(x.slug, x.name) for x in self.__sources]
def find(self, source: OAuthSource, kind: RequestKind) -> Callable: def find_type(self, source: "OAuthSource") -> SourceType:
"""Find fitting Source Type""" """Find type based on source"""
if kind.value in self.__source_types: found_type = None
if source.provider_type in self.__source_types[kind.value]: for src_type in self.__sources:
return self.__source_types[kind.value][source.provider_type] if src_type.slug == source.provider_type:
return src_type
if not found_type:
found_type = SourceType()
LOGGER.warning( LOGGER.warning(
"no matching type found, using default", "no matching type found, using default",
wanted=source.provider_type, wanted=source.provider_type,
have=self.__source_types[kind.value].keys(), have=[x.name for x in self.__sources],
) )
# Return defaults return found_type
def find(self, source: "OAuthSource", kind: RequestKind) -> Callable:
"""Find fitting Source Type"""
found_type = self.find_type(source)
if kind == RequestKind.CALLBACK: if kind == RequestKind.CALLBACK:
return OAuthCallback return found_type.callback_view
if kind == RequestKind.REDIRECT: if kind == RequestKind.REDIRECT:
return OAuthRedirect return found_type.redirect_view
raise KeyError( raise ValueError
f"Provider Type {source.provider_type} (type {kind.value}) not found."
)
MANAGER = SourceTypeManager() MANAGER = SourceTypeManager()

View File

@ -2,12 +2,11 @@
from typing import Any from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, RequestKind from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@MANAGER.source(kind=RequestKind.REDIRECT, name="OpenID Connect")
class OpenIDConnectOAuthRedirect(OAuthRedirect): class OpenIDConnectOAuthRedirect(OAuthRedirect):
"""OpenIDConnect OAuth2 Redirect""" """OpenIDConnect OAuth2 Redirect"""
@ -17,7 +16,6 @@ class OpenIDConnectOAuthRedirect(OAuthRedirect):
} }
@MANAGER.source(kind=RequestKind.CALLBACK, name="OpenID Connect")
class OpenIDConnectOAuth2Callback(OAuthCallback): class OpenIDConnectOAuth2Callback(OAuthCallback):
"""OpenIDConnect OAuth2 Callback""" """OpenIDConnect OAuth2 Callback"""
@ -35,3 +33,15 @@ class OpenIDConnectOAuth2Callback(OAuthCallback):
"email": info.get("email"), "email": info.get("email"),
"name": info.get("name"), "name": info.get("name"),
} }
@MANAGER.type()
class OpenIDConnectType(SourceType):
"""OpenIDConnect Type definition"""
callback_view = OpenIDConnectOAuth2Callback
redirect_view = OpenIDConnectOAuthRedirect
name = "OpenID Connect"
slug = "openid-connect"
urls_customizable = True

View File

@ -5,12 +5,11 @@ from requests.auth import HTTPBasicAuth
from authentik.sources.oauth.clients.oauth2 import OAuth2Client from authentik.sources.oauth.clients.oauth2 import OAuth2Client
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, RequestKind from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@MANAGER.source(kind=RequestKind.REDIRECT, name="reddit")
class RedditOAuthRedirect(OAuthRedirect): class RedditOAuthRedirect(OAuthRedirect):
"""Reddit OAuth2 Redirect""" """Reddit OAuth2 Redirect"""
@ -30,7 +29,6 @@ class RedditOAuth2Client(OAuth2Client):
return super().get_access_token(auth=auth) return super().get_access_token(auth=auth)
@MANAGER.source(kind=RequestKind.CALLBACK, name="reddit")
class RedditOAuth2Callback(OAuthCallback): class RedditOAuth2Callback(OAuthCallback):
"""Reddit OAuth2 Callback""" """Reddit OAuth2 Callback"""
@ -48,3 +46,17 @@ class RedditOAuth2Callback(OAuthCallback):
"name": info.get("name"), "name": info.get("name"),
"password": None, "password": None,
} }
@MANAGER.type()
class RedditType(SourceType):
"""Reddit Type definition"""
callback_view = RedditOAuth2Callback
redirect_view = RedditOAuthRedirect
name = "reddit"
slug = "reddit"
authorization_url = "https://accounts.google.com/o/oauth2/auth"
access_token_url = "https://accounts.google.com/o/oauth2/token" # nosec
profile_url = "https://www.googleapis.com/oauth2/v1/userinfo"

View File

@ -2,11 +2,10 @@
from typing import Any from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, RequestKind from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
@MANAGER.source(kind=RequestKind.CALLBACK, name="Twitter")
class TwitterOAuthCallback(OAuthCallback): class TwitterOAuthCallback(OAuthCallback):
"""Twitter OAuth2 Callback""" """Twitter OAuth2 Callback"""
@ -21,3 +20,20 @@ class TwitterOAuthCallback(OAuthCallback):
"email": info.get("email", None), "email": info.get("email", None),
"name": info.get("name"), "name": info.get("name"),
} }
@MANAGER.type()
class TwitterType(SourceType):
"""Twitter Type definition"""
callback_view = TwitterOAuthCallback
name = "Twitter"
slug = "twitter"
request_token_url = "https://api.twitter.com/oauth/request_token" # nosec
authorization_url = "https://api.twitter.com/oauth/authenticate"
access_token_url = "https://api.twitter.com/oauth/access_token" # nosec
profile_url = (
"https://api.twitter.com/1.1/account/"
"verify_credentials.json?include_email=true"
)

View File

@ -9642,9 +9642,9 @@ paths:
tags: tags:
- sources - sources
parameters: [] parameters: []
/sources/oauth/provider_types/: /sources/oauth/source_types/:
get: get:
operationId: sources_oauth_provider_types operationId: sources_oauth_source_types
description: Get all creatable source types description: Get all creatable source types
parameters: [] parameters: []
responses: responses:
@ -9653,7 +9653,7 @@ paths:
schema: schema:
type: array type: array
items: items:
$ref: '#/definitions/OAuthSourceProviderType' $ref: '#/definitions/SourceType'
'403': '403':
description: Authentication credentials were invalid, absent or insufficient. description: Authentication credentials were invalid, absent or insufficient.
schema: schema:
@ -16907,6 +16907,49 @@ definitions:
type: string type: string
format: date-time format: date-time
readOnly: true readOnly: true
SourceType:
description: Get source's type configuration
required:
- name
- slug
- urls_customizable
type: object
properties:
name:
title: Name
type: string
minLength: 1
slug:
title: Slug
type: string
minLength: 1
urls_customizable:
title: Urls customizable
type: boolean
request_token_url:
title: Request token url
type: string
readOnly: true
minLength: 1
x-nullable: true
authorization_url:
title: Authorization url
type: string
readOnly: true
minLength: 1
x-nullable: true
access_token_url:
title: Access token url
type: string
readOnly: true
minLength: 1
x-nullable: true
profile_url:
title: Profile url
type: string
readOnly: true
minLength: 1
x-nullable: true
OAuthSource: OAuthSource:
required: required:
- name - name
@ -17011,20 +17054,8 @@ definitions:
title: Callback url title: Callback url
type: string type: string
readOnly: true readOnly: true
OAuthSourceProviderType: type:
required: $ref: '#/definitions/SourceType'
- name
- value
type: object
properties:
name:
title: Name
type: string
minLength: 1
value:
title: Value
type: string
minLength: 1
UserOAuthSourceConnection: UserOAuthSourceConnection:
required: required:
- user - user