diff --git a/ee/app/services/ai/duo_workflows/create_oauth_access_token_service.rb b/ee/app/services/ai/duo_workflows/create_oauth_access_token_service.rb index f4699abf28741ad26ee7665d64bbb0268c64b8ce..555a7f1a314a1ee1b8b1cd779512b956f7e05c73 100644 --- a/ee/app/services/ai/duo_workflows/create_oauth_access_token_service.rb +++ b/ee/app/services/ai/duo_workflows/create_oauth_access_token_service.rb @@ -32,7 +32,7 @@ def create_oauth_access_token expires_in: 2.hours, resource_owner_id: current_user.id, organization: @organization, - scopes: oauth_application.scopes.to_s + scopes: (oauth_application.scopes + ['mcp']).to_s ) end diff --git a/lib/api/mcp/base.rb b/lib/api/mcp/base.rb index 545d985263abd47b12dcb20071e4676f615fa8a1..37bbffb7c0368d31c8a1c753c699ce8415d792fa 100644 --- a/lib/api/mcp/base.rb +++ b/lib/api/mcp/base.rb @@ -44,7 +44,7 @@ class Base < ::API::Base before do authenticate! not_found! unless Feature.enabled?(:mcp_server, current_user) - forbidden! unless access_token&.scopes&.map(&:to_s) == [Gitlab::Auth::MCP_SCOPE.to_s] + forbidden! unless access_token&.scopes&.include?(Gitlab::Auth::MCP_SCOPE.to_s) end helpers do @@ -154,7 +154,7 @@ def format_jsonrpc_response(result) # See: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#listening-for-messages-from-the-server get do - status :not_implemented + status :method_not_allowed end end end diff --git a/workhorse/go.mod b/workhorse/go.mod index 6e9a3e7526b196aacafd912582f162466ebb1247..b7f2f3a76213ff21acbc3a8004b08f9f02239b8d 100644 --- a/workhorse/go.mod +++ b/workhorse/go.mod @@ -41,6 +41,8 @@ require ( google.golang.org/protobuf v1.36.6 ) +require github.com/modelcontextprotocol/go-sdk v0.5.0 + require ( cel.dev/expr v0.23.0 // indirect cloud.google.com/go v0.115.1 // indirect @@ -101,6 +103,7 @@ require ( github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/flatbuffers v25.2.10+incompatible // indirect + github.com/google/jsonschema-go v0.2.3 // indirect github.com/google/pprof v0.0.0-20240711041743-f6c9dda6c6da // indirect github.com/google/s2a-go v0.1.8 // indirect github.com/google/uuid v1.6.0 // indirect @@ -119,6 +122,7 @@ require ( github.com/lightstep/lightstep-tracer-go v0.25.0 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect + github.com/modelcontextprotocol/go-sdk v0.5.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/oklog/ulid/v2 v2.0.2 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect @@ -141,6 +145,7 @@ require ( github.com/tklauser/numcpus v0.3.0 // indirect github.com/uber/jaeger-client-go v2.30.0+incompatible // indirect github.com/uber/jaeger-lib v2.4.1+incompatible // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yusufpapurcu/wmi v1.2.2 // indirect github.com/zeebo/errs v1.4.0 // indirect gitlab.com/gitlab-org/go/reopen v1.0.0 // indirect diff --git a/workhorse/go.sum b/workhorse/go.sum index 2daee9121108cd65e39fddbc04dde6062af075ec..037b89f18e60decc46a2ff343859d60f007d8eb2 100644 --- a/workhorse/go.sum +++ b/workhorse/go.sum @@ -361,6 +361,8 @@ github.com/google/go-replayers/httpreplay v1.2.0 h1:VM1wEyyjaoU53BwrOnaf9VhAyQQE github.com/google/go-replayers/httpreplay v1.2.0/go.mod h1:WahEFFZZ7a1P4VM1qEeHy+tME4bwyqPcwWbNlUI1Mcg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/jsonschema-go v0.2.3 h1:dkP3B96OtZKKFvdrUSaDkL+YDx8Uw9uC4Y+eukpCnmM= +github.com/google/jsonschema-go v0.2.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/martian v2.1.0+incompatible h1:/CP5g8u/VJHijgedC/Legn3BAbAaWPgecwXBIDzw5no= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= @@ -493,6 +495,8 @@ github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zx github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g= github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28= +github.com/modelcontextprotocol/go-sdk v0.5.0 h1:WXRHx/4l5LF5MZboeIJYn7PMFCrMNduGGVapYWFgrF8= +github.com/modelcontextprotocol/go-sdk v0.5.0/go.mod h1:degUj7OVKR6JcYbDF+O99Fag2lTSTbamZacbGTRTSGU= github.com/montanaflynn/stats v0.7.0 h1:r3y12KyNxj/Sb/iOE46ws+3mS1+MZca1wlHQFPsY/JU= github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= @@ -618,6 +622,8 @@ github.com/uber/jaeger-lib v2.4.1+incompatible h1:td4jdvLcExb4cBISKIpHuGoVXh+dVK github.com/uber/jaeger-lib v2.4.1+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U= github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/workhorse/internal/ai_assist/duoworkflow/actions.go b/workhorse/internal/ai_assist/duoworkflow/actions.go index 5f3cdbfb3ce566c775fa7d06ee1f94b0256ce677..0a62d1f565bcef6eb84ff4d7367b1dfe683ee665 100644 --- a/workhorse/internal/ai_assist/duoworkflow/actions.go +++ b/workhorse/internal/ai_assist/duoworkflow/actions.go @@ -3,24 +3,18 @@ package duoworkflow import ( "bytes" "context" - "fmt" "io" - "net" "net/http" "net/url" - "strings" pb "gitlab.com/gitlab-org/modelops/applied-ml/code-suggestions/ai-assist/clients/gopb/contract" - - "gitlab.com/gitlab-org/gitlab/workhorse/internal/api" ) const actionResponseBodyLimit = 1024 * 1024 // 1 megabyte type runHTTPActionHandler struct { - rails *api.API - token string - originalReq *http.Request + baseURL *url.URL + railsClient *http.Client action *pb.Action } @@ -37,27 +31,13 @@ func (a *runHTTPActionHandler) Execute(ctx context.Context) (*pb.ClientEvent, er return nil, err } - reqURL := a.rails.URL.ResolveReference(actionURL).String() + reqURL := a.baseURL.ResolveReference(actionURL).String() req, err := http.NewRequestWithContext(ctx, action.Method, reqURL, &bodyBuffer) if err != nil { return nil, err } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", a.token)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", "Agent-Flow-via-GitLab-Workhorse") - - if clientIP, _, splitHostErr := net.SplitHostPort(a.originalReq.RemoteAddr); splitHostErr == nil { - // If we aren't the first proxy retain prior X-Forwarded-For information as a comma+space separated list and fold multiple headers into one. - var header string - if prior, ok := a.originalReq.Header["X-Forwarded-For"]; ok { - header = strings.Join(prior, ", ") + ", " + clientIP - } else { - header = clientIP - } - req.Header.Set("X-Forwarded-For", header) - } - response, err := a.rails.Client.Do(req) + response, err := a.railsClient.Do(req) if err != nil { return nil, err } diff --git a/workhorse/internal/ai_assist/duoworkflow/runner.go b/workhorse/internal/ai_assist/duoworkflow/runner.go index 3b155f21b80a4128adcbf3846908d283111fe35d..4bd0eb387f6e38a0482f90a1b9208b3cd04f0cf2 100644 --- a/workhorse/internal/ai_assist/duoworkflow/runner.go +++ b/workhorse/internal/ai_assist/duoworkflow/runner.go @@ -5,8 +5,12 @@ import ( "errors" "fmt" "io" + "net" + "net/url" + "strings" "net/http" "sync" + "encoding/json" pb "gitlab.com/gitlab-org/modelops/applied-ml/code-suggestions/ai-assist/clients/gopb/contract" @@ -16,6 +20,7 @@ import ( "github.com/gorilla/websocket" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" + "github.com/modelcontextprotocol/go-sdk/mcp" ) var marshaler = protojson.MarshalOptions{ @@ -39,15 +44,41 @@ type workflowStream interface { } type runner struct { - rails *api.API - token string - originalReq *http.Request + baseURL *url.URL + railsClient *http.Client conn websocketConn wf workflowStream client *Client + mcpSession *mcp.ClientSession sendMu sync.Mutex } +type roundTripper struct { + next http.RoundTripper + token string + originalReq *http.Request +} + +func (t *roundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + r.Header.Set("Authorization", fmt.Sprintf("Bearer %v", t.token)) + r.Header.Set("Content-Type", "application/json") + r.Header.Set("User-Agent", "Agent-Flow-via-GitLab-Workhorse") + + if t.originalReq != nil { + if clientIP, _, splitHostErr := net.SplitHostPort(t.originalReq.RemoteAddr); splitHostErr == nil { + var header string + if prior, ok := t.originalReq.Header["X-Forwarded-For"]; ok { + header = strings.Join(prior, ", ") + ", " + clientIP + } else { + header = clientIP + } + r.Header.Set("X-Forwarded-For", header) + } + } + + return t.next.RoundTrip(r) +} + func newRunner(conn websocketConn, rails *api.API, r *http.Request, cfg *api.DuoWorkflow) (*runner, error) { client, err := NewClient(cfg.ServiceURI, cfg.Headers, cfg.Secure) if err != nil { @@ -59,13 +90,31 @@ func newRunner(conn websocketConn, rails *api.API, r *http.Request, cfg *api.Duo return nil, fmt.Errorf("failed to initialize stream: %v", err) } + railsClient := &http.Client{ + Transport: &roundTripper{ + next: rails.Client.Transport, + token: cfg.Headers["x-gitlab-oauth-token"], + originalReq: r, + }, + } + + mcpClient := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil) + transport := &mcp.StreamableClientTransport{ + Endpoint: rails.URL.JoinPath("api/v4/mcp").String(), + HTTPClient: railsClient, + } + mcpSession, err := mcpClient.Connect(r.Context(), transport, nil) + if err != nil { + return nil, fmt.Errorf("failed to initialize MCP client: %v", err) + } + return &runner{ - rails: rails, - token: cfg.Headers["x-gitlab-oauth-token"], - originalReq: r, + baseURL: rails.URL, + railsClient: railsClient, conn: conn, wf: wf, client: client, + mcpSession: mcpSession, }, nil } @@ -74,7 +123,7 @@ func (r *runner) Execute(ctx context.Context) error { go func() { for { - if err := r.handleWebSocketMessage(); err != nil { + if err := r.handleWebSocketMessage(ctx); err != nil { errCh <- err return } @@ -107,21 +156,47 @@ func (r *runner) Close() error { r.sendMu.Lock() defer r.sendMu.Unlock() - return errors.Join(r.wf.CloseSend(), r.client.Close()) + return errors.Join(r.wf.CloseSend(), r.client.Close(), r.mcpSession.Close()) } -func (r *runner) handleWebSocketMessage() error { +func (r *runner) handleWebSocketMessage(ctx context.Context) error { _, message, err := r.conn.ReadMessage() if err != nil { return fmt.Errorf("handleWebSocketMessage: failed to read a WS message: %v", err) } - response := &pb.ClientEvent{} - if err = unmarshaler.Unmarshal(message, response); err != nil { + clientEvent := &pb.ClientEvent{} + if err = unmarshaler.Unmarshal(message, clientEvent); err != nil { return fmt.Errorf("handleWebSocketMessage: failed to unmarshal a WS message: %v", err) } - if err = r.threadSafeSend(response); err != nil { + if startReq := clientEvent.GetStartRequest(); startReq != nil { + toolsResult, err := r.mcpSession.ListTools(ctx, &mcp.ListToolsParams{}) + if err != nil { + return fmt.Errorf("failed to list tools: %w", err) + } + + for _, tool := range toolsResult.Tools { + schemaBytes, err := json.Marshal(tool.InputSchema) + if err != nil { + return fmt.Errorf("failed to marshal input schema for tool %s: %w", tool.Name, err) + } + + mcpTool := &pb.McpTool{ + Name: tool.Name, + Description: tool.Description, + InputSchema: string(schemaBytes), + } + + startReq.McpTools = append(startReq.McpTools, mcpTool) + } + } + + fmt.Println("---------------------------------------------------------------") + fmt.Println(clientEvent) + fmt.Println("---------------------------------------------------------------") + + if err = r.threadSafeSend(clientEvent); err != nil { if err == io.EOF { // ignore EOF to let Recv() fail and return a meaningful message return nil @@ -137,9 +212,8 @@ func (r *runner) handleAgentAction(ctx context.Context, action *pb.Action) error switch action.Action.(type) { case *pb.Action_RunHTTPRequest: handler := &runHTTPActionHandler{ - rails: r.rails, - token: r.token, - originalReq: r.originalReq, + baseURL: r.baseURL, + railsClient: r.railsClient, action: action, } @@ -149,7 +223,7 @@ func (r *runner) handleAgentAction(ctx context.Context, action *pb.Action) error } statusCode := event.GetActionResponse().GetHttpResponse().StatusCode - log.WithContextFields(r.originalReq.Context(), log.Fields{ + log.WithContextFields(ctx, log.Fields{ "path": action.GetRunHTTPRequest().Path, "status_code": statusCode, "payload_size": proto.Size(event), @@ -159,10 +233,51 @@ func (r *runner) handleAgentAction(ctx context.Context, action *pb.Action) error if err := r.threadSafeSend(event); err != nil { return fmt.Errorf("handleAgentAction: failed to send gRPC message: %v", err) } - log.WithContextFields(r.originalReq.Context(), log.Fields{ + log.WithContextFields(ctx, log.Fields{ "path": action.GetRunHTTPRequest().Path, }).Info("Successfully sent HTTP response event") + case *pb.Action_RunMCPTool: + action := action.GetRunMCPTool() + + var args map[string]any + if err := json.Unmarshal([]byte(action.Args), &args); err != nil { + return fmt.Errorf("handleAgentAction: failed to unmarshal MCP args: %v", err) + } + + params := &mcp.CallToolParams{ + Name: action.Name, + Arguments: args, + } + + res, err := r.mcpSession.CallTool(ctx, params) + if err != nil { + return fmt.Errorf("handleAgentAction: failed to call MCP tool: %v", err) + } + + if res.IsError { + return fmt.Errorf("handleAgentAction: MCP tool failed: %v", err) + } + + if len(res.Content) > 0 { + c := res.Content[0] + clientEvent := &pb.ClientEvent{ + Response: &pb.ClientEvent_ActionResponse{ + ActionResponse: &pb.ActionResponse{ + ResponseType: &pb.ActionResponse_HttpResponse{ + HttpResponse: &pb.HttpResponse{ + Body: c.(*mcp.TextContent).Text, + StatusCode: 200, + }, + }, + Response: c.(*mcp.TextContent).Text, + }, + }, + } + if err := r.threadSafeSend(clientEvent); err != nil { + return fmt.Errorf("handleAgentAction: failed to send gRPC message: %v", err) + } + } default: message, err := marshaler.Marshal(action) if err != nil {