From 53511f3655a5eed9976164fbd88d14df3490000c Mon Sep 17 00:00:00 2001 From: Bob Van Landuyt Date: Thu, 7 Mar 2019 10:58:37 +0100 Subject: [PATCH 1/2] Detect user based on key, username or id This allows gitlab-shell to be called with an argument of the format `key-123` or `username-name`. When called in this way, `gitlab-shell` will call the GitLab internal API. If the API responds with user information, it will print a welcome message including the username. If the API responds with a successful but empty response, gitlab-shell will print a welcome message for an anonymous user. If the API response includes an error message in JSON, this message will be printed to stderr. If the API call fails, an error message including the status code will be printed to stderr. --- go/internal/command/discover/discover.go | 37 ++++- go/internal/command/discover/discover_test.go | 130 +++++++++++++++++ go/internal/gitlabnet/client.go | 77 ++++++++++ go/internal/gitlabnet/client_test.go | 131 ++++++++++++++++++ go/internal/gitlabnet/discover/client.go | 72 ++++++++++ go/internal/gitlabnet/discover/client_test.go | 86 ++++++++++++ go/internal/gitlabnet/socketclient.go | 46 ++++++ .../gitlabnet/testserver/testserver.go | 56 ++++++++ spec/gitlab_shell_gitlab_shell_spec.rb | 33 +++-- 9 files changed, 659 insertions(+), 9 deletions(-) create mode 100644 go/internal/command/discover/discover_test.go create mode 100644 go/internal/gitlabnet/client.go create mode 100644 go/internal/gitlabnet/client_test.go create mode 100644 go/internal/gitlabnet/discover/client.go create mode 100644 go/internal/gitlabnet/discover/client_test.go create mode 100644 go/internal/gitlabnet/socketclient.go create mode 100644 go/internal/gitlabnet/testserver/testserver.go diff --git a/go/internal/command/discover/discover.go b/go/internal/command/discover/discover.go index 63a7a32c..ab04cbd4 100644 --- a/go/internal/command/discover/discover.go +++ b/go/internal/command/discover/discover.go @@ -2,9 +2,12 @@ package discover import ( "fmt" + "io" + "os" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover" ) type Command struct { @@ -12,6 +15,38 @@ type Command struct { Args *commandargs.CommandArgs } +var ( + output io.Writer = os.Stdout +) + func (c *Command) Execute() error { - return fmt.Errorf("No feature is implemented yet") + response, err := c.getUserInfo() + if err != nil { + return fmt.Errorf("Failed to get username: %v", err) + } + + if response.IsAnonymous() { + fmt.Fprintf(output, "Welcome to GitLab, Anonymous!\n") + } else { + fmt.Fprintf(output, "Welcome to GitLab, @%s!\n", response.Username) + } + + return nil +} + +func (c *Command) getUserInfo() (*discover.Response, error) { + client, err := discover.NewClient(c.Config) + if err != nil { + return nil, err + } + + if c.Args.GitlabKeyId != "" { + return client.GetByKeyId(c.Args.GitlabKeyId) + } else if c.Args.GitlabUsername != "" { + return client.GetByUsername(c.Args.GitlabUsername) + } else { + // There was no 'who' information, this matches the ruby error + // message. + return nil, fmt.Errorf("who='' is invalid") + } } diff --git a/go/internal/command/discover/discover_test.go b/go/internal/command/discover/discover_test.go new file mode 100644 index 00000000..752e76e4 --- /dev/null +++ b/go/internal/command/discover/discover_test.go @@ -0,0 +1,130 @@ +package discover + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver" +) + +var ( + testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket} + requests = []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/discover", + Handler: func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("key_id") == "1" || r.URL.Query().Get("username") == "alex-doe" { + body := map[string]interface{}{ + "id": 2, + "username": "alex-doe", + "name": "Alex Doe", + } + json.NewEncoder(w).Encode(body) + } else if r.URL.Query().Get("username") == "broken_message" { + body := map[string]string{ + "message": "Forbidden!", + } + w.WriteHeader(http.StatusForbidden) + json.NewEncoder(w).Encode(body) + } else if r.URL.Query().Get("username") == "broken" { + w.WriteHeader(http.StatusInternalServerError) + } else { + fmt.Fprint(w, "null") + } + }, + }, + } +) + +func TestExecute(t *testing.T) { + cleanup, err := testserver.StartSocketHttpServer(requests) + require.NoError(t, err) + defer cleanup() + + testCases := []struct { + desc string + arguments *commandargs.CommandArgs + expectedOutput string + }{ + { + desc: "With a known username", + arguments: &commandargs.CommandArgs{GitlabUsername: "alex-doe"}, + expectedOutput: "Welcome to GitLab, @alex-doe!\n", + }, + { + desc: "With a known key id", + arguments: &commandargs.CommandArgs{GitlabKeyId: "1"}, + expectedOutput: "Welcome to GitLab, @alex-doe!\n", + }, + { + desc: "With an unknown key", + arguments: &commandargs.CommandArgs{GitlabKeyId: "-1"}, + expectedOutput: "Welcome to GitLab, Anonymous!\n", + }, + { + desc: "With an unknown username", + arguments: &commandargs.CommandArgs{GitlabUsername: "unknown"}, + expectedOutput: "Welcome to GitLab, Anonymous!\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + buffer := &bytes.Buffer{} + output = buffer + cmd := &Command{Config: testConfig, Args: tc.arguments} + + err := cmd.Execute() + + assert.NoError(t, err) + assert.Equal(t, tc.expectedOutput, buffer.String()) + }) + } +} + +func TestFailingExecute(t *testing.T) { + cleanup, err := testserver.StartSocketHttpServer(requests) + require.NoError(t, err) + defer cleanup() + + testCases := []struct { + desc string + arguments *commandargs.CommandArgs + expectedError string + }{ + { + desc: "With missing arguments", + arguments: &commandargs.CommandArgs{}, + expectedError: "Failed to get username: who='' is invalid", + }, + { + desc: "When the API returns an error", + arguments: &commandargs.CommandArgs{GitlabUsername: "broken_message"}, + expectedError: "Failed to get username: Forbidden!", + }, + { + desc: "When the API fails", + arguments: &commandargs.CommandArgs{GitlabUsername: "broken"}, + expectedError: "Failed to get username: Internal API error (500)", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + cmd := &Command{Config: testConfig, Args: tc.arguments} + + err := cmd.Execute() + + assert.EqualError(t, err, tc.expectedError) + }) + } + +} diff --git a/go/internal/gitlabnet/client.go b/go/internal/gitlabnet/client.go new file mode 100644 index 00000000..abc218f1 --- /dev/null +++ b/go/internal/gitlabnet/client.go @@ -0,0 +1,77 @@ +package gitlabnet + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "strings" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" +) + +const ( + internalApiPath = "/api/v4/internal" + secretHeaderName = "Gitlab-Shared-Secret" +) + +type GitlabClient interface { + Get(path string) (*http.Response, error) + // TODO: implement posts + // Post(path string) (http.Response, error) +} + +type ErrorResponse struct { + Message string `json:"message"` +} + +func GetClient(config *config.Config) (GitlabClient, error) { + url := config.GitlabUrl + if strings.HasPrefix(url, UnixSocketProtocol) { + return buildSocketClient(config), nil + } + + return nil, fmt.Errorf("Unsupported protocol") +} + +func normalizePath(path string) string { + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + + if !strings.HasPrefix(path, internalApiPath) { + path = internalApiPath + path + } + return path +} + +func parseError(resp *http.Response) error { + if resp.StatusCode >= 200 && resp.StatusCode <= 299 { + return nil + } + defer resp.Body.Close() + parsedResponse := &ErrorResponse{} + + if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil { + return fmt.Errorf("Internal API error (%v)", resp.StatusCode) + } else { + return fmt.Errorf(parsedResponse.Message) + } + +} + +func doRequest(client *http.Client, config *config.Config, request *http.Request) (*http.Response, error) { + encodedSecret := base64.StdEncoding.EncodeToString([]byte(config.Secret)) + request.Header.Set(secretHeaderName, encodedSecret) + + response, err := client.Do(request) + if err != nil { + return nil, fmt.Errorf("Internal API unreachable") + } + + if err := parseError(response); err != nil { + return nil, err + } + + return response, nil +} diff --git a/go/internal/gitlabnet/client_test.go b/go/internal/gitlabnet/client_test.go new file mode 100644 index 00000000..f69f2843 --- /dev/null +++ b/go/internal/gitlabnet/client_test.go @@ -0,0 +1,131 @@ +package gitlabnet + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver" +) + +func TestClients(t *testing.T) { + requests := []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/hello", + Handler: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "Hello") + }, + }, + { + Path: "/api/v4/internal/auth", + Handler: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, r.Header.Get(secretHeaderName)) + }, + }, + { + Path: "/api/v4/internal/error", + Handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + body := map[string]string{ + "message": "Don't do that", + } + json.NewEncoder(w).Encode(body) + }, + }, + { + Path: "/api/v4/internal/broken", + Handler: func(w http.ResponseWriter, r *http.Request) { + panic("Broken") + }, + }, + } + testConfig := &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket, Secret: "sssh, it's a secret"} + + testCases := []struct { + desc string + client GitlabClient + server func([]testserver.TestRequestHandler) (func(), error) + }{ + { + desc: "Socket client", + client: buildSocketClient(testConfig), + server: testserver.StartSocketHttpServer, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + cleanup, err := tc.server(requests) + defer cleanup() + require.NoError(t, err) + + testBrokenRequest(t, tc.client) + testSuccessfulGet(t, tc.client) + testMissing(t, tc.client) + testErrorMessage(t, tc.client) + testAuthenticationHeader(t, tc.client) + }) + } +} + +func testSuccessfulGet(t *testing.T, client GitlabClient) { + t.Run("Successful get", func(t *testing.T) { + response, err := client.Get("/hello") + defer response.Body.Close() + + require.NoError(t, err) + require.NotNil(t, response) + + responseBody, err := ioutil.ReadAll(response.Body) + assert.NoError(t, err) + assert.Equal(t, string(responseBody), "Hello") + }) +} + +func testMissing(t *testing.T, client GitlabClient) { + t.Run("Missing error", func(t *testing.T) { + response, err := client.Get("/missing") + assert.EqualError(t, err, "Internal API error (404)") + assert.Nil(t, response) + }) +} + +func testErrorMessage(t *testing.T, client GitlabClient) { + t.Run("Error with message", func(t *testing.T) { + response, err := client.Get("/error") + assert.EqualError(t, err, "Don't do that") + assert.Nil(t, response) + }) +} + +func testBrokenRequest(t *testing.T, client GitlabClient) { + t.Run("Broken request", func(t *testing.T) { + response, err := client.Get("/broken") + assert.EqualError(t, err, "Internal API unreachable") + assert.Nil(t, response) + }) +} + +func testAuthenticationHeader(t *testing.T, client GitlabClient) { + t.Run("Authentication headers", func(t *testing.T) { + response, err := client.Get("/auth") + defer response.Body.Close() + + require.NoError(t, err) + require.NotNil(t, response) + + responseBody, err := ioutil.ReadAll(response.Body) + require.NoError(t, err) + + header, err := base64.StdEncoding.DecodeString(string(responseBody)) + require.NoError(t, err) + assert.Equal(t, "sssh, it's a secret", string(header)) + }) +} diff --git a/go/internal/gitlabnet/discover/client.go b/go/internal/gitlabnet/discover/client.go new file mode 100644 index 00000000..4e65d259 --- /dev/null +++ b/go/internal/gitlabnet/discover/client.go @@ -0,0 +1,72 @@ +package discover + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet" +) + +type Client struct { + config *config.Config + client gitlabnet.GitlabClient +} + +type Response struct { + UserId int64 `json:"id"` + Name string `json:"name"` + Username string `json:"username"` +} + +func NewClient(config *config.Config) (*Client, error) { + client, err := gitlabnet.GetClient(config) + if err != nil { + return nil, fmt.Errorf("Error creating http client: %v", err) + } + + return &Client{config: config, client: client}, nil +} + +func (c *Client) GetByKeyId(keyId string) (*Response, error) { + params := url.Values{} + params.Add("key_id", keyId) + + return c.getResponse(params) +} + +func (c *Client) GetByUsername(username string) (*Response, error) { + params := url.Values{} + params.Add("username", username) + + return c.getResponse(params) +} + +func (c *Client) parseResponse(resp *http.Response) (*Response, error) { + defer resp.Body.Close() + parsedResponse := &Response{} + + if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil { + return nil, err + } else { + return parsedResponse, nil + } + +} + +func (c *Client) getResponse(params url.Values) (*Response, error) { + path := "/discover?" + params.Encode() + response, err := c.client.Get(path) + + if err != nil { + return nil, err + } + + return c.parseResponse(response) +} + +func (r *Response) IsAnonymous() bool { + return r.UserId < 1 +} diff --git a/go/internal/gitlabnet/discover/client_test.go b/go/internal/gitlabnet/discover/client_test.go new file mode 100644 index 00000000..6c87d07b --- /dev/null +++ b/go/internal/gitlabnet/discover/client_test.go @@ -0,0 +1,86 @@ +package discover + +import ( + "encoding/json" + "fmt" + "net/http" + "testing" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + testConfig *config.Config + requests []testserver.TestRequestHandler +) + +func init() { + testConfig = &config.Config{GitlabUrl: "http+unix://" + testserver.TestSocket} + requests = []testserver.TestRequestHandler{ + { + Path: "/api/v4/internal/discover", + Handler: func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("key_id") == "1" { + body := map[string]interface{}{ + "id": 2, + "username": "alex-doe", + "name": "Alex Doe", + } + json.NewEncoder(w).Encode(body) + } else if r.URL.Query().Get("username") == "jane-doe" { + body := map[string]interface{}{ + "id": 1, + "username": "jane-doe", + "name": "Jane Doe", + } + json.NewEncoder(w).Encode(body) + } else { + fmt.Fprint(w, "null") + + } + + }, + }, + } +} + +func TestGetByKeyId(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + result, err := client.GetByKeyId("1") + assert.NoError(t, err) + assert.Equal(t, &Response{UserId: 2, Username: "alex-doe", Name: "Alex Doe"}, result) +} + +func TestGetByUsername(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + result, err := client.GetByUsername("jane-doe") + assert.NoError(t, err) + assert.Equal(t, &Response{UserId: 1, Username: "jane-doe", Name: "Jane Doe"}, result) +} + +func TestMissingUser(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + result, err := client.GetByUsername("missing") + assert.NoError(t, err) + assert.True(t, result.IsAnonymous()) +} + +func setup(t *testing.T) (*Client, func()) { + cleanup, err := testserver.StartSocketHttpServer(requests) + require.NoError(t, err) + + client, err := NewClient(testConfig) + require.NoError(t, err) + + return client, cleanup +} diff --git a/go/internal/gitlabnet/socketclient.go b/go/internal/gitlabnet/socketclient.go new file mode 100644 index 00000000..3bd7c70f --- /dev/null +++ b/go/internal/gitlabnet/socketclient.go @@ -0,0 +1,46 @@ +package gitlabnet + +import ( + "context" + "net" + "net/http" + "strings" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" +) + +const ( + // We need to set the base URL to something starting with HTTP, the host + // itself is ignored as we're talking over a socket. + socketBaseUrl = "http://unix" + UnixSocketProtocol = "http+unix://" +) + +type GitlabSocketClient struct { + httpClient *http.Client + config *config.Config +} + +func buildSocketClient(config *config.Config) *GitlabSocketClient { + path := strings.TrimPrefix(config.GitlabUrl, UnixSocketProtocol) + httpClient := &http.Client{ + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", path) + }, + }, + } + + return &GitlabSocketClient{httpClient: httpClient, config: config} +} + +func (c *GitlabSocketClient) Get(path string) (*http.Response, error) { + path = normalizePath(path) + + request, err := http.NewRequest("GET", socketBaseUrl+path, nil) + if err != nil { + return nil, err + } + + return doRequest(c.httpClient, c.config, request) +} diff --git a/go/internal/gitlabnet/testserver/testserver.go b/go/internal/gitlabnet/testserver/testserver.go new file mode 100644 index 00000000..9640fd7d --- /dev/null +++ b/go/internal/gitlabnet/testserver/testserver.go @@ -0,0 +1,56 @@ +package testserver + +import ( + "io/ioutil" + "log" + "net" + "net/http" + "os" + "path" + "path/filepath" +) + +var ( + tempDir, _ = ioutil.TempDir("", "gitlab-shell-test-api") + TestSocket = path.Join(tempDir, "internal.sock") +) + +type TestRequestHandler struct { + Path string + Handler func(w http.ResponseWriter, r *http.Request) +} + +func StartSocketHttpServer(handlers []TestRequestHandler) (func(), error) { + if err := os.MkdirAll(filepath.Dir(TestSocket), 0700); err != nil { + return nil, err + } + + socketListener, err := net.Listen("unix", TestSocket) + if err != nil { + return nil, err + } + + server := http.Server{ + Handler: buildHandler(handlers), + // We'll put this server through some nasty stuff we don't want + // in our test output + ErrorLog: log.New(ioutil.Discard, "", 0), + } + go server.Serve(socketListener) + + return cleanupSocket, nil +} + +func cleanupSocket() { + os.RemoveAll(tempDir) +} + +func buildHandler(handlers []TestRequestHandler) http.Handler { + h := http.NewServeMux() + + for _, handler := range handlers { + h.HandleFunc(handler.Path, handler.Handler) + } + + return h +} diff --git a/spec/gitlab_shell_gitlab_shell_spec.rb b/spec/gitlab_shell_gitlab_shell_spec.rb index 11692d35..cb3fd9cc 100644 --- a/spec/gitlab_shell_gitlab_shell_spec.rb +++ b/spec/gitlab_shell_gitlab_shell_spec.rb @@ -30,12 +30,19 @@ describe 'bin/gitlab-shell' do @server = HTTPUNIXServer.new(BindAddress: tmp_socket_path) @server.mount_proc('/api/v4/internal/discover') do |req, res| - if req.query['key_id'] == '100' || - req.query['user_id'] == '10' || - req.query['username'] == 'someuser' + identifier = req.query['key_id'] || req.query['username'] || req.query['user_id'] + known_identifiers = %w(10 someuser 100) + if known_identifiers.include?(identifier) res.status = 200 res.content_type = 'application/json' res.body = '{"id":1, "name": "Some User", "username": "someuser"}' + elsif identifier == 'broken_message' + res.status = 401 + res.body = '{"message": "Forbidden!"}' + elsif identifier && identifier != 'broken' + res.status = 200 + res.content_type = 'application/json' + res.body = 'null' else res.status = 500 end @@ -145,11 +152,7 @@ describe 'bin/gitlab-shell' do ) end - it_behaves_like 'results with keys' do - before do - pending - end - end + it_behaves_like 'results with keys' it 'outputs "Only ssh allowed"' do _, stderr, status = run!(["-c/usr/share/webapps/gitlab-shell/bin/gitlab-shell", "username-someuser"], env: {}) @@ -157,6 +160,20 @@ describe 'bin/gitlab-shell' do expect(stderr).to eq("Only ssh allowed\n") expect(status).not_to be_success end + + it 'returns an error message when the API call fails with a message' do + _, stderr, status = run!(["-c/usr/share/webapps/gitlab-shell/bin/gitlab-shell", "username-broken_message"]) + + expect(stderr).to match(/Failed to get username: Forbidden!/) + expect(status).not_to be_success + end + + it 'returns an error message when the API call fails without a message' do + _, stderr, status = run!(["-c/usr/share/webapps/gitlab-shell/bin/gitlab-shell", "username-broken"]) + + expect(stderr).to match(/Failed to get username: Internal API error \(500\)/) + expect(status).not_to be_success + end end def run!(args, env: {'SSH_CONNECTION' => 'fake'}) -- GitLab From 83c0f18e1de04b3bad9c424084e738e911c47336 Mon Sep 17 00:00:00 2001 From: Bob Van Landuyt Date: Thu, 14 Mar 2019 14:01:42 +0100 Subject: [PATCH 2/2] Wrap Stderr & Stdout in a reporter struct The reporter struct can be used for passing around and reporting to the io.Writer of choice. --- go/cmd/gitlab-shell/main.go | 19 +++--- go/internal/command/command.go | 3 +- go/internal/command/discover/discover.go | 13 ++-- go/internal/command/discover/discover_test.go | 11 ++-- go/internal/command/fallback/fallback.go | 4 +- go/internal/command/reporting/reporter.go | 8 +++ go/internal/gitlabnet/discover/client.go | 10 ++- go/internal/gitlabnet/discover/client_test.go | 65 ++++++++++++++++--- 8 files changed, 96 insertions(+), 37 deletions(-) create mode 100644 go/internal/command/reporting/reporter.go diff --git a/go/cmd/gitlab-shell/main.go b/go/cmd/gitlab-shell/main.go index 07623b43..2ed319da 100644 --- a/go/cmd/gitlab-shell/main.go +++ b/go/cmd/gitlab-shell/main.go @@ -7,25 +7,28 @@ import ( "gitlab.com/gitlab-org/gitlab-shell/go/internal/command" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" ) var ( - binDir string - rootDir string + binDir string + rootDir string + reporter *reporting.Reporter ) func init() { binDir = filepath.Dir(os.Args[0]) rootDir = filepath.Dir(binDir) + reporter = &reporting.Reporter{Out: os.Stdout, ErrOut: os.Stderr} } // rubyExec will never return. It either replaces the current process with a // Ruby interpreter, or outputs an error and kills the process. func execRuby() { cmd := &fallback.Command{} - if err := cmd.Execute(); err != nil { - fmt.Fprintf(os.Stderr, "Failed to exec: %v\n", err) + if err := cmd.Execute(reporter); err != nil { + fmt.Fprintf(reporter.ErrOut, "Failed to exec: %v\n", err) os.Exit(1) } } @@ -35,7 +38,7 @@ func main() { // warning as this isn't something we can sustain indefinitely config, err := config.NewFromDir(rootDir) if err != nil { - fmt.Fprintln(os.Stderr, "Failed to read config, falling back to gitlab-shell-ruby") + fmt.Fprintln(reporter.ErrOut, "Failed to read config, falling back to gitlab-shell-ruby") execRuby() } @@ -43,14 +46,14 @@ func main() { if err != nil { // For now this could happen if `SSH_CONNECTION` is not set on // the environment - fmt.Fprintf(os.Stderr, "%v\n", err) + fmt.Fprintf(reporter.ErrOut, "%v\n", err) os.Exit(1) } // The command will write to STDOUT on execution or replace the current // process in case of the `fallback.Command` - if err = cmd.Execute(); err != nil { - fmt.Fprintf(os.Stderr, "%v\n", err) + if err = cmd.Execute(reporter); err != nil { + fmt.Fprintf(reporter.ErrOut, "%v\n", err) os.Exit(1) } } diff --git a/go/internal/command/command.go b/go/internal/command/command.go index cb2acdcf..d4649de0 100644 --- a/go/internal/command/command.go +++ b/go/internal/command/command.go @@ -4,11 +4,12 @@ import ( "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/discover" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/fallback" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" ) type Command interface { - Execute() error + Execute(*reporting.Reporter) error } func New(arguments []string, config *config.Config) (Command, error) { diff --git a/go/internal/command/discover/discover.go b/go/internal/command/discover/discover.go index ab04cbd4..8ad2868c 100644 --- a/go/internal/command/discover/discover.go +++ b/go/internal/command/discover/discover.go @@ -2,10 +2,9 @@ package discover import ( "fmt" - "io" - "os" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/discover" ) @@ -15,20 +14,16 @@ type Command struct { Args *commandargs.CommandArgs } -var ( - output io.Writer = os.Stdout -) - -func (c *Command) Execute() error { +func (c *Command) Execute(reporter *reporting.Reporter) error { response, err := c.getUserInfo() if err != nil { return fmt.Errorf("Failed to get username: %v", err) } if response.IsAnonymous() { - fmt.Fprintf(output, "Welcome to GitLab, Anonymous!\n") + fmt.Fprintf(reporter.Out, "Welcome to GitLab, Anonymous!\n") } else { - fmt.Fprintf(output, "Welcome to GitLab, @%s!\n", response.Username) + fmt.Fprintf(reporter.Out, "Welcome to GitLab, @%s!\n", response.Username) } return nil diff --git a/go/internal/command/discover/discover_test.go b/go/internal/command/discover/discover_test.go index 752e76e4..ec6f931a 100644 --- a/go/internal/command/discover/discover_test.go +++ b/go/internal/command/discover/discover_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/commandargs" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver" ) @@ -78,11 +79,10 @@ func TestExecute(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - buffer := &bytes.Buffer{} - output = buffer cmd := &Command{Config: testConfig, Args: tc.arguments} + buffer := &bytes.Buffer{} - err := cmd.Execute() + err := cmd.Execute(&reporting.Reporter{Out: buffer}) assert.NoError(t, err) assert.Equal(t, tc.expectedOutput, buffer.String()) @@ -120,11 +120,12 @@ func TestFailingExecute(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { cmd := &Command{Config: testConfig, Args: tc.arguments} + buffer := &bytes.Buffer{} - err := cmd.Execute() + err := cmd.Execute(&reporting.Reporter{Out: buffer}) + assert.Empty(t, buffer.String()) assert.EqualError(t, err, tc.expectedError) }) } - } diff --git a/go/internal/command/fallback/fallback.go b/go/internal/command/fallback/fallback.go index a136657d..a2c73edc 100644 --- a/go/internal/command/fallback/fallback.go +++ b/go/internal/command/fallback/fallback.go @@ -4,6 +4,8 @@ import ( "os" "path/filepath" "syscall" + + "gitlab.com/gitlab-org/gitlab-shell/go/internal/command/reporting" ) type Command struct{} @@ -12,7 +14,7 @@ var ( binDir = filepath.Dir(os.Args[0]) ) -func (c *Command) Execute() error { +func (c *Command) Execute(_ *reporting.Reporter) error { rubyCmd := filepath.Join(binDir, "gitlab-shell-ruby") execErr := syscall.Exec(rubyCmd, os.Args, os.Environ()) return execErr diff --git a/go/internal/command/reporting/reporter.go b/go/internal/command/reporting/reporter.go new file mode 100644 index 00000000..74bca590 --- /dev/null +++ b/go/internal/command/reporting/reporter.go @@ -0,0 +1,8 @@ +package reporting + +import "io" + +type Reporter struct { + Out io.Writer + ErrOut io.Writer +} diff --git a/go/internal/gitlabnet/discover/client.go b/go/internal/gitlabnet/discover/client.go index 4e65d259..8df78fb5 100644 --- a/go/internal/gitlabnet/discover/client.go +++ b/go/internal/gitlabnet/discover/client.go @@ -45,7 +45,6 @@ func (c *Client) GetByUsername(username string) (*Response, error) { } func (c *Client) parseResponse(resp *http.Response) (*Response, error) { - defer resp.Body.Close() parsedResponse := &Response{} if err := json.NewDecoder(resp.Body).Decode(parsedResponse); err != nil { @@ -53,7 +52,6 @@ func (c *Client) parseResponse(resp *http.Response) (*Response, error) { } else { return parsedResponse, nil } - } func (c *Client) getResponse(params url.Values) (*Response, error) { @@ -64,7 +62,13 @@ func (c *Client) getResponse(params url.Values) (*Response, error) { return nil, err } - return c.parseResponse(response) + defer response.Body.Close() + parsedResponse, err := c.parseResponse(response) + if err != nil { + return nil, fmt.Errorf("Parsing failed") + } + + return parsedResponse, nil } func (r *Response) IsAnonymous() bool { diff --git a/go/internal/gitlabnet/discover/client_test.go b/go/internal/gitlabnet/discover/client_test.go index 6c87d07b..e88cedd8 100644 --- a/go/internal/gitlabnet/discover/client_test.go +++ b/go/internal/gitlabnet/discover/client_test.go @@ -7,6 +7,7 @@ import ( "testing" "gitlab.com/gitlab-org/gitlab-shell/go/internal/config" + "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet" "gitlab.com/gitlab-org/gitlab-shell/go/internal/gitlabnet/testserver" "github.com/stretchr/testify/assert" @@ -25,24 +26,32 @@ func init() { Path: "/api/v4/internal/discover", Handler: func(w http.ResponseWriter, r *http.Request) { if r.URL.Query().Get("key_id") == "1" { - body := map[string]interface{}{ - "id": 2, - "username": "alex-doe", - "name": "Alex Doe", + body := &Response{ + UserId: 2, + Username: "alex-doe", + Name: "Alex Doe", } json.NewEncoder(w).Encode(body) } else if r.URL.Query().Get("username") == "jane-doe" { - body := map[string]interface{}{ - "id": 1, - "username": "jane-doe", - "name": "Jane Doe", + body := &Response{ + UserId: 1, + Username: "jane-doe", + Name: "Jane Doe", } json.NewEncoder(w).Encode(body) + } else if r.URL.Query().Get("username") == "broken_message" { + w.WriteHeader(http.StatusForbidden) + body := &gitlabnet.ErrorResponse{ + Message: "Not allowed!", + } + json.NewEncoder(w).Encode(body) + } else if r.URL.Query().Get("username") == "broken_json" { + w.Write([]byte("{ \"message\": \"broken json!\"")) + } else if r.URL.Query().Get("username") == "broken_empty" { + w.WriteHeader(http.StatusForbidden) } else { fmt.Fprint(w, "null") - } - }, }, } @@ -75,6 +84,42 @@ func TestMissingUser(t *testing.T) { assert.True(t, result.IsAnonymous()) } +func TestErrorResponses(t *testing.T) { + client, cleanup := setup(t) + defer cleanup() + + testCases := []struct { + desc string + fakeUsername string + expectedError string + }{ + { + desc: "A response with an error message", + fakeUsername: "broken_message", + expectedError: "Not allowed!", + }, + { + desc: "A response with bad JSON", + fakeUsername: "broken_json", + expectedError: "Parsing failed", + }, + { + desc: "An error response without message", + fakeUsername: "broken_empty", + expectedError: "Internal API error (403)", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + resp, err := client.GetByUsername(tc.fakeUsername) + + assert.EqualError(t, err, tc.expectedError) + assert.Nil(t, resp) + }) + } +} + func setup(t *testing.T) (*Client, func()) { cleanup, err := testserver.StartSocketHttpServer(requests) require.NoError(t, err) -- GitLab