From 113981c2a27d8faeb3193f1a174a29d99c5427ba Mon Sep 17 00:00:00 2001 From: Eric Ju Date: Wed, 22 Nov 2023 12:32:54 -0400 Subject: [PATCH 1/4] log: Implement gitaly log option and reporter interface In this commit we implement our own go-grpc-middleware v2 reporter interface and logger option in order to customize the logger with fields producers. The basic logic in our reporter interface follows grpc-middleware v2's interceptors/logging/interceptors.go. In PostCall function, we capture the error and feed it into field producer. --- .golangci.yml | 3 + go.mod | 2 +- go.sum | 4 +- internal/log/options.go | 81 +++ internal/log/options_test.go | 41 ++ internal/log/reporter.go | 210 ++++++++ internal/log/reporter_test.go | 490 ++++++++++++++++++ .../testprotomessage/mock_proto_message.pb.go | 158 ++++++ 8 files changed, 986 insertions(+), 3 deletions(-) create mode 100644 internal/log/options.go create mode 100644 internal/log/options_test.go create mode 100644 internal/log/reporter.go create mode 100644 internal/log/reporter_test.go create mode 100644 internal/testhelper/testprotomessage/mock_proto_message.pb.go diff --git a/.golangci.yml b/.golangci.yml index 825be0e444..052d016a98 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -3,6 +3,9 @@ run: # timeout for analysis, e.g. 30s, 5m, default is 1m timeout: 10m modules-download-mode: readonly + skip-files: + # ignore mock_proto_message.pb.go, since it is a generated file for testing + - internal/testhelper/testprotomessage/mock_proto_message.pb.go # list of useful linters could be found at https://github.com/golangci/awesome-go-linters linters: diff --git a/go.mod b/go.mod index 8e8c8c72a3..0959af34b7 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,7 @@ require ( github.com/google/go-cmp v0.5.9 github.com/google/uuid v1.3.1 github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 - github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.0 + github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.1 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/hashicorp/yamux v0.1.2-0.20220728231024-8f49b6f63f18 diff --git a/go.sum b/go.sum index 3b2a1f460b..4491019118 100644 --- a/go.sum +++ b/go.sum @@ -383,8 +383,8 @@ github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7Fsg github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 h1:UH//fgunKIs4JdUbpDl1VZCDaL56wXCB/5+wF6uHfaI= github.com/grpc-ecosystem/go-grpc-middleware v1.4.0/go.mod h1:g5qyo/la0ALbONm6Vbp88Yd8NsDy6rZz+RcrMPxvld8= -github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.0 h1:2cz5kSrxzMYHiWOBbKj8itQm+nRykkB8aMv4ThcHYHA= -github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.0/go.mod h1:w9Y7gY31krpLmrVU5ZPG9H7l9fZuRu5/3R3S3FMtVQ4= +github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.1 h1:HcUWd006luQPljE73d5sk+/VgYPGUReEVz2y1/qylwY= +github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.1/go.mod h1:w9Y7gY31krpLmrVU5ZPG9H7l9fZuRu5/3R3S3FMtVQ4= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 h1:Ovs26xHkKqVztRpIrF/92BcuyuQ/YW4NSIpoGtfXNho= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= diff --git a/internal/log/options.go b/internal/log/options.go new file mode 100644 index 0000000000..7ebf281dc0 --- /dev/null +++ b/internal/log/options.go @@ -0,0 +1,81 @@ +package log + +import ( + "time" + + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/selector" +) + +var defaultOptions = &options{ + loggableEvents: []logging.LoggableEvent{logging.StartCall, logging.FinishCall}, + codeFunc: logging.DefaultErrorToCode, + durationFieldFunc: logging.DefaultDurationToFields, + matcher: nil, + // levelFunc depends if it's client or server. + levelFunc: nil, + filedProducers: make([]FieldsProducer, 0), + timestampFormat: time.RFC3339, +} + +type options struct { + levelFunc logging.CodeToLevel + loggableEvents []logging.LoggableEvent + codeFunc logging.ErrorToCode + durationFieldFunc logging.DurationToFields + matcher *selector.Matcher + timestampFormat string + filedProducers []FieldsProducer +} + +// Option is used to customize the interceptor behavior. +// Use the With* functions (e.g. WithTimestampFormat) to create an Option. +type Option func(*options) + +func evaluateServerOpt(opts []Option) *options { + optCopy := &options{} + *optCopy = *defaultOptions + optCopy.levelFunc = logging.DefaultServerCodeToLevel + for _, o := range opts { + o(optCopy) + } + return optCopy +} + +func hasEvent(events []logging.LoggableEvent, event logging.LoggableEvent) bool { + for _, e := range events { + if e == event { + return true + } + } + return false +} + +// WithFiledProducers customizes the log fields with FieldsProducer. +// The fields produced by the producers will be appended to the log fields +func WithFiledProducers(producers ...FieldsProducer) Option { + return func(o *options) { + o.filedProducers = producers + } +} + +// WithTimestampFormat customizes the timestamps emitted in the log fields. +func WithTimestampFormat(format string) Option { + return func(o *options) { + o.timestampFormat = format + } +} + +// WithLogOnEvents customizes on what events the gRPC interceptor should log on. +func WithLogOnEvents(events ...logging.LoggableEvent) Option { + return func(o *options) { + o.loggableEvents = events + } +} + +// WithMatcher customizes the matcher used to select the gRPC method calls to log. +func WithMatcher(matcher *selector.Matcher) Option { + return func(o *options) { + o.matcher = matcher + } +} diff --git a/internal/log/options_test.go b/internal/log/options_test.go new file mode 100644 index 0000000000..a9fafbacc7 --- /dev/null +++ b/internal/log/options_test.go @@ -0,0 +1,41 @@ +package log + +import ( + "context" + "testing" + + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" + "github.com/stretchr/testify/require" +) + +func TestLoggerOptions(t *testing.T) { + t.Run("WithTimestampFormat", func(t *testing.T) { + format := "2006-01-02-gitaly-format-test" + opt := evaluateServerOpt([]Option{WithTimestampFormat(format)}) + require.Equal(t, opt.timestampFormat, format) + }) + + t.Run("WithFiledProducers", func(t *testing.T) { + fieldsProducers := []FieldsProducer{ + func(context.Context, error) Fields { + return Fields{"a": 1} + }, + func(context.Context, error) Fields { + return Fields{"b": "test"} + }, + func(ctx context.Context, err error) Fields { + return Fields{"c": err.Error()} + }, + } + opt := evaluateServerOpt([]Option{WithFiledProducers(fieldsProducers...)}) + require.Equal(t, opt.filedProducers, fieldsProducers) + }) + + t.Run("WithLogEvents", func(t *testing.T) { + events := []logging.LoggableEvent{ + logging.StartCall, logging.PayloadReceived, logging.PayloadSent, logging.FinishCall, + } + opt := evaluateServerOpt([]Option{WithLogOnEvents(events...)}) + require.Equal(t, opt.loggableEvents, events) + }) +} diff --git a/internal/log/reporter.go b/internal/log/reporter.go new file mode 100644 index 0000000000..4c20d1ec55 --- /dev/null +++ b/internal/log/reporter.go @@ -0,0 +1,210 @@ +package log + +import ( + "context" + "fmt" + "io" + "time" + + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" + "google.golang.org/grpc/peer" + "google.golang.org/protobuf/proto" +) + +// reporter implements v2 interceptors.Reporter interface see +// https://github.com/grpc-ecosystem/go-grpc-middleware/blob/main/interceptors/reporter.go. +// +// It is used in interceptor servers to add/extract/modified information from/into the grpc call. +// Refer to the interface implementation for more details. +type reporter struct { + interceptors.CallMeta + + ctx context.Context + kind string + startCallLogged bool + + opts *options + fields logging.Fields + logger logging.Logger +} + +// PostCall is called by logging interceptors after a request finishes (Unary) or when a stream handler exits. +// +// Internally, PostCall is called during the inPayload RPC stats in stats.Handler interface's HandleRPC method. +// More details can be found here +// at https://github.com/grpc-ecosystem/go-grpc-middleware/blob/main/interceptors/server.go and +// https://github.com/grpc-ecosystem/go-grpc-middleware/blob/main/interceptors/client.go +func (c *reporter) PostCall(err error, duration time.Duration) { + if !hasEvent(c.opts.loggableEvents, logging.FinishCall) { + return + } + if err == io.EOF { + err = nil + } + + code := c.opts.codeFunc(err) + fields := c.fields.WithUnique(logging.ExtractFields(c.ctx)) + fields = fields.AppendUnique(logging.Fields{"grpc.code", code.String()}) + if err != nil { + fields = fields.AppendUnique(logging.Fields{"grpc.error", fmt.Sprintf("%v", err)}) + } + + // Appending fields from fields producers, this is our customer logic versus + // what is defined originally in the go-grpc-middleware v2 reporter struct in logging package. + for _, fieldsProducer := range c.opts.filedProducers { + for key, val := range fieldsProducer(c.ctx, err) { + fields.Delete(key) + fields = fields.AppendUnique(logging.Fields{key, val}) + } + } + + msg := fmt.Sprintf("finished %s call with code %s", c.CallMeta.Typ, code.String()) + + c.logger.Log(c.ctx, c.opts.levelFunc(code), msg, fields.AppendUnique(c.opts.durationFieldFunc(duration))...) +} + +// PostMsgSend is called during the inPayload RPC stats in stats.Handler interface's HandleRPC method. +// It is the method called after a response is sent in a server interceptor or +// a request is sent in a client interceptor. +// +// This implementation is from on the go-grpc-middleware v2 reporter struct in logging package. +// Because logging's reporter is not exported, we have to copy the implementation here. +// More details can be found here at +// https://github.com/grpc-ecosystem/go-grpc-middleware/blob/47ca7d64b840248d6d2ea5a24af6496712396438/interceptors/logging/interceptors.go#L47 +func (c *reporter) PostMsgSend(payload any, err error, duration time.Duration) { + logLvl := c.opts.levelFunc(c.opts.codeFunc(err)) + fields := c.fields.WithUnique(logging.ExtractFields(c.ctx)) + if err != nil { + fields = fields.AppendUnique(logging.Fields{"grpc.error", fmt.Sprintf("%v", err)}) + } + if !c.startCallLogged && hasEvent(c.opts.loggableEvents, logging.StartCall) { + c.startCallLogged = true + c.logger.Log(c.ctx, logLvl, "started call", fields.AppendUnique(c.opts.durationFieldFunc(duration))...) + } + + if err != nil || !hasEvent(c.opts.loggableEvents, logging.PayloadSent) { + return + } + if c.CallMeta.IsClient { + p, ok := payload.(proto.Message) + if !ok { + c.logger.Log( + c.ctx, + logging.LevelError, + "payload is not a google.golang.org/protobuf/proto.Message; programmatic error?", + fields.AppendUnique(logging.Fields{"grpc.request.type", fmt.Sprintf("%T", payload)})..., + ) + return + } + + fields = fields.AppendUnique(logging.Fields{"grpc.send.duration", duration.String(), "grpc.request.content", p}) + c.logger.Log(c.ctx, logLvl, "request sent", fields...) + } else { + p, ok := payload.(proto.Message) + if !ok { + c.logger.Log( + c.ctx, + logging.LevelError, + "payload is not a google.golang.org/protobuf/proto.Message; programmatic error?", + fields.AppendUnique(logging.Fields{"grpc.response.type", fmt.Sprintf("%T", payload)})..., + ) + return + } + + fields = fields.AppendUnique(logging.Fields{"grpc.send.duration", duration.String(), "grpc.response.content", p}) + c.logger.Log(c.ctx, logLvl, "response sent", fields...) + } +} + +// PostMsgReceive is called during the inPayload RPC stats in stats.Handler interface's HandleRPC method. +// It is the method called after a request is received in a server interceptor or +// a response is received in a client interceptor. +// +// This implementation is from on the go-grpc-middleware v2 reporter struct in logging package. +// Because logging's reporter is not exported, we have to copy the implementation here. +// More details can be found here at +// https://github.com/grpc-ecosystem/go-grpc-middleware/blob/47ca7d64b840248d6d2ea5a24af6496712396438/interceptors/logging/interceptors.go#L92 +func (c *reporter) PostMsgReceive(payload any, err error, duration time.Duration) { + logLvl := c.opts.levelFunc(c.opts.codeFunc(err)) + fields := c.fields.WithUnique(logging.ExtractFields(c.ctx)) + if err != nil { + fields = fields.AppendUnique(logging.Fields{"grpc.error", fmt.Sprintf("%v", err)}) + } + if !c.startCallLogged && hasEvent(c.opts.loggableEvents, logging.StartCall) { + c.startCallLogged = true + c.logger.Log(c.ctx, logLvl, "started call", fields.AppendUnique(c.opts.durationFieldFunc(duration))...) + } + + if err != nil || !hasEvent(c.opts.loggableEvents, logging.PayloadReceived) { + return + } + if !c.CallMeta.IsClient { + p, ok := payload.(proto.Message) + if !ok { + c.logger.Log( + c.ctx, + logging.LevelError, + "payload is not a google.golang.org/protobuf/proto.Message; programmatic error?", + fields.AppendUnique(logging.Fields{"grpc.request.type", fmt.Sprintf("%T", payload)})..., + ) + return + } + + fields = fields.AppendUnique(logging.Fields{"grpc.recv.duration", duration.String(), "grpc.request.content", p}) + c.logger.Log(c.ctx, logLvl, "request received", fields...) + } else { + p, ok := payload.(proto.Message) + if !ok { + c.logger.Log( + c.ctx, + logging.LevelError, + "payload is not a google.golang.org/protobuf/proto.Message; programmatic error?", + fields.AppendUnique(logging.Fields{"grpc.response.type", fmt.Sprintf("%T", payload)})..., + ) + return + } + + fields = fields.AppendUnique(logging.Fields{"grpc.recv.duration", duration.String(), "grpc.response.content", p}) + c.logger.Log(c.ctx, logLvl, "response received", fields...) + } +} + +func reportable(logger logging.Logger, opts *options) interceptors.CommonReportableFunc { + return func(ctx context.Context, c interceptors.CallMeta) (interceptors.Reporter, context.Context) { + fields := logging.Fields{} + kind := logging.KindServerFieldValue + if c.IsClient { + kind = logging.KindClientFieldValue + } + + fields = fields.WithUnique(logging.ExtractFields(ctx)) + + // Appending fields from fields producers + for _, fieldsProducer := range opts.filedProducers { + for key, val := range fieldsProducer(ctx, nil) { + fields.AppendUnique(logging.Fields{key, val}) + } + } + + if !c.IsClient { + if peer, ok := peer.FromContext(ctx); ok { + fields = append(fields, "peer.address", peer.Addr.String()) + } + } + + singleUseFields := logging.Fields{"grpc.start_time", time.Now().Format(opts.timestampFormat)} + if d, ok := ctx.Deadline(); ok { + singleUseFields = singleUseFields.AppendUnique(logging.Fields{"grpc.request.deadline", d.Format(opts.timestampFormat)}) + } + return &reporter{ + CallMeta: c, + ctx: ctx, + startCallLogged: false, + opts: opts, + fields: fields.WithUnique(singleUseFields), + logger: logger, + kind: kind, + }, logging.InjectFields(ctx, fields) + } +} diff --git a/internal/log/reporter_test.go b/internal/log/reporter_test.go new file mode 100644 index 0000000000..7d5ecd67fd --- /dev/null +++ b/internal/log/reporter_test.go @@ -0,0 +1,490 @@ +package log + +import ( + "context" + "fmt" + "io" + "testing" + "time" + + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors" + grpcmwloggingv2 "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitaly/v16/internal/testhelper/testprotomessage" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" +) + +func TestLoggerFunc(t *testing.T) { + triggered := false + + attachedFields := grpcmwloggingv2.Fields{"field_key1", "v1", "field_key2", "v2"} + expectedFieldMap := ConvertLoggingFields(attachedFields) + + loggerFunc := grpcmwloggingv2.LoggerFunc( + func(c context.Context, level grpcmwloggingv2.Level, msg string, fields ...any) { + actual := ConvertLoggingFields(fields) + + require.Equal(t, createContext(), c) + require.Equal(t, "msg-stub", msg) + require.Equal(t, grpcmwloggingv2.LevelDebug, level) + + require.Equal(t, expectedFieldMap, actual) + triggered = true + }) + loggerFunc(createContext(), grpcmwloggingv2.LevelDebug, "msg-stub", attachedFields...) + + require.True(t, triggered) +} + +func TestReporter_PostCall(t *testing.T) { + abortedError := status.Error(codes.Aborted, "testing call aborted") + fieldProducers := []FieldsProducer{ + func(ctx context.Context, err error) Fields { return Fields{"a": 1} }, + func(ctx context.Context, err error) Fields { return Fields{"b": "2"} }, + func(ctx context.Context, err error) Fields { + if err == nil { + return Fields{"c": nil} + } + return Fields{"c": err.Error()} + }, + } + + for _, tc := range []struct { + desc string + err error + loggableEvents []grpcmwloggingv2.LoggableEvent + loggerFuncCalled bool // true if the logger function should be called + fieldsInCtx grpcmwloggingv2.Fields + expectedLevel grpcmwloggingv2.Level + expectedStatusCode codes.Code + expectedFields map[string]any + }{ + { + desc: "no error", + err: nil, + loggableEvents: []grpcmwloggingv2.LoggableEvent{grpcmwloggingv2.FinishCall}, + loggerFuncCalled: true, + fieldsInCtx: grpcmwloggingv2.Fields{"ctx.key1", "v1", "ctx.key2", "v2"}, + expectedLevel: grpcmwloggingv2.LevelInfo, + expectedStatusCode: codes.OK, + expectedFields: map[string]any{ + "a": 1, // from the field producers + "b": "2", // from the field producers + "c": nil, // from the field producers + "d": 3, // reporter pre-existing fields + "e": "4", // reporter pre-existing fields + "grpc.code": codes.OK.String(), // added by PostCall + "grpc.time_ms": "31.425", // added by PostCall + "ctx.key1": "v1", // from the context + "ctx.key2": "v2", // from the context + }, + }, + { + desc: "EOF error", + err: io.EOF, + loggableEvents: []grpcmwloggingv2.LoggableEvent{grpcmwloggingv2.FinishCall}, + loggerFuncCalled: true, + expectedLevel: grpcmwloggingv2.LevelInfo, + expectedStatusCode: codes.OK, + expectedFields: map[string]any{ + "a": 1, // from the field producers + "b": "2", // from the field producers + "c": nil, // from the field producers + "d": 3, // reporter pre-existing fields + "e": "4", // reporter pre-existing fields + "grpc.code": codes.OK.String(), // added by PostCall + "grpc.time_ms": "31.425", // added by PostCall + }, + }, + { + desc: "aborted error", + err: abortedError, + loggableEvents: []grpcmwloggingv2.LoggableEvent{grpcmwloggingv2.FinishCall}, + loggerFuncCalled: true, + fieldsInCtx: grpcmwloggingv2.Fields{"ctx.key1", "v1", "ctx.key2", "v2"}, + expectedLevel: grpcmwloggingv2.LevelWarn, + expectedStatusCode: codes.Aborted, + expectedFields: map[string]any{ + "a": 1, // from the field producers + "b": "2", // from the field producers + "c": abortedError.Error(), // from the field producers + "d": 3, // reporter pre-existing fields + "e": "4", // reporter pre-existing fields + "grpc.code": codes.Aborted.String(), // added by PostCall + "grpc.error": abortedError.Error(), // added by PostCall + "grpc.time_ms": "31.425", // added by PostCall + "ctx.key1": "v1", // from the context + "ctx.key2": "v2", // from the context + }, + }, + { + desc: "empty loggable events", + err: nil, + loggableEvents: []grpcmwloggingv2.LoggableEvent{}, + loggerFuncCalled: false, // loggable events are empty, so the logger function should not be called + // the rest of the fields are not important here, because the logger function is not called + }, + } { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + ctx := createContext() + if len(tc.fieldsInCtx) != 0 { + ctx = grpcmwloggingv2.InjectFields(ctx, tc.fieldsInCtx) + } + opts := evaluateServerOpt([]Option{WithFiledProducers(fieldProducers...), WithLogOnEvents(tc.loggableEvents...)}) + // opts.rpcType = rpcTypeUnary + var actualLoggerFuncCalled bool + mockReporter := &reporter{ + CallMeta: interceptors.CallMeta{ + Typ: interceptors.Unary, + }, + ctx: ctx, + kind: "kind-stub", + opts: opts, + fields: grpcmwloggingv2.Fields{"d", 3, "e", "4"}, + logger: grpcmwloggingv2.LoggerFunc( + // Customized logger function to verify the expected log message and fields + func(c context.Context, level grpcmwloggingv2.Level, msg string, fields ...any) { + actualLoggerFuncCalled = true + require.Equal(t, + fmt.Sprintf("finished unary call with code %s", tc.expectedStatusCode.String()), + msg) + require.Equal(t, tc.expectedLevel, level) + + actualFields := ConvertLoggingFields(fields) + require.Equal(t, tc.expectedFields, actualFields) + }), + } + mockReporter.PostCall(tc.err, time.Duration(31425926)) + require.Equal(t, tc.loggerFuncCalled, actualLoggerFuncCalled) + }) + } +} + +func TestReporter_PostMsgSend(t *testing.T) { + abortedError := status.Error(codes.Aborted, "testing call aborted") + testProtoMsg := &testprotomessage.MockProtoMessage{Key: "test-message-key", Value: "test-message-value"} + + for _, tc := range []struct { + desc string + err error + loggableEvents []grpcmwloggingv2.LoggableEvent + loggerFuncCalled bool // true if the logger function should be called + payload any + callMeta interceptors.CallMeta + fieldsInCtx grpcmwloggingv2.Fields + expectedLevel grpcmwloggingv2.Level + expectedFields map[string]any + expectedMsg string + }{ + { + desc: "no error with only StartCall event", + err: nil, + loggableEvents: []grpcmwloggingv2.LoggableEvent{grpcmwloggingv2.StartCall}, + loggerFuncCalled: true, + fieldsInCtx: grpcmwloggingv2.Fields{"ctx.key1", "v1", "ctx.key2", "v2"}, + expectedLevel: grpcmwloggingv2.LevelInfo, + expectedFields: map[string]any{ + "d": 3, // reporter pre-existing fields + "e": "4", // reporter pre-existing fields + "grpc.time_ms": "31.425", // added by PostMsgSend + "ctx.key1": "v1", // from the context + "ctx.key2": "v2", // from the context + }, + expectedMsg: "started call", + }, + { + desc: "client side, with only PayloadSent event", + err: nil, + loggableEvents: []grpcmwloggingv2.LoggableEvent{grpcmwloggingv2.PayloadSent}, + loggerFuncCalled: true, + callMeta: interceptors.CallMeta{ + IsClient: true, + }, + payload: proto.Message(testProtoMsg), + fieldsInCtx: grpcmwloggingv2.Fields{"ctx.key1", "v1", "ctx.key2", "v2"}, + expectedLevel: grpcmwloggingv2.LevelInfo, + expectedFields: map[string]any{ + "d": 3, // reporter pre-existing fields + "e": "4", // reporter pre-existing fields + "grpc.request.content": testProtoMsg, // added by PostMsgSend + "grpc.send.duration": "31.425926ms", // added by PostMsgSend + "ctx.key1": "v1", // from the context + "ctx.key2": "v2", // from the context + }, + expectedMsg: "request sent", + }, + { + desc: "client side, invalid payload, with only PayloadSent event", + err: nil, + loggableEvents: []grpcmwloggingv2.LoggableEvent{grpcmwloggingv2.PayloadSent}, + loggerFuncCalled: true, + callMeta: interceptors.CallMeta{ + IsClient: true, + }, + payload: "invalid payload", + fieldsInCtx: grpcmwloggingv2.Fields{"ctx.key1", "v1", "ctx.key2", "v2"}, + expectedLevel: grpcmwloggingv2.LevelError, + expectedFields: map[string]any{ + "d": 3, // reporter pre-existing fields + "e": "4", // reporter pre-existing fields + "grpc.request.type": "string", // added by PostMsgSend + "ctx.key1": "v1", // from the context + "ctx.key2": "v2", // from the context + }, + expectedMsg: "payload is not a google.golang.org/protobuf/proto.Message; programmatic error?", + }, + { + desc: "server side, with only PayloadSent event", + err: nil, + loggableEvents: []grpcmwloggingv2.LoggableEvent{grpcmwloggingv2.PayloadSent}, + loggerFuncCalled: true, + callMeta: interceptors.CallMeta{ + IsClient: false, + }, + payload: proto.Message(testProtoMsg), + fieldsInCtx: grpcmwloggingv2.Fields{"ctx.key1", "v1", "ctx.key2", "v2"}, + expectedLevel: grpcmwloggingv2.LevelInfo, + expectedFields: map[string]any{ + "d": 3, // reporter pre-existing fields + "e": "4", // reporter pre-existing fields + "grpc.response.content": testProtoMsg, // added by PostMsgSend + "grpc.send.duration": "31.425926ms", // added by PostMsgSend + "ctx.key1": "v1", // from the context + "ctx.key2": "v2", // from the context + }, + expectedMsg: "response sent", + }, + { + desc: "server side, invalid payload, with only PayloadSent event", + err: nil, + loggableEvents: []grpcmwloggingv2.LoggableEvent{grpcmwloggingv2.PayloadSent}, + loggerFuncCalled: true, + callMeta: interceptors.CallMeta{ + IsClient: false, + }, + payload: "invalid payload", + fieldsInCtx: grpcmwloggingv2.Fields{"ctx.key1", "v1", "ctx.key2", "v2"}, + expectedLevel: grpcmwloggingv2.LevelError, + expectedFields: map[string]any{ + "d": 3, // reporter pre-existing fields + "e": "4", // reporter pre-existing fields + "grpc.response.type": "string", // added by PostMsgSend + "ctx.key1": "v1", // from the context + "ctx.key2": "v2", // from the context + }, + expectedMsg: "payload is not a google.golang.org/protobuf/proto.Message; programmatic error?", + }, + { + desc: "aborted error with only StartCall event", + err: abortedError, + loggableEvents: []grpcmwloggingv2.LoggableEvent{grpcmwloggingv2.StartCall}, + loggerFuncCalled: true, + fieldsInCtx: grpcmwloggingv2.Fields{"ctx.key1", "v1", "ctx.key2", "v2"}, + expectedLevel: grpcmwloggingv2.LevelWarn, + expectedFields: map[string]any{ + "d": 3, // reporter pre-existing fields + "e": "4", // reporter pre-existing fields + "grpc.time_ms": "31.425", // added by PostMsgSend + "grpc.error": abortedError.Error(), // added by PostMsgSend + "ctx.key1": "v1", // from the context + "ctx.key2": "v2", // from the context + }, + expectedMsg: "started call", + }, + } { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + ctx := createContext() + if len(tc.fieldsInCtx) != 0 { + ctx = grpcmwloggingv2.InjectFields(ctx, tc.fieldsInCtx) + } + opts := evaluateServerOpt([]Option{WithLogOnEvents(tc.loggableEvents...)}) + // opts.rpcType = rpcTypeUnary + var actualLoggerFuncCalled bool + mockReporter := &reporter{ + CallMeta: tc.callMeta, + ctx: ctx, + kind: "kind-stub", + opts: opts, + fields: grpcmwloggingv2.Fields{"d", 3, "e", "4"}, + logger: grpcmwloggingv2.LoggerFunc( + // Customized logger function to verify the expected log message and fields + func(c context.Context, level grpcmwloggingv2.Level, msg string, fields ...any) { + actualLoggerFuncCalled = true + require.Equal(t, tc.expectedMsg, msg) + require.Equal(t, tc.expectedLevel, level) + actualFields := ConvertLoggingFields(fields) + require.Equal(t, tc.expectedFields, actualFields) + }), + } + mockReporter.PostMsgSend(tc.payload, tc.err, time.Duration(31425926)) + require.Equal(t, tc.loggerFuncCalled, actualLoggerFuncCalled) + }) + } +} + +func TestReporter_PostMsgReceive(t *testing.T) { + abortedError := status.Error(codes.Aborted, "testing call aborted") + testProtoMsg := &testprotomessage.MockProtoMessage{Key: "test-message-key", Value: "test-message-value"} + + for _, tc := range []struct { + desc string + err error + loggableEvents []grpcmwloggingv2.LoggableEvent + loggerFuncCalled bool // true if the logger function should be called + payload any + callMeta interceptors.CallMeta + fieldsInCtx grpcmwloggingv2.Fields + expectedLevel grpcmwloggingv2.Level + expectedFields map[string]any + expectedMsg string + }{ + { + desc: "no error with only StartCall event", + err: nil, + loggableEvents: []grpcmwloggingv2.LoggableEvent{grpcmwloggingv2.StartCall}, + loggerFuncCalled: true, + fieldsInCtx: grpcmwloggingv2.Fields{"ctx.key1", "v1", "ctx.key2", "v2"}, + expectedLevel: grpcmwloggingv2.LevelInfo, + expectedFields: map[string]any{ + "d": 3, // reporter pre-existing fields + "e": "4", // reporter pre-existing fields + "grpc.time_ms": "31.425", // added by PostMsgSend + "ctx.key1": "v1", // from the context + "ctx.key2": "v2", // from the context + }, + expectedMsg: "started call", + }, + { + desc: "client side, with only PayloadReceived event", + err: nil, + loggableEvents: []grpcmwloggingv2.LoggableEvent{grpcmwloggingv2.PayloadReceived}, + loggerFuncCalled: true, + callMeta: interceptors.CallMeta{ + IsClient: true, + }, + payload: proto.Message(testProtoMsg), + fieldsInCtx: grpcmwloggingv2.Fields{"ctx.key1", "v1", "ctx.key2", "v2"}, + expectedLevel: grpcmwloggingv2.LevelInfo, + expectedFields: map[string]any{ + "d": 3, // reporter pre-existing fields + "e": "4", // reporter pre-existing fields + "grpc.response.content": testProtoMsg, // added by PostMsgSend + "grpc.recv.duration": "31.425926ms", // added by PostMsgSend + "ctx.key1": "v1", // from the context + "ctx.key2": "v2", // from the context + }, + expectedMsg: "response received", + }, + { + desc: "client side, invalid payload, with only PayloadReceived event", + err: nil, + loggableEvents: []grpcmwloggingv2.LoggableEvent{grpcmwloggingv2.PayloadReceived}, + loggerFuncCalled: true, + callMeta: interceptors.CallMeta{ + IsClient: true, + }, + payload: "invalid payload", + fieldsInCtx: grpcmwloggingv2.Fields{"ctx.key1", "v1", "ctx.key2", "v2"}, + expectedLevel: grpcmwloggingv2.LevelError, + expectedFields: map[string]any{ + "d": 3, // reporter pre-existing fields + "e": "4", // reporter pre-existing fields + "grpc.response.type": "string", // added by PostMsgSend + "ctx.key1": "v1", // from the context + "ctx.key2": "v2", // from the context + }, + expectedMsg: "payload is not a google.golang.org/protobuf/proto.Message; programmatic error?", + }, + { + desc: "server side, with only PayloadReceived event", + err: nil, + loggableEvents: []grpcmwloggingv2.LoggableEvent{grpcmwloggingv2.PayloadReceived}, + loggerFuncCalled: true, + callMeta: interceptors.CallMeta{ + IsClient: false, + }, + payload: proto.Message(testProtoMsg), + fieldsInCtx: grpcmwloggingv2.Fields{"ctx.key1", "v1", "ctx.key2", "v2"}, + expectedLevel: grpcmwloggingv2.LevelInfo, + expectedFields: map[string]any{ + "d": 3, // reporter pre-existing fields + "e": "4", // reporter pre-existing fields + "grpc.request.content": testProtoMsg, // added by PostMsgSend + "grpc.recv.duration": "31.425926ms", // added by PostMsgSend + "ctx.key1": "v1", // from the context + "ctx.key2": "v2", // from the context + }, + expectedMsg: "request received", + }, + { + desc: "server side, invalid payload, with only PayloadSent event", + err: nil, + loggableEvents: []grpcmwloggingv2.LoggableEvent{grpcmwloggingv2.PayloadReceived}, + loggerFuncCalled: true, + callMeta: interceptors.CallMeta{ + IsClient: false, + }, + payload: "invalid payload", + fieldsInCtx: grpcmwloggingv2.Fields{"ctx.key1", "v1", "ctx.key2", "v2"}, + expectedLevel: grpcmwloggingv2.LevelError, + expectedFields: map[string]any{ + "d": 3, // reporter pre-existing fields + "e": "4", // reporter pre-existing fields + "grpc.request.type": "string", // added by PostMsgSend + "ctx.key1": "v1", // from the context + "ctx.key2": "v2", // from the context + }, + expectedMsg: "payload is not a google.golang.org/protobuf/proto.Message; programmatic error?", + }, + { + desc: "aborted error with only StartCall event", + err: abortedError, + loggableEvents: []grpcmwloggingv2.LoggableEvent{grpcmwloggingv2.StartCall}, + loggerFuncCalled: true, + fieldsInCtx: grpcmwloggingv2.Fields{"ctx.key1", "v1", "ctx.key2", "v2"}, + expectedLevel: grpcmwloggingv2.LevelWarn, + expectedFields: map[string]any{ + "d": 3, // reporter pre-existing fields + "e": "4", // reporter pre-existing fields + "grpc.time_ms": "31.425", // added by PostMsgSend + "grpc.error": abortedError.Error(), // added by PostMsgSend + "ctx.key1": "v1", // from the context + "ctx.key2": "v2", // from the context + }, + expectedMsg: "started call", + }, + } { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + ctx := createContext() + if len(tc.fieldsInCtx) != 0 { + ctx = grpcmwloggingv2.InjectFields(ctx, tc.fieldsInCtx) + } + opts := evaluateServerOpt([]Option{WithLogOnEvents(tc.loggableEvents...)}) + // opts.rpcType = rpcTypeUnary + var actualLoggerFuncCalled bool + mockReporter := &reporter{ + CallMeta: tc.callMeta, + ctx: ctx, + kind: "kind-stub", + opts: opts, + fields: grpcmwloggingv2.Fields{"d", 3, "e", "4"}, + logger: grpcmwloggingv2.LoggerFunc( + // Customized logger function to verify the expected log message and fields + func(c context.Context, level grpcmwloggingv2.Level, msg string, fields ...any) { + actualLoggerFuncCalled = true + require.Equal(t, tc.expectedMsg, msg) + require.Equal(t, tc.expectedLevel, level) + actualFields := ConvertLoggingFields(fields) + require.Equal(t, tc.expectedFields, actualFields) + }), + } + mockReporter.PostMsgReceive(tc.payload, tc.err, time.Duration(31425926)) + require.Equal(t, tc.loggerFuncCalled, actualLoggerFuncCalled) + }) + } +} diff --git a/internal/testhelper/testprotomessage/mock_proto_message.pb.go b/internal/testhelper/testprotomessage/mock_proto_message.pb.go new file mode 100644 index 0000000000..0332d487cb --- /dev/null +++ b/internal/testhelper/testprotomessage/mock_proto_message.pb.go @@ -0,0 +1,158 @@ +// Package testprotomessage provides a mock proto message for testing purposes +package testprotomessage + +import ( + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/runtime/protoimpl" + "reflect" + "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// MockProtoMessage is a test proto message generated by protoc-gen-go. +// It is used as a payload in the unit test +// +// syntax = "proto3"; +// package log; +// message TestRecord { +// string key = 1; +// string value = 2; +// } +type MockProtoMessage struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` + Value string `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` +} + +func (x *MockProtoMessage) Reset() { + *x = MockProtoMessage{} + if protoimpl.UnsafeEnabled { + mi := &file_test_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MockProtoMessage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MockProtoMessage) ProtoMessage() {} + +func (x *MockProtoMessage) ProtoReflect() protoreflect.Message { + mi := &file_test_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MockProtoMessage.ProtoReflect.Descriptor instead. +func (*MockProtoMessage) Descriptor() ([]byte, []int) { + return file_test_proto_rawDescGZIP(), []int{0} +} + +func (x *MockProtoMessage) GetKey() string { + if x != nil { + return x.Key + } + return "" +} + +func (x *MockProtoMessage) GetValue() string { + if x != nil { + return x.Value + } + return "" +} + +var File_test_proto protoreflect.FileDescriptor + +var file_test_proto_rawDesc = []byte{ + 0x0a, 0x0a, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x03, 0x6c, 0x6f, + 0x67, 0x22, 0x3a, 0x0a, 0x10, 0x4d, 0x6f, 0x63, 0x6b, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x4d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x42, 0x0e, 0x5a, + 0x0c, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x6c, 0x6f, 0x67, 0x62, 0x06, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_test_proto_rawDescOnce sync.Once + file_test_proto_rawDescData = file_test_proto_rawDesc +) + +func file_test_proto_rawDescGZIP() []byte { + file_test_proto_rawDescOnce.Do(func() { + file_test_proto_rawDescData = protoimpl.X.CompressGZIP(file_test_proto_rawDescData) + }) + return file_test_proto_rawDescData +} + +var ( + file_test_proto_msgTypes = make([]protoimpl.MessageInfo, 1) + file_test_proto_goTypes = []interface{}{ + (*MockProtoMessage)(nil), // 0: log.MockProtoMessage + } +) + +var file_test_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_test_proto_init() } +func file_test_proto_init() { + if File_test_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_test_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MockProtoMessage); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_test_proto_rawDesc, + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_test_proto_goTypes, + DependencyIndexes: file_test_proto_depIdxs, + MessageInfos: file_test_proto_msgTypes, + }.Build() + File_test_proto = out.File + file_test_proto_rawDesc = nil + file_test_proto_goTypes = nil + file_test_proto_depIdxs = nil +} -- GitLab From f3c038680e716ccdc53c834afadc923b9c58d268 Mon Sep 17 00:00:00 2001 From: Eric Ju Date: Wed, 22 Nov 2023 13:59:46 -0400 Subject: [PATCH 2/4] go: Change logger interceptor interface to use grpc/middleware v2 In this commit, we changed the logger interface to use grpc/middleware v2 interceptors. The v1 MessageProducer type is replaced by v2's LoggerFunc. A DefaultInterceptorLogger function is added to replace v1's DefaultMessageProducer. --- internal/gitaly/server/server.go | 37 +++--- internal/grpc/grpcstats/stats_test.go | 19 +--- .../customfields_handler_test.go | 19 ++-- .../featureflag/featureflag_handler_test.go | 25 ++-- internal/log/logger.go | 38 +++++-- internal/log/middleware.go | 97 ++++++++-------- internal/log/middleware_test.go | 107 +++++------------- internal/log/options_test.go | 6 + internal/praefect/delete_object_pool_test.go | 8 +- internal/praefect/server.go | 37 +++--- .../testserver/structerr_interceptors_test.go | 10 +- 11 files changed, 184 insertions(+), 219 deletions(-) diff --git a/internal/gitaly/server/server.go b/internal/gitaly/server/server.go index 2123a9cc59..989f3319d7 100644 --- a/internal/gitaly/server/server.go +++ b/internal/gitaly/server/server.go @@ -5,7 +5,6 @@ import ( "fmt" "time" - grpcmwlogrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus" "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/server/auth" "gitlab.com/gitlab-org/gitaly/v16/internal/grpc/backchannel" @@ -89,15 +88,7 @@ func (s *GitalyServerFactory) New(external, secure bool, opts ...Option) (*grpc. []grpc.DialOption{client.UnaryInterceptor()}, )) - logMsgProducer := grpcmwlogrus.WithMessageProducer( - gitalylog.MessageProducer( - gitalylog.PropagationMessageProducer(grpcmwlogrus.DefaultMessageProducer), - customfieldshandler.FieldsProducer, - grpcstats.FieldsProducer, - featureflag.FieldsProducer, - structerr.FieldsProducer, - ), - ) + loggerFunc := gitalylog.PropagationMessageProducer(gitalylog.DefaultInterceptorLogger(s.logger)) streamServerInterceptors := []grpc.StreamServerInterceptor{ grpccorrelation.StreamServerCorrelationInterceptor(), // Must be above the metadata handler @@ -105,9 +96,15 @@ func (s *GitalyServerFactory) New(external, secure bool, opts ...Option) (*grpc. grpcprometheus.StreamServerInterceptor, customfieldshandler.StreamInterceptor, s.logger.WithField("component", "gitaly.StreamServerInterceptor").StreamServerInterceptor( - grpcmwlogrus.WithTimestampFormat(gitalylog.LogTimestampFormat), - logMsgProducer, - gitalylog.DeciderOption(), + loggerFunc, + gitalylog.WithMatcher(gitalylog.DeciderMatcher()), + gitalylog.WithTimestampFormat(gitalylog.LogTimestampFormat), + gitalylog.WithFiledProducers( + customfieldshandler.FieldsProducer, + grpcstats.FieldsProducer, + featureflag.FieldsProducer, + structerr.FieldsProducer, + ), ), gitalylog.StreamLogDataCatcherServerInterceptor(), sentryhandler.StreamLogHandler(), @@ -119,10 +116,16 @@ func (s *GitalyServerFactory) New(external, secure bool, opts ...Option) (*grpc. requestinfohandler.UnaryInterceptor, grpcprometheus.UnaryServerInterceptor, customfieldshandler.UnaryInterceptor, - s.logger.WithField("component", "gitaly.UnaryServerInterceptor").UnaryServerInterceptor( - grpcmwlogrus.WithTimestampFormat(gitalylog.LogTimestampFormat), - logMsgProducer, - gitalylog.DeciderOption(), + s.logger.WithField("component", "gitaly.StreamServerInterceptor").UnaryServerInterceptor( + loggerFunc, + gitalylog.WithMatcher(gitalylog.DeciderMatcher()), + gitalylog.WithTimestampFormat(gitalylog.LogTimestampFormat), + gitalylog.WithFiledProducers( + customfieldshandler.FieldsProducer, + grpcstats.FieldsProducer, + featureflag.FieldsProducer, + structerr.FieldsProducer, + ), ), gitalylog.UnaryLogDataCatcherServerInterceptor(), sentryhandler.UnaryLogHandler(), diff --git a/internal/grpc/grpcstats/stats_test.go b/internal/grpc/grpcstats/stats_test.go index b3bfa881cf..9f39520235 100644 --- a/internal/grpc/grpcstats/stats_test.go +++ b/internal/grpc/grpcstats/stats_test.go @@ -8,7 +8,6 @@ import ( "sync" "testing" - grpcmwlogrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitaly/v16/internal/grpc/client" @@ -63,23 +62,15 @@ func TestPayloadBytes(t *testing.T) { }), grpc.ChainUnaryInterceptor( logger.UnaryServerInterceptor( - grpcmwlogrus.WithMessageProducer( - log.MessageProducer( - log.PropagationMessageProducer(grpcmwlogrus.DefaultMessageProducer), - FieldsProducer, - ), - ), + log.PropagationMessageProducer(log.DefaultInterceptorLogger(logger)), + log.WithFiledProducers(FieldsProducer), ), log.UnaryLogDataCatcherServerInterceptor(), ), grpc.ChainStreamInterceptor( logger.StreamServerInterceptor( - grpcmwlogrus.WithMessageProducer( - log.MessageProducer( - log.PropagationMessageProducer(grpcmwlogrus.DefaultMessageProducer), - FieldsProducer, - ), - ), + log.PropagationMessageProducer(log.DefaultInterceptorLogger(logger)), + log.WithFiledProducers(FieldsProducer), ), log.StreamLogDataCatcherServerInterceptor(), ), @@ -152,7 +143,7 @@ func TestPayloadBytes(t *testing.T) { require.EqualValues(t, 8, e.Data["grpc.request.payload_bytes"]) require.EqualValues(t, 8, e.Data["grpc.response.payload_bytes"]) } - if e.Message == "finished streaming call with code OK" { + if e.Message == "finished bidi_stream call with code OK" { stream++ require.EqualValues(t, 16, e.Data["grpc.request.payload_bytes"]) require.EqualValues(t, 16, e.Data["grpc.response.payload_bytes"]) diff --git a/internal/grpc/middleware/customfieldshandler/customfields_handler_test.go b/internal/grpc/middleware/customfieldshandler/customfields_handler_test.go index e92a4ec3b5..bb17c0b8ab 100644 --- a/internal/grpc/middleware/customfieldshandler/customfields_handler_test.go +++ b/internal/grpc/middleware/customfieldshandler/customfields_handler_test.go @@ -6,7 +6,6 @@ import ( "net" "testing" - grpcmwlogrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitaly/v16/internal/git/catfile" "gitlab.com/gitlab-org/gitaly/v16/internal/git/gittest" @@ -31,14 +30,18 @@ func createNewServer(t *testing.T, cfg config.Cfg, logger log.Logger) *grpc.Serv grpc.ChainStreamInterceptor( StreamInterceptor, logger.StreamServerInterceptor( - grpcmwlogrus.WithTimestampFormat(log.LogTimestampFormat), - grpcmwlogrus.WithMessageProducer(log.MessageProducer(grpcmwlogrus.DefaultMessageProducer, FieldsProducer))), + log.DefaultInterceptorLogger(logger), + log.WithTimestampFormat(log.LogTimestampFormat), + log.WithFiledProducers(FieldsProducer), + ), ), grpc.ChainUnaryInterceptor( UnaryInterceptor, logger.UnaryServerInterceptor( - grpcmwlogrus.WithTimestampFormat(log.LogTimestampFormat), - grpcmwlogrus.WithMessageProducer(log.MessageProducer(grpcmwlogrus.DefaultMessageProducer, FieldsProducer))), + log.DefaultInterceptorLogger(logger), + log.WithTimestampFormat(log.LogTimestampFormat), + log.WithFiledProducers(FieldsProducer), + ), ), } @@ -145,10 +148,10 @@ func TestInterceptor(t *testing.T) { tt.performRPC(t, ctx, client) logEntries := hook.AllEntries() - require.Len(t, logEntries, 1) + require.Len(t, logEntries, 2) // 1 for the starting RPC call, 1 for finishing it for expectedLogKey, expectedLogValue := range tt.expectedLogData { - require.Contains(t, logEntries[0].Data, expectedLogKey) - require.Equal(t, logEntries[0].Data[expectedLogKey], expectedLogValue) + require.Contains(t, logEntries[1].Data, expectedLogKey) + require.Equal(t, logEntries[1].Data[expectedLogKey], expectedLogValue) } }) } diff --git a/internal/grpc/middleware/featureflag/featureflag_handler_test.go b/internal/grpc/middleware/featureflag/featureflag_handler_test.go index 45eb31d800..fb8514d561 100644 --- a/internal/grpc/middleware/featureflag/featureflag_handler_test.go +++ b/internal/grpc/middleware/featureflag/featureflag_handler_test.go @@ -5,7 +5,6 @@ import ( "net" "testing" - grpcmwlogrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitaly/v16/internal/featureflag" "gitlab.com/gitlab-org/gitaly/v16/internal/log" @@ -38,15 +37,9 @@ func TestFeatureFlagLogs(t *testing.T) { service := &mockService{} server := grpc.NewServer( grpc.ChainUnaryInterceptor( - grpcmwlogrus.UnaryServerInterceptor( - logger.LogrusEntry(), //nolint:staticcheck - grpcmwlogrus.WithMessageProducer( - log.MessageProducer( - grpcmwlogrus.DefaultMessageProducer, - FieldsProducer, - ), - ), - ), + logger.UnaryServerInterceptor( + log.DefaultInterceptorLogger(logger), + log.WithFiledProducers(FieldsProducer)), ), ) grpc_testing.RegisterTestServiceServer(server, service) @@ -126,10 +119,14 @@ func TestFeatureFlagLogs(t *testing.T) { testhelper.RequireGrpcError(t, tc.returnedErr, err) for _, logEntry := range loggerHook.AllEntries() { - if tc.expectedFields == "" { - require.NotContains(t, logEntry.Data, "feature_flags") - } else { - require.Equal(t, tc.expectedFields, logEntry.Data["feature_flags"]) + // We will have 2 log entries for each RPC call, one for starting and one for finishing, + // and we only want to check the finishing one. + if logEntry.Message != "started call" { + if tc.expectedFields == "" { + require.NotContains(t, logEntry.Data, "feature_flags") + } else { + require.Equal(t, tc.expectedFields, logEntry.Data["feature_flags"]) + } } } }) diff --git a/internal/log/logger.go b/internal/log/logger.go index 6a38ae6328..ddf4c704de 100644 --- a/internal/log/logger.go +++ b/internal/log/logger.go @@ -2,10 +2,11 @@ package log import ( "context" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors" - grpcmwlogrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" grpcmwloggingv2 "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" + grpcmwselector "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/selector" "github.com/sirupsen/logrus" "google.golang.org/grpc" ) @@ -29,8 +30,10 @@ type Logger interface { WarnContext(ctx context.Context, msg string) ErrorContext(ctx context.Context, msg string) - StreamServerInterceptor(...grpcmwlogrus.Option) grpc.StreamServerInterceptor - UnaryServerInterceptor(...grpcmwlogrus.Option) grpc.UnaryServerInterceptor + StreamServerInterceptor(grpcmwloggingv2.Logger, ...Option) grpc.StreamServerInterceptor + UnaryServerInterceptor(grpcmwloggingv2.Logger, ...Option) grpc.UnaryServerInterceptor + + ReplaceFields(fields Fields) Logger } // LogrusLogger is an implementation of the Logger interface that is implemented via a `logrus.FieldLogger`. @@ -62,6 +65,11 @@ func (l LogrusLogger) WithFields(fields Fields) Logger { return LogrusLogger{entry: l.entry.WithFields(fields)} } +func (l LogrusLogger) ReplaceFields(fields Fields) Logger { + l.entry.Data = Fields{} + return LogrusLogger{entry: l.entry.WithFields(fields)} +} + // WithError creates a new logger with an appended error field. func (l LogrusLogger) WithError(err error) Logger { return LogrusLogger{entry: l.entry.WithError(err)} @@ -92,14 +100,26 @@ func (l LogrusLogger) toContext(ctx context.Context) context.Context { return ctxlogrus.ToContext(ctx, l.entry) } -// StreamServerInterceptor creates a gRPC interceptor that generates log messages for streaming RPC calls. -func (l LogrusLogger) StreamServerInterceptor(opts ...grpcmwlogrus.Option) grpc.StreamServerInterceptor { - return grpcmwlogrus.StreamServerInterceptor(l.entry, opts...) +// StreamServerInterceptor creates a new stream server interceptor. The loggerFunc is the function that will be called to +// log messages; options are the Option slice that can be used to configure the interceptor. +func (l LogrusLogger) StreamServerInterceptor(loggerFunc grpcmwloggingv2.Logger, options ...Option) grpc.StreamServerInterceptor { + o := evaluateServerOpt(options) + interceptor := interceptors.StreamServerInterceptor(reportable(loggerFunc, o)) + if o.matcher != nil { + interceptor = grpcmwselector.StreamServerInterceptor(interceptor, *o.matcher) + } + return interceptor } -// UnaryServerInterceptor creates a gRPC interceptor that generates log messages for unary RPC calls. -func (l LogrusLogger) UnaryServerInterceptor(opts ...grpcmwlogrus.Option) grpc.UnaryServerInterceptor { - return grpcmwlogrus.UnaryServerInterceptor(l.entry, opts...) +// UnaryServerInterceptor creates a new unary server interceptor. The loggerFunc is the function that will be called to +// log messages; options are the Option slice that can be used to configure the interceptor. +func (l LogrusLogger) UnaryServerInterceptor(loggerFunc grpcmwloggingv2.Logger, options ...Option) grpc.UnaryServerInterceptor { + o := evaluateServerOpt(options) + interceptor := interceptors.UnaryServerInterceptor(reportable(loggerFunc, o)) + if o.matcher != nil { + interceptor = grpcmwselector.UnaryServerInterceptor(interceptor, *o.matcher) + } + return interceptor } func (l LogrusLogger) log(ctx context.Context, level logrus.Level, msg string) { diff --git a/internal/log/middleware.go b/internal/log/middleware.go index e29fef98e5..3d47148d14 100644 --- a/internal/log/middleware.go +++ b/internal/log/middleware.go @@ -4,13 +4,11 @@ import ( "context" "regexp" - grpcmwlogging "github.com/grpc-ecosystem/go-grpc-middleware/logging" - grpcmwlogrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors" grpcmwloggingv2 "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" - "github.com/sirupsen/logrus" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/selector" "gitlab.com/gitlab-org/gitaly/v16/internal/helper/env" "google.golang.org/grpc" - "google.golang.org/grpc/codes" "google.golang.org/grpc/stats" ) @@ -19,26 +17,21 @@ const ( defaultLogRequestMethodDenyPattern = "^/grpc.health.v1.Health/Check$" ) -// DeciderOption returns a Option to support log filtering. +// DeciderMatcher is used as a selector to support log filtering. // If "GITALY_LOG_REQUEST_METHOD_DENY_PATTERN" ENV variable is set, logger will filter out the log whose "fullMethodName" matches it; // If "GITALY_LOG_REQUEST_METHOD_ALLOW_PATTERN" ENV variable is set, logger will only keep the log whose "fullMethodName" matches it; // Under any conditions, the error log will not be filtered out; // If the ENV variables are not set, there will be no additional effects. -func DeciderOption() grpcmwlogrus.Option { +// Replacing old DeciderOption +func DeciderMatcher() *selector.Matcher { matcher := methodNameMatcherFromEnv() - - if matcher == nil { - return grpcmwlogrus.WithDecider(grpcmwlogging.DefaultDeciderMethod) - } - - decider := func(fullMethodName string, err error) bool { - if err != nil { + matcherFunc := selector.MatchFunc(func(_ context.Context, callMeta interceptors.CallMeta) bool { + if matcher == nil { return true } - return matcher(fullMethodName) - } - - return grpcmwlogrus.WithDecider(decider) + return matcher(callMeta.FullMethod()) + }) + return &matcherFunc } func methodNameMatcherFromEnv() func(string) bool { @@ -67,27 +60,12 @@ func methodNameMatcherFromEnv() func(string) bool { // the result of RPC handling. type FieldsProducer func(context.Context, error) Fields -// MessageProducer returns a wrapper that extends passed mp to accept additional fields generated -// by each of the fieldsProducers. -func MessageProducer(mp grpcmwlogrus.MessageProducer, fieldsProducers ...FieldsProducer) grpcmwlogrus.MessageProducer { - return func(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields Fields) { - for _, fieldsProducer := range fieldsProducers { - for key, val := range fieldsProducer(ctx, err) { - fields[key] = val - } - } - mp(ctx, format, level, code, err, fields) - } -} - type messageProducerHolder struct { logger LogrusLogger - actual grpcmwlogrus.MessageProducer - format string - level logrus.Level - code codes.Code - err error - fields Fields + actual grpcmwloggingv2.LoggerFunc + msg string + level grpcmwloggingv2.Level + fields grpcmwloggingv2.Fields } type messageProducerHolderKey struct{} @@ -106,8 +84,8 @@ func messageProducerPropagationFrom(ctx context.Context) *messageProducerHolder // PropagationMessageProducer catches logging information from the context and populates it // to the special holder that should be present in the context. // Should be used only in combination with PerRPCLogHandler. -func PropagationMessageProducer(actual grpcmwlogrus.MessageProducer) grpcmwlogrus.MessageProducer { - return func(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields Fields) { +func PropagationMessageProducer(actual grpcmwloggingv2.LoggerFunc) grpcmwloggingv2.LoggerFunc { + return func(ctx context.Context, level grpcmwloggingv2.Level, msg string, fields ...any) { mpp := messageProducerPropagationFrom(ctx) if mpp == nil { return @@ -115,10 +93,8 @@ func PropagationMessageProducer(actual grpcmwlogrus.MessageProducer) grpcmwlogru *mpp = messageProducerHolder{ logger: fromContext(ctx), actual: actual, - format: format, + msg: msg, level: level, - code: code, - err: err, fields: fields, } } @@ -159,11 +135,16 @@ func (lh PerRPCLogHandler) HandleRPC(ctx context.Context, rs stats.RPCStats) { } if mpp.fields == nil { - mpp.fields = Fields{} + mpp.fields = grpcmwloggingv2.Fields{} } for _, fp := range lh.FieldProducers { - for k, v := range fp(ctx, mpp.err) { - mpp.fields[k] = v + for k, v := range fp(ctx, nil) { + // The message producers can have fields with updated values, for example + // grpc.response.payload_bytes increased from 0 to 100. In this case we need + // update the value of the field instead of appending it. The grpc middleware v2 logging + // fields don't support update, so we need to delete the field and append it again. + mpp.fields.Delete(k) + mpp.fields = mpp.fields.AppendUnique(grpcmwloggingv2.Fields{k, v}) } } // Once again because all interceptors are finished and context doesn't contain @@ -171,7 +152,8 @@ func (lh PerRPCLogHandler) HandleRPC(ctx context.Context, rs stats.RPCStats) { // It's needed because github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus.DefaultMessageProducer // extracts logger from the context and use it to write the logs. ctx = mpp.logger.toContext(ctx) - mpp.actual(ctx, mpp.format, mpp.level, mpp.code, mpp.err, mpp.fields) + mpp.actual(ctx, mpp.level, mpp.msg, mpp.fields...) + return } } @@ -191,7 +173,9 @@ func UnaryLogDataCatcherServerInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { mpp := messageProducerPropagationFrom(ctx) if mpp != nil { - mpp.fields = fromContext(ctx).entry.Data + for k, v := range fromContext(ctx).entry.Data { + mpp.fields = mpp.fields.AppendUnique(grpcmwloggingv2.Fields{k, v}) + } } return handler(ctx, req) } @@ -204,7 +188,9 @@ func StreamLogDataCatcherServerInterceptor() grpc.StreamServerInterceptor { ctx := ss.Context() mpp := messageProducerPropagationFrom(ctx) if mpp != nil { - mpp.fields = fromContext(ctx).entry.Data + for k, v := range fromContext(ctx).entry.Data { + mpp.fields = mpp.fields.AppendUnique(grpcmwloggingv2.Fields{k, v}) + } } return handler(srv, ss) } @@ -223,3 +209,20 @@ func ConvertLoggingFields(fields grpcmwloggingv2.Fields) map[string]any { } return fieldsMap } + +// DefaultInterceptorLogger adapts gitaly's logger interface to grpc middleware logger function. +func DefaultInterceptorLogger(l Logger) grpcmwloggingv2.LoggerFunc { + return func(c context.Context, level grpcmwloggingv2.Level, msg string, fields ...any) { + f := ConvertLoggingFields(fields) + switch level { + case grpcmwloggingv2.LevelDebug: + l.ReplaceFields(f).Debug(msg) + case grpcmwloggingv2.LevelInfo: + l.ReplaceFields(f).Info(msg) + case grpcmwloggingv2.LevelWarn: + l.ReplaceFields(f).Warn(msg) + case grpcmwloggingv2.LevelError: + l.ReplaceFields(f).Error(msg) + } + } +} diff --git a/internal/log/middleware_test.go b/internal/log/middleware_test.go index 4a714b097f..5bc90afa46 100644 --- a/internal/log/middleware_test.go +++ b/internal/log/middleware_test.go @@ -5,87 +5,36 @@ import ( "testing" grpcmw "github.com/grpc-ecosystem/go-grpc-middleware" - grpcmwlogrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" grpcmwloggingv2 "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/selector" "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" - "google.golang.org/grpc/codes" "google.golang.org/grpc/stats" ) -func TestMessageProducer(t *testing.T) { - triggered := false - - attachedFields := Fields{"e": "stub"} - msgProducer := MessageProducer(func(c context.Context, format string, level logrus.Level, code codes.Code, err error, fields Fields) { - require.Equal(t, createContext(), c) - require.Equal(t, "format-stub", format) - require.Equal(t, logrus.DebugLevel, level) - require.Equal(t, codes.OutOfRange, code) - require.Equal(t, assert.AnError, err) - require.Equal(t, attachedFields, fields) - triggered = true - }) - msgProducer(createContext(), "format-stub", logrus.DebugLevel, codes.OutOfRange, assert.AnError, attachedFields) - - require.True(t, triggered) -} - -func TestMessageProducerWithFieldsProducers(t *testing.T) { - triggered := false - - var infoFromCtx struct{} - ctx := createContext() - ctx = context.WithValue(ctx, infoFromCtx, "world") - - fieldsProducer1 := func(context.Context, error) Fields { - return Fields{"a": 1} - } - fieldsProducer2 := func(context.Context, error) Fields { - return Fields{"b": "test"} - } - fieldsProducer3 := func(ctx context.Context, err error) Fields { - return Fields{"c": err.Error()} - } - fieldsProducer4 := func(ctx context.Context, err error) Fields { - return Fields{"d": ctx.Value(infoFromCtx)} - } - attachedFields := Fields{"e": "stub"} - - msgProducer := MessageProducer(func(c context.Context, format string, level logrus.Level, code codes.Code, err error, fields Fields) { - require.Equal(t, Fields{"a": 1, "b": "test", "c": err.Error(), "d": "world", "e": "stub"}, fields) - triggered = true - }, fieldsProducer1, fieldsProducer2, fieldsProducer3, fieldsProducer4) - msgProducer(ctx, "format-stub", logrus.InfoLevel, codes.OK, assert.AnError, attachedFields) - - require.True(t, triggered) -} - func TestPropagationMessageProducer(t *testing.T) { t.Run("empty context", func(t *testing.T) { ctx := createContext() - mp := PropagationMessageProducer(func(context.Context, string, logrus.Level, codes.Code, error, Fields) {}) - mp(ctx, "", logrus.DebugLevel, codes.OK, nil, nil) + mp := PropagationMessageProducer(func(context.Context, grpcmwloggingv2.Level, string, ...any) {}) + mp(ctx, grpcmwloggingv2.LevelDebug, "", nil, nil) }) t.Run("context with holder", func(t *testing.T) { holder := new(messageProducerHolder) ctx := context.WithValue(createContext(), messageProducerHolderKey{}, holder) triggered := false - mp := PropagationMessageProducer(func(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields Fields) { + mp := PropagationMessageProducer(func(context.Context, grpcmwloggingv2.Level, string, ...any) { triggered = true }) - mp(ctx, "format-stub", logrus.DebugLevel, codes.OutOfRange, assert.AnError, Fields{"a": 1}) - require.Equal(t, "format-stub", holder.format) - require.Equal(t, logrus.DebugLevel, holder.level) - require.Equal(t, codes.OutOfRange, holder.code) - require.Equal(t, assert.AnError, holder.err) - require.Equal(t, Fields{"a": 1}, holder.fields) - holder.actual(ctx, "", logrus.DebugLevel, codes.OK, nil, nil) + mp(ctx, grpcmwloggingv2.LevelDebug, "format-stub", grpcmwloggingv2.Fields{"a", 1}...) + require.Equal(t, "format-stub", holder.msg) + require.Equal(t, grpcmwloggingv2.LevelDebug, holder.level) + require.Equal(t, grpcmwloggingv2.Fields{"a", 1}, holder.fields) + holder.actual(ctx, grpcmwloggingv2.LevelDebug, "", nil, nil) require.True(t, triggered) }) } @@ -98,7 +47,6 @@ func TestPerRPCLogHandler(t *testing.T) { FieldProducers: []FieldsProducer{ func(ctx context.Context, err error) Fields { return Fields{"a": 1} }, func(ctx context.Context, err error) Fields { return Fields{"b": "2"} }, - func(ctx context.Context, err error) Fields { return Fields{"c": err.Error()} }, }, } @@ -127,16 +75,12 @@ func TestPerRPCLogHandler(t *testing.T) { ctx := ctxlogrus.ToContext(createContext(), logrus.NewEntry(newLogger())) ctx = lh.TagRPC(ctx, &stats.RPCTagInfo{}) mpp := ctx.Value(messageProducerHolderKey{}).(*messageProducerHolder) - mpp.format = "message" - mpp.level = logrus.InfoLevel - mpp.code = codes.InvalidArgument - mpp.err = assert.AnError - mpp.actual = func(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields Fields) { - assert.Equal(t, "message", format) - assert.Equal(t, logrus.InfoLevel, level) - assert.Equal(t, codes.InvalidArgument, code) - assert.Equal(t, assert.AnError, err) - assert.Equal(t, Fields{"a": 1, "b": "2", "c": mpp.err.Error()}, mpp.fields) + mpp.msg = "message" + mpp.level = grpcmwloggingv2.LevelInfo + mpp.actual = func(ctx context.Context, level grpcmwloggingv2.Level, msg string, fields ...any) { + assert.Equal(t, "message", msg) + assert.Equal(t, grpcmwloggingv2.LevelInfo, level) + assert.Equal(t, grpcmwloggingv2.Fields{"a", 1, "b", "2"}, mpp.fields) } lh.HandleRPC(ctx, &stats.End{}) }) @@ -194,7 +138,7 @@ func TestUnaryLogDataCatcherServerInterceptor(t *testing.T) { ctx = ctxlogrus.ToContext(ctx, newLogger().WithField("a", 1)) interceptor := UnaryLogDataCatcherServerInterceptor() _, _ = interceptor(ctx, nil, nil, handlerStub) - assert.Equal(t, Fields{"a": 1}, mpp.fields) + assert.Equal(t, grpcmwloggingv2.Fields{"a", 1}, mpp.fields) }) } @@ -226,7 +170,7 @@ func TestStreamLogDataCatcherServerInterceptor(t *testing.T) { interceptor := StreamLogDataCatcherServerInterceptor() ss := &grpcmw.WrappedServerStream{WrappedContext: ctx} _ = interceptor(nil, ss, nil, func(interface{}, grpc.ServerStream) error { return nil }) - assert.Equal(t, Fields{"a": 1}, mpp.fields) + assert.Equal(t, grpcmwloggingv2.Fields{"a", 1}, mpp.fields) }) } @@ -285,7 +229,10 @@ func TestLogDeciderOption_logByRegexpMatch(t *testing.T) { t.Setenv("GITALY_LOG_REQUEST_METHOD_ALLOW_PATTERN", tc.only) logger, hook := test.NewNullLogger() - interceptor := grpcmwlogrus.UnaryServerInterceptor(logrus.NewEntry(logger), DeciderOption()) + gitalyLogger := FromLogrusEntry(logrus.NewEntry(logger)) + + interceptor := grpcmwloggingv2.UnaryServerInterceptor(DefaultInterceptorLogger(gitalyLogger)) + interceptor = selector.UnaryServerInterceptor(interceptor, *DeciderMatcher()) ctx := createContext() for _, methodName := range methodNames { @@ -301,9 +248,15 @@ func TestLogDeciderOption_logByRegexpMatch(t *testing.T) { } entries := hook.AllEntries() - require.Len(t, entries, len(tc.shouldLogMethods)) - for idx, entry := range entries { - require.Equal(t, entry.Message, "finished unary call with code OK") + finishingCallEntries := make([]*logrus.Entry, 0) + for _, entry := range entries { + if entry.Message == "finished call" { + finishingCallEntries = append(finishingCallEntries, entry) + } + } + + require.Len(t, finishingCallEntries, len(tc.shouldLogMethods)) + for idx, entry := range finishingCallEntries { require.Equal(t, entry.Data["grpc.method"], tc.shouldLogMethods[idx]) } }) diff --git a/internal/log/options_test.go b/internal/log/options_test.go index a9fafbacc7..85a9f2f4fe 100644 --- a/internal/log/options_test.go +++ b/internal/log/options_test.go @@ -38,4 +38,10 @@ func TestLoggerOptions(t *testing.T) { opt := evaluateServerOpt([]Option{WithLogOnEvents(events...)}) require.Equal(t, opt.loggableEvents, events) }) + + t.Run("WithMatcher", func(t *testing.T) { + matcher := DeciderMatcher() + opt := evaluateServerOpt([]Option{WithMatcher(matcher)}) + require.Equal(t, opt.matcher, matcher) + }) } diff --git a/internal/praefect/delete_object_pool_test.go b/internal/praefect/delete_object_pool_test.go index c0a984d114..20c1f07c18 100644 --- a/internal/praefect/delete_object_pool_test.go +++ b/internal/praefect/delete_object_pool_test.go @@ -4,7 +4,6 @@ import ( "context" "testing" - grpcmwlogrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitaly/v16/internal/git/gittest" "gitlab.com/gitlab-org/gitaly/v16/internal/log" @@ -72,7 +71,8 @@ func TestDeleteObjectPoolHandler(t *testing.T) { logger := testhelper.NewLogger(t) hook := testhelper.AddLoggerHook(logger) praefectSrv := grpc.NewServer(grpc.ChainStreamInterceptor( - logger.StreamServerInterceptor(grpcmwlogrus.WithTimestampFormat(log.LogTimestampFormat)), + logger.StreamServerInterceptor(log.DefaultInterceptorLogger(logger), + log.WithTimestampFormat(log.LogTimestampFormat)), )) praefectSrv.RegisterService(&grpc.ServiceDesc{ ServiceName: "gitaly.ObjectPoolService", @@ -106,8 +106,8 @@ func TestDeleteObjectPoolHandler(t *testing.T) { }) require.NoError(t, err) - require.Len(t, hook.AllEntries(), 2, "expected a log entry for failed deletion") - entry := hook.AllEntries()[0] + require.Len(t, hook.AllEntries(), 3, "expected a log entry for failed deletion") + entry := hook.AllEntries()[1] require.Equal(t, "failed deleting repository", entry.Message) require.Equal(t, repo.StorageName, entry.Data["virtual_storage"]) require.Equal(t, repo.RelativePath, entry.Data["relative_path"]) diff --git a/internal/praefect/server.go b/internal/praefect/server.go index 4792e46552..c20e2f14bf 100644 --- a/internal/praefect/server.go +++ b/internal/praefect/server.go @@ -7,7 +7,7 @@ package praefect import ( "time" - grpcmwlogrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" + grpcmwloggingv2 "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus" "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/server/auth" "gitlab.com/gitlab-org/gitaly/v16/internal/grpc/backchannel" @@ -43,17 +43,14 @@ import ( // NewBackchannelServerFactory returns a ServerFactory that serves the RefTransactionServer on the backchannel // connection. func NewBackchannelServerFactory(logger log.Logger, refSvc gitalypb.RefTransactionServer, registry *sidechannel.Registry) backchannel.ServerFactory { - logMsgProducer := log.MessageProducer( - log.PropagationMessageProducer(grpcmwlogrus.DefaultMessageProducer), - structerr.FieldsProducer, - ) - return func() backchannel.Server { lm := listenmux.New(insecure.NewCredentials()) lm.Register(sidechannel.NewServerHandshaker(registry)) srv := grpc.NewServer( grpc.ChainUnaryInterceptor( - commonUnaryServerInterceptors(logger.WithField("component", "backchannel.PraefectServer"), logMsgProducer)..., + commonUnaryServerInterceptors(logger.WithField("component", "backchannel.PraefectServer"), + log.PropagationMessageProducer(log.DefaultInterceptorLogger(logger)), + structerr.FieldsProducer)..., ), grpc.Creds(lm), ) @@ -63,15 +60,16 @@ func NewBackchannelServerFactory(logger log.Logger, refSvc gitalypb.RefTransacti } } -func commonUnaryServerInterceptors(logger log.Logger, messageProducer grpcmwlogrus.MessageProducer) []grpc.UnaryServerInterceptor { +func commonUnaryServerInterceptors(logger log.Logger, loggerFunc grpcmwloggingv2.Logger, producers ...log.FieldsProducer) []grpc.UnaryServerInterceptor { return []grpc.UnaryServerInterceptor{ grpccorrelation.UnaryServerCorrelationInterceptor(), // Must be above the metadata handler requestinfohandler.UnaryInterceptor, grpcprometheus.UnaryServerInterceptor, logger.UnaryServerInterceptor( - grpcmwlogrus.WithTimestampFormat(log.LogTimestampFormat), - grpcmwlogrus.WithMessageProducer(messageProducer), - log.DeciderOption(), + loggerFunc, + log.WithMatcher(log.DeciderMatcher()), + log.WithTimestampFormat(log.LogTimestampFormat), + log.WithFiledProducers(producers...), ), sentryhandler.UnaryLogHandler(), statushandler.Unary, // Should be below LogHandler @@ -116,13 +114,12 @@ func NewGRPCServer( opt(&serverCfg) } - logMsgProducer := log.MessageProducer( - log.PropagationMessageProducer(grpcmwlogrus.DefaultMessageProducer), - structerr.FieldsProducer, - ) - unaryInterceptors := append( - commonUnaryServerInterceptors(deps.Logger.WithField("component", "praefect.UnaryServerInterceptor"), logMsgProducer), + commonUnaryServerInterceptors( + deps.Logger.WithField("component", "praefect.UnaryServerInterceptor"), + log.PropagationMessageProducer(log.DefaultInterceptorLogger(deps.Logger)), + structerr.FieldsProducer, + ), middleware.MethodTypeUnaryInterceptor(deps.Registry, deps.Logger), auth.UnaryServerInterceptor(deps.Config.Auth), ) @@ -134,9 +131,9 @@ func NewGRPCServer( requestinfohandler.StreamInterceptor, grpcprometheus.StreamServerInterceptor, deps.Logger.WithField("component", "praefect.StreamServerInterceptor").StreamServerInterceptor( - grpcmwlogrus.WithTimestampFormat(log.LogTimestampFormat), - grpcmwlogrus.WithMessageProducer(logMsgProducer), - log.DeciderOption(), + log.DefaultInterceptorLogger(deps.Logger), + log.WithMatcher(log.DeciderMatcher()), + log.WithFiledProducers(structerr.FieldsProducer), ), sentryhandler.StreamLogHandler(), statushandler.Stream, // Should be below LogHandler diff --git a/internal/testhelper/testserver/structerr_interceptors_test.go b/internal/testhelper/testserver/structerr_interceptors_test.go index a0453e7ed4..ac79352c63 100644 --- a/internal/testhelper/testserver/structerr_interceptors_test.go +++ b/internal/testhelper/testserver/structerr_interceptors_test.go @@ -7,7 +7,6 @@ import ( "net" "testing" - grpcmwlogrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitaly/v16/internal/log" "gitlab.com/gitlab-org/gitaly/v16/internal/structerr" @@ -95,14 +94,7 @@ func TestFieldsProducer(t *testing.T) { service := &mockService{} server := grpc.NewServer( grpc.ChainUnaryInterceptor( - logger.UnaryServerInterceptor( - grpcmwlogrus.WithMessageProducer( - log.MessageProducer( - grpcmwlogrus.DefaultMessageProducer, - structerr.FieldsProducer, - ), - ), - ), + logger.UnaryServerInterceptor(log.DefaultInterceptorLogger(logger), log.WithFiledProducers(structerr.FieldsProducer)), ), ) grpc_testing.RegisterTestServiceServer(server, service) -- GitLab From d76feb3e6d006d2f1bffa153b474481c50d5c6d1 Mon Sep 17 00:00:00 2001 From: Eric Ju Date: Wed, 22 Nov 2023 14:14:50 -0400 Subject: [PATCH 3/4] log: Add thread safety to logger interface The `logrus.entry` is not thread safe, so our logger which uses `logrus.entry` directly is not thread safe either. And it is breaking the race-go test. In this commit, we use a wrapper with lock to wrap `logrus.entry`. The implementation of logger interface is also changed to be thread safe. --- internal/log/logger.go | 92 ++++++++++++++++++++++++++++++-------- internal/log/middleware.go | 4 +- 2 files changed, 75 insertions(+), 21 deletions(-) diff --git a/internal/log/logger.go b/internal/log/logger.go index ddf4c704de..bb68f7eefc 100644 --- a/internal/log/logger.go +++ b/internal/log/logger.go @@ -2,9 +2,10 @@ package log import ( "context" - "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors" + "sync" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors" grpcmwloggingv2 "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" grpcmwselector "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/selector" "github.com/sirupsen/logrus" @@ -36,14 +37,20 @@ type Logger interface { ReplaceFields(fields Fields) Logger } +// threadSafeEntry is a wrapper around logrus.Entry with lock that provides thread safety. +type threadSafeEntry struct { + entry *logrus.Entry + lock *sync.Mutex +} + // LogrusLogger is an implementation of the Logger interface that is implemented via a `logrus.FieldLogger`. type LogrusLogger struct { - entry *logrus.Entry + entryWrapper *threadSafeEntry } // FromLogrusEntry constructs a new Gitaly-specific logger from a `logrus.Logger`. func FromLogrusEntry(entry *logrus.Entry) LogrusLogger { - return LogrusLogger{entry: entry} + return LogrusLogger{entryWrapper: &threadSafeEntry{entry: entry, lock: &sync.Mutex{}}} } // LogrusEntry returns the `logrus.Entry` that backs this logger. Note that this interface only exists during the @@ -52,52 +59,94 @@ func FromLogrusEntry(entry *logrus.Entry) LogrusLogger { // Deprecated: This will be removed once all callsites have been converted to do something that is independent of the // logrus logger. func (l LogrusLogger) LogrusEntry() *logrus.Entry { - return l.entry + l.entryWrapper.lock.Lock() + defer l.entryWrapper.lock.Unlock() + return l.entryWrapper.entry } // WithField creates a new logger with the given field appended. func (l LogrusLogger) WithField(key string, value any) Logger { - return LogrusLogger{entry: l.entry.WithField(key, value)} + l.entryWrapper.lock.Lock() + defer l.entryWrapper.lock.Unlock() + return LogrusLogger{ + entryWrapper: &threadSafeEntry{ + l.entryWrapper.entry.WithField(key, value), + l.entryWrapper.lock, + }, + } } -// WithFields creates a new logger with the given fields appended. -func (l LogrusLogger) WithFields(fields Fields) Logger { - return LogrusLogger{entry: l.entry.WithFields(fields)} +// ReplaceFields creates a new logger with old fields truncated and replaced by new fields. +func (l LogrusLogger) ReplaceFields(fields Fields) Logger { + l.entryWrapper.lock.Lock() + defer l.entryWrapper.lock.Unlock() + l.entryWrapper.entry.Data = Fields{} + return LogrusLogger{ + entryWrapper: &threadSafeEntry{ + l.entryWrapper.entry.WithFields(fields), + l.entryWrapper.lock, + }, + } } -func (l LogrusLogger) ReplaceFields(fields Fields) Logger { - l.entry.Data = Fields{} - return LogrusLogger{entry: l.entry.WithFields(fields)} +// WithFields creates a new logger with the given fields appended. +func (l LogrusLogger) WithFields(fields Fields) Logger { + l.entryWrapper.lock.Lock() + defer l.entryWrapper.lock.Unlock() + return LogrusLogger{ + entryWrapper: &threadSafeEntry{ + l.entryWrapper.entry.WithFields(fields), + l.entryWrapper.lock, + }, + } } // WithError creates a new logger with an appended error field. func (l LogrusLogger) WithError(err error) Logger { - return LogrusLogger{entry: l.entry.WithError(err)} + l.entryWrapper.lock.Lock() + defer l.entryWrapper.lock.Unlock() + return LogrusLogger{ + entryWrapper: &threadSafeEntry{ + l.entryWrapper.entry.WithError(err), + l.entryWrapper.lock, + }, + } } // Debug writes a log message at debug level. func (l LogrusLogger) Debug(msg string) { - l.entry.Debug(msg) + l.entryWrapper.lock.Lock() + defer l.entryWrapper.lock.Unlock() + l.entryWrapper.entry.Debug(msg) } // Info writes a log message at info level. func (l LogrusLogger) Info(msg string) { - l.entry.Info(msg) + l.entryWrapper.lock.Lock() + defer l.entryWrapper.lock.Unlock() + l.entryWrapper.entry.Info(msg) } // Warn writes a log message at warn level. func (l LogrusLogger) Warn(msg string) { - l.entry.Warn(msg) + l.entryWrapper.lock.Lock() + defer l.entryWrapper.lock.Unlock() + l.entryWrapper.entry.Warn(msg) } // Error writes a log message at error level. func (l LogrusLogger) Error(msg string) { - l.entry.Error(msg) + l.entryWrapper.lock.Lock() + defer l.entryWrapper.lock.Unlock() + l.entryWrapper.entry.Error(msg) } // toContext injects the logger into the given context so that it can be retrieved via `FromContext()`. func (l LogrusLogger) toContext(ctx context.Context) context.Context { - return ctxlogrus.ToContext(ctx, l.entry) + if l.entryWrapper == nil { + return ctxlogrus.ToContext(ctx, nil) + } + return ctxlogrus.ToContext(ctx, l.entryWrapper.entry) } // StreamServerInterceptor creates a new stream server interceptor. The loggerFunc is the function that will be called to @@ -123,8 +172,10 @@ func (l LogrusLogger) UnaryServerInterceptor(loggerFunc grpcmwloggingv2.Logger, } func (l LogrusLogger) log(ctx context.Context, level logrus.Level, msg string) { + l.entryWrapper.lock.Lock() + defer l.entryWrapper.lock.Unlock() middlewareFields := ConvertLoggingFields(grpcmwloggingv2.ExtractFields(ctx)) - l.entry.WithFields(ctxlogrus.Extract(ctx).Data).WithFields(middlewareFields).Log(level, msg) + l.entryWrapper.entry.WithFields(ctxlogrus.Extract(ctx).Data).WithFields(middlewareFields).Log(level, msg) } // DebugContext logs a new log message at Debug level. Fields added to the context via AddFields will be appended. @@ -151,7 +202,10 @@ func (l LogrusLogger) ErrorContext(ctx context.Context, msg string) { // logger. func fromContext(ctx context.Context) LogrusLogger { return LogrusLogger{ - entry: ctxlogrus.Extract(ctx), + entryWrapper: &threadSafeEntry{ + ctxlogrus.Extract(ctx), + &sync.Mutex{}, + }, } } diff --git a/internal/log/middleware.go b/internal/log/middleware.go index 3d47148d14..075b39eba2 100644 --- a/internal/log/middleware.go +++ b/internal/log/middleware.go @@ -173,7 +173,7 @@ func UnaryLogDataCatcherServerInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { mpp := messageProducerPropagationFrom(ctx) if mpp != nil { - for k, v := range fromContext(ctx).entry.Data { + for k, v := range fromContext(ctx).entryWrapper.entry.Data { mpp.fields = mpp.fields.AppendUnique(grpcmwloggingv2.Fields{k, v}) } } @@ -188,7 +188,7 @@ func StreamLogDataCatcherServerInterceptor() grpc.StreamServerInterceptor { ctx := ss.Context() mpp := messageProducerPropagationFrom(ctx) if mpp != nil { - for k, v := range fromContext(ctx).entry.Data { + for k, v := range fromContext(ctx).entryWrapper.entry.Data { mpp.fields = mpp.fields.AppendUnique(grpcmwloggingv2.Fields{k, v}) } } -- GitLab From d37553f7834ae936fd06f8249724a67d444b25dd Mon Sep 17 00:00:00 2001 From: Eric Ju Date: Wed, 22 Nov 2023 18:09:48 -0400 Subject: [PATCH 4/4] requestinfohandler: Remove grpc middleware v1 tags During grpc middleware v2 migration, a defect https://gitlab.com/gitlab-org/gitaly/-/issues/5694 is caused by that v1 interceptor can't retrieved v2 fields in the context. To fix that, https://gitlab.com/gitlab-org/gitaly/-/merge_requests/6534 brought v1 tags back. Now, with v2 interceptor ready to use, v1 tags can be removed again --- .../requestinfohandler/requestinfohandler.go | 11 +---------- .../requestinfohandler/requestinfohandler_test.go | 13 ------------- 2 files changed, 1 insertion(+), 23 deletions(-) diff --git a/internal/grpc/middleware/requestinfohandler/requestinfohandler.go b/internal/grpc/middleware/requestinfohandler/requestinfohandler.go index 0d738c98e8..f91fddf5ba 100644 --- a/internal/grpc/middleware/requestinfohandler/requestinfohandler.go +++ b/internal/grpc/middleware/requestinfohandler/requestinfohandler.go @@ -4,7 +4,6 @@ import ( "context" "strings" - grpcmwtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus" "github.com/prometheus/client_golang/prometheus" @@ -197,18 +196,10 @@ func (i *RequestInfo) extractRequestInfo(request any) { } func (i *RequestInfo) injectTags(ctx context.Context) context.Context { - tags := grpcmwtags.NewTags() - for key, value := range i.Tags() { ctx = logging.InjectLogField(ctx, key, value) - tags.Set(key, value) + // tags.Set(key, value) } - - // This maintains backward compatibility for tags in the v1 grpc-go-middleware. - // This can be removed when the v1 interceptors are removed: - // https://gitlab.com/gitlab-org/gitaly/-/work_items/5661 - ctx = grpcmwtags.SetInContext(ctx, tags) - return ctx } diff --git a/internal/grpc/middleware/requestinfohandler/requestinfohandler_test.go b/internal/grpc/middleware/requestinfohandler/requestinfohandler_test.go index 38e40faebe..5e62707cd7 100644 --- a/internal/grpc/middleware/requestinfohandler/requestinfohandler_test.go +++ b/internal/grpc/middleware/requestinfohandler/requestinfohandler_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - grpcmwtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" "github.com/stretchr/testify/require" gitalylog "gitlab.com/gitlab-org/gitaly/v16/internal/log" @@ -284,18 +283,6 @@ func TestGRPCTags(t *testing.T) { "grpc.request.fullMethod": "/gitaly.RepositoryService/OptimizeRepository", }, gitalylog.ConvertLoggingFields(fields)) - legacyFields := grpcmwtags.Extract(ctx).Values() - - require.Equal(t, map[string]any{ - "correlation_id": correlationID, - "grpc.meta.client_name": clientName, - "grpc.meta.deadline_type": "none", - "grpc.meta.method_type": "unary", - "grpc.meta.method_operation": "maintenance", - "grpc.meta.method_scope": "repository", - "grpc.request.fullMethod": "/gitaly.RepositoryService/OptimizeRepository", - }, legacyFields) - return nil, nil }) require.NoError(t, err) -- GitLab