outposts: improve validation of providers (must match outpost type)

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-06-02 16:04:41 +02:00
parent b339452843
commit 6d0e0cbe5a
5 changed files with 75 additions and 3 deletions

View file

@ -4,6 +4,7 @@ from dacite.exceptions import DaciteError
from drf_spectacular.utils import extend_schema from drf_spectacular.utils import extend_schema
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.fields import BooleanField, CharField, DateTimeField from rest_framework.fields import BooleanField, CharField, DateTimeField
from rest_framework.relations import PrimaryKeyRelatedField
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.serializers import JSONField, ModelSerializer, ValidationError from rest_framework.serializers import JSONField, ModelSerializer, ValidationError
@ -11,15 +12,45 @@ from rest_framework.viewsets import ModelViewSet
from authentik.core.api.providers import ProviderSerializer from authentik.core.api.providers import ProviderSerializer
from authentik.core.api.utils import PassiveSerializer, is_dict from authentik.core.api.utils import PassiveSerializer, is_dict
from authentik.outposts.models import Outpost, OutpostConfig, default_outpost_config from authentik.core.models import Provider
from authentik.outposts.models import (
Outpost,
OutpostConfig,
OutpostType,
default_outpost_config,
)
from authentik.providers.ldap.models import LDAPProvider
from authentik.providers.proxy.models import ProxyProvider
class OutpostSerializer(ModelSerializer): class OutpostSerializer(ModelSerializer):
"""Outpost Serializer""" """Outpost Serializer"""
config = JSONField(validators=[is_dict], source="_config") config = JSONField(validators=[is_dict], source="_config")
providers = PrimaryKeyRelatedField(
allow_empty=False,
many=True,
queryset=Provider.objects.select_subclasses().all(),
)
providers_obj = ProviderSerializer(source="providers", many=True, read_only=True) providers_obj = ProviderSerializer(source="providers", many=True, read_only=True)
def validate_providers(self, providers: list[Provider]) -> list[Provider]:
"""Check that all providers match the type of the outpost"""
type_map = {
OutpostType.LDAP: LDAPProvider,
OutpostType.PROXY: ProxyProvider,
None: Provider,
}
for provider in providers:
if not isinstance(provider, type_map[self.initial_data.get("type")]):
raise ValidationError(
(
f"Outpost type {self.initial_data['type']} can't be used with "
f"{type(provider)} providers."
)
)
return providers
def validate_config(self, config) -> dict: def validate_config(self, config) -> dict:
"""Check that the config has all required fields""" """Check that the config has all required fields"""
try: try:
@ -41,6 +72,7 @@ class OutpostSerializer(ModelSerializer):
"token_identifier", "token_identifier",
"config", "config",
] ]
extra_kwargs = {"type": {"required": True}}
class OutpostDefaultConfigSerializer(PassiveSerializer): class OutpostDefaultConfigSerializer(PassiveSerializer):

View file

@ -376,7 +376,11 @@ class Outpost(models.Model):
@property @property
def token(self) -> Token: def token(self) -> Token:
"""Get/create token for auto-generated user""" """Get/create token for auto-generated user"""
token = Token.filter_not_expired(user=self.user, intent=TokenIntents.INTENT_API) token = Token.filter_not_expired(
user=self.user,
intent=TokenIntents.INTENT_API,
managed=f"goauthentik.io/outpost/{self.token_identifier}",
)
if token.exists(): if token.exists():
return token.first() return token.first()
return Token.objects.create( return Token.objects.create(

View file

@ -65,6 +65,8 @@ def outpost_service_connection_state(connection_pk: Any):
.select_subclasses() .select_subclasses()
.first() .first()
) )
if not connection:
return
state = connection.fetch_state() state = connection.fetch_state()
cache.set(connection.state_key, state, timeout=None) cache.set(connection.state_key, state, timeout=None)

View file

@ -5,7 +5,8 @@ from rest_framework.test import APITestCase
from authentik.core.models import PropertyMapping, User from authentik.core.models import PropertyMapping, User
from authentik.flows.models import Flow from authentik.flows.models import Flow
from authentik.outposts.api.outposts import OutpostSerializer from authentik.outposts.api.outposts import OutpostSerializer
from authentik.outposts.models import default_outpost_config from authentik.outposts.models import OutpostType, default_outpost_config
from authentik.providers.ldap.models import LDAPProvider
from authentik.providers.proxy.models import ProxyProvider from authentik.providers.proxy.models import ProxyProvider
@ -20,6 +21,36 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
self.user = User.objects.get(username="akadmin") self.user = User.objects.get(username="akadmin")
self.client.force_login(self.user) self.client.force_login(self.user)
def test_outpost_validaton(self):
"""Test Outpost validation"""
valid = OutpostSerializer(
data={
"name": "foo",
"type": OutpostType.PROXY,
"config": default_outpost_config(),
"providers": [
ProxyProvider.objects.create(
name="test", authorization_flow=Flow.objects.first()
).pk
],
}
)
self.assertTrue(valid.is_valid())
invalid = OutpostSerializer(
data={
"name": "foo",
"type": OutpostType.PROXY,
"config": default_outpost_config(),
"providers": [
LDAPProvider.objects.create(
name="test", authorization_flow=Flow.objects.first()
).pk
],
}
)
self.assertFalse(invalid.is_valid())
self.assertIn("providers", invalid.errors)
def test_types(self): def test_types(self):
"""Test OutpostServiceConnections's types endpoint""" """Test OutpostServiceConnections's types endpoint"""
response = self.client.get( response = self.client.get(
@ -42,6 +73,7 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
"name": "foo", "name": "foo",
"providers": [provider.pk], "providers": [provider.pk],
"config": default_outpost_config("foo"), "config": default_outpost_config("foo"),
"type": OutpostType.PROXY,
} }
) )
self.assertTrue(valid.is_valid()) self.assertTrue(valid.is_valid())

View file

@ -18975,6 +18975,7 @@ components:
- providers - providers
- providers_obj - providers_obj
- token_identifier - token_identifier
- type
OutpostDefaultConfig: OutpostDefaultConfig:
type: object type: object
description: Global default outpost config description: Global default outpost config
@ -19032,6 +19033,7 @@ components:
- config - config
- name - name
- providers - providers
- type
OutpostTypeEnum: OutpostTypeEnum:
enum: enum:
- proxy - proxy