diff --git a/cmd/server/main.go b/cmd/server/main.go index a4797c86a..3bca268bd 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -34,28 +34,15 @@ func main() { }) go debug.EnableDebugServer() l := log.WithField("logger", "authentik.root") - config.DefaultConfig() - err := config.LoadConfig("./authentik/lib/default.yml") - if err != nil { - l.WithError(err).Warning("failed to load default config") - } - err = config.LoadConfig("./local.env.yml") - if err != nil { - l.WithError(err).Debug("no local config to load") - } - err = config.FromEnv() - if err != nil { - l.WithError(err).Debug("failed to environment variables") - } - config.ConfigureLogger() + config.Get().Setup("./authentik/lib/default.yml", "./local.env.yml") - if config.G.ErrorReporting.Enabled { + if config.Get().ErrorReporting.Enabled { err := sentry.Init(sentry.ClientOptions{ - Dsn: config.G.ErrorReporting.DSN, + Dsn: config.Get().ErrorReporting.DSN, AttachStacktrace: true, - TracesSampler: sentryutils.SamplerFunc(config.G.ErrorReporting.SampleRate), + TracesSampler: sentryutils.SamplerFunc(config.Get().ErrorReporting.SampleRate), Release: fmt.Sprintf("authentik@%s", constants.VERSION), - Environment: config.G.ErrorReporting.Environment, + Environment: config.Get().ErrorReporting.Environment, HTTPTransport: webutils.NewUserAgentTransport(constants.UserAgent(), http.DefaultTransport), IgnoreErrors: []string{ http.ErrAbortHandler.Error(), @@ -74,7 +61,7 @@ func main() { g := gounicorn.NewGoUnicorn() ws := web.NewWebServer(g) g.HealthyCallback = func() { - if !config.G.Web.DisableEmbeddedOutpost { + if !config.Get().Web.DisableEmbeddedOutpost { go attemptProxyStart(ws, u) } } @@ -105,7 +92,7 @@ func attemptProxyStart(ws *web.WebServer, u *url.URL) { l := log.WithField("logger", "authentik.server") for { l.Debug("attempting to init outpost") - ac := ak.NewAPIController(*u, config.G.SecretKey) + ac := ak.NewAPIController(*u, config.Get().SecretKey) if ac == nil { attempt += 1 time.Sleep(1 * time.Second) diff --git a/internal/config/config.go b/internal/config/config.go index 5e98718ce..b1ccd6a8b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,6 +3,9 @@ package config import ( "fmt" "io/ioutil" + "net/url" + "os" + "reflect" "strings" env "github.com/Netflix/go-env" @@ -11,10 +14,17 @@ import ( "gopkg.in/yaml.v2" ) -var G Config +var cfg *Config -func DefaultConfig() { - G = Config{ +func Get() *Config { + if cfg == nil { + cfg = defaultConfig() + } + return cfg +} + +func defaultConfig() *Config { + return &Config{ Debug: false, Web: WebConfig{ Listen: "localhost:9000", @@ -32,7 +42,18 @@ func DefaultConfig() { } } -func LoadConfig(path string) error { +func (c *Config) Setup(paths ...string) { + for _, path := range paths { + err := c.LoadConfig(path) + if err != nil { + log.WithError(err).Info("failed to load config, skipping") + } + } + c.fromEnv() + c.configureLogger() +} + +func (c *Config) LoadConfig(path string) error { raw, err := ioutil.ReadFile(path) if err != nil { return fmt.Errorf("Failed to load config file: %w", err) @@ -42,28 +63,83 @@ func LoadConfig(path string) error { if err != nil { return fmt.Errorf("Failed to parse YAML: %w", err) } - if err := mergo.Merge(&G, nc, mergo.WithOverride); err != nil { + if err := mergo.Merge(c, nc, mergo.WithOverride); err != nil { return fmt.Errorf("failed to overlay config: %w", err) } + c.walkScheme(c) log.WithField("path", path).Debug("Loaded config") return nil } -func FromEnv() error { +func (c *Config) fromEnv() error { nc := Config{} _, err := env.UnmarshalFromEnviron(&nc) if err != nil { return fmt.Errorf("failed to load environment variables: %w", err) } - if err := mergo.Merge(&G, nc, mergo.WithOverride); err != nil { + if err := mergo.Merge(c, nc, mergo.WithOverride); err != nil { return fmt.Errorf("failed to overlay config: %w", err) } + c.walkScheme(c) log.Debug("Loaded config from environment") return nil } -func ConfigureLogger() { - switch strings.ToLower(G.LogLevel) { +func (c *Config) walkScheme(v interface{}) { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return + } + + rv = rv.Elem() + if rv.Kind() != reflect.Struct { + return + } + + t := rv.Type() + for i := 0; i < rv.NumField(); i++ { + valueField := rv.Field(i) + switch valueField.Kind() { + case reflect.Struct: + if !valueField.Addr().CanInterface() { + continue + } + + iface := valueField.Addr().Interface() + c.walkScheme(iface) + } + + typeField := t.Field(i) + if typeField.Type.Kind() != reflect.String { + continue + } + valueField.SetString(c.parseScheme(valueField.String())) + } +} + +func (c *Config) parseScheme(rawVal string) string { + u, err := url.Parse(rawVal) + if err != nil { + return rawVal + } + if u.Scheme == "env" { + e, ok := os.LookupEnv(u.Host) + if ok { + return e + } + return u.RawQuery + } else if u.Scheme == "file" { + d, err := ioutil.ReadFile(u.Path) + if err != nil { + return u.RawQuery + } + return string(d) + } + return rawVal +} + +func (c *Config) configureLogger() { + switch strings.ToLower(c.LogLevel) { case "trace": log.SetLevel(log.TraceLevel) case "debug": @@ -83,7 +159,7 @@ func ConfigureLogger() { log.FieldKeyTime: "timestamp", } - if G.Debug { + if c.Debug { log.SetFormatter(&log.TextFormatter{FieldMap: fm}) } else { log.SetFormatter(&log.JSONFormatter{FieldMap: fm, DisableHTMLEscape: true}) diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 000000000..012bf3917 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,39 @@ +package config + +import ( + "fmt" + "log" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConfigEnv(t *testing.T) { + os.Setenv("AUTHENTIK_SECRET_KEY", "bar") + cfg = nil + Get().fromEnv() + assert.Equal(t, "bar", Get().SecretKey) +} + +func TestConfigEnv_Scheme(t *testing.T) { + os.Setenv("foo", "bar") + os.Setenv("AUTHENTIK_SECRET_KEY", "env://foo") + cfg = nil + Get().fromEnv() + assert.Equal(t, "bar", Get().SecretKey) +} + +func TestConfigEnv_File(t *testing.T) { + file, err := os.CreateTemp("", "") + if err != nil { + log.Fatal(err) + } + defer os.Remove(file.Name()) + file.Write([]byte("bar")) + + os.Setenv("AUTHENTIK_SECRET_KEY", fmt.Sprintf("file://%s", file.Name())) + cfg = nil + Get().fromEnv() + assert.Equal(t, "bar", Get().SecretKey) +} diff --git a/internal/gounicorn/gounicorn.go b/internal/gounicorn/gounicorn.go index 15d14fcf8..4f593550a 100644 --- a/internal/gounicorn/gounicorn.go +++ b/internal/gounicorn/gounicorn.go @@ -39,7 +39,7 @@ func NewGoUnicorn() *GoUnicorn { func (g *GoUnicorn) initCmd() { command := "gunicorn" args := []string{"-c", "./lifecycle/gunicorn.conf.py", "authentik.root.asgi:application"} - if config.G.Debug { + if config.Get().Debug { command = "./manage.py" args = []string{"runserver"} } diff --git a/internal/outpost/proxyv2/application/session.go b/internal/outpost/proxyv2/application/session.go index 78c3fb75c..b46906594 100644 --- a/internal/outpost/proxyv2/application/session.go +++ b/internal/outpost/proxyv2/application/session.go @@ -15,8 +15,8 @@ import ( func (a *Application) getStore(p api.ProxyOutpostConfig, externalHost *url.URL) sessions.Store { var store sessions.Store - if config.G.Redis.Host != "" { - rs, err := redistore.NewRediStoreWithDB(10, "tcp", fmt.Sprintf("%s:%d", config.G.Redis.Host, config.G.Redis.Port), config.G.Redis.Password, strconv.Itoa(config.G.Redis.OutpostSessionDB), []byte(*p.CookieSecret)) + if config.Get().Redis.Host != "" { + rs, err := redistore.NewRediStoreWithDB(10, "tcp", fmt.Sprintf("%s:%d", config.Get().Redis.Host, config.Get().Redis.Port), config.Get().Redis.Password, strconv.Itoa(config.Get().Redis.OutpostSessionDB), []byte(*p.CookieSecret)) if err != nil { panic(err) } diff --git a/internal/web/metrics.go b/internal/web/metrics.go index b102aab64..2831abd4f 100644 --- a/internal/web/metrics.go +++ b/internal/web/metrics.go @@ -37,7 +37,7 @@ func RunMetricsServer() { l.WithError(err).Warning("failed to get upstream metrics") return } - re.SetBasicAuth("monitor", config.G.SecretKey) + re.SetBasicAuth("monitor", config.Get().SecretKey) res, err := http.DefaultClient.Do(re) if err != nil { l.WithError(err).Warning("failed to get upstream metrics") @@ -54,10 +54,10 @@ func RunMetricsServer() { return } }) - l.WithField("listen", config.G.Web.ListenMetrics).Info("Starting Metrics server") - err := http.ListenAndServe(config.G.Web.ListenMetrics, m) + l.WithField("listen", config.Get().Web.ListenMetrics).Info("Starting Metrics server") + err := http.ListenAndServe(config.Get().Web.ListenMetrics, m) if err != nil { l.WithError(err).Warning("Failed to start metrics server") } - l.WithField("listen", config.G.Web.ListenMetrics).Info("Stopping Metrics server") + l.WithField("listen", config.Get().Web.ListenMetrics).Info("Stopping Metrics server") } diff --git a/internal/web/sentry_proxy.go b/internal/web/sentry_proxy.go index 92d6e46fc..328307608 100644 --- a/internal/web/sentry_proxy.go +++ b/internal/web/sentry_proxy.go @@ -14,7 +14,7 @@ type SentryRequest struct { } func (ws *WebServer) APISentryProxy(rw http.ResponseWriter, r *http.Request) { - if !config.G.ErrorReporting.Enabled { + if !config.Get().ErrorReporting.Enabled { ws.log.Debug("error reporting disabled") rw.WriteHeader(http.StatusBadRequest) return @@ -37,8 +37,8 @@ func (ws *WebServer) APISentryProxy(rw http.ResponseWriter, r *http.Request) { rw.WriteHeader(http.StatusBadRequest) return } - if sd.DSN != config.G.ErrorReporting.DSN { - ws.log.WithField("have", sd.DSN).WithField("expected", config.G.ErrorReporting.DSN).Debug("invalid DSN") + if sd.DSN != config.Get().ErrorReporting.DSN { + ws.log.WithField("have", sd.DSN).WithField("expected", config.Get().ErrorReporting.DSN).Debug("invalid DSN") rw.WriteHeader(http.StatusBadRequest) return } diff --git a/internal/web/static.go b/internal/web/static.go index 7e7cf6772..534a839b0 100644 --- a/internal/web/static.go +++ b/internal/web/static.go @@ -17,7 +17,7 @@ func (ws *WebServer) configureStatic() { indexLessRouter := statRouter.NewRoute().Subrouter() indexLessRouter.Use(web.DisableIndex) // Media files, always local - fs := http.FileServer(http.Dir(config.G.Paths.Media)) + fs := http.FileServer(http.Dir(config.Get().Paths.Media)) distFs := http.FileServer(http.Dir("./web/dist")) distHandler := http.StripPrefix("/static/dist/", distFs) authentikHandler := http.StripPrefix("/static/authentik/", http.FileServer(http.Dir("./web/authentik"))) diff --git a/internal/web/tls.go b/internal/web/tls.go index 18dce9d1b..38f6c93cd 100644 --- a/internal/web/tls.go +++ b/internal/web/tls.go @@ -41,7 +41,7 @@ func (ws *WebServer) listenTLS() { GetCertificate: ws.GetCertificate(), } - ln, err := net.Listen("tcp", config.G.Web.ListenTLS) + ln, err := net.Listen("tcp", config.Get().Web.ListenTLS) if err != nil { ws.log.WithError(err).Fatalf("failed to listen (TLS)") return @@ -50,7 +50,7 @@ func (ws *WebServer) listenTLS() { defer proxyListener.Close() tlsListener := tls.NewListener(proxyListener, tlsConfig) - ws.log.WithField("listen", config.G.Web.ListenTLS).Info("Starting HTTPS server") + ws.log.WithField("listen", config.Get().Web.ListenTLS).Info("Starting HTTPS server") ws.serve(tlsListener) - ws.log.WithField("listen", config.G.Web.ListenTLS).Info("Stopping HTTPS server") + ws.log.WithField("listen", config.Get().Web.ListenTLS).Info("Stopping HTTPS server") } diff --git a/internal/web/web.go b/internal/web/web.go index 79fd4fdf3..50198dd6e 100644 --- a/internal/web/web.go +++ b/internal/web/web.go @@ -68,16 +68,16 @@ func (ws *WebServer) Shutdown() { } func (ws *WebServer) listenPlain() { - ln, err := net.Listen("tcp", config.G.Web.Listen) + ln, err := net.Listen("tcp", config.Get().Web.Listen) if err != nil { ws.log.WithError(err).Fatal("failed to listen") } proxyListener := &proxyproto.Listener{Listener: ln} defer proxyListener.Close() - ws.log.WithField("listen", config.G.Web.Listen).Info("Starting HTTP server") + ws.log.WithField("listen", config.Get().Web.Listen).Info("Starting HTTP server") ws.serve(proxyListener) - ws.log.WithField("listen", config.G.Web.Listen).Info("Stopping HTTP server") + ws.log.WithField("listen", config.Get().Web.Listen).Info("Stopping HTTP server") } func (ws *WebServer) serve(listener net.Listener) {