From 4b2268f5043b87f9748447448162a133151f902b Mon Sep 17 00:00:00 2001 From: jasper <1157507000@qq.com> Date: Tue, 12 Aug 2025 15:16:56 +0800 Subject: [PATCH] finetune v1 --- ...rain_unizero_multitask_segment_ddp copy.py | 940 ++++++++++++++++++ .../train_unizero_multitask_segment_ddp.py | 49 +- lzero/model/unizero_model_multitask.py | 28 +- .../model/unizero_world_models/transformer.py | 92 +- lzero/model/vit.py | 141 ++- lzero/policy/unizero_multitask.py | 476 +++++++-- ...ri_unizero_multitask_segment_ddp_config.py | 21 +- ...ultitask_segment_ddp_config_debug_naive.py | 492 +++++++++ ..._ddp_config_finetune_SpaceInvaders_full.py | 494 +++++++++ ..._ddp_config_finetune_SpaceInvaders_head.py | 492 +++++++++ ...ne_SpaceInvaders_head_back_encoder_lora.py | 494 +++++++++ ...g_finetune_SpaceInvaders_head_back_lora.py | 490 +++++++++ zoo/atari/config/test.py | 246 +++++ 13 files changed, 4277 insertions(+), 178 deletions(-) create mode 100644 lzero/entry/train_unizero_multitask_segment_ddp copy.py create mode 100644 zoo/atari/config/atari_unizero_multitask_segment_ddp_config_debug_naive.py create mode 100644 zoo/atari/config/atari_unizero_multitask_segment_ddp_config_finetune_SpaceInvaders_full.py create mode 100644 zoo/atari/config/atari_unizero_multitask_segment_ddp_config_finetune_SpaceInvaders_head.py create mode 100644 zoo/atari/config/atari_unizero_multitask_segment_ddp_config_finetune_SpaceInvaders_head_back_encoder_lora.py create mode 100644 zoo/atari/config/atari_unizero_multitask_segment_ddp_config_finetune_SpaceInvaders_head_back_lora.py create mode 100644 zoo/atari/config/test.py diff --git a/lzero/entry/train_unizero_multitask_segment_ddp copy.py b/lzero/entry/train_unizero_multitask_segment_ddp copy.py new file mode 100644 index 000000000..bdcd03bde --- /dev/null +++ b/lzero/entry/train_unizero_multitask_segment_ddp copy.py @@ -0,0 +1,940 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional, List + +import torch +import numpy as np +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank, get_world_size +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage, TemperatureScheduler +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from ding.utils import EasyTimer +import torch.nn.functional as F + +import torch.distributed as dist + +# ------------------------------------------------------------ +# 1. 额外增加 learner 专用 process-group +# (在 main / learner 初始化时调用一次) +# ------------------------------------------------------------ +def build_learner_group(learner_ranks: list[int]) -> dist.ProcessGroup: + """ + learner_ranks 里只放 **真正执行 backward** 的那些 rank + 例:CUDA_VISIBLE_DEVICES=0,1 → learner_ranks=[0,1] + 返回一个新的 ProcessGroup,后续给 GenericMoCo 使用 + """ + world_pg = dist.group.WORLD + pg = dist.new_group(ranks=learner_ranks, backend='nccl') + if dist.get_rank() in learner_ranks: + torch.cuda.set_device(learner_ranks.index(dist.get_rank())) + return pg + +import concurrent.futures +# ====== UniZero-MT 归一化所需基准分数 (26 Atari100k task_id 对应索引) ====== +# 原始的 RANDOM_SCORES 和 HUMAN_SCORES + + +# global BENCHMARK_NAME +# # BENCHMARK_NAME = "atari" +# BENCHMARK_NAME = "dmc" # TODO +# if BENCHMARK_NAME == "atari": +# RANDOM_SCORES = np.array([ +# 227.8, 5.8, 222.4, 210.0, 14.2, 2360.0, 0.1, 1.7, 811.0, 10780.5, +# 152.1, 0.0, 65.2, 257.6, 1027.0, 29.0, 52.0, 1598.0, 258.5, 307.3, +# -20.7, 24.9, 163.9, 11.5, 68.4, 533.4 +# ]) +# HUMAN_SCORES = np.array([ +# 7127.7, 1719.5, 742.0, 8503.3, 753.1, 37187.5, 12.1, 30.5, 7387.8, 35829.4, +# 1971.0, 29.6, 4334.7, 2412.5, 30826.4, 302.8, 3035.0, 2665.5, 22736.3, 6951.6, +# 14.6, 69571.3, 13455.0, 7845.0, 42054.7, 11693.2 +# ]) +# elif BENCHMARK_NAME == "dmc": +# RANDOM_SCORES = np.array([0]*26) +# HUMAN_SCORES = np.array([1000]*26) + + +# # 新顺序对应的原始索引列表 +# # 新顺序: [Pong, MsPacman, Seaquest, Boxing, Alien, ChopperCommand, Hero, RoadRunner, +# # Amidar, Assault, Asterix, BankHeist, BattleZone, CrazyClimber, DemonAttack, +# # Freeway, Frostbite, Gopher, Jamesbond, Kangaroo, Krull, KungFuMaster, +# # PrivateEye, UpNDown, Qbert, Breakout] +# # 映射为原始数组中的索引(注意:索引均从0开始) +# new_order = [ +# 20, # Pong +# 19, # MsPacman +# 24, # Seaquest +# 6, # Boxing +# 0, # Alien +# 8, # ChopperCommand +# 14, # Hero +# 23, # RoadRunner +# 1, # Amidar +# 2, # Assault +# 3, # Asterix +# 4, # BankHeist +# 5, # BattleZone +# 9, # CrazyClimber +# 10, # DemonAttack +# 11, # Freeway +# 12, # Frostbite +# 13, # Gopher +# 15, # Jamesbond +# 16, # Kangaroo +# 17, # Krull +# 18, # KungFuMaster +# 21, # PrivateEye +# 25, # UpNDown +# 22, # Qbert +# 7 # Breakout +# ] + +# # 根据 new_order 生成新的数组 +# new_RANDOM_SCORES = RANDOM_SCORES[new_order] +# new_HUMAN_SCORES = HUMAN_SCORES[new_order] + +# # 查看重排后的结果 +# print("重排后的 RANDOM_SCORES:") +# print(new_RANDOM_SCORES) +# print("\n重排后的 HUMAN_SCORES:") +# print(new_HUMAN_SCORES) + +# 保存最近一次评估回报:{task_id: eval_episode_return_mean} +from collections import defaultdict +GLOBAL_EVAL_RETURNS: dict[int, float] = defaultdict(lambda: None) +def compute_unizero_mt_normalized_stats( + eval_returns: dict[int, float] +) -> tuple[Optional[float], Optional[float]]: + """ + 由 eval_returns 计算 Human-Normalized Mean 和 Median。 + 若暂无样本,返回 (None, None)。 + """ + normalized = [] + for tid, ret in eval_returns.items(): + if ret is None: + continue + denom = new_HUMAN_SCORES[tid] - new_RANDOM_SCORES[tid] + if denom == 0: + continue + normalized.append((ret - new_RANDOM_SCORES[tid]) / denom) + + if not normalized: + return None, None + arr = np.asarray(normalized, dtype=np.float32) + return float(arr.mean()), float(np.median(arr)) + +# 设置超时时间 (秒) +TIMEOUT = 12000 # 例如200分钟 + +timer = EasyTimer() + + +def safe_eval( + evaluator: Evaluator, + learner: BaseLearner, + collector: Collector, + rank: int, + world_size: int +) -> Tuple[Optional[bool], Optional[float]]: + """ + Safely执行评估任务,避免超时。 + + Args: + evaluator (Evaluator): 评估器实例。 + learner (BaseLearner): 学习器实例。 + collector (Collector): 数据收集器实例。 + rank (int): 当前进程的rank。 + world_size (int): 总进程数。 + + Returns: + Tuple[Optional[bool], Optional[float]]: 如果评估成功,返回停止标志和奖励,否则返回(None, None)。 + """ + try: + print(f"=========评估开始 Rank {rank}/{world_size}===========") + # 重置 stop_event,确保每次评估前都处于未设置状态 + evaluator.stop_event.clear() + with concurrent.futures.ThreadPoolExecutor() as executor: + # 提交评估任务 + future = executor.submit(evaluator.eval, learner.save_checkpoint, learner.train_iter, collector.envstep) + try: + stop, reward = future.result(timeout=TIMEOUT) + except concurrent.futures.TimeoutError: + # 超时,设置 stop_event + evaluator.stop_event.set() + print(f"评估操作在 Rank {rank}/{world_size} 上超时,耗时 {TIMEOUT} 秒。") + return None, None + + print(f"======评估结束 Rank {rank}/{world_size}======") + return stop, reward + except Exception as e: + print(f"Rank {rank}/{world_size} 评估过程中发生错误: {e}") + return None, None + + +def allocate_batch_size( + cfgs: List[dict], + game_buffers, + alpha: float = 1.0, + clip_scale: int = 1 +) -> List[int]: + """ + 根据不同任务的收集剧集数反比分配batch_size,并动态调整batch_size范围以提高训练稳定性和效率。 + + Args: + cfgs (List[dict]): 每个任务的配置列表。 + game_buffers (List[GameBuffer]): 每个任务的重放缓冲区实例列表。 + alpha (float, optional): 控制反比程度的超参数。默认为1.0。 + clip_scale (int, optional): 动态调整的clip比例。默认为1。 + + Returns: + List[int]: 分配后的batch_size列表。 + """ + # 提取每个任务的 collected episodes 数量 + buffer_num_of_collected_episodes = [buffer.num_of_collected_episodes for buffer in game_buffers] + + # 获取当前的 world_size 和 rank + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + # 收集所有 rank 的 collected episodes 列表 + all_task_num_of_collected_episodes = [None for _ in range(world_size)] + torch.distributed.all_gather_object(all_task_num_of_collected_episodes, buffer_num_of_collected_episodes) + + # 将所有 rank 的 collected episodes 合并为一个大列表 + all_task_num_of_collected_episodes = [ + episode for sublist in all_task_num_of_collected_episodes for episode in sublist + ] + if rank == 0: + print(f'所有任务的 collected episodes: {all_task_num_of_collected_episodes}') + + # 计算每个任务的反比权重 + inv_episodes = np.array([1.0 / (episodes + 1) for episodes in all_task_num_of_collected_episodes]) + inv_sum = np.sum(inv_episodes) + + # 计算总的batch_size (所有任务 cfg.policy.batch_size 的和) + total_batch_size = cfgs[0].policy.total_batch_size + + # 动态调整的部分:最小和最大的 batch_size 范围 + avg_batch_size = total_batch_size / world_size + min_batch_size = avg_batch_size / clip_scale + max_batch_size = avg_batch_size * clip_scale + + # 动态调整 alpha,让 batch_size 的变化更加平滑 + task_weights = (inv_episodes / inv_sum) ** alpha + batch_sizes = total_batch_size * task_weights + + # 控制 batch_size 在 [min_batch_size, max_batch_size] 之间 + batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) + + # 确保 batch_size 是整数 + batch_sizes = [int(size) for size in batch_sizes] + + return batch_sizes + +import numpy as np + + +def symlog(x: torch.Tensor) -> torch.Tensor: + """ + Symlog 归一化,减少目标值的幅度差异。 + symlog(x) = sign(x) * log(|x| + 1) + """ + return torch.sign(x) * torch.log(torch.abs(x) + 1) + +def inv_symlog(x: torch.Tensor) -> torch.Tensor: + """ + Symlog 的逆操作,用于恢复原始值。 + inv_symlog(x) = sign(x) * (exp(|x|) - 1) + """ + return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) + +# 全局最大值和最小值(用于 "run-max-min") +GLOBAL_MAX = -float('inf') +GLOBAL_MIN = float('inf') + +def compute_task_weights( + task_returns: dict, + option: str = "symlog", + epsilon: float = 1e-6, + temperature: float = 1.0, + use_softmax: bool = False, # 是否使用 Softmax + reverse: bool = False, # 正比 (False) 或反比 (True) + clip_min: float = 1e-2, # 权重的最小值 + clip_max: float = 1.0, # 权重的最大值 +) -> dict: + """ + 改进后的任务权重计算函数,支持多种标准化方式、Softmax 和正反比权重计算,并增加权重范围裁剪功能。 + + Args: + task_returns (dict): 每个任务的字典,键为 task_id,值为评估奖励或损失。 + option (str): 标准化方式,可选值为 "symlog", "max-min", "run-max-min", "rank", "none"。 + epsilon (float): 避免分母为零的小值。 + temperature (float): 控制权重分布的温度系数。 + use_softmax (bool): 是否使用 Softmax 进行权重分配。 + reverse (bool): 若为 True,权重与值反比;若为 False,权重与值正比。 + clip_min (float): 权重的最小值,用于裁剪。 + clip_max (float): 权重的最大值,用于裁剪。 + + Returns: + dict: 每个任务的权重,键为 task_id,值为归一化后的权重。 + """ + import torch + import torch.nn.functional as F + + global GLOBAL_MAX, GLOBAL_MIN + + # 如果输入为空字典,直接返回空结果 + if not task_returns: + return {} + + # Step 1: 对 task_returns 的值构造张量 + task_ids = list(task_returns.keys()) + returns_tensor = torch.tensor(list(task_returns.values()), dtype=torch.float32) + + if option == "symlog": + # 使用 symlog 标准化 + scaled_returns = symlog(returns_tensor) + elif option == "max-min": + # 使用最大最小值归一化 + max_reward = returns_tensor.max().item() + min_reward = returns_tensor.min().item() + scaled_returns = (returns_tensor - min_reward) / (max_reward - min_reward + epsilon) + elif option == "run-max-min": + # 使用全局最大最小值归一化 + GLOBAL_MAX = max(GLOBAL_MAX, returns_tensor.max().item()) + GLOBAL_MIN = min(GLOBAL_MIN, returns_tensor.min().item()) + scaled_returns = (returns_tensor - GLOBAL_MIN) / (GLOBAL_MAX - GLOBAL_MIN + epsilon) + elif option == "rank": + # 使用 rank 标准化 + # Rank 是基于值大小的排名,1 表示最小值,越大排名越高 + sorted_indices = torch.argsort(returns_tensor) + scaled_returns = torch.empty_like(returns_tensor) + rank_values = torch.arange(1, len(returns_tensor) + 1, dtype=torch.float32) # 1 到 N + scaled_returns[sorted_indices] = rank_values + elif option == "none": + # 不进行标准化 + scaled_returns = returns_tensor + else: + raise ValueError(f"Unsupported option: {option}") + + # Step 2: 根据 reverse 确定权重是正比还是反比 + if not reverse: + # 正比:权重与值正相关 + raw_weights = scaled_returns + else: + # 反比:权重与值负相关 + # 避免 scaled_returns 为负数或零 + scaled_returns = torch.clamp(scaled_returns, min=epsilon) + raw_weights = 1.0 / scaled_returns + + # Step 3: 根据是否使用 Softmax 进行权重计算 + if use_softmax: + # 使用 Softmax 进行权重分配 + beta = 1.0 / max(temperature, epsilon) # 确保 temperature 不为零 + logits = -beta * raw_weights + softmax_weights = F.softmax(logits, dim=0).numpy() + weights = dict(zip(task_ids, softmax_weights)) + else: + # 不使用 Softmax,直接计算权重 + # 温度缩放 + scaled_weights = raw_weights ** (1 / max(temperature, epsilon)) # 确保温度不为零 + + # 归一化权重 + total_weight = scaled_weights.sum() + normalized_weights = scaled_weights / total_weight + + # 转换为字典 + weights = dict(zip(task_ids, normalized_weights.numpy())) + + # Step 4: Clip 权重范围 + for task_id in weights: + weights[task_id] = max(min(weights[task_id], clip_max), clip_min) + + return weights + +def train_unizero_multitask_segment_ddp( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), + benchmark_name: str = "atari", + finetune_components=[] +) -> 'Policy': + """ + Overview: + UniZero的训练入口,旨在通过解决MuZero类算法在需要捕捉长期依赖环境中的局限性,提高强化学习代理的规划能力。 + 详细信息请参阅 https://arxiv.org/abs/2406.10667。 + + Args: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): 不同任务的配置列表。 + - seed (:obj:`int`): 随机种子。 + - model (:obj:`Optional[torch.nn.Module]`): torch.nn.Module实例。 + - model_path (:obj:`Optional[str]`): 预训练模型路径,应指向预训练模型的ckpt文件。 + - max_train_iter (:obj:`Optional[int]`): 训练中的最大策略更新迭代次数。 + - max_env_step (:obj:`Optional[int]`): 最大收集环境交互步数。 + + Returns: + - policy (:obj:`Policy`): 收敛的策略。 + """ + + # --------------------------------------------------------------- + # ====== UniZero-MT 需要用到的基准分数(与 26 个 Atari100k 任务 id 一一对应)====== + # 原始的 RANDOM_SCORES 和 HUMAN_SCORES + if benchmark_name == "atari": + RANDOM_SCORES = np.array([ + 227.8, 5.8, 222.4, 210.0, 14.2, 2360.0, 0.1, 1.7, 811.0, 10780.5, + 152.1, 0.0, 65.2, 257.6, 1027.0, 29.0, 52.0, 1598.0, 258.5, 307.3, + -20.7, 24.9, 163.9, 11.5, 68.4, 533.4, + ]) + HUMAN_SCORES = np.array([ + 7127.7, 1719.5, 742.0, 8503.3, 753.1, 37187.5, 12.1, 30.5, 7387.8, 35829.4, + 1971.0, 29.6, 4334.7, 2412.5, 30826.4, 302.8, 3035.0, 2665.5, 22736.3, 6951.6, + 14.6, 69571.3, 13455.0, 7845.0, 42054.7, 11693.2 + ]) + # RANDOM_SCORES = np.array([ + # 148.0 + # ]) + # HUMAN_SCORES = np.array([ + # 1652.0 + # ]) + elif benchmark_name == "dmc": + # RANDOM_SCORES = np.array([0]*26) + # HUMAN_SCORES = np.array([1000]*26) + RANDOM_SCORES = np.zeros(26) + HUMAN_SCORES = np.ones(26) * 1000 + else: + raise ValueError(f"Unsupported BENCHMARK_NAME: {BENCHMARK_NAME}") + + # 新顺序对应的原始索引列表 + # 新顺序: [Pong, MsPacman, Seaquest, Boxing, Alien, ChopperCommand, Hero, RoadRunner, + # Amidar, Assault, Asterix, BankHeist, BattleZone, CrazyClimber, DemonAttack, + # Freeway, Frostbite, Gopher, Jamesbond, Kangaroo, Krull, KungFuMaster, + # PrivateEye, UpNDown, Qbert, Breakout] + # 映射为原始数组中的索引(注意:索引均从0开始) + new_order = [ + 20, # Pong + 19, # MsPacman + 24, # Seaquest + 6, # Boxing + 0, # Alien + 8, # ChopperCommand + 14, # Hero + 23, # RoadRunner + 1, # Amidar + 2, # Assault + 3, # Asterix + 4, # BankHeist + 5, # BattleZone + 9, # CrazyClimber + 10, # DemonAttack + 11, # Freeway + 12, # Frostbite + 13, # Gopher + 15, # Jamesbond + 16, # Kangaroo + 17, # Krull + 18, # KungFuMaster + 21, # PrivateEye + 25, # UpNDown + 22, # Qbert + 7 # Breakout + ] + global new_RANDOM_SCORES, new_HUMAN_SCORES + # 根据 new_order 生成新的数组 + new_RANDOM_SCORES = RANDOM_SCORES[new_order] + new_HUMAN_SCORES = HUMAN_SCORES[new_order] + # 查看重排后的结果 + print("重排后的 RANDOM_SCORES:") + print(new_RANDOM_SCORES) + print("\n重排后的 HUMAN_SCORES:") + print(new_HUMAN_SCORES) + # --------------------------------------------------------------- + + # 初始化温度调度器 + initial_temperature = 10.0 + final_temperature = 1.0 + threshold_steps = int(1e4) # 训练步数达到 10k 时,温度降至 1.0 + temperature_scheduler = TemperatureScheduler( + initial_temp=initial_temperature, + final_temp=final_temperature, + threshold_steps=threshold_steps, + mode='linear' # 或 'exponential' + ) + + # 获取当前进程的rank和总进程数 + rank = get_rank() + world_size = get_world_size() + + # 任务划分 + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // world_size + remainder = total_tasks % world_size + + if rank < remainder: + start_idx = rank * (tasks_per_rank + 1) + end_idx = start_idx + tasks_per_rank + 1 + else: + start_idx = rank * tasks_per_rank + remainder + end_idx = start_idx + tasks_per_rank + + tasks_for_this_rank = input_cfg_list[start_idx:end_idx] + + # 确保至少有一个任务 + if len(tasks_for_this_rank) == 0: + logging.warning(f"Rank {rank}: 未分配任务,继续执行。") + # 初始化空列表以避免后续代码报错 + cfgs, game_buffers, collector_envs, evaluator_envs, collectors, evaluators = [], [], [], [], [], [] + else: + print(f"Rank {rank}/{world_size}, 处理任务 {start_idx} 到 {end_idx - 1}") + + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + if tasks_for_this_rank: + # 使用第一个任务的配置创建共享的policy + task_id, [cfg, create_cfg] = tasks_for_this_rank[0] + + for config in tasks_for_this_rank: + config[1][0].policy.task_num = tasks_per_rank + + # 确保指定的策略类型受支持 + assert create_cfg.policy.type in ['unizero_multitask', + 'sampled_unizero_multitask'], "train_unizero entry 目前仅支持 'unizero_multitask'" + + if create_cfg.policy.type == 'unizero_multitask': + from lzero.mcts import UniZeroGameBuffer as GameBuffer + if create_cfg.policy.type == 'sampled_unizero_multitask': + from lzero.mcts import SampledUniZeroGameBuffer as GameBuffer + + + # 根据CUDA可用性设置设备 + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'配置的设备: {cfg.policy.device}') + + # 编译配置 + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # 创建共享的policy + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # 加载预训练模型(如果提供) + if model_path is not None: + logging.info(f'开始加载模型: {model_path}') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device),finetune_components=finetune_components) + logging.info(f'完成加载模型: {model_path}') + + # 创建TensorBoard日志记录器 + log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}') + tb_logger = SummaryWriter(log_dir) + + # 创建共享的learner + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + policy_config = cfg.policy + + # 处理当前进程分配到的每个任务 + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks_for_this_rank): + # 设置每个任务的随机种子 + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config(cfg, seed=seed + task_id, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + policy_config = cfg.policy + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + # 创建环境 + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # 创建不同的game buffer、collector和evaluator + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + # 调用learner的before_run钩子 + learner.call_hook('before_run') + value_priority_tasks = {} + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + update_per_collect = cfg.policy.update_per_collect + + # use_task_exploitation_weight = cfg.policy.use_task_exploitation_weight + task_exploitation_weight = None + + # 创建任务奖励字典 + task_returns = {} # {task_id: reward} + + while True: + # 动态调整batch_size + if cfg.policy.allocated_batch_sizes: + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes = allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=clip_scale) + if rank == 0: + print("分配后的 batch_sizes: ", allocated_batch_sizes) + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + cfg.policy.batch_size = allocated_batch_sizes + policy._cfg.batch_size = allocated_batch_sizes + + # 对于当前进程的每个任务,进行数据收集和评估 + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + + # 记录缓冲区内存使用情况 + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, cfg.policy.task_id) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # 默认的epsilon值 + } + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + # 判断是否需要进行评估 + # if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter): + if learner.train_iter > 10 and learner.train_iter % cfg.policy.eval_freq == 0 : + # if learner.train_iter > 10 and evaluator.should_eval(learner.train_iter): # only for debug + # if evaluator.should_eval(learner.train_iter): + print('=' * 20) + print(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}...') + + # =========TODO========= + evaluator._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) + + # 执行安全评估 + stop, reward = safe_eval(evaluator, learner, collector, rank, world_size) + # 判断评估是否成功 + if stop is None or reward is None: + print(f"Rank {rank} 在评估过程中遇到问题,继续训练...") + task_returns[cfg.policy.task_id] = float('inf') # 如果评估失败,将任务难度设为最大值 + else: + # 确保从评估结果中提取 `eval_episode_return_mean` 作为奖励值 + try: + eval_mean_reward = reward.get('eval_episode_return_mean', float('inf')) + print(f"任务 {cfg.policy.task_id} 的评估奖励: {eval_mean_reward}") + task_returns[cfg.policy.task_id] = eval_mean_reward + except Exception as e: + print(f"提取评估奖励时发生错误: {e}") + task_returns[cfg.policy.task_id] = float('inf') # 出现问题时,将奖励设为最大值 + + + print('=' * 20) + print(f'开始收集 Rank {rank} 的任务_id: {cfg.policy.task_id}...') + print(f'Rank {rank}: cfg.policy.task_id={cfg.policy.task_id} ') + + + # while replay_buffer.get_num_of_transitions() < cfg.policy.batch_size[cfg.policy.task_id]: + # for ddp training, 避免后面 train 时replay buffer中样本小于batch size 导致ddp hangs + + # 在每次收集之前重置初始数据,这对于多任务设置非常重要 + collector._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) + # 收集数据 + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # 更新重放缓冲区 + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + + # # ===== only for debug ===== + # if train_epoch > 2: + # with timer: + # replay_buffer.reanalyze_buffer(2, policy) + # buffer_reanalyze_count += 1 + # logging.info(f'缓冲区重新分析次数: {buffer_reanalyze_count}') + # logging.info(f'缓冲区重新分析耗时: {timer.value}') + # # ===== only for debug ===== + + + # 周期性地重新分析缓冲区 + if cfg.policy.buffer_reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + if train_epoch > 0 and train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and \ + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析次数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析耗时: {timer.value}') + + # 数据收集结束后添加日志 + logging.info(f'Rank {rank}: 完成任务 {cfg.policy.task_id} 的数据收集') + + # 检查是否有足够的数据进行训练 + not_enough_data = any( + replay_buffer.get_num_of_transitions() < cfgs[0].policy.total_batch_size / world_size + for replay_buffer in game_buffers + ) + + print(f"not_enough_data:{not_enough_data}") + # 获取当前温度 + current_temperature_task_weight = temperature_scheduler.get_temperature(learner.train_iter) + + # if learner.train_iter == 0 or learner.train_iter % cfg.policy.eval_freq == 0 : + if learner.train_iter > 10 and learner.train_iter % cfg.policy.eval_freq == 0 : + + # 计算任务权重 + try: + # 汇聚任务奖励 + dist.barrier() + # if cfg.policy.task_complexity_weight: + all_task_returns = [None for _ in range(world_size)] + dist.all_gather_object(all_task_returns, task_returns) + # 合并任务奖励 + merged_task_returns = {} + for returns in all_task_returns: + if returns: + merged_task_returns.update(returns) + + logging.warning(f"Rank {rank}: merged_task_returns: {merged_task_returns}") + + # 计算全局任务权重 + task_weights = compute_task_weights(merged_task_returns, temperature=current_temperature_task_weight) + + # ---------- 维护 UniZero-MT 全局评估结果 ---------- + for tid, ret in merged_task_returns.items(): + GLOBAL_EVAL_RETURNS[tid] = ret # solved 的任务同样更新 + + # 计算 Human-Normalized Mean / Median + uni_mean, uni_median = compute_unizero_mt_normalized_stats(GLOBAL_EVAL_RETURNS) + + if uni_mean is not None: # 至少评估过 1 个任务 + if rank == 0: # 仅在 rank0 写 TensorBoard,防止重复 + tb_logger.add_scalar('UniZero-MT/NormalizedMean', uni_mean, global_step=learner.train_iter) + tb_logger.add_scalar('UniZero-MT/NormalizedMedian', uni_median, global_step=learner.train_iter) + logging.info(f"Rank {rank}: UniZero-MT Norm Mean={uni_mean:.4f}, Median={uni_median:.4f}") + else: + logging.info(f"Rank {rank}: 暂无数据计算 UniZero-MT 归一化指标") + + # 同步任务权重 + dist.broadcast_object_list([task_weights], src=0) + # print(f"rank{rank}, 全局任务权重 (按 task_id 排列): {task_weights}") + # else: + # task_weights = None + except Exception as e: + logging.error(f'Rank {rank}: 同步任务权重失败,错误: {e}') + break + + + # ---------------- 采样完成,准备进入反向 ---------------- + # if dist.is_available() and dist.is_initialized(): + # dist.barrier() # ★★★ 关键同步 ★★★ + + # 学习策略 + if not not_enough_data: + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for idx, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)): + envstep_multi_task += collector.envstep + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + if replay_buffer.get_num_of_transitions() > batch_size: + if cfg.policy.buffer_reanalyze_freq >= 1: + if i % reanalyze_interval == 0 and \ + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'缓冲区重新分析次数: {buffer_reanalyze_count}') + logging.info(f'缓冲区重新分析耗时: {timer.value}') + + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(cfg.policy.task_id) # 追加task_id以区分任务 + train_data_multi_task.append(train_data) + else: + logging.warning( + f'重放缓冲区中的数据不足以采样mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + ) + break + + if train_data_multi_task: + # learn_kwargs = {'task_exploitation_weight':task_exploitation_weight, 'task_weights':task_weights, } + # learn_kwargs = {'task_weights': task_weights, } + # learn_kwargs = {'task_weights':task_exploitation_weight} + + learn_kwargs = {'task_weights': None,} + # logging.info(f'Rank {rank}: iter {i} one learn step start') + + # 在训练时,DDP会自动同步梯度和参数 + log_vars = learner.train(train_data_multi_task, envstep_multi_task, policy_kwargs=learn_kwargs) + + # logging.error(f'Rank {rank}: one learn step done') + + # 判断是否需要计算task_exploitation_weight + if i == 0: + # 计算任务权重 + try: + dist.barrier() # 等待所有进程同步 + if cfg.policy.use_task_exploitation_weight: # use obs loss now, new polish + # 收集所有任务的 obs_loss + all_obs_loss = [None for _ in range(world_size)] + # 构建当前进程的任务 obs_loss 数据 + merged_obs_loss_task = {} + for cfg, replay_buffer in zip(cfgs, game_buffers): + task_id = cfg.policy.task_id + if f'noreduce_obs_loss_task{task_id}' in log_vars[0]: + merged_obs_loss_task[task_id] = log_vars[0][f'noreduce_obs_loss_task{task_id}'] + # 汇聚所有进程的 obs_loss 数据 + dist.all_gather_object(all_obs_loss, merged_obs_loss_task) + # 合并所有进程的 obs_loss 数据 + global_obs_loss_task = {} + for obs_loss_task in all_obs_loss: + if obs_loss_task: + global_obs_loss_task.update(obs_loss_task) + # 计算全局任务权重 + if global_obs_loss_task: + task_exploitation_weight = compute_task_weights( + global_obs_loss_task, + option="rank", + # temperature=current_temperature_task_weight # TODO + temperature=1, + ) + # 广播任务权重到所有进程 + dist.broadcast_object_list([task_exploitation_weight], src=0) + print(f"rank{rank}, task_exploitation_weight (按 task_id 排列): {task_exploitation_weight}") + else: + logging.warning(f"Rank {rank}: 未能计算全局 obs_loss 任务权重,obs_loss 数据为空。") + task_exploitation_weight = None + else: + task_exploitation_weight = None + # 更新训练参数,使其包含计算后的任务权重 + learn_kwargs['task_weight'] = task_exploitation_weight + except Exception as e: + logging.error(f'Rank {rank}: 同步任务权重失败,错误: {e}') + raise e # 保留异常抛出,便于外部捕获和分析 + + + + if cfg.policy.use_priority: + for idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)): + # 更新任务特定的重放缓冲区优先级 + task_id = cfg.policy.task_id + replay_buffer.update_priority( + train_data_multi_task[idx], + log_vars[0][f'value_priority_task{task_id}'] + ) + + current_priorities = log_vars[0][f'value_priority_task{task_id}'] + mean_priority = np.mean(current_priorities) + std_priority = np.std(current_priorities) + + alpha = 0.1 # 平滑因子 + if f'running_mean_priority_task{task_id}' not in value_priority_tasks: + value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority + else: + value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( + alpha * mean_priority + + (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] + ) + + # 使用运行均值计算归一化的优先级 + running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] + normalized_priorities = (current_priorities - running_mean_priority) / (std_priority + 1e-6) + + # 如果需要,可以将归一化的优先级存储回重放缓冲区 + # replay_buffer.update_priority(train_data_multi_task[idx], normalized_priorities) + + # 记录优先级统计信息 + if cfg.policy.print_task_priority_logs: + print(f"任务 {task_id} - 平均优先级: {mean_priority:.8f}, " + f"运行平均优先级: {running_mean_priority:.8f}, " + f"标准差: {std_priority:.8f}") + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # 同步所有Rank,确保所有Rank完成训练 + try: + dist.barrier() + logging.info(f'Rank {rank}: 通过训练后的同步障碍') + except Exception as e: + logging.error(f'Rank {rank}: 同步障碍失败,错误: {e}') + break + + # 检查是否需要终止训练 + try: + local_envsteps = [collector.envstep for collector in collectors] + total_envsteps = [None for _ in range(world_size)] + dist.all_gather_object(total_envsteps, local_envsteps) + + all_envsteps = torch.cat([torch.tensor(envsteps, device=cfg.policy.device) for envsteps in total_envsteps]) + max_envstep_reached = torch.all(all_envsteps >= max_env_step) + + # 收集所有进程的train_iter + global_train_iter = torch.tensor([learner.train_iter], device=cfg.policy.device) + all_train_iters = [torch.zeros_like(global_train_iter) for _ in range(world_size)] + dist.all_gather(all_train_iters, global_train_iter) + + max_train_iter_reached = torch.any(torch.stack(all_train_iters) >= max_train_iter) + + if max_envstep_reached.item() or max_train_iter_reached.item(): + logging.info(f'Rank {rank}: 达到终止条件') + dist.barrier() # 确保所有进程同步 + break + except Exception as e: + logging.error(f'Rank {rank}: 终止检查失败,错误: {e}') + break + + # 调用learner的after_run钩子 + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero_multitask_segment_ddp.py b/lzero/entry/train_unizero_multitask_segment_ddp.py index 3fdcfa099..71dfda006 100644 --- a/lzero/entry/train_unizero_multitask_segment_ddp.py +++ b/lzero/entry/train_unizero_multitask_segment_ddp.py @@ -367,7 +367,8 @@ def train_unizero_multitask_segment_ddp( model_path: Optional[str] = None, max_train_iter: Optional[int] = int(1e10), max_env_step: Optional[int] = int(1e10), - benchmark_name: str = "atari" + benchmark_name: str = "atari", + finetune_components=[] ) -> 'Policy': """ Overview: @@ -391,15 +392,17 @@ def train_unizero_multitask_segment_ddp( # 原始的 RANDOM_SCORES 和 HUMAN_SCORES if benchmark_name == "atari": RANDOM_SCORES = np.array([ - 227.8, 5.8, 222.4, 210.0, 14.2, 2360.0, 0.1, 1.7, 811.0, 10780.5, - 152.1, 0.0, 65.2, 257.6, 1027.0, 29.0, 52.0, 1598.0, 258.5, 307.3, - -20.7, 24.9, 163.9, 11.5, 68.4, 533.4 + 148.0 # SpaceInvader ]) HUMAN_SCORES = np.array([ - 7127.7, 1719.5, 742.0, 8503.3, 753.1, 37187.5, 12.1, 30.5, 7387.8, 35829.4, - 1971.0, 29.6, 4334.7, 2412.5, 30826.4, 302.8, 3035.0, 2665.5, 22736.3, 6951.6, - 14.6, 69571.3, 13455.0, 7845.0, 42054.7, 11693.2 + 1652.0 # SpaceInvader ]) + # RANDOM_SCORES = np.array([ + # 148.0 + # ]) + # HUMAN_SCORES = np.array([ + # 1652.0 + # ]) elif benchmark_name == "dmc": # RANDOM_SCORES = np.array([0]*26) # HUMAN_SCORES = np.array([1000]*26) @@ -415,33 +418,8 @@ def train_unizero_multitask_segment_ddp( # PrivateEye, UpNDown, Qbert, Breakout] # 映射为原始数组中的索引(注意:索引均从0开始) new_order = [ - 20, # Pong - 19, # MsPacman - 24, # Seaquest - 6, # Boxing - 0, # Alien - 8, # ChopperCommand - 14, # Hero - 23, # RoadRunner - 1, # Amidar - 2, # Assault - 3, # Asterix - 4, # BankHeist - 5, # BattleZone - 9, # CrazyClimber - 10, # DemonAttack - 11, # Freeway - 12, # Frostbite - 13, # Gopher - 15, # Jamesbond - 16, # Kangaroo - 17, # Krull - 18, # KungFuMaster - 21, # PrivateEye - 25, # UpNDown - 22, # Qbert - 7 # Breakout - ] + 0 # SpaceInvader (唯一任务,索引为0) + ] global new_RANDOM_SCORES, new_HUMAN_SCORES # 根据 new_order 生成新的数组 new_RANDOM_SCORES = RANDOM_SCORES[new_order] @@ -521,12 +499,13 @@ def train_unizero_multitask_segment_ddp( # 编译配置 cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) # 创建共享的policy + cfg.policy.learn.learner.hook.log_show_after_iter=100 policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) # 加载预训练模型(如果提供) if model_path is not None: logging.info(f'开始加载模型: {model_path}') - policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device),finetune_components=finetune_components) logging.info(f'完成加载模型: {model_path}') # 创建TensorBoard日志记录器 diff --git a/lzero/model/unizero_model_multitask.py b/lzero/model/unizero_model_multitask.py index 9a24d8dfb..6d40b748e 100644 --- a/lzero/model/unizero_model_multitask.py +++ b/lzero/model/unizero_model_multitask.py @@ -123,33 +123,27 @@ def __init__( final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, )) elif world_model_cfg.encoder_type == "vit": + + lora_config={ + 'r': world_model_cfg.get('encoder_lora_r', 0), + 'alpha': world_model_cfg.get('encoder_lora_alpha', 0), + 'dropout': world_model_cfg.get('encoder_lora_dropout', 0), + } + for task_id in range(1): # TODO: one share encoder if world_model_cfg.task_num <=8: - # # vit base - # self.representation_network.append(ViT( - # image_size =observation_shape[1], - # patch_size = 8, - # num_classes = obs_act_embed_dim, - # dim = 768, - # depth = 12, - # heads = 12, - # mlp_dim = 3072, - # dropout = 0.1, - # emb_dropout = 0.1, - # final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, - # )) - # vit small self.representation_network.append(ViT( image_size =observation_shape[1], patch_size = 8, num_classes = obs_act_embed_dim, dim = 768, - depth = 6, - heads = 6, - mlp_dim = 2048, + depth = 12, + heads = 12, + mlp_dim = 3072, dropout = 0.1, emb_dropout = 0.1, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + lora_config=lora_config )) elif world_model_cfg.task_num > 8: # vit base diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index ad1265007..44a3035e2 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -39,6 +39,57 @@ def __init__(self, init=1.0, s_max=1.5): def forward(self): return self.s_max * torch.sigmoid(self.logit) +############################################## +# LoRALinear 实现 +############################################## + +class LoRALinear(nn.Module): + """ + 基础的LoRALinear实现,对标准线性层进行LoRA微调扩展。 + + - 保留原始的weight和bias参数 + - 添加LoRA的A和B矩阵进行低秩分解 + - 前向计算: output = F.linear(x, W, bias) + scaling * lora_B(lora_A(dropout(x))) + """ + def __init__(self, in_features: int, out_features: int, bias: bool = True, + r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.r = r + self.lora_alpha = lora_alpha + self.scaling = lora_alpha / r if r > 0 else 1.0 + self.lora_dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0.0 else nn.Identity() + + # 初始化基础权重 + self.weight = nn.Parameter(torch.empty(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.empty(out_features)) + else: + self.bias = None + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if bias: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + # 初始化LoRA参数 + if r > 0: + self.lora_A = nn.Parameter(torch.randn(r, in_features) * 0.01) + self.lora_B = nn.Parameter(torch.zeros(out_features, r)) + else: + self.lora_A = None + self.lora_B = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + baseline_out = F.linear(x, self.weight, self.bias) + if self.r == 0 or self.lora_A is None or self.lora_B is None: + return baseline_out + + lora_out = F.linear(self.lora_dropout(x), self.lora_A) + lora_out = F.linear(lora_out, self.lora_B) + return baseline_out + self.scaling * lora_out + ############################################## # CurriculumLoRALinear 实现 ############################################## @@ -132,9 +183,10 @@ def set_curriculum_stage(self, stage: int): 同时将 log 出模块信息和状态变化。 """ + # return assert 0 <= stage < self.curriculum_stage_num, f"stage 必须在 [0, {self.curriculum_stage_num-1}] 范围内" self.curriculum_stage = stage - + # 输出 log 信息,展示当前模块(可结合 in_features, out_features 标识) module_id = f"({self.in_features}x{self.out_features})" if stage == 0: @@ -202,7 +254,7 @@ def _maybe_wrap_linear(linear: nn.Linear, config, module_label: str) -> nn.Modul - 并且 config 中配置了 curriculum_stage_num > 1 否则,若仅满足基础 LoRA 条件,则返回原有 LoRALinear;否则返回原始的线性层。 """ - if config.lora_r > 0 and (module_label in config.lora_target_modules) and getattr(config, "curriculum_stage_num", 1) > 1: + if False and config.lora_r > 0 and (module_label in config.lora_target_modules) and getattr(config, "curriculum_stage_num", 1) > 1: new_linear = CurriculumLoRALinear( in_features=linear.in_features, out_features=linear.out_features, @@ -217,20 +269,20 @@ def _maybe_wrap_linear(linear: nn.Linear, config, module_label: str) -> nn.Modul if linear.bias is not None: new_linear.bias.data.copy_(linear.bias.data) return new_linear - # elif config.lora_r > 0 and (module_label in config.lora_target_modules): - # # 若不使用课程学习,则调用原有 LoRALinear 实现(未展示,此处假设其已定义) - # new_linear = LoRALinear( - # in_features=linear.in_features, - # out_features=linear.out_features, - # bias=(linear.bias is not None), - # r=config.lora_r, - # lora_alpha=config.lora_alpha, - # lora_dropout=config.lora_dropout - # ) - # new_linear.weight.data.copy_(linear.weight.data) - # if linear.bias is not None: - # new_linear.bias.data.copy_(linear.bias.data) - # return new_linear + elif config.lora_r > 0 and (module_label in config.lora_target_modules): + # 若不使用课程学习,则调用原有 LoRALinear 实现(未展示,此处假设其已定义) + new_linear = LoRALinear( + in_features=linear.in_features, + out_features=linear.out_features, + bias=(linear.bias is not None), + r=config.lora_r, + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout + ) + new_linear.weight.data.copy_(linear.weight.data) + if linear.bias is not None: + new_linear.bias.data.copy_(linear.bias.data) + return new_linear else: return linear @@ -346,7 +398,13 @@ def __init__(self, config: TransformerConfig, task_embed=None) -> None: else: self.use_register_token = False # TODO - + + # if config.lora_r > 0: + # set_curriculum_stage_for_transformer(self,) +# # set_curriculum_stage_for_transformer( +# self.policy._learn_model.world_model.transformer, +# self.stage +# ) def add_register_tokens(self, sequences: torch.Tensor, task_id: int) -> torch.Tensor: """ diff --git a/lzero/model/vit.py b/lzero/model/vit.py index 7009735b0..858dd4d90 100644 --- a/lzero/model/vit.py +++ b/lzero/model/vit.py @@ -27,7 +27,7 @@ def forward(self, x): return self.net(x) class Attention(nn.Module): - def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., lora_config=None): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) @@ -46,13 +46,57 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) if project_out else nn.Identity() + + # LoRA 配置 + if lora_config is None: + lora_config = {} + + self.lora_r = lora_config.get('r', 0) + self.lora_alpha = lora_config.get('alpha', 1) + self.lora_dropout_p = lora_config.get('dropout', 0.0) + self.use_lora = self.lora_r > 0 + + # LoRA 参数(如果启用) + if self.use_lora: + self.scaling = self.lora_alpha / self.lora_r + self.lora_dropout = nn.Dropout(self.lora_dropout_p) + + # 为 q、k、v 分别创建 LoRA 参数 + self.lora_A_q = nn.Parameter(torch.randn(self.lora_r, dim) * 0.01) + self.lora_B_q = nn.Parameter(torch.zeros(inner_dim, self.lora_r)) + + self.lora_A_k = nn.Parameter(torch.randn(self.lora_r, dim) * 0.01) + self.lora_B_k = nn.Parameter(torch.zeros(inner_dim, self.lora_r)) + + self.lora_A_v = nn.Parameter(torch.randn(self.lora_r, dim) * 0.01) + self.lora_B_v = nn.Parameter(torch.zeros(inner_dim, self.lora_r)) def forward(self, x): x = self.norm(x) + # 原有的预训练路径:获得 q、k、v qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + # 如果启用了 LoRA,添加 LoRA 贡献 + if self.use_lora: + x_dropped = self.lora_dropout(x) + + # 计算每个分量的 LoRA 贡献 + lora_q = (x_dropped @ self.lora_A_q.T) @ self.lora_B_q.T # (b, n, inner_dim) + lora_k = (x_dropped @ self.lora_A_k.T) @ self.lora_B_k.T # (b, n, inner_dim) + lora_v = (x_dropped @ self.lora_A_v.T) @ self.lora_B_v.T # (b, n, inner_dim) + + # 重排成多头格式:(b, n, inner_dim) -> (b, h, n, d) + lora_q = rearrange(lora_q, 'b n (h d) -> b h n d', h = self.heads) + lora_k = rearrange(lora_k, 'b n (h d) -> b h n d', h = self.heads) + lora_v = rearrange(lora_v, 'b n (h d) -> b h n d', h = self.heads) + + # 加到对应的 q、k、v 上 + q = q + self.scaling * lora_q + k = k + self.scaling * lora_k + v = v + self.scaling * lora_v + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) @@ -63,13 +107,13 @@ def forward(self, x): return self.to_out(out) class Transformer(nn.Module): - def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., lora_config=None): super().__init__() self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, lora_config=lora_config), FeedForward(dim, mlp_dim, dropout = dropout) ])) @@ -81,7 +125,7 @@ def forward(self, x): return self.norm(x) class ViT(nn.Module): - def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., final_norm_option_in_encoder='SimNorm'): + def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., final_norm_option_in_encoder='SimNorm', lora_config=None): super().__init__() image_height, image_width = pair(image_size) patch_height, patch_width = pair(patch_size) @@ -103,7 +147,7 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.dropout = nn.Dropout(emb_dropout) - self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, lora_config=lora_config) self.pool = pool self.last_linear = nn.Linear(dim, num_classes) @@ -143,32 +187,97 @@ def forward(self, img): import random torch.manual_seed(42) random.seed(42) + + # 创建一个带 LoRA 的模型 + print("=== 创建 ViT with LoRA ===") + lora_config = { + 'r': 16, + 'alpha': 32, + 'dropout': 0.1 + } model = ViT( image_size = 64, patch_size = 8, - num_classes =768, + num_classes = 768, dim = 768, depth = 12, heads = 12, mlp_dim = 3072, dropout = 0.1, emb_dropout = 0.1, - final_norm_option_in_encoder="LayerNorm" + final_norm_option_in_encoder="LayerNorm", + lora_config=lora_config ) model = model.cuda() if torch.cuda.is_available() else model + model.eval() dummy = torch.randn(256,3,64,64).to(next(model.parameters()).device) + + print("Total param count:", sum(p.numel() for p in model.parameters())) + + # 统计 LoRA 参数数量 + lora_params = 0 + for name, param in model.named_parameters(): + if 'lora_A' in name or 'lora_B' in name: + lora_params += param.numel() + print("LoRA-only param count:", lora_params) + + # 测试关闭 LoRA (use_lora=False) + print("\n=== 测试关闭 LoRA ===") + for module in model.modules(): + if hasattr(module, 'use_lora'): + module.use_lora = False + with torch.no_grad(): - out = model(dummy) - print("Output shape:", out.shape) # => (10, 768) - print("output[0]", out[0][:50]) # => (1, 50) - - # 简单基准 - import time, contextlib + out_no_lora = model(dummy) + print("No LoRA output shape:", out_no_lora.shape) + + # 测试开启 LoRA (use_lora=True) + print("\n=== 测试开启 LoRA ===") + for module in model.modules(): + if hasattr(module, 'use_lora'): + module.use_lora = True + + with torch.no_grad(): + out_with_lora = model(dummy) + print("With LoRA output shape:", out_with_lora.shape) + + # 验证初始时两个输出相近(因为 LoRA 的 B 矩阵初始化为 0) + print("\nOutput difference (should be very small initially):") + print("Max diff:", torch.max(torch.abs(out_no_lora - out_with_lora)).item()) + print("Mean diff:", torch.mean(torch.abs(out_no_lora - out_with_lora)).item()) + + # 简单基准测试 + print("\n=== 性能测试 ===") + import time warm, rep = 5, 20 - for _ in range(warm): out = model(dummy) + + # 关闭 LoRA 测试 + for module in model.modules(): + if hasattr(module, 'use_lora'): + module.use_lora = False + + for _ in range(warm): + with torch.no_grad(): out = model(dummy) + torch.cuda.synchronize() if torch.cuda.is_available() else None + t0=time.time() + for _ in range(rep): + with torch.no_grad(): out = model(dummy) + torch.cuda.synchronize() if torch.cuda.is_available() else None + no_lora_time = (time.time()-t0)/rep*1000 + print(f"No LoRA latency: {no_lora_time:.2f} ms") + + # 开启 LoRA 测试 + for module in model.modules(): + if hasattr(module, 'use_lora'): + module.use_lora = True + + for _ in range(warm): + with torch.no_grad(): out = model(dummy) torch.cuda.synchronize() if torch.cuda.is_available() else None t0=time.time() for _ in range(rep): - out = model(dummy) + with torch.no_grad(): out = model(dummy) torch.cuda.synchronize() if torch.cuda.is_available() else None - print(f"Average latency: {(time.time()-t0)/rep*1000:.2f} ms") + lora_time = (time.time()-t0)/rep*1000 + print(f"With LoRA latency: {lora_time:.2f} ms") + print(f"LoRA Overhead: {((lora_time-no_lora_time)/no_lora_time*100):.1f}%") diff --git a/lzero/policy/unizero_multitask.py b/lzero/policy/unizero_multitask.py index 031afb9fd..52b86fb06 100644 --- a/lzero/policy/unizero_multitask.py +++ b/lzero/policy/unizero_multitask.py @@ -49,7 +49,7 @@ def build_learner_group(learner_ranks: list[int]) -> dist.ProcessGroup: # from LibMTL.weighting.CAGrad_unizero import CAGrad as GradCorrect # from LibMTL.weighting.abstract_weighting import AbsWeighting - +a=0 def generate_task_loss_dict(multi_task_losses, task_name_template, task_id): """ @@ -140,7 +140,7 @@ class UniZeroMTPolicy(UniZeroPolicy): by addressing the limitations found in MuZero-style algorithms, particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. """ - + a=0 # The default_config for UniZero policy. config = dict( type='unizero_multitask', @@ -548,7 +548,289 @@ def _retain_prev_if_zero(self, name: str, self._prev_plasticity_metrics[name] = value return value - + def print_traninable_parameter(self, model): + """ + 打印模型的可训练参数,以树状结构显示 + """ + print("=" * 80) + print("TRAINABLE PARAMETERS TREE STRUCTURE") + print("=" * 80) + + # 统计信息 + total_trainable_params = 0 + total_trainable_size = 0 + + # 按模块组织参数 + module_params = {} + + for name, param in model.named_parameters(): + if param.requires_grad: + # 解析模块层次结构 + parts = name.split('.') + current_dict = module_params + + # 构建嵌套字典结构 + for i, part in enumerate(parts[:-1]): + if part not in current_dict: + current_dict[part] = {} + current_dict = current_dict[part] + + # 最后一层存储参数信息 + param_name = parts[-1] + param_info = { + 'shape': list(param.shape), + 'numel': param.numel(), + 'dtype': str(param.dtype), + 'device': str(param.device), + 'requires_grad': param.requires_grad + } + current_dict[param_name] = param_info + + total_trainable_params += param.numel() + total_trainable_size += param.numel() * param.element_size() + + # 递归打印树状结构 + def print_tree(tree, prefix="", is_last=True, level=0): + nonlocal total_trainable_params + + items = list(tree.items()) + for i, (key, value) in enumerate(items): + is_last_item = (i == len(items) - 1) + + # 选择合适的前缀符号 + if level == 0: + current_prefix = "" + next_prefix = "" + else: + current_prefix = prefix + ("└── " if is_last_item else "├── ") + next_prefix = prefix + (" " if is_last_item else "│ ") + + if isinstance(value, dict) and any(isinstance(v, dict) for v in value.values()): + # 这是一个模块节点 + print(f"{current_prefix}{key}/") + print_tree(value, next_prefix, is_last_item, level + 1) + elif isinstance(value, dict): + # 这是参数信息 + shape_str = "x".join(map(str, value['shape'])) if value['shape'] else "scalar" + size_mb = value['numel'] * 4 / (1024 * 1024) # 假设float32 + print(f"{current_prefix}{key}: {shape_str} ({value['numel']:,} params, {size_mb:.2f}MB) [{value['dtype']}]") + else: + # 这是一个叶子节点的模块 + print(f"{current_prefix}{key}/") + if isinstance(value, dict): + print_tree(value, next_prefix, is_last_item, level + 1) + + # 打印树状结构 + print_tree(module_params) + + # 打印汇总信息 + print("=" * 80) + print("SUMMARY:") + print(f"Total trainable parameters: {total_trainable_params:,}") + print(f"Total trainable size: {total_trainable_size / (1024 * 1024):.2f} MB") + print(f"Total trainable size: {total_trainable_size / (1024 * 1024 * 1024):.4f} GB") + print("=" * 80) + + def print_frozen_parameter(self, model): + """ + 打印模型的冻结参数,以树状结构显示,使用蓝色字体 + """ + # ANSI颜色代码 + BLUE = '\033[94m' + RESET = '\033[0m' + + print(BLUE + "=" * 80 + RESET) + print(BLUE + "FROZEN PARAMETERS TREE STRUCTURE" + RESET) + print(BLUE + "=" * 80 + RESET) + + # 统计信息 + total_frozen_params = 0 + total_frozen_size = 0 + + # 按模块组织参数 + module_params = {} + + for name, param in model.named_parameters(): + if not param.requires_grad: # 只处理冻结参数(不可训练) + # 解析模块层次结构 + parts = name.split('.') + current_dict = module_params + + # 构建嵌套字典结构 + for i, part in enumerate(parts[:-1]): + if part not in current_dict: + current_dict[part] = {} + current_dict = current_dict[part] + + # 最后一层存储参数信息 + param_name = parts[-1] + param_info = { + 'shape': list(param.shape), + 'numel': param.numel(), + 'dtype': str(param.dtype), + 'device': str(param.device), + 'requires_grad': param.requires_grad + } + current_dict[param_name] = param_info + + total_frozen_params += param.numel() + total_frozen_size += param.numel() * param.element_size() + + # 递归打印树状结构 + def print_tree(tree, prefix="", is_last=True, level=0): + items = list(tree.items()) + for i, (key, value) in enumerate(items): + is_last_item = (i == len(items) - 1) + + # 选择合适的前缀符号 + if level == 0: + current_prefix = "" + next_prefix = "" + else: + current_prefix = prefix + ("└── " if is_last_item else "├── ") + next_prefix = prefix + (" " if is_last_item else "│ ") + + if isinstance(value, dict) and any(isinstance(v, dict) for v in value.values()): + # 这是一个模块节点 + print(BLUE + f"{current_prefix}{key}/" + RESET) + print_tree(value, next_prefix, is_last_item, level + 1) + elif isinstance(value, dict): + # 这是参数信息 + shape_str = "x".join(map(str, value['shape'])) if value['shape'] else "scalar" + size_mb = value['numel'] * 4 / (1024 * 1024) # 假设float32,与原函数保持一致 + print(BLUE + f"{current_prefix}{key}: {shape_str} ({value['numel']:,} params, {size_mb:.2f}MB) [{value['dtype']}]" + RESET) + else: + # 这是一个叶子节点的模块 + print(BLUE + f"{current_prefix}{key}/" + RESET) + if isinstance(value, dict): + print_tree(value, next_prefix, is_last_item, level + 1) + + # 检查是否有冻结参数 + if not module_params: + print(BLUE + "No frozen parameters found in the model." + RESET) + print(BLUE + "=" * 80 + RESET) + return + + # 打印树状结构 + print_tree(module_params) + + # 打印汇总信息 + print(BLUE + "=" * 80 + RESET) + print(BLUE + "SUMMARY:" + RESET) + print(BLUE + f"Total frozen parameters: {total_frozen_params:,}" + RESET) + print(BLUE + f"Total frozen size: {total_frozen_size / (1024 * 1024):.2f} MB" + RESET) + print(BLUE + f"Total frozen size: {total_frozen_size / (1024 * 1024 * 1024):.4f} GB" + RESET) + print(BLUE + "=" * 80 + RESET) + + def print_lora_pos(self, model): + """ + 打印模型中所有的LoRA参数,以树状结构显示 + 如果参数是冻结的则用蓝色字体,非冻结的用正常字体 + """ + # ANSI颜色代码 + BLUE = '\033[94m' + RESET = '\033[0m' + + print("=" * 80) + print("LORA PARAMETERS TREE STRUCTURE") + print("=" * 80) + + # 统计信息 + total_lora_params = 0 + total_lora_size = 0 + frozen_lora_params = 0 + trainable_lora_params = 0 + + # 按模块组织参数 + module_params = {} + + for name, param in model.named_parameters(): + if "lora" in name.lower(): # 只处理包含lora的参数 + # 解析模块层次结构 + parts = name.split('.') + current_dict = module_params + + # 构建嵌套字典结构 + for i, part in enumerate(parts[:-1]): + if part not in current_dict: + current_dict[part] = {} + current_dict = current_dict[part] + + # 最后一层存储参数信息 + param_name = parts[-1] + param_info = { + 'shape': list(param.shape), + 'numel': param.numel(), + 'dtype': str(param.dtype), + 'device': str(param.device), + 'requires_grad': param.requires_grad, + 'is_frozen': not param.requires_grad + } + current_dict[param_name] = param_info + + total_lora_params += param.numel() + total_lora_size += param.numel() * param.element_size() + + if param.requires_grad: + trainable_lora_params += param.numel() + else: + frozen_lora_params += param.numel() + + # 递归打印树状结构 + def print_tree(tree, prefix="", is_last=True, level=0): + items = list(tree.items()) + for i, (key, value) in enumerate(items): + is_last_item = (i == len(items) - 1) + + # 选择合适的前缀符号 + if level == 0: + current_prefix = "" + next_prefix = "" + else: + current_prefix = prefix + ("└── " if is_last_item else "├── ") + next_prefix = prefix + (" " if is_last_item else "│ ") + + if isinstance(value, dict) and any(isinstance(v, dict) for v in value.values()): + # 这是一个模块节点 + print(f"{current_prefix}{key}/") + print_tree(value, next_prefix, is_last_item, level + 1) + elif isinstance(value, dict): + # 这是参数信息,根据冻结状态选择颜色 + shape_str = "x".join(map(str, value['shape'])) if value['shape'] else "scalar" + size_mb = value['numel'] * 4 / (1024 * 1024) # 假设float32 + status_str = "FROZEN" if value['is_frozen'] else "TRAINABLE" + + if value['is_frozen']: + # 冻结参数用蓝色 + print(BLUE + f"{current_prefix}{key}: {shape_str} ({value['numel']:,} params, {size_mb:.2f}MB) [{value['dtype']}] [{status_str}]" + RESET) + else: + # 非冻结参数用正常颜色 + print(f"{current_prefix}{key}: {shape_str} ({value['numel']:,} params, {size_mb:.2f}MB) [{value['dtype']}] [{status_str}]") + else: + # 这是一个叶子节点的模块 + print(f"{current_prefix}{key}/") + if isinstance(value, dict): + print_tree(value, next_prefix, is_last_item, level + 1) + + # 检查是否有LoRA参数 + if not module_params: + print("No LoRA parameters found in the model.") + print("=" * 80) + return + + # 打印树状结构 + print_tree(module_params) + + # 打印汇总信息 + print("=" * 80) + print("SUMMARY:") + print(f"Total LoRA parameters: {total_lora_params:,}") + print(f"├─ " + BLUE + f"Frozen LoRA parameters: {frozen_lora_params:,}" + RESET) + print(f"└─ Trainable LoRA parameters: {trainable_lora_params:,}") + print(f"Total LoRA size: {total_lora_size / (1024 * 1024):.2f} MB") + print(f"Total LoRA size: {total_lora_size / (1024 * 1024 * 1024):.4f} GB") + print("=" * 80) + #@profile def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_grad=False) -> Dict[str, Union[float, int]]: """ @@ -566,6 +848,16 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr self._learn_model.train() self._target_model.train() + # print model trainnable parameter self.print_traninable_parameter(self._target_model) + + # self.print_traninable_parameter(self._learn_model) + # self.print_frozen_parameter(self._learn_model) + + # self.print_lora_pos(self._learn_model) + # if a==0: + # self.print_traninable_parameter(self._learn_model) + # a+=1 + obs_loss_multi_task = [] reward_loss_multi_task = [] policy_loss_multi_task = [] @@ -1489,8 +1781,90 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any], finetune_components """ # finetune_components = [] # load-enc-trans_finetune-head # finetune_components = ['transformer'] # load-enc-trans_finetune-trans-head - finetune_components = ["representation_network", "encoder"] # load-enc-trans_finetune-encoder-head - + # finetune_components = ["representation_network", "encoder",'transformer'] # load-enc-trans_finetune-encoder-head + + # 打印state_dict中所有的lora 参数,以树的结构 + # def print_lora_from_state_dict(state_dict_part: Dict[str, Any], title: str): + # """ + # 打印state_dict中的LoRA参数,以树状结构显示 + # """ + # print("=" * 60) + # print(f"LORA PARAMETERS IN {title.upper()}") + # print("=" * 60) + + # # 收集所有LoRA参数 + # lora_params = {} + # total_lora_count = 0 + # total_lora_size = 0 + + # for key, value in state_dict_part.items(): + # if "lora" in key.lower(): + # # 解析参数层次结构 + # parts = key.split('.') + # current_dict = lora_params + + # # 构建嵌套字典结构 + # for i, part in enumerate(parts[:-1]): + # if part not in current_dict: + # current_dict[part] = {} + # current_dict = current_dict[part] + + # # 最后一层存储参数信息 + # param_name = parts[-1] + # if hasattr(value, 'shape'): + # param_info = { + # 'shape': list(value.shape), + # 'numel': value.numel() if hasattr(value, 'numel') else 0, + # 'dtype': str(value.dtype) + # } + # total_lora_count += param_info['numel'] + # total_lora_size += param_info['numel'] * 4 # 假设float32 + # else: + # param_info = { + # 'shape': 'unknown', + # 'numel': 0, + # 'dtype': str(type(value)) + # } + # current_dict[param_name] = param_info + + # # 递归打印LoRA参数树 + # def print_lora_tree(tree, prefix="", level=0): + # items = list(tree.items()) + # for i, (key, value) in enumerate(items): + # is_last_item = (i == len(items) - 1) + + # if level == 0: + # current_prefix = "" + # next_prefix = "" + # else: + # current_prefix = prefix + ("└── " if is_last_item else "├── ") + # next_prefix = prefix + (" " if is_last_item else "│ ") + + # if isinstance(value, dict) and any(isinstance(v, dict) for v in value.values()): + # # 模块节点 + # print(f"{current_prefix}{key}/") + # print_lora_tree(value, next_prefix, level + 1) + # elif isinstance(value, dict): + # # 参数信息 + # shape_str = "x".join(map(str, value['shape'])) if isinstance(value['shape'], list) else str(value['shape']) + # size_mb = value['numel'] * 4 / (1024 * 1024) if value['numel'] > 0 else 0 + # print(f"{current_prefix}{key}: {shape_str} ({value['numel']:,} params, {size_mb:.2f}MB) [{value['dtype']}]") + + # if lora_params: + # print_lora_tree(lora_params) + # print("─" * 60) + # print(f"Total LoRA parameters in {title}: {total_lora_count:,}") + # print(f"Total LoRA size in {title}: {total_lora_size / (1024 * 1024):.2f} MB") + # else: + # print(f"No LoRA parameters found in {title}") + # print("=" * 60) + + # # 分别打印learn_model和target_model中的LoRA参数 + # if 'model' in state_dict: + # print_lora_from_state_dict(state_dict['model'], "LEARN_MODEL STATE_DICT") + + + # 定义需要排除的参数前缀,即不加载这些参数 exclude_prefixes = [ '_orig_mod.world_model.head_policy_multi_task.', @@ -1551,96 +1925,24 @@ def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, # 包含 "transformer" 则属于 transformer 模块,其它部分可根据需要扩展。 for name, param in self._learn_model.named_parameters(): # 如果参数属于 encoder 且不在需要微调的组件中,则冻结该参数 - if "encoder" in name and "encoder" not in finetune_components: + if ("lora" not in name) and ("encoder" in name and "encoder" not in finetune_components): param.requires_grad = False print(f"Freezing parameter: {name}") - elif "representation_network" in name and "representation_network" not in finetune_components: + elif ("lora" not in name) and ("representation_network" in name and "representation_network" not in finetune_components): param.requires_grad = False print(f"Freezing parameter: {name}") # 如果参数属于 transformer 且不在需要微调的组件中,则冻结该参数 - elif "transformer" in name and "transformer" not in finetune_components: + elif ("lora" not in name) and ("transformer" in name and "transformer" not in finetune_components): + param.requires_grad = False + print(f"Freezing parameter: {name}") + elif ("_orig_mod.world_model.pos_emb" in name or "_orig_mod.world_model.act_embedding_table" in name) and ("transformer" not in finetune_components): param.requires_grad = False print(f"Freezing parameter: {name}") else: # 如果参数属于其他模块,或者包含在 finetune_components 中,则保持默认(或者根据需要调整) print(f"Parameter remains default: {name}") - - # 注意: - # 如果你的模型中嵌套模块更为复杂,可以基于 module 的属性而不是仅仅依靠参数名称进行判断,比如: - # for module in self._learn_model.modules(): - # if isinstance(module, EncoderModule) and "encoder" not in finetune_components: - # for param in module.parameters(): - # param.requires_grad = False - - # # ========== TODO: pretrain-finetue version: only load encoder and transformer-backbone parameters ========== - # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: - # """ - # Overview: - # Load the state_dict variable into policy learn mode, excluding multi-task related parameters. - # Arguments: - # - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved previously. - # """ - # # 定义需要排除的参数前缀 - # exclude_prefixes = [ - # '_orig_mod.world_model.head_policy_multi_task.', - # '_orig_mod.world_model.head_value_multi_task.', - # '_orig_mod.world_model.head_rewards_multi_task.', - # '_orig_mod.world_model.head_observations_multi_task.', - # '_orig_mod.world_model.task_emb.' - # ] - - # # 定义需要排除的具体参数(如果有特殊情况) - # exclude_keys = [ - # '_orig_mod.world_model.task_emb.weight', - # '_orig_mod.world_model.task_emb.bias', # 如果存在则添加 - # # 添加其他需要排除的具体参数名 - # ] - # def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]: - # """ - # 过滤掉需要排除的参数。 - # """ - # filtered = {} - # for k, v in state_dict_loader.items(): - # if any(k.startswith(prefix) for prefix in exclude_prefixes): - # print(f"Excluding parameter: {k}") # 调试用,查看哪些参数被排除 - # continue - # if k in exclude_keys: - # print(f"Excluding specific parameter: {k}") # 调试用 - # continue - # filtered[k] = v - # return filtered - - # # 过滤并加载 'model' 部分 - # if 'model' in state_dict: - # model_state_dict = state_dict['model'] - # filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys) - # missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False) - # if missing_keys: - # print(f"Missing keys when loading _learn_model: {missing_keys}") - # if unexpected_keys: - # print(f"Unexpected keys when loading _learn_model: {unexpected_keys}") - # else: - # print("No 'model' key found in the state_dict.") - - # # 过滤并加载 'target_model' 部分 - # if 'target_model' in state_dict: - # target_model_state_dict = state_dict['target_model'] - # filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys) - # missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False) - # if missing_keys: - # print(f"Missing keys when loading _target_model: {missing_keys}") - # if unexpected_keys: - # print(f"Unexpected keys when loading _target_model: {unexpected_keys}") - # else: - # print("No 'target_model' key found in the state_dict.") - - # # 不要加载优化器的 state_dict,因为优化器通常不包含模型参数,加载后性能反而变差 - # # if 'optimizer_world_model' in state_dict: - # # optimizer_state_dict = state_dict['optimizer_world_model'] - # # try: - # # self._optimizer_world_model.load_state_dict(optimizer_state_dict) - # # except Exception as e: - # # print(f"Error loading optimizer state_dict: {e}") - # # else: - # # print("No 'optimizer_world_model' key found in the state_dict.") \ No newline at end of file + + # a=1 + # if 'target_model' in state_dict: + # print_lora_from_state_dict(state_dict['target_model'], "TARGET_MODEL STATE_DICT") diff --git a/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py index bdc5e4f7a..39a2306e1 100644 --- a/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py +++ b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py @@ -339,8 +339,8 @@ def create_env_manager(): import os - num_games = 8 # 26 # 8 - num_layers = 4 # ==============TODO============== + num_games = 1 # 26 # 8 + num_layers = 8 # ==============TODO============== action_space_size = 18 collector_env_num = 8 num_segments = 8 @@ -350,8 +350,12 @@ def create_env_manager(): max_env_step = int(4e5) reanalyze_ratio = 0.0 + if num_games==1: + env_id_list = [ + 'PongNoFrameskip-v4' + ] - if num_games==3: + elif num_games==3: env_id_list = [ 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4' ] @@ -371,8 +375,13 @@ def create_env_manager(): 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ] - - if len(env_id_list) == 8: + if len(env_id_list) == 1: + if num_layers == 4: + # effective_batch_size = 1024 # nlayer4 需要设置replay_ratio=0.25对应的upc=40 + effective_batch_size = 512 # nlayer4 需要设置replay_ratio=0.25对应的upc=40 moco + elif num_layers == 8: + effective_batch_size = 16 + elif len(env_id_list) == 8: if num_layers == 4: # effective_batch_size = 1024 # nlayer4 需要设置replay_ratio=0.25对应的upc=40 effective_batch_size = 512 # nlayer4 需要设置replay_ratio=0.25对应的upc=40 moco @@ -388,7 +397,7 @@ def create_env_manager(): elif len(env_id_list) == 18: effective_batch_size = 512 * 3 # 1536 elif len(env_id_list) == 3: - effective_batch_size = 10 # debug + effective_batch_size = 512 # debug else: raise ValueError("不支持的环境数量: {}".format(n)) diff --git a/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_debug_naive.py b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_debug_naive.py new file mode 100644 index 000000000..30945a10f --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_debug_naive.py @@ -0,0 +1,492 @@ +from easydict import EasyDict + +import math + +def compute_batch_config(env_id_list, effective_batch_size): + n = len(env_id_list) + + # 根据环境数量设定有效 batch size 和每个环境的最大微 batch size + gpu_num = 1 + max_micro_batch_one_gpu = 400 + max_micro_batch = int(max_micro_batch_one_gpu / (n // gpu_num)) + + + # 计算每个环境理论上应该分得的 batch size + theoretical_env_batch = effective_batch_size / n + + if theoretical_env_batch > max_micro_batch: + # 当每个环境按均分的 batch 大于允许的最大微 batch 时, + # 则令每个环境的实际微 batch size 固定为 max_micro_batch + micro_batch_size = max_micro_batch + # 梯度累计步数 = ceil(每个环境理论 batch size / 最大微 batch size) + grad_accumulate_steps = math.ceil(theoretical_env_batch / max_micro_batch) + else: + # 否则直接使用计算出的理论 batch size(这里向下取整以保证整数) + micro_batch_size = int(theoretical_env_batch) + grad_accumulate_steps = 1 + + # 为每个环境分配相同的微 batch size + batch_size = [micro_batch_size for _ in range(n)] + + # 打印一些调试信息(也可以记录到 log 中) + print("环境数量: {}".format(n)) + print("有效 total batch size: {}".format(effective_batch_size)) + print("每个环境的理论 batch size: {:.2f}".format(theoretical_env_batch)) + print("每个环境的微 batch size: {}".format(micro_batch_size)) + print("梯度累积步数: {}".format(grad_accumulate_steps)) + + return batch_size, grad_accumulate_steps + + + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, + total_batch_size): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + # ===== only for debug ===== + # collect_max_episode_steps=int(40), + # eval_max_episode_steps=int(40), + ), + policy=dict( + multi_gpu=False, # Disabled for single GPU training + only_use_moco_stats=False, + use_moco=False, # ==============TODO============== + # use_moco=True, # ==============TODO============== + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + total_task_num=len(env_id_list), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + # num_channels=512, # ==============TODO============== + continuous_action_space=False, + world_model_cfg=dict( + use_global_pooling=False, + + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', # TODO: for latent state layer_norm + + # final_norm_option_in_obs_head='SimNorm', + # final_norm_option_in_encoder='SimNorm', + # predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm + + # predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm + # share_head=True, # TODO + share_head=False, # TODO + + # analysis_dormant_ratio_weight_rank=True, # TODO + analysis_dormant_ratio_weight_rank=False, # TODO + dormant_threshold=0.025, + continuous_action_space=False, + + task_embed_option=None, # ==============TODO: none ============== + use_task_embed=False, # ==============TODO============== + + # task_embed_option='concat_task_embed', # ==============TODO: none ============== + # use_task_embed=True, # ==============TODO============== + # task_embed_dim=128, + # # task_embed_dim=96, + + use_shared_projection=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + # batch_size=64 8games训练时,每张卡大约占 12*3=36G cuda显存 + # num_layers=12, + # num_heads=24, + + # num_layers=4, # TODO======= + num_layers=8, + + num_heads=24, + + # ===== only for debug ===== + # num_layers=1, + # num_heads=8, + + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + + encoder_type='vit', + # encoder_type='resnet', + + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + + moe_in_transformer=False, + # multiplication_moe_in_transformer=False, + multiplication_moe_in_transformer=True, # TODO======= + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + + # LoRA 参数: + moe_use_lora=False, # TODO + # moe_use_lora=True, # TODO + + curriculum_stage_num=curriculum_stage_num, + lora_target_modules=[], + lora_r=0, # modefied + lora_alpha=1, + lora_dropout=0.0, + lora_scale_init=1, + + min_stage0_iters=5000000, # 50k + max_stage_iters=20000, # 20k + ), + ), + use_task_exploitation_weight=False, # TODO + # use_task_exploitation_weight=True, # TODO + + target_return =target_return_dict[env_id], + balance_pipeline=True, + # task_complexity_weight=False, # TODO + task_complexity_weight=True, # TODO + + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), # TODO: DEBUG + # train_start_after_envsteps=int(2000), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=80, # TODO + # update_per_collect=2, # TODO + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + # cos_lr_scheduler=True, + cos_lr_scheduler=False, + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + # eval_freq=int(1e4), + eval_freq=int(1e4), + # eval_freq=int(2), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size): + configs = [] + # ===== only for debug ===== + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250509/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_vit-encoder-ps8_trans-nlayer8_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250509/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_no-encoder-scale_cnn-encoder_moe8_trans-nlayer8_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250514/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_vit-ln_moe8_trans-nlayer4_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' + exp_name_prefix = f'fintune_log/SpaceInvaders_naive_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, + reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + + cd /cpfs04/user/puyuan/code/LightZero/ + python -m torch.distributed.launch --nproc_per_node=6 --master_port=29502 /cpfs04/user/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/20250522_cpfs/uz_mt_nlayer4_atari8_vit-small_moe8-lora_balance-totalstage5_stage-50k-20k_s0.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29502 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/20250509/uz_mt_nlayer4_atari26_vit-ln_moe8_balance-totalstage9.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_balance_atari26_vit-ln_moe8_totalstage5.log + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/20250509/uz_mt_nlayer8_atari8_vit-ln_moe8_balance-totalstage5.log + + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_balance_atari8_no-encoder-grad-scale_cnn-encoder_moe8_totalstage5_20250509.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_atari26_cnn-encoder_totalstage9_balance20250505.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_atari8_vit-base-encoder-ps8_totalstage3_balance_20250501_debug.log + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_atari26_vit-large-encoder-ps8-simnorm_totalstage5_balance20250501.log + + """ + + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import os + + + # env_id_list = [ + # 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + # 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + # ] + # # List of Atari games used for multi-task learning + # env_id_list = [ + # 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + # 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + # 'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + # 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', 'FreewayNoFrameskip-v4', + # 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', + # 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + # 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + # ] + + def get_atari_target_return_dict(ratio=1.0): + """ + 根据 Human 分数和传入的比例参数 ratio 计算每个 Atari 游戏的 target_return。 + + 参数: + ratio: 控制 target_return 大小的比例因子,默认为 1.0 + + 返回: + 包含 Atari 游戏 target_return 的字典,key 为环境名称,value 为计算后的目标分数(整数)。 + """ + human_scores = { + # 8games + 'PongNoFrameskip-v4': 14.6, # 0 + 'MsPacmanNoFrameskip-v4': 6951.6, # 1 + 'SeaquestNoFrameskip-v4': 42054.7, # 2 + 'BoxingNoFrameskip-v4': 12.1, # 3 + 'AlienNoFrameskip-v4': 7127.7, # 4 + 'ChopperCommandNoFrameskip-v4': 7387.8, # 5 + 'HeroNoFrameskip-v4': 30826.4, # 6 + 'RoadRunnerNoFrameskip-v4': 7845.0, # 7 + # 后续 Atari 26games 的额外项 + 'AmidarNoFrameskip-v4': 1719.5, # 8 + 'AssaultNoFrameskip-v4': 742.0, # 9 + 'AsterixNoFrameskip-v4': 8503.3, # 10 + 'BankHeistNoFrameskip-v4': 753.1, # 11 + 'BattleZoneNoFrameskip-v4': 37187.5, # 12 + 'CrazyClimberNoFrameskip-v4': 35829.4, # 13 + 'DemonAttackNoFrameskip-v4': 1971.0, # 14 + 'FreewayNoFrameskip-v4': 29.6, # 15 + 'FrostbiteNoFrameskip-v4': 4334.7, # 16 + 'GopherNoFrameskip-v4': 2412.5, # 17 + 'JamesbondNoFrameskip-v4': 302.8, # 18 + 'KangarooNoFrameskip-v4': 3035.0, # 19 + 'KrullNoFrameskip-v4': 2665.5, # 20 + 'KungFuMasterNoFrameskip-v4': 22736.3, # 21 + 'PrivateEyeNoFrameskip-v4': 69571.3, # 22 + 'UpNDownNoFrameskip-v4': 11693.2, # 23 + 'QbertNoFrameskip-v4': 13455.0, # 24 + 'BreakoutNoFrameskip-v4': 30.5, # 25 + } + + # target score + target_scores = { + # 8games + # 'PongNoFrameskip-v4': 14.6, # 0 expert + 'PongNoFrameskip-v4': 20, # 0 expert + # 'MsPacmanNoFrameskip-v4': 1500.6, # 1 + 'MsPacmanNoFrameskip-v4': 6951.6, # 1 + # 'SeaquestNoFrameskip-v4': 1000.7, # 2 + 'SeaquestNoFrameskip-v4': 42054.7, # 2 expert + 'BoxingNoFrameskip-v4': 12.1, # 3 expert + # 'AlienNoFrameskip-v4': 1000.7, # 4 + 'AlienNoFrameskip-v4': 7127.7, # 4 expert + # 'ChopperCommandNoFrameskip-v4': 3000.8, # 5 + # 'HeroNoFrameskip-v4': 3082.4, # 6 + 'ChopperCommandNoFrameskip-v4': 7387.8, # 5 expert + 'HeroNoFrameskip-v4': 30826.4, # 6 expert + 'RoadRunnerNoFrameskip-v4': 7845.0, # 7 expert + # 后续 Atari 26games 的额外项 + 'AmidarNoFrameskip-v4': 100.5, # 8 + 'AssaultNoFrameskip-v4': 742.0, # 9 + 'AsterixNoFrameskip-v4': 1503.3, # 10 + 'BankHeistNoFrameskip-v4': 753.1, # 11 + 'BattleZoneNoFrameskip-v4': 12187.5, # 12 + 'CrazyClimberNoFrameskip-v4': 15829.4, # 13 + 'DemonAttackNoFrameskip-v4': 1971.0, # 14 + 'FreewayNoFrameskip-v4': 29.6, # 15 + 'FrostbiteNoFrameskip-v4': 334.7, # 16 + 'GopherNoFrameskip-v4': 2412.5, # 17 + 'JamesbondNoFrameskip-v4': 302.8, # 18 + 'KangarooNoFrameskip-v4': 3035.0, # 19 + 'KrullNoFrameskip-v4': 2665.5, # 20 + 'KungFuMasterNoFrameskip-v4': 12736.3, # 21 + 'PrivateEyeNoFrameskip-v4': 1001.3, # 22 + 'UpNDownNoFrameskip-v4': 11693.2, # 23 + 'QbertNoFrameskip-v4': 13455.0, # 24 + 'BreakoutNoFrameskip-v4': 30.5, # 25 + # --- 经典射击与反应 --- + 'SpaceInvadersNoFrameskip-v4': 1668, + 'RiverRaidNoFrameskip-v4' : 17117.1, + 'BeamRiderNoFrameskip-v4' : 16926.5, + + # --- 物理与惯性控制 --- + 'AsteroidsNoFrameskip-v4' : 47388.7, + 'GravitarNoFrameskip-v4' : 3351.4, + + # --- 探索与长时序规划 (Hard-Exploration) --- + 'PitfallNoFrameskip-v4' : 6463.7, + 'AdventureNoFrameskip-v4' : 0.0, + 'EnduroNoFrameskip-v4' : 860.5, # 长时程任务,有昼夜变化,考验模型的耐力和持续表现 + } + + + # 计算每个游戏的 target_return + # return {env: int(round(score * ratio)) for env, score in human_scores.items()} + return {env: int(round(score * ratio)) for env, score in target_scores.items()} + + + global target_return_dict + # global BENCHMARK_NAME + # BENCHMARK_NAME='atari' + + # 示例:以 ratio=1 使用 + target_return_dict = get_atari_target_return_dict(ratio=1) + # target_return_dict = get_atari_target_return_dict(ratio=0.5) + num_games = 1 # 26 # 8 + + # 分别定义 Atari 游戏列表(8games 和 26games) + if num_games==1: + env_id_list = [ + 'SpaceInvadersNoFrameskip-v4' + ] + + if num_games==3: + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4' + ] + elif num_games==8: + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + ] + elif num_games==26: + # List of Atari games used for multi-task learning + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + 'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', 'FreewayNoFrameskip-v4', + 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', + 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + ] + + global curriculum_stage_num + # TODO ============== + # curriculum_stage_num=3 + curriculum_stage_num=1 + # curriculum_stage_num=9 + + action_space_size = 18 + collector_env_num = 2 + num_segments = collector_env_num + n_episode = 8 + evaluator_env_num = 3 + # num_simulations = 50 + num_simulations = 25 + + max_env_step = int(4e5) + reanalyze_ratio = 0.0 + if len(env_id_list) == 1: + effective_batch_size = 64 + elif len(env_id_list) == 8: + effective_batch_size = 512 + elif len(env_id_list) == 26: + # effective_batch_size = 832 # cnn-encoder + effective_batch_size = 512 # base-vit-encoder + # effective_batch_size = 256 # base-vit-encoder large-vit-encoder + elif len(env_id_list) == 18: + effective_batch_size = 512 * 3 # 1536 + else: + raise ValueError("不支持的环境数量: {}".format(n)) + + batch_sizes, grad_acc_steps = compute_batch_config(env_id_list, effective_batch_size) + total_batch_size = effective_batch_size # 当前无效 + + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + # buffer_reanalyze_freq = 1 / 50 + buffer_reanalyze_freq = 1 / 1000000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ======== TODO: only for debug ======== + # collector_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # evaluator_env_num = 2 + # num_simulations = 1 + # reanalyze_batch_size = 2 + # num_unroll_steps = 5 + # infer_context_length = 2 + # batch_sizes = [4 for _ in range(len(env_id_list))] + + from lzero.entry import train_unizero_multitask_segment_ddp + # finetune_components = [] # load-enc-trans_finetune-head + # finetune_components = ['transformer'] # load-enc-trans_finetune-trans-head + finetune_components = [] # load-enc-trans_finetune-encoder-head + + for seed in [3]: + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_sizes, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size) + # pretrained_model_path = '/fs-computility/niuyazhe/tangjia/github/LightZero/ckpt/ckpt_best.pth.tar' + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step, benchmark_name="atari",finetune_components=finetune_components) + # ======== TODO: only for debug ======== + # train_unizero_multitask_segment_ddp(configs[:2], seed=seed, max_env_step=max_env_step) # train on the first four tasks + + + +# TODO(pu): only for debug,设置环境变量DEBUG=1 +# from train_grpo_rm_colocate import maybe_ipdb +# import torch.distributed as dist +# maybe_ipdb(dist.get_rank()) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_finetune_SpaceInvaders_full.py b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_finetune_SpaceInvaders_full.py new file mode 100644 index 000000000..2c08a0e35 --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_finetune_SpaceInvaders_full.py @@ -0,0 +1,494 @@ +from easydict import EasyDict + +import math + +def compute_batch_config(env_id_list, effective_batch_size): + n = len(env_id_list) + + # 根据环境数量设定有效 batch size 和每个环境的最大微 batch size + gpu_num = 1 + max_micro_batch_one_gpu = 400 + max_micro_batch = int(max_micro_batch_one_gpu / (n // gpu_num)) + + + # 计算每个环境理论上应该分得的 batch size + theoretical_env_batch = effective_batch_size / n + + if theoretical_env_batch > max_micro_batch: + # 当每个环境按均分的 batch 大于允许的最大微 batch 时, + # 则令每个环境的实际微 batch size 固定为 max_micro_batch + micro_batch_size = max_micro_batch + # 梯度累计步数 = ceil(每个环境理论 batch size / 最大微 batch size) + grad_accumulate_steps = math.ceil(theoretical_env_batch / max_micro_batch) + else: + # 否则直接使用计算出的理论 batch size(这里向下取整以保证整数) + micro_batch_size = int(theoretical_env_batch) + grad_accumulate_steps = 1 + + # 为每个环境分配相同的微 batch size + batch_size = [micro_batch_size for _ in range(n)] + + # 打印一些调试信息(也可以记录到 log 中) + print("环境数量: {}".format(n)) + print("有效 total batch size: {}".format(effective_batch_size)) + print("每个环境的理论 batch size: {:.2f}".format(theoretical_env_batch)) + print("每个环境的微 batch size: {}".format(micro_batch_size)) + print("梯度累积步数: {}".format(grad_accumulate_steps)) + + return batch_size, grad_accumulate_steps + + + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, + total_batch_size): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + # ===== only for debug ===== + # collect_max_episode_steps=int(40), + # eval_max_episode_steps=int(40), + ), + policy=dict( + multi_gpu=False, # Disabled for single GPU training + only_use_moco_stats=False, + use_moco=False, # ==============TODO============== + # use_moco=True, # ==============TODO============== + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + total_task_num=len(env_id_list), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + # num_channels=512, # ==============TODO============== + continuous_action_space=False, + world_model_cfg=dict( + use_global_pooling=False, + + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', # TODO: for latent state layer_norm + + # final_norm_option_in_obs_head='SimNorm', + # final_norm_option_in_encoder='SimNorm', + # predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm + + # predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm + # share_head=True, # TODO + share_head=False, # TODO + + # analysis_dormant_ratio_weight_rank=True, # TODO + analysis_dormant_ratio_weight_rank=False, # TODO + dormant_threshold=0.025, + continuous_action_space=False, + + task_embed_option=None, # ==============TODO: none ============== + use_task_embed=False, # ==============TODO============== + + # task_embed_option='concat_task_embed', # ==============TODO: none ============== + # use_task_embed=True, # ==============TODO============== + # task_embed_dim=128, + # # task_embed_dim=96, + + use_shared_projection=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + # batch_size=64 8games训练时,每张卡大约占 12*3=36G cuda显存 + # num_layers=12, + # num_heads=24, + + # num_layers=4, # TODO======= + num_layers=8, + + num_heads=24, + + # ===== only for debug ===== + # num_layers=1, + # num_heads=8, + + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + + encoder_type='vit', + # encoder_type='resnet', + + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + + moe_in_transformer=False, + # multiplication_moe_in_transformer=False, + multiplication_moe_in_transformer=True, # TODO======= + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + + # LoRA 参数: + moe_use_lora=False, # TODO + # moe_use_lora=True, # TODO + + curriculum_stage_num=curriculum_stage_num, + lora_target_modules=[], + lora_r=0, # modefied + lora_alpha=1, + lora_dropout=0.0, + lora_scale_init=1, + + min_stage0_iters=5000000, # 50k + max_stage_iters=20000, # 20k + + encoder_lora_r=0, + encoder_lora_alpha=1, + encoder_lora_dropout=0.1, + ), + ), + use_task_exploitation_weight=False, # TODO + # use_task_exploitation_weight=True, # TODO + + target_return =target_return_dict[env_id], + balance_pipeline=True, + # task_complexity_weight=False, # TODO + task_complexity_weight=True, # TODO + + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), # TODO: DEBUG + # train_start_after_envsteps=int(2000), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=80, # TODO + # update_per_collect=2, # TODO + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + # cos_lr_scheduler=True, + cos_lr_scheduler=False, + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + # eval_freq=int(1e4), + eval_freq=int(1e4), + # eval_freq=int(2), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size): + configs = [] + # ===== only for debug ===== + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250509/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_vit-encoder-ps8_trans-nlayer8_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250509/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_no-encoder-scale_cnn-encoder_moe8_trans-nlayer8_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250514/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_vit-ln_moe8_trans-nlayer4_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' + exp_name_prefix = f'fintune_log/SpaceInvaders_full_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, + reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + + cd /cpfs04/user/puyuan/code/LightZero/ + python -m torch.distributed.launch --nproc_per_node=6 --master_port=29502 /cpfs04/user/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/20250522_cpfs/uz_mt_nlayer4_atari8_vit-small_moe8-lora_balance-totalstage5_stage-50k-20k_s0.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29502 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/20250509/uz_mt_nlayer4_atari26_vit-ln_moe8_balance-totalstage9.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_balance_atari26_vit-ln_moe8_totalstage5.log + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/20250509/uz_mt_nlayer8_atari8_vit-ln_moe8_balance-totalstage5.log + + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_balance_atari8_no-encoder-grad-scale_cnn-encoder_moe8_totalstage5_20250509.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_atari26_cnn-encoder_totalstage9_balance20250505.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_atari8_vit-base-encoder-ps8_totalstage3_balance_20250501_debug.log + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_atari26_vit-large-encoder-ps8-simnorm_totalstage5_balance20250501.log + + """ + + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import os + + + # env_id_list = [ + # 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + # 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + # ] + # # List of Atari games used for multi-task learning + # env_id_list = [ + # 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + # 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + # 'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + # 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', 'FreewayNoFrameskip-v4', + # 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', + # 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + # 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + # ] + + def get_atari_target_return_dict(ratio=1.0): + """ + 根据 Human 分数和传入的比例参数 ratio 计算每个 Atari 游戏的 target_return。 + + 参数: + ratio: 控制 target_return 大小的比例因子,默认为 1.0 + + 返回: + 包含 Atari 游戏 target_return 的字典,key 为环境名称,value 为计算后的目标分数(整数)。 + """ + human_scores = { + # 8games + 'PongNoFrameskip-v4': 14.6, # 0 + 'MsPacmanNoFrameskip-v4': 6951.6, # 1 + 'SeaquestNoFrameskip-v4': 42054.7, # 2 + 'BoxingNoFrameskip-v4': 12.1, # 3 + 'AlienNoFrameskip-v4': 7127.7, # 4 + 'ChopperCommandNoFrameskip-v4': 7387.8, # 5 + 'HeroNoFrameskip-v4': 30826.4, # 6 + 'RoadRunnerNoFrameskip-v4': 7845.0, # 7 + # 后续 Atari 26games 的额外项 + 'AmidarNoFrameskip-v4': 1719.5, # 8 + 'AssaultNoFrameskip-v4': 742.0, # 9 + 'AsterixNoFrameskip-v4': 8503.3, # 10 + 'BankHeistNoFrameskip-v4': 753.1, # 11 + 'BattleZoneNoFrameskip-v4': 37187.5, # 12 + 'CrazyClimberNoFrameskip-v4': 35829.4, # 13 + 'DemonAttackNoFrameskip-v4': 1971.0, # 14 + 'FreewayNoFrameskip-v4': 29.6, # 15 + 'FrostbiteNoFrameskip-v4': 4334.7, # 16 + 'GopherNoFrameskip-v4': 2412.5, # 17 + 'JamesbondNoFrameskip-v4': 302.8, # 18 + 'KangarooNoFrameskip-v4': 3035.0, # 19 + 'KrullNoFrameskip-v4': 2665.5, # 20 + 'KungFuMasterNoFrameskip-v4': 22736.3, # 21 + 'PrivateEyeNoFrameskip-v4': 69571.3, # 22 + 'UpNDownNoFrameskip-v4': 11693.2, # 23 + 'QbertNoFrameskip-v4': 13455.0, # 24 + 'BreakoutNoFrameskip-v4': 30.5, # 25 + } + + # target score + target_scores = { + # 8games + # 'PongNoFrameskip-v4': 14.6, # 0 expert + 'PongNoFrameskip-v4': 20, # 0 expert + # 'MsPacmanNoFrameskip-v4': 1500.6, # 1 + 'MsPacmanNoFrameskip-v4': 6951.6, # 1 + # 'SeaquestNoFrameskip-v4': 1000.7, # 2 + 'SeaquestNoFrameskip-v4': 42054.7, # 2 expert + 'BoxingNoFrameskip-v4': 12.1, # 3 expert + # 'AlienNoFrameskip-v4': 1000.7, # 4 + 'AlienNoFrameskip-v4': 7127.7, # 4 expert + # 'ChopperCommandNoFrameskip-v4': 3000.8, # 5 + # 'HeroNoFrameskip-v4': 3082.4, # 6 + 'ChopperCommandNoFrameskip-v4': 7387.8, # 5 expert + 'HeroNoFrameskip-v4': 30826.4, # 6 expert + 'RoadRunnerNoFrameskip-v4': 7845.0, # 7 expert + # 后续 Atari 26games 的额外项 + 'AmidarNoFrameskip-v4': 100.5, # 8 + 'AssaultNoFrameskip-v4': 742.0, # 9 + 'AsterixNoFrameskip-v4': 1503.3, # 10 + 'BankHeistNoFrameskip-v4': 753.1, # 11 + 'BattleZoneNoFrameskip-v4': 12187.5, # 12 + 'CrazyClimberNoFrameskip-v4': 15829.4, # 13 + 'DemonAttackNoFrameskip-v4': 1971.0, # 14 + 'FreewayNoFrameskip-v4': 29.6, # 15 + 'FrostbiteNoFrameskip-v4': 334.7, # 16 + 'GopherNoFrameskip-v4': 2412.5, # 17 + 'JamesbondNoFrameskip-v4': 302.8, # 18 + 'KangarooNoFrameskip-v4': 3035.0, # 19 + 'KrullNoFrameskip-v4': 2665.5, # 20 + 'KungFuMasterNoFrameskip-v4': 12736.3, # 21 + 'PrivateEyeNoFrameskip-v4': 1001.3, # 22 + 'UpNDownNoFrameskip-v4': 11693.2, # 23 + 'QbertNoFrameskip-v4': 13455.0, # 24 + 'BreakoutNoFrameskip-v4': 30.5, # 25 + # --- 经典射击与反应 --- + 'SpaceInvadersNoFrameskip-v4': 1669.7, + 'RiverRaidNoFrameskip-v4' : 17117.1, + 'BeamRiderNoFrameskip-v4' : 16926.5, + + # --- 物理与惯性控制 --- + 'AsteroidsNoFrameskip-v4' : 47388.7, + 'GravitarNoFrameskip-v4' : 3351.4, + + # --- 探索与长时序规划 (Hard-Exploration) --- + 'PitfallNoFrameskip-v4' : 6463.7, + 'AdventureNoFrameskip-v4' : 0.0, + 'EnduroNoFrameskip-v4' : 860.5, # 长时程任务,有昼夜变化,考验模型的耐力和持续表现 + } + + + # 计算每个游戏的 target_return + # return {env: int(round(score * ratio)) for env, score in human_scores.items()} + return {env: int(round(score * ratio)) for env, score in target_scores.items()} + + + global target_return_dict + # global BENCHMARK_NAME + # BENCHMARK_NAME='atari' + + # 示例:以 ratio=1 使用 + target_return_dict = get_atari_target_return_dict(ratio=1) + # target_return_dict = get_atari_target_return_dict(ratio=0.5) + num_games = 1 # 26 # 8 + + # 分别定义 Atari 游戏列表(8games 和 26games) + if num_games==1: + env_id_list = [ + 'SpaceInvadersNoFrameskip-v4' + ] + + if num_games==3: + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4' + ] + elif num_games==8: + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + ] + elif num_games==26: + # List of Atari games used for multi-task learning + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + 'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', 'FreewayNoFrameskip-v4', + 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', + 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + ] + + global curriculum_stage_num + # TODO ============== + # curriculum_stage_num=3 + curriculum_stage_num=1 + # curriculum_stage_num=9 + + action_space_size = 18 + collector_env_num = 8 + num_segments = collector_env_num + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 25 + max_env_step = int(4e5) + reanalyze_ratio = 0.0 + if len(env_id_list) == 1: + effective_batch_size = 64 + elif len(env_id_list) == 8: + effective_batch_size = 512 + elif len(env_id_list) == 26: + # effective_batch_size = 832 # cnn-encoder + effective_batch_size = 512 # base-vit-encoder + # effective_batch_size = 256 # base-vit-encoder large-vit-encoder + elif len(env_id_list) == 18: + effective_batch_size = 512 * 3 # 1536 + else: + raise ValueError("不支持的环境数量: {}".format(n)) + + batch_sizes, grad_acc_steps = compute_batch_config(env_id_list, effective_batch_size) + total_batch_size = effective_batch_size # 当前无效 + + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + # buffer_reanalyze_freq = 1 / 50 + buffer_reanalyze_freq = 1 / 1000000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ======== TODO: only for debug ======== + # collector_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # evaluator_env_num = 2 + # num_simulations = 1 + # reanalyze_batch_size = 2 + # num_unroll_steps = 5 + # infer_context_length = 2 + # batch_sizes = [4 for _ in range(len(env_id_list))] + + from lzero.entry import train_unizero_multitask_segment_ddp + # finetune_components = [] # load-enc-trans_finetune-head + # finetune_components = ['transformer'] # load-enc-trans_finetune-trans-head + finetune_components = ["representation_network", "encoder","transformer"] # load-enc-trans_finetune-encoder-head + + for seed in [0]: + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_sizes, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size) + pretrained_model_path = '/fs-computility/niuyazhe/tangjia/github/LightZero/ckpt/ckpt_best.pth.tar' + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, model_path=pretrained_model_path, max_env_step=max_env_step, benchmark_name="atari",finetune_components=finetune_components) + # ======== TODO: only for debug ======== + # train_unizero_multitask_segment_ddp(configs[:2], seed=seed, max_env_step=max_env_step) # train on the first four tasks + + + +# TODO(pu): only for debug,设置环境变量DEBUG=1 +# from train_grpo_rm_colocate import maybe_ipdb +# import torch.distributed as dist +# maybe_ipdb(dist.get_rank()) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_finetune_SpaceInvaders_head.py b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_finetune_SpaceInvaders_head.py new file mode 100644 index 000000000..47a7b654b --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_finetune_SpaceInvaders_head.py @@ -0,0 +1,492 @@ +from easydict import EasyDict + +import math + +def compute_batch_config(env_id_list, effective_batch_size): + n = len(env_id_list) + + # 根据环境数量设定有效 batch size 和每个环境的最大微 batch size + gpu_num = 1 + max_micro_batch_one_gpu = 400 + max_micro_batch = int(max_micro_batch_one_gpu / (n // gpu_num)) + + + # 计算每个环境理论上应该分得的 batch size + theoretical_env_batch = effective_batch_size / n + + if theoretical_env_batch > max_micro_batch: + # 当每个环境按均分的 batch 大于允许的最大微 batch 时, + # 则令每个环境的实际微 batch size 固定为 max_micro_batch + micro_batch_size = max_micro_batch + # 梯度累计步数 = ceil(每个环境理论 batch size / 最大微 batch size) + grad_accumulate_steps = math.ceil(theoretical_env_batch / max_micro_batch) + else: + # 否则直接使用计算出的理论 batch size(这里向下取整以保证整数) + micro_batch_size = int(theoretical_env_batch) + grad_accumulate_steps = 1 + + # 为每个环境分配相同的微 batch size + batch_size = [micro_batch_size for _ in range(n)] + + # 打印一些调试信息(也可以记录到 log 中) + print("环境数量: {}".format(n)) + print("有效 total batch size: {}".format(effective_batch_size)) + print("每个环境的理论 batch size: {:.2f}".format(theoretical_env_batch)) + print("每个环境的微 batch size: {}".format(micro_batch_size)) + print("梯度累积步数: {}".format(grad_accumulate_steps)) + + return batch_size, grad_accumulate_steps + + + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, + total_batch_size): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + # ===== only for debug ===== + # collect_max_episode_steps=int(40), + # eval_max_episode_steps=int(40), + ), + policy=dict( + multi_gpu=False, # Disabled for single GPU training + only_use_moco_stats=False, + use_moco=False, # ==============TODO============== + # use_moco=True, # ==============TODO============== + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + total_task_num=len(env_id_list), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + # num_channels=512, # ==============TODO============== + continuous_action_space=False, + world_model_cfg=dict( + use_global_pooling=False, + + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', # TODO: for latent state layer_norm + + # final_norm_option_in_obs_head='SimNorm', + # final_norm_option_in_encoder='SimNorm', + # predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm + + # predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm + # share_head=True, # TODO + share_head=False, # TODO + + # analysis_dormant_ratio_weight_rank=True, # TODO + analysis_dormant_ratio_weight_rank=False, # TODO + dormant_threshold=0.025, + continuous_action_space=False, + + task_embed_option=None, # ==============TODO: none ============== + use_task_embed=False, # ==============TODO============== + + # task_embed_option='concat_task_embed', # ==============TODO: none ============== + # use_task_embed=True, # ==============TODO============== + # task_embed_dim=128, + # # task_embed_dim=96, + + use_shared_projection=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + # batch_size=64 8games训练时,每张卡大约占 12*3=36G cuda显存 + # num_layers=12, + # num_heads=24, + + # num_layers=4, # TODO======= + num_layers=8, + + num_heads=24, + + # ===== only for debug ===== + # num_layers=1, + # num_heads=8, + + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + + encoder_type='vit', + # encoder_type='resnet', + + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + + moe_in_transformer=False, + # multiplication_moe_in_transformer=False, + multiplication_moe_in_transformer=True, # TODO======= + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + + # LoRA 参数: + moe_use_lora=False, # TODO + # moe_use_lora=True, # TODO + + curriculum_stage_num=curriculum_stage_num, + lora_target_modules=[], + lora_r=0, # modefied + lora_alpha=1, + lora_dropout=0.0, + lora_scale_init=1, + + min_stage0_iters=5000000, # 50k + max_stage_iters=20000, # 20k + ), + ), + use_task_exploitation_weight=False, # TODO + # use_task_exploitation_weight=True, # TODO + + target_return =target_return_dict[env_id], + balance_pipeline=True, + # task_complexity_weight=False, # TODO + task_complexity_weight=True, # TODO + + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), # TODO: DEBUG + # train_start_after_envsteps=int(2000), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=80, # TODO + # update_per_collect=2, # TODO + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + # cos_lr_scheduler=True, + cos_lr_scheduler=False, + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + # eval_freq=int(1e4), + eval_freq=int(1e4), + # eval_freq=int(2), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size): + configs = [] + # ===== only for debug ===== + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250509/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_vit-encoder-ps8_trans-nlayer8_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250509/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_no-encoder-scale_cnn-encoder_moe8_trans-nlayer8_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250514/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_vit-ln_moe8_trans-nlayer4_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' + exp_name_prefix = f'fintune_log/SpaceInvaders_head_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, + reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + + cd /cpfs04/user/puyuan/code/LightZero/ + python -m torch.distributed.launch --nproc_per_node=6 --master_port=29502 /cpfs04/user/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/20250522_cpfs/uz_mt_nlayer4_atari8_vit-small_moe8-lora_balance-totalstage5_stage-50k-20k_s0.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29502 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/20250509/uz_mt_nlayer4_atari26_vit-ln_moe8_balance-totalstage9.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_balance_atari26_vit-ln_moe8_totalstage5.log + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/20250509/uz_mt_nlayer8_atari8_vit-ln_moe8_balance-totalstage5.log + + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_balance_atari8_no-encoder-grad-scale_cnn-encoder_moe8_totalstage5_20250509.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_atari26_cnn-encoder_totalstage9_balance20250505.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_atari8_vit-base-encoder-ps8_totalstage3_balance_20250501_debug.log + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_atari26_vit-large-encoder-ps8-simnorm_totalstage5_balance20250501.log + + """ + + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import os + + + # env_id_list = [ + # 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + # 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + # ] + # # List of Atari games used for multi-task learning + # env_id_list = [ + # 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + # 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + # 'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + # 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', 'FreewayNoFrameskip-v4', + # 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', + # 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + # 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + # ] + + def get_atari_target_return_dict(ratio=1.0): + """ + 根据 Human 分数和传入的比例参数 ratio 计算每个 Atari 游戏的 target_return。 + + 参数: + ratio: 控制 target_return 大小的比例因子,默认为 1.0 + + 返回: + 包含 Atari 游戏 target_return 的字典,key 为环境名称,value 为计算后的目标分数(整数)。 + """ + human_scores = { + # 8games + 'PongNoFrameskip-v4': 14.6, # 0 + 'MsPacmanNoFrameskip-v4': 6951.6, # 1 + 'SeaquestNoFrameskip-v4': 42054.7, # 2 + 'BoxingNoFrameskip-v4': 12.1, # 3 + 'AlienNoFrameskip-v4': 7127.7, # 4 + 'ChopperCommandNoFrameskip-v4': 7387.8, # 5 + 'HeroNoFrameskip-v4': 30826.4, # 6 + 'RoadRunnerNoFrameskip-v4': 7845.0, # 7 + # 后续 Atari 26games 的额外项 + 'AmidarNoFrameskip-v4': 1719.5, # 8 + 'AssaultNoFrameskip-v4': 742.0, # 9 + 'AsterixNoFrameskip-v4': 8503.3, # 10 + 'BankHeistNoFrameskip-v4': 753.1, # 11 + 'BattleZoneNoFrameskip-v4': 37187.5, # 12 + 'CrazyClimberNoFrameskip-v4': 35829.4, # 13 + 'DemonAttackNoFrameskip-v4': 1971.0, # 14 + 'FreewayNoFrameskip-v4': 29.6, # 15 + 'FrostbiteNoFrameskip-v4': 4334.7, # 16 + 'GopherNoFrameskip-v4': 2412.5, # 17 + 'JamesbondNoFrameskip-v4': 302.8, # 18 + 'KangarooNoFrameskip-v4': 3035.0, # 19 + 'KrullNoFrameskip-v4': 2665.5, # 20 + 'KungFuMasterNoFrameskip-v4': 22736.3, # 21 + 'PrivateEyeNoFrameskip-v4': 69571.3, # 22 + 'UpNDownNoFrameskip-v4': 11693.2, # 23 + 'QbertNoFrameskip-v4': 13455.0, # 24 + 'BreakoutNoFrameskip-v4': 30.5, # 25 + } + + # target score + target_scores = { + # 8games + # 'PongNoFrameskip-v4': 14.6, # 0 expert + 'PongNoFrameskip-v4': 20, # 0 expert + # 'MsPacmanNoFrameskip-v4': 1500.6, # 1 + 'MsPacmanNoFrameskip-v4': 6951.6, # 1 + # 'SeaquestNoFrameskip-v4': 1000.7, # 2 + 'SeaquestNoFrameskip-v4': 42054.7, # 2 expert + 'BoxingNoFrameskip-v4': 12.1, # 3 expert + # 'AlienNoFrameskip-v4': 1000.7, # 4 + 'AlienNoFrameskip-v4': 7127.7, # 4 expert + # 'ChopperCommandNoFrameskip-v4': 3000.8, # 5 + # 'HeroNoFrameskip-v4': 3082.4, # 6 + 'ChopperCommandNoFrameskip-v4': 7387.8, # 5 expert + 'HeroNoFrameskip-v4': 30826.4, # 6 expert + 'RoadRunnerNoFrameskip-v4': 7845.0, # 7 expert + # 后续 Atari 26games 的额外项 + 'AmidarNoFrameskip-v4': 100.5, # 8 + 'AssaultNoFrameskip-v4': 742.0, # 9 + 'AsterixNoFrameskip-v4': 1503.3, # 10 + 'BankHeistNoFrameskip-v4': 753.1, # 11 + 'BattleZoneNoFrameskip-v4': 12187.5, # 12 + 'CrazyClimberNoFrameskip-v4': 15829.4, # 13 + 'DemonAttackNoFrameskip-v4': 1971.0, # 14 + 'FreewayNoFrameskip-v4': 29.6, # 15 + 'FrostbiteNoFrameskip-v4': 334.7, # 16 + 'GopherNoFrameskip-v4': 2412.5, # 17 + 'JamesbondNoFrameskip-v4': 302.8, # 18 + 'KangarooNoFrameskip-v4': 3035.0, # 19 + 'KrullNoFrameskip-v4': 2665.5, # 20 + 'KungFuMasterNoFrameskip-v4': 12736.3, # 21 + 'PrivateEyeNoFrameskip-v4': 1001.3, # 22 + 'UpNDownNoFrameskip-v4': 11693.2, # 23 + 'QbertNoFrameskip-v4': 13455.0, # 24 + 'BreakoutNoFrameskip-v4': 30.5, # 25 + # --- 经典射击与反应 --- + 'SpaceInvadersNoFrameskip-v4': 1668, + 'RiverRaidNoFrameskip-v4' : 17117.1, + 'BeamRiderNoFrameskip-v4' : 16926.5, + + # --- 物理与惯性控制 --- + 'AsteroidsNoFrameskip-v4' : 47388.7, + 'GravitarNoFrameskip-v4' : 3351.4, + + # --- 探索与长时序规划 (Hard-Exploration) --- + 'PitfallNoFrameskip-v4' : 6463.7, + 'AdventureNoFrameskip-v4' : 0.0, + 'EnduroNoFrameskip-v4' : 860.5, # 长时程任务,有昼夜变化,考验模型的耐力和持续表现 + } + + + # 计算每个游戏的 target_return + # return {env: int(round(score * ratio)) for env, score in human_scores.items()} + return {env: int(round(score * ratio)) for env, score in target_scores.items()} + + + global target_return_dict + # global BENCHMARK_NAME + # BENCHMARK_NAME='atari' + + # 示例:以 ratio=1 使用 + target_return_dict = get_atari_target_return_dict(ratio=1) + # target_return_dict = get_atari_target_return_dict(ratio=0.5) + num_games = 1 # 26 # 8 + + # 分别定义 Atari 游戏列表(8games 和 26games) + if num_games==1: + env_id_list = [ + 'SpaceInvadersNoFrameskip-v4' + ] + + if num_games==3: + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4' + ] + elif num_games==8: + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + ] + elif num_games==26: + # List of Atari games used for multi-task learning + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + 'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', 'FreewayNoFrameskip-v4', + 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', + 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + ] + + global curriculum_stage_num + # TODO ============== + # curriculum_stage_num=3 + curriculum_stage_num=1 + # curriculum_stage_num=9 + + action_space_size = 18 + collector_env_num = 8 + num_segments = collector_env_num + n_episode = 8 + evaluator_env_num = 3 + # num_simulations = 50 + num_simulations = 25 + + max_env_step = int(4e5) + reanalyze_ratio = 0.0 + if len(env_id_list) == 1: + effective_batch_size = 64 + elif len(env_id_list) == 8: + effective_batch_size = 512 + elif len(env_id_list) == 26: + # effective_batch_size = 832 # cnn-encoder + effective_batch_size = 512 # base-vit-encoder + # effective_batch_size = 256 # base-vit-encoder large-vit-encoder + elif len(env_id_list) == 18: + effective_batch_size = 512 * 3 # 1536 + else: + raise ValueError("不支持的环境数量: {}".format(n)) + + batch_sizes, grad_acc_steps = compute_batch_config(env_id_list, effective_batch_size) + total_batch_size = effective_batch_size # 当前无效 + + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + # buffer_reanalyze_freq = 1 / 50 + buffer_reanalyze_freq = 1 / 1000000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ======== TODO: only for debug ======== + # collector_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # evaluator_env_num = 2 + # num_simulations = 1 + # reanalyze_batch_size = 2 + # num_unroll_steps = 5 + # infer_context_length = 2 + # batch_sizes = [4 for _ in range(len(env_id_list))] + + from lzero.entry import train_unizero_multitask_segment_ddp + # finetune_components = [] # load-enc-trans_finetune-head + # finetune_components = ['transformer'] # load-enc-trans_finetune-trans-head + finetune_components = [] # load-enc-trans_finetune-encoder-head + + for seed in [3]: + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_sizes, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size) + pretrained_model_path = '/fs-computility/niuyazhe/tangjia/github/LightZero/ckpt/ckpt_best.pth.tar' + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, model_path=pretrained_model_path, max_env_step=max_env_step, benchmark_name="atari",finetune_components=finetune_components) + # ======== TODO: only for debug ======== + # train_unizero_multitask_segment_ddp(configs[:2], seed=seed, max_env_step=max_env_step) # train on the first four tasks + + + +# TODO(pu): only for debug,设置环境变量DEBUG=1 +# from train_grpo_rm_colocate import maybe_ipdb +# import torch.distributed as dist +# maybe_ipdb(dist.get_rank()) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_finetune_SpaceInvaders_head_back_encoder_lora.py b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_finetune_SpaceInvaders_head_back_encoder_lora.py new file mode 100644 index 000000000..4ef9c715e --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_finetune_SpaceInvaders_head_back_encoder_lora.py @@ -0,0 +1,494 @@ +from easydict import EasyDict + +import math + +def compute_batch_config(env_id_list, effective_batch_size): + n = len(env_id_list) + + # 根据环境数量设定有效 batch size 和每个环境的最大微 batch size + gpu_num = 1 + max_micro_batch_one_gpu = 400 + max_micro_batch = int(max_micro_batch_one_gpu / (n // gpu_num)) + + + # 计算每个环境理论上应该分得的 batch size + theoretical_env_batch = effective_batch_size / n + + if theoretical_env_batch > max_micro_batch: + # 当每个环境按均分的 batch 大于允许的最大微 batch 时, + # 则令每个环境的实际微 batch size 固定为 max_micro_batch + micro_batch_size = max_micro_batch + # 梯度累计步数 = ceil(每个环境理论 batch size / 最大微 batch size) + grad_accumulate_steps = math.ceil(theoretical_env_batch / max_micro_batch) + else: + # 否则直接使用计算出的理论 batch size(这里向下取整以保证整数) + micro_batch_size = int(theoretical_env_batch) + grad_accumulate_steps = 1 + + # 为每个环境分配相同的微 batch size + batch_size = [micro_batch_size for _ in range(n)] + + # 打印一些调试信息(也可以记录到 log 中) + print("环境数量: {}".format(n)) + print("有效 total batch size: {}".format(effective_batch_size)) + print("每个环境的理论 batch size: {:.2f}".format(theoretical_env_batch)) + print("每个环境的微 batch size: {}".format(micro_batch_size)) + print("梯度累积步数: {}".format(grad_accumulate_steps)) + + return batch_size, grad_accumulate_steps + + + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, + total_batch_size): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + # ===== only for debug ===== + # collect_max_episode_steps=int(40), + # eval_max_episode_steps=int(40), + ), + policy=dict( + multi_gpu=False, # Disabled for single GPU training + only_use_moco_stats=False, + use_moco=False, # ==============TODO============== + # use_moco=True, # ==============TODO============== + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + total_task_num=len(env_id_list), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + # num_channels=512, # ==============TODO============== + continuous_action_space=False, + world_model_cfg=dict( + use_global_pooling=False, + + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', # TODO: for latent state layer_norm + + # final_norm_option_in_obs_head='SimNorm', + # final_norm_option_in_encoder='SimNorm', + # predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm + + # predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm + # share_head=True, # TODO + share_head=False, # TODO + + # analysis_dormant_ratio_weight_rank=True, # TODO + analysis_dormant_ratio_weight_rank=False, # TODO + dormant_threshold=0.025, + continuous_action_space=False, + + task_embed_option=None, # ==============TODO: none ============== + use_task_embed=False, # ==============TODO============== + + # task_embed_option='concat_task_embed', # ==============TODO: none ============== + # use_task_embed=True, # ==============TODO============== + # task_embed_dim=128, + # # task_embed_dim=96, + + use_shared_projection=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + # batch_size=64 8games训练时,每张卡大约占 12*3=36G cuda显存 + # num_layers=12, + # num_heads=24, + + # num_layers=4, # TODO======= + num_layers=8, + + num_heads=24, + + # ===== only for debug ===== + # num_layers=1, + # num_heads=8, + + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + + encoder_type='vit', + # encoder_type='resnet', + + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + + moe_in_transformer=False, + # multiplication_moe_in_transformer=False, + multiplication_moe_in_transformer=True, # TODO======= + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + + # LoRA 参数: + moe_use_lora=True, # TODO + # moe_use_lora=True, # TODO + + curriculum_stage_num=curriculum_stage_num, + lora_target_modules=["attn"], + lora_r=64, # modefied + lora_alpha=1, + lora_dropout=0.0, + lora_scale_init=1, + + min_stage0_iters=0, # 50k + max_stage_iters=200000000, # 20k + + encoder_lora_r=64, + encoder_lora_alpha=1, + encoder_lora_dropout=0.1, + ), + ), + use_task_exploitation_weight=False, # TODO + # use_task_exploitation_weight=True, # TODO + + target_return =target_return_dict[env_id], + balance_pipeline=True, + # task_complexity_weight=False, # TODO + task_complexity_weight=True, # TODO + + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), # TODO: DEBUG + # train_start_after_envsteps=int(2000), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=80, # TODO + # update_per_collect=2, # TODO + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + # cos_lr_scheduler=True, + cos_lr_scheduler=False, + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + # eval_freq=int(1e4), + eval_freq=int(1e4), + # eval_freq=int(2), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size): + configs = [] + # ===== only for debug ===== + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250509/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_vit-encoder-ps8_trans-nlayer8_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250509/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_no-encoder-scale_cnn-encoder_moe8_trans-nlayer8_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250514/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_vit-ln_moe8_trans-nlayer4_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' + exp_name_prefix = f'fintune_log/SpaceInvaders_head_backlora_encoder_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, + reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + + cd /cpfs04/user/puyuan/code/LightZero/ + python -m torch.distributed.launch --nproc_per_node=6 --master_port=29502 /cpfs04/user/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/20250522_cpfs/uz_mt_nlayer4_atari8_vit-small_moe8-lora_balance-totalstage5_stage-50k-20k_s0.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29502 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/20250509/uz_mt_nlayer4_atari26_vit-ln_moe8_balance-totalstage9.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_balance_atari26_vit-ln_moe8_totalstage5.log + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/20250509/uz_mt_nlayer8_atari8_vit-ln_moe8_balance-totalstage5.log + + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_balance_atari8_no-encoder-grad-scale_cnn-encoder_moe8_totalstage5_20250509.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_atari26_cnn-encoder_totalstage9_balance20250505.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_atari8_vit-base-encoder-ps8_totalstage3_balance_20250501_debug.log + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_atari26_vit-large-encoder-ps8-simnorm_totalstage5_balance20250501.log + + """ + + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import os + + + # env_id_list = [ + # 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + # 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + # ] + # # List of Atari games used for multi-task learning + # env_id_list = [ + # 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + # 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + # 'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + # 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', 'FreewayNoFrameskip-v4', + # 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', + # 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + # 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + # ] + + def get_atari_target_return_dict(ratio=1.0): + """ + 根据 Human 分数和传入的比例参数 ratio 计算每个 Atari 游戏的 target_return。 + + 参数: + ratio: 控制 target_return 大小的比例因子,默认为 1.0 + + 返回: + 包含 Atari 游戏 target_return 的字典,key 为环境名称,value 为计算后的目标分数(整数)。 + """ + human_scores = { + # 8games + 'PongNoFrameskip-v4': 14.6, # 0 + 'MsPacmanNoFrameskip-v4': 6951.6, # 1 + 'SeaquestNoFrameskip-v4': 42054.7, # 2 + 'BoxingNoFrameskip-v4': 12.1, # 3 + 'AlienNoFrameskip-v4': 7127.7, # 4 + 'ChopperCommandNoFrameskip-v4': 7387.8, # 5 + 'HeroNoFrameskip-v4': 30826.4, # 6 + 'RoadRunnerNoFrameskip-v4': 7845.0, # 7 + # 后续 Atari 26games 的额外项 + 'AmidarNoFrameskip-v4': 1719.5, # 8 + 'AssaultNoFrameskip-v4': 742.0, # 9 + 'AsterixNoFrameskip-v4': 8503.3, # 10 + 'BankHeistNoFrameskip-v4': 753.1, # 11 + 'BattleZoneNoFrameskip-v4': 37187.5, # 12 + 'CrazyClimberNoFrameskip-v4': 35829.4, # 13 + 'DemonAttackNoFrameskip-v4': 1971.0, # 14 + 'FreewayNoFrameskip-v4': 29.6, # 15 + 'FrostbiteNoFrameskip-v4': 4334.7, # 16 + 'GopherNoFrameskip-v4': 2412.5, # 17 + 'JamesbondNoFrameskip-v4': 302.8, # 18 + 'KangarooNoFrameskip-v4': 3035.0, # 19 + 'KrullNoFrameskip-v4': 2665.5, # 20 + 'KungFuMasterNoFrameskip-v4': 22736.3, # 21 + 'PrivateEyeNoFrameskip-v4': 69571.3, # 22 + 'UpNDownNoFrameskip-v4': 11693.2, # 23 + 'QbertNoFrameskip-v4': 13455.0, # 24 + 'BreakoutNoFrameskip-v4': 30.5, # 25 + } + + # target score + target_scores = { + # 8games + # 'PongNoFrameskip-v4': 14.6, # 0 expert + 'PongNoFrameskip-v4': 20, # 0 expert + # 'MsPacmanNoFrameskip-v4': 1500.6, # 1 + 'MsPacmanNoFrameskip-v4': 6951.6, # 1 + # 'SeaquestNoFrameskip-v4': 1000.7, # 2 + 'SeaquestNoFrameskip-v4': 42054.7, # 2 expert + 'BoxingNoFrameskip-v4': 12.1, # 3 expert + # 'AlienNoFrameskip-v4': 1000.7, # 4 + 'AlienNoFrameskip-v4': 7127.7, # 4 expert + # 'ChopperCommandNoFrameskip-v4': 3000.8, # 5 + # 'HeroNoFrameskip-v4': 3082.4, # 6 + 'ChopperCommandNoFrameskip-v4': 7387.8, # 5 expert + 'HeroNoFrameskip-v4': 30826.4, # 6 expert + 'RoadRunnerNoFrameskip-v4': 7845.0, # 7 expert + # 后续 Atari 26games 的额外项 + 'AmidarNoFrameskip-v4': 100.5, # 8 + 'AssaultNoFrameskip-v4': 742.0, # 9 + 'AsterixNoFrameskip-v4': 1503.3, # 10 + 'BankHeistNoFrameskip-v4': 753.1, # 11 + 'BattleZoneNoFrameskip-v4': 12187.5, # 12 + 'CrazyClimberNoFrameskip-v4': 15829.4, # 13 + 'DemonAttackNoFrameskip-v4': 1971.0, # 14 + 'FreewayNoFrameskip-v4': 29.6, # 15 + 'FrostbiteNoFrameskip-v4': 334.7, # 16 + 'GopherNoFrameskip-v4': 2412.5, # 17 + 'JamesbondNoFrameskip-v4': 302.8, # 18 + 'KangarooNoFrameskip-v4': 3035.0, # 19 + 'KrullNoFrameskip-v4': 2665.5, # 20 + 'KungFuMasterNoFrameskip-v4': 12736.3, # 21 + 'PrivateEyeNoFrameskip-v4': 1001.3, # 22 + 'UpNDownNoFrameskip-v4': 11693.2, # 23 + 'QbertNoFrameskip-v4': 13455.0, # 24 + 'BreakoutNoFrameskip-v4': 30.5, # 25 + # --- 经典射击与反应 --- + 'SpaceInvadersNoFrameskip-v4': 1669.7, + 'RiverRaidNoFrameskip-v4' : 17117.1, + 'BeamRiderNoFrameskip-v4' : 16926.5, + + # --- 物理与惯性控制 --- + 'AsteroidsNoFrameskip-v4' : 47388.7, + 'GravitarNoFrameskip-v4' : 3351.4, + + # --- 探索与长时序规划 (Hard-Exploration) --- + 'PitfallNoFrameskip-v4' : 6463.7, + 'AdventureNoFrameskip-v4' : 0.0, + 'EnduroNoFrameskip-v4' : 860.5, # 长时程任务,有昼夜变化,考验模型的耐力和持续表现 + } + + + # 计算每个游戏的 target_return + # return {env: int(round(score * ratio)) for env, score in human_scores.items()} + return {env: int(round(score * ratio)) for env, score in target_scores.items()} + + + global target_return_dict + # global BENCHMARK_NAME + # BENCHMARK_NAME='atari' + + # 示例:以 ratio=1 使用 + target_return_dict = get_atari_target_return_dict(ratio=1) + # target_return_dict = get_atari_target_return_dict(ratio=0.5) + num_games = 1 # 26 # 8 + + # 分别定义 Atari 游戏列表(8games 和 26games) + if num_games==1: + env_id_list = [ + 'SpaceInvadersNoFrameskip-v4' + ] + + if num_games==3: + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4' + ] + elif num_games==8: + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + ] + elif num_games==26: + # List of Atari games used for multi-task learning + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + 'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', 'FreewayNoFrameskip-v4', + 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', + 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + ] + + global curriculum_stage_num + # TODO ============== + # curriculum_stage_num=3 + curriculum_stage_num=1 + # curriculum_stage_num=9 + + action_space_size = 18 + collector_env_num = 2 + num_segments = collector_env_num + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 25 + max_env_step = int(4e5) + reanalyze_ratio = 0.0 + if len(env_id_list) == 1: + effective_batch_size = 64 + elif len(env_id_list) == 8: + effective_batch_size = 512 + elif len(env_id_list) == 26: + # effective_batch_size = 832 # cnn-encoder + effective_batch_size = 512 # base-vit-encoder + # effective_batch_size = 256 # base-vit-encoder large-vit-encoder + elif len(env_id_list) == 18: + effective_batch_size = 512 * 3 # 1536 + else: + raise ValueError("不支持的环境数量: {}".format(n)) + + batch_sizes, grad_acc_steps = compute_batch_config(env_id_list, effective_batch_size) + total_batch_size = effective_batch_size # 当前无效 + + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + # buffer_reanalyze_freq = 1 / 50 + buffer_reanalyze_freq = 1 / 1000000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ======== TODO: only for debug ======== + # collector_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # evaluator_env_num = 2 + # num_simulations = 1 + # reanalyze_batch_size = 2 + # num_unroll_steps = 5 + # infer_context_length = 2 + # batch_sizes = [4 for _ in range(len(env_id_list))] + + from lzero.entry import train_unizero_multitask_segment_ddp + # finetune_components = [] # load-enc-trans_finetune-head + # finetune_components = ['transformer'] # load-enc-trans_finetune-trans-head + finetune_components = ["representation_network","encoder"] # load-enc-trans_finetune-encoder-head + + for seed in [0]: + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_sizes, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size) + pretrained_model_path = '/fs-computility/niuyazhe/tangjia/github/LightZero/ckpt/ckpt_best.pth.tar' + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, model_path=pretrained_model_path, max_env_step=max_env_step, benchmark_name="atari",finetune_components=finetune_components) + # ======== TODO: only for debug ======== + # train_unizero_multitask_segment_ddp(configs[:2], seed=seed, max_env_step=max_env_step) # train on the first four tasks + + + +# TODO(pu): only for debug,设置环境变量DEBUG=1 +# from train_grpo_rm_colocate import maybe_ipdb +# import torch.distributed as dist +# maybe_ipdb(dist.get_rank()) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_finetune_SpaceInvaders_head_back_lora.py b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_finetune_SpaceInvaders_head_back_lora.py new file mode 100644 index 000000000..6fee9798b --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_finetune_SpaceInvaders_head_back_lora.py @@ -0,0 +1,490 @@ +from easydict import EasyDict + +import math + +def compute_batch_config(env_id_list, effective_batch_size): + n = len(env_id_list) + + # 根据环境数量设定有效 batch size 和每个环境的最大微 batch size + gpu_num = 1 + max_micro_batch_one_gpu = 400 + max_micro_batch = int(max_micro_batch_one_gpu / (n // gpu_num)) + + + # 计算每个环境理论上应该分得的 batch size + theoretical_env_batch = effective_batch_size / n + + if theoretical_env_batch > max_micro_batch: + # 当每个环境按均分的 batch 大于允许的最大微 batch 时, + # 则令每个环境的实际微 batch size 固定为 max_micro_batch + micro_batch_size = max_micro_batch + # 梯度累计步数 = ceil(每个环境理论 batch size / 最大微 batch size) + grad_accumulate_steps = math.ceil(theoretical_env_batch / max_micro_batch) + else: + # 否则直接使用计算出的理论 batch size(这里向下取整以保证整数) + micro_batch_size = int(theoretical_env_batch) + grad_accumulate_steps = 1 + + # 为每个环境分配相同的微 batch size + batch_size = [micro_batch_size for _ in range(n)] + + # 打印一些调试信息(也可以记录到 log 中) + print("环境数量: {}".format(n)) + print("有效 total batch size: {}".format(effective_batch_size)) + print("每个环境的理论 batch size: {:.2f}".format(theoretical_env_batch)) + print("每个环境的微 batch size: {}".format(micro_batch_size)) + print("梯度累积步数: {}".format(grad_accumulate_steps)) + + return batch_size, grad_accumulate_steps + + + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, + total_batch_size): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + # ===== only for debug ===== + # collect_max_episode_steps=int(40), + # eval_max_episode_steps=int(40), + ), + policy=dict( + multi_gpu=False, # Disabled for single GPU training + only_use_moco_stats=False, + use_moco=False, # ==============TODO============== + # use_moco=True, # ==============TODO============== + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + total_task_num=len(env_id_list), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + # num_channels=512, # ==============TODO============== + continuous_action_space=False, + world_model_cfg=dict( + use_global_pooling=False, + + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', # TODO: for latent state layer_norm + + # final_norm_option_in_obs_head='SimNorm', + # final_norm_option_in_encoder='SimNorm', + # predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm + + # predict_latent_loss_type='group_kl', # TODO: only for latent state sim_norm + # share_head=True, # TODO + share_head=False, # TODO + + # analysis_dormant_ratio_weight_rank=True, # TODO + analysis_dormant_ratio_weight_rank=False, # TODO + dormant_threshold=0.025, + continuous_action_space=False, + + task_embed_option=None, # ==============TODO: none ============== + use_task_embed=False, # ==============TODO============== + + # task_embed_option='concat_task_embed', # ==============TODO: none ============== + # use_task_embed=True, # ==============TODO============== + # task_embed_dim=128, + # # task_embed_dim=96, + + use_shared_projection=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + # batch_size=64 8games训练时,每张卡大约占 12*3=36G cuda显存 + # num_layers=12, + # num_heads=24, + + # num_layers=4, # TODO======= + num_layers=8, + + num_heads=24, + + # ===== only for debug ===== + # num_layers=1, + # num_heads=8, + + embed_dim=768, + obs_type='image', + env_num=8, + task_num=len(env_id_list), + + encoder_type='vit', + # encoder_type='resnet', + + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + + moe_in_transformer=False, + # multiplication_moe_in_transformer=False, + multiplication_moe_in_transformer=True, # TODO======= + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + + # LoRA 参数: + moe_use_lora=True, # TODO + # moe_use_lora=True, # TODO + + curriculum_stage_num=curriculum_stage_num, + lora_target_modules=["attn"], + lora_r=64, # modefied + lora_alpha=1, + lora_dropout=0.0, + lora_scale_init=1, + + min_stage0_iters=1000000, # 50k + max_stage_iters=100000, # 20k + ), + ), + use_task_exploitation_weight=False, # TODO + # use_task_exploitation_weight=True, # TODO + + target_return =target_return_dict[env_id], + balance_pipeline=True, + # task_complexity_weight=False, # TODO + task_complexity_weight=True, # TODO + + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), # TODO: DEBUG + # train_start_after_envsteps=int(2000), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=80, # TODO + # update_per_collect=2, # TODO + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + # cos_lr_scheduler=True, + cos_lr_scheduler=False, + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + # eval_freq=int(1e4), + eval_freq=int(1e4), + # eval_freq=int(2), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size): + configs = [] + # ===== only for debug ===== + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250509/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_vit-encoder-ps8_trans-nlayer8_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250509/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_no-encoder-scale_cnn-encoder_moe8_trans-nlayer8_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' + # exp_name_prefix = f'data_lz/data_unizero_atari_mt_balance_20250514/atari_{len(env_id_list)}games_balance-total-stage{curriculum_stage_num}_vit-ln_moe8_trans-nlayer4_brf{buffer_reanalyze_freq}_not-share-head_seed{seed}/' + exp_name_prefix = f'fintune_log/SpaceInvaders_head_backlora_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, + reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + + cd /cpfs04/user/puyuan/code/LightZero/ + python -m torch.distributed.launch --nproc_per_node=6 --master_port=29502 /cpfs04/user/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/20250522_cpfs/uz_mt_nlayer4_atari8_vit-small_moe8-lora_balance-totalstage5_stage-50k-20k_s0.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29502 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/20250509/uz_mt_nlayer4_atari26_vit-ln_moe8_balance-totalstage9.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_balance_atari26_vit-ln_moe8_totalstage5.log + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/20250509/uz_mt_nlayer8_atari8_vit-ln_moe8_balance-totalstage5.log + + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_balance_atari8_no-encoder-grad-scale_cnn-encoder_moe8_totalstage5_20250509.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_atari26_cnn-encoder_totalstage9_balance20250505.log + + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_atari8_vit-base-encoder-ps8_totalstage3_balance_20250501_debug.log + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29503 /fs-computility/ai-shen/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py 2>&1 | tee ./log/uz_mt_atari26_vit-large-encoder-ps8-simnorm_totalstage5_balance20250501.log + + """ + + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import os + + + # env_id_list = [ + # 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + # 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + # ] + # # List of Atari games used for multi-task learning + # env_id_list = [ + # 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + # 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + # 'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + # 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', 'FreewayNoFrameskip-v4', + # 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', + # 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + # 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + # ] + + def get_atari_target_return_dict(ratio=1.0): + """ + 根据 Human 分数和传入的比例参数 ratio 计算每个 Atari 游戏的 target_return。 + + 参数: + ratio: 控制 target_return 大小的比例因子,默认为 1.0 + + 返回: + 包含 Atari 游戏 target_return 的字典,key 为环境名称,value 为计算后的目标分数(整数)。 + """ + human_scores = { + # 8games + 'PongNoFrameskip-v4': 14.6, # 0 + 'MsPacmanNoFrameskip-v4': 6951.6, # 1 + 'SeaquestNoFrameskip-v4': 42054.7, # 2 + 'BoxingNoFrameskip-v4': 12.1, # 3 + 'AlienNoFrameskip-v4': 7127.7, # 4 + 'ChopperCommandNoFrameskip-v4': 7387.8, # 5 + 'HeroNoFrameskip-v4': 30826.4, # 6 + 'RoadRunnerNoFrameskip-v4': 7845.0, # 7 + # 后续 Atari 26games 的额外项 + 'AmidarNoFrameskip-v4': 1719.5, # 8 + 'AssaultNoFrameskip-v4': 742.0, # 9 + 'AsterixNoFrameskip-v4': 8503.3, # 10 + 'BankHeistNoFrameskip-v4': 753.1, # 11 + 'BattleZoneNoFrameskip-v4': 37187.5, # 12 + 'CrazyClimberNoFrameskip-v4': 35829.4, # 13 + 'DemonAttackNoFrameskip-v4': 1971.0, # 14 + 'FreewayNoFrameskip-v4': 29.6, # 15 + 'FrostbiteNoFrameskip-v4': 4334.7, # 16 + 'GopherNoFrameskip-v4': 2412.5, # 17 + 'JamesbondNoFrameskip-v4': 302.8, # 18 + 'KangarooNoFrameskip-v4': 3035.0, # 19 + 'KrullNoFrameskip-v4': 2665.5, # 20 + 'KungFuMasterNoFrameskip-v4': 22736.3, # 21 + 'PrivateEyeNoFrameskip-v4': 69571.3, # 22 + 'UpNDownNoFrameskip-v4': 11693.2, # 23 + 'QbertNoFrameskip-v4': 13455.0, # 24 + 'BreakoutNoFrameskip-v4': 30.5, # 25 + } + + # target score + target_scores = { + # 8games + # 'PongNoFrameskip-v4': 14.6, # 0 expert + 'PongNoFrameskip-v4': 20, # 0 expert + # 'MsPacmanNoFrameskip-v4': 1500.6, # 1 + 'MsPacmanNoFrameskip-v4': 6951.6, # 1 + # 'SeaquestNoFrameskip-v4': 1000.7, # 2 + 'SeaquestNoFrameskip-v4': 42054.7, # 2 expert + 'BoxingNoFrameskip-v4': 12.1, # 3 expert + # 'AlienNoFrameskip-v4': 1000.7, # 4 + 'AlienNoFrameskip-v4': 7127.7, # 4 expert + # 'ChopperCommandNoFrameskip-v4': 3000.8, # 5 + # 'HeroNoFrameskip-v4': 3082.4, # 6 + 'ChopperCommandNoFrameskip-v4': 7387.8, # 5 expert + 'HeroNoFrameskip-v4': 30826.4, # 6 expert + 'RoadRunnerNoFrameskip-v4': 7845.0, # 7 expert + # 后续 Atari 26games 的额外项 + 'AmidarNoFrameskip-v4': 100.5, # 8 + 'AssaultNoFrameskip-v4': 742.0, # 9 + 'AsterixNoFrameskip-v4': 1503.3, # 10 + 'BankHeistNoFrameskip-v4': 753.1, # 11 + 'BattleZoneNoFrameskip-v4': 12187.5, # 12 + 'CrazyClimberNoFrameskip-v4': 15829.4, # 13 + 'DemonAttackNoFrameskip-v4': 1971.0, # 14 + 'FreewayNoFrameskip-v4': 29.6, # 15 + 'FrostbiteNoFrameskip-v4': 334.7, # 16 + 'GopherNoFrameskip-v4': 2412.5, # 17 + 'JamesbondNoFrameskip-v4': 302.8, # 18 + 'KangarooNoFrameskip-v4': 3035.0, # 19 + 'KrullNoFrameskip-v4': 2665.5, # 20 + 'KungFuMasterNoFrameskip-v4': 12736.3, # 21 + 'PrivateEyeNoFrameskip-v4': 1001.3, # 22 + 'UpNDownNoFrameskip-v4': 11693.2, # 23 + 'QbertNoFrameskip-v4': 13455.0, # 24 + 'BreakoutNoFrameskip-v4': 30.5, # 25 + # --- 经典射击与反应 --- + 'SpaceInvadersNoFrameskip-v4': 1669.7, + 'RiverRaidNoFrameskip-v4' : 17117.1, + 'BeamRiderNoFrameskip-v4' : 16926.5, + + # --- 物理与惯性控制 --- + 'AsteroidsNoFrameskip-v4' : 47388.7, + 'GravitarNoFrameskip-v4' : 3351.4, + + # --- 探索与长时序规划 (Hard-Exploration) --- + 'PitfallNoFrameskip-v4' : 6463.7, + 'AdventureNoFrameskip-v4' : 0.0, + 'EnduroNoFrameskip-v4' : 860.5, # 长时程任务,有昼夜变化,考验模型的耐力和持续表现 + } + + + # 计算每个游戏的 target_return + # return {env: int(round(score * ratio)) for env, score in human_scores.items()} + return {env: int(round(score * ratio)) for env, score in target_scores.items()} + + + global target_return_dict + # global BENCHMARK_NAME + # BENCHMARK_NAME='atari' + + # 示例:以 ratio=1 使用 + target_return_dict = get_atari_target_return_dict(ratio=1) + # target_return_dict = get_atari_target_return_dict(ratio=0.5) + num_games = 1 # 26 # 8 + + # 分别定义 Atari 游戏列表(8games 和 26games) + if num_games==1: + env_id_list = [ + 'SpaceInvadersNoFrameskip-v4' + ] + + if num_games==3: + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4' + ] + elif num_games==8: + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + ] + elif num_games==26: + # List of Atari games used for multi-task learning + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + 'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', 'FreewayNoFrameskip-v4', + 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', + 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + ] + + global curriculum_stage_num + # TODO ============== + # curriculum_stage_num=3 + curriculum_stage_num=5 + # curriculum_stage_num=9 + + action_space_size = 18 + collector_env_num = 8 + num_segments = collector_env_num + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 25 + max_env_step = int(4e5) + reanalyze_ratio = 0.0 + if len(env_id_list) == 1: + effective_batch_size = 64 + elif len(env_id_list) == 8: + effective_batch_size = 512 + elif len(env_id_list) == 26: + # effective_batch_size = 832 # cnn-encoder + effective_batch_size = 512 # base-vit-encoder + # effective_batch_size = 256 # base-vit-encoder large-vit-encoder + elif len(env_id_list) == 18: + effective_batch_size = 512 * 3 # 1536 + else: + raise ValueError("不支持的环境数量: {}".format(n)) + + batch_sizes, grad_acc_steps = compute_batch_config(env_id_list, effective_batch_size) + total_batch_size = effective_batch_size # 当前无效 + + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + # buffer_reanalyze_freq = 1 / 50 + buffer_reanalyze_freq = 1 / 1000000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ======== TODO: only for debug ======== + # collector_env_num = 2 + # num_segments = 2 + # n_episode = 2 + # evaluator_env_num = 2 + # num_simulations = 1 + # reanalyze_batch_size = 2 + # num_unroll_steps = 5 + # infer_context_length = 2 + # batch_sizes = [4 for _ in range(len(env_id_list))] + + from lzero.entry import train_unizero_multitask_segment_ddp + # finetune_components = [] # load-enc-trans_finetune-head + # finetune_components = ['transformer'] # load-enc-trans_finetune-trans-head + finetune_components = [] # load-enc-trans_finetune-encoder-head + + for seed in [0]: + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_sizes, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size) + pretrained_model_path = '/fs-computility/niuyazhe/tangjia/github/LightZero/ckpt/ckpt_best.pth.tar' + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, model_path=pretrained_model_path, max_env_step=max_env_step, benchmark_name="atari",finetune_components=finetune_components) + # ======== TODO: only for debug ======== + # train_unizero_multitask_segment_ddp(configs[:2], seed=seed, max_env_step=max_env_step) # train on the first four tasks + + + +# TODO(pu): only for debug,设置环境变量DEBUG=1 +# from train_grpo_rm_colocate import maybe_ipdb +# import torch.distributed as dist +# maybe_ipdb(dist.get_rank()) \ No newline at end of file diff --git a/zoo/atari/config/test.py b/zoo/atari/config/test.py new file mode 100644 index 000000000..6949e5f17 --- /dev/null +++ b/zoo/atari/config/test.py @@ -0,0 +1,246 @@ +from easydict import EasyDict + +main_config = dict( + exp_name='data_unizero_atari_mt_balance_20250625/atari_8games_balance-total-stage5_stage-50k-20k_fix-lora-update-stablescale_vit-small-ln_moe8-lora_trans-nlayer4_brf1e-06_not-share-head_seed0/Pong_seed0', + env=dict( + manager=dict( + episode_num=float('inf'), + max_retry=1, + step_timeout=None, + auto_reset=True, + reset_timeout=None, + retry_type='reset', + retry_waiting_time=0.1, + shared_memory=False, + copy_on_get=True, + context='fork', + wait_num=float('inf'), + step_wait_timeout=None, + connect_timeout=60, + reset_inplace=False, + cfg_type='SyncSubprocessEnvManagerDict', + type='subprocess', + ), + stop_value=1000000, + n_evaluator_episode=3, + full_action_space=True, + collector_env_num=8, + evaluator_env_num=3, + env_type='Atari', + observation_shape=[3, 64, 64], + collect_max_episode_steps=5000, + eval_max_episode_steps=5000, + render_mode_human=False, + save_replay=False, + replay_path=None, + gray_scale=False, + frame_stack_num=1, + frame_skip=4, + episode_life=True, + clip_rewards=True, + channel_last=False, + scale=True, + warp_frame=True, + transform2string=False, + game_wrapper=True, + cfg_type='AtariEnvLightZeroDict', + env_id='PongNoFrameskip-v4', + ), + policy=dict( + model=dict( + model_type='conv', + continuous_action_space=False, + observation_shape=[3, 64, 64], + self_supervised_learning_loss=True, + categorical_distribution=True, + image_channel=3, + frame_stack_num=1, + num_res_blocks=2, + num_channels=256, + support_scale=50, + bias=True, + discrete_action_encoding_type='one_hot', + res_connection_in_dynamics=True, + norm_type='LN', + analysis_sim_norm=False, + analysis_dormant_ratio=False, + harmony_balance=False, + learn={'learner': {'hook': {'save_ckpt_after_iter': 10000}}}, + world_model_cfg={'continuous_action_space': False, 'tokens_per_block': 2, 'max_blocks': 10, 'max_tokens': 20, 'context_length': 8, 'gru_gating': False, 'device': 'cuda', 'analysis_sim_norm': False, 'analysis_dormant_ratio_weight_rank': False, 'action_space_size': 18, 'group_size': 8, 'attention': 'causal', 'num_layers': 4, 'num_heads': 24, 'embed_dim': 768, 'embed_pdrop': 0.1, 'resid_pdrop': 0.1, 'attn_pdrop': 0.1, 'support_size': 101, 'max_cache_size': 5000, 'env_num': 8, 'latent_recon_loss_weight': 0.0, 'perceptual_loss_weight': 0.0, 'policy_entropy_weight': 0.0001, 'predict_latent_loss_type': 'mse', 'obs_type': 'image', 'gamma': 1, 'dormant_threshold': 0.025, 'rotary_emb': False, 'rope_theta': 10000, 'max_seq_len': 8192, 'lora_r': 64, 'analysis_dormant_ratio': False, 'use_global_pooling': False, 'final_norm_option_in_obs_head': 'LayerNorm', 'final_norm_option_in_encoder': 'LayerNorm', 'share_head': False, 'task_embed_option': None, 'use_task_embed': False, 'use_shared_projection': False, 'task_num': 8, 'encoder_type': 'vit', 'use_normal_head': True, 'use_softmoe_head': False, 'use_moe_head': False, 'num_experts_in_moe_head': 4, 'moe_in_transformer': False, 'multiplication_moe_in_transformer': True, 'n_shared_experts': 1, 'num_experts_per_tok': 1, 'num_experts_of_moe_in_transformer': 8, 'moe_use_lora': False, 'curriculum_stage_num': 5, 'lora_target_modules': ['attn', 'feed_forward'], 'lora_alpha': 1, 'lora_dropout': 0.0, 'lora_scale_init': 1, 'min_stage0_iters': 50000, 'max_stage_iters': 20000}, + action_space_size=18, + ), + learn=dict( + learner=dict( + train_iterations=1000000000, + dataloader=dict( + num_workers=0, + ), + log_policy=True, + hook=dict( + load_ckpt_before_run='', + log_show_after_iter=100, + save_ckpt_after_iter=200000, + save_ckpt_after_run=True, + ), + cfg_type='BaseLearnerDict', + ), + ), + collect=dict( + collector=dict( + deepcopy_obs=False, + transform_obs=False, + collect_print_freq=100, + cfg_type='SampleSerialCollectorDict', + type='sample', + ), + ), + eval=dict( + evaluator=dict( + eval_freq=1000, + render={'render_freq': -1, 'mode': 'train_iter'}, + figure_path=None, + cfg_type='InteractionSerialEvaluatorDict', + stop_value=1000000, + n_episode=3, + ), + ), + other=dict( + replay_buffer=dict( + type='advanced', + replay_buffer_size=4096, + max_use=float('inf'), + max_staleness=float('inf'), + alpha=0.6, + beta=0.4, + anneal_step=100000, + enable_track_used_data=False, + deepcopy=False, + thruput_controller=dict( + push_sample_rate_limit=dict( + max=float('inf'), + min=0, + ), + window_seconds=30, + sample_min_limit_ratio=1, + ), + monitor=dict( + sampled_data_attr=dict( + average_range=5, + print_freq=200, + ), + periodic_thruput=dict( + seconds=60, + ), + ), + cfg_type='AdvancedReplayBufferDict', + ), + ), + on_policy=False, + cuda=True, + multi_gpu=True, + bp_update_sync=True, + traj_len_inf=False, + use_wandb=False, + use_rnd_model=False, + sampled_algo=False, + gumbel_algo=False, + mcts_ctree=True, + collector_env_num=8, + evaluator_env_num=3, + env_type='not_board_games', + action_type='fixed_action_space', + battle_mode='play_with_bot_mode', + monitor_extra_statistics=True, + game_segment_length=20, + eval_offline=False, + cal_dormant_ratio=False, + analysis_sim_norm=False, + analysis_dormant_ratio=False, + transform2string=False, + gray_scale=False, + use_augmentation=False, + augmentation=['shift', 'intensity'], + ignore_done=False, + update_per_collect=80, + replay_ratio=0.25, + batch_size=[64, 64, 64, 64, 64, 64, 64, 64], + optim_type='AdamW', + learning_rate=0.0001, + target_update_freq=100, + target_update_freq_for_intrinsic_reward=1000, + weight_decay=0.0001, + momentum=0.9, + grad_clip_value=5, + n_episode=8, + num_segments=8, + num_simulations=50, + discount_factor=0.997, + td_steps=5, + num_unroll_steps=10, + reward_loss_weight=1, + value_loss_weight=0.25, + policy_loss_weight=1, + policy_entropy_weight=0, + ssl_loss_weight=0, + piecewise_decay_lr_scheduler=False, + threshold_training_steps_for_final_lr=50000, + manual_temperature_decay=False, + threshold_training_steps_for_final_temperature=100000, + fixed_temperature_value=0.25, + use_ture_chance_label_in_chance_encoder=False, + reanalyze_noise=True, + reuse_search=False, + collect_with_pure_policy=False, + use_priority=False, + priority_prob_alpha=0.6, + priority_prob_beta=0.4, + root_dirichlet_alpha=0.3, + root_noise_weight=0.25, + random_collect_episode_num=0, + eps={'eps_greedy_exploration_in_collect': False, 'type': 'linear', 'start': 1.0, 'end': 0.05, 'decay': 100000}, + cfg_type='UniZeroMTPolicyDict', + eval_freq=10000, + sample_type='transition', + target_update_theta=0.05, + cos_lr_scheduler=False, + accumulation_steps=1, + train_start_after_envsteps=0, + lr_piecewise_constant_decay=False, + import_names=['lzero.policy.unizero_multitask'], + only_use_moco_stats=False, + use_moco=False, + grad_correct_params={'MoCo_beta': 0.5, 'MoCo_beta_sigma': 0.5, 'MoCo_gamma': 0.1, 'MoCo_gamma_sigma': 0.5, 'MoCo_rho': 0, 'calpha': 0.5, 'rescale': 1}, + total_task_num=8, + task_num=4, + task_id=0, + use_task_exploitation_weight=False, + target_return=20, + balance_pipeline=True, + task_complexity_weight=True, + total_batch_size=512, + allocated_batch_sizes=False, + print_task_priority_logs=False, + model_path=None, + reanalyze_ratio=0.0, + replay_buffer_size=500000, + buffer_reanalyze_freq=1e-06, + reanalyze_batch_size=160, + reanalyze_partition=0.75, + device='cuda', + ), +) +main_config = EasyDict(main_config) +main_config = main_config +create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict( + cfg_type='SyncSubprocessEnvManagerDict', + type='subprocess', + ), + policy=dict(type='unizero_multitask'), +) +create_config = EasyDict(create_config) +create_config = create_config