From 56f1204c9b00e4bb3ee5b837e92005dd9a62e2df Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Thu, 20 May 2021 15:23:18 +0200 Subject: [PATCH] outposts: fix update signal not being sent to correct instances Signed-off-by: Jens Langhammer --- authentik/outposts/channels.py | 16 ++++++++-------- authentik/outposts/models.py | 15 +++++++-------- authentik/outposts/tasks.py | 7 +++++-- outpost/pkg/proxy/oauth.go | 2 +- 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/authentik/outposts/channels.py b/authentik/outposts/channels.py index f29a49e22..59f8430b6 100644 --- a/authentik/outposts/channels.py +++ b/authentik/outposts/channels.py @@ -40,7 +40,7 @@ class WebsocketMessage: class OutpostConsumer(AuthJsonConsumer): """Handler for Outposts that connect over websockets for health checks and live updates""" - outpost: Optional[Outpost] = None + outpost: Outpost last_uid: Optional[str] = None @@ -64,7 +64,9 @@ class OutpostConsumer(AuthJsonConsumer): # pylint: disable=unused-argument def disconnect(self, close_code): if self.outpost and self.last_uid: - OutpostState.for_channel(self.outpost, self.last_uid).delete() + state = OutpostState.for_instance_uid(self.outpost, self.last_uid) + state.channel_ids.remove(self.channel_name) + state.save() LOGGER.debug( "removed outpost instance from cache", outpost=self.outpost, @@ -75,12 +77,10 @@ class OutpostConsumer(AuthJsonConsumer): msg = from_dict(WebsocketMessage, content) uid = msg.args.get("uuid", self.channel_name) self.last_uid = uid - state = OutpostState( - uid=uid, - channel_id=self.channel_name, - last_seen=datetime.now(), - _outpost=self.outpost, - ) + state = OutpostState.for_instance_uid(self.outpost, uid) + if self.channel_name not in state.channel_ids: + state.channel_ids.append(self.channel_name) + state.last_seen = datetime.now() if msg.instruction == WebsocketMessageInstruction.HELLO: state.version = msg.args.get("version", None) state.build_hash = msg.args.get("buildHash", "") diff --git a/authentik/outposts/models.py b/authentik/outposts/models.py index 35f6d68cd..3c5d61424 100644 --- a/authentik/outposts/models.py +++ b/authentik/outposts/models.py @@ -409,7 +409,7 @@ class OutpostState: """Outpost instance state, last_seen and version""" uid: str - channel_id: str + channel_ids: list[str] = field(default_factory=list) last_seen: Optional[datetime] = field(default=None) version: Optional[str] = field(default=None) version_should: Union[Version, LegacyVersion] = field(default=OUR_VERSION) @@ -432,21 +432,20 @@ class OutpostState: keys = cache.keys(f"{outpost.state_cache_prefix}_*") states = [] for key in keys: - channel = key.replace(f"{outpost.state_cache_prefix}_", "") - states.append(OutpostState.for_channel(outpost, channel)) + instance_uid = key.replace(f"{outpost.state_cache_prefix}_", "") + states.append(OutpostState.for_instance_uid(outpost, instance_uid)) return states @staticmethod - def for_channel(outpost: Outpost, channel: str) -> "OutpostState": - """Get state for a single channel""" - key = f"{outpost.state_cache_prefix}_{channel}" - default_data = {"uid": channel, "channel_id": channel} + def for_instance_uid(outpost: Outpost, uid: str) -> "OutpostState": + """Get state for a single instance""" + key = f"{outpost.state_cache_prefix}_{uid}" + default_data = {"uid": uid, "channel_ids": []} data = cache.get(key, default_data) if isinstance(data, str): cache.delete(key) data = default_data state = from_dict(OutpostState, data) - state.uid = channel # pylint: disable=protected-access state._outpost = outpost return state diff --git a/authentik/outposts/tasks.py b/authentik/outposts/tasks.py index 2768e8374..957482e21 100644 --- a/authentik/outposts/tasks.py +++ b/authentik/outposts/tasks.py @@ -202,8 +202,11 @@ def _outpost_single_update(outpost: Outpost, layer=None): if not layer: # pragma: no cover layer = get_channel_layer() for state in OutpostState.for_outpost(outpost): - LOGGER.debug("sending update", channel=state.channel_id, outpost=outpost) - async_to_sync(layer.send)(state.channel_id, {"type": "event.update"}) + for channel in state.channel_ids: + LOGGER.debug( + "sending update", channel=channel, instance=state.uid, outpost=outpost + ) + async_to_sync(layer.send)(channel, {"type": "event.update"}) @CELERY_APP.task() diff --git a/outpost/pkg/proxy/oauth.go b/outpost/pkg/proxy/oauth.go index 9ee0bcadd..96e72a9b5 100644 --- a/outpost/pkg/proxy/oauth.go +++ b/outpost/pkg/proxy/oauth.go @@ -207,7 +207,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { } p.ClearCSRFCookie(rw, req) if c.Value != nonce { - p.logger.WithField("user", session.Email).WithField("status", "AuthFailure").Info("Invalid authentication via OAuth2: CSRF token mismatch, potential attack") + p.logger.WithField("is", c.Value).WithField("should", nonce).WithField("user", session.Email).WithField("status", "AuthFailure").Info("Invalid authentication via OAuth2: CSRF token mismatch, potential attack") p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", "CSRF Failed") return }