diff --git a/cmd/server/server.go b/cmd/server/server.go index f80c98544..8e3c3f28b 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -9,6 +9,7 @@ import ( "github.com/getsentry/sentry-go" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "goauthentik.io/internal/common" "goauthentik.io/internal/config" "goauthentik.io/internal/constants" diff --git a/internal/gounicorn/gounicorn.go b/internal/gounicorn/gounicorn.go index b07a8c61a..0f328c943 100644 --- a/internal/gounicorn/gounicorn.go +++ b/internal/gounicorn/gounicorn.go @@ -1,14 +1,20 @@ package gounicorn import ( + "fmt" "os" "os/exec" + "os/signal" "runtime" + "strconv" + "strings" "syscall" "time" log "github.com/sirupsen/logrus" + "goauthentik.io/internal/config" + "goauthentik.io/internal/utils" ) type GoUnicorn struct { @@ -17,6 +23,7 @@ type GoUnicorn struct { log *log.Entry p *exec.Cmd + pidFile string started bool killed bool alive bool @@ -33,15 +40,36 @@ func New(healthcheck func() bool) *GoUnicorn { HealthyCallback: func() {}, } g.initCmd() + c := make(chan os.Signal, 1) + signal.Notify(c, syscall.SIGHUP, syscall.SIGUSR2) + go func() { + for sig := range c { + if sig == syscall.SIGHUP { + g.log.Info("SIGHUP received, forwarding to gunicorn") + g.Reload() + } else if sig == syscall.SIGUSR2 { + g.log.Info("SIGUSR2 received, restarting gunicorn") + g.Restart() + } + } + }() return g } func (g *GoUnicorn) initCmd() { - command := "gunicorn" - args := []string{"-c", "./lifecycle/gunicorn.conf.py", "authentik.root.asgi:application"} - if config.Get().Debug { - command = "./manage.py" - args = []string{"dev_server"} + command := "./manage.py" + args := []string{"dev_server"} + if !config.Get().Debug { + pidFile, err := os.CreateTemp("", "authentik-gunicorn.*.pid") + if err != nil { + panic(fmt.Errorf("failed to create temporary pid file: %v", err)) + } + g.pidFile = pidFile.Name() + command = "gunicorn" + args = []string{"-c", "./lifecycle/gunicorn.conf.py", "authentik.root.asgi:application"} + if g.pidFile != "" { + args = append(args, "--pid", g.pidFile) + } } g.log.WithField("args", args).WithField("cmd", command).Debug("Starting gunicorn") g.p = exec.Command(command, args...) @@ -55,13 +83,10 @@ func (g *GoUnicorn) IsRunning() bool { } func (g *GoUnicorn) Start() error { - if g.killed { - g.log.Debug("Not restarting gunicorn since we're shutdown") - return nil - } if g.started { g.initCmd() } + g.killed = false g.started = true go g.healthcheck() return g.p.Run() @@ -85,8 +110,76 @@ func (g *GoUnicorn) healthcheck() { } } +func (g *GoUnicorn) Reload() { + g.log.WithField("method", "reload").Info("reloading gunicorn") + err := g.p.Process.Signal(syscall.SIGHUP) + if err != nil { + g.log.WithError(err).Warning("failed to reload gunicorn") + } +} + +func (g *GoUnicorn) Restart() { + g.log.WithField("method", "restart").Info("restart gunicorn") + if g.pidFile == "" { + g.log.Warning("pidfile is non existent, cannot restart") + return + } + + err := g.p.Process.Signal(syscall.SIGUSR2) + if err != nil { + g.log.WithError(err).Warning("failed to restart gunicorn") + return + } + + newPidFile := fmt.Sprintf("%s.2", g.pidFile) + + // Wait for the new PID file to be created + for range time.NewTicker(1 * time.Second).C { + _, err = os.Stat(newPidFile) + if err == nil || !os.IsNotExist(err) { + break + } + g.log.Debugf("waiting for new gunicorn pidfile to appear at %s", newPidFile) + } + if err != nil { + g.log.WithError(err).Warning("failed to find the new gunicorn process, aborting") + return + } + + newPidB, err := os.ReadFile(newPidFile) + if err != nil { + g.log.WithError(err).Warning("failed to find the new gunicorn process, aborting") + return + } + newPidS := strings.TrimSpace(string(newPidB[:])) + newPid, err := strconv.Atoi(newPidS) + if err != nil { + g.log.WithError(err).Warning("failed to find the new gunicorn process, aborting") + return + } + g.log.Warningf("new gunicorn PID is %d", newPid) + + newProcess, err := utils.FindProcess(newPid) + if newProcess == nil || err != nil { + g.log.WithError(err).Warning("failed to find the new gunicorn process, aborting") + return + } + + // The new process has started, let's gracefully kill the old one + g.log.Warning("killing old gunicorn") + err = g.p.Process.Signal(syscall.SIGTERM) + if err != nil { + g.log.Warning("failed to kill old instance of gunicorn") + } + + g.p.Process = newProcess + // No need to close any files and the .2 pid file is deleted by Gunicorn +} + func (g *GoUnicorn) Kill() { - g.killed = true + if !g.started { + return + } var err error if runtime.GOOS == "darwin" { g.log.WithField("method", "kill").Warning("stopping gunicorn") @@ -98,4 +191,11 @@ func (g *GoUnicorn) Kill() { if err != nil { g.log.WithError(err).Warning("failed to stop gunicorn") } + if g.pidFile != "" { + err := os.Remove(g.pidFile) + if err != nil { + g.log.WithError(err).Warning("failed to remove pidfile") + } + } + g.killed = true } diff --git a/internal/utils/process.go b/internal/utils/process.go new file mode 100644 index 000000000..366d0d1c4 --- /dev/null +++ b/internal/utils/process.go @@ -0,0 +1,39 @@ +package utils + +import ( + "errors" + "fmt" + "os" + "syscall" +) + +func FindProcess(pid int) (*os.Process, error) { + if pid <= 0 { + return nil, fmt.Errorf("invalid pid %v", pid) + } + // The error doesn't mean anything on Unix systems, let's just check manually + // that the new gunicorn master has properly started + // https://github.com/golang/go/issues/34396 + proc, err := os.FindProcess(pid) + if err != nil { + return nil, err + } + err = proc.Signal(syscall.Signal(0)) + if err == nil { + return proc, nil + } + if errors.Is(err, os.ErrProcessDone) { + return nil, nil + } + errno, ok := err.(syscall.Errno) + if !ok { + return nil, err + } + switch errno { + case syscall.ESRCH: + return nil, nil + case syscall.EPERM: + return proc, nil + } + return nil, err +} diff --git a/internal/web/web.go b/internal/web/web.go index e0b3a749d..ea7ac1746 100644 --- a/internal/web/web.go +++ b/internal/web/web.go @@ -9,6 +9,7 @@ import ( "net/url" "os" "path" + "time" "github.com/gorilla/handlers" "github.com/gorilla/mux" @@ -109,6 +110,21 @@ func (ws *WebServer) attemptStartBackend() { } err := ws.g.Start() log.WithField("logger", "authentik.router").WithError(err).Warning("gunicorn process died, restarting") + if err != nil { + log.WithField("logger", "authentik.router").WithError(err).Error("gunicorn failed to start, restarting") + continue + } + failedChecks := 0 + for range time.NewTicker(30 * time.Second).C { + if !ws.g.IsRunning() { + log.WithField("logger", "authentik.router").Warningf("gunicorn process failed healthcheck %d times", failedChecks) + failedChecks += 1 + } + if failedChecks >= 3 { + log.WithField("logger", "authentik.router").WithError(err).Error("gunicorn process failed healthcheck three times, restarting") + break + } + } } }