root: add get_int to config loader instead of casting to int everywhere (#6436)

* root: add get_int to config loader instead of casting to int everywhere

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* improve error handling, add test

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L 2023-07-31 19:34:59 +02:00 committed by GitHub
parent 10b0c84d97
commit 561e6956fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 46 additions and 26 deletions

View File

@ -93,10 +93,10 @@ class ConfigView(APIView):
"traces_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.4)), "traces_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.4)),
}, },
"capabilities": self.get_capabilities(), "capabilities": self.get_capabilities(),
"cache_timeout": int(CONFIG.get("redis.cache_timeout")), "cache_timeout": CONFIG.get_int("redis.cache_timeout"),
"cache_timeout_flows": int(CONFIG.get("redis.cache_timeout_flows")), "cache_timeout_flows": CONFIG.get_int("redis.cache_timeout_flows"),
"cache_timeout_policies": int(CONFIG.get("redis.cache_timeout_policies")), "cache_timeout_policies": CONFIG.get_int("redis.cache_timeout_policies"),
"cache_timeout_reputation": int(CONFIG.get("redis.cache_timeout_reputation")), "cache_timeout_reputation": CONFIG.get_int("redis.cache_timeout_reputation"),
} }
) )

View File

@ -60,7 +60,7 @@ def default_token_key():
"""Default token key""" """Default token key"""
# We use generate_id since the chars in the key should be easy # We use generate_id since the chars in the key should be easy
# to use in Emails (for verification) and URLs (for recovery) # to use in Emails (for verification) and URLs (for recovery)
return generate_id(int(CONFIG.get("default_token_length"))) return generate_id(CONFIG.get_int("default_token_length"))
class UserTypes(models.TextChoices): class UserTypes(models.TextChoices):

View File

@ -33,7 +33,7 @@ PLAN_CONTEXT_SOURCE = "source"
# Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan # Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan
# was restored. # was restored.
PLAN_CONTEXT_IS_RESTORED = "is_restored" PLAN_CONTEXT_IS_RESTORED = "is_restored"
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_flows")) CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_flows")
CACHE_PREFIX = "goauthentik.io/flows/planner/" CACHE_PREFIX = "goauthentik.io/flows/planner/"

View File

@ -213,6 +213,14 @@ class ConfigLoader:
attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr(default)) attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr(default))
return attr.value return attr.value
def get_int(self, path: str, default=0) -> int:
"""Wrapper for get that converts value into int"""
try:
return int(self.get(path, default))
except ValueError as exc:
self.log("warning", "Failed to parse config as int", path=path, exc=str(exc))
return default
def get_bool(self, path: str, default=False) -> bool: def get_bool(self, path: str, default=False) -> bool:
"""Wrapper for get that converts value into boolean""" """Wrapper for get that converts value into boolean"""
return str(self.get(path, default)).lower() == "true" return str(self.get(path, default)).lower() == "true"

View File

@ -79,3 +79,15 @@ class TestConfig(TestCase):
config.update_from_file(file2_name) config.update_from_file(file2_name)
unlink(file_name) unlink(file_name)
unlink(file2_name) unlink(file2_name)
def test_get_int(self):
"""Test get_int"""
config = ConfigLoader()
config.set("foo", 1234)
self.assertEqual(config.get_int("foo"), 1234)
def test_get_int_invalid(self):
"""Test get_int"""
config = ConfigLoader()
config.set("foo", "bar")
self.assertEqual(config.get_int("foo", 1234), 1234)

View File

@ -19,7 +19,7 @@ from authentik.policies.types import CACHE_PREFIX, PolicyRequest, PolicyResult
LOGGER = get_logger() LOGGER = get_logger()
FORK_CTX = get_context("fork") FORK_CTX = get_context("fork")
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_policies")) CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_policies")
PROCESS_CLASS = FORK_CTX.Process PROCESS_CLASS = FORK_CTX.Process

View File

@ -13,7 +13,7 @@ from authentik.policies.reputation.tasks import save_reputation
from authentik.stages.identification.signals import identification_failed from authentik.stages.identification.signals import identification_failed
LOGGER = get_logger() LOGGER = get_logger()
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_reputation")) CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_reputation")
def update_score(request: HttpRequest, identifier: str, amount: int): def update_score(request: HttpRequest, identifier: str, amount: int):

View File

@ -30,7 +30,7 @@ def get_install_id_raw():
user=CONFIG.get("postgresql.user"), user=CONFIG.get("postgresql.user"),
password=CONFIG.get("postgresql.password"), password=CONFIG.get("postgresql.password"),
host=CONFIG.get("postgresql.host"), host=CONFIG.get("postgresql.host"),
port=int(CONFIG.get("postgresql.port")), port=CONFIG.get_int("postgresql.port"),
sslmode=CONFIG.get("postgresql.sslmode"), sslmode=CONFIG.get("postgresql.sslmode"),
sslrootcert=CONFIG.get("postgresql.sslrootcert"), sslrootcert=CONFIG.get("postgresql.sslrootcert"),
sslcert=CONFIG.get("postgresql.sslcert"), sslcert=CONFIG.get("postgresql.sslcert"),

