root: add get_int to config loader instead of casting to int everywhere (#6436)
* root: add get_int to config loader instead of casting to int everywhere Signed-off-by: Jens Langhammer <jens@goauthentik.io> * improve error handling, add test Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
10b0c84d97
commit
561e6956fe
|
@ -93,10 +93,10 @@ class ConfigView(APIView):
|
|||
"traces_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.4)),
|
||||
},
|
||||
"capabilities": self.get_capabilities(),
|
||||
"cache_timeout": int(CONFIG.get("redis.cache_timeout")),
|
||||
"cache_timeout_flows": int(CONFIG.get("redis.cache_timeout_flows")),
|
||||
"cache_timeout_policies": int(CONFIG.get("redis.cache_timeout_policies")),
|
||||
"cache_timeout_reputation": int(CONFIG.get("redis.cache_timeout_reputation")),
|
||||
"cache_timeout": CONFIG.get_int("redis.cache_timeout"),
|
||||
"cache_timeout_flows": CONFIG.get_int("redis.cache_timeout_flows"),
|
||||
"cache_timeout_policies": CONFIG.get_int("redis.cache_timeout_policies"),
|
||||
"cache_timeout_reputation": CONFIG.get_int("redis.cache_timeout_reputation"),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -60,7 +60,7 @@ def default_token_key():
|
|||
"""Default token key"""
|
||||
# We use generate_id since the chars in the key should be easy
|
||||
# to use in Emails (for verification) and URLs (for recovery)
|
||||
return generate_id(int(CONFIG.get("default_token_length")))
|
||||
return generate_id(CONFIG.get_int("default_token_length"))
|
||||
|
||||
|
||||
class UserTypes(models.TextChoices):
|
||||
|
|
|
@ -33,7 +33,7 @@ PLAN_CONTEXT_SOURCE = "source"
|
|||
# Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan
|
||||
# was restored.
|
||||
PLAN_CONTEXT_IS_RESTORED = "is_restored"
|
||||
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_flows"))
|
||||
CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_flows")
|
||||
CACHE_PREFIX = "goauthentik.io/flows/planner/"
|
||||
|
||||
|
||||
|
|
|
@ -213,6 +213,14 @@ class ConfigLoader:
|
|||
attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr(default))
|
||||
return attr.value
|
||||
|
||||
def get_int(self, path: str, default=0) -> int:
|
||||
"""Wrapper for get that converts value into int"""
|
||||
try:
|
||||
return int(self.get(path, default))
|
||||
except ValueError as exc:
|
||||
self.log("warning", "Failed to parse config as int", path=path, exc=str(exc))
|
||||
return default
|
||||
|
||||
def get_bool(self, path: str, default=False) -> bool:
|
||||
"""Wrapper for get that converts value into boolean"""
|
||||
return str(self.get(path, default)).lower() == "true"
|
||||
|
|
|
@ -79,3 +79,15 @@ class TestConfig(TestCase):
|
|||
config.update_from_file(file2_name)
|
||||
unlink(file_name)
|
||||
unlink(file2_name)
|
||||
|
||||
def test_get_int(self):
|
||||
"""Test get_int"""
|
||||
config = ConfigLoader()
|
||||
config.set("foo", 1234)
|
||||
self.assertEqual(config.get_int("foo"), 1234)
|
||||
|
||||
def test_get_int_invalid(self):
|
||||
"""Test get_int"""
|
||||
config = ConfigLoader()
|
||||
config.set("foo", "bar")
|
||||
self.assertEqual(config.get_int("foo", 1234), 1234)
|
||||
|
|
|
@ -19,7 +19,7 @@ from authentik.policies.types import CACHE_PREFIX, PolicyRequest, PolicyResult
|
|||
LOGGER = get_logger()
|
||||
|
||||
FORK_CTX = get_context("fork")
|
||||
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_policies"))
|
||||
CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_policies")
|
||||
PROCESS_CLASS = FORK_CTX.Process
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from authentik.policies.reputation.tasks import save_reputation
|
|||
from authentik.stages.identification.signals import identification_failed
|
||||
|
||||
LOGGER = get_logger()
|
||||
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_reputation"))
|
||||
CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_reputation")
|
||||
|
||||
|
||||
def update_score(request: HttpRequest, identifier: str, amount: int):
|
||||
|
|
|
@ -30,7 +30,7 @@ def get_install_id_raw():
|
|||
user=CONFIG.get("postgresql.user"),
|
||||
password=CONFIG.get("postgresql.password"),
|
||||
host=CONFIG.get("postgresql.host"),
|
||||
port=int(CONFIG.get("postgresql.port")),
|
||||
port=CONFIG.get_int("postgresql.port"),
|
||||
sslmode=CONFIG.get("postgresql.sslmode"),
|
||||
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
|
||||
sslcert=CONFIG.get("postgresql.sslcert"),
|
||||
|
|
|
@ -190,14 +190,14 @@ if CONFIG.get_bool("redis.tls", False):
|
|||
_redis_url = (
|
||||
f"{_redis_protocol_prefix}:"
|
||||
f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:"
|
||||
f"{int(CONFIG.get('redis.port'))}"
|
||||
f"{CONFIG.get_int('redis.port')}"
|
||||
)
|
||||
|
||||
CACHES = {
|
||||
"default": {
|
||||
"BACKEND": "django_redis.cache.RedisCache",
|
||||
"LOCATION": f"{_redis_url}/{CONFIG.get('redis.db')}",
|
||||
"TIMEOUT": int(CONFIG.get("redis.cache_timeout", 300)),
|
||||
"TIMEOUT": CONFIG.get_int("redis.cache_timeout", 300),
|
||||
"OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"},
|
||||
"KEY_PREFIX": "authentik_cache",
|
||||
}
|
||||
|
@ -274,7 +274,7 @@ DATABASES = {
|
|||
"NAME": CONFIG.get("postgresql.name"),
|
||||
"USER": CONFIG.get("postgresql.user"),
|
||||
"PASSWORD": CONFIG.get("postgresql.password"),
|
||||
"PORT": int(CONFIG.get("postgresql.port")),
|
||||
"PORT": CONFIG.get_int("postgresql.port"),
|
||||
"SSLMODE": CONFIG.get("postgresql.sslmode"),
|
||||
"SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"),
|
||||
"SSLCERT": CONFIG.get("postgresql.sslcert"),
|
||||
|
@ -293,12 +293,12 @@ if CONFIG.get_bool("postgresql.use_pgbouncer", False):
|
|||
# loads the config directly from CONFIG
|
||||
# See authentik/stages/email/models.py, line 105
|
||||
EMAIL_HOST = CONFIG.get("email.host")
|
||||
EMAIL_PORT = int(CONFIG.get("email.port"))
|
||||
EMAIL_PORT = CONFIG.get_int("email.port")
|
||||
EMAIL_HOST_USER = CONFIG.get("email.username")
|
||||
EMAIL_HOST_PASSWORD = CONFIG.get("email.password")
|
||||
EMAIL_USE_TLS = CONFIG.get_bool("email.use_tls", False)
|
||||
EMAIL_USE_SSL = CONFIG.get_bool("email.use_ssl", False)
|
||||
EMAIL_TIMEOUT = int(CONFIG.get("email.timeout"))
|
||||
EMAIL_TIMEOUT = CONFIG.get_int("email.timeout")
|
||||
DEFAULT_FROM_EMAIL = CONFIG.get("email.from")
|
||||
SERVER_EMAIL = DEFAULT_FROM_EMAIL
|
||||
EMAIL_SUBJECT_PREFIX = "[authentik] "
|
||||
|
|
|
@ -93,7 +93,7 @@ class BaseLDAPSynchronizer:
|
|||
types_only=False,
|
||||
get_operational_attributes=False,
|
||||
controls=None,
|
||||
paged_size=int(CONFIG.get("ldap.page_size", 50)),
|
||||
paged_size=CONFIG.get_int("ldap.page_size", 50),
|
||||
paged_criticality=False,
|
||||
):
|
||||
"""Search in pages, returns each page"""
|
||||
|
|
|
@ -59,7 +59,7 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
|
|||
signatures = []
|
||||
for page in sync_inst.get_objects():
|
||||
page_cache_key = CACHE_KEY_PREFIX + str(uuid4())
|
||||
cache.set(page_cache_key, page, 60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")))
|
||||
cache.set(page_cache_key, page, 60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"))
|
||||
page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key)
|
||||
signatures.append(page_sync)
|
||||
return signatures
|
||||
|
@ -68,12 +68,12 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
|
|||
@CELERY_APP.task(
|
||||
bind=True,
|
||||
base=MonitoredTask,
|
||||
soft_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")),
|
||||
task_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")),
|
||||
soft_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"),
|
||||
task_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"),
|
||||
)
|
||||
def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str):
|
||||
"""Synchronization of an LDAP Source"""
|
||||
self.result_timeout_hours = int(CONFIG.get("ldap.task_timeout_hours"))
|
||||
self.result_timeout_hours = CONFIG.get_int("ldap.task_timeout_hours")
|
||||
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()
|
||||
if not source:
|
||||
# Because the source couldn't be found, we don't have a UID
|
||||
|
|
|
@ -108,12 +108,12 @@ class EmailStage(Stage):
|
|||
CONFIG.refresh("email.password")
|
||||
return self.backend_class(
|
||||
host=CONFIG.get("email.host"),
|
||||
port=int(CONFIG.get("email.port")),
|
||||
port=CONFIG.get_int("email.port"),
|
||||
username=CONFIG.get("email.username"),
|
||||
password=CONFIG.get("email.password"),
|
||||
use_tls=CONFIG.get_bool("email.use_tls", False),
|
||||
use_ssl=CONFIG.get_bool("email.use_ssl", False),
|
||||
timeout=int(CONFIG.get("email.timeout")),
|
||||
timeout=CONFIG.get_int("email.timeout"),
|
||||
)
|
||||
return self.backend_class(
|
||||
host=self.host,
|
||||
|
|
|
@ -80,8 +80,8 @@ if SERVICE_HOST_ENV_NAME in os.environ:
|
|||
else:
|
||||
default_workers = max(cpu_count() * 0.25, 1) + 1 # Minimum of 2 workers
|
||||
|
||||
workers = int(CONFIG.get("web.workers", default_workers))
|
||||
threads = int(CONFIG.get("web.threads", 4))
|
||||
workers = CONFIG.get_int("web.workers", default_workers)
|
||||
threads = CONFIG.get_int("web.threads", 4)
|
||||
|
||||
|
||||
def post_fork(server: "Arbiter", worker: DjangoUvicornWorker):
|
||||
|
|
|
@ -56,7 +56,7 @@ if __name__ == "__main__":
|
|||
user=CONFIG.get("postgresql.user"),
|
||||
password=CONFIG.get("postgresql.password"),
|
||||
host=CONFIG.get("postgresql.host"),
|
||||
port=int(CONFIG.get("postgresql.port")),
|
||||
port=CONFIG.get_int("postgresql.port"),
|
||||
sslmode=CONFIG.get("postgresql.sslmode"),
|
||||
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
|
||||
sslcert=CONFIG.get("postgresql.sslcert"),
|
||||
|
|
|
@ -28,7 +28,7 @@ while True:
|
|||
user=CONFIG.get("postgresql.user"),
|
||||
password=CONFIG.get("postgresql.password"),
|
||||
host=CONFIG.get("postgresql.host"),
|
||||
port=int(CONFIG.get("postgresql.port")),
|
||||
port=CONFIG.get_int("postgresql.port"),
|
||||
sslmode=CONFIG.get("postgresql.sslmode"),
|
||||
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
|
||||
sslcert=CONFIG.get("postgresql.sslcert"),
|
||||
|
@ -47,7 +47,7 @@ if CONFIG.get_bool("redis.tls", False):
|
|||
REDIS_URL = (
|
||||
f"{REDIS_PROTOCOL_PREFIX}:"
|
||||
f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:"
|
||||
f"{int(CONFIG.get('redis.port'))}/{CONFIG.get('redis.db')}"
|
||||
f"{CONFIG.get_int('redis.port')}/{CONFIG.get('redis.db')}"
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
|
|
Reference in a new issue