root: add get_int to config loader instead of casting to int everywhere (#6436)

* root: add get_int to config loader instead of casting to int everywhere

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* improve error handling, add test

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L 2023-07-31 19:34:59 +02:00 committed by GitHub
parent 10b0c84d97
commit 561e6956fe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 46 additions and 26 deletions

View file

@ -93,10 +93,10 @@ class ConfigView(APIView):
"traces_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.4)),
},
"capabilities": self.get_capabilities(),
"cache_timeout": int(CONFIG.get("redis.cache_timeout")),
"cache_timeout_flows": int(CONFIG.get("redis.cache_timeout_flows")),
"cache_timeout_policies": int(CONFIG.get("redis.cache_timeout_policies")),
"cache_timeout_reputation": int(CONFIG.get("redis.cache_timeout_reputation")),
"cache_timeout": CONFIG.get_int("redis.cache_timeout"),
"cache_timeout_flows": CONFIG.get_int("redis.cache_timeout_flows"),
"cache_timeout_policies": CONFIG.get_int("redis.cache_timeout_policies"),
"cache_timeout_reputation": CONFIG.get_int("redis.cache_timeout_reputation"),
}
)

View file

@ -60,7 +60,7 @@ def default_token_key():
"""Default token key"""
# We use generate_id since the chars in the key should be easy
# to use in Emails (for verification) and URLs (for recovery)
return generate_id(int(CONFIG.get("default_token_length")))
return generate_id(CONFIG.get_int("default_token_length"))
class UserTypes(models.TextChoices):

View file

@ -33,7 +33,7 @@ PLAN_CONTEXT_SOURCE = "source"
# Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan
# was restored.
PLAN_CONTEXT_IS_RESTORED = "is_restored"
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_flows"))
CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_flows")
CACHE_PREFIX = "goauthentik.io/flows/planner/"

View file

@ -213,6 +213,14 @@ class ConfigLoader:
attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr(default))
return attr.value
def get_int(self, path: str, default=0) -> int:
"""Wrapper for get that converts value into int"""
try:
return int(self.get(path, default))
except ValueError as exc:
self.log("warning", "Failed to parse config as int", path=path, exc=str(exc))
return default
def get_bool(self, path: str, default=False) -> bool:
"""Wrapper for get that converts value into boolean"""
return str(self.get(path, default)).lower() == "true"

View file

@ -79,3 +79,15 @@ class TestConfig(TestCase):
config.update_from_file(file2_name)
unlink(file_name)
unlink(file2_name)
def test_get_int(self):
"""Test get_int"""
config = ConfigLoader()
config.set("foo", 1234)
self.assertEqual(config.get_int("foo"), 1234)
def test_get_int_invalid(self):
"""Test get_int"""
config = ConfigLoader()
config.set("foo", "bar")
self.assertEqual(config.get_int("foo", 1234), 1234)

View file

@ -19,7 +19,7 @@ from authentik.policies.types import CACHE_PREFIX, PolicyRequest, PolicyResult
LOGGER = get_logger()
FORK_CTX = get_context("fork")
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_policies"))
CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_policies")
PROCESS_CLASS = FORK_CTX.Process

View file

@ -13,7 +13,7 @@ from authentik.policies.reputation.tasks import save_reputation
from authentik.stages.identification.signals import identification_failed
LOGGER = get_logger()
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_reputation"))
CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_reputation")
def update_score(request: HttpRequest, identifier: str, amount: int):

View file

@ -30,7 +30,7 @@ def get_install_id_raw():
user=CONFIG.get("postgresql.user"),
password=CONFIG.get("postgresql.password"),
host=CONFIG.get("postgresql.host"),
port=int(CONFIG.get("postgresql.port")),
port=CONFIG.get_int("postgresql.port"),
sslmode=CONFIG.get("postgresql.sslmode"),
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
sslcert=CONFIG.get("postgresql.sslcert"),

View file

