diff --git a/predicators/approaches/agent_session_mixin.py b/predicators/agent_sdk/agent_session_mixin.py similarity index 92% rename from predicators/approaches/agent_session_mixin.py rename to predicators/agent_sdk/agent_session_mixin.py index f90697340..325974882 100644 --- a/predicators/approaches/agent_session_mixin.py +++ b/predicators/agent_sdk/agent_session_mixin.py @@ -8,7 +8,8 @@ import os from typing import Any, Dict, List, Optional, Set, Union -from predicators.agent_sdk.session_manager import AgentSessionManager +from predicators.agent_sdk.session_manager import AgentSessionManager, \ + run_query_sync from predicators.agent_sdk.tools import ToolContext, create_mcp_tools, \ get_allowed_tool_list from predicators.explorers import create_explorer @@ -127,12 +128,18 @@ def _ensure_agent_session(self) -> None: tools=tools, ) + extra_names = [ + getattr(t, "name", "") + for t in self._tool_context.extra_mcp_tools + ] self._agent_session = AgentSessionManager( system_prompt=self._get_agent_system_prompt(), mcp_server=mcp_server, log_dir=self._get_log_dir(), model_name=CFG.agent_sdk_model_name, - allowed_tools=get_allowed_tool_list(tool_names), + allowed_tools=get_allowed_tool_list(tool_names, + extra_names=extra_names + or None), ) if self._agent_session_id is not None: @@ -179,26 +186,18 @@ def _query_agent_sync(self, message: str) -> List[Dict[str, Any]]: """Synchronous wrapper for async agent query.""" self._ensure_agent_session() assert self._agent_session is not None - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - import nest_asyncio # type: ignore[import-untyped,import-not-found] # pylint: disable=import-outside-toplevel - nest_asyncio.apply() - return loop.run_until_complete( - self._agent_session.query(message)) - return loop.run_until_complete(self._agent_session.query(message)) - except RuntimeError: - return asyncio.run(self._agent_session.query(message)) + return run_query_sync(self._agent_session, message) def _create_agent_explorer( self, predicates: Set[Predicate], options: Set[ParameterizedOption], + name: str = "agent_plan", ) -> BaseExplorer: """Create an agent explorer with tool_context and agent_session.""" self._ensure_agent_session() return create_explorer( - "agent", + name, predicates, options, self._types, # type: ignore[attr-defined] diff --git a/predicators/agent_sdk/bilevel_sketch.py b/predicators/agent_sdk/bilevel_sketch.py new file mode 100644 index 000000000..25135af86 --- /dev/null +++ b/predicators/agent_sdk/bilevel_sketch.py @@ -0,0 +1,417 @@ +"""Shared helpers for bilevel plan-sketch construction and refinement. + +Extracted from ``AgentBilevelApproach`` so both the approach (at solve +time) and ``AgentBilevelExplorer`` (at exploration time) can build plan +sketches, parse subgoal annotations, and run backtracking refinement +against an arbitrary ``_OptionModelBase``. + +The helpers are pure module-level functions — they take their +dependencies (option_model, predicates, rng, settings) explicitly so +neither approaches nor explorers need to subclass one another. +""" +import dataclasses +import logging +import re +from typing import Callable, List, Optional, Sequence, Set, Tuple, cast + +import numpy as np + +from predicators import utils +from predicators.option_model import _OptionModelBase +from predicators.planning import run_backtracking_refinement +from predicators.structs import GroundAtom, Object, ParameterizedOption, \ + Predicate, State, Task, Type, _Option + + +@dataclasses.dataclass +class SketchStep: + """One step in an agent-produced plan sketch. + + ``subgoal_atoms`` / ``subgoal_neg_atoms`` are optional: ``None`` + means "no subgoal constraint at this step"; an empty set means "the + annotation was present but contained no atoms of that polarity". + """ + option: ParameterizedOption + objects: Sequence[Object] + subgoal_atoms: Optional[Set[GroundAtom]] + subgoal_neg_atoms: Optional[Set[GroundAtom]] = None + + +def strip_code_fences(text: str) -> str: + """Strip markdown code fences wrapping plan text.""" + lines = text.split('\n') + while lines and lines[0].strip().startswith('```'): + lines.pop(0) + while lines and lines[-1].strip().startswith('```'): + lines.pop() + return '\n'.join(lines) + + +def sample_params(option: ParameterizedOption, + rng: np.random.Generator) -> np.ndarray: + """Sample continuous parameters uniformly from the option's box.""" + if option.params_space.shape[0] == 0: + return np.array([], dtype=np.float32) + low = option.params_space.low + high = option.params_space.high + return rng.uniform(low, high).astype(np.float32) + + +def build_solve_prompt( + task: Task, + *, + all_predicates: Set[Predicate], + all_options: Set[ParameterizedOption], + trajectory_summary: str = "", + tool_names: Optional[Sequence[str]] = None, +) -> str: + """Build the bilevel solve/explore prompt asking for a plan sketch. + + Mirrors ``AgentBilevelApproach._build_solve_prompt`` but takes + dependencies explicitly so explorers can reuse it. + """ + init_state = task.init + objects = list(init_state) + + obj_strs = [] + for obj in sorted(objects, key=lambda o: o.name): + obj_strs.append(f" {obj.name}: {obj.type.name}") + + goal_strs = [str(a) for a in sorted(task.goal, key=str)] + + option_strs = [] + for opt in sorted(all_options, key=lambda o: o.name): + type_sig = ", ".join(t.name for t in opt.types) + params_dim = opt.params_space.shape[0] + if params_dim > 0: + low = opt.params_space.low.tolist() + high = opt.params_space.high.tolist() + if opt.params_description: + desc = ", ".join(opt.params_description) + param_info = (f" [auto-searched params: {desc}, " + f"range {low} to {high}]") + else: + param_info = (f" [auto-searched: {params_dim}d, " + f"range {low} to {high}]") + else: + param_info = "" + option_strs.append(f" {opt.name}({type_sig}){param_info}") + + atoms = utils.abstract(init_state, all_predicates) + atom_strs = [str(a) for a in sorted(atoms, key=str)] + + state_str = init_state.dict_str(indent=2) + + tools_str = "" + if tool_names: + tool_list = "\n".join(f" - {t}" for t in tool_names) + tools_str = f"\n## Available Tools\n{tool_list}\n" + + goal_nl_section = "" + if task.goal_nl: + goal_nl_section = f"\n## Goal Description\n{task.goal_nl}\n" + + pred_strs = [] + for pred in sorted(all_predicates, key=lambda p: p.name): + type_sig = ", ".join(t.name for t in pred.types) + pred_strs.append(f" {pred.name}({type_sig})") + + prompt = f"""You are solving a task. \ +Generate a plan sketch to achieve the goal. +{goal_nl_section} +## Goal Atoms +{chr(10).join(goal_strs)} + +## Initial State Atoms +{chr(10).join(atom_strs)} + +## Initial State Features +{state_str} + +## Objects +{chr(10).join(obj_strs)} + +## Available Options +{chr(10).join(option_strs)} + +## Available Predicates (for subgoal annotations) +{chr(10).join(pred_strs)} +{trajectory_summary}{tools_str} +## Instructions +Use your available tools to inspect the environment before producing the plan. + +Generate a plan SKETCH — the sequence of options with object arguments, but \ +WITHOUT continuous parameters. Continuous parameters will be found \ +automatically by a backtracking search procedure. + +Optionally annotate subgoal atoms that should hold after each step. This \ +helps the search verify progress. Use `-> {{atoms}}` after each step. + +After any action whose desired subgoal depends on a delayed process (e.g. \ +water filling, dominoes cascading, heating), insert a Wait action. For Wait \ +steps, annotate with the atoms the process should produce — this tells the \ +system exactly when the Wait should end rather than terminating on any \ +incidental atom change. Use `NOT Pred(...)` for atoms that should become false. + +Output the plan sketch with one option per line in this format: + OptionName(obj1:type1, obj2:type2) -> \ +{{Pred(obj1:type1), Pred2(obj1:type1, obj2:type2)}} + Wait(robot:Robot) -> {{Boiled(water:water_type)}} + Wait(robot:Robot) -> {{NOT Touching(a:block, b:block)}} + +Always use typed references (obj:type) in both option arguments AND subgoal \ +atoms. The `-> {{atoms}}` part is optional. If you omit it, the search will \ +only check that the option executed successfully (non-zero actions). + +Output ONLY the plan sketch lines at the end, after any analysis.""" + + return prompt + + +def parse_subgoal_annotations( + text: str, + predicates: Set[Predicate], + objects: Sequence[Object], + option_names: Set[str], +) -> List[Optional[Tuple[Set[GroundAtom], Set[GroundAtom]]]]: + """Parse ``-> {Pred(...), NOT Pred(...)}`` annotations from plan text. + + Returns a list parallel to the option lines in ``text``. Each entry + is ``None`` for a line with no annotation, or ``(positive_atoms, + negative_atoms)`` otherwise. + """ + pred_map = {p.name: p for p in predicates} + obj_map = {o.name: o for o in objects} + + subgoal_re = re.compile(r'->\s*\{([^}]*)\}') + atom_re = re.compile(r'(NOT\s+)?(\w+)\(([^)]*)\)') + + results: List[Optional[Tuple[Set[GroundAtom], Set[GroundAtom]]]] = [] + + for line in text.split('\n'): + stripped = line.strip() + if not stripped: + continue + first_token = stripped.split('(')[0] + if first_token not in option_names: + continue + + sg_match = subgoal_re.search(stripped) + if not sg_match: + results.append(None) + continue + + atoms_text = sg_match.group(1) + pos_atoms: Set[GroundAtom] = set() + neg_atoms: Set[GroundAtom] = set() + for atom_match in atom_re.finditer(atoms_text): + is_neg = atom_match.group(1) is not None + pred_name = atom_match.group(2) + obj_names = [ + n.strip().split(':')[0] for n in atom_match.group(3).split(',') + ] + + if pred_name not in pred_map: + logging.warning(f"Unknown predicate in subgoal: {pred_name}") + continue + pred = pred_map[pred_name] + try: + objs = [obj_map[n] for n in obj_names] + except KeyError as e: + logging.warning(f"Unknown object in subgoal: {e}") + continue + if len(objs) != len(pred.types): + logging.warning(f"Arity mismatch for {pred_name}: expected " + f"{len(pred.types)}, got {len(objs)}") + continue + atom = GroundAtom(pred, objs) + if is_neg: + neg_atoms.add(atom) + else: + pos_atoms.add(atom) + + if pos_atoms or neg_atoms: + results.append((pos_atoms, neg_atoms)) + else: + results.append(None) + + return results + + +def parse_sketch_from_text( + plan_text: str, + task: Task, + *, + predicates: Set[Predicate], + options: Set[ParameterizedOption], + types: Set[Type], +) -> List[SketchStep]: + """Parse plan-sketch text into ``SketchStep``s. + + Applies ``strip_code_fences`` first, then delegates option-plan + parsing to ``utils.parse_model_output_into_option_plan`` and subgoal + annotation parsing to ``parse_subgoal_annotations``. + """ + cleaned_text = strip_code_fences(plan_text) + objects = list(task.init) + option_names = {o.name for o in options} + + parsed = utils.parse_model_output_into_option_plan( + cleaned_text, objects, types, options, parse_continuous_params=False) + + if not parsed: + return [] + + subgoals = parse_subgoal_annotations(cleaned_text, predicates, objects, + option_names) + + sketch: List[SketchStep] = [] + for i, (option, objs, _) in enumerate(parsed): + sg = subgoals[i] if i < len(subgoals) else None + if sg is not None: + pos, neg = sg + sketch.append( + SketchStep(option=option, + objects=objs, + subgoal_atoms=pos if pos else None, + subgoal_neg_atoms=neg if neg else None)) + else: + sketch.append( + SketchStep(option=option, objects=objs, subgoal_atoms=None)) + return sketch + + +def refine_sketch( + task: Task, + sketch: List[SketchStep], + option_model: _OptionModelBase, + *, + predicates: Set[Predicate], + timeout: float, + rng: np.random.Generator, + max_samples_per_step: int, + check_subgoals: bool, + check_final_goal: bool = True, + truncate_on_subgoal_fail: bool = False, + log_state: bool = False, + run_id: str = "bilevel", + on_step_fail: Optional[Callable[[int, List[Optional[_Option]], str], + None]] = None, +) -> Tuple[List[_Option], bool, int]: + """Backtracking search over continuous parameters for a plan sketch. + + Returns ``(refined_plan, success, total_samples)``. On success the + plan is fully refined; on failure it is the longest prefix of + refined options (``None`` entries dropped). + + ``check_subgoals`` gates per-step subgoal-atom validation. + ``check_final_goal`` gates the task-goal check on the final step. + ``truncate_on_subgoal_fail`` (explorer mode) lets backtracking run + to exhaustion with subgoal checks enabled, then — if the search + fails — returns the consistent plan prefix captured at the deepest + subgoal failure seen during backtracking (inclusive of the failing + step). Use this to build *experiment* plans that probe a single + mental-model disagreement: upstream steps get their standard + backtracking retries, but once the deepest unresolvable subgoal is + identified, subsequent sketch steps are dropped (they would be + built on a false mental-model state). + + Wait steps inject ``wait_target_atoms`` / ``wait_target_neg_atoms`` + from the sketch's subgoal annotations into ``grounded.memory`` so + that ``WaitOption`` terminates on the intended atom change rather + than the first incidental one. + """ + if not sketch: + return [], False, 0 + + n = len(sketch) + max_tries = [ + max_samples_per_step if step.option.params_space.shape[0] > 0 else 1 + for step in sketch + ] + # Snapshot of the deepest subgoal failure seen during backtracking. + # Tracks (idx, plan_prefix_snapshot). Updated whenever on_step_fail + # reports a subgoal failure at a strictly deeper index than before. + # The snapshot is taken at the moment of failure, so it is a + # *consistent* trajectory: run_backtracking_refinement has already + # written plan[idx] for that attempt and the prefix plan[:idx+1] + # reflects the exact grounded options that led to this failure. + deepest_subgoal_fail_idx: List[int] = [-1] + deepest_subgoal_fail_prefix: List[List[Optional[_Option]]] = [[]] + + def sample_fn(idx: int, state: State, + rng_: np.random.Generator) -> _Option: + step = sketch[idx] + if log_state: + step_name = (f"{step.option.name}" + f"({', '.join(o.name for o in step.objects)})") + logging.debug(f"[{run_id}] State before {step_name}:\n" + f"{state.pretty_str()}") + params = sample_params(step.option, rng_) + grounded = step.option.ground(list(step.objects), params) + if grounded.name == "Wait": + if step.subgoal_atoms is not None: + grounded.memory["wait_target_atoms"] = step.subgoal_atoms + if step.subgoal_neg_atoms is not None: + grounded.memory["wait_target_neg_atoms"] = \ + step.subgoal_neg_atoms + return grounded + + def validate_fn(idx: int, _pre_state: State, _option: _Option, + post_state: State, _num_actions: int) -> Tuple[bool, str]: + step = sketch[idx] + if check_subgoals and step.subgoal_atoms is not None: + current_atoms = utils.abstract(post_state, predicates) + if not step.subgoal_atoms.issubset(current_atoms): + missing = step.subgoal_atoms - current_atoms + return False, (f"subgoal missing: " + f"{{{', '.join(str(a) for a in missing)}}}") + if check_final_goal and idx == n - 1: + if not task.goal_holds(post_state): + return False, "goal not reached" + return True, "" + + def wrapped_on_step_fail(idx: int, cur_plan: List[Optional[_Option]], + fail_reason: str) -> None: + # run_backtracking_refinement calls this BEFORE clearing + # plan[idx] (planning.py lines 592-599), so cur_plan[0..idx] is + # still populated with the grounded options that produced this + # exact failure trajectory. Record the deepest subgoal failure + # seen so far along with a consistent snapshot of the prefix. + if (truncate_on_subgoal_fail + and fail_reason.startswith("subgoal missing") + and idx > deepest_subgoal_fail_idx[0]): + deepest_subgoal_fail_idx[0] = idx + deepest_subgoal_fail_prefix[0] = list(cur_plan[:idx + 1]) + if on_step_fail is not None: + on_step_fail(idx, cur_plan, fail_reason) + + plan, success, total_samples = run_backtracking_refinement( + init_state=task.init, + option_model=option_model, + n_steps=n, + max_tries=max_tries, + sample_fn=sample_fn, + validate_fn=validate_fn, + rng=rng, + timeout=timeout, + on_step_fail=wrapped_on_step_fail, + ) + + logging.info( + f"[{run_id}] Refinement {'succeeded' if success else 'failed'}: " + f"{total_samples} samples for {n} steps.") + + if (truncate_on_subgoal_fail and not success + and deepest_subgoal_fail_idx[0] >= 0): + snapshot = deepest_subgoal_fail_prefix[0] + refined = [p for p in snapshot if p is not None] + logging.info(f"[{run_id}] Truncating at deepest subgoal failure " + f"(step {deepest_subgoal_fail_idx[0]}): " + f"{len(refined)}/{n} steps in experiment plan.") + return cast(List[_Option], refined), False, total_samples + + refined = [p for p in plan if p is not None] + if success: + return cast(List[_Option], refined), True, total_samples + return refined, False, total_samples diff --git a/predicators/agent_sdk/session_manager.py b/predicators/agent_sdk/session_manager.py index f56063a25..84c6ce880 100644 --- a/predicators/agent_sdk/session_manager.py +++ b/predicators/agent_sdk/session_manager.py @@ -1,4 +1,5 @@ """Agent session lifecycle management for Claude SDK.""" +import asyncio import datetime import json import logging @@ -211,3 +212,20 @@ def save_session_info(self) -> None: with open(path, "w", encoding="utf-8") as f: json.dump(info, f, indent=2) logging.info("Saved session info to %s", path) + + +def run_query_sync(session: Any, message: str) -> List[Dict[str, Any]]: + """Synchronously run ``session.query(message)``. + + Reuses a running event loop via nest_asyncio when one is active, + otherwise falls back to ``asyncio.run``. + """ + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + import nest_asyncio # type: ignore[import-untyped,import-not-found] # pylint: disable=import-outside-toplevel + nest_asyncio.apply() + return loop.run_until_complete(session.query(message)) + return loop.run_until_complete(session.query(message)) + except RuntimeError: + return asyncio.run(session.query(message)) diff --git a/predicators/agent_sdk/tools.py b/predicators/agent_sdk/tools.py index 01ea16fc3..685e73202 100644 --- a/predicators/agent_sdk/tools.py +++ b/predicators/agent_sdk/tools.py @@ -72,7 +72,10 @@ PLANNING_TOOL_NAMES + SCENE_TOOL_NAMES) -def get_allowed_tool_list(tool_names: Optional[List[str]] = None) -> List[str]: +def get_allowed_tool_list( + tool_names: Optional[List[str]] = None, + extra_names: Optional[List[str]] = None, +) -> List[str]: """Compute the allowed_tools list for the agent SDK. Args: @@ -82,6 +85,8 @@ def get_allowed_tool_list(tool_names: Optional[List[str]] = None) -> List[str]: prefix = f"mcp__{MCP_SERVER_NAME}__" names = ALL_TOOL_NAMES if tool_names is None else \ [n for n in tool_names if n in set(ALL_TOOL_NAMES)] + if extra_names: + names = list(names) + list(extra_names) return [f"{prefix}{n}" for n in names] @@ -114,6 +119,12 @@ class ToolContext: turn_id: int = 0 # current query/turn within the session test_call_id: int = 0 # incremented per test_option_plan call visualized_state: Optional[State] = None # last state from visualize_state + extra_mcp_tools: list = field(default_factory=list) # injected by subclass + # Populated by AgentBilevelExplorer so learning approaches can diff + # mental-model subgoals against real trajectories. + # TODO(sim-learning): consume these in learn_from_interaction_results. + last_sketch_subgoals: Optional[Any] = None + last_sketch_options: Optional[Any] = None def _text_result(text: str) -> Dict[str, Any]: @@ -166,7 +177,7 @@ def _render_pybullet_image( from PIL import Image as PILImage if state is not None: - ctx.env._reset_state(state) # pylint: disable=protected-access + ctx.env._set_state(state) # pylint: disable=protected-access video = ctx.env.render() if not video: @@ -1767,7 +1778,7 @@ async def annotate_scene(args: Dict[str, Any]) -> Dict[str, Any]: render_state = ctx.visualized_state or (ctx.current_task.init if ctx.current_task else None) if render_state is not None: - ctx.env._reset_state(render_state) # pylint: disable=protected-access + ctx.env._set_state(render_state) # pylint: disable=protected-access physics_id = ctx.env._physics_client_id # pylint: disable=protected-access annotations = args.get("annotations", []) @@ -1945,5 +1956,232 @@ async def visualize_state(args: Dict[str, Any]) -> Dict[str, Any]: "visualize_state": visualize_state, } if tool_names is None: - return list(_all.values()) - return [_all[n] for n in tool_names if n in _all] + tools = list(_all.values()) + else: + tools = [_all[n] for n in tool_names if n in _all] + tools.extend(ctx.extra_mcp_tools) + return tools + + +# ── Sim-learning tools ─────────────────────────────────────────── + + +def create_synthesis_tools( + exec_ns: Dict[str, Any], + base_pred_triples: list, + inferred_process_features: Dict[str, List[str]], + save_dir: Optional[str] = None, +) -> list: + """Create MCP tools for the sim-learning synthesis agent. + + Returns ``[run_python, evaluate_simulator, test_simulator]``. + + * ``run_python`` — executes arbitrary Python in a persistent + namespace pre-loaded with trajectory data. + * ``evaluate_simulator`` — fits parameters via MCMC on + ``PROCESS_RULES`` / ``PARAM_SPECS`` defined in the namespace. + * ``test_simulator`` — tests predictions vs observations. + + Both eval/test read ``PROCESS_FEATURES`` from ``exec_ns`` on each + call, falling back to ``inferred_process_features`` if the agent + hasn't declared it yet. + + Args: + exec_ns: Persistent namespace for ``run_python``. Should + contain ``trajectories``, ``np``, ``ParamSpec``. + base_pred_triples: ``(s_base, action, s_next_obs)`` triples + with the base step already advanced — eval/test consume + ``s_base`` directly so no live env is needed. + inferred_process_features: Data-driven default scope used + until the agent defines ``PROCESS_FEATURES`` in exec_ns. + save_dir: Directory to save simulator source code to. + Each ``run_python`` call appends code to + ``save_dir/simulator_code.py``. + """ + import io # pylint: disable=import-outside-toplevel + import sys # pylint: disable=import-outside-toplevel + import traceback # pylint: disable=import-outside-toplevel,redefined-outer-name,reimported + + from claude_agent_sdk import \ + tool # pylint: disable=import-outside-toplevel + + from predicators.approaches.agent_sim_learning_approach import \ + AgentSimLearningApproach # pylint: disable=import-outside-toplevel + + _run_count = [0] # mutable counter in closure + + def _text(msg: str) -> Dict[str, Any]: + return {"type": "text", "text": msg} + + # ── run_python ────────────────────────────────────────── + + @tool( + "run_python", + "Execute Python code with trajectory data in scope. " + "Available variables: trajectories (List[LowLevelTrajectory])," + " np, ParamSpec. print() output is returned. " + "The namespace persists across calls.", + { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute.", + } + }, + "required": ["code"], + }, + ) + async def run_python(args: Dict[str, Any]) -> Dict[str, Any]: + code = args["code"] + old_stdout = sys.stdout + sys.stdout = captured = io.StringIO() + try: + exec(code, exec_ns) # pylint: disable=exec-used + except Exception: # pylint: disable=broad-except + tb = traceback.format_exc() + return _text(f"Error:\n{tb}") + finally: + sys.stdout = old_stdout + + # Save each successful run_python call as a versioned file; + # _load_simulator_from_file replays these in order. + if save_dir is not None: + _run_count[0] += 1 + os.makedirs(save_dir, exist_ok=True) + filename = f"{_run_count[0]:03d}_run_python.py" + filepath = os.path.join(save_dir, filename) + with open(filepath, "w", encoding="utf-8") as f: + f.write(code) + + output = captured.getvalue() + return _text(output or "(no output)") + + # ── evaluate_simulator ────────────────────────────────── + + @tool( + "evaluate_simulator", + "Fit parameters using PROCESS_RULES and PARAM_SPECS " + "from the run_python namespace. Reports SSE and fitted " + "parameter values.", + { + "type": "object", + "properties": {} + }, + ) + async def evaluate_simulator(_args: Dict[str, Any]) -> Dict[str, Any]: + rules = exec_ns.get("PROCESS_RULES") + specs = exec_ns.get("PARAM_SPECS") + if not isinstance(rules, list) or not rules: + return _text("Error: PROCESS_RULES not defined. Use " + "run_python to define it first.") + if not isinstance(specs, list) or not specs: + return _text("Error: PARAM_SPECS not defined. Use " + "run_python to define it first.") + + declared = exec_ns.get("PROCESS_FEATURES") + process_features = (declared if isinstance(declared, dict) else + inferred_process_features) + scope_note = ("PROCESS_FEATURES" if isinstance(declared, dict) else + "inferred (PROCESS_FEATURES not declared)") + + try: + fitted_params, sse = ( + AgentSimLearningApproach._fit_parameters( # pylint: disable=protected-access + rules, specs, base_pred_triples, process_features)) + except Exception as e: # pylint: disable=broad-except + return _text(f"Error: fit_params failed:\n{e}") + + lines = [ + f"SSE: {sse:.6f} on " + f"{len(base_pred_triples)} step transitions " + f"(scope: {scope_note}).", + "", + "Fitted parameters:", + ] + for name, val in fitted_params.items(): + lines.append(f" {name}: {val:.6f}") + + return _text("\n".join(lines)) + + # ── test_simulator ────────────────────────────────────── + + @tool( + "test_simulator", + "Test PROCESS_RULES predictions vs observations on " + "step transitions. Shows mismatches.", + { + "type": "object", + "properties": { + "max_transitions": { + "type": "integer", + "description": "Max transitions to test (default 100).", + }, + "tolerance": { + "type": + "number", + "description": + "Absolute tolerance for mismatch " + "(default 1e-4).", + }, + }, + }, + ) + async def test_simulator(args: Dict[str, Any]) -> Dict[str, Any]: + rules = exec_ns.get("PROCESS_RULES") + specs = exec_ns.get("PARAM_SPECS") + if not isinstance(rules, list) or not rules: + return _text("Error: PROCESS_RULES not defined.") + + declared = exec_ns.get("PROCESS_FEATURES") + process_features = (declared if isinstance(declared, dict) else + inferred_process_features) + + max_n = args.get("max_transitions", 100) + tol = args.get("tolerance", 1e-4) + pairs = base_pred_triples[:max_n] + + # Use init params if not yet fitted. + if specs: + t_params = {s.name: s.init_value for s in specs} + else: + t_params = {} + + lines: list = [] + n_tested = 0 + n_mismatch = 0 + + for base_state, _action, s_next_obs in pairs: + updates: Dict = {} + for rule in rules: + updates = rule(base_state, updates, t_params) + + entry: list = [] + for obj in base_state: + type_name = obj.type.name + for feat in process_features.get(type_name, []): + if obj in updates and feat in updates[obj]: + pred = updates[obj][feat] + pred = (pred.item() + if hasattr(pred, "item") else float(pred)) + else: + pred = base_state.get(obj, feat) + obs = s_next_obs.get(obj, feat) + err = abs(pred - obs) + if err > tol: + entry.append(f" {obj.name}.{feat}: " + f"pred={pred:.6f} obs={obs:.6f} " + f"err={err:.6f}") + + n_tested += 1 + if entry: + n_mismatch += 1 + lines.append(f"Step {n_tested}:") + lines.extend(entry) + lines.append("") + + lines.append(f"Tested {n_tested} steps: {n_mismatch} mismatches, " + f"{n_tested - n_mismatch} correct.") + return _text("\n".join(lines)) + + return [run_python, evaluate_simulator, test_simulator] diff --git a/predicators/approaches/agent_abstraction_learning_approach.py b/predicators/approaches/agent_abstraction_learning_approach.py index b76fcf7de..bf24a5def 100644 --- a/predicators/approaches/agent_abstraction_learning_approach.py +++ b/predicators/approaches/agent_abstraction_learning_approach.py @@ -13,10 +13,10 @@ from gym.spaces import Box from predicators import utils +from predicators.agent_sdk.agent_session_mixin import AgentSessionMixin from predicators.agent_sdk.proposal_parser import ProposalBundle, \ build_exec_context, exec_code_safely from predicators.approaches.agent_planner_approach import AgentPlannerApproach -from predicators.approaches.agent_session_mixin import AgentSessionMixin from predicators.approaches.pp_online_process_learning_approach import \ OnlineProcessLearningAndPlanningApproach from predicators.approaches.pp_predicate_invention_approach import \ @@ -477,7 +477,7 @@ def _build_solve_prompt(self, task: Task) -> str: def _create_explorer(self) -> BaseExplorer: """Create explorer, passing agent context if using agent explorer.""" - if CFG.explorer == "agent": + if CFG.explorer == "agent_plan": all_trajs = (self._offline_dataset.trajectories + self._online_dataset.trajectories) self._sync_tool_context(all_trajs) diff --git a/predicators/approaches/agent_bilevel_approach.py b/predicators/approaches/agent_bilevel_approach.py index 98b8c1df8..1baf550a1 100644 --- a/predicators/approaches/agent_bilevel_approach.py +++ b/predicators/approaches/agent_bilevel_approach.py @@ -10,17 +10,17 @@ python predicators/main.py --env pybullet_domino \ --approach agent_bilevel --seed 0 \ --num_train_tasks 1 --num_test_tasks 1 \ - --num_online_learning_cycles 1 --explorer agent + --num_online_learning_cycles 1 --explorer agent_plan """ -import dataclasses import logging -import re import time -from typing import Callable, List, Optional, Sequence, Set, Tuple, cast +from typing import Callable, List, Optional, Sequence, Set, Tuple import numpy as np from predicators import utils +from predicators.agent_sdk import bilevel_sketch +from predicators.agent_sdk.bilevel_sketch import SketchStep as _SketchStep from predicators.approaches import ApproachFailure from predicators.approaches.agent_planner_approach import AgentPlannerApproach from predicators.planning import run_backtracking_refinement @@ -29,16 +29,6 @@ ParameterizedOption, Predicate, State, Task, _Option -@dataclasses.dataclass -class _SketchStep: - """One step in an agent-produced plan sketch.""" - option: ParameterizedOption - objects: Sequence[Object] - subgoal_atoms: Optional[Set[GroundAtom]] # None = no subgoal constraint - # Atoms that must be FALSE after this step. - subgoal_neg_atoms: Optional[Set[GroundAtom]] = None - - class AgentBilevelApproach(AgentPlannerApproach): """Bilevel planning: agent proposes discrete skeleton, search refines continuous parameters. @@ -90,114 +80,13 @@ def _get_agent_system_prompt(self) -> str: def _build_solve_prompt(self, task: Task) -> str: """Build prompt asking for a plan sketch without continuous params.""" - init_state = task.init - objects = list(init_state) - - # Objects - obj_strs = [] - for obj in sorted(objects, key=lambda o: o.name): - obj_strs.append(f" {obj.name}: {obj.type.name}") - - # Goal - goal_strs = [str(a) for a in sorted(task.goal, key=str)] - - # Options (show params_space info so agent understands what's tunable) - option_strs = [] - for opt in sorted(self._get_all_options(), key=lambda o: o.name): - type_sig = ", ".join(t.name for t in opt.types) - params_dim = opt.params_space.shape[0] - if params_dim > 0: - low = opt.params_space.low.tolist() - high = opt.params_space.high.tolist() - if opt.params_description: - desc = ", ".join(opt.params_description) - param_info = (f" [auto-searched params: {desc}, " - f"range {low} to {high}]") - else: - param_info = (f" [auto-searched: {params_dim}d, " - f"range {low} to {high}]") - else: - param_info = "" - option_strs.append(f" {opt.name}({type_sig}){param_info}") - - # Current atoms - atoms = utils.abstract(init_state, self._get_all_predicates()) - atom_strs = [str(a) for a in sorted(atoms, key=str)] - - # Trajectory summary - traj_summary = self._build_trajectory_summary() - - # State features - state_str = init_state.dict_str(indent=2) - - # Available tools - tool_names = self._get_agent_tool_names() - tools_str = "" - if tool_names: - tool_list = "\n".join(f" - {t}" for t in tool_names) - tools_str = f"\n## Available Tools\n{tool_list}\n" - - # Natural language goal - goal_nl_section = "" - if task.goal_nl: - goal_nl_section = f"\n## Goal Description\n{task.goal_nl}\n" - - # Available predicates for subgoal annotations - pred_strs = [] - for pred in sorted(self._get_all_predicates(), key=lambda p: p.name): - type_sig = ", ".join(t.name for t in pred.types) - pred_strs.append(f" {pred.name}({type_sig})") - - prompt = f"""You are solving a task. \ -Generate a plan sketch to achieve the goal. -{goal_nl_section} -## Goal Atoms -{chr(10).join(goal_strs)} - -## Initial State Atoms -{chr(10).join(atom_strs)} - -## Initial State Features -{state_str} - -## Objects -{chr(10).join(obj_strs)} - -## Available Options -{chr(10).join(option_strs)} - -## Available Predicates (for subgoal annotations) -{chr(10).join(pred_strs)} -{traj_summary}{tools_str} -## Instructions -Use your available tools to inspect the environment before producing the plan. - -Generate a plan SKETCH — the sequence of options with object arguments, but \ -WITHOUT continuous parameters. Continuous parameters will be found \ -automatically by a backtracking search procedure. - -Optionally annotate subgoal atoms that should hold after each step. This \ -helps the search verify progress. Use `-> {{atoms}}` after each step. - -After any action whose desired subgoal depends on a delayed process (e.g. \ -water filling, dominoes cascading, heating), insert a Wait action. For Wait \ -steps, annotate with the atoms the process should produce — this tells the \ -system exactly when the Wait should end rather than terminating on any \ -incidental atom change. Use `NOT Pred(...)` for atoms that should become false. - -Output the plan sketch with one option per line in this format: - OptionName(obj1:type1, obj2:type2) -> \ -{{Pred(obj1:type1), Pred2(obj1:type1, obj2:type2)}} - Wait(robot:Robot) -> {{Boiled(water:water_type)}} - Wait(robot:Robot) -> {{NOT Touching(a:block, b:block)}} - -Always use typed references (obj:type) in both option arguments AND subgoal \ -atoms. The `-> {{atoms}}` part is optional. If you omit it, the search will \ -only check that the option executed successfully (non-zero actions). - -Output ONLY the plan sketch lines at the end, after any analysis.""" - - return prompt + return bilevel_sketch.build_solve_prompt( + task, + all_predicates=self._get_all_predicates(), + all_options=self._get_all_options(), + trajectory_summary=self._build_trajectory_summary(), + tool_names=self._get_agent_tool_names(), + ) # ------------------------------------------------------------------ # # Solving @@ -261,140 +150,39 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: def _query_agent_for_plan_sketch(self, task: Task) -> List[_SketchStep]: """Query agent for a plan sketch and parse it.""" - prompt = self._build_solve_prompt(task) - responses = self._query_agent_sync(prompt) - plan_text = self._extract_option_plan_text(responses) + sketch_file = CFG.agent_bilevel_plan_sketch_file + if sketch_file: + with open(sketch_file, "r", encoding="utf-8") as f: + plan_text = f.read().strip() + logging.info("Loaded plan sketch from file: %s", sketch_file) + else: + prompt = self._build_solve_prompt(task) + responses = self._query_agent_sync(prompt) + plan_text = self._extract_option_plan_text(responses) if not plan_text: - n_responses = len(responses) - types = [r.get("type") for r in responses] - raise ApproachFailure( - f"Agent returned empty plan text. " - f"Got {n_responses} responses with types: {types}") - - cleaned_text = self._strip_code_fences(plan_text) - - # Phase 1: parse options + objects (no continuous params) - objects = list(task.init) - parsed = utils.parse_model_output_into_option_plan( - cleaned_text, - objects, - self._types, - self._get_all_options(), - parse_continuous_params=False) - - if not parsed: + raise ApproachFailure("Agent returned empty plan text.") + + sketch = bilevel_sketch.parse_sketch_from_text( + plan_text, + task, + predicates=self._get_all_predicates(), + options=self._get_all_options(), + types=self._types, + ) + + if not sketch: option_names = sorted(o.name for o in self._get_all_options()) raise ApproachFailure(f"Parsed empty plan sketch from agent.\n" f" Plan text:\n{plan_text}\n" f" Available option names: {option_names}") - # Phase 2: parse subgoal annotations from raw text - subgoals = self._parse_subgoal_annotations(cleaned_text, - self._get_all_predicates(), - objects) - - # Zip into sketch steps - sketch = [] - for i, (option, objs, _) in enumerate(parsed): - sg = subgoals[i] if i < len(subgoals) else None - if sg is not None: - pos, neg = sg - sketch.append( - _SketchStep(option=option, - objects=objs, - subgoal_atoms=pos if pos else None, - subgoal_neg_atoms=neg if neg else None)) - else: - sketch.append( - _SketchStep(option=option, - objects=objs, - subgoal_atoms=None)) - logging.info(f"[{self._run_id}] Agent produced sketch with " f"{len(sketch)} steps, " f"{sum(1 for s in sketch if s.subgoal_atoms)} " f"with subgoals.") return sketch - def _parse_subgoal_annotations( - self, - text: str, - predicates: Set[Predicate], - objects: Sequence[Object], - ) -> List[Optional[Tuple[Set[GroundAtom], Set[GroundAtom]]]]: - """Parse ``-> {Pred(...), NOT Pred(...)}`` annotations from plan text. - - Returns a list parallel to the option lines. Entries are None - for lines without annotations. Each non-None entry is - ``(positive_atoms, negative_atoms)``. - """ - pred_map = {p.name: p for p in predicates} - obj_map = {o.name: o for o in objects} - - # Regex: match -> { ... } after the option line - subgoal_re = re.compile(r'->\s*\{([^}]*)\}') - # Regex: match individual atoms, optionally prefixed with NOT - atom_re = re.compile(r'(NOT\s+)?(\w+)\(([^)]*)\)') - - results: List[Optional[Tuple[Set[GroundAtom], Set[GroundAtom]]]] = [] - option_names = {o.name for o in self._get_all_options()} - - for line in text.split('\n'): - stripped = line.strip() - if not stripped: - continue - # Check if this line starts with a valid option name - first_token = stripped.split('(')[0] - if first_token not in option_names: - continue - - # This is an option line — check for subgoal annotation - sg_match = subgoal_re.search(stripped) - if not sg_match: - results.append(None) - continue - - atoms_text = sg_match.group(1) - pos_atoms: Set[GroundAtom] = set() - neg_atoms: Set[GroundAtom] = set() - for atom_match in atom_re.finditer(atoms_text): - is_neg = atom_match.group(1) is not None - pred_name = atom_match.group(2) - # Handle both "obj" and "obj:type" formats - obj_names = [ - n.strip().split(':')[0] - for n in atom_match.group(3).split(',') - ] - - if pred_name not in pred_map: - logging.warning(f"Unknown predicate in subgoal: " - f"{pred_name}") - continue - pred = pred_map[pred_name] - try: - objs = [obj_map[n] for n in obj_names] - except KeyError as e: - logging.warning(f"Unknown object in subgoal: {e}") - continue - if len(objs) != len(pred.types): - logging.warning( - f"Arity mismatch for {pred_name}: expected " - f"{len(pred.types)}, got {len(objs)}") - continue - atom = GroundAtom(pred, objs) - if is_neg: - neg_atoms.add(atom) - else: - pos_atoms.add(atom) - - if pos_atoms or neg_atoms: - results.append((pos_atoms, neg_atoms)) - else: - results.append(None) - - return results - # ------------------------------------------------------------------ # # Backtracking refinement # ------------------------------------------------------------------ # @@ -411,86 +199,37 @@ def _refine_sketch( grounded options that achieves the task goal. On failure, ``plan`` is the longest partial refinement found. - Delegates to ``run_backtracking_refinement`` for the core loop. + Delegates to ``bilevel_sketch.refine_sketch``. """ - if not sketch: - return [], False - - rng = np.random.default_rng(CFG.seed) - max_samples = CFG.agent_bilevel_max_samples_per_step - check_subgoals = CFG.agent_bilevel_check_subgoals - n = len(sketch) - max_tries = [ - max_samples if step.option.params_space.shape[0] > 0 else 1 - for step in sketch - ] - predicates = self._get_all_predicates() - - def sample_fn(idx: int, state: State, - rng_: np.random.Generator) -> _Option: - step = sketch[idx] - if CFG.agent_bilevel_log_state: - step_name = (f"{step.option.name}" - f"({', '.join(o.name for o in step.objects)})") - logging.debug(f" State before {step_name}:\n" - f"{state.pretty_str()}") - params = self._sample_params(step.option, state, rng_) - grounded = step.option.ground(step.objects, params) - if grounded.name == "Wait": - if step.subgoal_atoms is not None: - grounded.memory["wait_target_atoms"] = \ - step.subgoal_atoms - if step.subgoal_neg_atoms is not None: - grounded.memory["wait_target_neg_atoms"] = \ - step.subgoal_neg_atoms - return grounded - - def validate_fn(idx: int, _pre_state: State, _option: _Option, - post_state: State, - _num_actions: int) -> Tuple[bool, str]: - step = sketch[idx] - if check_subgoals and step.subgoal_atoms is not None: - current_atoms = utils.abstract(post_state, predicates) - if not step.subgoal_atoms.issubset(current_atoms): - missing = step.subgoal_atoms - current_atoms - return False, (f"subgoal missing: " - f"{{{', '.join(str(a) for a in missing)}}}") - if idx == n - 1: - if not task.goal_holds(post_state): - return False, "goal not reached" - return True, "" - - plan, success, total_samples = run_backtracking_refinement( - init_state=task.init, - option_model=self._option_model, - n_steps=n, - max_tries=max_tries, - sample_fn=sample_fn, - validate_fn=validate_fn, - rng=rng, + plan, success, _ = bilevel_sketch.refine_sketch( + task, + sketch, + self._option_model, + predicates=self._get_all_predicates(), timeout=timeout, + rng=np.random.default_rng(CFG.seed), + max_samples_per_step=CFG.agent_bilevel_max_samples_per_step, + check_subgoals=CFG.agent_bilevel_check_subgoals, + log_state=CFG.agent_bilevel_log_state, + run_id=self._run_id, ) - - logging.info(f"Refinement {'succeeded' if success else 'failed'}: " - f"{total_samples} samples for {n} steps.") - - filtered = [p for p in plan if p is not None] - if success: - return cast(List[_Option], filtered), True - return filtered, False + return plan, success def _sample_params(self, option: ParameterizedOption, _state: State, rng: np.random.Generator) -> np.ndarray: - """Sample continuous parameters for an option. + """Sample continuous parameters for an option.""" + return bilevel_sketch.sample_params(option, rng) - Currently uniform random; hook point for future learned - samplers. - """ - if option.params_space.shape[0] == 0: - return np.array([], dtype=np.float32) - low = option.params_space.low - high = option.params_space.high - return rng.uniform(low, high).astype(np.float32) + def _parse_subgoal_annotations( + self, + text: str, + predicates: Set[Predicate], + objects: Sequence[Object], + ) -> List[Optional[Tuple[Set[GroundAtom], Set[GroundAtom]]]]: + """Shim over ``bilevel_sketch.parse_subgoal_annotations``.""" + option_names = {o.name for o in self._get_all_options()} + return bilevel_sketch.parse_subgoal_annotations( + text, predicates, objects, option_names) # ------------------------------------------------------------------ # # Forward validation diff --git a/predicators/approaches/agent_closed_loop_approach.py b/predicators/approaches/agent_closed_loop_approach.py index 8e38ebf33..1bf7805b1 100644 --- a/predicators/approaches/agent_closed_loop_approach.py +++ b/predicators/approaches/agent_closed_loop_approach.py @@ -9,7 +9,7 @@ python predicators/main.py --env pybullet_domino \ --approach agent_closed_loop --seed 0 \ --num_train_tasks 1 --num_test_tasks 1 \ - --num_online_learning_cycles 1 --explorer agent + --num_online_learning_cycles 1 --explorer agent_plan """ import logging from typing import Callable, List diff --git a/predicators/approaches/agent_planner_approach.py b/predicators/approaches/agent_planner_approach.py index f178fc76b..cfa164737 100644 --- a/predicators/approaches/agent_planner_approach.py +++ b/predicators/approaches/agent_planner_approach.py @@ -1,6 +1,6 @@ """Agent planner approach: fixed-vocabulary open-loop planning. -Combines online trajectory collection (via AgentExplorer) with open-loop +Combines online trajectory collection (via AgentPlanExplorer) with open-loop option plan generation (via Claude Agent SDK). No predicate/process/type invention — just stores trajectories and generates plans. @@ -8,7 +8,7 @@ python predicators/main.py --env pybullet_domino \ --approach agent_planner --seed 0 \ --num_train_tasks 1 --num_test_tasks 1 \ - --num_online_learning_cycles 1 --explorer agent + --num_online_learning_cycles 1 --explorer agent_plan """ import datetime import inspect as _inspect @@ -22,8 +22,8 @@ from gym.spaces import Box from predicators import utils +from predicators.agent_sdk.agent_session_mixin import AgentSessionMixin from predicators.approaches import ApproachFailure -from predicators.approaches.agent_session_mixin import AgentSessionMixin from predicators.approaches.base_approach import BaseApproach from predicators.explorers import create_explorer from predicators.explorers.base_explorer import BaseExplorer @@ -37,7 +37,7 @@ class AgentPlannerApproach(AgentSessionMixin, BaseApproach): """Fixed-vocabulary open-loop planning via Claude Agent SDK. - - Collects trajectories online using AgentExplorer + - Collects trajectories online using AgentPlanExplorer - At solve time, queries the agent for an option plan - No predicate/process/type invention """ @@ -60,13 +60,13 @@ def __init__(self, else: self._option_model = create_option_model(CFG.option_model_name) # Let the option model terminate Wait on atom change using the - # approach's predicates (which may include invented ones). + # approach's predicates (which may include invented ones). Looked + # up lazily so the lambda picks up predicates invented after + # __init__. if CFG.wait_option_terminate_on_atom_change: - preds = self._get_all_predicates() cast( # pylint: disable=protected-access - Any, self._option_model - )._abstract_function = \ - lambda s, _p=preds: utils.abstract(s, _p) + Any, self._option_model)._abstract_function = ( + lambda s: utils.abstract(s, self._get_all_predicates())) self._online_learning_cycle = 0 self._requests_train_task_idxs: Optional[List[int]] = None self._run_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") @@ -705,10 +705,13 @@ def _parse_and_ground_plan(self, plan_text: str, task: Task) -> list: def _create_explorer(self) -> BaseExplorer: """Create explorer for interaction requests.""" - if CFG.explorer == "agent": + if CFG.explorer in ("agent_plan", "agent_bilevel"): self._sync_tool_context() - return self._create_agent_explorer(self._get_all_predicates(), - self._get_all_options()) + return self._create_agent_explorer( + self._get_all_predicates(), + self._get_all_options(), + name=CFG.explorer, + ) return create_explorer( CFG.explorer, self._get_all_predicates(), diff --git a/predicators/approaches/agent_sim_learning_approach.py b/predicators/approaches/agent_sim_learning_approach.py new file mode 100644 index 000000000..f840e2781 --- /dev/null +++ b/predicators/approaches/agent_sim_learning_approach.py @@ -0,0 +1,654 @@ +"""Agent sim-learning approach: learns a simulator program online. + +Extends AgentBilevelApproach to learn process dynamics via an +agent-synthesized step-level simulator with parameterized process +rules. Parameters are fitted via emcee ensemble MCMC (training.py). + +The approach creates a base oracle (PyBullet with process +dynamics disabled) and composes it with the learned step-level +dynamics into a single simulator function, plugged into a standard +_OracleOptionModel for true per-step interleaving. + +Example command:: + + python predicators/main.py --env pybullet_boil \ + --approach agent_sim_learning --seed 0 \ + --num_train_tasks 10 --num_test_tasks 5 \ + --num_online_learning_cycles 5 --explorer agent_plan +""" + +import inspect +import logging +import os +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple + +import numpy as np +import pybullet +from gym.spaces import Box + +from predicators import utils +from predicators.agent_sdk.tools import create_synthesis_tools +from predicators.approaches.agent_bilevel_approach import AgentBilevelApproach +from predicators.code_sim_learning.training import ParamSpec, compute_sse, \ + fit_params, log_sse_breakdown +from predicators.code_sim_learning.utils import LearnedSimulator, \ + apply_rules, merge_updates, read_simulator_components +from predicators.envs import create_new_env +from predicators.ground_truth_models import get_gt_simulator +from predicators.option_model import _OptionModelBase, _OracleOptionModel +from predicators.settings import CFG +from predicators.structs import Action, Dataset, InteractionResult, \ + LowLevelTrajectory, ParameterizedOption, Predicate, State, Task, Type + +logger = logging.getLogger(__name__) + +# ── Approach ───────────────────────────────────────────────────── + + +class AgentSimLearningApproach(AgentBilevelApproach): + """Bilevel planning with a learned step-level simulator. + + During online learning: + 1. Collect trajectories (inherited from AgentBilevelApproach) + 2. Segment into option-level transitions + 3. Synthesize parameterized process rules via Claude agent + 4. Fit rule parameters via emcee ensemble MCMC + 5. Compose with base oracle into a combined simulator + 6. Build _OracleOptionModel with the combined simulator + + During solving: + - Uses the learned model for plan validation in backtracking + refinement. + """ + + def __init__(self, + initial_predicates: Set[Predicate], + initial_options: Set[ParameterizedOption], + types: Set[Type], + action_space: Box, + train_tasks: List[Task], + *args: Any, + option_model: Optional[_OptionModelBase] = None, + **kwargs: Any) -> None: + # Build the base env and pass the option model in so the parent + # __init__ doesn't spin up its own full-process env, which + # would fight this one for the PyBullet GUI client. + self._base_env = create_new_env(CFG.env, + do_cache=False, + use_gui=CFG.option_model_use_gui, + skip_process_dynamics=True) + if option_model is None: + option_model = _OracleOptionModel(initial_options, + self._base_env.simulate) + super().__init__(initial_predicates, + initial_options, + types, + action_space, + train_tasks, + *args, + option_model=option_model, + **kwargs) + self._learned_simulator: Optional[LearnedSimulator] = None + # Loss-scope mask for parameter fitting (compute_sse). + self._process_features: Dict[str, List[str]] = {} + self._process_rules: Optional[List] = None + self._fitted_params: Optional[Dict[str, float]] = None + self._fit_sse: float = float("inf") + self._learning_mode: bool = False + + @classmethod + def get_name(cls) -> str: + return "agent_sim_learning" + + # ── Agent session hooks ────────────────────────────────────── + + def _get_agent_system_prompt(self) -> str: + if self._learning_mode: + return self._build_synthesis_system_prompt() + return super()._get_agent_system_prompt() + + # ── Learning ──────────────────────────────────────────────── + + def learn_from_offline_dataset(self, dataset: Dataset) -> None: + super().learn_from_offline_dataset(dataset) + self._learn_simulator(dataset.trajectories) + + def learn_from_interaction_results( + self, results: Sequence[InteractionResult]) -> None: + super().learn_from_interaction_results(results) + self._learn_simulator(self._online_trajectories) + + def _learn_simulator(self, trajectories: List[LowLevelTrajectory]) -> None: + """Synthesize rules, fit parameters, and build the option model.""" + # Two parallel triple lists drive the rest of this method: + # * obs_triples — raw (s_t, a, s_{t+1}) from the data. + # * base_pred_triples — same triples but s_t replaced by the + # base sim's one-step prediction. The rules run on top of + # that prediction; SSE compares against s_{t+1}. + obs_triples = self._extract_obs_triples(trajectories) + if not obs_triples: + logger.warning("No step transitions; skipping simulator learning.") + return + # Headless env for the pre-compute: reusing the GUI base_env + # corrupts its visual-shape state after a few hundred steps. + fit_env = create_new_env(CFG.env, + do_cache=False, + use_gui=False, + skip_process_dynamics=True) + logger.info("Pre-computing base states for %d transitions.", + len(obs_triples)) + base_pred_triples = self._compute_base_pred_triples( + obs_triples, fit_env) + inferred_hint = self._infer_process_features_from_residuals( + obs_triples, base_pred_triples) + logger.info("Process features (data-driven hint): %s", inferred_hint) + + self._synthesize_with_agent(trajectories, obs_triples, + base_pred_triples, inferred_hint) + + if self._process_rules is not None and self._fitted_params is not None: + rules, params = self._process_rules, self._fitted_params + self._learned_simulator = LearnedSimulator( + step_fn=lambda s, _r=rules, _p=params: # type: ignore[misc] + apply_rules(s, _r, _p), + name="agent_synthesized") + elif self._learned_simulator is None: + logger.warning("Synthesis produced no simulator, skipping.") + return + + combined_sim = self._build_combined_simulator(self._learned_simulator) + self._option_model = self._build_option_model(combined_sim) + logger.info("Built learned option model (SSE: %.6f).", self._fit_sse) + + def _build_option_model( + self, + simulator_fn: Callable[[State, Action], State], + ) -> _OracleOptionModel: + """Wrap a simulator function in an OracleOptionModel. + + Uses ``self._get_all_options()`` rather than + ``get_gt_options(CFG.env)`` to avoid spawning a second cached + PyBullet env via ``get_or_create_env``. + """ + model = _OracleOptionModel(self._get_all_options(), simulator_fn) + if CFG.wait_option_terminate_on_atom_change: + model._abstract_function = ( # pylint: disable=protected-access + lambda s: utils.abstract(s, self._get_all_predicates())) + return model + + # ── Agent-based synthesis ──────────────────────────────────── + + def _synthesize_with_agent( + self, + trajectories: List[LowLevelTrajectory], + obs_triples: List[Tuple[State, Action, State]], + base_pred_triples: List[Tuple[State, Action, State]], + inferred_hint: Dict[str, List[str]], + ) -> None: + """Synthesize PROCESS_RULES, PARAM_SPECS, PROCESS_FEATURES via agent. + + ``inferred_hint`` is passed to the agent as a starting point and + used as the eval/test scope until it declares its own + ``PROCESS_FEATURES``. CFG flags + ``agent_sim_learn_oracle_sim_program`` and + ``agent_sim_learn_oracle_sim_params`` short-circuit the agent + and/or MCMC by loading the GT simulator instead. + """ + + if CFG.agent_sim_learn_oracle_sim_program: + rules, specs, process_features = get_gt_simulator(CFG.env) + self._log_feature_set_diff(inferred_hint, process_features, + "inferred", "oracle") + if not CFG.agent_sim_learn_oracle_sim_params: + rng = np.random.default_rng(CFG.seed) + noise_scale = CFG.agent_sim_learn_oracle_sim_param_noise_scale + if noise_scale < 0.0: + raise ValueError( + "agent_sim_learn_oracle_sim_param_noise_scale must " + "be non-negative.") + perturbed = [] + for s in specs: + val = s.init_value * (1.0 + + float(rng.normal(0, noise_scale))) + if s.lo is not None: + val = max(s.lo, val) + if s.hi is not None: + val = min(s.hi, val) + perturbed.append(ParamSpec(s.name, val, lo=s.lo, hi=s.hi)) + specs = perturbed + logger.info("Loaded oracle sim program (%d rules, %d params).", + len(rules), len(specs)) + else: + base = self._tool_context.sandbox_dir or self._get_log_dir() + save_dir = os.path.join(base, "simulator_code") + + exec_ns: Dict[str, Any] = { + "trajectories": trajectories, + "np": np, + "ParamSpec": ParamSpec, + } + + tools = create_synthesis_tools(exec_ns, + base_pred_triples, + inferred_hint, + save_dir=save_dir) + self._tool_context.extra_mcp_tools = tools + self._learning_mode = True + + # Fresh session so the synthesis prompt + tools take effect. + self._close_agent_session() + self._ensure_agent_session() + + structs_ref = self._write_structs_reference() + + n_trajs = len(trajectories) + message = f"""\ +Synthesize a process dynamics simulator for this environment. \ +There are {n_trajs} trajectories ({len(obs_triples)} step \ +transitions) available. + +Data-structure source code is at: {structs_ref} + +A residual scan between the base simulator's prediction and the \ +observed next state suggests these features carry process dynamics \ +(starting hint, may include base-sim jitter — refine as you go): +{inferred_hint} + +Read the data-structures file first, then explore the trajectory \ +data with `run_python` and define PROCESS_RULES, PARAM_SPECS, and \ +PROCESS_FEATURES.""" + + try: + self._query_agent_sync(message) + finally: + self._tool_context.extra_mcp_tools = [] + self._learning_mode = False + self._close_agent_session() + + rules, specs, declared = self._load_simulator_from_file( + save_dir, trajectories) + if rules is None or specs is None: + return + assert declared is not None, ( + "Agent did not declare PROCESS_FEATURES; " + "synthesis output is incomplete.") + process_features = declared + self._log_feature_set_diff(inferred_hint, process_features, + "inferred", "declared") + logger.info("Agent synthesized %d rules, %d params.", len(rules), + len(specs)) + + self._process_rules = rules + self._process_features = process_features + + _noise_sigma = 0.05 # matches fit_params default + if CFG.agent_sim_learn_oracle_sim_params: + self._fitted_params = {s.name: s.init_value for s in specs} + oracle_sim_fn = lambda s, a, p: apply_rules( # noqa: E731 + s, rules, p) + self._fit_sse = compute_sse(oracle_sim_fn, base_pred_triples, + self._fitted_params, process_features) + fit_ll = -0.5 * self._fit_sse / (_noise_sigma**2) + logger.info("Oracle params — SSE: %.6f log-likelihood: %.2f", + self._fit_sse, fit_ll) + for name, val in sorted(self._fitted_params.items()): + logger.info(" %-30s %.4f", name, val) + log_sse_breakdown(oracle_sim_fn, + base_pred_triples, + self._fitted_params, + process_features, + label="oracle") + else: + self._fitted_params, self._fit_sse = self._fit_parameters( + rules, specs, base_pred_triples, process_features) + if CFG.code_sim_learning_num_mcmc_steps == 0: + logger.info("Skipped MCMC; using %d initial params.", + len(specs)) + else: + logger.info("Fitted %d params.", len(specs)) + + # ── Parameter fitting ──────────────────────────────────────── + + @staticmethod + def _fit_parameters( + rules: List, + specs: List[ParamSpec], + base_pred_triples: List[Tuple[State, Action, State]], + process_features: Dict[str, List[str]], + ) -> Tuple[Dict[str, float], float]: + """Fit parameters for the synthesized rules via MCMC. + + ``base_pred_triples`` must already have the base step applied; + precomputing avoids re-running it inside the MCMC inner loop. + """ + + def sim_fn(state: State, _action: Action, params: Dict[str, + float]) -> Dict: + return apply_rules(state, rules, params) + + noise_sigma = 0.05 # matches fit_params default + init_params = {s.name: s.init_value for s in specs} + pre_sse = compute_sse(sim_fn, base_pred_triples, init_params, + process_features) + pre_ll = -0.5 * pre_sse / (noise_sigma**2) + logger.info("Before fitting — SSE: %.6f log-likelihood: %.2f", + pre_sse, pre_ll) + log_sse_breakdown(sim_fn, + base_pred_triples, + init_params, + process_features, + label="before") + + result = fit_params( + simulator_fn=sim_fn, + transitions=base_pred_triples, + param_specs=specs, + process_features=process_features, + ) + + fitted_params = result.point_estimate + post_sse = compute_sse(sim_fn, base_pred_triples, fitted_params, + process_features) + post_ll = -0.5 * post_sse / (noise_sigma**2) + logger.info("After fitting — SSE: %.6f log-likelihood: %.2f", + post_sse, post_ll) + log_sse_breakdown(sim_fn, + base_pred_triples, + fitted_params, + process_features, + label="after") + + for name in sorted(fitted_params): + init_val = init_params[name] + fit_val = fitted_params[name] + delta = fit_val - init_val + pct = (delta / init_val * 100) if init_val != 0 else float("nan") + logger.info(" %-30s %.4f -> %.4f (Δ=%.4f, %+.1f%%)", name, + init_val, fit_val, delta, pct) + + return fitted_params, post_sse + + # ── Process-feature inference ──────────────────────────────── + + @staticmethod + def _compute_base_pred_triples( + obs_triples: List[Tuple[State, Action, State]], + base_env: Any, + ) -> List[Tuple[State, Action, State]]: + """Replace each ``s_t`` with the base sim's one-step prediction.""" + return [(base_env.simulate(s, a), a, s_next) + for s, a, s_next in obs_triples] + + @staticmethod + def _infer_process_features_from_residuals( + obs_triples: List[Tuple[State, Action, State]], + base_pred_triples: List[Tuple[State, Action, State]], + abs_tol: float = 1e-4, + rel_tol: float = 1e-3, + min_hits: int = 3, + ) -> Dict[str, List[str]]: + """Features whose base-sim prediction diverges from observation. + + Flags ``(type, feat)`` if ``|pred - obs| > rel_tol*|obs| + abs_tol`` + on at least ``min_hits`` triples. The ``min_hits`` floor keeps + one-off PyBullet jitter from leaking base-handled features into the set. + """ + hits: Dict[Tuple[str, str], int] = {} + for (s_t, _, _), (s_base, _, s_obs) in zip(obs_triples, + base_pred_triples): + for obj in s_t: + for feat in obj.type.feature_names: + pred = float(s_base.get(obj, feat)) + obs = float(s_obs.get(obj, feat)) + if abs(pred - obs) > rel_tol * abs(obs) + abs_tol: + key = (obj.type.name, feat) + hits[key] = hits.get(key, 0) + 1 + out: Dict[str, List[str]] = {} + for (t, f), n in hits.items(): + if n >= min_hits: + out.setdefault(t, []).append(f) + return {t: sorted(fs) for t, fs in out.items()} + + @staticmethod + def _log_feature_set_diff( + a: Dict[str, List[str]], + b: Dict[str, List[str]], + a_label: str, + b_label: str, + ) -> None: + """Log set-difference between two {type: [feats]} maps.""" + a_pairs = {(t, f) for t, fs in a.items() for f in fs} + b_pairs = {(t, f) for t, fs in b.items() for f in fs} + only_a = sorted(a_pairs - b_pairs) + only_b = sorted(b_pairs - a_pairs) + common = a_pairs & b_pairs + logger.info( + "Feature-set diff: %s vs %s (%d common, %d only-%s, %d only-%s)", + a_label, b_label, len(common), len(only_a), a_label, len(only_b), + b_label) + if only_a: + logger.info(" only in %s: %s", a_label, only_a) + if only_b: + logger.info(" only in %s: %s", b_label, only_b) + + @staticmethod + def _load_simulator_from_file( + save_dir: str, + trajectories: Optional[List[LowLevelTrajectory]] = None, + ) -> Tuple[Optional[List], Optional[List[ParamSpec]], Optional[Dict[ + str, List[str]]]]: + """Load PROCESS_RULES, PARAM_SPECS, PROCESS_FEATURES from saved files. + + Execs all ``NNN_run_python.py`` files in ``save_dir`` in order + into one namespace. Returns ``(None, None, None)`` if rules or + specs are missing; ``features`` may be ``None`` independently, + in which case the caller asserts (PROCESS_FEATURES is required + from the agent). + """ + if not os.path.isdir(save_dir): + logger.warning("No simulator code dir at %s.", save_dir) + return None, None, None + + files = sorted(f for f in os.listdir(save_dir) + if f.endswith(".py") and f[0].isdigit()) + if not files: + logger.warning("No code files in %s.", save_dir) + return None, None, None + + ns: Dict[str, Any] = { + "np": np, + "ParamSpec": ParamSpec, + "trajectories": trajectories or [], + } + for fname in files: + fpath = os.path.join(save_dir, fname) + with open(fpath, "r", encoding="utf-8") as f: + code = f.read() + try: + exec(code, ns) # pylint: disable=exec-used + except Exception: # pylint: disable=broad-except + logger.warning("Failed to exec %s, skipping.", + fpath, + exc_info=True) + + rules, specs, features = read_simulator_components(ns) + if rules is None: + logger.warning("Saved code did not define PROCESS_RULES.") + return None, None, None + if specs is None: + logger.warning("Saved code did not define PARAM_SPECS.") + return None, None, None + + logger.info("Loaded %d rules, %d param specs from %d files in %s.", + len(rules), len(specs), len(files), save_dir) + return rules, specs, features + + # ── Static helpers ─────────────────────────────────────────── + + def _write_structs_reference(self) -> str: + """Write key struct sources to the sandbox; return the agent-visible + path.""" + # pylint: disable=import-outside-toplevel,reimported + from predicators.structs import Action as _Action + from predicators.structs import LowLevelTrajectory as _LLT + from predicators.structs import Object as _Object + from predicators.structs import State as _State + from predicators.structs import Type as _Type + + source = "\n\n".join( + inspect.getsource(cls) + for cls in [_Type, _Object, _State, _Action, _LLT]) + + base = self._tool_context.sandbox_dir or self._get_log_dir() + ref_dir = os.path.join(base, "reference") + os.makedirs(ref_dir, exist_ok=True) + ref_path = os.path.join(ref_dir, "structs.py") + with open(ref_path, "w", encoding="utf-8") as f: + f.write(source) + + # Agent sees the sandbox-mounted path, not the host path. + if self._tool_context.sandbox_dir: + return "/sandbox/reference/structs.py" + return ref_path + + @staticmethod + def _extract_obs_triples( + trajectories: List[LowLevelTrajectory], + ) -> List[Tuple[State, Action, State]]: + """Extract observed (s_t, action_t, s_{t+1}) triples.""" + triples: List[Tuple[State, Action, State]] = [] + for traj in trajectories: + for i in range(len(traj.actions)): + triples.append( + (traj.states[i], traj.actions[i], traj.states[i + 1])) + return triples + + def _recreate_base_env(self) -> None: + """Reconnect after a PyBullet physics-server crash.""" + try: + client_id = self._base_env._physics_client_id # type: ignore[attr-defined] # pylint: disable=protected-access + pybullet.disconnect(client_id) + except Exception: # pylint: disable=broad-except # client may already be dead + pass + logging.warning( + "PyBullet physics client crashed; recreating base env " + "(use_gui=%s).", CFG.option_model_use_gui) + self._base_env = create_new_env(CFG.env, + do_cache=False, + use_gui=CFG.option_model_use_gui, + skip_process_dynamics=True) + + def _build_combined_simulator( + self, + learned_simulator: LearnedSimulator, + ) -> Callable[[State, Action], State]: + """Compose base env with learned step-level dynamics. + + Captures ``self`` so the closure can recreate ``_base_env`` and + retry once on a PyBullet crash (common on macOS Metal + GUI). + """ + + def combined_simulate(state: State, action: Action) -> State: + try: + base_state = self._base_env.simulate(state, action) + except pybullet.error as e: + logging.warning( + "PyBullet error in combined_simulate (%s); " + "recreating base env and retrying.", e) + self._recreate_base_env() + base_state = self._base_env.simulate(state, action) + updates = learned_simulator.predict_step(base_state) + if not updates: + return base_state + return merge_updates(base_state, updates) + + return combined_simulate + + @staticmethod + def _build_synthesis_system_prompt() -> str: + """Build the system prompt for the synthesis agent.""" + return """\ +You are synthesizing a parameterized process dynamics simulator for a \ +robotic manipulation environment. + +A separate base physics engine (PyBullet) handles robot movement, grasping, \ +and rigid body physics. Your simulator handles **process dynamics**: features \ +that change due to ongoing physical or causal processes (e.g., water filling, \ +heat transfer) that the base sim doesn't model. + +## Tools + +- `run_python(code)` — execute Python in a persistent namespace. `print()` \ +output is returned. The namespace persists across calls. +- `evaluate_simulator` — fit parameters using PROCESS_RULES and PARAM_SPECS \ +from the namespace. Reports SSE. +- `test_simulator` — test predictions vs observations on step transitions. \ +Shows mismatches. + +### Pre-loaded variables + +- `trajectories`: List[LowLevelTrajectory] — the collected trajectory data +- `np`, `ParamSpec` — standard imports + +### Data structures + +The trajectory data uses classes from `predicators.structs` (Type, Object, \ +State, Action, LowLevelTrajectory). Their source code is provided as a \ +reference file — Read the path given in the first message. + +## Goal + +Define three variables in the `run_python` namespace: + +- `PROCESS_RULES`: list of rule functions +- `PARAM_SPECS`: list of ParamSpec objects +- `PROCESS_FEATURES`: `Dict[str, List[str]]` — for each object type, \ +the feature names your rules predict. This is treated as the truth: \ +the loss only penalises mismatches on these features, and at test \ +time the learned simulator only overwrites these features on top of \ +the base sim's prediction. Be honest — listing features your rules \ +don't actually update will inflate the loss without giving MCMC \ +anything to optimise. + +Parameters are fitted automatically after the session ends. + +### Process rule signature + +```python +def rule(state, updates, params): + \"\"\"Apply one process for a single simulation step. + + Args: + state: Current env state. + updates: Dict[Object, Dict[str, value]] accumulated from prior rules. + params: Dict[str, float] of learned parameters. + + Returns: + The (possibly modified) updates dict. + \"\"\" +``` + +### ParamSpec + +```python +ParamSpec(name: str, init_value: float) +``` + +## Workflow + +1. Explore the trajectory data with `run_python`: types, features, \ +state changes over time +2. Identify which features change due to process dynamics (not the base sim) +3. Define `PROCESS_RULES` and `PARAM_SPECS` in the namespace via `run_python` +4. Call `evaluate_simulator` to fit parameters and check SSE +5. Call `test_simulator` to see prediction mismatches +6. Iterate if needed + +## Tips + +- Each trajectory is a sequence of states from one episode. Compare \ +consecutive states to see per-step changes. +- Group objects by type: \ +`groups = {}; for o in state: groups.setdefault(o.type.name, []).append(o)` +- Accumulate updates: `updates.setdefault(obj, {})[feat] = new_value` +""" diff --git a/predicators/approaches/bilevel_planning_approach.py b/predicators/approaches/bilevel_planning_approach.py index 33ba29167..cc9d7ce36 100644 --- a/predicators/approaches/bilevel_planning_approach.py +++ b/predicators/approaches/bilevel_planning_approach.py @@ -5,7 +5,7 @@ """ import abc import logging -from typing import Any, Callable, List, Optional, Set, Tuple +from typing import Any, Callable, List, Optional, Set, Tuple, cast from gym.spaces import Box @@ -47,6 +47,20 @@ def __init__(self, if option_model is None: option_model = create_option_model(CFG.option_model_name) self._option_model = option_model + # Let the option model terminate Wait on atom change. Without + # this, Wait runs to max_num_steps_option_rollout during + # refinement and the step is rejected for "exceeded individual + # horizon", even when the expected atoms have already become + # true. Mirrors AgentPlannerApproach.__init__. + # Looked up lazily so subclasses whose _get_current_predicates + # depends on attributes set after super().__init__() (e.g. + # GrammarSearchInventionApproach._learned_predicates) don't break, + # and so predicates invented later are reflected at call time. + if CFG.wait_option_terminate_on_atom_change: + cast( # pylint: disable=protected-access + Any, self._option_model)._abstract_function = ( + lambda s: utils.abstract(s, self._get_current_predicates()) + ) self._num_calls = 0 self._last_plan: List[_Option] = [] # used if plan WITH sim self._last_nsrt_plan: List[_GroundNSRT] = [] # plan WITHOUT sim diff --git a/predicators/approaches/process_planning_approach.py b/predicators/approaches/process_planning_approach.py index 40a8e644f..65770ce06 100644 --- a/predicators/approaches/process_planning_approach.py +++ b/predicators/approaches/process_planning_approach.py @@ -119,7 +119,9 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: self._last_option_plan = option_plan self._last_process_plan = process_plan # pylint: enable=attribute-defined-outside-init - policy = utils.option_plan_to_policy(option_plan) + policy = utils.option_plan_to_policy( + option_plan, + abstract_function=lambda s: utils.abstract(s, preds)) self._save_metrics(metrics, processes, preds) diff --git a/predicators/code_sim_learning/__init__.py b/predicators/code_sim_learning/__init__.py new file mode 100644 index 000000000..5fba924ac --- /dev/null +++ b/predicators/code_sim_learning/__init__.py @@ -0,0 +1 @@ +"""Compositional world modeling via code.""" diff --git a/predicators/code_sim_learning/training.py b/predicators/code_sim_learning/training.py new file mode 100644 index 000000000..92ac98217 --- /dev/null +++ b/predicators/code_sim_learning/training.py @@ -0,0 +1,487 @@ +"""Training utilities for the sim-learning approach. + +Parameter fitting via emcee (affine-invariant ensemble MCMC). +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Tuple + +import numpy as np + +from predicators.settings import CFG +from predicators.structs import Action, State + +logger = logging.getLogger(__name__) + +# Step-level simulator: (State, Action, params_dict) -> {Object: {feat: val}} +StepSimulatorFn = Callable[[State, Action, Dict[str, float]], Dict] + + +@dataclass +class ParamSpec: + """Specification for a single learnable parameter.""" + + name: str + init_value: float + lo: Optional[float] = None + hi: Optional[float] = None + + +@dataclass +class FitResult: + """Result of parameter fitting.""" + + names: List[str] + samples: np.ndarray # (num_samples, num_params) + log_probs: np.ndarray # (num_samples,) + + @property + def point_estimate(self) -> Dict[str, float]: + """MAP (sample with highest log-probability).""" + best_idx = int(np.argmax(self.log_probs)) + return { + n: float(self.samples[best_idx, i]) + for i, n in enumerate(self.names) + } + + +def compute_sse( + simulator_fn: StepSimulatorFn, + transitions: List[Tuple[State, Action, State]], + params: Dict[str, float], + process_features: Dict[str, List[str]], +) -> float: + """Sum of squared errors between predicted and observed process features. + + Returns the total (un-normalized) SSE so that the Gaussian + log-likelihood ``-0.5 * SSE / noise_sigma**2`` is the correct + iid-observation form. Dividing by count would silently rescale the + per-observation noise by sqrt(count), making the chain insensitive + to parameter changes. + """ + total_se = 0.0 + + for s_t, action, s_next_obs in transitions: + updates = simulator_fn(s_t, action, params) + + for obj, feat_dict in updates.items(): + type_name = obj.type.name + allowed_feats = process_features.get(type_name, []) + for feat_name, pred_val in feat_dict.items(): + if feat_name not in allowed_feats: + continue + v = pred_val.item() if hasattr(pred_val, 'item') else pred_val + obs_val = float(s_next_obs.get(obj, feat_name)) + total_se += (v - obs_val)**2 + + # Penalize unpredicted features (model predicts no change). + for obj in s_t: + type_name = obj.type.name + for feat_name in process_features.get(type_name, []): + if obj in updates and feat_name in updates[obj]: + continue + pred_val = float(s_t.get(obj, feat_name)) + obs_val = float(s_next_obs.get(obj, feat_name)) + total_se += (pred_val - obs_val)**2 + + return total_se + + +def compute_residuals( + simulator_fn: StepSimulatorFn, + transitions: List[Tuple[State, Action, State]], + params: Dict[str, float], + process_features: Dict[str, List[str]], +) -> np.ndarray: + """Per-feature residuals (predicted - observed) as a flat vector. + + Used by Levenberg-Marquardt, which needs the residual *vector* + rather than scalar SSE so it can build J = dr/dtheta. Iteration + order is deterministic so the same theta produces the same vector + across calls (required for finite-difference Jacobians). + """ + residuals: List[float] = [] + for s_t, action, s_next_obs in transitions: + updates = simulator_fn(s_t, action, params) + for obj in s_t: + type_name = obj.type.name + for feat_name in process_features.get(type_name, []): + if obj in updates and feat_name in updates[obj]: + raw = updates[obj][feat_name] + pred = raw.item() if hasattr(raw, 'item') else float(raw) + else: + pred = float(s_t.get(obj, feat_name)) + obs = float(s_next_obs.get(obj, feat_name)) + residuals.append(pred - obs) + return np.asarray(residuals, dtype=float) + + +def log_sse_breakdown( + simulator_fn: StepSimulatorFn, + transitions: List[Tuple[State, Action, State]], + params: Dict[str, float], + process_features: Dict[str, List[str]], + label: str = "", +) -> None: + """Log per-(type, feature) SSE so we can see which features dominate. + + Splits each feature's residual into two buckets: + * ``pred`` — transitions where the rule produced an update + (residual is sim's prediction error) + * ``no_pred`` — transitions where no rule fired + (residual is whatever the env changed on its own; + large values here mean the model is missing a + process for this feature) + """ + bucket: Dict[Tuple[str, str], Dict[str, float]] = {} + + def _slot(key: Tuple[str, str]) -> Dict[str, float]: + if key not in bucket: + bucket[key] = { + "sse_pred": 0.0, + "n_pred": 0, + "sse_no_pred": 0.0, + "n_no_pred": 0, + "max_abs_err": 0.0, + } + return bucket[key] + + for s_t, action, s_next_obs in transitions: + updates = simulator_fn(s_t, action, params) + + for obj, feat_dict in updates.items(): + type_name = obj.type.name + allowed_feats = process_features.get(type_name, []) + for feat_name, pred_val in feat_dict.items(): + if feat_name not in allowed_feats: + continue + v = pred_val.item() if hasattr(pred_val, 'item') else pred_val + obs_val = float(s_next_obs.get(obj, feat_name)) + err = float(v) - obs_val + slot = _slot((type_name, feat_name)) + slot["sse_pred"] += err * err + slot["n_pred"] += 1 + slot["max_abs_err"] = max(slot["max_abs_err"], abs(err)) + + for obj in s_t: + type_name = obj.type.name + for feat_name in process_features.get(type_name, []): + if obj in updates and feat_name in updates[obj]: + continue + pred_val = float(s_t.get(obj, feat_name)) + obs_val = float(s_next_obs.get(obj, feat_name)) + err = pred_val - obs_val + slot = _slot((type_name, feat_name)) + slot["sse_no_pred"] += err * err + slot["n_no_pred"] += 1 + slot["max_abs_err"] = max(slot["max_abs_err"], abs(err)) + + if not bucket: + return + + total = sum(s["sse_pred"] + s["sse_no_pred"] for s in bucket.values()) + header = f"SSE breakdown{(' — ' + label) if label else ''} " \ + f"(total {total:.4f}):" + logger.info(header) + logger.info(" %-22s %10s %6s %10s %6s %10s", "type.feature", + "sse_pred", "n_pred", "sse_no_pred", "n_nop", "max|err|") + rows = sorted( + bucket.items(), + key=lambda kv: -(kv[1]["sse_pred"] + kv[1]["sse_no_pred"]), + ) + for (type_name, feat_name), s in rows: + logger.info( + " %-22s %10.4f %6d %10.4f %6d %10.4f", + f"{type_name}.{feat_name}", + s["sse_pred"], + int(s["n_pred"]), + s["sse_no_pred"], + int(s["n_no_pred"]), + s["max_abs_err"], + ) + + +def fit_map_lm( + simulator_fn: StepSimulatorFn, + transitions: List[Tuple[State, Action, State]], + param_specs: List[ParamSpec], + process_features: Dict[str, List[str]], + max_nfev: int = 200, +) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """Find a MAP estimate via Levenberg-Marquardt (trust-region reflective). + + Returns ``(theta_map, jacobian_at_optimum)``. Jacobian is ``None`` + only if the residual vector is empty or LM raises; in those cases + callers should treat the diagnostic as unavailable. + + How LM finds the MAP here: + * ``compute_residuals`` returns r(theta) = (s_{t+1}_obs - sim(s_t, a; + theta)) flattened over transitions and the features named in + ``process_features``. Minimizing 0.5 * ||r||^2 is exactly MLE + under iid Gaussian observation noise; with the broad Gaussian + prior used elsewhere in this module being effectively flat near + init, the least-squares minimizer coincides with the MAP. + * ``scipy.optimize.least_squares(method='trf')`` runs a + Levenberg-Marquardt step inside a trust region with box + constraints (``lo``/``hi`` from ``param_specs``). At each step + it numerically estimates the Jacobian J = dr/dtheta, solves the + damped normal equations (J^T J + lambda I) dtheta = -J^T r, and + adapts lambda based on whether the step reduces SSE. + * On exit, ``result.x`` is theta_map and ``result.jac`` is J at + the optimum. J^T J / sigma^2 is the Gauss-Newton approximation + to the negative log-likelihood Hessian — the input + ``log_hessian_identifiability`` eigendecomposes to flag flat + directions. + + Two callers (see ``fit_simulator_params``): + * Hessian identifiability diagnostic — eigendecompose J^T J. + * MCMC warm start — center emcee walkers on theta_map (and short- + circuit to it directly when ``num_mcmc_steps == 0``). + """ + from scipy.optimize import \ + least_squares # pylint: disable=import-outside-toplevel + + names = [s.name for s in param_specs] + init = np.array([s.init_value for s in param_specs], dtype=float) + lo = np.array([s.lo if s.lo is not None else 1e-6 for s in param_specs]) + hi = np.array([s.hi if s.hi is not None else np.inf for s in param_specs]) + # Nudge init strictly into the interior so trf doesn't reject it. + init = np.maximum(init, lo + 1e-9) + safe_hi = np.where(np.isfinite(hi), hi - 1e-9, np.inf) + init = np.minimum(init, safe_hi) + + def residuals_fn(theta: np.ndarray) -> np.ndarray: + params = {n: float(theta[i]) for i, n in enumerate(names)} + return compute_residuals(simulator_fn, transitions, params, + process_features) + + init_residuals = residuals_fn(init) + if init_residuals.size == 0: + logger.warning("No residuals to fit (empty process_features); " + "skipping LM diagnostic.") + return init, None + + sse_init = float(np.sum(init_residuals**2)) + + try: + result = least_squares(residuals_fn, + init, + method='trf', + bounds=(lo, hi), + max_nfev=max_nfev) + except Exception as exc: # pylint: disable=broad-except + logger.warning("LM diagnostic raised %s; skipping Hessian log.", exc) + return init, None + + sse_lm = float(2.0 * result.cost) + delta = {names[i]: float(result.x[i] - init[i]) for i in range(len(names))} + logger.info( + "LM diagnostic fit: SSE %.4f -> %.4f in %d fn-evals (status=%d, %s).", + sse_init, sse_lm, result.nfev, result.status, + "converged" if result.success else "max-evals") + logger.info("LM theta_map - init: %s", + {k: f"{v:+.4f}" + for k, v in delta.items()}) + + jac = np.asarray(result.jac, dtype=float) + if jac.size == 0: + return np.asarray(result.x, dtype=float), None + return np.asarray(result.x, dtype=float), jac + + +def log_hessian_identifiability( + jacobian: np.ndarray, + param_names: List[str], + noise_sigma: float, + prior_sigma: np.ndarray, + top_k: int = 3, +) -> None: + """Eigendecompose the Hessian at the MAP and log identifiability. + + Under a Laplace approximation, the Hessian of the negative + log-posterior is the inverse posterior covariance. Its eigenvectors + are *combinations* of parameters (not individual params), and the + eigenvalues say how tightly the data constrains each combination: + + * Large eigenvalue -> stiff direction: data pins this down. + * Small eigenvalue -> sloppy direction: data is silent here. + + Sloppy directions point to parameter combinations no optimizer can + recover from the current data — typically structural rule-pair + degeneracy or under-excited input trajectories. The Gauss-Newton + approximation H ~= J^T J / sigma^2 + diag(1/prior_sigma^2) reuses + the LM Jacobian, so this analysis costs effectively nothing once + LM has run. + """ + H_data = jacobian.T @ jacobian / (noise_sigma**2) + H_prior = np.diag(1.0 / prior_sigma**2) + H = H_data + H_prior + + eigvals, eigvecs = np.linalg.eigh(H) # ascending + + cond = float(eigvals[-1] / max(eigvals[0], 1e-30)) + logger.info("Hessian eigenanalysis (cond %.2e, %d params):", cond, + len(param_names)) + + def _format(vec: np.ndarray) -> str: + order = np.argsort(-np.abs(vec)) + parts = [] + for j in order[:4]: + if abs(vec[j]) < 0.05: + break + parts.append(f"{vec[j]:+.2f} {param_names[j]}") + return " ".join(parts) if parts else "(uniform)" + + n = len(eigvals) + k = min(top_k, n) + stiff_idx = list(range(n - 1, n - 1 - k, -1)) + stiff_set = set(stiff_idx) + sloppy_idx = [i for i in range(k) if i not in stiff_set] + + logger.info(" Stiff (well-constrained):") + for i in stiff_idx: + logger.info(" lambda = %10.3e : %s", eigvals[i], + _format(eigvecs[:, i])) + + if sloppy_idx: + logger.info(" Sloppy (under-constrained):") + for i in sloppy_idx: + logger.info(" lambda = %10.3e : %s", eigvals[i], + _format(eigvecs[:, i])) + + +def fit_params( + simulator_fn: StepSimulatorFn, + transitions: List[Tuple[State, Action, State]], + param_specs: List[ParamSpec], + process_features: Dict[str, List[str]], + num_walkers: int = 32, + num_steps: Optional[int] = None, + burn_in: int = 200, + noise_sigma: float = 0.05, + prior_sigma_scale: float = 1.0, +) -> FitResult: + """Fit simulator parameters via emcee ensemble MCMC. + + Gradient-free — handles all parameter types (rates, thresholds, + capacities) uniformly. Returns full posterior with uncertainty. + + Args: + simulator_fn: Simulator(state, action, params_dict) -> updates. + Should run the base sim internally if needed. + transitions: List of (s_t, action, s_{t+1}_obs) triples. + param_specs: Parameter specifications (name, init_value). + process_features: {type_name: [feat_names]} to fit. + num_walkers: Number of ensemble walkers (>= 2*ndim). + num_steps: Total MCMC steps per walker. If None, defaults to + CFG.code_sim_learning_num_mcmc_steps. If 0, skip training and + use initial parameter values directly. + burn_in: Steps to discard as burn-in. + noise_sigma: Observation noise std dev for likelihood. + prior_sigma_scale: Prior width as multiple of init_value. + + Returns: + FitResult with posterior samples and log-probabilities. + """ + names = [s.name for s in param_specs] + init_values = np.array([s.init_value for s in param_specs]) + if num_steps is None: + num_steps = CFG.code_sim_learning_num_mcmc_steps + if num_steps < 0: + raise ValueError("code_sim_learning_num_mcmc_steps must be " + "non-negative.") + prior_sigma = init_values * prior_sigma_scale + + # Optional one-shot LM fit. Two independent uses: + # * Hessian diagnostic — eigendecompose J^T J at the MAP. + # * Warm start — center MCMC walkers on theta_map (and short-circuit + # to it directly when num_steps == 0). + walker_center = init_values + if (CFG.code_sim_learning_log_hessian_identifiability + or CFG.code_sim_learning_warm_start_with_lm): + theta_map, jac = fit_map_lm(simulator_fn, transitions, param_specs, + process_features) + if (CFG.code_sim_learning_log_hessian_identifiability + and jac is not None and jac.size > 0): + log_hessian_identifiability(jac, names, noise_sigma, prior_sigma) + if CFG.code_sim_learning_warm_start_with_lm: + walker_center = np.asarray(theta_map, dtype=float) + logger.info("Warm-starting MCMC walkers from LM MAP estimate.") + lm_params = { + n: float(walker_center[i]) + for i, n in enumerate(names) + } + lm_sse = compute_sse(simulator_fn, transitions, lm_params, + process_features) + lm_ll = -0.5 * lm_sse / (noise_sigma**2) + logger.info( + "After LM warm start — SSE: %.6f log-likelihood: %.2f", + lm_sse, lm_ll) + log_sse_breakdown(simulator_fn, + transitions, + lm_params, + process_features, + label="lm-warm-start") + + if num_steps == 0: + if CFG.code_sim_learning_warm_start_with_lm: + logger.info("Skipping emcee; using LM warm-start parameters.") + else: + logger.info("Skipping emcee; using initial parameter values.") + return FitResult(names, walker_center[None, :], np.zeros(1)) + + import emcee # type: ignore[import-untyped] # pylint: disable=import-outside-toplevel + + ndim = len(param_specs) + num_walkers = max(num_walkers, 2 * ndim + 2) + burn_in = min(burn_in, max(num_steps - 1, 0)) + + def log_posterior(theta: np.ndarray) -> float: + # Reject negative values + if np.any(theta <= 0): + return -np.inf + params = {n: float(theta[i]) for i, n in enumerate(names)} + # Broad Gaussian prior centered on init values + log_prior = -0.5 * np.sum(((theta - init_values) / prior_sigma)**2) + # Likelihood + sse = compute_sse(simulator_fn, transitions, params, process_features) + return log_prior + (-0.5 * sse / (noise_sigma**2)) + + # Initialize walkers across the prior support (sigma = half the prior + # width). A tight ball around init traps the chain on flat plateaus + # of the likelihood (e.g., when threshold-based rules don't fire), + # because emcee stretch moves scale with the swarm's spread. + p0 = walker_center + 0.5 * prior_sigma * np.random.randn(num_walkers, ndim) + p0 = np.clip(p0, 1e-6, None) + + sampler = emcee.EnsembleSampler(num_walkers, ndim, log_posterior) + + logger.info("Running emcee: %d walkers, %d steps, %d burn-in.", + num_walkers, num_steps, burn_in) + + # Run with periodic progress reports. + report_interval = max(1, num_steps // 5) + report_interval = 100 + for i, _result in enumerate(sampler.sample(p0, iterations=num_steps), + start=1): + if i % report_interval == 0 or i == num_steps: + best_lp = sampler.get_log_prob()[:i].max() + logger.info(" emcee step %d/%d (best log-prob: %.2f)", i, + num_steps, best_lp) + for h in logger.handlers + logging.getLogger().handlers: + h.flush() + + # Discard burn-in, flatten chains. + samples = sampler.get_chain(discard=burn_in, flat=True) + log_probs = sampler.get_log_prob(discard=burn_in, flat=True) + + result = FitResult(names=names, samples=samples, log_probs=log_probs) + + logger.info("emcee done. Posterior mean: %s", + {k: f"{v:.4f}" + for k, v in result.point_estimate.items()}) + + return result diff --git a/predicators/code_sim_learning/utils.py b/predicators/code_sim_learning/utils.py new file mode 100644 index 000000000..830a1e1ed --- /dev/null +++ b/predicators/code_sim_learning/utils.py @@ -0,0 +1,148 @@ +"""Utilities for the code sim-learning module. + +Core primitives for process-dynamics simulation: + +* ``apply_rules`` — run a list of rule functions on a state, return + feature updates (``ProcessUpdate``). +* ``merge_updates`` — overwrite features in a ``State`` with values + from a ``ProcessUpdate``. +* ``simulate_step`` — full pipeline: base → rules → merge. +* ``read_simulator_components`` — pull the ``PROCESS_RULES``, + ``PARAM_SPECS``, ``PROCESS_FEATURES`` triple out of a namespace + (oracle module globals or agent-synthesized exec namespace). +""" + +from __future__ import annotations + +import logging +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple + +from predicators.structs import Action, Object, State + +logger = logging.getLogger(__name__) + +# Type alias: {Object: {feature_name: new_value}} +ProcessUpdate = Dict[Object, Dict[str, float]] + +# ── Primitives ──────────────────────────────────────────────────── + + +def apply_rules(state: State, rules: List, + params: Dict[str, float]) -> ProcessUpdate: + """Apply process rules sequentially and return feature updates. + + Each rule has signature ``rule(state, updates, params) -> updates``. + Values are normalised to plain floats (rules may return numpy + scalars). + """ + updates: ProcessUpdate = {} + for rule in rules: + updates = rule(state, updates, params) + return { + obj: {feat: float(val) + for feat, val in feat_dict.items()} + for obj, feat_dict in updates.items() + } + + +def merge_updates( + base_state: State, + updates: ProcessUpdate, +) -> State: + """Overwrite features in *base_state* with values from *updates*.""" + if not updates: + return base_state + + new_data = {} + for obj in base_state: + arr = base_state[obj].copy() + if obj in updates: + for feat_name, new_val in updates[obj].items(): + idx = obj.type.feature_names.index(feat_name) + arr[idx] = new_val + new_data[obj] = arr + + merged = base_state.copy() + merged.data = new_data + return merged + + +def simulate_step( + state: State, + action: Action, + base_env: Any, + rules: List, + params: Dict[str, float], +) -> State: + """Full simulation pipeline: base → rules → merge.""" + base_state = base_env.simulate(state, action) + updates = apply_rules(base_state, rules, params) + if not updates: + return base_state + return merge_updates(base_state, updates) + + +# ── Module-namespace loader ─────────────────────────────────────── + + +def read_simulator_components( + ns: Mapping[str, Any], +) -> Tuple[Optional[List], Optional[List], Optional[Dict[str, List[str]]]]: + """Pull the simulator triple from a namespace (module or exec dict). + + Looks for three names by convention: + + * ``PROCESS_RULES`` — non-empty list of rule functions. + * ``PARAM_SPECS`` — list of ``ParamSpec``, **or** a zero-arg + callable returning such a list. The callable form lets oracle + modules defer CFG-dependent values until consumption time, so the + module can be imported before CFG is finalized; the agent's + saved-file form normally just uses a list. + * ``PROCESS_FEATURES`` — ``{type_name: [feature_names]}`` dict. + + Returns ``(rules, specs, features)`` with ``None`` for any + missing-or-malformed component; callers decide how to react. + """ + rules = ns.get("PROCESS_RULES") + if not isinstance(rules, list) or not rules: + rules = None + + specs = ns.get("PARAM_SPECS") + if callable(specs): + specs = specs() + if not isinstance(specs, list) or not specs: + specs = None + + features = ns.get("PROCESS_FEATURES") + if features is not None and not isinstance(features, dict): + features = None + + return rules, specs, features + + +# ── LearnedSimulator ────────────────────────────────────────────── + + +class LearnedSimulator: + """Wraps a step-level simulator function (handwritten or LLM-synthesized). + + The function predicts process dynamics — features like water_volume, + heat_level, spilled_level that aren't captured by rigid body + physics. + """ + + StepFn = Callable[[State], ProcessUpdate] + + def __init__(self, + step_fn: StepFn, + name: str = "learned_simulator") -> None: + self._step_fn = step_fn + self.name = name + + def predict_step(self, state: State) -> ProcessUpdate: + """Predict process feature updates for a single timestep.""" + try: + return self._step_fn(state) + except Exception as e: # pylint: disable=broad-except + logger.warning("Simulator '%s' step raised: %s", self.name, e) + return {} diff --git a/predicators/cogman.py b/predicators/cogman.py index e35e27eb6..d573d2ad8 100644 --- a/predicators/cogman.py +++ b/predicators/cogman.py @@ -78,6 +78,7 @@ def step(self, observation: Observation) -> Optional[Action]: self._episode_state_history.append(state) if self._termination_fn is not None and self._termination_fn(state): logging.info("[CogMan] Termination triggered.") + logging.debug("[CogMan] step returning None: termination_fn fired") return None # Check if we should replan. if self._exec_monitor.step(state): @@ -227,8 +228,9 @@ def run_episode_and_get_observations( metrics["policy_call_time"] = 0.0 metrics["num_options_executed"] = 0.0 exception_raised_in_step = False + step_num = -1 if not (terminate_on_goal_reached and env.goal_reached()): - for _ in range(max_num_steps): + for step_num in range(max_num_steps): monitor_observed = False exception_raised_in_step = False try: @@ -236,6 +238,7 @@ def run_episode_and_get_observations( act = cogman.step(obs) metrics["policy_call_time"] += time.perf_counter() - start_time if act is None: + logging.debug("[CogMan] loop break: act is None") break if act.has_option() and act.get_option() != curr_option: curr_option = act.get_option() @@ -264,9 +267,14 @@ def run_episode_and_get_observations( any(issubclass(type(e), c) for c in exceptions_to_break_on): if monitor_observed: exception_raised_in_step = True + logging.debug( + f"[CogMan] loop break: exception in break_on set: {e}") break if CFG.terminate_on_goal_reached_and_option_terminated and \ env.goal_reached(): + logging.debug( + f"[CogMan] loop break: goal_reached+option_terminated " + f"(exception: {e})") break if monitor is not None and not monitor_observed: monitor.observe(obs, None) @@ -277,7 +285,18 @@ def run_episode_and_get_observations( return traj, solved, metrics raise e if terminate_on_goal_reached and env.goal_reached(): + logging.debug("[CogMan] loop break: terminate_on_goal_reached") break + else: + option_str = (None + if curr_option is None else curr_option.simple_str()) + logging.info( + "[CogMan] Reached max_num_steps=%d while executing " + "option %s.", max_num_steps, option_str) + logging.debug("[CogMan] Final loop step index before horizon: %d", + step_num) + logging.debug("[CogMan] Atoms at horizon: %s", + sorted(utils.abstract(obs, env.predicates))) if monitor is not None and not exception_raised_in_step: monitor.observe(obs, None) cogman.finish_episode(obs) diff --git a/predicators/envs/__init__.py b/predicators/envs/__init__.py index 66a497845..2510edd60 100644 --- a/predicators/envs/__init__.py +++ b/predicators/envs/__init__.py @@ -1,6 +1,7 @@ """Handle creation of environments.""" import logging +from typing import Any from predicators import utils from predicators.envs.base_env import BaseEnv @@ -14,7 +15,8 @@ def create_new_env(name: str, do_cache: bool = True, - use_gui: bool = False) -> BaseEnv: + use_gui: bool = False, + **kwargs: Any) -> BaseEnv: """Create a new instance of an environment from its name. If do_cache is True, then cache this env instance so that it can @@ -22,7 +24,7 @@ def create_new_env(name: str, """ for cls in utils.get_all_subclasses(BaseEnv): if not cls.__abstractmethods__ and cls.get_name() == name: - env = cls(use_gui) + env = cls(use_gui, **kwargs) break else: raise NotImplementedError(f"Unknown env: {name}") diff --git a/predicators/envs/base_env.py b/predicators/envs/base_env.py index f62ce3e30..a88eae29e 100644 --- a/predicators/envs/base_env.py +++ b/predicators/envs/base_env.py @@ -198,7 +198,15 @@ def get_test_tasks(self) -> List[EnvironmentTask]: @property def _current_state(self) -> State: - """Default for environments where states are observations.""" + """Typed accessor for _current_observation when it is a State. + + _current_observation is the raw Observation (which may not be a + State in vision-based envs). _current_state provides a + convenience accessor with a type assertion for the common case + where observations are States. Use _current_observation for + assignment (it is the backing field); use _current_state for + reads when you need a State. + """ assert isinstance(self._current_observation, State) return self._current_observation diff --git a/predicators/envs/pybullet_ants.py b/predicators/envs/pybullet_ants.py index 4d22d6b07..d02063333 100644 --- a/predicators/envs/pybullet_ants.py +++ b/predicators/envs/pybullet_ants.py @@ -5,9 +5,9 @@ import pybullet as p from predicators import utils -from predicators.envs.pybullet_env import PyBulletEnv, create_pybullet_block +from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.objects import create_object, \ - sample_collision_free_2d_positions, update_object + create_pybullet_block, sample_collision_free_2d_positions, update_object from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ @@ -91,7 +91,8 @@ class PyBulletAntsEnv(PyBulletEnv): def __init__(self, use_gui: bool = False, - debug_layout: bool = True) -> None: + debug_layout: bool = True, + **kwargs: Any) -> None: # Create single robot self._robot = Object("robot", self._robot_type) @@ -113,7 +114,7 @@ def __init__(self, if CFG.ants_ants_attracted_to_points: self._ants_to_xy: Dict[Object, Tuple[float, float]] = {} - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) self._debug_layout = debug_layout # Define predicates if needed (some are placeholders) @@ -174,7 +175,7 @@ def initialize_pybullet( food_ids = [] for _ in range(cls.num_food): fid = create_pybullet_block( - color=(0.5, 0.5, 0.5, 1.0), # We’ll override color later + color=(0.5, 0.5, 0.5, 1.0), # We'll override color later half_extents=cls.food_half_extents, mass=cls.food_mass, friction=0.5, @@ -215,10 +216,7 @@ def _get_object_ids_for_held_check(self) -> List[int]: # If we support robot picking up food blocks, return those IDs. return [f.id for f in self._blocks] - def _create_task_specific_objects(self, state: State) -> None: - pass - - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" if obj.type == self._food_type: if feature == "attractive": @@ -229,32 +227,30 @@ def _extract_feature(self, obj: Object, feature: str) -> float: raise ValueError(f"Unknown feature {feature} for object {obj}") - def _reset_custom_env_state(self, state: State) -> None: - - if CFG.ants_ants_attracted_to_points: - self._ant_to_xy = {} # type: ignore[no-redef] - for ant_obj in state.get_objects(self._ant_type): - self._ants_to_xy[ant_obj] = (self._train_rng.uniform( - self.one_third_x, self.two_third_x), - self._train_rng.uniform( - self.y_lb, self.y_ub)) - - # Hide irrelevant objects + def _set_domain_specific_state(self, state: State) -> None: + """Hide unused objects, set attraction points, food colors, and ant + target references.""" oov_x, oov_y = self._out_of_view_xy block_objs = state.get_objects(self._food_type) for i in range(len(block_objs), len(self._blocks)): - # Hide the remaining blocks update_object(self._blocks[i].id, position=(oov_x, oov_y, self.z_lb), physics_client_id=self._physics_client_id) ant_objs = state.get_objects(self._ant_type) for i in range(len(ant_objs), len(self._ants)): - # Hide the remaining ants update_object(self._ants[i].id, position=(oov_x, oov_y, self.z_lb), physics_client_id=self._physics_client_id) + if CFG.ants_ants_attracted_to_points: + self._ant_to_xy = {} # type: ignore[no-redef] + for ant_obj in state.get_objects(self._ant_type): + self._ants_to_xy[ant_obj] = (self._train_rng.uniform( + self.one_third_x, self.two_third_x), + self._train_rng.uniform( + self.y_lb, self.y_ub)) + for food in state.get_objects(self._food_type): r = state.get(food, "r") g = state.get(food, "g") @@ -265,7 +261,6 @@ def _reset_custom_env_state(self, state: State) -> None: physics_client_id=self._physics_client_id) food.attractive = attractive - # Set ant's attractive food for ant_obj in state.get_objects(self._ant_type): food_id = state.get(ant_obj, "target_food") for food_obj in state.get_objects(self._food_type): @@ -273,25 +268,10 @@ def _reset_custom_env_state(self, state: State) -> None: ant_obj.target_food = food_obj break - def step( # pylint: disable=redefined-outer-name - self, - action: Action, - render_obs: bool = False) -> State: - """Override to (1) do usual robot step, (2) move ants toward attracted - food with noise, and then (3) return the final state.""" - # Step the robot normally - next_state = super().step(action, render_obs=render_obs) - - # Move ants. For each ant, find a target food - # object that is “attractive.” If there’s more - # than one attractive block, pick the one it’s - # “assigned” to, or the first in the list. Then - # move a small step toward it with noise. - self._update_ant_positions(next_state) - - final_state = self._get_state() - self._current_observation = final_state - return final_state + def _domain_specific_step(self) -> None: + """Move ants toward attracted food with noise.""" + state = self._get_state() + self._update_ant_positions(state) def _update_ant_positions(self, state: State) -> None: """For each ant, move it a small step toward its assigned attractive @@ -304,7 +284,7 @@ def _update_ant_positions(self, state: State) -> None: if CFG.ants_ants_attracted_to_points: fx, fy = self._ants_to_xy[ant_obj] else: - # Retrieve this ant’s assigned food + # Retrieve this ant's assigned food target_food_obj = None for food_obj in state.get_objects(self._food_type): if food_obj.id == state.get(ant_obj, "target_food"): @@ -533,7 +513,7 @@ def _make_tasks( # pylint: disable=redefined-outer-name env = PyBulletAntsEnv(use_gui=True) rng = np.random.default_rng(CFG.seed) task = env._make_tasks(1, rng)[0] # pylint: disable=protected-access - env._reset_state(task.init) # pylint: disable=protected-access + env._set_state(task.init) # pylint: disable=protected-access while True: # Robot does nothing diff --git a/predicators/envs/pybullet_balance.py b/predicators/envs/pybullet_balance.py index 6b69ee4ad..4206875c6 100644 --- a/predicators/envs/pybullet_balance.py +++ b/predicators/envs/pybullet_balance.py @@ -15,8 +15,9 @@ import numpy as np import pybullet as p -from predicators.envs.pybullet_env import PyBulletEnv, create_pybullet_block +from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion +from predicators.pybullet_helpers.objects import create_pybullet_block from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot from predicators.settings import CFG from predicators.structs import Action, Array, ConceptPredicate, \ @@ -87,7 +88,7 @@ class PyBulletBalanceEnv(PyBulletEnv): _num_blocks_train = CFG.balance_num_blocks_train _num_blocks_test = CFG.balance_num_blocks_test - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: # Types # bbox_features = ["bbox_left", "bbox_right", # "bbox_upper", "bbox_lower"] @@ -115,7 +116,7 @@ def __init__(self, use_gui: bool = False) -> None: self._prev_diff = 0 - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) # Predicates self._DirectlyOn = Predicate( @@ -320,10 +321,7 @@ def get_name(cls) -> str: # ------------------------------------------------------------------------- # State Management: Get, (Re)Set, Step - def _create_task_specific_objects(self, state: State) -> None: - pass - - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" if obj.type == self._block_type: visual_data = p.getVisualShapeData( @@ -348,14 +346,9 @@ def _extract_feature(self, obj: Object, feature: str) -> float: raise ValueError(f"Unknown feature {feature} for object {obj}") - def step( # pylint: disable=redefined-outer-name - self, - action: Action, - render_obs: bool = False) -> State: - state = super().step(action, render_obs=render_obs) - + def _domain_specific_step(self) -> None: + state = self._get_state() self._update_balance_beam(state) - # Turn machine on if self._PressingButton_holds(state, [self._robot, self._machine]): if self._Balanced_holds(state, [self._plate1, self._plate3]): @@ -363,29 +356,28 @@ def step( # pylint: disable=redefined-outer-name -1, rgbaColor=self._button_color_on, physicsClientId=self._physics_client_id) - self._current_observation = self._get_state() - state = self._current_observation.copy() - return state + def _set_domain_specific_state(self, state: State) -> None: + """Set block placement, balance beam, block colors, ID mapping, and + button color.""" + block_objs = state.get_objects(self._block_type) - def _reset_custom_env_state(self, state: State) -> None: - """Replace the old `_reset_state` environment-specific logic. + # Put unused blocks out of view + h = self._block_size + oov_x, oov_y = self._out_of_view_xy + for i in range(len(block_objs), len(self._blocks)): + p.resetBasePositionAndOrientation( + self._blocks[i].id, [oov_x, oov_y, i * h], + self._default_orn, + physicsClientId=self._physics_client_id) + + self._prev_diff = 0 + self._update_balance_beam(state) - The base `_reset_state` has already handled standard features - for objects that appear in _get_all_objects(), so here we just - do custom domain-specific tasks: setting plates/blocks if we - aren't letting the base class handle them, updating button - color, and running the beam-balancing update. - """ - # block objs in the state - block_objs = state.get_objects(self._block_type) self._block_id_to_block.clear() - # Suppose we want to manually update each block's color or remove them - # if not used. For example: for i, block_obj in enumerate(block_objs): self._block_id_to_block[block_obj.id] = block_obj - # Manually set color if needed: r = state.get(block_obj, "color_r") g = state.get(block_obj, "color_g") b = state.get(block_obj, "color_b") @@ -394,20 +386,7 @@ def _reset_custom_env_state(self, state: State) -> None: rgbaColor=(r, g, b, 1.0), physicsClientId=self._physics_client_id) - # For blocks beyond the number actually in the state, put them out of - # view: - h = self._block_size - oov_x, oov_y = self._out_of_view_xy - for i in range(len(block_objs), len(self._blocks)): - p.resetBasePositionAndOrientation( - self._blocks[i].id, [oov_x, oov_y, i * h], - self._default_orn, - physicsClientId=self._physics_client_id) - - self._prev_diff = 0 # reset difference - self._update_balance_beam(state) - - # Update button color for whether the machine is on + # Update button color if self._MachineOn_holds(state, [self._machine, self._robot]): button_color = self._button_color_on else: @@ -961,7 +940,7 @@ def _table_xy_is_clear(self, x: float, y: float, CFG.num_test_tasks = 1 env = PyBulletBalanceEnv(use_gui=True) task = env._generate_test_tasks()[0] # pylint: disable=protected-access - env._reset_state(task.init) # pylint: disable=protected-access + env._set_state(task.init) # pylint: disable=protected-access while True: # Robot does nothing diff --git a/predicators/envs/pybullet_barrier.py b/predicators/envs/pybullet_barrier.py index c1a7f3132..c0e98ebe4 100644 --- a/predicators/envs/pybullet_barrier.py +++ b/predicators/envs/pybullet_barrier.py @@ -15,9 +15,10 @@ import pybullet as p from predicators import utils -from predicators.envs.pybullet_env import PyBulletEnv, create_pybullet_block +from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion -from predicators.pybullet_helpers.objects import create_object +from predicators.pybullet_helpers.objects import create_object, \ + create_pybullet_block from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ @@ -90,7 +91,7 @@ class PyBulletBarrierEnv(PyBulletEnv): _barrier_type = Type("barrier", ["x", "y", "rot", "height"], sim_features=["id", "base_z"]) - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: # Objects self._robot = Object("robot", self._robot_type) self._switches: List[Object] = [ @@ -102,7 +103,7 @@ def __init__(self, use_gui: bool = False) -> None: for i in range(self.num_barriers) ] - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) # Predicates self._SwitchOn = Predicate("SwitchOn", [self._switch_type], @@ -217,7 +218,7 @@ def _get_object_ids_for_held_check(self) -> List[int]: """Return IDs of objects that can be held (none in this env).""" return [] - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" if obj.type == self._switch_type and feature == "is_on": return float(self._is_switch_on(obj)) @@ -229,10 +230,7 @@ def _extract_feature(self, obj: Object, feature: str) -> float: return current_z - obj.base_z raise ValueError(f"Unknown feature {feature} for object {obj}") - def _create_task_specific_objects(self, state: State) -> None: - del state # Unused - - def _reset_custom_env_state(self, state: State) -> None: + def _set_domain_specific_state(self, state: State) -> None: """Reset environment state from a State object.""" # Set switch states and positions for switch in self._switches: @@ -474,7 +472,7 @@ def _make_tasks(self, num_tasks: int, CFG.num_train_tasks = 1 env = PyBulletBarrierEnv(use_gui=True) task = env._generate_train_tasks()[0] # pylint: disable=protected-access - env._reset_state(task.init) # pylint: disable=protected-access + env._set_state(task.init) # pylint: disable=protected-access print("PyBullet Barrier Environment Test") print("Barriers should animate when switches are toggled.") diff --git a/predicators/envs/pybullet_blocks.py b/predicators/envs/pybullet_blocks.py index 0aaa0afe2..d6b5f09ce 100644 --- a/predicators/envs/pybullet_blocks.py +++ b/predicators/envs/pybullet_blocks.py @@ -9,8 +9,9 @@ from predicators import utils from predicators.envs.blocks import BlocksEnv -from predicators.envs.pybullet_env import PyBulletEnv, create_pybullet_block +from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion +from predicators.pybullet_helpers.objects import create_pybullet_block from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, Object, State @@ -26,8 +27,8 @@ class PyBulletBlocksEnv(PyBulletEnv, BlocksEnv): _table_pose: ClassVar[Pose3D] = (1.35, 0.75, table_height / 2) _table_orientation: ClassVar[Quaternion] = (0., 0., 0., 1.) - def __init__(self, use_gui: bool = False) -> None: - super().__init__(use_gui) + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: + super().__init__(use_gui, **kwargs) # Store references self._table_id: int = -1 # self._block_ids: List[int] = [] @@ -93,21 +94,14 @@ def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None: for blk, blk_id in zip(self._blocks, self._block_ids): blk.id = blk_id - def _create_task_specific_objects(self, state: State) -> None: - """No additional environment assets needed per-task.""" - - def _reset_custom_env_state(self, state: State) -> None: - """After the parent `_reset_state()` has reset the robot, set the block - positions/colors and handle constraints for any 'held' block.""" + def _set_domain_specific_state(self, state: State) -> None: + """Set block positions, grasp constraints, out-of-view placement, ID + mapping, and block colors.""" block_objs = state.get_objects(self._block_type) - self._block_id_to_block.clear() # Place the relevant blocks for i, block_obj in enumerate(block_objs): - block_id = self._block_ids[i] # re-use the i-th block ID - self._block_id_to_block[block_id] = block_obj - - # Position/orientation from the state's block features + block_id = self._block_ids[i] bx = state.get(block_obj, "pose_x") by = state.get(block_obj, "pose_y") bz = state.get(block_obj, "pose_z") @@ -116,19 +110,9 @@ def _reset_custom_env_state(self, state: State) -> None: self._default_orn, physicsClientId=self._physics_client_id) - # Update color - r = state.get(block_obj, "color_r") - g = state.get(block_obj, "color_g") - b = state.get(block_obj, "color_b") - p.changeVisualShape(block_id, - linkIndex=-1, - rgbaColor=(r, g, b, 1.0), - physicsClientId=self._physics_client_id) - # If there is a held block, create the constraint held_block = self._get_held_block(state) if held_block is not None: - # Force grasp the relevant block self._force_grasp_object(held_block) # Teleport any leftover blocks out of view @@ -141,7 +125,20 @@ def _reset_custom_env_state(self, state: State) -> None: self._default_orn, physicsClientId=self._physics_client_id) - def _extract_feature(self, obj: Object, feature: str) -> float: + self._block_id_to_block.clear() + + for i, block_obj in enumerate(block_objs): + block_id = self._block_ids[i] + self._block_id_to_block[block_id] = block_obj + r = state.get(block_obj, "color_r") + g = state.get(block_obj, "color_g") + b = state.get(block_obj, "color_b") + p.changeVisualShape(block_id, + linkIndex=-1, + rgbaColor=(r, g, b, 1.0), + physicsClientId=self._physics_client_id) + + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Called by the parent class when constructing the `PyBulletState`. We read off the relevant block or robot features from PyBullet. @@ -204,17 +201,13 @@ def _extract_feature(self, obj: Object, feature: str) -> float: f"{feature}") def step(self, action: Action, render_obs: bool = False) -> State: - self._prev_held_obj_id = self._held_obj_id - # Otherwise, proceed with normal PyBullet step - next_state = super().step(action, render_obs=render_obs) + return super().step(action, render_obs=render_obs) + def _domain_specific_step(self) -> None: if CFG.blocks_high_towers_are_unstable: - self._apply_force_to_high_towers(next_state) - next_state = self._get_state() - self._current_observation = next_state - - return next_state + state = self._get_state() + self._apply_force_to_high_towers(state) def _extract_robot_state(self, state: State) -> np.ndarray: """As needed, parse from the robot's `pose_x`, `pose_y`, `pose_z`, @@ -233,6 +226,16 @@ def _extract_robot_state(self, state: State) -> np.ndarray: qx, qy, qz, qw = self.get_robot_ee_home_orn() return np.array([rx, ry, rz, qx, qy, qz, qw, f], dtype=np.float32) + def _get_robot_state_dict(self) -> Dict[str, float]: + rx, ry, rz, _, _, _, _, rf = self._pybullet_robot.get_state() + fingers = self._fingers_joint_to_state(self._pybullet_robot, rf) + return { + "pose_x": rx, + "pose_y": ry, + "pose_z": rz, + "fingers": fingers, + } + def _get_object_ids_for_held_check(self) -> List[int]: """Return the IDs of blocks for which we might be checking 'held' contact.""" @@ -272,7 +275,7 @@ def _force_grasp_object(self, block: Object) -> None: """Manually create a fixed constraint for a block that is marked 'held' in the State. - Called from _reset_custom_env_state(). + Called from _set_domain_specific_state(). """ # Find block's pybullet ID block_id = None diff --git a/predicators/envs/pybullet_boil.py b/predicators/envs/pybullet_boil.py index 76561aabb..1731ac0d1 100644 --- a/predicators/envs/pybullet_boil.py +++ b/predicators/envs/pybullet_boil.py @@ -9,9 +9,10 @@ import pybullet as p from predicators import utils -from predicators.envs.pybullet_env import PyBulletEnv, create_pybullet_block +from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion -from predicators.pybullet_helpers.objects import create_object, update_object +from predicators.pybullet_helpers.objects import create_object, \ + create_pybullet_block, update_object from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot from predicators.settings import CFG from predicators.structs import Action, DerivedPredicate, EnvironmentTask, \ @@ -173,7 +174,7 @@ def water_fill_speed(self) -> float: _human_type = Type("human", ["happiness_level"], sim_features=["id", "happiness_level"]) - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: # Create the robot as an Object self._robot = Object("robot", self._robot_type) @@ -212,7 +213,7 @@ def __init__(self, use_gui: bool = False) -> None: # Keep track of the spilled water block (None if no spill yet) self._spilled_water_id: Optional[int] = None - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) # Optionally, define some relevant predicates self._JugFilled = Predicate("JugFilled", [self._jug_type], @@ -491,11 +492,7 @@ def _get_object_ids_for_held_check(self) -> List[int]: jug_ids = [j.id for j in self._jugs if j.id is not None] return jug_ids - def _create_task_specific_objects(self, state: State) -> None: - """If you wanted additional objects depending on a given state, add - them here.""" - - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Map from environment object + feature name -> a float feature in the State.""" # Faucet @@ -558,8 +555,8 @@ def _extract_feature(self, obj: Object, feature: str) -> float: # Otherwise, rely on defaults (like the base PyBulletEnv) for x,y,z,... raise ValueError(f"Unknown feature {feature} for object {obj}.") - def _reset_custom_env_state(self, state: State) -> None: - """Called in _reset_state to do any environment-specific resetting. + def _set_domain_specific_state(self, state: State) -> None: + """Called in _set_state to do any environment-specific resetting. This environment only supports resetting the state at the beginning, because the state dict doesn't include all features @@ -570,7 +567,7 @@ def _reset_custom_env_state(self, state: State) -> None: for i, burner_obj in enumerate(burners): on_val = state.get(burner_obj, "is_on") burner_obj.switch_id = self._burner_switches[i].id - burner_obj.prev_on = 0.0 # Initialize prev_on to 0 + burner_obj.prev_on = 0.0 self._set_switch_on(self._burner_switches[i].id, bool(on_val > 0.5)) @@ -588,6 +585,8 @@ def _reset_custom_env_state(self, state: State) -> None: liquid_id = self._create_liquid_for_jug(jug, state) self._jug_to_liquid_id[jug] = liquid_id + self._update_liquid_colors(state) + # Update jug body colors from state for jug in jugs: if jug.id is not None: @@ -600,7 +599,7 @@ def _reset_custom_env_state(self, state: State) -> None: # Faucet on/off self._faucet.switch_id = self._faucet_switch.id - self._faucet.prev_on = 0.0 # Initialize prev_on to 0 + self._faucet.prev_on = 0.0 f_on = state.get(self._faucet, "is_on") self._set_switch_on(self._faucet_switch.id, bool(f_on > 0.5)) @@ -615,7 +614,6 @@ def _reset_custom_env_state(self, state: State) -> None: self._faucet._spilled_level = -self.water_fill_speed * 20 spilled_level = max(0.0, self._faucet._spilled_level) # pylint: enable=protected-access - # If there's already some spillage in the state, recreate a block if spilled_level > 0.0: self._spilled_water_id = self._create_spilled_water_block( spilled_level, state) @@ -627,17 +625,14 @@ def _reset_custom_env_state(self, state: State) -> None: # Move irrelevant jugs and burners out of the way oov_x, oov_y = self._out_of_view_xy - jugs = state.get_objects(self._jug_type) for i in range(len(jugs), len(self._jugs)): update_object(self._jugs[i].id, position=(oov_x, oov_y, 0.0), physics_client_id=self._physics_client_id) - burners = state.get_objects(self._burner_type) for i in range(len(burners), len(self._burners)): update_object(self._burners[i].id, position=(oov_x, oov_y, 0.0), physics_client_id=self._physics_client_id) - # Also move the corresponding switch update_object(self._burner_switches[i].id, position=(oov_x, oov_y, self.switch_height), physics_client_id=self._physics_client_id) @@ -648,34 +643,15 @@ def _reset_custom_env_state(self, state: State) -> None: # ------------------------------------------------------------------------- # Step Logic # ------------------------------------------------------------------------- - def step(self, action: Action, render_obs: bool = False) -> State: - """Execute a low-level action (robot controls), then handle water - filling/spillage and heating.""" - # First let the base environment perform the usual PyBullet step - next_state = super().step(action, render_obs=False) - - # 1) Handle faucet filling/spillage - self._handle_faucet_logic(next_state) - - # 2) Handle burner heating - self._handle_heating_logic(next_state) - - # 3) Update jug colors based on their 'heat' - self._update_jug_colors(next_state) - - # 4) Update burner colors based on their on/off state - self._update_burner_colors(next_state) - - # 5) Update the human's happiness level - self._update_human_happiness(next_state) - - # 6) Update prev_on states for next step - self._update_prev_on_states(next_state) - - # Re-read final state - final_state = self.get_observation(render=render_obs) - self._current_observation = final_state - return final_state + def _domain_specific_step(self) -> None: + """Handle water filling/spillage, heating, and happiness.""" + state = self._get_state() + self._handle_faucet_logic(state) + self._handle_heating_logic(state) + self._update_liquid_colors(state) + self._update_burner_colors(state) + self._update_human_happiness(state) + self._update_prev_on_states(state) def _handle_faucet_logic(self, state: State) -> None: """If faucet is on, fill any jug that is properly aligned; otherwise, @@ -790,7 +766,7 @@ def _handle_heating_logic(self, state: State) -> None: new_heat = min(1.0, old_heat + self.heating_speed) jug_obj.heat_level = new_heat - def _update_jug_colors(self, state: State) -> None: + def _update_liquid_colors(self, state: State) -> None: """Simple linear interpolation from blue (0.0) to red (1.0) based on jug.heat.""" jugs = state.get_objects(self._jug_type) @@ -927,7 +903,7 @@ def _set_switch_on(self, switch_id: int, power_on: bool) -> None: j_id, physicsClientId=self._physics_client_id) j_min, j_max = info[8], info[9] - target_val = j_max if power_on else j_min + target_val = (j_max if power_on else j_min) * self.switch_joint_scale p.resetJointState(switch_id, j_id, target_val, @@ -1388,6 +1364,8 @@ def _create_liquid_for_jug( cx = state.get(jug, "x") cy = state.get(jug, "y") cz = self.z_lb + liquid_height / 2 + 0.02 # sits on table + jug_rot = state.get(jug, "rot") + orientation = p.getQuaternionFromEuler([0.0, 0.0, jug_rot]) color = self.water_color return create_pybullet_block(color=color, @@ -1395,6 +1373,7 @@ def _create_liquid_for_jug( mass=0.01, friction=0.5, position=(cx, cy, cz), + orientation=orientation, physics_client_id=self._physics_client_id) @@ -1445,7 +1424,7 @@ def _main() -> None: # pylint: disable=too-many-locals burner1, faucet) for task in tasks: - env._reset_state(task.init) + env._set_state(task.init) for _ in range(20000): action = Action( np.array(env._pybullet_robot.initial_joint_positions)) diff --git a/predicators/envs/pybullet_circuit.py b/predicators/envs/pybullet_circuit.py index 6c1f414cc..4155c7a9d 100644 --- a/predicators/envs/pybullet_circuit.py +++ b/predicators/envs/pybullet_circuit.py @@ -104,7 +104,7 @@ class PyBulletCircuitEnv(PyBulletEnv): _c_battery_type = Type("c_battery", ["x", "y", "z", "yaw", "pitch", "roll"]) - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: # Objects self._robot = Object("robot", self._robot_type) @@ -120,7 +120,7 @@ def __init__(self, use_gui: bool = False) -> None: self._c_battery1 = Object("c_battery1", self._c_battery_type) self._c_battery2 = Object("c_battery2", self._c_battery_type) - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) # Predicates self._Holding = Predicate("Holding", @@ -297,7 +297,7 @@ def _get_object_ids_for_held_check(self) -> List[int]: """Return IDs of wires (assuming the robot can pick them up).""" return [self._wire1.id, self._wire2.id] - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" if obj.type == self._light_type and feature == "is_on": return int(self._is_bulb_on(obj.id)) @@ -305,10 +305,10 @@ def _extract_feature(self, obj: Object, feature: str) -> float: return int(self._is_switch_on()) raise ValueError(f"Unknown feature {feature} for object {obj}") - def _create_task_specific_objects(self, state: State) -> None: - pass - - def _reset_custom_env_state(self, state: State) -> None: + def _set_domain_specific_state(self, state: State) -> None: + """Set switch position and bulb on/off state.""" + is_switch_on = state.get(self._battery, "is_on") + self._set_switch_on(self._battery, is_switch_on) is_light_on = state.get(self._light, "is_on") if is_light_on: @@ -316,29 +316,23 @@ def _reset_custom_env_state(self, state: State) -> None: else: self._turn_bulb_off() - is_switch_on = state.get(self._battery, "is_on") - self._set_switch_on(self._battery, is_switch_on) - - def step(self, action: Action, render_obs: bool = False) -> State: - """Process a single action step. - - If the battery is connected to the light, turn the bulb on. - """ - next_state = super().step(action, render_obs=render_obs) + def _domain_specific_step(self) -> None: + """If the battery is connected to the light, turn the bulb on.""" + state = self._get_state() # Check basic conditions for turning on the bulb - switch_on = self._SwitchedOn_holds(next_state, [self._battery]) + switch_on = self._SwitchedOn_holds(state, [self._battery]) basic_conditions = switch_on and ( - CFG.circuit_light_doesnt_need_battery or self._CircuitClosed_holds( - next_state, [self._light, self._battery])) + CFG.circuit_light_doesnt_need_battery + or self._CircuitClosed_holds(state, [self._light, self._battery])) # Additional condition: if not using battery_in_box mode, # both C batteries must be in the battery box if not CFG.circuit_battery_in_box and self._c_battery1 is not None \ and self._c_battery2 is not None: both_batteries_in_box = ( - self._InBatteryBox_holds(next_state, [self._c_battery1]) - and self._InBatteryBox_holds(next_state, [self._c_battery2])) + self._InBatteryBox_holds(state, [self._c_battery1]) + and self._InBatteryBox_holds(state, [self._c_battery2])) can_turn_on = basic_conditions and both_batteries_in_box else: can_turn_on = basic_conditions @@ -348,13 +342,8 @@ def step(self, action: Action, render_obs: bool = False) -> State: else: self._turn_bulb_off() - final_state = self._get_state() - # Draw debug lines to visualize battery box region - self._draw_battery_box_debug_lines(final_state) - - self._current_observation = final_state - return final_state + self._draw_battery_box_debug_lines(state) # ------------------------------------------------------------------------- # Predicates @@ -775,7 +764,7 @@ def _main() -> None: CFG.num_train_tasks = 1 env = PyBulletCircuitEnv(use_gui=True) task = env._generate_train_tasks()[0] - env._reset_state(task.init) + env._set_state(task.init) while True: action = Action( diff --git a/predicators/envs/pybullet_coffee.py b/predicators/envs/pybullet_coffee.py index 73f429322..364318200 100644 --- a/predicators/envs/pybullet_coffee.py +++ b/predicators/envs/pybullet_coffee.py @@ -217,7 +217,7 @@ def pour_z_offset(cls) -> float: _camera_pitch: ClassVar[float] _camera_target: ClassVar[Pose3D] - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: if CFG.coffee_render_grid_world: # Camera parameters for grid world PyBulletCoffeeEnv._camera_distance = 3 @@ -238,7 +238,7 @@ def __init__(self, use_gui: bool = False) -> None: # PyBulletCoffeeEnv._camera_pitch = 0 # even lower PyBulletCoffeeEnv._camera_target = (0.75, 1.25, 0.42) - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) # Create the cups lazily because they can change size and color. # self._cup_id_to_cup: Dict[int, Object] = {} @@ -254,6 +254,11 @@ def __init__(self, use_gui: bool = False) -> None: self._machine_plugged_in_id: Optional[int] = None self._last_jug_liquid_level: float = 0.0 + # Captured in step() before kinematics, consumed by + # _domain_specific_step() to detect twisting motions. + self._pre_step_ee_rpy: Tuple[float, float, float] = (0.0, 0.0, 0.0) + self._last_action: Action = Action(np.zeros(0, dtype=np.float32)) + @property def oracle_proposed_predicates(self) -> Set[Predicate]: """Return the predicates that the oracle can propose.""" @@ -315,14 +320,6 @@ def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None: def get_name(cls) -> str: return "pybullet_coffee" - def _create_task_specific_objects(self, state: State) -> None: - """Remove/rebuild cups, liquids, and cords so each new task can have - different cups and states.""" - self._remake_jug_liquid(state) - self._remake_cup_liquids(state) - self._remake_cups(state) - self._remake_cord() - def _remake_cups(self, state: State) -> None: """Re-load cup URDFs with appropriate scaling and color for each new cup.""" @@ -403,16 +400,19 @@ def _remake_cord(self) -> None: self._physics_client_id) self._plug.id = self._cord_ids[-1] - def _reset_custom_env_state(self, state: State) -> None: - """Handles extra coffee-specific reset steps: spawning cups from - scratch, adding liquid visuals, adjusting jug fill color, toggling the - machine button, etc. + def _set_domain_specific_state(self, state: State) -> None: + """Reset liquid visuals, cup geometry, cord, and button colors.""" + self._remake_jug_liquid(state) + self._remake_cups(state) + for cup in state.get_objects(self._cup_type): + self._reset_single_object(cup, state) + self._remake_cup_liquids(state) + self._remake_cord() + if CFG.coffee_machine_has_plug: + for plug in state.get_objects(self._plug_type): + self._reset_single_object(plug, state) - The base `_reset_state` has already done the standard - position/orientation resets for objects in `_get_all_objects()`. - """ # Machine button color - # Check if the machine is on and the jug is in place: if self._MachineOn_holds(state, [self._machine]) and \ self._JugInMachine_holds(state, [self._jug, self._machine]): button_color = self.button_color_on @@ -439,7 +439,7 @@ def _reset_custom_env_state(self, state: State) -> None: rgbaColor=plate_color, physicsClientId=self._physics_client_id) - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" if obj.type == self._jug_type: if feature == "is_filled": @@ -480,21 +480,19 @@ def _extract_feature(self, obj: Object, feature: str) -> float: raise ValueError(f"Unknown feature {feature} for object {obj}") def step(self, action: Action, render_obs: bool = False) -> State: - # Save current end-effector roll-pitch-yaw for later comparison - current_ee_rpy = self._pybullet_robot.forward_kinematics( + # Save pre-kinematics state for _domain_specific_step. + self._pre_step_ee_rpy = self._pybullet_robot.forward_kinematics( self._pybullet_robot.get_joints()).rpy - state = super().step(action, render_obs=render_obs) - # self._update_jug_liquid_position() + self._last_action = action + return super().step(action, render_obs=render_obs) + + def _domain_specific_step(self) -> None: + state = self._get_state() if CFG.coffee_machine_has_plug: self._check_and_apply_plug_in_constraint(state) self._handle_machine_on_and_jug_filling(state) self._handle_pouring(state) - self._handle_twisting(state, current_ee_rpy, action) - # Refresh current observation - self._current_observation = self._get_state(_render_obs=False) - state = self._current_observation.copy() - - return state + self._handle_twisting(state, self._pre_step_ee_rpy, self._last_action) def _update_jug_liquid_position(self) -> None: """If the jug is filled, move its liquid to match the jug's pose. @@ -1275,7 +1273,7 @@ def _main() -> None: env = PyBulletCoffeeEnv(use_gui=True) rng = np.random.default_rng(CFG.seed) task = env._make_tasks(1, rng)[0] # type: ignore[attr-defined] # pylint: disable=no-member - env._reset_state(task.init) + env._set_state(task.init) while True: # Robot does nothing diff --git a/predicators/envs/pybullet_cover.py b/predicators/envs/pybullet_cover.py index 24ea5d5d0..97d288157 100644 --- a/predicators/envs/pybullet_cover.py +++ b/predicators/envs/pybullet_cover.py @@ -13,9 +13,10 @@ from predicators import utils from predicators.envs.cover import CoverEnv -from predicators.envs.pybullet_env import PyBulletEnv, create_pybullet_block +from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion -from predicators.pybullet_helpers.objects import update_object +from predicators.pybullet_helpers.objects import create_pybullet_block, \ + update_object from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot from predicators.settings import CFG from predicators.structs import Action, Array, EnvironmentTask, Object, State @@ -58,10 +59,10 @@ class PyBulletCoverEnv(PyBulletEnv, CoverEnv): float]]] = [(0, 0, 0, 1.), (1, 1, 1, 1.)] - def __init__(self, use_gui: bool = False) -> None: - super().__init__(use_gui) + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: + super().__init__(use_gui, **kwargs) # Store block/target IDs (from initialize_pybullet) so that we can - # reset their positions in _reset_custom_env_state(). + # reset their positions in _set_domain_specific_state(). self._table_id: int = -1 # self._block_ids: list[int] = [] # self._target_ids: list[int] = [] @@ -151,10 +152,7 @@ def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None: for tgt, tgt_id in zip(self._targets, pybullet_bodies["target_ids"]): tgt.id = tgt_id - def _create_task_specific_objects(self, state: State) -> None: - """No domain-specific extra creation needed here.""" - - def _reset_custom_env_state(self, state: State) -> None: + def _set_domain_specific_state(self, state: State) -> None: """After the parent class has reset the robot, handle the block/target positions. @@ -299,24 +297,13 @@ def _extract_robot_state(self, state: State) -> np.ndarray: return np.array([rx, ry, rz, qx, qy, qz, qw, fingers], dtype=np.float32) - def _extract_feature(self, obj: Object, feature: str) -> float: - """Domain-specific feature extraction for blocks, targets, and the - (robot).""" - # # 1) If it's the robot - # if obj.type == self._robot_type: - # # The parent's _get_robot_state_dict() will set x,y,z,fingers - # # We can handle additional features here: - # rx, ry, rz, _, _, _, _, rf = self._pybullet_robot.get_state() - # if feature == "hand": - # # Re-normalize the y coordinate - # return (ry - self.y_lb) / (self.y_ub - self.y_lb) - # elif feature == "pose_x": - # return rx - # elif feature == "pose_z": - # return rz - # raise ValueError(f"Unknown robot feature: {feature}") - - # 2) If it's a block + def _get_robot_state_dict(self) -> Dict[str, float]: + rx, ry, rz, _, _, _, _, _rf = self._pybullet_robot.get_state() + hand = (ry - self.y_lb) / (self.y_ub - self.y_lb) + return {"hand": hand, "pose_x": rx, "pose_z": rz} + + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: + """Domain-specific feature extraction for blocks and targets.""" if obj.type == self._block_type: block_id = obj.id if feature == "is_block": @@ -383,19 +370,15 @@ def _extract_feature(self, obj: Object, feature: str) -> float: # Step logic (unchanged except for removing direct calls to _get_state()) # ----------------------------------------------------------------------- def step(self, action: Action, render_obs: bool = False) -> State: - """Override to handle the Cover domain's 'hand region' constraint - before calling the parent's step().""" - # Check if the pick/place position satisfies the hand constraints + """Check hand region constraint before kinematics.""" if not self._satisfies_hand_contraints(action): - # Constraint violated => no-op return self._current_state.copy() + return super().step(action, render_obs=render_obs) - # Otherwise, proceed with normal PyBullet step - next_state = super().step(action, render_obs=render_obs) - + def _domain_specific_step(self) -> None: if CFG.cover_blocks_change_color_when_cover: - self._change_block_color_when_cover(next_state) - return next_state + state = self._get_state() + self._change_block_color_when_cover(state) def _change_block_color_when_cover(self, state: State) -> None: """If a block is now covering a target, change it's color to diff --git a/predicators/envs/pybullet_domino/components/ball_component.py b/predicators/envs/pybullet_domino/components/ball_component.py index a3fccb2d4..9d0e44677 100644 --- a/predicators/envs/pybullet_domino/components/ball_component.py +++ b/predicators/envs/pybullet_domino/components/ball_component.py @@ -14,9 +14,8 @@ from predicators.envs.pybullet_domino.components.base_component import \ DominoEnvComponent -from predicators.envs.pybullet_env import create_pybullet_block, \ - create_pybullet_sphere -from predicators.pybullet_helpers.objects import update_object +from predicators.pybullet_helpers.objects import create_pybullet_block, \ + create_pybullet_sphere, update_object from predicators.settings import CFG from predicators.structs import Object, Predicate, State, Type diff --git a/predicators/envs/pybullet_domino/components/domino_component.py b/predicators/envs/pybullet_domino/components/domino_component.py index 54d9cca85..8375ffba3 100644 --- a/predicators/envs/pybullet_domino/components/domino_component.py +++ b/predicators/envs/pybullet_domino/components/domino_component.py @@ -18,9 +18,9 @@ from predicators import utils from predicators.envs.pybullet_domino.components.base_component import \ DominoEnvComponent -from predicators.envs.pybullet_env import create_pybullet_block from predicators.pybullet_helpers.geometry import Pose3D, Quaternion -from predicators.pybullet_helpers.objects import create_object, update_object +from predicators.pybullet_helpers.objects import create_object, \ + create_pybullet_block, update_object from predicators.settings import CFG from predicators.structs import Object, Predicate, State, Type diff --git a/predicators/envs/pybullet_domino/components/stairs_component.py b/predicators/envs/pybullet_domino/components/stairs_component.py index ff966467c..24e32cc00 100644 --- a/predicators/envs/pybullet_domino/components/stairs_component.py +++ b/predicators/envs/pybullet_domino/components/stairs_component.py @@ -12,7 +12,7 @@ from predicators.envs.pybullet_domino.components.base_component import \ DominoEnvComponent -from predicators.envs.pybullet_env import create_pybullet_block +from predicators.pybullet_helpers.objects import create_pybullet_block from predicators.structs import Object, State, Type diff --git a/predicators/envs/pybullet_domino/composed_env.py b/predicators/envs/pybullet_domino/composed_env.py index a30846dba..34aa3da41 100644 --- a/predicators/envs/pybullet_domino/composed_env.py +++ b/predicators/envs/pybullet_domino/composed_env.py @@ -102,7 +102,8 @@ class PyBulletDominoComposedEnv(PyBulletEnv): def __init__(self, components: List[DominoEnvComponent], - use_gui: bool = False) -> None: + use_gui: bool = False, + **kwargs: Any) -> None: """Initialize the composed domino environment. Args: @@ -134,7 +135,7 @@ def __init__(self, # Wire up fan -> ball wind connection if both present # (done after PyBullet init in _store_pybullet_bodies) - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) def _create_robot_predicates(self) -> None: """Create robot-specific predicates.""" @@ -277,10 +278,7 @@ def _get_object_ids_for_held_check(self) -> List[int]: ids.extend(comp.get_object_ids_for_held_check()) return ids - def _create_task_specific_objects(self, state: State) -> None: - """Create any task-specific objects (not used in current impl).""" - - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract state feature for an object.""" # Try each component for comp in self._components: @@ -290,32 +288,23 @@ def _extract_feature(self, obj: Object, feature: str) -> float: raise ValueError(f"Unknown feature {feature} for object {obj}") - def _reset_custom_env_state(self, state: State) -> None: - """Reset environment to match the given state.""" - # Update ball component's state reference for is_hit feature - if self._ball_component is not None: - self._ball_component.set_current_state(state) - - # Reset each component + def _set_domain_specific_state(self, state: State) -> None: + """Reset each component and update ball state reference.""" for comp in self._components: comp.reset_state(state) - def step(self, action: Action, render_obs: bool = False) -> State: - """Execute action and run component physics updates.""" - super().step(action, render_obs=render_obs) + if self._ball_component is not None: + self._ball_component.set_current_state(state) - # Run component step functions (e.g., fan wind simulation) + def _domain_specific_step(self) -> None: + """Run component physics updates (e.g., fan wind simulation).""" for comp in self._components: comp.step() - final_state = self._get_state() - self._current_observation = final_state - # Update ball component's state reference if self._ball_component is not None: - self._ball_component.set_current_state(final_state) - - return final_state + state = self._get_state() + self._ball_component.set_current_state(state) # ========================================================================= # PREDICATE HOLD FUNCTIONS @@ -416,7 +405,7 @@ def _make_tasks(self, class PyBulletDominoEnvNew(PyBulletDominoComposedEnv): """Backward-compatible domino environment class.""" - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: workspace_bounds = { "x_lb": self.x_lb, "x_ub": self.x_ub, @@ -438,7 +427,7 @@ def __init__(self, use_gui: bool = False) -> None: num_pivots_max=max_pivots, workspace_bounds=workspace_bounds) - super().__init__(components=[domino_comp], use_gui=use_gui) + super().__init__(components=[domino_comp], use_gui=use_gui, **kwargs) @classmethod def get_name(cls) -> str: @@ -448,7 +437,7 @@ def get_name(cls) -> str: class PyBulletDominoFanEnvNew(PyBulletDominoComposedEnv): """Backward-compatible domino + fan + ball environment class.""" - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: workspace_bounds = { "x_lb": self.x_lb, "x_ub": self.x_ub, @@ -478,7 +467,8 @@ def __init__(self, use_gui: bool = False) -> None: table_height=self.table_height) super().__init__(components=[domino_comp, fan_comp, ball_comp], - use_gui=use_gui) + use_gui=use_gui, + **kwargs) @classmethod def get_name(cls) -> str: @@ -504,7 +494,7 @@ def goal_predicates(self) -> Set[Predicate]: class PyBulletDominoFanRampEnv(PyBulletDominoComposedEnv): """Domino + fan + ball + ramp environment class.""" - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: workspace_bounds = { "x_lb": self.x_lb, "x_ub": self.x_ub, @@ -539,7 +529,8 @@ def __init__(self, use_gui: bool = False) -> None: super().__init__( components=[domino_comp, fan_comp, ball_comp, ramp_comp], - use_gui=use_gui) + use_gui=use_gui, + **kwargs) @classmethod def get_name(cls) -> str: @@ -565,7 +556,7 @@ def goal_predicates(self) -> Set[Predicate]: class PyBulletDominoFanRampStairsEnv(PyBulletDominoComposedEnv): """Domino + fan + ball + ramp + stairs environment class.""" - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: workspace_bounds = { "x_lb": self.x_lb, "x_ub": self.x_ub, @@ -607,7 +598,8 @@ def __init__(self, use_gui: bool = False) -> None: super().__init__(components=[ domino_comp, fan_comp, ball_comp, ramp_comp, stairs_comp ], - use_gui=use_gui) + use_gui=use_gui, + **kwargs) # Store reference to stairs component self._stairs_component = stairs_comp @@ -699,7 +691,7 @@ def goal_predicates(self) -> Set[Predicate]: print(f"{'=' * 60}") # Reset to initial state - env._reset_state(task.init) # pylint: disable=protected-access + env._set_state(task.init) # pylint: disable=protected-access print("\nGoal atoms:") for atom in task.goal: diff --git a/predicators/envs/pybullet_env.py b/predicators/envs/pybullet_env.py index b07e31b39..c788bedb0 100644 --- a/predicators/envs/pybullet_env.py +++ b/predicators/envs/pybullet_env.py @@ -7,14 +7,30 @@ For a comprehensive guide on creating new PyBullet environments, see: docs/pybullet_env_guide.md -Quick reference - required methods to implement: +Main public API: + reset(train_or_test, task_idx) — reset env to a task, returns observation + simulate(state, action) — forward-simulate without touching real env + step(action) — _step_base (robot control, physics, grasps) + → _domain_specific_step (water filling, heating, etc.) + → get_observation. Domain dynamics are skipped when + skip_process_dynamics=True is passed to the constructor. + get_observation() — read PyBullet state, optionally attach images/masks + +State synchronization: + _set_state(state) — write a State into PyBullet (robot pose, object + poses, grasp constraints). Delegates domain-specific setup to + _set_domain_specific_state(). + _get_state() — read PyBullet into a PyBulletState. Delegates + domain-specific features to _get_domain_specific_feature(). + +Required overrides in subclasses: - get_name() -> str - initialize_pybullet(using_gui) -> (physics_id, robot, bodies_dict) - _store_pybullet_bodies(bodies_dict) - _get_object_ids_for_held_check() -> List[int] - - _create_task_specific_objects(state) - - _reset_custom_env_state(state) - - _extract_feature(obj, feature) -> float + - _set_domain_specific_state(state) + - _get_domain_specific_feature(obj, feature) -> float + - _domain_specific_step() (optional, default no-op) """ import abc @@ -31,6 +47,7 @@ from predicators.envs import BaseEnv from predicators.pybullet_helpers.camera import create_gui_connection from predicators.pybullet_helpers.geometry import Pose, Pose3D, Quaternion +from predicators.pybullet_helpers.joint import JointPositions from predicators.pybullet_helpers.link import get_link_state from predicators.pybullet_helpers.objects import update_object from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot, \ @@ -94,6 +111,11 @@ class PyBulletEnv(BaseEnv): _out_of_view_xy: ClassVar[Sequence[float]] = [10.0, 10.0] _default_orn: ClassVar[Sequence[float]] = [0.0, 0.0, 0.0, 1.0] + # Object types that have no PyBullet body — features managed + # entirely by _get_domain_specific_feature(). + _VIRTUAL_OBJECT_TYPES: ClassVar[frozenset] = frozenset( + {"loc", "angle", "human", "side", "direction"}) + # Camera parameters. _camera_distance: ClassVar[float] = 0.8 _camera_yaw: ClassVar[float] = 90.0 @@ -102,7 +124,9 @@ class PyBulletEnv(BaseEnv): _camera_fov: ClassVar[float] = 60 _debug_text_position: ClassVar[Pose3D] = (1.65, 0.25, 0.75) - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, + use_gui: bool = False, + skip_process_dynamics: bool = False) -> None: super().__init__(use_gui) # Forward declaration: subclasses must define _robot @@ -115,30 +139,20 @@ def __init__(self, use_gui: bool = False) -> None: self._held_obj_to_base_link: Optional[Any] = None self._held_obj_id: Optional[int] = None + # When True, _domain_specific_step() is skipped in step(). + # Used by sim-learning to create base-sim-only envs. + self._skip_domain_specific_dynamics: bool = skip_process_dynamics + # Set up all the static PyBullet content. self._physics_client_id, self._pybullet_robot, pybullet_bodies = \ self.initialize_pybullet(self.using_gui) self._store_pybullet_bodies(pybullet_bodies) - # What are they used for?? - # It's used in get_state, reset_state and labeling state. - # Should be populated at reset or reset state. + # Populated by reset() / _set_state(); used by _get_state(), + # _set_state(), and render_segmented_obj() for iteration. self._objects: List[Object] = [] - def get_extra_collision_ids(self) -> Sequence[int]: - """Return extra PyBullet body IDs to treat as collision obstacles. - - Override in subclasses for bodies not tracked as state Objects - (e.g. liquid blocks in Grow). - """ - return () - - def get_object_by_id(self, obj_id: int) -> Object: - """Get object by id.""" - for obj in self._objects: - if obj.id == obj_id: - return obj - raise ValueError(f"Object with ID {obj_id} not found") + # ── Setup & Initialization ────────────────────────────────── @classmethod def initialize_pybullet( @@ -175,11 +189,11 @@ def initialize_pybullet( loading. - Task-specific objects that need to be loaded with different sizes or other properties should be handled in the - `_create_task_specific_objects` method, which is called during each + `_set_domain_specific_state` method, which is called during each task's reset. - Subclasses may override this method to load additional assets. In the subclass, register all object IDs here and move them out of view - in the `reset_custom_env_state` method. + in the `_set_domain_specific_state` method. """ # Skip test coverage because GUI is too expensive to use in unit tests # and cannot be used in headless mode. @@ -221,6 +235,10 @@ def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None: @classmethod def _create_pybullet_robot( cls, physics_client_id: int) -> SingleArmPyBulletRobot: + """Instantiate the robot model. + + Called by initialize_pybullet(). + """ robot_ee_orn = cls.get_robot_ee_home_orn() ee_home = Pose((cls.robot_init_x, cls.robot_init_y, cls.robot_init_z), robot_ee_orn) @@ -234,194 +252,357 @@ def _create_pybullet_robot( physics_client_id, ee_home, base_pose) - def _extract_robot_state(self, state: State) -> Array: - """Given a State, extract the robot state, to be passed into - self._pybullet_robot.reset_state(). + @classmethod + def get_robot_ee_home_orn(cls) -> Quaternion: + """Return the default end-effector orientation for this env. - This should be the same type as the return value of - self._pybullet_robot.get_state(). + Used by initialize_pybullet() to set the robot's home pose, and + by oracle options to compute motion-planning targets. """ + robot_ee_orns = CFG.pybullet_robot_ee_orns[cls.get_name()] + return robot_ee_orns[CFG.pybullet_robot] - # EE Position - def get_pos_feature( - state: State, - feature_name: str) -> float: # type: ignore[no-untyped-def] - if feature_name in self._robot.type.feature_names: - return state.get(self._robot, feature_name) - if f"pose_{feature_name}" in self._robot.type.feature_names: - return state.get(self._robot, f"pose_{feature_name}") - raise ValueError(f"Cannot find robot pos '{feature_name}'") - - rx = get_pos_feature(state, "x") - ry = get_pos_feature(state, "y") - rz = get_pos_feature(state, "z") - - # EE Orientation - _, default_tilt, default_wrist = p.getEulerFromQuaternion( - self.get_robot_ee_home_orn()) - if "tilt" in self._robot.type.feature_names: - tilt = state.get(self._robot, "tilt") - else: - tilt = default_tilt - if "wrist" in self._robot.type.feature_names: - wrist = state.get(self._robot, "wrist") - else: - wrist = default_wrist - qx, qy, qz, qw = p.getQuaternionFromEuler([0.0, tilt, wrist]) - - # Fingers - f = state.get(self._robot, "fingers") - f = self._fingers_state_to_joint(self._pybullet_robot, f) - - return np.array([rx, ry, rz, qx, qy, qz, qw, f], dtype=np.float32) - - @abc.abstractmethod - def _get_object_ids_for_held_check(self) -> List[int]: - """Return a list of pybullet IDs corresponding to objects in the - simulator that should be checked when determining whether one is - held.""" - raise NotImplementedError("Override me!") - - def _get_expected_finger_normals(self) -> Dict[int, Array]: - # Get the current state of the robot, including the orientation - # quaternion - _rx, _ry, _rz, qx, qy, qz, qw, _rf = self._pybullet_robot.get_state() - - # Convert the quaternion to a rotation matrix - rotation_matrix = p.getMatrixFromQuaternion([qx, qy, qz, qw]) - rotation_matrix = np.array(rotation_matrix).reshape(3, 3) - - # Define the initial normal vectors for the fingers - if CFG.pybullet_robot == "panda": - # gripper rotated 90deg so parallel to x-axis - normal = np.array([1., 0., 0.], dtype=np.float32) - elif CFG.pybullet_robot in {"fetch", "mobile_fetch"}: - # gripper parallel to y-axis - normal = np.array([0., 1., 0.], dtype=np.float32) - else: # pragma: no cover - # Shouldn't happen unless we introduce a new robot. - raise ValueError(f"Unknown robot {CFG.pybullet_robot}") - - # Transform the normal vectors using the rotation matrix - transformed_normal = rotation_matrix.dot(normal) - transformed_normal_neg = rotation_matrix.dot(-1 * normal) - - return { - self._pybullet_robot.left_finger_id: transformed_normal, - self._pybullet_robot.right_finger_id: transformed_normal_neg, - } - - @classmethod - def _fingers_state_to_joint(cls, pybullet_robot: SingleArmPyBulletRobot, - finger_state: float) -> float: - """Map the fingers in the given *State* to joint values for - PyBullet.""" - # If open_fingers is undefined, use 1.0 as the default. - subs = { - cls.open_fingers: pybullet_robot.open_fingers, - cls.closed_fingers: pybullet_robot.closed_fingers, - } - match = min(subs, key=lambda k: abs(k - finger_state)) - return subs[match] - - @classmethod - def _fingers_joint_to_state(cls, pybullet_robot: SingleArmPyBulletRobot, - finger_joint: float) -> float: - """Inverse of _fingers_state_to_joint().""" - subs = { - pybullet_robot.open_fingers: cls.open_fingers, - pybullet_robot.closed_fingers: cls.closed_fingers, - } - match = min(subs, key=lambda k: abs(k - finger_joint)) - return subs[match] + # ── Public API & Properties ───────────────────────────────── @property def action_space(self) -> Box: return self._pybullet_robot.action_space - def simulate(self, state: State, action: Action) -> State: - # Optimization: check if we're already in the right state. - # self._current_observation is None at the beginning - # state is not allclose to self._current_state when the state has been - # updated, so it first calls _reset_state to update the pybullet state - if self._current_observation is None or \ - not state.allclose(self._current_state): - self._current_observation = state - self._reset_state(state) - return self.step(action) + def get_extra_collision_ids(self) -> Sequence[int]: + """Return extra PyBullet body IDs to treat as collision obstacles. - def render_state_plt( - self, - state: State, - task: EnvironmentTask, - action: Optional[Action] = None, - caption: Optional[str] = None) -> matplotlib.figure.Figure: - raise NotImplementedError("This env does not use Matplotlib") + Called by the motion planner (skill factories) when computing + collision-free paths. Override in subclasses for bodies not + tracked as state Objects (e.g. liquid blocks in Grow). + """ + return () - def render_state(self, - state: State, - task: EnvironmentTask, - action: Optional[Action] = None, - caption: Optional[str] = None) -> Video: - raise NotImplementedError("A PyBullet environment cannot render " - "arbitrary states.") + def get_object_by_id(self, obj_id: int) -> Object: + """Look up an Object by its PyBullet body ID. + + Used by agent tools and skill factories to map from a PyBullet + collision/contact result back to the predicators Object. + """ + for obj in self._objects: + if obj.id == obj_id: + return obj + raise ValueError(f"Object with ID {obj_id} not found") + + # ── Core Loop (Reset / Simulate / Step) ───────────────────── def reset(self, train_or_test: str, task_idx: int, render: bool = False) -> Observation: state = super().reset(train_or_test, task_idx) - self._reset_state(state) + self._set_state(state) observation = self.get_observation(render=render) return observation - def _reset_state(self, state: State) -> None: - """Reset the PyBullet state to match the given state. + def simulate(self, state: State, action: Action) -> State: + """Apply an action to a state using the PyBullet simulator. + + Called by the option model during bilevel planning to forward- + simulate candidate action sequences without touching the real + environment. + + The _set_state guard handles two cases: + - Skipped (common): during a sequential rollout the option model + calls simulate(s1, a1) -> s2, then simulate(s2, a2) -> s3, etc. + After each call, _current_state already equals the next input + state, so _set_state is unnecessary. + - Taken: when the planner jumps to a different state (e.g. trying + a new skeleton or backtracking), or on the very first call + before any reset() (_current_observation is None). + """ + if self._current_observation is None or \ + not state.allclose(self._current_state): + self._set_state(state) + return self.step(action) + + def step(self, action: Action, render_obs: bool = False) -> Observation: + """Execute one environment step with the given action. - Used in initialization (reset(), _add_pybullet_state_to_tasks()) - and bilevel planning (when creating the option model)). + Flow: base sim → domain-specific dynamics → observation. + Subclasses override ``_domain_specific_step`` (not this method) + to add post-base-sim dynamics (water filling, heating, etc.). """ - self._objects = list(state.data) - # 1) Clear old constraint if we had a held object - if self._held_constraint_id is not None: + self._step_base(action) + if not self._skip_domain_specific_dynamics: + self._domain_specific_step() + observation = self.get_observation( + render=CFG.rgb_observation or render_obs) + self._current_observation = observation + return observation + + def _step_base(self, action: Action) -> None: + """Run robot control, physics stepping, and grasp management.""" + # Send the action to the robot. + target_joint_positions, base_delta = self._split_action(action) + if base_delta.size: + self._apply_base_delta(base_delta) + self._pybullet_robot.set_motors(target_joint_positions.tolist()) + + # If we are setting the robot joints directly, and if there is a held + # object, we need to reset the pose of the held object directly. This + # is because the PyBullet constraints don't seem to play nicely with + # resetJointState (the robot will sometimes drop the object). + if CFG.pybullet_control_mode == "reset" and \ + self._held_obj_id is not None: + world_to_base_link = get_link_state( + self._pybullet_robot.robot_id, + self._pybullet_robot.end_effector_id, + physics_client_id=self._physics_client_id).com_pose + base_link_to_held_obj = p.invertTransform( + *self._held_obj_to_base_link) + world_to_held_obj = p.multiplyTransforms(world_to_base_link[0], + world_to_base_link[1], + base_link_to_held_obj[0], + base_link_to_held_obj[1]) + p.resetBasePositionAndOrientation( + self._held_obj_id, + world_to_held_obj[0], + world_to_held_obj[1], + physicsClientId=self._physics_client_id) + + # Step the simulation here before adding or removing constraints + # because detect_held_object() should use the updated state. + if CFG.pybullet_control_mode != "reset": + for _ in range(CFG.pybullet_sim_steps_per_action): + p.stepSimulation(physicsClientId=self._physics_client_id) + + # If not currently holding something, and fingers are closing, check + # for a new grasp. + if self._held_constraint_id is None and self._fingers_closing(action): + self._held_obj_id = self._detect_held_object() + if self._held_obj_id is not None: + self._create_grasp_constraint() + + # If placing, remove the grasp constraint. + if self._held_constraint_id is not None and \ + self._fingers_opening(action): p.removeConstraint(self._held_constraint_id, physicsClientId=self._physics_client_id) self._held_constraint_id = None - self._held_obj_to_base_link = None - self._held_obj_id = None + self._held_obj_id = None + + def _domain_specific_step(self) -> None: + """Apply domain-specific dynamics after the base sim. + + Override in subclasses to add post-base-sim effects (water + filling, heating, balance beam physics, etc.). Skipped when + ``skip_process_dynamics=True`` is passed to the constructor. + """ + + # ── State Write (State → PyBullet) ────────────────────────── + + def _set_state(self, state: State) -> None: + """State -> PyBullet: write the requested State into the simulator. + + Per-component diff: each piece of the State (robot pose, each + object pose, held-object identity) is compared against the live + PyBullet world and only re-written when it actually differs. + This lets sequential rollouts (option model, learned process + simulators) advance without snapping the arm or rebuilding the + grasp constraint when only a subset of features changed — which + is what eliminates the visible robot jitter during combined + base+learned simulator calls. It also lets a learned rule move + an *unheld* object without disturbing the arm or any other body. + + Call sites: + - reset() / _add_pybullet_state_to_tasks(): initialization + - simulate(): option-model / bilevel-planning rollouts + - external callers (skill factories, agent tools, tests) + """ + # Cohort change or the very first call forces a full reset: + # per-component compares assume the same set of bodies. + full_reset = (self._current_observation is None + or set(self._objects) != set(state.data)) + + # Keep _current_observation in sync so step() can read it + # (e.g. for finger-delta computation). + self._current_observation = state + self._objects = list(state.data) - # 2) Reset robot pose - self._pybullet_robot.reset_state(self._extract_robot_state(state)) + wrote_anything = False - # I want to have a step that creates task specific objects before reset - # their positions, what should I call this? - self._create_task_specific_objects(state) + # 1) Robot pose diff. Skipping this branch when the live joints + # already match the requested pose is what eliminates arm + # jitter: resetJointState would otherwise hard-snap the arm + # on every simulate() call in a sequential rollout. + robot_changed = full_reset or not self._robot_matches_state(state) - # 3) Reset all known objects (position, orientation, etc.) + # 2) Object pose diff. Identify which non-virtual object bodies + # have moved relative to PyBullet. + objects_to_reset: List[Object] = [] for obj in self._objects: - if obj.type.name in [ - "robot", "loc", "angle", "human", "side", "direction" - ]: + if obj.type.name == "robot" or \ + obj.type.name in self._VIRTUAL_OBJECT_TYPES or \ + obj.id is None: continue - self._reset_single_object(obj, state) + if full_reset or not self._object_pose_matches_state(obj, state): + objects_to_reset.append(obj) + + # 3) Held-object identity diff. The grasp constraint must be + # torn down and rebuilt whenever: + # - the held identity changes (including held → unheld and + # unheld → held), + # - the held object's recorded pose changes (the offset to + # the gripper moves), or + # - the gripper itself moves (resetJointState bypasses the + # constraint, so a kept constraint would leave the held + # body behind). + new_held_id = self._held_obj_id_in_state(state) + held_obj_moved = (self._held_obj_id is not None + and any(o.id == self._held_obj_id + for o in objects_to_reset)) + rebuild_constraint = (full_reset or new_held_id != self._held_obj_id + or (self._held_obj_id is not None and + (robot_changed or held_obj_moved))) + + # Tear down before robot/object resets so the held body is free + # while we move things around. + if rebuild_constraint: + if self._held_constraint_id is not None: + p.removeConstraint(self._held_constraint_id, + physicsClientId=self._physics_client_id) + wrote_anything = True + self._held_constraint_id = None + self._held_obj_to_base_link = None + self._held_obj_id = None - # 4) Let the subclass do any additional specialized resetting - self._reset_custom_env_state(state) + if robot_changed: + # Prefer exact joint positions when the State carries them in + # simulator_state — IK from (x, y, z, tilt, wrist) drops + # wrist roll, which corrupts the held-object offset that + # _create_grasp_constraint records below. + joint_positions = self._extract_robot_joint_positions(state) + self._pybullet_robot.reset_state(self._extract_robot_state(state), + joint_positions=joint_positions) + wrote_anything = True + + for obj in objects_to_reset: + self._reset_single_object(obj, state) + wrote_anything = True + + # Recreate the constraint after objects are repositioned so the + # recorded base_link → object offset matches the new pose. + if rebuild_constraint and new_held_id is not None: + self._held_obj_id = new_held_id + self._create_grasp_constraint() + wrote_anything = True + + # 4) Subclass-specific state always runs (idempotent and cheap). + self._set_domain_specific_state(state) + + # 5) Reconstruction check — only when we actually wrote + # something kinematic. Only raise for envs that override + # _get_state(). + if wrote_anything: + reconstructed = self._get_state() + if not reconstructed.allclose(state): + if type(self)._get_state is not PyBulletEnv._get_state: + raise ValueError("Could not reconstruct state.") + logging.warning( + "Could not reconstruct state exactly in reset.") + + def _robot_matches_state(self, state: State, atol: float = 1e-3) -> bool: + """True if PyBullet's live robot pose already equals state's. + + Compares at the joint level. The EE-quaternion path that + ``_extract_robot_state`` builds always uses ``roll=0``, so any + non-zero wrist roll in the live PyBullet pose would spuriously + fail an EE-pose comparison and trigger a full robot reset on + every simulate() call (visible jitter). + + ``atol`` matches ``State.allclose``'s feature tolerance: a looser + check would let the fast-path skip a reset even when the live EE + pose differs from the requested state by more than allclose + accepts (e.g. when a caller hands us + ``initial_joint_positions`` as a hint and the live joints are + only 1e-2 close). + + Returns False when ``state`` has no joint_positions — the only + live caller in that situation is + ``_add_pybullet_state_to_tasks``, where forcing a reset is + exactly the desired behavior. + """ + jp = self._extract_robot_joint_positions(state) + if jp is None: + return False + try: + cur_jp = self._pybullet_robot.get_joints() + except (KeyError, ValueError): + return False + return bool(np.allclose(jp, cur_jp, atol=atol)) + + def _object_pose_matches_state(self, + obj: Object, + state: State, + atol: float = 1e-2) -> bool: + """True if PyBullet's live pose for ``obj`` equals state[obj].""" + if obj.id is None: + return True + try: + features = obj.type.feature_names + (px, py, pz), orn = p.getBasePositionAndOrientation( + obj.id, physicsClientId=self._physics_client_id) + if "x" in features and \ + not np.isclose(state.get(obj, "x"), px, atol=atol): + return False + if "y" in features and \ + not np.isclose(state.get(obj, "y"), py, atol=atol): + return False + if "z" in features and \ + not np.isclose(state.get(obj, "z"), pz, atol=atol): + return False + if {"rot", "yaw", "roll", "pitch"} & set(features): + roll, pitch, yaw = p.getEulerFromQuaternion(orn) + if "rot" in features and not np.isclose( + state.get(obj, "rot"), yaw, atol=atol): + return False + if "yaw" in features and not np.isclose( + state.get(obj, "yaw"), yaw, atol=atol): + return False + if "roll" in features and not np.isclose( + state.get(obj, "roll"), roll, atol=atol): + return False + if "pitch" in features and not np.isclose( + state.get(obj, "pitch"), pitch, atol=atol): + return False + return True + except (KeyError, ValueError): + return False + + def _held_obj_id_in_state(self, state: State) -> Optional[int]: + """Which PyBullet body id is marked is_held > 0.5 in ``state``. + + Returns None if no object is held in ``state``. Mirrors the per- + object logic in _reset_single_object before constraint + management was hoisted out into _set_state. + """ + for obj in state.data: + if obj.id is None: + continue + if "is_held" not in obj.type.feature_names: + continue + try: + if state.get(obj, "is_held") > 0.5: + return obj.id + except (KeyError, ValueError): + continue + return None - # 5) Check for reconstruction mismatch. - # Only raise for envs that override _get_state(). - reconstructed = self._get_state() - if not reconstructed.allclose(state): - if type(self)._get_state is not PyBulletEnv._get_state: - raise ValueError("Could not reconstruct state.") - logging.warning("Could not reconstruct state exactly in reset.") + def _reset_single_object(self, obj: Object, state: State) -> None: + """Teleport a single physical object to match the given State. - @abc.abstractmethod - def _create_task_specific_objects(self, state: State) -> None: - raise NotImplementedError("Override me!") + Pose only — grasp-constraint management is centralized in + _set_state so teardown/rebuild stays in one place. - def _reset_single_object(self, obj: Object, state: State) -> None: - """Shared logic for setting position/orientation and constraints.""" + Called by _set_state() for every non-robot, non-virtual object + whose pose differs from PyBullet (or for all such objects on a + full reset). + """ # Skip objects without pybullet IDs (handled by subclass). if obj.id is None: return @@ -430,8 +611,6 @@ def _reset_single_object(self, obj: Object, state: State) -> None: features = obj.type.feature_names cur_x, cur_y, cur_z = p.getBasePositionAndOrientation( obj.id, physicsClientId=self._physics_client_id)[0] - # except: - # breakpoint() px = state.get(obj, "x") if "x" in obj.type.feature_names else cur_x py = state.get(obj, "y") if "y" in obj.type.feature_names else cur_y pz = state.get(obj, "z") if "z" in obj.type.feature_names else cur_z @@ -447,111 +626,142 @@ def _reset_single_object(self, obj: Object, state: State) -> None: else: orn = self._default_orn # e.g. (0,0,0,1) - # 2) Update the object’s position/orientation in PyBullet + # 2) Update the object's position/orientation in PyBullet update_object(obj.id, (px, py, pz), orn, physics_client_id=self._physics_client_id) - # 3) If there's an is_held feature, reattach constraints if needed - if "is_held" in features: - if state.get(obj, "is_held") > 0.5: - # attach constraint - self._held_obj_id = obj.id - self._create_grasp_constraint() - # _create_grasp_constraint already correctly computes - # and stores _held_obj_to_base_link. - @abc.abstractmethod - def _reset_custom_env_state(self, state: State) -> None: - """Hook for environment-specific resetting (colors, water, etc.). + def _set_domain_specific_state(self, state: State) -> None: + """Set simulator state for features that the base class doesn't handle. + + — e.g. switch on/off, liquid levels, button colors, balance beam + positions. - Subclasses can override or extend this if needed. + Called at the end of _set_state(), after the base class has + already set robot joints, object poses, and grasp constraints. + Subclasses must override. """ raise NotImplementedError("Override me!") - def _get_state(self, _render_obs: bool = False) -> State: - """Reads the PyBullet scene into a `State` (PyBulletState). It takes - care of: + def _extract_robot_state(self, state: State) -> Array: + """State -> robot array: extract robot features for PyBullet. - * robot features [x, y, z, tilt, wrist, fingers] - * object features [x, y, z, rot, is_held] - the other feature extractors should be implemented in the subclasses via - `_extract_feature`. + Converts the robot's features in a State into the array format + expected by self._pybullet_robot.reset_state() + (same format as self._pybullet_robot.get_state()). + + Called by _set_state() to position the robot. """ - state_dict: Dict[Object, Dict[str, float]] = {} - # --- 1) Robot --- - robot_state = self._get_robot_state_dict() - state_dict[self._robot] = robot_state + # EE Position + def get_pos_feature( + state: State, + feature_name: str) -> float: # type: ignore[no-untyped-def] + if feature_name in self._robot.type.feature_names: + return state.get(self._robot, feature_name) + if f"pose_{feature_name}" in self._robot.type.feature_names: + return state.get(self._robot, f"pose_{feature_name}") + raise ValueError(f"Cannot find robot pos '{feature_name}'") - # --- 2) Other Objects --- - for obj in self._objects: - if obj.type.name in ["robot"]: - continue + rx = get_pos_feature(state, "x") + ry = get_pos_feature(state, "y") + rz = get_pos_feature(state, "z") - obj_features = obj.type.feature_names - obj_dict = {} + # EE Orientation + _, default_tilt, default_wrist = p.getEulerFromQuaternion( + self.get_robot_ee_home_orn()) + if "tilt" in self._robot.type.feature_names: + tilt = state.get(self._robot, "tilt") + else: + tilt = default_tilt + if "wrist" in self._robot.type.feature_names: + wrist = state.get(self._robot, "wrist") + else: + wrist = default_wrist + qx, qy, qz, qw = p.getQuaternionFromEuler([0.0, tilt, wrist]) - if obj.type.name in ["loc", "angle", "human", "side", "direction"]: - for feature in obj_features: - obj_dict[feature] = self._extract_feature(obj, feature) - state_dict[obj] = obj_dict - continue + # Fingers + f = state.get(self._robot, "fingers") + f = self._fingers_state_to_joint(self._pybullet_robot, f) - # Basic features - try: - (px, py, pz), orn = p.getBasePositionAndOrientation( - obj.id, physicsClientId=self._physics_client_id) - except Exception as e: - raise RuntimeError(f"Failed to get pose for object {obj.name} " - f"(id={obj.id})") from e - if "x" in obj_features: - obj_dict["x"] = px - if "y" in obj_features: - obj_dict["y"] = py - if "z" in obj_features: - obj_dict["z"] = pz - if "rot" in obj_features or "yaw" in obj_features or \ - "roll" in obj_features or "pitch" in obj_features: - roll, pitch, yaw = p.getEulerFromQuaternion(orn) - if "rot" in obj_features: - obj_dict["rot"] = yaw - if "yaw" in obj_features: - obj_dict["yaw"] = yaw - if "roll" in obj_features: - obj_dict["roll"] = roll - if "pitch" in obj_features: - obj_dict["pitch"] = pitch - if "is_held" in obj_features: - obj_dict["is_held"] = 1.0 if obj.id == self._held_obj_id \ - else 0.0 - - if "r" in obj_features or "b" in obj_features or \ - "g" in obj_features: - # Note: also handle color_r, color_b, ... - visual_data = p.getVisualShapeData( - obj.id, physicsClientId=self._physics_client_id)[0] - (r, g, b, _a) = visual_data[7] - obj_dict["r"] = r - obj_dict["g"] = g - obj_dict["b"] = b - - # Additional features - for feature in obj_features: - if feature not in [ - "x", "y", "z", "rot", "yaw", "roll", "pitch", - "is_held", "r", "g", "b" - ]: - obj_dict[feature] = self._extract_feature(obj, feature) + return np.array([rx, ry, rz, qx, qy, qz, qw, f], dtype=np.float32) - state_dict[obj] = obj_dict + def _extract_robot_joint_positions( + self, state: State) -> Optional[JointPositions]: + """Pull arm joint positions out of a State's simulator_state. - # Convert to a PyBulletState - # try: - state = utils.create_state_from_dict(state_dict) - # except: - # breakpoint() - joint_positions = self._pybullet_robot.get_joints() + Returns None when the State doesn't carry them (plain State, or + a PyBulletState whose simulator_state has a different shape than + this robot's arm). Callers fall back to IK in that case. + """ + sim_state = getattr(state, "simulator_state", None) + jp: Any + if isinstance(sim_state, dict): + jp = sim_state.get("joint_positions") + elif sim_state is None: + return None + else: + # PyBulletState also accepts simulator_state passed as a raw + # joint-positions sequence (see PyBulletState.joint_positions + # and tests/envs/test_pybullet_blocks.py:69-70). + jp = sim_state + if jp is None: + return None + try: + jp_list = list(jp) + except TypeError: + return None + if len(jp_list) != len(self._pybullet_robot.arm_joints): + return None + return cast(JointPositions, jp_list) + + @classmethod + def _fingers_state_to_joint(cls, pybullet_robot: SingleArmPyBulletRobot, + finger_state: float) -> float: + """Map finger value in a State (e.g. open_fingers=0.04) to the + corresponding PyBullet joint position. + + Called by _extract_robot_state() when writing State -> PyBullet. + """ + # If open_fingers is undefined, use 1.0 as the default. + subs = { + cls.open_fingers: pybullet_robot.open_fingers, + cls.closed_fingers: pybullet_robot.closed_fingers, + } + match = min(subs, key=lambda k: abs(k - finger_state)) + return subs[match] + + # ── State Read (PyBullet → State) ─────────────────────────── + + # Features handled by _get_object_state_dict via PyBullet queries. + _PYBULLET_FEATURES: ClassVar[frozenset] = frozenset({ + "x", "y", "z", "rot", "yaw", "roll", "pitch", "is_held", "r", "g", "b" + }) + + def _get_state(self, _render_obs: bool = False) -> State: + """PyBullet -> State: read the simulator into a PyBulletState. + + Queries PyBullet for the current scene (joint positions, body + poses, visual data, etc.) and packs the values into the + agent-facing State representation. + + Handles common features (robot pose, object x/y/z/rot/is_held, + color); subclass-specific features are delegated to + `_get_domain_specific_feature`. + + Called by get_observation() (after reset/step) and by + _set_state() to verify reconstruction fidelity. + """ + state_dict: Dict[Object, Dict[str, float]] = {} + state_dict[self._robot] = self._get_robot_state_dict() + for obj in self._objects: + if obj.type.name == "robot": + continue + state_dict[obj] = self._get_object_state_dict(obj) + + state = utils.create_state_from_dict(state_dict) + joint_positions = self._pybullet_robot.get_joints() pyb_state = PyBulletState(state.data, simulator_state={ "joint_positions": joint_positions, @@ -562,237 +772,162 @@ def _get_state(self, _render_obs: bool = False) -> State: }) return pyb_state - @abc.abstractmethod - def _extract_feature(self, obj: Object, feature: str) -> float: - """Called in _get_state() to extract a feature from an object.""" - raise NotImplementedError("Override me!") - def _get_robot_state_dict(self) -> Dict[str, float]: - """Get dict state of the robot.""" - r_dict = {} + """Build a feature dict for the robot from PyBullet state. + + Called by _get_state() to populate the robot entry in the State. + Subclasses with non-standard robot features (e.g. cover's + normalized hand, blocks' pose_x/y/z) should override this. + """ + rx, ry, rz, qx, qy, qz, qw, rf = self._pybullet_robot.get_state() + r_dict: Dict[str, float] = {"x": rx, "y": ry, "z": rz, "fingers": rf} + _, tilt, wrist = p.getEulerFromQuaternion([qx, qy, qz, qw]) r_features = self._robot.type.feature_names - if CFG.env == "pybullet_cover": - rx, ry, rz, _, _, _, _, rf = self._pybullet_robot.get_state() - hand = (ry - self.y_lb) / (self.y_ub - self.y_lb) - r_dict.update({"hand": hand, "pose_x": rx, "pose_z": rz}) - elif CFG.env == "pybullet_blocks": - rx, ry, rz, _, _, _, _, rf = self._pybullet_robot.get_state() - fingers = self._fingers_joint_to_state(self._pybullet_robot, rf) - r_dict.update({ - "pose_x": rx, - "pose_y": ry, - "pose_z": rz, - "fingers": fingers - }) - else: - rx, ry, rz, qx, qy, qz, qw, rf = self._pybullet_robot.get_state() - r_dict.update({"x": rx, "y": ry, "z": rz, "fingers": rf}) - _, tilt, wrist = p.getEulerFromQuaternion([qx, qy, qz, qw]) - if "tilt" in r_features: - r_dict["tilt"] = tilt - if "wrist" in r_features: - r_dict["wrist"] = wrist + if "tilt" in r_features: + r_dict["tilt"] = tilt + if "wrist" in r_features: + r_dict["wrist"] = wrist return r_dict - def render(self, - action: Optional[Action] = None, - caption: Optional[str] = None) -> Video: # pragma: no cover - # Skip test coverage because GUI is too expensive to use in unit tests - # and cannot be used in headless mode. - del action, caption # unused - - view_matrix = p.computeViewMatrixFromYawPitchRoll( - cameraTargetPosition=self._camera_target, - distance=self._camera_distance, - yaw=self._camera_yaw, - pitch=self._camera_pitch, - roll=0, - upAxisIndex=2, - physicsClientId=self._physics_client_id) - - width = CFG.pybullet_camera_width - height = CFG.pybullet_camera_height - - proj_matrix = p.computeProjectionMatrixFOV( - fov=self._camera_fov, - aspect=float(width / height), - nearVal=0.1, - farVal=100.0, - physicsClientId=self._physics_client_id) - - (_, _, px, _, - _) = p.getCameraImage(width=width, - height=height, - viewMatrix=view_matrix, - projectionMatrix=proj_matrix, - renderer=p.ER_BULLET_HARDWARE_OPENGL, - physicsClientId=self._physics_client_id) - - rgb_array = np.array(px).reshape((height, width, 4)) - rgb_array = rgb_array[:, :, :3] - return [rgb_array] - - def render_segmented_obj( - self, - action: Optional[Action] = None, - caption: Optional[str] = None, - ) -> Tuple[Image.Image, Dict[Object, Mask]]: - """Render the scene and the segmented objects in the scene.""" - del action, caption # unused - # if not self.using_gui: - # raise Exception( - # "Rendering only works with GUI on. See " - # "https://github.com/bulletphysics/bullet3/issues/1157") + def _get_object_state_dict(self, obj: Object) -> Dict[str, float]: + """Build a feature dict for a single non-robot object. - view_matrix = p.computeViewMatrixFromYawPitchRoll( - cameraTargetPosition=self._camera_target, - distance=self._camera_distance, - yaw=self._camera_yaw, - pitch=self._camera_pitch, - roll=0, - upAxisIndex=2, - physicsClientId=self._physics_client_id) - - width = CFG.pybullet_camera_width - height = CFG.pybullet_camera_height - - proj_matrix = p.computeProjectionMatrixFOV( - fov=60, - aspect=float(width / height), - nearVal=0.1, - farVal=100.0, - physicsClientId=self._physics_client_id) - - # Initialize an empty dictionary - mask_dict: Dict[Object, Mask] = {} - - # Get the original image and segmentation mask - (_, _, rgbImg, _, - segImg) = p.getCameraImage(width=width, - height=height, - viewMatrix=view_matrix, - projectionMatrix=proj_matrix, - renderer=p.ER_BULLET_HARDWARE_OPENGL, - physicsClientId=self._physics_client_id) - - # Convert to numpy arrays - original_image: np.ndarray = np.array(rgbImg, dtype=np.uint8).reshape( - (height, width, 4)) - seg_image = np.array(segImg).reshape((height, width)) - - state_img = Image.fromarray( # type: ignore[no-untyped-call] - original_image[:, :, :3]) - - # Iterate over all bodies to be labeled - for obj in self._objects: - body_id = obj.id - mask = seg_image == body_id - mask_dict[obj] = mask + Virtual objects (loc, angle, etc.) delegate all features to + _get_domain_specific_feature. Physical objects get + pose/color/is_held from PyBullet; the rest are delegated. + """ + obj_features = obj.type.feature_names + obj_dict: Dict[str, float] = {} - return state_img, mask_dict + if obj.type.name in self._VIRTUAL_OBJECT_TYPES: + for feature in obj_features: + obj_dict[feature] = \ + self._get_domain_specific_feature(obj, feature) + return obj_dict + + # Physical object — query PyBullet for pose + try: + (px, py, pz), orn = p.getBasePositionAndOrientation( + obj.id, physicsClientId=self._physics_client_id) + except Exception as e: + raise RuntimeError(f"Failed to get pose for object {obj.name} " + f"(id={obj.id})") from e + if "x" in obj_features: + obj_dict["x"] = px + if "y" in obj_features: + obj_dict["y"] = py + if "z" in obj_features: + obj_dict["z"] = pz + + if {"rot", "yaw", "roll", "pitch"} & set(obj_features): + roll, pitch, yaw = p.getEulerFromQuaternion(orn) + if "rot" in obj_features: + obj_dict["rot"] = yaw + if "yaw" in obj_features: + obj_dict["yaw"] = yaw + if "roll" in obj_features: + obj_dict["roll"] = roll + if "pitch" in obj_features: + obj_dict["pitch"] = pitch + + if "is_held" in obj_features: + obj_dict["is_held"] = 1.0 if obj.id == self._held_obj_id else 0.0 + + if {"r", "g", "b"} & set(obj_features): + visual_data = p.getVisualShapeData( + obj.id, physicsClientId=self._physics_client_id)[0] + (r, g, b, _a) = visual_data[7] + obj_dict["r"] = r + obj_dict["g"] = g + obj_dict["b"] = b + + # Remaining features delegated to subclass + for feature in obj_features: + if feature not in self._PYBULLET_FEATURES: + obj_dict[feature] = \ + self._get_domain_specific_feature( + obj, feature) + + return obj_dict - def get_observation(self, render: bool = False) -> Observation: - """Get the current observation of this environment. + @abc.abstractmethod + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: + """Return a single feature value for a non-robot object. - Currently, this just return a copy of the state and optionally a - rendered image. + Called by _get_object_state_dict() for: + - All features of virtual objects (those in _VIRTUAL_OBJECT_TYPES) + - Non-standard features of physical objects (anything not in + _PYBULLET_FEATURES, e.g. is_on, growth, water_height) """ - self._current_observation = self._get_state() - assert isinstance(self._current_observation, PyBulletState) - state_copy = self._current_observation.copy() - - if render: - state_copy.add_images_and_masks(*self.render_segmented_obj()) + raise NotImplementedError("Override me!") - return state_copy + @classmethod + def _fingers_joint_to_state(cls, pybullet_robot: SingleArmPyBulletRobot, + finger_joint: float) -> float: + """Inverse of _fingers_state_to_joint(). - def step(self, action: Action, render_obs: bool = False) -> Observation: - """Execute one environment step with the given action. + Called by _get_robot_state_dict() when reading PyBullet -> + State. + """ + subs = { + pybullet_robot.open_fingers: cls.open_fingers, + pybullet_robot.closed_fingers: cls.closed_fingers, + } + match = min(subs, key=lambda k: abs(k - finger_joint)) + return subs[match] - This method handles: - 1. Robot joint control by converting action to target positions - 2. Management of held objects and grasping constraints - 3. Physics simulation stepping - 4. Object grasp detection and constraint creation/removal - 5. `self._current_observation` update + # ── Grasp Detection & Constraint Management ───────────────── - Args: - action (Action): The action to execute, containing target joint - positions - render_obs (bool, optional): Whether to include RGB observation. - Defaults to False. + @abc.abstractmethod + def _get_object_ids_for_held_check(self) -> List[int]: + """Return PyBullet body IDs of objects that can be grasped. - Returns: - Observation: Updated environment observation after executing the - action. May include an image if render_obs=True or - CFG.rgb_observation=True. + Called by _detect_held_object() (inside step()) to decide which + bodies to check for finger contact. Subclasses return only the + IDs of graspable objects (e.g. blocks, not tables). """ - # Send the action to the robot. - target_joint_positions, base_delta = self._split_action(action) - if base_delta.size: - self._apply_base_delta(base_delta) - self._pybullet_robot.set_motors(target_joint_positions.tolist()) + raise NotImplementedError("Override me!") - # If we are setting the robot joints directly, and if there is a held - # object, we need to reset the pose of the held object directly. This - # is because the PyBullet constraints don't seem to play nicely with - # resetJointState (the robot will sometimes drop the object). - if CFG.pybullet_control_mode == "reset" and \ - self._held_obj_id is not None: - world_to_base_link = get_link_state( - self._pybullet_robot.robot_id, - self._pybullet_robot.end_effector_id, - physics_client_id=self._physics_client_id).com_pose - base_link_to_held_obj = p.invertTransform( - *self._held_obj_to_base_link) - world_to_held_obj = p.multiplyTransforms(world_to_base_link[0], - world_to_base_link[1], - base_link_to_held_obj[0], - base_link_to_held_obj[1]) - p.resetBasePositionAndOrientation( - self._held_obj_id, - world_to_held_obj[0], - world_to_held_obj[1], - physicsClientId=self._physics_client_id) + def _get_expected_finger_normals(self) -> Dict[int, Array]: + """Compute the expected inward-facing normal for each finger. - # Step the simulation here before adding or removing constraints - # because detect_held_object() should use the updated state. - if CFG.pybullet_control_mode != "reset": - for _ in range(CFG.pybullet_sim_steps_per_action): - p.stepSimulation(physicsClientId=self._physics_client_id) + Called by _detect_held_object() to distinguish objects between + the fingers (valid grasp) from objects touching the outside. + """ + _rx, _ry, _rz, qx, qy, qz, qw, _rf = self._pybullet_robot.get_state() - # If not currently holding something, and fingers are closing, check - # for a new grasp. - if self._held_constraint_id is None and self._fingers_closing(action): - # logging.debug("Finger closing") - # Detect if an object is held. If so, create a grasp constraint. - self._held_obj_id = self._detect_held_object() - # logging.debug(f"Detected held object: {self._held_obj_id}") - # breakpoint() - if self._held_obj_id is not None: - self._create_grasp_constraint() + # Convert the quaternion to a rotation matrix + rotation_matrix = p.getMatrixFromQuaternion([qx, qy, qz, qw]) + rotation_matrix = np.array(rotation_matrix).reshape(3, 3) - # If placing, remove the grasp constraint. - if self._held_constraint_id is not None and \ - self._fingers_opening(action): - p.removeConstraint(self._held_constraint_id, - physicsClientId=self._physics_client_id) - self._held_constraint_id = None - # logging.debug("Finger opening") - self._held_obj_id = None + # Define the initial normal vectors for the fingers + if CFG.pybullet_robot == "panda": + # gripper rotated 90deg so parallel to x-axis + normal = np.array([1., 0., 0.], dtype=np.float32) + elif CFG.pybullet_robot in {"fetch", "mobile_fetch"}: + # gripper parallel to y-axis + normal = np.array([0., 1., 0.], dtype=np.float32) + else: # pragma: no cover + # Shouldn't happen unless we introduce a new robot. + raise ValueError(f"Unknown robot {CFG.pybullet_robot}") - # Depending on the observation mode, either return object-centric state - # or object_centric + rgb observation - observation = self.get_observation(render=CFG.rgb_observation or\ - render_obs) + # Transform the normal vectors using the rotation matrix + transformed_normal = rotation_matrix.dot(normal) + transformed_normal_neg = rotation_matrix.dot(-1 * normal) - return observation + return { + self._pybullet_robot.left_finger_id: transformed_normal, + self._pybullet_robot.right_finger_id: transformed_normal_neg, + } def _detect_held_object(self) -> Optional[int]: - """Return the PyBullet object ID of the held object if one exists. + """Return the PyBullet body ID of the grasped object, or None. - If multiple objects are within the grasp tolerance, return the - one that is closest. + Called by step() when fingers are closing and no object is + currently held. Checks contact between each finger and every + graspable body (from _get_object_ids_for_held_check()), using + contact-normal alignment to reject touches on the outside of the + gripper. If multiple objects qualify, returns the closest. """ expected_finger_normals = self._get_expected_finger_normals() closest_held_obj = None @@ -836,6 +971,12 @@ def _detect_held_object(self) -> Optional[int]: return closest_held_obj def _create_grasp_constraint(self) -> None: + """Create a fixed PyBullet constraint between the end-effector and + _held_obj_id so the object moves with the gripper. + + Called by step() after _detect_held_object() finds a grasp, and + by _reset_single_object() when restoring a held state. + """ assert self._held_obj_id is not None base_link_to_world = np.r_[p.invertTransform( *p.getLinkState(self._pybullet_robot.robot_id, @@ -860,32 +1001,50 @@ def _create_grasp_constraint(self) -> None: physicsClientId=self._physics_client_id) def _fingers_closing(self, action: Action) -> bool: - """Check whether this action is working toward closing the fingers.""" + """True if this action's finger target is below current position. + + Called by step() to decide whether to check for a new grasp. + """ f_delta = self._action_to_finger_delta(action) return f_delta < -self._finger_action_tol def _fingers_opening(self, action: Action) -> bool: - """Check whether this action is working toward opening the fingers.""" + """True if this action's finger target is above current position. + + Called by step() to decide whether to release a held object. + """ f_delta = self._action_to_finger_delta(action) - # logging.debug(f"Finger delta: {f_delta}") return f_delta > self._finger_action_tol def _get_finger_position(self, state: State) -> float: - # Arbitrarily use the left finger as reference. + """Return the current left-finger joint position from state. + + Called by _action_to_finger_delta() to compute the delta between + current and target finger positions. + """ state = cast(utils.PyBulletState, state) finger_joint_idx = self._pybullet_robot.left_finger_joint_idx return state.joint_positions[finger_joint_idx] def _action_to_finger_delta(self, action: Action) -> float: + """Compute (target - current) finger joint position. + + Called by _fingers_closing() and _fingers_opening(). + """ assert isinstance(self._current_observation, State) finger_position = self._get_finger_position(self._current_observation) joint_positions, _ = self._split_action(action) target = joint_positions[self._pybullet_robot.left_finger_joint_idx] - # logging.debug(f"Finger position: {finger_position}, target: {target}") return target - finger_position + # ── Action Helpers ────────────────────────────────────────── + def _split_action(self, action: Action) -> Tuple[np.ndarray, np.ndarray]: - """Split an action into joint targets and an optional base delta.""" + """Split an action into (arm_joint_targets, base_delta). + + Called by step() and _action_to_finger_delta(). For robots + without a mobile base, base_delta is an empty array. + """ action_arr = action.arr base_dim = int(getattr(self._pybullet_robot, "base_action_dim", 0)) if base_dim > 0: @@ -901,7 +1060,10 @@ def _split_action(self, action: Action) -> Tuple[np.ndarray, np.ndarray]: return action_arr, np.zeros(0, dtype=action_arr.dtype) def _apply_base_delta(self, base_delta: np.ndarray) -> None: - """Apply a delta (dx, dy, dtheta) to the robot base if supported.""" + """Apply a delta (dx, dy, dtheta) to the robot base. + + Called by step() for mobile robots (e.g. mobile_fetch). + """ robot = self._pybullet_robot assert hasattr(robot, 'get_base_pose'), \ "Robot does not support base pose operations" @@ -916,181 +1078,134 @@ def _apply_base_delta(self, base_delta: np.ndarray) -> None: ) robot.set_base_pose(new_pose) # type: ignore[attr-defined] + # ── Rendering & Observation ───────────────────────────────── + + def _get_camera_matrices(self) -> Tuple[Any, Any, int, int]: + """Return (view_matrix, proj_matrix, width, height) for rendering. + + Called by render() and render_segmented_obj(). + """ + view_matrix = p.computeViewMatrixFromYawPitchRoll( + cameraTargetPosition=self._camera_target, + distance=self._camera_distance, + yaw=self._camera_yaw, + pitch=self._camera_pitch, + roll=0, + upAxisIndex=2, + physicsClientId=self._physics_client_id) + width = CFG.pybullet_camera_width + height = CFG.pybullet_camera_height + proj_matrix = p.computeProjectionMatrixFOV( + fov=self._camera_fov, + aspect=float(width / height), + nearVal=0.1, + farVal=100.0, + physicsClientId=self._physics_client_id) + return view_matrix, proj_matrix, width, height + + def render(self, + action: Optional[Action] = None, + caption: Optional[str] = None) -> Video: # pragma: no cover + # Skip test coverage because GUI is too expensive to use in unit tests + # and cannot be used in headless mode. + del action, caption # unused + view_matrix, proj_matrix, width, height = self._get_camera_matrices() + (_, _, px, _, + _) = p.getCameraImage(width=width, + height=height, + viewMatrix=view_matrix, + projectionMatrix=proj_matrix, + renderer=p.ER_BULLET_HARDWARE_OPENGL, + physicsClientId=self._physics_client_id) + rgb_array = np.array(px).reshape((height, width, 4)) + rgb_array = rgb_array[:, :, :3] + return [rgb_array] + + def render_segmented_obj( + self, + action: Optional[Action] = None, + caption: Optional[str] = None, + ) -> Tuple[Image.Image, Dict[Object, Mask]]: + """Render the scene and return per-object segmentation masks. + + Called by get_observation(render=True) to attach RGB images and + masks to the observation (used for VLM predicate grounding). + """ + del action, caption # unused + view_matrix, proj_matrix, width, height = self._get_camera_matrices() + (_, _, rgbImg, _, + segImg) = p.getCameraImage(width=width, + height=height, + viewMatrix=view_matrix, + projectionMatrix=proj_matrix, + renderer=p.ER_BULLET_HARDWARE_OPENGL, + physicsClientId=self._physics_client_id) + original_image: np.ndarray = np.array(rgbImg, dtype=np.uint8).reshape( + (height, width, 4)) + seg_image = np.array(segImg).reshape((height, width)) + state_img = Image.fromarray( # type: ignore[no-untyped-call] + original_image[:, :, :3]) + mask_dict: Dict[Object, Mask] = {} + for obj in self._objects: + mask_dict[obj] = (seg_image == obj.id) + return state_img, mask_dict + + def render_state_plt( + self, + state: State, + task: EnvironmentTask, + action: Optional[Action] = None, + caption: Optional[str] = None) -> matplotlib.figure.Figure: + raise NotImplementedError("This env does not use Matplotlib") + + def render_state(self, + state: State, + task: EnvironmentTask, + action: Optional[Action] = None, + caption: Optional[str] = None) -> Video: + raise NotImplementedError("A PyBullet environment cannot render " + "arbitrary states.") + + def get_observation(self, render: bool = False) -> Observation: + """Get the current observation of this environment. + + Reads the current state from pybullet, updates + _current_observation (the backing field), and returns a copy + optionally with rendered images. + """ + state = self._get_state() + assert isinstance(state, PyBulletState) + self._current_observation = state + obs = state.copy() + + if render: + obs.add_images_and_masks(*self.render_segmented_obj()) + + return obs + + # ── Task Utilities ────────────────────────────────────────── + def _add_pybullet_state_to_tasks( self, tasks: List[EnvironmentTask]) -> List[EnvironmentTask]: - """Converts the task initial states into PyBulletStates. + """Convert plain-State tasks into PyBulletState tasks. - This is used in generating tasks. + Called by _generate_train/test_tasks() in subclasses. Sets up + the simulator for each task's init state so that joint positions + and (optionally) rendered images are captured into the task. """ pybullet_tasks = [] for task in tasks: # Reset the robot. init = task.init - # Extract the joints. - # YC: Probably need to reset_state here so I can then get an - # observation, would it work without the reset_state? - # Attempt 2: First reset it. - self._current_observation = init - self._reset_state(init) + self._set_state(init) # Cast _current_observation from type State to PybulletState joint_positions = self._pybullet_robot.get_joints() self._current_observation = utils.PyBulletState( init.data.copy(), simulator_state=joint_positions) - # Attempt 1: Let's try to get a rendering directly first pybullet_init = self.get_observation(render=CFG.render_init_state) - pybullet_init.option_history = [ - ] # useful for vlm predicate grounding - # # + pybullet_init.option_history = [] pybullet_task = EnvironmentTask(pybullet_init, task.goal, goal_nl=task.goal_nl) pybullet_tasks.append(pybullet_task) return pybullet_tasks - - @classmethod - def get_robot_ee_home_orn(cls) -> Quaternion: - """Public for use by oracle options.""" - robot_ee_orns = CFG.pybullet_robot_ee_orns[cls.get_name()] - return robot_ee_orns[CFG.pybullet_robot] - - -def create_pybullet_block( - color: Tuple[float, float, float, float], - half_extents: Tuple[float, float, float], - mass: float, - friction: float, - position: Pose3D = (0.0, 0.0, 0.0), - orientation: Quaternion = (0.0, 0.0, 0.0, 1.0), - physics_client_id: int = 0, - add_top_triangle: bool = False, -) -> int: - """A generic utility for creating a new block. - - Returns the PyBullet ID of the newly created block. - """ - # The poses here are not important because they are overwritten by - - # Create the collision shape. - collision_id = p.createCollisionShape(p.GEOM_BOX, - halfExtents=half_extents, - physicsClientId=physics_client_id) - - # Create the visual_shape. - visual_id = p.createVisualShape(p.GEOM_BOX, - halfExtents=half_extents, - rgbaColor=color, - physicsClientId=physics_client_id) - - # Create the body. - block_id = p.createMultiBody(baseMass=mass, - baseCollisionShapeIndex=collision_id, - baseVisualShapeIndex=visual_id, - basePosition=position, - baseOrientation=orientation, - physicsClientId=physics_client_id) - p.changeDynamics( - block_id, - linkIndex=-1, # -1 for the base - lateralFriction=friction, - spinningFriction=friction, - rollingFriction=friction, - physicsClientId=physics_client_id) - - if add_top_triangle: - # 1. Create the triangle's visual shape - triangle_size = min(half_extents[0], half_extents[1]) - triangle_vertices = [ - [triangle_size, 0, 0], # Tip pointing in +X - [-triangle_size, triangle_size, 0], # Back left - [-triangle_size, -triangle_size, 0] # Back right - ] - triangle_visual_id = p.createVisualShape( - p.GEOM_MESH, - vertices=triangle_vertices, - indices=[0, 1, 2], # <-- FIX: Added this line - rgbaColor=[1, 1, 0, - 1], # <-- CHANGE: Set to yellow (R=1, G=1, B=0, A=1) - physicsClientId=physics_client_id) - - # 2. Re-create the body, but this time WITH a link for the triangle - p.removeBody( - block_id, - physicsClientId=physics_client_id) # Remove the old simple block - - block_id = p.createMultiBody( - baseMass=mass, - baseCollisionShapeIndex=collision_id, - baseVisualShapeIndex=visual_id, - basePosition=position, - baseOrientation=orientation, - # --- Link Parameters for the Triangle --- - linkMasses=[0], # Massless link - linkCollisionShapeIndices=[-1], # No collision for the link - linkVisualShapeIndices=[triangle_visual_id - ], # Visual shape for the link - # Position the link's origin on top of the block's base - linkPositions=[[0, 0, half_extents[2] + 0.001]], - linkOrientations=[[0, 0, 0, 1]], # No relative rotation - linkInertialFramePositions=[[0, 0, 0]], - linkInertialFrameOrientations=[[0, 0, 0, 1]], - linkParentIndices=[0], # Link is attached to the base (index 0) - linkJointTypes=[p.JOINT_FIXED], # Link is fixed to the base - linkJointAxis=[[0, 0, - 1]], # Axis for the joint (not relevant for fixed) - physicsClientId=physics_client_id) - - # Re-apply dynamics to the new multi-body object - p.changeDynamics( - block_id, - linkIndex=-1, # -1 for the base - lateralFriction=friction, - spinningFriction=friction, - physicsClientId=physics_client_id) - - return block_id - - -def create_pybullet_sphere( - color: Tuple[float, float, float, float], - radius: float, - mass: float, - friction: float, - position: Pose3D = (0.0, 0.0, 0.0), - orientation: Quaternion = (0.0, 0.0, 0.0, 1.0), - physics_client_id: int = 0, -) -> int: - """A generic utility for creating a new sphere. - - Returns the PyBullet ID of the newly created sphere. - """ - # Create the collision shape. - collision_id = p.createCollisionShape(p.GEOM_SPHERE, - radius=radius, - physicsClientId=physics_client_id) - - # Create the visual shape. - visual_id = p.createVisualShape(p.GEOM_SPHERE, - radius=radius, - rgbaColor=color, - physicsClientId=physics_client_id) - - # Create the body. - sphere_id = p.createMultiBody(baseMass=mass, - baseCollisionShapeIndex=collision_id, - baseVisualShapeIndex=visual_id, - basePosition=position, - baseOrientation=orientation, - physicsClientId=physics_client_id) - p.changeDynamics( - sphere_id, - linkIndex=-1, # -1 for the base - lateralFriction=friction, - spinningFriction=friction, - physicsClientId=physics_client_id) - - return sphere_id diff --git a/predicators/envs/pybullet_fan.py b/predicators/envs/pybullet_fan.py index d4acbdfec..7876d9cdd 100644 --- a/predicators/envs/pybullet_fan.py +++ b/predicators/envs/pybullet_fan.py @@ -6,10 +6,10 @@ import pybullet as p from predicators import utils -from predicators.envs.pybullet_env import PyBulletEnv, create_pybullet_block, \ - create_pybullet_sphere +from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion -from predicators.pybullet_helpers.objects import create_object, update_object +from predicators.pybullet_helpers.objects import create_object, \ + create_pybullet_block, create_pybullet_sphere, update_object from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ @@ -257,7 +257,7 @@ def get_configuration_dict(cls) -> Dict[str, Any]: # ------------------------------------------------------------------------- # Environment initialization # ------------------------------------------------------------------------- - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: self._robot = Object("robot", self._robot_type) # Fans - create one fan object per side instead of multiple @@ -300,7 +300,7 @@ def __init__(self, use_gui: bool = False) -> None: # Target self._target = Object("target", self._target_type) - super().__init__(use_gui=use_gui) + super().__init__(use_gui=use_gui, **kwargs) # Define new predicates if desired self._FanOn = Predicate( @@ -610,7 +610,7 @@ def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None: self._target.id = pybullet_bodies["target_id"] # Initialize boundary wall IDs list (will be populated - # in _reset_custom_env_state) + # in _set_domain_specific_state) # pylint: disable=attribute-defined-outside-init self._boundary_wall_ids: List[int] = [] @@ -620,10 +620,7 @@ def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None: def _get_object_ids_for_held_check(self) -> List[int]: return [] - def _create_task_specific_objects(self, state: State) -> None: - pass - - def _reset_custom_env_state(self, state: State) -> None: + def _set_domain_specific_state(self, state: State) -> None: for switch_obj in self._switches: is_on_val = state.get(switch_obj, "is_on") self._set_switch_on(switch_obj.id, bool(is_on_val > 0.5)) @@ -838,7 +835,7 @@ def _position_fans_on_sides(self) -> None: orientation=p.getQuaternionFromEuler(rot), physics_client_id=self._physics_client_id) - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" if obj.type == self._fan_type: if feature == "facing_side": @@ -875,18 +872,12 @@ def _extract_feature(self, obj: Object, feature: str) -> float: # ------------------------------------------------------------------------- # Step # ------------------------------------------------------------------------- - def step( # pylint: disable=redefined-outer-name - self, - action: Action, - render_obs: bool = False) -> State: - """Execute a low-level action, then spin fans & blow the ball.""" - super().step(action, render_obs=render_obs) + def _domain_specific_step(self) -> None: + """Spin fans & blow the ball.""" self._simulate_fans() - final_state = self._get_state() - self._current_observation = final_state + state = self._get_state() # Draw a debug line at the ball's position - bx, by = final_state.get(self._ball, - "x"), final_state.get(self._ball, "y") + bx, by = state.get(self._ball, "x"), state.get(self._ball, "y") p.addUserDebugLine( [bx, by, self.table_height], [bx, by, self.table_height + self.debug_line_height], @@ -894,7 +885,6 @@ def step( # pylint: disable=redefined-outer-name lifeTime=self. debug_line_lifetime, # short lifetime so each step refreshes physicsClientId=self._physics_client_id) - return final_state # ------------------------------------------------------------------------- # Fan Simulation @@ -1633,7 +1623,7 @@ def _has_valid_path(self, start_pos: Tuple[int, CFG.fan_train_num_walls_per_task, _rng) for _task in _tasks: - env._reset_state(_task.init) # pylint: disable=protected-access + env._set_state(_task.init) # pylint: disable=protected-access for _ in range(5000): _action = Action( np.array(env._pybullet_robot # pylint: disable=protected-access diff --git a/predicators/envs/pybullet_float.py b/predicators/envs/pybullet_float.py index fef0830d3..3e566609e 100644 --- a/predicators/envs/pybullet_float.py +++ b/predicators/envs/pybullet_float.py @@ -13,10 +13,10 @@ import pybullet as p from predicators import utils -from predicators.envs.pybullet_env import PyBulletEnv, create_pybullet_block +from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion from predicators.pybullet_helpers.objects import create_object, \ - sample_collision_free_2d_positions, update_object + create_pybullet_block, sample_collision_free_2d_positions, update_object from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ Predicate, State, Type @@ -120,7 +120,7 @@ class PyBulletFloatEnv(PyBulletEnv): _block_type = Type("block", ["x", "y", "z", "in_water", "is_held"], sim_features=["id", "is_light"]) - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: self._robot = Object("robot", self._robot_type) self._vessel = Object("vessel", self._vessel_type) self._block0 = Object("block0", self._block_type) @@ -128,7 +128,7 @@ def __init__(self, use_gui: bool = False) -> None: self._block2 = Object("block2", self._block_type) self._blocks = [self._block0, self._block1, self._block2] - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) self._InWater = Predicate("InWater", [self._block_type], self._InWater_holds) @@ -229,10 +229,7 @@ def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None: def _get_object_ids_for_held_check(self) -> List[int]: return [block_obj.id for block_obj in self._blocks] - def _create_task_specific_objects(self, state: State) -> None: - pass - - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" if obj.type == self._block_type: # if feature == "is_light": @@ -255,10 +252,11 @@ def _extract_feature(self, obj: Object, feature: str) -> float: return self._current_water_height raise ValueError(f"Unknown feature {feature} for object {obj}") - def _reset_custom_env_state(self, state: State) -> None: - - # Initialize water level + def _set_domain_specific_state(self, state: State) -> None: + """Set water height and redraw water bodies, block colors, and + displacement tracking.""" self._current_water_height = state.get(self._vessel, "water_height") + # Clear old water for wid in self._water_ids.values(): if wid is not None: @@ -267,17 +265,9 @@ def _reset_custom_env_state(self, state: State) -> None: # Reset blocks for blk in self._blocks: - # Set block's color based on is_light - # update_object(blk.id, - # color=PyBulletFloatEnv.block_color_light \ - # if state.get(blk, "is_light") > 0.5 - # else PyBulletFloatEnv.block_color_heavy, - # physics_client_id=self._physics_client_id) - # Set block's color randomly update_object(blk.id, color=self._train_rng.choice(self._obj_colors), physics_client_id=self._physics_client_id) - # Re-initialize displacing to False self._block_is_displacing[blk] = False # Re-draw water @@ -293,21 +283,13 @@ def _reset_custom_env_state(self, state: State) -> None: color=[0.5, 0.5, 1, 0.5], physics_client_id=self._physics_client_id) - def step( # pylint: disable=redefined-outer-name - self, - action: Action, - render_obs: bool = False) -> State: - next_state = super().step(action, render_obs=render_obs) - # Check if blocks entering/exiting water changed its level - changed = self._update_water_level_if_needed(next_state) + def _domain_specific_step(self) -> None: + """Update water level and float light blocks.""" + state = self._get_state() + changed = self._update_water_level_if_needed(state) if changed: self._create_or_update_water(force_redraw=True) - # Keep light blocks floating on water surface - self._float_light_blocks(next_state) - - final_state = self._get_state() - self._current_observation = final_state - return final_state + self._float_light_blocks(state) def _float_light_blocks(self, state: State) -> None: """Force each light, unheld block in a container compartment to float @@ -617,7 +599,7 @@ def _make_tasks(self, num_tasks: int, CFG.pybullet_sim_steps_per_action = 1 env = PyBulletFloatEnv(use_gui=True) task = env._make_tasks(1, np.random.default_rng(0))[0] # pylint: disable=protected-access - env._reset_state(task.init) # pylint: disable=protected-access + env._set_state(task.init) # pylint: disable=protected-access while True: action = Action(np.array(env._pybullet_robot.initial_joint_positions)) # pylint: disable=protected-access diff --git a/predicators/envs/pybullet_grow.py b/predicators/envs/pybullet_grow.py index e1bc394a0..2d4f2f9ed 100644 --- a/predicators/envs/pybullet_grow.py +++ b/predicators/envs/pybullet_grow.py @@ -14,9 +14,10 @@ from predicators import utils from predicators.envs.pybullet_coffee import PyBulletCoffeeEnv -from predicators.envs.pybullet_env import PyBulletEnv, create_pybullet_block +from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion -from predicators.pybullet_helpers.objects import create_object, update_object +from predicators.pybullet_helpers.objects import create_object, \ + create_pybullet_block, update_object from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ @@ -109,7 +110,7 @@ class PyBulletGrowEnv(PyBulletEnv): _jug_type = Type("jug", ["x", "y", "z", "rot", "is_held", "r", "g", "b"], sim_features=["id", "init_x", "init_y", "init_z"]) - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: # Create the single robot Object self._robot = Object("robot", self._robot_type) @@ -132,7 +133,7 @@ def __init__(self, use_gui: bool = False) -> None: # For tracking the "liquid bodies" we create for each cup self._cup_to_liquid_id: Dict[Object, Optional[int]] = {} - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) # Define Predicates self._Grown = Predicate("Grown", [self._cup_type], self._Grown_holds) @@ -265,10 +266,7 @@ def _get_object_ids_for_held_check(self) -> List[int]: jug_ids = [jug.id for jug in self._jugs if jug.id is not None] return jug_ids - def _create_task_specific_objects(self, state: State) -> None: - """No extra objects to create beyond cups and jugs.""" - - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" # For growth, we look up the height of the liquid body if obj.type == self._cup_type and feature == "growth": @@ -285,9 +283,29 @@ def _extract_feature(self, obj: Object, feature: str) -> float: raise ValueError(f"Unknown feature {feature} for object {obj}") - def _reset_custom_env_state(self, state: State) -> None: - """Called in _reset_state to handle any custom resetting.""" - # Remove existing "liquid bodies" + def _set_domain_specific_state(self, state: State) -> None: + """Set out-of-view positioning, jug init positions, liquid bodies, and + cup/jug colors.""" + cups = state.get_objects(self._cup_type) + jugs = state.get_objects(self._jug_type) + + # Store jug initial positions + for jug in jugs: + jug.init_x = state.get(jug, "x") + jug.init_y = state.get(jug, "y") + jug.init_z = state.get(jug, "z") + + oov_x, oov_y = self._out_of_view_xy + for i in range(len(cups), len(self._cups)): + update_object(self._cups[i].id, + position=(oov_x, oov_y, 0.0), + physics_client_id=self._physics_client_id) + for i in range(len(jugs), len(self._jugs)): + update_object(self._jugs[i].id, + position=(oov_x, oov_y, 0.0), + physics_client_id=self._physics_client_id) + + # Remove existing liquid bodies for liquid_id in self._cup_to_liquid_id.values(): if liquid_id is not None: p.removeBody(liquid_id, @@ -295,13 +313,11 @@ def _reset_custom_env_state(self, state: State) -> None: self._cup_to_liquid_id.clear() # Recreate the liquid bodies as needed - cups = state.get_objects(self._cup_type) for cup in cups: liquid_id = self._create_pybullet_liquid_for_cup(cup, state) self._cup_to_liquid_id[cup] = liquid_id - # Also update the PyBullet color on each cup/jug to match the (r,g,b) in - # the state + # Update colors for cup in cups: if cup.id is not None: r = state.get(cup, "r") @@ -310,7 +326,6 @@ def _reset_custom_env_state(self, state: State) -> None: update_object(cup.id, color=(r, g, b, 1.0), physics_client_id=self._physics_client_id) - jugs = state.get_objects(self._jug_type) for jug in jugs: if jug.id is not None: r = state.get(jug, "r") @@ -319,34 +334,14 @@ def _reset_custom_env_state(self, state: State) -> None: update_object(jug.id, color=(r, g, b, 1.0), physics_client_id=self._physics_client_id) - # set the sim_feature position to the initial position - jug.init_x = state.get(jug, "x") - jug.init_y = state.get(jug, "y") - jug.init_z = state.get(jug, "z") - - oov_x, oov_y = self._out_of_view_xy - for i in range(len(cups), len(self._cups)): - update_object(self._cups[i].id, - position=(oov_x, oov_y, 0.0), - physics_client_id=self._physics_client_id) - for i in range(len(jugs), len(self._jugs)): - update_object(self._jugs[i].id, - position=(oov_x, oov_y, 0.0), - physics_client_id=self._physics_client_id) # ------------------------------------------------------------------------- # Pouring logic - def step(self, action: Action, render_obs: bool = False) -> State: - """Let parent handle the robot stepping, then apply custom pouring - logic.""" - next_state = super().step(action, render_obs=render_obs) - - self._handle_pouring(next_state) - - final_state = self._get_state() - self._current_observation = final_state.copy() - return final_state + def _domain_specific_step(self) -> None: + """Apply custom pouring logic.""" + state = self._get_state() + self._handle_pouring(state) def _handle_pouring(self, state: State) -> None: if self._held_obj_id is None: @@ -724,7 +719,7 @@ def _create_pybullet_liquid_for_cup( _rng = np.random.default_rng(CFG.seed) _task = env._get_tasks( # pylint: disable=protected-access 1, CFG.grow_num_cups_test, CFG.grow_num_jugs_test, _rng)[0] - env._reset_state(_task.init) # pylint: disable=protected-access + env._set_state(_task.init) # pylint: disable=protected-access while True: # Robot does nothing diff --git a/predicators/envs/pybullet_laser.py b/predicators/envs/pybullet_laser.py index 6a71cda18..0639de35a 100644 --- a/predicators/envs/pybullet_laser.py +++ b/predicators/envs/pybullet_laser.py @@ -121,7 +121,7 @@ class PyBulletLaserEnv(PyBulletEnv): ["x", "y", "z", "rot", "split_mirror", "is_held"]) _target_type = Type("target", ["x", "y", "z", "rot", "is_hit"]) - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: # Create environment objects (logic-level) self._robot = Object("robot", self._robot_type) self._station = Object("station", self._station_type) @@ -140,7 +140,7 @@ def __init__(self, use_gui: bool = False) -> None: ] # Initialize PyBullet - super().__init__(use_gui=use_gui) + super().__init__(use_gui=use_gui, **kwargs) # Define predicates # Example: "StationOn" checks whether the station is toggled on @@ -282,14 +282,11 @@ def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None: # ------------------------------------------------------------------------- # State Reading/Writing # ------------------------------------------------------------------------- - def _create_task_specific_objects(self, state: State) -> None: - pass - def _get_object_ids_for_held_check(self) -> List[int]: """Return IDs of wires (assuming the robot can pick them up).""" return [m.id for m in self._normal_mirrors + self._split_mirrors] - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" if obj.type == self._station_type: if feature == "is_on": @@ -302,18 +299,11 @@ def _extract_feature(self, obj: Object, feature: str) -> float: return 1.0 if self._is_target_hit(obj) else 0.0 raise ValueError(f"Unknown feature {feature} for object {obj}") - def _reset_custom_env_state(self, state: State) -> None: + def _set_domain_specific_state(self, state: State) -> None: + """Set target/mirror positioning, station switch, and remove old laser + beams.""" oov_x, oov_y = self._out_of_view_xy - lasers_copy = _laser_ids.copy() - for beam_id, creation_time, client_id in lasers_copy: - p.removeBody(beam_id, physicsClientId=client_id) - # Remove the beam from the list - _laser_ids.remove((beam_id, creation_time, client_id)) - logging.debug(f"[reset] removing beam_id: {beam_id} " - f"in sim{client_id}, remaining beams " - f"{[bid for bid, _, _ in _laser_ids]}") - # Move targets out of view if needed target_objs = state.get_objects(self._target_type) for i in range(len(target_objs), len(self._targets)): @@ -344,27 +334,29 @@ def _reset_custom_env_state(self, state: State) -> None: switch_on = state.get(self._station, "is_on") > 0.5 self._set_station_powered_on(switch_on) + lasers_copy = _laser_ids.copy() + for beam_id, creation_time, client_id in lasers_copy: + p.removeBody(beam_id, physicsClientId=client_id) + _laser_ids.remove((beam_id, creation_time, client_id)) + logging.debug(f"[reset] removing beam_id: {beam_id} " + f"in sim{client_id}, remaining beams " + f"{[bid for bid, _, _ in _laser_ids]}") + # ------------------------------------------------------------------------- # Step # ------------------------------------------------------------------------- - def step(self, action: Action, render_obs: bool = False) -> State: - next_state = super().step(action, render_obs=render_obs) - - # After any motion, we simulate the laser - self._simulate_laser(next_state) + def _domain_specific_step(self) -> None: + state = self._get_state() + self._simulate_laser(state) lasers_copy = _laser_ids.copy() for beam_id, creation_time, client_id in lasers_copy: if time.time() - creation_time > self._laser_life_time: p.removeBody(beam_id, physicsClientId=client_id) - # Remove the beam from the list _laser_ids.remove((beam_id, creation_time, client_id)) logging.debug(f"[step] removing beam_id: {beam_id} " f"in sim{client_id}, remaining beams " f"{[bid for bid, _, _ in _laser_ids]}") - final_state = self._get_state() - self._current_observation = final_state - return final_state # ------------------------------------------------------------------------- # Laser Simulation @@ -822,7 +814,7 @@ def create_laser_cylinder(start: Any, CFG.laser_zero_reflection_angle = True env = PyBulletLaserEnv(use_gui=True) task = env._make_tasks(1, np.random.default_rng(CFG.seed), True)[0] # pylint: disable=protected-access - env._reset_state(task.init) # pylint: disable=protected-access + env._set_state(task.init) # pylint: disable=protected-access while True: # Robot does nothing diff --git a/predicators/envs/pybullet_magic_bin.py b/predicators/envs/pybullet_magic_bin.py index 583fe1294..aec2d27a0 100644 --- a/predicators/envs/pybullet_magic_bin.py +++ b/predicators/envs/pybullet_magic_bin.py @@ -16,9 +16,10 @@ import pybullet as p from predicators import utils -from predicators.envs.pybullet_env import PyBulletEnv, create_pybullet_block +from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion -from predicators.pybullet_helpers.objects import create_object +from predicators.pybullet_helpers.objects import create_object, \ + create_pybullet_block from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ @@ -85,7 +86,7 @@ class PyBulletMagicBinEnv(PyBulletEnv): sim_features=["id", "joint_id", "joint_scale"]) _bin_type = Type("bin", ["x", "y", "z", "rot"]) - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: # Objects self._robot = Object("robot", self._robot_type) self._blocks: List[Object] = [ @@ -95,7 +96,7 @@ def __init__(self, use_gui: bool = False) -> None: self._switch = Object("switch", self._switch_type) self._bin = Object("bin", self._bin_type) - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) # Predicates self._HandEmpty = Predicate("HandEmpty", [self._robot_type], @@ -235,7 +236,7 @@ def _get_object_ids_for_held_check(self) -> List[int]: """Return IDs of objects that can be held (blocks).""" return [block.id for block in self._blocks] - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" if obj.type == self._switch_type and feature == "is_on": return float(self._is_switch_on()) @@ -246,10 +247,7 @@ def _extract_feature(self, obj: Object, feature: str) -> float: return float(pos[0] > 5.0) # Out of view if x > 5 raise ValueError(f"Unknown feature {feature} for object {obj}") - def _create_task_specific_objects(self, state: State) -> None: - del state # Unused - - def _reset_custom_env_state(self, state: State) -> None: + def _set_domain_specific_state(self, state: State) -> None: """Reset environment state from a State object.""" # Set switch state switch_on = state.get(self._switch, "is_on") > 0.5 @@ -267,12 +265,8 @@ def _reset_custom_env_state(self, state: State) -> None: self._default_orn, physicsClientId=self._physics_client_id) - def step(self, action: Action, render_obs: bool = False) -> State: - """Process a single action step.""" - # Execute the action - super().step(action, render_obs=render_obs) - - # Check magic bin logic: if switch is on and block is in bin, vanish it + def _domain_specific_step(self) -> None: + """If switch is on and block is in bin, vanish it.""" if self._is_switch_on(): bin_pos, _ = p.getBasePositionAndOrientation( self._bin.id, physicsClientId=self._physics_client_id) @@ -303,11 +297,6 @@ def step(self, action: Action, render_obs: bool = False) -> State: self._default_orn, physicsClientId=self._physics_client_id) - # Get updated state - final_state = self._get_state() - self._current_observation = final_state - return final_state - # ------------------------------------------------------------------------- # Switch helpers def _is_switch_on(self) -> bool: @@ -481,7 +470,7 @@ def _make_tasks(self, num_tasks: int, CFG.num_train_tasks = 1 env = PyBulletMagicBinEnv(use_gui=True) task = env._generate_train_tasks()[0] # pylint: disable=protected-access - env._reset_state(task.init) # pylint: disable=protected-access + env._set_state(task.init) # pylint: disable=protected-access print("PyBullet Magic Bin Environment Test") print("Blocks should vanish when in bin with switch ON.") diff --git a/predicators/envs/pybullet_switch.py b/predicators/envs/pybullet_switch.py index bd5ac59d1..cefcaa4ef 100644 --- a/predicators/envs/pybullet_switch.py +++ b/predicators/envs/pybullet_switch.py @@ -89,17 +89,18 @@ class PyBulletSwitchEnv(PyBulletEnv): sim_features=["id", "joint_id", "joint_scale", "color_count"]) _light_type = Type("light", ["x", "y", "z", "rot", "is_on", "color_index"]) - def __init__(self, use_gui: bool = False) -> None: + def __init__(self, use_gui: bool = False, **kwargs: Any) -> None: # Objects self._robot = Object("robot", self._robot_type) self._power_switch = Object("power_switch", self._power_switch_type) self._color_switch = Object("color_switch", self._color_switch_type) self._light = Object("light", self._light_type) - super().__init__(use_gui) + super().__init__(use_gui, **kwargs) # Track previous switch states for edge detection self._prev_color_switch_on: bool = False + self._pre_step_color_count: int = 0 # Predicates self._PowerOn = Predicate("PowerOn", [self._power_switch_type], @@ -223,7 +224,7 @@ def _get_object_ids_for_held_check(self) -> List[int]: """Return IDs of objects that can be held (none in this env).""" return [] - def _extract_feature(self, obj: Object, feature: str) -> float: + def _get_domain_specific_feature(self, obj: Object, feature: str) -> float: """Extract features for creating the State object.""" if obj.type == self._light_type and feature == "is_on": return float(self._is_power_switch_on()) @@ -236,42 +237,33 @@ def _extract_feature(self, obj: Object, feature: str) -> float: return float(self._is_switch_on(self._color_switch)) raise ValueError(f"Unknown feature {feature} for object {obj}") - def _create_task_specific_objects(self, state: State) -> None: - del state # Unused - - def _reset_custom_env_state(self, state: State) -> None: - """Reset environment state from a State object.""" - # Set power switch state + def _set_domain_specific_state(self, state: State) -> None: + """Set switch positions, tracking vars, color count, and light + visual.""" power_on = state.get(self._power_switch, "is_on") > 0.5 self._set_switch_state(self._power_switch, power_on) - # Set color switch state color_switch_on = state.get(self._color_switch, "is_on") > 0.5 self._set_switch_state(self._color_switch, color_switch_on) - # Track previous color switch state for edge detection self._prev_color_switch_on = color_switch_on - # Initialize color_count from light's color_index color_index = int(state.get(self._light, "color_index")) self._color_switch.color_count = color_index - # Update light visual self._update_light_visual(power_on, color_index) def step(self, action: Action, render_obs: bool = False) -> State: - """Process a single action step.""" - # Get current color_count from sim_feature - prev_color_count = self._color_switch.color_count - - # Execute the action - super().step(action, render_obs=render_obs) + """Save pre-step color count before kinematics.""" + self._pre_step_color_count = self._color_switch.color_count + return super().step(action, render_obs=render_obs) + def _domain_specific_step(self) -> None: # Detect color switch toggle (OFF -> ON transition) curr_color_switch_on = self._is_switch_on(self._color_switch) if not self._prev_color_switch_on and curr_color_switch_on: # Rising edge detected - increment color count - self._color_switch.color_count = prev_color_count + 1 + self._color_switch.color_count = self._pre_step_color_count + 1 self._prev_color_switch_on = curr_color_switch_on @@ -285,11 +277,6 @@ def step(self, action: Action, render_obs: bool = False) -> State: # Update light visual self._update_light_visual(power_on, color_index) - # Get updated state with correct light values - final_state = self._get_state() - self._current_observation = final_state - return final_state - # ------------------------------------------------------------------------- # Switch helpers def _is_switch_on(self, switch_obj: Object) -> bool: @@ -465,7 +452,7 @@ def _make_tasks(self, num_tasks: int, CFG.num_train_tasks = 1 env = PyBulletSwitchEnv(use_gui=True) task = env._generate_train_tasks()[0] # pylint: disable=protected-access - env._reset_state(task.init) # pylint: disable=protected-access + env._set_state(task.init) # pylint: disable=protected-access while True: _joints = env._pybullet_robot.initial_joint_positions # pylint: disable=protected-access diff --git a/predicators/explorers/__init__.py b/predicators/explorers/__init__.py index 560c840d6..644138648 100644 --- a/predicators/explorers/__init__.py +++ b/predicators/explorers/__init__.py @@ -109,7 +109,7 @@ def create_explorer( action_space, train_tasks, max_steps_before_termination, nsrts, maple_q_function) - elif name == "agent": + elif name in ("agent_plan", "agent_bilevel"): assert tool_context is not None assert agent_session is not None explorer = cls(initial_predicates, initial_options, types, diff --git a/predicators/explorers/agent_bilevel_explorer.py b/predicators/explorers/agent_bilevel_explorer.py new file mode 100644 index 000000000..8c50db54c --- /dev/null +++ b/predicators/explorers/agent_bilevel_explorer.py @@ -0,0 +1,232 @@ +"""Agent bilevel explorer: sketch → refine against mental model → execute real. + +Produces a plan *sketch* via a Claude agent, runs backtracking refinement +against the approach's currently-learned option model (read from +``tool_context.option_model``), then rolls the refined plan out in the +real environment. When the mental model disagrees with reality (e.g. a +subgoal atom the mental model expected after a Wait doesn't actually +hold), the resulting trajectory provides a targeted learning signal for +online simulator synthesis. + +Parallels ``AgentPlanExplorer`` for session plumbing and +``AgentBilevelApproach`` for the sketch/refine workflow. +""" + +import logging +from typing import Any, Callable, Dict, List, Optional, Set + +import numpy as np +from gym.spaces import Box + +from predicators import utils +from predicators.agent_sdk import bilevel_sketch +from predicators.agent_sdk.session_manager import AgentSessionManager, \ + run_query_sync +from predicators.agent_sdk.tools import ToolContext +from predicators.explorers.base_explorer import BaseExplorer +from predicators.settings import CFG +from predicators.structs import Action, ExplorationStrategy, \ + ParameterizedOption, Predicate, State, Task, Type + + +class AgentBilevelExplorer(BaseExplorer): + """Queries a Claude agent for a plan sketch, refines it, and executes.""" + + def __init__(self, predicates: Set[Predicate], + options: Set[ParameterizedOption], types: Set[Type], + action_space: Box, train_tasks: List[Task], + max_steps_before_termination: int, tool_context: ToolContext, + agent_session: AgentSessionManager) -> None: + super().__init__(predicates, options, types, action_space, train_tasks, + max_steps_before_termination) + self._tool_context = tool_context + self._agent_session = agent_session + + @classmethod + def get_name(cls) -> str: + return "agent_bilevel" + + # ------------------------------------------------------------------ # + # Exploration strategy + # ------------------------------------------------------------------ # + + def _get_exploration_strategy(self, train_task_idx: int, + timeout: int) -> ExplorationStrategy: + task = self._train_tasks[train_task_idx] + # The approach syncs tool_context.option_model right before + # constructing this explorer, so reading here picks up the most + # recently learned model. + option_model = self._tool_context.option_model + assert option_model is not None, \ + "agent_bilevel explorer needs a synced option_model" + + try: + prompt = bilevel_sketch.build_solve_prompt( + task, + all_predicates=self._predicates, + all_options=self._options, + trajectory_summary=self._build_trajectory_summary(), + tool_names=self._agent_tool_names(), + ) + responses = run_query_sync(self._agent_session, prompt) + plan_text = self._extract_option_plan_text(responses) + if not plan_text: + raise ValueError("agent returned empty plan text") + + sketch = bilevel_sketch.parse_sketch_from_text( + plan_text, + task, + predicates=self._predicates, + options=self._options, + types=self._types, + ) + if not sketch: + raise ValueError("parsed empty plan sketch") + + self._tool_context.last_sketch_subgoals = [ + (s.subgoal_atoms, s.subgoal_neg_atoms) for s in sketch + ] + self._tool_context.last_sketch_options = [ + (s.option.name, [o.name for o in s.objects]) for s in sketch + ] + + # Explorer mode: keep subgoal validation ON so the mental + # model can tell us which step it can't predict, but when + # that happens, truncate the plan at that step (inclusive) + # instead of backtracking. Steps beyond the first + # disagreement are built on a false mental-model state, so + # executing them in the real env adds noise rather than + # signal. The truncated plan — Pick → ... → first failing + # step — is the experiment we want to run. Final-goal check + # is also off: the explorer isn't trying to solve the task + # in the mental model. + plan, success, _ = bilevel_sketch.refine_sketch( + task, + sketch, + option_model, + predicates=self._predicates, + timeout=float(timeout), + rng=np.random.default_rng(CFG.seed), + max_samples_per_step=CFG. + agent_bilevel_explorer_max_samples_per_step, + check_subgoals=True, + check_final_goal=False, + truncate_on_subgoal_fail=True, + log_state=CFG.agent_bilevel_log_state, + run_id="agent_bilevel_explorer", + ) + logging.info( + f"agent_bilevel explorer: sketch has {len(sketch)} steps, " + f"refined {len(plan)} " + f"({'success' if success else 'partial'}).") + if plan: + plan_strs = [] + for i, opt in enumerate(plan): + obj_s = ", ".join(o.name for o in opt.objects) + par_s = ", ".join(f"{p:.4f}" for p in opt.params) + plan_strs.append(f" {i}: {opt.name}({obj_s})[{par_s}]") + logging.info("agent_bilevel explorer: experiment plan:\n%s", + "\n".join(plan_strs)) + + if plan: + policy = utils.option_plan_to_policy( + plan, + abstract_function=lambda s: utils.abstract( + s, self._predicates)) + return self._wrap_policy(policy), lambda _: False + + logging.info("agent_bilevel explorer: refinement produced zero " + "steps, falling back to random.") + except Exception as e: # pylint: disable=broad-except + logging.warning(f"agent_bilevel explorer failed: {e}. " + "Falling back to random options.") + + if not CFG.agent_explorer_fallback_to_random: + raise utils.RequestActPolicyFailure( + "agent_bilevel explorer failed and fallback disabled.") + return self._random_options_fallback() + + # ------------------------------------------------------------------ # + # Helpers + # ------------------------------------------------------------------ # + + def _wrap_policy( + self, policy: Callable[[State], + Action]) -> Callable[[State], Action]: + """Convert OptionExecutionFailure into RequestActPolicyFailure. + + This lets the main loop cleanly terminate the episode when the + refined plan finishes or fails mid-execution (which is exactly + the disagreement signal we want to collect). + """ + + def _wrapped(state: State) -> Action: + try: + return policy(state) + except utils.OptionExecutionFailure as e: + raise utils.RequestActPolicyFailure(e.args[0], e.info) from e + + return _wrapped + + def _random_options_fallback(self) -> ExplorationStrategy: + """Fall back to random option sampling.""" + + def fallback_policy(state: State) -> Action: + del state + raise utils.RequestActPolicyFailure( + "Random option sampling failed!") + + policy = utils.create_random_option_policy(self._options, self._rng, + fallback_policy) + return policy, lambda _: False + + def _agent_tool_names(self) -> Optional[List[str]]: + """Return tool names exposed by the current session, if any.""" + return getattr(self._agent_session, "tool_names", None) + + def _build_trajectory_summary(self) -> str: + """Summarize trajectory data for the agent.""" + all_trajs = (self._tool_context.offline_trajectories + + self._tool_context.online_trajectories) + if not all_trajs: + return "" + + max_trajs = CFG.agent_sdk_max_trajectories_in_context + recent = all_trajs[-max_trajs:] + lines = [ + f"\n## Trajectory Summary ({len(all_trajs)} total, " + f"showing last {len(recent)})" + ] + + for i, traj in enumerate(recent): + n_steps = len(traj.actions) + init_atoms = utils.abstract(traj.states[0], self._predicates) + final_atoms = utils.abstract(traj.states[-1], self._predicates) + new_atoms = final_atoms - init_atoms + lost_atoms = init_atoms - final_atoms + lines.append(f"\nTrajectory {i}: {n_steps} steps") + if new_atoms: + lines.append( + " Gained: " + + f"{', '.join(str(a) for a in sorted(new_atoms, key=str))}") + if lost_atoms: + lines.append( + " Lost: " + + f"{', '.join(str(a) for a in sorted(lost_atoms, key=str))}" + ) + + return "\n".join(lines) + + def _extract_option_plan_text(self, responses: List[Dict[str, + Any]]) -> str: + """Extract plan text from the last assistant text response.""" + last_text_parts: List[str] = [] + for resp in responses: + if resp.get("type") == "assistant": + parts = [ + block.get("text", "") for block in resp.get("content", []) + if isinstance(block, dict) and block.get("type") == "text" + ] + if parts: + last_text_parts = parts + return "\n".join(last_text_parts) diff --git a/predicators/explorers/agent_explorer.py b/predicators/explorers/agent_plan_explorer.py similarity index 90% rename from predicators/explorers/agent_explorer.py rename to predicators/explorers/agent_plan_explorer.py index 31b675ab4..46fb2f98b 100644 --- a/predicators/explorers/agent_explorer.py +++ b/predicators/explorers/agent_plan_explorer.py @@ -1,6 +1,12 @@ -"""An explorer that queries a Claude agent to generate option plans.""" +"""Agent plan explorer: Claude agent generates grounded option plans. + +Produces fully-grounded option plans (including continuous parameters) +and rolls them out in the real environment. Unlike +``AgentBilevelExplorer``, it does not run backtracking refinement +against a learned option model — the agent is expected to provide +complete parameters itself. +""" -import asyncio import logging from typing import Any, Dict, List, Set @@ -8,7 +14,8 @@ from gym.spaces import Box from predicators import utils -from predicators.agent_sdk.session_manager import AgentSessionManager +from predicators.agent_sdk.session_manager import AgentSessionManager, \ + run_query_sync from predicators.agent_sdk.tools import ToolContext from predicators.explorers.base_explorer import BaseExplorer from predicators.settings import CFG @@ -16,8 +23,8 @@ ParameterizedOption, Predicate, State, Task, Type -class AgentExplorer(BaseExplorer): - """Queries a Claude agent to produce option plans for exploration.""" +class AgentPlanExplorer(BaseExplorer): + """Queries a Claude agent to produce grounded option plans.""" def __init__(self, predicates: Set[Predicate], options: Set[ParameterizedOption], types: Set[Type], @@ -31,14 +38,14 @@ def __init__(self, predicates: Set[Predicate], @classmethod def get_name(cls) -> str: - return "agent" + return "agent_plan" def _get_exploration_strategy(self, train_task_idx: int, timeout: int) -> ExplorationStrategy: task = self._train_tasks[train_task_idx] try: prompt = self._build_exploration_prompt(train_task_idx) - responses = self._query_agent_sync(prompt) + responses = run_query_sync(self._agent_session, prompt) plan_text = self._extract_option_plan_text(responses) if plan_text: option_plan = self._parse_and_ground_plan(plan_text, task) @@ -185,20 +192,6 @@ def _build_trajectory_summary(self) -> str: return "\n".join(lines) - def _query_agent_sync(self, message: str) -> List[Dict[str, Any]]: - """Synchronous wrapper for async agent query.""" - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - # pylint: disable-next=import-outside-toplevel - import nest_asyncio # type: ignore[import-untyped] - nest_asyncio.apply() - return loop.run_until_complete( - self._agent_session.query(message)) - return loop.run_until_complete(self._agent_session.query(message)) - except RuntimeError: - return asyncio.run(self._agent_session.query(message)) - def _extract_option_plan_text(self, responses: List[Dict[str, Any]]) -> str: """Extract plan text from the last assistant text response. diff --git a/predicators/ground_truth_models/__init__.py b/predicators/ground_truth_models/__init__.py index df53c14d1..54b6155d9 100644 --- a/predicators/ground_truth_models/__init__.py +++ b/predicators/ground_truth_models/__init__.py @@ -1,5 +1,6 @@ """Implements ground-truth NSRTs and options.""" import abc +import sys from pathlib import Path from typing import Dict, List, Sequence, Set @@ -68,6 +69,24 @@ def get_processes( raise NotImplementedError("Override me!") +class GroundTruthSimulatorFactory(abc.ABC): + """Parent class for ground-truth process-dynamics simulator programs. + + The factory itself only pins an env-name binding. The actual + simulator components (``PROCESS_RULES``, ``PARAM_SPECS``, + ``PROCESS_FEATURES``) live as module-level globals on the same file + as the subclass, matching the contract used by agent-synthesized + simulators. ``get_gt_simulator`` reads them via + ``read_simulator_components``. + """ + + @classmethod + @abc.abstractmethod + def get_env_names(cls) -> Set[str]: + """Get the env names that this factory builds simulators for.""" + raise NotImplementedError("Override me!") + + class GroundTruthLDLBridgePolicyFactory(abc.ABC): """Ground-truth policies implemented with LDLs saved in text files.""" @@ -241,6 +260,38 @@ def get_gt_processes(env_name: str, return final_processes +def get_gt_simulator(env_name: str) -> tuple: + """Load ground-truth process rules and param specs for an env. + + Returns ``(rules, param_specs, process_features)``: *rules* is the + list of process rule functions, *param_specs* is the list of + ``ParamSpec`` objects whose ``init_value`` is the GT value, and + *process_features* is the ``{type_name: [feat_names]}`` mapping that + scopes which features the rules predict. + + Locates the right module via the ``GroundTruthSimulatorFactory`` + registry (env-name binding) and reads the three components from + that module's globals via ``read_simulator_components``. This + mirrors the loader used for agent-synthesized simulators. + """ + # Local import to avoid pulling code_sim_learning into ground_truth_models + # at import time. + # pylint: disable=import-outside-toplevel + from predicators.code_sim_learning.utils import read_simulator_components + + for cls in utils.get_all_subclasses(GroundTruthSimulatorFactory): + if not cls.__abstractmethods__ and env_name in cls.get_env_names(): + module = sys.modules[cls.__module__] + rules, specs, features = read_simulator_components(vars(module)) + if rules is None or specs is None or features is None: + raise RuntimeError( + f"GT simulator module {cls.__module__} is missing one " + "of PROCESS_RULES / PARAM_SPECS / PROCESS_FEATURES.") + return rules, specs, features + raise NotImplementedError("Ground-truth simulator not implemented for " + f"env: {env_name}") + + def get_gt_ldl_bridge_policy(env_name: str, types: Set[Type], predicates: Set[Predicate], options: Set[ParameterizedOption], diff --git a/predicators/ground_truth_models/boil/__init__.py b/predicators/ground_truth_models/boil/__init__.py index cde72a21a..12fb982f8 100644 --- a/predicators/ground_truth_models/boil/__init__.py +++ b/predicators/ground_truth_models/boil/__init__.py @@ -1,5 +1,6 @@ """Ground-truth models for coffee environment and variants.""" +from .gt_simulator import PyBulletBoilGroundTruthSimulatorFactory from .nsrts import PyBulletBoilGroundTruthNSRTFactory from .options import PyBulletBoilGroundTruthOptionFactory from .processes import PyBulletBoilGroundTruthProcessFactory @@ -7,5 +8,6 @@ __all__ = [ "PyBulletBoilGroundTruthNSRTFactory", "PyBulletBoilGroundTruthOptionFactory", - "PyBulletBoilGroundTruthProcessFactory" + "PyBulletBoilGroundTruthProcessFactory", + "PyBulletBoilGroundTruthSimulatorFactory", ] diff --git a/predicators/ground_truth_models/boil/gt_simulator.py b/predicators/ground_truth_models/boil/gt_simulator.py new file mode 100644 index 000000000..b971d9992 --- /dev/null +++ b/predicators/ground_truth_models/boil/gt_simulator.py @@ -0,0 +1,254 @@ +"""Ground-truth simulator program for pybullet_boil process dynamics. + +Reproduces the custom step logic from pybullet_boil.py as composable +process rules using plain numpy/float arithmetic. + +Parameter-dependent gates (alignment thresholds, capacity caps, fill +height) are softened with sigmoid weights so the residual is +differentiable in those parameters. The primary consumer is the +Levenberg-Marquardt fit (and its Hessian identifiability diagnostic), +which builds a finite-difference Jacobian and would see J ~ 0 almost +everywhere with hard indicators. Smoothing also keeps MCMC walkers +from stalling on flat-likelihood plateaus, but emcee is gradient-free +and benefits less directly. State-dependent gates (faucet on/off, jug +held) remain hard since they don't enter the parameter likelihood. +""" + +from __future__ import annotations + +from typing import Dict, List + +import numpy as np + +from predicators.code_sim_learning.training import ParamSpec +from predicators.code_sim_learning.utils import ProcessUpdate +from predicators.ground_truth_models import GroundTruthSimulatorFactory +from predicators.settings import CFG +from predicators.structs import Object, State + +# Constants matching pybullet_boil.py exactly. Note: water_fill_speed is +# derived from CFG at spec-build time (env uses +# CFG.boil_water_fill_speed * water_height_to_level_ratio). +HEATING_SPEED = 0.03 +HAPPINESS_SPEED = 0.05 +MAX_JUG_WATER_CAPACITY = 1.3 +WATER_FILLED_HEIGHT = 0.8 +MAX_WATER_SPILL_WIDTH = 0.3 +FAUCET_ALIGN_THRESHOLD = 0.1 +BURNER_ALIGN_THRESHOLD = 0.05 +FAUCET_X_LEN = 0.15 +_WATER_HEIGHT_TO_LEVEL_RATIO = 10 + +# Smoothing scale for parameter-dependent gates. Small enough that gates +# are ~99% saturated when the operand is one threshold-width into the +# active region, large enough to give MCMC a usable gradient near the +# cliff. 0.02 is in the right ballpark for both spatial thresholds +# (~0.05–0.15 m) and water-level thresholds (~0.3–1.3). +_SOFT_EPS = 0.02 + + +def _sigmoid(z: float) -> float: + """Numerically-stable scalar sigmoid.""" + if z >= 0: + return 1.0 / (1.0 + np.exp(-z)) + ez = np.exp(z) + return ez / (1.0 + ez) + + +def _build_param_specs() -> List[ParamSpec]: + """Build at call time so CFG-driven values match the current run.""" + water_fill_speed = (CFG.boil_water_fill_speed * + _WATER_HEIGHT_TO_LEVEL_RATIO) + return [ + ParamSpec("water_fill_speed", water_fill_speed, lo=0.0), + ParamSpec("heating_speed", HEATING_SPEED, lo=0.0), + ParamSpec("happiness_speed", HAPPINESS_SPEED, lo=0.0), + ParamSpec("max_jug_water_capacity", MAX_JUG_WATER_CAPACITY, lo=0.0), + ParamSpec("water_filled_height", WATER_FILLED_HEIGHT, lo=0.0), + ParamSpec("max_water_spill_width", MAX_WATER_SPILL_WIDTH, lo=0.0), + ParamSpec("faucet_x_len", FAUCET_X_LEN, lo=0.0), + ParamSpec("faucet_align_threshold", FAUCET_ALIGN_THRESHOLD, lo=0.0), + ParamSpec("burner_align_threshold", BURNER_ALIGN_THRESHOLD, lo=0.0), + ] + + +# Module-level globals consumed by ``read_simulator_components`` (the +# same contract used by agent-synthesized simulator files). +# ``PARAM_SPECS`` is bound to the *callable* rather than its result so +# CFG-dependent defaults are evaluated when the loader pulls the value, +# after CFG has been finalized. +PARAM_SPECS = _build_param_specs + +PROCESS_FEATURES: Dict[str, List[str]] = { + "jug": ["water_volume", "heat_level"], + "faucet": ["spilled_level"], + "human": ["happiness_level"], +} + +# Backward-compat alias for tests that import a static, eagerly-built +# spec list (uses CFG defaults at import time). +BOIL_PARAM_SPECS: List[ParamSpec] = _build_param_specs() + +Params = Dict[str, float] + + +def _objs_by_type(state: State) -> Dict[str, List[Object]]: + """Group state objects by type name.""" + groups: Dict[str, List[Object]] = {} + for o in state: + groups.setdefault(o.type.name, []).append(o) + return groups + + +def _water_filling(state: State, updates: ProcessUpdate, + params: Params) -> ProcessUpdate: + """Faucet on + jug aligned → fill jug; otherwise spill. + + Alignment and capacity gates are soft (sigmoid-weighted) so the + residual is differentiable in ``faucet_align_threshold``, + ``faucet_x_len``, and ``max_jug_water_capacity`` — needed for the LM + Jacobian (and downstream Hessian diagnostic) to be informative. + """ + objs = _objs_by_type(state) + for faucet in objs.get("faucet", []): + if state.get(faucet, "is_on") <= 0.5: + continue + + fx = float(state.get(faucet, "x")) + fy = float(state.get(faucet, "y")) + frot = float(state.get(faucet, "rot")) + out_x = fx + params["faucet_x_len"] * np.cos(frot) + out_y = fy - params["faucet_x_len"] * np.sin(frot) + + # Closest non-held jug picks up the catch (matches the + # original "first aligned wins" semantics for single-jug tasks). + best_jug, best_dist = None, float("inf") + for jug in objs.get("jug", []): + if state.get(jug, "is_held") > 0.5: + continue + jx = float(state.get(jug, "x")) + jy = float(state.get(jug, "y")) + d = float(np.hypot(out_x - jx, out_y - jy)) + if d < best_dist: + best_jug, best_dist = jug, d + + catch_w = 0.0 + if best_jug is not None: + water = float(state.get(best_jug, "water_volume")) + align_w = _sigmoid( + (params["faucet_align_threshold"] - best_dist) / _SOFT_EPS) + cap_w = _sigmoid( + (params["max_jug_water_capacity"] - water) / _SOFT_EPS) + catch_w = align_w * cap_w + new_water = water + catch_w * params["water_fill_speed"] + updates.setdefault(best_jug, {})["water_volume"] = new_water + + # Uncaught water spills (clamped at max_water_spill_width). + spill = float(state.get(faucet, "spilled_level")) + new_spill = min(params["max_water_spill_width"], + spill + (1.0 - catch_w) * params["water_fill_speed"]) + updates.setdefault(faucet, {})["spilled_level"] = new_spill + + return updates + + +def _heating(state: State, updates: ProcessUpdate, + params: Params) -> ProcessUpdate: + """Burner on + jug with water aligned → heat jug. + + Alignment gate is soft so the residual is differentiable in + ``burner_align_threshold`` (LM's finite-difference Jacobian needs + this; MCMC also avoids flat-likelihood plateaus as a side effect). + The heat cap at 1.0 stays hard since 1.0 is a constant boundary, not + a learned parameter. + """ + objs = _objs_by_type(state) + for burner in objs.get("burner", []): + if state.get(burner, "is_on") <= 0.5: + continue + bx = float(state.get(burner, "x")) + by = float(state.get(burner, "y")) + + for jug in objs.get("jug", []): + if state.get(jug, "is_held") > 0.5: + continue + if state.get(jug, "water_volume") <= 0.0: + continue + jx = float(state.get(jug, "x")) + jy = float(state.get(jug, "y")) + dist = float(np.hypot(bx - jx, by - jy)) + + align_w = _sigmoid( + (params["burner_align_threshold"] - dist) / _SOFT_EPS) + heat = float(state.get(jug, "heat_level")) + new_heat = min(1.0, heat + align_w * params["heating_speed"]) + updates.setdefault(jug, {})["heat_level"] = new_heat + + return updates + + +def _happiness(state: State, updates: ProcessUpdate, + params: Params) -> ProcessUpdate: + """Jug filled + boiled + no spill + burner off → human happy. + + The water-filled gate is soft on ``water_filled_height`` so the + residual is differentiable in that parameter for LM (and emcee gets + a non-flat likelihood as a side effect). The heat>=1.0 gate stays + hard (1.0 is a constant cap, not a learned parameter). Spill / + burner-on gates are state-dependent. + """ + objs = _objs_by_type(state) + faucets = objs.get("faucet", []) + burners = objs.get("burner", []) + + def _get_val(obj: Object, feat: str) -> float: + val = updates.get(obj, {}).get(feat, None) + if val is not None: + return float(val) if hasattr(val, 'item') else val + return float(state.get(obj, feat)) + + # Spilled-level prediction can be a tiny positive number under soft + # semantics even when the env reports zero, so treat anything below + # the smoothing scale as "no spill" to avoid spuriously gating + # happiness off. + any_spill = any(_get_val(f, "spilled_level") > _SOFT_EPS for f in faucets) + any_burner_on = any(state.get(b, "is_on") > 0.5 for b in burners) + + if any_spill or any_burner_on: + return updates + + for jug in objs.get("jug", []): + water = _get_val(jug, "water_volume") + heat = _get_val(jug, "heat_level") + if heat < 1.0: + continue + filled_w = _sigmoid( + (water - params["water_filled_height"]) / _SOFT_EPS) + for human in objs.get("human", []): + h = float(state.get(human, "happiness_level")) + new_h = min(1.0, h + filled_w * params["happiness_speed"]) + updates.setdefault(human, {})["happiness_level"] = new_h + + return updates + + +PROCESS_RULES = [_water_filling, _heating, _happiness] + + +def get_gt_process_features() -> Dict[str, List[str]]: + """Backward-compat accessor; prefer the ``PROCESS_FEATURES`` global.""" + return dict(PROCESS_FEATURES) + + +class PyBulletBoilGroundTruthSimulatorFactory(GroundTruthSimulatorFactory): + """GT process-dynamics simulator for pybullet_boil. + + The actual simulator components (``PROCESS_RULES``, ``PARAM_SPECS``, + ``PROCESS_FEATURES``) live as module globals above; this class only + pins the env-name binding so ``get_gt_simulator`` can locate the + right module via the factory registry. + """ + + @classmethod + def get_env_names(cls) -> set: + return {"pybullet_boil"} diff --git a/predicators/ground_truth_models/boil/options.py b/predicators/ground_truth_models/boil/options.py index 769edbcbd..59b2ccd48 100644 --- a/predicators/ground_truth_models/boil/options.py +++ b/predicators/ground_truth_models/boil/options.py @@ -88,7 +88,7 @@ def _get_options_skill_factories( # --------------------------------------------------------------- # Helper: find the switch object associated with a faucet/burner. - # The env sets obj.switch_id in _reset_state. + # The env sets obj.switch_id in _set_state. # --------------------------------------------------------------- def _get_switch_pose( state: State, diff --git a/predicators/ground_truth_models/skill_factories/base.py b/predicators/ground_truth_models/skill_factories/base.py index 8cdf73f48..64ef19541 100644 --- a/predicators/ground_truth_models/skill_factories/base.py +++ b/predicators/ground_truth_models/skill_factories/base.py @@ -543,7 +543,7 @@ def _plan_with_simulator( new_state_data, simulator_state=pb_state.simulator_state) # 3. Reset simulator to current state - sim._reset_state(remapped_state) # pylint: disable=protected-access + sim._set_state(remapped_state) # pylint: disable=protected-access # 4. Collect collision body IDs (exclude held objects and # non-physical types) and find the held object. diff --git a/predicators/main.py b/predicators/main.py index 0fd55c6e3..a50591fd4 100644 --- a/predicators/main.py +++ b/predicators/main.py @@ -65,6 +65,53 @@ "Please add `export PYTHONHASHSEED=0` to your bash profile!" +def main() -> None: + """Main entry point for running approaches in environments.""" + script_start = time.perf_counter() + + # Parse & validate args + args = utils.parse_args() + utils.update_config(args) + str_args = " ".join(sys.argv) + + # Setup logging and directories + utils.configure_logging() + os.makedirs(CFG.results_dir, exist_ok=True) + os.makedirs(CFG.eval_trajectories_dir, exist_ok=True) + + # Log initial info + utils.log_initial_info(str_args) + + # Setup environment and tasks + env, approach_train_tasks, train_tasks = setup_environment() + + # Setup predicates + included_preds, excluded_preds = utils.parse_config_excluded_predicates( + env) + preds = utils.replace_goals_with_agent_specific_goals( + included_preds, excluded_preds, + env) if CFG.approach != "oracle" else included_preds + + # Create approach + approach = setup_approach(env, preds, approach_train_tasks) + + # Create dataset and cognitive manager + offline_dataset = create_offline_dataset(env, train_tasks, preds, approach) + execution_monitor = create_execution_monitor(CFG.execution_monitor) + cogman = CogMan(approach, create_perceiver(CFG.perceiver), + execution_monitor) + + # Run pipeline + _run_pipeline(env, cogman, approach_train_tasks, offline_dataset) + + # Log completion + script_time = time.perf_counter() - script_start + logging.info(f"\n\nMain script terminated in {script_time:.5f} seconds") + + +# ── Setup helpers ──────────────────────────────────────────────── + + def setup_environment() -> Tuple[BaseEnv, List[Task], List[Task]]: """Create and setup the environment and tasks. @@ -141,48 +188,7 @@ def create_offline_dataset(env: BaseEnv, train_tasks: List[Task], preds: set, return None -def main() -> None: - """Main entry point for running approaches in environments.""" - script_start = time.perf_counter() - - # Parse & validate args - args = utils.parse_args() - utils.update_config(args) - str_args = " ".join(sys.argv) - - # Setup logging and directories - utils.configure_logging() - os.makedirs(CFG.results_dir, exist_ok=True) - os.makedirs(CFG.eval_trajectories_dir, exist_ok=True) - - # Log initial info - utils.log_initial_info(str_args) - - # Setup environment and tasks - env, approach_train_tasks, train_tasks = setup_environment() - - # Setup predicates - included_preds, excluded_preds = utils.parse_config_excluded_predicates( - env) - preds = utils.replace_goals_with_agent_specific_goals( - included_preds, excluded_preds, - env) if CFG.approach != "oracle" else included_preds - - # Create approach - approach = setup_approach(env, preds, approach_train_tasks) - - # Create dataset and cognitive manager - offline_dataset = create_offline_dataset(env, train_tasks, preds, approach) - execution_monitor = create_execution_monitor(CFG.execution_monitor) - cogman = CogMan(approach, create_perceiver(CFG.perceiver), - execution_monitor) - - # Run pipeline - _run_pipeline(env, cogman, approach_train_tasks, offline_dataset) - - # Log completion - script_time = time.perf_counter() - script_start - logging.info(f"\n\nMain script terminated in {script_time:.5f} seconds") +# ── Pipeline ───────────────────────────────────────────────────── def _run_pipeline(env: BaseEnv, diff --git a/predicators/option_model.py b/predicators/option_model.py index 9af23cd51..788f85b4e 100644 --- a/predicators/option_model.py +++ b/predicators/option_model.py @@ -11,6 +11,7 @@ from typing import Callable, Optional, Set, Tuple import numpy as np +import pybullet from predicators import utils from predicators.envs import create_new_env @@ -20,6 +21,28 @@ ParameterizedOption, State, _Option +def _check_wait_termination(option: _Option, state: State, last_state: State, + abstract_fn: Callable[[State], Set]) -> bool: + """Check if a Wait option should terminate based on target atoms or atom + change. + + Returns True if it should terminate. + """ + result = utils.check_wait_target_atoms(option, state, abstract_fn) + if result is True: + logging.info("Wait terminating: target atoms satisfied") + return True + if result is None: + cur_atoms = abstract_fn(state) + prev_atoms = abstract_fn(last_state) + if cur_atoms != prev_atoms: + logging.info(f"Wait terminating due to atom change: " + f"Add: {sorted(cur_atoms - prev_atoms)} " + f"Del: {sorted(prev_atoms - cur_atoms)}") + return True + return False + + def create_option_model(name: str, use_gui: Optional[bool] = None) -> _OptionModelBase: """Create an option model given its name. @@ -115,78 +138,34 @@ def get_next_state_and_num_actions(self, state: State, # if it does. This is a helpful optimization for planning with # fine-grained options over long horizons. # Note: mypy complains if this is None instead of DefaultState. - if CFG.option_model_terminate_on_repeat: - last_state = DefaultState - - def _terminal(s: State) -> bool: - nonlocal last_state - if option_copy.terminal(s): - logging.debug("Option reached terminal state.") - return True - if last_state is not DefaultState and last_state.allclose(s): - logging.debug("Option got stuck.") - raise utils.OptionExecutionFailure( - f"Option '{option_copy.name}' got stuck: the " - f"policy's action did not change the state. " - f"This usually means the first motion phase " - f"produced a no-op (e.g. IK returned current " - f"joints, or finger command matched current " - f"finger state).") - # Terminate Wait on target atoms or any atom change. - if (CFG.wait_option_terminate_on_atom_change - and option_copy.name == "Wait" - and last_state is not DefaultState - and self._abstract_function is not None): - result = utils.check_wait_target_atoms( - option_copy, s, self._abstract_function) - if result is True: - logging.info( - "Wait terminating: target atoms satisfied") - last_state = s - return True - if result is None: - cur_atoms = self._abstract_function(s) - prev_atoms = self._abstract_function(last_state) - if cur_atoms != prev_atoms: - logging.info( - f"Wait terminating due to atom change: " - f"Add: {sorted(cur_atoms - prev_atoms)} " - f"Del: {sorted(prev_atoms - cur_atoms)}") - last_state = s - return True - last_state = s - return False - else: + last_state = DefaultState + + def _terminal(s: State) -> bool: + nonlocal last_state + if option_copy.terminal(s): + logging.debug("Option reached terminal state.") + return True + if (CFG.option_model_terminate_on_repeat + and last_state is not DefaultState + and last_state.allclose(s)): + logging.debug("Option got stuck.") + raise utils.OptionExecutionFailure( + f"Option '{option_copy.name}' got stuck: the " + f"policy's action did not change the state. " + f"This usually means the first motion phase " + f"produced a no-op (e.g. IK returned current " + f"joints, or finger command matched current " + f"finger state).") if (CFG.wait_option_terminate_on_atom_change and option_copy.name == "Wait" - and self._abstract_function is not None): - last_state_ref = [DefaultState] - abstract_fn = self._abstract_function - - def _terminal(s: State) -> bool: - if option_copy.terminal(s): - return True - if last_state_ref[0] is not DefaultState: - result = utils.check_wait_target_atoms( - option_copy, s, abstract_fn) - if result is True: - logging.info( - "Wait terminating: target atoms satisfied") - return True - if result is None: - cur_atoms = abstract_fn(s) - prev_atoms = abstract_fn(last_state_ref[0]) - if cur_atoms != prev_atoms: - logging.info( - f"Wait terminating due to atom change: " - f"Add: {sorted(cur_atoms - prev_atoms)} " - f"Del: {sorted(prev_atoms - cur_atoms)}") - return True - last_state_ref[0] = s - return False - else: - # mypy complains without the lambda, pylint complains with it! - _terminal = lambda s: option_copy.terminal(s) # pylint: disable=unnecessary-lambda + and last_state is not DefaultState + and self._abstract_function is not None + and _check_wait_termination(option_copy, s, last_state, + self._abstract_function)): + logging.debug("Wait option terminating early.") + return True + last_state = s + return False try: traj = utils.run_policy_with_simulator( @@ -195,9 +174,9 @@ def _terminal(s: State) -> bool: state, _terminal, max_num_steps=CFG.max_num_steps_option_rollout) - except utils.OptionExecutionFailure as e: - # If there is a failure during the execution of the option, treat - # this as a noop. + except (utils.OptionExecutionFailure, pybullet.error) as e: + # Treat PyBullet physics engine errors the same as planned + # execution failures (e.g. GUI/Metal crash on macOS). self.last_execution_failure = str(e) return state, 0 # Note that in the case of using a PyBullet environment, the diff --git a/predicators/planning.py b/predicators/planning.py index 162e69443..4aaf9fc80 100644 --- a/predicators/planning.py +++ b/predicators/planning.py @@ -660,6 +660,12 @@ def sample_fn(idx: int, state: State, discovered_failures[idx] = None metrics["num_samples"] += 1 option = skeleton[idx].sample_option(state, task.goal, rng_) + # Inject Wait target atoms so Wait terminates as soon as the + # expected atoms hold rather than running to + # max_num_steps_option_rollout. Without this, refinement keeps + # hitting "exceeded individual horizon" even when heating / + # filling / etc. has already completed. + utils.inject_wait_targets_for_option(option, idx, atoms_sequence) logging.info(f"Running option {option}") return option @@ -688,7 +694,17 @@ def validate_fn(idx: int, pre_state: State, _option: _Option, for atom in atoms_sequence[idx + 1] if atom.predicate.name != _NOT_CAUSES_FAILURE } - if all(a.holds(post_state) for a in expected_atoms): + # Use utils.abstract to evaluate atoms so that + # DerivedPredicates (which need a Set[GroundAtom], not a + # State) are handled correctly. + preds: Set[Predicate] = set() + for a in expected_atoms: + preds.add(a.predicate) + aux = getattr(a.predicate, "auxiliary_predicates", None) + if aux: + preds.update(aux) + current_atoms = utils.abstract(post_state, preds) + if expected_atoms.issubset(current_atoms): return True, "" return False, "expected atoms not hold" # No atoms check — verify goal on final step. diff --git a/predicators/pybullet_helpers/objects.py b/predicators/pybullet_helpers/objects.py index 42941e9c1..6b226deac 100644 --- a/predicators/pybullet_helpers/objects.py +++ b/predicators/pybullet_helpers/objects.py @@ -157,3 +157,108 @@ def create_geom(px: float, py: float) -> _Geom2D: else: # We successfully placed all shapes return positions + + +def create_pybullet_block( + color: Tuple[float, float, float, float], + half_extents: Tuple[float, float, float], + mass: float, + friction: float, + position: Pose3D = (0.0, 0.0, 0.0), + orientation: Quaternion = (0.0, 0.0, 0.0, 1.0), + physics_client_id: int = 0, + add_top_triangle: bool = False, +) -> int: + """Create a box-shaped PyBullet body and return its ID.""" + collision_id = p.createCollisionShape(p.GEOM_BOX, + halfExtents=half_extents, + physicsClientId=physics_client_id) + visual_id = p.createVisualShape(p.GEOM_BOX, + halfExtents=half_extents, + rgbaColor=color, + physicsClientId=physics_client_id) + block_id = p.createMultiBody(baseMass=mass, + baseCollisionShapeIndex=collision_id, + baseVisualShapeIndex=visual_id, + basePosition=position, + baseOrientation=orientation, + physicsClientId=physics_client_id) + p.changeDynamics(block_id, + linkIndex=-1, + lateralFriction=friction, + spinningFriction=friction, + rollingFriction=friction, + physicsClientId=physics_client_id) + + if add_top_triangle: + triangle_size = min(half_extents[0], half_extents[1]) + triangle_vertices = [ + [triangle_size, 0, 0], + [-triangle_size, triangle_size, 0], + [-triangle_size, -triangle_size, 0], + ] + triangle_visual_id = p.createVisualShape( + p.GEOM_MESH, + vertices=triangle_vertices, + indices=[0, 1, 2], + rgbaColor=[1, 1, 0, 1], + physicsClientId=physics_client_id) + + p.removeBody(block_id, physicsClientId=physics_client_id) + + block_id = p.createMultiBody( + baseMass=mass, + baseCollisionShapeIndex=collision_id, + baseVisualShapeIndex=visual_id, + basePosition=position, + baseOrientation=orientation, + linkMasses=[0], + linkCollisionShapeIndices=[-1], + linkVisualShapeIndices=[triangle_visual_id], + linkPositions=[[0, 0, half_extents[2] + 0.001]], + linkOrientations=[[0, 0, 0, 1]], + linkInertialFramePositions=[[0, 0, 0]], + linkInertialFrameOrientations=[[0, 0, 0, 1]], + linkParentIndices=[0], + linkJointTypes=[p.JOINT_FIXED], + linkJointAxis=[[0, 0, 1]], + physicsClientId=physics_client_id) + + p.changeDynamics(block_id, + linkIndex=-1, + lateralFriction=friction, + spinningFriction=friction, + physicsClientId=physics_client_id) + + return block_id + + +def create_pybullet_sphere( + color: Tuple[float, float, float, float], + radius: float, + mass: float, + friction: float, + position: Pose3D = (0.0, 0.0, 0.0), + orientation: Quaternion = (0.0, 0.0, 0.0, 1.0), + physics_client_id: int = 0, +) -> int: + """Create a sphere-shaped PyBullet body and return its ID.""" + collision_id = p.createCollisionShape(p.GEOM_SPHERE, + radius=radius, + physicsClientId=physics_client_id) + visual_id = p.createVisualShape(p.GEOM_SPHERE, + radius=radius, + rgbaColor=color, + physicsClientId=physics_client_id) + sphere_id = p.createMultiBody(baseMass=mass, + baseCollisionShapeIndex=collision_id, + baseVisualShapeIndex=visual_id, + basePosition=position, + baseOrientation=orientation, + physicsClientId=physics_client_id) + p.changeDynamics(sphere_id, + linkIndex=-1, + lateralFriction=friction, + spinningFriction=friction, + physicsClientId=physics_client_id) + return sphere_id diff --git a/predicators/pybullet_helpers/robots/single_arm.py b/predicators/pybullet_helpers/robots/single_arm.py index a0ae333c4..5e32c7812 100644 --- a/predicators/pybullet_helpers/robots/single_arm.py +++ b/predicators/pybullet_helpers/robots/single_arm.py @@ -239,11 +239,20 @@ def initial_joint_positions(self) -> JointPositions: joint_positions[self.right_finger_joint_idx] = self.open_fingers return joint_positions - def reset_state(self, robot_state: Array) -> None: + def reset_state( + self, + robot_state: Array, + joint_positions: Optional[JointPositions] = None, + ) -> None: """Reset the robot state to match the input state. The robot_state corresponds to the State vector for the robot - object. + object. If joint_positions is provided, the arm joints are set + directly from it; otherwise IK is run from the EE pose, which + loses information not encoded in (x, y, z, tilt, wrist) — most + importantly wrist roll. Preserving exact joints is required for + held-object grasps to round-trip through state save/restore + without geometric drift. """ rx, ry, rz, qx, qy, qz, qw, rf = robot_state p.resetBasePositionAndOrientation( @@ -252,6 +261,19 @@ def reset_state(self, robot_state: Array) -> None: self._base_pose.orientation, physicsClientId=self.physics_client_id, ) + target = np.array([rx, ry, rz, qx, qy, qz, qw, rf], dtype=np.float32) + if joint_positions is not None: + # arm_joints includes fingers, so set_joints already + # restored both — skip the snapped-finger overwrite below + # so continuous finger values round-trip cleanly. + self.set_joints(list(joint_positions)) + # Some callers attach nominal joints to plain states as a reset + # hint. Preserve exact joints only when they really reconstruct the + # requested EE pose; otherwise fall back to IK, matching the legacy + # reset behavior. + if np.allclose(self.get_state()[:7], target[:7], atol=1e-3): + return + # First, reset the joint values to initial joint positions, # so that IK is consistent (less sensitive to initialization). self.set_joints(self.initial_joint_positions) @@ -261,7 +283,7 @@ def reset_state(self, robot_state: Array) -> None: pose = Pose((rx, ry, rz), (qx, qy, qz, qw)) self.inverse_kinematics(pose, validate=True) - # Handle setting the robot finger joints. + # IK does not touch fingers, so snap them from the EE state. for finger_id in [self.left_finger_id, self.right_finger_id]: p.resetJointState(self.robot_id, finger_id, diff --git a/predicators/settings.py b/predicators/settings.py index caefb43be..248b8c63e 100644 --- a/predicators/settings.py +++ b/predicators/settings.py @@ -1015,6 +1015,31 @@ class GlobalSettings: agent_bilevel_check_subgoals = True # check subgoal atoms after each step # log state pretty_str before/after each step agent_bilevel_log_state = False + agent_bilevel_plan_sketch_file = "" # load sketch from file instead of LLM + # Agent bilevel explorer settings. Separate from the solve-path budget + # above because the explorer runs full backtracking while looking for + # the deepest subgoal-failure to truncate at, and each exhausted + # upstream step multiplies the cost. + agent_bilevel_explorer_max_samples_per_step = 50 + + # Code sim-learning parameter fitting settings. + # Set to 0 to skip MCMC and use initial parameter values directly. + code_sim_learning_num_mcmc_steps = 500 + # Diagnostic: log the Hessian eigendecomposition at the MAP to + # spot unidentifiable parameter combinations. Adds ~5-15s per fit. + code_sim_learning_log_hessian_identifiability = False + # If True, run an LM fit and center MCMC walkers on its MAP estimate + # instead of init_values. Adds ~5-15s per fit. + code_sim_learning_warm_start_with_lm = True + + # Sim-learning oracle flags (for ablation / debugging). + # When True, load GT process rules instead of running agent synthesis. + # Parameters init_values are perturbed so MCMC still has work to do. + agent_sim_learn_oracle_sim_program = False + # Relative scale for perturbing oracle parameter init_values before MCMC. + agent_sim_learn_oracle_sim_param_noise_scale = 0.2 + # When True, use GT parameter values directly, skipping MCMC fitting. + agent_sim_learn_oracle_sim_params = False @classmethod def get_arg_specific_settings(cls, args: Dict[str, Any]) -> Dict[str, Any]: diff --git a/predicators/utils.py b/predicators/utils.py index 7181522b0..48b8590bb 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -1684,6 +1684,40 @@ def strip_wait_annotations(text: str) -> str: return re.sub(r'\s*->\s*\{[^}]*\}', '', text) +def _format_wait_target_debug( + state: State, target_atoms: Set[GroundAtom], + abstract_function: Callable[[State], Set[GroundAtom]]) -> str: + """Format state details for debugging why Wait has not terminated.""" + cur_atoms = abstract_function(state) + missing_targets = target_atoms - cur_atoms + target_objects = sorted( + { + ent + for atom in target_atoms + for ent in atom.entities if isinstance(ent, Object) + }, + key=lambda o: o.name) + object_details = [] + for obj in target_objects: + feature_values = [] + for feature_name in obj.type.feature_names: + value = state.get(obj, feature_name) + if isinstance(value, float): + value_str = f"{value:.4f}" + else: + value_str = str(value) + feature_values.append(f"{feature_name}={value_str}") + object_details.append(f"{obj}: " + ", ".join(feature_values)) + details = [ + f"Targets: {sorted(target_atoms)}", + f"Missing: {sorted(missing_targets)}", + f"cur_atoms: {sorted(cur_atoms)}", + ] + if object_details: + details.append(f"target_objects: {'; '.join(object_details)}") + return "; ".join(details) + + def option_policy_to_policy( option_policy: Callable[[State], _Option], max_option_steps: Optional[int] = None, @@ -1728,11 +1762,25 @@ def _policy(state: State) -> Action: and cur_option.name == "Wait": assert abstract_function is not None assert last_state is not None + target_atoms = cur_option.memory.get("wait_target_atoms") result = check_wait_target_atoms(cur_option, state, abstract_function) if result is True: - logging.debug("Wait terminating: target atoms satisfied") + cur_atoms = abstract_function(state) + logging.debug("Wait terminating: target atoms satisfied. " + f"Targets: {target_atoms}, " + f"cur_atoms: {sorted(cur_atoms)}, " + f"num_option_steps={num_cur_option_steps}") wait_terminate = True + elif result is False: + assert target_atoms is not None + if num_cur_option_steps <= 1 or num_cur_option_steps % 25 == 0: + wait_debug = _format_wait_target_debug( + state, target_atoms, abstract_function) + logging.debug( + "Wait continuing: target atoms not yet satisfied. " + "%s, num_option_steps=%d", wait_debug, + num_cur_option_steps) elif result is None: # No targets specified: fall back to any-atom-change cur_atoms = abstract_function(state) @@ -1766,6 +1814,8 @@ def _policy(state: State) -> Action: raise OptionExecutionFailure( "Unsound option policy.", info={"last_failed_option": last_option}) + logging.debug(f"[option_policy] Started option {cur_option.name}, " + f"initiable=True") num_cur_option_steps = 0 num_cur_option_steps += 1 @@ -1783,13 +1833,20 @@ def option_plan_to_policy( ) -> Callable[[State], Action]: """Create a policy that executes a sequence of options in order.""" queue = list(plan) # don't modify plan, just in case + total_options = len(queue) def _option_policy(state: State) -> _Option: del state # not used if not queue: + logging.info("Option plan exhausted after %d options.", + total_options) raise OptionExecutionFailure("Option plan exhausted!") option = queue.pop(0) - logging.info(f"Executing option {option.simple_str()}") + option_num = total_options - len(queue) + next_option = None if not queue else queue[0].simple_str() + logging.info("Executing option %d/%d: %s (remaining=%d, next=%s)", + option_num, total_options, option.simple_str(), + len(queue), next_option) return option return option_policy_to_policy( diff --git a/scripts/configs/predicatorv3/agents.yaml b/scripts/configs/predicatorv3/agents.yaml index d31968051..6fd77ef5c 100644 --- a/scripts/configs/predicatorv3/agents.yaml +++ b/scripts/configs/predicatorv3/agents.yaml @@ -3,5 +3,63 @@ --- includes: - common.yaml - - approaches/agents.yaml - envs/all.yaml +APPROACHES: + # agent_planner: + # NAME: "agent_planner" + # FLAGS: + # explorer: "agent_plan" + # demonstrator: "oracle_process_planning" + # terminate_on_goal_reached_and_option_terminated: True + # agent_sdk_use_local_sandbox: True + # option_model_terminate_on_repeat: False + # agent_sdk_max_agent_turns_per_iteration: 50 + # agent_planner_use_scratchpad: False + # agent_planner_use_visualize_state: True + # agent_planner_use_annotate_scene: True + # option_model_use_gui: True + # agent_bilevel: + # NAME: "agent_bilevel" + # FLAGS: + # explorer: "agent_plan" + # demonstrator: "oracle_process_planning" + # terminate_on_goal_reached_and_option_terminated: True + # agent_sdk_use_local_sandbox: True + # option_model_terminate_on_repeat: False + # agent_sdk_max_agent_turns_per_iteration: 50 + # agent_planner_use_scratchpad: False + # agent_planner_use_visualize_state: True + # agent_planner_use_annotate_scene: True + # option_model_use_gui: True + # agent_bilevel_log_state: False + # agent_bilevel_plan_sketch_file: "tests/approaches/test_data/boil_plan_sketch.txt" + agent_sim_learning: + NAME: "agent_sim_learning" + FLAGS: + explorer: "agent_bilevel" + demonstrator: "oracle_process_planning" + terminate_on_goal_reached_and_option_terminated: True + agent_sdk_use_local_sandbox: True + option_model_terminate_on_repeat: False + agent_sdk_max_agent_turns_per_iteration: 50 + agent_planner_use_scratchpad: False + agent_planner_use_visualize_state: True + agent_planner_use_annotate_scene: True + option_model_use_gui: True + agent_bilevel_log_state: False + agent_bilevel_plan_sketch_file: "tests/approaches/test_data/boil_plan_sketch.txt" + skip_test_until_last_ite_or_early_stopping: False + agent_sim_learn_oracle_sim_program: True + agent_sim_learn_oracle_sim_params: False + agent_sim_learn_oracle_sim_param_noise_scale: 1.0 # 1.0 allows successful planning but insatisficing plan; 0.8 produces satisficing plan + code_sim_learning_num_mcmc_steps: 0 + code_sim_learning_warm_start_with_lm: True + # agent_option_learning: + # NAME: "agent_option_learning" + # FLAGS: + # explorer: "agent_plan" + # option_learner: "agent" + # demonstrator: "oracle_process_planning" + # terminate_on_goal_reached_and_option_terminated: True + # agent_sdk_use_local_sandbox: True + # agent_sdk_max_agent_turns_per_iteration: 50 diff --git a/scripts/configs/predicatorv3/approaches/agents.yaml b/scripts/configs/predicatorv3/approaches/agents.yaml deleted file mode 100644 index c43ca6125..000000000 --- a/scripts/configs/predicatorv3/approaches/agents.yaml +++ /dev/null @@ -1,37 +0,0 @@ -APPROACHES: - # agent_planner: - # NAME: "agent_planner" - # FLAGS: - # explorer: "agent" - # demonstrator: "oracle_process_planning" - # terminate_on_goal_reached_and_option_terminated: True - # agent_sdk_use_local_sandbox: True - # option_model_terminate_on_repeat: False - # agent_sdk_max_agent_turns_per_iteration: 50 - # agent_planner_use_scratchpad: False - # agent_planner_use_visualize_state: True - # agent_planner_use_annotate_scene: True - # option_model_use_gui: True - agent_bilevel: - NAME: "agent_bilevel" - FLAGS: - explorer: "agent" - demonstrator: "oracle_process_planning" - terminate_on_goal_reached_and_option_terminated: True - agent_sdk_use_local_sandbox: True - option_model_terminate_on_repeat: False - agent_sdk_max_agent_turns_per_iteration: 50 - agent_planner_use_scratchpad: False - agent_planner_use_visualize_state: True - agent_planner_use_annotate_scene: True - option_model_use_gui: True - agent_bilevel_log_state: False - # agent_option_learning: - # NAME: "agent_option_learning" - # FLAGS: - # explorer: "agent" - # option_learner: "agent" - # demonstrator: "oracle_process_planning" - # terminate_on_goal_reached_and_option_terminated: True - # agent_sdk_use_local_sandbox: True - # agent_sdk_max_agent_turns_per_iteration: 50 diff --git a/scripts/configs/predicatorv3/approaches/oracle.yaml b/scripts/configs/predicatorv3/approaches/oracle.yaml deleted file mode 100644 index 7501a44b3..000000000 --- a/scripts/configs/predicatorv3/approaches/oracle.yaml +++ /dev/null @@ -1,15 +0,0 @@ -APPROACHES: - oracle: - NAME: "oracle_process_planning" - FLAGS: - demonstrator: "oracle_process_planning" - terminate_on_goal_reached_and_option_terminated: True - bilevel_plan_without_sim: True - # human_interaction: - # NAME: "human_interaction" - # FLAGS: - # human_interaction_approach_use_scripted_option: True - # human_interaction_approach_use_all_options: True - # scripted_option_dir: "scripted_option_policies" - # skill_phase_use_motion_planning: True - # terminate_on_goal_reached_and_option_terminated: True diff --git a/scripts/configs/predicatorv3/common.yaml b/scripts/configs/predicatorv3/common.yaml index c4d2a9ab4..581e5dd43 100644 --- a/scripts/configs/predicatorv3/common.yaml +++ b/scripts/configs/predicatorv3/common.yaml @@ -1,15 +1,15 @@ ARGS: - "debug" # - "use_gui" - - "make_failure_videos" - - "make_test_videos" + # - "make_failure_videos" + # - "make_test_videos" # - "make_demo_videos" # - "make_demo_images" # support images # - "make_failure_images" # query images # - "make_test_images" # query images # - "save_atoms" FLAGS: - max_initial_demos: 0 + max_initial_demos: 1 num_online_learning_cycles: 0 online_nsrt_learning_requests_per_cycle: 1 skill_phase_use_motion_planning: True diff --git a/scripts/configs/predicatorv3/oracle.yaml b/scripts/configs/predicatorv3/oracle.yaml index 1253eb4c1..45abe8371 100644 --- a/scripts/configs/predicatorv3/oracle.yaml +++ b/scripts/configs/predicatorv3/oracle.yaml @@ -3,5 +3,19 @@ --- includes: - common.yaml - - approaches/oracle.yaml - envs/all.yaml +APPROACHES: + oracle: + NAME: "oracle_process_planning" + FLAGS: + demonstrator: "oracle_process_planning" + terminate_on_goal_reached_and_option_terminated: True + bilevel_plan_without_sim: True + # human_interaction: + # NAME: "human_interaction" + # FLAGS: + # human_interaction_approach_use_scripted_option: True + # human_interaction_approach_use_all_options: True + # scripted_option_dir: "scripted_option_policies" + # skill_phase_use_motion_planning: True + # terminate_on_goal_reached_and_option_terminated: True diff --git a/scripts/configs/predicatorv3/predicator_v3.yaml b/scripts/configs/predicatorv3/predicator_v3.yaml index 29f0a5398..9678225af 100644 --- a/scripts/configs/predicatorv3/predicator_v3.yaml +++ b/scripts/configs/predicatorv3/predicator_v3.yaml @@ -18,7 +18,7 @@ APPROACHES: # agent_planner: # NAME: "agent_planner" # FLAGS: - # explorer: "agent" + # explorer: "agent_plan" # demonstrator: "oracle_process_planning" # terminate_on_goal_reached_and_option_terminated: True # # agent_sdk_use_docker_sandbox: True @@ -32,7 +32,7 @@ APPROACHES: # agent_bilevel: # NAME: "agent_bilevel" # FLAGS: - # explorer: "agent" + # explorer: "agent_plan" # demonstrator: "oracle_process_planning" # terminate_on_goal_reached_and_option_terminated: True # # agent_sdk_use_docker_sandbox: True @@ -46,7 +46,7 @@ APPROACHES: # agent_option_learning: # NAME: "agent_option_learning" # FLAGS: - # explorer: "agent" + # explorer: "agent_plan" # option_learner: "agent" # demonstrator: "oracle_process_planning" # terminate_on_goal_reached_and_option_terminated: True @@ -60,7 +60,7 @@ APPROACHES: # terminate_on_goal_reached_and_option_terminated: True # bilevel_plan_without_sim: True # max_initial_demos: 0 - # explorer: "agent" + # explorer: "agent_plan" # num_online_learning_cycles: 4 # online_nsrt_learning_requests_per_cycle: 1 ENVS: diff --git a/scripts/run_blocks_perception.py b/scripts/run_blocks_perception.py index 82b8e2693..585d4d067 100644 --- a/scripts/run_blocks_perception.py +++ b/scripts/run_blocks_perception.py @@ -98,9 +98,9 @@ from predicators import utils from predicators.envs.pybullet_blocks import PyBulletBlocksEnv -from predicators.envs.pybullet_env import create_pybullet_block from predicators.pybullet_helpers.camera import create_gui_connection from predicators.pybullet_helpers.geometry import Pose3D +from predicators.pybullet_helpers.objects import create_pybullet_block from predicators.pybullet_helpers.robots import \ create_single_arm_pybullet_robot from predicators.settings import CFG diff --git a/setup.py b/setup.py index 30408cddf..812788624 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,7 @@ "psutil", "claude-agent-sdk", "nest_asyncio", + "emcee", ], include_package_data=True, extras_require={ diff --git a/tests/approaches/test_agent_bilevel_approach.py b/tests/approaches/test_agent_bilevel_approach.py index 4d399883d..57808f594 100644 --- a/tests/approaches/test_agent_bilevel_approach.py +++ b/tests/approaches/test_agent_bilevel_approach.py @@ -1,5 +1,6 @@ """Tests for AgentBilevelApproach -- parsing and refinement logic.""" # pylint: disable=protected-access,import-outside-toplevel +import os from unittest.mock import MagicMock, patch import numpy as np @@ -12,6 +13,8 @@ from predicators.structs import Action, GroundAtom, Object, \ ParameterizedOption, Predicate, State, Task, Type +_TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "test_data") + # --------------------------------------------------------------------------- # Shared fixtures # --------------------------------------------------------------------------- @@ -804,6 +807,32 @@ def test_no_valid_options_raises(self): with pytest.raises(ApproachFailure, match="Parsed empty"): approach._query_agent_for_plan_sketch(task) + def test_sketch_from_file(self): + """Load sketch from a saved text file via CFG option.""" + approach, _, task = _make_approach() + sketch_path = os.path.join(_TEST_DATA_DIR, "simple_plan_sketch.txt") + + utils.reset_config({ + "env": "cover", + "approach": "agent_bilevel", + "num_train_tasks": 1, + "num_test_tasks": 1, + "seed": 42, + "agent_bilevel_plan_sketch_file": sketch_path, + }) + + sketch = approach._query_agent_for_plan_sketch(task) + + assert len(sketch) == 2 + assert sketch[0].option.name == "Pick" + assert list(sketch[0].objects) == [_block0] + assert sketch[0].subgoal_atoms is not None + assert GroundAtom(_Holding, [_block0]) in sketch[0].subgoal_atoms + assert sketch[1].option.name == "Place" + assert list(sketch[1].objects) == [_block0, _block1] + assert sketch[1].subgoal_atoms is not None + assert GroundAtom(_On, [_block0, _block1]) in sketch[1].subgoal_atoms + # --------------------------------------------------------------------------- # Tests: _sample_params diff --git a/tests/approaches/test_agent_sim_learning_approach.py b/tests/approaches/test_agent_sim_learning_approach.py new file mode 100644 index 000000000..f5e808700 --- /dev/null +++ b/tests/approaches/test_agent_sim_learning_approach.py @@ -0,0 +1,371 @@ +"""Integration test: GT simulator + backtracking refinement solves boil. + +Verifies that given a correct plan sketch (from a real agent run) and a +ground-truth simulator program, the hybrid learned option model +(PyBullet + learned process dynamics) can find continuous parameters +that solve a pybullet_boil task. +""" +# pylint: disable=protected-access +import logging +import os +import re +from typing import List, Optional, Sequence, Set, Tuple + +import numpy as np +import pytest + +from predicators import utils +from predicators.approaches.agent_bilevel_approach import _SketchStep +from predicators.code_sim_learning.utils import LearnedSimulator, \ + apply_rules, merge_updates +from predicators.envs import create_new_env +from predicators.ground_truth_models import get_gt_options +from predicators.ground_truth_models.boil.gt_simulator import \ + BOIL_PARAM_SPECS, PROCESS_RULES +from predicators.option_model import _OracleOptionModel +from predicators.planning import run_backtracking_refinement +from predicators.structs import GroundAtom, Object, ParameterizedOption, \ + Predicate + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def _setup_env(): + """Create boil env and return (env, task, options_dict, objects_dict).""" + utils.reset_config({ + "env": "pybullet_boil", + "seed": 0, + "num_train_tasks": 1, + "num_test_tasks": 1, + "boil_goal": "simple", + "boil_num_jugs_train": [1], + "boil_num_jugs_test": [1], + "boil_num_burner_train": [1], + "boil_num_burner_test": [1], + "option_model_use_gui": False, + "wait_option_terminate_on_atom_change": True, + }) + env = create_new_env("pybullet_boil", do_cache=False, use_gui=False) + task = [t.task for t in env.get_test_tasks()][0] + options = get_gt_options(env.get_name()) + options_dict = {o.name: o for o in options} + objects_dict = {obj.name: obj for obj in task.init} + return env, task, options_dict, objects_dict + + +def _build_oracle_model(env): + """Build an oracle option model.""" + options = get_gt_options(env.get_name()) + oracle = _OracleOptionModel(options, env.simulate) + preds = env.predicates + oracle._abstract_function = lambda s: utils.abstract(s, preds) + return oracle + + +def _build_kinematics_only_oracle(env): + """Build an oracle that only handles kinematics (no process dynamics). + + Creates a separate env instance with process dynamics disabled, so + that water filling, heating, and happiness are not simulated. + """ + base_env = create_new_env("pybullet_boil", + do_cache=False, + use_gui=False, + skip_process_dynamics=True) + options = get_gt_options(base_env.get_name()) + oracle = _OracleOptionModel(options, base_env.simulate) + preds = env.predicates + oracle._abstract_function = lambda s: utils.abstract(s, preds) + return oracle + + +def _build_combined_model(env): + """Build a combined model: base-sim-only env + GT step-level dynamics. + + Mirrors AgentSimLearningApproach: wraps GT rules in a + LearnedSimulator via apply_rules and composes with a base-sim-only + env. + """ + base_env = create_new_env("pybullet_boil", + do_cache=False, + use_gui=False, + skip_process_dynamics=True) + gt_params = {s.name: s.init_value for s in BOIL_PARAM_SPECS} + rules = PROCESS_RULES + + simulator = LearnedSimulator( + step_fn=lambda s, _r=rules, _p=gt_params: apply_rules(s, _r, _p), + name="gt_combined") + + def combined_simulate(state, action): + kin_state = base_env.simulate(state, action) + updates = simulator.predict_step(kin_state) + if not updates: + return kin_state + return merge_updates(kin_state, updates) + + options = get_gt_options(env.get_name()) + model = _OracleOptionModel(options, combined_simulate) + preds = env.predicates + model._abstract_function = lambda s: utils.abstract(s, preds) + return model + + +def _parse_sketch_from_file( + sketch_file: str, + options: Set[ParameterizedOption], + types: Set, + predicates: Set[Predicate], + objects: Sequence[Object], +) -> List[_SketchStep]: + """Parse a plan sketch from a text file, same as agent_bilevel_approach.""" + with open(sketch_file, "r", encoding="utf-8") as f: + plan_text = f.read().strip() + + # Phase 1: parse options + objects (no continuous params) + parsed = utils.parse_model_output_into_option_plan( + plan_text, objects, types, options, parse_continuous_params=False) + assert parsed, f"Parsed empty plan sketch from {sketch_file}" + + # Phase 2: parse subgoal annotations + pred_map = {p.name: p for p in predicates} + obj_map = {o.name: o for o in objects} + option_names = {o.name for o in options} + subgoal_re = re.compile(r'->\s*\{([^}]*)\}') + atom_re = re.compile(r'(NOT\s+)?(\w+)\(([^)]*)\)') + + subgoals: List[Optional[Tuple[Set[GroundAtom], Set[GroundAtom]]]] = [] + for line in plan_text.split('\n'): + stripped = line.strip() + if not stripped: + continue + first_token = stripped.split('(')[0] + if first_token not in option_names: + continue + sg_match = subgoal_re.search(stripped) + if not sg_match: + subgoals.append(None) + continue + atoms_text = sg_match.group(1) + pos_atoms: Set[GroundAtom] = set() + neg_atoms: Set[GroundAtom] = set() + for atom_match in atom_re.finditer(atoms_text): + is_neg = atom_match.group(1) is not None + pred_name = atom_match.group(2) + obj_names = [ + n.strip().split(':')[0] for n in atom_match.group(3).split(',') + ] + if pred_name not in pred_map: + continue + pred = pred_map[pred_name] + try: + objs: Sequence[Object] = [obj_map[n] for n in obj_names] + except KeyError: + continue + if len(objs) != len(pred.types): + continue + atom = GroundAtom(pred, objs) + if is_neg: + neg_atoms.add(atom) + else: + pos_atoms.add(atom) + if pos_atoms or neg_atoms: + subgoals.append((pos_atoms, neg_atoms)) + else: + subgoals.append(None) + + # Zip into sketch steps + sketch = [] + for i, (option, objs, _) in enumerate(parsed): + sg = subgoals[i] if i < len(subgoals) else None + if sg is not None: + pos, neg = sg + sketch.append( + _SketchStep(option=option, + objects=objs, + subgoal_atoms=pos if pos else None, + subgoal_neg_atoms=neg if neg else None)) + else: + sketch.append( + _SketchStep(option=option, objects=objs, subgoal_atoms=None)) + return sketch + + +def _informed_place_params(pre_state, sketch, step_idx, rng, n): + """Sample Place params biased toward the contextual target.""" + step = sketch[step_idx] + low = step.option.params_space.low + high = step.option.params_space.high + eps = 1e-4 + + next_step = sketch[step_idx + 1] if step_idx + 1 < n else None + + if next_step and "Faucet" in next_step.option.name: + for obj in pre_state: + if obj.type.name == "faucet": + fx = pre_state.get(obj, "x") + fy = pre_state.get(obj, "y") + frot = pre_state.get(obj, "rot") + # The jug has a physics offset after drop, so target + # slightly past the faucet output to compensate. + out_x = fx + 0.15 * np.cos(frot) + out_y = fy - 0.15 * np.sin(frot) + # Target near faucet output x but lower y (IK-reachable). + x = np.clip(out_x + rng.normal(0, 0.02), low[0] + eps, + high[0] - eps) + y = np.clip(out_y - 0.05 + rng.normal(0, 0.03), low[1] + eps, + high[1] - eps) + z = np.clip(low[2] + 0.02 + abs(rng.normal(0, 0.01)), + low[2] + eps, high[2] - eps) + # Negative yaw helps place jug closer to faucet output. + yaw = np.clip(rng.normal(-0.3, 0.5), low[3] + eps, + high[3] - eps) + return np.array([x, y, z, yaw], dtype=np.float32) + + if next_step and "Burner" in next_step.option.name: + for obj in pre_state: + if obj.type.name == "burner": + bx = pre_state.get(obj, "x") + by = pre_state.get(obj, "y") + x = np.clip(bx + rng.normal(0, 0.05), low[0] + eps, + high[0] - eps) + y = np.clip(by + rng.normal(0, 0.05), low[1] + eps, + high[1] - eps) + # Bias z toward low end for reliable IK. + z = np.clip(low[2] + 0.02 + abs(rng.normal(0, 0.01)), + low[2] + eps, high[2] - eps) + yaw = rng.uniform(low[3] + eps, high[3] - eps) + return np.array([x, y, z, yaw], dtype=np.float32) + + return rng.uniform(low + eps, high - eps).astype(np.float32) + + +def _refine(task, + sketch, + option_model, + predicates, + seed=0, + max_samples=200, + timeout=600.0): + """Run backtracking refinement with informed Place sampling.""" + rng = np.random.default_rng(seed) + n = len(sketch) + max_tries = [ + max_samples if step.option.params_space.shape[0] > 0 else 1 + for step in sketch + ] + + def sample_fn(idx, state, rng_): + step = sketch[idx] + if step.option.params_space.shape[0] == 0: + params = np.array([], dtype=np.float32) + elif step.option.name == "Place": + params = _informed_place_params(state, sketch, idx, rng_, n) + else: + low = step.option.params_space.low + high = step.option.params_space.high + params = rng_.uniform(low, high).astype(np.float32) + grounded = step.option.ground(step.objects, params) + if grounded.name == "Wait" and step.subgoal_atoms is not None: + grounded.memory["wait_target_atoms"] = step.subgoal_atoms + return grounded + + def validate_fn(idx, _pre, _opt, post_state, _n_acts): + step = sketch[idx] + if step.subgoal_atoms is not None: + current_atoms = utils.abstract(post_state, predicates) + if not step.subgoal_atoms.issubset(current_atoms): + missing = step.subgoal_atoms - current_atoms + return False, f"subgoal missing: {missing}" + if idx == n - 1 and not task.goal_holds(post_state): + return False, "goal not reached" + return True, "" + + plan, success, total_samples = run_backtracking_refinement( + init_state=task.init, + option_model=option_model, + n_steps=n, + max_tries=max_tries, + sample_fn=sample_fn, + validate_fn=validate_fn, + rng=rng, + timeout=timeout, + ) + logger.info("Refinement: %s, %d total samples", + "success" if success else "failed", total_samples) + return [p for p in plan if p is not None], success + + +SKETCH_FILE = os.path.join(os.path.dirname(__file__), "test_data", + "boil_plan_sketch.txt") + + +@pytest.mark.parametrize("model_type", ["oracle", "combined"]) +def test_boil_sketch_refinement(model_type): + """Test that backtracking refinement solves the first test task.""" + env, task, _options_dict, _objects_dict = _setup_env() + predicates = env.predicates + options = get_gt_options(env.get_name()) + + if model_type == "oracle": + option_model = _build_oracle_model(env) + else: + option_model = _build_combined_model(env) + + sketch = _parse_sketch_from_file(SKETCH_FILE, options, env.types, + predicates, list(task.init)) + plan, success = _refine(task, + sketch, + option_model, + predicates, + max_samples=500, + timeout=1200.0) + + logger.info("Model=%s, success=%s, plan_len=%d", model_type, success, + len(plan)) + if success: + for i, opt in enumerate(plan): + objs = ", ".join(o.name for o in opt.objects) + params = ", ".join(f"{p:.3f}" for p in opt.params) + logger.info(" %d: %s(%s)[%s]", i, opt.name, objs, params) + + assert success, (f"Refinement failed with {model_type} model. " + f"Partial plan: {len(plan)} steps.") + + # Forward validation: re-execute the plan in the oracle model (full + # env dynamics) to verify the plan actually solves the task. + # Always uses the oracle regardless of which model found the plan. + oracle_model = _build_oracle_model(env) + n = len(plan) + + def fwd_sample_fn(i, _s, _r): + return plan[i] + + def fwd_validate_fn(i, _s, _o, post, _n): + if i == n - 1 and not task.goal_holds(post): + return False, "goal not reached" + return True, "" + + _, fwd_success, _ = run_backtracking_refinement( + init_state=task.init, + option_model=oracle_model, + n_steps=n, + max_tries=[1] * n, + sample_fn=fwd_sample_fn, + validate_fn=fwd_validate_fn, + rng=np.random.default_rng(0), + timeout=600.0, + ) + if fwd_success: + logger.info("Forward validation passed for %s model.", model_type) + else: + logger.warning( + "Forward validation failed for %s model " + "(PyBullet state reconstruction is imperfect).", model_type) + + +if __name__ == "__main__": + import sys + _model = sys.argv[1] if len(sys.argv) > 1 else "oracle" + test_boil_sketch_refinement(_model) diff --git a/tests/approaches/test_data/boil_plan_sketch.txt b/tests/approaches/test_data/boil_plan_sketch.txt new file mode 100644 index 000000000..3553f3af8 --- /dev/null +++ b/tests/approaches/test_data/boil_plan_sketch.txt @@ -0,0 +1,10 @@ +PickJug(robot:robot, jug0:jug) -> {Holding(robot:robot, jug0:jug)} +Place(robot:robot) -> {JugAtFaucet(jug0:jug, faucet:faucet)} +SwitchFaucetOn(robot:robot, faucet:faucet) -> {FaucetOn(faucet:faucet)} +Wait(robot:robot) -> {JugFilled(jug0:jug)} +SwitchFaucetOff(robot:robot, faucet:faucet) -> {FaucetOff(faucet:faucet)} +PickJug(robot:robot, jug0:jug) -> {Holding(robot:robot, jug0:jug)} +Place(robot:robot) -> {JugAtBurner(jug0:jug, burner0:burner)} +SwitchBurnerOn(robot:robot, burner0:burner) -> {BurnerOn(burner0:burner)} +Wait(robot:robot) -> {WaterBoiled(jug0:jug)} +SwitchBurnerOff(robot:robot, burner0:burner) -> {BurnerOff(burner0:burner)} diff --git a/tests/approaches/test_data/simple_plan_sketch.txt b/tests/approaches/test_data/simple_plan_sketch.txt new file mode 100644 index 000000000..c14ff2dd5 --- /dev/null +++ b/tests/approaches/test_data/simple_plan_sketch.txt @@ -0,0 +1,2 @@ +Pick(block0:block) -> {Holding(block0:block)} +Place(block0:block, block1:block) -> {On(block0:block, block1:block)} diff --git a/tests/code_sim_learning/test_param_fitting.py b/tests/code_sim_learning/test_param_fitting.py new file mode 100644 index 000000000..742f795d9 --- /dev/null +++ b/tests/code_sim_learning/test_param_fitting.py @@ -0,0 +1,322 @@ +"""Test parameter fitting recovers GT simulator parameters. + +Uses step-level transitions from a real oracle trajectory (boil env), +then fits from perturbed initial values via emcee. +""" + +import logging +import os +import re +from typing import Dict, List, Optional, Sequence, Set, Tuple + +import numpy as np + +import predicators.approaches # noqa: F401 # pylint: disable=unused-import +from predicators import utils +from predicators.approaches.agent_bilevel_approach import _SketchStep +from predicators.code_sim_learning.training import ParamSpec, fit_params +from predicators.envs import create_new_env +from predicators.ground_truth_models import get_gt_options +from predicators.ground_truth_models.boil.gt_simulator import \ + BOIL_PARAM_SPECS, PROCESS_RULES, get_gt_process_features +from predicators.option_model import _OracleOptionModel +from predicators.planning import run_backtracking_refinement +from predicators.structs import Action, GroundAtom, LowLevelTrajectory, \ + Object, ParameterizedOption, Predicate, State + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Ground-truth parameter values (from BOIL_PARAM_SPECS). +GT_PARAMS = {s.name: s.init_value for s in BOIL_PARAM_SPECS} + +SKETCH_FILE = os.path.join(os.path.dirname(__file__), "..", "approaches", + "test_data", "boil_plan_sketch.txt") + + +def _setup_env(): + """Create boil env and return (env, task, options, predicates).""" + utils.reset_config({ + "env": "pybullet_boil", + "seed": 0, + "num_train_tasks": 1, + "num_test_tasks": 1, + "boil_goal": "simple", + "boil_num_jugs_train": [1], + "boil_num_jugs_test": [1], + "boil_num_burner_train": [1], + "boil_num_burner_test": [1], + "option_model_use_gui": False, + "wait_option_terminate_on_atom_change": True, + }) + env = create_new_env("pybullet_boil", do_cache=False, use_gui=False) + task = [t.task for t in env.get_train_tasks()][0] + options = get_gt_options(env.get_name()) + return env, task, options + + +def _build_oracle_model(env): + """Build an oracle option model.""" + options = get_gt_options(env.get_name()) + oracle = _OracleOptionModel(options, env.simulate) + preds = env.predicates + oracle._abstract_function = lambda s: utils.abstract(s, preds) # pylint: disable=protected-access + return oracle + + +def _parse_sketch_from_file( + sketch_file: str, + options: Set[ParameterizedOption], + types: Set, + predicates: Set[Predicate], + objects: Sequence[Object], +) -> List[_SketchStep]: + """Parse a plan sketch from a text file.""" + with open(sketch_file, "r", encoding="utf-8") as f: + plan_text = f.read().strip() + + parsed = utils.parse_model_output_into_option_plan( + plan_text, objects, types, options, parse_continuous_params=False) + assert parsed, f"Parsed empty plan sketch from {sketch_file}" + + pred_map = {p.name: p for p in predicates} + obj_map = {o.name: o for o in objects} + option_names = {o.name for o in options} + subgoal_re = re.compile(r'->\s*\{([^}]*)\}') + atom_re = re.compile(r'(NOT\s+)?(\w+)\(([^)]*)\)') + + subgoals: List[Optional[Tuple[Set[GroundAtom], Set[GroundAtom]]]] = [] + for line in plan_text.split('\n'): + stripped = line.strip() + if not stripped: + continue + first_token = stripped.split('(')[0] + if first_token not in option_names: + continue + sg_match = subgoal_re.search(stripped) + if not sg_match: + subgoals.append(None) + continue + atoms_text = sg_match.group(1) + pos_atoms: Set[GroundAtom] = set() + neg_atoms: Set[GroundAtom] = set() + for atom_match in atom_re.finditer(atoms_text): + is_neg = atom_match.group(1) is not None + pred_name = atom_match.group(2) + obj_names = [ + n.strip().split(':')[0] for n in atom_match.group(3).split(',') + ] + if pred_name not in pred_map: + continue + pred = pred_map[pred_name] + try: + objs: Sequence[Object] = [obj_map[n] for n in obj_names] + except KeyError: + continue + if len(objs) != len(pred.types): + continue + atom = GroundAtom(pred, objs) + if is_neg: + neg_atoms.add(atom) + else: + pos_atoms.add(atom) + if pos_atoms or neg_atoms: + subgoals.append((pos_atoms, neg_atoms)) + else: + subgoals.append(None) + + sketch = [] + for i, (option, objs, _) in enumerate(parsed): + sg = subgoals[i] if i < len(subgoals) else None + if sg is not None: + pos, neg = sg + sketch.append( + _SketchStep(option=option, + objects=objs, + subgoal_atoms=pos if pos else None, + subgoal_neg_atoms=neg if neg else None)) + else: + sketch.append( + _SketchStep(option=option, objects=objs, subgoal_atoms=None)) + return sketch + + +def _informed_place_params(pre_state, sketch, step_idx, rng, n): + """Sample Place params biased toward the contextual target.""" + step = sketch[step_idx] + low = step.option.params_space.low + high = step.option.params_space.high + eps = 1e-4 + + next_step = sketch[step_idx + 1] if step_idx + 1 < n else None + + if next_step and "Faucet" in next_step.option.name: + for obj in pre_state: + if obj.type.name == "faucet": + fx = pre_state.get(obj, "x") + fy = pre_state.get(obj, "y") + frot = pre_state.get(obj, "rot") + out_x = fx + 0.15 * np.cos(frot) + out_y = fy - 0.15 * np.sin(frot) + x = np.clip(out_x + rng.normal(0, 0.02), low[0] + eps, + high[0] - eps) + y = np.clip(out_y - 0.05 + rng.normal(0, 0.03), low[1] + eps, + high[1] - eps) + z = np.clip(low[2] + 0.02 + abs(rng.normal(0, 0.01)), + low[2] + eps, high[2] - eps) + yaw = np.clip(rng.normal(-0.3, 0.5), low[3] + eps, + high[3] - eps) + return np.array([x, y, z, yaw], dtype=np.float32) + + if next_step and "Burner" in next_step.option.name: + for obj in pre_state: + if obj.type.name == "burner": + bx = pre_state.get(obj, "x") + by = pre_state.get(obj, "y") + x = np.clip(bx + rng.normal(0, 0.05), low[0] + eps, + high[0] - eps) + y = np.clip(by + rng.normal(0, 0.05), low[1] + eps, + high[1] - eps) + z = np.clip(low[2] + 0.02 + abs(rng.normal(0, 0.01)), + low[2] + eps, high[2] - eps) + yaw = rng.uniform(low[3] + eps, high[3] - eps) + return np.array([x, y, z, yaw], dtype=np.float32) + + return rng.uniform(low + eps, high - eps).astype(np.float32) + + +def _generate_oracle_transitions( + env, + task, + options, + oracle, +) -> List[Tuple[State, Action, State]]: + """Generate (s, a, s') triples by running the oracle on the boil task. + + Parses the plan sketch, runs backtracking refinement to find + continuous parameters, then replays the plan through the oracle + model to collect step-level transitions with real actions. + """ + predicates = env.predicates + sketch = _parse_sketch_from_file(SKETCH_FILE, options, env.types, + predicates, list(task.init)) + n = len(sketch) + rng = np.random.default_rng(0) + max_tries = [ + 500 if step.option.params_space.shape[0] > 0 else 1 for step in sketch + ] + + def sample_fn(idx, state, rng_): + step = sketch[idx] + if step.option.params_space.shape[0] == 0: + params = np.array([], dtype=np.float32) + elif step.option.name == "Place": + params = _informed_place_params(state, sketch, idx, rng_, n) + else: + low = step.option.params_space.low + high = step.option.params_space.high + params = rng_.uniform(low, high).astype(np.float32) + grounded = step.option.ground(step.objects, params) + if grounded.name == "Wait" and step.subgoal_atoms is not None: + grounded.memory["wait_target_atoms"] = step.subgoal_atoms + return grounded + + def validate_fn(idx, _pre, _opt, post_state, _n_acts): + step = sketch[idx] + if step.subgoal_atoms is not None: + current_atoms = utils.abstract(post_state, predicates) + if not step.subgoal_atoms.issubset(current_atoms): + return False, "subgoal missing" + if idx == n - 1 and not task.goal_holds(post_state): + return False, "goal not reached" + return True, "" + + # Collect trajectories during refinement (not replay, since + # PyBullet state reconstruction is imperfect). + step_trajectories: Dict[int, LowLevelTrajectory] = {} + + orig_validate = validate_fn + + def collecting_validate_fn(idx, pre, opt, post_state, n_acts): + ok, reason = orig_validate(idx, pre, opt, post_state, n_acts) + if ok and oracle.last_trajectory is not None: + step_trajectories[idx] = oracle.last_trajectory + return ok, reason + + _plan, success, _ = run_backtracking_refinement( + init_state=task.init, + option_model=oracle, + n_steps=n, + max_tries=max_tries, + sample_fn=sample_fn, + validate_fn=collecting_validate_fn, + rng=rng, + timeout=1200.0, + ) + assert success, "Need a successful plan to generate transitions" + + # Extract step-level transitions from collected trajectories. + transitions: List[Tuple[State, Action, State]] = [] + for idx in sorted(step_trajectories.keys()): + traj = step_trajectories[idx] + for i in range(len(traj.actions)): + transitions.append( + (traj.states[i], traj.actions[i], traj.states[i + 1])) + + logger.info("Collected %d step-level transitions from oracle.", + len(transitions)) + return transitions + + +def test_emcee_recovers_rate_params(): + """Fit perturbed rate params from oracle-generated data.""" + np.random.seed(42) + env, task, options = _setup_env() + oracle = _build_oracle_model(env) + transitions = _generate_oracle_transitions(env, task, options, oracle) + process_features = get_gt_process_features() + + logger.info("Generated %d oracle transitions.", len(transitions)) + + def simulator_fn(state, _action, params): + updates = {} + for rule in PROCESS_RULES: + updates = rule(state, updates, params) + return updates + + # Perturb rate params (50%), keep others at true. + param_specs = [] + for s in BOIL_PARAM_SPECS: + if s.name in ("water_fill_speed", "heating_speed", "happiness_speed"): + param_specs.append(ParamSpec(s.name, s.init_value * 0.5)) + else: + param_specs.append(s) + + result = fit_params( + simulator_fn=simulator_fn, + transitions=transitions, + param_specs=param_specs, + process_features=process_features, + num_walkers=32, + num_steps=500, + burn_in=200, + noise_sigma=0.05, + ) + + fitted = result.point_estimate + logger.info("Fitted params (posterior mean):") + for name, val in fitted.items(): + true_val = GT_PARAMS[name] + rel_err = abs(val - true_val) / max(true_val, 1e-8) + logger.info(" %s: fitted=%.4f, true=%.4f, rel_err=%.1f%%", name, val, + true_val, rel_err * 100) + + for name in ["water_fill_speed", "heating_speed", "happiness_speed"]: + true_val = GT_PARAMS[name] + fitted_val = fitted[name] + rel_err = abs(fitted_val - true_val) / true_val + assert rel_err < 0.3, ( + f"{name}: fitted={fitted_val:.4f}, true={true_val:.4f}, " + f"rel_err={rel_err:.1%}") + + logger.info("All rate parameter recovery checks passed.") diff --git a/tests/code_sim_learning/test_training.py b/tests/code_sim_learning/test_training.py new file mode 100644 index 000000000..4f294c3a3 --- /dev/null +++ b/tests/code_sim_learning/test_training.py @@ -0,0 +1,23 @@ +"""Tests for code sim-learning training utilities.""" + +import numpy as np + +from predicators import utils +from predicators.code_sim_learning.training import ParamSpec, fit_params + + +def test_fit_params_can_skip_training_with_cfg(): + """Test that CFG can disable parameter fitting.""" + utils.reset_config({"code_sim_learning_num_mcmc_steps": 0}) + param_specs = [ParamSpec("rate", 2.5), ParamSpec("threshold", 0.7)] + + result = fit_params( + simulator_fn=lambda _s, _a, _p: {}, + transitions=[], + param_specs=param_specs, + process_features={}, + ) + + assert result.point_estimate == {"rate": 2.5, "threshold": 0.7} + np.testing.assert_allclose(result.samples, np.array([[2.5, 0.7]])) + np.testing.assert_allclose(result.log_probs, np.array([0.0])) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 11d03ca22..fdf922884 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -1,5 +1,6 @@ """Test cases for dataset generation.""" import os +import pickle as pkl import shutil from contextlib import nullcontext as does_not_raise @@ -304,6 +305,11 @@ def _ensure_cover_demo_data_exists(): this data file existing (for truncation and extension). When pytest- split distributes parametrized cases across groups, the generating case may not run first, so we ensure it here. + + Earlier tests (test_demo_dataset's max_initial_demos / impossible- + goal blocks) write a partial dataset under this same filename, so a + bare ``os.path.exists`` check is not enough — we also have to verify + the file actually carries 7 trajectories before trusting it. """ saved_cfg = { "env": CFG.env, @@ -323,7 +329,12 @@ def _ensure_cover_demo_data_exists(): }) dataset_fname, _ = utils.create_dataset_filename_str( saving_ground_atoms=False) - if not os.path.exists(dataset_fname): + has_full_dataset = False + if os.path.exists(dataset_fname): + with open(dataset_fname, "rb") as f: + existing = pkl.load(f) + has_full_dataset = len(existing.trajectories) == 7 + if not has_full_dataset: env = CoverEnv() train_tasks = [t.task for t in env.get_train_tasks()] predicates, _ = utils.parse_config_excluded_predicates(env) @@ -385,6 +396,43 @@ def test_demo_dataset_loading(num_train_tasks, load_data, demonstrator, assert "Cannot load data" in str(e) +def test_ensure_cover_demo_data_regenerates_partial_file(): + """A partial cover demo file under the 7-task name must be regenerated. + + Earlier tests in test_demo_dataset can write a 3-trajectory dataset + under ``cover__demo__oracle__7__...`` (e.g. the max_initial_demos + block). When pytest-split lands a downstream test that depends on a + 7-trajectory file (test_demo_dataset_loading[10-True-oracle-...]) in + a different shard, that downstream test loads the truncated file and + the load+extend path produces the wrong total. Lock in the helper's + "validate count, not just existence" contract. + """ + # Compute the 7-task filename in the default data_dir, since the + # helper resets data_dir during its reset_config call. + utils.reset_config({ + "env": "cover", + "approach": "random_actions", + "offline_data_method": "demo", + "offline_data_planning_timeout": 500, + "option_learner": "no_learning", + "num_train_tasks": 7, + "load_data": False, + "demonstrator": "oracle", + }) + dataset_fname, _ = utils.create_dataset_filename_str( + saving_ground_atoms=False) + os.makedirs(os.path.dirname(dataset_fname) or ".", exist_ok=True) + # Stage a stale empty dataset under the 7-task filename to simulate + # the leftover from earlier tests' partial writes. + stub = Dataset([]) + with open(dataset_fname, "wb") as f: + pkl.dump(stub, f) + _ensure_cover_demo_data_exists() + with open(dataset_fname, "rb") as f: + regenerated = pkl.load(f) + assert len(regenerated.trajectories) == 7 + + def _ensure_blocks_demo_data_exists(): """Generate the 10-task blocks demo dataset if it doesn't exist. diff --git a/tests/envs/test_pybullet_blocks.py b/tests/envs/test_pybullet_blocks.py index 40922fed6..739334493 100644 --- a/tests/envs/test_pybullet_blocks.py +++ b/tests/envs/test_pybullet_blocks.py @@ -70,7 +70,7 @@ def set_state(self, state): simulator_state=joint_positions) self._current_observation = state_with_sim self._current_task = None - self._reset_state(state_with_sim) + self._set_state(state_with_sim) def get_state(self): """Expose get_state().""" @@ -405,6 +405,46 @@ def test_pybullet_blocks_putontable_corners(env): assert abs(state.get(block, "pose_y") - by) < 1e-2 +def test_robot_matches_state_atol_forces_reset_on_small_drift(env): + """A small joint drift (~5e-3) must NOT be treated as "already there". + + Locks in the _robot_matches_state atol regression: with the prior + 1e-2 tolerance, a caller-supplied initial_joint_positions hint was + silently accepted whenever the live joints were within 1e-2 of + initial, leaving the EE pose ~3e-3 off the requested state — past + the 1e-3 State.allclose threshold. The fast-path must agree with + State.allclose precision. + """ + robot = env.robot + block = Object("block0", env.block_type) + bx = (env.x_lb + env.x_ub) / 2 + by = (env.y_lb + env.y_ub) / 2 + bz = env.table_height + 0.5 * env.block_size + rx, ry, rz = env.robot_init_x, env.robot_init_y, env.robot_init_z + rf = env.open_fingers + init_state = State({ + robot: np.array([rx, ry, rz, rf]), + block: np.array([bx, by, bz, 0.0, 1.0, 0.0, 0.0]), + }) + # First, get the env into the requested init pose. + env.set_state(init_state) + initial_joints = list(env._pybullet_robot.initial_joint_positions) # pylint: disable=protected-access + # Nudge the live joints by ~5e-3 (within old 1e-2 atol, outside new + # 1e-3 atol) so the fast-path *would* incorrectly accept under the + # old tolerance. + drifted_joints = [j + 5e-3 for j in initial_joints] + env._pybullet_robot.set_joints(drifted_joints) # pylint: disable=protected-access + # State carries the original initial joints as a "should be here" hint. + hint_state = utils.PyBulletState(init_state.data, + simulator_state=initial_joints) + # The fast-path comparison must reject the drift. + assert not env._robot_matches_state(hint_state) # pylint: disable=protected-access + # And calling _set_state must actually move the robot back to the + # requested EE pose at State.allclose precision (atol=1e-3). + env._set_state(hint_state) # pylint: disable=protected-access + assert env.get_state().allclose(init_state) + + def test_pybullet_blocks_close_pick_place(env): """Test a tricky case where we attempt to pick and place immediately next to a pile of blocks. diff --git a/tests/envs/test_pybullet_cover.py b/tests/envs/test_pybullet_cover.py index fe012bd94..376b88d71 100644 --- a/tests/envs/test_pybullet_cover.py +++ b/tests/envs/test_pybullet_cover.py @@ -43,7 +43,7 @@ def set_state(self, state): simulator_state=joint_positions) self._current_observation = state_with_sim self._current_task = None - self._reset_state(state_with_sim) + self._set_state(state_with_sim) def get_state(self): """Expose get_state().""" diff --git a/tests/explorers/test_agent_bilevel_explorer.py b/tests/explorers/test_agent_bilevel_explorer.py new file mode 100644 index 000000000..0db0dc237 --- /dev/null +++ b/tests/explorers/test_agent_bilevel_explorer.py @@ -0,0 +1,331 @@ +"""Tests for AgentBilevelExplorer.""" +# pylint: disable=protected-access + +from unittest.mock import AsyncMock, MagicMock + +import numpy as np +import pytest +from gym.spaces import Box + +from predicators import utils +from predicators.agent_sdk.tools import ToolContext +from predicators.explorers import create_explorer +from predicators.explorers.agent_bilevel_explorer import AgentBilevelExplorer +from predicators.explorers.base_explorer import BaseExplorer +from predicators.structs import Action, GroundAtom, Object, \ + ParameterizedOption, Predicate, State, Task, Type + +# --------------------------------------------------------------------------- +# Fixtures (parallel the bilevel approach tests) +# --------------------------------------------------------------------------- + +_block_type = Type("block", ["x", "y", "held"]) +_robot_type = Type("robot", ["x", "y"]) + +_block0 = Object("block0", _block_type) +_block1 = Object("block1", _block_type) +_robot = Object("robot0", _robot_type) + +_Holding = Predicate("Holding", [_block_type], + lambda s, o: s.get(o[0], "held") > 0.5) +_On = Predicate("On", [_block_type, _block_type], + lambda s, o: abs(s.get(o[0], "x") - s.get(o[1], "x")) < 0.1) +_HandEmpty = Predicate("HandEmpty", [_robot_type], lambda s, o: True) + +_ALL_PREDICATES = {_Holding, _On, _HandEmpty} +_ALL_TYPES = {_block_type, _robot_type} + + +def _noop_policy(_s, _m, _o, _p): + return Action(np.zeros(1, dtype=np.float32)) + + +def _always_true(_s, _m, _o, _p): + return True + + +def _always_false(_s, _m, _o, _p): + return False + + +_Pick = ParameterizedOption( + "Pick", + types=[_block_type], + params_space=Box(low=np.array([0.0], dtype=np.float32), + high=np.array([1.0], dtype=np.float32)), + policy=_noop_policy, + initiable=_always_true, + terminal=_always_false, +) + +_Place = ParameterizedOption( + "Place", + types=[_block_type, _block_type], + params_space=Box(low=np.array([0.0, 0.0], dtype=np.float32), + high=np.array([1.0, 1.0], dtype=np.float32)), + policy=_noop_policy, + initiable=_always_true, + terminal=_always_false, +) + +_Wait = ParameterizedOption( + "Wait", + types=[_robot_type], + params_space=Box(low=np.array([], dtype=np.float32), + high=np.array([], dtype=np.float32)), + policy=_noop_policy, + initiable=_always_true, + terminal=_always_false, +) + +_ALL_OPTIONS = {_Pick, _Place, _Wait} + + +def _make_state(overrides=None): + data = { + _block0: np.array([0.1, 0.2, 0.0], dtype=np.float32), + _block1: np.array([0.5, 0.6, 0.0], dtype=np.float32), + _robot: np.array([0.0, 0.0], dtype=np.float32), + } + if overrides: + for obj, vals in overrides.items(): + data[obj] = np.array(vals, dtype=np.float32) + return State(data) + + +def _make_task(): + state = _make_state() + goal = {GroundAtom(_On, [_block0, _block1])} + return Task(state, goal) + + +def _assistant_response(text: str): + return [{ + "type": "assistant", + "content": [{ + "type": "text", + "text": text + }], + }] + + +def _make_explorer(option_model, query_impl): + """Build an AgentBilevelExplorer with stubbed session + tool_context.""" + tool_context = ToolContext( + types=_ALL_TYPES, + predicates=_ALL_PREDICATES, + options=_ALL_OPTIONS, + train_tasks=[_make_task()], + option_model=option_model, + ) + agent_session = MagicMock() + agent_session.query = query_impl + agent_session.tool_names = None + explorer = AgentBilevelExplorer( + predicates=_ALL_PREDICATES, + options=_ALL_OPTIONS, + types=_ALL_TYPES, + action_space=Box(low=-1, high=1, shape=(1, )), + train_tasks=[_make_task()], + max_steps_before_termination=50, + tool_context=tool_context, + agent_session=agent_session, + ) + return explorer, tool_context + + +def _reset_config(**overrides): + base = { + "env": "cover", + "approach": "agent_bilevel", + "num_train_tasks": 1, + "num_test_tasks": 1, + "seed": 42, + "agent_bilevel_max_samples_per_step": 5, + "agent_bilevel_explorer_max_samples_per_step": 5, + "agent_bilevel_max_retries": 0, + "agent_bilevel_check_subgoals": True, + "agent_bilevel_log_state": False, + "agent_explorer_fallback_to_random": True, + "agent_sdk_max_trajectories_in_context": 5, + } + base.update(overrides) + utils.reset_config(base) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_factory_registration(): + """AgentBilevelExplorer is reachable through create_explorer.""" + _reset_config() + tool_context = ToolContext( + types=_ALL_TYPES, + predicates=_ALL_PREDICATES, + options=_ALL_OPTIONS, + train_tasks=[_make_task()], + option_model=MagicMock(), + ) + agent_session = MagicMock() + explorer = create_explorer( + "agent_bilevel", + _ALL_PREDICATES, + _ALL_OPTIONS, + _ALL_TYPES, + Box(low=-1, high=1, shape=(1, )), + [_make_task()], + tool_context=tool_context, + agent_session=agent_session, + ) + assert isinstance(explorer, BaseExplorer) + assert isinstance(explorer, AgentBilevelExplorer) + + +def test_happy_path_returns_policy_and_stashes_subgoals(): + """Canned sketch → refined plan → policy and stashed subgoals.""" + _reset_config() + + goal_state = _make_state({_block0: [0.5, 0.6, 0.0]}) + option_model = MagicMock() + option_model.get_next_state_and_num_actions.return_value = (goal_state, 3) + + plan_text = ("Pick(block0:block)\n" + "Place(block0:block, block1:block) -> " + "{On(block0:block, block1:block)}\n") + query = AsyncMock(return_value=_assistant_response(plan_text)) + + explorer, tool_context = _make_explorer(option_model, query) + policy, term_fn = explorer._get_exploration_strategy(0, timeout=5) + + assert callable(policy) + assert term_fn(_make_state()) is False + assert tool_context.last_sketch_subgoals is not None + assert len(tool_context.last_sketch_subgoals) == 2 + # Second step's positive subgoal should be {On(block0, block1)}. + pos2, _neg2 = tool_context.last_sketch_subgoals[1] + assert pos2 == {GroundAtom(_On, [_block0, _block1])} + assert tool_context.last_sketch_options == [ + ("Pick", ["block0"]), + ("Place", ["block0", "block1"]), + ] + assert query.await_count == 1 + + +def test_wait_memory_injection_on_refine(): + """Wait step with subgoal should have wait_target_atoms injected.""" + _reset_config() + + captured: list = [] + + def side_effect(_state, option): + captured.append(option) + return (_make_state({_block0: [0.5, 0.6, 0.0]}), 3) + + option_model = MagicMock() + option_model.get_next_state_and_num_actions.side_effect = side_effect + + plan_text = ("Wait(robot0:robot) -> {On(block0:block, block1:block)}\n") + query = AsyncMock(return_value=_assistant_response(plan_text)) + explorer, _ = _make_explorer(option_model, query) + + explorer._get_exploration_strategy(0, timeout=5) + assert captured, "option_model was not invoked" + wait_opt = captured[0] + assert wait_opt.name == "Wait" + assert "wait_target_atoms" in wait_opt.memory + assert wait_opt.memory["wait_target_atoms"] == { + GroundAtom(_On, [_block0, _block1]) + } + + +def test_plan_truncates_at_deepest_subgoal_failure_after_backtracking(): + """Regression: explorer returns the prefix up to (and including) the + deepest step whose subgoal backtracking couldn't satisfy. + + Reproduces the boil-task bug: the agent sketches ``Pick → Wait(Holding) + → Place`` and the mental model's Wait does NOT produce ``Holding``. + Backtracking runs normally — it retries Pick with different params + and re-runs Wait each time — but since the mental model simply can't + produce Holding under any params, Wait's subgoal keeps failing. + After exhaustion, the explorer returns ``[Pick, Wait]`` with the last + grounded attempts. Place is NEVER executed because refinement never + gets past Wait. + """ + _reset_config() + + # Mental model post-state: Holding(block0) NEVER holds (held=0). + no_holding_state = _make_state({_block0: [0.1, 0.2, 0.0]}) + option_model = MagicMock() + option_model.get_next_state_and_num_actions.return_value = ( + no_holding_state, 3) + + plan_text = ("Pick(block0:block)\n" + "Wait(robot0:robot) -> {Holding(block0:block)}\n" + "Place(block0:block, block1:block) -> " + "{On(block0:block, block1:block)}\n") + query = AsyncMock(return_value=_assistant_response(plan_text)) + explorer, tool_context = _make_explorer(option_model, query) + + policy, _ = explorer._get_exploration_strategy(0, timeout=5) + assert callable(policy) + + # All three sketch steps recorded in metadata — the SKETCH is the full + # agent output; the TRUNCATION only applies to the refined plan. + assert tool_context.last_sketch_options == [ + ("Pick", ["block0"]), + ("Wait", ["robot0"]), + ("Place", ["block0", "block1"]), + ] + + executed_names = [ + call.args[1].name + for call in option_model.get_next_state_and_num_actions.call_args_list + ] + # Pick and Wait were each executed at least once (backtracking likely + # retried Pick multiple times). + assert "Pick" in executed_names + assert "Wait" in executed_names + # Place must NEVER be executed in the mental model: backtracking never + # got past the Wait subgoal failure, so Place never reached sample_fn. + assert "Place" not in executed_names, ( + "Place must not be executed in the mental model — refinement " + f"should have stalled at Wait's unsatisfiable subgoal, got " + f"{executed_names}") + # Pick has params (5 max_samples_per_step in test config), Wait has none. + # Each backtracking cycle runs Pick + Wait once, so we expect roughly + # 2 * max_samples_per_step mental-model calls — confirm backtracking + # actually exercised the upstream retries (at least 2 Picks). + assert executed_names.count("Pick") >= 2, ( + "Backtracking should have retried Pick at least twice before " + f"giving up, got {executed_names}") + + +def test_fallback_when_query_fails_and_flag_on(): + """Agent raises → random options fallback when flag enabled.""" + _reset_config(agent_explorer_fallback_to_random=True) + + option_model = MagicMock() + + async def failing_query(_msg): + raise RuntimeError("boom") + + explorer, _ = _make_explorer(option_model, failing_query) + policy, term_fn = explorer._get_exploration_strategy(0, timeout=5) + assert callable(policy) + assert term_fn(_make_state()) is False + + +def test_fallback_disabled_raises(): + """Agent raises → RequestActPolicyFailure when fallback flag off.""" + _reset_config(agent_explorer_fallback_to_random=False) + + option_model = MagicMock() + + async def failing_query(_msg): + raise RuntimeError("boom") + + explorer, _ = _make_explorer(option_model, failing_query) + with pytest.raises(utils.RequestActPolicyFailure): + explorer._get_exploration_strategy(0, timeout=5) diff --git a/tests/explorers/test_glib_explorer.py b/tests/explorers/test_glib_explorer.py index 89a70d507..5c9af5376 100644 --- a/tests/explorers/test_glib_explorer.py +++ b/tests/explorers/test_glib_explorer.py @@ -11,18 +11,26 @@ @pytest.mark.parametrize("target_predicate", ["Covers", "Holding"]) def test_glib_explorer(target_predicate): """Tests for GLIBExplorer class.""" + # Bump glib_num_babbles so we reliably sample at least one goal + # containing the target predicate. Default 10 babbles from cover's + # 7-atom dynamic universe gives a ~3.5% chance of zero Holding + # samples, which surfaces as a flake when test ordering shifts the + # shared explorer-RNG counter (predicators/explorers/base_explorer.py:15). utils.reset_config({ "env": "cover", "explorer": "glib", "cover_initial_holding_prob": 0.0, + "glib_num_babbles": 100, }) env = CoverEnv() options = get_gt_options(env.get_name()) nsrts = get_gt_nsrts(env.get_name(), env.predicates, options) option_model = _OracleOptionModel(options, env.simulate) train_tasks = [t.task for t in env.get_train_tasks()] - # For testing purposes, score everything except target predicate low. - score_fn = lambda atoms: target_predicate in str(atoms) + # Filter out non-target goals so the explorer never falls through to + # plan toward a different predicate when target goals fail. + score_fn = lambda atoms: 1.0 if target_predicate in str(atoms) \ + else -float("inf") explorer = create_explorer("glib", env.predicates, get_gt_options(env.get_name()), diff --git a/tests/pybullet_helpers/test_motion_planning.py b/tests/pybullet_helpers/test_motion_planning.py index f471ff83d..7eb04e37f 100644 --- a/tests/pybullet_helpers/test_motion_planning.py +++ b/tests/pybullet_helpers/test_motion_planning.py @@ -6,12 +6,12 @@ import pybullet as p from predicators import utils -from predicators.envs.pybullet_env import create_pybullet_block from predicators.pybullet_helpers.camera import create_gui_connection from predicators.pybullet_helpers.geometry import Pose from predicators.pybullet_helpers.joint import JointPositions from predicators.pybullet_helpers.link import get_link_state from predicators.pybullet_helpers.motion_planning import run_motion_planning +from predicators.pybullet_helpers.objects import create_pybullet_block from predicators.pybullet_helpers.robots import \ create_single_arm_pybullet_robot diff --git a/tests/pybullet_helpers/test_objects.py b/tests/pybullet_helpers/test_objects.py new file mode 100644 index 000000000..fd743c3d1 --- /dev/null +++ b/tests/pybullet_helpers/test_objects.py @@ -0,0 +1,112 @@ +"""Unit tests for predicators.pybullet_helpers.objects.""" +import numpy as np +import pytest + +from predicators.pybullet_helpers.objects import \ + sample_collision_free_2d_positions +from predicators.utils import Circle, Rectangle + + +def test_sample_collision_free_2d_positions_circles_no_overlap(): + """Sampled circles never overlap with each other.""" + rng = np.random.default_rng(0) + radius = 0.05 + positions = sample_collision_free_2d_positions( + num_samples=8, + x_range=(0.0, 1.0), + y_range=(0.0, 1.0), + shape_type="circle", + shape_params=[radius], + rng=rng, + ) + assert len(positions) == 8 + circles = [Circle(x, y, radius) for x, y in positions] + for i, c1 in enumerate(circles): + for c2 in circles[i + 1:]: + assert not c1.intersects(c2) + + +def test_sample_collision_free_2d_positions_within_bounds(): + """Sampled positions stay inside the requested x/y range.""" + rng = np.random.default_rng(0) + positions = sample_collision_free_2d_positions( + num_samples=5, + x_range=(-0.5, 0.5), + y_range=(2.0, 3.0), + shape_type="circle", + shape_params=[0.05], + rng=rng, + ) + for x, y in positions: + assert -0.5 <= x <= 0.5 + assert 2.0 <= y <= 3.0 + + +def test_sample_collision_free_2d_positions_rectangles_no_overlap(): + """Sampled rectangles never overlap with each other.""" + rng = np.random.default_rng(1) + w, h, theta = 0.05, 0.05, 0.0 + positions = sample_collision_free_2d_positions( + num_samples=4, + x_range=(0.0, 1.0), + y_range=(0.0, 1.0), + shape_type="rectangle", + shape_params=[w, h, theta], + rng=rng, + ) + assert len(positions) == 4 + rects = [Rectangle(x, y, w, h, theta) for x, y in positions] + for i, r1 in enumerate(rects): + for r2 in rects[i + 1:]: + assert not r1.intersects(r2) + + +def test_sample_collision_free_2d_positions_reproducible(): + """Same seed produces the same positions.""" + pos_a = sample_collision_free_2d_positions( + num_samples=4, + x_range=(0.0, 1.0), + y_range=(0.0, 1.0), + shape_type="circle", + shape_params=[0.05], + rng=np.random.default_rng(123), + ) + pos_b = sample_collision_free_2d_positions( + num_samples=4, + x_range=(0.0, 1.0), + y_range=(0.0, 1.0), + shape_type="circle", + shape_params=[0.05], + rng=np.random.default_rng(123), + ) + assert pos_a == pos_b + + +def test_sample_collision_free_2d_positions_impossible_raises(): + """Asking for more shapes than fit raises RuntimeError.""" + # 4 disks of radius 0.5 cannot fit non-overlapping in [0,1]^2. + rng = np.random.default_rng(0) + with pytest.raises(RuntimeError, match="Max tries exceeded"): + sample_collision_free_2d_positions( + num_samples=4, + x_range=(0.0, 1.0), + y_range=(0.0, 1.0), + shape_type="circle", + shape_params=[0.5], + rng=rng, + max_tries_total=200, + ) + + +def test_sample_collision_free_2d_positions_invalid_shape_raises(): + """An unknown shape_type raises ValueError.""" + rng = np.random.default_rng(0) + with pytest.raises(ValueError, match="Unsupported shape_type"): + sample_collision_free_2d_positions( + num_samples=1, + x_range=(0.0, 1.0), + y_range=(0.0, 1.0), + shape_type="triangle", + shape_params=[0.05], + rng=rng, + ) diff --git a/tests/test_skill_factories_integration.py b/tests/test_skill_factories_integration.py index 40a685fec..54f56cde9 100644 --- a/tests/test_skill_factories_integration.py +++ b/tests/test_skill_factories_integration.py @@ -78,7 +78,7 @@ def set_state(self, state: Any) -> None: simulator_state=joint_positions) self._current_observation = state_with_sim self._current_task = None - self._reset_state(state_with_sim) # type: ignore[attr-defined] + self._set_state(state_with_sim) # type: ignore[attr-defined] def get_state(self) -> Any: """Get state."""