root: use channel send workaround for sync sending of websocket messages

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer 2023-02-15 16:06:17 +01:00
parent 7f009f6d02
commit bff34cc5dc
No known key found for this signature in database
3 changed files with 25 additions and 28 deletions

View File

@ -7,7 +7,6 @@ from urllib.parse import urlparse
import yaml import yaml
from asgiref.sync import async_to_sync from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.core.cache import cache from django.core.cache import cache
from django.db import DatabaseError, InternalError, ProgrammingError from django.db import DatabaseError, InternalError, ProgrammingError
from django.db.models.base import Model from django.db.models.base import Model
@ -43,6 +42,7 @@ from authentik.providers.ldap.controllers.kubernetes import LDAPKubernetesContro
from authentik.providers.proxy.controllers.docker import ProxyDockerController from authentik.providers.proxy.controllers.docker import ProxyDockerController
from authentik.providers.proxy.controllers.kubernetes import ProxyKubernetesController from authentik.providers.proxy.controllers.kubernetes import ProxyKubernetesController
from authentik.root.celery import CELERY_APP from authentik.root.celery import CELERY_APP
from authentik.root.messages.storage import closing_send
LOGGER = get_logger() LOGGER = get_logger()
CACHE_KEY_OUTPOST_DOWN = "outpost_teardown_%s" CACHE_KEY_OUTPOST_DOWN = "outpost_teardown_%s"
@ -217,26 +217,23 @@ def outpost_post_save(model_class: str, model_pk: Any):
def outpost_send_update(model_instace: Model): def outpost_send_update(model_instace: Model):
"""Send outpost update to all registered outposts, regardless to which authentik """Send outpost update to all registered outposts, regardless to which authentik
instance they are connected""" instance they are connected"""
channel_layer = get_channel_layer()
if isinstance(model_instace, OutpostModel): if isinstance(model_instace, OutpostModel):
for outpost in model_instace.outpost_set.all(): for outpost in model_instace.outpost_set.all():
_outpost_single_update(outpost, channel_layer) _outpost_single_update(outpost)
elif isinstance(model_instace, Outpost): elif isinstance(model_instace, Outpost):
_outpost_single_update(model_instace, channel_layer) _outpost_single_update(model_instace)
def _outpost_single_update(outpost: Outpost, layer=None): def _outpost_single_update(outpost: Outpost):
"""Update outpost instances connected to a single outpost""" """Update outpost instances connected to a single outpost"""
# Ensure token again, because this function is called when anything related to an # Ensure token again, because this function is called when anything related to an
# OutpostModel is saved, so we can be sure permissions are right # OutpostModel is saved, so we can be sure permissions are right
_ = outpost.token _ = outpost.token
outpost.build_user_permissions(outpost.user) outpost.build_user_permissions(outpost.user)
if not layer: # pragma: no cover
layer = get_channel_layer()
for state in OutpostState.for_outpost(outpost): for state in OutpostState.for_outpost(outpost):
for channel in state.channel_ids: for channel in state.channel_ids:
LOGGER.debug("sending update", channel=channel, instance=state.uid, outpost=outpost) LOGGER.debug("sending update", channel=channel, instance=state.uid, outpost=outpost)
async_to_sync(layer.send)(channel, {"type": "event.update"}) async_to_sync(closing_send)(channel, {"type": "event.update"})
@CELERY_APP.task() @CELERY_APP.task()

View File

