root: replace boj/redistore with vendored version of rbcervilla/redisstore (#6988)

* root: replace boj/redistore with vendored version of rbcervilla/redisstore

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* setup env for go tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L 2023-09-26 18:56:37 +02:00 committed by GitHub
parent 90aa5409cd
commit c93c6ee6f9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 418 additions and 38 deletions

View file

@ -39,6 +39,8 @@ jobs:
- uses: actions/setup-go@v4 - uses: actions/setup-go@v4
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- name: Setup authentik env
uses: ./.github/actions/setup
- name: Generate API - name: Generate API
run: make gen-client-go run: make gen-client-go
- name: Go unittests - name: Go unittests

4
go.mod
View file

@ -6,7 +6,6 @@ require (
beryju.io/ldap v0.1.0 beryju.io/ldap v0.1.0
github.com/Netflix/go-env v0.0.0-20210215222557-e437a7e7f9fb github.com/Netflix/go-env v0.0.0-20210215222557-e437a7e7f9fb
github.com/coreos/go-oidc v2.2.1+incompatible github.com/coreos/go-oidc v2.2.1+incompatible
github.com/garyburd/redigo v1.6.4
github.com/getsentry/sentry-go v0.24.1 github.com/getsentry/sentry-go v0.24.1
github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1 github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1
github.com/go-ldap/ldap/v3 v3.4.6 github.com/go-ldap/ldap/v3 v3.4.6
@ -23,6 +22,7 @@ require (
github.com/nmcclain/asn1-ber v0.0.0-20170104154839-2661553a0484 github.com/nmcclain/asn1-ber v0.0.0-20170104154839-2661553a0484
github.com/pires/go-proxyproto v0.7.0 github.com/pires/go-proxyproto v0.7.0
github.com/prometheus/client_golang v1.16.0 github.com/prometheus/client_golang v1.16.0
github.com/redis/go-redis/v9 v9.2.0
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.8.4 github.com/stretchr/testify v1.8.4
@ -30,7 +30,6 @@ require (
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab
golang.org/x/oauth2 v0.12.0 golang.org/x/oauth2 v0.12.0
golang.org/x/sync v0.3.0 golang.org/x/sync v0.3.0
gopkg.in/boj/redistore.v1 v1.0.0-20160128113310-fc113767cd6b
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0
layeh.com/radius v0.0.0-20210819152912-ad72663a72ab layeh.com/radius v0.0.0-20210819152912-ad72663a72ab
) )
@ -41,6 +40,7 @@ require (
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/felixge/httpsnoop v1.0.1 // indirect github.com/felixge/httpsnoop v1.0.1 // indirect
github.com/go-asn1-ber/asn1-ber v1.5.5 // indirect github.com/go-asn1-ber/asn1-ber v1.5.5 // indirect
github.com/go-http-utils/fresh v0.0.0-20161124030543-7231e26a4b27 // indirect github.com/go-http-utils/fresh v0.0.0-20161124030543-7231e26a4b27 // indirect

10
go.sum
View file

@ -48,6 +48,8 @@ github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3d
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
@ -63,14 +65,14 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ= github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ=
github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/garyburd/redigo v1.6.4 h1:LFu2R3+ZOPgSMWMOL+saa/zXRjw0ID2G8FepO53BGlg=
github.com/garyburd/redigo v1.6.4/go.mod h1:rTb6epsqigu3kYKBnaF028A7Tf/Aw5s0cqA47doKKqw=
github.com/getsentry/sentry-go v0.24.1 h1:W6/0GyTy8J6ge6lVCc94WB6Gx2ZuLrgopnn9w8Hiwuk= github.com/getsentry/sentry-go v0.24.1 h1:W6/0GyTy8J6ge6lVCc94WB6Gx2ZuLrgopnn9w8Hiwuk=
github.com/getsentry/sentry-go v0.24.1/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY= github.com/getsentry/sentry-go v0.24.1/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY=
github.com/go-asn1-ber/asn1-ber v1.5.5 h1:MNHlNMBDgEKD4TcKr36vQN68BA00aDfjIt3/bD50WnA= github.com/go-asn1-ber/asn1-ber v1.5.5 h1:MNHlNMBDgEKD4TcKr36vQN68BA00aDfjIt3/bD50WnA=
@ -286,6 +288,8 @@ github.com/prometheus/common v0.42.0 h1:EKsfXEYo4JpWMHH5cg+KOUWeuJSov1Id8zGR8eeI
github.com/prometheus/common v0.42.0/go.mod h1:xBwqVerjNdUDjgODMpudtOMwlOwf2SaTr1yjz4b7Zbc= github.com/prometheus/common v0.42.0/go.mod h1:xBwqVerjNdUDjgODMpudtOMwlOwf2SaTr1yjz4b7Zbc=
github.com/prometheus/procfs v0.10.1 h1:kYK1Va/YMlutzCGazswoHKo//tZVlFpKYh+PymziUAg= github.com/prometheus/procfs v0.10.1 h1:kYK1Va/YMlutzCGazswoHKo//tZVlFpKYh+PymziUAg=
github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM= github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM=
github.com/redis/go-redis/v9 v9.2.0 h1:zwMdX0A4eVzse46YN18QhuDiM4uf3JmkOB4VZrdt5uI=
github.com/redis/go-redis/v9 v9.2.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
@ -638,8 +642,6 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/boj/redistore.v1 v1.0.0-20160128113310-fc113767cd6b h1:U/Uqd1232+wrnHOvWNaxrNqn/kFnr4yu4blgPtQt0N8=
gopkg.in/boj/redistore.v1 v1.0.0-20160128113310-fc113767cd6b/go.mod h1:fgfIZMlsafAHpspcks2Bul+MWUNw/2dyQmjC2faKjtg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View file

@ -280,7 +280,7 @@ func (a *Application) handleSignOut(rw http.ResponseWriter, r *http.Request) {
"id_token_hint": []string{cc.RawToken}, "id_token_hint": []string{cc.RawToken},
} }
redirect += "?" + uv.Encode() redirect += "?" + uv.Encode()
err = a.Logout(cc.Sub) err = a.Logout(r.Context(), cc.Sub)
if err != nil { if err != nil {
a.log.WithError(err).Warning("failed to logout of other sessions") a.log.WithError(err).Warning("failed to logout of other sessions")
} }

View file

@ -1,23 +1,23 @@
package application package application
import ( import (
"context"
"fmt" "fmt"
"math" "math"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"path" "path"
"strconv"
"strings" "strings"
"github.com/garyburd/redigo/redis"
"github.com/gorilla/securecookie" "github.com/gorilla/securecookie"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"github.com/redis/go-redis/v9"
"goauthentik.io/api/v3" "goauthentik.io/api/v3"
"goauthentik.io/internal/config" "goauthentik.io/internal/config"
"goauthentik.io/internal/outpost/proxyv2/codecs" "goauthentik.io/internal/outpost/proxyv2/codecs"
"goauthentik.io/internal/outpost/proxyv2/constants" "goauthentik.io/internal/outpost/proxyv2/constants"
"gopkg.in/boj/redistore.v1" "goauthentik.io/internal/outpost/proxyv2/redisstore"
) )
const RedisKeyPrefix = "authentik_proxy_session_" const RedisKeyPrefix = "authentik_proxy_session_"
@ -30,20 +30,26 @@ func (a *Application) getStore(p api.ProxyOutpostConfig, externalHost *url.URL)
maxAge = int(*t) + 1 maxAge = int(*t) + 1
} }
if a.isEmbedded { if a.isEmbedded {
rs, err := redistore.NewRediStoreWithDB(10, "tcp", fmt.Sprintf("%s:%d", config.Get().Redis.Host, config.Get().Redis.Port), config.Get().Redis.Password, strconv.Itoa(config.Get().Redis.DB)) client := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", config.Get().Redis.Host, config.Get().Redis.Port),
// Username: config.Get().Redis.Password,
Password: config.Get().Redis.Password,
DB: config.Get().Redis.DB,
})
// New default RedisStore
rs, err := redisstore.NewRedisStore(context.Background(), client)
if err != nil { if err != nil {
panic(err) panic(err)
} }
rs.Codecs = codecs.CodecsFromPairs(maxAge, []byte(*p.CookieSecret))
rs.SetMaxLength(math.MaxInt)
rs.SetKeyPrefix(RedisKeyPrefix)
rs.Options.HttpOnly = true rs.KeyPrefix(RedisKeyPrefix)
if strings.ToLower(externalHost.Scheme) == "https" { rs.Options(sessions.Options{
rs.Options.Secure = true HttpOnly: strings.ToLower(externalHost.Scheme) == "https",
} Domain: *p.CookieDomain,
rs.Options.Domain = *p.CookieDomain SameSite: http.SameSiteLaxMode,
rs.Options.SameSite = http.SameSiteLaxMode })
a.log.Trace("using redis session backend") a.log.Trace("using redis session backend")
return rs return rs
} }
@ -80,7 +86,7 @@ func (a *Application) getAllCodecs() []securecookie.Codec {
return cs return cs
} }
func (a *Application) Logout(sub string) error { func (a *Application) Logout(ctx context.Context, sub string) error {
if _, ok := a.sessions.(*sessions.FilesystemStore); ok { if _, ok := a.sessions.(*sessions.FilesystemStore); ok {
files, err := os.ReadDir(os.TempDir()) files, err := os.ReadDir(os.TempDir())
if err != nil { if err != nil {
@ -120,31 +126,22 @@ func (a *Application) Logout(sub string) error {
} }
} }
} }
if rs, ok := a.sessions.(*redistore.RediStore); ok { if rs, ok := a.sessions.(*redisstore.RedisStore); ok {
pool := rs.Pool.Get() client := rs.Client()
defer pool.Close() defer client.Close()
rep, err := pool.Do("KEYS", fmt.Sprintf("%s*", RedisKeyPrefix)) keys, err := client.Keys(ctx, fmt.Sprintf("%s*", RedisKeyPrefix)).Result()
if err != nil { if err != nil {
return err return err
} }
keys, err := redis.Strings(rep, err) serializer := redisstore.GobSerializer{}
if err != nil {
return err
}
serializer := redistore.GobSerializer{}
for _, key := range keys { for _, key := range keys {
v, err := pool.Do("GET", key) v, err := client.Get(ctx, key).Result()
if err != nil { if err != nil {
a.log.WithError(err).Warning("failed to get value") a.log.WithError(err).Warning("failed to get value")
continue continue
} }
b, err := redis.Bytes(v, err)
if err != nil {
a.log.WithError(err).Warning("failed to load value")
continue
}
s := sessions.Session{} s := sessions.Session{}
err = serializer.Deserialize(b, &s) err = serializer.Deserialize([]byte(v), &s)
if err != nil { if err != nil {
a.log.WithError(err).Warning("failed to deserialize") a.log.WithError(err).Warning("failed to deserialize")
continue continue
@ -156,7 +153,7 @@ func (a *Application) Logout(sub string) error {
claims := c.(Claims) claims := c.(Claims)
if claims.Sub == sub { if claims.Sub == sub {
a.log.WithField("key", key).Trace("deleting session") a.log.WithField("key", key).Trace("deleting session")
_, err := pool.Do("DEL", key) _, err := client.Del(ctx, key).Result()
if err != nil { if err != nil {
a.log.WithError(err).Warning("failed to delete key") a.log.WithError(err).Warning("failed to delete key")
continue continue

View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2019 Ruben Cervilla
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,200 @@
package redisstore
import (
"bytes"
"context"
"crypto/rand"
"encoding/base32"
"encoding/gob"
"errors"
"io"
"net/http"
"strings"
"time"
"github.com/gorilla/sessions"
"github.com/redis/go-redis/v9"
)
// RedisStore stores gorilla sessions in Redis
type RedisStore struct {
// client to connect to redis
client redis.UniversalClient
// default options to use when a new session is created
options sessions.Options
// key prefix with which the session will be stored
keyPrefix string
// key generator
keyGen KeyGenFunc
// session serializer
serializer SessionSerializer
}
// KeyGenFunc defines a function used by store to generate a key
type KeyGenFunc func() (string, error)
// NewRedisStore returns a new RedisStore with default configuration
func NewRedisStore(ctx context.Context, client redis.UniversalClient) (*RedisStore, error) {
rs := &RedisStore{
options: sessions.Options{
Path: "/",
MaxAge: 86400 * 30,
},
client: client,
keyPrefix: "session:",
keyGen: generateRandomKey,
serializer: GobSerializer{},
}
return rs, rs.client.Ping(ctx).Err()
}
func (s *RedisStore) Client() redis.UniversalClient {
return s.client
}
// Get returns a session for the given name after adding it to the registry.
func (s *RedisStore) Get(r *http.Request, name string) (*sessions.Session, error) {
return sessions.GetRegistry(r).Get(s, name)
}
// New returns a session for the given name without adding it to the registry.
func (s *RedisStore) New(r *http.Request, name string) (*sessions.Session, error) {
session := sessions.NewSession(s, name)
opts := s.options
session.Options = &opts
session.IsNew = true
c, err := r.Cookie(name)
if err != nil {
return session, nil
}
session.ID = c.Value
err = s.load(r.Context(), session)
if err == nil {
session.IsNew = false
} else if err == redis.Nil {
err = nil // no data stored
}
return session, err
}
// Save adds a single session to the response.
//
// If the Options.MaxAge of the session is <= 0 then the session file will be
// deleted from the store. With this process it enforces the properly
// session cookie handling so no need to trust in the cookie management in the
// web browser.
func (s *RedisStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
// Delete if max-age is <= 0
if session.Options.MaxAge <= 0 {
if err := s.delete(r.Context(), session); err != nil {
return err
}
http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
return nil
}
if session.ID == "" {
id, err := s.keyGen()
if err != nil {
return errors.New("redisstore: failed to generate session id")
}
session.ID = id
}
if err := s.save(r.Context(), session); err != nil {
return err
}
http.SetCookie(w, sessions.NewCookie(session.Name(), session.ID, session.Options))
return nil
}
// Options set options to use when a new session is created
func (s *RedisStore) Options(opts sessions.Options) {
s.options = opts
}
// KeyPrefix sets the key prefix to store session in Redis
func (s *RedisStore) KeyPrefix(keyPrefix string) {
s.keyPrefix = keyPrefix
}
// KeyGen sets the key generator function
func (s *RedisStore) KeyGen(f KeyGenFunc) {
s.keyGen = f
}
// Serializer sets the session serializer to store session
func (s *RedisStore) Serializer(ss SessionSerializer) {
s.serializer = ss
}
// Close closes the Redis store
func (s *RedisStore) Close() error {
return s.client.Close()
}
// save writes session in Redis
func (s *RedisStore) save(ctx context.Context, session *sessions.Session) error {
b, err := s.serializer.Serialize(session)
if err != nil {
return err
}
return s.client.Set(ctx, s.keyPrefix+session.ID, b, time.Duration(session.Options.MaxAge)*time.Second).Err()
}
// load reads session from Redis
func (s *RedisStore) load(ctx context.Context, session *sessions.Session) error {
cmd := s.client.Get(ctx, s.keyPrefix+session.ID)
if cmd.Err() != nil {
return cmd.Err()
}
b, err := cmd.Bytes()
if err != nil {
return err
}
return s.serializer.Deserialize(b, session)
}
// delete deletes session in Redis
func (s *RedisStore) delete(ctx context.Context, session *sessions.Session) error {
return s.client.Del(ctx, s.keyPrefix+session.ID).Err()
}
// SessionSerializer provides an interface for serialize/deserialize a session
type SessionSerializer interface {
Serialize(s *sessions.Session) ([]byte, error)
Deserialize(b []byte, s *sessions.Session) error
}
// Gob serializer
type GobSerializer struct{}
func (gs GobSerializer) Serialize(s *sessions.Session) ([]byte, error) {
buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf)
err := enc.Encode(s.Values)
if err == nil {
return buf.Bytes(), nil
}
return nil, err
}
func (gs GobSerializer) Deserialize(d []byte, s *sessions.Session) error {
dec := gob.NewDecoder(bytes.NewBuffer(d))
return dec.Decode(&s.Values)
}
// generateRandomKey returns a new random key
func generateRandomKey() (string, error) {
k := make([]byte, 64)
if _, err := io.ReadFull(rand.Reader, k); err != nil {
return "", err
}
return strings.TrimRight(base32.StdEncoding.EncodeToString(k), "="), nil
}

View file

@ -0,0 +1,158 @@
package redisstore
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/gorilla/sessions"
"github.com/redis/go-redis/v9"
)
const (
redisAddr = "localhost:6379"
)
func TestNew(t *testing.T) {
client := redis.NewClient(&redis.Options{
Addr: redisAddr,
})
store, err := NewRedisStore(context.Background(), client)
if err != nil {
t.Fatal("failed to create redis store", err)
}
req, err := http.NewRequest("GET", "http://www.example.com", nil)
if err != nil {
t.Fatal("failed to create request", err)
}
session, err := store.New(req, "hello")
if err != nil {
t.Fatal("failed to create session", err)
}
if session.IsNew == false {
t.Fatal("session is not new")
}
}
func TestOptions(t *testing.T) {
client := redis.NewClient(&redis.Options{
Addr: redisAddr,
})
store, err := NewRedisStore(context.Background(), client)
if err != nil {
t.Fatal("failed to create redis store", err)
}
opts := sessions.Options{
Path: "/path",
MaxAge: 99999,
}
store.Options(opts)
req, err := http.NewRequest("GET", "http://www.example.com", nil)
if err != nil {
t.Fatal("failed to create request", err)
}
session, err := store.New(req, "hello")
if err != nil {
t.Fatal("failed to create store", err)
}
if session.Options.Path != opts.Path || session.Options.MaxAge != opts.MaxAge {
t.Fatal("failed to set options")
}
}
func TestSave(t *testing.T) {
client := redis.NewClient(&redis.Options{
Addr: redisAddr,
})
store, err := NewRedisStore(context.Background(), client)
if err != nil {
t.Fatal("failed to create redis store", err)
}
req, err := http.NewRequest("GET", "http://www.example.com", nil)
if err != nil {
t.Fatal("failed to create request", err)
}
w := httptest.NewRecorder()
session, err := store.New(req, "hello")
if err != nil {
t.Fatal("failed to create session", err)
}
session.Values["key"] = "value"
err = session.Save(req, w)
if err != nil {
t.Fatal("failed to save: ", err)
}
}
func TestDelete(t *testing.T) {
client := redis.NewClient(&redis.Options{
Addr: redisAddr,
})
store, err := NewRedisStore(context.Background(), client)
if err != nil {
t.Fatal("failed to create redis store", err)
}
req, err := http.NewRequest("GET", "http://www.example.com", nil)
if err != nil {
t.Fatal("failed to create request", err)
}
w := httptest.NewRecorder()
session, err := store.New(req, "hello")
if err != nil {
t.Fatal("failed to create session", err)
}
session.Values["key"] = "value"
err = session.Save(req, w)
if err != nil {
t.Fatal("failed to save session: ", err)
}
session.Options.MaxAge = -1
err = session.Save(req, w)
if err != nil {
t.Fatal("failed to delete session: ", err)
}
}
func TestClose(t *testing.T) {
client := redis.NewClient(&redis.Options{
Addr: redisAddr,
})
cmd := client.Ping(context.Background())
err := cmd.Err()
if err != nil {
t.Fatal("connection is not opened")
}
store, err := NewRedisStore(context.Background(), client)
if err != nil {
t.Fatal("failed to create redis store", err)
}
err = store.Close()
if err != nil {
t.Fatal("failed to close")
}
cmd = client.Ping(context.Background())
if cmd.Err() == nil {
t.Fatal("connection is properly closed")
}
}