remove logging to increase speed, add more caching to policy and rewriter

This commit is contained in:
Jens Langhammer 2019-04-13 17:22:03 +02:00
parent 9b5b03647b
commit dda41af5c8
5 changed files with 51 additions and 60 deletions

View File

@ -12,8 +12,8 @@ from django.utils.http import urlencode
from passbook.app_gw.models import ApplicationGatewayProvider from passbook.app_gw.models import ApplicationGatewayProvider
from passbook.app_gw.proxy.exceptions import InvalidUpstream from passbook.app_gw.proxy.exceptions import InvalidUpstream
from passbook.app_gw.proxy.response import get_django_response from passbook.app_gw.proxy.response import get_django_response
from passbook.app_gw.proxy.rewrite import Rewriter
from passbook.app_gw.proxy.utils import encode_items, normalize_request_headers from passbook.app_gw.proxy.utils import encode_items, normalize_request_headers
from passbook.app_gw.rewrite import Rewriter
from passbook.core.models import Application from passbook.core.models import Application
from passbook.core.policies import PolicyEngine from passbook.core.policies import PolicyEngine
@ -30,7 +30,7 @@ HTTP = urllib3.PoolManager(
cert_reqs='CERT_REQUIRED', cert_reqs='CERT_REQUIRED',
ca_certs=certifi.where()) ca_certs=certifi.where())
IGNORED_HOSTS = cache.get(IGNORED_HOSTNAMES_KEY, []) IGNORED_HOSTS = cache.get(IGNORED_HOSTNAMES_KEY, [])
POLICY_CACHE = {}
class RequestHandler: class RequestHandler:
"""Forward requests""" """Forward requests"""
@ -41,6 +41,8 @@ class RequestHandler:
def __init__(self, app_gw, request): def __init__(self, app_gw, request):
self.app_gw = app_gw self.app_gw = app_gw
self.request = request self.request = request
if self.app_gw.pk not in POLICY_CACHE:
POLICY_CACHE[self.app_gw.pk] = self.app_gw.application.policies.all()
@staticmethod @staticmethod
def find_app_gw_for_request(request): def find_app_gw_for_request(request):
@ -49,7 +51,7 @@ class RequestHandler:
# This saves us having to query the database on each request # This saves us having to query the database on each request
host_header = request.META.get('HTTP_HOST') host_header = request.META.get('HTTP_HOST')
if host_header in IGNORED_HOSTS: if host_header in IGNORED_HOSTS:
LOGGER.debug("%s is ignored", host_header) # LOGGER.debug("%s is ignored", host_header)
return False return False
# Look through all ApplicationGatewayProviders and check hostnames # Look through all ApplicationGatewayProviders and check hostnames
matches = ApplicationGatewayProvider.objects.filter( matches = ApplicationGatewayProvider.objects.filter(
@ -59,7 +61,7 @@ class RequestHandler:
# Mo matching Providers found, add host header to ignored list # Mo matching Providers found, add host header to ignored list
IGNORED_HOSTS.append(host_header) IGNORED_HOSTS.append(host_header)
cache.set(IGNORED_HOSTNAMES_KEY, IGNORED_HOSTS) cache.set(IGNORED_HOSTNAMES_KEY, IGNORED_HOSTS)
LOGGER.debug("Ignoring %s", host_header) # LOGGER.debug("Ignoring %s", host_header)
return False return False
# At this point we're certain there's a matching ApplicationGateway # At this point we're certain there's a matching ApplicationGateway
if len(matches) > 1: if len(matches) > 1:
@ -72,7 +74,8 @@ class RequestHandler:
if app_gw: if app_gw:
return app_gw return app_gw
except Application.DoesNotExist: except Application.DoesNotExist:
LOGGER.debug("ApplicationGateway not associated with Application") pass
# LOGGER.debug("ApplicationGateway not associated with Application")
return True return True
def _get_upstream(self): def _get_upstream(self):
@ -97,10 +100,10 @@ class RequestHandler:
return upstream return upstream
def _format_path_to_redirect(self): def _format_path_to_redirect(self):
LOGGER.debug("Path before: %s", self.request.get_full_path()) # LOGGER.debug("Path before: %s", self.request.get_full_path())
rewriter = Rewriter(self.app_gw, self.request) rewriter = Rewriter(self.app_gw, self.request)
after = rewriter.build() after = rewriter.build()
LOGGER.debug("Path after: %s", after) # LOGGER.debug("Path after: %s", after)
return after return after
def get_proxy_request_headers(self): def get_proxy_request_headers(self):
@ -126,7 +129,7 @@ class RequestHandler:
if not self.app_gw.authentication_header: if not self.app_gw.authentication_header:
return request_headers return request_headers
request_headers[self.app_gw.authentication_header] = self.request.user.get_username() request_headers[self.app_gw.authentication_header] = self.request.user.get_username()
LOGGER.info("%s set", self.app_gw.authentication_header) # LOGGER.debug("%s set", self.app_gw.authentication_header)
return request_headers return request_headers
@ -136,7 +139,7 @@ class RequestHandler:
return False return False
if not self.request.user.is_authenticated: if not self.request.user.is_authenticated:
return False return False
policy_engine = PolicyEngine(self.app_gw.application.policies.all()) policy_engine = PolicyEngine(POLICY_CACHE[self.app_gw.pk])
policy_engine.for_user(self.request.user).with_request(self.request).build() policy_engine.for_user(self.request.user).with_request(self.request).build()
passing, _messages = policy_engine.result passing, _messages = policy_engine.result
@ -150,14 +153,14 @@ class RequestHandler:
def _created_proxy_response(self, path): def _created_proxy_response(self, path):
request_payload = self.request.body request_payload = self.request.body
LOGGER.debug("Request headers: %s", self._request_headers) # LOGGER.debug("Request headers: %s", self._request_headers)
request_url = self.get_upstream() + path request_url = self.get_upstream() + path
LOGGER.debug("Request URL: %s", request_url) # LOGGER.debug("Request URL: %s", request_url)
if self.request.GET: if self.request.GET:
request_url += '?' + self.get_encoded_query_params() request_url += '?' + self.get_encoded_query_params()
LOGGER.debug("Request URL: %s", request_url) # LOGGER.debug("Request URL: %s", request_url)
http = HTTP http = HTTP
if not self.app_gw.upstream_ssl_verification: if not self.app_gw.upstream_ssl_verification:
@ -172,8 +175,8 @@ class RequestHandler:
body=request_payload, body=request_payload,
decode_content=False, decode_content=False,
preload_content=False) preload_content=False)
LOGGER.debug("Proxy response header: %s", # LOGGER.debug("Proxy response header: %s",
proxy_response.getheaders()) # proxy_response.getheaders())
except urllib3.exceptions.HTTPError as error: except urllib3.exceptions.HTTPError as error:
LOGGER.exception(error) LOGGER.exception(error)
raise raise
@ -195,8 +198,8 @@ class RequestHandler:
location = location.replace(upstream_host_http, request_host) location = location.replace(upstream_host_http, request_host)
location = location.replace(upstream_host_https, request_host) location = location.replace(upstream_host_https, request_host)
proxy_response.headers['Location'] = location proxy_response.headers['Location'] = location
LOGGER.debug("Proxy response LOCATION: %s", # LOGGER.debug("Proxy response LOCATION: %s",
proxy_response.headers['Location']) # proxy_response.headers['Location'])
def _set_content_type(self, proxy_response): def _set_content_type(self, proxy_response):
content_type = proxy_response.headers.get('Content-Type') content_type = proxy_response.headers.get('Content-Type')
@ -204,8 +207,8 @@ class RequestHandler:
content_type = (mimetypes.guess_type(self.request.path)[0] or content_type = (mimetypes.guess_type(self.request.path)[0] or
self.app_gw.default_content_type) self.app_gw.default_content_type)
proxy_response.headers['Content-Type'] = content_type proxy_response.headers['Content-Type'] = content_type
LOGGER.debug("Proxy response CONTENT-TYPE: %s", # LOGGER.debug("Proxy response CONTENT-TYPE: %s",
proxy_response.headers['Content-Type']) # proxy_response.headers['Content-Type'])
def get_response(self): def get_response(self):
"""Pass request to upstream and return response""" """Pass request to upstream and return response"""
@ -218,5 +221,5 @@ class RequestHandler:
self._set_content_type(proxy_response) self._set_content_type(proxy_response)
response = get_django_response(proxy_response, strict_cookies=False) response = get_django_response(proxy_response, strict_cookies=False)
LOGGER.debug("RESPONSE RETURNED: %s", response) # LOGGER.debug("RESPONSE RETURNED: %s", response)
return response return response

