From c18917e3d11b01dd0781bc10e2cdbd0cfe16ab60 Mon Sep 17 00:00:00 2001 From: Sylvester Chin Date: Tue, 18 Jul 2023 09:18:03 +0800 Subject: [PATCH 1/4] Switch go go-redis client --- .gitlab/ci/workhorse.gitlab-ci.yml | 4 + workhorse/go.mod | 5 +- workhorse/go.sum | 13 +- workhorse/internal/builds/register.go | 9 +- workhorse/internal/builds/register_test.go | 3 +- workhorse/internal/redis/keywatcher.go | 84 +++--- workhorse/internal/redis/keywatcher_test.go | 133 +++++----- workhorse/internal/redis/redis.go | 269 +++++--------------- workhorse/internal/redis/redis_test.go | 221 +--------------- workhorse/main.go | 6 +- 10 files changed, 203 insertions(+), 544 deletions(-) diff --git a/.gitlab/ci/workhorse.gitlab-ci.yml b/.gitlab/ci/workhorse.gitlab-ci.yml index 00c4dc6c9a9697..283d2e654313a7 100644 --- a/.gitlab/ci/workhorse.gitlab-ci.yml +++ b/.gitlab/ci/workhorse.gitlab-ci.yml @@ -11,6 +11,8 @@ workhorse:verify: .workhorse:test: extends: .workhorse:rules:workhorse image: ${REGISTRY_HOST}/${REGISTRY_GROUP}/gitlab-build-images/debian-${DEBIAN_VERSION}-ruby-${RUBY_VERSION}-golang-${GO_VERSION}-rust-${RUST_VERSION}:rubygems-${RUBYGEMS_VERSION}-git-2.36-exiftool-12.60 + services: + - name: redis:7.0-alpine variables: GITALY_ADDRESS: "tcp://127.0.0.1:8075" stage: test @@ -22,6 +24,8 @@ workhorse:verify: - bundle_install_script - go version - scripts/gitaly-test-build + - cp workhorse/config.toml.example workhorse/config.toml + - sed -i 's|URL.*$|URL = "redis://redis:6379"|g' workhorse/config.toml script: - make -C workhorse test diff --git a/workhorse/go.mod b/workhorse/go.mod index f5933f7efef303..adedafb4e83839 100644 --- a/workhorse/go.mod +++ b/workhorse/go.mod @@ -5,7 +5,6 @@ go 1.18 require ( github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0 github.com/BurntSushi/toml v1.3.2 - github.com/FZambia/sentinel v1.1.1 github.com/alecthomas/chroma/v2 v2.8.0 github.com/aws/aws-sdk-go v1.44.284 github.com/disintegration/imaging v1.6.2 @@ -13,7 +12,6 @@ require ( github.com/golang-jwt/jwt/v5 v5.0.0 github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f github.com/golang/protobuf v1.5.3 - github.com/gomodule/redigo v2.0.0+incompatible github.com/gorilla/websocket v1.5.0 github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 @@ -21,7 +19,7 @@ require ( github.com/jpillora/backoff v1.0.0 github.com/mitchellh/copystructure v1.2.0 github.com/prometheus/client_golang v1.16.0 - github.com/rafaeljusto/redigomock/v3 v3.1.2 + github.com/redis/go-redis/v9 v9.0.5 github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a github.com/sirupsen/logrus v1.9.3 github.com/smartystreets/goconvey v1.7.2 @@ -65,6 +63,7 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/client9/reopen v1.0.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.4.0 // indirect github.com/go-ole/go-ole v1.2.6 // indirect github.com/gogo/protobuf v1.3.2 // indirect diff --git a/workhorse/go.sum b/workhorse/go.sum index e39282f5c137d7..622073a56114f2 100644 --- a/workhorse/go.sum +++ b/workhorse/go.sum @@ -705,8 +705,6 @@ github.com/DataDog/datadog-go v4.4.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3 github.com/DataDog/gostackparse v0.5.0/go.mod h1:lTfqcJKqS9KnXQGnyQMCugq3u1FP6UZMfWR0aitKFMM= github.com/DataDog/sketches-go v1.0.0 h1:chm5KSXO7kO+ywGWJ0Zs6tdmWU8PBXSbywFVciL6BG4= github.com/DataDog/sketches-go v1.0.0/go.mod h1:O+XkJHWk9w4hDwY2ZUDU31ZC9sNYlYo8DiFsxjYeo1k= -github.com/FZambia/sentinel v1.1.1 h1:0ovTimlR7Ldm+wR15GgO+8C2dt7kkn+tm3PQS+Qk3Ek= -github.com/FZambia/sentinel v1.1.1/go.mod h1:ytL1Am/RLlAoAXG6Kj5LNuw/TRRQrv2rt2FT26vP5gI= github.com/GoogleCloudPlatform/cloudsql-proxy v1.33.7/go.mod h1:JBp/RvKNOoIkR5BdMSXswBksHcPZ/41sbBV+GhSjgMY= github.com/HdrHistogram/hdrhistogram-go v1.1.0/go.mod h1:yDgFjdqOqDEKOvasDdhWNXYg9BVp4O+o5f6V/ehm6Oo= github.com/HdrHistogram/hdrhistogram-go v1.1.2 h1:5IcZpTvzydCQeHzK4Ef/D5rrSqwxob0t8PQPMybUNFM= @@ -866,6 +864,8 @@ github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bradfitz/gomemcache v0.0.0-20170208213004-1952afaa557d/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60= github.com/bshuster-repo/logrus-logstash-hook v0.4.1/go.mod h1:zsTqEiSzDgAa/8GZR7E1qaXrhYNDKBYy5/dWPTIflbk= +github.com/bsm/ginkgo/v2 v2.7.0 h1:ItPMPH90RbmZJt5GtkcNvIRuGEdwlBItdNVoyzaNQao= +github.com/bsm/gomega v1.26.0 h1:LhQm+AFcgV2M0WyKroMASzAzCAJVpAxQXv4SaI9a69Y= github.com/buger/jsonparser v0.0.0-20180808090653-f4dd9f5a6b44/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/bugsnag/bugsnag-go v0.0.0-20141110184014-b1d153021fcd/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8= @@ -1065,6 +1065,8 @@ github.com/denverdino/aliyungo v0.0.0-20190125010748-a747050bb1ba/go.mod h1:dV8l github.com/devigned/tab v0.1.1/go.mod h1:XG9mPq0dFghrYvoBF3xdRrJzSTX1b7IQrvaL9mzjeJY= github.com/dgrijalva/jwt-go v0.0.0-20170104182250-a601269ab70c/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/dgryski/go-sip13 v0.0.0-20200911182023-62edffca9245/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/digitalocean/godo v1.78.0/go.mod h1:GBmu8MkjZmNARE7IXRPmkbbnocNN8+uBm0xbEVw2LCs= @@ -1367,9 +1369,6 @@ github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8l github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/gomodule/redigo v1.8.8/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE= -github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= -github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= @@ -2040,10 +2039,10 @@ github.com/prometheus/prometheus v0.35.0/go.mod h1:7HaLx5kEPKJ0GDgbODG0fZgXbQ8K/ github.com/prometheus/prometheus v0.44.0 h1:sgn8Fdx+uE5tHQn0/622swlk2XnIj6udoZCnbVjHIgc= github.com/prometheus/prometheus v0.44.0/go.mod h1:aPsmIK3py5XammeTguyqTmuqzX/jeCdyOWWobLHNKQg= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= -github.com/rafaeljusto/redigomock/v3 v3.1.2 h1:B4Y0XJQiPjpwYmkH55aratKX1VfR+JRqzmDKyZbC99o= -github.com/rafaeljusto/redigomock/v3 v3.1.2/go.mod h1:F9zPqz8rMriScZkPtUiLJoLruYcpGo/XXREpeyasREM= github.com/rakyll/embedmd v0.0.0-20171029212350-c8060a0752a2/go.mod h1:7jOTMgqac46PZcF54q6l2hkLEG8op93fZu61KmxWDV4= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= +github.com/redis/go-redis/v9 v9.0.5 h1:CuQcn5HIEeK7BgElubPP8CGtE0KakrnbBSTLjathl5o= +github.com/redis/go-redis/v9 v9.0.5/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= diff --git a/workhorse/internal/builds/register.go b/workhorse/internal/builds/register.go index 0a2fe47ed7e348..f45d03ab8edb32 100644 --- a/workhorse/internal/builds/register.go +++ b/workhorse/internal/builds/register.go @@ -2,6 +2,7 @@ package builds import ( "bytes" + "context" "encoding/json" "errors" "io" @@ -55,7 +56,7 @@ var ( type largeBodyError struct{ error } -type WatchKeyHandler func(key, value string, timeout time.Duration) (redis.WatchKeyStatus, error) +type WatchKeyHandler func(ctx context.Context, key, value string, timeout time.Duration) (redis.WatchKeyStatus, error) type runnerRequest struct { Token string `json:"token,omitempty"` @@ -102,11 +103,11 @@ func proxyRegisterRequest(h http.Handler, w http.ResponseWriter, r *http.Request h.ServeHTTP(w, r) } -func watchForRunnerChange(watchHandler WatchKeyHandler, token, lastUpdate string, duration time.Duration) (redis.WatchKeyStatus, error) { +func watchForRunnerChange(ctx context.Context, watchHandler WatchKeyHandler, token, lastUpdate string, duration time.Duration) (redis.WatchKeyStatus, error) { registerHandlerOpenAtWatching.Inc() defer registerHandlerOpenAtWatching.Dec() - return watchHandler(runnerBuildQueue+token, lastUpdate, duration) + return watchHandler(ctx, runnerBuildQueue+token, lastUpdate, duration) } func RegisterHandler(h http.Handler, watchHandler WatchKeyHandler, pollingDuration time.Duration) http.Handler { @@ -140,7 +141,7 @@ func RegisterHandler(h http.Handler, watchHandler WatchKeyHandler, pollingDurati return } - result, err := watchForRunnerChange(watchHandler, runnerRequest.Token, + result, err := watchForRunnerChange(r.Context(), watchHandler, runnerRequest.Token, runnerRequest.LastUpdate, pollingDuration) if err != nil { registerHandlerWatchErrors.Inc() diff --git a/workhorse/internal/builds/register_test.go b/workhorse/internal/builds/register_test.go index d5cbebd500bf6b..97d66517ac95be 100644 --- a/workhorse/internal/builds/register_test.go +++ b/workhorse/internal/builds/register_test.go @@ -2,6 +2,7 @@ package builds import ( "bytes" + "context" "errors" "io" "net/http" @@ -71,7 +72,7 @@ func TestRegisterHandlerMissingData(t *testing.T) { func expectWatcherToBeExecuted(t *testing.T, watchKeyStatus redis.WatchKeyStatus, watchKeyError error, httpStatus int, msgAndArgs ...interface{}) { executed := false - watchKeyHandler := func(key, value string, timeout time.Duration) (redis.WatchKeyStatus, error) { + watchKeyHandler := func(ctx context.Context, key, value string, timeout time.Duration) (redis.WatchKeyStatus, error) { executed = true return watchKeyStatus, watchKeyError } diff --git a/workhorse/internal/redis/keywatcher.go b/workhorse/internal/redis/keywatcher.go index cdf6ccd7e83d17..8618e2a015e11c 100644 --- a/workhorse/internal/redis/keywatcher.go +++ b/workhorse/internal/redis/keywatcher.go @@ -1,16 +1,17 @@ package redis import ( + "context" "errors" "fmt" "strings" "sync" "time" - "github.com/gomodule/redigo/redis" "github.com/jpillora/backoff" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" + goredis "github.com/redis/go-redis/v9" "gitlab.com/gitlab-org/gitlab/workhorse/internal/log" ) @@ -20,7 +21,8 @@ type KeyWatcher struct { subscribers map[string][]chan string shutdown chan struct{} reconnectBackoff backoff.Backoff - conn *redis.PubSubConn + redisConn *goredis.Client + conn *goredis.PubSub } func NewKeyWatcher() *KeyWatcher { @@ -73,12 +75,12 @@ const channelPrefix = "workhorse:notifications:" func countAction(action string) { totalActions.WithLabelValues(action).Add(1) } -func (kw *KeyWatcher) receivePubSubStream(conn redis.Conn) error { +func (kw *KeyWatcher) receivePubSubStream(ctx context.Context, pubsub *goredis.PubSub) error { kw.mu.Lock() // We must share kw.conn with the goroutines that call SUBSCRIBE and // UNSUBSCRIBE because Redis pubsub subscriptions are tied to the // connection. - kw.conn = &redis.PubSubConn{Conn: conn} + kw.conn = pubsub kw.mu.Unlock() defer func() { @@ -99,51 +101,49 @@ func (kw *KeyWatcher) receivePubSubStream(conn redis.Conn) error { }() for { - switch v := kw.conn.Receive().(type) { - case redis.Message: + msg, err := kw.conn.Receive(ctx) + if err != nil { + log.WithError(fmt.Errorf("keywatcher: pubsub receive: %v", err)).Error() + return nil + } + + switch msg := msg.(type) { + case *goredis.Subscription: + redisSubscriptions.Set(float64(msg.Count)) + case *goredis.Pong: + // Ignore. + case *goredis.Message: totalMessages.Inc() - receivedBytes.Add(float64(len(v.Data))) - if strings.HasPrefix(v.Channel, channelPrefix) { - kw.notifySubscribers(v.Channel[len(channelPrefix):], string(v.Data)) + receivedBytes.Add(float64(len(msg.Payload))) + if strings.HasPrefix(msg.Channel, channelPrefix) { + kw.notifySubscribers(msg.Channel[len(channelPrefix):], string(msg.Payload)) } - case redis.Subscription: - redisSubscriptions.Set(float64(v.Count)) - case error: - log.WithError(fmt.Errorf("keywatcher: pubsub receive: %v", v)).Error() - // Intermittent error, return nil so that it doesn't wait before reconnect + default: + log.WithError(fmt.Errorf("keywatcher: unknown: %T", msg)).Error() return nil } } } -func dialPubSub(dialer redisDialerFunc) (redis.Conn, error) { - conn, err := dialer() - if err != nil { - return nil, err - } - - // Make sure Redis is actually connected - conn.Do("PING") - if err := conn.Err(); err != nil { - conn.Close() - return nil, err - } +func (kw *KeyWatcher) Process(client *goredis.Client) { + log.Info("keywatcher: starting process loop") - return conn, nil -} + ctx := context.Background() // lint:allow context.Background + kw.mu.Lock() + kw.redisConn = client + kw.mu.Unlock() -func (kw *KeyWatcher) Process() { - log.Info("keywatcher: starting process loop") for { - conn, err := dialPubSub(workerDialFunc) - if err != nil { + pubsub := client.Subscribe(ctx, []string{}...) + if err := pubsub.Ping(ctx); err != nil { log.WithError(fmt.Errorf("keywatcher: %v", err)).Error() time.Sleep(kw.reconnectBackoff.Duration()) continue } + kw.reconnectBackoff.Reset() - if err = kw.receivePubSubStream(conn); err != nil { + if err := kw.receivePubSubStream(ctx, pubsub); err != nil { log.WithError(fmt.Errorf("keywatcher: receivePubSubStream: %v", err)).Error() } } @@ -182,7 +182,7 @@ func (kw *KeyWatcher) notifySubscribers(key, value string) { } } -func (kw *KeyWatcher) addSubscription(key string, notify chan string) error { +func (kw *KeyWatcher) addSubscription(ctx context.Context, key string, notify chan string) error { kw.mu.Lock() defer kw.mu.Unlock() @@ -195,7 +195,7 @@ func (kw *KeyWatcher) addSubscription(key string, notify chan string) error { if len(kw.subscribers[key]) == 0 { countAction("create-subscription") - if err := kw.conn.Subscribe(channelPrefix + key); err != nil { + if err := kw.conn.Subscribe(ctx, channelPrefix+key); err != nil { return err } } @@ -209,7 +209,7 @@ func (kw *KeyWatcher) addSubscription(key string, notify chan string) error { return nil } -func (kw *KeyWatcher) delSubscription(key string, notify chan string) { +func (kw *KeyWatcher) delSubscription(ctx context.Context, key string, notify chan string) { kw.mu.Lock() defer kw.mu.Unlock() @@ -231,7 +231,7 @@ func (kw *KeyWatcher) delSubscription(key string, notify chan string) { delete(kw.subscribers, key) countAction("delete-subscription") if kw.conn != nil { - kw.conn.Unsubscribe(channelPrefix + key) + kw.conn.Unsubscribe(ctx, channelPrefix+key) } } } @@ -251,15 +251,15 @@ const ( WatchKeyStatusNoChange ) -func (kw *KeyWatcher) WatchKey(key, value string, timeout time.Duration) (WatchKeyStatus, error) { +func (kw *KeyWatcher) WatchKey(ctx context.Context, key, value string, timeout time.Duration) (WatchKeyStatus, error) { notify := make(chan string, 1) - if err := kw.addSubscription(key, notify); err != nil { + if err := kw.addSubscription(ctx, key, notify); err != nil { return WatchKeyStatusNoChange, err } - defer kw.delSubscription(key, notify) + defer kw.delSubscription(ctx, key, notify) - currentValue, err := GetString(key) - if errors.Is(err, redis.ErrNil) { + currentValue, err := kw.redisConn.Get(ctx, key).Result() + if errors.Is(err, goredis.Nil) { currentValue = "" } else if err != nil { return WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET: %v", err) diff --git a/workhorse/internal/redis/keywatcher_test.go b/workhorse/internal/redis/keywatcher_test.go index bae49d81bb1893..bca4ca43a644a6 100644 --- a/workhorse/internal/redis/keywatcher_test.go +++ b/workhorse/internal/redis/keywatcher_test.go @@ -1,40 +1,27 @@ package redis import ( + "context" + "os" "sync" "testing" "time" - "github.com/gomodule/redigo/redis" - "github.com/rafaeljusto/redigomock/v3" "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" ) +var ctx = context.Background() + const ( runnerKey = "runner:build_queue:10" ) -func createSubscriptionMessage(key, data string) []interface{} { - return []interface{}{ - []byte("message"), - []byte(key), - []byte(data), - } -} - -func createSubscribeMessage(key string) []interface{} { - return []interface{}{ - []byte("subscribe"), - []byte(key), - []byte("1"), - } -} -func createUnsubscribeMessage(key string) []interface{} { - return []interface{}{ - []byte("unsubscribe"), - []byte(key), - []byte("1"), - } +func initRdb() { + buf, _ := os.ReadFile("../../config.toml") + cfg, _ := config.LoadConfig(string(buf)) + Configure(cfg.Redis) } func (kw *KeyWatcher) countSubscribers(key string) int { @@ -44,17 +31,14 @@ func (kw *KeyWatcher) countSubscribers(key string) int { } // Forces a run of the `Process` loop against a mock PubSubConn. -func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value string, ready chan<- struct{}) { - psc := redigomock.NewConn() - psc.ReceiveWait = true - - channel := channelPrefix + runnerKey - psc.Command("SUBSCRIBE", channel).Expect(createSubscribeMessage(channel)) - psc.Command("UNSUBSCRIBE", channel).Expect(createUnsubscribeMessage(channel)) - psc.AddSubscriptionMessage(createSubscriptionMessage(channel, value)) +func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value string, ready chan<- struct{}, wg *sync.WaitGroup) { + kw.mu.Lock() + kw.redisConn = rdb + psc := kw.redisConn.Subscribe(ctx, []string{}...) + kw.mu.Unlock() errC := make(chan error) - go func() { errC <- kw.receivePubSubStream(psc) }() + go func() { errC <- kw.receivePubSubStream(ctx, psc) }() require.Eventually(t, func() bool { kw.mu.Lock() @@ -66,7 +50,15 @@ func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value strin require.Eventually(t, func() bool { return kw.countSubscribers(runnerKey) == numWatchers }, time.Second, time.Millisecond) - close(psc.ReceiveNow) + + // send message after listeners are ready + kw.redisConn.Publish(ctx, channelPrefix+runnerKey, value) + + // close subscription after all workers are done + wg.Wait() + kw.mu.Lock() + kw.conn.Close() + kw.mu.Unlock() require.NoError(t, <-errC) } @@ -82,6 +74,8 @@ type keyChangeTestCase struct { } func TestKeyChangesInstantReturn(t *testing.T) { + initRdb() + testCases := []keyChangeTestCase{ // WatchKeyStatusAlreadyChanged { @@ -118,20 +112,22 @@ func TestKeyChangesInstantReturn(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - conn, td := setupMockPool() - defer td() - if tc.isKeyMissing { - conn.Command("GET", runnerKey).ExpectError(redis.ErrNil) - } else { - conn.Command("GET", runnerKey).Expect(tc.returnValue) + // setup + if !tc.isKeyMissing { + rdb.Set(ctx, runnerKey, tc.returnValue, 0) } + defer func() { + rdb.FlushDB(ctx) + }() + kw := NewKeyWatcher() defer kw.Shutdown() - kw.conn = &redis.PubSubConn{Conn: redigomock.NewConn()} + kw.redisConn = rdb + kw.conn = kw.redisConn.Subscribe(ctx, []string{}...) - val, err := kw.WatchKey(runnerKey, tc.watchValue, tc.timeout) + val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, tc.timeout) require.NoError(t, err, "Expected no error") require.Equal(t, tc.expectedStatus, val, "Expected value") @@ -140,6 +136,8 @@ func TestKeyChangesInstantReturn(t *testing.T) { } func TestKeyChangesWhenWatching(t *testing.T) { + initRdb() + testCases := []keyChangeTestCase{ // WatchKeyStatusSeenChange { @@ -168,17 +166,15 @@ func TestKeyChangesWhenWatching(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - conn, td := setupMockPool() - defer td() - - if tc.isKeyMissing { - conn.Command("GET", runnerKey).ExpectError(redis.ErrNil) - } else { - conn.Command("GET", runnerKey).Expect(tc.returnValue) + if !tc.isKeyMissing { + rdb.Set(ctx, runnerKey, tc.returnValue, 0) } kw := NewKeyWatcher() defer kw.Shutdown() + defer func() { + rdb.FlushDB(ctx) + }() wg := &sync.WaitGroup{} wg.Add(1) @@ -187,19 +183,20 @@ func TestKeyChangesWhenWatching(t *testing.T) { go func() { defer wg.Done() <-ready - val, err := kw.WatchKey(runnerKey, tc.watchValue, time.Second) + val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, time.Second) require.NoError(t, err, "Expected no error") require.Equal(t, tc.expectedStatus, val, "Expected value") }() - kw.processMessages(t, 1, tc.processedValue, ready) - wg.Wait() + kw.processMessages(t, 1, tc.processedValue, ready, wg) }) } } func TestKeyChangesParallel(t *testing.T) { + initRdb() + testCases := []keyChangeTestCase{ { desc: "massively parallel, sees change with key existing", @@ -221,19 +218,14 @@ func TestKeyChangesParallel(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { runTimes := 100 - conn, td := setupMockPool() - defer td() - - getCmd := conn.Command("GET", runnerKey) - - for i := 0; i < runTimes; i++ { - if tc.isKeyMissing { - getCmd = getCmd.ExpectError(redis.ErrNil) - } else { - getCmd = getCmd.Expect(tc.returnValue) - } + if !tc.isKeyMissing { + rdb.Set(ctx, runnerKey, tc.returnValue, 0) } + defer func() { + rdb.FlushDB(ctx) + }() + wg := &sync.WaitGroup{} wg.Add(runTimes) ready := make(chan struct{}) @@ -245,35 +237,34 @@ func TestKeyChangesParallel(t *testing.T) { go func() { defer wg.Done() <-ready - val, err := kw.WatchKey(runnerKey, tc.watchValue, time.Second) + val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, time.Second) require.NoError(t, err, "Expected no error") require.Equal(t, tc.expectedStatus, val, "Expected value") }() } - kw.processMessages(t, runTimes, tc.processedValue, ready) - wg.Wait() + kw.processMessages(t, runTimes, tc.processedValue, ready, wg) }) } } func TestShutdown(t *testing.T) { - conn, td := setupMockPool() - defer td() + initRdb() kw := NewKeyWatcher() - kw.conn = &redis.PubSubConn{Conn: redigomock.NewConn()} + kw.redisConn = rdb + kw.conn = kw.redisConn.Subscribe(ctx, []string{}...) defer kw.Shutdown() - conn.Command("GET", runnerKey).Expect("something") + rdb.Set(ctx, runnerKey, "something", 0) wg := &sync.WaitGroup{} wg.Add(2) go func() { defer wg.Done() - val, err := kw.WatchKey(runnerKey, "something", 10*time.Second) + val, err := kw.WatchKey(ctx, runnerKey, "something", 10*time.Second) require.NoError(t, err, "Expected no error") require.Equal(t, WatchKeyStatusNoChange, val, "Expected value not to change") @@ -295,7 +286,7 @@ func TestShutdown(t *testing.T) { var err error done := make(chan struct{}) go func() { - val, err = kw.WatchKey(runnerKey, "something", 10*time.Second) + val, err = kw.WatchKey(ctx, runnerKey, "something", 10*time.Second) close(done) }() diff --git a/workhorse/internal/redis/redis.go b/workhorse/internal/redis/redis.go index 03118cfcef65f4..7aa6be9b6795e2 100644 --- a/workhorse/internal/redis/redis.go +++ b/workhorse/internal/redis/redis.go @@ -1,24 +1,20 @@ package redis import ( + "context" "fmt" - "net" - "net/url" "time" - "github.com/FZambia/sentinel" - "github.com/gomodule/redigo/redis" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - "gitlab.com/gitlab-org/labkit/log" + goredis "github.com/redis/go-redis/v9" "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" - "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" + _ "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" ) var ( - pool *redis.Pool - sntnl *sentinel.Sentinel + rdb *goredis.Client ) const ( @@ -36,12 +32,6 @@ const ( // If you _actually_ hit this timeout often, you should consider turning of // redis-support since it's not necessary at that point... defaultIdleTimeout = 3 * time.Minute - // KeepAlivePeriod is to keep a TCP connection open for an extended period of - // time without being killed. This is used both in the pool, and in the - // worker-connection. - // See https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive for more - // information. - defaultKeepAlivePeriod = 5 * time.Minute ) var ( @@ -61,216 +51,91 @@ var ( ) ) -func sentinelConn(master string, urls []config.TomlURL) *sentinel.Sentinel { - if len(urls) == 0 { - return nil - } - var addrs []string - for _, url := range urls { - h := url.URL.String() - log.WithFields(log.Fields{ - "scheme": url.URL.Scheme, - "host": url.URL.Host, - }).Printf("redis: using sentinel") - addrs = append(addrs, h) - } - return &sentinel.Sentinel{ - Addrs: addrs, - MasterName: master, - Dial: func(addr string) (redis.Conn, error) { - // This timeout is recommended for Sentinel-support according to the guidelines. - // https://redis.io/topics/sentinel-clients#redis-service-discovery-via-sentinel - // For every address it should try to connect to the Sentinel, - // using a short timeout (in the order of a few hundreds of milliseconds). - timeout := 500 * time.Millisecond - url := helper.URLMustParse(addr) +// this Limiter effectively acts as a middleware to track dial successes and errors +type observabilityLimiter struct{} - var c redis.Conn - var err error - options := []redis.DialOption{ - redis.DialConnectTimeout(timeout), - redis.DialReadTimeout(timeout), - redis.DialWriteTimeout(timeout), - } +func (o observabilityLimiter) Allow() error { return nil } - if url.Scheme == "redis" || url.Scheme == "rediss" { - c, err = redis.DialURL(addr, options...) - } else { - c, err = redis.Dial("tcp", url.Host, options...) - } - - if err != nil { - errorCounter.WithLabelValues("dial", "sentinel").Inc() - return nil, err - } - return c, nil - }, - } +func (o observabilityLimiter) ReportResult(result error) { + errorCounter.WithLabelValues("dial", "redis").Inc() } -var poolDialFunc func() (redis.Conn, error) -var workerDialFunc func() (redis.Conn, error) - -func timeoutDialOptions(cfg *config.RedisConfig) []redis.DialOption { - return []redis.DialOption{ - redis.DialReadTimeout(defaultReadTimeout), - redis.DialWriteTimeout(defaultWriteTimeout), - } +func GetRedisClient() *goredis.Client { + return rdb } -func dialOptionsBuilder(cfg *config.RedisConfig, setTimeouts bool) []redis.DialOption { - var dopts []redis.DialOption - if setTimeouts { - dopts = timeoutDialOptions(cfg) - } +// Configure redis-connection +func Configure(cfg *config.RedisConfig) { if cfg == nil { - return dopts - } - if cfg.Password != "" { - dopts = append(dopts, redis.DialPassword(cfg.Password)) - } - if cfg.DB != nil { - dopts = append(dopts, redis.DialDatabase(*cfg.DB)) - } - return dopts -} - -func keepAliveDialer(network, address string) (net.Conn, error) { - addr, err := net.ResolveTCPAddr(network, address) - if err != nil { - return nil, err + return } - tc, err := net.DialTCP(network, nil, addr) - if err != nil { - return nil, err + maxIdle := defaultMaxIdle + if cfg.MaxIdle != nil { + maxIdle = *cfg.MaxIdle } - if err := tc.SetKeepAlive(true); err != nil { - return nil, err + maxActive := defaultMaxActive + if cfg.MaxActive != nil { + maxActive = *cfg.MaxActive } - if err := tc.SetKeepAlivePeriod(defaultKeepAlivePeriod); err != nil { - return nil, err + db := 0 + if cfg.DB != nil { + db = *cfg.DB } - return tc, nil -} -type redisDialerFunc func() (redis.Conn, error) + limiter := observabilityLimiter{} -func sentinelDialer(dopts []redis.DialOption) redisDialerFunc { - return func() (redis.Conn, error) { - address, err := sntnl.MasterAddr() - if err != nil { - errorCounter.WithLabelValues("master", "sentinel").Inc() - return nil, err - } - dopts = append(dopts, redis.DialNetDial(keepAliveDialer)) - conn, err := redisDial("tcp", address, dopts...) - if err != nil { - return nil, err - } - if !sentinel.TestRole(conn, "master") { - conn.Close() - return nil, fmt.Errorf("%s is not redis master", address) - } - return conn, nil + onConnectHook := func(ctx context.Context, cn *goredis.Conn) error { + totalConnections.Inc() + return nil } -} -func defaultDialer(dopts []redis.DialOption, url url.URL) redisDialerFunc { - return func() (redis.Conn, error) { - if url.Scheme == "unix" { - return redisDial(url.Scheme, url.Path, dopts...) + if len(cfg.Sentinel) > 0 { + sentinels := make([]string, len(cfg.Sentinel)) + for i := range cfg.Sentinel { + sentinelDetails := cfg.Sentinel[i] + sentinels[i] = fmt.Sprintf("%s:%s", sentinelDetails.Hostname(), sentinelDetails.Port()) } - dopts = append(dopts, redis.DialNetDial(keepAliveDialer)) - - // redis.DialURL only works with redis[s]:// URLs - if url.Scheme == "redis" || url.Scheme == "rediss" { - return redisURLDial(url, dopts...) - } - - return redisDial(url.Scheme, url.Host, dopts...) - } -} - -func redisURLDial(url url.URL, options ...redis.DialOption) (redis.Conn, error) { - log.WithFields(log.Fields{ - "scheme": url.Scheme, - "address": url.Host, - }).Printf("redis: dialing") - - return redis.DialURL(url.String(), options...) -} - -func redisDial(network, address string, options ...redis.DialOption) (redis.Conn, error) { - log.WithFields(log.Fields{ - "network": network, - "address": address, - }).Printf("redis: dialing") - - return redis.Dial(network, address, options...) -} - -func countDialer(dialer redisDialerFunc) redisDialerFunc { - return func() (redis.Conn, error) { - c, err := dialer() + // This timeout is recommended for Sentinel-support according to the guidelines. + // https://redis.io/topics/sentinel-clients#redis-service-discovery-via-sentinel + // For every address it should try to connect to the Sentinel, + // using a short timeout (in the order of a few hundreds of milliseconds). + timeout := 500 * time.Millisecond + + rdb = goredis.NewFailoverClient(&goredis.FailoverOptions{ + MasterName: cfg.SentinelMaster, + SentinelAddrs: sentinels, + Password: cfg.Password, + DB: db, + + PoolSize: maxActive, + MaxIdleConns: maxIdle, + ConnMaxIdleTime: defaultIdleTimeout, + + DialTimeout: timeout, + ReadTimeout: defaultReadTimeout, + WriteTimeout: defaultWriteTimeout, + + OnConnect: onConnectHook, + }) + } else { + opt, err := goredis.ParseURL(cfg.URL.String()) if err != nil { - errorCounter.WithLabelValues("dial", "redis").Inc() - } else { - totalConnections.Inc() + return } - return c, err - } -} -// DefaultDialFunc should always used. Only exception is for unit-tests. -func DefaultDialFunc(cfg *config.RedisConfig, setReadTimeout bool) func() (redis.Conn, error) { - dopts := dialOptionsBuilder(cfg, setReadTimeout) - if sntnl != nil { - return countDialer(sentinelDialer(dopts)) - } - return countDialer(defaultDialer(dopts, cfg.URL.URL)) -} + opt.DB = db + opt.Password = cfg.Password -// Configure redis-connection -func Configure(cfg *config.RedisConfig, dialFunc func(*config.RedisConfig, bool) func() (redis.Conn, error)) { - if cfg == nil { - return - } - maxIdle := defaultMaxIdle - if cfg.MaxIdle != nil { - maxIdle = *cfg.MaxIdle - } - maxActive := defaultMaxActive - if cfg.MaxActive != nil { - maxActive = *cfg.MaxActive - } - sntnl = sentinelConn(cfg.SentinelMaster, cfg.Sentinel) - workerDialFunc = dialFunc(cfg, false) - poolDialFunc = dialFunc(cfg, true) - pool = &redis.Pool{ - MaxIdle: maxIdle, // Keep at most X hot connections - MaxActive: maxActive, // Keep at most X live connections, 0 means unlimited - IdleTimeout: defaultIdleTimeout, // X time until an unused connection is closed - Dial: poolDialFunc, - Wait: true, - } -} + opt.PoolSize = maxActive + opt.MaxIdleConns = maxIdle + opt.ConnMaxIdleTime = defaultIdleTimeout + opt.ReadTimeout = defaultReadTimeout + opt.WriteTimeout = defaultWriteTimeout -// Get a connection for the Redis-pool -func Get() redis.Conn { - if pool != nil { - return pool.Get() - } - return nil -} + opt.OnConnect = onConnectHook + opt.Limiter = limiter -// GetString fetches the value of a key in Redis as a string -func GetString(key string) (string, error) { - conn := Get() - if conn == nil { - return "", fmt.Errorf("redis: could not get connection from pool") + rdb = goredis.NewClient(opt) } - defer conn.Close() - - return redis.String(conn.Do("GET", key)) } diff --git a/workhorse/internal/redis/redis_test.go b/workhorse/internal/redis/redis_test.go index 64b3a842a54e50..63adb9d90288db 100644 --- a/workhorse/internal/redis/redis_test.go +++ b/workhorse/internal/redis/redis_test.go @@ -1,228 +1,27 @@ package redis import ( - "net" + "os" "testing" - "time" - "github.com/gomodule/redigo/redis" - "github.com/rafaeljusto/redigomock/v3" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" - "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" ) -func mockRedisServer(t *testing.T, connectReceived *bool) string { - ln, err := net.Listen("tcp", "127.0.0.1:0") - - require.Nil(t, err) - - go func() { - defer ln.Close() - conn, err := ln.Accept() - require.Nil(t, err) - *connectReceived = true - conn.Write([]byte("OK\n")) - }() - - return ln.Addr().String() -} - -// Setup a MockPool for Redis -// -// Returns a teardown-function and the mock-connection -func setupMockPool() (*redigomock.Conn, func()) { - conn := redigomock.NewConn() - cfg := &config.RedisConfig{URL: config.TomlURL{}} - Configure(cfg, func(_ *config.RedisConfig, _ bool) func() (redis.Conn, error) { - return func() (redis.Conn, error) { - return conn, nil - } - }) - return conn, func() { - pool = nil - } -} - -func TestDefaultDialFunc(t *testing.T) { - testCases := []struct { - scheme string - }{ - { - scheme: "tcp", - }, - { - scheme: "redis", - }, - } - - for _, tc := range testCases { - t.Run(tc.scheme, func(t *testing.T) { - connectReceived := false - a := mockRedisServer(t, &connectReceived) - - parsedURL := helper.URLMustParse(tc.scheme + "://" + a) - cfg := &config.RedisConfig{URL: config.TomlURL{URL: *parsedURL}} - - dialer := DefaultDialFunc(cfg, true) - conn, err := dialer() - - require.Nil(t, err) - conn.Receive() - - require.True(t, connectReceived) - }) - } -} - func TestConfigureNoConfig(t *testing.T) { - pool = nil - Configure(nil, nil) - require.Nil(t, pool, "Pool should be nil") -} - -func TestConfigureMinimalConfig(t *testing.T) { - cfg := &config.RedisConfig{URL: config.TomlURL{}, Password: ""} - Configure(cfg, DefaultDialFunc) - - require.NotNil(t, pool, "Pool should not be nil") - require.Equal(t, 1, pool.MaxIdle) - require.Equal(t, 1, pool.MaxActive) - require.Equal(t, 3*time.Minute, pool.IdleTimeout) - - pool = nil -} - -func TestConfigureFullConfig(t *testing.T) { - i, a := 4, 10 - cfg := &config.RedisConfig{ - URL: config.TomlURL{}, - Password: "", - MaxIdle: &i, - MaxActive: &a, - } - Configure(cfg, DefaultDialFunc) - - require.NotNil(t, pool, "Pool should not be nil") - require.Equal(t, i, pool.MaxIdle) - require.Equal(t, a, pool.MaxActive) - require.Equal(t, 3*time.Minute, pool.IdleTimeout) - - pool = nil -} - -func TestGetConnFail(t *testing.T) { - conn := Get() - require.Nil(t, conn, "Expected `conn` to be nil") -} - -func TestGetConnPass(t *testing.T) { - _, teardown := setupMockPool() - defer teardown() - conn := Get() - require.NotNil(t, conn, "Expected `conn` to be non-nil") + rdb = nil + Configure(nil) + require.Nil(t, rdb, "rdb client should be nil") } -func TestGetStringPass(t *testing.T) { - conn, teardown := setupMockPool() - defer teardown() - conn.Command("GET", "foobar").Expect("baz") - str, err := GetString("foobar") +func TestConfigureValidConfig(t *testing.T) { + buf, _ := os.ReadFile("../../config.toml") + cfg, _ := config.LoadConfig(string(buf)) - require.NoError(t, err, "Expected `err` to be nil") - var value string - require.IsType(t, value, str, "Expected value to be a string") - require.Equal(t, "baz", str, "Expected it to be equal") -} - -func TestGetStringFail(t *testing.T) { - _, err := GetString("foobar") - require.Error(t, err, "Expected error when not connected to redis") -} - -func TestSentinelConnNoSentinel(t *testing.T) { - s := sentinelConn("", []config.TomlURL{}) - - require.Nil(t, s, "Sentinel without urls should return nil") -} - -func TestSentinelConnDialURL(t *testing.T) { - testCases := []struct { - scheme string - }{ - { - scheme: "tcp", - }, - { - scheme: "redis", - }, - } - - for _, tc := range testCases { - t.Run(tc.scheme, func(t *testing.T) { - connectReceived := false - a := mockRedisServer(t, &connectReceived) - - addrs := []string{tc.scheme + "://" + a} - var sentinelUrls []config.TomlURL - - for _, a := range addrs { - parsedURL := helper.URLMustParse(a) - sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL}) - } - - s := sentinelConn("foobar", sentinelUrls) - require.Equal(t, len(addrs), len(s.Addrs)) - - for i := range addrs { - require.Equal(t, addrs[i], s.Addrs[i]) - } - - conn, err := s.Dial(s.Addrs[0]) + Configure(cfg.Redis) - require.Nil(t, err) - conn.Receive() - - require.True(t, connectReceived) - }) - } -} - -func TestSentinelConnTwoURLs(t *testing.T) { - addrs := []string{"tcp://10.0.0.1:12345", "tcp://10.0.0.2:12345"} - var sentinelUrls []config.TomlURL - - for _, a := range addrs { - parsedURL := helper.URLMustParse(a) - sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL}) - } - - s := sentinelConn("foobar", sentinelUrls) - require.Equal(t, len(addrs), len(s.Addrs)) - - for i := range addrs { - require.Equal(t, addrs[i], s.Addrs[i]) - } -} - -func TestDialOptionsBuildersPassword(t *testing.T) { - dopts := dialOptionsBuilder(&config.RedisConfig{Password: "foo"}, false) - require.Equal(t, 1, len(dopts)) -} - -func TestDialOptionsBuildersSetTimeouts(t *testing.T) { - dopts := dialOptionsBuilder(nil, true) - require.Equal(t, 2, len(dopts)) -} - -func TestDialOptionsBuildersSetTimeoutsConfig(t *testing.T) { - dopts := dialOptionsBuilder(nil, true) - require.Equal(t, 2, len(dopts)) -} + require.NotNil(t, rdb, "Pool should not be nil") -func TestDialOptionsBuildersSelectDB(t *testing.T) { - db := 3 - dopts := dialOptionsBuilder(&config.RedisConfig{DB: &db}, false) - require.Equal(t, 1, len(dopts)) + rdb = nil } diff --git a/workhorse/main.go b/workhorse/main.go index ca9b86de528a73..6d79f17850a45e 100644 --- a/workhorse/main.go +++ b/workhorse/main.go @@ -224,9 +224,9 @@ func run(boot bootConfig, cfg config.Config) error { secret.SetPath(boot.secretPath) keyWatcher := redis.NewKeyWatcher() - if cfg.Redis != nil { - redis.Configure(cfg.Redis, redis.DefaultDialFunc) - go keyWatcher.Process() + redis.Configure(cfg.Redis) + if rdb := redis.GetRedisClient(); rdb != nil { + go keyWatcher.Process(rdb) } if err := cfg.RegisterGoCloudURLOpeners(); err != nil { -- GitLab From 99729d843dfb56451564e7c7a170488a73da3cdd Mon Sep 17 00:00:00 2001 From: Sylvester Chin Date: Tue, 25 Jul 2023 20:28:29 +0800 Subject: [PATCH 2/4] Restructure MR to enable feature flag toggle --- workhorse/go.mod | 3 + workhorse/go.sum | 9 + workhorse/internal/goredis/goredis.go | 123 +++++++ workhorse/internal/goredis/goredis_test.go | 27 ++ workhorse/internal/goredis/keywatcher.go | 236 ++++++++++++++ workhorse/internal/goredis/keywatcher_test.go | 301 ++++++++++++++++++ workhorse/internal/redis/keywatcher.go | 103 +++--- workhorse/internal/redis/keywatcher_test.go | 122 +++---- workhorse/internal/redis/redis.go | 273 ++++++++++++---- workhorse/internal/redis/redis_test.go | 221 ++++++++++++- workhorse/main.go | 26 +- 11 files changed, 1255 insertions(+), 189 deletions(-) create mode 100644 workhorse/internal/goredis/goredis.go create mode 100644 workhorse/internal/goredis/goredis_test.go create mode 100644 workhorse/internal/goredis/keywatcher.go create mode 100644 workhorse/internal/goredis/keywatcher_test.go diff --git a/workhorse/go.mod b/workhorse/go.mod index adedafb4e83839..da79087170ef7d 100644 --- a/workhorse/go.mod +++ b/workhorse/go.mod @@ -55,6 +55,7 @@ require ( github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0 // indirect github.com/DataDog/datadog-go v4.4.0+incompatible // indirect github.com/DataDog/sketches-go v1.0.0 // indirect + github.com/FZambia/sentinel v1.1.1 // indirect github.com/Microsoft/go-winio v0.6.0 // indirect github.com/beevik/ntp v1.0.0 // indirect github.com/beorn7/perks v1.0.1 // indirect @@ -69,6 +70,7 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt/v4 v4.5.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect + github.com/gomodule/redigo v2.0.0+incompatible // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/google/pprof v0.0.0-20230406165453-00490a63f317 // indirect github.com/google/s2a-go v0.1.4 // indirect @@ -97,6 +99,7 @@ require ( github.com/prometheus/common v0.42.0 // indirect github.com/prometheus/procfs v0.10.1 // indirect github.com/prometheus/prometheus v0.44.0 // indirect + github.com/rafaeljusto/redigomock/v3 v3.1.2 // indirect github.com/ryszard/goskiplist v0.0.0-20150312221310-2dfbae5fcf46 // indirect github.com/shabbyrobe/gocovmerge v0.0.0-20190829150210-3e036491d500 // indirect github.com/shirou/gopsutil/v3 v3.21.12 // indirect diff --git a/workhorse/go.sum b/workhorse/go.sum index 622073a56114f2..abb06459bc00e9 100644 --- a/workhorse/go.sum +++ b/workhorse/go.sum @@ -705,6 +705,8 @@ github.com/DataDog/datadog-go v4.4.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3 github.com/DataDog/gostackparse v0.5.0/go.mod h1:lTfqcJKqS9KnXQGnyQMCugq3u1FP6UZMfWR0aitKFMM= github.com/DataDog/sketches-go v1.0.0 h1:chm5KSXO7kO+ywGWJ0Zs6tdmWU8PBXSbywFVciL6BG4= github.com/DataDog/sketches-go v1.0.0/go.mod h1:O+XkJHWk9w4hDwY2ZUDU31ZC9sNYlYo8DiFsxjYeo1k= +github.com/FZambia/sentinel v1.1.1 h1:0ovTimlR7Ldm+wR15GgO+8C2dt7kkn+tm3PQS+Qk3Ek= +github.com/FZambia/sentinel v1.1.1/go.mod h1:ytL1Am/RLlAoAXG6Kj5LNuw/TRRQrv2rt2FT26vP5gI= github.com/GoogleCloudPlatform/cloudsql-proxy v1.33.7/go.mod h1:JBp/RvKNOoIkR5BdMSXswBksHcPZ/41sbBV+GhSjgMY= github.com/HdrHistogram/hdrhistogram-go v1.1.0/go.mod h1:yDgFjdqOqDEKOvasDdhWNXYg9BVp4O+o5f6V/ehm6Oo= github.com/HdrHistogram/hdrhistogram-go v1.1.2 h1:5IcZpTvzydCQeHzK4Ef/D5rrSqwxob0t8PQPMybUNFM= @@ -1369,6 +1371,11 @@ github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8l github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/gomodule/redigo v1.8.8/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE= +github.com/gomodule/redigo v1.8.9 h1:Sl3u+2BI/kk+VEatbj0scLdrFhjPmbxOc1myhDP41ws= +github.com/gomodule/redigo v1.8.9/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE= +github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= +github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= @@ -2039,6 +2046,8 @@ github.com/prometheus/prometheus v0.35.0/go.mod h1:7HaLx5kEPKJ0GDgbODG0fZgXbQ8K/ github.com/prometheus/prometheus v0.44.0 h1:sgn8Fdx+uE5tHQn0/622swlk2XnIj6udoZCnbVjHIgc= github.com/prometheus/prometheus v0.44.0/go.mod h1:aPsmIK3py5XammeTguyqTmuqzX/jeCdyOWWobLHNKQg= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= +github.com/rafaeljusto/redigomock/v3 v3.1.2 h1:B4Y0XJQiPjpwYmkH55aratKX1VfR+JRqzmDKyZbC99o= +github.com/rafaeljusto/redigomock/v3 v3.1.2/go.mod h1:F9zPqz8rMriScZkPtUiLJoLruYcpGo/XXREpeyasREM= github.com/rakyll/embedmd v0.0.0-20171029212350-c8060a0752a2/go.mod h1:7jOTMgqac46PZcF54q6l2hkLEG8op93fZu61KmxWDV4= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/redis/go-redis/v9 v9.0.5 h1:CuQcn5HIEeK7BgElubPP8CGtE0KakrnbBSTLjathl5o= diff --git a/workhorse/internal/goredis/goredis.go b/workhorse/internal/goredis/goredis.go new file mode 100644 index 00000000000000..403538e1bdd467 --- /dev/null +++ b/workhorse/internal/goredis/goredis.go @@ -0,0 +1,123 @@ +package goredis + +import ( + "context" + "fmt" + "time" + + redis "github.com/redis/go-redis/v9" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" + internalredis "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" + _ "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" +) + +var ( + rdb *redis.Client +) + +const ( + // Max Idle Connections in the pool. + defaultMaxIdle = 1 + // Max Active Connections in the pool. + defaultMaxActive = 1 + // Timeout for Read operations on the pool. 1 second is technically overkill, + // it's just for sanity. + defaultReadTimeout = 1 * time.Second + // Timeout for Write operations on the pool. 1 second is technically overkill, + // it's just for sanity. + defaultWriteTimeout = 1 * time.Second + // Timeout before killing Idle connections in the pool. 3 minutes seemed good. + // If you _actually_ hit this timeout often, you should consider turning of + // redis-support since it's not necessary at that point... + defaultIdleTimeout = 3 * time.Minute +) + +// this Limiter effectively acts as a middleware to track dial successes and errors +type observabilityLimiter struct{} + +func (o observabilityLimiter) Allow() error { return nil } + +func (o observabilityLimiter) ReportResult(result error) { + internalredis.ErrorCounter.WithLabelValues("dial", "redis").Inc() +} + +func GetRedisClient() *redis.Client { + return rdb +} + +// Configure redis-connection +func Configure(cfg *config.RedisConfig) { + if cfg == nil { + return + } + maxIdle := defaultMaxIdle + if cfg.MaxIdle != nil { + maxIdle = *cfg.MaxIdle + } + maxActive := defaultMaxActive + if cfg.MaxActive != nil { + maxActive = *cfg.MaxActive + } + db := 0 + if cfg.DB != nil { + db = *cfg.DB + } + + limiter := observabilityLimiter{} + + onConnectHook := func(ctx context.Context, cn *redis.Conn) error { + internalredis.TotalConnections.Inc() + return nil + } + + if len(cfg.Sentinel) > 0 { + sentinels := make([]string, len(cfg.Sentinel)) + for i := range cfg.Sentinel { + sentinelDetails := cfg.Sentinel[i] + sentinels[i] = fmt.Sprintf("%s:%s", sentinelDetails.Hostname(), sentinelDetails.Port()) + } + + // This timeout is recommended for Sentinel-support according to the guidelines. + // https://redis.io/topics/sentinel-clients#redis-service-discovery-via-sentinel + // For every address it should try to connect to the Sentinel, + // using a short timeout (in the order of a few hundreds of milliseconds). + timeout := 500 * time.Millisecond + + rdb = redis.NewFailoverClient(&redis.FailoverOptions{ + MasterName: cfg.SentinelMaster, + SentinelAddrs: sentinels, + Password: cfg.Password, + DB: db, + + PoolSize: maxActive, + MaxIdleConns: maxIdle, + ConnMaxIdleTime: defaultIdleTimeout, + + DialTimeout: timeout, + ReadTimeout: defaultReadTimeout, + WriteTimeout: defaultWriteTimeout, + + OnConnect: onConnectHook, + }) + } else { + opt, err := redis.ParseURL(cfg.URL.String()) + if err != nil { + return + } + + opt.DB = db + opt.Password = cfg.Password + + opt.PoolSize = maxActive + opt.MaxIdleConns = maxIdle + opt.ConnMaxIdleTime = defaultIdleTimeout + opt.ReadTimeout = defaultReadTimeout + opt.WriteTimeout = defaultWriteTimeout + + opt.OnConnect = onConnectHook + opt.Limiter = limiter + + rdb = redis.NewClient(opt) + } +} diff --git a/workhorse/internal/goredis/goredis_test.go b/workhorse/internal/goredis/goredis_test.go new file mode 100644 index 00000000000000..9ba1b424d22cc6 --- /dev/null +++ b/workhorse/internal/goredis/goredis_test.go @@ -0,0 +1,27 @@ +package goredis + +import ( + "os" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" +) + +func TestConfigureNoConfig(t *testing.T) { + rdb = nil + Configure(nil) + require.Nil(t, rdb, "rdb client should be nil") +} + +func TestConfigureValidConfig(t *testing.T) { + buf, _ := os.ReadFile("../../config.toml") + cfg, _ := config.LoadConfig(string(buf)) + + Configure(cfg.Redis) + + require.NotNil(t, rdb, "Pool should not be nil") + + rdb = nil +} diff --git a/workhorse/internal/goredis/keywatcher.go b/workhorse/internal/goredis/keywatcher.go new file mode 100644 index 00000000000000..741bfb17652394 --- /dev/null +++ b/workhorse/internal/goredis/keywatcher.go @@ -0,0 +1,236 @@ +package goredis + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "time" + + "github.com/jpillora/backoff" + "github.com/redis/go-redis/v9" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/log" + internalredis "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" +) + +type KeyWatcher struct { + mu sync.Mutex + subscribers map[string][]chan string + shutdown chan struct{} + reconnectBackoff backoff.Backoff + redisConn *redis.Client + conn *redis.PubSub +} + +func NewKeyWatcher() *KeyWatcher { + return &KeyWatcher{ + shutdown: make(chan struct{}), + reconnectBackoff: backoff.Backoff{ + Min: 100 * time.Millisecond, + Max: 60 * time.Second, + Factor: 2, + Jitter: true, + }, + } +} + +const channelPrefix = "workhorse:notifications:" + +func countAction(action string) { internalredis.TotalActions.WithLabelValues(action).Add(1) } + +func (kw *KeyWatcher) receivePubSubStream(ctx context.Context, pubsub *redis.PubSub) error { + kw.mu.Lock() + // We must share kw.conn with the goroutines that call SUBSCRIBE and + // UNSUBSCRIBE because Redis pubsub subscriptions are tied to the + // connection. + kw.conn = pubsub + kw.mu.Unlock() + + defer func() { + kw.mu.Lock() + defer kw.mu.Unlock() + kw.conn.Close() + kw.conn = nil + + // Reset kw.subscribers because it is tied to Redis server side state of + // kw.conn and we just closed that connection. + for _, chans := range kw.subscribers { + for _, ch := range chans { + close(ch) + internalredis.KeyWatchers.Dec() + } + } + kw.subscribers = nil + }() + + for { + msg, err := kw.conn.Receive(ctx) + if err != nil { + log.WithError(fmt.Errorf("keywatcher: pubsub receive: %v", err)).Error() + return nil + } + + switch msg := msg.(type) { + case *redis.Subscription: + internalredis.RedisSubscriptions.Set(float64(msg.Count)) + case *redis.Pong: + // Ignore. + case *redis.Message: + internalredis.TotalMessages.Inc() + internalredis.ReceivedBytes.Add(float64(len(msg.Payload))) + if strings.HasPrefix(msg.Channel, channelPrefix) { + kw.notifySubscribers(msg.Channel[len(channelPrefix):], string(msg.Payload)) + } + default: + log.WithError(fmt.Errorf("keywatcher: unknown: %T", msg)).Error() + return nil + } + } +} + +func (kw *KeyWatcher) Process(client *redis.Client) { + log.Info("keywatcher: starting process loop") + + ctx := context.Background() // lint:allow context.Background + kw.mu.Lock() + kw.redisConn = client + kw.mu.Unlock() + + for { + pubsub := client.Subscribe(ctx, []string{}...) + if err := pubsub.Ping(ctx); err != nil { + log.WithError(fmt.Errorf("keywatcher: %v", err)).Error() + time.Sleep(kw.reconnectBackoff.Duration()) + continue + } + + kw.reconnectBackoff.Reset() + + if err := kw.receivePubSubStream(ctx, pubsub); err != nil { + log.WithError(fmt.Errorf("keywatcher: receivePubSubStream: %v", err)).Error() + } + } +} + +func (kw *KeyWatcher) Shutdown() { + log.Info("keywatcher: shutting down") + + kw.mu.Lock() + defer kw.mu.Unlock() + + select { + case <-kw.shutdown: + // already closed + default: + close(kw.shutdown) + } +} + +func (kw *KeyWatcher) notifySubscribers(key, value string) { + kw.mu.Lock() + defer kw.mu.Unlock() + + chanList, ok := kw.subscribers[key] + if !ok { + countAction("drop-message") + return + } + + countAction("deliver-message") + for _, c := range chanList { + select { + case c <- value: + default: + } + } +} + +func (kw *KeyWatcher) addSubscription(ctx context.Context, key string, notify chan string) error { + kw.mu.Lock() + defer kw.mu.Unlock() + + if kw.conn == nil { + // This can happen because CI long polling is disabled in this Workhorse + // process. It can also be that we are waiting for the pubsub connection + // to be established. Either way it is OK to fail fast. + return errors.New("no redis connection") + } + + if len(kw.subscribers[key]) == 0 { + countAction("create-subscription") + if err := kw.conn.Subscribe(ctx, channelPrefix+key); err != nil { + return err + } + } + + if kw.subscribers == nil { + kw.subscribers = make(map[string][]chan string) + } + kw.subscribers[key] = append(kw.subscribers[key], notify) + internalredis.KeyWatchers.Inc() + + return nil +} + +func (kw *KeyWatcher) delSubscription(ctx context.Context, key string, notify chan string) { + kw.mu.Lock() + defer kw.mu.Unlock() + + chans, ok := kw.subscribers[key] + if !ok { + // This can happen if the pubsub connection dropped while we were + // waiting. + return + } + + for i, c := range chans { + if notify == c { + kw.subscribers[key] = append(chans[:i], chans[i+1:]...) + internalredis.KeyWatchers.Dec() + break + } + } + if len(kw.subscribers[key]) == 0 { + delete(kw.subscribers, key) + countAction("delete-subscription") + if kw.conn != nil { + kw.conn.Unsubscribe(ctx, channelPrefix+key) + } + } +} + +func (kw *KeyWatcher) WatchKey(ctx context.Context, key, value string, timeout time.Duration) (internalredis.WatchKeyStatus, error) { + notify := make(chan string, 1) + if err := kw.addSubscription(ctx, key, notify); err != nil { + return internalredis.WatchKeyStatusNoChange, err + } + defer kw.delSubscription(ctx, key, notify) + + currentValue, err := kw.redisConn.Get(ctx, key).Result() + if errors.Is(err, redis.Nil) { + currentValue = "" + } else if err != nil { + return internalredis.WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET: %v", err) + } + if currentValue != value { + return internalredis.WatchKeyStatusAlreadyChanged, nil + } + + select { + case <-kw.shutdown: + log.WithFields(log.Fields{"key": key}).Info("stopping watch due to shutdown") + return internalredis.WatchKeyStatusNoChange, nil + case currentValue := <-notify: + if currentValue == "" { + return internalredis.WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET failed") + } + if currentValue == value { + return internalredis.WatchKeyStatusNoChange, nil + } + return internalredis.WatchKeyStatusSeenChange, nil + case <-time.After(timeout): + return internalredis.WatchKeyStatusTimeout, nil + } +} diff --git a/workhorse/internal/goredis/keywatcher_test.go b/workhorse/internal/goredis/keywatcher_test.go new file mode 100644 index 00000000000000..b64262dc9c814c --- /dev/null +++ b/workhorse/internal/goredis/keywatcher_test.go @@ -0,0 +1,301 @@ +package goredis + +import ( + "context" + "os" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" +) + +var ctx = context.Background() + +const ( + runnerKey = "runner:build_queue:10" +) + +func initRdb() { + buf, _ := os.ReadFile("../../config.toml") + cfg, _ := config.LoadConfig(string(buf)) + Configure(cfg.Redis) +} + +func (kw *KeyWatcher) countSubscribers(key string) int { + kw.mu.Lock() + defer kw.mu.Unlock() + return len(kw.subscribers[key]) +} + +// Forces a run of the `Process` loop against a mock PubSubConn. +func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value string, ready chan<- struct{}, wg *sync.WaitGroup) { + kw.mu.Lock() + kw.redisConn = rdb + psc := kw.redisConn.Subscribe(ctx, []string{}...) + kw.mu.Unlock() + + errC := make(chan error) + go func() { errC <- kw.receivePubSubStream(ctx, psc) }() + + require.Eventually(t, func() bool { + kw.mu.Lock() + defer kw.mu.Unlock() + return kw.conn != nil + }, time.Second, time.Millisecond) + close(ready) + + require.Eventually(t, func() bool { + return kw.countSubscribers(runnerKey) == numWatchers + }, time.Second, time.Millisecond) + + // send message after listeners are ready + kw.redisConn.Publish(ctx, channelPrefix+runnerKey, value) + + // close subscription after all workers are done + wg.Wait() + kw.mu.Lock() + kw.conn.Close() + kw.mu.Unlock() + + require.NoError(t, <-errC) +} + +type keyChangeTestCase struct { + desc string + returnValue string + isKeyMissing bool + watchValue string + processedValue string + expectedStatus redis.WatchKeyStatus + timeout time.Duration +} + +func TestKeyChangesInstantReturn(t *testing.T) { + initRdb() + + testCases := []keyChangeTestCase{ + // WatchKeyStatusAlreadyChanged + { + desc: "sees change with key existing and changed", + returnValue: "somethingelse", + watchValue: "something", + expectedStatus: redis.WatchKeyStatusAlreadyChanged, + timeout: time.Second, + }, + { + desc: "sees change with key non-existing", + isKeyMissing: true, + watchValue: "something", + processedValue: "somethingelse", + expectedStatus: redis.WatchKeyStatusAlreadyChanged, + timeout: time.Second, + }, + // WatchKeyStatusTimeout + { + desc: "sees timeout with key existing and unchanged", + returnValue: "something", + watchValue: "something", + expectedStatus: redis.WatchKeyStatusTimeout, + timeout: time.Millisecond, + }, + { + desc: "sees timeout with key non-existing and unchanged", + isKeyMissing: true, + watchValue: "", + expectedStatus: redis.WatchKeyStatusTimeout, + timeout: time.Millisecond, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + + // setup + if !tc.isKeyMissing { + rdb.Set(ctx, runnerKey, tc.returnValue, 0) + } + + defer func() { + rdb.FlushDB(ctx) + }() + + kw := NewKeyWatcher() + defer kw.Shutdown() + kw.redisConn = rdb + kw.conn = kw.redisConn.Subscribe(ctx, []string{}...) + + val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, tc.timeout) + + require.NoError(t, err, "Expected no error") + require.Equal(t, tc.expectedStatus, val, "Expected value") + }) + } +} + +func TestKeyChangesWhenWatching(t *testing.T) { + initRdb() + + testCases := []keyChangeTestCase{ + // WatchKeyStatusSeenChange + { + desc: "sees change with key existing", + returnValue: "something", + watchValue: "something", + processedValue: "somethingelse", + expectedStatus: redis.WatchKeyStatusSeenChange, + }, + { + desc: "sees change with key non-existing, when watching empty value", + isKeyMissing: true, + watchValue: "", + processedValue: "something", + expectedStatus: redis.WatchKeyStatusSeenChange, + }, + // WatchKeyStatusNoChange + { + desc: "sees no change with key existing", + returnValue: "something", + watchValue: "something", + processedValue: "something", + expectedStatus: redis.WatchKeyStatusNoChange, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + if !tc.isKeyMissing { + rdb.Set(ctx, runnerKey, tc.returnValue, 0) + } + + kw := NewKeyWatcher() + defer kw.Shutdown() + defer func() { + rdb.FlushDB(ctx) + }() + + wg := &sync.WaitGroup{} + wg.Add(1) + ready := make(chan struct{}) + + go func() { + defer wg.Done() + <-ready + val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, time.Second) + + require.NoError(t, err, "Expected no error") + require.Equal(t, tc.expectedStatus, val, "Expected value") + }() + + kw.processMessages(t, 1, tc.processedValue, ready, wg) + }) + } +} + +func TestKeyChangesParallel(t *testing.T) { + initRdb() + + testCases := []keyChangeTestCase{ + { + desc: "massively parallel, sees change with key existing", + returnValue: "something", + watchValue: "something", + processedValue: "somethingelse", + expectedStatus: redis.WatchKeyStatusSeenChange, + }, + { + desc: "massively parallel, sees change with key existing, watching missing keys", + isKeyMissing: true, + watchValue: "", + processedValue: "somethingelse", + expectedStatus: redis.WatchKeyStatusSeenChange, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + runTimes := 100 + + if !tc.isKeyMissing { + rdb.Set(ctx, runnerKey, tc.returnValue, 0) + } + + defer func() { + rdb.FlushDB(ctx) + }() + + wg := &sync.WaitGroup{} + wg.Add(runTimes) + ready := make(chan struct{}) + + kw := NewKeyWatcher() + defer kw.Shutdown() + + for i := 0; i < runTimes; i++ { + go func() { + defer wg.Done() + <-ready + val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, time.Second) + + require.NoError(t, err, "Expected no error") + require.Equal(t, tc.expectedStatus, val, "Expected value") + }() + } + + kw.processMessages(t, runTimes, tc.processedValue, ready, wg) + }) + } +} + +func TestShutdown(t *testing.T) { + initRdb() + + kw := NewKeyWatcher() + kw.redisConn = rdb + kw.conn = kw.redisConn.Subscribe(ctx, []string{}...) + defer kw.Shutdown() + + rdb.Set(ctx, runnerKey, "something", 0) + + wg := &sync.WaitGroup{} + wg.Add(2) + + go func() { + defer wg.Done() + val, err := kw.WatchKey(ctx, runnerKey, "something", 10*time.Second) + + require.NoError(t, err, "Expected no error") + require.Equal(t, redis.WatchKeyStatusNoChange, val, "Expected value not to change") + }() + + go func() { + defer wg.Done() + require.Eventually(t, func() bool { return kw.countSubscribers(runnerKey) == 1 }, 10*time.Second, time.Millisecond) + + kw.Shutdown() + }() + + wg.Wait() + + require.Eventually(t, func() bool { return kw.countSubscribers(runnerKey) == 0 }, 10*time.Second, time.Millisecond) + + // Adding a key after the shutdown should result in an immediate response + var val redis.WatchKeyStatus + var err error + done := make(chan struct{}) + go func() { + val, err = kw.WatchKey(ctx, runnerKey, "something", 10*time.Second) + close(done) + }() + + select { + case <-done: + require.NoError(t, err, "Expected no error") + require.Equal(t, redis.WatchKeyStatusNoChange, val, "Expected value not to change") + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for WatchKey") + } +} diff --git a/workhorse/internal/redis/keywatcher.go b/workhorse/internal/redis/keywatcher.go index 8618e2a015e11c..8f1772a91958c4 100644 --- a/workhorse/internal/redis/keywatcher.go +++ b/workhorse/internal/redis/keywatcher.go @@ -8,10 +8,10 @@ import ( "sync" "time" + "github.com/gomodule/redigo/redis" "github.com/jpillora/backoff" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - goredis "github.com/redis/go-redis/v9" "gitlab.com/gitlab-org/gitlab/workhorse/internal/log" ) @@ -21,8 +21,7 @@ type KeyWatcher struct { subscribers map[string][]chan string shutdown chan struct{} reconnectBackoff backoff.Backoff - redisConn *goredis.Client - conn *goredis.PubSub + conn *redis.PubSubConn } func NewKeyWatcher() *KeyWatcher { @@ -38,32 +37,32 @@ func NewKeyWatcher() *KeyWatcher { } var ( - keyWatchers = promauto.NewGauge( + KeyWatchers = promauto.NewGauge( prometheus.GaugeOpts{ Name: "gitlab_workhorse_keywatcher_keywatchers", Help: "The number of keys that is being watched by gitlab-workhorse", }, ) - redisSubscriptions = promauto.NewGauge( + RedisSubscriptions = promauto.NewGauge( prometheus.GaugeOpts{ Name: "gitlab_workhorse_keywatcher_redis_subscriptions", Help: "Current number of keywatcher Redis pubsub subscriptions", }, ) - totalMessages = promauto.NewCounter( + TotalMessages = promauto.NewCounter( prometheus.CounterOpts{ Name: "gitlab_workhorse_keywatcher_total_messages", Help: "How many messages gitlab-workhorse has received in total on pubsub.", }, ) - totalActions = promauto.NewCounterVec( + TotalActions = promauto.NewCounterVec( prometheus.CounterOpts{ Name: "gitlab_workhorse_keywatcher_actions_total", Help: "Counts of various keywatcher actions", }, []string{"action"}, ) - receivedBytes = promauto.NewCounter( + ReceivedBytes = promauto.NewCounter( prometheus.CounterOpts{ Name: "gitlab_workhorse_keywatcher_received_bytes_total", Help: "How many bytes of messages gitlab-workhorse has received in total on pubsub.", @@ -73,14 +72,14 @@ var ( const channelPrefix = "workhorse:notifications:" -func countAction(action string) { totalActions.WithLabelValues(action).Add(1) } +func countAction(action string) { TotalActions.WithLabelValues(action).Add(1) } -func (kw *KeyWatcher) receivePubSubStream(ctx context.Context, pubsub *goredis.PubSub) error { +func (kw *KeyWatcher) receivePubSubStream(conn redis.Conn) error { kw.mu.Lock() // We must share kw.conn with the goroutines that call SUBSCRIBE and // UNSUBSCRIBE because Redis pubsub subscriptions are tied to the // connection. - kw.conn = pubsub + kw.conn = &redis.PubSubConn{Conn: conn} kw.mu.Unlock() defer func() { @@ -94,56 +93,58 @@ func (kw *KeyWatcher) receivePubSubStream(ctx context.Context, pubsub *goredis.P for _, chans := range kw.subscribers { for _, ch := range chans { close(ch) - keyWatchers.Dec() + KeyWatchers.Dec() } } kw.subscribers = nil }() for { - msg, err := kw.conn.Receive(ctx) - if err != nil { - log.WithError(fmt.Errorf("keywatcher: pubsub receive: %v", err)).Error() - return nil - } - - switch msg := msg.(type) { - case *goredis.Subscription: - redisSubscriptions.Set(float64(msg.Count)) - case *goredis.Pong: - // Ignore. - case *goredis.Message: - totalMessages.Inc() - receivedBytes.Add(float64(len(msg.Payload))) - if strings.HasPrefix(msg.Channel, channelPrefix) { - kw.notifySubscribers(msg.Channel[len(channelPrefix):], string(msg.Payload)) + switch v := kw.conn.Receive().(type) { + case redis.Message: + TotalMessages.Inc() + ReceivedBytes.Add(float64(len(v.Data))) + if strings.HasPrefix(v.Channel, channelPrefix) { + kw.notifySubscribers(v.Channel[len(channelPrefix):], string(v.Data)) } - default: - log.WithError(fmt.Errorf("keywatcher: unknown: %T", msg)).Error() + case redis.Subscription: + RedisSubscriptions.Set(float64(v.Count)) + case error: + log.WithError(fmt.Errorf("keywatcher: pubsub receive: %v", v)).Error() + // Intermittent error, return nil so that it doesn't wait before reconnect return nil } } } -func (kw *KeyWatcher) Process(client *goredis.Client) { - log.Info("keywatcher: starting process loop") +func dialPubSub(dialer redisDialerFunc) (redis.Conn, error) { + conn, err := dialer() + if err != nil { + return nil, err + } - ctx := context.Background() // lint:allow context.Background - kw.mu.Lock() - kw.redisConn = client - kw.mu.Unlock() + // Make sure Redis is actually connected + conn.Do("PING") + if err := conn.Err(); err != nil { + conn.Close() + return nil, err + } + + return conn, nil +} +func (kw *KeyWatcher) Process() { + log.Info("keywatcher: starting process loop") for { - pubsub := client.Subscribe(ctx, []string{}...) - if err := pubsub.Ping(ctx); err != nil { + conn, err := dialPubSub(workerDialFunc) + if err != nil { log.WithError(fmt.Errorf("keywatcher: %v", err)).Error() time.Sleep(kw.reconnectBackoff.Duration()) continue } - kw.reconnectBackoff.Reset() - if err := kw.receivePubSubStream(ctx, pubsub); err != nil { + if err = kw.receivePubSubStream(conn); err != nil { log.WithError(fmt.Errorf("keywatcher: receivePubSubStream: %v", err)).Error() } } @@ -182,7 +183,7 @@ func (kw *KeyWatcher) notifySubscribers(key, value string) { } } -func (kw *KeyWatcher) addSubscription(ctx context.Context, key string, notify chan string) error { +func (kw *KeyWatcher) addSubscription(key string, notify chan string) error { kw.mu.Lock() defer kw.mu.Unlock() @@ -195,7 +196,7 @@ func (kw *KeyWatcher) addSubscription(ctx context.Context, key string, notify ch if len(kw.subscribers[key]) == 0 { countAction("create-subscription") - if err := kw.conn.Subscribe(ctx, channelPrefix+key); err != nil { + if err := kw.conn.Subscribe(channelPrefix + key); err != nil { return err } } @@ -204,12 +205,12 @@ func (kw *KeyWatcher) addSubscription(ctx context.Context, key string, notify ch kw.subscribers = make(map[string][]chan string) } kw.subscribers[key] = append(kw.subscribers[key], notify) - keyWatchers.Inc() + KeyWatchers.Inc() return nil } -func (kw *KeyWatcher) delSubscription(ctx context.Context, key string, notify chan string) { +func (kw *KeyWatcher) delSubscription(key string, notify chan string) { kw.mu.Lock() defer kw.mu.Unlock() @@ -223,7 +224,7 @@ func (kw *KeyWatcher) delSubscription(ctx context.Context, key string, notify ch for i, c := range chans { if notify == c { kw.subscribers[key] = append(chans[:i], chans[i+1:]...) - keyWatchers.Dec() + KeyWatchers.Dec() break } } @@ -231,7 +232,7 @@ func (kw *KeyWatcher) delSubscription(ctx context.Context, key string, notify ch delete(kw.subscribers, key) countAction("delete-subscription") if kw.conn != nil { - kw.conn.Unsubscribe(ctx, channelPrefix+key) + kw.conn.Unsubscribe(channelPrefix + key) } } } @@ -251,15 +252,15 @@ const ( WatchKeyStatusNoChange ) -func (kw *KeyWatcher) WatchKey(ctx context.Context, key, value string, timeout time.Duration) (WatchKeyStatus, error) { +func (kw *KeyWatcher) WatchKey(_ context.Context, key, value string, timeout time.Duration) (WatchKeyStatus, error) { notify := make(chan string, 1) - if err := kw.addSubscription(ctx, key, notify); err != nil { + if err := kw.addSubscription(key, notify); err != nil { return WatchKeyStatusNoChange, err } - defer kw.delSubscription(ctx, key, notify) + defer kw.delSubscription(key, notify) - currentValue, err := kw.redisConn.Get(ctx, key).Result() - if errors.Is(err, goredis.Nil) { + currentValue, err := GetString(key) + if errors.Is(err, redis.ErrNil) { currentValue = "" } else if err != nil { return WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET: %v", err) diff --git a/workhorse/internal/redis/keywatcher_test.go b/workhorse/internal/redis/keywatcher_test.go index bca4ca43a644a6..33daa07569fa0e 100644 --- a/workhorse/internal/redis/keywatcher_test.go +++ b/workhorse/internal/redis/keywatcher_test.go @@ -1,15 +1,14 @@ package redis import ( - "context" - "os" "sync" "testing" "time" + "context" + "github.com/gomodule/redigo/redis" + "github.com/rafaeljusto/redigomock/v3" "github.com/stretchr/testify/require" - - "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" ) var ctx = context.Background() @@ -18,10 +17,27 @@ const ( runnerKey = "runner:build_queue:10" ) -func initRdb() { - buf, _ := os.ReadFile("../../config.toml") - cfg, _ := config.LoadConfig(string(buf)) - Configure(cfg.Redis) +func createSubscriptionMessage(key, data string) []interface{} { + return []interface{}{ + []byte("message"), + []byte(key), + []byte(data), + } +} + +func createSubscribeMessage(key string) []interface{} { + return []interface{}{ + []byte("subscribe"), + []byte(key), + []byte("1"), + } +} +func createUnsubscribeMessage(key string) []interface{} { + return []interface{}{ + []byte("unsubscribe"), + []byte(key), + []byte("1"), + } } func (kw *KeyWatcher) countSubscribers(key string) int { @@ -31,14 +47,17 @@ func (kw *KeyWatcher) countSubscribers(key string) int { } // Forces a run of the `Process` loop against a mock PubSubConn. -func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value string, ready chan<- struct{}, wg *sync.WaitGroup) { - kw.mu.Lock() - kw.redisConn = rdb - psc := kw.redisConn.Subscribe(ctx, []string{}...) - kw.mu.Unlock() +func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value string, ready chan<- struct{}) { + psc := redigomock.NewConn() + psc.ReceiveWait = true + + channel := channelPrefix + runnerKey + psc.Command("SUBSCRIBE", channel).Expect(createSubscribeMessage(channel)) + psc.Command("UNSUBSCRIBE", channel).Expect(createUnsubscribeMessage(channel)) + psc.AddSubscriptionMessage(createSubscriptionMessage(channel, value)) errC := make(chan error) - go func() { errC <- kw.receivePubSubStream(ctx, psc) }() + go func() { errC <- kw.receivePubSubStream(psc) }() require.Eventually(t, func() bool { kw.mu.Lock() @@ -50,15 +69,7 @@ func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value strin require.Eventually(t, func() bool { return kw.countSubscribers(runnerKey) == numWatchers }, time.Second, time.Millisecond) - - // send message after listeners are ready - kw.redisConn.Publish(ctx, channelPrefix+runnerKey, value) - - // close subscription after all workers are done - wg.Wait() - kw.mu.Lock() - kw.conn.Close() - kw.mu.Unlock() + close(psc.ReceiveNow) require.NoError(t, <-errC) } @@ -74,8 +85,6 @@ type keyChangeTestCase struct { } func TestKeyChangesInstantReturn(t *testing.T) { - initRdb() - testCases := []keyChangeTestCase{ // WatchKeyStatusAlreadyChanged { @@ -112,20 +121,18 @@ func TestKeyChangesInstantReturn(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { + conn, td := setupMockPool() + defer td() - // setup - if !tc.isKeyMissing { - rdb.Set(ctx, runnerKey, tc.returnValue, 0) + if tc.isKeyMissing { + conn.Command("GET", runnerKey).ExpectError(redis.ErrNil) + } else { + conn.Command("GET", runnerKey).Expect(tc.returnValue) } - defer func() { - rdb.FlushDB(ctx) - }() - kw := NewKeyWatcher() defer kw.Shutdown() - kw.redisConn = rdb - kw.conn = kw.redisConn.Subscribe(ctx, []string{}...) + kw.conn = &redis.PubSubConn{Conn: redigomock.NewConn()} val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, tc.timeout) @@ -136,8 +143,6 @@ func TestKeyChangesInstantReturn(t *testing.T) { } func TestKeyChangesWhenWatching(t *testing.T) { - initRdb() - testCases := []keyChangeTestCase{ // WatchKeyStatusSeenChange { @@ -166,15 +171,17 @@ func TestKeyChangesWhenWatching(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - if !tc.isKeyMissing { - rdb.Set(ctx, runnerKey, tc.returnValue, 0) + conn, td := setupMockPool() + defer td() + + if tc.isKeyMissing { + conn.Command("GET", runnerKey).ExpectError(redis.ErrNil) + } else { + conn.Command("GET", runnerKey).Expect(tc.returnValue) } kw := NewKeyWatcher() defer kw.Shutdown() - defer func() { - rdb.FlushDB(ctx) - }() wg := &sync.WaitGroup{} wg.Add(1) @@ -189,14 +196,13 @@ func TestKeyChangesWhenWatching(t *testing.T) { require.Equal(t, tc.expectedStatus, val, "Expected value") }() - kw.processMessages(t, 1, tc.processedValue, ready, wg) + kw.processMessages(t, 1, tc.processedValue, ready) + wg.Wait() }) } } func TestKeyChangesParallel(t *testing.T) { - initRdb() - testCases := []keyChangeTestCase{ { desc: "massively parallel, sees change with key existing", @@ -218,13 +224,18 @@ func TestKeyChangesParallel(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { runTimes := 100 - if !tc.isKeyMissing { - rdb.Set(ctx, runnerKey, tc.returnValue, 0) - } + conn, td := setupMockPool() + defer td() - defer func() { - rdb.FlushDB(ctx) - }() + getCmd := conn.Command("GET", runnerKey) + + for i := 0; i < runTimes; i++ { + if tc.isKeyMissing { + getCmd = getCmd.ExpectError(redis.ErrNil) + } else { + getCmd = getCmd.Expect(tc.returnValue) + } + } wg := &sync.WaitGroup{} wg.Add(runTimes) @@ -244,20 +255,21 @@ func TestKeyChangesParallel(t *testing.T) { }() } - kw.processMessages(t, runTimes, tc.processedValue, ready, wg) + kw.processMessages(t, runTimes, tc.processedValue, ready) + wg.Wait() }) } } func TestShutdown(t *testing.T) { - initRdb() + conn, td := setupMockPool() + defer td() kw := NewKeyWatcher() - kw.redisConn = rdb - kw.conn = kw.redisConn.Subscribe(ctx, []string{}...) + kw.conn = &redis.PubSubConn{Conn: redigomock.NewConn()} defer kw.Shutdown() - rdb.Set(ctx, runnerKey, "something", 0) + conn.Command("GET", runnerKey).Expect("something") wg := &sync.WaitGroup{} wg.Add(2) diff --git a/workhorse/internal/redis/redis.go b/workhorse/internal/redis/redis.go index 7aa6be9b6795e2..c79e1e56b3ad16 100644 --- a/workhorse/internal/redis/redis.go +++ b/workhorse/internal/redis/redis.go @@ -1,20 +1,24 @@ package redis import ( - "context" "fmt" + "net" + "net/url" "time" + "github.com/FZambia/sentinel" + "github.com/gomodule/redigo/redis" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - goredis "github.com/redis/go-redis/v9" + "gitlab.com/gitlab-org/labkit/log" "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" - _ "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" ) var ( - rdb *goredis.Client + pool *redis.Pool + sntnl *sentinel.Sentinel ) const ( @@ -32,17 +36,23 @@ const ( // If you _actually_ hit this timeout often, you should consider turning of // redis-support since it's not necessary at that point... defaultIdleTimeout = 3 * time.Minute + // KeepAlivePeriod is to keep a TCP connection open for an extended period of + // time without being killed. This is used both in the pool, and in the + // worker-connection. + // See https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive for more + // information. + defaultKeepAlivePeriod = 5 * time.Minute ) var ( - totalConnections = promauto.NewCounter( + TotalConnections = promauto.NewCounter( prometheus.CounterOpts{ Name: "gitlab_workhorse_redis_total_connections", Help: "How many connections gitlab-workhorse has opened in total. Can be used to track Redis connection rate for this process", }, ) - errorCounter = promauto.NewCounterVec( + ErrorCounter = promauto.NewCounterVec( prometheus.CounterOpts{ Name: "gitlab_workhorse_redis_errors", Help: "Counts different types of Redis errors encountered by workhorse, by type and destination (redis, sentinel)", @@ -51,91 +61,216 @@ var ( ) ) -// this Limiter effectively acts as a middleware to track dial successes and errors -type observabilityLimiter struct{} +func sentinelConn(master string, urls []config.TomlURL) *sentinel.Sentinel { + if len(urls) == 0 { + return nil + } + var addrs []string + for _, url := range urls { + h := url.URL.String() + log.WithFields(log.Fields{ + "scheme": url.URL.Scheme, + "host": url.URL.Host, + }).Printf("redis: using sentinel") + addrs = append(addrs, h) + } + return &sentinel.Sentinel{ + Addrs: addrs, + MasterName: master, + Dial: func(addr string) (redis.Conn, error) { + // This timeout is recommended for Sentinel-support according to the guidelines. + // https://redis.io/topics/sentinel-clients#redis-service-discovery-via-sentinel + // For every address it should try to connect to the Sentinel, + // using a short timeout (in the order of a few hundreds of milliseconds). + timeout := 500 * time.Millisecond + url := helper.URLMustParse(addr) -func (o observabilityLimiter) Allow() error { return nil } + var c redis.Conn + var err error + options := []redis.DialOption{ + redis.DialConnectTimeout(timeout), + redis.DialReadTimeout(timeout), + redis.DialWriteTimeout(timeout), + } -func (o observabilityLimiter) ReportResult(result error) { - errorCounter.WithLabelValues("dial", "redis").Inc() + if url.Scheme == "redis" || url.Scheme == "rediss" { + c, err = redis.DialURL(addr, options...) + } else { + c, err = redis.Dial("tcp", url.Host, options...) + } + + if err != nil { + ErrorCounter.WithLabelValues("dial", "sentinel").Inc() + return nil, err + } + return c, nil + }, + } } -func GetRedisClient() *goredis.Client { - return rdb +var poolDialFunc func() (redis.Conn, error) +var workerDialFunc func() (redis.Conn, error) + +func timeoutDialOptions(cfg *config.RedisConfig) []redis.DialOption { + return []redis.DialOption{ + redis.DialReadTimeout(defaultReadTimeout), + redis.DialWriteTimeout(defaultWriteTimeout), + } } -// Configure redis-connection -func Configure(cfg *config.RedisConfig) { +func dialOptionsBuilder(cfg *config.RedisConfig, setTimeouts bool) []redis.DialOption { + var dopts []redis.DialOption + if setTimeouts { + dopts = timeoutDialOptions(cfg) + } if cfg == nil { - return + return dopts } - maxIdle := defaultMaxIdle - if cfg.MaxIdle != nil { - maxIdle = *cfg.MaxIdle + if cfg.Password != "" { + dopts = append(dopts, redis.DialPassword(cfg.Password)) } - maxActive := defaultMaxActive - if cfg.MaxActive != nil { - maxActive = *cfg.MaxActive - } - db := 0 if cfg.DB != nil { - db = *cfg.DB + dopts = append(dopts, redis.DialDatabase(*cfg.DB)) } + return dopts +} - limiter := observabilityLimiter{} +func keepAliveDialer(network, address string) (net.Conn, error) { + addr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + tc, err := net.DialTCP(network, nil, addr) + if err != nil { + return nil, err + } + if err := tc.SetKeepAlive(true); err != nil { + return nil, err + } + if err := tc.SetKeepAlivePeriod(defaultKeepAlivePeriod); err != nil { + return nil, err + } + return tc, nil +} - onConnectHook := func(ctx context.Context, cn *goredis.Conn) error { - totalConnections.Inc() - return nil +type redisDialerFunc func() (redis.Conn, error) + +func sentinelDialer(dopts []redis.DialOption) redisDialerFunc { + return func() (redis.Conn, error) { + address, err := sntnl.MasterAddr() + if err != nil { + ErrorCounter.WithLabelValues("master", "sentinel").Inc() + return nil, err + } + dopts = append(dopts, redis.DialNetDial(keepAliveDialer)) + conn, err := redisDial("tcp", address, dopts...) + if err != nil { + return nil, err + } + if !sentinel.TestRole(conn, "master") { + conn.Close() + return nil, fmt.Errorf("%s is not redis master", address) + } + return conn, nil } +} - if len(cfg.Sentinel) > 0 { - sentinels := make([]string, len(cfg.Sentinel)) - for i := range cfg.Sentinel { - sentinelDetails := cfg.Sentinel[i] - sentinels[i] = fmt.Sprintf("%s:%s", sentinelDetails.Hostname(), sentinelDetails.Port()) +func defaultDialer(dopts []redis.DialOption, url url.URL) redisDialerFunc { + return func() (redis.Conn, error) { + if url.Scheme == "unix" { + return redisDial(url.Scheme, url.Path, dopts...) } - // This timeout is recommended for Sentinel-support according to the guidelines. - // https://redis.io/topics/sentinel-clients#redis-service-discovery-via-sentinel - // For every address it should try to connect to the Sentinel, - // using a short timeout (in the order of a few hundreds of milliseconds). - timeout := 500 * time.Millisecond - - rdb = goredis.NewFailoverClient(&goredis.FailoverOptions{ - MasterName: cfg.SentinelMaster, - SentinelAddrs: sentinels, - Password: cfg.Password, - DB: db, - - PoolSize: maxActive, - MaxIdleConns: maxIdle, - ConnMaxIdleTime: defaultIdleTimeout, - - DialTimeout: timeout, - ReadTimeout: defaultReadTimeout, - WriteTimeout: defaultWriteTimeout, - - OnConnect: onConnectHook, - }) - } else { - opt, err := goredis.ParseURL(cfg.URL.String()) + dopts = append(dopts, redis.DialNetDial(keepAliveDialer)) + + // redis.DialURL only works with redis[s]:// URLs + if url.Scheme == "redis" || url.Scheme == "rediss" { + return redisURLDial(url, dopts...) + } + + return redisDial(url.Scheme, url.Host, dopts...) + } +} + +func redisURLDial(url url.URL, options ...redis.DialOption) (redis.Conn, error) { + log.WithFields(log.Fields{ + "scheme": url.Scheme, + "address": url.Host, + }).Printf("redis: dialing") + + return redis.DialURL(url.String(), options...) +} + +func redisDial(network, address string, options ...redis.DialOption) (redis.Conn, error) { + log.WithFields(log.Fields{ + "network": network, + "address": address, + }).Printf("redis: dialing") + + return redis.Dial(network, address, options...) +} + +func countDialer(dialer redisDialerFunc) redisDialerFunc { + return func() (redis.Conn, error) { + c, err := dialer() if err != nil { - return + ErrorCounter.WithLabelValues("dial", "redis").Inc() + } else { + TotalConnections.Inc() } + return c, err + } +} - opt.DB = db - opt.Password = cfg.Password +// DefaultDialFunc should always used. Only exception is for unit-tests. +func DefaultDialFunc(cfg *config.RedisConfig, setReadTimeout bool) func() (redis.Conn, error) { + dopts := dialOptionsBuilder(cfg, setReadTimeout) + if sntnl != nil { + return countDialer(sentinelDialer(dopts)) + } + return countDialer(defaultDialer(dopts, cfg.URL.URL)) +} - opt.PoolSize = maxActive - opt.MaxIdleConns = maxIdle - opt.ConnMaxIdleTime = defaultIdleTimeout - opt.ReadTimeout = defaultReadTimeout - opt.WriteTimeout = defaultWriteTimeout +// Configure redis-connection +func Configure(cfg *config.RedisConfig, dialFunc func(*config.RedisConfig, bool) func() (redis.Conn, error)) { + if cfg == nil { + return + } + maxIdle := defaultMaxIdle + if cfg.MaxIdle != nil { + maxIdle = *cfg.MaxIdle + } + maxActive := defaultMaxActive + if cfg.MaxActive != nil { + maxActive = *cfg.MaxActive + } + sntnl = sentinelConn(cfg.SentinelMaster, cfg.Sentinel) + workerDialFunc = dialFunc(cfg, false) + poolDialFunc = dialFunc(cfg, true) + pool = &redis.Pool{ + MaxIdle: maxIdle, // Keep at most X hot connections + MaxActive: maxActive, // Keep at most X live connections, 0 means unlimited + IdleTimeout: defaultIdleTimeout, // X time until an unused connection is closed + Dial: poolDialFunc, + Wait: true, + } +} - opt.OnConnect = onConnectHook - opt.Limiter = limiter +// Get a connection for the Redis-pool +func Get() redis.Conn { + if pool != nil { + return pool.Get() + } + return nil +} - rdb = goredis.NewClient(opt) +// GetString fetches the value of a key in Redis as a string +func GetString(key string) (string, error) { + conn := Get() + if conn == nil { + return "", fmt.Errorf("redis: could not get connection from pool") } + defer conn.Close() + + return redis.String(conn.Do("GET", key)) } diff --git a/workhorse/internal/redis/redis_test.go b/workhorse/internal/redis/redis_test.go index 63adb9d90288db..64b3a842a54e50 100644 --- a/workhorse/internal/redis/redis_test.go +++ b/workhorse/internal/redis/redis_test.go @@ -1,27 +1,228 @@ package redis import ( - "os" + "net" "testing" + "time" + "github.com/gomodule/redigo/redis" + "github.com/rafaeljusto/redigomock/v3" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" ) +func mockRedisServer(t *testing.T, connectReceived *bool) string { + ln, err := net.Listen("tcp", "127.0.0.1:0") + + require.Nil(t, err) + + go func() { + defer ln.Close() + conn, err := ln.Accept() + require.Nil(t, err) + *connectReceived = true + conn.Write([]byte("OK\n")) + }() + + return ln.Addr().String() +} + +// Setup a MockPool for Redis +// +// Returns a teardown-function and the mock-connection +func setupMockPool() (*redigomock.Conn, func()) { + conn := redigomock.NewConn() + cfg := &config.RedisConfig{URL: config.TomlURL{}} + Configure(cfg, func(_ *config.RedisConfig, _ bool) func() (redis.Conn, error) { + return func() (redis.Conn, error) { + return conn, nil + } + }) + return conn, func() { + pool = nil + } +} + +func TestDefaultDialFunc(t *testing.T) { + testCases := []struct { + scheme string + }{ + { + scheme: "tcp", + }, + { + scheme: "redis", + }, + } + + for _, tc := range testCases { + t.Run(tc.scheme, func(t *testing.T) { + connectReceived := false + a := mockRedisServer(t, &connectReceived) + + parsedURL := helper.URLMustParse(tc.scheme + "://" + a) + cfg := &config.RedisConfig{URL: config.TomlURL{URL: *parsedURL}} + + dialer := DefaultDialFunc(cfg, true) + conn, err := dialer() + + require.Nil(t, err) + conn.Receive() + + require.True(t, connectReceived) + }) + } +} + func TestConfigureNoConfig(t *testing.T) { - rdb = nil - Configure(nil) - require.Nil(t, rdb, "rdb client should be nil") + pool = nil + Configure(nil, nil) + require.Nil(t, pool, "Pool should be nil") +} + +func TestConfigureMinimalConfig(t *testing.T) { + cfg := &config.RedisConfig{URL: config.TomlURL{}, Password: ""} + Configure(cfg, DefaultDialFunc) + + require.NotNil(t, pool, "Pool should not be nil") + require.Equal(t, 1, pool.MaxIdle) + require.Equal(t, 1, pool.MaxActive) + require.Equal(t, 3*time.Minute, pool.IdleTimeout) + + pool = nil +} + +func TestConfigureFullConfig(t *testing.T) { + i, a := 4, 10 + cfg := &config.RedisConfig{ + URL: config.TomlURL{}, + Password: "", + MaxIdle: &i, + MaxActive: &a, + } + Configure(cfg, DefaultDialFunc) + + require.NotNil(t, pool, "Pool should not be nil") + require.Equal(t, i, pool.MaxIdle) + require.Equal(t, a, pool.MaxActive) + require.Equal(t, 3*time.Minute, pool.IdleTimeout) + + pool = nil +} + +func TestGetConnFail(t *testing.T) { + conn := Get() + require.Nil(t, conn, "Expected `conn` to be nil") +} + +func TestGetConnPass(t *testing.T) { + _, teardown := setupMockPool() + defer teardown() + conn := Get() + require.NotNil(t, conn, "Expected `conn` to be non-nil") } -func TestConfigureValidConfig(t *testing.T) { - buf, _ := os.ReadFile("../../config.toml") - cfg, _ := config.LoadConfig(string(buf)) +func TestGetStringPass(t *testing.T) { + conn, teardown := setupMockPool() + defer teardown() + conn.Command("GET", "foobar").Expect("baz") + str, err := GetString("foobar") - Configure(cfg.Redis) + require.NoError(t, err, "Expected `err` to be nil") + var value string + require.IsType(t, value, str, "Expected value to be a string") + require.Equal(t, "baz", str, "Expected it to be equal") +} + +func TestGetStringFail(t *testing.T) { + _, err := GetString("foobar") + require.Error(t, err, "Expected error when not connected to redis") +} + +func TestSentinelConnNoSentinel(t *testing.T) { + s := sentinelConn("", []config.TomlURL{}) + + require.Nil(t, s, "Sentinel without urls should return nil") +} + +func TestSentinelConnDialURL(t *testing.T) { + testCases := []struct { + scheme string + }{ + { + scheme: "tcp", + }, + { + scheme: "redis", + }, + } + + for _, tc := range testCases { + t.Run(tc.scheme, func(t *testing.T) { + connectReceived := false + a := mockRedisServer(t, &connectReceived) + + addrs := []string{tc.scheme + "://" + a} + var sentinelUrls []config.TomlURL + + for _, a := range addrs { + parsedURL := helper.URLMustParse(a) + sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL}) + } + + s := sentinelConn("foobar", sentinelUrls) + require.Equal(t, len(addrs), len(s.Addrs)) + + for i := range addrs { + require.Equal(t, addrs[i], s.Addrs[i]) + } + + conn, err := s.Dial(s.Addrs[0]) - require.NotNil(t, rdb, "Pool should not be nil") + require.Nil(t, err) + conn.Receive() + + require.True(t, connectReceived) + }) + } +} + +func TestSentinelConnTwoURLs(t *testing.T) { + addrs := []string{"tcp://10.0.0.1:12345", "tcp://10.0.0.2:12345"} + var sentinelUrls []config.TomlURL + + for _, a := range addrs { + parsedURL := helper.URLMustParse(a) + sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL}) + } + + s := sentinelConn("foobar", sentinelUrls) + require.Equal(t, len(addrs), len(s.Addrs)) + + for i := range addrs { + require.Equal(t, addrs[i], s.Addrs[i]) + } +} + +func TestDialOptionsBuildersPassword(t *testing.T) { + dopts := dialOptionsBuilder(&config.RedisConfig{Password: "foo"}, false) + require.Equal(t, 1, len(dopts)) +} + +func TestDialOptionsBuildersSetTimeouts(t *testing.T) { + dopts := dialOptionsBuilder(nil, true) + require.Equal(t, 2, len(dopts)) +} + +func TestDialOptionsBuildersSetTimeoutsConfig(t *testing.T) { + dopts := dialOptionsBuilder(nil, true) + require.Equal(t, 2, len(dopts)) +} - rdb = nil +func TestDialOptionsBuildersSelectDB(t *testing.T) { + db := 3 + dopts := dialOptionsBuilder(&config.RedisConfig{DB: &db}, false) + require.Equal(t, 1, len(dopts)) } diff --git a/workhorse/main.go b/workhorse/main.go index 6d79f17850a45e..ab06c83eb5b3cb 100644 --- a/workhorse/main.go +++ b/workhorse/main.go @@ -17,8 +17,10 @@ import ( "gitlab.com/gitlab-org/labkit/monitoring" "gitlab.com/gitlab-org/labkit/tracing" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/builds" "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" "gitlab.com/gitlab-org/gitlab/workhorse/internal/gitaly" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/goredis" "gitlab.com/gitlab-org/gitlab/workhorse/internal/queueing" "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" "gitlab.com/gitlab-org/gitlab/workhorse/internal/secret" @@ -224,9 +226,24 @@ func run(boot bootConfig, cfg config.Config) error { secret.SetPath(boot.secretPath) keyWatcher := redis.NewKeyWatcher() - redis.Configure(cfg.Redis) - if rdb := redis.GetRedisClient(); rdb != nil { - go keyWatcher.Process(rdb) + goredisKeyWatcher := goredis.NewKeyWatcher() + + var watchKeyFn builds.WatchKeyHandler + + if os.Getenv("GITLAB_WORKHORSE_FF_GO_REDIS_ENABLED") == "true" { + goredis.Configure(cfg.Redis) + if rdb := goredis.GetRedisClient(); rdb != nil { + go goredisKeyWatcher.Process(rdb) + } + + watchKeyFn = goredisKeyWatcher.WatchKey + } else { + if cfg.Redis != nil { + redis.Configure(cfg.Redis, redis.DefaultDialFunc) + go keyWatcher.Process() + } + + watchKeyFn = keyWatcher.WatchKey } if err := cfg.RegisterGoCloudURLOpeners(); err != nil { @@ -241,7 +258,7 @@ func run(boot bootConfig, cfg config.Config) error { gitaly.InitializeSidechannelRegistry(accessLogger) - up := wrapRaven(upstream.NewUpstream(cfg, accessLogger, keyWatcher.WatchKey)) + up := wrapRaven(upstream.NewUpstream(cfg, accessLogger, watchKeyFn)) done := make(chan os.Signal, 1) signal.Notify(done, syscall.SIGINT, syscall.SIGTERM) @@ -276,6 +293,7 @@ func run(boot bootConfig, cfg config.Config) error { defer cancel() keyWatcher.Shutdown() + goredisKeyWatcher.Shutdown() return srv.Shutdown(ctx) } } -- GitLab From 392f4f06023b47082f7c7fe5af24dc3cee0c7f91 Mon Sep 17 00:00:00 2001 From: Sylvester Chin Date: Thu, 27 Jul 2023 16:18:44 +0800 Subject: [PATCH 3/4] Add logs and run fmt --- workhorse/go.mod | 6 +++--- workhorse/go.sum | 2 -- workhorse/internal/goredis/goredis.go | 2 +- workhorse/internal/redis/keywatcher_test.go | 2 +- workhorse/main.go | 4 ++++ 5 files changed, 9 insertions(+), 7 deletions(-) diff --git a/workhorse/go.mod b/workhorse/go.mod index da79087170ef7d..ceb9603b58d3be 100644 --- a/workhorse/go.mod +++ b/workhorse/go.mod @@ -5,6 +5,7 @@ go 1.18 require ( github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0 github.com/BurntSushi/toml v1.3.2 + github.com/FZambia/sentinel v1.1.1 github.com/alecthomas/chroma/v2 v2.8.0 github.com/aws/aws-sdk-go v1.44.284 github.com/disintegration/imaging v1.6.2 @@ -12,6 +13,7 @@ require ( github.com/golang-jwt/jwt/v5 v5.0.0 github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f github.com/golang/protobuf v1.5.3 + github.com/gomodule/redigo v2.0.0+incompatible github.com/gorilla/websocket v1.5.0 github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 @@ -19,6 +21,7 @@ require ( github.com/jpillora/backoff v1.0.0 github.com/mitchellh/copystructure v1.2.0 github.com/prometheus/client_golang v1.16.0 + github.com/rafaeljusto/redigomock/v3 v3.1.2 github.com/redis/go-redis/v9 v9.0.5 github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a github.com/sirupsen/logrus v1.9.3 @@ -55,7 +58,6 @@ require ( github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0 // indirect github.com/DataDog/datadog-go v4.4.0+incompatible // indirect github.com/DataDog/sketches-go v1.0.0 // indirect - github.com/FZambia/sentinel v1.1.1 // indirect github.com/Microsoft/go-winio v0.6.0 // indirect github.com/beevik/ntp v1.0.0 // indirect github.com/beorn7/perks v1.0.1 // indirect @@ -70,7 +72,6 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt/v4 v4.5.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect - github.com/gomodule/redigo v2.0.0+incompatible // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/google/pprof v0.0.0-20230406165453-00490a63f317 // indirect github.com/google/s2a-go v0.1.4 // indirect @@ -99,7 +100,6 @@ require ( github.com/prometheus/common v0.42.0 // indirect github.com/prometheus/procfs v0.10.1 // indirect github.com/prometheus/prometheus v0.44.0 // indirect - github.com/rafaeljusto/redigomock/v3 v3.1.2 // indirect github.com/ryszard/goskiplist v0.0.0-20150312221310-2dfbae5fcf46 // indirect github.com/shabbyrobe/gocovmerge v0.0.0-20190829150210-3e036491d500 // indirect github.com/shirou/gopsutil/v3 v3.21.12 // indirect diff --git a/workhorse/go.sum b/workhorse/go.sum index abb06459bc00e9..88d5171bacf2fe 100644 --- a/workhorse/go.sum +++ b/workhorse/go.sum @@ -1372,8 +1372,6 @@ github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gomodule/redigo v1.8.8/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE= -github.com/gomodule/redigo v1.8.9 h1:Sl3u+2BI/kk+VEatbj0scLdrFhjPmbxOc1myhDP41ws= -github.com/gomodule/redigo v1.8.9/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE= github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= diff --git a/workhorse/internal/goredis/goredis.go b/workhorse/internal/goredis/goredis.go index 403538e1bdd467..82f18ccef2e7d0 100644 --- a/workhorse/internal/goredis/goredis.go +++ b/workhorse/internal/goredis/goredis.go @@ -8,8 +8,8 @@ import ( redis "github.com/redis/go-redis/v9" "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" - internalredis "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" _ "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" + internalredis "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" ) var ( diff --git a/workhorse/internal/redis/keywatcher_test.go b/workhorse/internal/redis/keywatcher_test.go index 33daa07569fa0e..3abc1bf1107ddd 100644 --- a/workhorse/internal/redis/keywatcher_test.go +++ b/workhorse/internal/redis/keywatcher_test.go @@ -1,10 +1,10 @@ package redis import ( + "context" "sync" "testing" "time" - "context" "github.com/gomodule/redigo/redis" "github.com/rafaeljusto/redigomock/v3" diff --git a/workhorse/main.go b/workhorse/main.go index ab06c83eb5b3cb..b94961670db041 100644 --- a/workhorse/main.go +++ b/workhorse/main.go @@ -231,6 +231,8 @@ func run(boot bootConfig, cfg config.Config) error { var watchKeyFn builds.WatchKeyHandler if os.Getenv("GITLAB_WORKHORSE_FF_GO_REDIS_ENABLED") == "true" { + log.Info("Using redis/go-redis") + goredis.Configure(cfg.Redis) if rdb := goredis.GetRedisClient(); rdb != nil { go goredisKeyWatcher.Process(rdb) @@ -238,6 +240,8 @@ func run(boot bootConfig, cfg config.Config) error { watchKeyFn = goredisKeyWatcher.WatchKey } else { + log.Info("Using gomodule/redigo") + if cfg.Redis != nil { redis.Configure(cfg.Redis, redis.DefaultDialFunc) go keyWatcher.Process() -- GitLab From 804522e39a44ca98bb886d02197a9fb4c4c63555 Mon Sep 17 00:00:00 2001 From: Sylvester Chin Date: Mon, 31 Jul 2023 20:48:45 +0800 Subject: [PATCH 4/4] Move new keywatcher initialization into envvar-scoped block --- workhorse/main.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/workhorse/main.go b/workhorse/main.go index b94961670db041..698fe1b5233ee8 100644 --- a/workhorse/main.go +++ b/workhorse/main.go @@ -226,13 +226,14 @@ func run(boot bootConfig, cfg config.Config) error { secret.SetPath(boot.secretPath) keyWatcher := redis.NewKeyWatcher() - goredisKeyWatcher := goredis.NewKeyWatcher() var watchKeyFn builds.WatchKeyHandler + var goredisKeyWatcher *goredis.KeyWatcher if os.Getenv("GITLAB_WORKHORSE_FF_GO_REDIS_ENABLED") == "true" { log.Info("Using redis/go-redis") + goredisKeyWatcher = goredis.NewKeyWatcher() goredis.Configure(cfg.Redis) if rdb := goredis.GetRedisClient(); rdb != nil { go goredisKeyWatcher.Process(rdb) @@ -296,8 +297,11 @@ func run(boot bootConfig, cfg config.Config) error { ctx, cancel := context.WithTimeout(context.Background(), cfg.ShutdownTimeout.Duration) // lint:allow context.Background defer cancel() + if goredisKeyWatcher != nil { + goredisKeyWatcher.Shutdown() + } + keyWatcher.Shutdown() - goredisKeyWatcher.Shutdown() return srv.Shutdown(ctx) } } -- GitLab