Skip to content

Commit 81c9e7b

Browse files
authored
Adds codegen agent --id pull to pull branches locally (#1198)
1 parent bb6343b commit 81c9e7b

File tree

1 file changed

+241
-4
lines changed
  • src/codegen/cli/commands/agent

1 file changed

+241
-4
lines changed

src/codegen/cli/commands/agent/main.py

Lines changed: 241 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Agent command for creating remote agent runs."""
22

33
import json
4+
from pathlib import Path
45

56
import requests
67
import typer
@@ -13,6 +14,9 @@
1314
from codegen.cli.auth.token_manager import get_current_org_name, get_current_token
1415
from codegen.cli.rich.spinners import create_spinner
1516
from codegen.cli.utils.org import resolve_org_id
17+
from codegen.git.repo_operator.local_git_repo import LocalGitRepo
18+
from codegen.git.repo_operator.repo_operator import RepoOperator
19+
from codegen.git.schemas.repo_config import RepoConfig
1620

1721
console = Console()
1822

@@ -144,28 +148,33 @@ def agent_callback(ctx: typer.Context):
144148
raise typer.Exit()
145149

146150

147-
# For backward compatibility, also allow `codegen agent --prompt "..."` and `codegen agent --id X --json`
151+
# For backward compatibility, also allow `codegen agent --prompt "..."`, `codegen agent --id X --json`, and `codegen agent --id X pull`
148152
def agent(
153+
action: str = typer.Argument(None, help="Action to perform: 'pull' to checkout PR branch"),
149154
prompt: str | None = typer.Option(None, "--prompt", "-p", help="The prompt to send to the agent"),
150-
agent_id: int | None = typer.Option(None, "--id", help="Agent run ID to fetch"),
155+
agent_id: int | None = typer.Option(None, "--id", help="Agent run ID to fetch or pull"),
151156
as_json: bool = typer.Option(False, "--json", help="Output raw JSON response"),
152157
org_id: int | None = typer.Option(None, help="Organization ID (defaults to CODEGEN_ORG_ID/REPOSITORY_ORG_ID or auto-detect)"),
153158
model: str | None = typer.Option(None, help="Model to use for this agent run (optional)"),
154159
repo_id: int | None = typer.Option(None, help="Repository ID to use for this agent run (optional)"),
155160
):
156-
"""Create a new agent run with the given prompt, or fetch an existing agent run by ID."""
161+
"""Create a new agent run with the given prompt, fetch an existing agent run by ID, or pull PR branch."""
157162
if prompt:
158163
# If prompt is provided, create the agent run
159164
create(prompt=prompt, org_id=org_id, model=model, repo_id=repo_id)
165+
elif agent_id and action == "pull":
166+
# If agent ID and pull action provided, pull the PR branch
167+
pull(agent_id=agent_id, org_id=org_id)
160168
elif agent_id:
161169
# If agent ID is provided, fetch the agent run
162170
get(agent_id=agent_id, as_json=as_json, org_id=org_id)
163171
else:
164-
# If neither prompt nor agent_id, show help
172+
# If none of the above, show help
165173
console.print("[red]Error:[/red] Either --prompt or --id is required")
166174
console.print("Usage:")
167175
console.print(" [cyan]codegen agent --prompt 'Your prompt here'[/cyan] # Create agent run")
168176
console.print(" [cyan]codegen agent --id 123 --json[/cyan] # Fetch agent run as JSON")
177+
console.print(" [cyan]codegen agent --id 123 pull[/cyan] # Pull PR branch")
169178
raise typer.Exit(1)
170179

171180

@@ -232,3 +241,231 @@ def get(
232241
except Exception as e:
233242
console.print(f"[red]Unexpected error:[/red] {e}")
234243
raise typer.Exit(1)
244+
245+
246+
@agent_app.command()
247+
def pull(
248+
agent_id: int = typer.Option(..., "--id", help="Agent run ID to pull PR branch for"),
249+
org_id: int | None = typer.Option(None, help="Organization ID (defaults to CODEGEN_ORG_ID/REPOSITORY_ORG_ID or auto-detect)"),
250+
):
251+
"""Fetch and checkout the PR branch for an agent run."""
252+
token = get_current_token()
253+
if not token:
254+
console.print("[red]Error:[/red] Not authenticated. Please run 'codegen login' first.")
255+
raise typer.Exit(1)
256+
257+
resolved_org_id = resolve_org_id(org_id)
258+
if resolved_org_id is None:
259+
console.print("[red]Error:[/red] Organization ID not provided. Pass --org-id, set CODEGEN_ORG_ID, or REPOSITORY_ORG_ID.")
260+
raise typer.Exit(1)
261+
262+
# Check if we're in a git repository
263+
try:
264+
current_repo = LocalGitRepo(Path.cwd())
265+
if not current_repo.has_remote():
266+
console.print("[red]Error:[/red] Current directory is not a git repository with remotes.")
267+
raise typer.Exit(1)
268+
except Exception:
269+
console.print("[red]Error:[/red] Current directory is not a valid git repository.")
270+
raise typer.Exit(1)
271+
272+
# Fetch agent run data
273+
spinner = create_spinner(f"Fetching agent run {agent_id}...")
274+
spinner.start()
275+
276+
try:
277+
headers = {"Authorization": f"Bearer {token}"}
278+
url = f"{API_ENDPOINT.rstrip('/')}/v1/organizations/{resolved_org_id}/agent/run/{agent_id}"
279+
response = requests.get(url, headers=headers)
280+
response.raise_for_status()
281+
agent_data = response.json()
282+
except requests.HTTPError as e:
283+
org_name = get_current_org_name()
284+
org_display = f"{org_name} ({resolved_org_id})" if org_name else f"organization {resolved_org_id}"
285+
286+
if e.response.status_code == 404:
287+
console.print(f"[red]Error:[/red] Agent run {agent_id} not found in {org_display}.")
288+
elif e.response.status_code == 403:
289+
console.print(f"[red]Error:[/red] Access denied to agent run {agent_id} in {org_display}. Check your permissions.")
290+
else:
291+
console.print(f"[red]Error:[/red] HTTP {e.response.status_code}: {e}")
292+
raise typer.Exit(1)
293+
except requests.RequestException as e:
294+
console.print(f"[red]Error fetching agent run:[/red] {e}")
295+
raise typer.Exit(1)
296+
finally:
297+
spinner.stop()
298+
299+
# Check if agent run has PRs
300+
github_prs = agent_data.get("github_pull_requests", [])
301+
if not github_prs:
302+
console.print(f"[yellow]Warning:[/yellow] Agent run {agent_id} has no associated pull requests.")
303+
raise typer.Exit(1)
304+
305+
if len(github_prs) > 1:
306+
console.print(f"[yellow]Warning:[/yellow] Agent run {agent_id} has multiple PRs. Using the first one.")
307+
308+
pr = github_prs[0]
309+
pr_url = pr.get("url")
310+
head_branch_name = pr.get("head_branch_name")
311+
312+
if not pr_url:
313+
console.print("[red]Error:[/red] PR URL not found in agent run data.")
314+
raise typer.Exit(1)
315+
316+
if not head_branch_name:
317+
# Try to extract branch name from PR URL as fallback
318+
# GitHub PR URLs often follow patterns like:
319+
# https://github.com/owner/repo/pull/123
320+
# We can use GitHub API to get the branch name
321+
console.print("[yellow]Info:[/yellow] HEAD branch name not in API response, attempting to fetch from GitHub...")
322+
try:
323+
# Extract owner, repo, and PR number from PR URL manually
324+
# Expected format: https://github.com/owner/repo/pull/123
325+
if not pr_url.startswith("https://github.com/"):
326+
msg = f"Only GitHub URLs are supported, got: {pr_url}"
327+
raise ValueError(msg)
328+
329+
# Remove the GitHub base and split the path
330+
path_parts = pr_url.replace("https://github.com/", "").split("/")
331+
if len(path_parts) < 4 or path_parts[2] != "pull":
332+
msg = f"Invalid GitHub PR URL format: {pr_url}"
333+
raise ValueError(msg)
334+
335+
owner = path_parts[0]
336+
repo = path_parts[1]
337+
pr_number = path_parts[3]
338+
339+
# Use GitHub API to get PR details
340+
import requests as github_requests
341+
342+
github_api_url = f"https://api.github.com/repos/{owner}/{repo}/pulls/{pr_number}"
343+
344+
github_response = github_requests.get(github_api_url)
345+
if github_response.status_code == 200:
346+
pr_data = github_response.json()
347+
head_branch_name = pr_data.get("head", {}).get("ref")
348+
if head_branch_name:
349+
console.print(f"[green]✓ Found branch name from GitHub API:[/green] {head_branch_name}")
350+
else:
351+
console.print("[red]Error:[/red] Could not extract branch name from GitHub API response.")
352+
raise typer.Exit(1)
353+
else:
354+
console.print(f"[red]Error:[/red] Failed to fetch PR details from GitHub API (status: {github_response.status_code})")
355+
console.print("[yellow]Tip:[/yellow] The PR may be private or the GitHub API rate limit may be exceeded.")
356+
raise typer.Exit(1)
357+
except Exception as e:
358+
console.print(f"[red]Error:[/red] Could not fetch branch name from GitHub: {e}")
359+
console.print("[yellow]Tip:[/yellow] The backend may need to be updated to include branch information.")
360+
raise typer.Exit(1)
361+
362+
# Parse PR URL to get repository information
363+
try:
364+
# Extract owner and repo from PR URL manually
365+
# Expected format: https://github.com/owner/repo/pull/123
366+
if not pr_url.startswith("https://github.com/"):
367+
msg = f"Only GitHub URLs are supported, got: {pr_url}"
368+
raise ValueError(msg)
369+
370+
# Remove the GitHub base and split the path
371+
path_parts = pr_url.replace("https://github.com/", "").split("/")
372+
if len(path_parts) < 4 or path_parts[2] != "pull":
373+
msg = f"Invalid GitHub PR URL format: {pr_url}"
374+
raise ValueError(msg)
375+
376+
owner = path_parts[0]
377+
repo = path_parts[1]
378+
pr_repo_full_name = f"{owner}/{repo}"
379+
except Exception as e:
380+
console.print(f"[red]Error:[/red] Could not parse PR URL: {pr_url} - {e}")
381+
raise typer.Exit(1)
382+
383+
# Check if current repository matches PR repository
384+
current_repo_full_name = current_repo.full_name
385+
if not current_repo_full_name:
386+
console.print("[red]Error:[/red] Could not determine current repository name.")
387+
raise typer.Exit(1)
388+
389+
if current_repo_full_name.lower() != pr_repo_full_name.lower():
390+
console.print("[red]Error:[/red] Repository mismatch!")
391+
console.print(f" Current repo: [cyan]{current_repo_full_name}[/cyan]")
392+
console.print(f" PR repo: [cyan]{pr_repo_full_name}[/cyan]")
393+
console.print("[yellow]Tip:[/yellow] Navigate to the correct repository directory first.")
394+
raise typer.Exit(1)
395+
396+
# Perform git operations with safety checks
397+
try:
398+
repo_config = RepoConfig.from_repo_path(str(Path.cwd()))
399+
repo_operator = RepoOperator(repo_config)
400+
401+
# Safety check: warn if repository has uncommitted changes
402+
if repo_operator.git_cli.is_dirty():
403+
console.print("[yellow]⚠️ Warning:[/yellow] You have uncommitted changes in your repository.")
404+
console.print("These changes may be lost when switching branches.")
405+
406+
# Get user confirmation
407+
confirm = typer.confirm("Do you want to continue? Your changes may be lost.")
408+
if not confirm:
409+
console.print("[yellow]Operation cancelled.[/yellow]")
410+
raise typer.Exit(0)
411+
412+
console.print("[blue]Proceeding with branch checkout...[/blue]")
413+
414+
console.print(f"[blue]Repository match confirmed:[/blue] {current_repo_full_name}")
415+
console.print(f"[blue]Fetching and checking out branch:[/blue] {head_branch_name}")
416+
417+
# Fetch the branch from remote
418+
fetch_spinner = create_spinner("Fetching latest changes from remote...")
419+
fetch_spinner.start()
420+
try:
421+
fetch_result = repo_operator.fetch_remote("origin")
422+
if fetch_result.name != "SUCCESS":
423+
console.print(f"[yellow]Warning:[/yellow] Fetch result: {fetch_result.name}")
424+
except Exception as e:
425+
console.print(f"[red]Error during fetch:[/red] {e}")
426+
raise
427+
finally:
428+
fetch_spinner.stop()
429+
430+
# Check if the branch already exists locally
431+
local_branches = [b.name for b in repo_operator.git_cli.branches]
432+
if head_branch_name in local_branches:
433+
console.print(f"[yellow]Info:[/yellow] Local branch '{head_branch_name}' already exists. It will be reset to match the remote.")
434+
435+
# Checkout the remote branch
436+
checkout_spinner = create_spinner(f"Checking out branch {head_branch_name}...")
437+
checkout_spinner.start()
438+
try:
439+
checkout_result = repo_operator.checkout_remote_branch(head_branch_name)
440+
if checkout_result.name == "SUCCESS":
441+
console.print(f"[green]✓ Successfully checked out branch:[/green] {head_branch_name}")
442+
elif checkout_result.name == "NOT_FOUND":
443+
console.print(f"[red]Error:[/red] Branch {head_branch_name} not found on remote.")
444+
console.print("[yellow]Tip:[/yellow] The branch may have been deleted or renamed.")
445+
raise typer.Exit(1)
446+
else:
447+
console.print(f"[yellow]Warning:[/yellow] Checkout result: {checkout_result.name}")
448+
except Exception as e:
449+
console.print(f"[red]Error during checkout:[/red] {e}")
450+
raise
451+
finally:
452+
checkout_spinner.stop()
453+
454+
# Display success info
455+
console.print(
456+
Panel(
457+
f"[green]✓ Successfully pulled PR branch![/green]\n\n"
458+
f"[cyan]Agent Run:[/cyan] {agent_id}\n"
459+
f"[cyan]Repository:[/cyan] {current_repo_full_name}\n"
460+
f"[cyan]Branch:[/cyan] {head_branch_name}\n"
461+
f"[cyan]PR URL:[/cyan] {pr_url}",
462+
title="🌿 [bold]Branch Checkout Complete[/bold]",
463+
border_style="green",
464+
box=box.ROUNDED,
465+
padding=(1, 2),
466+
)
467+
)
468+
469+
except Exception as e:
470+
console.print(f"[red]Error during git operations:[/red] {e}")
471+
raise typer.Exit(1)

0 commit comments

Comments
 (0)