Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 20 additions & 227 deletions lzero/mcts/buffer/game_buffer_priorzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,158 +18,6 @@
from typing import List, Any, Union, Tuple
from lzero.mcts.buffer.game_buffer_unizero import UniZeroGameBuffer


class PriorZeroGameBuffer(UniZeroGameBuffer):
"""
[PRIORZERO-MODIFIED]
Enhanced GameBuffer that provides game_segments for LLM policy training.

Modifications:
1. sample() returns game_segments as 4th element
2. Efficient implementation using existing game_segment_list from _make_batch
3. No additional memory overhead (returns references, not copies)
"""

def __init__(self, cfg):
"""Initialize PriorZero Game Buffer."""
super().__init__(cfg)

# [PRIORZERO-NEW] Cache for the last sampled game segments
# This avoids re-sampling when we need game segments
self._last_sampled_game_segments = None
self._last_sampled_batch_indices = None

def sample(
self,
batch_size: int,
policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]
) -> List[Any]:
"""
[PRIORZERO-MODIFIED]
Sample data and prepare current_batch, target_batch, AND game_segments.

Returns:
train_data: [current_batch, target_batch, game_segments]
- current_batch: [obs, action, target_action, mask, indices, weights, make_time, timestep]
- target_batch: [rewards, values, policies]
- game_segments: List of GameSegment objects used in this batch

Note:
game_segments are returned for LLM training (SFT/RFT).
They contain:
- mcts_policy_segment: MCTS visit distributions (for SFT supervision)
- raw_obs_segment: Raw text observations (for LLM prompts)
- reward_segment: Environment rewards (for RFT)
- search_value_segment: MCTS search values (for analysis)
"""
policy._target_model.to(self._cfg.device)
policy._target_model.eval()

# ======================================================================
# [PRIORZERO-KEY] Sample data and extract game_segments
# ======================================================================
# obtain the current_batch and prepare target context
reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch(
batch_size, self._cfg.reanalyze_ratio
)

# [PRIORZERO-NEW] Extract game_segments from the sampling process
# These were already created in _make_batch, we just need to save them
game_segments = self._last_sampled_game_segments

# Defensive check: ensure game_segments match batch_size
if game_segments is None or len(game_segments) != len(current_batch[4]): # current_batch[4] is batch_index_list
# Fallback: create empty list if something went wrong
import logging
logging.warning(
f"[PriorZeroBuffer] game_segments mismatch: "
f"expected {len(current_batch[4])}, got {len(game_segments) if game_segments else None}. "
f"Falling back to empty list (SFT/RFT will be skipped)."
)
game_segments = []

# ======================================================================
# Standard UniZero processing (unchanged)
# ======================================================================
# current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list]

# target reward, target value
batch_rewards, batch_target_values = self._compute_target_reward_value(
reward_value_context, policy._target_model, current_batch[2], current_batch[-1] # current_batch[2] is batch_target_action
)

# target policy
batch_target_policies_re = self._compute_target_policy_reanalyzed(
policy_re_context, policy._target_model, current_batch[1], current_batch[-1]
) # current_batch[1] is batch_action
batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed(
policy_non_re_context, self.action_space_size
)

# fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies
if 0 < self._cfg.reanalyze_ratio < 1:
batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re])
elif self._cfg.reanalyze_ratio == 1:
batch_target_policies = batch_target_policies_re
elif self._cfg.reanalyze_ratio == 0:
batch_target_policies = batch_target_policies_non_re

target_batch = [batch_rewards, batch_target_values, batch_target_policies]

# ======================================================================
# [PRIORZERO-KEY] Return current_batch, target_batch, AND game_segments
# ======================================================================
train_data = [current_batch, target_batch, game_segments]
return train_data

def _sample_orig_data(self, batch_size: int) -> Tuple[Any]:
"""
[PRIORZERO-MODIFIED]
Override to cache game_segments during sampling.

This avoids double sampling by caching the result for sample() to use.
"""
# Call parent implementation
result = super()._sample_orig_data(batch_size)

# Cache the game_segment_list (first element of result tuple)
game_segment_list = result[0]
self._last_sampled_game_segments = game_segment_list
self._last_sampled_batch_indices = result[2] # batch_index_list

return result

def _sample_orig_data_episode(self, batch_size: int) -> Tuple[Any]:
"""
[PRIORZERO-MODIFIED]
Override to cache game_segments during episode sampling.

This avoids double sampling by caching the result for sample() to use.
"""
# Call parent implementation
result = super()._sample_orig_data_episode(batch_size)

# Cache the game_segment_list (first element of result tuple)
game_segment_list = result[0]
self._last_sampled_game_segments = game_segment_list
self._last_sampled_batch_indices = result[2] # batch_index_list

return result

def clear(self):
"""
[PRIORZERO-MODIFIED]
Clear buffer and cached game segments.
"""
super().clear()
self._last_sampled_game_segments = None
self._last_sampled_batch_indices = None


# ==============================================================================
# Optimized Alternative (Avoids Double Sampling)
# ==============================================================================

