|
1 | 1 | """Agent command for creating remote agent runs."""
|
2 | 2 |
|
3 | 3 | import json
|
| 4 | +from pathlib import Path |
4 | 5 |
|
5 | 6 | import requests
|
6 | 7 | import typer
|
|
13 | 14 | from codegen.cli.auth.token_manager import get_current_org_name, get_current_token
|
14 | 15 | from codegen.cli.rich.spinners import create_spinner
|
15 | 16 | from codegen.cli.utils.org import resolve_org_id
|
| 17 | +from codegen.git.repo_operator.local_git_repo import LocalGitRepo |
| 18 | +from codegen.git.repo_operator.repo_operator import RepoOperator |
| 19 | +from codegen.git.schemas.repo_config import RepoConfig |
16 | 20 |
|
17 | 21 | console = Console()
|
18 | 22 |
|
@@ -144,28 +148,33 @@ def agent_callback(ctx: typer.Context):
|
144 | 148 | raise typer.Exit()
|
145 | 149 |
|
146 | 150 |
|
147 |
| -# For backward compatibility, also allow `codegen agent --prompt "..."` and `codegen agent --id X --json` |
| 151 | +# For backward compatibility, also allow `codegen agent --prompt "..."`, `codegen agent --id X --json`, and `codegen agent --id X pull` |
148 | 152 | def agent(
|
| 153 | + action: str = typer.Argument(None, help="Action to perform: 'pull' to checkout PR branch"), |
149 | 154 | prompt: str | None = typer.Option(None, "--prompt", "-p", help="The prompt to send to the agent"),
|
150 |
| - agent_id: int | None = typer.Option(None, "--id", help="Agent run ID to fetch"), |
| 155 | + agent_id: int | None = typer.Option(None, "--id", help="Agent run ID to fetch or pull"), |
151 | 156 | as_json: bool = typer.Option(False, "--json", help="Output raw JSON response"),
|
152 | 157 | org_id: int | None = typer.Option(None, help="Organization ID (defaults to CODEGEN_ORG_ID/REPOSITORY_ORG_ID or auto-detect)"),
|
153 | 158 | model: str | None = typer.Option(None, help="Model to use for this agent run (optional)"),
|
154 | 159 | repo_id: int | None = typer.Option(None, help="Repository ID to use for this agent run (optional)"),
|
155 | 160 | ):
|
156 |
| - """Create a new agent run with the given prompt, or fetch an existing agent run by ID.""" |
| 161 | + """Create a new agent run with the given prompt, fetch an existing agent run by ID, or pull PR branch.""" |
157 | 162 | if prompt:
|
158 | 163 | # If prompt is provided, create the agent run
|
159 | 164 | create(prompt=prompt, org_id=org_id, model=model, repo_id=repo_id)
|
| 165 | + elif agent_id and action == "pull": |
| 166 | + # If agent ID and pull action provided, pull the PR branch |
| 167 | + pull(agent_id=agent_id, org_id=org_id) |
160 | 168 | elif agent_id:
|
161 | 169 | # If agent ID is provided, fetch the agent run
|
162 | 170 | get(agent_id=agent_id, as_json=as_json, org_id=org_id)
|
163 | 171 | else:
|
164 |
| - # If neither prompt nor agent_id, show help |
| 172 | + # If none of the above, show help |
165 | 173 | console.print("[red]Error:[/red] Either --prompt or --id is required")
|
166 | 174 | console.print("Usage:")
|
167 | 175 | console.print(" [cyan]codegen agent --prompt 'Your prompt here'[/cyan] # Create agent run")
|
168 | 176 | console.print(" [cyan]codegen agent --id 123 --json[/cyan] # Fetch agent run as JSON")
|
| 177 | + console.print(" [cyan]codegen agent --id 123 pull[/cyan] # Pull PR branch") |
169 | 178 | raise typer.Exit(1)
|
170 | 179 |
|
171 | 180 |
|
@@ -232,3 +241,231 @@ def get(
|
232 | 241 | except Exception as e:
|
233 | 242 | console.print(f"[red]Unexpected error:[/red] {e}")
|
234 | 243 | raise typer.Exit(1)
|
| 244 | + |
| 245 | + |
| 246 | +@agent_app.command() |
| 247 | +def pull( |
| 248 | + agent_id: int = typer.Option(..., "--id", help="Agent run ID to pull PR branch for"), |
| 249 | + org_id: int | None = typer.Option(None, help="Organization ID (defaults to CODEGEN_ORG_ID/REPOSITORY_ORG_ID or auto-detect)"), |
| 250 | +): |
| 251 | + """Fetch and checkout the PR branch for an agent run.""" |
| 252 | + token = get_current_token() |
| 253 | + if not token: |
| 254 | + console.print("[red]Error:[/red] Not authenticated. Please run 'codegen login' first.") |
| 255 | + raise typer.Exit(1) |
| 256 | + |
| 257 | + resolved_org_id = resolve_org_id(org_id) |
| 258 | + if resolved_org_id is None: |
| 259 | + console.print("[red]Error:[/red] Organization ID not provided. Pass --org-id, set CODEGEN_ORG_ID, or REPOSITORY_ORG_ID.") |
| 260 | + raise typer.Exit(1) |
| 261 | + |
| 262 | + # Check if we're in a git repository |
| 263 | + try: |
| 264 | + current_repo = LocalGitRepo(Path.cwd()) |
| 265 | + if not current_repo.has_remote(): |
| 266 | + console.print("[red]Error:[/red] Current directory is not a git repository with remotes.") |
| 267 | + raise typer.Exit(1) |
| 268 | + except Exception: |
| 269 | + console.print("[red]Error:[/red] Current directory is not a valid git repository.") |
| 270 | + raise typer.Exit(1) |
| 271 | + |
| 272 | + # Fetch agent run data |
| 273 | + spinner = create_spinner(f"Fetching agent run {agent_id}...") |
| 274 | + spinner.start() |
| 275 | + |
| 276 | + try: |
| 277 | + headers = {"Authorization": f"Bearer {token}"} |
| 278 | + url = f"{API_ENDPOINT.rstrip('/')}/v1/organizations/{resolved_org_id}/agent/run/{agent_id}" |
| 279 | + response = requests.get(url, headers=headers) |
| 280 | + response.raise_for_status() |
| 281 | + agent_data = response.json() |
| 282 | + except requests.HTTPError as e: |
| 283 | + org_name = get_current_org_name() |
| 284 | + org_display = f"{org_name} ({resolved_org_id})" if org_name else f"organization {resolved_org_id}" |
| 285 | + |
| 286 | + if e.response.status_code == 404: |
| 287 | + console.print(f"[red]Error:[/red] Agent run {agent_id} not found in {org_display}.") |
| 288 | + elif e.response.status_code == 403: |
| 289 | + console.print(f"[red]Error:[/red] Access denied to agent run {agent_id} in {org_display}. Check your permissions.") |
| 290 | + else: |
| 291 | + console.print(f"[red]Error:[/red] HTTP {e.response.status_code}: {e}") |
| 292 | + raise typer.Exit(1) |
| 293 | + except requests.RequestException as e: |
| 294 | + console.print(f"[red]Error fetching agent run:[/red] {e}") |
| 295 | + raise typer.Exit(1) |
| 296 | + finally: |
| 297 | + spinner.stop() |
| 298 | + |
| 299 | + # Check if agent run has PRs |
| 300 | + github_prs = agent_data.get("github_pull_requests", []) |
| 301 | + if not github_prs: |
| 302 | + console.print(f"[yellow]Warning:[/yellow] Agent run {agent_id} has no associated pull requests.") |
| 303 | + raise typer.Exit(1) |
| 304 | + |
| 305 | + if len(github_prs) > 1: |
| 306 | + console.print(f"[yellow]Warning:[/yellow] Agent run {agent_id} has multiple PRs. Using the first one.") |
| 307 | + |
| 308 | + pr = github_prs[0] |
| 309 | + pr_url = pr.get("url") |
| 310 | + head_branch_name = pr.get("head_branch_name") |
| 311 | + |
| 312 | + if not pr_url: |
| 313 | + console.print("[red]Error:[/red] PR URL not found in agent run data.") |
| 314 | + raise typer.Exit(1) |
| 315 | + |
| 316 | + if not head_branch_name: |
| 317 | + # Try to extract branch name from PR URL as fallback |
| 318 | + # GitHub PR URLs often follow patterns like: |
| 319 | + # https://github.com/owner/repo/pull/123 |
| 320 | + # We can use GitHub API to get the branch name |
| 321 | + console.print("[yellow]Info:[/yellow] HEAD branch name not in API response, attempting to fetch from GitHub...") |
| 322 | + try: |
| 323 | + # Extract owner, repo, and PR number from PR URL manually |
| 324 | + # Expected format: https://github.com/owner/repo/pull/123 |
| 325 | + if not pr_url.startswith("https://github.com/"): |
| 326 | + msg = f"Only GitHub URLs are supported, got: {pr_url}" |
| 327 | + raise ValueError(msg) |
| 328 | + |
| 329 | + # Remove the GitHub base and split the path |
| 330 | + path_parts = pr_url.replace("https://github.com/", "").split("/") |
| 331 | + if len(path_parts) < 4 or path_parts[2] != "pull": |
| 332 | + msg = f"Invalid GitHub PR URL format: {pr_url}" |
| 333 | + raise ValueError(msg) |
| 334 | + |
| 335 | + owner = path_parts[0] |
| 336 | + repo = path_parts[1] |
| 337 | + pr_number = path_parts[3] |
| 338 | + |
| 339 | + # Use GitHub API to get PR details |
| 340 | + import requests as github_requests |
| 341 | + |
| 342 | + github_api_url = f"https://api.github.com/repos/{owner}/{repo}/pulls/{pr_number}" |
| 343 | + |
| 344 | + github_response = github_requests.get(github_api_url) |
| 345 | + if github_response.status_code == 200: |
| 346 | + pr_data = github_response.json() |
| 347 | + head_branch_name = pr_data.get("head", {}).get("ref") |
| 348 | + if head_branch_name: |
| 349 | + console.print(f"[green]✓ Found branch name from GitHub API:[/green] {head_branch_name}") |
| 350 | + else: |
| 351 | + console.print("[red]Error:[/red] Could not extract branch name from GitHub API response.") |
| 352 | + raise typer.Exit(1) |
| 353 | + else: |
| 354 | + console.print(f"[red]Error:[/red] Failed to fetch PR details from GitHub API (status: {github_response.status_code})") |
| 355 | + console.print("[yellow]Tip:[/yellow] The PR may be private or the GitHub API rate limit may be exceeded.") |
| 356 | + raise typer.Exit(1) |
| 357 | + except Exception as e: |
| 358 | + console.print(f"[red]Error:[/red] Could not fetch branch name from GitHub: {e}") |
| 359 | + console.print("[yellow]Tip:[/yellow] The backend may need to be updated to include branch information.") |
| 360 | + raise typer.Exit(1) |
| 361 | + |
| 362 | + # Parse PR URL to get repository information |
| 363 | + try: |
| 364 | + # Extract owner and repo from PR URL manually |
| 365 | + # Expected format: https://github.com/owner/repo/pull/123 |
| 366 | + if not pr_url.startswith("https://github.com/"): |
| 367 | + msg = f"Only GitHub URLs are supported, got: {pr_url}" |
| 368 | + raise ValueError(msg) |
| 369 | + |
| 370 | + # Remove the GitHub base and split the path |
| 371 | + path_parts = pr_url.replace("https://github.com/", "").split("/") |
| 372 | + if len(path_parts) < 4 or path_parts[2] != "pull": |
| 373 | + msg = f"Invalid GitHub PR URL format: {pr_url}" |
| 374 | + raise ValueError(msg) |
| 375 | + |
| 376 | + owner = path_parts[0] |
| 377 | + repo = path_parts[1] |
| 378 | + pr_repo_full_name = f"{owner}/{repo}" |
| 379 | + except Exception as e: |
| 380 | + console.print(f"[red]Error:[/red] Could not parse PR URL: {pr_url} - {e}") |
| 381 | + raise typer.Exit(1) |
| 382 | + |
| 383 | + # Check if current repository matches PR repository |
| 384 | + current_repo_full_name = current_repo.full_name |
| 385 | + if not current_repo_full_name: |
| 386 | + console.print("[red]Error:[/red] Could not determine current repository name.") |
| 387 | + raise typer.Exit(1) |
| 388 | + |
| 389 | + if current_repo_full_name.lower() != pr_repo_full_name.lower(): |
| 390 | + console.print("[red]Error:[/red] Repository mismatch!") |
| 391 | + console.print(f" Current repo: [cyan]{current_repo_full_name}[/cyan]") |
| 392 | + console.print(f" PR repo: [cyan]{pr_repo_full_name}[/cyan]") |
| 393 | + console.print("[yellow]Tip:[/yellow] Navigate to the correct repository directory first.") |
| 394 | + raise typer.Exit(1) |
| 395 | + |
| 396 | + # Perform git operations with safety checks |
| 397 | + try: |
| 398 | + repo_config = RepoConfig.from_repo_path(str(Path.cwd())) |
| 399 | + repo_operator = RepoOperator(repo_config) |
| 400 | + |
| 401 | + # Safety check: warn if repository has uncommitted changes |
| 402 | + if repo_operator.git_cli.is_dirty(): |
| 403 | + console.print("[yellow]⚠️ Warning:[/yellow] You have uncommitted changes in your repository.") |
| 404 | + console.print("These changes may be lost when switching branches.") |
| 405 | + |
| 406 | + # Get user confirmation |
| 407 | + confirm = typer.confirm("Do you want to continue? Your changes may be lost.") |
| 408 | + if not confirm: |
| 409 | + console.print("[yellow]Operation cancelled.[/yellow]") |
| 410 | + raise typer.Exit(0) |
| 411 | + |
| 412 | + console.print("[blue]Proceeding with branch checkout...[/blue]") |
| 413 | + |
| 414 | + console.print(f"[blue]Repository match confirmed:[/blue] {current_repo_full_name}") |
| 415 | + console.print(f"[blue]Fetching and checking out branch:[/blue] {head_branch_name}") |
| 416 | + |
| 417 | + # Fetch the branch from remote |
| 418 | + fetch_spinner = create_spinner("Fetching latest changes from remote...") |
| 419 | + fetch_spinner.start() |
| 420 | + try: |
| 421 | + fetch_result = repo_operator.fetch_remote("origin") |
| 422 | + if fetch_result.name != "SUCCESS": |
| 423 | + console.print(f"[yellow]Warning:[/yellow] Fetch result: {fetch_result.name}") |
| 424 | + except Exception as e: |
| 425 | + console.print(f"[red]Error during fetch:[/red] {e}") |
| 426 | + raise |
| 427 | + finally: |
| 428 | + fetch_spinner.stop() |
| 429 | + |
| 430 | + # Check if the branch already exists locally |
| 431 | + local_branches = [b.name for b in repo_operator.git_cli.branches] |
| 432 | + if head_branch_name in local_branches: |
| 433 | + console.print(f"[yellow]Info:[/yellow] Local branch '{head_branch_name}' already exists. It will be reset to match the remote.") |
| 434 | + |
| 435 | + # Checkout the remote branch |
| 436 | + checkout_spinner = create_spinner(f"Checking out branch {head_branch_name}...") |
| 437 | + checkout_spinner.start() |
| 438 | + try: |
| 439 | + checkout_result = repo_operator.checkout_remote_branch(head_branch_name) |
| 440 | + if checkout_result.name == "SUCCESS": |
| 441 | + console.print(f"[green]✓ Successfully checked out branch:[/green] {head_branch_name}") |
| 442 | + elif checkout_result.name == "NOT_FOUND": |
| 443 | + console.print(f"[red]Error:[/red] Branch {head_branch_name} not found on remote.") |
| 444 | + console.print("[yellow]Tip:[/yellow] The branch may have been deleted or renamed.") |
| 445 | + raise typer.Exit(1) |
| 446 | + else: |
| 447 | + console.print(f"[yellow]Warning:[/yellow] Checkout result: {checkout_result.name}") |
| 448 | + except Exception as e: |
| 449 | + console.print(f"[red]Error during checkout:[/red] {e}") |
| 450 | + raise |
| 451 | + finally: |
| 452 | + checkout_spinner.stop() |
| 453 | + |
| 454 | + # Display success info |
| 455 | + console.print( |
| 456 | + Panel( |
| 457 | + f"[green]✓ Successfully pulled PR branch![/green]\n\n" |
| 458 | + f"[cyan]Agent Run:[/cyan] {agent_id}\n" |
| 459 | + f"[cyan]Repository:[/cyan] {current_repo_full_name}\n" |
| 460 | + f"[cyan]Branch:[/cyan] {head_branch_name}\n" |
| 461 | + f"[cyan]PR URL:[/cyan] {pr_url}", |
| 462 | + title="🌿 [bold]Branch Checkout Complete[/bold]", |
| 463 | + border_style="green", |
| 464 | + box=box.ROUNDED, |
| 465 | + padding=(1, 2), |
| 466 | + ) |
| 467 | + ) |
| 468 | + |
| 469 | + except Exception as e: |
| 470 | + console.print(f"[red]Error during git operations:[/red] {e}") |
| 471 | + raise typer.Exit(1) |
0 commit comments