diff --git a/internal/backup/pipeline.go b/internal/backup/pipeline.go index 8e4e15784d2a82a24900253f4057a2b6fff56c2c..5672493acdd070781c26bc6f6c9aa5f39c8e387b 100644 --- a/internal/backup/pipeline.go +++ b/internal/backup/pipeline.go @@ -66,13 +66,6 @@ type Command interface { Execute(context.Context) error } -// Pipeline executes a series of commands and encapsulates error handling for -// the caller. -type Pipeline interface { - Handle(context.Context, Command) - Done() error -} - // CreateCommand creates a backup for a repository type CreateCommand struct { strategy Strategy @@ -131,136 +124,122 @@ func (cmd RestoreCommand) Execute(ctx context.Context) error { return cmd.strategy.Restore(ctx, &cmd.request) } -// PipelineErrors represents a summary of errors by repository -type PipelineErrors []error +// commandErrors represents a summary of errors by repository +// +//nolint:errname +type commandErrors struct { + errs []error + mu sync.Mutex +} // AddError adds an error associated with a repository to the summary. -func (e *PipelineErrors) AddError(repo *gitalypb.Repository, err error) { +func (c *commandErrors) AddError(repo *gitalypb.Repository, err error) { + c.mu.Lock() + defer c.mu.Unlock() + if repo.GetGlProjectPath() != "" { err = fmt.Errorf("%s (%s): %w", repo.GetRelativePath(), repo.GetGlProjectPath(), err) } else { err = fmt.Errorf("%s: %w", repo.GetRelativePath(), err) } - *e = append(*e, err) + c.errs = append(c.errs, err) } -func (e PipelineErrors) Error() string { +func (c *commandErrors) Error() string { var builder strings.Builder - _, _ = fmt.Fprintf(&builder, "%d failures encountered:\n", len(e)) - for _, err := range e { + _, _ = fmt.Fprintf(&builder, "%d failures encountered:\n", len(c.errs)) + for _, err := range c.errs { _, _ = fmt.Fprintf(&builder, " - %s\n", err.Error()) } return builder.String() } -// LoggingPipeline outputs logging for each command executed -type LoggingPipeline struct { - log log.Logger - - mu sync.Mutex - errs PipelineErrors +type contextCommand struct { + Command Command + Context context.Context } -// NewLoggingPipeline creates a new logging pipeline -func NewLoggingPipeline(log log.Logger) *LoggingPipeline { - return &LoggingPipeline{ - log: log, - } -} +// Pipeline is a pipeline for running backup and restore jobs. +type Pipeline struct { + log log.Logger -// Handle takes a command to process. Commands are logged and executed immediately. -func (p *LoggingPipeline) Handle(ctx context.Context, cmd Command) { - log := p.cmdLogger(cmd) - log.Info(fmt.Sprintf("started %s", cmd.Name())) + parallel int + parallelStorage int - if err := cmd.Execute(ctx); err != nil { - if errors.Is(err, ErrSkipped) { - log.Warn(fmt.Sprintf("skipped %s", cmd.Name())) - } else { - log.WithError(err).Error(fmt.Sprintf("%s failed", cmd.Name())) - p.addError(cmd.Repository(), err) - } - return - } + // totalWorkers allows the total number of parallel jobs to be + // limited. This allows us to create the required workers for + // each storage, while still limiting the absolute parallelism. + totalWorkers chan struct{} - log.Info(fmt.Sprintf("completed %s", cmd.Name())) -} + workerWg sync.WaitGroup + workersByStorage map[string]chan *contextCommand + workersByStorageMu sync.Mutex -func (p *LoggingPipeline) addError(repo *gitalypb.Repository, err error) { - p.mu.Lock() - defer p.mu.Unlock() + // done signals that no more commands will be provided to the Pipeline via + // Handle(), and the pipeline should wait for workers to complete and exit. + done chan struct{} - p.errs.AddError(repo, err) + pipelineError error + cmdErrors *commandErrors } -// Done indicates that the pipeline is complete and returns any accumulated errors -func (p *LoggingPipeline) Done() error { - if len(p.errs) > 0 { - return fmt.Errorf("pipeline: %w", p.errs) +// NewPipeline creates a pipeline that executes backup and restore jobs. +// The pipeline executes sequentially by default, but can be made concurrent +// by calling WithConcurrency() after initialisation. +func NewPipeline(log log.Logger, opts ...PipelineOption) (*Pipeline, error) { + p := &Pipeline{ + log: log, + // Default to no concurrency. + parallel: 1, + parallelStorage: 0, + done: make(chan struct{}), + workersByStorage: make(map[string]chan *contextCommand), + cmdErrors: &commandErrors{}, } - return nil -} -func (p *LoggingPipeline) cmdLogger(cmd Command) log.Logger { - return p.log.WithFields(log.Fields{ - "command": cmd.Name(), - "storage_name": cmd.Repository().StorageName, - "relative_path": cmd.Repository().RelativePath, - "gl_project_path": cmd.Repository().GlProjectPath, - }) -} + for _, opt := range opts { + if err := opt(p); err != nil { + return nil, err + } + } -type contextCommand struct { - Command Command - Context context.Context + return p, nil } -// ParallelPipeline is a pipeline that executes commands in parallel -type ParallelPipeline struct { - next Pipeline - parallel int - parallelStorage int +// PipelineOption represents an optional configuration parameter for the Pipeline. +type PipelineOption func(*Pipeline) error - wg sync.WaitGroup - workerSlots chan struct{} - done chan struct{} +// WithConcurrency configures the pipeline to run backup and restore jobs concurrently. +// total defines the absolute maximum number of jobs that the pipeline should execute +// concurrently. perStorage defines the number of jobs per Gitaly storage that the +// pipeline should attempt to execute concurrently. +// +// For example, in a Gitaly deployment with 2 storages, WithConcurrency(3, 2) means +// that at most 3 jobs will execute concurrently, despite 2 concurrent jobs being allowed +// per storage (2*2=4). +func WithConcurrency(total, perStorage int) PipelineOption { + return func(p *Pipeline) error { + if total == 0 && perStorage == 0 { + return errors.New("total and perStorage cannot both be 0") + } - mu sync.Mutex - requests map[string]chan *contextCommand - err error -} + p.parallel = total + p.parallelStorage = perStorage -// NewParallelPipeline creates a new ParallelPipeline where all commands are -// passed onto `next` to be processed, `parallel` is the maximum number of -// parallel backups that will run and `parallelStorage` is the maximum number -// of parallel backups that will run per storage. Since the number of storages -// is unknown at initialisation, workers are created lazily as new storage -// names are encountered. -// -// Note: When both `parallel` and `parallelStorage` are zero or less no workers -// are created and the pipeline will block forever. -func NewParallelPipeline(next Pipeline, parallel, parallelStorage int) *ParallelPipeline { - var workerSlots chan struct{} - if parallel > 0 && parallelStorage > 0 { - // workerSlots allows the total number of parallel jobs to be - // limited. This allows us to create the required workers for - // each storage, while still limiting the absolute parallelism. - workerSlots = make(chan struct{}, parallel) - } - return &ParallelPipeline{ - next: next, - parallel: parallel, - parallelStorage: parallelStorage, - workerSlots: workerSlots, - done: make(chan struct{}), - requests: make(map[string]chan *contextCommand), + if total > 0 && perStorage > 0 { + // When both values are provided, we ensure that total limits + // the global concurrency. + p.totalWorkers = make(chan struct{}, total) + } + + return nil } } -// Handle queues a request to create a backup. Commands are processed by -// n-workers per storage. -func (p *ParallelPipeline) Handle(ctx context.Context, cmd Command) { - ch := p.getStorage(cmd.Repository().StorageName) +// Handle queues a request to create a backup. Commands either processed sequentially +// or concurrently, if WithConcurrency() was called. +func (p *Pipeline) Handle(ctx context.Context, cmd Command) { + ch := p.getWorker(cmd.Repository().StorageName) select { case <-ctx.Done(): @@ -272,49 +251,53 @@ func (p *ParallelPipeline) Handle(ctx context.Context, cmd Command) { } } -// Done waits for any in progress calls to `next` to complete then reports any -// accumulated errors -func (p *ParallelPipeline) Done() error { +// Done waits for any in progress jobs to complete then reports any accumulated errors +func (p *Pipeline) Done() error { close(p.done) - p.wg.Wait() - if err := p.next.Done(); err != nil { - return err + p.workerWg.Wait() + + if p.pipelineError != nil { + return fmt.Errorf("pipeline: %w", p.pipelineError) } - if p.err != nil { - return fmt.Errorf("pipeline: %w", p.err) + + if len(p.cmdErrors.errs) > 0 { + return fmt.Errorf("pipeline: %w", p.cmdErrors) } + return nil } -// getStorage finds the channel associated with a storage. When no channel is +// getWorker finds the channel associated with a storage. When no channel is // found, one is created and n-workers are started to process requests. -func (p *ParallelPipeline) getStorage(storage string) chan<- *contextCommand { - p.mu.Lock() - defer p.mu.Unlock() +// If parallelStorage is 0, a channel is created against a pseudo-storage to +// enforce the number of total concurrent jobs. +func (p *Pipeline) getWorker(storage string) chan<- *contextCommand { + p.workersByStorageMu.Lock() + defer p.workersByStorageMu.Unlock() workers := p.parallelStorage - if p.parallelStorage < 1 { + if p.parallelStorage == 0 { // if the workers are not limited by storage, then pretend there is a single storage with `parallel` workers storage = "" workers = p.parallel } - ch, ok := p.requests[storage] + ch, ok := p.workersByStorage[storage] if !ok { ch = make(chan *contextCommand) - p.requests[storage] = ch + p.workersByStorage[storage] = ch for i := 0; i < workers; i++ { - p.wg.Add(1) + p.workerWg.Add(1) go p.worker(ch) } } return ch } -func (p *ParallelPipeline) worker(ch <-chan *contextCommand) { - defer p.wg.Done() +func (p *Pipeline) worker(ch <-chan *contextCommand) { + defer p.workerWg.Done() for { select { case <-p.done: @@ -325,35 +308,58 @@ func (p *ParallelPipeline) worker(ch <-chan *contextCommand) { } } -func (p *ParallelPipeline) processCommand(ctx context.Context, cmd Command) { +func (p *Pipeline) processCommand(ctx context.Context, cmd Command) { p.acquireWorkerSlot() defer p.releaseWorkerSlot() - p.next.Handle(ctx, cmd) + log := p.cmdLogger(cmd) + log.Info(fmt.Sprintf("started %s", cmd.Name())) + + if err := cmd.Execute(ctx); err != nil { + if errors.Is(err, ErrSkipped) { + log.Warn(fmt.Sprintf("skipped %s", cmd.Name())) + } else { + log.WithError(err).Error(fmt.Sprintf("%s failed", cmd.Name())) + p.addError(cmd.Repository(), err) + } + return + } + + log.Info(fmt.Sprintf("completed %s", cmd.Name())) } -func (p *ParallelPipeline) setErr(err error) { - p.mu.Lock() - defer p.mu.Unlock() - if p.err != nil { +func (p *Pipeline) setErr(err error) { + if p.pipelineError != nil { return } - p.err = err + p.pipelineError = err +} + +func (p *Pipeline) addError(repo *gitalypb.Repository, err error) { + p.cmdErrors.AddError(repo, err) +} + +func (p *Pipeline) cmdLogger(cmd Command) log.Logger { + return p.log.WithFields(log.Fields{ + "command": cmd.Name(), + "storage_name": cmd.Repository().StorageName, + "relative_path": cmd.Repository().RelativePath, + "gl_project_path": cmd.Repository().GlProjectPath, + }) } // acquireWorkerSlot queues the worker until a slot is available. -// It never blocks if `parallel` or `parallelStorage` are 0 -func (p *ParallelPipeline) acquireWorkerSlot() { - if p.workerSlots == nil { +func (p *Pipeline) acquireWorkerSlot() { + if p.totalWorkers == nil { return } - p.workerSlots <- struct{}{} + p.totalWorkers <- struct{}{} } // releaseWorkerSlot releases the worker slot. -func (p *ParallelPipeline) releaseWorkerSlot() { - if p.workerSlots == nil { +func (p *Pipeline) releaseWorkerSlot() { + if p.totalWorkers == nil { return } - <-p.workerSlots + <-p.totalWorkers } diff --git a/internal/backup/pipeline_test.go b/internal/backup/pipeline_test.go index fe3c82616c6c205a3939bcf32537667fd6c8bb1c..04c539a5f4cd80984b51ba6eb29063b1c39b429e 100644 --- a/internal/backup/pipeline_test.go +++ b/internal/backup/pipeline_test.go @@ -3,7 +3,7 @@ package backup import ( "context" "fmt" - "sync/atomic" + "sync" "testing" "time" @@ -15,59 +15,83 @@ import ( "gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb" ) -func TestLoggingPipeline(t *testing.T) { +func TestPipeline(t *testing.T) { t.Parallel() - testPipeline(t, func() Pipeline { - return NewLoggingPipeline(testhelper.SharedLogger(t)) - }) -} - -func TestParallelPipeline(t *testing.T) { - t.Parallel() - - testPipeline(t, func() Pipeline { - return NewParallelPipeline(NewLoggingPipeline(testhelper.SharedLogger(t)), 2, 0) + // Sequential + testPipeline(t, func() *Pipeline { + p, err := NewPipeline(testhelper.SharedLogger(t)) + require.NoError(t, err) + return p }) + // Concurrent t.Run("parallelism", func(t *testing.T) { for _, tc := range []struct { - parallel int - parallelStorage int - expectedMaxParallel int64 + parallel int + parallelStorage int + expectedMaxParallel int + expectedMaxStorageParallel int }{ { - parallel: 2, - parallelStorage: 0, - expectedMaxParallel: 2, + parallel: 2, + parallelStorage: 0, + expectedMaxParallel: 2, + expectedMaxStorageParallel: 2, }, { - parallel: 2, - parallelStorage: 3, - expectedMaxParallel: 2, + parallel: 2, + parallelStorage: 3, + expectedMaxParallel: 2, + expectedMaxStorageParallel: 2, }, { - parallel: 0, - parallelStorage: 3, - expectedMaxParallel: 6, // 2 storages * 3 workers per storage + parallel: 0, + parallelStorage: 3, + expectedMaxParallel: 6, // 2 storages * 3 workers per storage + expectedMaxStorageParallel: 3, + }, + { + parallel: 3, + parallelStorage: 2, + expectedMaxParallel: 3, + expectedMaxStorageParallel: 2, }, } { t.Run(fmt.Sprintf("parallel:%d,parallelStorage:%d", tc.parallel, tc.parallelStorage), func(t *testing.T) { - var calls int64 + var mu sync.Mutex + // callsPerStorage tracks the number of concurrent jobs running for each storage. + callsPerStorage := map[string]int{ + "storage1": 0, + "storage2": 0, + } + strategy := MockStrategy{ CreateFunc: func(ctx context.Context, req *CreateRequest) error { - currentCalls := atomic.AddInt64(&calls, 1) - defer atomic.AddInt64(&calls, -1) - - assert.LessOrEqual(t, currentCalls, tc.expectedMaxParallel) + mu.Lock() + callsPerStorage[req.Repository.StorageName]++ + allCalls := 0 + for _, v := range callsPerStorage { + allCalls += v + } + // We ensure that the concurrency for each storage is not above the + // parallelStorage threshold, and also that the total number of concurrent + // jobs is not above the parallel threshold. + require.LessOrEqual(t, callsPerStorage[req.Repository.StorageName], tc.expectedMaxStorageParallel) + require.LessOrEqual(t, allCalls, tc.expectedMaxParallel) + mu.Unlock() + defer func() { + mu.Lock() + callsPerStorage[req.Repository.StorageName]-- + mu.Unlock() + }() time.Sleep(time.Millisecond) return nil }, } - var p Pipeline - p = NewLoggingPipeline(testhelper.SharedLogger(t)) - p = NewParallelPipeline(p, tc.parallel, tc.parallelStorage) + p, err := NewPipeline(testhelper.SharedLogger(t), WithConcurrency(tc.parallel, tc.parallelStorage)) + require.NoError(t, err) ctx := testhelper.Context(t) for i := 0; i < 10; i++ { @@ -81,9 +105,8 @@ func TestParallelPipeline(t *testing.T) { t.Run("context done", func(t *testing.T) { var strategy MockStrategy - var p Pipeline - p = NewLoggingPipeline(testhelper.SharedLogger(t)) - p = NewParallelPipeline(p, 0, 0) // make sure worker channels always block + p, err := NewPipeline(testhelper.SharedLogger(t)) + require.NoError(t, err) ctx, cancel := context.WithCancel(testhelper.Context(t)) @@ -92,8 +115,7 @@ func TestParallelPipeline(t *testing.T) { p.Handle(ctx, NewCreateCommand(strategy, CreateRequest{Repository: &gitalypb.Repository{StorageName: "default"}})) - err := p.Done() - require.EqualError(t, err, "pipeline: context canceled") + require.EqualError(t, p.Done(), "pipeline: context canceled") }) } @@ -124,7 +146,7 @@ func (s MockStrategy) RemoveAllRepositories(ctx context.Context, req *RemoveAllR return nil } -func testPipeline(t *testing.T, init func() Pipeline) { +func testPipeline(t *testing.T, init func() *Pipeline) { strategy := MockStrategy{ CreateFunc: func(_ context.Context, req *CreateRequest) error { switch req.Repository.StorageName { @@ -277,7 +299,7 @@ func TestPipelineError(t *testing.T) { }, } { t.Run(tc.name, func(t *testing.T) { - err := PipelineErrors{} + err := &commandErrors{} for _, repo := range tc.repos { err.AddError(repo, assert.AnError) diff --git a/internal/cli/gitalybackup/create.go b/internal/cli/gitalybackup/create.go index a67cc75da13dc9023d672fba2ae8d6dd4d7fe24b..288e5a08d0ca91ccc245cd9c6b700fee017d6582 100644 --- a/internal/cli/gitalybackup/create.go +++ b/internal/cli/gitalybackup/create.go @@ -142,10 +142,13 @@ func (cmd *createSubcommand) run(ctx context.Context, logger log.Logger, stdin i manager = backup.NewManager(sink, locator, pool) } - var pipeline backup.Pipeline - pipeline = backup.NewLoggingPipeline(logger) + var opts []backup.PipelineOption if cmd.parallel > 0 || cmd.parallelStorage > 0 { - pipeline = backup.NewParallelPipeline(pipeline, cmd.parallel, cmd.parallelStorage) + opts = append(opts, backup.WithConcurrency(cmd.parallel, cmd.parallelStorage)) + } + pipeline, err := backup.NewPipeline(logger, opts...) + if err != nil { + return fmt.Errorf("create pipeline: %w", err) } decoder := json.NewDecoder(stdin) diff --git a/internal/cli/gitalybackup/restore.go b/internal/cli/gitalybackup/restore.go index 06a20bb6ede750af444520647d6040dc30a7cc42..de9e2cd3d313320e55a196d69c852ef3a7295d93 100644 --- a/internal/cli/gitalybackup/restore.go +++ b/internal/cli/gitalybackup/restore.go @@ -17,11 +17,8 @@ import ( ) type restoreRequest struct { - storage.ServerInfo - StorageName string `json:"storage_name"` - RelativePath string `json:"relative_path"` - GlProjectPath string `json:"gl_project_path"` - AlwaysCreate bool `json:"always_create"` + serverRepository + AlwaysCreate bool `json:"always_create"` } type restoreSubcommand struct { @@ -149,10 +146,13 @@ func (cmd *restoreSubcommand) run(ctx context.Context, logger log.Logger, stdin } } - var pipeline backup.Pipeline - pipeline = backup.NewLoggingPipeline(logger) + var opts []backup.PipelineOption if cmd.parallel > 0 || cmd.parallelStorage > 0 { - pipeline = backup.NewParallelPipeline(pipeline, cmd.parallel, cmd.parallelStorage) + opts = append(opts, backup.WithConcurrency(cmd.parallel, cmd.parallelStorage)) + } + pipeline, err := backup.NewPipeline(logger, opts...) + if err != nil { + return fmt.Errorf("create pipeline: %w", err) } decoder := json.NewDecoder(stdin)