@ -1,6 +1,7 @@
"""Channels Messages storage""" """Channels Messages storage"""
from asgiref.sync import async_to_sync from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer from channels import DEFAULT_CHANNEL_LAYER
from channels.layers import channel_layers
from django.contrib.messages.storage.base import Message from django.contrib.messages.storage.base import Message
from django.contrib.messages.storage.session import SessionStorage from django.contrib.messages.storage.session import SessionStorage
from django.core.cache import cache from django.core.cache import cache
@ -10,13 +11,21 @@ SESSION_KEY = "_messages"
CACHE_PREFIX = "goauthentik.io/root/messages_" CACHE_PREFIX = "goauthentik.io/root/messages_"
async def closing_send(channel, message):
"""Wrapper around layer send that closes the connection"""
# See https://github.com/django/channels_redis/issues/332
# TODO: Remove this after channels_redis 4.1 is released
channel_layer = channel_layers.make_backend(DEFAULT_CHANNEL_LAYER)
await channel_layer.send(channel, message)
await channel_layer.close_pools()
class ChannelsStorage(SessionStorage): class ChannelsStorage(SessionStorage):
"""Send contrib.messages over websocket""" """Send contrib.messages over websocket"""
def __init__(self, request: HttpRequest) -> None: def __init__(self, request: HttpRequest) -> None:
# pyright: reportGeneralTypeIssues=false # pyright: reportGeneralTypeIssues=false
super().__init__(request) super().__init__(request)
self.channel = get_channel_layer()
def _store(self, messages: list[Message], response, *args, **kwargs): def _store(self, messages: list[Message], response, *args, **kwargs):
prefix = f"{CACHE_PREFIX}{self.request.session.session_key}_messages_" prefix = f"{CACHE_PREFIX}{self.request.session.session_key}_messages_"
@ -28,7 +37,7 @@ class ChannelsStorage(SessionStorage):
for key in keys: for key in keys:
uid = key.replace(prefix, "") uid = key.replace(prefix, "")
for message in messages: for message in messages:
async_to_sync(self.channel.send)( async_to_sync(closing_send)(
uid, uid,
{ {
"type": "event.update", "type": "event.update",

View File

@ -1,8 +1,5 @@
[tool.pyright] [tool.pyright]
ignore = [ ignore = ["**/migrations/**", "**/node_modules/**"]
"**/migrations/**",
"**/node_modules/**"
]
reportMissingTypeStubs = false reportMissingTypeStubs = false
strictParameterNoneValue = true strictParameterNoneValue = true
strictDictionaryInference = true strictDictionaryInference = true
@ -63,14 +60,7 @@ exclude_lines = [
show_missing = true show_missing = true
[tool.pylint.basic] [tool.pylint.basic]
good-names = [ good-names = ["pk", "id", "i", "j", "k", "_"]
"pk",
"id",
"i",
"j",
"k",
"_",
]
[tool.pylint.master] [tool.pylint.master]
disable = [ disable = [
@ -85,6 +75,7 @@ disable = [
"protected-access", "protected-access",
"unused-argument", "unused-argument",
"raise-missing-from", "raise-missing-from",
"fixme",
# To preserve django's translation function we need to use %-formatting # To preserve django's translation function we need to use %-formatting
"consider-using-f-string", "consider-using-f-string",
] ]
@ -120,7 +111,7 @@ authors = ["authentik Team <hello@goauthentik.io>"]
[tool.poetry.dependencies] [tool.poetry.dependencies]
celery = "*" celery = "*"
channels = {version = "*", extras = ["daphne"]} channels = { version = "*", extras = ["daphne"] }
channels-redis = "*" channels-redis = "*"
codespell = "*" codespell = "*"
colorama = "*" colorama = "*"
@ -147,7 +138,7 @@ gunicorn = "*"
kubernetes = "*" kubernetes = "*"
ldap3 = "*" ldap3 = "*"
lxml = "*" lxml = "*"
opencontainers = {extras = ["reggie"],version = "*"} opencontainers = { extras = ["reggie"], version = "*" }
packaging = "*" packaging = "*"
paramiko = "*" paramiko = "*"
psycopg2-binary = "*" psycopg2-binary = "*"
@ -163,8 +154,8 @@ swagger-spec-validator = "*"
twilio = "*" twilio = "*"
twisted = "*" twisted = "*"
ua-parser = "*" ua-parser = "*"
urllib3 = {extras = ["secure"],version = "*"} urllib3 = { extras = ["secure"], version = "*" }
uvicorn = {extras = ["standard"],version = "*"} uvicorn = { extras = ["standard"], version = "*" }
webauthn = "*" webauthn = "*"
wsproto = "*" wsproto = "*"
xmlsec = "*" xmlsec = "*"
@ -176,7 +167,7 @@ bandit = "*"
black = "*" black = "*"
bump2version = "*" bump2version = "*"
colorama = "*" colorama = "*"
coverage = {extras = ["toml"],version = "*"} coverage = { extras = ["toml"], version = "*" }
importlib-metadata = "*" importlib-metadata = "*"
pylint = "*" pylint = "*"
pylint-django = "*" pylint-django = "*"