outposts: fix update signal not being sent to correct instances
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
349a5b2d00
commit
56f1204c9b
|
@ -40,7 +40,7 @@ class WebsocketMessage:
|
||||||
class OutpostConsumer(AuthJsonConsumer):
|
class OutpostConsumer(AuthJsonConsumer):
|
||||||
"""Handler for Outposts that connect over websockets for health checks and live updates"""
|
"""Handler for Outposts that connect over websockets for health checks and live updates"""
|
||||||
|
|
||||||
outpost: Optional[Outpost] = None
|
outpost: Outpost
|
||||||
|
|
||||||
last_uid: Optional[str] = None
|
last_uid: Optional[str] = None
|
||||||
|
|
||||||
|
@ -64,7 +64,9 @@ class OutpostConsumer(AuthJsonConsumer):
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def disconnect(self, close_code):
|
def disconnect(self, close_code):
|
||||||
if self.outpost and self.last_uid:
|
if self.outpost and self.last_uid:
|
||||||
OutpostState.for_channel(self.outpost, self.last_uid).delete()
|
state = OutpostState.for_instance_uid(self.outpost, self.last_uid)
|
||||||
|
state.channel_ids.remove(self.channel_name)
|
||||||
|
state.save()
|
||||||
LOGGER.debug(
|
LOGGER.debug(
|
||||||
"removed outpost instance from cache",
|
"removed outpost instance from cache",
|
||||||
outpost=self.outpost,
|
outpost=self.outpost,
|
||||||
|
@ -75,12 +77,10 @@ class OutpostConsumer(AuthJsonConsumer):
|
||||||
msg = from_dict(WebsocketMessage, content)
|
msg = from_dict(WebsocketMessage, content)
|
||||||
uid = msg.args.get("uuid", self.channel_name)
|
uid = msg.args.get("uuid", self.channel_name)
|
||||||
self.last_uid = uid
|
self.last_uid = uid
|
||||||
state = OutpostState(
|
state = OutpostState.for_instance_uid(self.outpost, uid)
|
||||||
uid=uid,
|
if self.channel_name not in state.channel_ids:
|
||||||
channel_id=self.channel_name,
|
state.channel_ids.append(self.channel_name)
|
||||||
last_seen=datetime.now(),
|
state.last_seen = datetime.now()
|
||||||
_outpost=self.outpost,
|
|
||||||
)
|
|
||||||
if msg.instruction == WebsocketMessageInstruction.HELLO:
|
if msg.instruction == WebsocketMessageInstruction.HELLO:
|
||||||
state.version = msg.args.get("version", None)
|
state.version = msg.args.get("version", None)
|
||||||
state.build_hash = msg.args.get("buildHash", "")
|
state.build_hash = msg.args.get("buildHash", "")
|
||||||
|
|
|
@ -409,7 +409,7 @@ class OutpostState:
|
||||||
"""Outpost instance state, last_seen and version"""
|
"""Outpost instance state, last_seen and version"""
|
||||||
|
|
||||||
uid: str
|
uid: str
|
||||||
channel_id: str
|
channel_ids: list[str] = field(default_factory=list)
|
||||||
last_seen: Optional[datetime] = field(default=None)
|
last_seen: Optional[datetime] = field(default=None)
|
||||||
version: Optional[str] = field(default=None)
|
version: Optional[str] = field(default=None)
|
||||||
version_should: Union[Version, LegacyVersion] = field(default=OUR_VERSION)
|
version_should: Union[Version, LegacyVersion] = field(default=OUR_VERSION)
|
||||||
|
@ -432,21 +432,20 @@ class OutpostState:
|
||||||
keys = cache.keys(f"{outpost.state_cache_prefix}_*")
|
keys = cache.keys(f"{outpost.state_cache_prefix}_*")
|
||||||
states = []
|
states = []
|
||||||
for key in keys:
|
for key in keys:
|
||||||
channel = key.replace(f"{outpost.state_cache_prefix}_", "")
|
instance_uid = key.replace(f"{outpost.state_cache_prefix}_", "")
|
||||||
states.append(OutpostState.for_channel(outpost, channel))
|
states.append(OutpostState.for_instance_uid(outpost, instance_uid))
|
||||||
return states
|
return states
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def for_channel(outpost: Outpost, channel: str) -> "OutpostState":
|
def for_instance_uid(outpost: Outpost, uid: str) -> "OutpostState":
|
||||||
"""Get state for a single channel"""
|
"""Get state for a single instance"""
|
||||||
key = f"{outpost.state_cache_prefix}_{channel}"
|
key = f"{outpost.state_cache_prefix}_{uid}"
|
||||||
default_data = {"uid": channel, "channel_id": channel}
|
default_data = {"uid": uid, "channel_ids": []}
|
||||||
data = cache.get(key, default_data)
|
data = cache.get(key, default_data)
|
||||||
if isinstance(data, str):
|
if isinstance(data, str):
|
||||||
cache.delete(key)
|
cache.delete(key)
|
||||||
data = default_data
|
data = default_data
|
||||||
state = from_dict(OutpostState, data)
|
state = from_dict(OutpostState, data)
|
||||||
state.uid = channel
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
state._outpost = outpost
|
state._outpost = outpost
|
||||||
return state
|
return state
|
||||||
|
|
|
@ -202,8 +202,11 @@ def _outpost_single_update(outpost: Outpost, layer=None):
|
||||||
if not layer: # pragma: no cover
|
if not layer: # pragma: no cover
|
||||||
layer = get_channel_layer()
|
layer = get_channel_layer()
|
||||||
for state in OutpostState.for_outpost(outpost):
|
for state in OutpostState.for_outpost(outpost):
|
||||||
LOGGER.debug("sending update", channel=state.channel_id, outpost=outpost)
|
for channel in state.channel_ids:
|
||||||
async_to_sync(layer.send)(state.channel_id, {"type": "event.update"})
|
LOGGER.debug(
|
||||||
|
"sending update", channel=channel, instance=state.uid, outpost=outpost
|
||||||
|
)
|
||||||
|
async_to_sync(layer.send)(channel, {"type": "event.update"})
|
||||||
|
|
||||||
|
|
||||||
@CELERY_APP.task()
|
@CELERY_APP.task()
|
||||||
|
|
|
@ -207,7 +207,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
||||||
}
|
}
|
||||||
p.ClearCSRFCookie(rw, req)
|
p.ClearCSRFCookie(rw, req)
|
||||||
if c.Value != nonce {
|
if c.Value != nonce {
|
||||||
p.logger.WithField("user", session.Email).WithField("status", "AuthFailure").Info("Invalid authentication via OAuth2: CSRF token mismatch, potential attack")
|
p.logger.WithField("is", c.Value).WithField("should", nonce).WithField("user", session.Email).WithField("status", "AuthFailure").Info("Invalid authentication via OAuth2: CSRF token mismatch, potential attack")
|
||||||
p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", "CSRF Failed")
|
p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", "CSRF Failed")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
Reference in a new issue