diff --git a/authentik/outposts/tasks.py b/authentik/outposts/tasks.py index 7c6cac932..984fa140a 100644 --- a/authentik/outposts/tasks.py +++ b/authentik/outposts/tasks.py @@ -7,7 +7,6 @@ from urllib.parse import urlparse import yaml from asgiref.sync import async_to_sync -from channels.layers import get_channel_layer from django.core.cache import cache from django.db import DatabaseError, InternalError, ProgrammingError from django.db.models.base import Model @@ -43,6 +42,7 @@ from authentik.providers.ldap.controllers.kubernetes import LDAPKubernetesContro from authentik.providers.proxy.controllers.docker import ProxyDockerController from authentik.providers.proxy.controllers.kubernetes import ProxyKubernetesController from authentik.root.celery import CELERY_APP +from authentik.root.messages.storage import closing_send LOGGER = get_logger() CACHE_KEY_OUTPOST_DOWN = "outpost_teardown_%s" @@ -217,26 +217,23 @@ def outpost_post_save(model_class: str, model_pk: Any): def outpost_send_update(model_instace: Model): """Send outpost update to all registered outposts, regardless to which authentik instance they are connected""" - channel_layer = get_channel_layer() if isinstance(model_instace, OutpostModel): for outpost in model_instace.outpost_set.all(): - _outpost_single_update(outpost, channel_layer) + _outpost_single_update(outpost) elif isinstance(model_instace, Outpost): - _outpost_single_update(model_instace, channel_layer) + _outpost_single_update(model_instace) -def _outpost_single_update(outpost: Outpost, layer=None): +def _outpost_single_update(outpost: Outpost): """Update outpost instances connected to a single outpost""" # Ensure token again, because this function is called when anything related to an # OutpostModel is saved, so we can be sure permissions are right _ = outpost.token outpost.build_user_permissions(outpost.user) - if not layer: # pragma: no cover - layer = get_channel_layer() for state in OutpostState.for_outpost(outpost): for channel in state.channel_ids: LOGGER.debug("sending update", channel=channel, instance=state.uid, outpost=outpost) - async_to_sync(layer.send)(channel, {"type": "event.update"}) + async_to_sync(closing_send)(channel, {"type": "event.update"}) @CELERY_APP.task() diff --git a/authentik/root/messages/storage.py b/authentik/root/messages/storage.py index 200cf44e0..b821547b5 100644 --- a/authentik/root/messages/storage.py +++ b/authentik/root/messages/storage.py @@ -1,6 +1,7 @@ """Channels Messages storage""" from asgiref.sync import async_to_sync -from channels.layers import get_channel_layer +from channels import DEFAULT_CHANNEL_LAYER +from channels.layers import channel_layers from django.contrib.messages.storage.base import Message from django.contrib.messages.storage.session import SessionStorage from django.core.cache import cache @@ -10,13 +11,21 @@ SESSION_KEY = "_messages" CACHE_PREFIX = "goauthentik.io/root/messages_" +async def closing_send(channel, message): + """Wrapper around layer send that closes the connection""" + # See https://github.com/django/channels_redis/issues/332 + # TODO: Remove this after channels_redis 4.1 is released + channel_layer = channel_layers.make_backend(DEFAULT_CHANNEL_LAYER) + await channel_layer.send(channel, message) + await channel_layer.close_pools() + + class ChannelsStorage(SessionStorage): """Send contrib.messages over websocket""" def __init__(self, request: HttpRequest) -> None: # pyright: reportGeneralTypeIssues=false super().__init__(request) - self.channel = get_channel_layer() def _store(self, messages: list[Message], response, *args, **kwargs): prefix = f"{CACHE_PREFIX}{self.request.session.session_key}_messages_" @@ -28,7 +37,7 @@ class ChannelsStorage(SessionStorage): for key in keys: uid = key.replace(prefix, "") for message in messages: - async_to_sync(self.channel.send)( + async_to_sync(closing_send)( uid, { "type": "event.update", diff --git a/pyproject.toml b/pyproject.toml index a57b00c20..f3e7f1ca5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,5 @@ [tool.pyright] -ignore = [ - "**/migrations/**", - "**/node_modules/**" -] +ignore = ["**/migrations/**", "**/node_modules/**"] reportMissingTypeStubs = false strictParameterNoneValue = true strictDictionaryInference = true @@ -63,14 +60,7 @@ exclude_lines = [ show_missing = true [tool.pylint.basic] -good-names = [ - "pk", - "id", - "i", - "j", - "k", - "_", -] +good-names = ["pk", "id", "i", "j", "k", "_"] [tool.pylint.master] disable = [ @@ -85,6 +75,7 @@ disable = [ "protected-access", "unused-argument", "raise-missing-from", + "fixme", # To preserve django's translation function we need to use %-formatting "consider-using-f-string", ] @@ -120,7 +111,7 @@ authors = ["authentik Team "] [tool.poetry.dependencies] celery = "*" -channels = {version = "*", extras = ["daphne"]} +channels = { version = "*", extras = ["daphne"] } channels-redis = "*" codespell = "*" colorama = "*" @@ -147,7 +138,7 @@ gunicorn = "*" kubernetes = "*" ldap3 = "*" lxml = "*" -opencontainers = {extras = ["reggie"],version = "*"} +opencontainers = { extras = ["reggie"], version = "*" } packaging = "*" paramiko = "*" psycopg2-binary = "*" @@ -163,8 +154,8 @@ swagger-spec-validator = "*" twilio = "*" twisted = "*" ua-parser = "*" -urllib3 = {extras = ["secure"],version = "*"} -uvicorn = {extras = ["standard"],version = "*"} +urllib3 = { extras = ["secure"], version = "*" } +uvicorn = { extras = ["standard"], version = "*" } webauthn = "*" wsproto = "*" xmlsec = "*" @@ -176,7 +167,7 @@ bandit = "*" black = "*" bump2version = "*" colorama = "*" -coverage = {extras = ["toml"],version = "*"} +coverage = { extras = ["toml"], version = "*" } importlib-metadata = "*" pylint = "*" pylint-django = "*"