diff --git a/authentik/outposts/tasks.py b/authentik/outposts/tasks.py index 9c604327a..179822fdf 100644 --- a/authentik/outposts/tasks.py +++ b/authentik/outposts/tasks.py @@ -7,6 +7,7 @@ from urllib.parse import urlparse import yaml from asgiref.sync import async_to_sync +from channels.layers import get_channel_layer from django.core.cache import cache from django.db import DatabaseError, InternalError, ProgrammingError from django.db.models.base import Model @@ -42,7 +43,6 @@ from authentik.providers.ldap.controllers.kubernetes import LDAPKubernetesContro from authentik.providers.proxy.controllers.docker import ProxyDockerController from authentik.providers.proxy.controllers.kubernetes import ProxyKubernetesController from authentik.root.celery import CELERY_APP -from authentik.root.messages.storage import closing_send LOGGER = get_logger() CACHE_KEY_OUTPOST_DOWN = "outpost_teardown_%s" @@ -214,26 +214,29 @@ def outpost_post_save(model_class: str, model_pk: Any): outpost_send_update(reverse) -def outpost_send_update(model_instace: Model): +def outpost_send_update(model_instance: Model): """Send outpost update to all registered outposts, regardless to which authentik instance they are connected""" - if isinstance(model_instace, OutpostModel): - for outpost in model_instace.outpost_set.all(): - _outpost_single_update(outpost) - elif isinstance(model_instace, Outpost): - _outpost_single_update(model_instace) + channel_layer = get_channel_layer() + if isinstance(model_instance, OutpostModel): + for outpost in model_instance.outpost_set.all(): + _outpost_single_update(outpost, channel_layer) + elif isinstance(model_instance, Outpost): + _outpost_single_update(model_instance, channel_layer) -def _outpost_single_update(outpost: Outpost): +def _outpost_single_update(outpost: Outpost, layer=None): """Update outpost instances connected to a single outpost""" # Ensure token again, because this function is called when anything related to an # OutpostModel is saved, so we can be sure permissions are right _ = outpost.token outpost.build_user_permissions(outpost.user) + if not layer: # pragma: no cover + layer = get_channel_layer() for state in OutpostState.for_outpost(outpost): for channel in state.channel_ids: LOGGER.debug("sending update", channel=channel, instance=state.uid, outpost=outpost) - async_to_sync(closing_send)(channel, {"type": "event.update"}) + async_to_sync(layer.send)(channel, {"type": "event.update"}) @CELERY_APP.task( diff --git a/authentik/root/messages/storage.py b/authentik/root/messages/storage.py index b821547b5..200cf44e0 100644 --- a/authentik/root/messages/storage.py +++ b/authentik/root/messages/storage.py @@ -1,7 +1,6 @@ """Channels Messages storage""" from asgiref.sync import async_to_sync -from channels import DEFAULT_CHANNEL_LAYER -from channels.layers import channel_layers +from channels.layers import get_channel_layer from django.contrib.messages.storage.base import Message from django.contrib.messages.storage.session import SessionStorage from django.core.cache import cache @@ -11,21 +10,13 @@ SESSION_KEY = "_messages" CACHE_PREFIX = "goauthentik.io/root/messages_" -async def closing_send(channel, message): - """Wrapper around layer send that closes the connection""" - # See https://github.com/django/channels_redis/issues/332 - # TODO: Remove this after channels_redis 4.1 is released - channel_layer = channel_layers.make_backend(DEFAULT_CHANNEL_LAYER) - await channel_layer.send(channel, message) - await channel_layer.close_pools() - - class ChannelsStorage(SessionStorage): """Send contrib.messages over websocket""" def __init__(self, request: HttpRequest) -> None: # pyright: reportGeneralTypeIssues=false super().__init__(request) + self.channel = get_channel_layer() def _store(self, messages: list[Message], response, *args, **kwargs): prefix = f"{CACHE_PREFIX}{self.request.session.session_key}_messages_" @@ -37,7 +28,7 @@ class ChannelsStorage(SessionStorage): for key in keys: uid = key.replace(prefix, "") for message in messages: - async_to_sync(closing_send)( + async_to_sync(self.channel.send)( uid, { "type": "event.update", diff --git a/poetry.lock b/poetry.lock index 3e94f6a6d..85c4a41eb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -661,25 +661,25 @@ tests = ["async-timeout", "coverage (>=4.5,<5.0)", "pytest", "pytest-asyncio", " [[package]] name = "channels-redis" -version = "4.0.0" +version = "4.1.0" description = "Redis-backed ASGI channel layer implementation" category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "channels_redis-4.0.0-py3-none-any.whl", hash = "sha256:81b59d68f53313e1aa891f23591841b684abb936b42e4d1a966d9e4dc63a95ec"}, - {file = "channels_redis-4.0.0.tar.gz", hash = "sha256:122414f29f525f7b9e0c9d59cdcfc4dc1b0eecba16fbb6a1c23f1d9b58f49dcb"}, + {file = "channels_redis-4.1.0-py3-none-any.whl", hash = "sha256:3696f5b9fe367ea495d402ba83d7c3c99e8ca0e1354ff8d913535976ed0abf73"}, + {file = "channels_redis-4.1.0.tar.gz", hash = "sha256:6bd4f75f4ab4a7db17cee495593ace886d7e914c66f8214a1f247ff6659c073a"}, ] [package.dependencies] asgiref = ">=3.2.10,<4" channels = "*" msgpack = ">=1.0,<2.0" -redis = ">=4.2.0" +redis = ">=4.5.3" [package.extras] cryptography = ["cryptography (>=1.3.0)"] -tests = ["async-timeout", "cryptography (>=1.3.0)", "pytest", "pytest-asyncio"] +tests = ["async-timeout", "cryptography (>=1.3.0)", "pytest", "pytest-asyncio", "pytest-timeout"] [[package]] name = "charset-normalizer"