diff --git a/internal/git/pktline/read_monitor.go b/internal/git/pktline/read_monitor.go index 177a2ec3d63f253275775498248eae14dfc0a708..1de015f49a243e21f147051afe4c76a93e298da5 100644 --- a/internal/git/pktline/read_monitor.go +++ b/internal/git/pktline/read_monitor.go @@ -8,6 +8,7 @@ import ( "sync" "gitlab.com/gitlab-org/gitaly/v16/internal/helper" + "gitlab.com/gitlab-org/gitaly/v16/internal/log" ) // ReadMonitor monitors an io.Reader, waiting for a specified packet. If the @@ -28,6 +29,7 @@ import ( // fetch, for instance, so tighter limits can be placed on it, leading to a // better mitigation. type ReadMonitor struct { + logger log.Logger pr *os.File pw *os.File underlying io.Reader @@ -42,13 +44,14 @@ type ReadMonitor struct { // // The returned function will release allocated resources. You must make sure to call this // function. -func NewReadMonitor(ctx context.Context, r io.Reader) (*os.File, *ReadMonitor, func(), error) { +func NewReadMonitor(ctx context.Context, r io.Reader, logger log.Logger) (*os.File, *ReadMonitor, func(), error) { pr, pw, err := os.Pipe() if err != nil { return nil, nil, nil, err } return pr, &ReadMonitor{ + logger: logger, pr: pr, pw: pw, underlying: r, @@ -85,6 +88,11 @@ func (m *ReadMonitor) Monitor(ctx context.Context, pkt []byte, timeout helper.Ti } } + if err := scanner.Err(); err != nil { + m.logger.WithError(err).ErrorContext(ctx, "failed scanning stream for specified packet") + stopOnce.Do(timeout.Stop) + } + // Complete the read loop, then signal completion on pr by closing pw _, _ = io.Copy(io.Discard, teeReader) _ = m.pw.Close() diff --git a/internal/git/pktline/read_monitor_test.go b/internal/git/pktline/read_monitor_test.go index f66dff56924097debb6b49f76d2dcaf9c64cbc60..36acd689c87d5d5011614735e33d87e1e53edd13 100644 --- a/internal/git/pktline/read_monitor_test.go +++ b/internal/git/pktline/read_monitor_test.go @@ -3,6 +3,7 @@ package pktline import ( "bytes" "context" + "errors" "io" "os" "strings" @@ -14,6 +15,14 @@ import ( "gitlab.com/gitlab-org/gitaly/v16/internal/testhelper" ) +type errGenReader struct { + err error +} + +func (e *errGenReader) Read(p []byte) (int, error) { + return 0, e.err +} + func TestReadMonitorTimeout(t *testing.T) { waitPipeR, waitPipeW := io.Pipe() defer waitPipeW.Close() @@ -24,7 +33,9 @@ func TestReadMonitorTimeout(t *testing.T) { waitPipeR, // this pipe reader lets us block the multi reader ) - r, monitor, cleanup, err := NewReadMonitor(ctx, in) + logger := testhelper.NewLogger(t) + + r, monitor, cleanup, err := NewReadMonitor(ctx, in, logger) require.NoError(t, err) timeoutTicker := helper.NewManualTicker() @@ -63,7 +74,9 @@ func TestReadMonitorSuccess(t *testing.T) { strings.NewReader(postTimeoutPayload), ) - r, monitor, cleanup, err := NewReadMonitor(ctx, in) + logger := testhelper.NewLogger(t) + + r, monitor, cleanup, err := NewReadMonitor(ctx, in, logger) require.NoError(t, err) defer cleanup() @@ -95,3 +108,47 @@ func TestReadMonitorSuccess(t *testing.T) { require.NoError(t, ctx.Err()) } + +func TestReadMonitorReadError(t *testing.T) { + ctx, cancel := context.WithCancel(testhelper.Context(t)) + + preErrPayload := "000ftest string" + expectedErr := errors.New("read error") + + in := io.MultiReader( + strings.NewReader(preErrPayload), + &errGenReader{err: expectedErr}, + ) + + logger := testhelper.NewLogger(t) + hook := testhelper.AddLoggerHook(logger) + + r, monitor, cleanup, err := NewReadMonitor(ctx, in, logger) + require.NoError(t, err) + defer cleanup() + + timeoutTicker := helper.NewManualTicker() + + stopCh := make(chan any) + timeoutTicker.StopFunc = func() { + close(stopCh) + } + + go monitor.Monitor(ctx, PktFlush(), timeoutTicker, cancel) + + // Simulate read error + scanner := NewScanner(r) + require.True(t, scanner.Scan()) + require.Equal(t, preErrPayload, scanner.Text()) + require.False(t, scanner.Scan()) + + // Timer stoped on read error + <-stopCh + + // Ensure the read error is logged properly + var logs []string + for _, entry := range hook.AllEntries() { + logs = append(logs, entry.Message) + } + require.Contains(t, logs, "failed scanning stream for specified packet") +} diff --git a/internal/gitaly/service/ssh/upload_command.go b/internal/gitaly/service/ssh/upload_command.go index 68d2af2ca22bc2d03850da158cc6e96da4494608..8eb433540d0a9959260bb144ed956ab97008d1eb 100644 --- a/internal/gitaly/service/ssh/upload_command.go +++ b/internal/gitaly/service/ssh/upload_command.go @@ -43,7 +43,7 @@ func (s *server) runUploadCommand( // Use large copy buffer to reduce the number of system calls stdout = &largeBufferReaderFrom{Writer: stdoutCounter} - stdinPipe, monitor, cleanup, err := pktline.NewReadMonitor(ctx, stdin) + stdinPipe, monitor, cleanup, err := pktline.NewReadMonitor(ctx, stdin, s.logger) if err != nil { return fmt.Errorf("create monitor: %w", err) }