From 64df6cca199667188ba1850f07390fdec8cab399 Mon Sep 17 00:00:00 2001 From: Kai Armstrong Date: Thu, 23 Jan 2025 16:38:11 -0600 Subject: [PATCH 1/2] broken duo chat --- commands/duo/chat/chat.go | 303 ++++++++++++++++++++++++++++++++++++++ commands/duo/duo.go | 2 + go.mod | 6 + go.sum | 4 + 4 files changed, 315 insertions(+) create mode 100644 commands/duo/chat/chat.go diff --git a/commands/duo/chat/chat.go b/commands/duo/chat/chat.go new file mode 100644 index 000000000..329adb37b --- /dev/null +++ b/commands/duo/chat/chat.go @@ -0,0 +1,303 @@ +package chat + +import ( + "bufio" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "strings" + "time" + + "github.com/hasura/go-graphql-client" + "github.com/spf13/cobra" + gitlab "gitlab.com/gitlab-org/api/client-go" + "gitlab.com/gitlab-org/cli/commands/cmdutils" + "gitlab.com/gitlab-org/cli/internal/config" + "gitlab.com/gitlab-org/cli/pkg/iostreams" +) + +type chatOpts struct { + IO *iostreams.IOStreams + HttpClient func() (*gitlab.Client, error) + Config func() (config.Config, error) +} + +type aiCompletionResponse struct { + ID string `json:"id"` + RequestID string `json:"requestId"` + Content string `json:"content"` + Errors []string `json:"errors"` + Role string `json:"role"` + Timestamp time.Time `json:"timestamp"` + ChunkID int `json:"chunkId"` +} + +const ( + spinnerText = "Connecting to GitLab Duo Chat..." + apiUnreachableErr = "Error: API is unreachable." + wsConnectErr = "Error: Failed to connect to WebSocket." + wsReadErr = "Error: Failed to read WebSocket message." + wsSendErr = "Error: Failed to send WebSocket message." + maxRetries = 5 + retryDelay = 5 * time.Second +) + +func NewCmdChat(f *cmdutils.Factory) *cobra.Command { + opts := &chatOpts{ + IO: f.IO, + HttpClient: f.HttpClient, + Config: f.Config, + } + + chatCmd := &cobra.Command{ + Use: "chat", + Short: "Start an interactive chat session with GitLab Duo", + RunE: func(cmd *cobra.Command, args []string) error { + return runChatSession(opts) + }, + } + + return chatCmd +} + +func runChatSession(opts *chatOpts) error { + opts.IO.StartSpinner(spinnerText) + defer opts.IO.StopSpinner("") + + client, err := opts.HttpClient() + if err != nil { + return cmdutils.WrapError(err, "failed to get HTTP client") + } + + baseURL := client.BaseURL() + wsURL := *baseURL + wsURL.Path = "/-/cable" + wsURL.Scheme = "wss" + + opts.IO.LogInfo(fmt.Sprintf("Attempting to connect to WebSocket URL: %s\n", wsURL.String())) + + // Try making an HTTP request to the WebSocket URL + httpURL := wsURL + httpURL.Scheme = "https" + resp, err := http.Get(httpURL.String()) + if err != nil { + opts.IO.LogInfo(fmt.Sprintf("Error making HTTP request to WebSocket URL: %v\n", err)) + } else { + defer resp.Body.Close() + body, _ := ioutil.ReadAll(resp.Body) + opts.IO.LogInfo(fmt.Sprintf("HTTP response status: %s\n", resp.Status)) + opts.IO.LogInfo(fmt.Sprintf("HTTP response body: %s\n", string(body))) + } + + cfg, err := opts.Config() + if err != nil { + return cmdutils.WrapError(err, "failed to get config") + } + token, _ := cfg.Get(baseURL.Host, "token") + + opts.IO.LogInfo(fmt.Sprintf("Using token: %s\n", maskToken(token))) + + headers := http.Header{ + "Origin": {baseURL.String()}, + "Sec-WebSocket-Protocol": {"actioncable-v1-json", "actioncable-unsupported"}, + "Authorization": {"Bearer " + token}, + } + + opts.IO.LogInfo("WebSocket Headers:") + for k, v := range headers { + opts.IO.LogInfo(fmt.Sprintf("%s: %s\n", k, v)) + } + + subscriptionClient := graphql.NewSubscriptionClient(wsURL.String()). + WithConnectionParams(map[string]interface{}{ + "headers": headers, + }). + WithLog(func(args ...interface{}) { + opts.IO.LogInfo(fmt.Sprintf("WebSocket Log: %v\n", args)) + }). + WithWebSocketOptions(graphql.WebsocketOptions{ + HTTPHeader: headers, + }). + WithRetryTimeout(time.Minute). + WithRetryStatusCodes("4000-4999") + + opts.IO.StopSpinner("") + opts.IO.LogInfo("Attempting to establish WebSocket connection...\n") + + subscriptionId := generateUniqueID() + + var subscription struct { + AiCompletionResponse struct { + ID string + RequestID string + Content string + Errors []string + Role string + Timestamp time.Time + Type string + ChunkID int + Extras struct { + Sources []string + TypeName string `json:"__typename"` + } + TypeName string `json:"__typename"` + } `graphql:"aiCompletionResponse(userId: $userId, aiAction: $aiAction, clientSubscriptionId: $clientSubscriptionId)"` + } + + variables := map[string]interface{}{ + "userId": graphql.ID(subscriptionId), + "aiAction": "CHAT", + "clientSubscriptionId": subscriptionId, + "htmlResponse": true, + } + + subID, err := subscriptionClient.Subscribe(&subscription, variables, func(data []byte, err error) error { + if err != nil { + opts.IO.LogInfo(fmt.Sprintf("Subscription error: %v\n", err)) + return nil + } + + opts.IO.LogInfo(fmt.Sprintf("Received data: %s\n", string(data))) + + var response struct { + AiCompletionResponse aiCompletionResponse + } + if err := json.Unmarshal(data, &response); err != nil { + opts.IO.LogInfo(fmt.Sprintf("Error unmarshaling response: %v\n", err)) + return nil + } + + displayFormattedResponse(opts, response.AiCompletionResponse.Content) + return nil + }) + + if err != nil { + return cmdutils.WrapError(err, "failed to set up subscription") + } + + opts.IO.LogInfo("Connected! Type 'exit' or 'quit' to end the session.\n") + + errChan := make(chan error, 1) + go func() { + for i := 0; i < maxRetries; i++ { + if err := subscriptionClient.Run(); err != nil { + opts.IO.LogInfo(fmt.Sprintf("Subscription client error: %v\n", err)) + opts.IO.LogInfo(fmt.Sprintf("Retrying in %v seconds...\n", retryDelay.Seconds())) + time.Sleep(retryDelay) + } else { + break + } + } + errChan <- fmt.Errorf("max retries reached, subscription client stopped") + }() + + reader := bufio.NewReader(opts.IO.In) + for { + select { + case err := <-errChan: + opts.IO.LogInfo(fmt.Sprintf("Subscription client stopped: %v\n", err)) + return err + default: + opts.IO.LogInfo(opts.IO.Color().Bold("You: ")) + input, err := reader.ReadString('\n') + if err != nil { + return cmdutils.WrapError(err, "failed to read input") + } + + input = strings.TrimSpace(input) + if input == "exit" || input == "quit" { + subscriptionClient.Unsubscribe(subID) + subscriptionClient.Close() + return nil + } + + if input == "" { + continue + } + + if err := sendChatMessage(subscriptionClient, input, subscriptionId); err != nil { + opts.IO.LogInfo(fmt.Sprintf("Error sending message: %v\n", err)) + } else { + opts.IO.LogInfo("Message sent successfully\n") + } + } + } +} + +func sendChatMessage(client *graphql.SubscriptionClient, content string, subscriptionId string) error { + mutation := ` + mutation chat($input: AiActionInput!) { + aiAction(input: $input) { + requestId + errors + __typename + } + } + ` + + variables := map[string]interface{}{ + "input": map[string]interface{}{ + "chat": map[string]interface{}{ + "content": content, + }, + "clientSubscriptionId": subscriptionId, + "conversationType": "DUO_CHAT", + "platformOrigin": "cli", + }, + } + + var response struct { + AiAction struct { + RequestID string `json:"requestId"` + Errors []string `json:"errors"` + TypeName string `json:"__typename"` + } `json:"aiAction"` + } + + _, err := client.Exec(mutation, variables, func(message []byte, err error) error { + if err != nil { + return fmt.Errorf("error in Exec callback: %v", err) + } + return json.Unmarshal(message, &response) + }) + + if err != nil { + return fmt.Errorf("failed to execute mutation: %v", err) + } + + if len(response.AiAction.Errors) > 0 { + return fmt.Errorf("mutation errors: %v", response.AiAction.Errors) + } + + return nil +} + +func displayFormattedResponse(opts *chatOpts, content string) { + color := opts.IO.Color() + opts.IO.LogInfo(color.Bold("GitLab Duo: ")) + + paragraphs := strings.Split(content, "\n\n") + for _, paragraph := range paragraphs { + paragraph = strings.ReplaceAll(paragraph, "```", color.Cyan("```")) + paragraph = strings.ReplaceAll(paragraph, "`", color.Cyan("`")) + + for _, pattern := range []string{"Note:", "Important:", "Warning:"} { + paragraph = strings.ReplaceAll(paragraph, pattern, color.Bold(pattern)) + } + + opts.IO.LogInfo(paragraph + "\n") + } + opts.IO.LogInfo("\n") +} + +func generateUniqueID() string { + return fmt.Sprintf("%d", time.Now().UnixNano()) +} + +func maskToken(token string) string { + if len(token) > 8 { + return token[:4] + "..." + token[len(token)-4:] + } + return "****" +} diff --git a/commands/duo/duo.go b/commands/duo/duo.go index 79f7f1292..3a65faaab 100644 --- a/commands/duo/duo.go +++ b/commands/duo/duo.go @@ -3,6 +3,7 @@ package duo import ( "gitlab.com/gitlab-org/cli/commands/cmdutils" duoAskCmd "gitlab.com/gitlab-org/cli/commands/duo/ask" + duoChatCmd "gitlab.com/gitlab-org/cli/commands/duo/chat" "github.com/spf13/cobra" ) @@ -15,6 +16,7 @@ func NewCmdDuo(f *cmdutils.Factory) *cobra.Command { } duoCmd.AddCommand(duoAskCmd.NewCmdAsk(f)) + duoCmd.AddCommand(duoChatCmd.NewCmdChat(f)) return duoCmd } diff --git a/go.mod b/go.mod index 5df45b8a3..4a192663a 100644 --- a/go.mod +++ b/go.mod @@ -49,6 +49,11 @@ require ( k8s.io/client-go v0.32.1 ) +require ( + github.com/coder/websocket v1.8.12 // indirect + github.com/google/uuid v1.6.0 // indirect +) + require ( al.essio.dev/pkg/shellescape v1.5.1 // indirect github.com/alecthomas/chroma/v2 v2.14.0 // indirect @@ -73,6 +78,7 @@ require ( github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/hcl v1.0.0 // indirect + github.com/hasura/go-graphql-client v0.13.1 github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect diff --git a/go.sum b/go.sum index f99b91324..23745b1a9 100644 --- a/go.sum +++ b/go.sum @@ -32,6 +32,8 @@ github.com/charmbracelet/x/ansi v0.1.4 h1:IEU3D6+dWwPSgZ6HBH+v6oUuZ/nVawMiWj5831 github.com/charmbracelet/x/ansi v0.1.4/go.mod h1:dk73KoMTT5AX5BsX0KrqhsTqAnhZZoCBjs7dGWp4Ktw= github.com/charmbracelet/x/exp/golden v0.0.0-20240715153702-9ba8adf781c4 h1:6KzMkQeAF56rggw2NZu1L+TH7j9+DM1/2Kmh7KUxg1I= github.com/charmbracelet/x/exp/golden v0.0.0-20240715153702-9ba8adf781c4/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U= +github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= +github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/cpuguy83/go-md2man/v2 v2.0.4 h1:wfIWP927BUkWJb2NmU/kNDYIBTh/ziUX91+lVfRxZq4= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.17 h1:QeVUsEDNrLBW4tMgZHvxy18sKtr6VI492kBhUfhDJNI= @@ -110,6 +112,8 @@ github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKe github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/hasura/go-graphql-client v0.13.1 h1:kKbjhxhpwz58usVl+Xvgah/TDha5K2akNTRQdsEHN6U= +github.com/hasura/go-graphql-client v0.13.1/go.mod h1:k7FF7h53C+hSNFRG3++DdVZWIuHdCaTbI7siTJ//zGQ= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog= -- GitLab From 469eb06310968b2706c65facdcd58b40110f06d9 Mon Sep 17 00:00:00 2001 From: Kai Armstrong Date: Mon, 27 Jan 2025 16:22:12 -0600 Subject: [PATCH 2/2] working, but still broken --- commands/duo/chat/chat.go | 734 ++++++++++++++++++++++++++++---------- 1 file changed, 544 insertions(+), 190 deletions(-) diff --git a/commands/duo/chat/chat.go b/commands/duo/chat/chat.go index 329adb37b..b855d673e 100644 --- a/commands/duo/chat/chat.go +++ b/commands/duo/chat/chat.go @@ -2,14 +2,19 @@ package chat import ( "bufio" + "context" "encoding/json" + "errors" "fmt" - "io/ioutil" + "io" + "log" "net/http" + "os" "strings" "time" - "github.com/hasura/go-graphql-client" + "github.com/coder/websocket" + "github.com/google/uuid" "github.com/spf13/cobra" gitlab "gitlab.com/gitlab-org/api/client-go" "gitlab.com/gitlab-org/cli/commands/cmdutils" @@ -17,31 +22,46 @@ import ( "gitlab.com/gitlab-org/cli/pkg/iostreams" ) +const ( + gitlabBaseURL = "https://gitlab.com" + gitlabAPIURL = gitlabBaseURL + "/api/v4" + gitlabGraphQLURL = gitlabBaseURL + "/api/graphql" + gitlabWSURL = "wss://gitlab.com/-/cable" +) + type chatOpts struct { IO *iostreams.IOStreams HttpClient func() (*gitlab.Client, error) Config func() (config.Config, error) } -type aiCompletionResponse struct { - ID string `json:"id"` - RequestID string `json:"requestId"` - Content string `json:"content"` - Errors []string `json:"errors"` - Role string `json:"role"` - Timestamp time.Time `json:"timestamp"` - ChunkID int `json:"chunkId"` +type DuoChatClient struct { + token string + userID string + conn *websocket.Conn + responses chan CompletionResponse + IO *iostreams.IOStreams + debugLogger *log.Logger + debugFile *os.File } -const ( - spinnerText = "Connecting to GitLab Duo Chat..." - apiUnreachableErr = "Error: API is unreachable." - wsConnectErr = "Error: Failed to connect to WebSocket." - wsReadErr = "Error: Failed to read WebSocket message." - wsSendErr = "Error: Failed to send WebSocket message." - maxRetries = 5 - retryDelay = 5 * time.Second -) +type ActionCableMessage struct { + Type string `json:"type,omitempty"` + Command string `json:"command,omitempty"` + Identifier string `json:"identifier,omitempty"` + Message json.RawMessage `json:"message,omitempty"` +} + +type CompletionResponse struct { + ID string `json:"id"` + RequestID string `json:"requestId"` + Content string `json:"content"` + Errors []string `json:"errors"` + Role string `json:"role"` + Timestamp string `json:"timestamp"` + Type *string `json:"type"` + ChunkID *int `json:"chunkId"` +} func NewCmdChat(f *cmdutils.Factory) *cobra.Command { opts := &chatOpts{ @@ -50,254 +70,588 @@ func NewCmdChat(f *cmdutils.Factory) *cobra.Command { Config: f.Config, } + var debug bool + chatCmd := &cobra.Command{ Use: "chat", Short: "Start an interactive chat session with GitLab Duo", RunE: func(cmd *cobra.Command, args []string) error { - return runChatSession(opts) + return runChatSession(opts, debug) }, } + chatCmd.Flags().BoolVar(&debug, "debug", false, "Enable debug logging") + return chatCmd } -func runChatSession(opts *chatOpts) error { - opts.IO.StartSpinner(spinnerText) +func runChatSession(opts *chatOpts, debug bool) error { + opts.IO.StartSpinner("Connecting to GitLab Duo Chat...") defer opts.IO.StopSpinner("") - client, err := opts.HttpClient() + cfg, err := opts.Config() + if err != nil { + return cmdutils.WrapError(err, "failed to get config") + } + token, _ := cfg.Get(gitlabBaseURL, "token") + + client, err := NewDuoChatClient(context.Background(), token, opts.IO, debug) if err != nil { - return cmdutils.WrapError(err, "failed to get HTTP client") + return cmdutils.WrapError(err, "failed to create Duo Chat client") } + defer client.Close() + + subID := uuid.New().String() + if err := client.Subscribe(context.Background(), subID); err != nil { + return cmdutils.WrapError(err, "failed to subscribe") + } + + opts.IO.StopSpinner("") + opts.IO.LogInfo("Connected! Type 'exit' or 'quit' to end the session.\n") + + reader := bufio.NewReader(opts.IO.In) + inputChan := make(chan string, 1) + errChan := make(chan error, 1) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + inputReady := make(chan struct{}, 1) + inputReady <- struct{}{} // Initially allow input + + go func() { + for { + <-inputReady // Wait for input to be allowed + opts.IO.LogInfo("\n" + opts.IO.Color().Green("You: ")) + input, err := reader.ReadString('\n') + if err != nil { + errChan <- cmdutils.WrapError(err, "failed to read input") + return + } + input = strings.TrimSpace(input) + if input != "" { + client.debugLog("Sending input to channel: '%s'", input) + inputChan <- input + } + } + }() + + for { + select { + case input := <-inputChan: + if strings.ToLower(input) == "exit" || strings.ToLower(input) == "quit" { + opts.IO.LogInfo("Ending chat session...\n") + client.debugLog("Exit condition met, returning from runChatSession") + return nil + } + + responseCtx, cancelResponse := context.WithTimeout(ctx, 30*time.Second) + defer cancelResponse() + + if err := client.SendPrompt(responseCtx, input, subID); err != nil { + client.debugLog("Error sending prompt: %v", err) + opts.IO.LogInfo(fmt.Sprintf("Error sending message: %v\n", err)) + continue + } + + client.debugLog("Processing responses...") + + // Disable input while waiting for response + select { + case <-inputReady: + // Input was ready, now disable it + default: + // Input was already disabled, do nothing + } - baseURL := client.BaseURL() - wsURL := *baseURL - wsURL.Path = "/-/cable" - wsURL.Scheme = "wss" + opts.IO.StartSpinner("Waiting for GitLab Duo response...") + opts.IO.LogInfo("\n" + opts.IO.Color().Cyan("GitLab Duo: ")) + + responseDone := make(chan struct{}) + go func() { + defer close(responseDone) + if err := client.ProcessResponses(responseCtx); err != nil { + if err != context.Canceled { + opts.IO.LogInfo(fmt.Sprintf("\nError processing responses: %v\n", err)) + if err := client.reconnect(ctx); err != nil { + opts.IO.LogInfo(fmt.Sprintf("Failed to reconnect: %v\n", err)) + } + } + } + }() + + select { + case <-responseDone: + opts.IO.StopSpinner("") + opts.IO.LogInfo("\n") // Add a newline after GitLab Duo's response + case <-responseCtx.Done(): + opts.IO.StopSpinner("Response timed out") + } + + // Re-enable input after response + select { + case inputReady <- struct{}{}: + // Enable input + default: + // Input was already enabled, do nothing + } + + cancelResponse() + client.debugLog("Finished processing responses") + + case err := <-errChan: + return err - opts.IO.LogInfo(fmt.Sprintf("Attempting to connect to WebSocket URL: %s\n", wsURL.String())) + case <-time.After(100 * time.Millisecond): + // This case prevents the select from blocking indefinitely + continue + } + } +} - // Try making an HTTP request to the WebSocket URL - httpURL := wsURL - httpURL.Scheme = "https" - resp, err := http.Get(httpURL.String()) +func NewDuoChatClient(ctx context.Context, token string, io *iostreams.IOStreams, debug bool) (*DuoChatClient, error) { + userID, err := fetchUserID(ctx, token) if err != nil { - opts.IO.LogInfo(fmt.Sprintf("Error making HTTP request to WebSocket URL: %v\n", err)) - } else { - defer resp.Body.Close() - body, _ := ioutil.ReadAll(resp.Body) - opts.IO.LogInfo(fmt.Sprintf("HTTP response status: %s\n", resp.Status)) - opts.IO.LogInfo(fmt.Sprintf("HTTP response body: %s\n", string(body))) + return nil, fmt.Errorf("fetch user ID: %w", err) } - cfg, err := opts.Config() + conn, err := setupWebSocket(ctx, token) if err != nil { - return cmdutils.WrapError(err, "failed to get config") + return nil, fmt.Errorf("setup websocket: %w", err) } - token, _ := cfg.Get(baseURL.Host, "token") - opts.IO.LogInfo(fmt.Sprintf("Using token: %s\n", maskToken(token))) + var debugLogger *log.Logger + var debugFile *os.File - headers := http.Header{ - "Origin": {baseURL.String()}, - "Sec-WebSocket-Protocol": {"actioncable-v1-json", "actioncable-unsupported"}, - "Authorization": {"Bearer " + token}, + if debug { + debugFile, err = os.OpenFile("duo_chat_debug.log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return nil, fmt.Errorf("create debug log file: %w", err) + } + debugLogger = log.New(debugFile, "DEBUG: ", log.Ldate|log.Ltime|log.Lshortfile) } - opts.IO.LogInfo("WebSocket Headers:") - for k, v := range headers { - opts.IO.LogInfo(fmt.Sprintf("%s: %s\n", k, v)) + return &DuoChatClient{ + token: token, + userID: userID, + conn: conn, + responses: make(chan CompletionResponse, 100), + IO: io, + debugLogger: debugLogger, + debugFile: debugFile, + }, nil +} + +func (c *DuoChatClient) Close() error { + if c.debugFile != nil { + if err := c.debugFile.Close(); err != nil { + return fmt.Errorf("close debug file: %w", err) + } } + return c.conn.Close(websocket.StatusNormalClosure, "") +} - subscriptionClient := graphql.NewSubscriptionClient(wsURL.String()). - WithConnectionParams(map[string]interface{}{ - "headers": headers, - }). - WithLog(func(args ...interface{}) { - opts.IO.LogInfo(fmt.Sprintf("WebSocket Log: %v\n", args)) - }). - WithWebSocketOptions(graphql.WebsocketOptions{ - HTTPHeader: headers, - }). - WithRetryTimeout(time.Minute). - WithRetryStatusCodes("4000-4999") +func (c *DuoChatClient) debugLog(format string, v ...interface{}) { + if c.debugLogger != nil { + c.debugLogger.Printf(format, v...) + } +} - opts.IO.StopSpinner("") - opts.IO.LogInfo("Attempting to establish WebSocket connection...\n") - - subscriptionId := generateUniqueID() - - var subscription struct { - AiCompletionResponse struct { - ID string - RequestID string - Content string - Errors []string - Role string - Timestamp time.Time - Type string - ChunkID int - Extras struct { - Sources []string - TypeName string `json:"__typename"` - } - TypeName string `json:"__typename"` - } `graphql:"aiCompletionResponse(userId: $userId, aiAction: $aiAction, clientSubscriptionId: $clientSubscriptionId)"` +func fetchUserID(ctx context.Context, token string) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, gitlabAPIURL+"/user", nil) + if err != nil { + return "", fmt.Errorf("create request: %w", err) } - variables := map[string]interface{}{ - "userId": graphql.ID(subscriptionId), - "aiAction": "CHAT", - "clientSubscriptionId": subscriptionId, - "htmlResponse": true, + req.Header.Set("Authorization", "Bearer "+token) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("send request: %w", err) } + defer resp.Body.Close() - subID, err := subscriptionClient.Subscribe(&subscription, variables, func(data []byte, err error) error { - if err != nil { - opts.IO.LogInfo(fmt.Sprintf("Subscription error: %v\n", err)) - return nil - } + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("API error: status=%d body=%s", resp.StatusCode, body) + } - opts.IO.LogInfo(fmt.Sprintf("Received data: %s\n", string(data))) + var userData struct { + ID int `json:"id"` + } - var response struct { - AiCompletionResponse aiCompletionResponse - } - if err := json.Unmarshal(data, &response); err != nil { - opts.IO.LogInfo(fmt.Sprintf("Error unmarshaling response: %v\n", err)) - return nil - } + if err := json.NewDecoder(resp.Body).Decode(&userData); err != nil { + return "", fmt.Errorf("decode response: %w", err) + } - displayFormattedResponse(opts, response.AiCompletionResponse.Content) - return nil - }) + return fmt.Sprintf("gid://gitlab/User/%d", userData.ID), nil +} +func setupWebSocket(ctx context.Context, token string) (*websocket.Conn, error) { + dialCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + c, _, err := websocket.Dial(dialCtx, gitlabWSURL, &websocket.DialOptions{ + HTTPHeader: http.Header{ + "Authorization": {"Bearer " + token}, + "Origin": {gitlabBaseURL}, + }, + }) if err != nil { - return cmdutils.WrapError(err, "failed to set up subscription") + return nil, fmt.Errorf("websocket dial: %w", err) } - opts.IO.LogInfo("Connected! Type 'exit' or 'quit' to end the session.\n") + return c, nil +} - errChan := make(chan error, 1) - go func() { - for i := 0; i < maxRetries; i++ { - if err := subscriptionClient.Run(); err != nil { - opts.IO.LogInfo(fmt.Sprintf("Subscription client error: %v\n", err)) - opts.IO.LogInfo(fmt.Sprintf("Retrying in %v seconds...\n", retryDelay.Seconds())) - time.Sleep(retryDelay) - } else { - break - } - } - errChan <- fmt.Errorf("max retries reached, subscription client stopped") - }() +func (c *DuoChatClient) Subscribe(ctx context.Context, subID string) error { + if err := c.subscribeMain(ctx); err != nil { + return fmt.Errorf("subscribe main: %w", err) + } - reader := bufio.NewReader(opts.IO.In) - for { - select { - case err := <-errChan: - opts.IO.LogInfo(fmt.Sprintf("Subscription client stopped: %v\n", err)) - return err - default: - opts.IO.LogInfo(opts.IO.Color().Bold("You: ")) - input, err := reader.ReadString('\n') - if err != nil { - return cmdutils.WrapError(err, "failed to read input") - } + if err := c.subscribeStream(ctx, subID); err != nil { + return fmt.Errorf("subscribe stream: %w", err) + } - input = strings.TrimSpace(input) - if input == "exit" || input == "quit" { - subscriptionClient.Unsubscribe(subID) - subscriptionClient.Close() - return nil - } + go c.handleMessages(ctx) + return nil +} - if input == "" { - continue - } +func (c *DuoChatClient) subscribeMain(ctx context.Context) error { + query := map[string]interface{}{ + "channel": "GraphqlChannel", + "query": `subscription aiCompletionResponse($userId: UserID, $aiAction: AiAction, $clientSubscriptionId) { + aiCompletionResponse( + userId: $userId + aiAction: $aiAction + ) { + id requestId content errors role timestamp type chunkId + } + }`, + "variables": map[string]interface{}{ + "aiAction": "CHAT", + "userId": c.userID, + }, + "operationName": "aiCompletionResponse", + "nonce": uuid.New().String(), + } - if err := sendChatMessage(subscriptionClient, input, subscriptionId); err != nil { - opts.IO.LogInfo(fmt.Sprintf("Error sending message: %v\n", err)) - } else { - opts.IO.LogInfo("Message sent successfully\n") - } - } + return c.sendSubscription(ctx, query) +} + +func (c *DuoChatClient) subscribeStream(ctx context.Context, subID string) error { + query := map[string]interface{}{ + "channel": "GraphqlChannel", + "query": `subscription aiCompletionResponseStream($userId: UserID, $clientSubscriptionId: String) { + aiCompletionResponse( + userId: $userId + aiAction: CHAT + clientSubscriptionId: $clientSubscriptionId + ) { + id requestId content errors role timestamp type chunkId + } + }`, + "variables": map[string]interface{}{ + "clientSubscriptionId": subID, + "userId": c.userID, + }, + "operationName": "aiCompletionResponseStream", + "nonce": uuid.New().String(), + } + + return c.sendSubscription(ctx, query) +} + +func (c *DuoChatClient) sendSubscription(ctx context.Context, query map[string]interface{}) error { + identifierBytes, err := json.Marshal(query) + if err != nil { + return fmt.Errorf("marshal query: %w", err) } + + msg := ActionCableMessage{ + Command: "subscribe", + Identifier: string(identifierBytes), + } + + msgBytes, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("marshal message: %w", err) + } + + return c.conn.Write(ctx, websocket.MessageText, msgBytes) } -func sendChatMessage(client *graphql.SubscriptionClient, content string, subscriptionId string) error { +func (c *DuoChatClient) SendPrompt(ctx context.Context, prompt, subID string) error { mutation := ` - mutation chat($input: AiActionInput!) { - aiAction(input: $input) { + mutation chat($question: String!, $clientSubscriptionId: String) { + aiAction( + input: { + chat: { + content: $question + } + clientSubscriptionId: $clientSubscriptionId + } + ) { requestId errors - __typename } } ` variables := map[string]interface{}{ - "input": map[string]interface{}{ - "chat": map[string]interface{}{ - "content": content, - }, - "clientSubscriptionId": subscriptionId, - "conversationType": "DUO_CHAT", - "platformOrigin": "cli", - }, + "question": prompt, + "clientSubscriptionId": subID, } - var response struct { - AiAction struct { - RequestID string `json:"requestId"` - Errors []string `json:"errors"` - TypeName string `json:"__typename"` - } `json:"aiAction"` + body, err := json.Marshal(map[string]interface{}{ + "query": mutation, + "variables": variables, + }) + if err != nil { + return fmt.Errorf("marshal mutation: %w", err) } - _, err := client.Exec(mutation, variables, func(message []byte, err error) error { - if err != nil { - return fmt.Errorf("error in Exec callback: %v", err) - } - return json.Unmarshal(message, &response) - }) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, gitlabGraphQLURL, strings.NewReader(string(body))) + if err != nil { + return fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+c.token) + req.Header.Set("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) if err != nil { - return fmt.Errorf("failed to execute mutation: %v", err) + return fmt.Errorf("send request: %w", err) } + defer resp.Body.Close() - if len(response.AiAction.Errors) > 0 { - return fmt.Errorf("mutation errors: %v", response.AiAction.Errors) + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("GraphQL error: status=%d body=%s", resp.StatusCode, body) } return nil } -func displayFormattedResponse(opts *chatOpts, content string) { - color := opts.IO.Color() - opts.IO.LogInfo(color.Bold("GitLab Duo: ")) +func (c *DuoChatClient) handleMessages(ctx context.Context) { + pingTicker := time.NewTicker(30 * time.Second) + defer pingTicker.Stop() - paragraphs := strings.Split(content, "\n\n") - for _, paragraph := range paragraphs { - paragraph = strings.ReplaceAll(paragraph, "```", color.Cyan("```")) - paragraph = strings.ReplaceAll(paragraph, "`", color.Cyan("`")) + for { + select { + case <-ctx.Done(): + return + case <-pingTicker.C: + if err := c.sendPing(ctx); err != nil { + c.debugLog("Failed to send ping: %v", err) + } + default: + _, rawMsg, err := c.conn.Read(ctx) + if err != nil { + if websocket.CloseStatus(err) == websocket.StatusNormalClosure { + c.debugLog("WebSocket connection closed normally") + return + } + c.debugLog("WebSocket read error: %v", err) + if err := c.reconnect(ctx); err != nil { + c.debugLog("Failed to reconnect: %v", err) + return + } + continue + } + + var msg ActionCableMessage + if err := json.Unmarshal(rawMsg, &msg); err != nil { + c.debugLog("Error unmarshaling message: %v", err) + continue + } + + switch msg.Type { + case "ping": + continue + case "welcome": + c.debugLog("Connected to ActionCable") + case "confirm_subscription": + c.debugLog("Subscribed to channel") + case "reject_subscription": + c.debugLog("Subscription rejected") + if err := c.reconnect(ctx); err != nil { + c.debugLog("Failed to reconnect: %v", err) + return + } + default: + if err := c.handleGraphQLMessage(msg.Message); err != nil { + c.debugLog("Handle GraphQL message error: %v", err) + } + } + } + } +} + +func (c *DuoChatClient) handleGraphQLMessage(rawMsg json.RawMessage) error { + var graphqlMsg struct { + Result struct { + Data struct { + AiCompletionResponse CompletionResponse `json:"aiCompletionResponse"` + } `json:"data"` + } `json:"result"` + } + if err := json.Unmarshal(rawMsg, &graphqlMsg); err != nil { + return fmt.Errorf("parse GraphQL message: %w", err) + } + + c.responses <- graphqlMsg.Result.Data.AiCompletionResponse + return nil +} + +func (c *DuoChatClient) ProcessResponses(ctx context.Context) error { + finishedRequests := make(map[string]bool) + pendingResponses := make(map[int]CompletionResponse) + var currentRequestID string + maxChunkID := -1 + var buffer strings.Builder + + responseChan := make(chan CompletionResponse) + errChan := make(chan error) + doneChan := make(chan struct{}) + + go func() { + for { + select { + case response, ok := <-c.responses: + if !ok { + errChan <- errors.New("response channel closed") + return + } + responseChan <- response + case <-ctx.Done(): + return + } + } + }() - for _, pattern := range []string{"Note:", "Important:", "Warning:"} { - paragraph = strings.ReplaceAll(paragraph, pattern, color.Bold(pattern)) + go func() { + defer close(doneChan) + lastActivityTime := time.Now() + + for { + select { + case response := <-responseChan: + c.debugLog("Received chunk: ID=%v, Type=%v, Content=%s", response.ChunkID, response.Type, response.Content) + lastActivityTime = time.Now() + + if response.Role != "ASSISTANT" || finishedRequests[response.RequestID] { + continue + } + + if currentRequestID == "" { + currentRequestID = response.RequestID + } + + if response.RequestID != currentRequestID { + continue + } + + if response.ChunkID != nil && *response.ChunkID > maxChunkID { + for i := maxChunkID + 1; i <= *response.ChunkID; i++ { + if pending, ok := pendingResponses[i]; ok { + buffer.WriteString(pending.Content) + delete(pendingResponses, i) + } else if i == *response.ChunkID { + buffer.WriteString(response.Content) + } + } + maxChunkID = *response.ChunkID + c.IO.LogInfo(buffer.String()) + buffer.Reset() + } else if response.ChunkID == nil { + // This is likely the final message + c.IO.LogInfo(response.Content) + } else { + pendingResponses[*response.ChunkID] = response + } + + if isResponseComplete(response) { + c.debugLog("Response complete, ending processing") + finishedRequests[response.RequestID] = true + c.IO.LogInfo("\n") // Add a newline after the complete response + return + } + + case err := <-errChan: + c.debugLog("Error processing responses: %v", err) + return + + case <-ctx.Done(): + c.debugLog("Context cancelled in ProcessResponses") + return + + case <-time.After(1 * time.Second): + if time.Since(lastActivityTime) > 5*time.Second && len(buffer.String()) > 0 { + c.debugLog("No new chunks received for 5 seconds, assuming completion") + c.IO.LogInfo("\n") // Add a newline after the assume + c.IO.LogInfo("\n") // Add a newline after the assumed complete response + return + } + } } + }() - opts.IO.LogInfo(paragraph + "\n") + select { + case <-doneChan: + return nil + case <-time.After(30 * time.Second): + return errors.New("response timeout") } - opts.IO.LogInfo("\n") } -func generateUniqueID() string { - return fmt.Sprintf("%d", time.Now().UnixNano()) +func isResponseComplete(response CompletionResponse) bool { + return response.ChunkID == nil || + response.Type != nil && *response.Type == "COMPLETE" || + (len(response.Content) > 0 && (strings.HasSuffix(response.Content, ".") || strings.HasSuffix(response.Content, "?") || strings.HasSuffix(response.Content, "!"))) } -func maskToken(token string) string { - if len(token) > 8 { - return token[:4] + "..." + token[len(token)-4:] +func (c *DuoChatClient) sendPing(ctx context.Context) error { + pingMessage := ActionCableMessage{ + Type: "ping", + Message: json.RawMessage("{}"), + } + msgBytes, err := json.Marshal(pingMessage) + if err != nil { + return fmt.Errorf("marshal ping message: %w", err) } - return "****" + return c.conn.Write(ctx, websocket.MessageText, msgBytes) +} + +func (c *DuoChatClient) reconnect(ctx context.Context) error { + maxRetries := 5 + backoff := time.Second + + for i := 0; i < maxRetries; i++ { + c.debugLog("Attempting to reconnect (attempt %d of %d)...", i+1, maxRetries) + + newConn, err := setupWebSocket(ctx, c.token) + if err != nil { + c.debugLog("Reconnection attempt failed: %v", err) + backoff *= 2 // Exponential backoff + time.Sleep(backoff) + continue + } + + c.conn = newConn + + subID := uuid.New().String() + if err := c.Subscribe(ctx, subID); err != nil { + c.debugLog("Failed to resubscribe: %v", err) + c.conn.Close(websocket.StatusInternalError, "") + backoff *= 2 // Exponential backoff + time.Sleep(backoff) + continue + } + + c.debugLog("Reconnected successfully") + return nil + } + + return errors.New("failed to reconnect after maximum retries") } -- GitLab