From 5d60c7282ca6d004554a95dbd54ef59d2995d809 Mon Sep 17 00:00:00 2001 From: kailashr-nv Date: Wed, 5 Nov 2025 21:24:33 +0000 Subject: [PATCH] async writing for mimic datagen --- run_mimic_async.sh | 12 + .../isaaclab_mimic/generate_dataset_async.py | 184 ++++ .../isaaclab/managers/recorder_manager.py | 268 +++-- .../isaaclab/utils/datasets/episode_data.py | 7 +- .../isaaclab_mimic/async_writer.py | 960 ++++++++++++++++++ .../datagen/async_writer_recorder.py | 174 ++++ .../isaaclab_mimic/datagen/generation.py | 45 +- .../isaaclab_mimic/io_functions.py | 486 +++++++++ 8 files changed, 2018 insertions(+), 118 deletions(-) create mode 100755 run_mimic_async.sh create mode 100644 scripts/imitation_learning/isaaclab_mimic/generate_dataset_async.py create mode 100644 source/isaaclab_mimic/isaaclab_mimic/async_writer.py create mode 100644 source/isaaclab_mimic/isaaclab_mimic/datagen/async_writer_recorder.py create mode 100644 source/isaaclab_mimic/isaaclab_mimic/io_functions.py diff --git a/run_mimic_async.sh b/run_mimic_async.sh new file mode 100755 index 00000000000..d79f962db8d --- /dev/null +++ b/run_mimic_async.sh @@ -0,0 +1,12 @@ +#!/bin/bash +./isaaclab.sh -p scripts/imitation_learning/isaaclab_mimic/generate_dataset_async.py \ +--enable_pinocchio \ +--enable_cameras \ +--rendering_mode balanced \ +--task Isaac-NutPour-GR1T2-Pink-IK-Abs-Mimic-v0 \ +--generation_num_trials 100 \ +--num_envs 30 \ +--headless \ +--input_file ./datasets/annotated_dataset.hdf5 \ +--output_file ./datasets/async_generated_dataset_gr1_nut_pouring_new.hdf5 \ +--early_cpu_offload \ No newline at end of file diff --git a/scripts/imitation_learning/isaaclab_mimic/generate_dataset_async.py b/scripts/imitation_learning/isaaclab_mimic/generate_dataset_async.py new file mode 100644 index 00000000000..1779eab5d80 --- /dev/null +++ b/scripts/imitation_learning/isaaclab_mimic/generate_dataset_async.py @@ -0,0 +1,184 @@ +# Copyright (c) 2024-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Main data generation script. +""" + + +"""Launch Isaac Sim Simulator first.""" + +import argparse + +from isaaclab.app import AppLauncher +import time + +# add argparse arguments +parser = argparse.ArgumentParser(description="Generate demonstrations for Isaac Lab environments.") +parser.add_argument("--task", type=str, default=None, help="Name of the task.") +parser.add_argument("--generation_num_trials", type=int, help="Number of demos to be generated.", default=None) +parser.add_argument( + "--num_envs", type=int, default=1, help="Number of environments to instantiate for generating datasets." +) +parser.add_argument("--input_file", type=str, default=None, required=True, help="File path to the source dataset file.") +parser.add_argument( + "--output_file", + type=str, + default="./datasets/output_dataset.hdf5", + help="File path to export recorded and generated episodes.", +) +parser.add_argument( + "--pause_subtask", + action="store_true", + help="pause after every subtask during generation for debugging - only useful with render flag", +) +parser.add_argument( + "--enable_pinocchio", + action="store_true", + default=False, + help="Enable Pinocchio.", +) +parser.add_argument( + "--early_cpu_offload", + action="store_true", + default=False, + help="Enable early cpu offload.", +) +# append AppLauncher cli args +AppLauncher.add_app_launcher_args(parser) +# parse the arguments +args_cli = parser.parse_args() + +if args_cli.enable_pinocchio: + # Import pinocchio before AppLauncher to force the use of the version installed by IsaacLab and not the one installed by Isaac Sim + # pinocchio is required by the Pink IK controllers and the GR1T2 retargeter + import pinocchio # noqa: F401 + +# launch the simulator +app_launcher = AppLauncher(args_cli) +simulation_app = app_launcher.app + +"""Rest everything follows.""" + +import asyncio +import gymnasium as gym +import inspect +import numpy as np +import random +import torch + +import omni + +from isaaclab.envs import ManagerBasedRLMimicEnv + +import isaaclab_mimic.envs # noqa: F401 + +if args_cli.enable_pinocchio: + import isaaclab_mimic.envs.pinocchio_envs # noqa: F401 +from isaaclab_mimic.datagen.generation import env_loop, setup_async_generation, setup_env_config +from isaaclab_mimic.datagen.utils import get_env_name_from_dataset, setup_output_paths + +import isaaclab_tasks # noqa: F401 + + +def main(): + num_envs = args_cli.num_envs + + # Setup output paths and get env name + output_dir, output_file_name = setup_output_paths(args_cli.output_file) + task_name = args_cli.task + if task_name: + task_name = args_cli.task.split(":")[-1] + env_name = task_name or get_env_name_from_dataset(args_cli.input_file) + + # Configure environment + env_cfg, success_term = setup_env_config( + env_name=env_name, + output_dir=output_dir, + output_file_name=output_file_name, + num_envs=num_envs, + device=args_cli.device, + generation_num_trials=args_cli.generation_num_trials, + use_async_writer = True, + early_cpu_offload=args_cli.early_cpu_offload, + ) + + # create environment + env = gym.make(env_name, cfg=env_cfg).unwrapped + + if not isinstance(env, ManagerBasedRLMimicEnv): + raise ValueError("The environment should be derived from ManagerBasedRLMimicEnv") + + # check if the mimic API from this environment contains decprecated signatures + if "action_noise_dict" not in inspect.signature(env.target_eef_pose_to_action).parameters: + omni.log.warn( + f'The "noise" parameter in the "{env_name}" environment\'s mimic API "target_eef_pose_to_action", ' + "is deprecated. Please update the API to take action_noise_dict instead." + ) + + # set seed for generation + random.seed(env.cfg.datagen_config.seed) + np.random.seed(env.cfg.datagen_config.seed) + torch.manual_seed(env.cfg.datagen_config.seed) + + # reset before starting + env.reset() + + # Setup and run async data generation + async_components = setup_async_generation( + env=env, + num_envs=args_cli.num_envs, + input_file=args_cli.input_file, + success_term=success_term, + pause_subtask=args_cli.pause_subtask, + ) + + try: + data_gen_tasks = asyncio.ensure_future(asyncio.gather(*async_components["tasks"])) + start = time.time() + env_loop( + env, + async_components["reset_queue"], + async_components["action_queue"], + async_components["info_pool"], + async_components["event_loop"], + ) + end = time.time() + + print(f"total elapsed for env loop for {num_envs} envs and {args_cli.generation_num_trials} trials: {end - start}") + except asyncio.CancelledError: + print("Tasks were cancelled.") + finally: + # Cancel all async tasks when env_loop finishes + + + data_gen_tasks.cancel() + + + try: + # Wait for tasks to be cancelled + async_components["event_loop"].run_until_complete(data_gen_tasks) + except asyncio.CancelledError: + print("Remaining async tasks cancelled and cleaned up.") + except Exception as e: + print(f"Error cancelling remaining async tasks: {e}") + + # finish async writes before cancelling remaining tasks + + # hacky way to get the async writer term + async_term = next((t for t in env.recorder_manager._async_export_terms if hasattr(t, "flush_async")), None) + try: + async_components["event_loop"].run_until_complete(async_term.flush_async()) + except Exception as e: + print(f"Error flushing async writer: {e}") + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nProgram interrupted by user. Exiting...") + # close sim app + simulation_app.close() diff --git a/source/isaaclab/isaaclab/managers/recorder_manager.py b/source/isaaclab/isaaclab/managers/recorder_manager.py index 48f66598c28..1bc05058dfb 100644 --- a/source/isaaclab/isaaclab/managers/recorder_manager.py +++ b/source/isaaclab/isaaclab/managers/recorder_manager.py @@ -12,6 +12,11 @@ from collections.abc import Sequence from prettytable import PrettyTable from typing import TYPE_CHECKING +import asyncio +import threading +import queue + +import carb.profiler from isaaclab.utils import configclass from isaaclab.utils.datasets import EpisodeData, HDF5DatasetFileHandler @@ -122,7 +127,7 @@ def record_post_step(self) -> tuple[str | None, torch.Tensor | dict | None]: Please refer to the `record_pre_reset` function for more details. """ return None, None - + def record_post_physics_decimation_step(self) -> tuple[str | None, torch.Tensor | dict | None]: """Record data after the physics step is executed in the decimation loop. @@ -152,6 +157,8 @@ def __init__(self, cfg: object, env: ManagerBasedEnv): super().__init__(cfg, env) + print(self._terms) + # Do nothing if no active recorder terms are provided if len(self.active_terms) == 0: return @@ -159,10 +166,13 @@ def __init__(self, cfg: object, env: ManagerBasedEnv): if not isinstance(cfg, RecorderManagerBaseCfg): raise TypeError("Configuration for the recorder manager is not of type RecorderManagerBaseCfg.") + # early offload flag from env cfg (CPU storage in EpisodeData) + self.early_offload = bool(getattr(env.cfg, "early_cpu_offload", False)) + # create episode data buffer indexed by environment id self._episodes: dict[int, EpisodeData] = dict() for env_id in range(env.num_envs): - self._episodes[env_id] = EpisodeData() + self._episodes[env_id] = EpisodeData(early_cpu_offload=self.early_offload) env_name = getattr(env.cfg, "env_name", None) @@ -183,6 +193,23 @@ def __init__(self, cfg: object, env: ManagerBasedEnv): self._exported_successful_episode_count = {} self._exported_failed_episode_count = {} + # Async writing setup + self._async_export_terms = [t for t in self._terms.values() if hasattr(t, "schedule_async_write_for_episode")] + self.use_async_writing = len(self._async_export_terms) > 0 + self._async_loop = None + self._async_loop_thread = None + + # + if self.use_async_writing: + self._async_loop = asyncio.new_event_loop() + self._async_loop_thread = threading.Thread(target=self._async_loop.run_forever, daemon=True) + self._async_loop_thread.start() + # producer/consumer for non-blocking async writes + self._writer_queue: queue.Queue[tuple[int, EpisodeData] | None] = queue.Queue(maxsize=256) + self._writer_consumer_running: bool = True + self._writer_consumer_thread = threading.Thread(target=self._run_writer_consumer, daemon=True) + self._writer_consumer_thread.start() + def __str__(self) -> str: """Returns: A string representation for recorder manager.""" msg = f" contains {len(self._term_names)} active terms.\n" @@ -212,6 +239,24 @@ def __del__(self): if self._failed_episode_dataset_file_handler is not None: self._failed_episode_dataset_file_handler.close() + # shutdown background asyncio loop if started + if hasattr(self, "_async_loop") and self._async_loop is not None: + try: + # stop consumer + if hasattr(self, "_writer_consumer_running") and self._writer_consumer_running: + self._writer_consumer_running = False + try: + self._writer_queue.put_nowait(None) + except Exception: + pass + if hasattr(self, "_writer_consumer_thread") and self._writer_consumer_thread is not None: + self._writer_consumer_thread.join(timeout=1.0) + self._async_loop.call_soon_threadsafe(self._async_loop.stop) + if hasattr(self, "_async_loop_thread") and self._async_loop_thread is not None: + self._async_loop_thread.join(timeout=1.0) + except Exception: + pass + """ Properties. """ @@ -281,7 +326,7 @@ def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, torch.Tensor] term.reset(env_ids=env_ids) for env_id in env_ids: - self._episodes[env_id] = EpisodeData() + self._episodes[env_id] = EpisodeData(early_cpu_offload=self.early_offload) # nothing to log here return {} @@ -327,8 +372,9 @@ def add_to_episodes(self, key: str, value: torch.Tensor | dict, env_ids: Sequenc for value_index, env_id in enumerate(env_ids): if env_id not in self._episodes: - self._episodes[env_id] = EpisodeData() + self._episodes[env_id] = EpisodeData(early_cpu_offload=self.early_offload) self._episodes[env_id].env_id = env_id + #print(env_id, key, value[value_index]) self._episodes[env_id].add(key, value[value_index]) def set_success_to_episodes(self, env_ids: Sequence[int] | None, success_values: torch.Tensor): @@ -367,18 +413,10 @@ def record_post_step(self) -> None: if len(self.active_terms) == 0: return - for term in self._terms.values(): - key, value = term.record_post_step() - self.add_to_episodes(key, value) - def record_post_physics_decimation_step(self) -> None: - """Trigger recorder terms for post-physics step functions in the decimation loop.""" - # Do nothing if no active recorder terms are provided - if len(self.active_terms) == 0: - return for term in self._terms.values(): - key, value = term.record_post_physics_decimation_step() + key, value = term.record_post_step() self.add_to_episodes(key, value) def record_pre_reset(self, env_ids: Sequence[int] | None, force_export_or_skip=None) -> None: @@ -425,99 +463,110 @@ def record_post_reset(self, env_ids: Sequence[int] | None) -> None: key, value = term.record_post_reset(env_ids) self.add_to_episodes(key, value, env_ids) - def get_ep_meta(self) -> dict: - """Get the episode metadata.""" - if not hasattr(self._env.cfg, "get_ep_meta"): - # Add basic episode metadata - ep_meta = dict() - ep_meta["sim_args"] = { - "dt": self._env.cfg.sim.dt, - "decimation": self._env.cfg.decimation, - "render_interval": self._env.cfg.sim.render_interval, - "num_envs": self._env.cfg.scene.num_envs, - } - return ep_meta - - # Add custom episode metadata if available - ep_meta = self._env.cfg.get_ep_meta() - return ep_meta - - def export_episodes(self, env_ids: Sequence[int] | None = None, demo_ids: Sequence[int] | None = None) -> None: - """Concludes and exports the episodes for the given environment ids. - - Args: - env_ids: The environment ids. Defaults to None, in which case - all environments are considered. - demo_ids: Custom identifiers for the exported episodes. - If provided, episodes will be named "demo_{demo_id}" in the dataset. - Should have the same length as env_ids if both are provided. - If None, uses the default sequential naming scheme. Defaults to None. - """ + def record_post_physics_decimation_step(self) -> None: + """Trigger recorder terms for post-physics step functions in the decimation loop.""" # Do nothing if no active recorder terms are provided if len(self.active_terms) == 0: return + for term in self._terms.values(): + key, value = term.record_post_physics_decimation_step() + self.add_to_episodes(key, value) + + @carb.profiler.profile + def export_episodes(self, env_ids: Sequence[int] | None = None) -> None: + #print("recorder manager exporting episodes") + + carb.profiler.begin(10, "export_episodes") + if len(self.active_terms) == 0: + return + if env_ids is None: env_ids = list(range(self._env.num_envs)) if isinstance(env_ids, torch.Tensor): env_ids = env_ids.tolist() - - # Handle demo_ids processing - if demo_ids is not None: - if isinstance(demo_ids, torch.Tensor): - demo_ids = demo_ids.tolist() - if len(demo_ids) != len(env_ids): - raise ValueError(f"Length of demo_ids ({len(demo_ids)}) must match length of env_ids ({len(env_ids)})") - # Check for duplicate demo_ids - if len(set(demo_ids)) != len(demo_ids): - duplicates = [x for i, x in enumerate(demo_ids) if demo_ids.index(x) != i] - raise ValueError(f"demo_ids must be unique. Found duplicates: {list(set(duplicates))}") - - # Export episode data through dataset exporter - need_to_flush = False - - if any(env_id in self._episodes and not self._episodes[env_id].is_empty() for env_id in env_ids): - ep_meta = self.get_ep_meta() - if self._dataset_file_handler is not None: - self._dataset_file_handler.add_env_args(ep_meta) - if self._failed_episode_dataset_file_handler is not None: - self._failed_episode_dataset_file_handler.add_env_args(ep_meta) - - for i, env_id in enumerate(env_ids): - if env_id in self._episodes and not self._episodes[env_id].is_empty(): - self._episodes[env_id].pre_export() - - episode_succeeded = self._episodes[env_id].success - target_dataset_file_handler = None - if (self.cfg.dataset_export_mode == DatasetExportMode.EXPORT_ALL) or ( - self.cfg.dataset_export_mode == DatasetExportMode.EXPORT_SUCCEEDED_ONLY and episode_succeeded - ): - target_dataset_file_handler = self._dataset_file_handler - elif self.cfg.dataset_export_mode == DatasetExportMode.EXPORT_SUCCEEDED_FAILED_IN_SEPARATE_FILES: + if self.use_async_writing: + # enqueue episode snapshots and return immediately + for env_id in env_ids: + if env_id in self._episodes and not self._episodes[env_id].is_empty(): + episode = self._episodes[env_id] + + episode_succeeded = episode.success if episode_succeeded: + print("adding successful episode to async writing queue") + self._exported_successful_episode_count[env_id] = ( + self._exported_successful_episode_count.get(env_id, 0) + 1 + ) + # snapshot and enqueue + #snapshot = self._snapshot_episode_for_async(episode, env_id) + self._writer_queue.put_nowait((env_id, episode)) + + + else: + self._exported_failed_episode_count[env_id] = ( + self._exported_failed_episode_count.get(env_id, 0) + 1 + ) + self._episodes[env_id] = EpisodeData(early_cpu_offload=self.early_offload) + return + + + else: + # non async + # Export episode data through dataset exporter + need_to_flush = False + for env_id in env_ids: + if env_id in self._episodes and not self._episodes[env_id].is_empty(): + episode_succeeded = self._episodes[env_id].success + target_dataset_file_handler = None + if (self.cfg.dataset_export_mode == DatasetExportMode.EXPORT_ALL) or ( + self.cfg.dataset_export_mode == DatasetExportMode.EXPORT_SUCCEEDED_ONLY and episode_succeeded + ): target_dataset_file_handler = self._dataset_file_handler + elif self.cfg.dataset_export_mode == DatasetExportMode.EXPORT_SUCCEEDED_FAILED_IN_SEPARATE_FILES: + if episode_succeeded: + target_dataset_file_handler = self._dataset_file_handler + else: + target_dataset_file_handler = self._failed_episode_dataset_file_handler + if target_dataset_file_handler is not None: + print(f"writing episode to {target_dataset_file_handler.filename}") + target_dataset_file_handler.write_episode(self._episodes[env_id]) + need_to_flush = True + # Update episode count + if episode_succeeded: + self._exported_successful_episode_count[env_id] = ( + self._exported_successful_episode_count.get(env_id, 0) + 1 + ) else: - target_dataset_file_handler = self._failed_episode_dataset_file_handler - if target_dataset_file_handler is not None: - # Use corresponding demo_id if provided, otherwise None - current_demo_id = demo_ids[i] if demo_ids is not None else None - target_dataset_file_handler.write_episode(self._episodes[env_id], current_demo_id) - need_to_flush = True - # Update episode count - if episode_succeeded: - self._exported_successful_episode_count[env_id] = ( - self._exported_successful_episode_count.get(env_id, 0) + 1 - ) - else: - self._exported_failed_episode_count[env_id] = self._exported_failed_episode_count.get(env_id, 0) + 1 - # Reset the episode buffer for the given environment after export - self._episodes[env_id] = EpisodeData() - - if need_to_flush: - if self._dataset_file_handler is not None: - self._dataset_file_handler.flush() - if self._failed_episode_dataset_file_handler is not None: - self._failed_episode_dataset_file_handler.flush() + self._exported_failed_episode_count[env_id] = self._exported_failed_episode_count.get(env_id, 0) + 1 + # Reset the episode buffer for the given environment after export + early_offload = bool(getattr(self._env.cfg, "early_cpu_offload", False)) + self._episodes[env_id] = EpisodeData(early_cpu_offload=early_offload) + + if need_to_flush: + if self._dataset_file_handler is not None: + self._dataset_file_handler.flush() + if self._failed_episode_dataset_file_handler is not None: + self._failed_episode_dataset_file_handler.flush() + + carb.profiler.end(10) + + + + + async def async_export_episodes(self, env_ids: Sequence[int] | None = None) -> None: + """Deprecated. Use export_episodes which enqueues to background consumer.""" + if env_ids is None: + env_ids = list(range(self._env.num_envs)) + for env_id in env_ids: + if env_id in self._episodes and not self._episodes[env_id].is_empty(): + snapshot = self._snapshot_episode_for_async(self._episodes[env_id], env_id) + self._writer_queue.put((env_id, snapshot)) + self._episodes[env_id] = EpisodeData(early_cpu_offload=self.early_offload) + return + + + + """ Helper functions. @@ -557,3 +606,32 @@ def _prepare_terms(self): # add term name and parameters self._term_names.append(term_name) self._terms[term_name] = term + + # background consumer thread for async writes + # receive (env_id, episode) from writer queue of export_episodes + def _run_writer_consumer(self): + """Background consumer thread: synchronous writes in a single thread.""" + while getattr(self, "_writer_consumer_running", False): + item = self._writer_queue.get() + if item is None: + break + _env_id, episode = item + for term in self._async_export_terms: + term.schedule_sync_write_for_episode(episode) + + + + def _snapshot_episode_for_async(self, episode: EpisodeData, env_id: int) -> EpisodeData: + """Create a CPU snapshot of episode data to avoid races in async path.""" + def _clone_tree(node): + if isinstance(node, torch.Tensor): + return node.detach().to("cpu").clone() + if isinstance(node, dict): + return {k: _clone_tree(v) for k, v in node.items()} + return node + snap = EpisodeData(early_cpu_offload=True) + snap.data = _clone_tree(episode.data) + snap.seed = episode.seed + snap.success = episode.success + snap.env_id = env_id + return snap diff --git a/source/isaaclab/isaaclab/utils/datasets/episode_data.py b/source/isaaclab/isaaclab/utils/datasets/episode_data.py index 31971b6181c..91a5b4a3c55 100644 --- a/source/isaaclab/isaaclab/utils/datasets/episode_data.py +++ b/source/isaaclab/isaaclab/utils/datasets/episode_data.py @@ -16,7 +16,7 @@ class EpisodeData: """Class to store episode data.""" - def __init__(self) -> None: + def __init__(self, early_cpu_offload: bool = False) -> None: """Initializes episode data class.""" self._data = dict() self._next_action_index = 0 @@ -25,6 +25,7 @@ def __init__(self) -> None: self._seed = None self._env_id = None self._success = None + self._early_cpu_offload = early_cpu_offload @property def data(self): @@ -106,6 +107,10 @@ def add(self, key: str, value: torch.Tensor | dict): self.add(f"{key}/{sub_key}", sub_value) return + # optionally offload tensors to CPU immediately + if isinstance(value, torch.Tensor) and self._early_cpu_offload: + value = value.detach().to("cpu") + sub_keys = key.split("/") current_dataset_pointer = self._data for sub_key_index in range(len(sub_keys)): diff --git a/source/isaaclab_mimic/isaaclab_mimic/async_writer.py b/source/isaaclab_mimic/isaaclab_mimic/async_writer.py new file mode 100644 index 00000000000..1c98b4c7295 --- /dev/null +++ b/source/isaaclab_mimic/isaaclab_mimic/async_writer.py @@ -0,0 +1,960 @@ +import os +from typing import List, Dict +import json +import threading + +import carb +import numpy as np +import h5py +import pandas as pd +import torch +from collections import defaultdict + +import omni.kit +import omni.usd +import omni.replicator.core as rep +import asyncio + +from omni.syntheticdata.scripts.SyntheticData import SyntheticData + +from omni.replicator.core import functional as F +from omni.replicator.core import AnnotatorRegistry +from omni.replicator.core import BackendDispatch +from omni.replicator.core.backends import BackendGroup, BaseBackend +from omni.replicator.core.utils import skeleton_data_utils +from omni.replicator.core import Writer +from omni.replicator.core.writers_default.tools import colorize_distance, colorize_normals + +from .io_functions import write_dataframe_hdf5 + + + + +# Helpers for writing hdf5 file from rl_games dataframe + + +def parse_column_structure(df_columns): + """ + Parse DataFrame column names to determine the nested group structure of the hdf5 file + + return: + - structure: dict mapping main groups to their subgroups and columns + """ + structure = defaultdict(list) + + for col in df_columns: + if '/' in col: + # column has subgroup structure: "main_group/subgroup/column" + + # ie. obs/right_eef_pos creates one level of nestin g + # obs/datagen_info/eef_pose/left creates 3 levels of nesting + + parts = col.split('/') + main_group = parts[0] + subgroup_path = '/'.join(parts[1:]) + structure[main_group].append(subgroup_path) + else: + # root column goes directly under "demo group" + structure['root'].append(col) + + return dict(structure) + + +# pretty print structure for debugging +def print_structure(d, indent=0): + for key, value in d.items(): + if isinstance(value, dict): + print(" " * indent + f"{key}/") + print_structure(value, indent + 1) + else: + print(" " * indent + f"{key}: {value.shape}") + +# take in the parsed column structure from above and create nested datasets in hdf5 file +# tree structure -- every group is either a dataset in its immediate group or a subgroup +def create_nested_datasets(demo_group, df, structure): + """ + Create nested datasets in HDF5 based on the parsed structure. + """ + for main_group, subgroups in structure.items(): + if main_group == 'root': + # root groups + for col_name in subgroups: + data_series = df[col_name] + if isinstance(data_series.iloc[0], torch.Tensor): + stacked_data = torch.stack(data_series.tolist()).numpy() + else: + stacked_data = np.stack(data_series.values) + demo_group.create_dataset(col_name, data=stacked_data) + else: + # subgroups + group_obj = demo_group.create_group(main_group) + + # Organize columns by their immediate subgroup + subgroup_dict = defaultdict(list) + for col_path in subgroups: + parts = col_path.split('/') + if len(parts) == 1: + # direct dataset + subgroup_dict['root'].append((parts[0], col_path)) + else: + # nested subgroup + subgroup_dict[parts[0]].append(('/'.join(parts[1:]), col_path)) + + # Create datasets and nested subgroups + for immediate_subgroup, column_info in subgroup_dict.items(): + if immediate_subgroup == 'root': + # Create datasets directly in the main group + for dataset_name, col_path in column_info: + data_series = df[f"{main_group}/{col_path}"] + if isinstance(data_series.iloc[0], torch.Tensor): + stacked_data = torch.stack(data_series.tolist()).numpy() + else: + stacked_data = np.stack(data_series.values) + group_obj.create_dataset(dataset_name, data=stacked_data) + else: + + subgroup_obj = group_obj.create_group(immediate_subgroup) + for nested_path, col_path in column_info: + + + data_series = df[f"{main_group}/{col_path}"] + if isinstance(data_series.iloc[0], torch.Tensor): + stacked_data = torch.stack(data_series.tolist()).numpy() + else: + stacked_data = np.stack(data_series.values) + subgroup_obj.create_dataset(nested_path, data=stacked_data) + +def dataframe_to_nested_hdf5(df, hdf5_path, demo_name="demo_0", env_args: Dict | None = None): + """ + convert dataframe with nested column naming to HDF5 file with nested group structure + + @param + - df: DataFrame with columns like "obs/right_eef_pos", "actions", "states/articulation" etc. + - hdf5_path: Path to save the HDF5 file + - demo_name: Name for this demo subgroup + """ + + # entry function for dataframe -> hdf5 demo conversion + structure = parse_column_structure(df.columns) + + + with h5py.File(hdf5_path, 'w') as f: + dataset_group = f.create_group('data') + if env_args is not None: + dataset_group.attrs['env_args'] = json.dumps(env_args) + demo_group = dataset_group.create_group(demo_name) + # also mirror env_args at file root for easier discovery + if env_args is not None: + f.attrs['env_args'] = json.dumps(env_args) + + create_nested_datasets(demo_group, df, structure) + + # todo: incorporate replicator backend + +def add_demo_with_nested_structure(df, hdf5_path, demo_name, env_args: Dict | None = None): + """ + Add a new demo with nested structure to an existing HDF5 file. + """ + + structure = parse_column_structure(df.columns) + + with h5py.File(hdf5_path, 'a') as f: + # prefer 'data' group, fall back to existing 'dataset' if present, else create 'data' + if 'data' in f: + dataset_group = f['data'] + elif 'dataset' in f: + dataset_group = f['dataset'] + else: + dataset_group = f.create_group('data') + if env_args is not None: + # set or update env_args on the chosen top-level group + dataset_group.attrs['env_args'] = json.dumps(env_args) + if demo_name in dataset_group: + #print(f"[WARNING] Demo {demo_name} already exists in the HDF5 file. Overwriting...") + raise ValueError(f"Demo {demo_name} already exists in the HDF5 file.") + del dataset_group[demo_name] + + + demo_group = dataset_group.create_group(demo_name) + create_nested_datasets(demo_group, df, structure) + +def read_nested_demo(hdf5_path, demo_name="demo_0"): + """ + Read a demo with nested structure from HDF5 file. + + Returns a nested dictionary of PyTorch tensors. + """ + def h5_to_dict(group): + """Recursively convert HDF5 group to dictionary.""" + data_dict = {} + for key, item in group.items(): + if isinstance(item, h5py.Group): + data_dict[key] = h5_to_dict(item) + else: + data_dict[key] = torch.from_numpy(item[:]) + return data_dict + + with h5py.File(hdf5_path, 'r') as f: + demo_group = f[f'data/{demo_name}'] + return h5_to_dict(demo_group) + +# Example usage: + + +def dataframe_to_hdf5(df, hdf5_path, obs_column='obs', actions_column='actions', + initial_state_column='initial_state', states_column='states'): + """ + Convert a DataFrame of trajectory data to a nested HDF5 file. + + Parameters: + - df: DataFrame containing trajectory data + - hdf5_path: Path to save the HDF5 file + - obs_column: Column name for observations + - actions_column: Column name for actions + - initial_state_column: Column name for initial states + - states_column: Column name for states + """ + + with h5py.File(hdf5_path, 'w') as f: + dataset_group = f.create_group('dataset') + + + demo_group = dataset_group.create_group(demo_name) + + # Create obs group and dataset from the series + obs_group = demo_group.create_group('obs') + obs_series = df[obs_column] + +__version__ = '0.0.2' +class AsyncWriter(Writer): + """async writer taken from basic writer implementation at https://gitlab-master.nvidia.com/omniverse/synthetic-data/omni.replicator/-/blob/develop/source/extensions/omni.replicator.core/python/scripts/writers_default/basicwriter.py?ref_type=heads + + + + Args: + output_dir: + Output directory string that indicates the directory to save the results. + s3_bucket: + The S3 Bucket name to write to. If not provided, disk backend will be used instead. Default: ``None``. + This backend requires that AWS credentials are set up in ``~/.aws/credentials``. + See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration + s3_region: + If provided, this is the region the S3 bucket will be set to. Default: ``us-east-1`` + s3_endpoint: + If provided, this endpoint URL will be used instead of the default. + semantic_types: + List of semantic types to consider when filtering annotator data. Default: ``["class"]`` + rgb: + Boolean value that indicates whether the ``rgb``/``LdrColor`` annotator will be activated + and the data will be written or not. Default: ``False``. + bounding_box_2d_tight: + Boolean value that indicates whether the ``bounding_box_2d_tight`` annotator will be activated + and the data will be written or not. Default: ``False``. + bounding_box_2d_loose: + Boolean value that indicates whether the ``bounding_box_2d_loose`` annotator will be activated + and the data will be written or not. Default: ``False``. + semantic_segmentation: + Boolean value that indicates whether the ``semantic_segmentation`` annotator will be activated + and the data will be written or not. Default: ``False``. + instance_id_segmentation: + Boolean value that indicates whether the ``instance_id_segmentation`` annotator will be activated + and the data will be written or not. Default: ``False``. + instance_segmentation: + Boolean value that indicates whether the ``instance_segmentation`` annotator will be activated + and the data will be written or not. Default: ``False``. + distance_to_camera: + Boolean value that indicates whether the ``distance_to_camera`` annotator will be activated + and the data will be written or not. Default: ``False``. + distance_to_image_plane: + Boolean value that indicates whether the ``distance_to_image_plane`` annotator will be activated + and the data will be written or not. Default: ``False``. + bounding_box_3d: + Boolean value that indicates whether the ``bounding_box_3d`` annotator will be activated + and the data will be written or not. Default: ``False``. + occlusion: + Boolean value that indicates whether the ``occlusion`` annotator will be activated + and the data will be written or not. Default: ``False``. + normals: + Boolean value that indicates whether the ``normals`` annotator will be activated + and the data will be written or not. Default: ``False``. + motion_vectors: + Boolean value that indicates whether the ``motion_vectors`` annotator will be activated + and the data will be written or not. Default: ``False``. + camera_params: + Boolean value that indicates whether the ``camera_params`` annotator will be activated + and the data will be written or not. Default: ``False``. + pointcloud: + Boolean value that indicates whether the ``pointcloud`` annotator will be activated + and the data will be written or not. Default: ``False``. + pointcloud_include_unlabelled: + If ``True``, pointcloud annotator will capture any prim in the camera's perspective, not matter if it has + semantics or not. If ``False``, only prims with semantics will be captured. + Defaults to ``False``. + image_output_format: + String that indicates the format of saved RGB images. Default: ``"png"`` + colorize_semantic_segmentation: + If ``True``, semantic segmentation is converted to an image where semantic IDs are mapped to colors + and saved as a uint8 4 channel PNG image. If ``False``, the output is saved as a ``uint32`` PNG image. + Defaults to ``True``. + colorize_instance_id_segmentation: + If ``True``, instance id segmentation is converted to an image where instance IDs are mapped to colors. + and saved as a uint8 4 channel PNG image. If ``False``, the output is saved as a ``uint32`` PNG image. + Defaults to ``True``. + colorize_instance_segmentation: + If ``True``, instance segmentation is converted to an image where instance are mapped to colors. + and saved as a uint8 4 channel PNG image. If ``False``, the output is saved as a ``uint32`` PNG image. + Defaults to ``True``. + colorize_depth: + If ``True``, will output an additional PNG image for depth for visualization + Defaults to ``False``. + frame_padding: + Pad the frame number with leading zeroes. Default: ``4`` + semantic_filter_predicate: + A string specifying a semantic filter predicate as a disjunctive normal form of semantic type, labels. + + Examples : + "typeA : labelA & !labelB | labelC , typeB: labelA ; typeC: labelD" + "typeA : * ; * : labelA" + use_common_output_dir: + If ``True``, output for each annotator coming from multiple render products are saved under a common directory + with the render product as the filename prefix (eg. __.). + If ``False``, multiple render product outputs are placed into their own directory + (eg. /_.). Setting is ignored if using the writer with + a single render product. Defaults to ``False``. + backend: Optionally pass a backend to use. If specified, `output_dir` and `s3_<>` arguments may be omitted. If + both are provided, the backends will be grouped. + + + Example: + >>> import omni.replicator.core as rep + >>> import carb + >>> camera = rep.create.camera() + >>> render_product = rep.create.render_product(camera, (1024, 1024)) + >>> writer = rep.WriterRegistry.get("BasicWriter") + >>> tmp_dir = carb.tokens.get_tokens_interface().resolve("${temp}/rgb") + >>> writer.initialize(output_dir=tmp_dir, rgb=True) + >>> writer.attach([render_product]) + >>> rep.orchestrator.run() + """ + + def __init__( + self, + output_dir: str = None, + s3_bucket: str = None, + s3_region: str = None, + s3_endpoint: str = None, + semantic_types: List[str] = None, + rgb: bool = False, + bounding_box_2d_tight: bool = False, + bounding_box_2d_loose: bool = False, + semantic_segmentation: bool = False, + instance_id_segmentation: bool = False, + instance_segmentation: bool = False, + distance_to_camera: bool = False, + distance_to_image_plane: bool = False, + bounding_box_3d: bool = False, + occlusion: bool = False, + normals: bool = False, + motion_vectors: bool = False, + camera_params: bool = False, + pointcloud: bool = False, + pointcloud_include_unlabelled: bool = False, + image_output_format: str = "png", + colorize_semantic_segmentation: bool = True, + colorize_instance_id_segmentation: bool = True, + colorize_instance_segmentation: bool = True, + colorize_depth: bool = False, + skeleton_data: bool = False, + frame_padding: int = 4, + semantic_filter_predicate: str = None, + use_common_output_dir: bool = False, + backend: BaseBackend = None, + ): + self._output_dir = output_dir + self.data_structure = "annotator" + self.use_common_output_dir = use_common_output_dir + self._backend = None + if s3_bucket: + self._backend = BackendDispatch( + key_prefix=output_dir, + bucket=s3_bucket, + region=s3_region, + endpoint_url=s3_endpoint, + ) + elif output_dir: + self._backend = BackendDispatch(output_dir=output_dir) + + if backend and self._backend: + self._backend = BackendGroup([backend, *self._backend._backends]) + elif backend: + self._backend = backend + + if not self._backend: + raise ValueError("No `backend`, `output_dir` or `s3_` parameter specified, unable to initialize writer.") + + self.backend = self._backend + self._frame_id = 0 + self._sequence_id = 0 + self._image_output_format = image_output_format + self._output_data_format = {} + self.annotators = [] + self.version = __version__ + self._frame_padding = frame_padding + + self.colorize_semantic_segmentation = colorize_semantic_segmentation + self.colorize_instance_id_segmentation = colorize_instance_id_segmentation + self.colorize_instance_segmentation = colorize_instance_segmentation + self.colorize_depth = colorize_depth + + self.num_demos_written = 0 + + # persistent HDF5 file handles keyed by absolute path + self._file_map: Dict[str, h5py.File] = {} + self._file_lock = threading.Lock() + # env args cache + if not hasattr(self, '_env_args'): + self._env_args = None + + is_default_semantic_filter = semantic_filter_predicate is None + # Specify the semantic types that will be included in output + if semantic_types is not None: + if semantic_filter_predicate is None: + semantic_filter_predicate = ":*; ".join(semantic_types) + ":*" + else: + raise ValueError( + "`semantic_types` and `semantic_filter_predicate` are mutually exclusive. Please choose only one." + ) + elif is_default_semantic_filter: + semantic_filter_predicate = "class:*" + + # Set the global semantic filter predicate + # FIXME: don't set the global semantic filter predicate after support of multiple instances of annotators + if semantic_filter_predicate is not None: + SyntheticData.Get().set_instance_mapping_semantic_filter(semantic_filter_predicate) + + # RGB + if rgb: + self.annotators.append(AnnotatorRegistry.get_annotator("rgb")) + + # Bounding Box 2D + if bounding_box_2d_tight: + if is_default_semantic_filter: + self.annotators.append("bounding_box_2d_tight_fast") + else: + self.annotators.append( + AnnotatorRegistry.get_annotator( + "bounding_box_2d_tight_fast", + init_params={ + "semanticFilter": semantic_filter_predicate, + }, + ) + ) + + if bounding_box_2d_loose: + if is_default_semantic_filter: + self.annotators.append("bounding_box_2d_loose_fast") + else: + self.annotators.append( + AnnotatorRegistry.get_annotator( + "bounding_box_2d_loose_fast", + init_params={ + "semanticFilter": semantic_filter_predicate, + }, + ) + ) + + # Semantic Segmentation + if semantic_segmentation: + self.annotators.append( + AnnotatorRegistry.get_annotator( + "semantic_segmentation", + init_params={ + "colorize": colorize_semantic_segmentation, + "semanticFilter": semantic_filter_predicate, + }, + ) + ) + + # Instance Segmentation + if instance_id_segmentation: + self.annotators.append( + AnnotatorRegistry.get_annotator( + "instance_id_segmentation_fast", init_params={"colorize": colorize_instance_id_segmentation} + ) + ) + + # Instance Segmentation + if instance_segmentation: + self.annotators.append( + AnnotatorRegistry.get_annotator( + "instance_segmentation_fast", + init_params={ + "colorize": colorize_instance_segmentation, + "semanticFilter": semantic_filter_predicate, + }, + ) + ) + + # Depth + if distance_to_camera: + self.annotators.append(AnnotatorRegistry.get_annotator("distance_to_camera")) + + if distance_to_image_plane: + self.annotators.append(AnnotatorRegistry.get_annotator("distance_to_image_plane")) + + # Bounding Box 3D + if bounding_box_3d: + self.annotators.append( + AnnotatorRegistry.get_annotator( + "bounding_box_3d_fast", + init_params={ + "semanticFilter": semantic_filter_predicate, + }, + ) + ) + + # Motion Vectors + if motion_vectors: + self.annotators.append(AnnotatorRegistry.get_annotator("motion_vectors")) + + # Occlusion + if occlusion: + self.annotators.append(AnnotatorRegistry.get_annotator("occlusion")) + + # Normals + if normals: + self.annotators.append(AnnotatorRegistry.get_annotator("normals")) + + # Camera Params + if camera_params: + self.annotators.append(AnnotatorRegistry.get_annotator("camera_params")) + + # Pointcloud + if pointcloud: + self.annotators.append( + AnnotatorRegistry.get_annotator( + "pointcloud", init_params={"includeUnlabelled": pointcloud_include_unlabelled} + ) + ) + + # Skeleton Data + if skeleton_data: + self.annotators.append( + AnnotatorRegistry.get_annotator("skeleton_data", init_params={"useSkelJoints": False}) + ) + + backend_type = "S3" if s3_bucket else "Disk" + + + def _write_trajectory_data_hdf5(self, data : pd.DataFrame, out_file : str, debug = False): + filepath = os.path.join(self._output_dir, out_file) + + def _get_or_create_file_handle(filepath: str) -> h5py.File: + """Get or create persistent file handle for the given filepath.""" + with self._file_lock: + f = self._file_map.get(filepath) + if f is None: + os.makedirs(os.path.dirname(filepath), exist_ok=True) + f = h5py.File(filepath, 'a') + self._file_map[filepath] = f + return f + + demo_name = f"demo_{self.num_demos_written}" + env_args = getattr(self, '_env_args', None) + + # Use backend-compatible write function with persistent file handle getter + write_dataframe_hdf5( + path=filepath, + data=data, + backend_instance=self._backend, + demo_name=demo_name, + env_args=env_args, + file_handle_getter=_get_or_create_file_handle, + ) + + self.num_demos_written += 1 + + + if debug: + loaded_data = read_nested_demo(filepath, demo_name=f"demo_{self.num_demos_written-1}") + print_structure(loaded_data) + + + + + # for now dont actually use this + def write(self, data : pd.DataFrame): + return self._write_trajectory_data_hdf5(data, f"trajectory_data.hdf5", debug=True) + + async def write_trajectory_data_async( + self, + data: pd.DataFrame, + out_file: str = "trajectory_data.hdf5", + debug: bool = False, + ): + + await asyncio.to_thread(self._write_trajectory_data_hdf5, data, out_file, debug) + + # --- Env args support (to mirror HDF5DatasetFileHandler) --- + def set_env_args(self, env_args: Dict): + if not hasattr(self, '_env_args') or self._env_args is None: + self._env_args = {} + self._env_args.update(env_args) + + def close(self): + # flush and close all open files + with self._file_lock: + for fp, f in list(self._file_map.items()): + try: + f.flush() + except Exception: + pass + try: + f.close() + except Exception: + pass + self._file_map.pop(fp, None) + + async def _write(self, data: dict): + """Write function called from the OgnWriter node on every frame to process annotator output. + + Args: + data: A dictionary containing the annotator data for the current frame. + """ + # Check for on_time triggers + # For each on_time trigger, prefix the output frame number with the trigger counts + sequence_id = "" + for trigger_name, call_count in data["trigger_outputs"].items(): + if "on_time" in trigger_name: + sequence_id = f"{call_count}_{sequence_id}" + if sequence_id != self._sequence_id: + self._frame_id = 0 + self._sequence_id = sequence_id + + for annotator_name, annotator_data in data["annotators"].items(): + # Shorten fast annotator names + if annotator_name.endswith("_fast"): + annotator_name = annotator_name[:-5] + + is_multi_rp = len(annotator_data) > 1 + for render_product_name, anno_rp_data in annotator_data.items(): + if is_multi_rp: + if self.use_common_output_dir: + output_path = ( + os.path.join(annotator_name, render_product_name) + "_" + ) # Add render product as prefix + else: + output_path = ( + os.path.join(render_product_name, annotator_name) + os.path.sep + ) # Legacy behaviour + else: + output_path = "" + + if annotator_name == "rgb" or annotator_name.startswith("Aug"): + self._write_rgb(anno_rp_data, output_path) + + elif annotator_name == "normals": + self._write_normals(anno_rp_data, output_path) + + elif annotator_name == "distance_to_camera": + self._write_distance_to_camera(anno_rp_data, output_path) + + elif annotator_name == "distance_to_image_plane": + self._write_distance_to_image_plane(anno_rp_data, output_path) + + elif annotator_name.startswith("semantic_segmentation"): + self._write_semantic_segmentation(anno_rp_data, output_path) + + elif annotator_name.startswith("instance_id_segmentation"): + self._write_instance_id_segmentation(anno_rp_data, output_path) + + elif annotator_name.startswith("instance_segmentation"): + self._write_instance_segmentation(anno_rp_data, output_path) + + elif annotator_name.startswith("motion_vectors"): + self._write_motion_vectors(anno_rp_data, output_path) + + elif annotator_name.startswith("occlusion"): + self._write_occlusion(anno_rp_data, output_path) + + elif annotator_name.startswith("bounding_box_3d"): + self._write_bounding_box_data(anno_rp_data, "3d", output_path) + + elif annotator_name.startswith("bounding_box_2d_loose"): + self._write_bounding_box_data(anno_rp_data, "2d_loose", output_path) + + elif annotator_name.startswith("bounding_box_2d_tight"): + self._write_bounding_box_data(anno_rp_data, "2d_tight", output_path) + + elif annotator_name.startswith("camera_params"): + self._write_camera_params(anno_rp_data, output_path) + + elif annotator_name.startswith("pointcloud"): + self._write_pointcloud(anno_rp_data, output_path) + + elif annotator_name.startswith("skeleton_data"): + self._write_skeleton(anno_rp_data, output_path) + + elif annotator_name not in ["camera", "resolution"]: + carb.log_warn(f"Unknown {annotator_name=}") + + self._frame_id += 1 + + def _write_rgb(self, anno_rp_data: dict, output_path: str): + file_path = ( + f"{output_path}rgb_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.{self._image_output_format}" + ) + self._backend.schedule(F.write_image, data=anno_rp_data["data"], path=file_path) + + def _write_normals(self, anno_rp_data: dict, output_path: str): + normals_data = anno_rp_data["data"] + file_path = f"{output_path}normals_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.png" + colorized_normals_data = colorize_normals(normals_data) + self._backend.schedule(F.write_image, data=colorized_normals_data, path=file_path) + + def _write_distance_to_camera(self, anno_rp_data: dict, output_path: str): + dist_to_cam_data = anno_rp_data["data"] + file_path = f"{output_path}distance_to_camera_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.npy" + self._backend.schedule(F.write_np, data=dist_to_cam_data, path=file_path) + if self.colorize_depth: + file_path = ( + f"{output_path}distance_to_camera_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.png" + ) + self._backend.schedule( + F.write_image, data=colorize_distance(dist_to_cam_data, near=None, far=None), path=file_path + ) + + + + + + def _write_distance_to_image_plane(self, anno_rp_data: dict, output_path: str): + dis_to_img_plane_data = anno_rp_data["data"] + file_path = ( + f"{output_path}distance_to_image_plane_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.npy" + ) + self._backend.schedule(F.write_np, data=dis_to_img_plane_data, path=file_path) + if self.colorize_depth: + file_path = ( + f"{output_path}distance_to_image_plane_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.png" + ) + self._backend.schedule( + F.write_image, data=colorize_distance(dis_to_img_plane_data, near=None, far=None), path=file_path + ) + + def _write_semantic_segmentation(self, anno_rp_data: dict, output_path: str): + semantic_seg_data = anno_rp_data["data"] + height, width = semantic_seg_data.shape[:2] + + file_path = f"{output_path}semantic_segmentation_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.png" + if self.colorize_semantic_segmentation: + semantic_seg_data = semantic_seg_data.view(np.uint8).reshape(height, width, -1) + self._backend.schedule(F.write_image, data=semantic_seg_data, path=file_path) + else: + semantic_seg_data = semantic_seg_data.view(np.uint32).reshape(height, width) + self._backend.schedule(F.write_image, data=semantic_seg_data, path=file_path) + + id_to_labels = anno_rp_data["idToLabels"] + file_path = ( + f"{output_path}semantic_segmentation_labels_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.json" + ) + + self._backend.schedule(F.write_json, data={str(k): v for k, v in id_to_labels.items()}, path=file_path) + + def _write_instance_id_segmentation(self, anno_rp_data: dict, output_path: str): + instance_seg_data = anno_rp_data["data"] + height, width = instance_seg_data.shape[:2] + + file_path = ( + f"{output_path}instance_id_segmentation_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.png" + ) + if self.colorize_instance_id_segmentation: + instance_seg_data = instance_seg_data.view(np.uint8).reshape(height, width, -1) + self._backend.schedule(F.write_image, data=instance_seg_data, path=file_path) + else: + instance_seg_data = instance_seg_data.view(np.uint32).reshape(height, width) + self._backend.schedule(F.write_image, data=instance_seg_data, path=file_path) + + id_to_labels = anno_rp_data["idToLabels"] + file_path = f"{output_path}instance_id_segmentation_mapping_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.json" + self._backend.schedule(F.write_json, data={str(k): v for k, v in id_to_labels.items()}, path=file_path) + + def _write_instance_segmentation(self, anno_rp_data: dict, output_path: str): + instance_seg_data = anno_rp_data["data"] + height, width = instance_seg_data.shape[:2] + + file_path = f"{output_path}instance_segmentation_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.png" + if self.colorize_instance_segmentation: + instance_seg_data = instance_seg_data.view(np.uint8).reshape(height, width, -1) + self._backend.schedule(F.write_image, data=instance_seg_data, path=file_path) + else: + instance_seg_data = instance_seg_data.view(np.uint32).reshape(height, width) + self._backend.schedule(F.write_image, data=instance_seg_data, path=file_path) + + id_to_labels = anno_rp_data["idToLabels"] + file_path = f"{output_path}instance_segmentation_mapping_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.json" + self._backend.schedule(F.write_json, data={str(k): v for k, v in id_to_labels.items()}, path=file_path) + + id_to_semantics = anno_rp_data["idToSemantics"] + file_path = f"{output_path}instance_segmentation_semantics_mapping_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.json" + self._backend.schedule(F.write_json, data={str(k): v for k, v in id_to_semantics.items()}, path=file_path) + + def _write_motion_vectors(self, anno_rp_data: dict, output_path: str): + motion_vec_data = anno_rp_data["data"] + file_path = f"{output_path}motion_vectors_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.npy" + self._backend.schedule(F.write_np, data=motion_vec_data, path=file_path) + + def _write_occlusion(self, anno_rp_data: dict, output_path: str): + occlusion_data = anno_rp_data["data"] + + file_path = f"{output_path}occlusion_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.npy" + self._backend.schedule(F.write_np, data=occlusion_data, path=file_path) + + def _write_bounding_box_data(self, anno_rp_data: dict, bbox_type: str, output_path: str): + bbox_data = anno_rp_data["data"] + id_to_labels = anno_rp_data["idToLabels"] + prim_paths = anno_rp_data["primPaths"] + + file_path = ( + f"{output_path}bounding_box_{bbox_type}_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.npy" + ) + self._backend.schedule(F.write_np, data=bbox_data, path=file_path) + + labels_file_path = f"{output_path}bounding_box_{bbox_type}_labels_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.json" + self._backend.schedule(F.write_json, data=id_to_labels, path=labels_file_path) + + labels_file_path = f"{output_path}bounding_box_{bbox_type}_prim_paths_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.json" + self._backend.schedule(F.write_json, data=prim_paths, path=labels_file_path) + + def _write_camera_params(self, anno_rp_data: dict, output_path: str): + camera_data = anno_rp_data + serializable_data = {} + + for key, val in camera_data.items(): + if isinstance(val, np.ndarray): + serializable_data[key] = val.tolist() + else: + serializable_data[key] = val + + file_path = f"{output_path}camera_params_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.json" + self._backend.schedule(F.write_json, data=serializable_data, path=file_path) + + def _write_pointcloud(self, anno_rp_data: dict, output_path: str): + pointcloud_data = anno_rp_data["data"] + pointcloud_rgb = anno_rp_data["pointRgb"].reshape(-1, 4) + pointcloud_normals = anno_rp_data["pointNormals"].reshape(-1, 4) + pointcloud_semantic = anno_rp_data["pointSemantic"] + pointcloud_instance = anno_rp_data["pointInstance"] + + file_path = f"{output_path}pointcloud_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.npy" + self._backend.schedule(F.write_np, data=pointcloud_data, path=file_path) + + rgb_file_path = f"{output_path}pointcloud_rgb_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.npy" + self._backend.schedule(F.write_np, data=pointcloud_rgb, path=rgb_file_path) + + normals_file_path = ( + f"{output_path}pointcloud_normals_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.npy" + ) + self._backend.schedule(F.write_np, data=pointcloud_normals, path=normals_file_path) + + semantic_file_path = ( + f"{output_path}pointcloud_semantic_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.npy" + ) + self._backend.schedule(F.write_np, data=pointcloud_semantic, path=semantic_file_path) + + instance_file_path = ( + f"{output_path}pointcloud_instance_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.npy" + ) + self._backend.schedule(F.write_np, data=pointcloud_instance, path=instance_file_path) + + def _write_skeleton(self, anno_rp_data: dict, output_path: str): + # "skeletonData" is deprecated + # skeleton = json.loads(anno_rp_data["skeletonData"]) + + skeleton_dict = {} + + skel_name = anno_rp_data["skelName"] + skel_path = anno_rp_data["skelPath"] + asset_path = anno_rp_data["assetPath"] + animation_variant = anno_rp_data["animationVariant"] + skeleton_parents = skeleton_data_utils.get_skeleton_parents( + anno_rp_data["numSkeletons"], anno_rp_data["skeletonParents"], anno_rp_data["skeletonParentsSizes"] + ) + rest_global_translations = skeleton_data_utils.get_rest_global_translations( + anno_rp_data["numSkeletons"], + anno_rp_data["restGlobalTranslations"], + anno_rp_data["restGlobalTranslationsSizes"], + ) + rest_local_translations = skeleton_data_utils.get_rest_local_translations( + anno_rp_data["numSkeletons"], + anno_rp_data["restLocalTranslations"], + anno_rp_data["restLocalTranslationsSizes"], + ) + rest_local_rotations = skeleton_data_utils.get_rest_local_rotations( + anno_rp_data["numSkeletons"], + anno_rp_data["restLocalRotations"], + anno_rp_data["restLocalRotationsSizes"], + ) + global_translations = skeleton_data_utils.get_global_translations( + anno_rp_data["numSkeletons"], + anno_rp_data["globalTranslations"], + anno_rp_data["globalTranslationsSizes"], + ) + local_rotations = skeleton_data_utils.get_local_rotations( + anno_rp_data["numSkeletons"], anno_rp_data["localRotations"], anno_rp_data["localRotationsSizes"] + ) + translations_2d = skeleton_data_utils.get_translations_2d( + anno_rp_data["numSkeletons"], anno_rp_data["translations2d"], anno_rp_data["translations2dSizes"] + ) + skeleton_joints = skeleton_data_utils.get_skeleton_joints(anno_rp_data["skeletonJoints"]) + joint_occlusions = skeleton_data_utils.get_joint_occlusions( + anno_rp_data["numSkeletons"], anno_rp_data["jointOcclusions"], anno_rp_data["jointOcclusionsSizes"] + ) + occlusion_types = skeleton_data_utils.get_occlusion_types( + anno_rp_data["numSkeletons"], anno_rp_data["occlusionTypes"], anno_rp_data["occlusionTypesSizes"] + ) + in_view = anno_rp_data["inView"] + + for skel_num in range(anno_rp_data["numSkeletons"]): + skeleton_dict[f"skeleton_{skel_num}"] = {} + skeleton_dict[f"skeleton_{skel_num}"]["skel_name"] = skel_name[skel_num] + skeleton_dict[f"skeleton_{skel_num}"]["skel_path"] = skel_path[skel_num] + skeleton_dict[f"skeleton_{skel_num}"]["asset_path"] = asset_path[skel_num] + skeleton_dict[f"skeleton_{skel_num}"]["animation_variant"] = animation_variant[skel_num] + skeleton_dict[f"skeleton_{skel_num}"]["skeleton_parents"] = ( + skeleton_parents[skel_num].tolist() if skeleton_parents else [] + ) + skeleton_dict[f"skeleton_{skel_num}"]["rest_global_translations"] = ( + rest_global_translations[skel_num].tolist() if rest_global_translations else [] + ) + skeleton_dict[f"skeleton_{skel_num}"]["rest_local_translations"] = ( + rest_local_translations[skel_num].tolist() if rest_local_translations else [] + ) + skeleton_dict[f"skeleton_{skel_num}"]["rest_local_rotations"] = ( + rest_local_rotations[skel_num].tolist() if rest_local_rotations else [] + ) + skeleton_dict[f"skeleton_{skel_num}"]["global_translations"] = ( + global_translations[skel_num].tolist() if global_translations else [] + ) + skeleton_dict[f"skeleton_{skel_num}"]["local_rotations"] = ( + local_rotations[skel_num].tolist() if local_rotations else [] + ) + skeleton_dict[f"skeleton_{skel_num}"]["translations_2d"] = ( + translations_2d[skel_num].tolist() if translations_2d else [] + ) + skeleton_dict[f"skeleton_{skel_num}"]["skeleton_joints"] = ( + skeleton_joints[skel_num] if skeleton_joints else [] + ) + skeleton_dict[f"skeleton_{skel_num}"]["joint_occlusions"] = ( + joint_occlusions[skel_num].tolist() if joint_occlusions else [] + ) + skeleton_dict[f"skeleton_{skel_num}"]["occlusion_types"] = ( + occlusion_types[skel_num] if occlusion_types else [] + ) + skeleton_dict[f"skeleton_{skel_num}"]["in_view"] = bool(in_view[skel_num]) if in_view.any() else False + + file_path = f"{output_path}skeleton_{self._sequence_id}{self._frame_id:0{self._frame_padding}}.json" + + self._backend.schedule(F.write_json, data=skeleton_dict, path=file_path) + + diff --git a/source/isaaclab_mimic/isaaclab_mimic/datagen/async_writer_recorder.py b/source/isaaclab_mimic/isaaclab_mimic/datagen/async_writer_recorder.py new file mode 100644 index 00000000000..f0676d0f745 --- /dev/null +++ b/source/isaaclab_mimic/isaaclab_mimic/datagen/async_writer_recorder.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import asyncio +from typing import Any, Dict, List, Sequence, Tuple + +import pandas as pd +import torch +import asyncio +import threading + +from isaaclab.managers.recorder_manager import RecorderTerm +from isaaclab.utils.datasets import EpisodeData + +from isaaclab_mimic.async_writer import AsyncWriter + + +def _flatten_episode_dict(episode_data: Dict[str, Any], prefix: str = "") -> Dict[str, torch.Tensor]: + """Flattens nested dict tensors in EpisodeData.data into a flat dict with '/'-joined keys. + + Returns a mapping key -> tensor shaped (T, ...). All tensors remain as torch tensors (on CPU). + """ + flat: Dict[str, torch.Tensor] = {} + for key, value in episode_data.items(): + full_key = f"{prefix}/{key}" if prefix else key + if isinstance(value, dict): + flat.update(_flatten_episode_dict(value, prefix=full_key)) + elif isinstance(value, torch.Tensor): + # Ensure on CPU for downstream numpy conversion + flat[full_key] = value.detach().to("cpu") + else: + # Ignore unsupported types silently + continue + return flat + + +def _episode_to_dataframe_from_dict(data_dict: Dict[str, Any]) -> pd.DataFrame: + """Converts a nested episode data dict to a pandas DataFrame (one row per timestep).""" + if not data_dict: + return pd.DataFrame() + + flat = _flatten_episode_dict(data_dict) + + # Determine timesteps (prefer 'actions' length, else max length across tensors) + def _tensor_len(t: torch.Tensor) -> int: + return int(t.shape[0]) if t.dim() > 0 else 1 + + timesteps = 0 + if "actions" in flat: + timesteps = _tensor_len(flat["actions"]) + else: + for t in flat.values(): + timesteps = max(timesteps, _tensor_len(t)) + + if timesteps == 0: + return pd.DataFrame() + + columns: Dict[str, List[Any]] = {} + for k, t in flat.items(): + if _tensor_len(t) == timesteps: + # Per-step series + columns[k] = [t[i] for i in range(timesteps)] + else: + # Do not broadcast single-step values like initial_state; keep single entry + single = t[0] if t.dim() > 0 else t + if k.endswith("initial_state") or k.split("/")[-1] == "initial_state": + columns[k] = [single] + else: + # For non-initial_state, retain previous behavior (broadcast) to keep writer compatibility + columns[k] = [single for _ in range(timesteps)] + + return pd.DataFrame(columns) + + +# take snapshot of episode data, onto cpu & cloned so live buffers arent used +def _snapshot_episode_data(episode: EpisodeData) -> Dict[str, Any]: + + def _clone_tree(node: Any) -> Any: + if isinstance(node, torch.Tensor): + return node.detach().to("cpu").clone() + if isinstance(node, dict): + return {k: _clone_tree(v) for k, v in node.items()} + return node + + return _clone_tree(episode.data) + + + + +class AsyncWriterRecorder(RecorderTerm): + + def __init__(self, cfg, env): + super().__init__(cfg, env) + if AsyncWriter is None: + raise RuntimeError("AsyncWriter could not be imported; cannot initialize AsyncWriterRecorder.") + + + rm_cfg = getattr(env.cfg, "recorders", None) + if rm_cfg is None: + + self._output_dir = "/tmp/isaaclab/logs" + self._out_file = "dataset.hdf5" + else: + self._output_dir = rm_cfg.dataset_export_dir_path + self._out_file = f"{rm_cfg.dataset_filename}.hdf5" + + self._writer = AsyncWriter(output_dir=self._output_dir) + # Mirror HDF5DatasetFileHandler defaults: set env_name and type=2 + env_name = getattr(env.cfg, "env_name", "") + try: + self._writer.set_env_args({"env_name": env_name, "type": 2}) + except Exception: + pass + + # Track pending async write tasks to allow explicit draining/close + self._pending_tasks: set[asyncio.Task] = set() + + def record_pre_reset(self, env_ids: Sequence[int] | None) -> Tuple[str | None, torch.Tensor | dict | None]: + if env_ids is None: + env_ids = list(range(self._env.num_envs)) + + + for env_id in env_ids: + episode = self._env.recorder_manager.get_episode(env_id) + if episode is None or episode.is_empty(): + continue + self.schedule_async_write_for_episode(episode) + # Clear episode buffer synchronously + from isaaclab.utils.datasets import EpisodeData as _EpisodeData + self._env.recorder_manager._episodes[env_id] = _EpisodeData() + + + return None, None + + + async def schedule_async_write_for_episode(self, episode: EpisodeData) -> None: + snapshot = episode.data + if not snapshot: + return + + async def _do_write(data_snapshot: Dict[str, Any]): + df = await asyncio.to_thread(_episode_to_dataframe_from_dict, data_snapshot) + if not df.empty: + await self._writer.write_trajectory_data_async(df, self._out_file, debug=False) + + task = asyncio.create_task(_do_write(snapshot)) + self._pending_tasks.add(task) + task.add_done_callback(lambda t: self._pending_tasks.discard(t)) + # Return immediately; manager may await this coroutine, but work continues in background + return + + # Synchronous writer for single-writer-thread consumer + def schedule_sync_write_for_episode(self, episode: EpisodeData) -> None: + snapshot = _snapshot_episode_data(episode) + if not snapshot: + return + df = _episode_to_dataframe_from_dict(snapshot) + if not df.empty: + # direct, blocking write on the writer thread + self._writer._write_trajectory_data_hdf5(df, self._out_file, debug=False) + + def close(self) -> None: + try: + if hasattr(self, "_writer") and self._writer is not None: + self._writer.close() + except Exception: + pass + + async def flush_async(self) -> None: + if not self._pending_tasks: + return + await asyncio.gather(*list(self._pending_tasks), return_exceptions=True) + + + diff --git a/source/isaaclab_mimic/isaaclab_mimic/datagen/generation.py b/source/isaaclab_mimic/isaaclab_mimic/datagen/generation.py index 6abdc088170..50261fc7b99 100644 --- a/source/isaaclab_mimic/isaaclab_mimic/datagen/generation.py +++ b/source/isaaclab_mimic/isaaclab_mimic/datagen/generation.py @@ -11,9 +11,11 @@ from isaaclab.envs import ManagerBasedRLMimicEnv from isaaclab.envs.mdp.recorders.recorders_cfg import ActionStateRecorderManagerCfg from isaaclab.managers import DatasetExportMode, TerminationTermCfg +from isaaclab.managers.manager_term_cfg import RecorderTermCfg from isaaclab_mimic.datagen.data_generator import DataGenerator from isaaclab_mimic.datagen.datagen_info_pool import DataGenInfoPool +from isaaclab_mimic.datagen.async_writer_recorder import AsyncWriterRecorder from isaaclab_tasks.utils.parse_cfg import parse_env_cfg @@ -31,7 +33,7 @@ async def run_data_generator( data_generator: DataGenerator, success_term: TerminationTermCfg, pause_subtask: bool = False, - motion_planner: Any = None, + motion_planner: Any | None = None, ): """Run mimic data generation from the given data generator in the specified environment index. @@ -43,17 +45,17 @@ async def run_data_generator( data_generator: The data generator instance to use. success_term: The success termination term to use. pause_subtask: Whether to pause the subtask during generation. - motion_planner: The motion planner to use. """ global num_success, num_failures, num_attempts while True: + + results = await data_generator.generate( env_id=env_id, success_term=success_term, env_reset_queue=env_reset_queue, env_action_queue=env_action_queue, pause_subtask=pause_subtask, - motion_planner=motion_planner, ) if bool(results["success"]): num_success += 1 @@ -68,6 +70,7 @@ def env_loop( env_action_queue: asyncio.Queue, shared_datagen_info_pool: DataGenInfoPool, asyncio_event_loop: asyncio.AbstractEventLoop, + ): """Main asyncio loop for the environment. @@ -81,6 +84,8 @@ def env_loop( global num_success, num_failures, num_attempts env_id_tensor = torch.tensor([0], dtype=torch.int64, device=env.device) prev_num_attempts = 0 + + print("STARTING ENV LOOP") # simulate environment -- run everything in inference mode with contextlib.suppress(KeyboardInterrupt) and torch.inference_mode(): while True: @@ -141,6 +146,8 @@ def setup_env_config( num_envs: int, device: str, generation_num_trials: int | None = None, + use_async_writer: bool = False, + early_cpu_offload: bool = False, ) -> tuple[Any, Any]: """Configure the environment for data generation. @@ -151,7 +158,8 @@ def setup_env_config( num_envs: Number of environments to run device: Device to run on generation_num_trials: Optional override for number of trials - + use_async_writer: Whether to use async writer + early_cpu_offload: Whether to use early cpu offload (episode data moved to CPU aggressively) Returns: tuple containing: - env_cfg: The environment configuration @@ -180,25 +188,27 @@ def setup_env_config( env_cfg.observations.policy.concatenate_terms = False # Setup recorders + + env_cfg.early_cpu_offload = early_cpu_offload env_cfg.recorders = ActionStateRecorderManagerCfg() env_cfg.recorders.dataset_export_dir_path = output_dir env_cfg.recorders.dataset_filename = output_file_name - if env_cfg.datagen_config.generation_keep_failed: - env_cfg.recorders.dataset_export_mode = DatasetExportMode.EXPORT_SUCCEEDED_FAILED_IN_SEPARATE_FILES + if use_async_writer: + env_cfg.recorders.dataset_export_mode = DatasetExportMode.EXPORT_NONE + env_cfg.recorders.export_in_record_pre_reset = False + env_cfg.recorders.async_writer = RecorderTermCfg(class_type=AsyncWriterRecorder) else: env_cfg.recorders.dataset_export_mode = DatasetExportMode.EXPORT_SUCCEEDED_ONLY + env_cfg.recorders.export_in_record_pre_reset = True + env_cfg.recorders.async_writer = None + return env_cfg, success_term def setup_async_generation( - env: Any, - num_envs: int, - input_file: str, - success_term: Any, - pause_subtask: bool = False, - motion_planners: Any = None, + env: Any, num_envs: int, input_file: str, success_term: Any, pause_subtask: bool = False ) -> dict[str, Any]: """Setup async data generation tasks. @@ -208,7 +218,6 @@ def setup_async_generation( input_file: Path to input dataset file success_term: Success termination condition pause_subtask: Whether to pause after subtasks - motion_planners: Motion planner instances for all environments Returns: List of asyncio tasks for data generation @@ -225,17 +234,9 @@ def setup_async_generation( data_generator = DataGenerator(env=env, src_demo_datagen_info_pool=shared_datagen_info_pool) data_generator_asyncio_tasks = [] for i in range(num_envs): - env_motion_planner = motion_planners[i] if motion_planners else None task = asyncio_event_loop.create_task( run_data_generator( - env, - i, - env_reset_queue, - env_action_queue, - data_generator, - success_term, - pause_subtask=pause_subtask, - motion_planner=env_motion_planner, + env, i, env_reset_queue, env_action_queue, data_generator, success_term, pause_subtask=pause_subtask ) ) data_generator_asyncio_tasks.append(task) diff --git a/source/isaaclab_mimic/isaaclab_mimic/io_functions.py b/source/isaaclab_mimic/isaaclab_mimic/io_functions.py new file mode 100644 index 00000000000..070f1741af9 --- /dev/null +++ b/source/isaaclab_mimic/isaaclab_mimic/io_functions.py @@ -0,0 +1,486 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r''' +The io_functions.py module provides a set of utility functions, which can be easily utilized to manage I/O operations +within different backends. Additional functions can be registered with this module to seamlessly expand the capabilities +of backends to handle a variety of I/O tasks. + +Example: + +.. code:: python + + def my_io_function(backend_instance, **kwargs) + """My IO function + + Parameters: + backend_instance: An instance of the backend derived from BaseBackend and registered with BackendRegistry, used + for executing the I/O operation. For instance, in the statement + ``backend_instance.schedule(io_functions.my_io_function, kwargs)``, backend_instance is the instance of the + backend used to schedule the my_function operation with specified keyword arguments. + **kwargs: Keyword arguments that can be employed to provide supplementary information or customization for the + I/O operation. These arguments can be specified for I/O functions that require specific configurations, such + as ``write_jpeg``. + """ +''' + +import io +import json +import os +import pickle +import platform +import threading +from contextlib import nullcontext +from typing import Union, Dict, Callable, Optional + +import numpy as np +import warp as wp +import h5py +import pandas as pd +import torch +from collections import defaultdict +from omni.replicator.core.bindings._omni_replicator_exrwriter import load_exr_from_stream, save_exr_to_stream +from PIL import Image + +from omni.replicator.core.backends import BaseBackend, DiskBackend + + +def _to_pil_image(data): + if isinstance(data, wp.array): + data = data.numpy() + + if isinstance(data, np.ndarray): + if data.shape[-1] > 3 and len(data.shape) == 3: + data = Image.fromarray(data, "RGBA") + elif data.shape[-1] == 3 and len(data.shape) == 3: + data = Image.fromarray(data, "RGB") + elif data.shape[-1] == 1 and len(data.shape) == 3: + data = Image.fromarray(data[:, :, 0], "L") + else: + if data.dtype == np.uint16: + data = Image.fromarray(data, "I;16") + else: + data = Image.fromarray(data) + + if not isinstance(data, Image.Image): + raise ValueError(f"Expected image data to be a numpy ndarray, warp array or PIL.Image, got {type(data)}") + + return data + + +def write_image( + path: str, data: Union[np.ndarray, wp.array, Image.Image], backend_instance: BaseBackend = DiskBackend, **kwargs +) -> None: + """ + Write image data to a specified path. + Supported image extensions include: [jpeg, jpg, png, exr] + + Args: + path: Write path URI + data: Image data + backend_instance: Backend to use to write. Defaults to ``DiskBackend``. + kwargs: Specify additional save parameters, typically specific to the image file type. + """ + if isinstance(data, wp.array): + data = data.numpy() + + ext = os.path.splitext(path)[-1][1:] + if ext.lower() not in ["jpeg", "jpg", "png", "exr"]: + raise ValueError(f"Could not write image to path `{path}`, image extension `{ext}` is not supported.") + + if ext.lower() in ["jpeg", "jpg", "png"]: + data = _to_pil_image(data) + + if ext.lower() in ["jpeg", "jpg"]: + data = data.convert("RGB") + write_jpeg(path, data, backend_instance=backend_instance, **kwargs) + else: + write_png(path, data, backend_instance=backend_instance, **kwargs) + + elif ext.lower() == "exr": + write_exr(path, data, backend_instance=backend_instance, **kwargs) + + +def write_jpeg( + path: str, + data: Union[np.ndarray, wp.array], + backend_instance: BaseBackend = DiskBackend, + quality: int = 75, + progressive: bool = False, + optimize: bool = False, + **kwargs, +) -> None: + """ + Write image data to JPEG. + + Args: + path: Write path URI + data: Data to write + backend_instance: Backend to use to write. Defaults to ``DiskBackend``. Defaults to ``DiskBackend``. + quality: The image quality, on a scale from 0 (worst) to 95 (best), or the string keep. The default is 75. + Values above 95 should be avoided; 100 disables portions of the JPEG compression algorithm, and results in + large files with hardly any gain in image quality. The value keep is only valid for JPEG files and will + retain the original image quality level, subsampling, and qtables. + progressive: Indicates that this image should be stored as a progressive JPEG file. + optimize: Reduce file size, may be slower. Indicates that the encoder should make an extra pass over the image + in order to select optimal encoder settings. + kwargs: Additional parameters may be specified and can be found within the PILLOW documentation: + https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#jpeg-saving + """ + data = _to_pil_image(data) + buf = io.BytesIO() + data.save(buf, format="jpeg", quality=quality, optimize=optimize, progressive=progressive, **kwargs) + backend_instance.write_blob(path, buf.getvalue()) + + +def write_png( + path: str, + data: Union[np.ndarray, wp.array], + backend_instance: BaseBackend = DiskBackend, + compress_level: int = 3, + **kwargs, +) -> None: + """ + Write image data to PNG. + + + Args: + path: Write path URI + data: Data to write + backend_instance: Backend to use to write. Defaults to ``DiskBackend``. + compress_level: Specifies ZLIB compression level. Compression is specified as a value between [0, 9] where 1 is + fastest and 9 provides the best compression. A value of 0 provides no compression. Defaults to 3. + **kwargs: Additional parameters may be specified and can be found within the PILLOW documentation: + https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#png-saving + """ + data = _to_pil_image(data) + buf = io.BytesIO() + data.save(buf, format="png", compress_level=compress_level, **kwargs) + backend_instance.write_blob(path, buf.getvalue()) + + +def _write_exr_imageio( + path: str, data: Union[np.ndarray, wp.array], backend_instance: BaseBackend = DiskBackend, exr_flag=None, **kwargs +) -> None: + """ + Write data to EXR. + + Args: + path: Write path URI + data: Data to write + backend_instance: Backend to use to write. Defaults to ``DiskBackend``. + exr_flag from FIF_EXR: + - imageio.plugins.freeimage.IO_FLAGS.EXR_DEFAULT: Save data as half with piz-based wavelet compression + - imageio.plugins.freeimage.IO_FLAGS.EXR_FLOAT: Save data as float instead of as half (not recommended) + - imageio.plugins.freeimage.IO_FLAGS.EXR_NONE: Save with no compression + - imageio.plugins.freeimage.IO_FLAGS.EXR_ZIP: Save with zlib compression, in blocks of 16 scan lines + - imageio.plugins.freeimage.IO_FLAGS.EXR_PIZ: Save with piz-based wavelet compression + - imageio.plugins.freeimage.IO_FLAGS.EXR_PXR24: Save with lossy 24-bit float compression + - imageio.plugins.freeimage.IO_FLAGS.EXR_B44: Save with lossy 44% float compression - goes to 22% when + combined with EXR_LC + - imageio.plugins.freeimage.IO_FLAGS.EXR_LC: Save images with one luminance and two chroma channels, rather + than as RGB (lossy compression) + """ + import imageio + + if isinstance(data, wp.array): + data = data.numpy() + + # Download freeimage dll, will only download once if not present + # from https://imageio.readthedocs.io/en/v2.8.0/format_exr-fi.html#exr-fi + imageio.plugins.freeimage.download() + if exr_flag is None and platform.machine() != "aarch64": + # Flag for x86_64, not supported on ARM at the moment, tracked in OMPE-46846 + exr_flag = imageio.plugins.freeimage.IO_FLAGS.EXR_ZIP + + exr_bytes = imageio.imwrite( + imageio.RETURN_BYTES, + data, + format="exr", + flags=exr_flag, + ) + else: + exr_bytes = imageio.imwrite( + imageio.RETURN_BYTES, + data, + format="exr", + ) + backend_instance.write_blob(path, exr_bytes) + + +def write_exr( + path: str, + data: Union[np.ndarray, wp.array], + backend_instance: BaseBackend = DiskBackend, + half_precision: bool = False, + **kwargs, +) -> None: + """ + Write data to EXR. + + Args: + path: Write path URI + data: Data to write + backend_instance: Backend to use to write. Defaults to ``DiskBackend``. + half_precision: bool, optional + Save data as half precision instead of full precision. Default to False. + **kwargs: If "exr_flag" is provided, legacy imageio implementation is used. + """ + if "exr_flag" in kwargs: + return _write_exr_imageio(path, data, backend_instance, kwargs["exr_flag"]) + + if isinstance(data, wp.array): + data = data.numpy() + + buf = io.BytesIO() + save_exr_to_stream(buf, data, half_precision) + backend_instance.write_blob(path, buf.getvalue()) + + +def write_json( + path, + data, + backend_instance=None, + encoding="utf-8", + errors="strict", + **kwargs, +) -> None: + """ + Write json data to a specified path. + + Args: + path: Write path URI + data: Data to write + backend_instance: Backend to use to write. Defaults to ``DiskBackend``. + encoding: This parameter specifies the encoding to be used. For a list of all encoding schemes, please visit: + https://docs.python.org/3/library/codecs.html#standard-encodings + errors: This parameter specifies an error handling scheme when encoding the json string data. The default for + errors is 'strict' which means that the encoding errors raise a UnicodeError. Other possible values are + 'ignore', 'replace', 'xmlcharrefreplace', 'backslashreplace' and any othername registered via + codecs.register_error(). + **kwargs: Additional JSON encoding parameters may be supplied. See + https://docs.python.org/3/library/json.html#json.dump for full list. + """ + + buf = io.BytesIO() + buf.write( + json.dumps( + data, + **kwargs, + ).encode(encoding, errors=errors) + ) + backend_instance.write_blob(path, buf.getvalue()) + + +def write_pickle( + path: str, data: Union[np.ndarray, wp.array], backend_instance: BaseBackend = DiskBackend, **kwargs +) -> None: + """ + Write pickle data to a specified path. + + Args: + path: Write path URI + data: Data to write + backend_instance: Backend to use to write. Defaults to ``DiskBackend``. + **kwargs: Additional Pickle encoding parameters may be supplied. See + https://docs.python.org/3/library/pickle.html#pickle.Pickler for full list. + """ + buf = io.BytesIO() + pickle.dump(data, buf, **kwargs) + backend_instance.write_blob(path, buf.getvalue()) + + +def write_np( + path: str, + data: Union[np.ndarray, wp.array], + backend_instance: BaseBackend = DiskBackend, + allow_pickle: bool = True, + fix_imports: bool = True, +) -> None: + """ + Write numpy data to a specified path. + Save parameters are detailed here: https://numpy.org/doc/stable/reference/generated/numpy.save.html + + Args: + path: Write path URI + data: Data to write + backend_instance: Backend to use to write. Defaults to ``DiskBackend``. + allow_pickle : bool, optional + Allow saving object arrays using Python pickles. Reasons for disallowing + pickles include security (loading pickled data can execute arbitrary + code) and portability (pickled objects may not be loadable on different + Python installations, for example if the stored objects require libraries + that are not available, and not all pickled data is compatible between + Python 2 and Python 3). + Default to True. + fix_imports : bool, optional + Only useful in forcing objects in object arrays on Python 3 to be + pickled in a Python 2 compatible way. If ``fix_imports`` is True, pickle + will try to map the new Python 3 names to the old module names used in + Python 2, so that the pickle data stream is readable with Python 2. Defaults + to True + """ + if isinstance(data, wp.array): + data = data.numpy() + + buf = io.BytesIO() + np.save(buf, data, allow_pickle=allow_pickle, fix_imports=fix_imports) + backend_instance.write_blob(path, buf.getvalue()) + + +def _parse_column_structure(df_columns): + """Parse DataFrame column names to determine the nested group structure of the hdf5 file.""" + structure = defaultdict(list) + for col in df_columns: + if '/' in col: + parts = col.split('/') + main_group = parts[0] + subgroup_path = '/'.join(parts[1:]) + structure[main_group].append(subgroup_path) + else: + structure['root'].append(col) + return dict(structure) + + +def _create_nested_datasets(demo_group, df, structure): + """Create nested datasets in HDF5 based on the parsed structure.""" + for main_group, subgroups in structure.items(): + if main_group == 'root': + for col_name in subgroups: + data_series = df[col_name] + if isinstance(data_series.iloc[0], torch.Tensor): + stacked_data = torch.stack(data_series.tolist()).numpy() + else: + stacked_data = np.stack(data_series.values) + demo_group.create_dataset(col_name, data=stacked_data) + else: + group_obj = demo_group.create_group(main_group) + subgroup_dict = defaultdict(list) + for col_path in subgroups: + parts = col_path.split('/') + if len(parts) == 1: + subgroup_dict['root'].append((parts[0], col_path)) + else: + subgroup_dict[parts[0]].append(('/'.join(parts[1:]), col_path)) + + for immediate_subgroup, column_info in subgroup_dict.items(): + if immediate_subgroup == 'root': + for dataset_name, col_path in column_info: + data_series = df[f"{main_group}/{col_path}"] + if isinstance(data_series.iloc[0], torch.Tensor): + stacked_data = torch.stack(data_series.tolist()).numpy() + else: + stacked_data = np.stack(data_series.values) + group_obj.create_dataset(dataset_name, data=stacked_data) + else: + subgroup_obj = group_obj.create_group(immediate_subgroup) + for nested_path, col_path in column_info: + data_series = df[f"{main_group}/{col_path}"] + if isinstance(data_series.iloc[0], torch.Tensor): + stacked_data = torch.stack(data_series.tolist()).numpy() + else: + stacked_data = np.stack(data_series.values) + subgroup_obj.create_dataset(nested_path, data=stacked_data) + + +def write_dataframe_hdf5( + path: str, + data: pd.DataFrame, + backend_instance: BaseBackend = DiskBackend, + demo_name: str = "demo_0", + env_args: Optional[Dict] = None, + file_handle_getter: Optional[Callable[[str], h5py.File]] = None, + file_lock: Optional[threading.Lock] = None, + **kwargs, +) -> None: + """ + Write DataFrame data to HDF5 file with nested structure. + + Args: + path: Write path URI for the HDF5 file + data: DataFrame containing trajectory data with nested column naming (e.g., "obs/right_eef_pos") + backend_instance: Backend to use for directory creation and path management. Defaults to ``DiskBackend``. + demo_name: Name for this demo subgroup. Defaults to "demo_0". + env_args: Optional dictionary of environment arguments to store as attributes. + file_handle_getter: Optional callable that takes a filepath and returns an open h5py.File handle. + If provided, this handle will be used for writing. If None, a new file will be opened. + **kwargs: Additional parameters (unused, for compatibility with backend pattern). + """ + lock_ctx = file_lock if file_lock is not None else nullcontext() + + with lock_ctx: + structure = _parse_column_structure(data.columns) + + if file_handle_getter is not None: + # Use provided file handle (for persistent handles) + f = file_handle_getter(path) + else: + # Ensure directory exists when possible + dir_path = os.path.dirname(path) + if hasattr(backend_instance, "make_dirs"): + if dir_path: + backend_instance.make_dirs(dir_path) + elif dir_path: + os.makedirs(dir_path, exist_ok=True) + f = h5py.File(path, 'a') + + group = f['data'] if 'data' in f else f.create_group('data') + if env_args is not None: + try: + group.attrs['env_args'] = json.dumps(env_args) + f.attrs['env_args'] = json.dumps(env_args) + except Exception: + pass + + if demo_name in group: + if file_handle_getter is None: + f.close() + raise ValueError(f"Demo {demo_name} already exists in the HDF5 file.") + + demo_group = group.create_group(demo_name) + _create_nested_datasets(demo_group, data, structure) + + try: + f.flush() + except Exception: + pass + + if file_handle_getter is None: + f.close() + + +def read_exr( + path: str, + backend_instance: BaseBackend = DiskBackend, +) -> np.ndarray: + """Read an EXR image and return it as a NumPy ``ndarray``. + + Args: + path (str): Path to the EXR file to read. + backend_instance (BaseBackend, optional): Backend to use when reading the + file. If an *instance* of a backend is supplied, its + :py:meth:`read_blob` method is used and the image is decoded from + memory. If a backend *class* (e.g. ``DiskBackend``) is given, the + path is read directly from disk. Defaults to ``DiskBackend``. + + Returns: + numpy.ndarray: The decoded image data. The array shape is + ``(H, W)`` for single-channel images or ``(H, W, C)`` for multi-channel + images. The dtype matches the source file (typically ``float32`` or + ``float16``). + """ + exr_bytes = backend_instance.read_blob(path) + buf = io.BytesIO(exr_bytes) + return load_exr_from_stream(buf)