View File

@ -2,6 +2,7 @@
from passbook.app_gw.models import RewriteRule from passbook.app_gw.models import RewriteRule
RULE_CACHE = {}
class Context: class Context:
"""Empty class which we dynamically add attributes to""" """Empty class which we dynamically add attributes to"""
@ -15,6 +16,9 @@ class Rewriter:
def __init__(self, application, request): def __init__(self, application, request):
self.__application = application self.__application = application
self.__request = request self.__request = request
if self.__application.pk not in RULE_CACHE:
RULE_CACHE[self.__application.pk] = RewriteRule.objects.filter(
provider__in=[self.__application])
def __build_context(self, matches): def __build_context(self, matches):
"""Build object with .0, .1, etc as groups and give access to request""" """Build object with .0, .1, etc as groups and give access to request"""
@ -27,7 +31,7 @@ class Rewriter:
def build(self): def build(self):
"""Run all rules over path and return final path""" """Run all rules over path and return final path"""
path = self.__request.get_full_path() path = self.__request.get_full_path()
for rule in RewriteRule.objects.filter(provider__in=[self.__application]): for rule in RULE_CACHE[self.__application.pk]:
matches = rule.compiled_matcher.search(path) matches = rule.compiled_matcher.search(path)
if not matches: if not matches:
continue continue

