outposts: use channel groups instead of saving channel names (#7183)
* outposts: use channel groups instead of saving channel names Signed-off-by: Jens Langhammer <jens@goauthentik.io> * use pubsub Signed-off-by: Jens Langhammer <jens@goauthentik.io> * support storing other args with state Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
00b2a773b4
commit
25d4905d6c
|
@ -4,6 +4,7 @@ from datetime import datetime
|
|||
from enum import IntEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.exceptions import DenyConnection
|
||||
from dacite.core import from_dict
|
||||
from dacite.data import Data
|
||||
|
@ -14,6 +15,8 @@ from authentik.core.channels import AuthJsonConsumer
|
|||
from authentik.outposts.apps import GAUGE_OUTPOSTS_CONNECTED, GAUGE_OUTPOSTS_LAST_UPDATE
|
||||
from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState
|
||||
|
||||
OUTPOST_GROUP = "group_outpost_%(outpost_pk)s"
|
||||
|
||||
|
||||
class WebsocketMessageInstruction(IntEnum):
|
||||
"""Commands which can be triggered over Websocket"""
|
||||
|
@ -47,8 +50,6 @@ class OutpostConsumer(AuthJsonConsumer):
|
|||
|
||||
last_uid: Optional[str] = None
|
||||
|
||||
first_msg = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.logger = get_logger()
|
||||
|
@ -71,22 +72,26 @@ class OutpostConsumer(AuthJsonConsumer):
|
|||
raise DenyConnection()
|
||||
self.outpost = outpost
|
||||
self.last_uid = self.channel_name
|
||||
async_to_sync(self.channel_layer.group_add)(
|
||||
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
|
||||
)
|
||||
GAUGE_OUTPOSTS_CONNECTED.labels(
|
||||
outpost=self.outpost.name,
|
||||
uid=self.last_uid,
|
||||
expected=self.outpost.config.kubernetes_replicas,
|
||||
).inc()
|
||||
|
||||
def disconnect(self, code):
|
||||
if self.outpost:
|
||||
async_to_sync(self.channel_layer.group_discard)(
|
||||
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
|
||||
)
|
||||
if self.outpost and self.last_uid:
|
||||
state = OutpostState.for_instance_uid(self.outpost, self.last_uid)
|
||||
if self.channel_name in state.channel_ids:
|
||||
state.channel_ids.remove(self.channel_name)
|
||||
state.save()
|
||||
GAUGE_OUTPOSTS_CONNECTED.labels(
|
||||
outpost=self.outpost.name,
|
||||
uid=self.last_uid,
|
||||
expected=self.outpost.config.kubernetes_replicas,
|
||||
).dec()
|
||||
self.logger.debug(
|
||||
"removed outpost instance from cache",
|
||||
instance_uuid=self.last_uid,
|
||||
)
|
||||
|
||||
def receive_json(self, content: Data):
|
||||
msg = from_dict(WebsocketMessage, content)
|
||||
|
@ -97,26 +102,13 @@ class OutpostConsumer(AuthJsonConsumer):
|
|||
raise DenyConnection()
|
||||
|
||||
state = OutpostState.for_instance_uid(self.outpost, uid)
|
||||
if self.channel_name not in state.channel_ids:
|
||||
state.channel_ids.append(self.channel_name)
|
||||
state.last_seen = datetime.now()
|
||||
state.hostname = msg.args.get("hostname", "")
|
||||
|
||||
if not self.first_msg:
|
||||
GAUGE_OUTPOSTS_CONNECTED.labels(
|
||||
outpost=self.outpost.name,
|
||||
uid=self.last_uid,
|
||||
expected=self.outpost.config.kubernetes_replicas,
|
||||
).inc()
|
||||
self.logger.debug(
|
||||
"added outpost instance to cache",
|
||||
instance_uuid=self.last_uid,
|
||||
)
|
||||
self.first_msg = True
|
||||
state.hostname = msg.args.pop("hostname", "")
|
||||
|
||||
if msg.instruction == WebsocketMessageInstruction.HELLO:
|
||||
state.version = msg.args.get("version", None)
|
||||
state.build_hash = msg.args.get("buildHash", "")
|
||||
state.version = msg.args.pop("version", None)
|
||||
state.build_hash = msg.args.pop("buildHash", "")
|
||||
state.args = msg.args
|
||||
elif msg.instruction == WebsocketMessageInstruction.ACK:
|
||||
return
|
||||
GAUGE_OUTPOSTS_LAST_UPDATE.labels(
|
|
@ -411,12 +411,12 @@ class OutpostState:
|
|||
"""Outpost instance state, last_seen and version"""
|
||||
|
||||
uid: str
|
||||
channel_ids: list[str] = field(default_factory=list)
|
||||
last_seen: Optional[datetime] = field(default=None)
|
||||
version: Optional[str] = field(default=None)
|
||||
version_should: Version = field(default=OUR_VERSION)
|
||||
build_hash: str = field(default="")
|
||||
hostname: str = field(default="")
|
||||
args: dict = field(default_factory=dict)
|
||||
|
||||
_outpost: Optional[Outpost] = field(default=None)
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ from authentik.events.monitored_tasks import (
|
|||
)
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.utils.reflection import path_to_class
|
||||
from authentik.outposts.consumer import OUTPOST_GROUP
|
||||
from authentik.outposts.controllers.base import BaseController, ControllerException
|
||||
from authentik.outposts.controllers.docker import DockerClient
|
||||
from authentik.outposts.controllers.kubernetes import KubernetesClient
|
||||
|
@ -34,7 +35,6 @@ from authentik.outposts.models import (
|
|||
Outpost,
|
||||
OutpostModel,
|
||||
OutpostServiceConnection,
|
||||
OutpostState,
|
||||
OutpostType,
|
||||
ServiceConnectionInvalid,
|
||||
)
|
||||
|
@ -243,10 +243,9 @@ def _outpost_single_update(outpost: Outpost, layer=None):
|
|||
outpost.build_user_permissions(outpost.user)
|
||||
if not layer: # pragma: no cover
|
||||
layer = get_channel_layer()
|
||||
for state in OutpostState.for_outpost(outpost):
|
||||
for channel in state.channel_ids:
|
||||
LOGGER.debug("sending update", channel=channel, instance=state.uid, outpost=outpost)
|
||||
async_to_sync(layer.send)(channel, {"type": "event.update"})
|
||||
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
|
||||
LOGGER.debug("sending update", channel=group, outpost=outpost)
|
||||
async_to_sync(layer.group_send)(group, {"type": "event.update"})
|
||||
|
||||
|
||||
@CELERY_APP.task(
|
||||
|
|
|
@ -7,7 +7,7 @@ from django.test import TransactionTestCase
|
|||
|
||||
from authentik import __version__
|
||||
from authentik.core.tests.utils import create_test_flow
|
||||
from authentik.outposts.channels import WebsocketMessage, WebsocketMessageInstruction
|
||||
from authentik.outposts.consumer import WebsocketMessage, WebsocketMessageInstruction
|
||||
from authentik.outposts.models import Outpost, OutpostType
|
||||
from authentik.providers.proxy.models import ProxyProvider
|
||||
from authentik.root import websocket
|
||||
|
|
|
@ -7,7 +7,7 @@ from authentik.outposts.api.service_connections import (
|
|||
KubernetesServiceConnectionViewSet,
|
||||
ServiceConnectionViewSet,
|
||||
)
|
||||
from authentik.outposts.channels import OutpostConsumer
|
||||
from authentik.outposts.consumer import OutpostConsumer
|
||||
from authentik.root.middleware import ChannelsLoggingMiddleware
|
||||
|
||||
websocket_urlpatterns = [
|
||||
|
|
|
@ -3,7 +3,8 @@ from asgiref.sync import async_to_sync
|
|||
from channels.layers import get_channel_layer
|
||||
from django.db import DatabaseError, InternalError, ProgrammingError
|
||||
|
||||
from authentik.outposts.models import Outpost, OutpostState, OutpostType
|
||||
from authentik.outposts.consumer import OUTPOST_GROUP
|
||||
from authentik.outposts.models import Outpost, OutpostType
|
||||
from authentik.providers.proxy.models import ProxyProvider
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
|
@ -23,13 +24,12 @@ def proxy_on_logout(session_id: str):
|
|||
"""Update outpost instances connected to a single outpost"""
|
||||
layer = get_channel_layer()
|
||||
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
|
||||
for state in OutpostState.for_outpost(outpost):
|
||||
for channel in state.channel_ids:
|
||||
async_to_sync(layer.send)(
|
||||
channel,
|
||||
{
|
||||
"type": "event.provider.specific",
|
||||
"sub_type": "logout",
|
||||
"session_id": session_id,
|
||||
},
|
||||
)
|
||||
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
|
||||
async_to_sync(layer.group_send)(
|
||||
group,
|
||||
{
|
||||
"type": "event.provider.specific",
|
||||
"sub_type": "logout",
|
||||
"session_id": session_id,
|
||||
},
|
||||
)
|
||||
|
|
|
@ -253,10 +253,10 @@ ASGI_APPLICATION = "authentik.root.asgi.application"
|
|||
|
||||
CHANNEL_LAYERS = {
|
||||
"default": {
|
||||
"BACKEND": "channels_redis.core.RedisChannelLayer",
|
||||
"BACKEND": "channels_redis.pubsub.RedisPubSubChannelLayer",
|
||||
"CONFIG": {
|
||||
"hosts": [f"{_redis_url}/{CONFIG.get('redis.db')}"],
|
||||
"prefix": "authentik_channels",
|
||||
"prefix": "authentik_channels_",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
Reference in a new issue