sources/oauth: fix error when creating an oauth source which has fixed URLs
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
c579540473
commit
cfe0a7a694
|
@ -5,6 +5,7 @@ from rest_framework.decorators import action
|
|||
from rest_framework.fields import BooleanField, CharField, SerializerMethodField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import ValidationError
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.sources import SourceSerializer
|
||||
|
@ -47,6 +48,20 @@ class OAuthSourceSerializer(SourceSerializer):
|
|||
"""Get source's type configuration"""
|
||||
return SourceTypeSerializer(instace.type).data
|
||||
|
||||
def validate(self, attrs: dict) -> dict:
|
||||
provider_type = MANAGER.find_type(attrs.get("provider_type", ""))
|
||||
for url in [
|
||||
"authorization_url",
|
||||
"access_token_url",
|
||||
"profile_url",
|
||||
]:
|
||||
if getattr(provider_type, url, None) is None:
|
||||
if url not in attrs:
|
||||
raise ValidationError(
|
||||
f"{url} is required for provider {provider_type.name}"
|
||||
)
|
||||
return attrs
|
||||
|
||||
class Meta:
|
||||
model = OAuthSource
|
||||
fields = SourceSerializer.Meta.fields + [
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
# Generated by Django 3.2 on 2021-04-16 07:26
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("authentik_sources_oauth", "0002_auto_20200520_1108"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="oauthsource",
|
||||
name="access_token_url",
|
||||
field=models.CharField(
|
||||
blank=True,
|
||||
help_text="URL used by authentik to retrive tokens.",
|
||||
max_length=255,
|
||||
verbose_name="Access Token URL",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="oauthsource",
|
||||
name="authorization_url",
|
||||
field=models.CharField(
|
||||
blank=True,
|
||||
help_text="URL the user is redirect to to conest the flow.",
|
||||
max_length=255,
|
||||
verbose_name="Authorization URL",
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="oauthsource",
|
||||
name="profile_url",
|
||||
field=models.CharField(
|
||||
blank=True,
|
||||
help_text="URL used by authentik to get user information.",
|
||||
max_length=255,
|
||||
verbose_name="Profile URL",
|
||||
),
|
||||
),
|
||||
]
|
|
@ -28,16 +28,19 @@ class OAuthSource(Source):
|
|||
)
|
||||
authorization_url = models.CharField(
|
||||
max_length=255,
|
||||
blank=True,
|
||||
verbose_name=_("Authorization URL"),
|
||||
help_text=_("URL the user is redirect to to conest the flow."),
|
||||
)
|
||||
access_token_url = models.CharField(
|
||||
max_length=255,
|
||||
blank=True,
|
||||
verbose_name=_("Access Token URL"),
|
||||
help_text=_("URL used by authentik to retrive tokens."),
|
||||
)
|
||||
profile_url = models.CharField(
|
||||
max_length=255,
|
||||
blank=True,
|
||||
verbose_name=_("Profile URL"),
|
||||
help_text=_("URL used by authentik to get user information."),
|
||||
)
|
||||
|
@ -49,7 +52,7 @@ class OAuthSource(Source):
|
|||
"""Return the provider instance for this source"""
|
||||
from authentik.sources.oauth.types.manager import MANAGER
|
||||
|
||||
return MANAGER.find_type(self)
|
||||
return MANAGER.find_type(self.provider_type)
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""OAuth Source tests"""
|
||||
from authentik.sources.oauth.api.source import OAuthSourceSerializer
|
||||
from django.test import TestCase
|
||||
from django.urls import reverse
|
||||
|
||||
|
@ -18,6 +19,23 @@ class TestOAuthSource(TestCase):
|
|||
consumer_key="",
|
||||
)
|
||||
|
||||
def test_api_validate(self):
|
||||
"""Test API validation"""
|
||||
self.assertTrue(OAuthSourceSerializer(data={
|
||||
"name": "foo",
|
||||
"slug": "bar",
|
||||
"provider_type": "google",
|
||||
"consumer_key": "foo",
|
||||
"consumer_secret": "foo",
|
||||
}).is_valid())
|
||||
self.assertFalse(OAuthSourceSerializer(data={
|
||||
"name": "foo",
|
||||
"slug": "bar",
|
||||
"provider_type": "openid-connect",
|
||||
"consumer_key": "foo",
|
||||
"consumer_secret": "foo",
|
||||
}).is_valid())
|
||||
|
||||
def test_source_redirect(self):
|
||||
"""test redirect view"""
|
||||
self.client.get(
|
||||
|
|
|
@ -58,17 +58,17 @@ class SourceTypeManager:
|
|||
"""Get list of tuples of all registered names"""
|
||||
return [(x.slug, x.name) for x in self.__sources]
|
||||
|
||||
def find_type(self, source: "OAuthSource") -> SourceType:
|
||||
def find_type(self, type_name: str) -> SourceType:
|
||||
"""Find type based on source"""
|
||||
found_type = None
|
||||
for src_type in self.__sources:
|
||||
if src_type.slug == source.provider_type:
|
||||
if src_type.slug == type_name:
|
||||
return src_type
|
||||
if not found_type:
|
||||
found_type = SourceType()
|
||||
LOGGER.warning(
|
||||
"no matching type found, using default",
|
||||
wanted=source.provider_type,
|
||||
wanted=type_name,
|
||||
have=[x.name for x in self.__sources],
|
||||
)
|
||||
return found_type
|
||||
|
|
|
@ -16963,9 +16963,6 @@ definitions:
|
|||
- name
|
||||
- slug
|
||||
- provider_type
|
||||
- authorization_url
|
||||
- access_token_url
|
||||
- profile_url
|
||||
- consumer_key
|
||||
- consumer_secret
|
||||
type: object
|
||||
|
@ -17037,19 +17034,16 @@ definitions:
|
|||
description: URL the user is redirect to to conest the flow.
|
||||
type: string
|
||||
maxLength: 255
|
||||
minLength: 1
|
||||
access_token_url:
|
||||
title: Access Token URL
|
||||
description: URL used by authentik to retrive tokens.
|
||||
type: string
|
||||
maxLength: 255
|
||||
minLength: 1
|
||||
profile_url:
|
||||
title: Profile URL
|
||||
description: URL used by authentik to get user information.
|
||||
type: string
|
||||
maxLength: 255
|
||||
minLength: 1
|
||||
consumer_key:
|
||||
title: Consumer key
|
||||
type: string
|
||||
|
|
Reference in a new issue