outposts: use channel groups instead of saving channel names (#7183)
* outposts: use channel groups instead of saving channel names Signed-off-by: Jens Langhammer <jens@goauthentik.io> * use pubsub Signed-off-by: Jens Langhammer <jens@goauthentik.io> * support storing other args with state Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
00b2a773b4
commit
25d4905d6c
|
@ -4,6 +4,7 @@ from datetime import datetime
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from asgiref.sync import async_to_sync
|
||||||
from channels.exceptions import DenyConnection
|
from channels.exceptions import DenyConnection
|
||||||
from dacite.core import from_dict
|
from dacite.core import from_dict
|
||||||
from dacite.data import Data
|
from dacite.data import Data
|
||||||
|
@ -14,6 +15,8 @@ from authentik.core.channels import AuthJsonConsumer
|
||||||
from authentik.outposts.apps import GAUGE_OUTPOSTS_CONNECTED, GAUGE_OUTPOSTS_LAST_UPDATE
|
from authentik.outposts.apps import GAUGE_OUTPOSTS_CONNECTED, GAUGE_OUTPOSTS_LAST_UPDATE
|
||||||
from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState
|
from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState
|
||||||
|
|
||||||
|
OUTPOST_GROUP = "group_outpost_%(outpost_pk)s"
|
||||||
|
|
||||||
|
|
||||||
class WebsocketMessageInstruction(IntEnum):
|
class WebsocketMessageInstruction(IntEnum):
|
||||||
"""Commands which can be triggered over Websocket"""
|
"""Commands which can be triggered over Websocket"""
|
||||||
|
@ -47,8 +50,6 @@ class OutpostConsumer(AuthJsonConsumer):
|
||||||
|
|
||||||
last_uid: Optional[str] = None
|
last_uid: Optional[str] = None
|
||||||
|
|
||||||
first_msg = False
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.logger = get_logger()
|
self.logger = get_logger()
|
||||||
|
@ -71,22 +72,26 @@ class OutpostConsumer(AuthJsonConsumer):
|
||||||
raise DenyConnection()
|
raise DenyConnection()
|
||||||
self.outpost = outpost
|
self.outpost = outpost
|
||||||
self.last_uid = self.channel_name
|
self.last_uid = self.channel_name
|
||||||
|
async_to_sync(self.channel_layer.group_add)(
|
||||||
|
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
|
||||||
|
)
|
||||||
|
GAUGE_OUTPOSTS_CONNECTED.labels(
|
||||||
|
outpost=self.outpost.name,
|
||||||
|
uid=self.last_uid,
|
||||||
|
expected=self.outpost.config.kubernetes_replicas,
|
||||||
|
).inc()
|
||||||
|
|
||||||
def disconnect(self, code):
|
def disconnect(self, code):
|
||||||
|
if self.outpost:
|
||||||
|
async_to_sync(self.channel_layer.group_discard)(
|
||||||
|
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
|
||||||
|
)
|
||||||
if self.outpost and self.last_uid:
|
if self.outpost and self.last_uid:
|
||||||
state = OutpostState.for_instance_uid(self.outpost, self.last_uid)
|
|
||||||
if self.channel_name in state.channel_ids:
|
|
||||||
state.channel_ids.remove(self.channel_name)
|
|
||||||
state.save()
|
|
||||||
GAUGE_OUTPOSTS_CONNECTED.labels(
|
GAUGE_OUTPOSTS_CONNECTED.labels(
|
||||||
outpost=self.outpost.name,
|
outpost=self.outpost.name,
|
||||||
uid=self.last_uid,
|
uid=self.last_uid,
|
||||||
expected=self.outpost.config.kubernetes_replicas,
|
expected=self.outpost.config.kubernetes_replicas,
|
||||||
).dec()
|
).dec()
|
||||||
self.logger.debug(
|
|
||||||
"removed outpost instance from cache",
|
|
||||||
instance_uuid=self.last_uid,
|
|
||||||
)
|
|
||||||
|
|
||||||
def receive_json(self, content: Data):
|
def receive_json(self, content: Data):
|
||||||
msg = from_dict(WebsocketMessage, content)
|
msg = from_dict(WebsocketMessage, content)
|
||||||
|
@ -97,26 +102,13 @@ class OutpostConsumer(AuthJsonConsumer):
|
||||||
raise DenyConnection()
|
raise DenyConnection()
|
||||||
|
|
||||||
state = OutpostState.for_instance_uid(self.outpost, uid)
|
state = OutpostState.for_instance_uid(self.outpost, uid)
|
||||||
if self.channel_name not in state.channel_ids:
|
|
||||||
state.channel_ids.append(self.channel_name)
|
|
||||||
state.last_seen = datetime.now()
|
state.last_seen = datetime.now()
|
||||||
state.hostname = msg.args.get("hostname", "")
|
state.hostname = msg.args.pop("hostname", "")
|
||||||
|
|
||||||
if not self.first_msg:
|
|
||||||
GAUGE_OUTPOSTS_CONNECTED.labels(
|
|
||||||
outpost=self.outpost.name,
|
|
||||||
uid=self.last_uid,
|
|
||||||
expected=self.outpost.config.kubernetes_replicas,
|
|
||||||
).inc()
|
|
||||||
self.logger.debug(
|
|
||||||
"added outpost instance to cache",
|
|
||||||
instance_uuid=self.last_uid,
|
|
||||||
)
|
|
||||||
self.first_msg = True
|
|
||||||
|
|
||||||
if msg.instruction == WebsocketMessageInstruction.HELLO:
|
if msg.instruction == WebsocketMessageInstruction.HELLO:
|
||||||
state.version = msg.args.get("version", None)
|
state.version = msg.args.pop("version", None)
|
||||||
state.build_hash = msg.args.get("buildHash", "")
|
state.build_hash = msg.args.pop("buildHash", "")
|
||||||
|
state.args = msg.args
|
||||||
elif msg.instruction == WebsocketMessageInstruction.ACK:
|
elif msg.instruction == WebsocketMessageInstruction.ACK:
|
||||||
return
|
return
|
||||||
GAUGE_OUTPOSTS_LAST_UPDATE.labels(
|
GAUGE_OUTPOSTS_LAST_UPDATE.labels(
|
|
@ -411,12 +411,12 @@ class OutpostState:
|
||||||
"""Outpost instance state, last_seen and version"""
|
"""Outpost instance state, last_seen and version"""
|
||||||
|
|
||||||
uid: str
|
uid: str
|
||||||
channel_ids: list[str] = field(default_factory=list)
|
|
||||||
last_seen: Optional[datetime] = field(default=None)
|
last_seen: Optional[datetime] = field(default=None)
|
||||||
version: Optional[str] = field(default=None)
|
version: Optional[str] = field(default=None)
|
||||||
version_should: Version = field(default=OUR_VERSION)
|
version_should: Version = field(default=OUR_VERSION)
|
||||||
build_hash: str = field(default="")
|
build_hash: str = field(default="")
|
||||||
hostname: str = field(default="")
|
hostname: str = field(default="")
|
||||||
|
args: dict = field(default_factory=dict)
|
||||||
|
|
||||||
_outpost: Optional[Outpost] = field(default=None)
|
_outpost: Optional[Outpost] = field(default=None)
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,7 @@ from authentik.events.monitored_tasks import (
|
||||||
)
|
)
|
||||||
from authentik.lib.config import CONFIG
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.lib.utils.reflection import path_to_class
|
from authentik.lib.utils.reflection import path_to_class
|
||||||
|
from authentik.outposts.consumer import OUTPOST_GROUP
|
||||||
from authentik.outposts.controllers.base import BaseController, ControllerException
|
from authentik.outposts.controllers.base import BaseController, ControllerException
|
||||||
from authentik.outposts.controllers.docker import DockerClient
|
from authentik.outposts.controllers.docker import DockerClient
|
||||||
from authentik.outposts.controllers.kubernetes import KubernetesClient
|
from authentik.outposts.controllers.kubernetes import KubernetesClient
|
||||||
|
@ -34,7 +35,6 @@ from authentik.outposts.models import (
|
||||||
Outpost,
|
Outpost,
|
||||||
OutpostModel,
|
OutpostModel,
|
||||||
OutpostServiceConnection,
|
OutpostServiceConnection,
|
||||||
OutpostState,
|
|
||||||
OutpostType,
|
OutpostType,
|
||||||
ServiceConnectionInvalid,
|
ServiceConnectionInvalid,
|
||||||
)
|
)
|
||||||
|
@ -243,10 +243,9 @@ def _outpost_single_update(outpost: Outpost, layer=None):
|
||||||
outpost.build_user_permissions(outpost.user)
|
outpost.build_user_permissions(outpost.user)
|
||||||
if not layer: # pragma: no cover
|
if not layer: # pragma: no cover
|
||||||
layer = get_channel_layer()
|
layer = get_channel_layer()
|
||||||
for state in OutpostState.for_outpost(outpost):
|
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
|
||||||
for channel in state.channel_ids:
|
LOGGER.debug("sending update", channel=group, outpost=outpost)
|
||||||
LOGGER.debug("sending update", channel=channel, instance=state.uid, outpost=outpost)
|
async_to_sync(layer.group_send)(group, {"type": "event.update"})
|
||||||
async_to_sync(layer.send)(channel, {"type": "event.update"})
|
|
||||||
|
|
||||||
|
|
||||||
@CELERY_APP.task(
|
@CELERY_APP.task(
|
||||||
|
|
|
@ -7,7 +7,7 @@ from django.test import TransactionTestCase
|
||||||
|
|
||||||
from authentik import __version__
|
from authentik import __version__
|
||||||
from authentik.core.tests.utils import create_test_flow
|
from authentik.core.tests.utils import create_test_flow
|
||||||
from authentik.outposts.channels import WebsocketMessage, WebsocketMessageInstruction
|
from authentik.outposts.consumer import WebsocketMessage, WebsocketMessageInstruction
|
||||||
from authentik.outposts.models import Outpost, OutpostType
|
from authentik.outposts.models import Outpost, OutpostType
|
||||||
from authentik.providers.proxy.models import ProxyProvider
|
from authentik.providers.proxy.models import ProxyProvider
|
||||||
from authentik.root import websocket
|
from authentik.root import websocket
|
||||||
|
|
|
@ -7,7 +7,7 @@ from authentik.outposts.api.service_connections import (
|
||||||
KubernetesServiceConnectionViewSet,
|
KubernetesServiceConnectionViewSet,
|
||||||
ServiceConnectionViewSet,
|
ServiceConnectionViewSet,
|
||||||
)
|
)
|
||||||
from authentik.outposts.channels import OutpostConsumer
|
from authentik.outposts.consumer import OutpostConsumer
|
||||||
from authentik.root.middleware import ChannelsLoggingMiddleware
|
from authentik.root.middleware import ChannelsLoggingMiddleware
|
||||||
|
|
||||||
websocket_urlpatterns = [
|
websocket_urlpatterns = [
|
||||||
|
|
|
@ -3,7 +3,8 @@ from asgiref.sync import async_to_sync
|
||||||
from channels.layers import get_channel_layer
|
from channels.layers import get_channel_layer
|
||||||
from django.db import DatabaseError, InternalError, ProgrammingError
|
from django.db import DatabaseError, InternalError, ProgrammingError
|
||||||
|
|
||||||
from authentik.outposts.models import Outpost, OutpostState, OutpostType
|
from authentik.outposts.consumer import OUTPOST_GROUP
|
||||||
|
from authentik.outposts.models import Outpost, OutpostType
|
||||||
from authentik.providers.proxy.models import ProxyProvider
|
from authentik.providers.proxy.models import ProxyProvider
|
||||||
from authentik.root.celery import CELERY_APP
|
from authentik.root.celery import CELERY_APP
|
||||||
|
|
||||||
|
@ -23,10 +24,9 @@ def proxy_on_logout(session_id: str):
|
||||||
"""Update outpost instances connected to a single outpost"""
|
"""Update outpost instances connected to a single outpost"""
|
||||||
layer = get_channel_layer()
|
layer = get_channel_layer()
|
||||||
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
|
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
|
||||||
for state in OutpostState.for_outpost(outpost):
|
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
|
||||||
for channel in state.channel_ids:
|
async_to_sync(layer.group_send)(
|
||||||
async_to_sync(layer.send)(
|
group,
|
||||||
channel,
|
|
||||||
{
|
{
|
||||||
"type": "event.provider.specific",
|
"type": "event.provider.specific",
|
||||||
"sub_type": "logout",
|
"sub_type": "logout",
|
||||||
|
|
|
@ -253,10 +253,10 @@ ASGI_APPLICATION = "authentik.root.asgi.application"
|
||||||
|
|
||||||
CHANNEL_LAYERS = {
|
CHANNEL_LAYERS = {
|
||||||
"default": {
|
"default": {
|
||||||
"BACKEND": "channels_redis.core.RedisChannelLayer",
|
"BACKEND": "channels_redis.pubsub.RedisPubSubChannelLayer",
|
||||||
"CONFIG": {
|
"CONFIG": {
|
||||||
"hosts": [f"{_redis_url}/{CONFIG.get('redis.db')}"],
|
"hosts": [f"{_redis_url}/{CONFIG.get('redis.db')}"],
|
||||||
"prefix": "authentik_channels",
|
"prefix": "authentik_channels_",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
Reference in a new issue