View File

@ -190,14 +190,14 @@ if CONFIG.get_bool("redis.tls", False):
_redis_url = ( _redis_url = (
f"{_redis_protocol_prefix}:" f"{_redis_protocol_prefix}:"
f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:" f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:"
f"{int(CONFIG.get('redis.port'))}" f"{CONFIG.get_int('redis.port')}"
) )
CACHES = { CACHES = {
"default": { "default": {
"BACKEND": "django_redis.cache.RedisCache", "BACKEND": "django_redis.cache.RedisCache",
"LOCATION": f"{_redis_url}/{CONFIG.get('redis.db')}", "LOCATION": f"{_redis_url}/{CONFIG.get('redis.db')}",
"TIMEOUT": int(CONFIG.get("redis.cache_timeout", 300)), "TIMEOUT": CONFIG.get_int("redis.cache_timeout", 300),
"OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"}, "OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"},
"KEY_PREFIX": "authentik_cache", "KEY_PREFIX": "authentik_cache",
} }
@ -274,7 +274,7 @@ DATABASES = {
"NAME": CONFIG.get("postgresql.name"), "NAME": CONFIG.get("postgresql.name"),
"USER": CONFIG.get("postgresql.user"), "USER": CONFIG.get("postgresql.user"),
"PASSWORD": CONFIG.get("postgresql.password"), "PASSWORD": CONFIG.get("postgresql.password"),
"PORT": int(CONFIG.get("postgresql.port")), "PORT": CONFIG.get_int("postgresql.port"),
"SSLMODE": CONFIG.get("postgresql.sslmode"), "SSLMODE": CONFIG.get("postgresql.sslmode"),
"SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"), "SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"),
"SSLCERT": CONFIG.get("postgresql.sslcert"), "SSLCERT": CONFIG.get("postgresql.sslcert"),
@ -293,12 +293,12 @@ if CONFIG.get_bool("postgresql.use_pgbouncer", False):
# loads the config directly from CONFIG # loads the config directly from CONFIG
# See authentik/stages/email/models.py, line 105 # See authentik/stages/email/models.py, line 105
EMAIL_HOST = CONFIG.get("email.host") EMAIL_HOST = CONFIG.get("email.host")
EMAIL_PORT = int(CONFIG.get("email.port")) EMAIL_PORT = CONFIG.get_int("email.port")
EMAIL_HOST_USER = CONFIG.get("email.username") EMAIL_HOST_USER = CONFIG.get("email.username")
EMAIL_HOST_PASSWORD = CONFIG.get("email.password") EMAIL_HOST_PASSWORD = CONFIG.get("email.password")
EMAIL_USE_TLS = CONFIG.get_bool("email.use_tls", False) EMAIL_USE_TLS = CONFIG.get_bool("email.use_tls", False)
EMAIL_USE_SSL = CONFIG.get_bool("email.use_ssl", False) EMAIL_USE_SSL = CONFIG.get_bool("email.use_ssl", False)
EMAIL_TIMEOUT = int(CONFIG.get("email.timeout")) EMAIL_TIMEOUT = CONFIG.get_int("email.timeout")
DEFAULT_FROM_EMAIL = CONFIG.get("email.from") DEFAULT_FROM_EMAIL = CONFIG.get("email.from")
SERVER_EMAIL = DEFAULT_FROM_EMAIL SERVER_EMAIL = DEFAULT_FROM_EMAIL
EMAIL_SUBJECT_PREFIX = "[authentik] " EMAIL_SUBJECT_PREFIX = "[authentik] "

View File

@ -93,7 +93,7 @@ class BaseLDAPSynchronizer:
types_only=False, types_only=False,
get_operational_attributes=False, get_operational_attributes=False,
controls=None, controls=None,
paged_size=int(CONFIG.get("ldap.page_size", 50)), paged_size=CONFIG.get_int("ldap.page_size", 50),
paged_criticality=False, paged_criticality=False,
): ):
"""Search in pages, returns each page""" """Search in pages, returns each page"""

View File

