diff --git a/ee/app/services/ai/duo_workflows/mcp_config_service.rb b/ee/app/services/ai/duo_workflows/mcp_config_service.rb new file mode 100644 index 0000000000000000000000000000000000000000..7d14a25631cd4101042ca12f16507396718e2e7e --- /dev/null +++ b/ee/app/services/ai/duo_workflows/mcp_config_service.rb @@ -0,0 +1,63 @@ +# frozen_string_literal: true + +module Ai + module DuoWorkflows + class McpConfigService + GITLAB_ENABLED_TOOLS = ['get_issue'].freeze + + def initialize(current_user, gitlab_token) + @current_user = current_user + @gitlab_token = gitlab_token + end + + # This method returns configuration for supported MCP servers + # + # Expected configuration format is: + # + # { + # server_name: { + # URL: , + # Headers: , + # Tools: # empty means that all tools will be listed + # } + # } + # + # GitLab configuration is hard-coded, while the list may also contain other server configurations + # For example, + # { + # gitlab: gitlab_mcp_server, + # context7: { + # URL: "https://mcp.context7.com/mcp", + # } + # } + # + # Or the list can be extended by user provided configurations on namespace/project/user levels + def execute + return unless Feature.enabled?(:mcp_client, current_user) + + { + gitlab: gitlab_mcp_server + } + end + + def gitlab_enabled_tools + return [] unless Feature.enabled?(:mcp_client, current_user) + + GITLAB_ENABLED_TOOLS + end + + private + + attr_reader :gitlab_token, :current_user + + def gitlab_mcp_server + { + Headers: { + Authorization: "Bearer #{gitlab_token}" + }, + Tools: GITLAB_ENABLED_TOOLS + } + end + end + end +end diff --git a/ee/config/feature_flags/gitlab_com_derisk/mcp_client.yml b/ee/config/feature_flags/gitlab_com_derisk/mcp_client.yml new file mode 100644 index 0000000000000000000000000000000000000000..f196d2f48a20dc87f6a877859ad757e9483260a6 --- /dev/null +++ b/ee/config/feature_flags/gitlab_com_derisk/mcp_client.yml @@ -0,0 +1,10 @@ +--- +name: mcp_client +description: Enable Workhorse MCP client +feature_issue_url: https://gitlab.com/gitlab-org/gitlab/-/issues/561296 +introduced_by_url: https://gitlab.com/gitlab-org/gitlab/-/merge_requests/206445 +rollout_issue_url: https://gitlab.com/gitlab-org/gitlab/-/issues/572340 +milestone: '18.5' +group: group::agent foundations +type: gitlab_com_derisk +default_enabled: false diff --git a/ee/lib/api/ai/duo_workflows/workflows.rb b/ee/lib/api/ai/duo_workflows/workflows.rb index d4a03de76e6b1a9ffc916b475009cd9aab34d6cb..ef6c30af7142f6abf079318411fa9bd9f72fc69b 100644 --- a/ee/lib/api/ai/duo_workflows/workflows.rb +++ b/ee/lib/api/ai/duo_workflows/workflows.rb @@ -297,9 +297,12 @@ def create_workflow_params .new(current_user, :duo_agent_platform, root_namespace) .execute.payload + gitlab_token = gitlab_oauth_token.plaintext_token + mcp_config_service = ::Ai::DuoWorkflows::McpConfigService.new(current_user, gitlab_token) grpc_headers = Gitlab::DuoWorkflow::Client.cloud_connector_headers(user: current_user).merge( - 'x-gitlab-oauth-token' => gitlab_oauth_token.plaintext_token, - 'x-gitlab-unidirectional-streaming' => 'enabled' + 'x-gitlab-oauth-token' => gitlab_token, + 'x-gitlab-unidirectional-streaming' => 'enabled', + 'x-gitlab-enabled-mcp-server-tools' => mcp_config_service.gitlab_enabled_tools.join(',') ).merge(model_metadata_headers) grpc_headers['x-gitlab-project-id'] ||= params[:project_id].presence @@ -318,7 +321,8 @@ def create_workflow_params DuoWorkflow: { Headers: grpc_headers, ServiceURI: Gitlab::DuoWorkflow::Client.url_for(feature_setting: feature_setting, user: current_user), - Secure: Gitlab::DuoWorkflow::Client.secure? + Secure: Gitlab::DuoWorkflow::Client.secure?, + McpServers: mcp_config_service.execute } } end diff --git a/ee/spec/requests/api/ai/duo_workflows/workflows_spec.rb b/ee/spec/requests/api/ai/duo_workflows/workflows_spec.rb index 36908e72e64975ad87346628a1d614639e497591..dff8b9b413dae632cf6facea02f69c8507306864 100644 --- a/ee/spec/requests/api/ai/duo_workflows/workflows_spec.rb +++ b/ee/spec/requests/api/ai/duo_workflows/workflows_spec.rb @@ -891,6 +891,7 @@ expect(response).to have_gitlab_http_status(:ok) expect(response.media_type).to eq(Gitlab::Workhorse::INTERNAL_API_CONTENT_TYPE) + enabled_mcp_tools = ::Ai::DuoWorkflows::McpConfigService::GITLAB_ENABLED_TOOLS expect(json_response['DuoWorkflow']['Headers']).to include( 'x-gitlab-oauth-token' => 'oauth_token', 'authorization' => 'Bearer token', @@ -898,10 +899,19 @@ 'x-gitlab-enabled-feature-flags' => anything, 'x-gitlab-instance-id' => anything, 'x-gitlab-version' => Gitlab.version_info.to_s, - 'x-gitlab-unidirectional-streaming' => 'enabled' + 'x-gitlab-unidirectional-streaming' => 'enabled', + 'x-gitlab-enabled-mcp-server-tools' => enabled_mcp_tools.join(',') ) expect(json_response['DuoWorkflow']['Secure']).to eq(true) + expect(json_response['DuoWorkflow']['McpServers']).to eq({ + "gitlab" => { + "Headers" => { + "Authorization" => "Bearer oauth_token" + }, + "Tools" => enabled_mcp_tools + } + }) end it_behaves_like 'ServiceURI has the right value', false diff --git a/ee/spec/services/ai/duo_workflows/mcp_config_service_spec.rb b/ee/spec/services/ai/duo_workflows/mcp_config_service_spec.rb new file mode 100644 index 0000000000000000000000000000000000000000..208e0f3e3ffce2e9d635f911bafd5f0e45dafa18 --- /dev/null +++ b/ee/spec/services/ai/duo_workflows/mcp_config_service_spec.rb @@ -0,0 +1,97 @@ +# frozen_string_literal: true + +require 'spec_helper' + +RSpec.describe Ai::DuoWorkflows::McpConfigService, feature_category: :duo_agent_platform do + let_it_be(:user) { create(:user) } + let(:gitlab_token) { 'test_gitlab_token_12345' } + + subject(:service) { described_class.new(user, gitlab_token) } + + describe '#execute' do + it 'returns configuration hash with gitlab server' do + result = service.execute + + expect(result).to be_a(Hash) + expect(result).to have_key(:gitlab) + end + + it 'includes gitlab MCP server configuration' do + result = service.execute + + expect(result[:gitlab]).to match( + Headers: { + Authorization: "Bearer #{gitlab_token}" + }, + Tools: described_class::GITLAB_ENABLED_TOOLS + ) + end + + it 'includes proper authorization header with token' do + result = service.execute + + expect(result[:gitlab][:Headers][:Authorization]).to eq("Bearer #{gitlab_token}") + end + + it 'includes enabled tools list' do + result = service.execute + + expect(result[:gitlab][:Tools]).to eq(['get_issue']) + end + + context 'when mcp_client feature flag is disabled' do + before do + stub_feature_flags(mcp_client: false) + end + + it 'returns nil' do + result = service.execute + + expect(result).to be_nil + end + end + + context 'with different gitlab tokens' do + it 'uses the provided token in authorization header' do + custom_token = 'custom_token_xyz' + custom_service = described_class.new(user, custom_token) + + result = custom_service.execute + + expect(result[:gitlab][:Headers][:Authorization]).to eq("Bearer #{custom_token}") + end + end + end + + describe '#gitlab_enabled_tools' do + it 'returns array of enabled tools' do + result = service.gitlab_enabled_tools + + expect(result).to eq(['get_issue']) + end + + it 'returns the GITLAB_ENABLED_TOOLS constant' do + result = service.gitlab_enabled_tools + + expect(result).to eq(described_class::GITLAB_ENABLED_TOOLS) + end + + context 'when mcp_client feature flag is disabled' do + before do + stub_feature_flags(mcp_client: false) + end + + it 'returns empty array' do + result = service.gitlab_enabled_tools + + expect(result).to eq([]) + end + end + end + + describe 'constant GITLAB_ENABLED_TOOLS' do + it 'is defined with expected tools' do + expect(described_class::GITLAB_ENABLED_TOOLS).to eq(['get_issue']) + end + end +end diff --git a/workhorse/.golangci.yml b/workhorse/.golangci.yml index 59918f1c13f07159aacb49cb3eca6e05b8904a98..3a5c55db164c634e503b751edf62c12e297456cd 100644 --- a/workhorse/.golangci.yml +++ b/workhorse/.golangci.yml @@ -131,6 +131,7 @@ linters: - github.com/jpillora/backoff - github.com/sony/gobreaker/v2 - gitlab.com/gitlab-org/modelops/applied-ml/code-suggestions/ai-assist/clients/gopb + - github.com/modelcontextprotocol/go-sdk - github.com/go-redsync/redsync - github.com/go-redsync/redsync/v4/redis/goredis/v9 # gRPC and Protocol Buffers diff --git a/workhorse/go.mod b/workhorse/go.mod index d738a31e56a3c9af796046dc734b9630127b6ea0..6b7c4da8b566d69284a0ae9e018f6dce29a8480b 100644 --- a/workhorse/go.mod +++ b/workhorse/go.mod @@ -24,6 +24,7 @@ require ( github.com/johannesboyne/gofakes3 v0.0.0-20240701191259-edd0227ffc37 github.com/jpillora/backoff v1.0.0 github.com/mitchellh/copystructure v1.2.0 + github.com/modelcontextprotocol/go-sdk v0.7.0 github.com/prometheus/client_golang v1.23.0 github.com/redis/go-redis/v9 v9.10.0 github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a @@ -101,6 +102,7 @@ require ( github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/flatbuffers v25.2.10+incompatible // indirect + github.com/google/jsonschema-go v0.3.0 // indirect github.com/google/pprof v0.0.0-20240711041743-f6c9dda6c6da // indirect github.com/google/s2a-go v0.1.8 // indirect github.com/google/uuid v1.6.0 // indirect @@ -141,6 +143,7 @@ require ( github.com/tklauser/numcpus v0.3.0 // indirect github.com/uber/jaeger-client-go v2.30.0+incompatible // indirect github.com/uber/jaeger-lib v2.4.1+incompatible // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yusufpapurcu/wmi v1.2.2 // indirect github.com/zeebo/errs v1.4.0 // indirect gitlab.com/gitlab-org/go/reopen v1.0.0 // indirect diff --git a/workhorse/go.sum b/workhorse/go.sum index a39bdc6d9470b733138aaf68c0d8ffadbafaf87f..92436f252bc4441a9dd8edabfa8c2b8b0eda9c18 100644 --- a/workhorse/go.sum +++ b/workhorse/go.sum @@ -361,6 +361,8 @@ github.com/google/go-replayers/httpreplay v1.2.0 h1:VM1wEyyjaoU53BwrOnaf9VhAyQQE github.com/google/go-replayers/httpreplay v1.2.0/go.mod h1:WahEFFZZ7a1P4VM1qEeHy+tME4bwyqPcwWbNlUI1Mcg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= +github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/martian v2.1.0+incompatible h1:/CP5g8u/VJHijgedC/Legn3BAbAaWPgecwXBIDzw5no= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= @@ -493,6 +495,8 @@ github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zx github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g= github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28= +github.com/modelcontextprotocol/go-sdk v0.7.0 h1:XEQfn3bDx2cAdSUKty3tYEMll5dtRgBUDX88Q65fai0= +github.com/modelcontextprotocol/go-sdk v0.7.0/go.mod h1:nYtYQroQ2KQiM0/SbyEPUWQ6xs4B95gJjEalc9AQyOs= github.com/montanaflynn/stats v0.7.0 h1:r3y12KyNxj/Sb/iOE46ws+3mS1+MZca1wlHQFPsY/JU= github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= @@ -618,6 +622,8 @@ github.com/uber/jaeger-lib v2.4.1+incompatible h1:td4jdvLcExb4cBISKIpHuGoVXh+dVK github.com/uber/jaeger-lib v2.4.1+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U= github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/workhorse/internal/ai_assist/duoworkflow/actions_test.go b/workhorse/internal/ai_assist/duoworkflow/actions_test.go index fd10225181ab731e4a1bffab4231726ec14b9ffa..eeca61a9d792a34368e814bdf39133090bb2c359 100644 --- a/workhorse/internal/ai_assist/duoworkflow/actions_test.go +++ b/workhorse/internal/ai_assist/duoworkflow/actions_test.go @@ -18,6 +18,8 @@ import ( "gitlab.com/gitlab-org/gitlab/workhorse/internal/api" ) +const testRemoteAddr = "192.0.2.1:1234" + // errorReader is a mock reader that always returns an error type errorReader struct{} @@ -73,7 +75,7 @@ func TestRunHttpActionHandler_Execute(t *testing.T) { } originalReq := httptest.NewRequest("GET", "/ws", nil) - originalReq.RemoteAddr = "192.0.2.1:1234" + originalReq.RemoteAddr = testRemoteAddr handler := &runHTTPActionHandler{ rails: &api.API{ @@ -119,7 +121,7 @@ func TestRunHttpActionHandler_Execute(t *testing.T) { } originalReq := httptest.NewRequest("GET", "/ws", nil) - originalReq.RemoteAddr = "192.0.2.1:1234" + originalReq.RemoteAddr = testRemoteAddr originalReq.Header.Set("X-Forwarded-For", "127.0.0.1:3000") handler := &runHTTPActionHandler{ diff --git a/workhorse/internal/ai_assist/duoworkflow/mcp.go b/workhorse/internal/ai_assist/duoworkflow/mcp.go new file mode 100644 index 0000000000000000000000000000000000000000..29ce42494e3f22e7249dbe05c299270b886c66cc --- /dev/null +++ b/workhorse/internal/ai_assist/duoworkflow/mcp.go @@ -0,0 +1,303 @@ +package duoworkflow + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "reflect" + "slices" + "strings" + + pb "gitlab.com/gitlab-org/modelops/applied-ml/code-suggestions/ai-assist/clients/gopb/contract" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/log" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/api" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/transport" +) + +const gitlabServerName = "gitlab" + +type serverSession struct { + name string + cfg api.McpServerConfig + session *mcp.ClientSession +} + +type toolSession struct { + originalName string + session *mcp.ClientSession +} + +type mcpManager interface { + HasTool(string) bool + CallTool(context.Context, *pb.Action) (*pb.ClientEvent, error) + Tools() []*pb.McpTool + Close() error +} + +type manager struct { + tools []*pb.McpTool + toolSessionsByName map[string]*toolSession + serverSessions []*serverSession +} + +type roundTripper struct { + next http.RoundTripper + headers map[string]string + originalReq *http.Request +} + +type limitedReadCloser struct { + io.LimitedReader + closer io.Closer +} + +func (lrc *limitedReadCloser) Close() error { + return lrc.closer.Close() +} + +func (t *roundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + for name, value := range t.headers { + r.Header.Set(name, value) + } + r.Header.Set("User-Agent", "GitLab-Workhorse-Mcp-Client") + + if t.originalReq != nil { + if clientIP, _, splitHostErr := net.SplitHostPort(t.originalReq.RemoteAddr); splitHostErr == nil { + var header string + if prior, ok := t.originalReq.Header["X-Forwarded-For"]; ok { + header = strings.Join(prior, ", ") + ", " + clientIP + } else { + header = clientIP + } + r.Header.Set("X-Forwarded-For", header) + } + } + + resp, err := t.next.RoundTrip(r) + if resp != nil && resp.Body != nil { + resp.Body = &limitedReadCloser{ + LimitedReader: io.LimitedReader{ + R: resp.Body, + N: ActionResponseBodyLimit, + }, + closer: resp.Body, + } + } + + return resp, err +} + +func newMcpManager(rails *api.API, r *http.Request, servers map[string]api.McpServerConfig) (*manager, error) { + if len(servers) == 0 { + return nil, fmt.Errorf("the list of server configs is empty") + } + + var errs []error + var sessions []*serverSession + + for serverName, serverCfg := range servers { + session, err := buildSession(rails, r, serverName, serverCfg) + if err != nil { + errs = append(errs, fmt.Errorf("failed to initialize MCP session %s: %v", serverName, err)) + continue + } + + sessions = append(sessions, session) + } + + manager := &manager{ + toolSessionsByName: make(map[string]*toolSession), + serverSessions: sessions, + } + + if err := manager.buildTools(r.Context()); err != nil { + errs = append(errs, err) + } + + return manager, errors.Join(errs...) +} + +func buildSession(rails *api.API, r *http.Request, serverName string, serverCfg api.McpServerConfig) (*serverSession, error) { + client := mcp.NewClient(&mcp.Implementation{Name: "mcp-client", Version: "v1.0.0"}, nil) + + var t *mcp.StreamableClientTransport + + var endpoint string + var nextTransport http.RoundTripper + + if serverName == gitlabServerName { + endpoint = rails.URL.JoinPath("api/v4/mcp").String() + nextTransport = rails.Client.Transport + } else { + endpoint = serverCfg.URL + nextTransport = transport.NewRestrictedTransport() + } + + t = &mcp.StreamableClientTransport{ + Endpoint: endpoint, + HTTPClient: &http.Client{ + Transport: &roundTripper{ + next: nextTransport, + headers: serverCfg.Headers, + originalReq: r, + }, + }, + } + + session, err := client.Connect(r.Context(), t, nil) + if err != nil { + return nil, err + } + + return &serverSession{name: serverName, cfg: serverCfg, session: session}, err +} + +func (m *manager) buildTools(ctx context.Context) error { + var errs []error + + for _, s := range m.serverSessions { + toolsResult, err := s.session.ListTools(ctx, &mcp.ListToolsParams{}) + if err != nil { + errs = append(errs, fmt.Errorf("failed to list tools %s: %v", s.name, err)) + continue + } + + // If s.cfg.Tools is missing (nil), then all tools are available + // Otherwise, we filter the list of tools based on the provided value + allToolsAvailable := true + var configuredTools []string + if s.cfg.Tools != nil { + allToolsAvailable = false + configuredTools = *s.cfg.Tools + } + + for _, tool := range toolsResult.Tools { + schemaBytes, err := json.Marshal(tool.InputSchema) + if err != nil { + errs = append(errs, fmt.Errorf("failed to marshal input schema, server: %s, tool: %s, error: %v", s.name, tool.Name, err)) + continue + } + + if allToolsAvailable || slices.Contains(configuredTools, tool.Name) { + prefixedName := s.name + "_" + tool.Name + + mcpTool := &pb.McpTool{ + Name: prefixedName, + Description: tool.Description, + InputSchema: string(schemaBytes), + } + + m.tools = append(m.tools, mcpTool) + + m.toolSessionsByName[prefixedName] = &toolSession{ + originalName: tool.Name, + session: s.session, + } + } + } + } + + return errors.Join(errs...) +} + +func (m *manager) HasTool(name string) bool { + if m == nil { + return false + } + + _, ok := m.toolSessionsByName[name] + return ok +} + +func (m *manager) Tools() []*pb.McpTool { + if m == nil { + return nil + } + + return m.tools +} + +func (m *manager) CallTool(ctx context.Context, action *pb.Action) (*pb.ClientEvent, error) { + mcpTool := action.GetRunMCPTool() + + log.WithContextFields(ctx, log.Fields{ + "name": mcpTool.Name, + "args_size": len(mcpTool.Args), + "request_id": action.RequestID, + }).Info("Calling an MCP tool") + + toolSession, ok := m.toolSessionsByName[mcpTool.Name] + if !ok { + return nil, fmt.Errorf("CallTool: unknown tool: %v", mcpTool.Name) + } + + var arguments map[string]any + if err := json.Unmarshal([]byte(mcpTool.Args), &arguments); err != nil { + return nil, fmt.Errorf("CallTool: failed to unmarshal MCP args: %v", err) + } + params := &mcp.CallToolParams{ + Name: toolSession.originalName, + Arguments: arguments, + } + + res, err := toolSession.session.CallTool(ctx, params) + if err != nil { + return nil, fmt.Errorf("CallTool: failed to call MCP tool: %v", err) + } + + var content string + if len(res.Content) == 0 { + content = "MCP tool response is empty" + } else { + if textContent, ok := res.Content[0].(*mcp.TextContent); ok { + content = textContent.Text + } else { + log.WithContextFields(ctx, log.Fields{ + "name": mcpTool.Name, + "request_id": action.RequestID, + "content_type": reflect.TypeOf(res.Content[0]).String(), + }).Info("MCP tool response content type not supported") + content = "MCP tool response content type not supported" + } + } + + response := &pb.PlainTextResponse{} + if res.IsError { + response.Error = content + } else { + response.Response = content + } + + return &pb.ClientEvent{ + Response: &pb.ClientEvent_ActionResponse{ + ActionResponse: &pb.ActionResponse{ + RequestID: action.RequestID, + ResponseType: &pb.ActionResponse_PlainTextResponse{ + PlainTextResponse: response, + }, + }, + }, + }, nil +} + +func (m *manager) Close() error { + if m == nil { + return nil + } + + var errs []error + for _, s := range m.serverSessions { + errs = append(errs, s.session.Close()) + } + + return errors.Join(errs...) +} diff --git a/workhorse/internal/ai_assist/duoworkflow/mcp_test.go b/workhorse/internal/ai_assist/duoworkflow/mcp_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c1833854c9f2b0cf80cc60fbb45d7efc94e51df2 --- /dev/null +++ b/workhorse/internal/ai_assist/duoworkflow/mcp_test.go @@ -0,0 +1,865 @@ +package duoworkflow + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + pb "gitlab.com/gitlab-org/modelops/applied-ml/code-suggestions/ai-assist/clients/gopb/contract" + + "gitlab.com/gitlab-org/gitlab/workhorse/internal/api" +) + +func TestRoundTripper_RoundTrip(t *testing.T) { + t.Run("adds custom headers", func(t *testing.T) { + var capturedRequest *http.Request + var transportFunc = func(req *http.Request) (*http.Response, error) { + capturedRequest = req + return &http.Response{ + StatusCode: 200, + Body: http.NoBody, + Header: make(http.Header), + }, nil + } + + rt := &roundTripper{ + next: &mockTransportFunc{fn: transportFunc}, + headers: map[string]string{ + "Authorization": "Bearer test-token", + "X-Custom-Header": "custom-value", + }, + } + + req := httptest.NewRequest("GET", "http://example.com", nil) + r, err := rt.RoundTrip(req) + require.NoError(t, err) + require.NoError(t, r.Body.Close()) + + require.NotNil(t, capturedRequest) + assert.Equal(t, "Bearer test-token", capturedRequest.Header.Get("Authorization")) + assert.Equal(t, "custom-value", capturedRequest.Header.Get("X-Custom-Header")) + assert.Equal(t, "GitLab-Workhorse-Mcp-Client", capturedRequest.Header.Get("User-Agent")) + }) + + t.Run("sets X-Forwarded-For from original request", func(t *testing.T) { + var capturedRequest *http.Request + var transportFunc = func(req *http.Request) (*http.Response, error) { + capturedRequest = req + return &http.Response{ + StatusCode: 200, + Body: http.NoBody, + Header: make(http.Header), + }, nil + } + + originalReq := httptest.NewRequest("GET", "/test", nil) + originalReq.RemoteAddr = "192.0.2.1:1234" + + rt := &roundTripper{ + next: &mockTransportFunc{fn: transportFunc}, + headers: map[string]string{}, + originalReq: originalReq, + } + + req := httptest.NewRequest("GET", "http://example.com", nil) + r, err := rt.RoundTrip(req) + require.NoError(t, err) + require.NoError(t, r.Body.Close()) + + require.NotNil(t, capturedRequest) + assert.Equal(t, "192.0.2.1", capturedRequest.Header.Get("X-Forwarded-For")) + }) + + t.Run("appends to existing X-Forwarded-For", func(t *testing.T) { + var capturedRequest *http.Request + var transportFunc = func(req *http.Request) (*http.Response, error) { + capturedRequest = req + return &http.Response{ + StatusCode: 200, + Body: http.NoBody, + Header: make(http.Header), + }, nil + } + + originalReq := httptest.NewRequest("GET", "/test", nil) + originalReq.RemoteAddr = "192.0.2.1:1234" + originalReq.Header.Set("X-Forwarded-For", "10.0.0.1") + + rt := &roundTripper{ + next: &mockTransportFunc{fn: transportFunc}, + headers: map[string]string{}, + originalReq: originalReq, + } + + req := httptest.NewRequest("GET", "http://example.com", nil) + r, err := rt.RoundTrip(req) + require.NoError(t, err) + require.NoError(t, r.Body.Close()) + + require.NotNil(t, capturedRequest) + assert.Equal(t, "10.0.0.1, 192.0.2.1", capturedRequest.Header.Get("X-Forwarded-For")) + }) + + t.Run("handles invalid RemoteAddr gracefully", func(t *testing.T) { + var capturedRequest *http.Request + var transportFunc = func(req *http.Request) (*http.Response, error) { + capturedRequest = req + return &http.Response{ + StatusCode: 200, + Body: http.NoBody, + Header: make(http.Header), + }, nil + } + + originalReq := httptest.NewRequest("GET", "/test", nil) + originalReq.RemoteAddr = "invalid-address" + + rt := &roundTripper{ + next: &mockTransportFunc{fn: transportFunc}, + headers: map[string]string{}, + originalReq: originalReq, + } + + req := httptest.NewRequest("GET", "http://example.com", nil) + r, err := rt.RoundTrip(req) + require.NoError(t, err) + require.NoError(t, r.Body.Close()) + + require.NotNil(t, capturedRequest) + assert.Empty(t, capturedRequest.Header.Get("X-Forwarded-For")) + }) + + t.Run("wraps response body with limited reader", func(t *testing.T) { + responseBody := "This is a test response body that should be limited" + var transportFunc = func(_ *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(responseBody)), + Header: make(http.Header), + }, nil + } + + rt := &roundTripper{ + next: &mockTransportFunc{fn: transportFunc}, + headers: map[string]string{}, + } + + req := httptest.NewRequest("GET", "http://example.com", nil) + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + require.NotNil(t, resp) + defer resp.Body.Close() + + _, ok := resp.Body.(*limitedReadCloser) + require.True(t, ok, "Response body should be wrapped with limitedReadCloser") + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, responseBody, string(body)) + }) + + t.Run("limits response body to ActionResponseBodyLimit", func(t *testing.T) { + largeBody := strings.Repeat("x", ActionResponseBodyLimit+1000) + var transportFunc = func(_ *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(largeBody)), + Header: make(http.Header), + }, nil + } + + rt := &roundTripper{ + next: &mockTransportFunc{fn: transportFunc}, + headers: map[string]string{}, + } + + req := httptest.NewRequest("GET", "http://example.com", nil) + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + require.NotNil(t, resp) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Len(t, body, ActionResponseBodyLimit) + assert.Equal(t, strings.Repeat("x", ActionResponseBodyLimit), string(body)) + }) +} + +func TestNewMcpManager(t *testing.T) { + t.Run("successful initialization with single server", func(t *testing.T) { + mcpServer := setupMockMcpServer(t, "", []mcpTool{ + {Name: "test_tool", Description: "A test tool"}, + }) + + req := httptest.NewRequest("GET", "/test", nil) + + servers := map[string]api.McpServerConfig{ + "test-server": { + URL: mcpServer.URL, + Headers: map[string]string{}, + Tools: &([]string{"test_tool"}), + }, + } + + mgr, err := newMcpManager(nil, req, servers) + + require.NoError(t, err) + require.NotNil(t, mgr) + assert.Len(t, mgr.tools, 1) + assert.Equal(t, "test-server_test_tool", mgr.tools[0].Name) + assert.Len(t, mgr.toolSessionsByName, 1) + assert.Len(t, mgr.serverSessions, 1) + }) + + t.Run("successful initialization with multiple servers", func(t *testing.T) { + mcpServer1 := setupMockMcpServer(t, "", []mcpTool{ + {Name: "tool1", Description: "Tool 1"}, + }) + mcpServer2 := setupMockMcpServer(t, "", []mcpTool{ + {Name: "tool2", Description: "Tool 2"}, + }) + + req := httptest.NewRequest("GET", "/test", nil) + + servers := map[string]api.McpServerConfig{ + "server1": { + URL: mcpServer1.URL, + Headers: map[string]string{}, + Tools: &([]string{"tool1"}), + }, + "server2": { + URL: mcpServer2.URL, + Headers: map[string]string{}, + Tools: &([]string{"tool2"}), + }, + } + + mgr, err := newMcpManager(nil, req, servers) + + require.NoError(t, err) + require.NotNil(t, mgr) + assert.Len(t, mgr.tools, 2) + assert.Len(t, mgr.toolSessionsByName, 2) + assert.Len(t, mgr.serverSessions, 2) + }) + + t.Run("filters tools based on config", func(t *testing.T) { + mcpServer := setupMockMcpServer(t, "", []mcpTool{ + {Name: "allowed_tool", Description: "Allowed tool"}, + {Name: "blocked_tool", Description: "Blocked tool"}, + }) + + req := httptest.NewRequest("GET", "/test", nil) + + servers := map[string]api.McpServerConfig{ + "test-server": { + URL: mcpServer.URL, + Headers: map[string]string{}, + Tools: &([]string{"allowed_tool"}), + }, + } + + mgr, err := newMcpManager(nil, req, servers) + + require.NoError(t, err) + require.NotNil(t, mgr) + assert.Len(t, mgr.tools, 1) + assert.Equal(t, "test-server_allowed_tool", mgr.tools[0].Name) + }) + + t.Run("includes all tools when Tools config is empty", func(t *testing.T) { + mcpServer := setupMockMcpServer(t, "", []mcpTool{ + {Name: "tool1", Description: "Tool 1"}, + {Name: "tool2", Description: "Tool 2"}, + }) + + req := httptest.NewRequest("GET", "/test", nil) + + servers := map[string]api.McpServerConfig{ + "test-server": { + URL: mcpServer.URL, + Headers: map[string]string{}, + }, + } + + mgr, err := newMcpManager(nil, req, servers) + + require.NoError(t, err) + require.NotNil(t, mgr) + assert.Len(t, mgr.tools, 2) + }) + + t.Run("filters all tools when Tools config is empty", func(t *testing.T) { + mcpServer := setupMockMcpServer(t, "gitlab", []mcpTool{ + {Name: "tool1", Description: "Tool 1"}, + {Name: "tool2", Description: "Tool 2"}, + }) + + apiURL, err := url.Parse(mcpServer.URL) + require.NoError(t, err) + + rails := api.NewAPI(apiURL, "test-version", http.DefaultTransport) + req := httptest.NewRequest("GET", "/test", nil) + + servers := map[string]api.McpServerConfig{ + "gitlab": { + URL: mcpServer.URL, + Headers: map[string]string{}, + Tools: &([]string{}), + }, + } + + mgr, err := newMcpManager(rails, req, servers) + + require.NoError(t, err) + require.NotNil(t, mgr) + assert.Empty(t, mgr.tools) + }) + + t.Run("returns error when servers map is empty", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + + servers := map[string]api.McpServerConfig{} + + mgr, err := newMcpManager(nil, req, servers) + + require.Error(t, err) + require.Nil(t, mgr) + assert.Contains(t, err.Error(), "the list of server configs is empty") + }) + + t.Run("continues with partial success when one server fails", func(t *testing.T) { + mcpServer := setupMockMcpServer(t, "", []mcpTool{ + {Name: "tool1", Description: "Tool 1"}, + }) + + req := httptest.NewRequest("GET", "/test", nil) + + servers := map[string]api.McpServerConfig{ + "good-server": { + URL: mcpServer.URL, + Headers: map[string]string{}, + Tools: &([]string{"tool1"}), + }, + "bad-server": { + URL: "http://localhost:1", // Use a port that's likely to be refused + Headers: map[string]string{}, + Tools: &([]string{"tool2"}), + }, + } + + mgr, err := newMcpManager(nil, req, servers) + + require.Error(t, err) + require.NotNil(t, mgr) + assert.Len(t, mgr.tools, 1) + assert.Contains(t, err.Error(), "failed to initialize MCP session bad-server") + }) +} + +func TestManager_HasTool(t *testing.T) { + t.Run("returns true for existing tool", func(t *testing.T) { + mgr := &manager{ + toolSessionsByName: map[string]*toolSession{ + "test_tool": {}, + }, + } + + assert.True(t, mgr.HasTool("test_tool")) + }) + + t.Run("returns false for non-existing tool", func(t *testing.T) { + mgr := &manager{ + toolSessionsByName: map[string]*toolSession{ + "test_tool": {}, + }, + } + + assert.False(t, mgr.HasTool("other_tool")) + }) + + t.Run("returns false when manager is nil", func(t *testing.T) { + var mgr *manager + assert.False(t, mgr.HasTool("test_tool")) + }) +} + +func TestManager_Tools(t *testing.T) { + t.Run("returns tools list", func(t *testing.T) { + tools := []*pb.McpTool{ + {Name: "tool1", Description: "Tool 1"}, + {Name: "tool2", Description: "Tool 2"}, + } + + mgr := &manager{ + tools: tools, + } + + result := mgr.Tools() + assert.Equal(t, tools, result) + assert.Len(t, result, 2) + }) + + t.Run("returns nil when manager is nil", func(t *testing.T) { + var mgr *manager + assert.Nil(t, mgr.Tools()) + }) + + t.Run("returns empty list when no tools", func(t *testing.T) { + mgr := &manager{ + tools: []*pb.McpTool{}, + } + + result := mgr.Tools() + assert.NotNil(t, result) + assert.Empty(t, result) + }) +} + +func TestManager_CallTool(t *testing.T) { + t.Run("successfully calls tool with valid arguments", func(t *testing.T) { + mcpServer := setupMockMcpServerWithCallHandler(t, "", []mcpTool{ + {Name: "test_tool", Description: "A test tool"}, + }, func(name string, args map[string]any) (string, bool, error) { + assert.Equal(t, "test_tool", name) + assert.Equal(t, "123", args["issue_id"]) + return `{"id": 123, "title": "Test Issue"}`, false, nil + }) + + req := httptest.NewRequest("GET", "/test", nil) + + servers := map[string]api.McpServerConfig{ + "test-server": { + URL: mcpServer.URL, + Headers: map[string]string{}, + }, + } + + mgr, err := newMcpManager(nil, req, servers) + require.NoError(t, err) + + action := &pb.Action{ + RequestID: "req-mcp-123", + Action: &pb.Action_RunMCPTool{ + RunMCPTool: &pb.RunMCPTool{ + Name: "test-server_test_tool", + Args: `{"issue_id": "123"}`, + }, + }, + } + + result, err := mgr.CallTool(context.Background(), action) + + require.NoError(t, err) + require.NotNil(t, result) + plainTextResp := result.GetActionResponse().GetPlainTextResponse() + assert.Contains(t, plainTextResp.Response, "Test Issue") + assert.Empty(t, plainTextResp.Error) + }) + + t.Run("returns error for unknown tool", func(t *testing.T) { + mgr := &manager{ + toolSessionsByName: map[string]*toolSession{}, + } + + action := &pb.Action{ + RequestID: "req-mcp-123", + Action: &pb.Action_RunMCPTool{ + RunMCPTool: &pb.RunMCPTool{ + Name: "unknown_tool", + Args: `{}`, + }, + }, + } + + result, err := mgr.CallTool(context.Background(), action) + + require.Error(t, err) + require.Nil(t, result) + assert.Contains(t, err.Error(), "unknown tool") + }) + + t.Run("returns error for invalid JSON arguments", func(t *testing.T) { + mgr := &manager{ + toolSessionsByName: map[string]*toolSession{ + "test_tool": {}, + }, + } + + action := &pb.Action{ + RequestID: "req-mcp-123", + Action: &pb.Action_RunMCPTool{ + RunMCPTool: &pb.RunMCPTool{ + Name: "test_tool", + Args: `invalid json`, + }, + }, + } + + result, err := mgr.CallTool(context.Background(), action) + + require.Error(t, err) + require.Nil(t, result) + assert.Contains(t, err.Error(), "failed to unmarshal MCP args") + }) + + t.Run("handles MCP error response", func(t *testing.T) { + mcpServer := setupMockMcpServerWithCallHandler(t, "", []mcpTool{ + {Name: "test_tool", Description: "A test tool"}, + }, func(_ string, _ map[string]any) (string, bool, error) { + return "Tool execution failed", true, nil + }) + + req := httptest.NewRequest("GET", "/test", nil) + + servers := map[string]api.McpServerConfig{ + "test-server": { + URL: mcpServer.URL, + Headers: map[string]string{}, + }, + } + + mgr, err := newMcpManager(nil, req, servers) + require.NoError(t, err) + + action := &pb.Action{ + RequestID: "req-mcp-123", + Action: &pb.Action_RunMCPTool{ + RunMCPTool: &pb.RunMCPTool{ + Name: "test-server_test_tool", + Args: `{}`, + }, + }, + } + + result, err := mgr.CallTool(context.Background(), action) + + require.NoError(t, err) + require.NotNil(t, result) + plainTextResp := result.GetActionResponse().GetPlainTextResponse() + assert.Equal(t, "Tool execution failed", plainTextResp.Error) + assert.Empty(t, plainTextResp.Response) + }) + + t.Run("handles empty MCP response", func(t *testing.T) { + mcpServer := setupMockMcpServerWithCallHandler(t, "", []mcpTool{ + {Name: "test_tool", Description: "A test tool"}, + }, func(_ string, _ map[string]any) (string, bool, error) { + return "", false, nil + }) + + req := httptest.NewRequest("GET", "/test", nil) + + servers := map[string]api.McpServerConfig{ + "test-server": { + URL: mcpServer.URL, + Headers: map[string]string{}, + }, + } + + mgr, err := newMcpManager(nil, req, servers) + require.NoError(t, err) + + action := &pb.Action{ + RequestID: "req-mcp-123", + Action: &pb.Action_RunMCPTool{ + RunMCPTool: &pb.RunMCPTool{ + Name: "test-server_test_tool", + Args: `{}`, + }, + }, + } + + result, err := mgr.CallTool(context.Background(), action) + + require.NoError(t, err) + require.NotNil(t, result) + plainTextResp := result.GetActionResponse().GetPlainTextResponse() + assert.Equal(t, "MCP tool response is empty", plainTextResp.Response) + }) +} + +func TestManager_Close(t *testing.T) { + t.Run("closes all sessions successfully", func(t *testing.T) { + mcpServer := setupMockMcpServer(t, "", []mcpTool{ + {Name: "test_tool", Description: "A test tool"}, + }) + + req := httptest.NewRequest("GET", "/test", nil) + + servers := map[string]api.McpServerConfig{ + "test-server": { + URL: mcpServer.URL, + Headers: map[string]string{}, + }, + } + + mgr, err := newMcpManager(nil, req, servers) + require.NoError(t, err) + + err = mgr.Close() + assert.NoError(t, err) + }) + + t.Run("returns nil when manager is nil", func(t *testing.T) { + var mgr *manager + err := mgr.Close() + assert.NoError(t, err) + }) +} + +// Helper types and functions for testing + +type mockTransportFunc struct { + fn func(*http.Request) (*http.Response, error) +} + +func (m *mockTransportFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return m.fn(req) +} + +type mcpTool struct { + Name string + Description string +} + +func setupMockMcpServer(t *testing.T, name string, tools []mcpTool) *httptest.Server { + return setupMockMcpServerWithCallHandler(t, name, tools, nil) +} + +func setupMockMcpServerWithCallHandler(t *testing.T, name string, tools []mcpTool, callHandler func(string, map[string]any) (string, bool, error)) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleMcpRequest(t, w, r, name, tools, callHandler) + })) + t.Cleanup(func() { + server.Close() + }) + return server +} + +func handleMcpRequest(t *testing.T, w http.ResponseWriter, r *http.Request, name string, tools []mcpTool, callHandler func(string, map[string]any) (string, bool, error)) { + w.Header().Set("Content-Type", "application/json") + + if name == "gitlab" { + assert.Contains(t, r.URL.Path, "/api/v4/mcp") + } + + if r.Method != "POST" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + var request map[string]any + err := json.NewDecoder(r.Body).Decode(&request) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + method, ok := request["method"].(string) + if !ok { + w.WriteHeader(http.StatusBadRequest) + return + } + + switch method { + case "initialize": + response := map[string]any{ + "jsonrpc": "2.0", + "id": request["id"], + "result": map[string]any{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]any{}, + "serverInfo": map[string]any{ + "name": "test-server", + "version": "1.0.0", + }, + }, + } + json.NewEncoder(w).Encode(response) + + case "notifications/initialized": + w.WriteHeader(http.StatusOK) + + case "tools/list": + var toolsList []map[string]any + for _, tool := range tools { + toolsList = append(toolsList, map[string]any{ + "name": tool.Name, + "description": tool.Description, + "inputSchema": map[string]any{ + "type": "object", + "properties": map[string]any{}, + }, + }) + } + response := map[string]any{ + "jsonrpc": "2.0", + "id": request["id"], + "result": map[string]any{ + "tools": toolsList, + }, + } + json.NewEncoder(w).Encode(response) + + case "tools/call": + if callHandler == nil { + w.WriteHeader(http.StatusNotImplemented) + return + } + params, ok := request["params"].(map[string]any) + if !ok { + w.WriteHeader(http.StatusBadRequest) + return + } + name, _ := params["name"].(string) + arguments, _ := params["arguments"].(map[string]any) + content, isError, err := callHandler(name, arguments) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(w).Encode(map[string]any{ + "jsonrpc": "2.0", + "id": request["id"], + "error": map[string]any{ + "code": -32603, + "message": err.Error(), + }, + }) + return + } + var resultContent []map[string]any + if content != "" { + resultContent = append(resultContent, map[string]any{ + "type": "text", + "text": content, + }) + } + response := map[string]any{ + "jsonrpc": "2.0", + "id": request["id"], + "result": map[string]any{ + "content": resultContent, + "isError": isError, + }, + } + json.NewEncoder(w).Encode(response) + + default: + w.WriteHeader(http.StatusNotImplemented) + } +} + +func TestBuildSession(t *testing.T) { + t.Run("builds session for a server", func(t *testing.T) { + mcpServer := setupMockMcpServer(t, "", []mcpTool{}) + + req := httptest.NewRequest("GET", "/test", nil) + + serverCfg := api.McpServerConfig{ + URL: mcpServer.URL, + Headers: map[string]string{}, + } + + session, err := buildSession(nil, req, "server-name", serverCfg) + + require.NoError(t, err) + require.NotNil(t, session) + assert.Equal(t, "server-name", session.name) + assert.NotNil(t, session.session) + }) + + t.Run("returns error on connection failure", func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + + serverCfg := api.McpServerConfig{ + URL: "http://localhost:1", // Use a port that's likely to be refused + Headers: map[string]string{}, + } + + session, err := buildSession(nil, req, "test-server", serverCfg) + + require.Error(t, err) + require.Nil(t, session) + }) +} + +func TestManager_buildTools(t *testing.T) { + t.Run("successfully builds tools from server", func(t *testing.T) { + mcpServer := setupMockMcpServer(t, "", []mcpTool{ + {Name: "tool1", Description: "Tool 1"}, + {Name: "tool2", Description: "Tool 2"}, + }) + + req := httptest.NewRequest("GET", "/test", nil) + + servers := map[string]api.McpServerConfig{ + "test-server": { + URL: mcpServer.URL, + Headers: map[string]string{}, + }, + } + + mgr, err := newMcpManager(nil, req, servers) + + require.NoError(t, err) + require.NotNil(t, mgr) + assert.Len(t, mgr.tools, 2) + assert.Equal(t, "test-server_tool1", mgr.tools[0].Name) + assert.Equal(t, "test-server_tool2", mgr.tools[1].Name) + }) + + t.Run("handles error from ListTools", func(t *testing.T) { + // Create a server that returns an error for tools/list + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + var request map[string]any + json.NewDecoder(r.Body).Decode(&request) + + switch request["method"].(string) { + case "initialize": + response := map[string]any{ + "jsonrpc": "2.0", + "id": request["id"], + "result": map[string]any{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]any{}, + "serverInfo": map[string]any{ + "name": "test-server", + "version": "1.0.0", + }, + }, + } + json.NewEncoder(w).Encode(response) + case "tools/list": + w.WriteHeader(http.StatusInternalServerError) + } + })) + defer server.Close() + + req := httptest.NewRequest("GET", "/test", nil) + + servers := map[string]api.McpServerConfig{ + "test-server": { + URL: server.URL, + Headers: map[string]string{}, + Tools: &([]string{}), + }, + } + + mgr, err := newMcpManager(nil, req, servers) + + require.Error(t, err) + require.NotNil(t, mgr) + assert.Contains(t, err.Error(), "failed to list tools") + }) +} diff --git a/workhorse/internal/ai_assist/duoworkflow/runner.go b/workhorse/internal/ai_assist/duoworkflow/runner.go index bb4ba461941d9992b7a2dfaed590400bc4868fad..d9508172610d2178277ebc0116772cfc93ccbe5c 100644 --- a/workhorse/internal/ai_assist/duoworkflow/runner.go +++ b/workhorse/internal/ai_assist/duoworkflow/runner.go @@ -55,6 +55,7 @@ type runner struct { wf workflowStream client *Client sendMu sync.Mutex + mcpManager mcpManager } func newRunner(conn websocketConn, rails *api.API, r *http.Request, cfg *api.DuoWorkflow) (*runner, error) { @@ -70,6 +71,12 @@ func newRunner(conn websocketConn, rails *api.API, r *http.Request, cfg *api.Duo return nil, fmt.Errorf("failed to initialize stream: %v", err) } + mcpManager, err := newMcpManager(rails, r, cfg.McpServers) + if err != nil { + // Log the error while the feature is in development + log.WithRequest(r).WithError(err).Info("failed to initialize MCP server(s)") + } + return &runner{ rails: rails, token: cfg.Headers["x-gitlab-oauth-token"], @@ -77,6 +84,7 @@ func newRunner(conn websocketConn, rails *api.API, r *http.Request, cfg *api.Duo conn: conn, wf: wf, client: client, + mcpManager: mcpManager, }, nil } @@ -154,7 +162,7 @@ func (r *runner) Close() error { r.sendMu.Lock() defer r.sendMu.Unlock() - return errors.Join(r.wf.CloseSend(), r.client.Close(), r.closeWebSocketConnection()) + return errors.Join(r.wf.CloseSend(), r.client.Close(), r.closeWebSocketConnection(), r.mcpManager.Close()) } func (r *runner) closeWebSocketConnection() error { @@ -189,6 +197,10 @@ func (r *runner) handleWebSocketMessage(message []byte) error { return fmt.Errorf("handleWebSocketMessage: failed to unmarshal a WS message: %v", err) } + if startReq := response.GetStartRequest(); startReq != nil { + startReq.McpTools = append(startReq.McpTools, r.mcpManager.Tools()...) + } + log.WithContextFields(r.originalReq.Context(), log.Fields{ "payload_size": proto.Size(response), "event_type": fmt.Sprintf("%T", response.Response), @@ -221,8 +233,8 @@ func (r *runner) handleAgentAction(ctx context.Context, action *pb.Action) error if err != nil { return fmt.Errorf("handleAgentAction: failed to perform API call: %v", err) } - statusCode := event.GetActionResponse().GetHttpResponse().StatusCode + log.WithContextFields(r.originalReq.Context(), log.Fields{ "path": action.GetRunHTTPRequest().Path, "method": action.GetRunHTTPRequest().Method, @@ -232,13 +244,39 @@ func (r *runner) handleAgentAction(ctx context.Context, action *pb.Action) error "action_response_type": fmt.Sprintf("%T", event.GetActionResponse().GetResponseType()), "request_id": action.GetRequestID(), }).Info("Sending HTTP response event") + if err := r.threadSafeSend(event); err != nil { return fmt.Errorf("handleAgentAction: failed to send gRPC message: %v", err) } + log.WithContextFields(r.originalReq.Context(), log.Fields{ "path": action.GetRunHTTPRequest().Path, }).Info("Successfully sent HTTP response event") + case *pb.Action_RunMCPTool: + mcpTool := action.GetRunMCPTool() + // If a tool is not recongnized, propagate the message to the client + // It's possible when a user has local MCP servers configured in IDE + if !r.mcpManager.HasTool(mcpTool.Name) { + return r.sendActionToWs(action) + } + event, err := r.mcpManager.CallTool(ctx, action) + if err != nil { + return fmt.Errorf("handleAgentAction: failed to call MCP tool: %v", err) + } + + log.WithContextFields(ctx, log.Fields{ + "request_id": action.GetRequestID(), + "name": mcpTool.Name, + "args_size": len(mcpTool.Args), + "payload_size": proto.Size(event), + "event_type": fmt.Sprintf("%T", event.Response), + "action_response_type": fmt.Sprintf("%T", event.GetActionResponse().GetResponseType()), + }).Info("Sending MCP tool response") + + if err := r.threadSafeSend(event); err != nil { + return fmt.Errorf("handleAgentAction: failed to send gRPC message: %v", err) + } default: return r.sendActionToWs(action) } diff --git a/workhorse/internal/ai_assist/duoworkflow/runner_test.go b/workhorse/internal/ai_assist/duoworkflow/runner_test.go index a51a0defff37f2cfed1daad82d9d81d6ec6f9439..01bf75c5fbcd022f8cd4c6bd4e736906d5d27b4b 100644 --- a/workhorse/internal/ai_assist/duoworkflow/runner_test.go +++ b/workhorse/internal/ai_assist/duoworkflow/runner_test.go @@ -116,6 +116,56 @@ func (m *mockWorkflowStream) CloseSend() error { return nil } +type mockMcpManager struct { + tools []*pb.McpTool + hasToolResult bool + callToolResult *pb.ClientEvent + callToolError error + closeError error + callToolInvocations []struct { + name string + args string + } +} + +func (m *mockMcpManager) HasTool(_ string) bool { + if m == nil { + return false + } + + return m.hasToolResult +} + +func (m *mockMcpManager) Tools() []*pb.McpTool { + if m == nil { + return nil + } + + return m.tools +} + +func (m *mockMcpManager) CallTool(_ context.Context, action *pb.Action) (*pb.ClientEvent, error) { + mcpTool := action.GetRunMCPTool() + + m.callToolInvocations = append(m.callToolInvocations, struct { + name string + args string + }{name: mcpTool.Name, args: mcpTool.Args}) + + if m.callToolError != nil { + return nil, m.callToolError + } + return m.callToolResult, nil +} + +func (m *mockMcpManager) Close() error { + if m == nil { + return nil + } + + return m.closeError +} + func Test_newRunner(t *testing.T) { server := setupTestServer(t) mockConn := &mockWebSocketConn{} @@ -337,7 +387,9 @@ func TestRunner_handleWebSocketMessage(t *testing.T) { name string message []byte sendError error + mcpManager *mockMcpManager expectedErrMsg string + expectMcpTools bool }{ { name: "invalid json", @@ -361,6 +413,24 @@ func TestRunner_handleWebSocketMessage(t *testing.T) { message: []byte(`{"type": "test"}`), expectedErrMsg: "", }, + { + name: "start request with mcp tools", + message: []byte(`{"startRequest": {"goal": "test goal", "mcpTools": [{"name": "get_issue"}]}}`), + mcpManager: &mockMcpManager{ + tools: []*pb.McpTool{ + {Name: "test_tool", Description: "A test tool"}, + }, + }, + expectMcpTools: true, + expectedErrMsg: "", + }, + { + name: "start request without mcp manager", + message: []byte(`{"startRequest": {"goal": "test goal"}}`), + mcpManager: nil, + expectMcpTools: false, + expectedErrMsg: "", + }, } for _, tt := range tests { @@ -379,6 +449,7 @@ func TestRunner_handleWebSocketMessage(t *testing.T) { originalReq: &http.Request{}, conn: &mockWebSocketConn{}, wf: mockWf, + mcpManager: tt.mcpManager, } err := r.handleWebSocketMessage(tt.message) @@ -388,6 +459,16 @@ func TestRunner_handleWebSocketMessage(t *testing.T) { require.Contains(t, err.Error(), tt.expectedErrMsg) } else { require.NoError(t, err) + + if tt.expectMcpTools { + require.Len(t, mockWf.sendEvents, 1) + startReq := mockWf.sendEvents[0].GetStartRequest() + require.NotNil(t, startReq) + require.Len(t, startReq.McpTools, 2) + assert.Equal(t, "get_issue", startReq.McpTools[0].Name) + assert.Equal(t, "test_tool", startReq.McpTools[1].Name) + assert.Equal(t, "A test tool", startReq.McpTools[1].Description) + } } }) } @@ -399,9 +480,11 @@ func TestRunner_handleAgentAction(t *testing.T) { action *pb.Action wsWriteError error wfSendError error + mcpManager *mockMcpManager expectedErrMsg string shouldCallWS bool shouldCallWF bool + shouldCallMcp bool }{ { name: "successful HTTP request action", @@ -475,6 +558,82 @@ func TestRunner_handleAgentAction(t *testing.T) { }, shouldCallWS: true, }, + { + name: "MCP tool action with mcp manager", + action: &pb.Action{ + RequestID: "req-mcp-123", + Action: &pb.Action_RunMCPTool{ + RunMCPTool: &pb.RunMCPTool{ + Name: "gitlab_get_issue", + Args: `{"issue_id": "123"}`, + }, + }, + }, + mcpManager: &mockMcpManager{ + hasToolResult: true, + callToolResult: &pb.ClientEvent{ + Response: &pb.ClientEvent_ActionResponse{ + ActionResponse: &pb.ActionResponse{ + ResponseType: &pb.ActionResponse_PlainTextResponse{ + PlainTextResponse: &pb.PlainTextResponse{ + Response: `{"id": 123, "title": "Test Issue"}`, + }, + }, + }, + }, + }, + }, + shouldCallMcp: true, + shouldCallWF: true, + }, + { + name: "MCP tool action without mcp manager", + action: &pb.Action{ + RequestID: "req-mcp-no-manager", + Action: &pb.Action_RunMCPTool{ + RunMCPTool: &pb.RunMCPTool{ + Name: "gitlab_get_issue", + Args: `{"issue_id": "123"}`, + }, + }, + }, + mcpManager: nil, + shouldCallWS: true, + }, + { + name: "MCP tool action with tool not recognized", + action: &pb.Action{ + RequestID: "req-mcp-unknown", + Action: &pb.Action_RunMCPTool{ + RunMCPTool: &pb.RunMCPTool{ + Name: "unknown_tool", + Args: `{"param": "value"}`, + }, + }, + }, + mcpManager: &mockMcpManager{ + hasToolResult: false, + }, + shouldCallWS: true, + }, + { + name: "MCP tool action with call error", + action: &pb.Action{ + RequestID: "req-mcp-error", + Action: &pb.Action_RunMCPTool{ + RunMCPTool: &pb.RunMCPTool{ + Name: "gitlab_get_issue", + Args: `{"issue_id": "123"}`, + }, + }, + }, + mcpManager: &mockMcpManager{ + hasToolResult: true, + callToolError: errors.New("mcp call failed"), + }, + shouldCallMcp: true, + expectedErrMsg: "handleAgentAction: failed to call MCP tool: mcp call failed", + }, } for _, tt := range tests { @@ -507,6 +666,7 @@ func TestRunner_handleAgentAction(t *testing.T) { originalReq: &http.Request{}, conn: mockConn, wf: mockWf, + mcpManager: tt.mcpManager, } ctx := context.Background() @@ -528,12 +688,21 @@ func TestRunner_handleAgentAction(t *testing.T) { if tt.shouldCallWF { require.Len(t, sendEvents, 1, "Expected one workflow event to be sent") - response := sendEvents[0].Response.(*pb.ClientEvent_ActionResponse).ActionResponse - responseBody := response.ResponseType.(*pb.ActionResponse_HttpResponse).HttpResponse.Body - require.JSONEq(t, `[{"id": 123, "name": "test-project"}]`, responseBody) + if tt.action.GetRunHTTPRequest() != nil { + response := mockWf.sendEvents[0].Response.(*pb.ClientEvent_ActionResponse).ActionResponse + responseBody := response.ResponseType.(*pb.ActionResponse_HttpResponse).HttpResponse.Body + require.JSONEq(t, `[{"id": 123, "name": "test-project"}]`, responseBody) + } } else { require.Empty(t, sendEvents) } + + if tt.shouldCallMcp { + require.Len(t, tt.mcpManager.callToolInvocations, 1, "Expected MCP tool to be called") + mcpAction := tt.action.GetRunMCPTool() + require.Equal(t, mcpAction.Name, tt.mcpManager.callToolInvocations[0].name) + require.Equal(t, mcpAction.Args, tt.mcpManager.callToolInvocations[0].args) + } }) } } @@ -601,3 +770,64 @@ func TestRunner_closeWebSocketConnection(t *testing.T) { }) } } + +func TestRunner_sendActionToWs(t *testing.T) { + tests := []struct { + name string + action *pb.Action + writeError error + expectedErrMsg string + }{ + { + name: "successful send", + action: &pb.Action{ + RequestID: "req-123", + Action: &pb.Action_RunCommand{ + RunCommand: &pb.RunCommandAction{ + Program: "ls", + }, + }, + }, + expectedErrMsg: "", + }, + { + name: "write error", + action: &pb.Action{ + RequestID: "req-456", + Action: &pb.Action_RunCommand{ + RunCommand: &pb.RunCommandAction{ + Program: "ls", + }, + }, + }, + writeError: errors.New("write failed"), + expectedErrMsg: "sendActionToWs: failed to send WS message: write failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockConn := &mockWebSocketConn{ + writeError: tt.writeError, + } + + testURL, _ := url.Parse("http://example.com") + r := &runner{ + rails: &api.API{ + Client: &http.Client{}, + URL: testURL, + }, + conn: mockConn, + } + + err := r.sendActionToWs(tt.action) + + if tt.expectedErrMsg != "" { + require.EqualError(t, err, tt.expectedErrMsg) + } else { + require.NoError(t, err) + require.Len(t, mockConn.writeMessages, 1) + } + }) + } +} diff --git a/workhorse/internal/api/api.go b/workhorse/internal/api/api.go index 3026e22beddf7546597136329f4664e44fa28203..ec562a3b0c0ec496942bf7b5a0588fe1009b61fa 100644 --- a/workhorse/internal/api/api.go +++ b/workhorse/internal/api/api.go @@ -142,11 +142,19 @@ type RemoteObject struct { ObjectStorage *ObjectStorageParams } +// McpServerConfig holds configuration for MCP servers configured in GitLab Rails +type McpServerConfig struct { + URL string + Headers map[string]string + Tools *[]string +} + // DuoWorkflow holds configuration for the Duo Workflow service. type DuoWorkflow struct { Headers map[string]string ServiceURI string Secure bool + McpServers map[string]McpServerConfig } // Response represents a structure containing various GitLab-related environment variables.