sources/oauth: fix error when creating an oauth source which has fixed URLs

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-04-16 09:49:25 +02:00
parent c579540473
commit cfe0a7a694
6 changed files with 83 additions and 10 deletions

View file

@ -5,6 +5,7 @@ from rest_framework.decorators import action
from rest_framework.fields import BooleanField, CharField, SerializerMethodField
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import ValidationError
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.sources import SourceSerializer
@ -47,6 +48,20 @@ class OAuthSourceSerializer(SourceSerializer):
"""Get source's type configuration"""
return SourceTypeSerializer(instace.type).data
def validate(self, attrs: dict) -> dict:
provider_type = MANAGER.find_type(attrs.get("provider_type", ""))
for url in [
"authorization_url",
"access_token_url",
"profile_url",
]:
if getattr(provider_type, url, None) is None:
if url not in attrs:
raise ValidationError(
f"{url} is required for provider {provider_type.name}"
)
return attrs
class Meta:
model = OAuthSource
fields = SourceSerializer.Meta.fields + [

View file

@ -0,0 +1,43 @@
# Generated by Django 3.2 on 2021-04-16 07:26
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_sources_oauth", "0002_auto_20200520_1108"),
]
operations = [
migrations.AlterField(
model_name="oauthsource",
name="access_token_url",
field=models.CharField(
blank=True,
help_text="URL used by authentik to retrive tokens.",
max_length=255,
verbose_name="Access Token URL",
),
),
migrations.AlterField(
model_name="oauthsource",
name="authorization_url",
field=models.CharField(
blank=True,
help_text="URL the user is redirect to to conest the flow.",
max_length=255,
verbose_name="Authorization URL",
),
),
migrations.AlterField(
model_name="oauthsource",
name="profile_url",
field=models.CharField(
blank=True,
help_text="URL used by authentik to get user information.",
max_length=255,
verbose_name="Profile URL",
),
),
]

View file

@ -28,16 +28,19 @@ class OAuthSource(Source):
)
authorization_url = models.CharField(
max_length=255,
blank=True,
verbose_name=_("Authorization URL"),
help_text=_("URL the user is redirect to to conest the flow."),
)
access_token_url = models.CharField(
max_length=255,
blank=True,
verbose_name=_("Access Token URL"),
help_text=_("URL used by authentik to retrive tokens."),
)
profile_url = models.CharField(
max_length=255,
blank=True,
verbose_name=_("Profile URL"),
help_text=_("URL used by authentik to get user information."),
)
@ -49,7 +52,7 @@ class OAuthSource(Source):
"""Return the provider instance for this source"""
from authentik.sources.oauth.types.manager import MANAGER
return MANAGER.find_type(self)
return MANAGER.find_type(self.provider_type)
@property
def component(self) -> str:

View file

@ -1,4 +1,5 @@
"""OAuth Source tests"""
from authentik.sources.oauth.api.source import OAuthSourceSerializer
from django.test import TestCase
from django.urls import reverse
@ -18,6 +19,23 @@ class TestOAuthSource(TestCase):
consumer_key="",
)
def test_api_validate(self):
"""Test API validation"""
self.assertTrue(OAuthSourceSerializer(data={
"name": "foo",
"slug": "bar",
"provider_type": "google",
"consumer_key": "foo",
"consumer_secret": "foo",
}).is_valid())
self.assertFalse(OAuthSourceSerializer(data={
"name": "foo",
"slug": "bar",
"provider_type": "openid-connect",
"consumer_key": "foo",
"consumer_secret": "foo",
}).is_valid())
def test_source_redirect(self):
"""test redirect view"""
self.client.get(

View file

@ -58,17 +58,17 @@ class SourceTypeManager:
"""Get list of tuples of all registered names"""
return [(x.slug, x.name) for x in self.__sources]
def find_type(self, source: "OAuthSource") -> SourceType:
def find_type(self, type_name: str) -> SourceType:
"""Find type based on source"""
found_type = None
for src_type in self.__sources:
if src_type.slug == source.provider_type:
if src_type.slug == type_name:
return src_type
if not found_type:
found_type = SourceType()
LOGGER.warning(
"no matching type found, using default",
wanted=source.provider_type,
wanted=type_name,
have=[x.name for x in self.__sources],
)
return found_type

View file

@ -16963,9 +16963,6 @@ definitions:
- name
- slug
- provider_type
- authorization_url
- access_token_url
- profile_url
- consumer_key
- consumer_secret
type: object
@ -17037,19 +17034,16 @@ definitions:
description: URL the user is redirect to to conest the flow.
type: string
maxLength: 255
minLength: 1
access_token_url:
title: Access Token URL
description: URL used by authentik to retrive tokens.
type: string
maxLength: 255
minLength: 1
profile_url:
title: Profile URL
description: URL used by authentik to get user information.
type: string
maxLength: 255
minLength: 1
consumer_key:
title: Consumer key
type: string