Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 172 additions & 0 deletions tests/api/test_session_op.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
import json
import shutil
from pathlib import Path

import pytest

pytest.importorskip("fastapi")

from fastapi import BackgroundTasks
from fastapi.testclient import TestClient

from zsim.api import app
from zsim.api_src.routes import session_op
from zsim.api_src.services.database.session_db import get_session_db
from zsim.define import results_dir
from zsim.models.session.session_create import Session
from zsim.models.session.session_run import SessionRun
from zsim.lib_webui import process_parallel_data as webui_parallel_data
from zsim.lib_webui.process_simulator import generate_parallel_args
from zsim.utils.process_parallel_data import judge_parallel_result, merge_parallel_dmg_data

client = TestClient(app)

Expand Down Expand Up @@ -152,3 +167,160 @@ async def test_delete_session(session_data):

response = client.get(f"/api/sessions/{session_data['session_id']}")
assert response.status_code == 404


@pytest.mark.asyncio
async def test_parallel_run_writes_config_and_merge(session_data, monkeypatch):
session_id = "parallel_test_session"
result_dir_path = Path(results_dir) / session_id
if result_dir_path.exists():
shutil.rmtree(result_dir_path)

db = await get_session_db()
await db.delete_session(session_id)

session_payload = session_data.copy()
session_payload["session_id"] = session_id
session = Session(**session_payload)
await db.add_session(session)

class DummySimController:
def generate_parallel_args(self, _session, _session_run):
return [None]

async def put_into_queue(self, *_args, **_kwargs):
return None

async def execute_simulation_test(self):
return []

async def execute_simulation(self):
return []

monkeypatch.setattr(session_op, "SimController", DummySimController)

session_run_payload = {
"stop_tick": 10,
"mode": "parallel",
"common_config": {
"session_id": session_id,
"char_config": [
{"name": "仪玄"},
{"name": "耀嘉音"},
{"name": "扳机"},
],
"enemy_config": {"index_id": 11412, "adjustment_id": 22412},
"apl_path": "zsim/data/APLData/仪玄-耀嘉音-扳机.toml",
},
"parallel_config": {
"enable": True,
"adjust_char": 1,
"func": "attr_curve",
"func_config": {
"sc_range": [0, 0],
"sc_list": ["scATK_percent"],
"remove_equip_list": [],
},
},
}

session_run = SessionRun(**session_run_payload)
background_tasks = BackgroundTasks()

try:
response = await session_op.run_session(
session_id,
session_run,
background_tasks,
db,
test_mode=True,
)
assert response["code"] == 0

config_path = result_dir_path / ".parallel_config.json"
assert config_path.exists()

config_data = json.loads(config_path.read_text(encoding="utf-8"))
assert config_data["enabled"] is True
assert config_data["adjust_sc"]["enabled"] is True
assert config_data["adjust_sc"]["sc_range"] == [0, 0]
assert config_data["adjust_sc"]["sc_list"] == ["scATK_percent"]
assert config_data["adjust_weapon"]["enabled"] is False
assert config_data["func_config"] == session_run_payload["parallel_config"]["func_config"]

sub_dir = result_dir_path / "attr_curve_sample"
sub_dir.mkdir(parents=True, exist_ok=True)
(sub_dir / "sub.parallel_config.json").write_text(
json.dumps(
{
"adjust_char": "仪玄",
"sc_name": "scATK_percent",
"sc_value": 0,
},
ensure_ascii=False,
),
encoding="utf-8",
)
(sub_dir / "damage_attribution.json").write_text(
json.dumps(
{
"仪玄": {
"direct_damage": 100.0,
"anomaly_damage": 50.0,
}
},
ensure_ascii=False,
),
encoding="utf-8",
)

assert judge_parallel_result(session_id) is True

alias_config = json.loads(config_path.read_text(encoding="utf-8"))
alias_config["enable"] = alias_config.pop("enabled")
alias_config["adjust_sc"]["enable"] = alias_config["adjust_sc"].pop("enabled")
alias_config["adjust_weapon"]["enable"] = alias_config["adjust_weapon"].pop(
"enabled"
)
config_path.write_text(
json.dumps(alias_config, indent=4, ensure_ascii=False), encoding="utf-8"
)

assert judge_parallel_result(session_id) is True
assert webui_parallel_data.judge_parallel_result(session_id) is True

result = await merge_parallel_dmg_data(session_id)
assert result is not None
func, merged_data = result
assert func == "attr_curve"
assert merged_data["仪玄"]["scATK_percent"][0]["result"] == 150.0
finally:
await db.delete_session(session_id)
if result_dir_path.exists():
shutil.rmtree(result_dir_path)


def test_generate_parallel_args_accepts_enable_alias():
args = list(
generate_parallel_args(
stop_tick=10,
parallel_cfg={
"func": "attr_curve",
"adjust_char": 1,
"adjust_sc": {
"enable": True,
"sc_range": [0, 0],
"sc_list": ["攻击力%"],
"remove_equip_list": [],
},
"adjust_weapon": {"enable": False, "weapon_list": []},
},
run_turn_uuid="uuid",
)
)

assert len(args) == 1
first_arg = args[0]
assert first_arg.func == "attr_curve"
assert first_arg.adjust_char == 1
assert first_arg.sc_name == "scATK_percent"
49 changes: 48 additions & 1 deletion zsim/api_src/routes/session_op.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import json
import logging
from pathlib import Path
from typing import Any

from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException

from zsim.api_src.services.database.session_db import SessionDB, get_session_db
from zsim.api_src.services.sim_controller.sim_controller import SimController
from zsim.define import results_dir
from zsim.models.session.session_create import Session
from zsim.models.session.session_run import SessionRun
from zsim.models.session.session_run import ParallelCfg, SessionRun

logger = logging.getLogger(__name__)
router = APIRouter()
Expand Down Expand Up @@ -69,6 +73,49 @@ async def run_session(
background_tasks.add_task(sim_controller.execute_simulation)

if session_run.mode == "parallel" and session_run.parallel_config:
result_dir = Path(results_dir) / session.session_id
result_dir.mkdir(parents=True, exist_ok=True)

parallel_cfg = session_run.parallel_config
parallel_cfg_dump: dict[str, Any] = parallel_cfg.model_dump(mode="json")

parallel_config_payload: dict[str, Any] = {
"enabled": True,
"func": parallel_cfg.func,
"adjust_char": parallel_cfg.adjust_char,
"func_config": parallel_cfg_dump.get("func_config"),
}

adjust_sc_config: dict[str, Any] = {"enabled": False}
adjust_weapon_config: dict[str, Any] = {"enabled": False}

func_cfg_dump = parallel_cfg_dump.get("func_config")
if parallel_cfg.func == "attr_curve":
sc_config = func_cfg_dump if isinstance(func_cfg_dump, dict) else {}
adjust_sc_config.update(
{
"enabled": True,
"sc_range": list(sc_config.get("sc_range", [])),
"sc_list": list(sc_config.get("sc_list", [])),
"remove_equip_list": list(sc_config.get("remove_equip_list", [])),
}
)
elif parallel_cfg.func == "weapon":
weapon_config = func_cfg_dump if isinstance(func_cfg_dump, dict) else {}
adjust_weapon_config.update(
{
"enabled": True,
"weapon_list": list(weapon_config.get("weapon_list", [])),
}
)

parallel_config_payload["adjust_sc"] = adjust_sc_config
parallel_config_payload["adjust_weapon"] = adjust_weapon_config

config_path = result_dir / ".parallel_config.json"
with config_path.open("w", encoding="utf-8") as f:
json.dump(parallel_config_payload, f, indent=4, ensure_ascii=False)

args_iterator = sim_controller.generate_parallel_args(session, session_run)
for sim_cfg in args_iterator:
await sim_controller.put_into_queue(
Expand Down
47 changes: 40 additions & 7 deletions zsim/lib_webui/process_parallel_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@
reversed_stats_trans_mapping = {v: k for k, v in stats_trans_mapping.items()}


def _get_enabled_flag(config: dict[str, Any] | None) -> bool:
"""Return the enabled flag supporting legacy ``enable`` naming."""

if not isinstance(config, dict):
return False

if "enabled" in config:
return bool(config.get("enabled"))

if "enable" in config:
return bool(config.get("enable"))

return False


def judge_parallel_result(rid: int | str) -> bool:
"""判断对应的rid是否为并行模式。

Expand All @@ -41,8 +56,15 @@ def judge_parallel_result(rid: int | str) -> bool:

try:
with open(parallel_config_path, "r", encoding="utf-8") as f:
parallel_config: dict = json.load(f)
if not parallel_config.get("enabled", False):
parallel_config: dict[str, Any] = json.load(f)

enabled_flag = _get_enabled_flag(parallel_config)
if not enabled_flag:
func_name = parallel_config.get("func")
if isinstance(func_name, str) and func_name in {"attr_curve", "weapon"}:
enabled_flag = True

if not enabled_flag:
return False
except (json.JSONDecodeError, IOError):
# 如果文件读取或解析失败,也视为非并行模式
Expand Down Expand Up @@ -94,7 +116,11 @@ async def prepare_parallel_data_and_cache(rid: int | str) -> None:
st.error(f"读取或解析并行配置文件 {parallel_config_path} 失败: {e}")
return

if parallel_config.get("adjust_sc", {}).get("enabled", False):
adjust_sc_enabled = _get_enabled_flag(parallel_config.get("adjust_sc"))
if not adjust_sc_enabled and parallel_config.get("func") == "attr_curve":
adjust_sc_enabled = True

if adjust_sc_enabled:
merged_sc_file_path = os.path.join(result_dir, "merged_sc_data.json")
if os.path.exists(merged_sc_file_path):
return
Expand Down Expand Up @@ -130,7 +156,11 @@ async def merge_parallel_dmg_data(
async with aiofiles.open(parallel_config_path, "r", encoding="utf-8") as f:
parallel_config: dict = json.loads(await f.read())

if parallel_config.get("adjust_sc", {}).get("enabled", False):
adjust_sc_enabled = _get_enabled_flag(parallel_config.get("adjust_sc"))
if not adjust_sc_enabled and parallel_config.get("func") == "attr_curve":
adjust_sc_enabled = True

if adjust_sc_enabled:
# 属性收益曲线功能
func = "attr_curve"
merged_sc_file_path = os.path.join(result_dir, "merged_sc_data.json")
Expand All @@ -154,7 +184,11 @@ async def merge_parallel_dmg_data(
except Exception as e:
st.error(f"合并属性收益曲线数据时出错: {e}")
return func, sc_merged_data
elif parallel_config.get("adjust_weapon", {}).get("enabled", False):
adjust_weapon_enabled = _get_enabled_flag(parallel_config.get("adjust_weapon"))
if not adjust_weapon_enabled and parallel_config.get("func") == "weapon":
adjust_weapon_enabled = True

if adjust_weapon_enabled:
# 武器切换功能
func = "weapon"
merged_weapon_file_path = os.path.join(result_dir, "merged_weapon_data.json")
Expand All @@ -178,8 +212,7 @@ async def merge_parallel_dmg_data(
st.error(f"合并武器切换数据时出错: {e}")
return func, weapon_merged_data

else:
return None
return None


def __draw_attr_curve(
Expand Down
Loading
Loading