diff --git a/.gitlab/ci/workhorse.gitlab-ci.yml b/.gitlab/ci/workhorse.gitlab-ci.yml index 5b128ef6170679826fcf91ff683e80874d090198..cedcde27b7eb79ca17ca3b1b63424d9fe9101cb2 100644 --- a/.gitlab/ci/workhorse.gitlab-ci.yml +++ b/.gitlab/ci/workhorse.gitlab-ci.yml @@ -11,6 +11,8 @@ workhorse:verify: .workhorse:test: extends: .workhorse:rules:workhorse image: ${REGISTRY_HOST}/${REGISTRY_GROUP}/gitlab-build-images/debian-${DEBIAN_VERSION}-ruby-${RUBY_VERSION}-golang-${GO_VERSION}-rust-${RUST_VERSION}:rubygems-${RUBYGEMS_VERSION}-git-2.36-exiftool-12.60 + services: + - name: redis:${REDIS_VERSION}-alpine variables: GITALY_ADDRESS: "tcp://127.0.0.1:8075" stage: test @@ -22,6 +24,8 @@ workhorse:verify: - bundle_install_script - go version - scripts/gitaly-test-build + - cp workhorse/config.toml.example workhorse/config.toml + - sed -i 's|URL.*$|URL = "redis://redis:6379"|g' workhorse/config.toml script: - make -C workhorse test @@ -30,6 +34,7 @@ workhorse:test go: parallel: matrix: - GO_VERSION: ["1.18", "1.19", "1.20"] + REDIS_VERSION: ["7.0", "6.2"] script: - make -C workhorse test-coverage coverage: '/\d+.\d+%/' @@ -43,11 +48,15 @@ workhorse:test fips: parallel: matrix: - GO_VERSION: ["1.18", "1.19", "1.20"] + REDIS_VERSION: ["7.0", "6.2"] image: ${REGISTRY_HOST}/${REGISTRY_GROUP}/gitlab-build-images/ubi-${UBI_VERSION}-ruby-${RUBY_VERSION}-golang-${GO_VERSION}-rust-${RUST_VERSION}:rubygems-${RUBYGEMS_VERSION}-git-2.36-exiftool-12.60 variables: FIPS_MODE: 1 workhorse:test race: extends: .workhorse:test + parallel: + matrix: + - REDIS_VERSION: ["7.0", "6.2"] script: - make -C workhorse test-race diff --git a/workhorse/go.mod b/workhorse/go.mod index 17ae3ce12ec07cc018afeded3381db3572476323..966657152ecdce09cbd35c95df0678a74b56e9c1 100644 --- a/workhorse/go.mod +++ b/workhorse/go.mod @@ -22,6 +22,7 @@ require ( github.com/mitchellh/copystructure v1.2.0 github.com/prometheus/client_golang v1.16.0 github.com/rafaeljusto/redigomock/v3 v3.1.2 + github.com/redis/go-redis/v9 v9.0.5 github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a github.com/sirupsen/logrus v1.9.3 github.com/smartystreets/goconvey v1.7.2 @@ -65,6 +66,7 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/client9/reopen v1.0.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.4.0 // indirect github.com/go-ole/go-ole v1.2.6 // indirect github.com/gogo/protobuf v1.3.2 // indirect diff --git a/workhorse/go.sum b/workhorse/go.sum index f3ceee8b5e8fd5eb42e94263fe3b706fc2ce40da..ecdd58148879912bf3fc6fa3d40e148fb5284d89 100644 --- a/workhorse/go.sum +++ b/workhorse/go.sum @@ -869,6 +869,8 @@ github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bradfitz/gomemcache v0.0.0-20170208213004-1952afaa557d/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60= github.com/bshuster-repo/logrus-logstash-hook v0.4.1/go.mod h1:zsTqEiSzDgAa/8GZR7E1qaXrhYNDKBYy5/dWPTIflbk= +github.com/bsm/ginkgo/v2 v2.7.0 h1:ItPMPH90RbmZJt5GtkcNvIRuGEdwlBItdNVoyzaNQao= +github.com/bsm/gomega v1.26.0 h1:LhQm+AFcgV2M0WyKroMASzAzCAJVpAxQXv4SaI9a69Y= github.com/buger/jsonparser v0.0.0-20180808090653-f4dd9f5a6b44/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/bugsnag/bugsnag-go v0.0.0-20141110184014-b1d153021fcd/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8= @@ -1068,6 +1070,8 @@ github.com/denverdino/aliyungo v0.0.0-20190125010748-a747050bb1ba/go.mod h1:dV8l github.com/devigned/tab v0.1.1/go.mod h1:XG9mPq0dFghrYvoBF3xdRrJzSTX1b7IQrvaL9mzjeJY= github.com/dgrijalva/jwt-go v0.0.0-20170104182250-a601269ab70c/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/dgryski/go-sip13 v0.0.0-20200911182023-62edffca9245/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/digitalocean/godo v1.78.0/go.mod h1:GBmu8MkjZmNARE7IXRPmkbbnocNN8+uBm0xbEVw2LCs= @@ -2047,6 +2051,8 @@ github.com/rafaeljusto/redigomock/v3 v3.1.2 h1:B4Y0XJQiPjpwYmkH55aratKX1VfR+JRqz github.com/rafaeljusto/redigomock/v3 v3.1.2/go.mod h1:F9zPqz8rMriScZkPtUiLJoLruYcpGo/XXREpeyasREM= github.com/rakyll/embedmd v0.0.0-20171029212350-c8060a0752a2/go.mod h1:7jOTMgqac46PZcF54q6l2hkLEG8op93fZu61KmxWDV4= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= +github.com/redis/go-redis/v9 v9.0.5 h1:CuQcn5HIEeK7BgElubPP8CGtE0KakrnbBSTLjathl5o= +github.com/redis/go-redis/v9 v9.0.5/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= diff --git a/workhorse/internal/goredis/goredis.go b/workhorse/internal/goredis/goredis.go new file mode 100644 index 0000000000000000000000000000000000000000..cd25c7ca60ef3da014a8236a1d61ed92054b967b --- /dev/null +++ b/workhorse/internal/goredis/goredis.go @@ -0,0 +1,185 @@ +package goredis + +import ( + "context" + "errors" + "fmt" + "net" + "time" + + redis "github.com/redis/go-redis/v9" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" + _ "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" + internalredis "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" +) + +var ( + rdb *redis.Client + // found in https://github.com/redis/go-redis/blob/c7399b6a17d7d3e2a57654528af91349f2468529/sentinel.go#L626 + errSentinelMasterAddr error = errors.New("redis: all sentinels specified in configuration are unreachable") +) + +const ( + // Max Idle Connections in the pool. + defaultMaxIdle = 1 + // Max Active Connections in the pool. + defaultMaxActive = 1 + // Timeout for Read operations on the pool. 1 second is technically overkill, + // it's just for sanity. + defaultReadTimeout = 1 * time.Second + // Timeout for Write operations on the pool. 1 second is technically overkill, + // it's just for sanity. + defaultWriteTimeout = 1 * time.Second + // Timeout before killing Idle connections in the pool. 3 minutes seemed good. + // If you _actually_ hit this timeout often, you should consider turning of + // redis-support since it's not necessary at that point... + defaultIdleTimeout = 3 * time.Minute +) + +// createDialer references https://github.com/redis/go-redis/blob/b1103e3d436b6fe98813ecbbe1f99dc8d59b06c9/options.go#L214 +// it intercepts the error and tracks it via a Prometheus counter +func createDialer(sentinels []string) func(ctx context.Context, network, addr string) (net.Conn, error) { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + var isSentinel bool + for _, sentinelAddr := range sentinels { + if sentinelAddr == addr { + isSentinel = true + break + } + } + + dialTimeout := 5 * time.Second // go-redis default + destination := "redis" + if isSentinel { + // This timeout is recommended for Sentinel-support according to the guidelines. + // https://redis.io/topics/sentinel-clients#redis-service-discovery-via-sentinel + // For every address it should try to connect to the Sentinel, + // using a short timeout (in the order of a few hundreds of milliseconds). + destination = "sentinel" + dialTimeout = 500 * time.Millisecond + } + + netDialer := &net.Dialer{ + Timeout: dialTimeout, + KeepAlive: 5 * time.Minute, + } + + conn, err := netDialer.DialContext(ctx, network, addr) + if err != nil { + internalredis.ErrorCounter.WithLabelValues("dial", destination).Inc() + } else { + if !isSentinel { + internalredis.TotalConnections.Inc() + } + } + + return conn, err + } +} + +// implements the redis.Hook interface for instrumentation +type sentinelInstrumentationHook struct{} + +func (s sentinelInstrumentationHook) DialHook(next redis.DialHook) redis.DialHook { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := next(ctx, network, addr) + if err != nil && err.Error() == errSentinelMasterAddr.Error() { + // check for non-dialer error + internalredis.ErrorCounter.WithLabelValues("master", "sentinel").Inc() + } + return conn, err + } +} + +func (s sentinelInstrumentationHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook { + return func(ctx context.Context, cmd redis.Cmder) error { + return next(ctx, cmd) + } +} + +func (s sentinelInstrumentationHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook { + return func(ctx context.Context, cmds []redis.Cmder) error { + return next(ctx, cmds) + } +} + +func GetRedisClient() *redis.Client { + return rdb +} + +// Configure redis-connection +func Configure(cfg *config.RedisConfig) error { + if cfg == nil { + return nil + } + + var err error + + if len(cfg.Sentinel) > 0 { + rdb = configureSentinel(cfg) + } else { + rdb, err = configureRedis(cfg) + } + + return err +} + +func configureRedis(cfg *config.RedisConfig) (*redis.Client, error) { + if cfg.URL.Scheme == "tcp" { + cfg.URL.Scheme = "redis" + } + + opt, err := redis.ParseURL(cfg.URL.String()) + if err != nil { + return nil, err + } + + opt.DB = getOrDefault(cfg.DB, 0) + opt.Password = cfg.Password + + opt.PoolSize = getOrDefault(cfg.MaxActive, defaultMaxActive) + opt.MaxIdleConns = getOrDefault(cfg.MaxIdle, defaultMaxIdle) + opt.ConnMaxIdleTime = defaultIdleTimeout + opt.ReadTimeout = defaultReadTimeout + opt.WriteTimeout = defaultWriteTimeout + + opt.Dialer = createDialer([]string{}) + + return redis.NewClient(opt), nil +} + +func configureSentinel(cfg *config.RedisConfig) *redis.Client { + sentinels := make([]string, len(cfg.Sentinel)) + for i := range cfg.Sentinel { + sentinelDetails := cfg.Sentinel[i] + sentinels[i] = fmt.Sprintf("%s:%s", sentinelDetails.Hostname(), sentinelDetails.Port()) + } + + client := redis.NewFailoverClient(&redis.FailoverOptions{ + MasterName: cfg.SentinelMaster, + SentinelAddrs: sentinels, + Password: cfg.Password, + DB: getOrDefault(cfg.DB, 0), + + PoolSize: getOrDefault(cfg.MaxActive, defaultMaxActive), + MaxIdleConns: getOrDefault(cfg.MaxIdle, defaultMaxIdle), + ConnMaxIdleTime: defaultIdleTimeout, + + ReadTimeout: defaultReadTimeout, + WriteTimeout: defaultWriteTimeout, + + Dialer: createDialer(sentinels), + }) + + client.AddHook(sentinelInstrumentationHook{}) + + return client +} + +func getOrDefault(ptr *int, val int) int { + if ptr != nil { + return *ptr + } + return val +} diff --git a/workhorse/internal/goredis/goredis_test.go b/workhorse/internal/goredis/goredis_test.go new file mode 100644 index 0000000000000000000000000000000000000000..fc1ecaa227d68b7a50aaf334e36671c2e52038fd --- /dev/null +++ b/workhorse/internal/goredis/goredis_test.go @@ -0,0 +1,107 @@ +package goredis + +import ( + "context" + "net" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" +) + +func mockRedisServer(t *testing.T, connectReceived *bool) string { + // go-redis does not deal with port 0 + ln, err := net.Listen("tcp", "127.0.0.1:6389") + + require.Nil(t, err) + + go func() { + defer ln.Close() + conn, err := ln.Accept() + require.Nil(t, err) + *connectReceived = true + conn.Write([]byte("OK\n")) + }() + + return ln.Addr().String() +} + +func TestConfigureNoConfig(t *testing.T) { + rdb = nil + Configure(nil) + require.Nil(t, rdb, "rdb client should be nil") +} + +func TestConfigureValidConfigX(t *testing.T) { + testCases := []struct { + scheme string + }{ + { + scheme: "redis", + }, + { + scheme: "tcp", + }, + } + + for _, tc := range testCases { + t.Run(tc.scheme, func(t *testing.T) { + connectReceived := false + a := mockRedisServer(t, &connectReceived) + + parsedURL := helper.URLMustParse(tc.scheme + "://" + a) + cfg := &config.RedisConfig{URL: config.TomlURL{URL: *parsedURL}} + + Configure(cfg) + + require.NotNil(t, GetRedisClient().Conn(), "Pool should not be nil") + + // goredis initialise connections lazily + rdb.Ping(context.Background()) + require.True(t, connectReceived) + + rdb = nil + }) + } +} + +func TestConnectToSentinel(t *testing.T) { + testCases := []struct { + scheme string + }{ + { + scheme: "redis", + }, + { + scheme: "tcp", + }, + } + + for _, tc := range testCases { + t.Run(tc.scheme, func(t *testing.T) { + connectReceived := false + a := mockRedisServer(t, &connectReceived) + + addrs := []string{tc.scheme + "://" + a} + var sentinelUrls []config.TomlURL + + for _, a := range addrs { + parsedURL := helper.URLMustParse(a) + sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL}) + } + + cfg := &config.RedisConfig{Sentinel: sentinelUrls} + Configure(cfg) + + require.NotNil(t, GetRedisClient().Conn(), "Pool should not be nil") + + // goredis initialise connections lazily + rdb.Ping(context.Background()) + require.True(t, connectReceived) + + rdb = nil + }) + } +} diff --git a/workhorse/internal/goredis/keywatcher.go b/workhorse/internal/goredis/keywatcher.go new file mode 100644 index 0000000000000000000000000000000000000000..741bfb17652394fb738e2d953abf0d8b70502c36 --- /dev/null +++ b/workhorse/internal/goredis/keywatcher.go @@ -0,0 +1,236 @@ +package goredis + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "time" + + "github.com/jpillora/backoff" + "github.com/redis/go-redis/v9" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/log" + internalredis "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" +) + +type KeyWatcher struct { + mu sync.Mutex + subscribers map[string][]chan string + shutdown chan struct{} + reconnectBackoff backoff.Backoff + redisConn *redis.Client + conn *redis.PubSub +} + +func NewKeyWatcher() *KeyWatcher { + return &KeyWatcher{ + shutdown: make(chan struct{}), + reconnectBackoff: backoff.Backoff{ + Min: 100 * time.Millisecond, + Max: 60 * time.Second, + Factor: 2, + Jitter: true, + }, + } +} + +const channelPrefix = "workhorse:notifications:" + +func countAction(action string) { internalredis.TotalActions.WithLabelValues(action).Add(1) } + +func (kw *KeyWatcher) receivePubSubStream(ctx context.Context, pubsub *redis.PubSub) error { + kw.mu.Lock() + // We must share kw.conn with the goroutines that call SUBSCRIBE and + // UNSUBSCRIBE because Redis pubsub subscriptions are tied to the + // connection. + kw.conn = pubsub + kw.mu.Unlock() + + defer func() { + kw.mu.Lock() + defer kw.mu.Unlock() + kw.conn.Close() + kw.conn = nil + + // Reset kw.subscribers because it is tied to Redis server side state of + // kw.conn and we just closed that connection. + for _, chans := range kw.subscribers { + for _, ch := range chans { + close(ch) + internalredis.KeyWatchers.Dec() + } + } + kw.subscribers = nil + }() + + for { + msg, err := kw.conn.Receive(ctx) + if err != nil { + log.WithError(fmt.Errorf("keywatcher: pubsub receive: %v", err)).Error() + return nil + } + + switch msg := msg.(type) { + case *redis.Subscription: + internalredis.RedisSubscriptions.Set(float64(msg.Count)) + case *redis.Pong: + // Ignore. + case *redis.Message: + internalredis.TotalMessages.Inc() + internalredis.ReceivedBytes.Add(float64(len(msg.Payload))) + if strings.HasPrefix(msg.Channel, channelPrefix) { + kw.notifySubscribers(msg.Channel[len(channelPrefix):], string(msg.Payload)) + } + default: + log.WithError(fmt.Errorf("keywatcher: unknown: %T", msg)).Error() + return nil + } + } +} + +func (kw *KeyWatcher) Process(client *redis.Client) { + log.Info("keywatcher: starting process loop") + + ctx := context.Background() // lint:allow context.Background + kw.mu.Lock() + kw.redisConn = client + kw.mu.Unlock() + + for { + pubsub := client.Subscribe(ctx, []string{}...) + if err := pubsub.Ping(ctx); err != nil { + log.WithError(fmt.Errorf("keywatcher: %v", err)).Error() + time.Sleep(kw.reconnectBackoff.Duration()) + continue + } + + kw.reconnectBackoff.Reset() + + if err := kw.receivePubSubStream(ctx, pubsub); err != nil { + log.WithError(fmt.Errorf("keywatcher: receivePubSubStream: %v", err)).Error() + } + } +} + +func (kw *KeyWatcher) Shutdown() { + log.Info("keywatcher: shutting down") + + kw.mu.Lock() + defer kw.mu.Unlock() + + select { + case <-kw.shutdown: + // already closed + default: + close(kw.shutdown) + } +} + +func (kw *KeyWatcher) notifySubscribers(key, value string) { + kw.mu.Lock() + defer kw.mu.Unlock() + + chanList, ok := kw.subscribers[key] + if !ok { + countAction("drop-message") + return + } + + countAction("deliver-message") + for _, c := range chanList { + select { + case c <- value: + default: + } + } +} + +func (kw *KeyWatcher) addSubscription(ctx context.Context, key string, notify chan string) error { + kw.mu.Lock() + defer kw.mu.Unlock() + + if kw.conn == nil { + // This can happen because CI long polling is disabled in this Workhorse + // process. It can also be that we are waiting for the pubsub connection + // to be established. Either way it is OK to fail fast. + return errors.New("no redis connection") + } + + if len(kw.subscribers[key]) == 0 { + countAction("create-subscription") + if err := kw.conn.Subscribe(ctx, channelPrefix+key); err != nil { + return err + } + } + + if kw.subscribers == nil { + kw.subscribers = make(map[string][]chan string) + } + kw.subscribers[key] = append(kw.subscribers[key], notify) + internalredis.KeyWatchers.Inc() + + return nil +} + +func (kw *KeyWatcher) delSubscription(ctx context.Context, key string, notify chan string) { + kw.mu.Lock() + defer kw.mu.Unlock() + + chans, ok := kw.subscribers[key] + if !ok { + // This can happen if the pubsub connection dropped while we were + // waiting. + return + } + + for i, c := range chans { + if notify == c { + kw.subscribers[key] = append(chans[:i], chans[i+1:]...) + internalredis.KeyWatchers.Dec() + break + } + } + if len(kw.subscribers[key]) == 0 { + delete(kw.subscribers, key) + countAction("delete-subscription") + if kw.conn != nil { + kw.conn.Unsubscribe(ctx, channelPrefix+key) + } + } +} + +func (kw *KeyWatcher) WatchKey(ctx context.Context, key, value string, timeout time.Duration) (internalredis.WatchKeyStatus, error) { + notify := make(chan string, 1) + if err := kw.addSubscription(ctx, key, notify); err != nil { + return internalredis.WatchKeyStatusNoChange, err + } + defer kw.delSubscription(ctx, key, notify) + + currentValue, err := kw.redisConn.Get(ctx, key).Result() + if errors.Is(err, redis.Nil) { + currentValue = "" + } else if err != nil { + return internalredis.WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET: %v", err) + } + if currentValue != value { + return internalredis.WatchKeyStatusAlreadyChanged, nil + } + + select { + case <-kw.shutdown: + log.WithFields(log.Fields{"key": key}).Info("stopping watch due to shutdown") + return internalredis.WatchKeyStatusNoChange, nil + case currentValue := <-notify: + if currentValue == "" { + return internalredis.WatchKeyStatusNoChange, fmt.Errorf("keywatcher: redis GET failed") + } + if currentValue == value { + return internalredis.WatchKeyStatusNoChange, nil + } + return internalredis.WatchKeyStatusSeenChange, nil + case <-time.After(timeout): + return internalredis.WatchKeyStatusTimeout, nil + } +} diff --git a/workhorse/internal/goredis/keywatcher_test.go b/workhorse/internal/goredis/keywatcher_test.go new file mode 100644 index 0000000000000000000000000000000000000000..b64262dc9c814c4165dfd84472f513d4b652e2b7 --- /dev/null +++ b/workhorse/internal/goredis/keywatcher_test.go @@ -0,0 +1,301 @@ +package goredis + +import ( + "context" + "os" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" +) + +var ctx = context.Background() + +const ( + runnerKey = "runner:build_queue:10" +) + +func initRdb() { + buf, _ := os.ReadFile("../../config.toml") + cfg, _ := config.LoadConfig(string(buf)) + Configure(cfg.Redis) +} + +func (kw *KeyWatcher) countSubscribers(key string) int { + kw.mu.Lock() + defer kw.mu.Unlock() + return len(kw.subscribers[key]) +} + +// Forces a run of the `Process` loop against a mock PubSubConn. +func (kw *KeyWatcher) processMessages(t *testing.T, numWatchers int, value string, ready chan<- struct{}, wg *sync.WaitGroup) { + kw.mu.Lock() + kw.redisConn = rdb + psc := kw.redisConn.Subscribe(ctx, []string{}...) + kw.mu.Unlock() + + errC := make(chan error) + go func() { errC <- kw.receivePubSubStream(ctx, psc) }() + + require.Eventually(t, func() bool { + kw.mu.Lock() + defer kw.mu.Unlock() + return kw.conn != nil + }, time.Second, time.Millisecond) + close(ready) + + require.Eventually(t, func() bool { + return kw.countSubscribers(runnerKey) == numWatchers + }, time.Second, time.Millisecond) + + // send message after listeners are ready + kw.redisConn.Publish(ctx, channelPrefix+runnerKey, value) + + // close subscription after all workers are done + wg.Wait() + kw.mu.Lock() + kw.conn.Close() + kw.mu.Unlock() + + require.NoError(t, <-errC) +} + +type keyChangeTestCase struct { + desc string + returnValue string + isKeyMissing bool + watchValue string + processedValue string + expectedStatus redis.WatchKeyStatus + timeout time.Duration +} + +func TestKeyChangesInstantReturn(t *testing.T) { + initRdb() + + testCases := []keyChangeTestCase{ + // WatchKeyStatusAlreadyChanged + { + desc: "sees change with key existing and changed", + returnValue: "somethingelse", + watchValue: "something", + expectedStatus: redis.WatchKeyStatusAlreadyChanged, + timeout: time.Second, + }, + { + desc: "sees change with key non-existing", + isKeyMissing: true, + watchValue: "something", + processedValue: "somethingelse", + expectedStatus: redis.WatchKeyStatusAlreadyChanged, + timeout: time.Second, + }, + // WatchKeyStatusTimeout + { + desc: "sees timeout with key existing and unchanged", + returnValue: "something", + watchValue: "something", + expectedStatus: redis.WatchKeyStatusTimeout, + timeout: time.Millisecond, + }, + { + desc: "sees timeout with key non-existing and unchanged", + isKeyMissing: true, + watchValue: "", + expectedStatus: redis.WatchKeyStatusTimeout, + timeout: time.Millisecond, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + + // setup + if !tc.isKeyMissing { + rdb.Set(ctx, runnerKey, tc.returnValue, 0) + } + + defer func() { + rdb.FlushDB(ctx) + }() + + kw := NewKeyWatcher() + defer kw.Shutdown() + kw.redisConn = rdb + kw.conn = kw.redisConn.Subscribe(ctx, []string{}...) + + val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, tc.timeout) + + require.NoError(t, err, "Expected no error") + require.Equal(t, tc.expectedStatus, val, "Expected value") + }) + } +} + +func TestKeyChangesWhenWatching(t *testing.T) { + initRdb() + + testCases := []keyChangeTestCase{ + // WatchKeyStatusSeenChange + { + desc: "sees change with key existing", + returnValue: "something", + watchValue: "something", + processedValue: "somethingelse", + expectedStatus: redis.WatchKeyStatusSeenChange, + }, + { + desc: "sees change with key non-existing, when watching empty value", + isKeyMissing: true, + watchValue: "", + processedValue: "something", + expectedStatus: redis.WatchKeyStatusSeenChange, + }, + // WatchKeyStatusNoChange + { + desc: "sees no change with key existing", + returnValue: "something", + watchValue: "something", + processedValue: "something", + expectedStatus: redis.WatchKeyStatusNoChange, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + if !tc.isKeyMissing { + rdb.Set(ctx, runnerKey, tc.returnValue, 0) + } + + kw := NewKeyWatcher() + defer kw.Shutdown() + defer func() { + rdb.FlushDB(ctx) + }() + + wg := &sync.WaitGroup{} + wg.Add(1) + ready := make(chan struct{}) + + go func() { + defer wg.Done() + <-ready + val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, time.Second) + + require.NoError(t, err, "Expected no error") + require.Equal(t, tc.expectedStatus, val, "Expected value") + }() + + kw.processMessages(t, 1, tc.processedValue, ready, wg) + }) + } +} + +func TestKeyChangesParallel(t *testing.T) { + initRdb() + + testCases := []keyChangeTestCase{ + { + desc: "massively parallel, sees change with key existing", + returnValue: "something", + watchValue: "something", + processedValue: "somethingelse", + expectedStatus: redis.WatchKeyStatusSeenChange, + }, + { + desc: "massively parallel, sees change with key existing, watching missing keys", + isKeyMissing: true, + watchValue: "", + processedValue: "somethingelse", + expectedStatus: redis.WatchKeyStatusSeenChange, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + runTimes := 100 + + if !tc.isKeyMissing { + rdb.Set(ctx, runnerKey, tc.returnValue, 0) + } + + defer func() { + rdb.FlushDB(ctx) + }() + + wg := &sync.WaitGroup{} + wg.Add(runTimes) + ready := make(chan struct{}) + + kw := NewKeyWatcher() + defer kw.Shutdown() + + for i := 0; i < runTimes; i++ { + go func() { + defer wg.Done() + <-ready + val, err := kw.WatchKey(ctx, runnerKey, tc.watchValue, time.Second) + + require.NoError(t, err, "Expected no error") + require.Equal(t, tc.expectedStatus, val, "Expected value") + }() + } + + kw.processMessages(t, runTimes, tc.processedValue, ready, wg) + }) + } +} + +func TestShutdown(t *testing.T) { + initRdb() + + kw := NewKeyWatcher() + kw.redisConn = rdb + kw.conn = kw.redisConn.Subscribe(ctx, []string{}...) + defer kw.Shutdown() + + rdb.Set(ctx, runnerKey, "something", 0) + + wg := &sync.WaitGroup{} + wg.Add(2) + + go func() { + defer wg.Done() + val, err := kw.WatchKey(ctx, runnerKey, "something", 10*time.Second) + + require.NoError(t, err, "Expected no error") + require.Equal(t, redis.WatchKeyStatusNoChange, val, "Expected value not to change") + }() + + go func() { + defer wg.Done() + require.Eventually(t, func() bool { return kw.countSubscribers(runnerKey) == 1 }, 10*time.Second, time.Millisecond) + + kw.Shutdown() + }() + + wg.Wait() + + require.Eventually(t, func() bool { return kw.countSubscribers(runnerKey) == 0 }, 10*time.Second, time.Millisecond) + + // Adding a key after the shutdown should result in an immediate response + var val redis.WatchKeyStatus + var err error + done := make(chan struct{}) + go func() { + val, err = kw.WatchKey(ctx, runnerKey, "something", 10*time.Second) + close(done) + }() + + select { + case <-done: + require.NoError(t, err, "Expected no error") + require.Equal(t, redis.WatchKeyStatusNoChange, val, "Expected value not to change") + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for WatchKey") + } +} diff --git a/workhorse/internal/redis/keywatcher.go b/workhorse/internal/redis/keywatcher.go index 2fd0753c3c95217693f3719c15dd8d22a09f49c2..8f1772a91958c404ed83c993aa45052a2a9ac626 100644 --- a/workhorse/internal/redis/keywatcher.go +++ b/workhorse/internal/redis/keywatcher.go @@ -37,32 +37,32 @@ func NewKeyWatcher() *KeyWatcher { } var ( - keyWatchers = promauto.NewGauge( + KeyWatchers = promauto.NewGauge( prometheus.GaugeOpts{ Name: "gitlab_workhorse_keywatcher_keywatchers", Help: "The number of keys that is being watched by gitlab-workhorse", }, ) - redisSubscriptions = promauto.NewGauge( + RedisSubscriptions = promauto.NewGauge( prometheus.GaugeOpts{ Name: "gitlab_workhorse_keywatcher_redis_subscriptions", Help: "Current number of keywatcher Redis pubsub subscriptions", }, ) - totalMessages = promauto.NewCounter( + TotalMessages = promauto.NewCounter( prometheus.CounterOpts{ Name: "gitlab_workhorse_keywatcher_total_messages", Help: "How many messages gitlab-workhorse has received in total on pubsub.", }, ) - totalActions = promauto.NewCounterVec( + TotalActions = promauto.NewCounterVec( prometheus.CounterOpts{ Name: "gitlab_workhorse_keywatcher_actions_total", Help: "Counts of various keywatcher actions", }, []string{"action"}, ) - receivedBytes = promauto.NewCounter( + ReceivedBytes = promauto.NewCounter( prometheus.CounterOpts{ Name: "gitlab_workhorse_keywatcher_received_bytes_total", Help: "How many bytes of messages gitlab-workhorse has received in total on pubsub.", @@ -72,7 +72,7 @@ var ( const channelPrefix = "workhorse:notifications:" -func countAction(action string) { totalActions.WithLabelValues(action).Add(1) } +func countAction(action string) { TotalActions.WithLabelValues(action).Add(1) } func (kw *KeyWatcher) receivePubSubStream(conn redis.Conn) error { kw.mu.Lock() @@ -93,7 +93,7 @@ func (kw *KeyWatcher) receivePubSubStream(conn redis.Conn) error { for _, chans := range kw.subscribers { for _, ch := range chans { close(ch) - keyWatchers.Dec() + KeyWatchers.Dec() } } kw.subscribers = nil @@ -102,13 +102,13 @@ func (kw *KeyWatcher) receivePubSubStream(conn redis.Conn) error { for { switch v := kw.conn.Receive().(type) { case redis.Message: - totalMessages.Inc() - receivedBytes.Add(float64(len(v.Data))) + TotalMessages.Inc() + ReceivedBytes.Add(float64(len(v.Data))) if strings.HasPrefix(v.Channel, channelPrefix) { kw.notifySubscribers(v.Channel[len(channelPrefix):], string(v.Data)) } case redis.Subscription: - redisSubscriptions.Set(float64(v.Count)) + RedisSubscriptions.Set(float64(v.Count)) case error: log.WithError(fmt.Errorf("keywatcher: pubsub receive: %v", v)).Error() // Intermittent error, return nil so that it doesn't wait before reconnect @@ -205,7 +205,7 @@ func (kw *KeyWatcher) addSubscription(key string, notify chan string) error { kw.subscribers = make(map[string][]chan string) } kw.subscribers[key] = append(kw.subscribers[key], notify) - keyWatchers.Inc() + KeyWatchers.Inc() return nil } @@ -224,7 +224,7 @@ func (kw *KeyWatcher) delSubscription(key string, notify chan string) { for i, c := range chans { if notify == c { kw.subscribers[key] = append(chans[:i], chans[i+1:]...) - keyWatchers.Dec() + KeyWatchers.Dec() break } } diff --git a/workhorse/internal/redis/redis.go b/workhorse/internal/redis/redis.go index 03118cfcef65f4259d38fe09c1a151844c4986e2..c79e1e56b3ad169d2da9b00ebd713e9a53d2c12d 100644 --- a/workhorse/internal/redis/redis.go +++ b/workhorse/internal/redis/redis.go @@ -45,14 +45,14 @@ const ( ) var ( - totalConnections = promauto.NewCounter( + TotalConnections = promauto.NewCounter( prometheus.CounterOpts{ Name: "gitlab_workhorse_redis_total_connections", Help: "How many connections gitlab-workhorse has opened in total. Can be used to track Redis connection rate for this process", }, ) - errorCounter = promauto.NewCounterVec( + ErrorCounter = promauto.NewCounterVec( prometheus.CounterOpts{ Name: "gitlab_workhorse_redis_errors", Help: "Counts different types of Redis errors encountered by workhorse, by type and destination (redis, sentinel)", @@ -100,7 +100,7 @@ func sentinelConn(master string, urls []config.TomlURL) *sentinel.Sentinel { } if err != nil { - errorCounter.WithLabelValues("dial", "sentinel").Inc() + ErrorCounter.WithLabelValues("dial", "sentinel").Inc() return nil, err } return c, nil @@ -159,7 +159,7 @@ func sentinelDialer(dopts []redis.DialOption) redisDialerFunc { return func() (redis.Conn, error) { address, err := sntnl.MasterAddr() if err != nil { - errorCounter.WithLabelValues("master", "sentinel").Inc() + ErrorCounter.WithLabelValues("master", "sentinel").Inc() return nil, err } dopts = append(dopts, redis.DialNetDial(keepAliveDialer)) @@ -214,9 +214,9 @@ func countDialer(dialer redisDialerFunc) redisDialerFunc { return func() (redis.Conn, error) { c, err := dialer() if err != nil { - errorCounter.WithLabelValues("dial", "redis").Inc() + ErrorCounter.WithLabelValues("dial", "redis").Inc() } else { - totalConnections.Inc() + TotalConnections.Inc() } return c, err } diff --git a/workhorse/main.go b/workhorse/main.go index ca9b86de528a73f78d3c9e0808634d628cb1c04b..9ba213d47d355cf61b68a17fb4ff65ecb6beb6fc 100644 --- a/workhorse/main.go +++ b/workhorse/main.go @@ -17,8 +17,10 @@ import ( "gitlab.com/gitlab-org/labkit/monitoring" "gitlab.com/gitlab-org/labkit/tracing" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/builds" "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" "gitlab.com/gitlab-org/gitlab/workhorse/internal/gitaly" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/goredis" "gitlab.com/gitlab-org/gitlab/workhorse/internal/queueing" "gitlab.com/gitlab-org/gitlab/workhorse/internal/redis" "gitlab.com/gitlab-org/gitlab/workhorse/internal/secret" @@ -224,9 +226,32 @@ func run(boot bootConfig, cfg config.Config) error { secret.SetPath(boot.secretPath) keyWatcher := redis.NewKeyWatcher() - if cfg.Redis != nil { - redis.Configure(cfg.Redis, redis.DefaultDialFunc) - go keyWatcher.Process() + + var watchKeyFn builds.WatchKeyHandler + var goredisKeyWatcher *goredis.KeyWatcher + + if os.Getenv("GITLAB_WORKHORSE_FF_GO_REDIS_ENABLED") == "true" { + log.Info("Using redis/go-redis") + + goredisKeyWatcher = goredis.NewKeyWatcher() + if err := goredis.Configure(cfg.Redis); err != nil { + log.WithError(err).Error("unable to configure redis client") + } + + if rdb := goredis.GetRedisClient(); rdb != nil { + go goredisKeyWatcher.Process(rdb) + } + + watchKeyFn = goredisKeyWatcher.WatchKey + } else { + log.Info("Using gomodule/redigo") + + if cfg.Redis != nil { + redis.Configure(cfg.Redis, redis.DefaultDialFunc) + go keyWatcher.Process() + } + + watchKeyFn = keyWatcher.WatchKey } if err := cfg.RegisterGoCloudURLOpeners(); err != nil { @@ -241,7 +266,7 @@ func run(boot bootConfig, cfg config.Config) error { gitaly.InitializeSidechannelRegistry(accessLogger) - up := wrapRaven(upstream.NewUpstream(cfg, accessLogger, keyWatcher.WatchKey)) + up := wrapRaven(upstream.NewUpstream(cfg, accessLogger, watchKeyFn)) done := make(chan os.Signal, 1) signal.Notify(done, syscall.SIGINT, syscall.SIGTERM) @@ -275,6 +300,10 @@ func run(boot bootConfig, cfg config.Config) error { ctx, cancel := context.WithTimeout(context.Background(), cfg.ShutdownTimeout.Duration) // lint:allow context.Background defer cancel() + if goredisKeyWatcher != nil { + goredisKeyWatcher.Shutdown() + } + keyWatcher.Shutdown() return srv.Shutdown(ctx) }