outposts: improve validation of providers (must match outpost type)
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
b339452843
commit
6d0e0cbe5a
|
@ -4,6 +4,7 @@ from dacite.exceptions import DaciteError
|
|||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.fields import BooleanField, CharField, DateTimeField
|
||||
from rest_framework.relations import PrimaryKeyRelatedField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import JSONField, ModelSerializer, ValidationError
|
||||
|
@ -11,15 +12,45 @@ from rest_framework.viewsets import ModelViewSet
|
|||
|
||||
from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.utils import PassiveSerializer, is_dict
|
||||
from authentik.outposts.models import Outpost, OutpostConfig, default_outpost_config
|
||||
from authentik.core.models import Provider
|
||||
from authentik.outposts.models import (
|
||||
Outpost,
|
||||
OutpostConfig,
|
||||
OutpostType,
|
||||
default_outpost_config,
|
||||
)
|
||||
from authentik.providers.ldap.models import LDAPProvider
|
||||
from authentik.providers.proxy.models import ProxyProvider
|
||||
|
||||
|
||||
class OutpostSerializer(ModelSerializer):
|
||||
"""Outpost Serializer"""
|
||||
|
||||
config = JSONField(validators=[is_dict], source="_config")
|
||||
providers = PrimaryKeyRelatedField(
|
||||
allow_empty=False,
|
||||
many=True,
|
||||
queryset=Provider.objects.select_subclasses().all(),
|
||||
)
|
||||
providers_obj = ProviderSerializer(source="providers", many=True, read_only=True)
|
||||
|
||||
def validate_providers(self, providers: list[Provider]) -> list[Provider]:
|
||||
"""Check that all providers match the type of the outpost"""
|
||||
type_map = {
|
||||
OutpostType.LDAP: LDAPProvider,
|
||||
OutpostType.PROXY: ProxyProvider,
|
||||
None: Provider,
|
||||
}
|
||||
for provider in providers:
|
||||
if not isinstance(provider, type_map[self.initial_data.get("type")]):
|
||||
raise ValidationError(
|
||||
(
|
||||
f"Outpost type {self.initial_data['type']} can't be used with "
|
||||
f"{type(provider)} providers."
|
||||
)
|
||||
)
|
||||
return providers
|
||||
|
||||
def validate_config(self, config) -> dict:
|
||||
"""Check that the config has all required fields"""
|
||||
try:
|
||||
|
@ -41,6 +72,7 @@ class OutpostSerializer(ModelSerializer):
|
|||
"token_identifier",
|
||||
"config",
|
||||
]
|
||||
extra_kwargs = {"type": {"required": True}}
|
||||
|
||||
|
||||
class OutpostDefaultConfigSerializer(PassiveSerializer):
|
||||
|
|
|
@ -376,7 +376,11 @@ class Outpost(models.Model):
|
|||
@property
|
||||
def token(self) -> Token:
|
||||
"""Get/create token for auto-generated user"""
|
||||
token = Token.filter_not_expired(user=self.user, intent=TokenIntents.INTENT_API)
|
||||
token = Token.filter_not_expired(
|
||||
user=self.user,
|
||||
intent=TokenIntents.INTENT_API,
|
||||
managed=f"goauthentik.io/outpost/{self.token_identifier}",
|
||||
)
|
||||
if token.exists():
|
||||
return token.first()
|
||||
return Token.objects.create(
|
||||
|
|
|
@ -65,6 +65,8 @@ def outpost_service_connection_state(connection_pk: Any):
|
|||
.select_subclasses()
|
||||
.first()
|
||||
)
|
||||
if not connection:
|
||||
return
|
||||
state = connection.fetch_state()
|
||||
cache.set(connection.state_key, state, timeout=None)
|
||||
|
||||
|
|
|
@ -5,7 +5,8 @@ from rest_framework.test import APITestCase
|
|||
from authentik.core.models import PropertyMapping, User
|
||||
from authentik.flows.models import Flow
|
||||
from authentik.outposts.api.outposts import OutpostSerializer
|
||||
from authentik.outposts.models import default_outpost_config
|
||||
from authentik.outposts.models import OutpostType, default_outpost_config
|
||||
from authentik.providers.ldap.models import LDAPProvider
|
||||
from authentik.providers.proxy.models import ProxyProvider
|
||||
|
||||
|
||||
|
@ -20,6 +21,36 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
|
|||
self.user = User.objects.get(username="akadmin")
|
||||
self.client.force_login(self.user)
|
||||
|
||||
def test_outpost_validaton(self):
|
||||
"""Test Outpost validation"""
|
||||
valid = OutpostSerializer(
|
||||
data={
|
||||
"name": "foo",
|
||||
"type": OutpostType.PROXY,
|
||||
"config": default_outpost_config(),
|
||||
"providers": [
|
||||
ProxyProvider.objects.create(
|
||||
name="test", authorization_flow=Flow.objects.first()
|
||||
).pk
|
||||
],
|
||||
}
|
||||
)
|
||||
self.assertTrue(valid.is_valid())
|
||||
invalid = OutpostSerializer(
|
||||
data={
|
||||
"name": "foo",
|
||||
"type": OutpostType.PROXY,
|
||||
"config": default_outpost_config(),
|
||||
"providers": [
|
||||
LDAPProvider.objects.create(
|
||||
name="test", authorization_flow=Flow.objects.first()
|
||||
).pk
|
||||
],
|
||||
}
|
||||
)
|
||||
self.assertFalse(invalid.is_valid())
|
||||
self.assertIn("providers", invalid.errors)
|
||||
|
||||
def test_types(self):
|
||||
"""Test OutpostServiceConnections's types endpoint"""
|
||||
response = self.client.get(
|
||||
|
@ -42,6 +73,7 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
|
|||
"name": "foo",
|
||||
"providers": [provider.pk],
|
||||
"config": default_outpost_config("foo"),
|
||||
"type": OutpostType.PROXY,
|
||||
}
|
||||
)
|
||||
self.assertTrue(valid.is_valid())
|
||||
|
|
|
@ -18975,6 +18975,7 @@ components:
|
|||
- providers
|
||||
- providers_obj
|
||||
- token_identifier
|
||||
- type
|
||||
OutpostDefaultConfig:
|
||||
type: object
|
||||
description: Global default outpost config
|
||||
|
@ -19032,6 +19033,7 @@ components:
|
|||
- config
|
||||
- name
|
||||
- providers
|
||||
- type
|
||||
OutpostTypeEnum:
|
||||
enum:
|
||||
- proxy
|
||||
|
|
Reference in New Issue