class PriorZeroGameBufferOptimized(UniZeroGameBuffer):
"""
[PRIORZERO-OPTIMIZED]
Expand All @@ -195,16 +43,14 @@ def sample(self, batch_size: int, policy) -> List[Any]:
batch_size, self._cfg.reanalyze_ratio
)

# Get cached game segments (set by our overridden _make_batch)
game_segments = self._cached_game_segments or []

obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list, raw_obs_list, history_obs_list, action_logprob_list = current_batch
# Standard processing
batch_rewards, batch_target_values = self._compute_target_reward_value(
reward_value_context, policy._target_model, current_batch[2], current_batch[-1]
reward_value_context, policy._target_model, current_batch[2], timestep_list
)

batch_target_policies_re = self._compute_target_policy_reanalyzed(
policy_re_context, policy._target_model, current_batch[1], current_batch[-1]
policy_re_context, policy._target_model, current_batch[1], timestep_list
)
batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed(
policy_non_re_context, self.action_space_size
Expand All @@ -219,7 +65,7 @@ def sample(self, batch_size: int, policy) -> List[Any]:

target_batch = [batch_rewards, batch_target_values, batch_target_policies]

return [current_batch, target_batch, game_segments]
return [current_batch, target_batch]

def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
"""
Expand All @@ -243,6 +89,8 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
# Rest of the code is identical to parent's _make_batch
batch_size = len(batch_index_list)
obs_list, action_list, mask_list = [], [], []
raw_obs_list, history_obs_list = [], []
action_logprob_list = []
timestep_list = []
bootstrap_action_list = []

Expand Down Expand Up @@ -272,6 +120,16 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
)
)
raw_obs_list.append(game_segment_list[i].get_unroll_raw_obs(
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
))
history_obs_list.append(game_segment_list[i].get_unroll_histroy_obs(
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
))
action_logprob_list.append(game_segment_list[i].get_unroll_action_logprob(
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
))

action_list.append(actions_tmp)
mask_list.append(mask_tmp)
timestep_list.append(timestep_tmp)
Expand All @@ -291,6 +149,10 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list]
for i in range(len(current_batch)):
current_batch[i] = np.asarray(current_batch[i])

current_batch.append(raw_obs_list)
current_batch.append(history_obs_list)
current_batch.append(action_logprob_list)

total_transitions = self.get_num_of_transitions()

Expand Down Expand Up @@ -318,72 +180,3 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
policy_non_re_context = None

return reward_value_context, policy_re_context, policy_non_re_context, current_batch


# ==============================================================================
# Factory Function
# ==============================================================================

def create_priorzero_buffer(cfg, optimized: bool = True):
"""
Factory function to create PriorZero game buffer.

Args:
cfg: Configuration dict
optimized: If True, use optimized version (recommended)

Returns:
buffer: PriorZero game buffer instance
"""
if optimized:
return PriorZeroGameBufferOptimized(cfg)
else:
return PriorZeroGameBuffer(cfg)


if __name__ == "__main__":
print("="*80)
print("PriorZero Game Buffer - Unit Tests")
print("="*80)

# Create mock config
class MockConfig:
def __init__(self):
self.device = 'cpu'
self.env_type = 'not_board_games'
self.game_segment_length = 200
self.num_unroll_steps = 5
self.td_steps = 5
self.batch_size = 32
self.use_priority = False
self.reanalyze_ratio = 0.0
self.sample_type = 'transition'
self.replay_buffer_size = 10000
self.model = type('obj', (object,), {
'model_type': 'mlp',
'action_space_size': 10,
'observation_shape': 128,
})()

cfg = MockConfig()

# Test both versions
for name, buffer_class in [
("Standard", PriorZeroGameBuffer),
("Optimized", PriorZeroGameBufferOptimized)
]:
print(f"\nTesting {name} Buffer:")
print("-" * 40)

buffer = buffer_class(cfg)
print(f"✓ Buffer created: {type(buffer).__name__}")
print(f" - sample_type: {buffer.sample_type}")
print(f" - action_space_size: {buffer.action_space_size}")

# Note: Full testing would require mock GameSegments and Policy
# For now, just verify instantiation
print(f"✓ {name} buffer initialized successfully")

print("\n" + "="*80)
print("✓ All tests passed!")
print("="*80)
2 changes: 1 addition & 1 deletion lzero/model/unizero_world_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2064,7 +2064,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
value_priority=value_priority,
intermediate_tensor_x=intermediate_tensor_x,
obs_embeddings=detached_obs_embeddings, # <-- 新增
)
), inverse_scalar_transform_handle(outputs.logits_value.reshape(-1, outputs.logits_value.shape[-1])).detach()


# TODO: test correctness
Expand Down
10 changes: 0 additions & 10 deletions lzero/worker/muzero_segment_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,16 +477,6 @@ def collect(
if self.policy_config.use_ture_chance_label_in_chance_encoder:
append_kwargs['chance'] = self.chance_dict_tmp[env_id]

# [PRIORZERO-NEW] Add raw_obs_text if available in obs (not info!)
# Jericho env puts raw_obs_text in the obs dictionary
if env_id == 0 and collected_step < 5: # Debug first few steps
print(f"[OBS_DEBUG] Step {collected_step} env {env_id}: obs keys = {list(obs.keys())}")
print(f"[OBS_DEBUG] obs type = {type(obs)}")
if 'raw_obs_text' in obs:
print(f"[OBS_DEBUG] Found raw_obs_text: {str(obs['raw_obs_text'])[:100]}...")
else:
print(f"[OBS_DEBUG] NO raw_obs_text in obs!")

if 'raw_obs_text' in obs:
append_kwargs['raw_obs_text'] = obs['raw_obs_text']
elif 'raw_obs_text' in info:
Expand Down
1 change: 1 addition & 0 deletions zoo/jericho/envs/jericho_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def step(self, action: Union[int, np.ndarray, str], return_str: bool = False) ->
previous_obs: Optional[str] = self.last_observation if (self.remove_stuck_actions and self.last_observation is not None) else None

observation, reward, done, info = self._env.step(action_str)
info['action_str'] = action_str

self._timestep += 1
if not self.for_unizero:
Expand Down
Loading