View File

@ -1,7 +1,5 @@
"""passbook core policy engine""" """passbook core policy engine"""
import cProfile # from logging import getLogger
from logging import getLogger
from amqp.exceptions import UnexpectedFrame from amqp.exceptions import UnexpectedFrame
from celery import group from celery import group
from celery.exceptions import TimeoutError as CeleryTimeoutError from celery.exceptions import TimeoutError as CeleryTimeoutError
@ -11,19 +9,7 @@ from ipware import get_client_ip
from passbook.core.celery import CELERY_APP from passbook.core.celery import CELERY_APP
from passbook.core.models import Policy, User from passbook.core.models import Policy, User
# LOGGER = getLogger(__name__)
def profileit(func):
def wrapper(*args, **kwargs):
datafn = func.__name__ + ".profile" # Name the data file sensibly
prof = cProfile.Profile()
retval = prof.runcall(func, *args, **kwargs)
prof.dump_stats(datafn)
return retval
return wrapper
LOGGER = getLogger(__name__)
def _cache_key(policy, user): def _cache_key(policy, user):
return "%s#%s" % (policy.uuid, user.pk) return "%s#%s" % (policy.uuid, user.pk)
@ -37,8 +23,8 @@ def _policy_engine_task(user_pk, policy_pk, **kwargs):
user_obj = User.objects.get(pk=user_pk) user_obj = User.objects.get(pk=user_pk)
for key, value in kwargs.items(): for key, value in kwargs.items():
setattr(user_obj, key, value) setattr(user_obj, key, value)
LOGGER.debug("Running policy `%s`#%s for user %s...", policy_obj.name, # LOGGER.debug("Running policy `%s`#%s for user %s...", policy_obj.name,
policy_obj.pk.hex, user_obj) # policy_obj.pk.hex, user_obj)
policy_result = policy_obj.passes(user_obj) policy_result = policy_obj.passes(user_obj)
# Handle policy result correctly if result, message or just result # Handle policy result correctly if result, message or just result
message = None message = None
@ -47,10 +33,10 @@ def _policy_engine_task(user_pk, policy_pk, **kwargs):
# Invert result if policy.negate is set # Invert result if policy.negate is set
if policy_obj.negate: if policy_obj.negate:
policy_result = not policy_result policy_result = not policy_result
LOGGER.debug("Policy %r#%s got %s", policy_obj.name, policy_obj.pk.hex, policy_result) # LOGGER.debug("Policy %r#%s got %s", policy_obj.name, policy_obj.pk.hex, policy_result)
cache_key = _cache_key(policy_obj, user_obj) cache_key = _cache_key(policy_obj, user_obj)
cache.set(cache_key, (policy_obj.action, policy_result, message)) cache.set(cache_key, (policy_obj.action, policy_result, message))
LOGGER.debug("Cached entry as %s", cache_key) # LOGGER.debug("Cached entry as %s", cache_key)
return policy_obj.action, policy_result, message return policy_obj.action, policy_result, message
class PolicyEngine: class PolicyEngine:
@ -79,7 +65,6 @@ class PolicyEngine:
self.__request = request self.__request = request
return self return self
@profileit
def build(self): def build(self):
"""Build task group""" """Build task group"""
if not self.__user: if not self.__user:
@ -96,16 +81,16 @@ class PolicyEngine:
for policy in self.policies: for policy in self.policies:
cached_policy = cache.get(_cache_key(policy, self.__user), None) cached_policy = cache.get(_cache_key(policy, self.__user), None)
if cached_policy: if cached_policy:
LOGGER.warning("Taking result from cache for %s", policy.pk.hex) # LOGGER.debug("Taking result from cache for %s", policy.pk.hex)
cached_policies.append(cached_policy) cached_policies.append(cached_policy)
else: else:
LOGGER.warning("Evaluating policy %s", policy.pk.hex) # LOGGER.debug("Evaluating policy %s", policy.pk.hex)
signatures.append(_policy_engine_task.signature( signatures.append(_policy_engine_task.signature(
args=(self.__user.pk, policy.pk.hex), args=(self.__user.pk, policy.pk.hex),
kwargs=kwargs, kwargs=kwargs,
time_limit=policy.timeout)) time_limit=policy.timeout))
self.__get_timeout += policy.timeout self.__get_timeout += policy.timeout
LOGGER.warning("Set total policy timeout to %r", self.__get_timeout) # LOGGER.debug("Set total policy timeout to %r", self.__get_timeout)
# If all policies are cached, we have an empty list here. # If all policies are cached, we have an empty list here.
if signatures: if signatures:
self.__group = group(signatures)() self.__group = group(signatures)()
@ -134,7 +119,7 @@ class PolicyEngine:
for policy_action, policy_result, policy_message in result: for policy_action, policy_result, policy_message in result:
passing = (policy_action == Policy.ACTION_ALLOW and policy_result) or \ passing = (policy_action == Policy.ACTION_ALLOW and policy_result) or \
(policy_action == Policy.ACTION_DENY and not policy_result) (policy_action == Policy.ACTION_DENY and not policy_result)
LOGGER.debug('Action=%s, Result=%r => %r', policy_action, policy_result, passing) # LOGGER.debug('Action=%s, Result=%r => %r', policy_action, policy_result, passing)
if policy_message: if policy_message:
messages.append(policy_message) messages.append(policy_message)
if not passing: if not passing:

