diff --git a/workhorse/.golangci.yml b/workhorse/.golangci.yml index e9e793727453210df4bcc4fb1665d3bb452351a2..4cea46a3949668b9da776c1dd2f228a4a09369c2 100644 --- a/workhorse/.golangci.yml +++ b/workhorse/.golangci.yml @@ -155,6 +155,8 @@ linters-settings: - github.com/grpc-ecosystem/go-grpc-prometheus - github.com/mitchellh/copystructure - github.com/jpillora/backoff + - github.com/alicebob/miniredis/v2 + - github.com/sony/gobreaker/v2 dupl: # tokens count to trigger issue, 150 by default threshold: 100 diff --git a/workhorse/go.mod b/workhorse/go.mod index 4f134dddf7c023c593089c97affcc3b7df9ef8cb..f1683d3b85f77b2325a0e36757ccd7a212eb5134 100644 --- a/workhorse/go.mod +++ b/workhorse/go.mod @@ -9,6 +9,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.0 github.com/BurntSushi/toml v1.4.0 github.com/alecthomas/chroma/v2 v2.14.0 + github.com/alicebob/miniredis/v2 v2.34.0 github.com/aws/aws-sdk-go-v2 v1.32.3 github.com/aws/aws-sdk-go-v2/config v1.28.1 github.com/aws/aws-sdk-go-v2/credentials v1.17.42 @@ -27,6 +28,7 @@ require ( github.com/redis/go-redis/v9 v9.7.3 github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a github.com/sirupsen/logrus v1.9.3 + github.com/sony/gobreaker/v2 v2.1.0 github.com/stretchr/testify v1.10.0 gitlab.com/gitlab-org/gitaly/v16 v16.11.0-rc1.0.20250408053233-c6d43513e93c gitlab.com/gitlab-org/labkit v1.23.2 @@ -62,6 +64,7 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.48.1 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.48.1 // indirect github.com/Microsoft/go-winio v0.6.1 // indirect + github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect github.com/aws/aws-sdk-go v1.55.5 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.6 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.18 // indirect @@ -93,6 +96,7 @@ require ( github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.2.6 // indirect + github.com/go-redsync/redsync/v4 v4.13.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect @@ -105,6 +109,8 @@ require ( github.com/googleapis/gax-go/v2 v2.13.0 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.1 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hashicorp/yamux v0.1.2-0.20220728231024-8f49b6f63f18 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/klauspost/compress v1.18.0 // indirect @@ -134,6 +140,7 @@ require ( github.com/tklauser/numcpus v0.3.0 // indirect github.com/uber/jaeger-client-go v2.30.0+incompatible // indirect github.com/uber/jaeger-lib v2.4.1+incompatible // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect github.com/yusufpapurcu/wmi v1.2.2 // indirect gitlab.com/gitlab-org/go/reopen v1.0.0 // indirect go.etcd.io/raft/v3 v3.6.0 // indirect diff --git a/workhorse/go.sum b/workhorse/go.sum index e68d114bca1fea1575846d13431b4d68483b8f75..78ee643283668f1dea18c5c3b19b55fbd965a951 100644 --- a/workhorse/go.sum +++ b/workhorse/go.sum @@ -120,6 +120,10 @@ github.com/alecthomas/chroma/v2 v2.14.0 h1:R3+wzpnUArGcQz7fCETQBzO5n9IMNi13iIs46 github.com/alecthomas/chroma/v2 v2.14.0/go.mod h1:QolEbTfmUHIMVpBqxeDnNBj2uoeI4EbYP4i6n68SG4I= github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc= github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= +github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE= +github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= +github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQggTglU/0= +github.com/alicebob/miniredis/v2 v2.34.0/go.mod h1:kWShP4b58T1CW0Y5dViCd5ztzrDqRWqM3nksiyXk5s8= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/aws/aws-sdk-go v1.44.256/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU= @@ -274,6 +278,14 @@ github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre github.com/go-ole/go-ole v1.2.4/go.mod h1:XCwSNxSkXRo4vlyPy93sltvi/qJq0jqQhjqQNIwKuxM= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/go-redis/redis v6.15.9+incompatible h1:K0pv1D7EQUjfyoMql+r/jZqCLizCGKFlFgcHWWmHQjg= +github.com/go-redis/redis v6.15.9+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= +github.com/go-redis/redis/v7 v7.4.1 h1:PASvf36gyUpr2zdOUS/9Zqc80GbM+9BDyiJSJDDOrTI= +github.com/go-redis/redis/v7 v7.4.1/go.mod h1:JDNMw23GTyLNC4GZu9njt15ctBQVn7xjRfnwdHj/Dcg= +github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= +github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/go-redsync/redsync/v4 v4.13.0 h1:49X6GJfnbLGaIpBBREM/zA4uIMDXKAh1NDkvQ1EkZKA= +github.com/go-redsync/redsync/v4 v4.13.0/go.mod h1:HMW4Q224GZQz6x1Xc7040Yfgacukdzu7ifTDAKiyErQ= github.com/go-stack/stack v1.6.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= @@ -322,6 +334,8 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/gomodule/redigo v1.8.9 h1:Sl3u+2BI/kk+VEatbj0scLdrFhjPmbxOc1myhDP41ws= +github.com/gomodule/redigo v1.8.9/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/flatbuffers v25.2.10+incompatible h1:F3vclr7C3HpB1k9mxCGRMXq6FdUalZ6H/pNX4FP1v0Q= @@ -398,6 +412,11 @@ github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.1/go.mod h1:qOchhhIlmRcqk/O github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 h1:Ovs26xHkKqVztRpIrF/92BcuyuQ/YW4NSIpoGtfXNho= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.6.0 h1:uL2shRDx7RTrOrTCUZEGP/wJUFiUI8QT6E7z5o8jga4= @@ -527,6 +546,8 @@ github.com/prometheus/prometheus v0.54.0 h1:6+VmEkohHcofl3W5LyRlhw1Lfm575w/aX6ZF github.com/prometheus/prometheus v0.54.0/go.mod h1:xlLByHhk2g3ycakQGrMaU8K7OySZx98BzeCR99991NY= github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= +github.com/redis/rueidis v1.0.19 h1:s65oWtotzlIFN8eMPhyYwxlwLR1lUdhza2KtWprKYSo= +github.com/redis/rueidis v1.0.19/go.mod h1:8B+r5wdnjwK3lTFml5VtxjzGOQAC+5UmujoD12pDrEo= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= @@ -554,6 +575,8 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/skeema/knownhosts v1.2.1 h1:SHWdIUa82uGZz+F+47k8SY4QhhI291cXCpopT1lK2AQ= github.com/skeema/knownhosts v1.2.1/go.mod h1:xYbVRSPxqBZFrdmDyMmsOs+uX1UZC3nTN3ThzgDxUwo= +github.com/sony/gobreaker/v2 v2.1.0 h1:av2BnjtRmVPWBvy5gSFPytm1J8BmN5AGhq875FfGKDM= +github.com/sony/gobreaker/v2 v2.1.0/go.mod h1:dO3Q/nCzxZj6ICjH6J/gM0r4oAwBMVLY8YAQf+NTtUg= github.com/spf13/afero v0.0.0-20170901052352-ee1bd8ee15a1/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= github.com/spf13/afero v1.2.1/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk= github.com/spf13/cast v1.1.0/go.mod h1:r2rcYCSwa1IExKTDiTfzaxqT2FNHs8hODu4LnUfgKEg= @@ -577,6 +600,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stvp/tempredis v0.0.0-20181119212430-b82af8480203 h1:QVqDTf3h2WHt08YuiTGPZLls0Wq99X9bWd0Q5ZSBesM= +github.com/stvp/tempredis v0.0.0-20181119212430-b82af8480203/go.mod h1:oqN97ltKNihBbwlX8dLpwxCl3+HnXKV/R0e+sRLd9C8= github.com/tinylib/msgp v1.1.2 h1:gWmO7n0Ys2RBEb7GPYB9Ujq8Mk5p2U08lRnmMcGy6BQ= github.com/tinylib/msgp v1.1.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= github.com/tklauser/go-sysconf v0.3.4/go.mod h1:Cl2c8ZRWfHD5IrfHo9VN+FX9kCFjIOyVklgXycLB6ek= @@ -597,6 +622,8 @@ github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= github.com/yusufpapurcu/wmi v1.2.2 h1:KBNDSne4vP5mbSWnJbO+51IMOXJB67QiYCSBrubbPRg= github.com/yusufpapurcu/wmi v1.2.2/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= gitlab.com/gitlab-org/gitaly/v16 v16.11.0-rc1.0.20250408053233-c6d43513e93c h1:xwidECyV4uYBsKqKaAg2wwUrwpCwtfbbisQ3PwlZOoI= diff --git a/workhorse/internal/circuitbreaker/roundtripper.go b/workhorse/internal/circuitbreaker/roundtripper.go new file mode 100644 index 0000000000000000000000000000000000000000..f8354ac0226609aef3ef39a18deb791ff61a1a17 --- /dev/null +++ b/workhorse/internal/circuitbreaker/roundtripper.go @@ -0,0 +1,157 @@ +/* +Package circuitbreaker provides a custom HTTP wrapper roundTripper that implements a circuitbreaker. +*/ +package circuitbreaker + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "time" + + redis "github.com/redis/go-redis/v9" + "github.com/sony/gobreaker/v2" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab/workhorse/internal/log" +) + +const ( + Timeout = 60 * time.Second // Timeout is the duration to transition to half-open when open + Interval = 180 * time.Second // Interval is the duration to clear consecutive failures (and other gobreaker.Counts) when closed + MaxRequests = 1 // MaxRequests is the number of failed requests to open the circuit breaker when half-open + ConsecutiveFailures = 5 // ConsecutiveFailures is the number of consecutive failures to open the circuit breaker when closed +) + +type roundTripper struct { + delegate http.RoundTripper + store *gobreaker.RedisStore +} + +// NewRoundTripper returns a new RoundTripper that wraps the provided RoundTripper with a circuit breaker +func NewRoundTripper(delegate http.RoundTripper, cfg *config.RedisConfig) http.RoundTripper { + if cfg == nil { + return delegate + } + + opt, err := redis.ParseURL(cfg.URL.String()) + if err != nil { + log.WithError(err).Info("gobreaker: failed to parse redis URL") + return delegate + } + + return &roundTripper{ + delegate: delegate, + store: gobreaker.NewRedisStore(opt.Addr), + } +} + +// RoundTrip wraps the provided delegate RoundTripper with a circuit breaker. +func (r roundTripper) RoundTrip(req *http.Request) (res *http.Response, err error) { + cb, err := newCircuitBreaker(req, r.store) + if err != nil { + return r.delegate.RoundTrip(req) + } + + response, executeErr := cb.Execute(func() (any, error) { + roundTripRes, roundTripErr := r.delegate.RoundTrip(req) + if roundTripErr != nil { + return nil, roundTripErr + } + + err = roundTripRes.Body.Close() + if err != nil { + return nil, err + } + + return roundTripRes, responseToError(roundTripRes) + }) + + if response != nil { + return response.(*http.Response), executeErr + } + + if errors.Is(executeErr, gobreaker.ErrOpenState) { + errorMsg := "This endpoint has been requested too many times. Try again later." + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Body: io.NopCloser(bytes.NewBufferString(errorMsg)), + Header: make(http.Header), + } + + resp.Header.Set("Retry-After", Timeout.String()) + + return resp, nil + } + + return nil, executeErr +} + +func newCircuitBreaker(req *http.Request, store *gobreaker.RedisStore) (*gobreaker.DistributedCircuitBreaker[any], error) { + var st gobreaker.Settings + + key, err := getRedisKey(req) + if err != nil { + return nil, err + } + st.Name = key + st.MaxRequests = MaxRequests + st.Timeout = Timeout + + st.OnStateChange = func(name string, from gobreaker.State, to gobreaker.State) { + log.WithFields(log.Fields{"name": name, "from": from.String(), "to": to.String()}).Info("gobreaker: state change") + } + st.ReadyToTrip = func(counts gobreaker.Counts) bool { + return counts.ConsecutiveFailures > ConsecutiveFailures + } + st.IsSuccessful = func(err error) bool { + return err == nil + } + + return gobreaker.NewDistributedCircuitBreaker[any](store, st) +} + +func getRedisKey(req *http.Request) (string, error) { + if req.Body == nil { + return "", errors.New("gobreaker: missing response body") + } + + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + log.WithError(err).Info("gobreaker: failed to read request body") + return "", err + } + + defer func() { _ = req.Body.Close() }() + req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + // Ssh key_id is present in the JSON body for git ssh requests, and uniquely identifies a user + var jsonBody map[string]any + if err := json.Unmarshal(bodyBytes, &jsonBody); err == nil { + if id, ok := jsonBody["key_id"].(string); ok && id != "" { + return "gobreaker:key_id:" + id, nil + } + } + + return "", errors.New("gobreaker: key not found") +} + +// If there was a Too Many Requests error in the http response, return an error to be passed into IsSuccessful() +func responseToError(res *http.Response) error { + if res.StatusCode != http.StatusTooManyRequests { + return nil + } + + body, err := io.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + + defer func() { _ = res.Body.Close() }() + res.Body = io.NopCloser(bytes.NewBuffer(body)) + + return errors.New(string(body)) +} diff --git a/workhorse/internal/circuitbreaker/roundtripper_test.go b/workhorse/internal/circuitbreaker/roundtripper_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1acd338bc3d82e5fdf746e654e03a5e1be60f1ba --- /dev/null +++ b/workhorse/internal/circuitbreaker/roundtripper_test.go @@ -0,0 +1,270 @@ +package circuitbreaker + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "net/http" + "net/url" + "strings" + "testing" + + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/config" +) + +// mockRoundTripper implements http.RoundTripper for testing +type mockRoundTripper struct { + response *http.Response + err error +} + +const ( + delegateBody = "delegate body" +) + +func (m *mockRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) { + return m.response, m.err +} + +func TestRoundTripCircuitBreaker(t *testing.T) { + redisConfig, cleanup := setupRedisConfig(t) + defer cleanup() + + testCases := []struct { + name string + statusCode int + shouldTrip bool + }{ + {"429 Too Many Requests", http.StatusTooManyRequests, true}, + {"200 OK", http.StatusOK, false}, + {"500 Internal Server Error", http.StatusInternalServerError, false}, + {"403 Forbidden", http.StatusForbidden, false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + delegateResponseHeader := http.Header{ + tc.name: []string{tc.name}, + } + mockRT := &mockRoundTripper{ + response: &http.Response{ + StatusCode: tc.statusCode, + Body: io.NopCloser(bytes.NewBufferString(tc.name)), + Header: delegateResponseHeader, + }, + } + rt := NewRoundTripper(mockRT, redisConfig) + + reqBody, err := json.Marshal(map[string]string{"key_id": "test-user-" + tc.name}) + require.NoError(t, err) + req, err := http.NewRequest("POST", "http://example.com", bytes.NewBuffer(reqBody)) + require.NoError(t, err) + + // Make enough requests to trip the circuit breaker + for range ConsecutiveFailures + 1 { + resp, _ := rt.RoundTrip(req) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewBuffer(body)) + + assert.Equal(t, tc.statusCode, resp.StatusCode) + assert.Equal(t, delegateResponseHeader, resp.Header) + assert.Equal(t, tc.name, string(body)) + resp.Body.Close() + } + + // Check if the circuit breaker tripped + resp, _ := rt.RoundTrip(req) + + if tc.shouldTrip { + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewBuffer(body)) + + circuitBreakerHeader := http.Header{ + "Retry-After": []string{Timeout.String()}, + } + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + assert.Equal(t, "This endpoint has been requested too many times. Try again later.", string(body)) + assert.Equal(t, circuitBreakerHeader, resp.Header) + } else { + assert.Equal(t, tc.statusCode, resp.StatusCode) + } + }) + } +} + +func TestRedisConfigErrors(t *testing.T) { + mockRT := &mockRoundTripper{ + response: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewBufferString(delegateBody)), + }, + } + + testCases := []struct { + name string + redisConfig *config.RedisConfig + }{ + { + name: "Nil Redis config", + redisConfig: nil, + }, + { + name: "Invalid Redis URL", + redisConfig: func() *config.RedisConfig { + invalidURL, _ := url.Parse("invalid://localhost:6379") + return &config.RedisConfig{ + URL: config.TomlURL{URL: *invalidURL}, + } + }(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rt := NewRoundTripper(mockRT, tc.redisConfig) + + req, err := http.NewRequest("GET", "http://example.com", nil) + require.NoError(t, err) + + resp, _ := rt.RoundTrip(req) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewBuffer(body)) + + // Should use delegate directly in both cases + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, delegateBody, string(body)) + }) + } +} + +func TestCircuitBreakerNilRedisKey(t *testing.T) { + redisConfig, cleanup := setupRedisConfig(t) + defer cleanup() + + errorResp := delegateErrorResponse() + mockRT := &mockRoundTripper{response: errorResp} + errorResp.Body.Close() + rt := NewRoundTripper(mockRT, redisConfig) + + reqBody, err := json.Marshal(map[string]string{"not_a_key_id": "test-value"}) + require.NoError(t, err) + + req, err := http.NewRequest("POST", "http://example.com", bytes.NewBuffer(reqBody)) + require.NoError(t, err) + + testCircuitBreakerResponse(t, rt, req, delegateBody) +} + +func TestCircuitBreakerRedisKeyException(t *testing.T) { + redisConfig, cleanup := setupRedisConfig(t) + defer cleanup() + + errorResp := delegateErrorResponse() + mockRT := &mockRoundTripper{response: errorResp} + errorResp.Body.Close() + rt := NewRoundTripper(mockRT, redisConfig) + + req, err := http.NewRequest("POST", "http://example.com", &errorReader{}) + require.NoError(t, err) + + testCircuitBreakerResponse(t, rt, req, delegateBody) +} + +func delegateErrorResponse() *http.Response { + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Body: io.NopCloser(bytes.NewBufferString(delegateBody)), + } +} + +type errorReader struct{} + +func (e *errorReader) Read(_ []byte) (n int, err error) { + return 0, errors.New("simulated read error") +} + +func testCircuitBreakerResponse(t *testing.T, rt http.RoundTripper, req *http.Request, expectedBody string) { + for range ConsecutiveFailures + 2 { + resp, _ := rt.RoundTrip(req) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewBuffer(body)) + + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + assert.Equal(t, expectedBody, string(body)) + } +} + +func TestGetRedisKey(t *testing.T) { + tests := []struct { + name string + body string + expected string + }{ + { + name: "with key_id", + body: `{"key_id":"123456"}`, + expected: "gobreaker:key_id:123456", + }, + { + name: "without key_id", + body: `{"something":"else"}`, + expected: "", + }, + { + name: "invalid json", + body: `not json`, + expected: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest("POST", "http://example.com", strings.NewReader(tc.body)) + require.NoError(t, err) + + key, _ := getRedisKey(req) + assert.Equal(t, tc.expected, key) + + // Verify body can still be read + body, err := io.ReadAll(req.Body) + require.NoError(t, err) + assert.Equal(t, tc.body, string(body)) + }) + } +} + +// Create a miniredis instance +func setupRedisConfig(t *testing.T) (*config.RedisConfig, func()) { + s, err := miniredis.Run() + require.NoError(t, err) + + redisURL, err := url.Parse("redis://" + s.Addr()) + require.NoError(t, err) + redisConfig := &config.RedisConfig{ + URL: config.TomlURL{URL: *redisURL}, + } + + cleanup := func() { + s.Close() + } + + return redisConfig, cleanup +}