policies/engine: Add sanity test to ensure result count matches policy count

This commit is contained in:
Jens Langhammer 2020-12-19 23:40:25 +01:00
parent e62333dfb3
commit efc849e760

View file

@ -47,6 +47,8 @@ class PolicyEngine:
__cached_policies: List[PolicyResult] __cached_policies: List[PolicyResult]
__processes: List[PolicyProcessInfo] __processes: List[PolicyProcessInfo]
__expected_result_count: int
def __init__( def __init__(
self, pbm: PolicyBindingModel, user: User, request: HttpRequest = None self, pbm: PolicyBindingModel, user: User, request: HttpRequest = None
): ):
@ -59,6 +61,7 @@ class PolicyEngine:
self.__cached_policies = [] self.__cached_policies = []
self.__processes = [] self.__processes = []
self.use_cache = True self.use_cache = True
self.__expected_result_count = 0
def _iter_bindings(self) -> Iterator[PolicyBinding]: def _iter_bindings(self) -> Iterator[PolicyBinding]:
"""Make sure all Policies are their respective classes""" """Make sure all Policies are their respective classes"""
@ -79,6 +82,8 @@ class PolicyEngine:
span.set_data("pbm", self.__pbm) span.set_data("pbm", self.__pbm)
span.set_data("request", self.request) span.set_data("request", self.request)
for binding in self._iter_bindings(): for binding in self._iter_bindings():
self.__expected_result_count += 1
self._check_policy_type(binding.policy) self._check_policy_type(binding.policy)
key = cache_key(binding, self.request) key = cache_key(binding, self.request)
cached_policy = cache.get(key, None) cached_policy = cache.get(key, None)
@ -112,10 +117,13 @@ class PolicyEngine:
process_results: List[PolicyResult] = [ process_results: List[PolicyResult] = [
x.result for x in self.__processes if x.result x.result for x in self.__processes if x.result
] ]
all_results = list(process_results + self.__cached_policies)
final_result = PolicyResult(False) final_result = PolicyResult(False)
final_result.messages = [] final_result.messages = []
final_result.source_results = list(process_results + self.__cached_policies) final_result.source_results = all_results
for result in process_results + self.__cached_policies: if len(all_results) < self.__expected_result_count: # pragma: no cover
raise AssertionError("Got less results than polices")
for result in all_results:
LOGGER.debug( LOGGER.debug(
"P_ENG: result", passing=result.passing, messages=result.messages "P_ENG: result", passing=result.passing, messages=result.messages
) )