View File

@ -299,7 +299,7 @@ with CONFIG.cd('log'):
}, },
'django': { 'django': {
'handlers': ['queue'], 'handlers': ['queue'],
'level': 'DEBUG', 'level': 'INFO',
'propagate': True, 'propagate': True,
}, },
'tasks': { 'tasks': {
@ -324,7 +324,7 @@ with CONFIG.cd('log'):
}, },
'daphne': { 'daphne': {
'handlers': ['queue'], 'handlers': ['queue'],
'level': 'DEBUG', 'level': 'INFO',
'propagate': True, 'propagate': True,
} }
} }

View File

@ -1,38 +1,37 @@
"""QueueListener that can be configured from logging.dictConfig"""
from atexit import register from atexit import register
from logging.config import ConvertingDict, ConvertingList, valid_ident from logging.config import ConvertingList
from logging.handlers import QueueHandler, QueueListener from logging.handlers import QueueHandler, QueueListener
from queue import Queue from queue import Queue
from django.conf import settings
def _resolve_handlers(_list):
def _resolve_handlers(l): """Evaluates ConvertingList by iterating over it"""
# import pudb; pu.db if not isinstance(_list, ConvertingList):
if not isinstance(l, ConvertingList): return _list
return l
# Indexing the list performs the evaluation. # Indexing the list performs the evaluation.
return [l[i] for i in range(len(l))] return [_list[i] for i in range(len(_list))]
class QueueListenerHandler(QueueHandler): class QueueListenerHandler(QueueHandler):
"""QueueListener that can be configured from logging.dictConfig"""
def __init__(self, handlers, respect_handler_level=False, auto_run=True, queue=Queue(-1)): def __init__(self, handlers, auto_run=True, queue=Queue(-1)):
super().__init__(queue) super().__init__(queue)
handlers = _resolve_handlers(handlers) handlers = _resolve_handlers(handlers)
self._listener = QueueListener( self._listener = QueueListener(
self.queue, self.queue,
*handlers, *handlers,
respect_handler_level=respect_handler_level) respect_handler_level=True)
if auto_run: if auto_run:
self.start() self.start()
register(self.stop) register(self.stop)
def start(self): def start(self):
"""start background thread"""
self._listener.start() self._listener.start()
def stop(self): def stop(self):
"""stop background thread"""
self._listener.stop() self._listener.stop()
def emit(self, record):
return super().emit(record)