policies: detect when running in a daemon process and run policies sync

This commit is contained in:
Jens Langhammer 2021-01-17 19:59:58 +01:00
parent 65c9d4bf4c
commit e6f897c7e6
1 changed files with 9 additions and 6 deletions

View File

@ -1,6 +1,6 @@
"""authentik policy engine""" """authentik policy engine"""
from enum import Enum from enum import Enum
from multiprocessing import Pipe, set_start_method from multiprocessing import Pipe, current_process
from multiprocessing.connection import Connection from multiprocessing.connection import Connection
from typing import Iterator, List, Optional from typing import Iterator, List, Optional
@ -16,9 +16,7 @@ from authentik.policies.process import PolicyProcess, cache_key
from authentik.policies.types import PolicyRequest, PolicyResult from authentik.policies.types import PolicyRequest, PolicyResult
LOGGER = get_logger() LOGGER = get_logger()
# This is only really needed for macOS, because Python 3.8 changed the default to spawn CURRENT_PROCESS = current_process()
# spawn causes issues with objects that aren't picklable, and also the django setup
set_start_method("fork")
class PolicyProcessInfo: class PolicyProcessInfo:
@ -117,14 +115,19 @@ class PolicyEngine:
LOGGER.debug("P_ENG: Evaluating policy", policy=binding.policy) LOGGER.debug("P_ENG: Evaluating policy", policy=binding.policy)
our_end, task_end = Pipe(False) our_end, task_end = Pipe(False)
task = PolicyProcess(binding, self.request, task_end) task = PolicyProcess(binding, self.request, task_end)
task.daemon = False
LOGGER.debug("P_ENG: Starting Process", policy=binding.policy) LOGGER.debug("P_ENG: Starting Process", policy=binding.policy)
task.start() if CURRENT_PROCESS._config.get("daemon"):
task.run()
else:
task.start()
self.__processes.append( self.__processes.append(
PolicyProcessInfo(process=task, connection=our_end, binding=binding) PolicyProcessInfo(process=task, connection=our_end, binding=binding)
) )
# If all policies are cached, we have an empty list here. # If all policies are cached, we have an empty list here.
for proc_info in self.__processes: for proc_info in self.__processes:
proc_info.process.join(proc_info.binding.timeout) if proc_info.process.is_alive():
proc_info.process.join(proc_info.binding.timeout)
# Only call .recv() if no result is saved, otherwise we just deadlock here # Only call .recv() if no result is saved, otherwise we just deadlock here
if not proc_info.result: if not proc_info.result:
proc_info.result = proc_info.connection.recv() proc_info.result = proc_info.connection.recv()