outposts: improve validation of providers (must match outpost type)

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-06-02 16:04:41 +02:00
parent b339452843
commit 6d0e0cbe5a
5 changed files with 75 additions and 3 deletions

View File

@ -4,6 +4,7 @@ from dacite.exceptions import DaciteError
from drf_spectacular.utils import extend_schema
from rest_framework.decorators import action
from rest_framework.fields import BooleanField, CharField, DateTimeField
from rest_framework.relations import PrimaryKeyRelatedField
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import JSONField, ModelSerializer, ValidationError
@ -11,15 +12,45 @@ from rest_framework.viewsets import ModelViewSet
from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.utils import PassiveSerializer, is_dict
from authentik.outposts.models import Outpost, OutpostConfig, default_outpost_config
from authentik.core.models import Provider
from authentik.outposts.models import (
Outpost,
OutpostConfig,
OutpostType,
default_outpost_config,
)
from authentik.providers.ldap.models import LDAPProvider
from authentik.providers.proxy.models import ProxyProvider
class OutpostSerializer(ModelSerializer):
"""Outpost Serializer"""
config = JSONField(validators=[is_dict], source="_config")
providers = PrimaryKeyRelatedField(
allow_empty=False,
many=True,
queryset=Provider.objects.select_subclasses().all(),
)
providers_obj = ProviderSerializer(source="providers", many=True, read_only=True)
def validate_providers(self, providers: list[Provider]) -> list[Provider]:
"""Check that all providers match the type of the outpost"""
type_map = {
OutpostType.LDAP: LDAPProvider,
OutpostType.PROXY: ProxyProvider,
None: Provider,
}
for provider in providers:
if not isinstance(provider, type_map[self.initial_data.get("type")]):
raise ValidationError(
(
f"Outpost type {self.initial_data['type']} can't be used with "
f"{type(provider)} providers."
)
)
return providers
def validate_config(self, config) -> dict:
"""Check that the config has all required fields"""
try:
@ -41,6 +72,7 @@ class OutpostSerializer(ModelSerializer):
"token_identifier",
"config",
]
extra_kwargs = {"type": {"required": True}}
class OutpostDefaultConfigSerializer(PassiveSerializer):

View File

@ -376,7 +376,11 @@ class Outpost(models.Model):
@property
def token(self) -> Token:
"""Get/create token for auto-generated user"""
token = Token.filter_not_expired(user=self.user, intent=TokenIntents.INTENT_API)
token = Token.filter_not_expired(
user=self.user,
intent=TokenIntents.INTENT_API,
managed=f"goauthentik.io/outpost/{self.token_identifier}",
)
if token.exists():
return token.first()
return Token.objects.create(

View File

@ -65,6 +65,8 @@ def outpost_service_connection_state(connection_pk: Any):
.select_subclasses()
.first()
)
if not connection:
return
state = connection.fetch_state()
cache.set(connection.state_key, state, timeout=None)

View File

@ -5,7 +5,8 @@ from rest_framework.test import APITestCase
from authentik.core.models import PropertyMapping, User
from authentik.flows.models import Flow
from authentik.outposts.api.outposts import OutpostSerializer
from authentik.outposts.models import default_outpost_config
from authentik.outposts.models import OutpostType, default_outpost_config
from authentik.providers.ldap.models import LDAPProvider
from authentik.providers.proxy.models import ProxyProvider
@ -20,6 +21,36 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
self.user = User.objects.get(username="akadmin")
self.client.force_login(self.user)
def test_outpost_validaton(self):
"""Test Outpost validation"""
valid = OutpostSerializer(
data={
"name": "foo",
"type": OutpostType.PROXY,
"config": default_outpost_config(),
"providers": [
ProxyProvider.objects.create(
name="test", authorization_flow=Flow.objects.first()
).pk
],
}
)
self.assertTrue(valid.is_valid())
invalid = OutpostSerializer(
data={
"name": "foo",
"type": OutpostType.PROXY,
"config": default_outpost_config(),
"providers": [
LDAPProvider.objects.create(
name="test", authorization_flow=Flow.objects.first()
).pk
],
}
)
self.assertFalse(invalid.is_valid())
self.assertIn("providers", invalid.errors)
def test_types(self):
"""Test OutpostServiceConnections's types endpoint"""
response = self.client.get(
@ -42,6 +73,7 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
"name": "foo",
"providers": [provider.pk],
"config": default_outpost_config("foo"),
"type": OutpostType.PROXY,
}
)
self.assertTrue(valid.is_valid())

View File

@ -18975,6 +18975,7 @@ components:
- providers
- providers_obj
- token_identifier
- type
OutpostDefaultConfig:
type: object
description: Global default outpost config
@ -19032,6 +19033,7 @@ components:
- config
- name
- providers
- type
OutpostTypeEnum:
enum:
- proxy