diff --git a/.gitignore b/.gitignore index ca10ad9..6cdb04c 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,8 @@ __pycache__/ .env .venv/ venv/ +.tools/ +.uv-cache/ .pytest_cache/ .mypy_cache/ diff --git a/agents/paperorchestra/figure_orchestra.py b/agents/paperorchestra/figure_orchestra.py index cb649af..9460b41 100644 --- a/agents/paperorchestra/figure_orchestra.py +++ b/agents/paperorchestra/figure_orchestra.py @@ -18,6 +18,19 @@ def _safe_filename(text: str) -> str: return cleaned[:80] or "figure" +def _is_motivation_or_overview_figure(fig: dict[str, Any]) -> bool: + text = " ".join( + str(fig.get(key) or "") + for key in ("figure_id", "title", "objective", "caption", "data_source") + ).lower() + return any(token in text for token in ("motivation", "overview", "teaser", "problem-method-result", "problem method result")) + + +def _banana_motivation_overview_enabled() -> bool: + raw = os.getenv("DEEPGRAPH_PAPERBANANA_MOTIVATION_OVERVIEW", "true").strip().lower() + return raw in {"1", "true", "yes", "on"} + + def _default_plot_plan(metric_name: str) -> list[dict[str, Any]]: return [ { @@ -281,6 +294,139 @@ def _render_framework_diagram(fig: dict[str, Any], state: dict, out_path: Path) plt.close(fig_obj) +def _draw_small_glyph(ax: Any, x: float, y: float, kind: int, color: str = "#64748b", alpha: float = 0.85, size: float = 1.0) -> None: + import matplotlib.patches as patches + + if kind % 5 == 0: + ax.add_patch(patches.Circle((x, y), 0.010 * size, facecolor="none", edgecolor=color, linewidth=0.9, alpha=alpha)) + elif kind % 5 == 1: + ax.add_patch(patches.RegularPolygon((x, y), 3, radius=0.014 * size, orientation=0.52, facecolor="none", edgecolor=color, linewidth=0.9, alpha=alpha)) + elif kind % 5 == 2: + ax.add_patch(patches.Rectangle((x - 0.010 * size, y - 0.010 * size), 0.020 * size, 0.020 * size, facecolor="none", edgecolor=color, linewidth=0.9, alpha=alpha)) + elif kind % 5 == 3: + ax.plot([x - 0.012 * size, x + 0.012 * size], [y - 0.012 * size, y + 0.012 * size], color=color, linewidth=0.9, alpha=alpha) + ax.plot([x - 0.012 * size, x + 0.012 * size], [y + 0.012 * size, y - 0.012 * size], color=color, linewidth=0.9, alpha=alpha) + else: + ax.add_patch(patches.RegularPolygon((x, y), 6, radius=0.012 * size, facecolor="none", edgecolor=color, linewidth=0.9, alpha=alpha)) + + +def _render_symbolic_motivation(fig: dict[str, Any], state: dict, out_path: Path) -> None: + plt = _setup_matplotlib() + import matplotlib.patches as patches + import numpy as np + + rng = np.random.default_rng(7) + fig_obj, ax = plt.subplots(figsize=(7.2, 4.05)) + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.axis("off") + + # Sparse problem field: many easy cases, a few uncertain/high-value cases. + easy = rng.normal(loc=(0.19, 0.52), scale=(0.075, 0.17), size=(54, 2)) + hard = rng.normal(loc=(0.37, 0.52), scale=(0.035, 0.13), size=(10, 2)) + for idx, (x, y) in enumerate(easy): + if 0.05 < x < 0.35 and 0.12 < y < 0.88: + _draw_small_glyph(ax, float(x), float(y), idx, color="#94a3b8", alpha=0.65, size=0.85) + for idx, (x, y) in enumerate(hard): + if 0.28 < x < 0.48 and 0.14 < y < 0.86: + ax.add_patch(patches.Circle((float(x), float(y)), 0.020, facecolor="#d9f0ee", edgecolor="#0f766e", linewidth=0.9, alpha=0.95)) + _draw_small_glyph(ax, float(x), float(y), idx, color="#0f766e", alpha=1.0, size=0.82) + + # Faint wasted-compute band and missed-value void, expressed without labels. + for offset, alpha in [(0.00, 0.11), (0.022, 0.07), (-0.022, 0.07)]: + ax.add_patch( + patches.Arc( + (0.25, 0.50 + offset), + 0.46, + 0.58, + theta1=-38, + theta2=42, + linewidth=1.1, + color="#f59e0b", + alpha=alpha, + ) + ) + ax.add_patch(patches.Circle((0.33, 0.22), 0.055, facecolor="#f8fafc", edgecolor="#cbd5e1", linewidth=0.9, alpha=0.8)) + ax.plot([0.302, 0.358], [0.22, 0.22], color="#cbd5e1", linewidth=1.0) + + # Selective aperture as dominant focal anchor. + center = (0.56, 0.52) + ax.add_patch(patches.Circle(center, 0.175, facecolor="#ffffff", edgecolor="#0b3b63", linewidth=3.0)) + ax.add_patch(patches.Circle(center, 0.145, facecolor="#ecfeff", edgecolor="#5eead4", linewidth=1.3, alpha=0.65)) + ax.add_patch(patches.Wedge(center, 0.175, 38, 92, width=0.024, facecolor="#f59e0b", edgecolor="none", alpha=0.75)) + ax.add_patch(patches.Wedge(center, 0.175, 190, 250, width=0.024, facecolor="#0f766e", edgecolor="none", alpha=0.75)) + for idx, (x, y) in enumerate([(0.52, 0.58), (0.57, 0.46), (0.61, 0.57), (0.55, 0.52)]): + _draw_small_glyph(ax, x, y, idx, color="#0b3b63", alpha=0.95, size=1.0) + for angle in np.linspace(0.2, 2.8, 7): + ax.plot([0.42, center[0] - 0.13 * np.cos(angle)], [0.30 + 0.04 * np.sin(angle), center[1] - 0.13 * np.sin(angle)], color="#bae6fd", linewidth=0.8, alpha=0.65) + + # Clean resolved set: intentionally simple, no labels. + resolved_x = [0.78, 0.84, 0.90] + for idx, x in enumerate(resolved_x): + _draw_small_glyph(ax, x, 0.57 - idx * 0.035, idx + 2, color="#0b3b63", alpha=1.0, size=1.5) + ax.add_patch(patches.Circle((x + 0.026, 0.57 - idx * 0.035), 0.007, facecolor="#f59e0b", edgecolor="none")) + ax.plot([0.68, 0.74], [0.52, 0.55], color="#0b3b63", linewidth=1.2, alpha=0.65) + + fig_obj.tight_layout(pad=0.0) + _save_native_matplotlib_figure(fig_obj, out_path) + plt.close(fig_obj) + + +def _render_symbolic_overview(fig: dict[str, Any], state: dict, out_path: Path) -> None: + plt = _setup_matplotlib() + import matplotlib.patches as patches + import numpy as np + + rng = np.random.default_rng(11) + fig_obj, ax = plt.subplots(figsize=(7.2, 4.05)) + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.axis("off") + + # Left evidence/problem manifold. + for idx, (x, y) in enumerate(rng.normal(loc=(0.22, 0.55), scale=(0.055, 0.16), size=(28, 2))): + if 0.10 < x < 0.35 and 0.16 < y < 0.86: + _draw_small_glyph(ax, float(x), float(y), idx, color="#64748b", alpha=0.55, size=0.78) + for y in [0.36, 0.50, 0.64]: + ax.plot([0.32, 0.42], [y, 0.52], color="#bae6fd", linewidth=0.9, alpha=0.65) + + # Central conservative gate / aperture. + center = (0.52, 0.52) + for r, c, lw, alpha in [(0.22, "#0b3b63", 2.6, 1.0), (0.18, "#94a3b8", 1.2, 0.9), (0.13, "#5eead4", 1.1, 0.75)]: + ax.add_patch(patches.Circle(center, r, facecolor="none", edgecolor=c, linewidth=lw, alpha=alpha)) + ax.add_patch(patches.Wedge(center, 0.22, 82, 118, width=0.040, facecolor="#ffffff", edgecolor="none")) + ax.add_patch(patches.Wedge(center, 0.18, 252, 292, width=0.030, facecolor="#ffffff", edgecolor="none")) + ax.add_patch(patches.RegularPolygon(center, 6, radius=0.060, orientation=0.52, facecolor="#d9f0ee", edgecolor="#0f766e", linewidth=1.2, alpha=0.95)) + for angle in np.linspace(0, 2 * np.pi, 12, endpoint=False): + x = center[0] + 0.105 * np.cos(angle) + y = center[1] + 0.105 * np.sin(angle) + ax.add_patch(patches.Circle((x, y), 0.006, facecolor="#0b3b63", edgecolor="none", alpha=0.85)) + ax.plot([center[0], x], [center[1], y], color="#94a3b8", linewidth=0.55, alpha=0.45) + + # Cost / confidence / utility cues as tiny side motifs. + for idx, x in enumerate([0.46, 0.485, 0.51]): + ax.add_patch(patches.Rectangle((x, 0.25 + idx * 0.018), 0.028, 0.006, facecolor="#f59e0b", edgecolor="none", alpha=0.80)) + ax.plot([0.43, 0.62], [0.27, 0.27], color="#0b3b63", linewidth=1.0, alpha=0.75) + ax.add_patch(patches.Arc((0.61, 0.33), 0.070, 0.045, theta1=0, theta2=180, color="#64748b", linewidth=0.9)) + ax.add_patch(patches.Circle((0.595, 0.328), 0.005, facecolor="#f59e0b", edgecolor="none")) + + # Reasoning field and resolved symbols, not a chain of boxes. + for idx, angle in enumerate(np.linspace(0, 2 * np.pi, 18, endpoint=False)): + rr = 0.070 + 0.030 * (idx % 3) + x = 0.72 + rr * np.cos(angle) + y = 0.53 + rr * np.sin(angle) + ax.plot([0.72, x], [0.53, y], color="#0b3b63", linewidth=0.65, alpha=0.65) + ax.add_patch(patches.Circle((x, y), 0.006, facecolor="#5eead4" if idx % 2 else "#0b3b63", edgecolor="none", alpha=0.9)) + ax.add_patch(patches.Circle((0.72, 0.53), 0.045, facecolor="#ffffff", edgecolor="#0b3b63", linewidth=1.4)) + for idx, x in enumerate([0.86, 0.91, 0.955]): + _draw_small_glyph(ax, x, 0.55 - idx * 0.025, idx + 1, color="#0b3b63", alpha=0.98, size=1.35) + ax.add_patch(patches.Circle((x, 0.55 - idx * 0.025), 0.004, facecolor="#f59e0b", edgecolor="none")) + + fig_obj.tight_layout(pad=0.0) + _save_native_matplotlib_figure(fig_obj, out_path) + plt.close(fig_obj) + + def _render_constraint_diagram(fig: dict[str, Any], state: dict, out_path: Path) -> None: plt = _setup_matplotlib() @@ -612,6 +758,14 @@ def render_native_figure( text = " ".join(str(fig.get(k) or "") for k in ("figure_id", "title", "plot_type", "objective", "caption")).lower() rows = _metric_points(iterations) try: + if _is_motivation_or_overview_figure(fig): + if "motivation" in text: + _render_symbolic_motivation(fig, state, out_path) + renderer = "symbolic_motivation" + else: + _render_symbolic_overview(fig, state, out_path) + renderer = "symbolic_overview" + return _native_asset(fid=fid, fig=fig, out_path=out_path, kind="diagram", renderer=renderer, objective=objective) if "benchmark" in text or "method comparison" in text: _render_benchmark_method_panel(fig, state, out_path) return _native_asset(fid=fid, fig=fig, out_path=out_path, kind="plot", renderer="benchmark_method_panel", objective=objective) @@ -735,7 +889,7 @@ def _run_external_diagram( ensure_ascii=False, ) command = paperbanana_cmd.format( - output=_shell_quote(str(out_path)), + output=_shell_quote(str(out_path.resolve())), spec=_shell_quote(spec), ) try: @@ -822,8 +976,13 @@ def run_figure_orchestra( objective = str(fig.get("objective") or title) plot_type = str(fig.get("plot_type") or "plot").lower() if plot_type == "diagram": + force_banana = ( + _banana_motivation_overview_enabled() + and paperbanana_cmd + and _is_motivation_or_overview_figure(fig) + ) prefer_ai = os.getenv("DEEPGRAPH_PAPERBANANA_PREFER_AI", "").strip().lower() in {"1", "true", "yes"} - if allow_external_diagrams and prefer_ai and paperbanana_cmd: + if force_banana or (allow_external_diagrams and prefer_ai and paperbanana_cmd): asset = _run_external_diagram( fig, figures_dir=figures_dir, @@ -901,6 +1060,50 @@ def run_postwriting_api_figure_stage( for token in ("framework", "overview", "method", "problem", "gating", "architecture") ) ] + plan_text = " ".join( + " ".join(str(row.get(key) or "") for key in ("figure_id", "title", "objective")).lower() + for row in diagram_plan + ) + if "motivation" not in plan_text: + diagram_plan.insert( + 0, + { + "figure_id": "fig_motivation_symbolic", + "plot_type": "diagram", + "title": "Motivation", + "objective": ( + "Create a symbolic motivation figure from the manuscript draft and caption intent: " + "show why the problem matters, what existing methods miss, and what selective reasoning changes. " + "No in-image title, no Fig. caption, no text labels, and no flowchart." + ), + "caption": ( + "Motivation figure contrasting indiscriminate reasoning with selective, evidence-aware reasoning " + "using abstract scientific symbols rather than a process diagram." + ), + "data_source": "postwriting manuscript draft plus figure caption intent", + "aspect_ratio": "16:9", + }, + ) + if "overview" not in plan_text and "framework" not in plan_text: + diagram_plan.insert( + 1 if diagram_plan else 0, + { + "figure_id": "fig_overview_symbolic", + "plot_type": "diagram", + "title": "Overview", + "objective": ( + "Create a symbolic method overview from the manuscript draft and caption intent: " + "represent the central mechanism, evidence flow, and final research claim as an integrated " + "scientific illustration. No in-image title, no Fig. caption, no text labels, and no flowchart." + ), + "caption": ( + "Overview figure summarizing the proposed mechanism and evidence structure in an abstract " + "camera-ready visual language." + ), + "data_source": "postwriting manuscript draft plus figure caption intent", + "aspect_ratio": "16:9", + }, + ) if not diagram_plan: pa = state.get("problem_awareness") if isinstance(state.get("problem_awareness"), dict) else {} diagram_plan = [ diff --git a/scripts/paperbanana_wrapper.py b/scripts/paperbanana_wrapper.py index f207a1f..7149591 100755 --- a/scripts/paperbanana_wrapper.py +++ b/scripts/paperbanana_wrapper.py @@ -295,6 +295,21 @@ def _build_caption(spec: dict[str, Any]) -> str: return title or objective or "Framework overview" +def _is_motivation_overview_spec(spec: dict[str, Any]) -> bool: + fig = spec.get("figure") or {} + text = " ".join( + str(part or "") + for part in ( + fig.get("figure_id"), + fig.get("title"), + fig.get("objective"), + spec.get("state_title"), + spec.get("problem_statement"), + ) + ).lower() + return any(token in text for token in ("motivation", "overview", "teaser", "problem-method-result", "problem method result")) + + def _check_credentials() -> tuple[bool, str]: image_protocol = (_env_first("DEEPGRAPH_PAPERBANANA_IMAGE_PROTOCOL") or "").strip().lower() if image_protocol == "openai_compatible" and _env_first("DEEPGRAPH_PAPERBANANA_IMAGE_API_KEY") and _openai_image_base_url(): @@ -312,6 +327,81 @@ def _check_credentials() -> tuple[bool, str]: def _image_prompt(spec: dict[str, Any], *, caption: str, content: str) -> str: fig = spec.get("figure") or {} + if _is_motivation_overview_spec(spec): + fig_text = " ".join(str(fig.get(key) or "") for key in ("figure_id", "title", "objective")).lower() + figure_role = "motivation" if "motivation" in fig_text else "overview" + cleaned_caption = _clip(caption, 700) + title = str(fig.get("title") or "").strip() + if title and cleaned_caption.lower().startswith(f"{title.lower()}."): + cleaned_caption = cleaned_caption[len(title) + 1 :].strip() + context_lines = [ + f"Paper title: {_clip(spec.get('state_title'), 200)}", + f"Method name: {_clip(spec.get('method_name'), 200)}", + f"Problem statement: {_clip(spec.get('problem_statement'), 800)}", + f"Existing weakness: {_clip(spec.get('existing_weakness'), 600)}", + f"Method summary: {_clip(spec.get('method_summary'), 1200)}", + ] + contributions = spec.get("contributions") or [] + if contributions: + context_lines.extend(["Key contributions:", _list_block(contributions, limit=6)]) + cleaned_context = "\n".join(line for line in context_lines if line.strip()) + if figure_role == "motivation": + schema = ( + "This figure should function as the paper's motivation figure. " + "It should look like a publication framework figure with clear regions, concise labels, and a visible scientific contrast. " + "The reader should immediately understand what the problem is, why it matters, what is insufficient in the current setting, and what key contrast motivates the proposed direction." + ) + else: + schema = ( + "This figure should function as the paper's overview figure. " + "It should summarize the main mechanism or conceptual structure of the method in one unified framework diagram. " + "Use a structured multi-region layout with grouped modules, concise labels, arrows, and a clear semantic flow. " + "The reader should understand the core idea at a glance." + ) + return "\n".join( + [ + "You are an experienced scientific figure designer preparing a camera-ready figure for a machine learning paper.", + "Carefully read the paper context and the figure intent, fully understand the research content, and produce a figure suitable for academic publication.", + schema, + "The figure should be understandable at a glance, even before the reader studies the full paper.", + "Prefer a wide publication-style framework layout with 3 to 5 clearly separated functional regions across the canvas. Each region should have an obvious role in the scientific story.", + "Do not make it a plain left-to-right pipeline or a generic flowchart. Create a denser scientific composition with local substructures, internal comparisons, and grouped modules.", + "Avoid three equally sized vertical slabs with a single arrow passing through them. Prefer one dominant dense working area plus one or two supporting grouped regions, or an asymmetric multi-cluster arrangement.", + "Take inspiration from strong editorial scientific figures: use asymmetric layout, a clear focal region, supporting side clusters, fan-in or fan-out connectors, and local density variation rather than uniform columns.", + "A good pattern is: one contextual cluster, one bridge or interface cluster, and one dense main analytical cluster, with a small integrated legend or semantic key in a corner if needed.", + "Do not center the figure around one giant symbolic object. The main structure should come from grouped panels, modules, and connections rather than a single metaphor shape.", + "Do not place a large title at the top of the image.", + "Never add a standalone figure heading such as Figure 1, Motivation, Overview, System Overview, Framework, or any caption-like sentence anywhere in the image.", + "Do not create giant comparison banners such as Traditional X vs Proposed Y across the top. If a comparison is necessary, express it with local grouped modules and small embedded labels only.", + "Do not add a detached bottom takeaway box, key insight box, or summary strip outside the main composition.", + "Do not place a bottom caption, footnote, or explanatory paragraph inside the image.", + "Short in-figure labels are allowed and encouraged when they improve scientific clarity. Use concise framework-style labels, module names, arrow labels, and compact legends when necessary.", + "Use a disciplined text hierarchy: small integrated panel headers, short module labels inside boxes, and very short arrow labels. Avoid giant all-caps banner text spanning the full canvas.", + "If region names are needed, embed them inside the relevant panel and keep them secondary. Do not place oversized text floating above large regions.", + "All visible text should use Times New Roman or a very close academic serif font. Avoid sans-serif, poster-like display fonts, handwritten styles, or playful typography.", + "Design it like a strong conference framework figure: organized blocks, grouped regions, rounded rectangles when useful, arrows or connectors where they clarify logic, and a composition that feels authored rather than templated.", + "Do not over-expand low-level implementation detail. Keep the abstraction at the right level for a publication figure.", + "Do not let generic background context occupy too much of the canvas. Large low-information background panels are discouraged. Reserve most visual emphasis for the method logic and the main scientific contrast.", + "Use semantic consistency: similar roles should share consistent color, shape, visual weight, iconography, and placement logic. Assign a small palette of 3 to 4 semantic colors and reuse them consistently.", + "Each region should contain meaningful internal structure: nested boxes, grouped items, small comparisons, aligned rows, or compact examples. Avoid large empty washes with only one object inside.", + "Use a few necessary concrete icons or data thumbnails when they improve comprehension, such as document, message, cache, embedding, user, model, dataset, or output icons. They should be clean, intentional, and tied to real modules rather than decorative.", + "When reusing visual elements, introduce controlled differences so repeated modules are not mechanically identical. Avoid long stacks of near-duplicate cards or repeated clipart blocks.", + "Allow multiple meaningful visual elements if the figure needs them, but every element must have a clear role in the scientific explanation.", + "Avoid decorative concept art, giant symbolic brains, clouds, funnels, waves, logo walls, random icon piles, or visually flashy but semantically empty motifs.", + "Avoid collage-like card stacking. The figure should feel like an engineered layout, not a pile of decorative tiles.", + "Avoid generic stock shapes that scream AI-generated infographic, such as oversized trapezoid encoders, giant ribbon arrows, or repeated empty neural-network clipart unless grounded in a precise scientific role.", + "Do not turn the figure into a toy infographic. It should read like a polished framework figure from a top ML paper.", + "Use whitespace well, but do not oversimplify the figure into a vague sparse composition. A richer multi-block framework figure is acceptable if it improves clarity.", + "Create visual texture through meaningful structure: nested containers, aligned micro-elements, varied line weights, subtle shadows, and local detail. Do not rely on huge gradients or oversized empty background areas for style.", + "A compact legend strip or semantic key is allowed when useful. If used, make it small, integrated, and tucked into a corner or margin. Never let the legend become a bottom-wide banner.", + "Favor non-uniform occupancy: let important regions be denser and larger, and let supporting regions be smaller and more compact. Avoid evenly distributing empty space across the canvas.", + "The visual style should resemble a modern academic framework diagram template: white background, restrained local tinting, grouped panels, rounded modules, controlled outlines, balanced spacing, moderate line weights, arrows with clear direction, and clean readable labels.", + "The output should look like a serious NeurIPS/ICLR/ICML figure prepared for publication.", + f"Figure-specific intent: {_clip(fig.get('objective') or caption, 700)}", + f"Figure caption context from the paper: {cleaned_caption}", + f"Paper context to read and use: {_clip(cleaned_context, 2200)}", + ] + ).strip() labels = [ str(part) for part in [ @@ -405,6 +495,76 @@ def _run_openai_compatible_image_generation( return _write_openai_image_response(output_path=output_path, body=body) +def _run_gemini_native_image_generation( + *, + output_path: Path, + prompt: str, + aspect_ratio: str, +) -> int: + api_key = _env_first("DEEPGRAPH_PAPERBANANA_IMAGE_API_KEY", "GEMINI_NATIVE_API_KEY") + base_url = _normalize_gemini_native_base_url( + _env_first("DEEPGRAPH_PAPERBANANA_IMAGE_BASE_URL", "GEMINI_NATIVE_BASE_URL") + ) + model = _env_first("DEEPGRAPH_PAPERBANANA_IMAGE_MODEL", "IMAGE_GEN_MODEL_NAME") or "gemini-2.5-flash-image" + if not api_key or not base_url: + print("Gemini-native image generation is missing API key or base URL.", file=sys.stderr) + return 3 + + payload = { + "contents": [ + { + "role": "user", + "parts": [{"text": prompt}], + } + ], + "generationConfig": { + "responseModalities": ["TEXT", "IMAGE"], + "imageConfig": {"aspectRatio": aspect_ratio}, + }, + } + url = f"{base_url}/v1beta/models/{model}:generateContent" + request = urllib.request.Request( + url, + data=json.dumps(payload, ensure_ascii=False).encode("utf-8"), + headers={ + "Authorization": f"Bearer {api_key}", + "x-goog-api-key": api_key, + "Content-Type": "application/json", + "User-Agent": "DeepGraph-PaperBanana-Wrapper/1.0", + }, + method="POST", + ) + try: + with urllib.request.urlopen(request, timeout=600) as response: + body = json.loads(response.read().decode("utf-8", errors="replace")) + except Exception as exc: + print(f"Gemini-native image generation failed: {exc}", file=sys.stderr) + return 4 + + candidates = body.get("candidates") if isinstance(body, dict) else None + for candidate in candidates or []: + content = candidate.get("content") if isinstance(candidate, dict) else None + parts = content.get("parts") if isinstance(content, dict) else None + for part in parts or []: + if not isinstance(part, dict): + continue + inline_data = part.get("inlineData") or part.get("inline_data") + if not isinstance(inline_data, dict): + continue + data = inline_data.get("data") + if not data: + continue + try: + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_bytes(base64.b64decode(str(data))) + return 0 + except Exception as exc: + print(f"Gemini-native image base64 decode failed: {exc}", file=sys.stderr) + return 4 + print("Gemini-native image generation response had no inline image data.", file=sys.stderr) + return 4 + + def _image_attempt_count() -> int: raw = _env_first("DEEPGRAPH_PAPERBANANA_IMAGE_ATTEMPTS", "DEEPGRAPH_PAPERBANANA_IMAGE_RETRIES") try: @@ -586,6 +746,12 @@ def main() -> int: output_path.parent.mkdir(parents=True, exist_ok=True) if not PAPERBANANA_PYTHON.exists() or not PAPERBANANA_ENTRY.exists(): + if provider == "gemini_native": + return _run_gemini_native_image_generation( + output_path=output_path, + prompt=_image_prompt(spec, caption=caption, content=content), + aspect_ratio=aspect_ratio, + ) if provider in {"openai", "openrouter", "openai_compatible_image"}: return _run_openai_compatible_image_generation( output_path=output_path, diff --git a/tests/test_evidence_planner.py b/tests/test_evidence_planner.py index fb1402b..efa1425 100644 --- a/tests/test_evidence_planner.py +++ b/tests/test_evidence_planner.py @@ -1,6 +1,7 @@ import tempfile import unittest from pathlib import Path +from unittest import mock from agents.evidence_planner import build_evidence_plan from agents.paperorchestra.figure_orchestra import run_figure_orchestra @@ -75,6 +76,66 @@ def test_disabled_visualization_skips_default_figure_generation(self): ) self.assertEqual(manifest["assets"], []) + def test_motivation_overview_diagram_uses_banana_by_default(self): + outline = { + "plotting_plan": [ + { + "figure_id": "fig_motivation_overview", + "plot_type": "diagram", + "title": "Motivation overview", + "objective": "Motivation and overview for selective reasoning.", + } + ] + } + with tempfile.TemporaryDirectory() as tmpdir: + manifest = run_figure_orchestra( + outline=outline, + state={"title": "Selective reasoning"}, + iterations=[], + figures_dir=Path(tmpdir), + baseline=None, + metric_name="accuracy", + paperbanana_cmd="printf x > {output}", + ) + self.assertGreaterEqual(manifest["generated_count"], 1) + asset = next( + row for row in manifest["assets"] + if row.get("figure_id") == "fig_motivation_overview" + ) + self.assertEqual(asset["notes"], "paperbanana_ok") + self.assertTrue(Path(asset["path"]).exists()) + + def test_motivation_overview_diagram_can_opt_out_to_native(self): + outline = { + "plotting_plan": [ + { + "figure_id": "fig_motivation_overview", + "plot_type": "diagram", + "title": "Motivation overview", + "objective": "Motivation and overview for selective reasoning.", + } + ] + } + with tempfile.TemporaryDirectory() as tmpdir, mock.patch.dict( + "os.environ", + {"DEEPGRAPH_PAPERBANANA_MOTIVATION_OVERVIEW": "false"}, + ): + manifest = run_figure_orchestra( + outline=outline, + state={"title": "Selective reasoning"}, + iterations=[], + figures_dir=Path(tmpdir), + baseline=None, + metric_name="accuracy", + paperbanana_cmd="printf x > {output}", + ) + asset = next( + row for row in manifest["assets"] + if row.get("figure_id") == "fig_motivation_overview" + ) + self.assertEqual(asset["notes"], "native_symbolic_motivation") + self.assertTrue(Path(asset["path"]).exists()) + if __name__ == "__main__": unittest.main()