diff --git a/workhorse/internal/ai_assist/duoworkflow/client.go b/workhorse/internal/ai_assist/duoworkflow/client.go index 38dccfec9a348e6e3112e6b0c8ec74893244247d..49b6d192a1b00db59918ce51a29f97468720e642 100644 --- a/workhorse/internal/ai_assist/duoworkflow/client.go +++ b/workhorse/internal/ai_assist/duoworkflow/client.go @@ -4,7 +4,7 @@ package duoworkflow import ( "context" "crypto/tls" - "fmt" + "errors" "time" grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" @@ -26,18 +26,31 @@ import ( const MaxMessageSize = 4 * 1024 * 1024 // 4MB // ErrServerUnavailable is returned when the workflow server cannot be reached. -var ErrServerUnavailable = fmt.Errorf("server is unavailable") +var ErrServerUnavailable = errors.New("server is unavailable") // Client is a gRPC client for the Duo Workflow service. type Client struct { grpcConn *grpc.ClientConn grpcClient pb.DuoWorkflowClient - headers map[string]string + md metadata.MD } // NewClient creates a new Duo Workflow client with the specified server address, // headers, and security settings. func NewClient(serverURI string, headers map[string]string, secure bool) (*Client, error) { + // Configured based on https://grpc.io/docs/guides/service-config/ + serviceConfig := `{ + "methodConfig": [{ + "name": [{"service": "DuoWorkflow"}], + "retryPolicy": { + "maxAttempts": 4, + "initialBackoff": "0.1s", + "maxBackoff": "1s", + "backoffMultiplier": 2, + "retryableStatusCodes": [ "UNAVAILABLE" ] + } + }] + }` opts := []grpc.DialOption{ grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 20 * time.Second, // send pings every 20 seconds if there is no activity @@ -50,6 +63,12 @@ func NewClient(serverURI string, headers map[string]string, secure bool) (*Clien grpccorrelation.WithClientName("gitlab-duo-workflow"), ), ), + grpc.WithDefaultServiceConfig(serviceConfig), + grpc.WithDefaultCallOptions( + grpc.MaxCallRecvMsgSize(MaxMessageSize), + grpc.MaxCallSendMsgSize(MaxMessageSize), + grpc.WaitForReady(true), + ), } if secure { @@ -58,23 +77,6 @@ func NewClient(serverURI string, headers map[string]string, secure bool) (*Clien opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) } - // Configured based on https://grpc.io/docs/guides/service-config/ - serviceConfig := `{ - "methodConfig": [{ - "name": [{"service": "DuoWorkflow"}], - "retryPolicy": { - "maxAttempts": 4, - "initialBackoff": "0.1s", - "maxBackoff": "1s", - "backoffMultiplier": 2, - "retryableStatusCodes": [ "UNAVAILABLE" ] - } - }] - }` - - callOptions := grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxMessageSize), grpc.MaxCallSendMsgSize(MaxMessageSize)) - opts = append(opts, grpc.WithDefaultServiceConfig(serviceConfig), callOptions) - conn, err := grpc.NewClient(serverURI, opts...) if err != nil { return nil, err @@ -83,18 +85,17 @@ func NewClient(serverURI string, headers map[string]string, secure bool) (*Clien return &Client{ grpcConn: conn, grpcClient: pb.NewDuoWorkflowClient(conn), - headers: headers, + md: metadata.New(headers), }, nil } // ExecuteWorkflow initiates a new workflow execution stream with the server. func (c *Client) ExecuteWorkflow(ctx context.Context) (pb.DuoWorkflow_ExecuteWorkflowClient, error) { - ctx = metadata.NewOutgoingContext(ctx, metadata.New(c.headers)) + ctx = metadata.NewOutgoingContext(ctx, c.md) stream, err := c.grpcClient.ExecuteWorkflow(ctx) if err != nil { - st, ok := status.FromError(err) - if ok && st.Code() == codes.Unavailable { + if status.Code(err) == codes.Unavailable { return nil, ErrServerUnavailable } return nil, err diff --git a/workhorse/internal/ai_assist/duoworkflow/handler.go b/workhorse/internal/ai_assist/duoworkflow/handler.go index 7ac0415651de8f18a9298c8b1105a7d2c00fbbdc..787f81771fcb4a6a14156b11f6601a8b038adbfb 100644 --- a/workhorse/internal/ai_assist/duoworkflow/handler.go +++ b/workhorse/internal/ai_assist/duoworkflow/handler.go @@ -11,29 +11,49 @@ import ( "github.com/gorilla/websocket" ) -var upgrader = websocket.Upgrader{} +const ( + // maxControlPayload is the maximum length of a control frame payload. + // See https://tools.ietf.org/html/rfc6455#section-5.5. + maxControlPayload = 125 + // 2 bytes for the payload len. + maxCloseReason = maxControlPayload - 2 +) // Handler creates an HTTP handler for Duo Workflow WebSocket connections. func Handler(rails *api.API) http.Handler { + var u websocket.Upgrader return rails.PreAuthorizeHandler(func(w http.ResponseWriter, r *http.Request, a *api.Response) { - conn, err := upgrader.Upgrade(w, r, nil) + // 1. Create gRPC client for DWS. + // This shouldn't fail, but if it does, we can return a simple HTTP error response so this is the first thing we do. + // We should maintain a per-URI pool of clients or, ideally, configure Workhorse with a single DWS URI. + client, err := NewClient(a.DuoWorkflow.ServiceURI, a.DuoWorkflow.Headers, a.DuoWorkflow.Secure) if err != nil { - fail.Request(w, r, fmt.Errorf("failed to upgrade: %v", err)) + fail.Request(w, r, fmt.Errorf("failed to initialize client: %v", err)) return } - - runner, err := newRunner(conn, rails, r, a.DuoWorkflow) - if err != nil { - fail.Request(w, r, fmt.Errorf("failed to initialize agent platform client: %v", err)) - if closeErr := conn.Close(); closeErr != nil { - log.WithRequest(r).WithError(closeErr).Error("failed to close connection") + defer func() { + err = client.Close() + if err != nil { + log.WithRequest(r).WithError(err).Error("Failed to close gRPC client") } + }() + // 2. Accept the WebSocket upgrade request. + conn, err := u.Upgrade(w, r, nil) + if err != nil { + fail.Request(w, r, fmt.Errorf("failed to upgrade: %v", err)) return } - defer func() { _ = runner.Close() }() + defer func() { + err = conn.Close() + if err != nil { + log.WithRequest(r).WithError(err).Error("Failed to close WebSocket connection") + } + }() - if err := runner.Execute(r.Context()); err != nil { - log.WithRequest(r).WithError(err).Error() - } + // 3. Construct the runner + rnr := newRunner(conn, rails, r, client, a.DuoWorkflow.Headers["x-gitlab-oauth-token"]) + + // 4. Execute the logic. + rnr.Execute(r.Context()) }, "") } diff --git a/workhorse/internal/ai_assist/duoworkflow/runner.go b/workhorse/internal/ai_assist/duoworkflow/runner.go index 21fd62c52659af541b6952b64cfa12992047c40b..e3e70a893883eb8c5c6eb561f094189afe489851 100644 --- a/workhorse/internal/ai_assist/duoworkflow/runner.go +++ b/workhorse/internal/ai_assist/duoworkflow/runner.go @@ -2,25 +2,20 @@ package duoworkflow import ( "context" - "errors" "fmt" "io" "net/http" - "sync" "time" - pb "gitlab.com/gitlab-org/modelops/applied-ml/code-suggestions/ai-assist/clients/gopb/contract" - "gitlab.com/gitlab-org/gitlab/workhorse/internal/api" "gitlab.com/gitlab-org/gitlab/workhorse/internal/log" + pb "gitlab.com/gitlab-org/modelops/applied-ml/code-suggestions/ai-assist/clients/gopb/contract" "github.com/gorilla/websocket" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" ) -const wsCloseTimeout = 5 * time.Second - var marshaler = protojson.MarshalOptions{ UseProtoNames: true, EmitUnpopulated: true, @@ -33,9 +28,8 @@ var unmarshaler = protojson.UnmarshalOptions{ type websocketConn interface { ReadMessage() (int, []byte, error) WriteMessage(int, []byte) error - WriteControl(int, []byte, time.Time) error + WriteControl(messageType int, data []byte, deadline time.Time) error SetReadDeadline(time.Time) error - Close() error } type workflowStream interface { @@ -49,123 +43,150 @@ type runner struct { token string originalReq *http.Request conn websocketConn - wf workflowStream client *Client - sendMu sync.Mutex } -func newRunner(conn websocketConn, rails *api.API, r *http.Request, cfg *api.DuoWorkflow) (*runner, error) { - client, err := NewClient(cfg.ServiceURI, cfg.Headers, cfg.Secure) - if err != nil { - return nil, fmt.Errorf("failed to initialize client: %v", err) - } - - wf, err := client.ExecuteWorkflow(r.Context()) - if err != nil { - return nil, fmt.Errorf("failed to initialize stream: %v", err) - } - +func newRunner(conn websocketConn, rails *api.API, r *http.Request, client *Client, token string) *runner { return &runner{ rails: rails, - token: cfg.Headers["x-gitlab-oauth-token"], + token: token, originalReq: r, conn: conn, - wf: wf, client: client, - }, nil + } } -func (r *runner) Execute(ctx context.Context) error { - errCh := make(chan error, 2) +func (r *runner) Execute(ctx context.Context) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + wf, err := r.client.ExecuteWorkflow(ctx) + if err != nil { + r.websocketWriteCloseMessage(websocket.CloseInternalServerErr, fmt.Sprintf("ExecuteWorkflow RPC: %v", err)) + return + } + + toSend := make(chan *pb.ClientEvent) // nil means "call CloseSend()" + aaErrCh := make(chan error) + wsErrCh := make(chan error, 1) + var sendError error go func() { - for { - if err := r.handleWebSocketMessage(); err != nil { - errCh <- err - return + wsErrCh <- r.handleWebSocketMessages(ctx, toSend) + }() + go func() { + defer cancel() + aaErrCh <- r.handleAgentActions(ctx, toSend, wf) + }() + defer func() { // We need to send the Close control message if it hasn't been sent. + // We are here because there was an error or context was canceled. + // Cancel the context if it's the latter to signal handleAgentActions() to abort. + cancel() + readAborted := false + finalErr := sendError + aaErr := <-aaErrCh // wait for the handleAgentActions() goroutine to exit + if finalErr == nil { + finalErr = aaErr + } + if finalErr == nil { + finalErr = <-wsErrCh // if there was no error, wait for the handleWebSocketMessages() to exit + } else { + // If there was an error talking to DWS, handleWebSocketMessages() may be blocked trying to read + // from the incoming WebSocket connection as it is unaware of the error. + // Abort it by setting a deadline in the past. See doc for SetDeadline() at https://pkg.go.dev/net#Conn + readAborted = true + readDeadErr := r.conn.SetReadDeadline(time.Unix(0, 1)) + if readDeadErr != nil { + log.WithRequest(r.originalReq).WithError(readDeadErr).Error("SetReadDeadline() failed") // unlikely } + <-wsErrCh // wait for handleWebSocketMessages() to return, ignore the error as we have one already in err. } - }() - go func() { - for { - action, err := r.wf.Recv() - if err != nil { - if err == io.EOF { - errCh <- nil // Expected error when a workflow ends - } else { - errCh <- fmt.Errorf("duoworkflow: failed to read a gRPC message: %v", err) - } + if finalErr == nil { // Clean close. Send the Close control message. + r.websocketWriteCloseMessage(websocket.CloseNormalClosure, "") + } else { + if readAborted { + // Cannot write anything in response, just reset the connection. + // The client will get an I/O error and know that something went wrong on the server. + } else { + // Send the client the error if we didn't abort the WebSocket read operation. + r.websocketWriteCloseMessage(websocket.CloseAbnormalClosure, err.Error()) + } + } + }() + done := ctx.Done() + for { + select { + case <-done: + return // context error will be picked up by Send() and Recv() on the gRPC stream. + case event := <-toSend: + if event == nil { // WebSocket closed normally + // Signal DWS we are done talking to it. + // The handleAgentActions() goroutine will get io.EOF from Recv(). + // We should wait for that for clean termination. + _ = wf.CloseSend() // always returns nil error return } - - if err := r.handleAgentAction(ctx, action); err != nil { - errCh <- err + sendError = wf.Send(event) + if sendError != nil { + if sendError == io.EOF { + sendError = nil // let Recv() get the actual error + } return } } - }() - - return <-errCh -} - -func (r *runner) Close() error { - r.sendMu.Lock() - defer r.sendMu.Unlock() - - return errors.Join(r.wf.CloseSend(), r.client.Close(), r.closeWebSocketConnection()) + } } -func (r *runner) closeWebSocketConnection() error { - deadline := time.Now().Add(wsCloseTimeout) - if err := r.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), deadline); err != nil { - // If we can't send the close message, just close the connection - closeErr := r.conn.Close() - if closeErr != nil { - return fmt.Errorf("failed to send close message and failed to close connection: %w", closeErr) +func (r *runner) handleWebSocketMessages(ctx context.Context, toSend chan<- *pb.ClientEvent) error { + done := ctx.Done() + for { + mt, message, err := r.conn.ReadMessage() + if err != nil { + if websocket.IsCloseError(err, websocket.CloseNormalClosure) { + select { + case <-done: + case toSend <- nil: // propagate normal closure + } + return nil + } + return fmt.Errorf("handleWebSocketMessages: failed to read a message: %v", err) } - return fmt.Errorf("failed to send close message: %w", err) - } - - if err := r.conn.SetReadDeadline(deadline); err != nil { - closeErr := r.conn.Close() - if closeErr != nil { - return fmt.Errorf("failed to set read deadline and failed to close connection: %w", closeErr) + if mt != websocket.BinaryMessage { + // ReadMessage() uses NextReader(), which is documented to only return BinaryMessage or TextMessage. + // websocket.TextMessage would be an unexpected type because we are dealing with binary data here, not utf-8. + return fmt.Errorf("handleWebSocketMessages: unexpected message type: %d", mt) + } + response := &pb.ClientEvent{} + if err = unmarshaler.Unmarshal(message, response); err != nil { + return fmt.Errorf("handleWebSocketMessages: failed to unmarshal a message: %v", err) + } + select { + case <-done: + return nil + case toSend <- response: } - return fmt.Errorf("failed to set read deadline: %w", err) - } - - if err := r.conn.Close(); err != nil { - return fmt.Errorf("failed to close connection: %w", err) } - - return nil } -func (r *runner) handleWebSocketMessage() error { - _, message, err := r.conn.ReadMessage() - if err != nil { - return fmt.Errorf("handleWebSocketMessage: failed to read a WS message: %v", err) - } - - response := &pb.ClientEvent{} - if err = unmarshaler.Unmarshal(message, response); err != nil { - return fmt.Errorf("handleWebSocketMessage: failed to unmarshal a WS message: %v", err) - } - - if err = r.threadSafeSend(response); err != nil { - if err == io.EOF { - // ignore EOF to let Recv() fail and return a meaningful message - return nil +func (r *runner) handleAgentActions(ctx context.Context, toSend chan<- *pb.ClientEvent, wf pb.DuoWorkflow_ExecuteWorkflowClient) error { + for { + action, err := wf.Recv() + if err != nil { + if err == io.EOF { // Expected error when a workflow ends + return nil + } + return fmt.Errorf("handleAgentActions: failed to read a gRPC message: %v", err) } - return fmt.Errorf("handleWebSocketMessage: failed to write a gRPC message: %v", err) + err = r.handleAgentAction(ctx, toSend, action) + if err != nil { + return err + } } - - return nil } -func (r *runner) handleAgentAction(ctx context.Context, action *pb.Action) error { +func (r *runner) handleAgentAction(ctx context.Context, toSend chan<- *pb.ClientEvent, action *pb.Action) error { switch action.Action.(type) { case *pb.Action_RunHTTPRequest: handler := &runHTTPActionHandler{ @@ -179,7 +200,6 @@ func (r *runner) handleAgentAction(ctx context.Context, action *pb.Action) error if err != nil { return fmt.Errorf("handleAgentAction: failed to perform API call: %v", err) } - statusCode := event.GetActionResponse().GetHttpResponse().StatusCode log.WithContextFields(r.originalReq.Context(), log.Fields{ "path": action.GetRunHTTPRequest().Path, @@ -188,13 +208,12 @@ func (r *runner) handleAgentAction(ctx context.Context, action *pb.Action) error "event_type": fmt.Sprintf("%T", event.Response), "action_response_type": fmt.Sprintf("%T", event.GetActionResponse().GetResponseType()), }).Info("Sending HTTP response event") - if err := r.threadSafeSend(event); err != nil { - return fmt.Errorf("handleAgentAction: failed to send gRPC message: %v", err) - } - log.WithContextFields(r.originalReq.Context(), log.Fields{ - "path": action.GetRunHTTPRequest().Path, - }).Info("Successfully sent HTTP response event") + select { + case <-ctx.Done(): + return ctx.Err() + case toSend <- event: + } default: message, err := marshaler.Marshal(action) if err != nil { @@ -209,8 +228,31 @@ func (r *runner) handleAgentAction(ctx context.Context, action *pb.Action) error return nil } -func (r *runner) threadSafeSend(event *pb.ClientEvent) error { - r.sendMu.Lock() - defer r.sendMu.Unlock() - return r.wf.Send(event) +func (r *runner) websocketWriteCloseMessage(closeCode int, text string) { + if len(text) > maxCloseReason { + text = text[:maxCloseReason-3] + "..." + } + deadline := time.Now().Add(5 * time.Second) + err := r.conn.WriteControl( + websocket.CloseMessage, + websocket.FormatCloseMessage(closeCode, text), + deadline, + ) + switch err { + case nil: // We sent the Close control frame first, now we need the remote peer to ack it. + err = r.conn.SetReadDeadline(deadline) + if err != nil { + log.WithRequest(r.originalReq).WithError(err).Error("SetReadDeadline() failed") // unlikely + err = nil + } + for err == nil { // Drain connection until error + _, _, err = r.conn.ReadMessage() + } + if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + log.WithRequest(r.originalReq).WithError(err).Error("Error from WebSocket client") + } + case websocket.ErrCloseSent: // The library had sent the Close control frame already, no need to log this error. + default: // Log everything else. + log.WithRequest(r.originalReq).WithError(err).Error("Failed to write close message") + } } diff --git a/workhorse/internal/ai_assist/duoworkflow/runner_test.go b/workhorse/internal/ai_assist/duoworkflow/runner_test.go index dba71376d8c1fbd3f45e743dab8c3a18d692aff5..e8b317cce660b3c3997708526ec9f56f44b8b6d8 100644 --- a/workhorse/internal/ai_assist/duoworkflow/runner_test.go +++ b/workhorse/internal/ai_assist/duoworkflow/runner_test.go @@ -136,7 +136,6 @@ func Test_newRunner(t *testing.T) { require.Equal(t, "oauth-token-123", runner.token) require.Equal(t, req, runner.originalReq) require.Equal(t, mockConn, runner.conn) - require.NotNil(t, runner.wf) require.NotNil(t, runner.client) require.Equal(t, apiClient, runner.rails) @@ -159,7 +158,7 @@ func TestRunner_Execute(t *testing.T) { wsMessages: [][]byte{[]byte(`{"type": "test"}`), []byte(`{"type": "test2"}`)}, wfBlockCh: make(chan bool), sendEventsCount: 2, - expectedErrMsg: "handleWebSocketMessage: failed to read a WS message: EOF", + expectedErrMsg: "handleWebSocketMessages: failed to read a WS message: EOF", }, { name: "wf actions", @@ -233,7 +232,7 @@ func TestRunner_Execute_with_errors(t *testing.T) { name: "websocket read error", wsReadError: errors.New("read error"), wfBlockCh: make(chan bool), - expectedErrMsg: "handleWebSocketMessage: failed to read a WS message: read error", + expectedErrMsg: "handleWebSocketMessages: failed to read a WS message: read error", }, { name: "workflow recv error", @@ -295,18 +294,18 @@ func TestRunner_handleWebSocketMessage(t *testing.T) { { name: "read error", readError: errors.New("read error"), - expectedErrMsg: "handleWebSocketMessage: failed to read a WS message: read error", + expectedErrMsg: "handleWebSocketMessages: failed to read a WS message: read error", }, { name: "invalid json", message: []byte("invalid json"), - expectedErrMsg: "handleWebSocketMessage: failed to unmarshal a WS message: proto:", + expectedErrMsg: "handleWebSocketMessages: failed to unmarshal a WS message: proto:", }, { name: "send error", message: []byte(`{"type": "test"}`), sendError: errors.New("send error"), - expectedErrMsg: "handleWebSocketMessage: failed to write a gRPC message: send error", + expectedErrMsg: "handleWebSocketMessages: failed to write a gRPC message: send error", }, { name: "send EOF error", @@ -343,7 +342,7 @@ func TestRunner_handleWebSocketMessage(t *testing.T) { wf: mockWf, } - err := r.handleWebSocketMessage() + err := r.handleWebSocketMessages(nil, nil) if tt.expectedErrMsg != "" { require.Error(t, err)