diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index d6dd3f96..2d53db33 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -1606,3 +1606,127 @@ func newGQLIntPtr(i *int32) *githubv4.Int { gi := githubv4.Int(*i) return &gi } + +// SetPRStatus creates a tool to set pull request status between draft and ready-for-review states. +// This uses the GraphQL API because the REST API does not support changing PR draft status. +func SetPRStatus(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { + return mcp.NewTool("set_pr_status", + mcp.WithDescription(t("TOOL_SET_PR_STATUS_DESCRIPTION", "Set pull request status between draft and ready-for-review states. Use this to change a pull request from draft to ready-for-review or vice versa.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_SET_PR_STATUS_USER_TITLE", "Set pull request status"), + ReadOnlyHint: toBoolPtr(false), + }), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithNumber("pullNumber", + mcp.Required(), + mcp.Description("Pull request number"), + ), + mcp.WithString("status", + mcp.Required(), + mcp.Description("Target status for the pull request"), + mcp.Enum("draft", "ready_for_review"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + var params struct { + Owner string + Repo string + PullNumber int32 + Status string + } + if err := mapstructure.Decode(request.Params.Arguments, ¶ms); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + // Validate status parameter + if params.Status != "draft" && params.Status != "ready_for_review" { + return mcp.NewToolResultError("status must be either 'draft' or 'ready_for_review'"), nil + } + + // Get the GraphQL client + client, err := getGQLClient(ctx) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to get GitHub GraphQL client: %v", err)), nil + } + + // First, we need to get the GraphQL ID of the pull request and its current status + var getPullRequestQuery struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + + vars := map[string]any{ + "owner": githubv4.String(params.Owner), + "repo": githubv4.String(params.Repo), + "prNum": githubv4.Int(params.PullNumber), + } + + if err := client.Query(ctx, &getPullRequestQuery, vars); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request: %v", err)), nil + } + + currentIsDraft := bool(getPullRequestQuery.Repository.PullRequest.IsDraft) + targetIsDraft := params.Status == "draft" + + // Check if the PR is already in the target state + if currentIsDraft == targetIsDraft { + if targetIsDraft { + return mcp.NewToolResultText("Pull request is already in draft state"), nil + } else { + return mcp.NewToolResultText("Pull request is already marked as ready for review"), nil + } + } + + // Perform the appropriate mutation based on target status + if targetIsDraft { + // Convert to draft + var convertToDraftMutation struct { + ConvertPullRequestToDraft struct { + PullRequest struct { + ID githubv4.ID // Required by GraphQL schema, but not used in response + } + } `graphql:"convertPullRequestToDraft(input: $input)"` + } + + input := githubv4.ConvertPullRequestToDraftInput{ + PullRequestID: getPullRequestQuery.Repository.PullRequest.ID, + } + + if err := client.Mutate(ctx, &convertToDraftMutation, input, nil); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to convert pull request to draft: %v", err)), nil + } + + return mcp.NewToolResultText("Pull request successfully converted to draft"), nil + } else { + // Mark as ready for review + var markReadyForReviewMutation struct { + MarkPullRequestReadyForReview struct { + PullRequest struct { + ID githubv4.ID // Required by GraphQL schema, but not used in response + } + } `graphql:"markPullRequestReadyForReview(input: $input)"` + } + + input := githubv4.MarkPullRequestReadyForReviewInput{ + PullRequestID: getPullRequestQuery.Repository.PullRequest.ID, + } + + if err := client.Mutate(ctx, &markReadyForReviewMutation, input, nil); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to mark pull request as ready for review: %v", err)), nil + } + + return mcp.NewToolResultText("Pull request successfully marked as ready for review"), nil + } + } +} diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index 6202ec16..5576eff7 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -2291,3 +2291,389 @@ func getLatestPendingReviewQuery(p getLatestPendingReviewQueryParams) githubv4mo ), ) } + +func TestSetPRStatus(t *testing.T) { + t.Parallel() + + // Verify tool definition once + mockClient := githubv4.NewClient(nil) + tool, _ := SetPRStatus(stubGetGQLClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "set_pr_status", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "pullNumber") + assert.Contains(t, tool.InputSchema.Properties, "status") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber", "status"}) + + // Verify enum validation for status parameter + statusProperty := tool.InputSchema.Properties["status"].(map[string]any) + assert.Contains(t, statusProperty, "enum") + enumValues := statusProperty["enum"].([]string) + assert.ElementsMatch(t, enumValues, []string{"draft", "ready_for_review"}) + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]any + expectToolError bool + expectedToolErrMsg string + expectedMessage string + }{ + { + name: "successful draft to ready conversion", + mockedClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + }{}, + map[string]any{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "prNum": githubv4.Int(42), + }, + githubv4mock.DataResponse( + map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDODKw3uc6WYN1T", + "isDraft": true, + }, + }, + }, + ), + ), + githubv4mock.NewMutationMatcher( + struct { + MarkPullRequestReadyForReview struct { + PullRequest struct { + ID githubv4.ID + } + } `graphql:"markPullRequestReadyForReview(input: $input)"` + }{}, + githubv4.MarkPullRequestReadyForReviewInput{ + PullRequestID: githubv4.ID("PR_kwDODKw3uc6WYN1T"), + }, + nil, + githubv4mock.DataResponse(map[string]any{}), + ), + ), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "status": "ready_for_review", + }, + expectToolError: false, + expectedMessage: "Pull request successfully marked as ready for review", + }, + { + name: "successful ready to draft conversion", + mockedClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + }{}, + map[string]any{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "prNum": githubv4.Int(42), + }, + githubv4mock.DataResponse( + map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDODKw3uc6WYN1T", + "isDraft": false, + }, + }, + }, + ), + ), + githubv4mock.NewMutationMatcher( + struct { + ConvertPullRequestToDraft struct { + PullRequest struct { + ID githubv4.ID + } + } `graphql:"convertPullRequestToDraft(input: $input)"` + }{}, + githubv4.ConvertPullRequestToDraftInput{ + PullRequestID: githubv4.ID("PR_kwDODKw3uc6WYN1T"), + }, + nil, + githubv4mock.DataResponse(map[string]any{}), + ), + ), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "status": "draft", + }, + expectToolError: false, + expectedMessage: "Pull request successfully converted to draft", + }, + { + name: "no change - already draft", + mockedClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + }{}, + map[string]any{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "prNum": githubv4.Int(42), + }, + githubv4mock.DataResponse( + map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDODKw3uc6WYN1T", + "isDraft": true, + }, + }, + }, + ), + ), + ), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "status": "draft", + }, + expectToolError: false, + expectedMessage: "Pull request is already in draft state", + }, + { + name: "no change - already ready", + mockedClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + }{}, + map[string]any{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "prNum": githubv4.Int(42), + }, + githubv4mock.DataResponse( + map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDODKw3uc6WYN1T", + "isDraft": false, + }, + }, + }, + ), + ), + ), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "status": "ready_for_review", + }, + expectToolError: false, + expectedMessage: "Pull request is already marked as ready for review", + }, + { + name: "invalid status enum", + mockedClient: githubv4mock.NewMockedHTTPClient(), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "status": "invalid_status", + }, + expectToolError: true, + expectedToolErrMsg: "status must be either 'draft' or 'ready_for_review'", + }, + { + name: "GraphQL query failure", + mockedClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + }{}, + map[string]any{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "prNum": githubv4.Int(999), + }, + githubv4mock.ErrorResponse("pull request not found"), + ), + ), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(999), + "status": "ready_for_review", + }, + expectToolError: true, + expectedToolErrMsg: "failed to get pull request: pull request not found", + }, + { + name: "GraphQL mutation failure - mark ready", + mockedClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + }{}, + map[string]any{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "prNum": githubv4.Int(42), + }, + githubv4mock.DataResponse( + map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDODKw3uc6WYN1T", + "isDraft": true, + }, + }, + }, + ), + ), + githubv4mock.NewMutationMatcher( + struct { + MarkPullRequestReadyForReview struct { + PullRequest struct { + ID githubv4.ID + } + } `graphql:"markPullRequestReadyForReview(input: $input)"` + }{}, + githubv4.MarkPullRequestReadyForReviewInput{ + PullRequestID: githubv4.ID("PR_kwDODKw3uc6WYN1T"), + }, + nil, + githubv4mock.ErrorResponse("insufficient permissions"), + ), + ), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "status": "ready_for_review", + }, + expectToolError: true, + expectedToolErrMsg: "failed to mark pull request as ready for review: insufficient permissions", + }, + { + name: "GraphQL mutation failure - convert to draft", + mockedClient: githubv4mock.NewMockedHTTPClient( + githubv4mock.NewQueryMatcher( + struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + IsDraft githubv4.Boolean + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + }{}, + map[string]any{ + "owner": githubv4.String("owner"), + "repo": githubv4.String("repo"), + "prNum": githubv4.Int(42), + }, + githubv4mock.DataResponse( + map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "id": "PR_kwDODKw3uc6WYN1T", + "isDraft": false, + }, + }, + }, + ), + ), + githubv4mock.NewMutationMatcher( + struct { + ConvertPullRequestToDraft struct { + PullRequest struct { + ID githubv4.ID + } + } `graphql:"convertPullRequestToDraft(input: $input)"` + }{}, + githubv4.ConvertPullRequestToDraftInput{ + PullRequestID: githubv4.ID("PR_kwDODKw3uc6WYN1T"), + }, + nil, + githubv4mock.ErrorResponse("insufficient permissions"), + ), + ), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumber": float64(42), + "status": "draft", + }, + expectToolError: true, + expectedToolErrMsg: "failed to convert pull request to draft: insufficient permissions", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Setup client with mock + client := githubv4.NewClient(tc.mockedClient) + _, handler := SetPRStatus(stubGetGQLClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + require.NoError(t, err) + + textContent := getTextResult(t, result) + + if tc.expectToolError { + require.True(t, result.IsError) + assert.Contains(t, textContent.Text, tc.expectedToolErrMsg) + return + } + + // Verify success message + require.False(t, result.IsError) + assert.Equal(t, tc.expectedMessage, textContent.Text) + }) + } +} diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 9c1ab34a..298c8c97 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -72,6 +72,7 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, toolsets.NewServerTool(UpdatePullRequestBranch(getClient, t)), toolsets.NewServerTool(CreatePullRequest(getClient, t)), toolsets.NewServerTool(UpdatePullRequest(getClient, t)), + toolsets.NewServerTool(SetPRStatus(getGQLClient, t)), toolsets.NewServerTool(RequestCopilotReview(getClient, t)), // Reviews