diff --git a/authentik/core/management/commands/dev_server.py b/authentik/core/management/commands/dev_server.py new file mode 100644 index 000000000..ea53167cd --- /dev/null +++ b/authentik/core/management/commands/dev_server.py @@ -0,0 +1,9 @@ +"""custom runserver command""" +from daphne.management.commands.runserver import Command as RunServer + + +class Command(RunServer): + """custom runserver command, which doesn't show the misleading django startup message""" + + def on_bind(self, server_port): + pass diff --git a/authentik/lib/logging.py b/authentik/lib/logging.py index 379507e74..682475230 100644 --- a/authentik/lib/logging.py +++ b/authentik/lib/logging.py @@ -1,7 +1,112 @@ """logging helpers""" +import logging from logging import Logger from os import getpid +import structlog + +from authentik.lib.config import CONFIG + +LOG_PRE_CHAIN = [ + # Add the log level and a timestamp to the event_dict if the log entry + # is not from structlog. + structlog.stdlib.add_log_level, + structlog.stdlib.add_logger_name, + structlog.processors.TimeStamper(), + structlog.processors.StackInfoRenderer(), +] + + +def get_log_level(): + """Get log level, clamp trace to debug""" + level = CONFIG.get("log_level").upper() + # We could add a custom level to stdlib logging and structlog, but it's not easy or clean + # https://stackoverflow.com/questions/54505487/custom-log-level-not-working-with-structlog + # Additionally, the entire code uses debug as highest level + # so that would have to be re-written too + if level == "TRACE": + level = "DEBUG" + return level + + +def structlog_configure(): + """Configure structlog itself""" + structlog.configure_once( + processors=[ + structlog.stdlib.add_log_level, + structlog.stdlib.add_logger_name, + structlog.contextvars.merge_contextvars, + add_process_id, + structlog.stdlib.PositionalArgumentsFormatter(), + structlog.processors.TimeStamper(fmt="iso", utc=False), + structlog.processors.StackInfoRenderer(), + structlog.processors.dict_tracebacks, + structlog.stdlib.ProcessorFormatter.wrap_for_formatter, + ], + logger_factory=structlog.stdlib.LoggerFactory(), + wrapper_class=structlog.make_filtering_bound_logger( + getattr(logging, get_log_level(), logging.WARNING) + ), + cache_logger_on_first_use=True, + ) + + +def get_logger_config(): + """Configure python stdlib's logging""" + debug = CONFIG.get_bool("debug") + global_level = get_log_level() + base_config = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "json": { + "()": structlog.stdlib.ProcessorFormatter, + "processor": structlog.processors.JSONRenderer(sort_keys=True), + "foreign_pre_chain": LOG_PRE_CHAIN + [structlog.processors.dict_tracebacks], + }, + "console": { + "()": structlog.stdlib.ProcessorFormatter, + "processor": structlog.dev.ConsoleRenderer(colors=debug), + "foreign_pre_chain": LOG_PRE_CHAIN, + }, + }, + "handlers": { + "console": { + "level": "DEBUG", + "class": "logging.StreamHandler", + "formatter": "console" if debug else "json", + }, + }, + "loggers": {}, + } + + handler_level_map = { + "": global_level, + "authentik": global_level, + "django": "WARNING", + "django.request": "ERROR", + "celery": "WARNING", + "selenium": "WARNING", + "docker": "WARNING", + "urllib3": "WARNING", + "websockets": "WARNING", + "daphne": "WARNING", + "kubernetes": "INFO", + "asyncio": "WARNING", + "redis": "WARNING", + "silk": "INFO", + "fsevents": "WARNING", + "uvicorn": "WARNING", + "gunicorn": "INFO", + } + for handler_name, level in handler_level_map.items(): + base_config["loggers"][handler_name] = { + "handlers": ["console"], + "level": level, + "propagate": False, + } + return base_config + def add_process_id(logger: Logger, method_name: str, event_dict): """Add the current process ID""" diff --git a/authentik/root/middleware.py b/authentik/root/middleware.py index 590884c92..8f97c3c9e 100644 --- a/authentik/root/middleware.py +++ b/authentik/root/middleware.py @@ -172,7 +172,7 @@ class ChannelsLoggingMiddleware: LOGGER.info( scope["path"], scheme="ws", - remote=scope.get("client", [""])[0], + remote=headers.get(b"x-forwarded-for", b"").decode(), user_agent=headers.get(b"user-agent", b"").decode(), **kwargs, ) diff --git a/authentik/root/settings.py b/authentik/root/settings.py index 496ef96cb..af8f85dbf 100644 --- a/authentik/root/settings.py +++ b/authentik/root/settings.py @@ -1,25 +1,21 @@ """root settings for authentik""" import importlib -import logging import os from hashlib import sha512 from pathlib import Path from urllib.parse import quote_plus -import structlog from celery.schedules import crontab from sentry_sdk import set_tag from authentik import ENV_GIT_HASH_KEY, __version__ from authentik.lib.config import CONFIG -from authentik.lib.logging import add_process_id +from authentik.lib.logging import get_logger_config, structlog_configure from authentik.lib.sentry import sentry_init from authentik.lib.utils.reflection import get_env from authentik.stages.password import BACKEND_APP_PASSWORD, BACKEND_INBUILT, BACKEND_LDAP -LOGGER = structlog.get_logger() - BASE_DIR = Path(__file__).absolute().parent.parent.parent STATICFILES_DIRS = [BASE_DIR / Path("web")] MEDIA_ROOT = BASE_DIR / Path("media") @@ -368,90 +364,9 @@ MEDIA_URL = "/media/" TEST = False TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner" -# We can't check TEST here as its set later by the test runner -LOG_LEVEL = CONFIG.get("log_level").upper() if "TF_BUILD" not in os.environ else "DEBUG" -# We could add a custom level to stdlib logging and structlog, but it's not easy or clean -# https://stackoverflow.com/questions/54505487/custom-log-level-not-working-with-structlog -# Additionally, the entire code uses debug as highest level so that would have to be re-written too -if LOG_LEVEL == "TRACE": - LOG_LEVEL = "DEBUG" -structlog.configure_once( - processors=[ - structlog.stdlib.add_log_level, - structlog.stdlib.add_logger_name, - structlog.contextvars.merge_contextvars, - add_process_id, - structlog.stdlib.PositionalArgumentsFormatter(), - structlog.processors.TimeStamper(fmt="iso", utc=False), - structlog.processors.StackInfoRenderer(), - structlog.processors.dict_tracebacks, - structlog.stdlib.ProcessorFormatter.wrap_for_formatter, - ], - logger_factory=structlog.stdlib.LoggerFactory(), - wrapper_class=structlog.make_filtering_bound_logger( - getattr(logging, LOG_LEVEL, logging.WARNING) - ), - cache_logger_on_first_use=True, -) - -LOG_PRE_CHAIN = [ - # Add the log level and a timestamp to the event_dict if the log entry - # is not from structlog. - structlog.stdlib.add_log_level, - structlog.stdlib.add_logger_name, - structlog.processors.TimeStamper(), - structlog.processors.StackInfoRenderer(), -] - -LOGGING = { - "version": 1, - "disable_existing_loggers": False, - "formatters": { - "json": { - "()": structlog.stdlib.ProcessorFormatter, - "processor": structlog.processors.JSONRenderer(sort_keys=True), - "foreign_pre_chain": LOG_PRE_CHAIN + [structlog.processors.dict_tracebacks], - }, - "console": { - "()": structlog.stdlib.ProcessorFormatter, - "processor": structlog.dev.ConsoleRenderer(colors=DEBUG), - "foreign_pre_chain": LOG_PRE_CHAIN, - }, - }, - "handlers": { - "console": { - "level": "DEBUG", - "class": "logging.StreamHandler", - "formatter": "console" if DEBUG else "json", - }, - }, - "loggers": {}, -} - -_LOGGING_HANDLER_MAP = { - "": LOG_LEVEL, - "authentik": LOG_LEVEL, - "django": "WARNING", - "django.request": "ERROR", - "celery": "WARNING", - "selenium": "WARNING", - "docker": "WARNING", - "urllib3": "WARNING", - "websockets": "WARNING", - "daphne": "WARNING", - "kubernetes": "INFO", - "asyncio": "WARNING", - "redis": "WARNING", - "silk": "INFO", - "fsevents": "WARNING", -} -for handler_name, level in _LOGGING_HANDLER_MAP.items(): - LOGGING["loggers"][handler_name] = { - "handlers": ["console"], - "level": level, - "propagate": False, - } +structlog_configure() +LOGGING = get_logger_config() _DISALLOWED_ITEMS = [ diff --git a/cmd/server/server.go b/cmd/server/server.go index b734a99d4..f80c98544 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -13,7 +13,6 @@ import ( "goauthentik.io/internal/config" "goauthentik.io/internal/constants" "goauthentik.io/internal/debug" - "goauthentik.io/internal/gounicorn" "goauthentik.io/internal/outpost/ak" "goauthentik.io/internal/outpost/proxyv2" sentryutils "goauthentik.io/internal/utils/sentry" @@ -22,8 +21,6 @@ import ( "goauthentik.io/internal/web/tenant_tls" ) -var running = true - var rootCmd = &cobra.Command{ Use: "authentik", Short: "Start authentik instance", @@ -63,40 +60,25 @@ var rootCmd = &cobra.Command{ ex := common.Init() defer common.Defer() - u, _ := url.Parse("http://localhost:8000") - - g := gounicorn.New() - defer func() { - l.Info("shutting down gunicorn") - g.Kill() - }() - ws := web.NewWebServer(g) - g.HealthyCallback = func() { - if !config.Get().Outposts.DisableEmbeddedOutpost { - go attemptProxyStart(ws, u) - } + u, err := url.Parse(fmt.Sprintf("http://%s", config.Get().Listen.HTTP)) + if err != nil { + panic(err) + } + + ws := web.NewWebServer() + ws.Core().HealthyCallback = func() { + if config.Get().Outposts.DisableEmbeddedOutpost { + return + } + go attemptProxyStart(ws, u) } - go web.RunMetricsServer() - go attemptStartBackend(g) ws.Start() <-ex - running = false l.Info("shutting down webserver") go ws.Shutdown() - }, } -func attemptStartBackend(g *gounicorn.GoUnicorn) { - for { - if !running { - return - } - err := g.Start() - log.WithField("logger", "authentik.router").WithError(err).Warning("gunicorn process died, restarting") - } -} - func attemptProxyStart(ws *web.WebServer, u *url.URL) { maxTries := 100 attempt := 0 diff --git a/internal/gounicorn/gounicorn.go b/internal/gounicorn/gounicorn.go index 5cf65a733..b07a8c61a 100644 --- a/internal/gounicorn/gounicorn.go +++ b/internal/gounicorn/gounicorn.go @@ -1,7 +1,6 @@ package gounicorn import ( - "net/http" "os" "os/exec" "runtime" @@ -10,10 +9,10 @@ import ( log "github.com/sirupsen/logrus" "goauthentik.io/internal/config" - "goauthentik.io/internal/utils/web" ) type GoUnicorn struct { + Healthcheck func() bool HealthyCallback func() log *log.Entry @@ -23,9 +22,10 @@ type GoUnicorn struct { alive bool } -func New() *GoUnicorn { +func New(healthcheck func() bool) *GoUnicorn { logger := log.WithField("logger", "authentik.router.unicorn") g := &GoUnicorn{ + Healthcheck: healthcheck, log: logger, started: false, killed: false, @@ -41,7 +41,7 @@ func (g *GoUnicorn) initCmd() { args := []string{"-c", "./lifecycle/gunicorn.conf.py", "authentik.root.asgi:application"} if config.Get().Debug { command = "./manage.py" - args = []string{"runserver"} + args = []string{"dev_server"} } g.log.WithField("args", args).WithField("cmd", command).Debug("Starting gunicorn") g.p = exec.Command(command, args...) @@ -69,22 +69,11 @@ func (g *GoUnicorn) Start() error { func (g *GoUnicorn) healthcheck() { g.log.Debug("starting healthcheck") - h := &http.Client{ - Transport: web.NewUserAgentTransport("goauthentik.io/proxy/healthcheck", http.DefaultTransport), - } - check := func() bool { - res, err := h.Get("http://localhost:8000/-/health/live/") - if err == nil && res.StatusCode == 204 { - g.alive = true - return true - } - return false - } - // Default healthcheck is every 1 second on startup // once we've been healthy once, increase to 30 seconds for range time.Tick(time.Second) { - if check() { + if g.Healthcheck() { + g.alive = true g.log.Info("backend is alive, backing off with healthchecks") g.HealthyCallback() break @@ -92,7 +81,7 @@ func (g *GoUnicorn) healthcheck() { g.log.Debug("backend not alive yet") } for range time.Tick(30 * time.Second) { - check() + g.Healthcheck() } } diff --git a/internal/web/metrics.go b/internal/web/metrics.go index 0f22f59ab..0b8670b61 100644 --- a/internal/web/metrics.go +++ b/internal/web/metrics.go @@ -1,6 +1,7 @@ package web import ( + "fmt" "io" "net/http" @@ -26,7 +27,7 @@ var ( }, []string{"dest"}) ) -func RunMetricsServer() { +func (ws *WebServer) runMetricsServer() { m := mux.NewRouter() l := log.WithField("logger", "authentik.router.metrics") m.Use(sentry.SentryNoSampleMiddleware) @@ -38,13 +39,13 @@ func RunMetricsServer() { ).ServeHTTP(rw, r) // Get upstream metrics - re, err := http.NewRequest("GET", "http://localhost:8000/-/metrics/", nil) + re, err := http.NewRequest("GET", fmt.Sprintf("%s/-/metrics/", ws.ul.String()), nil) if err != nil { l.WithError(err).Warning("failed to get upstream metrics") return } re.SetBasicAuth("monitor", config.Get().SecretKey) - res, err := http.DefaultClient.Do(re) + res, err := ws.upstreamHttpClient().Do(re) if err != nil { l.WithError(err).Warning("failed to get upstream metrics") return diff --git a/internal/web/proxy.go b/internal/web/proxy.go index 13c2c76fe..b52d24c3b 100644 --- a/internal/web/proxy.go +++ b/internal/web/proxy.go @@ -2,10 +2,10 @@ package web import ( "encoding/json" + "errors" "fmt" "net/http" "net/http/httputil" - "net/url" "time" "github.com/prometheus/client_golang/prometheus" @@ -14,10 +14,9 @@ import ( func (ws *WebServer) configureProxy() { // Reverse proxy to the application server - u, _ := url.Parse("http://localhost:8000") director := func(req *http.Request) { - req.URL.Scheme = u.Scheme - req.URL.Host = u.Host + req.URL.Scheme = ws.ul.Scheme + req.URL.Host = ws.ul.Host if _, ok := req.Header["User-Agent"]; !ok { // explicitly disable User-Agent so it's not set to default value req.Header.Set("User-Agent", "") @@ -27,7 +26,10 @@ func (ws *WebServer) configureProxy() { } ws.log.WithField("url", req.URL.String()).WithField("headers", req.Header).Trace("tracing request to backend") } - rp := &httputil.ReverseProxy{Director: director} + rp := &httputil.ReverseProxy{ + Director: director, + Transport: ws.upstreamHttpClient().Transport, + } rp.ErrorHandler = ws.proxyErrorHandler rp.ModifyResponse = ws.proxyModifyResponse ws.m.PathPrefix("/outpost.goauthentik.io").HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { @@ -43,14 +45,14 @@ func (ws *WebServer) configureProxy() { }).Observe(float64(elapsed)) return } - ws.proxyErrorHandler(rw, r, fmt.Errorf("proxy not running")) + ws.proxyErrorHandler(rw, r, errors.New("proxy not running")) }) ws.m.Path("/-/health/live/").HandlerFunc(sentry.SentryNoSample(func(rw http.ResponseWriter, r *http.Request) { rw.WriteHeader(204) })) ws.m.PathPrefix("/").HandlerFunc(sentry.SentryNoSample(func(rw http.ResponseWriter, r *http.Request) { - if !ws.p.IsRunning() { - ws.proxyErrorHandler(rw, r, fmt.Errorf("authentik core not running yet")) + if !ws.g.IsRunning() { + ws.proxyErrorHandler(rw, r, errors.New("authentik starting")) return } before := time.Now() @@ -82,17 +84,14 @@ func (ws *WebServer) proxyErrorHandler(rw http.ResponseWriter, req *http.Request ws.log.WithError(err).Warning("failed to proxy to backend") rw.WriteHeader(http.StatusBadGateway) em := fmt.Sprintf("failed to connect to authentik backend: %v", err) - if !ws.p.IsRunning() { - em = "authentik starting..." - } // return json if the client asks for json if req.Header.Get("Accept") == "application/json" { - eem, _ := json.Marshal(map[string]string{ + err = json.NewEncoder(rw).Encode(map[string]string{ "error": em, }) - em = string(eem) + } else { + _, err = rw.Write([]byte(em)) } - _, err = rw.Write([]byte(em)) if err != nil { ws.log.WithError(err).Warning("failed to write error message") } diff --git a/internal/web/web.go b/internal/web/web.go index 97a1eb068..e0b3a749d 100644 --- a/internal/web/web.go +++ b/internal/web/web.go @@ -3,8 +3,12 @@ package web import ( "context" "errors" + "fmt" "net" "net/http" + "net/url" + "os" + "path" "github.com/gorilla/handlers" "github.com/gorilla/mux" @@ -26,13 +30,18 @@ type WebServer struct { ProxyServer *proxyv2.ProxyServer TenantTLS *tenant_tls.Watcher + g *gounicorn.GoUnicorn + gr bool m *mux.Router lh *mux.Router log *log.Entry - p *gounicorn.GoUnicorn + uc *http.Client + ul *url.URL } -func NewWebServer(g *gounicorn.GoUnicorn) *WebServer { +const UnixSocketName = "authentik-core.sock" + +func NewWebServer() *WebServer { l := log.WithField("logger", "authentik.router") mainHandler := mux.NewRouter() mainHandler.Use(web.ProxyHeaders()) @@ -40,23 +49,80 @@ func NewWebServer(g *gounicorn.GoUnicorn) *WebServer { loggingHandler := mainHandler.NewRoute().Subrouter() loggingHandler.Use(web.NewLoggingHandler(l, nil)) + tmp := os.TempDir() + socketPath := path.Join(tmp, "authentik-core.sock") + + // create http client to talk to backend, normal client if we're in debug more + // and a client that connects to our socket when in non debug mode + var upstreamClient *http.Client + if config.Get().Debug { + upstreamClient = http.DefaultClient + } else { + upstreamClient = &http.Client{ + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", socketPath) + }, + }, + } + } + + u, _ := url.Parse("http://localhost:8000") + ws := &WebServer{ m: mainHandler, lh: loggingHandler, log: l, - p: g, + gr: true, + uc: upstreamClient, + ul: u, } ws.configureStatic() ws.configureProxy() + ws.g = gounicorn.New(func() bool { + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/-/health/live/", ws.ul.String()), nil) + if err != nil { + ws.log.WithError(err).Warning("failed to create request for healthcheck") + return false + } + req.Header.Set("User-Agent", "goauthentik.io/router/healthcheck") + res, err := ws.upstreamHttpClient().Do(req) + if err == nil && res.StatusCode == 204 { + return true + } + return false + }) return ws } func (ws *WebServer) Start() { + go ws.runMetricsServer() + go ws.attemptStartBackend() go ws.listenPlain() go ws.listenTLS() } +func (ws *WebServer) attemptStartBackend() { + for { + if !ws.gr { + return + } + err := ws.g.Start() + log.WithField("logger", "authentik.router").WithError(err).Warning("gunicorn process died, restarting") + } +} + +func (ws *WebServer) Core() *gounicorn.GoUnicorn { + return ws.g +} + +func (ws *WebServer) upstreamHttpClient() *http.Client { + return ws.uc +} + func (ws *WebServer) Shutdown() { + ws.log.Info("shutting down gunicorn") + ws.g.Kill() ws.stop <- struct{}{} } diff --git a/lifecycle/gunicorn.conf.py b/lifecycle/gunicorn.conf.py index 9359196fc..b26e23fb0 100644 --- a/lifecycle/gunicorn.conf.py +++ b/lifecycle/gunicorn.conf.py @@ -7,12 +7,12 @@ from pathlib import Path from tempfile import gettempdir from typing import TYPE_CHECKING -import structlog from kubernetes.config.incluster_config import SERVICE_HOST_ENV_NAME from prometheus_client.values import MultiProcessValue from authentik import get_full_version from authentik.lib.config import CONFIG +from authentik.lib.logging import get_logger_config from authentik.lib.utils.http import get_http_session from authentik.lib.utils.reflection import get_env from authentik.root.install_id import get_install_id_raw @@ -21,57 +21,23 @@ from lifecycle.worker import DjangoUvicornWorker if TYPE_CHECKING: from gunicorn.arbiter import Arbiter -bind = "127.0.0.1:8000" - _tmp = Path(gettempdir()) worker_class = "lifecycle.worker.DjangoUvicornWorker" worker_tmp_dir = str(_tmp.joinpath("authentik_worker_tmp")) prometheus_tmp_dir = str(_tmp.joinpath("authentik_prometheus_tmp")) -os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings") -os.environ.setdefault("PROMETHEUS_MULTIPROC_DIR", prometheus_tmp_dir) - makedirs(worker_tmp_dir, exist_ok=True) makedirs(prometheus_tmp_dir, exist_ok=True) +bind = f"unix://{str(_tmp.joinpath('authentik-core.sock'))}" + +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings") +os.environ.setdefault("PROMETHEUS_MULTIPROC_DIR", prometheus_tmp_dir) + max_requests = 1000 max_requests_jitter = 50 -_debug = CONFIG.get_bool("DEBUG", False) - -logconfig_dict = { - "version": 1, - "disable_existing_loggers": False, - "formatters": { - "json": { - "()": structlog.stdlib.ProcessorFormatter, - "processor": structlog.processors.JSONRenderer(), - "foreign_pre_chain": [ - structlog.stdlib.add_log_level, - structlog.stdlib.add_logger_name, - structlog.processors.TimeStamper(), - structlog.processors.StackInfoRenderer(), - ], - }, - "console": { - "()": structlog.stdlib.ProcessorFormatter, - "processor": structlog.dev.ConsoleRenderer(colors=True), - "foreign_pre_chain": [ - structlog.stdlib.add_log_level, - structlog.stdlib.add_logger_name, - structlog.processors.TimeStamper(), - structlog.processors.StackInfoRenderer(), - ], - }, - }, - "handlers": { - "console": {"class": "logging.StreamHandler", "formatter": "json" if _debug else "console"}, - }, - "loggers": { - "uvicorn": {"handlers": ["console"], "level": "WARNING", "propagate": False}, - "gunicorn": {"handlers": ["console"], "level": "INFO", "propagate": False}, - }, -} +logconfig_dict = get_logger_config() # if we're running in kubernetes, use fixed workers because we can scale with more pods # otherwise (assume docker-compose), use as much as we can