diff --git a/internal/gitaly/server/server.go b/internal/gitaly/server/server.go index 3c28d64f06cef04c4123191c3b7f53f8020598b4..be9dfa92d8a69fb2ef385aa9ec870df30893f106 100644 --- a/internal/gitaly/server/server.go +++ b/internal/gitaly/server/server.go @@ -14,6 +14,7 @@ import ( "gitlab.com/gitlab-org/gitaly/v16/internal/grpc/grpcstats" "gitlab.com/gitlab-org/gitaly/v16/internal/grpc/listenmux" "gitlab.com/gitlab-org/gitaly/v16/internal/grpc/middleware/cache" + "gitlab.com/gitlab-org/gitaly/v16/internal/grpc/middleware/clientcontexthandler" "gitlab.com/gitlab-org/gitaly/v16/internal/grpc/middleware/customfieldshandler" "gitlab.com/gitlab-org/gitaly/v16/internal/grpc/middleware/featureflag" "gitlab.com/gitlab-org/gitaly/v16/internal/grpc/middleware/loghandler" @@ -105,6 +106,7 @@ func (s *GitalyServerFactory) New(external, secure bool, opts ...Option) (*grpc. streamServerInterceptors := []grpc.StreamServerInterceptor{ grpccorrelation.StreamServerCorrelationInterceptor(), // Must be above the metadata handler + clientcontexthandler.StreamInterceptor, requestinfohandler.StreamInterceptor, grpcprometheus.StreamServerInterceptor, customfieldshandler.StreamInterceptor, @@ -119,6 +121,7 @@ func (s *GitalyServerFactory) New(external, secure bool, opts ...Option) (*grpc. } unaryServerInterceptors := []grpc.UnaryServerInterceptor{ grpccorrelation.UnaryServerCorrelationInterceptor(), // Must be above the metadata handler + clientcontexthandler.UnaryInterceptor, requestinfohandler.UnaryInterceptor, grpcprometheus.UnaryServerInterceptor, customfieldshandler.UnaryInterceptor, diff --git a/internal/gitaly/service/smarthttp/receive_pack_test.go b/internal/gitaly/service/smarthttp/receive_pack_test.go index 0a7824394fd7ca87a840dbbafdbbe8206b19a193..d37ec6218fa26064ef8e9330c20b00972286cac9 100644 --- a/internal/gitaly/service/smarthttp/receive_pack_test.go +++ b/internal/gitaly/service/smarthttp/receive_pack_test.go @@ -3,9 +3,12 @@ package smarthttp import ( "bytes" "context" + "encoding/base64" + "encoding/json" "errors" "fmt" "io" + "net/http" "path/filepath" "strings" "testing" @@ -18,6 +21,7 @@ import ( "gitlab.com/gitlab-org/gitaly/v16/internal/git/gittest" "gitlab.com/gitlab-org/gitaly/v16/internal/git/pktline" "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/config" + "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/config/prometheus" gitalyhook "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/hook" "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/storage" "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/storage/storagemgr" @@ -34,6 +38,7 @@ import ( "gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb" "gitlab.com/gitlab-org/gitaly/v16/streamio" "google.golang.org/grpc" + grpcmetadata "google.golang.org/grpc/metadata" ) func TestPostReceivePack_successful(t *testing.T) { @@ -121,6 +126,9 @@ func TestPostReceivePack_successful(t *testing.T) { } client := newSmartHTTPClient(t, server.Address(), cfg.Auth.Token) + + correlationID := "correlation123" + ctx = grpcmetadata.AppendToOutgoingContext(ctx, "X-GitLab-Correlation-ID", correlationID) stream, err := client.PostReceivePack(ctx) require.NoError(t, err) @@ -195,6 +203,14 @@ func TestPostReceivePack_successful(t *testing.T) { transactionID = 3 } + expectedGitalyClientContext := map[string]string{ + "correlation_id": correlationID, + } + + marshalled, err := json.Marshal(expectedGitalyClientContext) + require.NoError(t, err) + expectedGitalyClientContextEncoded := base64.StdEncoding.EncodeToString(marshalled) + require.Equal(t, gitcmd.HooksPayload{ ObjectFormat: gittest.DefaultObjectHash.Format, RuntimeDir: cfg.RuntimeDir, @@ -205,8 +221,9 @@ func TestPostReceivePack_successful(t *testing.T) { Username: "user", Protocol: "http", }, - RequestedHooks: expectedHooks, - TransactionID: transactionID, + RequestedHooks: expectedHooks, + TransactionID: transactionID, + GitalyClientContext: []byte(expectedGitalyClientContextEncoded), }, payload) require.Equal(t, 1, preReceiveCount) @@ -941,6 +958,87 @@ func TestPostReceivePack_notAllowed(t *testing.T) { require.Equal(t, 1, refTransactionServer.called) } +func handleAllowed(tb testing.TB, options gitlab.TestServerOptions) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + params := struct { + ClientContext []byte `json:"gitaly_client_context_bin"` + }{} + + decoder := json.NewDecoder(r.Body) + defer r.Body.Close() + + err := decoder.Decode(¶ms) + require.NoError(tb, err) + + gitalyClientContextDecoded, err := base64.StdEncoding.DecodeString(string(params.ClientContext)) + require.NoError(tb, err) + + var clientCtx map[string]string + + require.NoError(tb, json.Unmarshal(gitalyClientContextDecoded, &clientCtx)) + _, ok := clientCtx["correlation_id"] + require.True(tb, ok, "correlation id exists in client context") + } +} + +func TestPostReceivePack_clientContext(t *testing.T) { + t.Parallel() + + const ( + secretToken = "secret token" + glRepository = "some_repo" + glID = "key-123" + ) + + ctx := testhelper.Context(t) + cfg := testcfg.Build(t) + + var cleanup func() + cfg.Gitlab.URL, cleanup = gitlab.NewTestServer( + t, + gitlab.TestServerOptions{ + HandleAllowed: handleAllowed, + }, + ) + defer cleanup() + + gitlabClient, err := gitlab.NewHTTPClient( + testhelper.NewLogger(t), + cfg.Gitlab, + config.TLS{}, + prometheus.Config{}, + ) + + require.NoError(t, err) + gitalyServer := startSmartHTTPServerWithOptions(t, cfg, nil, []testserver.GitalyServerOpt{testserver.WithGitLabClient( + gitlabClient, + )}) + + cfg.SocketPath = gitalyServer.Address() + cfg.GitlabShell.Dir = testhelper.TempDir(t) + cfg.Auth.Token = "abc123" + cfg.Gitlab.SecretFile = gitlab.WriteShellSecretFile(t, cfg.GitlabShell.Dir, secretToken) + + testcfg.BuildGitalyHooks(t, cfg) + + repo, repoPath := gittest.CreateRepository(t, ctx, cfg) + gittest.WriteCommit(t, cfg, repoPath, gittest.WithBranch(git.DefaultBranch)) + push := setupSimplePush(t, ctx, cfg, repoPath, git.DefaultRef) + + gittest.WriteCheckNewObjectExistsHook(t, repoPath) + + client := newSmartHTTPClient(t, cfg.SocketPath, cfg.Auth.Token) + + stream, err := client.PostReceivePack(ctx) + require.NoError(t, err) + + push.perform(t, stream, &gitalypb.PostReceivePackRequest{ + Repository: repo, + GlId: glID, + GlRepository: glRepository, + }) +} + type refUpdate struct { ref git.ReferenceName from, to git.ObjectID diff --git a/internal/gitaly/service/ssh/receive_pack_test.go b/internal/gitaly/service/ssh/receive_pack_test.go index 5f2a309f2898535375e61a44284a159fe5172a0b..acb86af50bb98f2b6f9af225543ce486b4694ba0 100644 --- a/internal/gitaly/service/ssh/receive_pack_test.go +++ b/internal/gitaly/service/ssh/receive_pack_test.go @@ -3,6 +3,8 @@ package ssh import ( "bytes" "context" + "encoding/base64" + "encoding/json" "fmt" "io" "os" @@ -34,6 +36,7 @@ import ( "gitlab.com/gitlab-org/gitaly/v16/internal/transaction/txinfo" "gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb" "gitlab.com/gitlab-org/gitaly/v16/streamio" + "gitlab.com/gitlab-org/labkit/correlation" "google.golang.org/protobuf/encoding/protojson" ) @@ -203,6 +206,8 @@ func TestReceivePack_success(t *testing.T) { ctx = featureflag.ContextWithFeatureFlag(ctx, featureFlag, true) } + correlationID := "correlation123" + ctx = correlation.ContextWithCorrelation(ctx, correlationID) lHead, rHead, err := setupRepoAndPush(t, ctx, cfg, &gitalypb.SSHReceivePackRequest{ Repository: remoteRepo, GlId: "123", @@ -265,6 +270,14 @@ func TestReceivePack_success(t *testing.T) { transactionID = 6 } + expectedGitalyClientContext := map[string]string{ + "correlation_id": correlationID, + } + + marshalled, err := json.Marshal(expectedGitalyClientContext) + require.NoError(t, err) + expectedGitalyClientContextEncoded := base64.StdEncoding.EncodeToString(marshalled) + require.Equal(t, gitcmd.HooksPayload{ ObjectFormat: gittest.DefaultObjectHash.Format, RuntimeDir: cfg.RuntimeDir, @@ -275,8 +288,9 @@ func TestReceivePack_success(t *testing.T) { Username: "user", Protocol: "ssh", }, - RequestedHooks: expectedHooks, - TransactionID: transactionID, + RequestedHooks: expectedHooks, + TransactionID: transactionID, + GitalyClientContext: []byte(expectedGitalyClientContextEncoded), }, payload) require.Equal(t, 1, preReceiveCount) @@ -880,6 +894,10 @@ func sshPushCommand(t *testing.T, ctx context.Context, cfg config.Cfg, repo repo fmt.Sprintf("GIT_SSH_COMMAND=%s receive-pack", cfg.BinaryPath("gitaly-ssh")), ) + if correlationID := correlation.ExtractFromContext(ctx); correlationID != "" { + cmd.Env = append(cmd.Env, fmt.Sprintf("CORRELATION_ID=%s", correlationID)) + } + return cmd } diff --git a/internal/gitlab/test_server.go b/internal/gitlab/test_server.go index 81845ccf0eef6de709e2001d1aa381f9f8336451..74e41b3227ecfbdbca75bd8c0c373881712ffe9d 100644 --- a/internal/gitlab/test_server.go +++ b/internal/gitlab/test_server.go @@ -54,19 +54,42 @@ type TestServerOptions struct { GlRepository string ClientCertificate *testhelper.Certificate ServerCertificate *testhelper.Certificate + HandlePreReceive internalAPIHandler + HandleAllowed internalAPIHandler + HandlePostReceive internalAPIHandler + HandleLFS internalAPIHandler + HandleCheck internalAPIHandler } +type internalAPIHandler func(testing.TB, TestServerOptions) func(http.ResponseWriter, *http.Request) + // NewTestServer returns a mock gitlab server that responds to the hook api endpoints func NewTestServer(tb testing.TB, options TestServerOptions) (url string, cleanup func()) { tb.Helper() + if options.HandlePreReceive == nil { + options.HandlePreReceive = handlePreReceive + } + if options.HandleAllowed == nil { + options.HandleAllowed = handleAllowed + } + if options.HandlePostReceive == nil { + options.HandlePostReceive = handlePostReceive + } + if options.HandleLFS == nil { + options.HandleLFS = handleLfs + } + if options.HandleCheck == nil { + options.HandleCheck = handleCheck + } + mux := http.NewServeMux() prefix := strings.TrimRight(options.RelativeURLRoot, "/") + "/api/v4/internal" - mux.Handle(prefix+"/allowed", http.HandlerFunc(handleAllowed(tb, options))) - mux.Handle(prefix+"/pre_receive", http.HandlerFunc(handlePreReceive(tb, options))) - mux.Handle(prefix+"/post_receive", http.HandlerFunc(handlePostReceive(options))) - mux.Handle(prefix+"/check", http.HandlerFunc(handleCheck(tb, options))) - mux.Handle(prefix+"/lfs", http.HandlerFunc(handleLfs(tb, options))) + mux.Handle(prefix+"/allowed", http.HandlerFunc(options.HandleAllowed(tb, options))) + mux.Handle(prefix+"/pre_receive", http.HandlerFunc(options.HandlePreReceive(tb, options))) + mux.Handle(prefix+"/post_receive", http.HandlerFunc(options.HandlePostReceive(tb, options))) + mux.Handle(prefix+"/check", http.HandlerFunc(options.HandleCheck(tb, options))) + mux.Handle(prefix+"/lfs", http.HandlerFunc(options.HandleLFS(tb, options))) var tlsCfg *tls.Config if options.ClientCertificate != nil { @@ -376,7 +399,7 @@ func handlePreReceive(tb testing.TB, options TestServerOptions) func(w http.Resp } } -func handlePostReceive(options TestServerOptions) func(w http.ResponseWriter, r *http.Request) { +func handlePostReceive(tb testing.TB, options TestServerOptions) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { http.Error(w, "couldn't parse form", http.StatusBadRequest) diff --git a/internal/grpc/metadata/metadata.go b/internal/grpc/metadata/metadata.go index ab9499c1efd08ed6993f3d1cf8196d076391a745..22446a7e183f142236887517f3fdf95b53d7c58a 100644 --- a/internal/grpc/metadata/metadata.go +++ b/internal/grpc/metadata/metadata.go @@ -7,6 +7,7 @@ import ( ) // ClientContextMetadataKey is the key used by rails to propagate client context back to internal APIs +// The contents is a base64 encoded json string. const ClientContextMetadataKey = "gitaly-client-context-bin" // IncomingToOutgoing creates an outgoing context out of an incoming context with the same storage metadata @@ -49,3 +50,15 @@ func AppendToIncomingContext(ctx context.Context, key, value string) context.Con md.Append(key, value) return metadata.NewIncomingContext(ctx, md) } + +// ReplaceInIncomingContext sets a key/value pair in the incoming context +func ReplaceInIncomingContext(ctx context.Context, key, value string) context.Context { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + md = metadata.New(nil) + } + + md.Set(key, value) + + return metadata.NewIncomingContext(ctx, md) +} diff --git a/internal/grpc/middleware/clientcontexthandler/clientcontexthandler.go b/internal/grpc/middleware/clientcontexthandler/clientcontexthandler.go new file mode 100644 index 0000000000000000000000000000000000000000..f5ba91b527f925a0ecf2164b4fc9990c88847ba8 --- /dev/null +++ b/internal/grpc/middleware/clientcontexthandler/clientcontexthandler.go @@ -0,0 +1,66 @@ +package clientcontexthandler + +import ( + "context" + "encoding/base64" + "encoding/json" + + grpcmw "github.com/grpc-ecosystem/go-grpc-middleware" + "gitlab.com/gitlab-org/gitaly/v16/internal/grpc/metadata" + "gitlab.com/gitlab-org/labkit/correlation" + "google.golang.org/grpc" +) + +func addCorrelationIDToClientContext(ctx context.Context) (context.Context, error) { + gitalyClientContextEncoded := metadata.GetValue(ctx, metadata.ClientContextMetadataKey) + gitalyClientContextDecoded, err := base64.StdEncoding.DecodeString(gitalyClientContextEncoded) + if err != nil { + return nil, err + } + + // Create a map to store the decoded JSON + gitalyClientContext := make(map[string]interface{}) + + if string(gitalyClientContextDecoded) != "" { + // Unmarshal the JSON string into the map + if err = json.Unmarshal(gitalyClientContextDecoded, &gitalyClientContext); err != nil { + return nil, err + } + } + + gitalyClientContext["correlation_id"] = correlation.ExtractFromContext(ctx) + + jsonBytes, err := json.Marshal(gitalyClientContext) + if err != nil { + return nil, err + } + + base64String := base64.StdEncoding.EncodeToString(jsonBytes) + + return metadata.ReplaceInIncomingContext(ctx, metadata.ClientContextMetadataKey, base64String), nil +} + +// UnaryInterceptor returns a Unary Interceptor that initializes and injects a log.CustomFields object into the context +func UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + ctx, err := addCorrelationIDToClientContext(ctx) + if err != nil { + return nil, err + } + + res, err := handler(ctx, req) + + return res, err +} + +// StreamInterceptor returns a Stream Interceptor that initializes and injects a log.CustomFields object into the context +func StreamInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + ctx, err := addCorrelationIDToClientContext(stream.Context()) + if err != nil { + return err + } + + wrapped := grpcmw.WrapServerStream(stream) + wrapped.WrappedContext = ctx + + return handler(srv, wrapped) +}