From 561e6956fe4f93ccc29326f8b00734816f1d948d Mon Sep 17 00:00:00 2001 From: Jens L Date: Mon, 31 Jul 2023 19:34:59 +0200 Subject: [PATCH] root: add get_int to config loader instead of casting to int everywhere (#6436) * root: add get_int to config loader instead of casting to int everywhere Signed-off-by: Jens Langhammer * improve error handling, add test Signed-off-by: Jens Langhammer --------- Signed-off-by: Jens Langhammer --- authentik/api/v3/config.py | 8 ++++---- authentik/core/models.py | 2 +- authentik/flows/planner.py | 2 +- authentik/lib/config.py | 8 ++++++++ authentik/lib/tests/test_config.py | 12 ++++++++++++ authentik/policies/process.py | 2 +- authentik/policies/reputation/signals.py | 2 +- authentik/root/install_id.py | 2 +- authentik/root/settings.py | 10 +++++----- authentik/sources/ldap/sync/base.py | 2 +- authentik/sources/ldap/tasks.py | 8 ++++---- authentik/stages/email/models.py | 4 ++-- lifecycle/gunicorn.conf.py | 4 ++-- lifecycle/migrate.py | 2 +- lifecycle/wait_for_db.py | 4 ++-- 15 files changed, 46 insertions(+), 26 deletions(-) diff --git a/authentik/api/v3/config.py b/authentik/api/v3/config.py index dc9bc37df..bbc676647 100644 --- a/authentik/api/v3/config.py +++ b/authentik/api/v3/config.py @@ -93,10 +93,10 @@ class ConfigView(APIView): "traces_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.4)), }, "capabilities": self.get_capabilities(), - "cache_timeout": int(CONFIG.get("redis.cache_timeout")), - "cache_timeout_flows": int(CONFIG.get("redis.cache_timeout_flows")), - "cache_timeout_policies": int(CONFIG.get("redis.cache_timeout_policies")), - "cache_timeout_reputation": int(CONFIG.get("redis.cache_timeout_reputation")), + "cache_timeout": CONFIG.get_int("redis.cache_timeout"), + "cache_timeout_flows": CONFIG.get_int("redis.cache_timeout_flows"), + "cache_timeout_policies": CONFIG.get_int("redis.cache_timeout_policies"), + "cache_timeout_reputation": CONFIG.get_int("redis.cache_timeout_reputation"), } ) diff --git a/authentik/core/models.py b/authentik/core/models.py index 5eafdfd25..c735f6ecb 100644 --- a/authentik/core/models.py +++ b/authentik/core/models.py @@ -60,7 +60,7 @@ def default_token_key(): """Default token key""" # We use generate_id since the chars in the key should be easy # to use in Emails (for verification) and URLs (for recovery) - return generate_id(int(CONFIG.get("default_token_length"))) + return generate_id(CONFIG.get_int("default_token_length")) class UserTypes(models.TextChoices): diff --git a/authentik/flows/planner.py b/authentik/flows/planner.py index 962273b6a..4cf3c6aad 100644 --- a/authentik/flows/planner.py +++ b/authentik/flows/planner.py @@ -33,7 +33,7 @@ PLAN_CONTEXT_SOURCE = "source" # Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan # was restored. PLAN_CONTEXT_IS_RESTORED = "is_restored" -CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_flows")) +CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_flows") CACHE_PREFIX = "goauthentik.io/flows/planner/" diff --git a/authentik/lib/config.py b/authentik/lib/config.py index 70e615e11..043f77460 100644 --- a/authentik/lib/config.py +++ b/authentik/lib/config.py @@ -213,6 +213,14 @@ class ConfigLoader: attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr(default)) return attr.value + def get_int(self, path: str, default=0) -> int: + """Wrapper for get that converts value into int""" + try: + return int(self.get(path, default)) + except ValueError as exc: + self.log("warning", "Failed to parse config as int", path=path, exc=str(exc)) + return default + def get_bool(self, path: str, default=False) -> bool: """Wrapper for get that converts value into boolean""" return str(self.get(path, default)).lower() == "true" diff --git a/authentik/lib/tests/test_config.py b/authentik/lib/tests/test_config.py index 6a5765632..d95ff8fb8 100644 --- a/authentik/lib/tests/test_config.py +++ b/authentik/lib/tests/test_config.py @@ -79,3 +79,15 @@ class TestConfig(TestCase): config.update_from_file(file2_name) unlink(file_name) unlink(file2_name) + + def test_get_int(self): + """Test get_int""" + config = ConfigLoader() + config.set("foo", 1234) + self.assertEqual(config.get_int("foo"), 1234) + + def test_get_int_invalid(self): + """Test get_int""" + config = ConfigLoader() + config.set("foo", "bar") + self.assertEqual(config.get_int("foo", 1234), 1234) diff --git a/authentik/policies/process.py b/authentik/policies/process.py index fe05d8571..3c6710a51 100644 --- a/authentik/policies/process.py +++ b/authentik/policies/process.py @@ -19,7 +19,7 @@ from authentik.policies.types import CACHE_PREFIX, PolicyRequest, PolicyResult LOGGER = get_logger() FORK_CTX = get_context("fork") -CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_policies")) +CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_policies") PROCESS_CLASS = FORK_CTX.Process diff --git a/authentik/policies/reputation/signals.py b/authentik/policies/reputation/signals.py index af78e6109..2ee6045df 100644 --- a/authentik/policies/reputation/signals.py +++ b/authentik/policies/reputation/signals.py @@ -13,7 +13,7 @@ from authentik.policies.reputation.tasks import save_reputation from authentik.stages.identification.signals import identification_failed LOGGER = get_logger() -CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_reputation")) +CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_reputation") def update_score(request: HttpRequest, identifier: str, amount: int): diff --git a/authentik/root/install_id.py b/authentik/root/install_id.py index a2b4fc544..6275e8f5b 100644 --- a/authentik/root/install_id.py +++ b/authentik/root/install_id.py @@ -30,7 +30,7 @@ def get_install_id_raw(): user=CONFIG.get("postgresql.user"), password=CONFIG.get("postgresql.password"), host=CONFIG.get("postgresql.host"), - port=int(CONFIG.get("postgresql.port")), + port=CONFIG.get_int("postgresql.port"), sslmode=CONFIG.get("postgresql.sslmode"), sslrootcert=CONFIG.get("postgresql.sslrootcert"), sslcert=CONFIG.get("postgresql.sslcert"), diff --git a/authentik/root/settings.py b/authentik/root/settings.py index 29294a5a1..0302b3116 100644 --- a/authentik/root/settings.py +++ b/authentik/root/settings.py @@ -190,14 +190,14 @@ if CONFIG.get_bool("redis.tls", False): _redis_url = ( f"{_redis_protocol_prefix}:" f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:" - f"{int(CONFIG.get('redis.port'))}" + f"{CONFIG.get_int('redis.port')}" ) CACHES = { "default": { "BACKEND": "django_redis.cache.RedisCache", "LOCATION": f"{_redis_url}/{CONFIG.get('redis.db')}", - "TIMEOUT": int(CONFIG.get("redis.cache_timeout", 300)), + "TIMEOUT": CONFIG.get_int("redis.cache_timeout", 300), "OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"}, "KEY_PREFIX": "authentik_cache", } @@ -274,7 +274,7 @@ DATABASES = { "NAME": CONFIG.get("postgresql.name"), "USER": CONFIG.get("postgresql.user"), "PASSWORD": CONFIG.get("postgresql.password"), - "PORT": int(CONFIG.get("postgresql.port")), + "PORT": CONFIG.get_int("postgresql.port"), "SSLMODE": CONFIG.get("postgresql.sslmode"), "SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"), "SSLCERT": CONFIG.get("postgresql.sslcert"), @@ -293,12 +293,12 @@ if CONFIG.get_bool("postgresql.use_pgbouncer", False): # loads the config directly from CONFIG # See authentik/stages/email/models.py, line 105 EMAIL_HOST = CONFIG.get("email.host") -EMAIL_PORT = int(CONFIG.get("email.port")) +EMAIL_PORT = CONFIG.get_int("email.port") EMAIL_HOST_USER = CONFIG.get("email.username") EMAIL_HOST_PASSWORD = CONFIG.get("email.password") EMAIL_USE_TLS = CONFIG.get_bool("email.use_tls", False) EMAIL_USE_SSL = CONFIG.get_bool("email.use_ssl", False) -EMAIL_TIMEOUT = int(CONFIG.get("email.timeout")) +EMAIL_TIMEOUT = CONFIG.get_int("email.timeout") DEFAULT_FROM_EMAIL = CONFIG.get("email.from") SERVER_EMAIL = DEFAULT_FROM_EMAIL EMAIL_SUBJECT_PREFIX = "[authentik] " diff --git a/authentik/sources/ldap/sync/base.py b/authentik/sources/ldap/sync/base.py index b544d70c9..b1d12df7f 100644 --- a/authentik/sources/ldap/sync/base.py +++ b/authentik/sources/ldap/sync/base.py @@ -93,7 +93,7 @@ class BaseLDAPSynchronizer: types_only=False, get_operational_attributes=False, controls=None, - paged_size=int(CONFIG.get("ldap.page_size", 50)), + paged_size=CONFIG.get_int("ldap.page_size", 50), paged_criticality=False, ): """Search in pages, returns each page""" diff --git a/authentik/sources/ldap/tasks.py b/authentik/sources/ldap/tasks.py index 0fec1ab0a..39aeb4c3c 100644 --- a/authentik/sources/ldap/tasks.py +++ b/authentik/sources/ldap/tasks.py @@ -59,7 +59,7 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> signatures = [] for page in sync_inst.get_objects(): page_cache_key = CACHE_KEY_PREFIX + str(uuid4()) - cache.set(page_cache_key, page, 60 * 60 * int(CONFIG.get("ldap.task_timeout_hours"))) + cache.set(page_cache_key, page, 60 * 60 * CONFIG.get_int("ldap.task_timeout_hours")) page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key) signatures.append(page_sync) return signatures @@ -68,12 +68,12 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> @CELERY_APP.task( bind=True, base=MonitoredTask, - soft_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")), - task_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")), + soft_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"), + task_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"), ) def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str): """Synchronization of an LDAP Source""" - self.result_timeout_hours = int(CONFIG.get("ldap.task_timeout_hours")) + self.result_timeout_hours = CONFIG.get_int("ldap.task_timeout_hours") source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first() if not source: # Because the source couldn't be found, we don't have a UID diff --git a/authentik/stages/email/models.py b/authentik/stages/email/models.py index 1beff50b6..fcc63868c 100644 --- a/authentik/stages/email/models.py +++ b/authentik/stages/email/models.py @@ -108,12 +108,12 @@ class EmailStage(Stage): CONFIG.refresh("email.password") return self.backend_class( host=CONFIG.get("email.host"), - port=int(CONFIG.get("email.port")), + port=CONFIG.get_int("email.port"), username=CONFIG.get("email.username"), password=CONFIG.get("email.password"), use_tls=CONFIG.get_bool("email.use_tls", False), use_ssl=CONFIG.get_bool("email.use_ssl", False), - timeout=int(CONFIG.get("email.timeout")), + timeout=CONFIG.get_int("email.timeout"), ) return self.backend_class( host=self.host, diff --git a/lifecycle/gunicorn.conf.py b/lifecycle/gunicorn.conf.py index 89dcb784f..9359196fc 100644 --- a/lifecycle/gunicorn.conf.py +++ b/lifecycle/gunicorn.conf.py @@ -80,8 +80,8 @@ if SERVICE_HOST_ENV_NAME in os.environ: else: default_workers = max(cpu_count() * 0.25, 1) + 1 # Minimum of 2 workers -workers = int(CONFIG.get("web.workers", default_workers)) -threads = int(CONFIG.get("web.threads", 4)) +workers = CONFIG.get_int("web.workers", default_workers) +threads = CONFIG.get_int("web.threads", 4) def post_fork(server: "Arbiter", worker: DjangoUvicornWorker): diff --git a/lifecycle/migrate.py b/lifecycle/migrate.py index fe4957652..b87cc2338 100755 --- a/lifecycle/migrate.py +++ b/lifecycle/migrate.py @@ -56,7 +56,7 @@ if __name__ == "__main__": user=CONFIG.get("postgresql.user"), password=CONFIG.get("postgresql.password"), host=CONFIG.get("postgresql.host"), - port=int(CONFIG.get("postgresql.port")), + port=CONFIG.get_int("postgresql.port"), sslmode=CONFIG.get("postgresql.sslmode"), sslrootcert=CONFIG.get("postgresql.sslrootcert"), sslcert=CONFIG.get("postgresql.sslcert"), diff --git a/lifecycle/wait_for_db.py b/lifecycle/wait_for_db.py index fd5c5c848..f464f56ba 100755 --- a/lifecycle/wait_for_db.py +++ b/lifecycle/wait_for_db.py @@ -28,7 +28,7 @@ while True: user=CONFIG.get("postgresql.user"), password=CONFIG.get("postgresql.password"), host=CONFIG.get("postgresql.host"), - port=int(CONFIG.get("postgresql.port")), + port=CONFIG.get_int("postgresql.port"), sslmode=CONFIG.get("postgresql.sslmode"), sslrootcert=CONFIG.get("postgresql.sslrootcert"), sslcert=CONFIG.get("postgresql.sslcert"), @@ -47,7 +47,7 @@ if CONFIG.get_bool("redis.tls", False): REDIS_URL = ( f"{REDIS_PROTOCOL_PREFIX}:" f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:" - f"{int(CONFIG.get('redis.port'))}/{CONFIG.get('redis.db')}" + f"{CONFIG.get_int('redis.port')}/{CONFIG.get('redis.db')}" ) while True: try: