diff --git a/internal/gitaly/storage/raftmgr/replica.go b/internal/gitaly/storage/raftmgr/replica.go index ae8ef0c8128306bde5672e137ea97748b883d49f..52513466374aff0efdb344bd46f5af6da4e40a3a 100644 --- a/internal/gitaly/storage/raftmgr/replica.go +++ b/internal/gitaly/storage/raftmgr/replica.go @@ -70,6 +70,12 @@ type RaftReplica interface { // GetCurrentState returns comprehensive current state information including term, index, and Raft state. // This provides an efficient way to get multiple state values with consistent locking. GetCurrentState() *ReplicaState + + // IsLeader returns true if this RaftReplica is the leader of its Raft group + IsLeader() bool + + // LeaderID returns the ID of the leader of this replica's Raft group. + LeaderID() uint64 } // StateString returns the normalized state string (removes "State" prefix and converts to lowercase). @@ -1058,6 +1064,14 @@ func (replica *Replica) AddLearner(ctx context.Context, address, destinationStor }) } +func (replica *Replica) IsLeader() bool { + return replica.leadership.IsLeader() +} + +func (replica *Replica) LeaderID() uint64 { + return replica.leadership.GetLeaderID() +} + func (replica *Replica) proposeMembershipChange( ctx context.Context, changeType, diff --git a/internal/grpc/middleware/raft/raft_proxy.go b/internal/grpc/middleware/raft/raft_proxy.go new file mode 100644 index 0000000000000000000000000000000000000000..42e691c60df2d4cdd246f9e824c8596a980995b5 --- /dev/null +++ b/internal/grpc/middleware/raft/raft_proxy.go @@ -0,0 +1,217 @@ +package raft + +import ( + "context" + "errors" + "fmt" + + "gitlab.com/gitlab-org/gitaly/v18/internal/gitaly/storage" + "gitlab.com/gitlab-org/gitaly/v18/internal/gitaly/storage/raftmgr" + "gitlab.com/gitlab-org/gitaly/v18/internal/grpc/client" + "gitlab.com/gitlab-org/gitaly/v18/internal/grpc/protoregistry" + "gitlab.com/gitlab-org/gitaly/v18/internal/grpc/proxy" + "gitlab.com/gitlab-org/gitaly/v18/internal/log" + "google.golang.org/grpc" + "google.golang.org/protobuf/proto" +) + +var ErrIgnore = errors.New("ignore") + +type Proxy struct { + connPool *client.Pool + logger log.Logger + node storage.Node +} + +// NewRaftProxyInterceptor creates a new proxy interceptor to proxy Raft request to the +// appropriate replica. This proxy exposes a gRPC interceptor for Unary and Stream. +func NewRaftProxyInterceptor(connPool *client.Pool, node storage.Node, logger log.Logger) (*Proxy, error) { + return &Proxy{ + connPool: connPool, + logger: logger, + node: node, + }, nil +} + +// Unary returns the grpc.UnaryServerInterceptor used for Raft proxying +func (rp *Proxy) Unary() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + mi, err := protoregistry.GitalyProtoPreregistered.LookupMethod(info.FullMethod) + if err != nil { + // Some requests, such as the gRPC HealthCheck, are not registered on + // the Gitaly server. Just allow them to go through + return handler(ctx, req) + } + + if mi.Scope == protoregistry.ScopeRepository { + // Parse the raw request received from the client + grpcReq, err := toProtoMessage(mi, req) + if err != nil { + return nil, err + } + + // We only proxy mutator request to the leader replica + // All other requests will be handled by this replica. + // Currently, we do not proxy Accessor request because we have + // no way of knowing if an Accessor request is coming from a + // client, or from the RaftProxy of another node. We want to + // avoid the case where the RaftProxies of all Gitaly servers + // just play ping-pong with a request. + if mi.Operation != protoregistry.OpMutator { + return handler(ctx, req) + } + + // Determine if we should proxy the request + addr, shouldProxy, err := rp.shouldProxy(mi, grpcReq) + if !shouldProxy { + // If we should not proxy the request, look at the error. + // Some errors can be ignored, in that case we still let the + // request go through on this replica. + // Else we return an error. + if err != nil && errors.Is(err, ErrIgnore) { + return handler(ctx, req) + } + return nil, err + } + + // Proxy the request + // First, get a connection from the connection pool + conn, err := rp.connPool.Dial(ctx, addr, "") + if err != nil { + return nil, err + } + + return proxy. + NewGrpcProxier(conn, mi.FullMethodName(), proxy.CommUnary). + ProxyUnary(ctx, grpcReq) + } + return handler(ctx, req) + } +} + +// Stream returns the grpc.StreamServerInterceptor used for Raft proxying +func (rp *Proxy) Stream() grpc.StreamServerInterceptor { + return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + mi, err := protoregistry.GitalyProtoPreregistered.LookupMethod(info.FullMethod) + if err != nil { + // Some requests, such as the gRPC HealthCheck, are not registered on + // the Gitaly server. Just allow them to go through + return handler(srv, stream) + } + + if mi.Scope == protoregistry.ScopeRepository { + // We only proxy mutator request to the leader replica + // All other requests will be handled by this replica. + // Currently, we do not proxy Accessor request because we have + // no way of knowing if an Accessor request is coming from a + // client, or from the RaftProxy of another node. We want to + // avoid the case where the RaftProxies of all Gitaly servers + // just play ping-pong with a request. + if mi.Operation != protoregistry.OpMutator { + return handler(srv, stream) + } + + streamBuffer := proxy.NewStreamBuffer(stream) + firstMessage := mi.NewRequest() + if err = streamBuffer.Head(firstMessage); err != nil { + return err + } + + // Determine if we should proxy the request + addr, shouldProxy, err := rp.shouldProxy(mi, firstMessage) + if !shouldProxy { + // If we should not proxy the request, look at the error. + // Some errors can be ignored, in that case we still let the + // request go through on this replica. + // Else we return an error. + if err != nil && errors.Is(err, ErrIgnore) { + return handler(srv, stream) + } + return err + } + + // Proxy the request + // First, get a connection from the connection pool + conn, err := rp.connPool.Dial(stream.Context(), addr, "") + if err != nil { + return err + } + + return proxy. + NewGrpcProxier(conn, mi.FullMethodName(), proxy.CommStream). + ProxyStream(streamBuffer) + } + return handler(srv, stream) + } +} + +func (rp *Proxy) shouldProxy(mi protoregistry.MethodInfo, req proto.Message) (addr string, proxy bool, err error) { + addr = "" + proxy = false + + // Extract the target repository from the request + targetRepo, err := mi.TargetRepo(req) + if err != nil { + return addr, proxy, err + } + + // Get the storage where the repository is stored + repoStorage, err := rp.node.GetStorage(targetRepo.GetStorageName()) + if err != nil { + return addr, proxy, err + } + + mi.NewRequest() + // Get partition ID of the repository + ptnId, err := repoStorage.GetAssignedPartitionID(targetRepo.GetRelativePath()) + if err != nil { + // If the repository does not exist, let this Gitaly server handle the request + if errors.Is(err, storage.ErrPartitionAssignmentNotFound) { + return addr, false, ErrIgnore + } + // Else return an error + return addr, proxy, err + } + + // Get partition key of the repository + ptnKey := raftmgr.NewPartitionKey(targetRepo.StorageName, ptnId) + + // TODO: this is the only way I have found to access the routing table + // but there must be another way, or else it seems we have a design issue. + raftStorage, ok := repoStorage.(*raftmgr.RaftEnabledStorage) + if !ok { + err = fmt.Errorf("cannot cast storage to *raftmgr.RaftEnabledStorage") + return addr, proxy, err + } + + // Fetch the current replica based on the partition key + raftReplica, err := raftStorage.GetReplicaRegistry().GetReplica(ptnKey) + if err != nil { + return + } + + // Do not proxy if current replica is leader + if raftReplica.IsLeader() { + return "", false, nil + } + + // Find the replica ID of the leader for this partition + leaderEntry, err := raftStorage.GetRoutingTable().Translate(ptnKey, raftReplica.LeaderID()) + if err != nil { + return + } + return leaderEntry.Metadata.GetAddress(), true, nil +} + +func toProtoMessage(mi protoregistry.MethodInfo, req interface{}) (proto.Message, error) { + pm, ok := req.(proto.Message) + if !ok { + return nil, fmt.Errorf("cannot cast request into proto.Message{}\n") + } + + payload, err := proto.Marshal(pm) + if err != nil { + return nil, fmt.Errorf("cannot marshal proto.Message{} into byte slice: %v\n", err) + } + return mi.UnmarshalRequestProto(payload) +} diff --git a/internal/grpc/middleware/raft/raft_proxy_test.go b/internal/grpc/middleware/raft/raft_proxy_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1e5094f88b86b9c7c99d1b74e5281a42b1ee5c61 --- /dev/null +++ b/internal/grpc/middleware/raft/raft_proxy_test.go @@ -0,0 +1,120 @@ +package raft + +//import ( +// "context" +// "fmt" +// "os" +// "testing" +// +// "github.com/google/uuid" +// "github.com/stretchr/testify/require" +// "gitlab.com/gitlab-org/gitaly/v18/internal/gitaly/config" +// "gitlab.com/gitlab-org/gitaly/v18/internal/gitaly/service" +// hookservice "gitlab.com/gitlab-org/gitaly/v18/internal/gitaly/service/hook" +// "gitlab.com/gitlab-org/gitaly/v18/internal/gitaly/service/repository" +// "gitlab.com/gitlab-org/gitaly/v18/internal/gitaly/service/smarthttp" +// "gitlab.com/gitlab-org/gitaly/v18/internal/gitaly/storage" +// "gitlab.com/gitlab-org/gitaly/v18/internal/grpc/client" +// "gitlab.com/gitlab-org/gitaly/v18/internal/grpc/protoregistry" +// "gitlab.com/gitlab-org/gitaly/v18/internal/log" +// "gitlab.com/gitlab-org/gitaly/v18/internal/testhelper" +// "gitlab.com/gitlab-org/gitaly/v18/internal/testhelper/testcfg" +// "gitlab.com/gitlab-org/gitaly/v18/internal/testhelper/testserver" +// "gitlab.com/gitlab-org/gitaly/v18/proto/go/gitalypb" +// "google.golang.org/grpc" +//) +// +//const TestOliStorageName = "oli-one" +//const TestOliRelativePath = "raft-rel-path" +// +//func init() { +// _ = os.Setenv("GITALY_TEST_WAL", "True") +// _ = os.Setenv("GITALY_TEST_RAFT", "True") +//} +//func startRaftGitalyServer(t *testing.T, logger log.Logger, raftServerOpt testserver.GitalyServerOpt) config.Cfg { +// cfgOpts := []testcfg.Option{ +// testcfg.WithStorages(TestOliStorageName), +// testcfg.WithBase(config.Cfg{ +// Raft: config.Raft{ +// Enabled: true, +// ClusterID: uuid.New().String(), +// }, +// Transactions: config.Transactions{ +// Enabled: true, +// MaxInactivePartitions: 10, +// }, +// }), +// } +// cfg := testcfg.Build(t, cfgOpts...) +// +// // Create server +// serverOpts := []testserver.GitalyServerOpt{ +// testserver.WithLogger(logger), +// } +// +// // Add Raft middleware if specified +// if raftServerOpt != nil { +// serverOpts = append(serverOpts, raftServerOpt) +// } +// +// gitalyServer := testserver.StartGitalyServer(t, cfg, func(srv *grpc.Server, deps *service.Dependencies) { +// gitalypb.RegisterRepositoryServiceServer(srv, repository.NewServer(deps)) +// gitalypb.RegisterHookServiceServer(srv, hookservice.NewServer(deps)) +// gitalypb.RegisterSmartHTTPServiceServer(srv, smarthttp.NewServer(deps)) +// }, serverOpts...) +// t.Cleanup(gitalyServer.Shutdown) +// cfg.SocketPath = gitalyServer.Address() +// return cfg +//} +// +//func Test_proxyUnaryRequest(t *testing.T) { +// ctx := testhelper.Context(t) +// logger := testhelper.NewLogger(t) +// +// // Create main Gitaly server +// mainServer := startRaftGitalyServer(t, logger, nil) +// conn, err := client.New(testhelper.Context(t), mainServer.SocketPath) +// require.NoError(t, err) +// +// t.Cleanup(func() { require.NoError(t, conn.Close()) }) +// +// repositoryClient := gitalypb.NewRepositoryServiceClient(conn) +// +// // First create a repository +// _, err = repositoryClient.CreateRepository(ctx, &gitalypb.CreateRepositoryRequest{ +// Repository: &gitalypb.Repository{ +// StorageName: TestOliStorageName, +// RelativePath: TestOliRelativePath, +// }, +// }) +// require.NoError(t, err) +// +// // Create a second Gitaly server with a middleware that will use `raftProxy.proxyUnaryRequest(...)` +// // to proxy the request to the first server above, and then assert that the request was successfully +// // proxied by inspecting the response. +// secondServerRaftOpt := testserver.WithRaftProxy(func(node storage.Node) (unary grpc.UnaryServerInterceptor, stream grpc.StreamServerInterceptor) { +// var raftProxy *RaftProxy +// raftProxy, _ = NewRaftProxyInterceptor(client.NewPool(), node, logger) +// +// unary = func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { +// mi, err := protoregistry.GitalyProtoPreregistered.LookupMethod(info.FullMethod) +// return raftProxy.proxyUnaryRequest(ctx, mi, mainServer.SocketPath, req) +// } +// return unary, nil +// }) +// +// secondServer := startRaftGitalyServer(t, logger, secondServerRaftOpt) +// secondConn, err := client.New(testhelper.Context(t), secondServer.SocketPath) +// require.NoError(t, err) +// t.Cleanup(func() { require.NoError(t, secondConn.Close()) }) +// secondClient := gitalypb.NewRepositoryServiceClient(secondConn) +// +// infoResponse, err := secondClient.RepositoryInfo(ctx, &gitalypb.RepositoryInfoRequest{ +// Repository: &gitalypb.Repository{ +// StorageName: TestOliStorageName, +// RelativePath: TestOliRelativePath, +// }, +// }) +// require.NoError(t, err) +// fmt.Printf("Repository %s info retrieved: %s\n", infoResponse.References.String(), infoResponse.String()) +//} diff --git a/internal/grpc/proxy/grpc_proxier.go b/internal/grpc/proxy/grpc_proxier.go new file mode 100644 index 0000000000000000000000000000000000000000..9b07e3eb3c1784c9380c521d3de63089b95b57ed --- /dev/null +++ b/internal/grpc/proxy/grpc_proxier.go @@ -0,0 +1,242 @@ +package proxy + +import ( + "context" + "io" + "strings" + "sync" + + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" +) + +const ( + CommUnary = "unary" + CommStream = "stream" +) + +// readWriteStream is a helper interface that wraps +// a subset of the grpc.ServerStream interface to +// facilitate some operations below. +type readWriteStream interface { + SendMsg(m any) error + RecvMsg(m any) error +} + +type CommPattern string + +// GrpcProxier is an object that allows proxying gRPC requests. +// It can proxy unary and stream requests. +// This proxier is simple and limited in scope. It can only +// be used for one request at a time. For each instance of +// GrpcProxier, the full method must be specified. +type GrpcProxier struct { + // conn is the gRPC connection to use to connect to the remote server + conn *grpc.ClientConn + + // fullMethod is the full method to call on the remote server + fullMethod string + + // opts defines the type of communication to open with the remote server + // ie: unary, full duplex, etc. + opts *grpc.StreamDesc +} + +// NewGrpcProxier creates a new GrpcProxier. +func NewGrpcProxier(conn *grpc.ClientConn, fullMethod string, comm CommPattern) *GrpcProxier { + // The stream name must be the short method name + // Example: + // Full method name = `package.ServiceName/MethodName` + // Short method name = `MethodName` + fullMethodParts := strings.Split(fullMethod, "/") + shortMethodName := fullMethodParts[len(fullMethodParts)-1] + + opts := &grpc.StreamDesc{ + StreamName: shortMethodName, + } + + if comm == CommStream { + opts.ServerStreams = true + opts.ClientStreams = true + } + + return &GrpcProxier{ + conn: conn, + fullMethod: fullMethod, + opts: opts, + } +} + +func (proxier *GrpcProxier) ProxyStream(incomingStream grpc.ServerStream) error { + // Get metadata from client + md, ok := metadata.FromIncomingContext(incomingStream.Context()) + if !ok { + md = metadata.New(nil) + } + + // Propagate headers to remote server and create outgoing stream + outCtx := metadata.NewOutgoingContext(context.Background(), md.Copy()) + outgoingStream, err := proxier.conn.NewStream(outCtx, proxier.opts, proxier.fullMethod) + if err != nil { + return err + } + + // Buffer length of 2 because we have 2 goroutines running: + // 1. Handle communication client <-> remote + // 2. Handle communication remote <-> client + // As such, there can, at most, be only 2 errors reported. + errCh := make(chan error, 2) + doneCh := make(chan struct{}, 1) + + wg := sync.WaitGroup{} + + // Handle client -> remote communication + wg.Add(1) + go func() { + defer wg.Done() + defer func() { _ = outgoingStream.CloseSend() }() + + for { + select { + case <-doneCh: + case <-incomingStream.Context().Done(): + return + default: + } + + msg := new(anypb.Any) + if recvErr := incomingStream.RecvMsg(msg); recvErr != nil { + // If err is EOF, it means the client closed + // the send stream, so there is nothing else to + // send to the remote, so we in turn close the + // outgoing stream by calling CloseSend() + // EOF is a non-error, so we don't send an error + // through the error channel in that case. + if recvErr != io.EOF { + errCh <- recvErr + } + break + } + + // Send the message to the remote + if sendErr := outgoingStream.SendMsg(msg); sendErr != nil { + // Any errors when sending a message should be reported + // because here we assume the send stream is still opened. + errCh <- sendErr + break + } + } + }() + + // Handle remote -> client communication + wg.Add(1) + go func() { + defer wg.Done() + defer func() { + incomingStream.SetTrailer(outgoingStream.Trailer().Copy()) + }() + + // this is used to control when to + // propagate headers + firstRecvMsg := true + + for { + select { + case <-doneCh: + case <-outgoingStream.Context().Done(): + return + default: + } + + msg := new(anypb.Any) + if recvErr := outgoingStream.RecvMsg(msg); recvErr != nil { + // If error is EOF, it means the remote server closed + // the `send` part of the stream, which means we should + // not expect any more new messages from the remote. So + // we stop sending messages back to the client. + // Else we report the error. + if recvErr != io.EOF { + errCh <- recvErr + } + break + } + + // Headers must always be set before sending the first message + if firstRecvMsg { + if hd, hdErr := outgoingStream.Header(); hdErr == nil { + _ = incomingStream.SetHeader(hd) + } + firstRecvMsg = false + } + + if sendErr := incomingStream.SendMsg(msg); sendErr != nil { + errCh <- sendErr + break + } + } + }() + + go func() { + defer close(errCh) + wg.Wait() + }() + + // Return the first error encountered + firstErr := <-errCh + + // Once an error has been returned, signal to all goroutines + // that they must abort by closing the channel. + close(doneCh) + + return firstErr +} + +func (proxier *GrpcProxier) ProxyUnary(incomingCtx context.Context, incomingRequest proto.Message) (any, error) { + // Get metadata from client + md, ok := metadata.FromIncomingContext(incomingCtx) + if !ok { + md = metadata.New(nil) + } + + // Propagate headers to remote server and create outgoing stream + outCtx := metadata.NewOutgoingContext(context.Background(), md.Copy()) + outgoingStream, err := proxier.conn.NewStream(outCtx, proxier.opts, proxier.fullMethod) + if err != nil { + return nil, err + } + + // Send the request to the remote server + if err := outgoingStream.SendMsg(incomingRequest); err != nil && err != io.EOF { + return nil, err + } + + // Close the stream to signal to the server that the request has been sent + if err := outgoingStream.CloseSend(); err != nil { + return nil, err + } + + // Receive the response. + // Here we use `mi.NewRequest()` because it creates a new proto.Message + // object that we can use to marshal the response in. + var response = anypb.Any{} + if err := outgoingStream.RecvMsg(&response); err != nil { + return nil, err + } + + // Send the headers back to the client + if headers, err := outgoingStream.Header(); err == nil && headers != nil { + if err := grpc.SendHeader(incomingCtx, headers); err != nil { + return nil, err + } + } + + // Send the trailers back to the client + if trailers := outgoingStream.Trailer(); trailers != nil { + if err := grpc.SetTrailer(incomingCtx, trailers); err != nil { + return nil, err + } + } + return &response, nil +} diff --git a/internal/grpc/proxy/grpc_proxier_test.go b/internal/grpc/proxy/grpc_proxier_test.go new file mode 100644 index 0000000000000000000000000000000000000000..ea4afd538be23edc71b31ee55b6cd539bd8f9731 --- /dev/null +++ b/internal/grpc/proxy/grpc_proxier_test.go @@ -0,0 +1,459 @@ +package proxy_test + +import ( + "context" + "io" + "sync" + "testing" + + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v18/internal/grpc/proxy" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/interop/grpc_testing" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +func TestPipeUnary(t *testing.T) { + const fullMethod = "grpc.testing.TestService/UnaryCall" + const shortMethod = "UnaryCall" + + tests := []struct { + name string + handler func(ctx context.Context, request *grpc_testing.SimpleRequest) (*grpc_testing.SimpleResponse, error) + sender func(ctx context.Context, client grpc_testing.TestServiceClient) (*grpc_testing.SimpleResponse, error) + expectedResponse *grpc_testing.SimpleResponse + expectedError error + doCancel bool + }{ + { + name: "a normal request should succeed", + handler: func(ctx context.Context, request *grpc_testing.SimpleRequest) (*grpc_testing.SimpleResponse, error) { + return &grpc_testing.SimpleResponse{ + Payload: &grpc_testing.Payload{ + Body: []byte("hello-world"), + }, + }, nil + }, + sender: func(ctx context.Context, client grpc_testing.TestServiceClient) (*grpc_testing.SimpleResponse, error) { + return client.UnaryCall(ctx, &grpc_testing.SimpleRequest{ + Payload: &grpc_testing.Payload{ + Body: []byte("some-request"), + }, + }) + }, + expectedResponse: &grpc_testing.SimpleResponse{ + Payload: &grpc_testing.Payload{ + Body: []byte("hello-world"), + }, + }, + }, + { + name: "headers and trailers must be propagated", + handler: func(ctx context.Context, request *grpc_testing.SimpleRequest) (*grpc_testing.SimpleResponse, error) { + err := grpc.SetTrailer(ctx, metadata.Pairs("some", "trailer")) + require.NoError(t, err) + + err = grpc.SetHeader(ctx, metadata.Pairs("some", "header")) + require.NoError(t, err) + + return &grpc_testing.SimpleResponse{ + Payload: &grpc_testing.Payload{ + Body: []byte("hello-world"), + }, + }, nil + }, + sender: func(ctx context.Context, client grpc_testing.TestServiceClient) (*grpc_testing.SimpleResponse, error) { + var headers metadata.MD + var trailers metadata.MD + + response, err := client.UnaryCall(ctx, &grpc_testing.SimpleRequest{ + Payload: &grpc_testing.Payload{ + Body: []byte("some-request"), + }, + }, grpc.Trailer(&trailers), grpc.Header(&headers)) + + require.Equal(t, "trailer", trailers.Get("some")[0]) + require.Equal(t, "header", headers.Get("some")[0]) + + return response, err + }, + expectedResponse: &grpc_testing.SimpleResponse{ + Payload: &grpc_testing.Payload{ + Body: []byte("hello-world"), + }, + }, + }, + { + name: "when an error is returned by the remote, the proxy should return that error back to the client", + handler: func(ctx context.Context, request *grpc_testing.SimpleRequest) (*grpc_testing.SimpleResponse, error) { + return nil, status.Error(codes.NotFound, "repo not found") + }, + sender: func(ctx context.Context, client grpc_testing.TestServiceClient) (*grpc_testing.SimpleResponse, error) { + return client.UnaryCall(ctx, &grpc_testing.SimpleRequest{ + Payload: &grpc_testing.Payload{ + Body: []byte("some-request"), + }, + }) + }, + expectedResponse: nil, + expectedError: status.Error(codes.NotFound, "repo not found"), + }, + { + name: "when the context is cancelled, the error should be reported back to the client", + handler: func(ctx context.Context, request *grpc_testing.SimpleRequest) (*grpc_testing.SimpleResponse, error) { + return nil, status.Error(codes.NotFound, "repo not found") + }, + sender: func(ctx context.Context, client grpc_testing.TestServiceClient) (*grpc_testing.SimpleResponse, error) { + return client.UnaryCall(ctx, &grpc_testing.SimpleRequest{ + Payload: &grpc_testing.Payload{ + Body: []byte("some-request"), + }, + }) + }, + expectedResponse: nil, + expectedError: status.Error(codes.Canceled, "context canceled"), + doCancel: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + backEndCtx := context.Background() + + // Create remote backend + remoteBackendConn, remoteBackend := newBackendPinger(t, backEndCtx) + remoteBackend.unaryCall = tc.handler + + // Create client backend + clientCtx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + + unaryProxy := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + return proxy. + NewGrpcProxier(remoteBackendConn, fullMethod, proxy.CommUnary). + ProxyUnary(ctx, req.(*grpc_testing.SimpleRequest)) + } + + clientBackendConn, _ := newBackendPinger(t, clientCtx, grpc.ChainUnaryInterceptor(unaryProxy)) + client := grpc_testing.NewTestServiceClient(clientBackendConn) + + // If the test cases requires we cancel the request, we cancel + // right before making the gRPC call, but after the connection + // has been opened. + if tc.doCancel { + clientCancel() + } + + response, err := tc.sender(clientCtx, client) + require.Equal(t, tc.expectedError, err) + + if tc.expectedResponse == nil { + require.Nil(t, response) + } else { + require.Equal(t, tc.expectedResponse.Payload.Body, response.Payload.Body) + } + }) + } +} + +func TestPipeStream(t *testing.T) { + const fullMethod = "grpc.testing.TestService/FullDuplexCall" + const shortMethod = "FullDuplexCall" + + tests := []struct { + name string + handler func(grpc_testing.TestService_FullDuplexCallServer) error + sender func(ctx context.Context, client grpc_testing.TestServiceClient) ([]string, error) + expectedPayloads []string + expectedError error + doCancel bool + }{ + { + name: "a simple request a and response should succeed", + handler: func(stream grpc_testing.TestService_FullDuplexCallServer) error { + msg, err := stream.Recv() + require.NoError(t, err) + + return stream.Send(&grpc_testing.StreamingOutputCallResponse{ + Payload: msg.Payload, + }) + }, + sender: func(ctx context.Context, client grpc_testing.TestServiceClient) (responses []string, err error) { + stream, err := client.FullDuplexCall(ctx) + require.NoError(t, err) + + err = stream.Send(&grpc_testing.StreamingOutputCallRequest{ + Payload: &grpc_testing.Payload{ + Body: []byte("hello"), + }, + }) + require.NoError(t, err) + + response, err := stream.Recv() + require.NoError(t, err) + + return []string{string(response.Payload.Body)}, nil + }, + expectedPayloads: []string{"hello"}, + expectedError: nil, + doCancel: false, + }, + { + name: "when the remote returns an error, it should be propagated to the client", + handler: func(stream grpc_testing.TestService_FullDuplexCallServer) error { + msg, err := stream.Recv() + require.NoError(t, err) + + return stream.Send(&grpc_testing.StreamingOutputCallResponse{ + Payload: msg.Payload, + }) + }, + sender: func(ctx context.Context, client grpc_testing.TestServiceClient) (responses []string, err error) { + stream, err := client.FullDuplexCall(ctx) + if err != nil { + return []string{}, err + } + + err = stream.Send(&grpc_testing.StreamingOutputCallRequest{ + Payload: &grpc_testing.Payload{ + Body: []byte("hello"), + }, + }) + require.NoError(t, err) + + _, err = stream.Recv() + return []string{}, err + }, + expectedPayloads: []string{}, + expectedError: status.Error(codes.Canceled, "context canceled"), + doCancel: true, + }, + { + name: "when the remote returns an error, it should be propagated to the client", + handler: func(stream grpc_testing.TestService_FullDuplexCallServer) error { + return status.Error(codes.NotFound, "repo not found") + }, + sender: func(ctx context.Context, client grpc_testing.TestServiceClient) (responses []string, err error) { + stream, err := client.FullDuplexCall(ctx) + require.NoError(t, err) + + err = stream.Send(&grpc_testing.StreamingOutputCallRequest{ + Payload: &grpc_testing.Payload{ + Body: []byte("hello"), + }, + }) + require.NoError(t, err) + + _, err = stream.Recv() + return []string{}, err + }, + expectedPayloads: []string{}, + expectedError: status.Error(codes.NotFound, "repo not found"), + doCancel: false, + }, + { + name: "concurrent requests and response should succeed", + sender: func(ctx context.Context, client grpc_testing.TestServiceClient) (responses []string, err error) { + payloads := []string{"hello", "hi", "bonjour"} + stream, err := client.FullDuplexCall(ctx) + require.NoError(t, err) + + errCh := make(chan error, 2) + + wg := sync.WaitGroup{} + + // Send requests + wg.Add(1) + go func() { + defer wg.Done() + defer func() { + _ = stream.CloseSend() + }() + + for _, p := range payloads { + select { + case <-stream.Context().Done(): + errCh <- stream.Context().Err() + return + default: + } + + err := stream.Send(&grpc_testing.StreamingOutputCallRequest{ + Payload: &grpc_testing.Payload{ + Body: []byte(p), + }, + }) + if err != nil { + if err != io.EOF { + errCh <- err + } + break + } + } + }() + + // Collect responses + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-stream.Context().Done(): + errCh <- stream.Context().Err() + return + default: + } + + res := grpc_testing.StreamingOutputCallResponse{} + err := stream.RecvMsg(&res) + if err != nil { + if err != io.EOF { + errCh <- err + } + break + } + responses = append(responses, string(res.Payload.Body)) + } + }() + + go func() { + defer close(errCh) + wg.Wait() + }() + + err = <-errCh + return responses, err + }, + handler: func(stream grpc_testing.TestService_FullDuplexCallServer) error { + for { + select { + case <-stream.Context().Done(): + return stream.Context().Err() + default: + } + + // First receive a message + req, err := stream.Recv() + if err != nil { + if err == io.EOF { + break + } + return err + } + + // send back the same payload that was received + err = stream.Send(&grpc_testing.StreamingOutputCallResponse{ + Payload: req.Payload, + }) + if err != nil { + if err == io.EOF { + break + } + return err + } + } + return nil + }, + expectedPayloads: []string{"hello", "hi", "bonjour"}, + }, + { + name: "headers and trailers should be propagated", + handler: func(stream grpc_testing.TestService_FullDuplexCallServer) error { + // Assert that the headers set in the `sender` func were propagated + // up to here. + md, ok := metadata.FromIncomingContext(stream.Context()) + require.True(t, ok) + require.Equal(t, "data", md["meta"][0]) + + msg, err := stream.Recv() + require.NoError(t, err) + + // Set headers to be sent back to the client + md.Set("hello", "goodbye") + err = stream.SetHeader(md) + require.NoError(t, err) + + err = stream.Send(&grpc_testing.StreamingOutputCallResponse{ + Payload: msg.Payload, + }) + require.NoError(t, err) + + stream.SetTrailer(metadata.New(map[string]string{"last": "trailer"})) + return err + }, + sender: func(ctx context.Context, client grpc_testing.TestServiceClient) (responses []string, err error) { + stream, err := client.FullDuplexCall(ctx) + require.NoError(t, err) + + err = stream.Send(&grpc_testing.StreamingOutputCallRequest{ + Payload: &grpc_testing.Payload{ + Body: []byte("goodbye"), + }, + }) + require.NoError(t, err) + + err = stream.CloseSend() + require.NoError(t, err) + + response, err := stream.Recv() + require.NoError(t, err) + + // A stream must always be read until EOF + _, err = stream.Recv() + require.ErrorIs(t, err, io.EOF) + + // Assert that headers and trailers sent back from the remote + // are propagated back up to the original client + hd, err := stream.Header() + require.NoError(t, err) + require.Equal(t, "goodbye", hd.Get("hello")[0]) + + // Assert that trailers are propagated + trailers := stream.Trailer() + require.Equal(t, "trailer", trailers.Get("last")[0]) + + return []string{string(response.Payload.Body)}, nil + }, + expectedPayloads: []string{"goodbye"}, + expectedError: nil, + doCancel: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Create remote backend + backEndCtx := context.Background() + remoteBackendConn, remoteBackend := newBackendPinger(t, backEndCtx) + remoteBackend.fullDuplexCall = tc.handler + + // Create client backend + clientCtx, clientCancel := context.WithCancel(context.Background()) + defer clientCancel() + + streamProxy := func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + return proxy. + NewGrpcProxier(remoteBackendConn, fullMethod, proxy.CommStream). + ProxyStream(ss) + } + clientBackendConn, _ := newBackendPinger(t, clientCtx, grpc.ChainStreamInterceptor(streamProxy)) + client := grpc_testing.NewTestServiceClient(clientBackendConn) + + // If the test cases requires we cancel the request, we cancel + // right before making the gRPC call, but after the connection + // has been opened. + if tc.doCancel { + clientCancel() + } + + // Set a custom metadata on the outgoing stream so + // we can validate later on that headers are propagated + clientCtx = metadata.NewOutgoingContext(clientCtx, metadata.New(map[string]string{"meta": "data"})) + response, err := tc.sender(clientCtx, client) + require.Equal(t, tc.expectedPayloads, response) + require.Equal(t, tc.expectedError, err) + }) + } +} diff --git a/internal/grpc/proxy/proxy_test_testhelper_test.go b/internal/grpc/proxy/proxy_test_testhelper_test.go index d7caf003adea5065b1eaa1ace46afd2c4678efbc..c958c75d56803a1f54fc2c945e7f0ace1d89b7f8 100644 --- a/internal/grpc/proxy/proxy_test_testhelper_test.go +++ b/internal/grpc/proxy/proxy_test_testhelper_test.go @@ -20,10 +20,10 @@ func newListener(tb testing.TB) net.Listener { return listener } -func newBackendPinger(tb testing.TB, ctx context.Context) (*grpc.ClientConn, *interceptPinger) { +func newBackendPinger(tb testing.TB, ctx context.Context, opts ...grpc.ServerOption) (*grpc.ClientConn, *interceptPinger) { ip := &interceptPinger{} - srvr := grpc.NewServer() + srvr := grpc.NewServer(opts...) listener := newListener(tb) grpc_testing.RegisterTestServiceServer(srvr, ip) diff --git a/internal/grpc/proxy/stream_buffer.go b/internal/grpc/proxy/stream_buffer.go new file mode 100644 index 0000000000000000000000000000000000000000..1b0215fbc046215473b1472d134a755cb46b973f --- /dev/null +++ b/internal/grpc/proxy/stream_buffer.go @@ -0,0 +1,90 @@ +package proxy + +import ( + "reflect" + "sync" + + "google.golang.org/grpc" + "google.golang.org/protobuf/proto" +) + +// StreamBuffer allows for buffering the head (first message) of a gRPC stream +// in order to inspect it before being consumed by the intended recipient. +// It implements the `grpc.ServerStream` interface so it can be used as any +// other gRPC stream. +type StreamBuffer struct { + grpc.ServerStream + buffer chan any + mutex *sync.Mutex +} + +// NewStreamBuffer returns a new StreamBuffer +func NewStreamBuffer(stream grpc.ServerStream) *StreamBuffer { + return &StreamBuffer{ + ServerStream: stream, + + // buffer is a channel of 1 because we always only buffer + // the head (first message) of the stream. + buffer: make(chan any, 1), + + // mutex is used because it is possible that Head() and RecvMsg() + // are called concurrently. If that happen, Head() might be in the + // state where the buffer has been emptied and has not been refilled + // back with the message. If RecvMsg() is executed during that time, + // it will see the buffer as empty and will process the next message in the + // stream. That will lead to re-ordering of messages, which we don't want. + mutex: &sync.Mutex{}, + } +} + +// Head returns in `m` the head (first message) in the gRPC stream. +// The head is always the first message not consumed by the intended +// recipient. +// Calling Head() multiple time without RecvMsg being called +// will always return the same result. +// For Head() to return the next message in the stream, RecvMsg +// must be called for the message to be fully consumed by the recipient. +func (s *StreamBuffer) Head(m any) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + // If a message is in the buffer, return it and put it back in the + // buffer for the next call to RecvMsg(). We must not lose this message + // since RecvMsg() will always return the message in the buffer if there + // is one before reading from the actual stream. + select { + case msg := <-s.buffer: + reflect.ValueOf(m).Elem().Set(reflect.ValueOf(msg).Elem()) + s.buffer <- msg + return nil + default: + } + + // If the buffer is empty, read the head of the stream + // into `m`. + v := m.(proto.Message) + err := s.ServerStream.RecvMsg(v) + if err != nil { + return err + } + + // Add the message in the buffer + s.buffer <- v + return nil +} + +// RecvMsg is a wrapper around the original stream. It looks into the +// buffer to see if there is a message there. If so, it returns it. +// Else it reads from the stream as normal. +func (s *StreamBuffer) RecvMsg(m any) error { + s.mutex.Lock() + defer s.mutex.Unlock() + + select { + case msg := <-s.buffer: + reflect.ValueOf(m).Elem().Set(reflect.ValueOf(msg).Elem()) + return nil + default: + return s.ServerStream.RecvMsg(m) + } +} diff --git a/internal/grpc/proxy/stream_buffer_test.go b/internal/grpc/proxy/stream_buffer_test.go new file mode 100644 index 0000000000000000000000000000000000000000..78dd5e07fe2143cadd38c82bf6fc386f8fce8194 --- /dev/null +++ b/internal/grpc/proxy/stream_buffer_test.go @@ -0,0 +1,165 @@ +package proxy_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v18/internal/grpc/proxy" + "gitlab.com/gitlab-org/gitaly/v18/internal/testhelper" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/interop/grpc_testing" + "google.golang.org/grpc/status" +) + +func TestNewStreamBuffer(t *testing.T) { + tests := []struct { + name string + // This is the payloads of each request sent to the server + sentPayloads []string + + // This function returns orderly collected payload from the StreamBuffer + // from either calls to Head() or RecvMsg() + collectedPayloads func(stream grpc_testing.TestService_FullDuplexCallServer) ([]string, error) + + // This is the expected reads by the StreamBuffer. It is the ordered + // reads of either Head() or RecvMsg() + expectedPayloads []string + + // The expected error if any + expectedError error + + // if true, cancel request in flight + cancelRequest bool + }{ + { + name: "should return expected collected payloads", + sentPayloads: []string{ + "1", "2", "3", + }, + collectedPayloads: func(stream grpc_testing.TestService_FullDuplexCallServer) ([]string, error) { + var payloads []string + streamBuffer := proxy.NewStreamBuffer(stream) + req := grpc_testing.StreamingOutputCallRequest{} + + // Fetch head + err := streamBuffer.Head(&req) + if err != nil { + return nil, err + } + payloads = append(payloads, string(req.Payload.Body)) + + // Fetch message + err = streamBuffer.RecvMsg(&req) + if err != nil { + return nil, err + } + payloads = append(payloads, string(req.Payload.Body)) + + // Fetch message + err = streamBuffer.RecvMsg(&req) + if err != nil { + return nil, err + } + payloads = append(payloads, string(req.Payload.Body)) + + // Fetch head + err = streamBuffer.Head(&req) + if err != nil { + return nil, err + } + payloads = append(payloads, string(req.Payload.Body)) + + // Fetch message + err = streamBuffer.RecvMsg(&req) + if err != nil { + return nil, err + } + payloads = append(payloads, string(req.Payload.Body)) + return payloads, nil + }, + expectedPayloads: []string{ + "1", "1", "2", "3", "3", + }, + expectedError: nil, + }, + { + name: "when context is cancelled, it should return the error", + collectedPayloads: func(stream grpc_testing.TestService_FullDuplexCallServer) ([]string, error) { + var payloads []string + streamBuffer := proxy.NewStreamBuffer(stream) + req := grpc_testing.StreamingOutputCallRequest{} + + // Fetch head + err := streamBuffer.Head(&req) + if err != nil { + return nil, err + } + payloads = append(payloads, string(req.Payload.Body)) + + // Fetch message + err = streamBuffer.RecvMsg(&req) + if err != nil { + return nil, err + } + return payloads, nil + }, + expectedError: status.Error(codes.Canceled, "context canceled"), + cancelRequest: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := testhelper.Context(t) + ctx, cancel := context.WithCancel(ctx) + clientConn, backend := newBackendPinger(t, ctx) + + var collectedPayloads []string + backend.fullDuplexCall = func(stream grpc_testing.TestService_FullDuplexCallServer) error { + var err error + collectedPayloads, err = tt.collectedPayloads(stream) + if err != nil { + return err + } + return stream.Send(&grpc_testing.StreamingOutputCallResponse{ + Payload: &grpc_testing.Payload{ + Body: []byte("bye"), + }, + }) + } + + clientStream, err := clientConn.NewStream(ctx, &grpc.StreamDesc{ + ServerStreams: true, + ClientStreams: true, + }, "grpc.testing.TestService/FullDuplexCall") + if err != nil { + t.Fatal(err) + } + + defer func() { _ = clientStream.CloseSend() }() + + // Handle the scenarios where the request + // should be canceled + if tt.cancelRequest { + cancel() + } else { + defer cancel() + } + + for _, p := range tt.sentPayloads { + err := clientStream.SendMsg(&grpc_testing.StreamingOutputCallRequest{ + Payload: &grpc_testing.Payload{ + Body: []byte(p), + }, + }) + require.NoError(t, err) + } + + ws := grpc_testing.StreamingOutputCallResponse{} + err = clientStream.RecvMsg(&ws) + require.Equal(t, tt.expectedError, err) + require.Equal(t, tt.expectedPayloads, collectedPayloads) + }) + } +}