outposts: use channel groups instead of saving channel names (#7183)

* outposts: use channel groups instead of saving channel names

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* use pubsub

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* support storing other args with state

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L 2023-10-16 17:01:44 +02:00 committed by GitHub
parent 00b2a773b4
commit 25d4905d6c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 39 additions and 48 deletions

View file

@ -4,6 +4,7 @@ from datetime import datetime
from enum import IntEnum from enum import IntEnum
from typing import Any, Optional from typing import Any, Optional
from asgiref.sync import async_to_sync
from channels.exceptions import DenyConnection from channels.exceptions import DenyConnection
from dacite.core import from_dict from dacite.core import from_dict
from dacite.data import Data from dacite.data import Data
@ -14,6 +15,8 @@ from authentik.core.channels import AuthJsonConsumer
from authentik.outposts.apps import GAUGE_OUTPOSTS_CONNECTED, GAUGE_OUTPOSTS_LAST_UPDATE from authentik.outposts.apps import GAUGE_OUTPOSTS_CONNECTED, GAUGE_OUTPOSTS_LAST_UPDATE
from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState
OUTPOST_GROUP = "group_outpost_%(outpost_pk)s"
class WebsocketMessageInstruction(IntEnum): class WebsocketMessageInstruction(IntEnum):
"""Commands which can be triggered over Websocket""" """Commands which can be triggered over Websocket"""
@ -47,8 +50,6 @@ class OutpostConsumer(AuthJsonConsumer):
last_uid: Optional[str] = None last_uid: Optional[str] = None
first_msg = False
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.logger = get_logger() self.logger = get_logger()
@ -71,22 +72,26 @@ class OutpostConsumer(AuthJsonConsumer):
raise DenyConnection() raise DenyConnection()
self.outpost = outpost self.outpost = outpost
self.last_uid = self.channel_name self.last_uid = self.channel_name
async_to_sync(self.channel_layer.group_add)(
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
)
GAUGE_OUTPOSTS_CONNECTED.labels(
outpost=self.outpost.name,
uid=self.last_uid,
expected=self.outpost.config.kubernetes_replicas,
).inc()
def disconnect(self, code): def disconnect(self, code):
if self.outpost:
async_to_sync(self.channel_layer.group_discard)(
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
)
if self.outpost and self.last_uid: if self.outpost and self.last_uid:
state = OutpostState.for_instance_uid(self.outpost, self.last_uid)
if self.channel_name in state.channel_ids:
state.channel_ids.remove(self.channel_name)
state.save()
GAUGE_OUTPOSTS_CONNECTED.labels( GAUGE_OUTPOSTS_CONNECTED.labels(
outpost=self.outpost.name, outpost=self.outpost.name,
uid=self.last_uid, uid=self.last_uid,
expected=self.outpost.config.kubernetes_replicas, expected=self.outpost.config.kubernetes_replicas,
).dec() ).dec()
self.logger.debug(
"removed outpost instance from cache",
instance_uuid=self.last_uid,
)
def receive_json(self, content: Data): def receive_json(self, content: Data):
msg = from_dict(WebsocketMessage, content) msg = from_dict(WebsocketMessage, content)
@ -97,26 +102,13 @@ class OutpostConsumer(AuthJsonConsumer):
raise DenyConnection() raise DenyConnection()
state = OutpostState.for_instance_uid(self.outpost, uid) state = OutpostState.for_instance_uid(self.outpost, uid)
if self.channel_name not in state.channel_ids:
state.channel_ids.append(self.channel_name)
state.last_seen = datetime.now() state.last_seen = datetime.now()
state.hostname = msg.args.get("hostname", "") state.hostname = msg.args.pop("hostname", "")
if not self.first_msg:
GAUGE_OUTPOSTS_CONNECTED.labels(
outpost=self.outpost.name,
uid=self.last_uid,
expected=self.outpost.config.kubernetes_replicas,
).inc()
self.logger.debug(
"added outpost instance to cache",
instance_uuid=self.last_uid,
)
self.first_msg = True
if msg.instruction == WebsocketMessageInstruction.HELLO: if msg.instruction == WebsocketMessageInstruction.HELLO:
state.version = msg.args.get("version", None) state.version = msg.args.pop("version", None)
state.build_hash = msg.args.get("buildHash", "") state.build_hash = msg.args.pop("buildHash", "")
state.args = msg.args
elif msg.instruction == WebsocketMessageInstruction.ACK: elif msg.instruction == WebsocketMessageInstruction.ACK:
return return
GAUGE_OUTPOSTS_LAST_UPDATE.labels( GAUGE_OUTPOSTS_LAST_UPDATE.labels(

View file

@ -411,12 +411,12 @@ class OutpostState:
"""Outpost instance state, last_seen and version""" """Outpost instance state, last_seen and version"""
uid: str uid: str
channel_ids: list[str] = field(default_factory=list)
last_seen: Optional[datetime] = field(default=None) last_seen: Optional[datetime] = field(default=None)
version: Optional[str] = field(default=None) version: Optional[str] = field(default=None)
version_should: Version = field(default=OUR_VERSION) version_should: Version = field(default=OUR_VERSION)
build_hash: str = field(default="") build_hash: str = field(default="")
hostname: str = field(default="") hostname: str = field(default="")
args: dict = field(default_factory=dict)
_outpost: Optional[Outpost] = field(default=None) _outpost: Optional[Outpost] = field(default=None)

View file

@ -25,6 +25,7 @@ from authentik.events.monitored_tasks import (
) )
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.lib.utils.reflection import path_to_class from authentik.lib.utils.reflection import path_to_class
from authentik.outposts.consumer import OUTPOST_GROUP
from authentik.outposts.controllers.base import BaseController, ControllerException from authentik.outposts.controllers.base import BaseController, ControllerException
from authentik.outposts.controllers.docker import DockerClient from authentik.outposts.controllers.docker import DockerClient
from authentik.outposts.controllers.kubernetes import KubernetesClient from authentik.outposts.controllers.kubernetes import KubernetesClient
@ -34,7 +35,6 @@ from authentik.outposts.models import (
Outpost, Outpost,
OutpostModel, OutpostModel,
OutpostServiceConnection, OutpostServiceConnection,
OutpostState,
OutpostType, OutpostType,
ServiceConnectionInvalid, ServiceConnectionInvalid,
) )
@ -243,10 +243,9 @@ def _outpost_single_update(outpost: Outpost, layer=None):
outpost.build_user_permissions(outpost.user) outpost.build_user_permissions(outpost.user)
if not layer: # pragma: no cover if not layer: # pragma: no cover
layer = get_channel_layer() layer = get_channel_layer()
for state in OutpostState.for_outpost(outpost): group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
for channel in state.channel_ids: LOGGER.debug("sending update", channel=group, outpost=outpost)
LOGGER.debug("sending update", channel=channel, instance=state.uid, outpost=outpost) async_to_sync(layer.group_send)(group, {"type": "event.update"})
async_to_sync(layer.send)(channel, {"type": "event.update"})
@CELERY_APP.task( @CELERY_APP.task(

View file

@ -7,7 +7,7 @@ from django.test import TransactionTestCase
from authentik import __version__ from authentik import __version__
from authentik.core.tests.utils import create_test_flow from authentik.core.tests.utils import create_test_flow
from authentik.outposts.channels import WebsocketMessage, WebsocketMessageInstruction from authentik.outposts.consumer import WebsocketMessage, WebsocketMessageInstruction
from authentik.outposts.models import Outpost, OutpostType from authentik.outposts.models import Outpost, OutpostType
from authentik.providers.proxy.models import ProxyProvider from authentik.providers.proxy.models import ProxyProvider
from authentik.root import websocket from authentik.root import websocket

View file

@ -7,7 +7,7 @@ from authentik.outposts.api.service_connections import (
KubernetesServiceConnectionViewSet, KubernetesServiceConnectionViewSet,
ServiceConnectionViewSet, ServiceConnectionViewSet,
) )
from authentik.outposts.channels import OutpostConsumer from authentik.outposts.consumer import OutpostConsumer
from authentik.root.middleware import ChannelsLoggingMiddleware from authentik.root.middleware import ChannelsLoggingMiddleware
websocket_urlpatterns = [ websocket_urlpatterns = [

View file

@ -3,7 +3,8 @@ from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer from channels.layers import get_channel_layer
from django.db import DatabaseError, InternalError, ProgrammingError from django.db import DatabaseError, InternalError, ProgrammingError
from authentik.outposts.models import Outpost, OutpostState, OutpostType from authentik.outposts.consumer import OUTPOST_GROUP
from authentik.outposts.models import Outpost, OutpostType
from authentik.providers.proxy.models import ProxyProvider from authentik.providers.proxy.models import ProxyProvider
from authentik.root.celery import CELERY_APP from authentik.root.celery import CELERY_APP
@ -23,13 +24,12 @@ def proxy_on_logout(session_id: str):
"""Update outpost instances connected to a single outpost""" """Update outpost instances connected to a single outpost"""
layer = get_channel_layer() layer = get_channel_layer()
for outpost in Outpost.objects.filter(type=OutpostType.PROXY): for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
for state in OutpostState.for_outpost(outpost): group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
for channel in state.channel_ids: async_to_sync(layer.group_send)(
async_to_sync(layer.send)( group,
channel, {
{ "type": "event.provider.specific",
"type": "event.provider.specific", "sub_type": "logout",
"sub_type": "logout", "session_id": session_id,
"session_id": session_id, },
}, )
)

View file

@ -253,10 +253,10 @@ ASGI_APPLICATION = "authentik.root.asgi.application"
CHANNEL_LAYERS = { CHANNEL_LAYERS = {
"default": { "default": {
"BACKEND": "channels_redis.core.RedisChannelLayer", "BACKEND": "channels_redis.pubsub.RedisPubSubChannelLayer",
"CONFIG": { "CONFIG": {
"hosts": [f"{_redis_url}/{CONFIG.get('redis.db')}"], "hosts": [f"{_redis_url}/{CONFIG.get('redis.db')}"],
"prefix": "authentik_channels", "prefix": "authentik_channels_",
}, },
}, },
} }