From 3b5e1c7b34ca548c51b366bef8becb5157342013 Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Sun, 13 Dec 2020 17:46:34 +0100 Subject: [PATCH] core: cleanup channels code, fix error when server side close --- authentik/core/channels.py | 8 +++----- authentik/outposts/channels.py | 14 +++++++------- authentik/outposts/tasks.py | 7 ++++++- authentik/root/asgi.py | 11 ++++++++++- 4 files changed, 26 insertions(+), 14 deletions(-) diff --git a/authentik/core/channels.py b/authentik/core/channels.py index 31be6ffd0..ac359d1e6 100644 --- a/authentik/core/channels.py +++ b/authentik/core/channels.py @@ -1,4 +1,5 @@ """Channels base classes""" +from channels.exceptions import DenyConnection from channels.generic.websocket import JsonWebsocketConsumer from structlog import get_logger @@ -17,16 +18,13 @@ class AuthJsonConsumer(JsonWebsocketConsumer): headers = dict(self.scope["headers"]) if b"authorization" not in headers: LOGGER.warning("WS Request without authorization header") - self.close() - return False + raise DenyConnection() raw_header = headers[b"authorization"] token = token_from_header(raw_header) if not token: LOGGER.warning("Failed to authenticate") - self.close() - return False + raise DenyConnection() self.user = token.user - return True diff --git a/authentik/outposts/channels.py b/authentik/outposts/channels.py index cebe16049..22eec3abb 100644 --- a/authentik/outposts/channels.py +++ b/authentik/outposts/channels.py @@ -2,8 +2,9 @@ from dataclasses import asdict, dataclass, field from datetime import datetime from enum import IntEnum -from typing import Any, Dict +from typing import Any, Dict, Optional +from channels.exceptions import DenyConnection from dacite import from_dict from dacite.data import Data from guardian.shortcuts import get_objects_for_user @@ -39,18 +40,16 @@ class WebsocketMessage: class OutpostConsumer(AuthJsonConsumer): """Handler for Outposts that connect over websockets for health checks and live updates""" - outpost: Outpost + outpost: Optional[Outpost] = None def connect(self): - if not super().connect(): - return + super().connect() uuid = self.scope["url_route"]["kwargs"]["pk"] outpost = get_objects_for_user( self.user, "authentik_outposts.view_outpost" ).filter(pk=uuid) if not outpost.exists(): - self.close() - return + raise DenyConnection() self.accept() self.outpost = outpost.first() OutpostState( @@ -60,7 +59,8 @@ class OutpostConsumer(AuthJsonConsumer): # pylint: disable=unused-argument def disconnect(self, close_code): - OutpostState.for_channel(self.outpost, self.channel_name).delete() + if self.outpost: + OutpostState.for_channel(self.outpost, self.channel_name).delete() LOGGER.debug("removed channel from cache", channel_name=self.channel_name) def receive_json(self, content: Data): diff --git a/authentik/outposts/tasks.py b/authentik/outposts/tasks.py index d5246c83b..d40efcbcc 100644 --- a/authentik/outposts/tasks.py +++ b/authentik/outposts/tasks.py @@ -97,7 +97,12 @@ def outpost_token_ensurer(self: MonitoredTask): all_outposts = Outpost.objects.all() for outpost in all_outposts: _ = outpost.token - self.set_status(TaskResult(TaskResultStatus.SUCCESSFUL, f"Successfully checked {len(all_outposts)} Outposts.")) + self.set_status( + TaskResult( + TaskResultStatus.SUCCESSFUL, + [f"Successfully checked {len(all_outposts)} Outposts."], + ) + ) @CELERY_APP.task() diff --git a/authentik/root/asgi.py b/authentik/root/asgi.py index 454daffba..93e9257b0 100644 --- a/authentik/root/asgi.py +++ b/authentik/root/asgi.py @@ -105,7 +105,16 @@ class ASGILogger: # https://code.djangoproject.com/ticket/31508 # https://github.com/encode/uvicorn/issues/266 return - await self.app(scope, receive, send_hooked) + try: + await self.app(scope, receive, send_hooked) + except TypeError as exc: + # https://github.com/encode/uvicorn/issues/244 + if exc.args == ( + "An asyncio.Future, a coroutine or an awaitable is required", + ): + pass + else: + raise exc def _get_ip(self) -> str: client_ip = None