outposts: fix update signal not being sent to correct instances

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-05-20 15:23:18 +02:00
parent 349a5b2d00
commit 56f1204c9b
4 changed files with 21 additions and 19 deletions

View file

@ -40,7 +40,7 @@ class WebsocketMessage:
class OutpostConsumer(AuthJsonConsumer): class OutpostConsumer(AuthJsonConsumer):
"""Handler for Outposts that connect over websockets for health checks and live updates""" """Handler for Outposts that connect over websockets for health checks and live updates"""
outpost: Optional[Outpost] = None outpost: Outpost
last_uid: Optional[str] = None last_uid: Optional[str] = None
@ -64,7 +64,9 @@ class OutpostConsumer(AuthJsonConsumer):
# pylint: disable=unused-argument # pylint: disable=unused-argument
def disconnect(self, close_code): def disconnect(self, close_code):
if self.outpost and self.last_uid: if self.outpost and self.last_uid:
OutpostState.for_channel(self.outpost, self.last_uid).delete() state = OutpostState.for_instance_uid(self.outpost, self.last_uid)
state.channel_ids.remove(self.channel_name)
state.save()
LOGGER.debug( LOGGER.debug(
"removed outpost instance from cache", "removed outpost instance from cache",
outpost=self.outpost, outpost=self.outpost,
@ -75,12 +77,10 @@ class OutpostConsumer(AuthJsonConsumer):
msg = from_dict(WebsocketMessage, content) msg = from_dict(WebsocketMessage, content)
uid = msg.args.get("uuid", self.channel_name) uid = msg.args.get("uuid", self.channel_name)
self.last_uid = uid self.last_uid = uid
state = OutpostState( state = OutpostState.for_instance_uid(self.outpost, uid)
uid=uid, if self.channel_name not in state.channel_ids:
channel_id=self.channel_name, state.channel_ids.append(self.channel_name)
last_seen=datetime.now(), state.last_seen = datetime.now()
_outpost=self.outpost,
)
if msg.instruction == WebsocketMessageInstruction.HELLO: if msg.instruction == WebsocketMessageInstruction.HELLO:
state.version = msg.args.get("version", None) state.version = msg.args.get("version", None)
state.build_hash = msg.args.get("buildHash", "") state.build_hash = msg.args.get("buildHash", "")

View file

@ -409,7 +409,7 @@ class OutpostState:
"""Outpost instance state, last_seen and version""" """Outpost instance state, last_seen and version"""
uid: str uid: str
channel_id: str channel_ids: list[str] = field(default_factory=list)
last_seen: Optional[datetime] = field(default=None) last_seen: Optional[datetime] = field(default=None)
version: Optional[str] = field(default=None) version: Optional[str] = field(default=None)
version_should: Union[Version, LegacyVersion] = field(default=OUR_VERSION) version_should: Union[Version, LegacyVersion] = field(default=OUR_VERSION)
@ -432,21 +432,20 @@ class OutpostState:
keys = cache.keys(f"{outpost.state_cache_prefix}_*") keys = cache.keys(f"{outpost.state_cache_prefix}_*")
states = [] states = []
for key in keys: for key in keys:
channel = key.replace(f"{outpost.state_cache_prefix}_", "") instance_uid = key.replace(f"{outpost.state_cache_prefix}_", "")
states.append(OutpostState.for_channel(outpost, channel)) states.append(OutpostState.for_instance_uid(outpost, instance_uid))
return states return states
@staticmethod @staticmethod
def for_channel(outpost: Outpost, channel: str) -> "OutpostState": def for_instance_uid(outpost: Outpost, uid: str) -> "OutpostState":
"""Get state for a single channel""" """Get state for a single instance"""
key = f"{outpost.state_cache_prefix}_{channel}" key = f"{outpost.state_cache_prefix}_{uid}"
default_data = {"uid": channel, "channel_id": channel} default_data = {"uid": uid, "channel_ids": []}
data = cache.get(key, default_data) data = cache.get(key, default_data)
if isinstance(data, str): if isinstance(data, str):
cache.delete(key) cache.delete(key)
data = default_data data = default_data
state = from_dict(OutpostState, data) state = from_dict(OutpostState, data)
state.uid = channel
# pylint: disable=protected-access # pylint: disable=protected-access
state._outpost = outpost state._outpost = outpost
return state return state

View file

@ -202,8 +202,11 @@ def _outpost_single_update(outpost: Outpost, layer=None):
if not layer: # pragma: no cover if not layer: # pragma: no cover
layer = get_channel_layer() layer = get_channel_layer()
for state in OutpostState.for_outpost(outpost): for state in OutpostState.for_outpost(outpost):
LOGGER.debug("sending update", channel=state.channel_id, outpost=outpost) for channel in state.channel_ids:
async_to_sync(layer.send)(state.channel_id, {"type": "event.update"}) LOGGER.debug(
"sending update", channel=channel, instance=state.uid, outpost=outpost
)
async_to_sync(layer.send)(channel, {"type": "event.update"})
@CELERY_APP.task() @CELERY_APP.task()

View file

@ -207,7 +207,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
} }
p.ClearCSRFCookie(rw, req) p.ClearCSRFCookie(rw, req)
if c.Value != nonce { if c.Value != nonce {
p.logger.WithField("user", session.Email).WithField("status", "AuthFailure").Info("Invalid authentication via OAuth2: CSRF token mismatch, potential attack") p.logger.WithField("is", c.Value).WithField("should", nonce).WithField("user", session.Email).WithField("status", "AuthFailure").Info("Invalid authentication via OAuth2: CSRF token mismatch, potential attack")
p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", "CSRF Failed") p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", "CSRF Failed")
return return
} }