@ -59,7 +59,7 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
signatures = [] signatures = []
for page in sync_inst.get_objects(): for page in sync_inst.get_objects():
page_cache_key = CACHE_KEY_PREFIX + str(uuid4()) page_cache_key = CACHE_KEY_PREFIX + str(uuid4())
cache.set(page_cache_key, page, 60 * 60 * int(CONFIG.get("ldap.task_timeout_hours"))) cache.set(page_cache_key, page, 60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"))
page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key) page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key)
signatures.append(page_sync) signatures.append(page_sync)
return signatures return signatures
@ -68,12 +68,12 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
@CELERY_APP.task( @CELERY_APP.task(
bind=True, bind=True,
base=MonitoredTask, base=MonitoredTask,
soft_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")), soft_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"),
task_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")), task_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"),
) )
def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str): def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str):
"""Synchronization of an LDAP Source""" """Synchronization of an LDAP Source"""
self.result_timeout_hours = int(CONFIG.get("ldap.task_timeout_hours")) self.result_timeout_hours = CONFIG.get_int("ldap.task_timeout_hours")
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first() source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()
if not source: if not source:
# Because the source couldn't be found, we don't have a UID # Because the source couldn't be found, we don't have a UID

View File

@ -108,12 +108,12 @@ class EmailStage(Stage):
CONFIG.refresh("email.password") CONFIG.refresh("email.password")
return self.backend_class( return self.backend_class(
host=CONFIG.get("email.host"), host=CONFIG.get("email.host"),
port=int(CONFIG.get("email.port")), port=CONFIG.get_int("email.port"),
username=CONFIG.get("email.username"), username=CONFIG.get("email.username"),
password=CONFIG.get("email.password"), password=CONFIG.get("email.password"),
use_tls=CONFIG.get_bool("email.use_tls", False), use_tls=CONFIG.get_bool("email.use_tls", False),
use_ssl=CONFIG.get_bool("email.use_ssl", False), use_ssl=CONFIG.get_bool("email.use_ssl", False),
timeout=int(CONFIG.get("email.timeout")), timeout=CONFIG.get_int("email.timeout"),
) )
return self.backend_class( return self.backend_class(
host=self.host, host=self.host,

View File

@ -80,8 +80,8 @@ if SERVICE_HOST_ENV_NAME in os.environ:
else: else:
default_workers = max(cpu_count() * 0.25, 1) + 1 # Minimum of 2 workers default_workers = max(cpu_count() * 0.25, 1) + 1 # Minimum of 2 workers
workers = int(CONFIG.get("web.workers", default_workers)) workers = CONFIG.get_int("web.workers", default_workers)
threads = int(CONFIG.get("web.threads", 4)) threads = CONFIG.get_int("web.threads", 4)
def post_fork(server: "Arbiter", worker: DjangoUvicornWorker): def post_fork(server: "Arbiter", worker: DjangoUvicornWorker):

View File

@ -56,7 +56,7 @@ if __name__ == "__main__":
user=CONFIG.get("postgresql.user"), user=CONFIG.get("postgresql.user"),
password=CONFIG.get("postgresql.password"), password=CONFIG.get("postgresql.password"),
host=CONFIG.get("postgresql.host"), host=CONFIG.get("postgresql.host"),
port=int(CONFIG.get("postgresql.port")), port=CONFIG.get_int("postgresql.port"),
sslmode=CONFIG.get("postgresql.sslmode"), sslmode=CONFIG.get("postgresql.sslmode"),
sslrootcert=CONFIG.get("postgresql.sslrootcert"), sslrootcert=CONFIG.get("postgresql.sslrootcert"),
sslcert=CONFIG.get("postgresql.sslcert"), sslcert=CONFIG.get("postgresql.sslcert"),

View File

@ -28,7 +28,7 @@ while True:
user=CONFIG.get("postgresql.user"), user=CONFIG.get("postgresql.user"),
password=CONFIG.get("postgresql.password"), password=CONFIG.get("postgresql.password"),
host=CONFIG.get("postgresql.host"), host=CONFIG.get("postgresql.host"),
port=int(CONFIG.get("postgresql.port")), port=CONFIG.get_int("postgresql.port"),
sslmode=CONFIG.get("postgresql.sslmode"), sslmode=CONFIG.get("postgresql.sslmode"),
sslrootcert=CONFIG.get("postgresql.sslrootcert"), sslrootcert=CONFIG.get("postgresql.sslrootcert"),
sslcert=CONFIG.get("postgresql.sslcert"), sslcert=CONFIG.get("postgresql.sslcert"),
@ -47,7 +47,7 @@ if CONFIG.get_bool("redis.tls", False):
REDIS_URL = ( REDIS_URL = (
f"{REDIS_PROTOCOL_PREFIX}:" f"{REDIS_PROTOCOL_PREFIX}:"
f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:" f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:"
f"{int(CONFIG.get('redis.port'))}/{CONFIG.get('redis.db')}" f"{CONFIG.get_int('redis.port')}/{CONFIG.get('redis.db')}"
) )
while True: while True:
try: try: