diff --git a/tests/api/test_session_op.py b/tests/api/test_session_op.py index 2cfeb02d..9612cb21 100644 --- a/tests/api/test_session_op.py +++ b/tests/api/test_session_op.py @@ -1,8 +1,23 @@ +import json +import shutil +from pathlib import Path + import pytest + +pytest.importorskip("fastapi") + +from fastapi import BackgroundTasks from fastapi.testclient import TestClient + from zsim.api import app +from zsim.api_src.routes import session_op from zsim.api_src.services.database.session_db import get_session_db +from zsim.define import results_dir from zsim.models.session.session_create import Session +from zsim.models.session.session_run import SessionRun +from zsim.lib_webui import process_parallel_data as webui_parallel_data +from zsim.lib_webui.process_simulator import generate_parallel_args +from zsim.utils.process_parallel_data import judge_parallel_result, merge_parallel_dmg_data client = TestClient(app) @@ -152,3 +167,160 @@ async def test_delete_session(session_data): response = client.get(f"/api/sessions/{session_data['session_id']}") assert response.status_code == 404 + + +@pytest.mark.asyncio +async def test_parallel_run_writes_config_and_merge(session_data, monkeypatch): + session_id = "parallel_test_session" + result_dir_path = Path(results_dir) / session_id + if result_dir_path.exists(): + shutil.rmtree(result_dir_path) + + db = await get_session_db() + await db.delete_session(session_id) + + session_payload = session_data.copy() + session_payload["session_id"] = session_id + session = Session(**session_payload) + await db.add_session(session) + + class DummySimController: + def generate_parallel_args(self, _session, _session_run): + return [None] + + async def put_into_queue(self, *_args, **_kwargs): + return None + + async def execute_simulation_test(self): + return [] + + async def execute_simulation(self): + return [] + + monkeypatch.setattr(session_op, "SimController", DummySimController) + + session_run_payload = { + "stop_tick": 10, + "mode": "parallel", + "common_config": { + "session_id": session_id, + "char_config": [ + {"name": "仪玄"}, + {"name": "耀嘉音"}, + {"name": "扳机"}, + ], + "enemy_config": {"index_id": 11412, "adjustment_id": 22412}, + "apl_path": "zsim/data/APLData/仪玄-耀嘉音-扳机.toml", + }, + "parallel_config": { + "enable": True, + "adjust_char": 1, + "func": "attr_curve", + "func_config": { + "sc_range": [0, 0], + "sc_list": ["scATK_percent"], + "remove_equip_list": [], + }, + }, + } + + session_run = SessionRun(**session_run_payload) + background_tasks = BackgroundTasks() + + try: + response = await session_op.run_session( + session_id, + session_run, + background_tasks, + db, + test_mode=True, + ) + assert response["code"] == 0 + + config_path = result_dir_path / ".parallel_config.json" + assert config_path.exists() + + config_data = json.loads(config_path.read_text(encoding="utf-8")) + assert config_data["enabled"] is True + assert config_data["adjust_sc"]["enabled"] is True + assert config_data["adjust_sc"]["sc_range"] == [0, 0] + assert config_data["adjust_sc"]["sc_list"] == ["scATK_percent"] + assert config_data["adjust_weapon"]["enabled"] is False + assert config_data["func_config"] == session_run_payload["parallel_config"]["func_config"] + + sub_dir = result_dir_path / "attr_curve_sample" + sub_dir.mkdir(parents=True, exist_ok=True) + (sub_dir / "sub.parallel_config.json").write_text( + json.dumps( + { + "adjust_char": "仪玄", + "sc_name": "scATK_percent", + "sc_value": 0, + }, + ensure_ascii=False, + ), + encoding="utf-8", + ) + (sub_dir / "damage_attribution.json").write_text( + json.dumps( + { + "仪玄": { + "direct_damage": 100.0, + "anomaly_damage": 50.0, + } + }, + ensure_ascii=False, + ), + encoding="utf-8", + ) + + assert judge_parallel_result(session_id) is True + + alias_config = json.loads(config_path.read_text(encoding="utf-8")) + alias_config["enable"] = alias_config.pop("enabled") + alias_config["adjust_sc"]["enable"] = alias_config["adjust_sc"].pop("enabled") + alias_config["adjust_weapon"]["enable"] = alias_config["adjust_weapon"].pop( + "enabled" + ) + config_path.write_text( + json.dumps(alias_config, indent=4, ensure_ascii=False), encoding="utf-8" + ) + + assert judge_parallel_result(session_id) is True + assert webui_parallel_data.judge_parallel_result(session_id) is True + + result = await merge_parallel_dmg_data(session_id) + assert result is not None + func, merged_data = result + assert func == "attr_curve" + assert merged_data["仪玄"]["scATK_percent"][0]["result"] == 150.0 + finally: + await db.delete_session(session_id) + if result_dir_path.exists(): + shutil.rmtree(result_dir_path) + + +def test_generate_parallel_args_accepts_enable_alias(): + args = list( + generate_parallel_args( + stop_tick=10, + parallel_cfg={ + "func": "attr_curve", + "adjust_char": 1, + "adjust_sc": { + "enable": True, + "sc_range": [0, 0], + "sc_list": ["攻击力%"], + "remove_equip_list": [], + }, + "adjust_weapon": {"enable": False, "weapon_list": []}, + }, + run_turn_uuid="uuid", + ) + ) + + assert len(args) == 1 + first_arg = args[0] + assert first_arg.func == "attr_curve" + assert first_arg.adjust_char == 1 + assert first_arg.sc_name == "scATK_percent" diff --git a/zsim/api_src/routes/session_op.py b/zsim/api_src/routes/session_op.py index d54b702c..6d77349f 100644 --- a/zsim/api_src/routes/session_op.py +++ b/zsim/api_src/routes/session_op.py @@ -1,11 +1,15 @@ +import json import logging +from pathlib import Path +from typing import Any from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException from zsim.api_src.services.database.session_db import SessionDB, get_session_db from zsim.api_src.services.sim_controller.sim_controller import SimController +from zsim.define import results_dir from zsim.models.session.session_create import Session -from zsim.models.session.session_run import SessionRun +from zsim.models.session.session_run import ParallelCfg, SessionRun logger = logging.getLogger(__name__) router = APIRouter() @@ -69,6 +73,49 @@ async def run_session( background_tasks.add_task(sim_controller.execute_simulation) if session_run.mode == "parallel" and session_run.parallel_config: + result_dir = Path(results_dir) / session.session_id + result_dir.mkdir(parents=True, exist_ok=True) + + parallel_cfg = session_run.parallel_config + parallel_cfg_dump: dict[str, Any] = parallel_cfg.model_dump(mode="json") + + parallel_config_payload: dict[str, Any] = { + "enabled": True, + "func": parallel_cfg.func, + "adjust_char": parallel_cfg.adjust_char, + "func_config": parallel_cfg_dump.get("func_config"), + } + + adjust_sc_config: dict[str, Any] = {"enabled": False} + adjust_weapon_config: dict[str, Any] = {"enabled": False} + + func_cfg_dump = parallel_cfg_dump.get("func_config") + if parallel_cfg.func == "attr_curve": + sc_config = func_cfg_dump if isinstance(func_cfg_dump, dict) else {} + adjust_sc_config.update( + { + "enabled": True, + "sc_range": list(sc_config.get("sc_range", [])), + "sc_list": list(sc_config.get("sc_list", [])), + "remove_equip_list": list(sc_config.get("remove_equip_list", [])), + } + ) + elif parallel_cfg.func == "weapon": + weapon_config = func_cfg_dump if isinstance(func_cfg_dump, dict) else {} + adjust_weapon_config.update( + { + "enabled": True, + "weapon_list": list(weapon_config.get("weapon_list", [])), + } + ) + + parallel_config_payload["adjust_sc"] = adjust_sc_config + parallel_config_payload["adjust_weapon"] = adjust_weapon_config + + config_path = result_dir / ".parallel_config.json" + with config_path.open("w", encoding="utf-8") as f: + json.dump(parallel_config_payload, f, indent=4, ensure_ascii=False) + args_iterator = sim_controller.generate_parallel_args(session, session_run) for sim_cfg in args_iterator: await sim_controller.put_into_queue( diff --git a/zsim/lib_webui/process_parallel_data.py b/zsim/lib_webui/process_parallel_data.py index ef3d54a2..f04bec42 100644 --- a/zsim/lib_webui/process_parallel_data.py +++ b/zsim/lib_webui/process_parallel_data.py @@ -22,6 +22,21 @@ reversed_stats_trans_mapping = {v: k for k, v in stats_trans_mapping.items()} +def _get_enabled_flag(config: dict[str, Any] | None) -> bool: + """Return the enabled flag supporting legacy ``enable`` naming.""" + + if not isinstance(config, dict): + return False + + if "enabled" in config: + return bool(config.get("enabled")) + + if "enable" in config: + return bool(config.get("enable")) + + return False + + def judge_parallel_result(rid: int | str) -> bool: """判断对应的rid是否为并行模式。 @@ -41,8 +56,15 @@ def judge_parallel_result(rid: int | str) -> bool: try: with open(parallel_config_path, "r", encoding="utf-8") as f: - parallel_config: dict = json.load(f) - if not parallel_config.get("enabled", False): + parallel_config: dict[str, Any] = json.load(f) + + enabled_flag = _get_enabled_flag(parallel_config) + if not enabled_flag: + func_name = parallel_config.get("func") + if isinstance(func_name, str) and func_name in {"attr_curve", "weapon"}: + enabled_flag = True + + if not enabled_flag: return False except (json.JSONDecodeError, IOError): # 如果文件读取或解析失败,也视为非并行模式 @@ -94,7 +116,11 @@ async def prepare_parallel_data_and_cache(rid: int | str) -> None: st.error(f"读取或解析并行配置文件 {parallel_config_path} 失败: {e}") return - if parallel_config.get("adjust_sc", {}).get("enabled", False): + adjust_sc_enabled = _get_enabled_flag(parallel_config.get("adjust_sc")) + if not adjust_sc_enabled and parallel_config.get("func") == "attr_curve": + adjust_sc_enabled = True + + if adjust_sc_enabled: merged_sc_file_path = os.path.join(result_dir, "merged_sc_data.json") if os.path.exists(merged_sc_file_path): return @@ -130,7 +156,11 @@ async def merge_parallel_dmg_data( async with aiofiles.open(parallel_config_path, "r", encoding="utf-8") as f: parallel_config: dict = json.loads(await f.read()) - if parallel_config.get("adjust_sc", {}).get("enabled", False): + adjust_sc_enabled = _get_enabled_flag(parallel_config.get("adjust_sc")) + if not adjust_sc_enabled and parallel_config.get("func") == "attr_curve": + adjust_sc_enabled = True + + if adjust_sc_enabled: # 属性收益曲线功能 func = "attr_curve" merged_sc_file_path = os.path.join(result_dir, "merged_sc_data.json") @@ -154,7 +184,11 @@ async def merge_parallel_dmg_data( except Exception as e: st.error(f"合并属性收益曲线数据时出错: {e}") return func, sc_merged_data - elif parallel_config.get("adjust_weapon", {}).get("enabled", False): + adjust_weapon_enabled = _get_enabled_flag(parallel_config.get("adjust_weapon")) + if not adjust_weapon_enabled and parallel_config.get("func") == "weapon": + adjust_weapon_enabled = True + + if adjust_weapon_enabled: # 武器切换功能 func = "weapon" merged_weapon_file_path = os.path.join(result_dir, "merged_weapon_data.json") @@ -178,8 +212,7 @@ async def merge_parallel_dmg_data( st.error(f"合并武器切换数据时出错: {e}") return func, weapon_merged_data - else: - return None + return None def __draw_attr_curve( diff --git a/zsim/lib_webui/process_simulator.py b/zsim/lib_webui/process_simulator.py index d71e6341..8dbb1ac2 100644 --- a/zsim/lib_webui/process_simulator.py +++ b/zsim/lib_webui/process_simulator.py @@ -16,9 +16,24 @@ from .constants import stats_trans_mapping +def _get_enabled_flag(config: dict[str, Any] | None) -> bool: + """Return the enabled flag supporting legacy ``enable`` naming.""" + + if not isinstance(config, dict): + return False + + if "enabled" in config: + return bool(config.get("enabled")) + + if "enable" in config: + return bool(config.get("enable")) + + return False + + def generate_parallel_args( stop_tick: int, - parallel_cfg: dict, + parallel_cfg: dict[str, Any], run_turn_uuid: str, ) -> Iterator[SimCfg]: """生成用于并行模拟的参数。 @@ -33,13 +48,18 @@ def generate_parallel_args( """ # Determine the function based on enabled flags func = None - if parallel_cfg.get("adjust_sc", {}).get("enabled", False): + if _get_enabled_flag(parallel_cfg.get("adjust_sc")): func = "attr_curve" - elif parallel_cfg.get("adjust_weapon", {}).get("enabled", False): + elif _get_enabled_flag(parallel_cfg.get("adjust_weapon")): func = "weapon" + elif isinstance(parallel_cfg.get("func"), str): + func = parallel_cfg["func"] if func == "attr_curve": - adjust_sc_cfg = parallel_cfg["adjust_sc"] + adjust_sc_cfg_raw = parallel_cfg.get("adjust_sc") + if not isinstance(adjust_sc_cfg_raw, dict): + raise ValueError(f"并行配置缺少属性收益参数: {parallel_cfg}") + adjust_sc_cfg = adjust_sc_cfg_raw sc_list = adjust_sc_cfg["sc_list"] sc_range_start, sc_range_end = adjust_sc_cfg["sc_range"] remove_equip_list = adjust_sc_cfg.get( @@ -59,7 +79,10 @@ def generate_parallel_args( ) yield args elif func == "weapon": - adjust_weapon_cfg = parallel_cfg["adjust_weapon"] + adjust_weapon_cfg_raw = parallel_cfg.get("adjust_weapon") + if not isinstance(adjust_weapon_cfg_raw, dict): + raise ValueError(f"并行配置缺少武器参数: {parallel_cfg}") + adjust_weapon_cfg = adjust_weapon_cfg_raw weapon_list = adjust_weapon_cfg["weapon_list"] for weapon in weapon_list: args = ExecWeaponCfg( diff --git a/zsim/utils/process_parallel_data.py b/zsim/utils/process_parallel_data.py index b0741289..ab1420f5 100644 --- a/zsim/utils/process_parallel_data.py +++ b/zsim/utils/process_parallel_data.py @@ -18,6 +18,21 @@ reversed_stats_trans_mapping = {v: k for k, v in stats_trans_mapping.items()} +def _get_enabled_flag(config: dict[str, Any] | None) -> bool: + """返回配置中的启用标志,兼容 enable/enabled 命名。""" + + if not isinstance(config, dict): + return False + + if "enabled" in config: + return bool(config.get("enabled")) + + if "enable" in config: + return bool(config.get("enable")) + + return False + + def judge_parallel_result(rid: int | str) -> bool: """判断对应的rid是否为并行模式。 @@ -37,8 +52,15 @@ def judge_parallel_result(rid: int | str) -> bool: try: with open(parallel_config_path, "r", encoding="utf-8") as f: - parallel_config: dict = json.load(f) - if not parallel_config.get("enabled", False): + parallel_config: dict[str, Any] = json.load(f) + + enabled_flag = _get_enabled_flag(parallel_config) + if not enabled_flag: + func_name = parallel_config.get("func") + if isinstance(func_name, str) and func_name in {"attr_curve", "weapon"}: + enabled_flag = True + + if not enabled_flag: return False except (json.JSONDecodeError, IOError): # 如果文件读取或解析失败,也视为非并行模式 @@ -90,7 +112,11 @@ async def prepare_parallel_data_and_cache(rid: int | str) -> None: print(f"读取或解析并行配置文件 {parallel_config_path} 失败: {e}") return - if parallel_config.get("adjust_sc", {}).get("enabled", False): + adjust_sc_enabled = _get_enabled_flag(parallel_config.get("adjust_sc")) + if not adjust_sc_enabled and parallel_config.get("func") == "attr_curve": + adjust_sc_enabled = True + + if adjust_sc_enabled: merged_sc_file_path = os.path.join(result_dir, "merged_sc_data.json") if os.path.exists(merged_sc_file_path): return @@ -126,7 +152,11 @@ async def merge_parallel_dmg_data( async with aiofiles.open(parallel_config_path, "r", encoding="utf-8") as f: parallel_config: dict = json.loads(await f.read()) - if parallel_config.get("adjust_sc", {}).get("enabled", False): + adjust_sc_enabled = _get_enabled_flag(parallel_config.get("adjust_sc")) + if not adjust_sc_enabled and parallel_config.get("func") == "attr_curve": + adjust_sc_enabled = True + + if adjust_sc_enabled: # 属性收益曲线功能 func = "attr_curve" merged_sc_file_path = os.path.join(result_dir, "merged_sc_data.json") @@ -150,7 +180,11 @@ async def merge_parallel_dmg_data( except Exception as e: print(f"合并属性收益曲线数据时出错: {e}") return func, sc_merged_data - elif parallel_config.get("adjust_weapon", {}).get("enabled", False): + adjust_weapon_enabled = _get_enabled_flag(parallel_config.get("adjust_weapon")) + if not adjust_weapon_enabled and parallel_config.get("func") == "weapon": + adjust_weapon_enabled = True + + if adjust_weapon_enabled: # 武器切换功能 func = "weapon" merged_weapon_file_path = os.path.join(result_dir, "merged_weapon_data.json")