root: add get_int to config loader instead of casting to int everywhere (#6436)
* root: add get_int to config loader instead of casting to int everywhere Signed-off-by: Jens Langhammer <jens@goauthentik.io> * improve error handling, add test Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
10b0c84d97
commit
561e6956fe
|
@ -93,10 +93,10 @@ class ConfigView(APIView):
|
||||||
"traces_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.4)),
|
"traces_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.4)),
|
||||||
},
|
},
|
||||||
"capabilities": self.get_capabilities(),
|
"capabilities": self.get_capabilities(),
|
||||||
"cache_timeout": int(CONFIG.get("redis.cache_timeout")),
|
"cache_timeout": CONFIG.get_int("redis.cache_timeout"),
|
||||||
"cache_timeout_flows": int(CONFIG.get("redis.cache_timeout_flows")),
|
"cache_timeout_flows": CONFIG.get_int("redis.cache_timeout_flows"),
|
||||||
"cache_timeout_policies": int(CONFIG.get("redis.cache_timeout_policies")),
|
"cache_timeout_policies": CONFIG.get_int("redis.cache_timeout_policies"),
|
||||||
"cache_timeout_reputation": int(CONFIG.get("redis.cache_timeout_reputation")),
|
"cache_timeout_reputation": CONFIG.get_int("redis.cache_timeout_reputation"),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -60,7 +60,7 @@ def default_token_key():
|
||||||
"""Default token key"""
|
"""Default token key"""
|
||||||
# We use generate_id since the chars in the key should be easy
|
# We use generate_id since the chars in the key should be easy
|
||||||
# to use in Emails (for verification) and URLs (for recovery)
|
# to use in Emails (for verification) and URLs (for recovery)
|
||||||
return generate_id(int(CONFIG.get("default_token_length")))
|
return generate_id(CONFIG.get_int("default_token_length"))
|
||||||
|
|
||||||
|
|
||||||
class UserTypes(models.TextChoices):
|
class UserTypes(models.TextChoices):
|
||||||
|
|
|
@ -33,7 +33,7 @@ PLAN_CONTEXT_SOURCE = "source"
|
||||||
# Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan
|
# Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan
|
||||||
# was restored.
|
# was restored.
|
||||||
PLAN_CONTEXT_IS_RESTORED = "is_restored"
|
PLAN_CONTEXT_IS_RESTORED = "is_restored"
|
||||||
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_flows"))
|
CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_flows")
|
||||||
CACHE_PREFIX = "goauthentik.io/flows/planner/"
|
CACHE_PREFIX = "goauthentik.io/flows/planner/"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -213,6 +213,14 @@ class ConfigLoader:
|
||||||
attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr(default))
|
attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr(default))
|
||||||
return attr.value
|
return attr.value
|
||||||
|
|
||||||
|
def get_int(self, path: str, default=0) -> int:
|
||||||
|
"""Wrapper for get that converts value into int"""
|
||||||
|
try:
|
||||||
|
return int(self.get(path, default))
|
||||||
|
except ValueError as exc:
|
||||||
|
self.log("warning", "Failed to parse config as int", path=path, exc=str(exc))
|
||||||
|
return default
|
||||||
|
|
||||||
def get_bool(self, path: str, default=False) -> bool:
|
def get_bool(self, path: str, default=False) -> bool:
|
||||||
"""Wrapper for get that converts value into boolean"""
|
"""Wrapper for get that converts value into boolean"""
|
||||||
return str(self.get(path, default)).lower() == "true"
|
return str(self.get(path, default)).lower() == "true"
|
||||||
|
|
|
@ -79,3 +79,15 @@ class TestConfig(TestCase):
|
||||||
config.update_from_file(file2_name)
|
config.update_from_file(file2_name)
|
||||||
unlink(file_name)
|
unlink(file_name)
|
||||||
unlink(file2_name)
|
unlink(file2_name)
|
||||||
|
|
||||||
|
def test_get_int(self):
|
||||||
|
"""Test get_int"""
|
||||||
|
config = ConfigLoader()
|
||||||
|
config.set("foo", 1234)
|
||||||
|
self.assertEqual(config.get_int("foo"), 1234)
|
||||||
|
|
||||||
|
def test_get_int_invalid(self):
|
||||||
|
"""Test get_int"""
|
||||||
|
config = ConfigLoader()
|
||||||
|
config.set("foo", "bar")
|
||||||
|
self.assertEqual(config.get_int("foo", 1234), 1234)
|
||||||
|
|
|
@ -19,7 +19,7 @@ from authentik.policies.types import CACHE_PREFIX, PolicyRequest, PolicyResult
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
FORK_CTX = get_context("fork")
|
FORK_CTX = get_context("fork")
|
||||||
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_policies"))
|
CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_policies")
|
||||||
PROCESS_CLASS = FORK_CTX.Process
|
PROCESS_CLASS = FORK_CTX.Process
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ from authentik.policies.reputation.tasks import save_reputation
|
||||||
from authentik.stages.identification.signals import identification_failed
|
from authentik.stages.identification.signals import identification_failed
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_reputation"))
|
CACHE_TIMEOUT = CONFIG.get_int("redis.cache_timeout_reputation")
|
||||||
|
|
||||||
|
|
||||||
def update_score(request: HttpRequest, identifier: str, amount: int):
|
def update_score(request: HttpRequest, identifier: str, amount: int):
|
||||||
|
|
|
@ -30,7 +30,7 @@ def get_install_id_raw():
|
||||||
user=CONFIG.get("postgresql.user"),
|
user=CONFIG.get("postgresql.user"),
|
||||||
password=CONFIG.get("postgresql.password"),
|
password=CONFIG.get("postgresql.password"),
|
||||||
host=CONFIG.get("postgresql.host"),
|
host=CONFIG.get("postgresql.host"),
|
||||||
port=int(CONFIG.get("postgresql.port")),
|
port=CONFIG.get_int("postgresql.port"),
|
||||||
sslmode=CONFIG.get("postgresql.sslmode"),
|
sslmode=CONFIG.get("postgresql.sslmode"),
|
||||||
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
|
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
|
||||||
sslcert=CONFIG.get("postgresql.sslcert"),
|
sslcert=CONFIG.get("postgresql.sslcert"),
|
||||||
|
|
|
@ -190,14 +190,14 @@ if CONFIG.get_bool("redis.tls", False):
|
||||||
_redis_url = (
|
_redis_url = (
|
||||||
f"{_redis_protocol_prefix}:"
|
f"{_redis_protocol_prefix}:"
|
||||||
f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:"
|
f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:"
|
||||||
f"{int(CONFIG.get('redis.port'))}"
|
f"{CONFIG.get_int('redis.port')}"
|
||||||
)
|
)
|
||||||
|
|
||||||
CACHES = {
|
CACHES = {
|
||||||
"default": {
|
"default": {
|
||||||
"BACKEND": "django_redis.cache.RedisCache",
|
"BACKEND": "django_redis.cache.RedisCache",
|
||||||
"LOCATION": f"{_redis_url}/{CONFIG.get('redis.db')}",
|
"LOCATION": f"{_redis_url}/{CONFIG.get('redis.db')}",
|
||||||
"TIMEOUT": int(CONFIG.get("redis.cache_timeout", 300)),
|
"TIMEOUT": CONFIG.get_int("redis.cache_timeout", 300),
|
||||||
"OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"},
|
"OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"},
|
||||||
"KEY_PREFIX": "authentik_cache",
|
"KEY_PREFIX": "authentik_cache",
|
||||||
}
|
}
|
||||||
|
@ -274,7 +274,7 @@ DATABASES = {
|
||||||
"NAME": CONFIG.get("postgresql.name"),
|
"NAME": CONFIG.get("postgresql.name"),
|
||||||
"USER": CONFIG.get("postgresql.user"),
|
"USER": CONFIG.get("postgresql.user"),
|
||||||
"PASSWORD": CONFIG.get("postgresql.password"),
|
"PASSWORD": CONFIG.get("postgresql.password"),
|
||||||
"PORT": int(CONFIG.get("postgresql.port")),
|
"PORT": CONFIG.get_int("postgresql.port"),
|
||||||
"SSLMODE": CONFIG.get("postgresql.sslmode"),
|
"SSLMODE": CONFIG.get("postgresql.sslmode"),
|
||||||
"SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"),
|
"SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"),
|
||||||
"SSLCERT": CONFIG.get("postgresql.sslcert"),
|
"SSLCERT": CONFIG.get("postgresql.sslcert"),
|
||||||
|
@ -293,12 +293,12 @@ if CONFIG.get_bool("postgresql.use_pgbouncer", False):
|
||||||
# loads the config directly from CONFIG
|
# loads the config directly from CONFIG
|
||||||
# See authentik/stages/email/models.py, line 105
|
# See authentik/stages/email/models.py, line 105
|
||||||
EMAIL_HOST = CONFIG.get("email.host")
|
EMAIL_HOST = CONFIG.get("email.host")
|
||||||
EMAIL_PORT = int(CONFIG.get("email.port"))
|
EMAIL_PORT = CONFIG.get_int("email.port")
|
||||||
EMAIL_HOST_USER = CONFIG.get("email.username")
|
EMAIL_HOST_USER = CONFIG.get("email.username")
|
||||||
EMAIL_HOST_PASSWORD = CONFIG.get("email.password")
|
EMAIL_HOST_PASSWORD = CONFIG.get("email.password")
|
||||||
EMAIL_USE_TLS = CONFIG.get_bool("email.use_tls", False)
|
EMAIL_USE_TLS = CONFIG.get_bool("email.use_tls", False)
|
||||||
EMAIL_USE_SSL = CONFIG.get_bool("email.use_ssl", False)
|
EMAIL_USE_SSL = CONFIG.get_bool("email.use_ssl", False)
|
||||||
EMAIL_TIMEOUT = int(CONFIG.get("email.timeout"))
|
EMAIL_TIMEOUT = CONFIG.get_int("email.timeout")
|
||||||
DEFAULT_FROM_EMAIL = CONFIG.get("email.from")
|
DEFAULT_FROM_EMAIL = CONFIG.get("email.from")
|
||||||
SERVER_EMAIL = DEFAULT_FROM_EMAIL
|
SERVER_EMAIL = DEFAULT_FROM_EMAIL
|
||||||
EMAIL_SUBJECT_PREFIX = "[authentik] "
|
EMAIL_SUBJECT_PREFIX = "[authentik] "
|
||||||
|
|
|
@ -93,7 +93,7 @@ class BaseLDAPSynchronizer:
|
||||||
types_only=False,
|
types_only=False,
|
||||||
get_operational_attributes=False,
|
get_operational_attributes=False,
|
||||||
controls=None,
|
controls=None,
|
||||||
paged_size=int(CONFIG.get("ldap.page_size", 50)),
|
paged_size=CONFIG.get_int("ldap.page_size", 50),
|
||||||
paged_criticality=False,
|
paged_criticality=False,
|
||||||
):
|
):
|
||||||
"""Search in pages, returns each page"""
|
"""Search in pages, returns each page"""
|
||||||
|
|
|
@ -59,7 +59,7 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
|
||||||
signatures = []
|
signatures = []
|
||||||
for page in sync_inst.get_objects():
|
for page in sync_inst.get_objects():
|
||||||
page_cache_key = CACHE_KEY_PREFIX + str(uuid4())
|
page_cache_key = CACHE_KEY_PREFIX + str(uuid4())
|
||||||
cache.set(page_cache_key, page, 60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")))
|
cache.set(page_cache_key, page, 60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"))
|
||||||
page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key)
|
page_sync = ldap_sync.si(source.pk, class_to_path(sync), page_cache_key)
|
||||||
signatures.append(page_sync)
|
signatures.append(page_sync)
|
||||||
return signatures
|
return signatures
|
||||||
|
@ -68,12 +68,12 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
|
||||||
@CELERY_APP.task(
|
@CELERY_APP.task(
|
||||||
bind=True,
|
bind=True,
|
||||||
base=MonitoredTask,
|
base=MonitoredTask,
|
||||||
soft_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")),
|
soft_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"),
|
||||||
task_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")),
|
task_time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours"),
|
||||||
)
|
)
|
||||||
def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str):
|
def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str):
|
||||||
"""Synchronization of an LDAP Source"""
|
"""Synchronization of an LDAP Source"""
|
||||||
self.result_timeout_hours = int(CONFIG.get("ldap.task_timeout_hours"))
|
self.result_timeout_hours = CONFIG.get_int("ldap.task_timeout_hours")
|
||||||
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()
|
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()
|
||||||
if not source:
|
if not source:
|
||||||
# Because the source couldn't be found, we don't have a UID
|
# Because the source couldn't be found, we don't have a UID
|
||||||
|
|
|
@ -108,12 +108,12 @@ class EmailStage(Stage):
|
||||||
CONFIG.refresh("email.password")
|
CONFIG.refresh("email.password")
|
||||||
return self.backend_class(
|
return self.backend_class(
|
||||||
host=CONFIG.get("email.host"),
|
host=CONFIG.get("email.host"),
|
||||||
port=int(CONFIG.get("email.port")),
|
port=CONFIG.get_int("email.port"),
|
||||||
username=CONFIG.get("email.username"),
|
username=CONFIG.get("email.username"),
|
||||||
password=CONFIG.get("email.password"),
|
password=CONFIG.get("email.password"),
|
||||||
use_tls=CONFIG.get_bool("email.use_tls", False),
|
use_tls=CONFIG.get_bool("email.use_tls", False),
|
||||||
use_ssl=CONFIG.get_bool("email.use_ssl", False),
|
use_ssl=CONFIG.get_bool("email.use_ssl", False),
|
||||||
timeout=int(CONFIG.get("email.timeout")),
|
timeout=CONFIG.get_int("email.timeout"),
|
||||||
)
|
)
|
||||||
return self.backend_class(
|
return self.backend_class(
|
||||||
host=self.host,
|
host=self.host,
|
||||||
|
|
|
@ -80,8 +80,8 @@ if SERVICE_HOST_ENV_NAME in os.environ:
|
||||||
else:
|
else:
|
||||||
default_workers = max(cpu_count() * 0.25, 1) + 1 # Minimum of 2 workers
|
default_workers = max(cpu_count() * 0.25, 1) + 1 # Minimum of 2 workers
|
||||||
|
|
||||||
workers = int(CONFIG.get("web.workers", default_workers))
|
workers = CONFIG.get_int("web.workers", default_workers)
|
||||||
threads = int(CONFIG.get("web.threads", 4))
|
threads = CONFIG.get_int("web.threads", 4)
|
||||||
|
|
||||||
|
|
||||||
def post_fork(server: "Arbiter", worker: DjangoUvicornWorker):
|
def post_fork(server: "Arbiter", worker: DjangoUvicornWorker):
|
||||||
|
|
|
@ -56,7 +56,7 @@ if __name__ == "__main__":
|
||||||
user=CONFIG.get("postgresql.user"),
|
user=CONFIG.get("postgresql.user"),
|
||||||
password=CONFIG.get("postgresql.password"),
|
password=CONFIG.get("postgresql.password"),
|
||||||
host=CONFIG.get("postgresql.host"),
|
host=CONFIG.get("postgresql.host"),
|
||||||
port=int(CONFIG.get("postgresql.port")),
|
port=CONFIG.get_int("postgresql.port"),
|
||||||
sslmode=CONFIG.get("postgresql.sslmode"),
|
sslmode=CONFIG.get("postgresql.sslmode"),
|
||||||
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
|
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
|
||||||
sslcert=CONFIG.get("postgresql.sslcert"),
|
sslcert=CONFIG.get("postgresql.sslcert"),
|
||||||
|
|
|
@ -28,7 +28,7 @@ while True:
|
||||||
user=CONFIG.get("postgresql.user"),
|
user=CONFIG.get("postgresql.user"),
|
||||||
password=CONFIG.get("postgresql.password"),
|
password=CONFIG.get("postgresql.password"),
|
||||||
host=CONFIG.get("postgresql.host"),
|
host=CONFIG.get("postgresql.host"),
|
||||||
port=int(CONFIG.get("postgresql.port")),
|
port=CONFIG.get_int("postgresql.port"),
|
||||||
sslmode=CONFIG.get("postgresql.sslmode"),
|
sslmode=CONFIG.get("postgresql.sslmode"),
|
||||||
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
|
sslrootcert=CONFIG.get("postgresql.sslrootcert"),
|
||||||
sslcert=CONFIG.get("postgresql.sslcert"),
|
sslcert=CONFIG.get("postgresql.sslcert"),
|
||||||
|
@ -47,7 +47,7 @@ if CONFIG.get_bool("redis.tls", False):
|
||||||
REDIS_URL = (
|
REDIS_URL = (
|
||||||
f"{REDIS_PROTOCOL_PREFIX}:"
|
f"{REDIS_PROTOCOL_PREFIX}:"
|
||||||
f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:"
|
f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:"
|
||||||
f"{int(CONFIG.get('redis.port'))}/{CONFIG.get('redis.db')}"
|
f"{CONFIG.get_int('redis.port')}/{CONFIG.get('redis.db')}"
|
||||||
)
|
)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
|
Reference in New Issue