sources/oauth: fix error when creating an oauth source which has fixed URLs
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
c579540473
commit
cfe0a7a694
|
@ -5,6 +5,7 @@ from rest_framework.decorators import action
|
||||||
from rest_framework.fields import BooleanField, CharField, SerializerMethodField
|
from rest_framework.fields import BooleanField, CharField, SerializerMethodField
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
|
from rest_framework.serializers import ValidationError
|
||||||
from rest_framework.viewsets import ModelViewSet
|
from rest_framework.viewsets import ModelViewSet
|
||||||
|
|
||||||
from authentik.core.api.sources import SourceSerializer
|
from authentik.core.api.sources import SourceSerializer
|
||||||
|
@ -47,6 +48,20 @@ class OAuthSourceSerializer(SourceSerializer):
|
||||||
"""Get source's type configuration"""
|
"""Get source's type configuration"""
|
||||||
return SourceTypeSerializer(instace.type).data
|
return SourceTypeSerializer(instace.type).data
|
||||||
|
|
||||||
|
def validate(self, attrs: dict) -> dict:
|
||||||
|
provider_type = MANAGER.find_type(attrs.get("provider_type", ""))
|
||||||
|
for url in [
|
||||||
|
"authorization_url",
|
||||||
|
"access_token_url",
|
||||||
|
"profile_url",
|
||||||
|
]:
|
||||||
|
if getattr(provider_type, url, None) is None:
|
||||||
|
if url not in attrs:
|
||||||
|
raise ValidationError(
|
||||||
|
f"{url} is required for provider {provider_type.name}"
|
||||||
|
)
|
||||||
|
return attrs
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = OAuthSource
|
model = OAuthSource
|
||||||
fields = SourceSerializer.Meta.fields + [
|
fields = SourceSerializer.Meta.fields + [
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
# Generated by Django 3.2 on 2021-04-16 07:26
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
("authentik_sources_oauth", "0002_auto_20200520_1108"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="oauthsource",
|
||||||
|
name="access_token_url",
|
||||||
|
field=models.CharField(
|
||||||
|
blank=True,
|
||||||
|
help_text="URL used by authentik to retrive tokens.",
|
||||||
|
max_length=255,
|
||||||
|
verbose_name="Access Token URL",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="oauthsource",
|
||||||
|
name="authorization_url",
|
||||||
|
field=models.CharField(
|
||||||
|
blank=True,
|
||||||
|
help_text="URL the user is redirect to to conest the flow.",
|
||||||
|
max_length=255,
|
||||||
|
verbose_name="Authorization URL",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="oauthsource",
|
||||||
|
name="profile_url",
|
||||||
|
field=models.CharField(
|
||||||
|
blank=True,
|
||||||
|
help_text="URL used by authentik to get user information.",
|
||||||
|
max_length=255,
|
||||||
|
verbose_name="Profile URL",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
|
@ -28,16 +28,19 @@ class OAuthSource(Source):
|
||||||
)
|
)
|
||||||
authorization_url = models.CharField(
|
authorization_url = models.CharField(
|
||||||
max_length=255,
|
max_length=255,
|
||||||
|
blank=True,
|
||||||
verbose_name=_("Authorization URL"),
|
verbose_name=_("Authorization URL"),
|
||||||
help_text=_("URL the user is redirect to to conest the flow."),
|
help_text=_("URL the user is redirect to to conest the flow."),
|
||||||
)
|
)
|
||||||
access_token_url = models.CharField(
|
access_token_url = models.CharField(
|
||||||
max_length=255,
|
max_length=255,
|
||||||
|
blank=True,
|
||||||
verbose_name=_("Access Token URL"),
|
verbose_name=_("Access Token URL"),
|
||||||
help_text=_("URL used by authentik to retrive tokens."),
|
help_text=_("URL used by authentik to retrive tokens."),
|
||||||
)
|
)
|
||||||
profile_url = models.CharField(
|
profile_url = models.CharField(
|
||||||
max_length=255,
|
max_length=255,
|
||||||
|
blank=True,
|
||||||
verbose_name=_("Profile URL"),
|
verbose_name=_("Profile URL"),
|
||||||
help_text=_("URL used by authentik to get user information."),
|
help_text=_("URL used by authentik to get user information."),
|
||||||
)
|
)
|
||||||
|
@ -49,7 +52,7 @@ class OAuthSource(Source):
|
||||||
"""Return the provider instance for this source"""
|
"""Return the provider instance for this source"""
|
||||||
from authentik.sources.oauth.types.manager import MANAGER
|
from authentik.sources.oauth.types.manager import MANAGER
|
||||||
|
|
||||||
return MANAGER.find_type(self)
|
return MANAGER.find_type(self.provider_type)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def component(self) -> str:
|
def component(self) -> str:
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""OAuth Source tests"""
|
"""OAuth Source tests"""
|
||||||
|
from authentik.sources.oauth.api.source import OAuthSourceSerializer
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
|
|
||||||
|
@ -18,6 +19,23 @@ class TestOAuthSource(TestCase):
|
||||||
consumer_key="",
|
consumer_key="",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_api_validate(self):
|
||||||
|
"""Test API validation"""
|
||||||
|
self.assertTrue(OAuthSourceSerializer(data={
|
||||||
|
"name": "foo",
|
||||||
|
"slug": "bar",
|
||||||
|
"provider_type": "google",
|
||||||
|
"consumer_key": "foo",
|
||||||
|
"consumer_secret": "foo",
|
||||||
|
}).is_valid())
|
||||||
|
self.assertFalse(OAuthSourceSerializer(data={
|
||||||
|
"name": "foo",
|
||||||
|
"slug": "bar",
|
||||||
|
"provider_type": "openid-connect",
|
||||||
|
"consumer_key": "foo",
|
||||||
|
"consumer_secret": "foo",
|
||||||
|
}).is_valid())
|
||||||
|
|
||||||
def test_source_redirect(self):
|
def test_source_redirect(self):
|
||||||
"""test redirect view"""
|
"""test redirect view"""
|
||||||
self.client.get(
|
self.client.get(
|
||||||
|
|
|
@ -58,17 +58,17 @@ class SourceTypeManager:
|
||||||
"""Get list of tuples of all registered names"""
|
"""Get list of tuples of all registered names"""
|
||||||
return [(x.slug, x.name) for x in self.__sources]
|
return [(x.slug, x.name) for x in self.__sources]
|
||||||
|
|
||||||
def find_type(self, source: "OAuthSource") -> SourceType:
|
def find_type(self, type_name: str) -> SourceType:
|
||||||
"""Find type based on source"""
|
"""Find type based on source"""
|
||||||
found_type = None
|
found_type = None
|
||||||
for src_type in self.__sources:
|
for src_type in self.__sources:
|
||||||
if src_type.slug == source.provider_type:
|
if src_type.slug == type_name:
|
||||||
return src_type
|
return src_type
|
||||||
if not found_type:
|
if not found_type:
|
||||||
found_type = SourceType()
|
found_type = SourceType()
|
||||||
LOGGER.warning(
|
LOGGER.warning(
|
||||||
"no matching type found, using default",
|
"no matching type found, using default",
|
||||||
wanted=source.provider_type,
|
wanted=type_name,
|
||||||
have=[x.name for x in self.__sources],
|
have=[x.name for x in self.__sources],
|
||||||
)
|
)
|
||||||
return found_type
|
return found_type
|
||||||
|
|
|
@ -16963,9 +16963,6 @@ definitions:
|
||||||
- name
|
- name
|
||||||
- slug
|
- slug
|
||||||
- provider_type
|
- provider_type
|
||||||
- authorization_url
|
|
||||||
- access_token_url
|
|
||||||
- profile_url
|
|
||||||
- consumer_key
|
- consumer_key
|
||||||
- consumer_secret
|
- consumer_secret
|
||||||
type: object
|
type: object
|
||||||
|
@ -17037,19 +17034,16 @@ definitions:
|
||||||
description: URL the user is redirect to to conest the flow.
|
description: URL the user is redirect to to conest the flow.
|
||||||
type: string
|
type: string
|
||||||
maxLength: 255
|
maxLength: 255
|
||||||
minLength: 1
|
|
||||||
access_token_url:
|
access_token_url:
|
||||||
title: Access Token URL
|
title: Access Token URL
|
||||||
description: URL used by authentik to retrive tokens.
|
description: URL used by authentik to retrive tokens.
|
||||||
type: string
|
type: string
|
||||||
maxLength: 255
|
maxLength: 255
|
||||||
minLength: 1
|
|
||||||
profile_url:
|
profile_url:
|
||||||
title: Profile URL
|
title: Profile URL
|
||||||
description: URL used by authentik to get user information.
|
description: URL used by authentik to get user information.
|
||||||
type: string
|
type: string
|
||||||
maxLength: 255
|
maxLength: 255
|
||||||
minLength: 1
|
|
||||||
consumer_key:
|
consumer_key:
|
||||||
title: Consumer key
|
title: Consumer key
|
||||||
type: string
|
type: string
|
||||||
|
|
Reference in a new issue