diff --git a/pkg/github/__toolsnaps__/pull_request_read.snap b/pkg/github/__toolsnaps__/pull_request_read.snap index d70f77e1e0..87b901a373 100644 --- a/pkg/github/__toolsnaps__/pull_request_read.snap +++ b/pkg/github/__toolsnaps__/pull_request_read.snap @@ -11,7 +11,7 @@ "type": "string" }, "method": { - "description": "Action to specify what pull request data needs to be retrieved from GitHub. \nPossible options: \n 1. get - Get details of a specific pull request.\n 2. get_diff - Get the diff of a pull request.\n 3. get_status - Get combined commit status of a head commit in a pull request.\n 4. get_files - Get the list of files changed in a pull request. Use with pagination parameters to control the number of results returned.\n 5. get_review_comments - Get review threads on a pull request. Each thread contains logically grouped review comments made on the same code location during pull request reviews. Returns threads with metadata (isResolved, isOutdated, isCollapsed) and their associated comments. Use cursor-based pagination (perPage, after) to control results.\n 6. get_reviews - Get the reviews on a pull request. When asked for review comments, use get_review_comments method. Use with pagination parameters to control the number of results returned.\n 7. get_comments - Get comments on a pull request. Use this if user doesn't specifically want review comments. Use with pagination parameters to control the number of results returned.\n 8. get_check_runs - Get check runs for the head commit of a pull request. Check runs are the individual CI/CD jobs and checks that run on the PR.\n", + "description": "Action to specify what pull request data needs to be retrieved from GitHub. \nPossible options: \n 1. get - Get details of a specific pull request.\n 2. get_diff - Get the diff of a pull request.\n 3. get_status - Get combined commit status of a head commit in a pull request.\n 4. get_files - Get the list of files changed in a pull request. Use with pagination parameters to control the number of results returned.\n 5. get_review_comments - Get review threads on a pull request. Each thread contains logically grouped review comments made on the same code location during pull request reviews. Returns threads with metadata (isResolved, isOutdated, isCollapsed) and their associated comments. Review comments include structured code suggestions when available, including Copilot-generated \"Suggest\" changesets (via thread partial) and human-authored suggestion code blocks in the comment body. Use cursor-based pagination (perPage, after) to control results.\n 6. get_reviews - Get the reviews on a pull request. When asked for review comments, use get_review_comments method. Use with pagination parameters to control the number of results returned.\n 7. get_comments - Get comments on a pull request. Use this if user doesn't specifically want review comments. Use with pagination parameters to control the number of results returned.\n 8. get_check_runs - Get check runs for the head commit of a pull request. Check runs are the individual CI/CD jobs and checks that run on the PR.\n", "enum": [ "get", "get_diff", diff --git a/pkg/github/minimal_types.go b/pkg/github/minimal_types.go index 5200be297f..469947b7bc 100644 --- a/pkg/github/minimal_types.go +++ b/pkg/github/minimal_types.go @@ -1567,13 +1567,14 @@ type MinimalPageInfo struct { // MinimalReviewComment is the trimmed output type for PR review comment objects. type MinimalReviewComment struct { - Body string `json:"body,omitempty"` - Path string `json:"path"` - Line *int `json:"line,omitempty"` - Author string `json:"author,omitempty"` - CreatedAt string `json:"created_at,omitempty"` - UpdatedAt string `json:"updated_at,omitempty"` - HTMLURL string `json:"html_url"` + Body string `json:"body,omitempty"` + Path string `json:"path"` + Line *int `json:"line,omitempty"` + Author string `json:"author,omitempty"` + CreatedAt string `json:"created_at,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` + HTMLURL string `json:"html_url"` + Suggestions []MinimalReviewSuggestion `json:"suggestions,omitempty"` } // MinimalReviewThread is the trimmed output type for PR review thread objects. diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 05028850d7..bd6b433c7d 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -35,7 +35,7 @@ Possible options: 2. get_diff - Get the diff of a pull request. 3. get_status - Get combined commit status of a head commit in a pull request. 4. get_files - Get the list of files changed in a pull request. Use with pagination parameters to control the number of results returned. - 5. get_review_comments - Get review threads on a pull request. Each thread contains logically grouped review comments made on the same code location during pull request reviews. Returns threads with metadata (isResolved, isOutdated, isCollapsed) and their associated comments. Use cursor-based pagination (perPage, after) to control results. + 5. get_review_comments - Get review threads on a pull request. Each thread contains logically grouped review comments made on the same code location during pull request reviews. Returns threads with metadata (isResolved, isOutdated, isCollapsed) and their associated comments. Review comments include structured code suggestions when available, including Copilot-generated "Suggest" changesets (via thread partial) and human-authored suggestion code blocks in the comment body. Use cursor-based pagination (perPage, after) to control results. 6. get_reviews - Get the reviews on a pull request. When asked for review comments, use get_review_comments method. Use with pagination parameters to control the number of results returned. 7. get_comments - Get comments on a pull request. Use this if user doesn't specifically want review comments. Use with pagination parameters to control the number of results returned. 8. get_check_runs - Get check runs for the head commit of a pull request. Check runs are the individual CI/CD jobs and checks that run on the PR. @@ -482,7 +482,13 @@ func GetPullRequestReviewComments(ctx context.Context, gqlClient *githubv4.Clien } } - return MarshalledTextResult(convertToMinimalReviewThreadsResponse(query)), nil + response := convertToMinimalReviewThreadsResponse(query) + + if client, err := deps.GetClient(ctx); err == nil { + enrichReviewThreadsWithSuggestions(ctx, client, owner, repo, pullNumber, response.ReviewThreads) + } + + return MarshalledTextResult(response), nil } func GetPullRequestReviews(ctx context.Context, client *github.Client, deps ToolDependencies, owner, repo string, pullNumber int, pagination PaginationParams) (*mcp.CallToolResult, error) { diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index aff71e4c1a..151f2a6cc1 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "net/http" + "net/http/httptest" "testing" "time" @@ -2056,6 +2057,103 @@ func Test_GetPullRequestComments(t *testing.T) { } } +func Test_GetPullRequestCommentsWithSuggestions(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/owner/repo/pull/42/threads/1964378741", r.URL.Path) + _, _ = w.Write([]byte(automatedSuggestionHTMLFixture)) + })) + defer server.Close() + + restClient, err := github.NewClient(github.WithHTTPClient(server.Client()), github.WithEnterpriseURLs(server.URL+"/", server.URL+"/")) + require.NoError(t, err) + + gqlHTTPClient := githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + reviewThreadsQuery{}, + map[string]any{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "prNum": githubv4.Int(42), + "first": githubv4.Int(30), + "commentsPerThread": githubv4.Int(100), + "after": (*githubv4.String)(nil), + }, + githubv4mock.DataResponse(map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "reviewThreads": map[string]any{ + "nodes": []map[string]any{ + { + "id": "PRRT_kwDORGz4i851Fgp1", + "isResolved": false, + "isOutdated": false, + "isCollapsed": false, + "comments": map[string]any{ + "totalCount": 1, + "nodes": []map[string]any{ + { + "id": "PRRC_kwDORGz4i86v72Xc", + "body": "Consider adding validation.", + "path": "glmocr/cli.py", + "line": 10, + "author": map[string]any{ + "login": "copilot-pull-request-reviewer", + }, + "createdAt": "2024-01-01T12:00:00Z", + "updatedAt": "2024-01-01T12:00:00Z", + "url": "https://github.com/owner/repo/pull/42#discussion_r101", + }, + }, + }, + }, + }, + "pageInfo": map[string]any{ + "hasNextPage": false, + "hasPreviousPage": false, + "startCursor": "cursor1", + "endCursor": "cursor2", + }, + "totalCount": 1, + }, + }, + }, + }), + ), + ) + + serverTool := PullRequestRead(translations.NullTranslationHelper) + deps := BaseDeps{ + Client: restClient, + GQLClient: githubv4.NewClient(gqlHTTPClient), + } + handler := serverTool.Handler(deps) + + request := createMCPRequest(map[string]any{ + "method": "get_review_comments", + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + }) + + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + require.NoError(t, err) + require.False(t, result.IsError) + + textContent := getTextResult(t, result) + var response MinimalReviewThreadsResponse + require.NoError(t, json.Unmarshal([]byte(textContent.Text), &response)) + require.Len(t, response.ReviewThreads, 1) + require.Len(t, response.ReviewThreads[0].Comments, 1) + + suggestions := response.ReviewThreads[0].Comments[0].Suggestions + require.Len(t, suggestions, 1) + assert.Equal(t, suggestionSourceAutomated, suggestions[0].Source) + assert.Equal(t, "glmocr/cli.py", suggestions[0].Path) + assert.Contains(t, suggestions[0].Suggestion, "import re") +} + func Test_GetPullRequestReviews(t *testing.T) { // Verify tool definition once serverTool := PullRequestRead(translations.NullTranslationHelper) diff --git a/pkg/github/review_suggestions.go b/pkg/github/review_suggestions.go new file mode 100644 index 0000000000..c4d42b9349 --- /dev/null +++ b/pkg/github/review_suggestions.go @@ -0,0 +1,286 @@ +package github + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "regexp" + "strings" + + gogithub "github.com/google/go-github/v87/github" +) + +const ( + suggestionSourceBody = "body" + suggestionSourceAutomated = "automated" +) + +var suggestionBlockPattern = regexp.MustCompile("(?s)```suggestion\\s*\n(.*?)```") + +// MinimalReviewSuggestion is a structured code suggestion attached to a review comment. +type MinimalReviewSuggestion struct { + Path string `json:"path,omitempty"` + Suggestion string `json:"suggestion"` + StartLine *int `json:"start_line,omitempty"` + EndLine *int `json:"end_line,omitempty"` + Source string `json:"source,omitempty"` +} + +type automatedDiffLine struct { + Text string `json:"text"` + Type string `json:"type"` + Left *int `json:"left"` + Right *int `json:"right"` +} + +type automatedDiffEntry struct { + Path string `json:"path"` + DiffLines []automatedDiffLine `json:"diffLines"` +} + +type automatedSuggestionPayload struct { + Props struct { + Comment struct { + AutomatedComment struct { + Suggestion struct { + DiffEntries []automatedDiffEntry `json:"diffEntries"` + } `json:"suggestion"` + } `json:"automatedComment"` + } `json:"comment"` + } `json:"props"` +} + +// decodeNodeDatabaseID extracts the numeric database ID encoded in a GitHub GraphQL node ID. +func decodeNodeDatabaseID(nodeID string) (int64, error) { + _, payload, ok := strings.Cut(nodeID, "_") + if !ok || payload == "" { + return 0, fmt.Errorf("invalid node ID: %q", nodeID) + } + + padded := payload + strings.Repeat("=", (4-len(payload)%4)%4) + raw, err := base64.RawURLEncoding.DecodeString(padded) + if err != nil { + raw, err = base64.URLEncoding.DecodeString(padded) + if err != nil { + return 0, fmt.Errorf("decode node ID %q: %w", nodeID, err) + } + } + + if len(raw) < 4 { + return 0, fmt.Errorf("node ID payload too short: %q", nodeID) + } + + dbID := int64(raw[len(raw)-4])<<24 | int64(raw[len(raw)-3])<<16 | int64(raw[len(raw)-2])<<8 | int64(raw[len(raw)-1]) + return dbID, nil +} + +func parseSuggestionsFromBody(body string) []MinimalReviewSuggestion { + matches := suggestionBlockPattern.FindAllStringSubmatch(body, -1) + if len(matches) == 0 { + return nil + } + + suggestions := make([]MinimalReviewSuggestion, 0, len(matches)) + for _, match := range matches { + if len(match) < 2 { + continue + } + suggestions = append(suggestions, MinimalReviewSuggestion{ + Suggestion: strings.TrimRight(match[1], "\n"), + Source: suggestionSourceBody, + }) + } + return suggestions +} + +func suggestionsFromAutomatedPayload(payload automatedSuggestionPayload) []MinimalReviewSuggestion { + diffEntries := payload.Props.Comment.AutomatedComment.Suggestion.DiffEntries + if len(diffEntries) == 0 { + return nil + } + + suggestions := make([]MinimalReviewSuggestion, 0, len(diffEntries)) + for _, entry := range diffEntries { + suggestionText, startLine, endLine := buildSuggestionFromDiffLines(entry.DiffLines) + if suggestionText == "" { + continue + } + suggestions = append(suggestions, MinimalReviewSuggestion{ + Path: entry.Path, + Suggestion: suggestionText, + StartLine: startLine, + EndLine: endLine, + Source: suggestionSourceAutomated, + }) + } + return suggestions +} + +func buildSuggestionFromDiffLines(lines []automatedDiffLine) (string, *int, *int) { + var builder strings.Builder + var startLine, endLine *int + + for _, line := range lines { + switch line.Type { + case "HUNK": + continue + case "ADDITION", "CONTEXT": + if builder.Len() > 0 { + builder.WriteByte('\n') + } + builder.WriteString(line.Text) + if line.Right != nil { + if startLine == nil { + startLine = line.Right + } + endLine = line.Right + } + } + } + + if builder.Len() == 0 { + return "", nil, nil + } + return builder.String(), startLine, endLine +} + +func webBaseURLFromClient(client *gogithub.Client) (*url.URL, error) { + if client == nil { + return url.Parse("https://github.com") + } + + apiURL, err := url.Parse(client.BaseURL()) + if err != nil || apiURL.Hostname() == "" { + return url.Parse("https://github.com") + } + + host := apiURL.Hostname() + switch { + case host == "api.github.com": + return url.Parse("https://github.com") + case strings.HasPrefix(host, "api."): + webHost := strings.TrimPrefix(host, "api.") + return url.Parse("https://" + webHost) + default: + webURL := *apiURL + webURL.Path = strings.TrimSuffix(webURL.Path, "/api/v3/") + webURL.Path = strings.TrimSuffix(webURL.Path, "/api/v3") + webURL.Path = "" + webURL.RawQuery = "" + webURL.Fragment = "" + return &webURL, nil + } +} + +func fetchAutomatedSuggestionsForThread( + ctx context.Context, + client *gogithub.Client, + owner, repo string, + pullNumber int, + threadNodeID string, +) ([]MinimalReviewSuggestion, error) { + threadDBID, err := decodeNodeDatabaseID(threadNodeID) + if err != nil { + return nil, err + } + + webBase, err := webBaseURLFromClient(client) + if err != nil { + return nil, err + } + + threadURL := fmt.Sprintf("%s/%s/%s/pull/%d/threads/%d?rendering_on_files_tab=true", + strings.TrimRight(webBase.String(), "/"), owner, repo, pullNumber, threadDBID) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, threadURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "text/html") + + resp, err := client.Client().Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) + return nil, fmt.Errorf("thread partial request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + html, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024)) + if err != nil { + return nil, err + } + + return parseAutomatedSuggestionsFromHTML(string(html)) +} + +func parseAutomatedSuggestionsFromHTML(html string) ([]MinimalReviewSuggestion, error) { + const marker = `") + if contentEnd == -1 { + break + } + + var payload automatedSuggestionPayload + if err := json.Unmarshal([]byte(html[contentStart:contentStart+contentEnd]), &payload); err == nil { + if suggestions := suggestionsFromAutomatedPayload(payload); len(suggestions) > 0 { + return suggestions, nil + } + } + + start = contentStart + contentEnd + } + + return nil, nil +} + +func enrichReviewThreadsWithSuggestions( + ctx context.Context, + client *gogithub.Client, + owner, repo string, + pullNumber int, + threads []MinimalReviewThread, +) { + for i := range threads { + thread := &threads[i] + if len(thread.Comments) == 0 { + continue + } + + for j := range thread.Comments { + if suggestions := parseSuggestionsFromBody(thread.Comments[j].Body); len(suggestions) > 0 { + thread.Comments[j].Suggestions = append(thread.Comments[j].Suggestions, suggestions...) + } + } + + automatedSuggestions, err := fetchAutomatedSuggestionsForThread(ctx, client, owner, repo, pullNumber, thread.ID) + if err != nil || len(automatedSuggestions) == 0 { + continue + } + + targetIdx := 0 + for j, comment := range thread.Comments { + if strings.Contains(strings.ToLower(comment.Author), "copilot") { + targetIdx = j + break + } + } + + thread.Comments[targetIdx].Suggestions = append(thread.Comments[targetIdx].Suggestions, automatedSuggestions...) + } +} diff --git a/pkg/github/review_suggestions_test.go b/pkg/github/review_suggestions_test.go new file mode 100644 index 0000000000..cd3b7dd93a --- /dev/null +++ b/pkg/github/review_suggestions_test.go @@ -0,0 +1,163 @@ +package github + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + gogithub "github.com/google/go-github/v87/github" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDecodeNodeDatabaseID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + nodeID string + want int64 + wantErr bool + }{ + { + name: "pull request review thread", + nodeID: "PRRT_kwDORGz4i851Fgp1", + want: 1964378741, + }, + { + name: "pull request review thread with url-safe padding char", + nodeID: "PRRT_kwDORGz4i851Fgo-", + want: 1964378686, + }, + { + name: "pull request review comment", + nodeID: "PRRC_kwDORGz4i86v72Xc", + want: 2951701980, + }, + { + name: "invalid node id", + nodeID: "invalid", + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got, err := decodeNodeDatabaseID(tc.nodeID) + if tc.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestParseSuggestionsFromBody(t *testing.T) { + t.Parallel() + + body := "Please update this.\n\n```suggestion\nimport pytest\n\npytest.importorskip(\"torch\")\n```\n" + suggestions := parseSuggestionsFromBody(body) + require.Len(t, suggestions, 1) + assert.Equal(t, suggestionSourceBody, suggestions[0].Source) + assert.Equal(t, "import pytest\n\npytest.importorskip(\"torch\")", suggestions[0].Suggestion) +} + +func TestParseAutomatedSuggestionsFromHTML(t *testing.T) { + t.Parallel() + + html := `
` + automatedSuggestionHTMLFixture + `` + suggestions, err := parseAutomatedSuggestionsFromHTML(html) + require.NoError(t, err) + require.Len(t, suggestions, 1) + assert.Equal(t, suggestionSourceAutomated, suggestions[0].Source) + assert.Equal(t, "glmocr/cli.py", suggestions[0].Path) + assert.Contains(t, suggestions[0].Suggestion, "import re") + require.NotNil(t, suggestions[0].StartLine) + assert.Equal(t, 10, *suggestions[0].StartLine) +} + +func TestParseAutomatedSuggestionsFromHTMLWithDeletions(t *testing.T) { + t.Parallel() + + html := `` + automatedSuggestionWithDeletionsFixture + `` + suggestions, err := parseAutomatedSuggestionsFromHTML(html) + require.NoError(t, err) + require.Len(t, suggestions, 1) + + s := suggestions[0] + assert.Equal(t, suggestionSourceAutomated, s.Source) + assert.Equal(t, "glmocr/tests/test_layout_device.py", s.Path) + assert.NotContains(t, s.Suggestion, "from glmocr.layout.layout_detector import PPDocLayoutDetector") + assert.Contains(t, s.Suggestion, "from glmocr import layout as layout_mod") + assert.Contains(t, s.Suggestion, "pytest.skip") + require.NotNil(t, s.StartLine) + assert.Equal(t, 132, *s.StartLine) +} + +func TestFetchAutomatedSuggestionsForThread(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/owner/repo/pull/42/threads/1964378741", r.URL.Path) + assert.Equal(t, "rendering_on_files_tab=true", r.URL.RawQuery) + _, _ = w.Write([]byte(automatedSuggestionHTMLFixture)) + })) + defer server.Close() + + client, err := gogithub.NewClient(gogithub.WithHTTPClient(server.Client()), gogithub.WithEnterpriseURLs(server.URL+"/", server.URL+"/")) + require.NoError(t, err) + + suggestions, err := fetchAutomatedSuggestionsForThread( + context.Background(), + client, + "owner", + "repo", + 42, + "PRRT_kwDORGz4i851Fgp1", + ) + require.NoError(t, err) + require.Len(t, suggestions, 1) + assert.Equal(t, "glmocr/cli.py", suggestions[0].Path) +} + +func TestEnrichReviewThreadsWithSuggestions(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte(automatedSuggestionHTMLFixture)) + })) + defer server.Close() + + client, err := gogithub.NewClient(gogithub.WithHTTPClient(server.Client()), gogithub.WithEnterpriseURLs(server.URL+"/", server.URL+"/")) + require.NoError(t, err) + + threads := []MinimalReviewThread{ + { + ID: "PRRT_kwDORGz4i851Fgp1", + Comments: []MinimalReviewComment{ + { + Body: "Consider adding validation.\n```suggestion\nvalidated = True\n```", + Author: "copilot-pull-request-reviewer", + Path: "glmocr/cli.py", + }, + }, + }, + } + + enrichReviewThreadsWithSuggestions(context.Background(), client, "owner", "repo", 42, threads) + + require.Len(t, threads[0].Comments[0].Suggestions, 2) + assert.Equal(t, suggestionSourceBody, threads[0].Comments[0].Suggestions[0].Source) + assert.Equal(t, "validated = True", threads[0].Comments[0].Suggestions[0].Suggestion) + assert.Equal(t, suggestionSourceAutomated, threads[0].Comments[0].Suggestions[1].Source) + assert.Equal(t, "glmocr/cli.py", threads[0].Comments[0].Suggestions[1].Path) +} + +const automatedSuggestionHTMLFixture = `` + +// Fixture derived from a real Copilot review thread partial (zai-org/GLM-OCR#131). +const automatedSuggestionWithDeletionsFixture = ``