This repository has been archived on 2024-05-31. You can view files and clone it, but cannot push or open issues or pull requests.
authentik/authentik/outposts/consumer.py
Philipp Kolberg 9db9ad3d66
root: Restructure broker / cache / channel / result configuration (#7097)
* Initial commit

* Remove any remaining mentions of Redis URL

This is handled in https://github.com/goauthentik/authentik/pull/5395

* Allow setting broker transport options

This enables usage of other brokers that require additional settings

* Remove remaining reference to Redis URL

This functionality is not part of this PR

* Reset default TLS requirements to none

* Fix linter errors

* Move dict from base64 encoded json to config.py

Additionally add tests

* Replace ast.literal_eval with json.loads

* Use default channel and cache backend configuration

If more customization is desired users shall look at goauthentik.io/docs/installation/configuration#custom-python-settings

* Send config deprecation notification to all superusers

* Remove duplicate method

* Add configuration explanation

For channel layer settings

* Use Event for deprecation warning

* Fix remove duplicated method

* Add missing comma

* Update authentik/lib/config.py

Signed-off-by: Jens L. <jens@beryju.org>

* Fix Event deprecation handling

---------

Signed-off-by: Jens L. <jens@beryju.org>
Co-authored-by: Jens L <jens@beryju.org>
2023-11-10 15:44:37 +01:00

140 lines
4.7 KiB
Python

"""Outpost websocket handler"""
from dataclasses import asdict, dataclass, field
from datetime import datetime
from enum import IntEnum
from typing import Any, Optional
from asgiref.sync import async_to_sync
from channels.exceptions import DenyConnection
from dacite.core import from_dict
from dacite.data import Data
from guardian.shortcuts import get_objects_for_user
from structlog.stdlib import BoundLogger, get_logger
from authentik.core.channels import AuthJsonConsumer
from authentik.outposts.apps import GAUGE_OUTPOSTS_CONNECTED, GAUGE_OUTPOSTS_LAST_UPDATE
from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState
OUTPOST_GROUP = "group_outpost_%(outpost_pk)s"
class WebsocketMessageInstruction(IntEnum):
"""Commands which can be triggered over Websocket"""
# Simple message used by either side when a message is acknowledged
ACK = 0
# Message used by outposts to report their alive status
HELLO = 1
# Message sent by us to trigger an Update
TRIGGER_UPDATE = 2
# Provider specific message
PROVIDER_SPECIFIC = 3
@dataclass(slots=True)
class WebsocketMessage:
"""Complete Websocket Message that is being sent"""
instruction: int
args: dict[str, Any] = field(default_factory=dict)
class OutpostConsumer(AuthJsonConsumer):
"""Handler for Outposts that connect over websockets for health checks and live updates"""
outpost: Optional[Outpost] = None
logger: BoundLogger
last_uid: Optional[str] = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.logger = get_logger()
def connect(self):
super().connect()
uuid = self.scope["url_route"]["kwargs"]["pk"]
outpost = (
get_objects_for_user(self.user, "authentik_outposts.view_outpost")
.filter(pk=uuid)
.first()
)
if not outpost:
raise DenyConnection()
self.logger = self.logger.bind(outpost=outpost)
try:
self.accept()
except RuntimeError as exc:
self.logger.warning("runtime error during accept", exc=exc)
raise DenyConnection()
self.outpost = outpost
self.last_uid = self.channel_name
async_to_sync(self.channel_layer.group_add)(
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
)
GAUGE_OUTPOSTS_CONNECTED.labels(
outpost=self.outpost.name,
uid=self.last_uid,
expected=self.outpost.config.kubernetes_replicas,
).inc()
def disconnect(self, code):
if self.outpost:
async_to_sync(self.channel_layer.group_discard)(
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
)
if self.outpost and self.last_uid:
GAUGE_OUTPOSTS_CONNECTED.labels(
outpost=self.outpost.name,
uid=self.last_uid,
expected=self.outpost.config.kubernetes_replicas,
).dec()
def receive_json(self, content: Data, **kwargs):
msg = from_dict(WebsocketMessage, content)
uid = msg.args.get("uuid", self.channel_name)
self.last_uid = uid
if not self.outpost:
raise DenyConnection()
state = OutpostState.for_instance_uid(self.outpost, uid)
state.last_seen = datetime.now()
state.hostname = msg.args.pop("hostname", "")
if msg.instruction == WebsocketMessageInstruction.HELLO:
state.version = msg.args.pop("version", None)
state.build_hash = msg.args.pop("buildHash", "")
state.args = msg.args
elif msg.instruction == WebsocketMessageInstruction.ACK:
return
GAUGE_OUTPOSTS_LAST_UPDATE.labels(
outpost=self.outpost.name,
uid=self.last_uid or "",
version=state.version or "",
).set_to_current_time()
state.save(timeout=OUTPOST_HELLO_INTERVAL * 1.5)
response = WebsocketMessage(instruction=WebsocketMessageInstruction.ACK)
self.send_json(asdict(response))
def event_update(self, event): # pragma: no cover
"""Event handler which is called by post_save signals, Send update instruction"""
self.send_json(
asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE))
)
def event_provider_specific(self, event):
"""Event handler which can be called by provider-specific
implementations to send specific messages to the outpost"""
self.send_json(
asdict(
WebsocketMessage(
instruction=WebsocketMessageInstruction.PROVIDER_SPECIFIC, args=event
)
)
)