diff --git a/pkg/github/check_runs.go b/pkg/github/check_runs.go new file mode 100644 index 0000000000..c516350efc --- /dev/null +++ b/pkg/github/check_runs.go @@ -0,0 +1,211 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/google/go-github/v87/github" + "github.com/modelcontextprotocol/go-sdk/mcp" + + ghErrors "github.com/github/github-mcp-server/pkg/errors" + "github.com/github/github-mcp-server/pkg/utils" +) + +const ( + checkRunsSourceChecksAPI = "checks_api" + checkRunsSourceWorkflowRuns = "workflow_runs" + checkRunsSourceCommitStatuses = "commit_statuses" +) + +func isAccessDenied(resp *github.Response) bool { + return resp != nil && resp.StatusCode == http.StatusForbidden +} + +func checkRunsAccessErrMsg(base, owner, repo string) string { + return fmt.Sprintf("%s. Check runs require the Checks API (checks:read for GitHub Apps, repo scope for classic PATs). "+ + "When using hosted MCP, the GitHub App installation must include Checks: Read permission. "+ + "Fallbacks using workflow runs and commit statuses were also unavailable for %s/%s.", + base, owner, repo) +} + +func GetPullRequestCheckRuns(ctx context.Context, client *github.Client, owner, repo string, pullNumber int, pagination PaginationParams) (*mcp.CallToolResult, error) { + headSHA, errResult, err := getPullRequestHeadSHA(ctx, client, owner, repo, pullNumber) + if errResult != nil || err != nil { + return errResult, err + } + + result, resp, err := fetchCheckRunsFromChecksAPI(ctx, client, owner, repo, headSHA, pagination) + if err == nil { + return marshalCheckRunsResult(result) + } + if !isAccessDenied(resp) { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get check runs", resp, err), nil + } + closeResponse(resp) + + // Checks API is unavailable (common on hosted MCP without checks:read). Try fallbacks. + workflowFallback, workflowResp, workflowErr := fetchCheckRunsFromWorkflowRuns(ctx, client, owner, repo, headSHA, pagination) + if workflowErr == nil { + return marshalCheckRunsResult(workflowFallback) + } + if !isAccessDenied(workflowResp) { + return ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get check runs", workflowResp, workflowErr), nil + } + closeResponse(workflowResp) + + statusFallback, statusResp, statusErr := fetchCheckRunsFromCommitStatuses(ctx, client, owner, repo, headSHA, pagination) + if statusErr == nil { + return marshalCheckRunsResult(statusFallback) + } + closeResponse(statusResp) + + return ghErrors.NewGitHubAPIErrorResponse(ctx, + checkRunsAccessErrMsg("failed to get check runs", owner, repo), + resp, + err, + ), nil +} + +func getPullRequestHeadSHA(ctx context.Context, client *github.Client, owner, repo string, pullNumber int) (string, *mcp.CallToolResult, error) { + pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) + if err != nil { + return "", ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get pull request", resp, err), nil + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, readErr := readResponseBody(resp) + if readErr != nil { + return "", nil, readErr + } + return "", ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get pull request", resp, body), nil + } + + return pr.GetHead().GetSHA(), nil, nil +} + +func fetchCheckRunsFromChecksAPI(ctx context.Context, client *github.Client, owner, repo, headSHA string, pagination PaginationParams) (MinimalCheckRunsResult, *github.Response, error) { + opts := &github.ListCheckRunsOptions{ + ListOptions: github.ListOptions{ + PerPage: pagination.PerPage, + Page: pagination.Page, + }, + } + + checkRuns, resp, err := client.Checks.ListCheckRunsForRef(ctx, owner, repo, headSHA, opts) + if err != nil { + return MinimalCheckRunsResult{}, resp, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, readErr := readResponseBody(resp) + if readErr != nil { + return MinimalCheckRunsResult{}, resp, readErr + } + return MinimalCheckRunsResult{}, resp, fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)) + } + + minimalCheckRuns := make([]MinimalCheckRun, 0, len(checkRuns.CheckRuns)) + for _, checkRun := range checkRuns.CheckRuns { + minimalCheckRuns = append(minimalCheckRuns, convertToMinimalCheckRun(checkRun)) + } + + return MinimalCheckRunsResult{ + TotalCount: checkRuns.GetTotal(), + CheckRuns: minimalCheckRuns, + Source: checkRunsSourceChecksAPI, + }, resp, nil +} + +func fetchCheckRunsFromWorkflowRuns(ctx context.Context, client *github.Client, owner, repo, headSHA string, pagination PaginationParams) (MinimalCheckRunsResult, *github.Response, error) { + opts := &github.ListWorkflowRunsOptions{ + HeadSHA: headSHA, + ListOptions: github.ListOptions{ + PerPage: pagination.PerPage, + Page: pagination.Page, + }, + } + + runs, resp, err := client.Actions.ListRepositoryWorkflowRuns(ctx, owner, repo, opts) + if err != nil { + return MinimalCheckRunsResult{}, resp, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, readErr := readResponseBody(resp) + if readErr != nil { + return MinimalCheckRunsResult{}, resp, readErr + } + return MinimalCheckRunsResult{}, resp, fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)) + } + + minimalCheckRuns := make([]MinimalCheckRun, 0, len(runs.WorkflowRuns)) + for _, run := range runs.WorkflowRuns { + minimalCheckRuns = append(minimalCheckRuns, convertWorkflowRunToMinimalCheckRun(run)) + } + + return MinimalCheckRunsResult{ + TotalCount: runs.GetTotalCount(), + CheckRuns: minimalCheckRuns, + Source: checkRunsSourceWorkflowRuns, + }, resp, nil +} + +func fetchCheckRunsFromCommitStatuses(ctx context.Context, client *github.Client, owner, repo, headSHA string, pagination PaginationParams) (MinimalCheckRunsResult, *github.Response, error) { + opts := &github.ListOptions{ + PerPage: pagination.PerPage, + Page: pagination.Page, + } + + statuses, resp, err := client.Repositories.ListStatuses(ctx, owner, repo, headSHA, opts) + if err != nil { + return MinimalCheckRunsResult{}, resp, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, readErr := readResponseBody(resp) + if readErr != nil { + return MinimalCheckRunsResult{}, resp, readErr + } + return MinimalCheckRunsResult{}, resp, fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)) + } + + minimalCheckRuns := make([]MinimalCheckRun, 0, len(statuses)) + for _, status := range statuses { + minimalCheckRuns = append(minimalCheckRuns, convertCommitStatusToMinimalCheckRun(status)) + } + + return MinimalCheckRunsResult{ + TotalCount: len(minimalCheckRuns), + CheckRuns: minimalCheckRuns, + Source: checkRunsSourceCommitStatuses, + }, resp, nil +} + +func marshalCheckRunsResult(result MinimalCheckRunsResult) (*mcp.CallToolResult, error) { + r, err := json.Marshal(result) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + return utils.NewToolResultText(string(r)), nil +} + +func closeResponse(resp *github.Response) { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } +} + +func readResponseBody(resp *github.Response) ([]byte, error) { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return body, nil +} diff --git a/pkg/github/check_runs_test.go b/pkg/github/check_runs_test.go new file mode 100644 index 0000000000..4c8abd75bd --- /dev/null +++ b/pkg/github/check_runs_test.go @@ -0,0 +1,69 @@ +package github + +import ( + "testing" + "time" + + "github.com/google/go-github/v87/github" + "github.com/stretchr/testify/assert" +) + +func Test_convertWorkflowRunToMinimalCheckRun(t *testing.T) { + startedAt := time.Date(2026, 1, 2, 3, 4, 5, 0, time.UTC) + updatedAt := time.Date(2026, 1, 2, 3, 10, 0, 0, time.UTC) + + run := &github.WorkflowRun{ + ID: github.Ptr(int64(42)), + Name: github.Ptr("CI"), + Status: github.Ptr("completed"), + Conclusion: github.Ptr("failure"), + HTMLURL: github.Ptr("https://github.com/o/r/actions/runs/42"), + RunStartedAt: &github.Timestamp{Time: startedAt}, + UpdatedAt: &github.Timestamp{Time: updatedAt}, + } + + result := convertWorkflowRunToMinimalCheckRun(run) + + assert.Equal(t, int64(42), result.ID) + assert.Equal(t, "CI", result.Name) + assert.Equal(t, "completed", result.Status) + assert.Equal(t, "failure", result.Conclusion) + assert.Equal(t, "https://github.com/o/r/actions/runs/42", result.HTMLURL) + assert.Equal(t, "2026-01-02T03:04:05Z", result.StartedAt) + assert.Equal(t, "2026-01-02T03:10:00Z", result.CompletedAt) +} + +func Test_convertCommitStatusToMinimalCheckRun(t *testing.T) { + createdAt := time.Date(2026, 2, 1, 12, 0, 0, 0, time.UTC) + updatedAt := time.Date(2026, 2, 1, 12, 5, 0, 0, time.UTC) + + status := &github.RepoStatus{ + ID: github.Ptr(int64(9)), + Context: github.Ptr("ci/build"), + State: github.Ptr("success"), + TargetURL: github.Ptr("https://ci.example.com/build/9"), + CreatedAt: &github.Timestamp{Time: createdAt}, + UpdatedAt: &github.Timestamp{Time: updatedAt}, + } + + result := convertCommitStatusToMinimalCheckRun(status) + + assert.Equal(t, int64(9), result.ID) + assert.Equal(t, "ci/build", result.Name) + assert.Equal(t, "completed", result.Status) + assert.Equal(t, "success", result.Conclusion) + assert.Equal(t, "https://ci.example.com/build/9", result.DetailsURL) +} + +func Test_convertCommitStatusToMinimalCheckRun_pending(t *testing.T) { + status := &github.RepoStatus{ + ID: github.Ptr(int64(1)), + Context: github.Ptr("ci/build"), + State: github.Ptr("pending"), + } + + result := convertCommitStatusToMinimalCheckRun(status) + + assert.Equal(t, "in_progress", result.Status) + assert.Empty(t, result.Conclusion) +} diff --git a/pkg/github/minimal_types.go b/pkg/github/minimal_types.go index 5200be297f..bfd26b61e4 100644 --- a/pkg/github/minimal_types.go +++ b/pkg/github/minimal_types.go @@ -1665,6 +1665,8 @@ type MinimalCheckRun struct { type MinimalCheckRunsResult struct { TotalCount int `json:"total_count"` CheckRuns []MinimalCheckRun `json:"check_runs"` + // Source indicates which API provided the data: "checks_api", "workflow_runs", or "commit_statuses". + Source string `json:"source,omitempty"` } // convertToMinimalCheckRun converts a GitHub API CheckRun to MinimalCheckRun @@ -1688,6 +1690,56 @@ func convertToMinimalCheckRun(checkRun *github.CheckRun) MinimalCheckRun { return minimalCheckRun } +func convertWorkflowRunToMinimalCheckRun(run *github.WorkflowRun) MinimalCheckRun { + status := run.GetStatus() + conclusion := run.GetConclusion() + + minimalCheckRun := MinimalCheckRun{ + ID: run.GetID(), + Name: run.GetName(), + Status: status, + Conclusion: conclusion, + HTMLURL: run.GetHTMLURL(), + DetailsURL: run.GetHTMLURL(), + } + + if run.RunStartedAt != nil { + minimalCheckRun.StartedAt = run.RunStartedAt.Format("2006-01-02T15:04:05Z") + } + if run.UpdatedAt != nil && status == "completed" { + minimalCheckRun.CompletedAt = run.UpdatedAt.Format("2006-01-02T15:04:05Z") + } + + return minimalCheckRun +} + +func convertCommitStatusToMinimalCheckRun(status *github.RepoStatus) MinimalCheckRun { + state := status.GetState() + conclusion := state + checkStatus := "completed" + if state == "pending" { + checkStatus = "in_progress" + conclusion = "" + } + + minimalCheckRun := MinimalCheckRun{ + ID: status.GetID(), + Name: status.GetContext(), + Status: checkStatus, + Conclusion: conclusion, + DetailsURL: status.GetTargetURL(), + } + + if status.CreatedAt != nil { + minimalCheckRun.StartedAt = status.CreatedAt.Format("2006-01-02T15:04:05Z") + } + if status.UpdatedAt != nil { + minimalCheckRun.CompletedAt = status.UpdatedAt.Format("2006-01-02T15:04:05Z") + } + + return minimalCheckRun +} + func convertToMinimalReviewThreadsResponse(query reviewThreadsQuery) MinimalReviewThreadsResponse { threads := query.Repository.PullRequest.ReviewThreads diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index 05028850d7..62890b6584 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -278,71 +278,6 @@ func GetPullRequestStatus(ctx context.Context, client *github.Client, owner, rep return utils.NewToolResultText(string(r)), nil } -func GetPullRequestCheckRuns(ctx context.Context, client *github.Client, owner, repo string, pullNumber int, pagination PaginationParams) (*mcp.CallToolResult, error) { - // First get the PR to get the head SHA - pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get pull request", - resp, - err, - ), nil - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get pull request", resp, body), nil - } - - // Get check runs for the head SHA - opts := &github.ListCheckRunsOptions{ - ListOptions: github.ListOptions{ - PerPage: pagination.PerPage, - Page: pagination.Page, - }, - } - - checkRuns, resp, err := client.Checks.ListCheckRunsForRef(ctx, owner, repo, *pr.Head.SHA, opts) - if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, - "failed to get check runs", - resp, - err, - ), nil - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get check runs", resp, body), nil - } - - // Convert to minimal check runs to reduce context usage - minimalCheckRuns := make([]MinimalCheckRun, 0, len(checkRuns.CheckRuns)) - for _, checkRun := range checkRuns.CheckRuns { - minimalCheckRuns = append(minimalCheckRuns, convertToMinimalCheckRun(checkRun)) - } - - minimalResult := MinimalCheckRunsResult{ - TotalCount: checkRuns.GetTotal(), - CheckRuns: minimalCheckRuns, - } - - r, err := json.Marshal(minimalResult) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } - - return utils.NewToolResultText(string(r)), nil -} - func GetPullRequestFiles(ctx context.Context, client *github.Client, owner, repo string, pullNumber int, pagination PaginationParams) (*mcp.CallToolResult, error) { opts := &github.ListOptions{ PerPage: pagination.PerPage, diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index aff71e4c1a..72876e3747 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -1468,12 +1468,35 @@ func Test_GetPullRequestCheckRuns(t *testing.T) { }, } + mockWorkflowRuns := &github.WorkflowRuns{ + TotalCount: github.Ptr(1), + WorkflowRuns: []*github.WorkflowRun{ + { + ID: github.Ptr(int64(99)), + Name: github.Ptr("CI"), + Status: github.Ptr("completed"), + Conclusion: github.Ptr("success"), + HTMLURL: github.Ptr("https://github.com/owner/repo/actions/runs/99"), + }, + }, + } + + mockCommitStatuses := []*github.RepoStatus{ + { + ID: github.Ptr(int64(7)), + Context: github.Ptr("ci/travis"), + State: github.Ptr("success"), + TargetURL: github.Ptr("https://travis-ci.org/owner/repo/builds/1"), + }, + } + tests := []struct { name string mockedClient *http.Client requestArgs map[string]any expectError bool expectedCheckRuns *github.ListCheckRunsResults + expectedSource string expectedErrMsg string }{ { @@ -1490,6 +1513,49 @@ func Test_GetPullRequestCheckRuns(t *testing.T) { }, expectError: false, expectedCheckRuns: mockCheckRuns, + expectedSource: checkRunsSourceChecksAPI, + }, + { + name: "falls back to workflow runs when checks API returns 403", + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposPullsByOwnerByRepoByPullNumber: mockResponse(t, http.StatusOK, mockPR), + GetReposCommitsCheckRunsByOwnerByRepoByRef: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"message": "Resource not accessible by personal access token"}`)) + }), + GetReposActionsRunsByOwnerByRepo: mockResponse(t, http.StatusOK, mockWorkflowRuns), + }), + requestArgs: map[string]any{ + "method": "get_check_runs", + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + }, + expectError: false, + expectedSource: checkRunsSourceWorkflowRuns, + }, + { + name: "falls back to commit statuses when checks and workflow runs return 403", + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposPullsByOwnerByRepoByPullNumber: mockResponse(t, http.StatusOK, mockPR), + GetReposCommitsCheckRunsByOwnerByRepoByRef: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"message": "Resource not accessible by personal access token"}`)) + }), + GetReposActionsRunsByOwnerByRepo: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"message": "Resource not accessible by integration"}`)) + }), + GetReposCommitsStatusesByOwnerByRepoByRef: mockResponse(t, http.StatusOK, mockCommitStatuses), + }), + requestArgs: map[string]any{ + "method": "get_check_runs", + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + }, + expectError: false, + expectedSource: checkRunsSourceCommitStatuses, }, { name: "PR fetch fails", @@ -1509,7 +1575,7 @@ func Test_GetPullRequestCheckRuns(t *testing.T) { expectedErrMsg: "failed to get pull request", }, { - name: "check runs fetch fails", + name: "check runs fetch fails with non-403 error", mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ GetReposPullsByOwnerByRepoByPullNumber: mockResponse(t, http.StatusOK, mockPR), GetReposCommitsCheckRunsByOwnerByRepoByRef: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -1526,6 +1592,32 @@ func Test_GetPullRequestCheckRuns(t *testing.T) { expectError: true, expectedErrMsg: "failed to get check runs", }, + { + name: "returns permission guidance when all sources are denied", + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposPullsByOwnerByRepoByPullNumber: mockResponse(t, http.StatusOK, mockPR), + GetReposCommitsCheckRunsByOwnerByRepoByRef: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"message": "Resource not accessible by personal access token"}`)) + }), + GetReposActionsRunsByOwnerByRepo: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"message": "Resource not accessible by integration"}`)) + }), + GetReposCommitsStatusesByOwnerByRepoByRef: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"message": "Resource not accessible by integration"}`)) + }), + }), + requestArgs: map[string]any{ + "method": "get_check_runs", + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + }, + expectError: true, + expectedErrMsg: "Checks API (checks:read for GitHub Apps", + }, } for _, tc := range tests { @@ -1565,12 +1657,17 @@ func Test_GetPullRequestCheckRuns(t *testing.T) { var returnedCheckRuns MinimalCheckRunsResult err = json.Unmarshal([]byte(textContent.Text), &returnedCheckRuns) require.NoError(t, err) - assert.Equal(t, *tc.expectedCheckRuns.Total, returnedCheckRuns.TotalCount) - assert.Len(t, returnedCheckRuns.CheckRuns, len(tc.expectedCheckRuns.CheckRuns)) - for i, checkRun := range returnedCheckRuns.CheckRuns { - assert.Equal(t, *tc.expectedCheckRuns.CheckRuns[i].Name, checkRun.Name) - assert.Equal(t, *tc.expectedCheckRuns.CheckRuns[i].Status, checkRun.Status) - assert.Equal(t, *tc.expectedCheckRuns.CheckRuns[i].Conclusion, checkRun.Conclusion) + if tc.expectedSource != "" { + assert.Equal(t, tc.expectedSource, returnedCheckRuns.Source) + } + if tc.expectedCheckRuns != nil { + assert.Equal(t, *tc.expectedCheckRuns.Total, returnedCheckRuns.TotalCount) + assert.Len(t, returnedCheckRuns.CheckRuns, len(tc.expectedCheckRuns.CheckRuns)) + for i, checkRun := range returnedCheckRuns.CheckRuns { + assert.Equal(t, *tc.expectedCheckRuns.CheckRuns[i].Name, checkRun.Name) + assert.Equal(t, *tc.expectedCheckRuns.CheckRuns[i].Status, checkRun.Status) + assert.Equal(t, *tc.expectedCheckRuns.CheckRuns[i].Conclusion, checkRun.Conclusion) + } } }) }