From 56476626a991a0d68c4761d0fdac9a1d4f95910c Mon Sep 17 00:00:00 2001 From: Olivier Campeau Date: Tue, 18 Nov 2025 20:30:54 -0500 Subject: [PATCH 1/3] grpc: Add a gRPC proxier to proxy requests This commit adds a gRPC Proxier to proxy gRPC requests from a client to a remote server. This Proxier is a building block that will be needed in an upcoming commit for the Raft proxy, in order to proxy requests from a follower to a leader. This prosier is very basic and limited, but is sufficient for what we need so far. It proxies one request at a time. A new Proxier is needed for each request. It is not used in this commit, but will be used in a subsequent commit. --- internal/grpc/proxy/grpc_proxier.go | 242 +++++++++ internal/grpc/proxy/grpc_proxier_test.go | 459 ++++++++++++++++++ .../grpc/proxy/proxy_test_testhelper_test.go | 4 +- 3 files changed, 703 insertions(+), 2 deletions(-) create mode 100644 internal/grpc/proxy/grpc_proxier.go create mode 100644 internal/grpc/proxy/grpc_proxier_test.go diff --git a/internal/grpc/proxy/grpc_proxier.go b/internal/grpc/proxy/grpc_proxier.go new file mode 100644 index 0000000000..9b07e3eb3c --- /dev/null +++ b/internal/grpc/proxy/grpc_proxier.go @@ -0,0 +1,242 @@ +package proxy + +import ( + "context" + "io" + "strings" + "sync" + + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" +) + +const ( + CommUnary = "unary" + CommStream = "stream" +) + +// readWriteStream is a helper interface that wraps +// a subset of the grpc.ServerStream interface to +// facilitate some operations below. +type readWriteStream interface { + SendMsg(m any) error + RecvMsg(m any) error +} + +type CommPattern string + +// GrpcProxier is an object that allows proxying gRPC requests. +// It can proxy unary and stream requests. +// This proxier is simple and limited in scope. It can only +// be used for one request at a time. For each instance of +// GrpcProxier, the full method must be specified. +type GrpcProxier struct { + // conn is the gRPC connection to use to connect to the remote server + conn *grpc.ClientConn + + // fullMethod is the full method to call on the remote server + fullMethod string + + // opts defines the type of communication to open with the remote server + // ie: unary, full duplex, etc. + opts *grpc.StreamDesc +} + +// NewGrpcProxier creates a new GrpcProxier. +func NewGrpcProxier(conn *grpc.ClientConn, fullMethod string, comm CommPattern) *GrpcProxier { + // The stream name must be the short method name + // Example: + // Full method name = `package.ServiceName/MethodName` + // Short method name = `MethodName` + fullMethodParts := strings.Split(fullMethod, "/") + shortMethodName := fullMethodParts[len(fullMethodParts)-1] + + opts := &grpc.StreamDesc{ + StreamName: shortMethodName, + } + + if comm == CommStream { + opts.ServerStreams = true + opts.ClientStreams = true + } + + return &GrpcProxier{ + conn: conn, + fullMethod: fullMethod, + opts: opts, + } +} + +func (proxier *GrpcProxier) ProxyStream(incomingStream grpc.ServerStream) error { + // Get metadata from client + md, ok := metadata.FromIncomingContext(incomingStream.Context()) + if !ok { + md = metadata.New(nil) + } + + // Propagate headers to remote server and create outgoing stream + outCtx := metadata.NewOutgoingContext(context.Background(), md.Copy()) + outgoingStream, err := proxier.conn.NewStream(outCtx, proxier.opts, proxier.fullMethod) + if err != nil { + return err + } + + // Buffer length of 2 because we have 2 goroutines running: + // 1. Handle communication client <-> remote + // 2. Handle communication remote <-> client + // As such, there can, at most, be only 2 errors reported. + errCh := make(chan error, 2) + doneCh := make(chan struct{}, 1) + + wg := sync.WaitGroup{} + + // Handle client -> remote communication + wg.Add(1) + go func() { + defer wg.Done() + defer func() { _ = outgoingStream.CloseSend() }() + + for { + select { + case <-doneCh: + case <-incomingStream.Context().Done(): + return + default: + } + + msg := new(anypb.Any) + if recvErr := incomingStream.RecvMsg(msg); recvErr != nil { + // If err is EOF, it means the client closed + // the send stream, so there is nothing else to + // send to the remote, so we in turn close the + // outgoing stream by calling CloseSend() + // EOF is a non-error, so we don't send an error + // through the error channel in that case. + if recvErr != io.EOF { + errCh <- recvErr + } + break + } + + // Send the message to the remote + if sendErr := outgoingStream.SendMsg(msg); sendErr != nil { + // Any errors when sending a message should be reported + // because here we assume the send stream is still opened. + errCh <- sendErr + break + } + } + }() + + // Handle remote -> client communication + wg.Add(1) + go func() { + defer wg.Done() + defer func() { + incomingStream.SetTrailer(outgoingStream.Trailer().Copy()) + }() + + // this is used to control when to + // propagate headers + firstRecvMsg := true + + for { + select { + case <-doneCh: + case <-outgoingStream.Context().Done(): + return + default: + } + + msg := new(anypb.Any) + if recvErr := outgoingStream.RecvMsg(msg); recvErr != nil { + // If error is EOF, it means the remote server closed + // the `send` part of the stream, which means we should + // not expect any more new messages from the remote. So + // we stop sending messages back to the client. + // Else we report the error. + if recvErr != io.EOF { + errCh <- recvErr + } + break + } + + // Headers must always be set before sending the first message + if firstRecvMsg { + if hd, hdErr := outgoingStream.Header(); hdErr == nil { + _ = incomingStream.SetHeader(hd) + } + firstRecvMsg = false + } + + if sendErr := incomingStream.SendMsg(msg); sendErr != nil { + errCh <- sendErr + break + } + } + }() + + go func() { + defer close(errCh) + wg.Wait() + }() + + // Return the first error encountered + firstErr := <-errCh + + // Once an error has been returned, signal to all goroutines + // that they must abort by closing the channel. + close(doneCh) + + return firstErr +} + +func (proxier *GrpcProxier) ProxyUnary(incomingCtx context.Context, incomingRequest proto.Message) (any, error) { + // Get metadata from client + md, ok := metadata.FromIncomingContext(incomingCtx) + if !ok { + md = metadata.New(nil) + } + + // Propagate headers to remote server and create outgoing stream + outCtx := metadata.NewOutgoingContext(context.Background(), md.Copy()) + outgoingStream, err := proxier.conn.NewStream(outCtx, proxier.opts, proxier.fullMethod) + if err != nil { + return nil, err + } + + // Send the request to the remote server + if err := outgoingStream.SendMsg(incomingRequest); err != nil && err != io.EOF { + return nil, err + } + + // Close the stream to signal to the server that the request has been sent + if err := outgoingStream.CloseSend(); err != nil { + return nil, err + } + + // Receive the response. + // Here we use `mi.NewRequest()` because it creates a new proto.Message + // object that we can use to marshal the response in. + var response = anypb.Any{} + if err := outgoingStream.RecvMsg(&response); err != nil { + return nil, err + } + + // Send the headers back to the client + if headers, err := outgoingStream.Header(); err == nil && headers != nil { + if err := grpc.SendHeader(incomingCtx, headers); err != nil { + return nil, err + } + } + + // Send the trailers back to the client + if trailers := outgoingStream.Trailer(); trailers != nil { + if err := grpc.SetTrailer(incomingCtx, trailers); err != nil { + return nil, err + } + } + return &response, nil +} diff --git a/internal/grpc/proxy/grpc_proxier_test.go b/internal/grpc/proxy/grpc_proxier_test.go new file mode 100644 index 0000000000..ea4afd538b --- /dev/null +++ b/internal/grpc/proxy/grpc_proxier_test.go @@ -0,0 +1,459 @@ +package proxy_test + +import ( + "context" + "io" + "sync" + "testing" + + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v18/internal/grpc/proxy" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/interop/grpc_testing" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +func TestPipeUnary(t *testing.T) { + const fullMethod = "grpc.testing.TestService/UnaryCall" + const shortMethod = "UnaryCall" + + tests := []struct { + name string + handler func(ctx context.Context, request *grpc_testing.SimpleRequest) (*grpc_testing.SimpleResponse, error) + sender func(ctx context.Context, client grpc_testing.TestServiceClient) (*grpc_testing.SimpleResponse, error) + expectedResponse *grpc_testing.SimpleResponse + expectedError error + doCancel bool + }{ + { + name: "a normal request should succeed", + handler: func(ctx context.Context, request *grpc_testing.SimpleRequest) (*grpc_testing.SimpleResponse, error) { + return &grpc_testing.SimpleResponse{ + Payload: &grpc_testing.Payload{ + Body: []byte("hello-world"), + }, + }, nil + }, + sender: func(ctx context.Context, client grpc_testing.TestServiceClient) (*grpc_testing.SimpleResponse, error) { + return client.UnaryCall(ctx, &grpc_testing.SimpleRequest{ + Payload: &grpc_testing.Payload{ + Body: []byte("some-request"), + }, + }) + }, + expectedResponse: &grpc_testing.SimpleResponse{ + Payload: &grpc_testing.Payload{ + Body: []byte("hello-world"), + }, + }, + }, + { + name: "headers and trailers must be propagated", + handler: func(ctx context.Context, request *grpc_testing.SimpleRequest) (*grpc_testing.SimpleResponse, error) { + err := grpc.SetTrailer(ctx, metadata.Pairs("some", "trailer")) + require.NoError(t, err) + + err = grpc.SetHeader(ctx, metadata.Pairs("some", "header")) + require.NoError(t, err) + + return &grpc_testing.SimpleResponse{ + Payload: &grpc_testing.Payload{ + Body: []byte("hello-world"), + }, + }, nil + }, + sender: func(ctx context.Context, client grpc_testing.TestServiceClient) (*grpc_testing.SimpleResponse, error) { + var headers metadata.MD + var trailers metadata.MD + + response, err := client.UnaryCall(ctx, &grpc_testing.SimpleRequest{ + Payload: &grpc_testing.Payload{ + Body: []byte("some-request"), + }, + }, grpc.Trailer(&trailers), grpc.Header(&headers)) + + require.Equal(t, "trailer", trailers.Get("some")[0]) + require.Equal(t, "header", headers.Get("some")[0]) + + return response, err + }, + expectedResponse: &grpc_testing.SimpleResponse{ + Payload: &grpc_testing.Payload{ + Body: []byte("hello-world"), + }, + }, + }, + { + name: "when an error is returned by the remote, the proxy should return that error back to the client", + handler: func(ctx context.Context, request *grpc_testing.SimpleRequest) (*grpc_testing.SimpleResponse, error) { + return nil, status.Error(codes.NotFound, "repo not found") + }, + sender: func(ctx context.Context, client grpc_testing.TestServiceClient) (*grpc_testing.SimpleResponse, error) { + return client.UnaryCall(ctx, &grpc_testing.SimpleRequest{ + Payload: &grpc_testing.Payload{ + Body: []byte("some-request"), + }, + }) + }, + expectedResponse: nil, + expectedError: status.Error(codes.NotFound, "repo not found"), + }, + { + name: "when the context is cancelled, the error should be reported back to the client", + handler: func(ctx context.Context, request *grpc_testing.SimpleRequest) (*grpc_testing.SimpleResponse, error) { + return nil, status.Error(codes.NotFound, "repo not found") + }, + sender: func(ctx context.Context, client grpc_testing.TestServiceClient) (*grpc_testing.SimpleResponse, error) { + return client.UnaryCall(ctx, &grpc_testing.SimpleRequest{ + Payload: &grpc_testing.Payload{ + Body: []byte("some-request"), + }, + }) + }, + expectedResponse: nil, + expectedError: status.Error(codes.Canceled, "context canceled"), + doCancel: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + backEndCtx := context.Background() + + // Create remote backend + remoteBackendConn, remoteBackend := newBackendPinger(t, backEndCtx) + remoteBackend.unaryCall = tc.handler + + // Create client backend + clientCtx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + + unaryProxy := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + return proxy. + NewGrpcProxier(remoteBackendConn, fullMethod, proxy.CommUnary). + ProxyUnary(ctx, req.(*grpc_testing.SimpleRequest)) + } + + clientBackendConn, _ := newBackendPinger(t, clientCtx, grpc.ChainUnaryInterceptor(unaryProxy)) + client := grpc_testing.NewTestServiceClient(clientBackendConn) + + // If the test cases requires we cancel the request, we cancel + // right before making the gRPC call, but after the connection + // has been opened. + if tc.doCancel { + clientCancel() + } + + response, err := tc.sender(clientCtx, client) + require.Equal(t, tc.expectedError, err) + + if tc.expectedResponse == nil { + require.Nil(t, response) + } else { + require.Equal(t, tc.expectedResponse.Payload.Body, response.Payload.Body) + } + }) + } +} + +func TestPipeStream(t *testing.T) { + const fullMethod = "grpc.testing.TestService/FullDuplexCall" + const shortMethod = "FullDuplexCall" + + tests := []struct { + name string + handler func(grpc_testing.TestService_FullDuplexCallServer) error + sender func(ctx context.Context, client grpc_testing.TestServiceClient) ([]string, error) + expectedPayloads []string + expectedError error + doCancel bool + }{ + { + name: "a simple request a and response should succeed", + handler: func(stream grpc_testing.TestService_FullDuplexCallServer) error { + msg, err := stream.Recv() + require.NoError(t, err) + + return stream.Send(&grpc_testing.StreamingOutputCallResponse{ + Payload: msg.Payload, + }) + }, + sender: func(ctx context.Context, client grpc_testing.TestServiceClient) (responses []string, err error) { + stream, err := client.FullDuplexCall(ctx) + require.NoError(t, err) + + err = stream.Send(&grpc_testing.StreamingOutputCallRequest{ + Payload: &grpc_testing.Payload{ + Body: []byte("hello"), + }, + }) + require.NoError(t, err) + + response, err := stream.Recv() + require.NoError(t, err) + + return []string{string(response.Payload.Body)}, nil + }, + expectedPayloads: []string{"hello"}, + expectedError: nil, + doCancel: false, + }, + { + name: "when the remote returns an error, it should be propagated to the client", + handler: func(stream grpc_testing.TestService_FullDuplexCallServer) error { + msg, err := stream.Recv() + require.NoError(t, err) + + return stream.Send(&grpc_testing.StreamingOutputCallResponse{ + Payload: msg.Payload, + }) + }, + sender: func(ctx context.Context, client grpc_testing.TestServiceClient) (responses []string, err error) { + stream, err := client.FullDuplexCall(ctx) + if err != nil { + return []string{}, err + } + + err = stream.Send(&grpc_testing.StreamingOutputCallRequest{ + Payload: &grpc_testing.Payload{ + Body: []byte("hello"), + }, + }) + require.NoError(t, err) + + _, err = stream.Recv() + return []string{}, err + }, + expectedPayloads: []string{}, + expectedError: status.Error(codes.Canceled, "context canceled"), + doCancel: true, + }, + { + name: "when the remote returns an error, it should be propagated to the client", + handler: func(stream grpc_testing.TestService_FullDuplexCallServer) error { + return status.Error(codes.NotFound, "repo not found") + }, + sender: func(ctx context.Context, client grpc_testing.TestServiceClient) (responses []string, err error) { + stream, err := client.FullDuplexCall(ctx) + require.NoError(t, err) + + err = stream.Send(&grpc_testing.StreamingOutputCallRequest{ + Payload: &grpc_testing.Payload{ + Body: []byte("hello"), + }, + }) + require.NoError(t, err) + + _, err = stream.Recv() + return []string{}, err + }, + expectedPayloads: []string{}, + expectedError: status.Error(codes.NotFound, "repo not found"), + doCancel: false, + }, + { + name: "concurrent requests and response should succeed", + sender: func(ctx context.Context, client grpc_testing.TestServiceClient) (responses []string, err error) { + payloads := []string{"hello", "hi", "bonjour"} + stream, err := client.FullDuplexCall(ctx) + require.NoError(t, err) + + errCh := make(chan error, 2) + + wg := sync.WaitGroup{} + + // Send requests + wg.Add(1) + go func() { + defer wg.Done() + defer func() { + _ = stream.CloseSend() + }() + + for _, p := range payloads { + select { + case <-stream.Context().Done(): + errCh <- stream.Context().Err() + return + default: + } + + err := stream.Send(&grpc_testing.StreamingOutputCallRequest{ + Payload: &grpc_testing.Payload{ + Body: []byte(p), + }, + }) + if err != nil { + if err != io.EOF { + errCh <- err + } + break + } + } + }() + + // Collect responses + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-stream.Context().Done(): + errCh <- stream.Context().Err() + return + default: + } + + res := grpc_testing.StreamingOutputCallResponse{} + err := stream.RecvMsg(&res) + if err != nil { + if err != io.EOF { + errCh <- err + } + break + } + responses = append(responses, string(res.Payload.Body)) + } + }() + + go func() { + defer close(errCh) + wg.Wait() + }() + + err = <-errCh + return responses, err + }, + handler: func(stream grpc_testing.TestService_FullDuplexCallServer) error { + for { + select { + case <-stream.Context().Done(): + return stream.Context().Err() + default: + } + + // First receive a message + req, err := stream.Recv() + if err != nil { + if err == io.EOF { + break + } + return err + } + + // send back the same payload that was received + err = stream.Send(&grpc_testing.StreamingOutputCallResponse{ + Payload: req.Payload, + }) + if err != nil { + if err == io.EOF { + break + } + return err + } + } + return nil + }, + expectedPayloads: []string{"hello", "hi", "bonjour"}, + }, + { + name: "headers and trailers should be propagated", + handler: func(stream grpc_testing.TestService_FullDuplexCallServer) error { + // Assert that the headers set in the `sender` func were propagated + // up to here. + md, ok := metadata.FromIncomingContext(stream.Context()) + require.True(t, ok) + require.Equal(t, "data", md["meta"][0]) + + msg, err := stream.Recv() + require.NoError(t, err) + + // Set headers to be sent back to the client + md.Set("hello", "goodbye") + err = stream.SetHeader(md) + require.NoError(t, err) + + err = stream.Send(&grpc_testing.StreamingOutputCallResponse{ + Payload: msg.Payload, + }) + require.NoError(t, err) + + stream.SetTrailer(metadata.New(map[string]string{"last": "trailer"})) + return err + }, + sender: func(ctx context.Context, client grpc_testing.TestServiceClient) (responses []string, err error) { + stream, err := client.FullDuplexCall(ctx) + require.NoError(t, err) + + err = stream.Send(&grpc_testing.StreamingOutputCallRequest{ + Payload: &grpc_testing.Payload{ + Body: []byte("goodbye"), + }, + }) + require.NoError(t, err) + + err = stream.CloseSend() + require.NoError(t, err) + + response, err := stream.Recv() + require.NoError(t, err) + + // A stream must always be read until EOF + _, err = stream.Recv() + require.ErrorIs(t, err, io.EOF) + + // Assert that headers and trailers sent back from the remote + // are propagated back up to the original client + hd, err := stream.Header() + require.NoError(t, err) + require.Equal(t, "goodbye", hd.Get("hello")[0]) + + // Assert that trailers are propagated + trailers := stream.Trailer() + require.Equal(t, "trailer", trailers.Get("last")[0]) + + return []string{string(response.Payload.Body)}, nil + }, + expectedPayloads: []string{"goodbye"}, + expectedError: nil, + doCancel: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Create remote backend + backEndCtx := context.Background() + remoteBackendConn, remoteBackend := newBackendPinger(t, backEndCtx) + remoteBackend.fullDuplexCall = tc.handler + + // Create client backend + clientCtx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + + streamProxy := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + return proxy. + NewGrpcProxier(remoteBackendConn, fullMethod, proxy.CommStream). + ProxyStream(ss) + } + clientBackendConn, _ := newBackendPinger(t, clientCtx, grpc.ChainStreamInterceptor(streamProxy)) + client := grpc_testing.NewTestServiceClient(clientBackendConn) + + // If the test cases requires we cancel the request, we cancel + // right before making the gRPC call, but after the connection + // has been opened. + if tc.doCancel { + clientCancel() + } + + // Set a custom metadata on the outgoing stream so + // we can validate later on that headers are propagated + clientCtx = metadata.NewOutgoingContext(clientCtx, metadata.New(map[string]string{"meta": "data"})) + response, err := tc.sender(clientCtx, client) + require.Equal(t, tc.expectedPayloads, response) + require.Equal(t, tc.expectedError, err) + }) + } +} diff --git a/internal/grpc/proxy/proxy_test_testhelper_test.go b/internal/grpc/proxy/proxy_test_testhelper_test.go index d7caf003ad..c958c75d56 100644 --- a/internal/grpc/proxy/proxy_test_testhelper_test.go +++ b/internal/grpc/proxy/proxy_test_testhelper_test.go @@ -20,10 +20,10 @@ func newListener(tb testing.TB) net.Listener { return listener } -func newBackendPinger(tb testing.TB, ctx context.Context) (*grpc.ClientConn, *interceptPinger) { +func newBackendPinger(tb testing.TB, ctx context.Context, opts ...grpc.ServerOption) (*grpc.ClientConn, *interceptPinger) { ip := &interceptPinger{} - srvr := grpc.NewServer() + srvr := grpc.NewServer(opts...) listener := newListener(tb) grpc_testing.RegisterTestServiceServer(srvr, ip) -- GitLab From f4527f8b52fc8ed64584a5933881f866d7447b59 Mon Sep 17 00:00:00 2001 From: Olivier Campeau Date: Tue, 18 Nov 2025 20:31:15 -0500 Subject: [PATCH 2/3] grpc: Add a StreamBuffer for gRPC streams This commit introduces another building block needed for the Raft proxy. This StreamBuffer allows to buffer the head message of a stream (the first message) into a buffer for further inspection. It has a similar aim to what the `peeker` does but it differs in important ways in the usage. The `peeker` holds the `n` first message of a stream into a buffer for further inspection, similar to this StreamBuffer. However, when the time comes to send the stream of messages to the remote server, the user must make sure to send all peeked messages first before sending the remaining of the messages in rthe stream. It's a manual process. Here, the StreamBuffer handles this transparently because it implements the `grpc.ServerStream` interface, allowing it to be used as a Stream. The implementation of the StreamBuffer takes care of sending the buffered message before the rest of the stream. This implementation will greatly simplify the Raft proxy that will be introduced in a subsequent commit. It implements the `grpc.ServerStream` so it can be used as any gRPC stream. --- internal/grpc/proxy/stream_buffer.go | 90 ++++++++++++ internal/grpc/proxy/stream_buffer_test.go | 165 ++++++++++++++++++++++ 2 files changed, 255 insertions(+) create mode 100644 internal/grpc/proxy/stream_buffer.go create mode 100644 internal/grpc/proxy/stream_buffer_test.go diff --git a/internal/grpc/proxy/stream_buffer.go b/internal/grpc/proxy/stream_buffer.go new file mode 100644 index 0000000000..1b0215fbc0 --- /dev/null +++ b/internal/grpc/proxy/stream_buffer.go @@ -0,0 +1,90 @@ +package proxy + +import ( + "reflect" + "sync" + + "google.golang.org/grpc" + "google.golang.org/protobuf/proto" +) + +// StreamBuffer allows for buffering the head (first message) of a gRPC stream +// in order to inspect it before being consumed by the intended recipient. +// It implements the `grpc.ServerStream` interface so it can be used as any +// other gRPC stream. +type StreamBuffer struct { + grpc.ServerStream + buffer chan any + mutex *sync.Mutex +} + +// NewStreamBuffer returns a new StreamBuffer +func NewStreamBuffer(stream grpc.ServerStream) *StreamBuffer { + return &StreamBuffer{ + ServerStream: stream, + + // buffer is a channel of 1 because we always only buffer + // the head (first message) of the stream. + buffer: make(chan any, 1), + + // mutex is used because it is possible that Head() and RecvMsg() + // are called concurrently. If that happen, Head() might be in the + // state where the buffer has been emptied and has not been refilled + // back with the message. If RecvMsg() is executed during that time, + // it will see the buffer as empty and will process the next message in the + // stream. That will lead to re-ordering of messages, which we don't want. + mutex: &sync.Mutex{}, + } +} + +// Head returns in `m` the head (first message) in the gRPC stream. +// The head is always the first message not consumed by the intended +// recipient. +// Calling Head() multiple time without RecvMsg being called +// will always return the same result. +// For Head() to return the next message in the stream, RecvMsg +// must be called for the message to be fully consumed by the recipient. +func (s *StreamBuffer) Head(m any) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + // If a message is in the buffer, return it and put it back in the + // buffer for the next call to RecvMsg(). We must not lose this message + // since RecvMsg() will always return the message in the buffer if there + // is one before reading from the actual stream. + select { + case msg := <-s.buffer: + reflect.ValueOf(m).Elem().Set(reflect.ValueOf(msg).Elem()) + s.buffer <- msg + return nil + default: + } + + // If the buffer is empty, read the head of the stream + // into `m`. + v := m.(proto.Message) + err := s.ServerStream.RecvMsg(v) + if err != nil { + return err + } + + // Add the message in the buffer + s.buffer <- v + return nil +} + +// RecvMsg is a wrapper around the original stream. It looks into the +// buffer to see if there is a message there. If so, it returns it. +// Else it reads from the stream as normal. +func (s *StreamBuffer) RecvMsg(m any) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + select { + case msg := <-s.buffer: + reflect.ValueOf(m).Elem().Set(reflect.ValueOf(msg).Elem()) + return nil + default: + return s.ServerStream.RecvMsg(m) + } +} diff --git a/internal/grpc/proxy/stream_buffer_test.go b/internal/grpc/proxy/stream_buffer_test.go new file mode 100644 index 0000000000..78dd5e07fe --- /dev/null +++ b/internal/grpc/proxy/stream_buffer_test.go @@ -0,0 +1,165 @@ +package proxy_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v18/internal/grpc/proxy" + "gitlab.com/gitlab-org/gitaly/v18/internal/testhelper" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/interop/grpc_testing" + "google.golang.org/grpc/status" +) + +func TestNewStreamBuffer(t *testing.T) { + tests := []struct { + name string + // This is the payloads of each request sent to the server + sentPayloads []string + + // This function returns orderly collected payload from the StreamBuffer + // from either calls to Head() or RecvMsg() + collectedPayloads func(stream grpc_testing.TestService_FullDuplexCallServer) ([]string, error) + + // This is the expected reads by the StreamBuffer. It is the ordered + // reads of either Head() or RecvMsg() + expectedPayloads []string + + // The expected error if any + expectedError error + + // if true, cancel request in flight + cancelRequest bool + }{ + { + name: "should return expected collected payloads", + sentPayloads: []string{ + "1", "2", "3", + }, + collectedPayloads: func(stream grpc_testing.TestService_FullDuplexCallServer) ([]string, error) { + var payloads []string + streamBuffer := proxy.NewStreamBuffer(stream) + req := grpc_testing.StreamingOutputCallRequest{} + + // Fetch head + err := streamBuffer.Head(&req) + if err != nil { + return nil, err + } + payloads = append(payloads, string(req.Payload.Body)) + + // Fetch message + err = streamBuffer.RecvMsg(&req) + if err != nil { + return nil, err + } + payloads = append(payloads, string(req.Payload.Body)) + + // Fetch message + err = streamBuffer.RecvMsg(&req) + if err != nil { + return nil, err + } + payloads = append(payloads, string(req.Payload.Body)) + + // Fetch head + err = streamBuffer.Head(&req) + if err != nil { + return nil, err + } + payloads = append(payloads, string(req.Payload.Body)) + + // Fetch message + err = streamBuffer.RecvMsg(&req) + if err != nil { + return nil, err + } + payloads = append(payloads, string(req.Payload.Body)) + return payloads, nil + }, + expectedPayloads: []string{ + "1", "1", "2", "3", "3", + }, + expectedError: nil, + }, + { + name: "when context is cancelled, it should return the error", + collectedPayloads: func(stream grpc_testing.TestService_FullDuplexCallServer) ([]string, error) { + var payloads []string + streamBuffer := proxy.NewStreamBuffer(stream) + req := grpc_testing.StreamingOutputCallRequest{} + + // Fetch head + err := streamBuffer.Head(&req) + if err != nil { + return nil, err + } + payloads = append(payloads, string(req.Payload.Body)) + + // Fetch message + err = streamBuffer.RecvMsg(&req) + if err != nil { + return nil, err + } + return payloads, nil + }, + expectedError: status.Error(codes.Canceled, "context canceled"), + cancelRequest: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := testhelper.Context(t) + ctx, cancel := context.WithCancel(ctx) + clientConn, backend := newBackendPinger(t, ctx) + + var collectedPayloads []string + backend.fullDuplexCall = func(stream grpc_testing.TestService_FullDuplexCallServer) error { + var err error + collectedPayloads, err = tt.collectedPayloads(stream) + if err != nil { + return err + } + return stream.Send(&grpc_testing.StreamingOutputCallResponse{ + Payload: &grpc_testing.Payload{ + Body: []byte("bye"), + }, + }) + } + + clientStream, err := clientConn.NewStream(ctx, &grpc.StreamDesc{ + ServerStreams: true, + ClientStreams: true, + }, "grpc.testing.TestService/FullDuplexCall") + if err != nil { + t.Fatal(err) + } + + defer func() { _ = clientStream.CloseSend() }() + + // Handle the scenarios where the request + // should be canceled + if tt.cancelRequest { + cancel() + } else { + defer cancel() + } + + for _, p := range tt.sentPayloads { + err := clientStream.SendMsg(&grpc_testing.StreamingOutputCallRequest{ + Payload: &grpc_testing.Payload{ + Body: []byte(p), + }, + }) + require.NoError(t, err) + } + + ws := grpc_testing.StreamingOutputCallResponse{} + err = clientStream.RecvMsg(&ws) + require.Equal(t, tt.expectedError, err) + require.Equal(t, tt.expectedPayloads, collectedPayloads) + }) + } +} -- GitLab From 8db2dc2402c4a19988afed7f940e987be3f19167 Mon Sep 17 00:00:00 2001 From: Olivier Campeau Date: Tue, 18 Nov 2025 20:35:46 -0500 Subject: [PATCH 3/3] raft: Add a gRPC proxy interceptor This commit is still in draft. It introduces a gRPC interceptor to proxy requests from a Raft replica that is a follower to the leader of its Raft group. It is missing some features like authentication but the general flow is there. --- internal/gitaly/storage/raftmgr/replica.go | 14 ++ internal/grpc/middleware/raft/raft_proxy.go | 217 ++++++++++++++++++ .../grpc/middleware/raft/raft_proxy_test.go | 120 ++++++++++ 3 files changed, 351 insertions(+) create mode 100644 internal/grpc/middleware/raft/raft_proxy.go create mode 100644 internal/grpc/middleware/raft/raft_proxy_test.go diff --git a/internal/gitaly/storage/raftmgr/replica.go b/internal/gitaly/storage/raftmgr/replica.go index ae8ef0c812..5251346637 100644 --- a/internal/gitaly/storage/raftmgr/replica.go +++ b/internal/gitaly/storage/raftmgr/replica.go @@ -70,6 +70,12 @@ type RaftReplica interface { // GetCurrentState returns comprehensive current state information including term, index, and Raft state. // This provides an efficient way to get multiple state values with consistent locking. GetCurrentState() *ReplicaState + + // IsLeader returns true if this RaftReplica is the leader of its Raft group + IsLeader() bool + + // LeaderID returns the ID of the leader of this replica's Raft group. + LeaderID() uint64 } // StateString returns the normalized state string (removes "State" prefix and converts to lowercase). @@ -1058,6 +1064,14 @@ func (replica *Replica) AddLearner(ctx context.Context, address, destinationStor }) } +func (replica *Replica) IsLeader() bool { + return replica.leadership.IsLeader() +} + +func (replica *Replica) LeaderID() uint64 { + return replica.leadership.GetLeaderID() +} + func (replica *Replica) proposeMembershipChange( ctx context.Context, changeType, diff --git a/internal/grpc/middleware/raft/raft_proxy.go b/internal/grpc/middleware/raft/raft_proxy.go new file mode 100644 index 0000000000..42e691c60d --- /dev/null +++ b/internal/grpc/middleware/raft/raft_proxy.go @@ -0,0 +1,217 @@ +package raft + +import ( + "context" + "errors" + "fmt" + + "gitlab.com/gitlab-org/gitaly/v18/internal/gitaly/storage" + "gitlab.com/gitlab-org/gitaly/v18/internal/gitaly/storage/raftmgr" + "gitlab.com/gitlab-org/gitaly/v18/internal/grpc/client" + "gitlab.com/gitlab-org/gitaly/v18/internal/grpc/protoregistry" + "gitlab.com/gitlab-org/gitaly/v18/internal/grpc/proxy" + "gitlab.com/gitlab-org/gitaly/v18/internal/log" + "google.golang.org/grpc" + "google.golang.org/protobuf/proto" +) + +var ErrIgnore = errors.New("ignore") + +type Proxy struct { + connPool *client.Pool + logger log.Logger + node storage.Node +} + +// NewRaftProxyInterceptor creates a new proxy interceptor to proxy Raft request to the +// appropriate replica. This proxy exposes a gRPC interceptor for Unary and Stream. +func NewRaftProxyInterceptor(connPool *client.Pool, node storage.Node, logger log.Logger) (*Proxy, error) { + return &Proxy{ + connPool: connPool, + logger: logger, + node: node, + }, nil +} + +// Unary returns the grpc.UnaryServerInterceptor used for Raft proxying +func (rp *Proxy) Unary() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + mi, err := protoregistry.GitalyProtoPreregistered.LookupMethod(info.FullMethod) + if err != nil { + // Some requests, such as the gRPC HealthCheck, are not registered on + // the Gitaly server. Just allow them to go through + return handler(ctx, req) + } + + if mi.Scope == protoregistry.ScopeRepository { + // Parse the raw request received from the client + grpcReq, err := toProtoMessage(mi, req) + if err != nil { + return nil, err + } + + // We only proxy mutator request to the leader replica + // All other requests will be handled by this replica. + // Currently, we do not proxy Accessor request because we have + // no way of knowing if an Accessor request is coming from a + // client, or from the RaftProxy of another node. We want to + // avoid the case where the RaftProxies of all Gitaly servers + // just play ping-pong with a request. + if mi.Operation != protoregistry.OpMutator { + return handler(ctx, req) + } + + // Determine if we should proxy the request + addr, shouldProxy, err := rp.shouldProxy(mi, grpcReq) + if !shouldProxy { + // If we should not proxy the request, look at the error. + // Some errors can be ignored, in that case we still let the + // request go through on this replica. + // Else we return an error. + if err != nil && errors.Is(err, ErrIgnore) { + return handler(ctx, req) + } + return nil, err + } + + // Proxy the request + // First, get a connection from the connection pool + conn, err := rp.connPool.Dial(ctx, addr, "") + if err != nil { + return nil, err + } + + return proxy. + NewGrpcProxier(conn, mi.FullMethodName(), proxy.CommUnary). + ProxyUnary(ctx, grpcReq) + } + return handler(ctx, req) + } +} + +// Stream returns the grpc.StreamServerInterceptor used for Raft proxying +func (rp *Proxy) Stream() grpc.StreamServerInterceptor { + return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + mi, err := protoregistry.GitalyProtoPreregistered.LookupMethod(info.FullMethod) + if err != nil { + // Some requests, such as the gRPC HealthCheck, are not registered on + // the Gitaly server. Just allow them to go through + return handler(srv, stream) + } + + if mi.Scope == protoregistry.ScopeRepository { + // We only proxy mutator request to the leader replica + // All other requests will be handled by this replica. + // Currently, we do not proxy Accessor request because we have + // no way of knowing if an Accessor request is coming from a + // client, or from the RaftProxy of another node. We want to + // avoid the case where the RaftProxies of all Gitaly servers + // just play ping-pong with a request. + if mi.Operation != protoregistry.OpMutator { + return handler(srv, stream) + } + + streamBuffer := proxy.NewStreamBuffer(stream) + firstMessage := mi.NewRequest() + if err = streamBuffer.Head(firstMessage); err != nil { + return err + } + + // Determine if we should proxy the request + addr, shouldProxy, err := rp.shouldProxy(mi, firstMessage) + if !shouldProxy { + // If we should not proxy the request, look at the error. + // Some errors can be ignored, in that case we still let the + // request go through on this replica. + // Else we return an error. + if err != nil && errors.Is(err, ErrIgnore) { + return handler(srv, stream) + } + return err + } + + // Proxy the request + // First, get a connection from the connection pool + conn, err := rp.connPool.Dial(stream.Context(), addr, "") + if err != nil { + return err + } + + return proxy. + NewGrpcProxier(conn, mi.FullMethodName(), proxy.CommStream). + ProxyStream(streamBuffer) + } + return handler(srv, stream) + } +} + +func (rp *Proxy) shouldProxy(mi protoregistry.MethodInfo, req proto.Message) (addr string, proxy bool, err error) { + addr = "" + proxy = false + + // Extract the target repository from the request + targetRepo, err := mi.TargetRepo(req) + if err != nil { + return addr, proxy, err + } + + // Get the storage where the repository is stored + repoStorage, err := rp.node.GetStorage(targetRepo.GetStorageName()) + if err != nil { + return addr, proxy, err + } + + mi.NewRequest() + // Get partition ID of the repository + ptnId, err := repoStorage.GetAssignedPartitionID(targetRepo.GetRelativePath()) + if err != nil { + // If the repository does not exist, let this Gitaly server handle the request + if errors.Is(err, storage.ErrPartitionAssignmentNotFound) { + return addr, false, ErrIgnore + } + // Else return an error + return addr, proxy, err + } + + // Get partition key of the repository + ptnKey := raftmgr.NewPartitionKey(targetRepo.StorageName, ptnId) + + // TODO: this is the only way I have found to access the routing table + // but there must be another way, or else it seems we have a design issue. + raftStorage, ok := repoStorage.(*raftmgr.RaftEnabledStorage) + if !ok { + err = fmt.Errorf("cannot cast storage to *raftmgr.RaftEnabledStorage") + return addr, proxy, err + } + + // Fetch the current replica based on the partition key + raftReplica, err := raftStorage.GetReplicaRegistry().GetReplica(ptnKey) + if err != nil { + return + } + + // Do not proxy if current replica is leader + if raftReplica.IsLeader() { + return "", false, nil + } + + // Find the replica ID of the leader for this partition + leaderEntry, err := raftStorage.GetRoutingTable().Translate(ptnKey, raftReplica.LeaderID()) + if err != nil { + return + } + return leaderEntry.Metadata.GetAddress(), true, nil +} + +func toProtoMessage(mi protoregistry.MethodInfo, req interface{}) (proto.Message, error) { + pm, ok := req.(proto.Message) + if !ok { + return nil, fmt.Errorf("cannot cast request into proto.Message{}\n") + } + + payload, err := proto.Marshal(pm) + if err != nil { + return nil, fmt.Errorf("cannot marshal proto.Message{} into byte slice: %v\n", err) + } + return mi.UnmarshalRequestProto(payload) +} diff --git a/internal/grpc/middleware/raft/raft_proxy_test.go b/internal/grpc/middleware/raft/raft_proxy_test.go new file mode 100644 index 0000000000..1e5094f88b --- /dev/null +++ b/internal/grpc/middleware/raft/raft_proxy_test.go @@ -0,0 +1,120 @@ +package raft + +//import ( +// "context" +// "fmt" +// "os" +// "testing" +// +// "github.com/google/uuid" +// "github.com/stretchr/testify/require" +// "gitlab.com/gitlab-org/gitaly/v18/internal/gitaly/config" +// "gitlab.com/gitlab-org/gitaly/v18/internal/gitaly/service" +// hookservice "gitlab.com/gitlab-org/gitaly/v18/internal/gitaly/service/hook" +// "gitlab.com/gitlab-org/gitaly/v18/internal/gitaly/service/repository" +// "gitlab.com/gitlab-org/gitaly/v18/internal/gitaly/service/smarthttp" +// "gitlab.com/gitlab-org/gitaly/v18/internal/gitaly/storage" +// "gitlab.com/gitlab-org/gitaly/v18/internal/grpc/client" +// "gitlab.com/gitlab-org/gitaly/v18/internal/grpc/protoregistry" +// "gitlab.com/gitlab-org/gitaly/v18/internal/log" +// "gitlab.com/gitlab-org/gitaly/v18/internal/testhelper" +// "gitlab.com/gitlab-org/gitaly/v18/internal/testhelper/testcfg" +// "gitlab.com/gitlab-org/gitaly/v18/internal/testhelper/testserver" +// "gitlab.com/gitlab-org/gitaly/v18/proto/go/gitalypb" +// "google.golang.org/grpc" +//) +// +//const TestOliStorageName = "oli-one" +//const TestOliRelativePath = "raft-rel-path" +// +//func init() { +// _ = os.Setenv("GITALY_TEST_WAL", "True") +// _ = os.Setenv("GITALY_TEST_RAFT", "True") +//} +//func startRaftGitalyServer(t *testing.T, logger log.Logger, raftServerOpt testserver.GitalyServerOpt) config.Cfg { +// cfgOpts := []testcfg.Option{ +// testcfg.WithStorages(TestOliStorageName), +// testcfg.WithBase(config.Cfg{ +// Raft: config.Raft{ +// Enabled: true, +// ClusterID: uuid.New().String(), +// }, +// Transactions: config.Transactions{ +// Enabled: true, +// MaxInactivePartitions: 10, +// }, +// }), +// } +// cfg := testcfg.Build(t, cfgOpts...) +// +// // Create server +// serverOpts := []testserver.GitalyServerOpt{ +// testserver.WithLogger(logger), +// } +// +// // Add Raft middleware if specified +// if raftServerOpt != nil { +// serverOpts = append(serverOpts, raftServerOpt) +// } +// +// gitalyServer := testserver.StartGitalyServer(t, cfg, func(srv *grpc.Server, deps *service.Dependencies) { +// gitalypb.RegisterRepositoryServiceServer(srv, repository.NewServer(deps)) +// gitalypb.RegisterHookServiceServer(srv, hookservice.NewServer(deps)) +// gitalypb.RegisterSmartHTTPServiceServer(srv, smarthttp.NewServer(deps)) +// }, serverOpts...) +// t.Cleanup(gitalyServer.Shutdown) +// cfg.SocketPath = gitalyServer.Address() +// return cfg +//} +// +//func Test_proxyUnaryRequest(t *testing.T) { +// ctx := testhelper.Context(t) +// logger := testhelper.NewLogger(t) +// +// // Create main Gitaly server +// mainServer := startRaftGitalyServer(t, logger, nil) +// conn, err := client.New(testhelper.Context(t), mainServer.SocketPath) +// require.NoError(t, err) +// +// t.Cleanup(func() { require.NoError(t, conn.Close()) }) +// +// repositoryClient := gitalypb.NewRepositoryServiceClient(conn) +// +// // First create a repository +// _, err = repositoryClient.CreateRepository(ctx, &gitalypb.CreateRepositoryRequest{ +// Repository: &gitalypb.Repository{ +// StorageName: TestOliStorageName, +// RelativePath: TestOliRelativePath, +// }, +// }) +// require.NoError(t, err) +// +// // Create a second Gitaly server with a middleware that will use `raftProxy.proxyUnaryRequest(...)` +// // to proxy the request to the first server above, and then assert that the request was successfully +// // proxied by inspecting the response. +// secondServerRaftOpt := testserver.WithRaftProxy(func(node storage.Node) (unary grpc.UnaryServerInterceptor, stream grpc.StreamServerInterceptor) { +// var raftProxy *RaftProxy +// raftProxy, _ = NewRaftProxyInterceptor(client.NewPool(), node, logger) +// +// unary = func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { +// mi, err := protoregistry.GitalyProtoPreregistered.LookupMethod(info.FullMethod) +// return raftProxy.proxyUnaryRequest(ctx, mi, mainServer.SocketPath, req) +// } +// return unary, nil +// }) +// +// secondServer := startRaftGitalyServer(t, logger, secondServerRaftOpt) +// secondConn, err := client.New(testhelper.Context(t), secondServer.SocketPath) +// require.NoError(t, err) +// t.Cleanup(func() { require.NoError(t, secondConn.Close()) }) +// secondClient := gitalypb.NewRepositoryServiceClient(secondConn) +// +// infoResponse, err := secondClient.RepositoryInfo(ctx, &gitalypb.RepositoryInfoRequest{ +// Repository: &gitalypb.Repository{ +// StorageName: TestOliStorageName, +// RelativePath: TestOliRelativePath, +// }, +// }) +// require.NoError(t, err) +// fmt.Printf("Repository %s info retrieved: %s\n", infoResponse.References.String(), infoResponse.String()) +//} -- GitLab