sources/oauth: revamp types system, move default URLs to type
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
83fc22005c
commit
1daba5db87
|
@ -2,10 +2,11 @@
|
||||||
from django.urls.base import reverse_lazy
|
from django.urls.base import reverse_lazy
|
||||||
from drf_yasg.utils import swagger_auto_schema
|
from drf_yasg.utils import swagger_auto_schema
|
||||||
from rest_framework.decorators import action
|
from rest_framework.decorators import action
|
||||||
from rest_framework.fields import CharField, SerializerMethodField
|
from rest_framework.fields import BooleanField, CharField, SerializerMethodField
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.viewsets import ModelViewSet
|
from rest_framework.viewsets import ModelViewSet
|
||||||
|
from drf_yasg.utils import swagger_serializer_method
|
||||||
|
|
||||||
from authentik.core.api.sources import SourceSerializer
|
from authentik.core.api.sources import SourceSerializer
|
||||||
from authentik.core.api.utils import PassiveSerializer
|
from authentik.core.api.utils import PassiveSerializer
|
||||||
|
@ -13,6 +14,18 @@ from authentik.sources.oauth.models import OAuthSource
|
||||||
from authentik.sources.oauth.types.manager import MANAGER
|
from authentik.sources.oauth.types.manager import MANAGER
|
||||||
|
|
||||||
|
|
||||||
|
class SourceTypeSerializer(PassiveSerializer):
|
||||||
|
"""Serializer for SourceType"""
|
||||||
|
|
||||||
|
name = CharField(required=True)
|
||||||
|
slug = CharField(required=True)
|
||||||
|
urls_customizable = BooleanField()
|
||||||
|
request_token_url = CharField(read_only=True, allow_null=True)
|
||||||
|
authorization_url = CharField(read_only=True, allow_null=True)
|
||||||
|
access_token_url = CharField(read_only=True, allow_null=True)
|
||||||
|
profile_url = CharField(read_only=True, allow_null=True)
|
||||||
|
|
||||||
|
|
||||||
class OAuthSourceSerializer(SourceSerializer):
|
class OAuthSourceSerializer(SourceSerializer):
|
||||||
"""OAuth Source Serializer"""
|
"""OAuth Source Serializer"""
|
||||||
|
|
||||||
|
@ -28,6 +41,13 @@ class OAuthSourceSerializer(SourceSerializer):
|
||||||
return relative_url
|
return relative_url
|
||||||
return self.context["request"].build_absolute_uri(relative_url)
|
return self.context["request"].build_absolute_uri(relative_url)
|
||||||
|
|
||||||
|
type = SerializerMethodField()
|
||||||
|
|
||||||
|
@swagger_serializer_method(serializer_or_field=SourceTypeSerializer)
|
||||||
|
def get_type(self, instace: OAuthSource) -> SourceTypeSerializer:
|
||||||
|
"""Get source's type configuration"""
|
||||||
|
return SourceTypeSerializer(instace.type).data
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = OAuthSource
|
model = OAuthSource
|
||||||
fields = SourceSerializer.Meta.fields + [
|
fields = SourceSerializer.Meta.fields + [
|
||||||
|
@ -39,17 +59,11 @@ class OAuthSourceSerializer(SourceSerializer):
|
||||||
"consumer_key",
|
"consumer_key",
|
||||||
"consumer_secret",
|
"consumer_secret",
|
||||||
"callback_url",
|
"callback_url",
|
||||||
|
"type",
|
||||||
]
|
]
|
||||||
extra_kwargs = {"consumer_secret": {"write_only": True}}
|
extra_kwargs = {"consumer_secret": {"write_only": True}}
|
||||||
|
|
||||||
|
|
||||||
class OAuthSourceProviderType(PassiveSerializer):
|
|
||||||
"""OAuth Provider"""
|
|
||||||
|
|
||||||
name = CharField(required=True)
|
|
||||||
value = CharField(required=True)
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthSourceViewSet(ModelViewSet):
|
class OAuthSourceViewSet(ModelViewSet):
|
||||||
"""Source Viewset"""
|
"""Source Viewset"""
|
||||||
|
|
||||||
|
@ -57,16 +71,11 @@ class OAuthSourceViewSet(ModelViewSet):
|
||||||
serializer_class = OAuthSourceSerializer
|
serializer_class = OAuthSourceSerializer
|
||||||
lookup_field = "slug"
|
lookup_field = "slug"
|
||||||
|
|
||||||
@swagger_auto_schema(responses={200: OAuthSourceProviderType(many=True)})
|
@swagger_auto_schema(responses={200: SourceTypeSerializer(many=True)})
|
||||||
@action(detail=False, pagination_class=None, filter_backends=[])
|
@action(detail=False, pagination_class=None, filter_backends=[])
|
||||||
def provider_types(self, request: Request) -> Response:
|
def source_types(self, request: Request) -> Response:
|
||||||
"""Get all creatable source types"""
|
"""Get all creatable source types"""
|
||||||
data = []
|
data = []
|
||||||
for key, value in MANAGER.get_name_tuple():
|
for source_type in MANAGER.get():
|
||||||
data.append(
|
data.append(SourceTypeSerializer(source_type).data)
|
||||||
{
|
return Response(data)
|
||||||
"name": value,
|
|
||||||
"value": key,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return Response(OAuthSourceProviderType(data, many=True).data)
|
|
||||||
|
|
|
@ -1,138 +0,0 @@
|
||||||
"""authentik oauth_client forms"""
|
|
||||||
|
|
||||||
from django import forms
|
|
||||||
|
|
||||||
from authentik.flows.models import Flow, FlowDesignation
|
|
||||||
from authentik.sources.oauth.models import OAuthSource
|
|
||||||
from authentik.sources.oauth.types.manager import MANAGER
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthSourceForm(forms.ModelForm):
|
|
||||||
"""OAuthSource Form"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.fields["authentication_flow"].queryset = Flow.objects.filter(
|
|
||||||
designation=FlowDesignation.AUTHENTICATION
|
|
||||||
)
|
|
||||||
self.fields["authentication_flow"].required = True
|
|
||||||
self.fields["enrollment_flow"].queryset = Flow.objects.filter(
|
|
||||||
designation=FlowDesignation.ENROLLMENT
|
|
||||||
)
|
|
||||||
self.fields["enrollment_flow"].required = True
|
|
||||||
if hasattr(self.Meta, "overrides"):
|
|
||||||
for overide_field, overide_value in getattr(self.Meta, "overrides").items():
|
|
||||||
self.fields[overide_field].initial = overide_value
|
|
||||||
self.fields[overide_field].widget.attrs["readonly"] = "readonly"
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
|
|
||||||
model = OAuthSource
|
|
||||||
fields = [
|
|
||||||
"name",
|
|
||||||
"slug",
|
|
||||||
"enabled",
|
|
||||||
"policy_engine_mode",
|
|
||||||
"authentication_flow",
|
|
||||||
"enrollment_flow",
|
|
||||||
"provider_type",
|
|
||||||
"request_token_url",
|
|
||||||
"authorization_url",
|
|
||||||
"access_token_url",
|
|
||||||
"profile_url",
|
|
||||||
"consumer_key",
|
|
||||||
"consumer_secret",
|
|
||||||
]
|
|
||||||
widgets = {
|
|
||||||
"name": forms.TextInput(),
|
|
||||||
"consumer_key": forms.TextInput(),
|
|
||||||
"consumer_secret": forms.TextInput(),
|
|
||||||
"provider_type": forms.Select(choices=MANAGER.get_name_tuple()),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class GitHubOAuthSourceForm(OAuthSourceForm):
|
|
||||||
"""OAuth Source form with pre-determined URL for GitHub"""
|
|
||||||
|
|
||||||
class Meta(OAuthSourceForm.Meta):
|
|
||||||
|
|
||||||
overrides = {
|
|
||||||
"provider_type": "github",
|
|
||||||
"request_token_url": "",
|
|
||||||
"authorization_url": "https://github.com/login/oauth/authorize",
|
|
||||||
"access_token_url": "https://github.com/login/oauth/access_token",
|
|
||||||
"profile_url": "https://api.github.com/user",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TwitterOAuthSourceForm(OAuthSourceForm):
|
|
||||||
"""OAuth Source form with pre-determined URL for Twitter"""
|
|
||||||
|
|
||||||
class Meta(OAuthSourceForm.Meta):
|
|
||||||
|
|
||||||
overrides = {
|
|
||||||
"provider_type": "twitter",
|
|
||||||
"request_token_url": "https://api.twitter.com/oauth/request_token",
|
|
||||||
"authorization_url": "https://api.twitter.com/oauth/authenticate",
|
|
||||||
"access_token_url": "https://api.twitter.com/oauth/access_token",
|
|
||||||
"profile_url": (
|
|
||||||
"https://api.twitter.com/1.1/account/"
|
|
||||||
"verify_credentials.json?include_email=true"
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class FacebookOAuthSourceForm(OAuthSourceForm):
|
|
||||||
"""OAuth Source form with pre-determined URL for Facebook"""
|
|
||||||
|
|
||||||
class Meta(OAuthSourceForm.Meta):
|
|
||||||
|
|
||||||
overrides = {
|
|
||||||
"provider_type": "facebook",
|
|
||||||
"request_token_url": "",
|
|
||||||
"authorization_url": "https://www.facebook.com/v7.0/dialog/oauth",
|
|
||||||
"access_token_url": "https://graph.facebook.com/v7.0/oauth/access_token",
|
|
||||||
"profile_url": "https://graph.facebook.com/v7.0/me?fields=id,name,email",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class DiscordOAuthSourceForm(OAuthSourceForm):
|
|
||||||
"""OAuth Source form with pre-determined URL for Discord"""
|
|
||||||
|
|
||||||
class Meta(OAuthSourceForm.Meta):
|
|
||||||
|
|
||||||
overrides = {
|
|
||||||
"provider_type": "discord",
|
|
||||||
"request_token_url": "",
|
|
||||||
"authorization_url": "https://discord.com/api/oauth2/authorize",
|
|
||||||
"access_token_url": "https://discord.com/api/oauth2/token",
|
|
||||||
"profile_url": "https://discord.com/api/users/@me",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleOAuthSourceForm(OAuthSourceForm):
|
|
||||||
"""OAuth Source form with pre-determined URL for Google"""
|
|
||||||
|
|
||||||
class Meta(OAuthSourceForm.Meta):
|
|
||||||
|
|
||||||
overrides = {
|
|
||||||
"provider_type": "google",
|
|
||||||
"request_token_url": "",
|
|
||||||
"authorization_url": "https://accounts.google.com/o/oauth2/auth",
|
|
||||||
"access_token_url": "https://accounts.google.com/o/oauth2/token",
|
|
||||||
"profile_url": "https://www.googleapis.com/oauth2/v1/userinfo",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class AzureADOAuthSourceForm(OAuthSourceForm):
|
|
||||||
"""OAuth Source form with pre-determined URL for AzureAD"""
|
|
||||||
|
|
||||||
class Meta(OAuthSourceForm.Meta):
|
|
||||||
|
|
||||||
overrides = {
|
|
||||||
"provider_type": "azure-ad",
|
|
||||||
"request_token_url": "",
|
|
||||||
"authorization_url": "https://login.microsoftonline.com/common/oauth2/authorize",
|
|
||||||
"access_token_url": "https://login.microsoftonline.com/common/oauth2/token",
|
|
||||||
"profile_url": "https://graph.windows.net/myorganization/me?api-version=1.6",
|
|
||||||
}
|
|
|
@ -1,8 +1,7 @@
|
||||||
"""OAuth Client models"""
|
"""OAuth Client models"""
|
||||||
from typing import Optional, Type
|
from typing import TYPE_CHECKING, Optional, Type
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.forms import ModelForm
|
|
||||||
from django.templatetags.static import static
|
from django.templatetags.static import static
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
@ -11,6 +10,9 @@ from rest_framework.serializers import Serializer
|
||||||
from authentik.core.models import Source, UserSourceConnection
|
from authentik.core.models import Source, UserSourceConnection
|
||||||
from authentik.core.types import UILoginButton, UserSettingSerializer
|
from authentik.core.types import UILoginButton, UserSettingSerializer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from authentik.sources.oauth.types.manager import SourceType
|
||||||
|
|
||||||
|
|
||||||
class OAuthSource(Source):
|
class OAuthSource(Source):
|
||||||
"""Login using a Generic OAuth provider."""
|
"""Login using a Generic OAuth provider."""
|
||||||
|
@ -43,10 +45,15 @@ class OAuthSource(Source):
|
||||||
consumer_secret = models.TextField()
|
consumer_secret = models.TextField()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def form(self) -> Type[ModelForm]:
|
def type(self) -> "SourceType":
|
||||||
from authentik.sources.oauth.forms import OAuthSourceForm
|
"""Return the provider instance for this source"""
|
||||||
|
from authentik.sources.oauth.types.manager import MANAGER
|
||||||
|
|
||||||
return OAuthSourceForm
|
return MANAGER.find_type(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def component(self) -> str:
|
||||||
|
return "ak-source-oauth-form"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def serializer(self) -> Type[Serializer]:
|
def serializer(self) -> Type[Serializer]:
|
||||||
|
@ -86,12 +93,6 @@ class OAuthSource(Source):
|
||||||
class GitHubOAuthSource(OAuthSource):
|
class GitHubOAuthSource(OAuthSource):
|
||||||
"""Social Login using GitHub.com or a GitHub-Enterprise Instance."""
|
"""Social Login using GitHub.com or a GitHub-Enterprise Instance."""
|
||||||
|
|
||||||
@property
|
|
||||||
def form(self) -> Type[ModelForm]:
|
|
||||||
from authentik.sources.oauth.forms import GitHubOAuthSourceForm
|
|
||||||
|
|
||||||
return GitHubOAuthSourceForm
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
||||||
abstract = True
|
abstract = True
|
||||||
|
@ -102,12 +103,6 @@ class GitHubOAuthSource(OAuthSource):
|
||||||
class TwitterOAuthSource(OAuthSource):
|
class TwitterOAuthSource(OAuthSource):
|
||||||
"""Social Login using Twitter.com"""
|
"""Social Login using Twitter.com"""
|
||||||
|
|
||||||
@property
|
|
||||||
def form(self) -> Type[ModelForm]:
|
|
||||||
from authentik.sources.oauth.forms import TwitterOAuthSourceForm
|
|
||||||
|
|
||||||
return TwitterOAuthSourceForm
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
||||||
abstract = True
|
abstract = True
|
||||||
|
@ -118,12 +113,6 @@ class TwitterOAuthSource(OAuthSource):
|
||||||
class FacebookOAuthSource(OAuthSource):
|
class FacebookOAuthSource(OAuthSource):
|
||||||
"""Social Login using Facebook.com."""
|
"""Social Login using Facebook.com."""
|
||||||
|
|
||||||
@property
|
|
||||||
def form(self) -> Type[ModelForm]:
|
|
||||||
from authentik.sources.oauth.forms import FacebookOAuthSourceForm
|
|
||||||
|
|
||||||
return FacebookOAuthSourceForm
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
||||||
abstract = True
|
abstract = True
|
||||||
|
@ -134,12 +123,6 @@ class FacebookOAuthSource(OAuthSource):
|
||||||
class DiscordOAuthSource(OAuthSource):
|
class DiscordOAuthSource(OAuthSource):
|
||||||
"""Social Login using Discord."""
|
"""Social Login using Discord."""
|
||||||
|
|
||||||
@property
|
|
||||||
def form(self) -> Type[ModelForm]:
|
|
||||||
from authentik.sources.oauth.forms import DiscordOAuthSourceForm
|
|
||||||
|
|
||||||
return DiscordOAuthSourceForm
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
||||||
abstract = True
|
abstract = True
|
||||||
|
@ -150,12 +133,6 @@ class DiscordOAuthSource(OAuthSource):
|
||||||
class GoogleOAuthSource(OAuthSource):
|
class GoogleOAuthSource(OAuthSource):
|
||||||
"""Social Login using Google or Gsuite."""
|
"""Social Login using Google or Gsuite."""
|
||||||
|
|
||||||
@property
|
|
||||||
def form(self) -> Type[ModelForm]:
|
|
||||||
from authentik.sources.oauth.forms import GoogleOAuthSourceForm
|
|
||||||
|
|
||||||
return GoogleOAuthSourceForm
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
||||||
abstract = True
|
abstract = True
|
||||||
|
@ -166,12 +143,6 @@ class GoogleOAuthSource(OAuthSource):
|
||||||
class AzureADOAuthSource(OAuthSource):
|
class AzureADOAuthSource(OAuthSource):
|
||||||
"""Social Login using Azure AD."""
|
"""Social Login using Azure AD."""
|
||||||
|
|
||||||
@property
|
|
||||||
def form(self) -> Type[ModelForm]:
|
|
||||||
from authentik.sources.oauth.forms import AzureADOAuthSourceForm
|
|
||||||
|
|
||||||
return AzureADOAuthSourceForm
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
||||||
abstract = True
|
abstract = True
|
||||||
|
@ -182,12 +153,6 @@ class AzureADOAuthSource(OAuthSource):
|
||||||
class OpenIDOAuthSource(OAuthSource):
|
class OpenIDOAuthSource(OAuthSource):
|
||||||
"""Login using a Generic OpenID-Connect compliant provider."""
|
"""Login using a Generic OpenID-Connect compliant provider."""
|
||||||
|
|
||||||
@property
|
|
||||||
def form(self) -> Type[ModelForm]:
|
|
||||||
from authentik.sources.oauth.forms import OAuthSourceForm
|
|
||||||
|
|
||||||
return OAuthSourceForm
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
||||||
abstract = True
|
abstract = True
|
||||||
|
|
|
@ -3,11 +3,10 @@ from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||||
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||||
|
|
||||||
|
|
||||||
@MANAGER.source(kind=RequestKind.CALLBACK, name="Azure AD")
|
|
||||||
class AzureADOAuthCallback(OAuthCallback):
|
class AzureADOAuthCallback(OAuthCallback):
|
||||||
"""AzureAD OAuth2 Callback"""
|
"""AzureAD OAuth2 Callback"""
|
||||||
|
|
||||||
|
@ -26,3 +25,18 @@ class AzureADOAuthCallback(OAuthCallback):
|
||||||
"email": mail,
|
"email": mail,
|
||||||
"name": info.get("displayName"),
|
"name": info.get("displayName"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@MANAGER.type()
|
||||||
|
class AzureADType(SourceType):
|
||||||
|
"""Azure AD Type definition"""
|
||||||
|
|
||||||
|
callback_view = AzureADOAuthCallback
|
||||||
|
name = "Azure AD"
|
||||||
|
slug = "azure-ad"
|
||||||
|
|
||||||
|
urls_customizable = True
|
||||||
|
|
||||||
|
authorization_url = "https://login.microsoftonline.com/common/oauth2/authorize"
|
||||||
|
access_token_url = "https://login.microsoftonline.com/common/oauth2/token" # nosec
|
||||||
|
profile_url = "https://graph.windows.net/myorganization/me?api-version=1.6"
|
||||||
|
|
|
@ -2,12 +2,11 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||||
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||||
|
|
||||||
|
|
||||||
@MANAGER.source(kind=RequestKind.REDIRECT, name="Discord")
|
|
||||||
class DiscordOAuthRedirect(OAuthRedirect):
|
class DiscordOAuthRedirect(OAuthRedirect):
|
||||||
"""Discord OAuth2 Redirect"""
|
"""Discord OAuth2 Redirect"""
|
||||||
|
|
||||||
|
@ -17,7 +16,6 @@ class DiscordOAuthRedirect(OAuthRedirect):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@MANAGER.source(kind=RequestKind.CALLBACK, name="Discord")
|
|
||||||
class DiscordOAuth2Callback(OAuthCallback):
|
class DiscordOAuth2Callback(OAuthCallback):
|
||||||
"""Discord OAuth2 Callback"""
|
"""Discord OAuth2 Callback"""
|
||||||
|
|
||||||
|
@ -32,3 +30,17 @@ class DiscordOAuth2Callback(OAuthCallback):
|
||||||
"email": info.get("email", None),
|
"email": info.get("email", None),
|
||||||
"name": info.get("username"),
|
"name": info.get("username"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@MANAGER.type()
|
||||||
|
class DiscordType(SourceType):
|
||||||
|
"""Discord Type definition"""
|
||||||
|
|
||||||
|
callback_view = DiscordOAuth2Callback
|
||||||
|
redirect_view = DiscordOAuthRedirect
|
||||||
|
name = "Discord"
|
||||||
|
slug = "discord"
|
||||||
|
|
||||||
|
authorization_url = "https://discord.com/api/oauth2/authorize"
|
||||||
|
access_token_url = "https://discord.com/api/oauth2/token" # nosec
|
||||||
|
profile_url = "https://discord.com/api/users/@me"
|
||||||
|
|
|
@ -5,12 +5,11 @@ from facebook import GraphAPI
|
||||||
|
|
||||||
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
|
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
|
||||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||||
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||||
|
|
||||||
|
|
||||||
@MANAGER.source(kind=RequestKind.REDIRECT, name="Facebook")
|
|
||||||
class FacebookOAuthRedirect(OAuthRedirect):
|
class FacebookOAuthRedirect(OAuthRedirect):
|
||||||
"""Facebook OAuth2 Redirect"""
|
"""Facebook OAuth2 Redirect"""
|
||||||
|
|
||||||
|
@ -28,7 +27,6 @@ class FacebookOAuth2Client(OAuth2Client):
|
||||||
return api.get_object("me", fields="id,name,email")
|
return api.get_object("me", fields="id,name,email")
|
||||||
|
|
||||||
|
|
||||||
@MANAGER.source(kind=RequestKind.CALLBACK, name="Facebook")
|
|
||||||
class FacebookOAuth2Callback(OAuthCallback):
|
class FacebookOAuth2Callback(OAuthCallback):
|
||||||
"""Facebook OAuth2 Callback"""
|
"""Facebook OAuth2 Callback"""
|
||||||
|
|
||||||
|
@ -45,3 +43,17 @@ class FacebookOAuth2Callback(OAuthCallback):
|
||||||
"email": info.get("email"),
|
"email": info.get("email"),
|
||||||
"name": info.get("name"),
|
"name": info.get("name"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@MANAGER.type()
|
||||||
|
class FacebookType(SourceType):
|
||||||
|
"""Facebook Type definition"""
|
||||||
|
|
||||||
|
callback_view = FacebookOAuth2Callback
|
||||||
|
redirect_view = FacebookOAuthRedirect
|
||||||
|
name = "Facebook"
|
||||||
|
slug = "facebook"
|
||||||
|
|
||||||
|
authorization_url = "https://www.facebook.com/v7.0/dialog/oauth"
|
||||||
|
access_token_url = "https://graph.facebook.com/v7.0/oauth/access_token" # nosec
|
||||||
|
profile_url = "https://graph.facebook.com/v7.0/me?fields=id,name,email"
|
||||||
|
|
|
@ -2,11 +2,10 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||||
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||||
|
|
||||||
|
|
||||||
@MANAGER.source(kind=RequestKind.CALLBACK, name="GitHub")
|
|
||||||
class GitHubOAuth2Callback(OAuthCallback):
|
class GitHubOAuth2Callback(OAuthCallback):
|
||||||
"""GitHub OAuth2 Callback"""
|
"""GitHub OAuth2 Callback"""
|
||||||
|
|
||||||
|
@ -21,3 +20,18 @@ class GitHubOAuth2Callback(OAuthCallback):
|
||||||
"email": info.get("email"),
|
"email": info.get("email"),
|
||||||
"name": info.get("name"),
|
"name": info.get("name"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@MANAGER.type()
|
||||||
|
class GitHubType(SourceType):
|
||||||
|
"""GitHub Type definition"""
|
||||||
|
|
||||||
|
callback_view = GitHubOAuth2Callback
|
||||||
|
name = "GitHub"
|
||||||
|
slug = "github"
|
||||||
|
|
||||||
|
urls_customizable = True
|
||||||
|
|
||||||
|
authorization_url = "https://github.com/login/oauth/authorize"
|
||||||
|
access_token_url = "https://github.com/login/oauth/access_token" # nosec
|
||||||
|
profile_url = "https://api.github.com/user"
|
||||||
|
|
|
@ -2,12 +2,11 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||||
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||||
|
|
||||||
|
|
||||||
@MANAGER.source(kind=RequestKind.REDIRECT, name="Google")
|
|
||||||
class GoogleOAuthRedirect(OAuthRedirect):
|
class GoogleOAuthRedirect(OAuthRedirect):
|
||||||
"""Google OAuth2 Redirect"""
|
"""Google OAuth2 Redirect"""
|
||||||
|
|
||||||
|
@ -17,7 +16,6 @@ class GoogleOAuthRedirect(OAuthRedirect):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@MANAGER.source(kind=RequestKind.CALLBACK, name="Google")
|
|
||||||
class GoogleOAuth2Callback(OAuthCallback):
|
class GoogleOAuth2Callback(OAuthCallback):
|
||||||
"""Google OAuth2 Callback"""
|
"""Google OAuth2 Callback"""
|
||||||
|
|
||||||
|
@ -32,3 +30,17 @@ class GoogleOAuth2Callback(OAuthCallback):
|
||||||
"email": info.get("email"),
|
"email": info.get("email"),
|
||||||
"name": info.get("name"),
|
"name": info.get("name"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@MANAGER.type()
|
||||||
|
class GoogleType(SourceType):
|
||||||
|
"""Google Type definition"""
|
||||||
|
|
||||||
|
callback_view = GoogleOAuth2Callback
|
||||||
|
redirect_view = GoogleOAuthRedirect
|
||||||
|
name = "Google"
|
||||||
|
slug = "google"
|
||||||
|
|
||||||
|
authorization_url = "https://accounts.google.com/o/oauth2/auth"
|
||||||
|
access_token_url = "https://accounts.google.com/o/oauth2/token" # nosec
|
||||||
|
profile_url = "https://www.googleapis.com/oauth2/v1/userinfo"
|
||||||
|
|
|
@ -1,16 +1,17 @@
|
||||||
"""Source type manager"""
|
"""Source type manager"""
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable
|
from typing import TYPE_CHECKING, Callable, Optional
|
||||||
|
|
||||||
from django.utils.text import slugify
|
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.sources.oauth.models import OAuthSource
|
|
||||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from authentik.sources.oauth.models import OAuthSource
|
||||||
|
|
||||||
|
|
||||||
class RequestKind(Enum):
|
class RequestKind(Enum):
|
||||||
"""Enum of OAuth Request types"""
|
"""Enum of OAuth Request types"""
|
||||||
|
@ -19,46 +20,67 @@ class RequestKind(Enum):
|
||||||
REDIRECT = "redirect"
|
REDIRECT = "redirect"
|
||||||
|
|
||||||
|
|
||||||
|
class SourceType:
|
||||||
|
"""Source type, allows overriding of urls and views per type"""
|
||||||
|
|
||||||
|
callback_view = OAuthCallback
|
||||||
|
redirect_view = OAuthRedirect
|
||||||
|
name: str
|
||||||
|
slug: str
|
||||||
|
|
||||||
|
urls_customizable = False
|
||||||
|
|
||||||
|
request_token_url: Optional[str] = None
|
||||||
|
authorization_url: Optional[str] = None
|
||||||
|
access_token_url: Optional[str] = None
|
||||||
|
profile_url: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class SourceTypeManager:
|
class SourceTypeManager:
|
||||||
"""Manager to hold all Source types."""
|
"""Manager to hold all Source types."""
|
||||||
|
|
||||||
__source_types: dict[RequestKind, dict[str, Callable]] = {}
|
__sources: list[SourceType] = []
|
||||||
__names: list[str] = []
|
|
||||||
|
|
||||||
def source(self, kind: RequestKind, name: str):
|
def type(self):
|
||||||
"""Class decorator to register classes inline."""
|
"""Class decorator to register classes inline."""
|
||||||
|
|
||||||
def inner_wrapper(cls):
|
def inner_wrapper(cls):
|
||||||
if kind.value not in self.__source_types:
|
self.__sources.append(cls)
|
||||||
self.__source_types[kind.value] = {}
|
|
||||||
self.__source_types[kind.value][slugify(name)] = cls
|
|
||||||
self.__names.append(name)
|
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
return inner_wrapper
|
return inner_wrapper
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
"""Get a list of all source types"""
|
||||||
|
return self.__sources
|
||||||
|
|
||||||
def get_name_tuple(self):
|
def get_name_tuple(self):
|
||||||
"""Get list of tuples of all registered names"""
|
"""Get list of tuples of all registered names"""
|
||||||
return [(slugify(x), x) for x in set(self.__names)]
|
return [(x.slug, x.name) for x in self.__sources]
|
||||||
|
|
||||||
def find(self, source: OAuthSource, kind: RequestKind) -> Callable:
|
def find_type(self, source: "OAuthSource") -> SourceType:
|
||||||
"""Find fitting Source Type"""
|
"""Find type based on source"""
|
||||||
if kind.value in self.__source_types:
|
found_type = None
|
||||||
if source.provider_type in self.__source_types[kind.value]:
|
for src_type in self.__sources:
|
||||||
return self.__source_types[kind.value][source.provider_type]
|
if src_type.slug == source.provider_type:
|
||||||
|
return src_type
|
||||||
|
if not found_type:
|
||||||
|
found_type = SourceType()
|
||||||
LOGGER.warning(
|
LOGGER.warning(
|
||||||
"no matching type found, using default",
|
"no matching type found, using default",
|
||||||
wanted=source.provider_type,
|
wanted=source.provider_type,
|
||||||
have=self.__source_types[kind.value].keys(),
|
have=[x.name for x in self.__sources],
|
||||||
)
|
)
|
||||||
# Return defaults
|
return found_type
|
||||||
|
|
||||||
|
def find(self, source: "OAuthSource", kind: RequestKind) -> Callable:
|
||||||
|
"""Find fitting Source Type"""
|
||||||
|
found_type = self.find_type(source)
|
||||||
if kind == RequestKind.CALLBACK:
|
if kind == RequestKind.CALLBACK:
|
||||||
return OAuthCallback
|
return found_type.callback_view
|
||||||
if kind == RequestKind.REDIRECT:
|
if kind == RequestKind.REDIRECT:
|
||||||
return OAuthRedirect
|
return found_type.redirect_view
|
||||||
raise KeyError(
|
raise ValueError
|
||||||
f"Provider Type {source.provider_type} (type {kind.value}) not found."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
MANAGER = SourceTypeManager()
|
MANAGER = SourceTypeManager()
|
||||||
|
|
|
@ -2,12 +2,11 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||||
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||||
|
|
||||||
|
|
||||||
@MANAGER.source(kind=RequestKind.REDIRECT, name="OpenID Connect")
|
|
||||||
class OpenIDConnectOAuthRedirect(OAuthRedirect):
|
class OpenIDConnectOAuthRedirect(OAuthRedirect):
|
||||||
"""OpenIDConnect OAuth2 Redirect"""
|
"""OpenIDConnect OAuth2 Redirect"""
|
||||||
|
|
||||||
|
@ -17,7 +16,6 @@ class OpenIDConnectOAuthRedirect(OAuthRedirect):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@MANAGER.source(kind=RequestKind.CALLBACK, name="OpenID Connect")
|
|
||||||
class OpenIDConnectOAuth2Callback(OAuthCallback):
|
class OpenIDConnectOAuth2Callback(OAuthCallback):
|
||||||
"""OpenIDConnect OAuth2 Callback"""
|
"""OpenIDConnect OAuth2 Callback"""
|
||||||
|
|
||||||
|
@ -35,3 +33,15 @@ class OpenIDConnectOAuth2Callback(OAuthCallback):
|
||||||
"email": info.get("email"),
|
"email": info.get("email"),
|
||||||
"name": info.get("name"),
|
"name": info.get("name"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@MANAGER.type()
|
||||||
|
class OpenIDConnectType(SourceType):
|
||||||
|
"""OpenIDConnect Type definition"""
|
||||||
|
|
||||||
|
callback_view = OpenIDConnectOAuth2Callback
|
||||||
|
redirect_view = OpenIDConnectOAuthRedirect
|
||||||
|
name = "OpenID Connect"
|
||||||
|
slug = "openid-connect"
|
||||||
|
|
||||||
|
urls_customizable = True
|
||||||
|
|
|
@ -5,12 +5,11 @@ from requests.auth import HTTPBasicAuth
|
||||||
|
|
||||||
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
|
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
|
||||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||||
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||||
|
|
||||||
|
|
||||||
@MANAGER.source(kind=RequestKind.REDIRECT, name="reddit")
|
|
||||||
class RedditOAuthRedirect(OAuthRedirect):
|
class RedditOAuthRedirect(OAuthRedirect):
|
||||||
"""Reddit OAuth2 Redirect"""
|
"""Reddit OAuth2 Redirect"""
|
||||||
|
|
||||||
|
@ -30,7 +29,6 @@ class RedditOAuth2Client(OAuth2Client):
|
||||||
return super().get_access_token(auth=auth)
|
return super().get_access_token(auth=auth)
|
||||||
|
|
||||||
|
|
||||||
@MANAGER.source(kind=RequestKind.CALLBACK, name="reddit")
|
|
||||||
class RedditOAuth2Callback(OAuthCallback):
|
class RedditOAuth2Callback(OAuthCallback):
|
||||||
"""Reddit OAuth2 Callback"""
|
"""Reddit OAuth2 Callback"""
|
||||||
|
|
||||||
|
@ -48,3 +46,17 @@ class RedditOAuth2Callback(OAuthCallback):
|
||||||
"name": info.get("name"),
|
"name": info.get("name"),
|
||||||
"password": None,
|
"password": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@MANAGER.type()
|
||||||
|
class RedditType(SourceType):
|
||||||
|
"""Reddit Type definition"""
|
||||||
|
|
||||||
|
callback_view = RedditOAuth2Callback
|
||||||
|
redirect_view = RedditOAuthRedirect
|
||||||
|
name = "reddit"
|
||||||
|
slug = "reddit"
|
||||||
|
|
||||||
|
authorization_url = "https://accounts.google.com/o/oauth2/auth"
|
||||||
|
access_token_url = "https://accounts.google.com/o/oauth2/token" # nosec
|
||||||
|
profile_url = "https://www.googleapis.com/oauth2/v1/userinfo"
|
||||||
|
|
|
@ -2,11 +2,10 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||||
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||||
|
|
||||||
|
|
||||||
@MANAGER.source(kind=RequestKind.CALLBACK, name="Twitter")
|
|
||||||
class TwitterOAuthCallback(OAuthCallback):
|
class TwitterOAuthCallback(OAuthCallback):
|
||||||
"""Twitter OAuth2 Callback"""
|
"""Twitter OAuth2 Callback"""
|
||||||
|
|
||||||
|
@ -21,3 +20,20 @@ class TwitterOAuthCallback(OAuthCallback):
|
||||||
"email": info.get("email", None),
|
"email": info.get("email", None),
|
||||||
"name": info.get("name"),
|
"name": info.get("name"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@MANAGER.type()
|
||||||
|
class TwitterType(SourceType):
|
||||||
|
"""Twitter Type definition"""
|
||||||
|
|
||||||
|
callback_view = TwitterOAuthCallback
|
||||||
|
name = "Twitter"
|
||||||
|
slug = "twitter"
|
||||||
|
|
||||||
|
request_token_url = "https://api.twitter.com/oauth/request_token" # nosec
|
||||||
|
authorization_url = "https://api.twitter.com/oauth/authenticate"
|
||||||
|
access_token_url = "https://api.twitter.com/oauth/access_token" # nosec
|
||||||
|
profile_url = (
|
||||||
|
"https://api.twitter.com/1.1/account/"
|
||||||
|
"verify_credentials.json?include_email=true"
|
||||||
|
)
|
||||||
|
|
65
swagger.yaml
65
swagger.yaml
|
@ -9642,9 +9642,9 @@ paths:
|
||||||
tags:
|
tags:
|
||||||
- sources
|
- sources
|
||||||
parameters: []
|
parameters: []
|
||||||
/sources/oauth/provider_types/:
|
/sources/oauth/source_types/:
|
||||||
get:
|
get:
|
||||||
operationId: sources_oauth_provider_types
|
operationId: sources_oauth_source_types
|
||||||
description: Get all creatable source types
|
description: Get all creatable source types
|
||||||
parameters: []
|
parameters: []
|
||||||
responses:
|
responses:
|
||||||
|
@ -9653,7 +9653,7 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
$ref: '#/definitions/OAuthSourceProviderType'
|
$ref: '#/definitions/SourceType'
|
||||||
'403':
|
'403':
|
||||||
description: Authentication credentials were invalid, absent or insufficient.
|
description: Authentication credentials were invalid, absent or insufficient.
|
||||||
schema:
|
schema:
|
||||||
|
@ -16907,6 +16907,49 @@ definitions:
|
||||||
type: string
|
type: string
|
||||||
format: date-time
|
format: date-time
|
||||||
readOnly: true
|
readOnly: true
|
||||||
|
SourceType:
|
||||||
|
description: Get source's type configuration
|
||||||
|
required:
|
||||||
|
- name
|
||||||
|
- slug
|
||||||
|
- urls_customizable
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
name:
|
||||||
|
title: Name
|
||||||
|
type: string
|
||||||
|
minLength: 1
|
||||||
|
slug:
|
||||||
|
title: Slug
|
||||||
|
type: string
|
||||||
|
minLength: 1
|
||||||
|
urls_customizable:
|
||||||
|
title: Urls customizable
|
||||||
|
type: boolean
|
||||||
|
request_token_url:
|
||||||
|
title: Request token url
|
||||||
|
type: string
|
||||||
|
readOnly: true
|
||||||
|
minLength: 1
|
||||||
|
x-nullable: true
|
||||||
|
authorization_url:
|
||||||
|
title: Authorization url
|
||||||
|
type: string
|
||||||
|
readOnly: true
|
||||||
|
minLength: 1
|
||||||
|
x-nullable: true
|
||||||
|
access_token_url:
|
||||||
|
title: Access token url
|
||||||
|
type: string
|
||||||
|
readOnly: true
|
||||||
|
minLength: 1
|
||||||
|
x-nullable: true
|
||||||
|
profile_url:
|
||||||
|
title: Profile url
|
||||||
|
type: string
|
||||||
|
readOnly: true
|
||||||
|
minLength: 1
|
||||||
|
x-nullable: true
|
||||||
OAuthSource:
|
OAuthSource:
|
||||||
required:
|
required:
|
||||||
- name
|
- name
|
||||||
|
@ -17011,20 +17054,8 @@ definitions:
|
||||||
title: Callback url
|
title: Callback url
|
||||||
type: string
|
type: string
|
||||||
readOnly: true
|
readOnly: true
|
||||||
OAuthSourceProviderType:
|
type:
|
||||||
required:
|
$ref: '#/definitions/SourceType'
|
||||||
- name
|
|
||||||
- value
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
name:
|
|
||||||
title: Name
|
|
||||||
type: string
|
|
||||||
minLength: 1
|
|
||||||
value:
|
|
||||||
title: Value
|
|
||||||
type: string
|
|
||||||
minLength: 1
|
|
||||||
UserOAuthSourceConnection:
|
UserOAuthSourceConnection:
|
||||||
required:
|
required:
|
||||||
- user
|
- user
|
||||||
|
|
Reference in New Issue