Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/agents/repository_analysis_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ async def execute(self, **kwargs) -> AgentResult:
await analyze_pr_history(state, request.max_prs)
await analyze_contributing_guidelines(state)

# Only generate recommendations if we have basic repository data
if not state.repository_features.language:
raise ValueError("Unable to determine repository language - cannot generate appropriate rules")

state.recommendations = _default_recommendations(state)
validate_recommendations(state)
response = summarize_analysis(state, request)
Expand Down
74 changes: 60 additions & 14 deletions src/agents/repository_analysis_agent/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ async def analyze_repository_structure(state: RepositoryAnalysisState) -> None:
installation_id = state.installation_id

repo_data = await github_client.get_repository(repo, installation_id=installation_id)
if not repo_data:
raise ValueError(f"Could not fetch repository data for {repo}")

workflows = await github_client.list_directory_any_auth(
repo_full_name=repo, path=".github/workflows", installation_id=installation_id
)
Expand All @@ -42,7 +45,7 @@ async def analyze_repository_structure(state: RepositoryAnalysisState) -> None:
has_codeowners=bool(await github_client.get_file_content(repo, ".github/CODEOWNERS", installation_id)),
has_workflows=bool(workflows),
workflow_count=len(workflows or []),
language=(repo_data or {}).get("language"),
language=repo_data.get("language"),
contributor_count=len(contributors),
pr_count=0,
)
Expand All @@ -54,8 +57,14 @@ async def analyze_pr_history(state: RepositoryAnalysisState, max_prs: int) -> No
installation_id = state.installation_id
prs = await github_client.list_pull_requests(repo, installation_id=installation_id, state="all", per_page=max_prs)

if prs is None:
# If PR listing fails, continue with empty samples rather than failing
state.pr_samples = []
state.repository_features.pr_count = 0
return