@ -190,14 +190,14 @@ if CONFIG.get_bool("redis.tls", False):
_redis_url = (
f"{_redis_protocol_prefix}:"
f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:"
f"{int(CONFIG.get('redis.port'))}"
f"{CONFIG.get_int('redis.port')}"
)
CACHES = {
"default": {
"BACKEND": "django_redis.cache.RedisCache",
"LOCATION": f"{_redis_url}/{CONFIG.get('redis.db')}",
"TIMEOUT": int(CONFIG.get("redis.cache_timeout", 300)),
"TIMEOUT": CONFIG.get_int("redis.cache_timeout", 300),
"OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"},
"KEY_PREFIX": "authentik_cache",
}
@ -274,7 +274,7 @@ DATABASES = {
"NAME": CONFIG.get("postgresql.name"),
"USER": CONFIG.get("postgresql.user"),
"PASSWORD": CONFIG.get("postgresql.password"),
"PORT": int(CONFIG.get("postgresql.port")),
"PORT": CONFIG.get_int("postgresql.port"),
"SSLMODE": CONFIG.get("postgresql.sslmode"),
"SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"),
"SSLCERT": CONFIG.get("postgresql.sslcert"),
@ -293,12 +293,12 @@ if CONFIG.get_bool("postgresql.use_pgbouncer", False):
# loads the config directly from CONFIG
# See authentik/stages/email/models.py, line 105
EMAIL_HOST = CONFIG.get("email.host")
EMAIL_PORT = int(CONFIG.get("email.port"))
EMAIL_PORT = CONFIG.get_int("email.port")
EMAIL_HOST_USER = CONFIG.get("email.username")
EMAIL_HOST_PASSWORD = CONFIG.get("email.password")
EMAIL_USE_TLS = CONFIG.get_bool("email.use_tls", False)
EMAIL_USE_SSL = CONFIG.get_bool("email.use_ssl", False)
EMAIL_TIMEOUT = int(CONFIG.get("email.timeout"))
EMAIL_TIMEOUT = CONFIG.get_int("email.timeout")
DEFAULT_FROM_EMAIL = CONFIG.get("email.from")
SERVER_EMAIL = DEFAULT_FROM_EMAIL
EMAIL_SUBJECT_PREFIX = "[authentik] "

View file

@ -93,7 +93,7 @@ class BaseLDAPSynchronizer:
types_only=False,
get_operational_attributes=False,
controls=None,
paged_size=int(CONFIG.get("ldap.page_size", 50)),
paged_size=CONFIG.get_int("ldap.page_size", 50),
paged_criticality=False,
):
"""Search in pages, returns each page"""

View file

@ -59,7 +59,7 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
signatures = []
for page in sync_inst.get_objects():
page_cache_key = CACHE_KEY_PREFIX + str(uuid4())
cache.set(page_cache_key, page, 60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")))
cache.set(page_cache_key, page, 60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"))
page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key)
signatures.append(page_sync)
return signatures
@ -68,12 +68,12 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
@CELERY_APP.task(
bind=True,
base=MonitoredTask,
soft_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")),
task_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")),
soft_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"),
task_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"),
)
def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str):
"""Synchronization of an LDAP Source"""
self.result_timeout_hours = int(CONFIG.get("ldap.task_timeout_hours"))
self.result_timeout_hours = CONFIG.get_int("ldap.task_timeout_hours")
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()
if not source:
# Because the source couldn't be found, we don't have a UID

View file

@ -108,12 +108,12 @@ class EmailStage(Stage):
CONFIG.refresh("email.password")
return self.backend_class(
host=CONFIG.get("email.host"),
port=int(CONFIG.get("email.port")),
port=CONFIG.get_int("email.port"),
username=CONFIG.get("email.username"),
password=CONFIG.get("email.password"),
use_tls=CONFIG.get_bool("email.use_tls", False),
use_ssl=CONFIG.get_bool("email.use_ssl", False),
timeout=int(CONFIG.get("email.timeout")),
timeout=CONFIG.get_int("email.timeout"),
)
return self.backend_class(
host=self.host,

View file

@ -80,8 +80,8 @@ if SERVICE_HOST_ENV_NAME in os.environ:
else:
default_workers = max(cpu_count() * 0.25, 1) + 1 # Minimum of 2 workers
workers = int(CONFIG.get("web.workers", default_workers))
threads = int(CONFIG.get("web.threads", 4))
workers = CONFIG.get_int("web.workers", default_workers)
threads = CONFIG.get_int("web.threads", 4)
def post_fork(server: "Arbiter", worker: DjangoUvicornWorker):

View file

@ -56,7 +56,7 @@ if __name__ == "__main__":
user=CONFIG.get("postgresql.user"),
password=CONFIG.get("postgresql.password"),
host=CONFIG.get("postgresql.host"),
port=int(CONFIG.get("postgresql.port")),
port=CONFIG.get_int("postgresql.port"),
sslmode=CONFIG.get("postgresql.sslmode"),
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
sslcert=CONFIG.get("postgresql.sslcert"),

View file

@ -28,7 +28,7 @@ while True:
user=CONFIG.get("postgresql.user"),
password=CONFIG.get("postgresql.password"),
host=CONFIG.get("postgresql.host"),
port=int(CONFIG.get("postgresql.port")),
port=CONFIG.get_int("postgresql.port"),
sslmode=CONFIG.get("postgresql.sslmode"),
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
sslcert=CONFIG.get("postgresql.sslcert"),
@ -47,7 +47,7 @@ if CONFIG.get_bool("redis.tls", False):
REDIS_URL = (
f"{REDIS_PROTOCOL_PREFIX}:"
f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:"
f"{int(CONFIG.get('redis.port'))}/{CONFIG.get('redis.db')}"
f"{CONFIG.get_int('redis.port')}/{CONFIG.get('redis.db')}"
)
while True:
try: