diff --git a/internal/cli/gitaly/serve.go b/internal/cli/gitaly/serve.go index 9db91a758a7d795426ba5de4df219aa433a24081..beee3407a2e6c194f3e3db84752d487459be581b 100644 --- a/internal/cli/gitaly/serve.go +++ b/internal/cli/gitaly/serve.go @@ -253,7 +253,16 @@ func run(cfg config.Cfg, logger log.Logger) error { } prometheus.MustRegister(gitlabClient) - hookManager = hook.NewManager(cfg, locator, logger, gitCmdFactory, transactionManager, gitlabClient, hook.NewTransactionRegistry(txRegistry)) + hookManager = hook.NewManager( + cfg, + locator, + logger, + gitCmdFactory, + transactionManager, + gitlabClient, + hook.NewTransactionRegistry(txRegistry), + hook.NewProcReceiveRegistry(), + ) } conns := client.NewPool( diff --git a/internal/cli/gitaly/subcmd_check.go b/internal/cli/gitaly/subcmd_check.go index 5ca3419d14eee60d0d23312df476af405ee0de3a..7f1cbb3978480606ce4a832a484935e5c8744ded 100644 --- a/internal/cli/gitaly/subcmd_check.go +++ b/internal/cli/gitaly/subcmd_check.go @@ -74,5 +74,14 @@ func checkAPI(cfg config.Cfg, logger log.Logger) (*gitlab.CheckInfo, error) { } defer cleanup() - return hook.NewManager(cfg, config.NewLocator(cfg), logger, gitCmdFactory, nil, gitlabAPI, hook.NewTransactionRegistry(storagemgr.NewTransactionRegistry())).Check(context.Background()) + return hook.NewManager( + cfg, + config.NewLocator(cfg), + logger, + gitCmdFactory, + nil, + gitlabAPI, + hook.NewTransactionRegistry(storagemgr.NewTransactionRegistry()), + hook.NewProcReceiveRegistry(), + ).Check(context.Background()) } diff --git a/internal/gitaly/hook/manager.go b/internal/gitaly/hook/manager.go index dd4b2a5d5ac7f1eb422d9ca2f983bb42a5405fc4..1cdd18e3366b7859e04450ee772b88ac51ed5a2a 100644 --- a/internal/gitaly/hook/manager.go +++ b/internal/gitaly/hook/manager.go @@ -78,13 +78,14 @@ func NewTransactionRegistry(txRegistry *storagemgr.TransactionRegistry) Transact // GitLabHookManager is a hook manager containing Git hook business logic. It // uses the GitLab API to authenticate and track ongoing hook calls. type GitLabHookManager struct { - cfg config.Cfg - locator storage.Locator - logger log.Logger - gitCmdFactory git.CommandFactory - txManager transaction.Manager - gitlabClient gitlab.Client - txRegistry TransactionRegistry + cfg config.Cfg + locator storage.Locator + logger log.Logger + gitCmdFactory git.CommandFactory + txManager transaction.Manager + gitlabClient gitlab.Client + txRegistry TransactionRegistry + procReceiveRegistry *ProcReceiveRegistry } // NewManager returns a new hook manager @@ -96,14 +97,16 @@ func NewManager( txManager transaction.Manager, gitlabClient gitlab.Client, txRegistry TransactionRegistry, + procReceiveRegistry *ProcReceiveRegistry, ) *GitLabHookManager { return &GitLabHookManager{ - cfg: cfg, - locator: locator, - logger: logger, - gitCmdFactory: gitCmdFactory, - txManager: txManager, - gitlabClient: gitlabClient, - txRegistry: txRegistry, + cfg: cfg, + locator: locator, + logger: logger, + gitCmdFactory: gitCmdFactory, + txManager: txManager, + gitlabClient: gitlabClient, + txRegistry: txRegistry, + procReceiveRegistry: procReceiveRegistry, } } diff --git a/internal/gitaly/hook/postreceive_test.go b/internal/gitaly/hook/postreceive_test.go index 8080358312e6ff12d668130cddd472196678b0d7..31fc8948add2143751bb7c4035b4b2e09bf751f7 100644 --- a/internal/gitaly/hook/postreceive_test.go +++ b/internal/gitaly/hook/postreceive_test.go @@ -84,7 +84,7 @@ func TestPostReceive_customHook(t *testing.T) { txManager := transaction.NewTrackingManager() hookManager := NewManager(cfg, locator, testhelper.SharedLogger(t), gitCmdFactory, txManager, gitlab.NewMockClient( t, gitlab.MockAllowed, gitlab.MockPreReceive, gitlab.MockPostReceive, - ), NewTransactionRegistry(storagemgr.NewTransactionRegistry())) + ), NewTransactionRegistry(storagemgr.NewTransactionRegistry()), NewProcReceiveRegistry()) receiveHooksPayload := &git.UserDetails{ UserID: "1234", @@ -381,7 +381,15 @@ func TestPostReceive_gitlab(t *testing.T) { }, } - hookManager := NewManager(cfg, config.NewLocator(cfg), testhelper.SharedLogger(t), gittest.NewCommandFactory(t, cfg), transaction.NewManager(cfg, testhelper.SharedLogger(t), backchannel.NewRegistry()), &gitlabAPI, NewTransactionRegistry(storagemgr.NewTransactionRegistry())) + hookManager := NewManager( + cfg, + config.NewLocator(cfg), + testhelper.SharedLogger(t), + gittest.NewCommandFactory(t, cfg), + transaction.NewManager(cfg, testhelper.SharedLogger(t), backchannel.NewRegistry()), + &gitlabAPI, NewTransactionRegistry(storagemgr.NewTransactionRegistry()), + NewProcReceiveRegistry(), + ) gittest.WriteCustomHook(t, repoPath, "post-receive", []byte("#!/bin/sh\necho hook called\n")) @@ -419,7 +427,7 @@ func TestPostReceive_quarantine(t *testing.T) { hookManager := NewManager(cfg, config.NewLocator(cfg), testhelper.SharedLogger(t), gittest.NewCommandFactory(t, cfg), nil, gitlab.NewMockClient( t, gitlab.MockAllowed, gitlab.MockPreReceive, gitlab.MockPostReceive, - ), NewTransactionRegistry(storagemgr.NewTransactionRegistry())) + ), NewTransactionRegistry(storagemgr.NewTransactionRegistry()), NewProcReceiveRegistry()) gittest.WriteCustomHook(t, repoPath, "post-receive", []byte(fmt.Sprintf( `#!/bin/sh diff --git a/internal/gitaly/hook/prereceive_test.go b/internal/gitaly/hook/prereceive_test.go index 77f6bb4a15d7c8a129e6e2bb32fd2ecfdad370a5..1bfab944786f73b892871f7f5b0dc617f1d00a26 100644 --- a/internal/gitaly/hook/prereceive_test.go +++ b/internal/gitaly/hook/prereceive_test.go @@ -43,7 +43,7 @@ func TestPrereceive_customHooks(t *testing.T) { txManager := transaction.NewTrackingManager() hookManager := NewManager(cfg, locator, testhelper.SharedLogger(t), gitCmdFactory, txManager, gitlab.NewMockClient( t, gitlab.MockAllowed, gitlab.MockPreReceive, gitlab.MockPostReceive, - ), NewTransactionRegistry(storagemgr.NewTransactionRegistry())) + ), NewTransactionRegistry(storagemgr.NewTransactionRegistry()), NewProcReceiveRegistry()) receiveHooksPayload := &git.UserDetails{ UserID: "1234", @@ -228,7 +228,7 @@ func TestPrereceive_quarantine(t *testing.T) { hookManager := NewManager(cfg, config.NewLocator(cfg), testhelper.SharedLogger(t), gittest.NewCommandFactory(t, cfg), nil, gitlab.NewMockClient( t, gitlab.MockAllowed, gitlab.MockPreReceive, gitlab.MockPostReceive, - ), NewTransactionRegistry(storagemgr.NewTransactionRegistry())) + ), NewTransactionRegistry(storagemgr.NewTransactionRegistry()), NewProcReceiveRegistry()) //nolint:gitaly-linters gittest.WriteCustomHook(t, repoPath, "pre-receive", []byte(fmt.Sprintf( @@ -419,7 +419,16 @@ func TestPrereceive_gitlab(t *testing.T) { }, } - hookManager := NewManager(cfg, config.NewLocator(cfg), testhelper.SharedLogger(t), gittest.NewCommandFactory(t, cfg), transaction.NewManager(cfg, testhelper.SharedLogger(t), backchannel.NewRegistry()), &gitlabAPI, NewTransactionRegistry(storagemgr.NewTransactionRegistry())) + hookManager := NewManager( + cfg, + config.NewLocator(cfg), + testhelper.SharedLogger(t), + gittest.NewCommandFactory(t, cfg), + transaction.NewManager(cfg, testhelper.SharedLogger(t), backchannel.NewRegistry()), + &gitlabAPI, + NewTransactionRegistry(storagemgr.NewTransactionRegistry()), + NewProcReceiveRegistry(), + ) gittest.WriteCustomHook(t, repoPath, "pre-receive", []byte("#!/bin/sh\necho called\n")) diff --git a/internal/gitaly/hook/procreceive.go b/internal/gitaly/hook/procreceive.go new file mode 100644 index 0000000000000000000000000000000000000000..44b5ae7c6810f6dc931b7ea1d4dca8cd4db990b8 --- /dev/null +++ b/internal/gitaly/hook/procreceive.go @@ -0,0 +1,164 @@ +package hook + +import ( + "bytes" + "context" + "fmt" + "io" + + "gitlab.com/gitlab-org/gitaly/v16/internal/git" + "gitlab.com/gitlab-org/gitaly/v16/internal/git/pktline" + "gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb" +) + +// ProcReceiveHook is used to intercept git-receive-pack(1)'s execute-commands code. +// This allows us to intercept the reference updates and avoid writing directly to +// the disk. The intercepted updates are then bundled into `procReceiveHookInvocation` +// and added to the registry. The RPC which invoked git-receive-pack(1) in the first +// place picks up the invocation from the RPC and accepts/rejects individual references. +func (m *GitLabHookManager) ProcReceiveHook(ctx context.Context, repo *gitalypb.Repository, env []string, stdin io.Reader, stdout, stderr io.Writer) error { + payload, err := git.HooksPayloadFromEnv(env) + if err != nil { + return fmt.Errorf("extracting hooks payload: %w", err) + } + + // This hook only works when there is a transaction present. + if payload.TransactionID == 0 { + return fmt.Errorf("no transaction found in payload") + } + + scanner := pktline.NewScanner(stdin) + + // Version and feature negotiation. + if !scanner.Scan() { + return fmt.Errorf("expected input: %w", scanner.Err()) + } + + data, err := pktline.Payload(scanner.Bytes()) + if err != nil { + return fmt.Errorf("receiving header: %w", err) + } + + var featureRequests *procReceiveFeatureRequests + after, ok := bytes.CutPrefix(data, []byte("version=1\000")) + if !ok { + return fmt.Errorf("unsupported version: %s", data) + } + + featureRequests, err = parseFeatureRequest(after) + if err != nil { + return fmt.Errorf("parsing feature request: %w", err) + } + + if !scanner.Scan() { + return fmt.Errorf("expected input: %w", scanner.Err()) + } + + if !pktline.IsFlush(scanner.Bytes()) { + return fmt.Errorf("expected pkt flush") + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("parsing stdin: %w", err) + } + + if _, err := pktline.WriteString(stdout, fmt.Sprintf("version=1\000%s", featureRequests)); err != nil { + return fmt.Errorf("writing version: %w", err) + } + + if err := pktline.WriteFlush(stdout); err != nil { + return fmt.Errorf("flushing version: %w", err) + } + + updates := []ReferenceUpdate{} + for scanner.Scan() { + bytes := scanner.Bytes() + + // When all reference updates are transmitted, we expect a flush. + if pktline.IsFlush(bytes) { + break + } + + data, err := pktline.Payload(bytes) + if err != nil { + return fmt.Errorf("receiving reference update: %w", err) + } + + update, err := parseRefUpdate(data) + if err != nil { + return fmt.Errorf("parse reference update: %w", err) + } + updates = append(updates, update) + } + + invocation := newProcReceiveHookInvocation( + featureRequests.atomic, + payload.TransactionID, + updates, + func(referenceName git.ReferenceName) error { + if _, err := pktline.WriteString(stdout, fmt.Sprintf("ok %s", referenceName)); err != nil { + return fmt.Errorf("write ref %s ok: %w", referenceName, err) + } + + return nil + }, + func(referenceName git.ReferenceName, reason string) error { + if _, err := pktline.WriteString(stdout, fmt.Sprintf("ng %s %s", referenceName, reason)); err != nil { + return fmt.Errorf("write ref %s ng: %w", referenceName, err) + } + + return nil + }, + func() error { + if err := pktline.WriteFlush(stdout); err != nil { + return fmt.Errorf("flushing updates: %w", err) + } + + return nil + }) + + m.procReceiveRegistry.set(invocation) + + return nil +} + +func parseRefUpdate(data []byte) (ReferenceUpdate, error) { + var update ReferenceUpdate + + split := bytes.Split(data, []byte(" ")) + if len(split) != 3 { + return update, fmt.Errorf("unknown ref update format: %s", split) + } + + update.Ref = git.ReferenceName(split[2]) + update.OldOID = git.ObjectID(split[0]) + update.NewOID = git.ObjectID(split[1]) + + return update, nil +} + +type procReceiveFeatureRequests struct { + atomic bool +} + +func (r *procReceiveFeatureRequests) String() string { + s := "" + if r.atomic { + s = "atomic" + } + + return s +} + +// parseFeatureRequest parses the features requested. +func parseFeatureRequest(data []byte) (*procReceiveFeatureRequests, error) { + var featureRequests procReceiveFeatureRequests + + for _, feature := range bytes.Split(data, []byte(" ")) { + if bytes.Equal(feature, []byte("atomic")) { + featureRequests.atomic = true + } + } + + return &featureRequests, nil +} diff --git a/internal/gitaly/hook/procreceive_registry.go b/internal/gitaly/hook/procreceive_registry.go new file mode 100644 index 0000000000000000000000000000000000000000..4a27c4a7acdb4cb2e6eac3920928199f8b552ba3 --- /dev/null +++ b/internal/gitaly/hook/procreceive_registry.go @@ -0,0 +1,154 @@ +package hook + +import ( + "sync" + + "gitlab.com/gitlab-org/gitaly/v16/internal/git" + "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/storage" +) + +// ReferenceUpdate denotes a single reference update to be made. +type ReferenceUpdate struct { + Ref git.ReferenceName + OldOID git.ObjectID + NewOID git.ObjectID +} + +// ProcReceiveHookInvocation is an interface which provides abstraction +// around the proc-receive invocation provided by the ProcReceiveRegistry. +// The interface allows the users to obtain reference updates, meta +// information around these updates and functions to accept or reject +// individual updates. +type ProcReceiveHookInvocation interface { + // Atomic denotes whether the push was atomic. + Atomic() bool + + // ReferenceUpdates provides the reference updates to be made. + ReferenceUpdates() []ReferenceUpdate + + // AcceptUpdate writes to the stream that the reference was accepted. + AcceptUpdate(referenceName git.ReferenceName) error + // RejectUpdate writes to the stream the reference was rejected and + // the reason why. + RejectUpdate(referenceName git.ReferenceName, reason string) error + + // Close must be called on the invocation to clean up. Calling it also + // signals to the `ProcReceiveHook` handler that it can exit as the streams + // are no longer needed. + Close() error +} + +type procReceiveHookInvocation struct { + acceptUpdateFn func(referenceName git.ReferenceName) error + rejectUpdateFn func(referenceName git.ReferenceName, reason string) error + closeFn func() error + referenceUpdates []ReferenceUpdate + id storage.TransactionID + atomic bool +} + +func newProcReceiveHookInvocation( + atomic bool, + id storage.TransactionID, + referenceUpdates []ReferenceUpdate, + acceptUpdateFn func(referenceName git.ReferenceName) error, + rejectUpdateFn func(referenceName git.ReferenceName, reason string) error, + closeFn func() error, +) *procReceiveHookInvocation { + return &procReceiveHookInvocation{ + atomic: atomic, + id: id, + referenceUpdates: referenceUpdates, + acceptUpdateFn: acceptUpdateFn, + rejectUpdateFn: rejectUpdateFn, + closeFn: closeFn, + } +} + +// Atomic denotes whether the push was atomic. +func (i *procReceiveHookInvocation) Atomic() bool { + return i.atomic +} + +// ReferenceUpdates provides the reference updates to be made. +func (i *procReceiveHookInvocation) ReferenceUpdates() []ReferenceUpdate { + return i.referenceUpdates +} + +// AcceptUpdate writes to the stream that the reference was accepted. +func (i *procReceiveHookInvocation) AcceptUpdate(referenceName git.ReferenceName) error { + return i.acceptUpdateFn(referenceName) +} + +// RejectUpdate writes to the stream the reference was rejected and +// the reason why. +func (i *procReceiveHookInvocation) RejectUpdate(referenceName git.ReferenceName, reason string) error { + return i.rejectUpdateFn(referenceName, reason) +} + +// Close must be called on the invocation to clean up. Calling it also +// signals to the `ProcReceiveHook` handler that it can exit as the streams +// are no longer needed. +func (i *procReceiveHookInvocation) Close() error { + return i.closeFn() +} + +// ProcReceiveRegistry is the registry which provides the proc-receive hook +// invocation mechanism against a provided transaction ID. +// +// The registry allows RPCs to communicate with the git-proc-receive hook and +// receive information about the reference updates to be performed, then the RPCs +// can interact with the transaction manager and accept or reject reference +// updates, this information is relayed to the proc-receive hook which will +// relay the information to the user. +type ProcReceiveRegistry struct { + subs map[storage.TransactionID]chan ProcReceiveHookInvocation + invocations map[storage.TransactionID]ProcReceiveHookInvocation + sync.Mutex +} + +// NewProcReceiveRegistry creates a new registry by allocating the required +// variables. +func NewProcReceiveRegistry() *ProcReceiveRegistry { + return &ProcReceiveRegistry{ + subs: make(map[storage.TransactionID]chan ProcReceiveHookInvocation), + invocations: make(map[storage.TransactionID]ProcReceiveHookInvocation), + } +} + +// Get is a blocking call which allows the user to obtain the invocation for +// a particular transaction ID. If the proc-receive hook is yet to add the +// invocation, the call blocks indefinitely until it is available. +// +// Once an invocation is retrieved, it is deleted from the internal state. +func (r *ProcReceiveRegistry) Get(id storage.TransactionID) ProcReceiveHookInvocation { + r.Lock() + + if invocation, ok := r.invocations[id]; ok { + delete(r.invocations, id) + r.Unlock() + + return invocation + } + + ch := make(chan ProcReceiveHookInvocation) + r.subs[id] = ch + r.Unlock() + + return <-ch +} + +// set adds a invocation against its transaction ID. If there are any +// subscribers waiting for this invocation, we stream the invocation to them +// and unblock them. +func (r *ProcReceiveRegistry) set(invocation *procReceiveHookInvocation) { + r.Lock() + defer r.Unlock() + + r.invocations[invocation.id] = invocation + + if listener, ok := r.subs[invocation.id]; ok { + listener <- invocation + } + delete(r.subs, invocation.id) +} diff --git a/internal/gitaly/hook/procreceive_registry_test.go b/internal/gitaly/hook/procreceive_registry_test.go new file mode 100644 index 0000000000000000000000000000000000000000..49a60a16a1ad53c72d22432e9a05eabf29dd2294 --- /dev/null +++ b/internal/gitaly/hook/procreceive_registry_test.go @@ -0,0 +1,110 @@ +package hook + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/storage" + "golang.org/x/sync/errgroup" +) + +func TestProcReceiveRegistry(t *testing.T) { + t.Parallel() + + newInvocation := func(id storage.TransactionID) *procReceiveHookInvocation { + return newProcReceiveHookInvocation(false, id, nil, nil, nil, nil) + } + + t.Run("invocation added and received", func(t *testing.T) { + t.Parallel() + + registry := NewProcReceiveRegistry() + invocation := newInvocation(1) + registry.set(invocation) + receivedInvocation := registry.Get(1) + require.Equal(t, invocation, receivedInvocation) + }) + + t.Run("invocation added after receiver", func(t *testing.T) { + t.Parallel() + + registry := NewProcReceiveRegistry() + invocation := newInvocation(1) + + go func() { + registry.set(invocation) + }() + + receivedInvocation := registry.Get(1) + require.Equal(t, invocation, receivedInvocation) + }) + + t.Run("invocation not received", func(t *testing.T) { + t.Parallel() + + registry := NewProcReceiveRegistry() + invocation := newInvocation(1) + + // Shouldn't block and should finish the test. + registry.set(invocation) + }) + + t.Run("invocation added twice", func(t *testing.T) { + t.Parallel() + + registry := NewProcReceiveRegistry() + invocation := newInvocation(1) + + // set() is idempotent, so makes no difference. + registry.set(invocation) + registry.set(invocation) + }) + + t.Run("multiple invocations", func(t *testing.T) { + t.Parallel() + + registry := NewProcReceiveRegistry() + + invocation1 := newInvocation(1) + invocation2 := newInvocation(2) + invocation3 := newInvocation(3) + + group, _ := errgroup.WithContext(context.Background()) + + group.Go(func() error { + receivedInvocation := registry.Get(1) + if receivedInvocation != invocation1 { + return fmt.Errorf("invalid invocation: %d", 1) + } + return nil + }) + + group.Go(func() error { + receivedInvocation := registry.Get(2) + if receivedInvocation != invocation2 { + return fmt.Errorf("invalid invocation: %d", 2) + } + return nil + }) + + group.Go(func() error { + receivedInvocation := registry.Get(3) + if receivedInvocation != invocation3 { + return fmt.Errorf("invalid invocation: %d", 3) + } + return nil + }) + + group.Go(func() error { + registry.set(invocation1) + registry.set(invocation2) + registry.set(invocation3) + + return nil + }) + + require.NoError(t, group.Wait()) + }) +} diff --git a/internal/gitaly/hook/procreceive_test.go b/internal/gitaly/hook/procreceive_test.go new file mode 100644 index 0000000000000000000000000000000000000000..a5cf5823c9568775a2af8033a6a0f3add340bdcc --- /dev/null +++ b/internal/gitaly/hook/procreceive_test.go @@ -0,0 +1,267 @@ +package hook + +import ( + "bytes" + "context" + "errors" + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v16/internal/featureflag" + "gitlab.com/gitlab-org/gitaly/v16/internal/git" + "gitlab.com/gitlab-org/gitaly/v16/internal/git/gittest" + "gitlab.com/gitlab-org/gitaly/v16/internal/git/pktline" + "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/config" + "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/storage/storagemgr" + "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/transaction" + "gitlab.com/gitlab-org/gitaly/v16/internal/gitlab" + "gitlab.com/gitlab-org/gitaly/v16/internal/testhelper" + "gitlab.com/gitlab-org/gitaly/v16/internal/testhelper/testcfg" +) + +func TestProcReceiveHook(t *testing.T) { + t.Parallel() + + ctx := testhelper.Context(t) + cfg := testcfg.Build(t) + + repo, _ := gittest.CreateRepository(t, ctx, cfg, gittest.CreateRepositoryConfig{ + SkipCreationViaService: true, + }) + + gitCmdFactory := gittest.NewCommandFactory(t, cfg) + locator := config.NewLocator(cfg) + + txManager := transaction.NewTrackingManager() + + receiveHooksPayload := &git.UserDetails{ + UserID: "1234", + Username: "user", + Protocol: "web", + } + + payload, err := git.NewHooksPayload( + cfg, + repo, + gittest.DefaultObjectHash, + nil, + receiveHooksPayload, + git.PreReceiveHook, + featureflag.FromContext(ctx), + 1, + ).Env() + require.NoError(t, err) + + procReceiveRegistry := NewProcReceiveRegistry() + + hookManager := NewManager( + cfg, + locator, + testhelper.SharedLogger(t), + gitCmdFactory, + txManager, + gitlab.NewMockClient(t, gitlab.MockAllowed, gitlab.MockPreReceive, gitlab.MockPostReceive), + NewTransactionRegistry(storagemgr.NewTransactionRegistry()), + procReceiveRegistry, + ) + + type setupData struct { + env []string + ctx context.Context + stdin string + expectedErr error + expectedStdout string + expectedUpdates []ReferenceUpdate + expectedAtomic bool + invocationSteps func(invocation ProcReceiveHookInvocation) error + } + + for _, tc := range []struct { + desc string + setup func(t *testing.T, ctx context.Context) setupData + }{ + { + desc: "no payload", + setup: func(t *testing.T, ctx context.Context) setupData { + return setupData{ + env: []string{}, + ctx: ctx, + expectedErr: fmt.Errorf("extracting hooks payload: %w", errors.New("no hooks payload found in environment")), + } + }, + }, + { + desc: "invalid version", + setup: func(t *testing.T, ctx context.Context) setupData { + var stdin bytes.Buffer + _, err = pktline.WriteString(&stdin, "version=2") + require.NoError(t, err) + + return setupData{ + env: []string{payload}, + ctx: ctx, + stdin: stdin.String(), + expectedErr: errors.New("unsupported version: version=2"), + } + }, + }, + { + desc: "single reference with atomic", + setup: func(t *testing.T, ctx context.Context) setupData { + var stdin bytes.Buffer + _, err = pktline.WriteString(&stdin, "version=1\000push-options atomic") + require.NoError(t, err) + err = pktline.WriteFlush(&stdin) + require.NoError(t, err) + _, err = pktline.WriteString(&stdin, fmt.Sprintf("%s %s %s", + gittest.DefaultObjectHash.ZeroOID, gittest.DefaultObjectHash.EmptyTreeOID, "refs/heads/main")) + require.NoError(t, err) + err = pktline.WriteFlush(&stdin) + require.NoError(t, err) + + var stdout bytes.Buffer + _, err = pktline.WriteString(&stdout, "version=1\000atomic") + require.NoError(t, err) + err = pktline.WriteFlush(&stdout) + _, err = pktline.WriteString(&stdout, "ok refs/heads/main") + require.NoError(t, err) + err = pktline.WriteFlush(&stdout) + + return setupData{ + env: []string{payload}, + ctx: ctx, + stdin: stdin.String(), + expectedStdout: stdout.String(), + expectedAtomic: true, + expectedUpdates: []ReferenceUpdate{ + { + Ref: "refs/heads/main", + OldOID: gittest.DefaultObjectHash.ZeroOID, + NewOID: gittest.DefaultObjectHash.EmptyTreeOID, + }, + }, + invocationSteps: func(invocation ProcReceiveHookInvocation) error { + require.NoError(t, invocation.AcceptUpdate("refs/heads/main")) + return invocation.Close() + }, + } + }, + }, + { + desc: "single reference without atomic", + setup: func(t *testing.T, ctx context.Context) setupData { + var stdin bytes.Buffer + _, err = pktline.WriteString(&stdin, "version=1\000push-options") + require.NoError(t, err) + err = pktline.WriteFlush(&stdin) + require.NoError(t, err) + _, err = pktline.WriteString(&stdin, fmt.Sprintf("%s %s %s", + gittest.DefaultObjectHash.ZeroOID, gittest.DefaultObjectHash.EmptyTreeOID, "refs/heads/main")) + require.NoError(t, err) + err = pktline.WriteFlush(&stdin) + require.NoError(t, err) + + var stdout bytes.Buffer + _, err = pktline.WriteString(&stdout, "version=1\000") + require.NoError(t, err) + err = pktline.WriteFlush(&stdout) + _, err = pktline.WriteString(&stdout, "ok refs/heads/main") + require.NoError(t, err) + err = pktline.WriteFlush(&stdout) + + return setupData{ + env: []string{payload}, + ctx: ctx, + stdin: stdin.String(), + expectedStdout: stdout.String(), + expectedUpdates: []ReferenceUpdate{ + { + Ref: "refs/heads/main", + OldOID: gittest.DefaultObjectHash.ZeroOID, + NewOID: gittest.DefaultObjectHash.EmptyTreeOID, + }, + }, + invocationSteps: func(invocation ProcReceiveHookInvocation) error { + require.NoError(t, invocation.AcceptUpdate("refs/heads/main")) + return invocation.Close() + }, + } + }, + }, + { + desc: "multiple references", + setup: func(t *testing.T, ctx context.Context) setupData { + var stdin bytes.Buffer + _, err = pktline.WriteString(&stdin, "version=1\000push-options") + require.NoError(t, err) + err = pktline.WriteFlush(&stdin) + require.NoError(t, err) + _, err = pktline.WriteString(&stdin, fmt.Sprintf("%s %s %s", + gittest.DefaultObjectHash.ZeroOID, gittest.DefaultObjectHash.EmptyTreeOID, "refs/heads/main")) + require.NoError(t, err) + _, err = pktline.WriteString(&stdin, fmt.Sprintf("%s %s %s", + gittest.DefaultObjectHash.ZeroOID, gittest.DefaultObjectHash.EmptyTreeOID, "refs/heads/branch")) + require.NoError(t, err) + err = pktline.WriteFlush(&stdin) + require.NoError(t, err) + + var stdout bytes.Buffer + _, err = pktline.WriteString(&stdout, "version=1\000") + require.NoError(t, err) + err = pktline.WriteFlush(&stdout) + _, err = pktline.WriteString(&stdout, "ok refs/heads/main") + _, err = pktline.WriteString(&stdout, "ng refs/heads/branch for fun") + require.NoError(t, err) + err = pktline.WriteFlush(&stdout) + + return setupData{ + env: []string{payload}, + ctx: ctx, + stdin: stdin.String(), + expectedStdout: stdout.String(), + expectedUpdates: []ReferenceUpdate{ + { + Ref: "refs/heads/main", + OldOID: gittest.DefaultObjectHash.ZeroOID, + NewOID: gittest.DefaultObjectHash.EmptyTreeOID, + }, + { + Ref: "refs/heads/branch", + OldOID: gittest.DefaultObjectHash.ZeroOID, + NewOID: gittest.DefaultObjectHash.EmptyTreeOID, + }, + }, + invocationSteps: func(invocation ProcReceiveHookInvocation) error { + require.NoError(t, invocation.AcceptUpdate("refs/heads/main")) + require.NoError(t, invocation.RejectUpdate("refs/heads/branch", "for fun")) + return invocation.Close() + }, + } + }, + }, + } { + tc := tc + + t.Run(tc.desc, func(t *testing.T) { + setup := tc.setup(t, ctx) + + var stdout, stderr bytes.Buffer + err := hookManager.ProcReceiveHook(setup.ctx, repo, setup.env, strings.NewReader(setup.stdin), &stdout, &stderr) + if err != nil || setup.expectedErr != nil { + require.Equal(t, setup.expectedErr, err) + return + } + + invocation := procReceiveRegistry.Get(1) + require.Equal(t, setup.expectedAtomic, invocation.Atomic()) + + updates := invocation.ReferenceUpdates() + require.Equal(t, setup.expectedUpdates, updates) + + require.NoError(t, setup.invocationSteps(invocation)) + require.Equal(t, setup.expectedStdout, stdout.String()) + }) + } +} diff --git a/internal/gitaly/hook/transactions_test.go b/internal/gitaly/hook/transactions_test.go index 568cef0f8f54481ceedbcece99b77967ed634faf..8bcaa97f38d797b8310aaa19a8c4f6de77ff7db0 100644 --- a/internal/gitaly/hook/transactions_test.go +++ b/internal/gitaly/hook/transactions_test.go @@ -38,7 +38,7 @@ func TestHookManager_stopCalled(t *testing.T) { var mockTxMgr transaction.MockManager hookManager := NewManager(cfg, config.NewLocator(cfg), testhelper.SharedLogger(t), gittest.NewCommandFactory(t, cfg), &mockTxMgr, gitlab.NewMockClient( t, gitlab.MockAllowed, gitlab.MockPreReceive, gitlab.MockPostReceive, - ), NewTransactionRegistry(storagemgr.NewTransactionRegistry())) + ), NewTransactionRegistry(storagemgr.NewTransactionRegistry()), NewProcReceiveRegistry()) hooksPayload, err := git.NewHooksPayload( cfg, @@ -144,7 +144,7 @@ func TestHookManager_contextCancellationCancelsVote(t *testing.T) { hookManager := NewManager(cfg, config.NewLocator(cfg), testhelper.SharedLogger(t), gittest.NewCommandFactory(t, cfg), &mockTxMgr, gitlab.NewMockClient( t, gitlab.MockAllowed, gitlab.MockPreReceive, gitlab.MockPostReceive, - ), NewTransactionRegistry(storagemgr.NewTransactionRegistry())) + ), NewTransactionRegistry(storagemgr.NewTransactionRegistry()), NewProcReceiveRegistry()) hooksPayload, err := git.NewHooksPayload( cfg, diff --git a/internal/gitaly/hook/update_test.go b/internal/gitaly/hook/update_test.go index 8d8ab8e30a5fb088926f047afb2fdf2fe71f7c89..8f446dd1638e256495efd402bf0705b12e747454 100644 --- a/internal/gitaly/hook/update_test.go +++ b/internal/gitaly/hook/update_test.go @@ -41,7 +41,7 @@ func TestUpdate_customHooks(t *testing.T) { txManager := transaction.NewTrackingManager() hookManager := NewManager(cfg, locator, testhelper.SharedLogger(t), gitCmdFactory, txManager, gitlab.NewMockClient( t, gitlab.MockAllowed, gitlab.MockPreReceive, gitlab.MockPostReceive, - ), NewTransactionRegistry(storagemgr.NewTransactionRegistry())) + ), NewTransactionRegistry(storagemgr.NewTransactionRegistry()), NewProcReceiveRegistry()) receiveHooksPayload := &git.UserDetails{ UserID: "1234", @@ -258,7 +258,7 @@ func TestUpdate_quarantine(t *testing.T) { hookManager := NewManager(cfg, config.NewLocator(cfg), testhelper.SharedLogger(t), gittest.NewCommandFactory(t, cfg), nil, gitlab.NewMockClient( t, gitlab.MockAllowed, gitlab.MockPreReceive, gitlab.MockPostReceive, - ), NewTransactionRegistry(storagemgr.NewTransactionRegistry())) + ), NewTransactionRegistry(storagemgr.NewTransactionRegistry()), NewProcReceiveRegistry()) //nolint:gitaly-linters gittest.WriteCustomHook(t, repoPath, "update", []byte(fmt.Sprintf( diff --git a/internal/gitaly/server/auth_test.go b/internal/gitaly/server/auth_test.go index 1c8ff4d0d0cb60d3389fdd93c004b8a4446ba08e..835442db9de0f67524740da8784eaf753cf037e6 100644 --- a/internal/gitaly/server/auth_test.go +++ b/internal/gitaly/server/auth_test.go @@ -195,7 +195,7 @@ func runServer(t *testing.T, cfg config.Cfg) string { gitCmdFactory := gittest.NewCommandFactory(t, cfg) hookManager := hook.NewManager(cfg, locator, logger, gitCmdFactory, txManager, gitlab.NewMockClient( t, gitlab.MockAllowed, gitlab.MockPreReceive, gitlab.MockPostReceive, - ), hook.NewTransactionRegistry(storagemgr.NewTransactionRegistry())) + ), hook.NewTransactionRegistry(storagemgr.NewTransactionRegistry()), hook.NewProcReceiveRegistry()) catfileCache := catfile.NewCache(cfg) t.Cleanup(catfileCache.Stop) diskCache := cache.New(cfg, locator, logger) diff --git a/internal/gitaly/service/hook/testhelper_test.go b/internal/gitaly/service/hook/testhelper_test.go index 90064ec67e2ff48f9b45b913be7f7bcfe4c87cd6..16b9d1f142fa1f50556926818430fe53ae52f0b7 100644 --- a/internal/gitaly/service/hook/testhelper_test.go +++ b/internal/gitaly/service/hook/testhelper_test.go @@ -59,7 +59,7 @@ func runHooksServerWithTransactionRegistry(tb testing.TB, cfg config.Cfg, opts [ return testserver.RunGitalyServer(tb, cfg, func(srv *grpc.Server, deps *service.Dependencies) { if txRegistry != nil { - deps.GitalyHookManager = gitalyhook.NewManager(deps.GetCfg(), deps.GetLocator(), deps.GetLogger(), deps.GetGitCmdFactory(), deps.GetTxManager(), deps.GetGitlabClient(), txRegistry) + deps.GitalyHookManager = gitalyhook.NewManager(deps.GetCfg(), deps.GetLocator(), deps.GetLogger(), deps.GetGitCmdFactory(), deps.GetTxManager(), deps.GetGitlabClient(), txRegistry, gitalyhook.NewProcReceiveRegistry()) } hookServer := NewServer(deps) diff --git a/internal/gitaly/service/operations/merge_branch_test.go b/internal/gitaly/service/operations/merge_branch_test.go index aaf450194b4553a17f062981344d23d18b8b82be..e0e6817226e83b1bd3ed31e7fcb1d8ee7c6d1962 100644 --- a/internal/gitaly/service/operations/merge_branch_test.go +++ b/internal/gitaly/service/operations/merge_branch_test.go @@ -1148,7 +1148,9 @@ func testUserMergeBranchAllowed(t *testing.T, ctx context.Context) { }, gitlab.MockPreReceive, gitlab.MockPostReceive, - ), hook.NewTransactionRegistry(storagemgr.NewTransactionRegistry())) + ), hook.NewTransactionRegistry(storagemgr.NewTransactionRegistry()), + hook.NewProcReceiveRegistry(), + ) ctx, cfg, client := setupOperationsServiceWithCfg( t, ctx, cfg, diff --git a/internal/testhelper/testserver/gitaly.go b/internal/testhelper/testserver/gitaly.go index 9af5d742b551eab809ed3ff56ff4062dd5c48ed8..d7210b6b816ed7b28213fd4431db41b2f770aa32 100644 --- a/internal/testhelper/testserver/gitaly.go +++ b/internal/testhelper/testserver/gitaly.go @@ -324,7 +324,15 @@ func (gsd *gitalyServerDeps) createDependencies(tb testing.TB, cfg config.Cfg) * } if gsd.hookMgr == nil { - gsd.hookMgr = hook.NewManager(cfg, gsd.locator, gsd.logger, gsd.gitCmdFactory, gsd.txMgr, gsd.gitlabClient, hook.NewTransactionRegistry(gsd.transactionRegistry)) + gsd.hookMgr = hook.NewManager( + cfg, gsd.locator, + gsd.logger, + gsd.gitCmdFactory, + gsd.txMgr, + gsd.gitlabClient, + hook.NewTransactionRegistry(gsd.transactionRegistry), + hook.NewProcReceiveRegistry(), + ) } if gsd.catfileCache == nil {