samples: list[PullRequestSample] = []
for pr in prs or []:
for pr in prs:
samples.append(
PullRequestSample(
number=pr.get("number", 0),
Expand Down Expand Up @@ -215,19 +224,27 @@ def _default_recommendations(

Currently, validators like `author_team_is` and `file_patterns` operate independently.
"""
import logging
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to PEP 8 style guidelines, imports should be placed at the top of the file. Please move import logging to the top level of the module to improve readability and adhere to standard Python conventions.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@naaa760 import should on top

logger = logging.getLogger(__name__)

recommendations: list[RuleRecommendation] = []

# Get language-specific patterns based on repository analysis
source_patterns, test_patterns = _get_language_specific_patterns(state.repository_features.language)
language = state.repository_features.language
source_patterns, test_patterns = _get_language_specific_patterns(language)

logger.info(f"Generating recommendations for {state.repository_full_name}: language={language}, pr_count={state.repository_features.pr_count}")

# Analyze PR history for bad habits
pr_issues = _analyze_pr_bad_habits(state)

# Require tests when source code changes.
# This is especially important if we detect missing tests in PR history
test_reasoning = f"Default guardrail for code changes without tests. Patterns adapted for {state.repository_features.language or 'multi-language'} repository."
test_reasoning = f"Repository analysis for {state.repository_full_name}. Language: {language or 'unknown'}. Patterns adapted for {language or 'multi-language'} repository."
if pr_issues.get("missing_tests", 0) > 0:
test_reasoning += f" Detected {pr_issues['missing_tests']} recent PRs without test files."
if state.contributing_analysis.content and state.contributing_analysis.requires_tests:
test_reasoning += " Contributing guidelines explicitly require tests."

# Build YAML rule with proper indentation
# parameters: is at column 0, source_patterns: at column 2, list items at column 4
Expand All @@ -246,20 +263,32 @@ def _default_recommendations(
{test_patterns_yaml}
"""

confidence = 0.74
if pr_issues.get("missing_tests", 0) > 0:
confidence = 0.85
if state.contributing_analysis.content and state.contributing_analysis.requires_tests:
confidence = min(0.95, confidence + 0.1)

recommendations.append(
RuleRecommendation(
yaml_rule=yaml_content.strip(),
confidence=0.74 if pr_issues.get("missing_tests", 0) == 0 else 0.85,
confidence=confidence,
reasoning=test_reasoning,
strategy_used="hybrid",
)
)

# Require description in PR body.
# Increase confidence if we detect short titles in PR history (indicator of missing context)
desc_reasoning = "Encourage context for reviewers; lightweight default."
desc_reasoning = f"Repository analysis for {state.repository_full_name}."
if pr_issues.get("short_titles", 0) > 0:
desc_reasoning += f" Detected {pr_issues['short_titles']} PRs with very short titles (likely missing context)."
else:
desc_reasoning += " Encourages context for reviewers; lightweight default."

desc_confidence = 0.68
if pr_issues.get("short_titles", 0) > 0:
desc_confidence = 0.80

recommendations.append(
RuleRecommendation(
Expand All @@ -274,20 +303,37 @@ def _default_recommendations(
min_description_length: 50
"""
).strip(),
confidence=0.68 if pr_issues.get("short_titles", 0) == 0 else 0.80,
confidence=desc_confidence,
reasoning=desc_reasoning,
strategy_used="static",
)
)

# If contributing guidelines require tests, increase confidence
if state.contributing_analysis.content is not None and state.contributing_analysis.requires_tests:
# Find the test rule and boost its confidence
for rec in recommendations:
if "tests" in rec.yaml_rule.lower():
rec.confidence = min(0.95, rec.confidence + 0.1)
rec.reasoning += " Contributing guidelines explicitly require tests."
# Add a repository-specific rule if we detect specific patterns
if state.repository_features.has_workflows:
workflow_rule = textwrap.dedent(
f"""
description: "Protect CI/CD workflows"
enabled: true
severity: high
event_types:
- pull_request
parameters:
file_patterns:
- ".github/workflows/**"
"""
).strip()

recommendations.append(
RuleRecommendation(
yaml_rule=workflow_rule,
confidence=0.90,
reasoning=f"Repository {state.repository_full_name} has {state.repository_features.workflow_count} workflows that should be protected.",
strategy_used="static",
)
)

logger.info(f"Generated {len(recommendations)} recommendations for {state.repository_full_name}")
return recommendations


Expand Down
99 changes: 92 additions & 7 deletions src/api/recommendations.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ async def recommend_rules(
if not request.repository_full_name or "/" not in request.repository_full_name:
raise HTTPException(status_code=400, detail="Invalid repository name format. Expected 'owner/repo'")

cache_key = f"repo_analysis:{request.repository_full_name}"
# Include authentication context in cache key to ensure different access levels get different results
auth_context = request.installation_id or request.user_token or "anonymous"
cache_key = f"repo_analysis:{request.repository_full_name}:{auth_context}"
cached_result = await get_cache(cache_key)

if cached_result:
Expand All @@ -57,6 +59,7 @@ async def recommend_rules(
"cache_hit",
operation="repository_analysis",
subject_ids=[request.repository_full_name],
auth_context=auth_context,
cached=True,
)
return RepositoryAnalysisResponse(**cached_result)
Expand Down Expand Up @@ -85,6 +88,8 @@ async def recommend_rules(
decision="failed",
error=result.message,
)
# Clear any cached results for this repository to ensure fresh analysis on retry
await set_cache(cache_key, None, ttl=0)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a critical issue with this cache clearing attempt. The set_cache function checks if ttl:, which evaluates to False for ttl=0. Consequently, this call caches a None value for the default duration (1 hour) instead of clearing the entry, preventing fresh analysis on retries. A proper fix would be in set_cache to check if ttl is not None:. As a workaround within this file, you can use a very small positive TTL to ensure near-immediate expiration.

Suggested change
await set_cache(cache_key, None, ttl=0)
await set_cache(cache_key, None, ttl=1)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@naaa760 any thougths?

raise HTTPException(status_code=500, detail=result.message)

analysis_response = result.data.get("analysis_response")
Expand Down Expand Up @@ -161,6 +166,18 @@ async def proceed_with_pr(request: ProceedWithPullRequestRequest) -> ProceedWith
branch=request.branch_name,
existing_sha=existing_branch_sha,
)
# Verify the branch points to the correct base
if existing_branch_sha != base_sha:
log_structured(
logger,
"branch_sha_mismatch",
operation="proceed_with_pr",
subject_ids=[repo],
branch=request.branch_name,
existing_sha=existing_branch_sha,
expected_sha=base_sha,
warning="Branch exists but points to different SHA than base branch",
)
else:
# Create new branch
created_ref = await github_client.create_git_ref(repo, request.branch_name, base_sha, **auth_ctx)
Expand All @@ -182,6 +199,15 @@ async def proceed_with_pr(request: ProceedWithPullRequestRequest) -> ProceedWith
"The branch may already exist or you may not have permission to create branches."
),
)
log_structured(
logger,
"branch_created",
operation="proceed_with_pr",
subject_ids=[repo],
branch=request.branch_name,
base_branch=base_branch,
new_sha=created_ref.get("object", {}).get("sha"),
)

file_result = await github_client.create_or_update_file(
repo_full_name=repo,
Expand Down Expand Up @@ -209,6 +235,17 @@ async def proceed_with_pr(request: ProceedWithPullRequestRequest) -> ProceedWith
),
)

commit_sha = (file_result.get("commit") or {}).get("sha")
log_structured(
logger,
"file_created",
operation="proceed_with_pr",
subject_ids=[repo],
branch=request.branch_name,
file_path=request.file_path,
commit_sha=commit_sha,
)

pr = await github_client.create_pull_request(
repo_full_name=repo,
title=request.pr_title,
Expand Down Expand Up @@ -237,16 +274,63 @@ async def proceed_with_pr(request: ProceedWithPullRequestRequest) -> ProceedWith
)

pr_url = pr.get("html_url", "")
if not pr_url:
pr_number = pr.get("number")
if not pr_url or not pr_number:
log_structured(
logger,
"pr_url_missing",
"pr_creation_incomplete",
operation="proceed_with_pr",
subject_ids=[repo],
pr_data=pr,
error="PR created but html_url is missing",
pr_url=pr_url,
pr_number=pr_number,
error="PR creation response missing required fields",
)
raise HTTPException(status_code=500, detail="PR was created but response is incomplete")

# Validate the PR URL is a proper GitHub URL format
if not pr_url.startswith("https://github.com/") or "/pull/" not in pr_url:
log_structured(
logger,
"pr_url_invalid",
operation="proceed_with_pr",
subject_ids=[repo],
pr_url=pr_url,
pr_number=pr_number,
error="PR URL is not a valid GitHub pull request URL",
)
raise HTTPException(status_code=500, detail="PR was created but returned invalid URL format")

# Validate PR number is reasonable
if not isinstance(pr_number, int) or pr_number <= 0:
log_structured(
logger,
"pr_number_invalid",
operation="proceed_with_pr",
subject_ids=[repo],
pr_url=pr_url,
pr_number=pr_number,
error="PR number is invalid",
)
raise HTTPException(status_code=500, detail="PR was created but returned invalid PR number")

# Final validation before returning success
final_pr_url = pr.get("html_url", "")
final_pr_number = pr.get("number")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These variables final_pr_url and final_pr_number are redundant. The variables pr_url and pr_number were already assigned from the same source on lines 276-277 and have been validated. You can remove these lines and use pr_url and pr_number in the subsequent logic (lines 322-333, 342-343, and 347) to avoid duplication and improve clarity.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@naaa760 here too...


# Double-check URL format one more time
expected_url_pattern = f"https://github.com/{repo}/pull/{final_pr_number}"
if final_pr_url != expected_url_pattern:
log_structured(
logger,
"pr_url_mismatch",
operation="proceed_with_pr",
subject_ids=[repo],
expected_url=expected_url_pattern,
actual_url=final_pr_url,
pr_number=final_pr_number,
warning="PR URL doesn't match expected pattern",
)
Comment on lines +319 to 332
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This is a great series of validation steps to ensure the PR URL is correct. However, in this final check, if the pr_url does not match the expected_url_pattern, it only logs a warning.

A mismatch here, even if the URL seems valid, is a strong indicator of a problem (e.g., an unexpected API response, a repository rename, etc.). It could lead to returning a URL that points to the wrong place or requires a redirect.

To make this more robust and prevent potentially confusing behavior for the user, I recommend raising an HTTPException in this case, similar to the other validation failures. This ensures you only return a URL that is confirmed to be exactly what is expected.

    if pr_url != expected_url_pattern:
        log_structured(
            logger,
            "pr_url_mismatch",
            operation="proceed_with_pr",
            subject_ids=[repo],
            expected_url=expected_url_pattern,
            actual_url=pr_url,
            pr_number=pr_number,
            error="PR URL doesn't match expected pattern",
        )
        raise HTTPException(
            status_code=500, detail=f"PR URL mismatch: expected {expected_url_pattern} but got {pr_url}"
        )

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@naaa760 please check this too ^

raise HTTPException(status_code=500, detail="PR was created but URL is missing")

log_structured(
logger,
Expand All @@ -255,11 +339,12 @@ async def proceed_with_pr(request: ProceedWithPullRequestRequest) -> ProceedWith
subject_ids=[repo],
decision="success",
branch=request.branch_name,
pr_number=pr.get("number"),
pr_number=final_pr_number,
pr_url=final_pr_url,
)

return ProceedWithPullRequestResponse(
pull_request_url=pr.get("html_url", ""),
pull_request_url=final_pr_url,
branch_name=request.branch_name,
base_branch=base_branch,
file_path=request.file_path,
Expand Down
Loading