outposts: use channel groups instead of saving channel names (#7183)

* outposts: use channel groups instead of saving channel names

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* use pubsub

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* support storing other args with state

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L 2023-10-16 17:01:44 +02:00 committed by GitHub
parent 00b2a773b4
commit 25d4905d6c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 39 additions and 48 deletions

View file

@ -4,6 +4,7 @@ from datetime import datetime
from enum import IntEnum
from typing import Any, Optional
from asgiref.sync import async_to_sync
from channels.exceptions import DenyConnection
from dacite.core import from_dict
from dacite.data import Data
@ -14,6 +15,8 @@ from authentik.core.channels import AuthJsonConsumer
from authentik.outposts.apps import GAUGE_OUTPOSTS_CONNECTED, GAUGE_OUTPOSTS_LAST_UPDATE
from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState
OUTPOST_GROUP = "group_outpost_%(outpost_pk)s"
class WebsocketMessageInstruction(IntEnum):
"""Commands which can be triggered over Websocket"""
@ -47,8 +50,6 @@ class OutpostConsumer(AuthJsonConsumer):
last_uid: Optional[str] = None
first_msg = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.logger = get_logger()
@ -71,22 +72,26 @@ class OutpostConsumer(AuthJsonConsumer):
raise DenyConnection()
self.outpost = outpost
self.last_uid = self.channel_name
async_to_sync(self.channel_layer.group_add)(
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
)
GAUGE_OUTPOSTS_CONNECTED.labels(
outpost=self.outpost.name,
uid=self.last_uid,
expected=self.outpost.config.kubernetes_replicas,
).inc()
def disconnect(self, code):
if self.outpost:
async_to_sync(self.channel_layer.group_discard)(
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
)
if self.outpost and self.last_uid:
state = OutpostState.for_instance_uid(self.outpost, self.last_uid)
if self.channel_name in state.channel_ids:
state.channel_ids.remove(self.channel_name)
state.save()
GAUGE_OUTPOSTS_CONNECTED.labels(
outpost=self.outpost.name,
uid=self.last_uid,
expected=self.outpost.config.kubernetes_replicas,
).dec()
self.logger.debug(
"removed outpost instance from cache",
instance_uuid=self.last_uid,
)
def receive_json(self, content: Data):
msg = from_dict(WebsocketMessage, content)
@ -97,26 +102,13 @@ class OutpostConsumer(AuthJsonConsumer):
raise DenyConnection()
state = OutpostState.for_instance_uid(self.outpost, uid)
if self.channel_name not in state.channel_ids:
state.channel_ids.append(self.channel_name)
state.last_seen = datetime.now()
state.hostname = msg.args.get("hostname", "")
if not self.first_msg:
GAUGE_OUTPOSTS_CONNECTED.labels(
outpost=self.outpost.name,
uid=self.last_uid,
expected=self.outpost.config.kubernetes_replicas,
).inc()
self.logger.debug(
"added outpost instance to cache",
instance_uuid=self.last_uid,
)
self.first_msg = True
state.hostname = msg.args.pop("hostname", "")
if msg.instruction == WebsocketMessageInstruction.HELLO:
state.version = msg.args.get("version", None)
state.build_hash = msg.args.get("buildHash", "")
state.version = msg.args.pop("version", None)
state.build_hash = msg.args.pop("buildHash", "")
state.args = msg.args
elif msg.instruction == WebsocketMessageInstruction.ACK:
return
GAUGE_OUTPOSTS_LAST_UPDATE.labels(

View file

@ -411,12 +411,12 @@ class OutpostState:
"""Outpost instance state, last_seen and version"""
uid: str
channel_ids: list[str] = field(default_factory=list)
last_seen: Optional[datetime] = field(default=None)
version: Optional[str] = field(default=None)
version_should: Version = field(default=OUR_VERSION)
build_hash: str = field(default="")
hostname: str = field(default="")
args: dict = field(default_factory=dict)
_outpost: Optional[Outpost] = field(default=None)

View file

@ -25,6 +25,7 @@ from authentik.events.monitored_tasks import (
)
from authentik.lib.config import CONFIG
from authentik.lib.utils.reflection import path_to_class
from authentik.outposts.consumer import OUTPOST_GROUP
from authentik.outposts.controllers.base import BaseController, ControllerException
from authentik.outposts.controllers.docker import DockerClient
from authentik.outposts.controllers.kubernetes import KubernetesClient
@ -34,7 +35,6 @@ from authentik.outposts.models import (
Outpost,
OutpostModel,
OutpostServiceConnection,
OutpostState,
OutpostType,
ServiceConnectionInvalid,
)
@ -243,10 +243,9 @@ def _outpost_single_update(outpost: Outpost, layer=None):
outpost.build_user_permissions(outpost.user)
if not layer: # pragma: no cover
layer = get_channel_layer()
for state in OutpostState.for_outpost(outpost):
for channel in state.channel_ids:
LOGGER.debug("sending update", channel=channel, instance=state.uid, outpost=outpost)
async_to_sync(layer.send)(channel, {"type": "event.update"})
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
LOGGER.debug("sending update", channel=group, outpost=outpost)
async_to_sync(layer.group_send)(group, {"type": "event.update"})
@CELERY_APP.task(

View file

@ -7,7 +7,7 @@ from django.test import TransactionTestCase
from authentik import __version__
from authentik.core.tests.utils import create_test_flow
from authentik.outposts.channels import WebsocketMessage, WebsocketMessageInstruction
from authentik.outposts.consumer import WebsocketMessage, WebsocketMessageInstruction
from authentik.outposts.models import Outpost, OutpostType
from authentik.providers.proxy.models import ProxyProvider
from authentik.root import websocket

View file

@ -7,7 +7,7 @@ from authentik.outposts.api.service_connections import (
KubernetesServiceConnectionViewSet,
ServiceConnectionViewSet,
)
from authentik.outposts.channels import OutpostConsumer
from authentik.outposts.consumer import OutpostConsumer
from authentik.root.middleware import ChannelsLoggingMiddleware
websocket_urlpatterns = [

View file

@ -3,7 +3,8 @@ from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.db import DatabaseError, InternalError, ProgrammingError
from authentik.outposts.models import Outpost, OutpostState, OutpostType
from authentik.outposts.consumer import OUTPOST_GROUP
from authentik.outposts.models import Outpost, OutpostType
from authentik.providers.proxy.models import ProxyProvider
from authentik.root.celery import CELERY_APP
@ -23,10 +24,9 @@ def proxy_on_logout(session_id: str):
"""Update outpost instances connected to a single outpost"""
layer = get_channel_layer()
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
for state in OutpostState.for_outpost(outpost):
for channel in state.channel_ids:
async_to_sync(layer.send)(
channel,
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
async_to_sync(layer.group_send)(
group,
{
"type": "event.provider.specific",
"sub_type": "logout",

View file

@ -253,10 +253,10 @@ ASGI_APPLICATION = "authentik.root.asgi.application"
CHANNEL_LAYERS = {
"default": {
"BACKEND": "channels_redis.core.RedisChannelLayer",
"BACKEND": "channels_redis.pubsub.RedisPubSubChannelLayer",
"CONFIG": {
"hosts": [f"{_redis_url}/{CONFIG.get('redis.db')}"],
"prefix": "authentik_channels",
"prefix": "authentik_channels_",
},
},
}