From ff54e91c44871db468a956d3add0548178859fd9 Mon Sep 17 00:00:00 2001 From: hzy14610046011 <1295752786@qq.com> Date: Sun, 21 Jan 2024 19:01:56 +0800 Subject: [PATCH] added shortcut to register dataframe in scenario --- Melodie/data_collector.py | 84 +++++++++++++++--- Melodie/data_loader.py | 89 +++++++++++++++++--- Melodie/scenario_manager.py | 14 +-- Melodie/simulator.py | 26 ++---- MelodieInfra/config/config.py | 3 +- MelodieInfra/db/db.py | 3 +- MelodieInfra/exceptions/exceptions.py | 2 +- MelodieInfra/parallel/parallel_manager.py | 6 +- MelodieInfra/table/table_general.py | 18 ++-- MelodieInfra/table/table_objects.py | 12 +-- tests/infra/resources/.gitignore | 1 + tests/infra/resources/output/.gitignore | 2 - tests/infra/resources/output/placeholder.txt | 1 - tests/infra/test_data_collector.py | 28 +++++- tests/infra/test_data_collector_csv.py | 16 +++- tests/infra/test_typing.py | 17 ++++ 16 files changed, 245 insertions(+), 77 deletions(-) create mode 100644 tests/infra/resources/.gitignore delete mode 100644 tests/infra/resources/output/.gitignore delete mode 100644 tests/infra/resources/output/placeholder.txt create mode 100644 tests/infra/test_typing.py diff --git a/Melodie/data_collector.py b/Melodie/data_collector.py index 2162393c..583e758a 100644 --- a/Melodie/data_collector.py +++ b/Melodie/data_collector.py @@ -1,7 +1,7 @@ import logging import os import time -from typing import Callable, List, TYPE_CHECKING, Dict, Tuple, Any, Optional, Type +from typing import Callable, Generic, List, TYPE_CHECKING, Dict, NewType, Tuple, Any, Optional, Type, TypeVar, Union, cast import pandas as pd import collections import sqlalchemy @@ -11,6 +11,7 @@ MelodieExceptions, is_pypy, Table, + GeneralTable, TableRow, objs_to_table_row_vectorizer, ) @@ -20,6 +21,7 @@ if TYPE_CHECKING: from Melodie import Model, Scenario, BaseAgentContainer, AgentList + M = TypeVar('M', bound=Model) class PropertyToCollect: @@ -34,6 +36,10 @@ def vectorize_template(obj): """ +def underline_to_camel(s: str): + return ''.join(word.capitalize() for word in s.split('_')) + + def vectorizer(attrs): code = VEC_TEMPLATE.format(exprs=",".join( [f'obj["{attr}"]' for attr in attrs])) @@ -45,7 +51,7 @@ def vectorizer(attrs): vectorizers = {} -class DataCollector: +class DataCollector(): """ Data Collector collects data in the model. @@ -56,6 +62,7 @@ class DataCollector: Before the model running finished, the DataCollector dumps data to dataframe, and save to database. """ + _CORE_PROPERTIES_ = ['id_scenario', 'id_run', 'period'] def __init__(self, target="sqlite"): """ @@ -75,6 +82,9 @@ def __init__(self, target="sqlite"): self.agent_properties_dict: Dict[str, Table] = {} self.environment_properties_list: Dict[str, Table] = None + self._custom_collectors: Dict[str, + Tuple[Callable[[Model], Dict[str, Any]], List[str]]] = {} + self._custom_collected_data: Dict[str, GeneralTable] = {} self._time_elapsed = 0 @@ -132,6 +142,18 @@ def add_environment_property(self, property_name: str, as_type: Type = None): PropertyToCollect(property_name, as_type) ) + def add_custom_collector(self, table_name: str, row_collector: + "Callable[[M], Union[Dict[str, Any], List[Dict[str, Any]]]]", columns: List[str]): + """ + Add a custom data collector to generate a standalone table. + + :param table_name: The name of table storing the data collected. + :param row_collector: A callable function, returning a `dict` computed from `Model` forming one + row in the table. + """ + self._custom_collectors[table_name] = ( + cast(Any, row_collector), columns) + def env_property_names(self) -> List[str]: """ Get the environment property names to collect @@ -177,6 +199,33 @@ def collect_agent_properties(self, period: int): container_name, agent_property_names[container_name], container, period ) + def collect_custom_properties(self, period: int): + """ + Collect custom properties by calling custom callbacks. + + :param period: Current simulation step + :return: None + """ + for collector_name in self._custom_collectors.keys(): + self.collect_single_custom_property(collector_name, period) + + def collect_single_custom_property(self, collector_name: str, period: int): + collector_func, column_names = self._custom_collectors[collector_name] + if collector_name not in self._custom_collected_data: + self._custom_collected_data[collector_name] = GeneralTable( + {k: None for k in column_names}, column_names) + assert self.model is not None + data = collector_func(self.model) + if isinstance(data, list): + for item in data: + assert isinstance(item, dict) + self._custom_collected_data[collector_name].data.append(item) + elif isinstance(data, dict): + self._custom_collected_data[collector_name].data.append(data) + else: + raise NotImplementedError( + "Data collector function should return a list or dict") + def append_agent_properties_by_records( self, container_name: str, @@ -190,6 +239,7 @@ def append_agent_properties_by_records( :return: None """ + assert self.model is not None id_run, id_scenario = self.model.run_id_in_scenario, self.model.scenario.id if container_name not in self.agent_properties_dict: if len(container) == 0: @@ -206,7 +256,7 @@ def append_agent_properties_by_records( row_cls = TableRow.subcls_from_dict(props) self.agent_properties_dict[container_name] = Table( - row_cls, ['id_scenario', 'id_run', 'period', 'id'] + prop_names) + row_cls, self._CORE_PROPERTIES_+['id'] + prop_names) self._agent_properties_collectors[ container_name ] = objs_to_table_row_vectorizer(row_cls, prop_names) @@ -223,6 +273,9 @@ def append_agent_properties_by_records( props_list.append(row) def append_environment_properties(self, period: int): + assert self.model is not None + assert self.model.environment is not None + assert self.model.scenario is not None env_dic = { "id_scenario": self.model.scenario.id, "id_run": self.model.run_id_in_scenario, @@ -233,7 +286,7 @@ def append_environment_properties(self, period: int): if self.environment_properties_list is None: row_cls = TableRow.subcls_from_dict(env_dic) self.environment_properties_list = Table( - row_cls, ['id_scenario', 'id_run', 'period']+self.env_property_names()) + row_cls, self._CORE_PROPERTIES_+self.env_property_names()) self.environment_properties_list.append_from_dicts([env_dic]) @property @@ -263,6 +316,7 @@ def collect(self, period: int) -> None: t0 = time.time() self.append_environment_properties(period) self.collect_agent_properties(period) + self.collect_custom_properties(period) t1 = time.time() self._time_elapsed += t1 - t0 @@ -304,7 +358,7 @@ def get_single_agent_data(self, agent_container_name: str, agent_id: int): container_data = self.agent_properties_dict[agent_container_name] return list(filter(lambda item: item["id"] == agent_id, container_data)) - def _write_list_to_table(self, engine, table_name: str, data: Table): + def _write_list_to_table(self, engine, table_name: str, data: Union[Table, GeneralTable]): """ Write a list of dict into database. @@ -338,11 +392,6 @@ def save(self): DBConn.ENVIRONMENT_RESULT_TABLE, self.environment_properties_list, ) - # self._write_list_to_table( - # connection.get_engine(), - # DBConn.ENVIRONMENT_RESULT_TABLE, - # self.environment_properties_list, - # ) self.environment_properties_list = None write_db_time += time.time() - _t @@ -350,10 +399,21 @@ def save(self): _t = time.time() self._write_list_to_table( connection.get_engine(), - container_name + "_result", + # "agent_list" -> "AgentList" + # "agent_list" -> "Agent_list" + # "Agent" + "Result_"+underline_to_camel(container_name), self.agent_properties_dict[container_name], ) - # print("wrote agent properties!", container_name+"_result") + write_db_time += time.time() - _t + + for custom_table_name in self._custom_collected_data.keys(): + _t = time.time() + self._write_list_to_table( + connection.get_engine(), + custom_table_name, + self._custom_collected_data[custom_table_name] + ) write_db_time += time.time() - _t self.agent_properties_dict = {} diff --git a/Melodie/data_loader.py b/Melodie/data_loader.py index dc997df5..f8fd9b8e 100644 --- a/Melodie/data_loader.py +++ b/Melodie/data_loader.py @@ -3,6 +3,7 @@ import os import shutil from typing import Optional, Dict, List, Union, Callable, Type, TYPE_CHECKING +import numpy as np import sqlalchemy import cloudpickle @@ -23,6 +24,19 @@ logger = logging.getLogger(__name__) +def first_char_upper(s: str) -> str: + if len(s) >= 1: + return s[0].upper()+s[1:] + else: + return s + +# TODO: move this function to utils + + +def underline_to_camel(s): + return ''.join(first_char_upper(word) for word in s.split('_')) + + class DataFrameInfo: """ DataFrameInfo provides standard format for input tables as parameters. @@ -134,8 +148,51 @@ def __init__( self.registered_matrices: Dict[str, "np.ndarray"] = {} self.manager = manager self.manager.data_loader = self + self.load_scenarios() self.setup() + def load_scenarios(self): + for file_name in os.listdir(self.config.input_folder): + camel_case = underline_to_camel(os.path.splitext(file_name)[0]) + + if camel_case in ("SimulatorScenarios", "TrainerScenarios", "CalibratorScenarios"): + self.load_dataframe(file_name, camel_case) + + def load_dataframe(self, df_info: Union[str, "DataFrameInfo"], df_name=""): + """ + Load a data frame from table file. + + :df_info: The file name of that containing the data frame, or pass a `DataFrameInfo` + """ + from .data_loader import DataFrameInfo + + assert self.manager is not None, MelodieExceptions.MLD_INTL_EXC + assert self.manager.data_loader is not None, MelodieExceptions.MLD_INTL_EXC + if isinstance(df_info, str): + df_name = df_name if df_name != "" else os.path.basename(df_info) + info = DataFrameInfo(df_name, {}, df_info) + return self._load_dataframe(info) + else: + return self._load_dataframe(df_info) + + def load_matrix(self, mat_info: Union[str, "MatrixInfo"], mat_name="") -> np.ndarray: + """ + Load a matrix from table file. + + :mat_info: The file name of that containing the matrix, or pass a `DataFrameInfo` + """ + from .data_loader import MatrixInfo + + assert self.manager is not None, MelodieExceptions.MLD_INTL_EXC + assert self.manager.data_loader is not None, MelodieExceptions.MLD_INTL_EXC + if isinstance(mat_info, str): + mat_name = mat_name if mat_name != "" else os.path.basename( + mat_info) + info = MatrixInfo(mat_name, None, mat_info) + return self.manager.data_loader._load_matrix(info) + else: + return self.manager.data_loader._load_matrix(mat_info) + def setup(self): pass @@ -209,7 +266,7 @@ def _load_dataframe_cached(self, file_path_abs: str) -> "pd.DataFrame": cloudpickle.dump(table, f) return table - def load_dataframe(self, df_info: "DataFrameInfo") -> "pd.DataFrame": + def _load_dataframe(self, df_info: "DataFrameInfo") -> "pd.DataFrame": """ Register static table. The static table will be copied into database. @@ -219,24 +276,27 @@ def load_dataframe(self, df_info: "DataFrameInfo") -> "pd.DataFrame": """ table: Optional["pd.DataFrame"] - - MelodieExceptions.Data.TableNameInvalid(df_info.df_name) + if df_info.df_name in self.registered_dataframes: + return self.registered_dataframes[df_info.df_name] assert df_info.file_name is not None file_path_abs = os.path.join( self.config.input_folder, df_info.file_name) - # if df_info.engine == "pandas": + table = self._load_dataframe_cached(file_path_abs) self.registered_dataframes[df_info.df_name] = table return table - def load_matrix(self, matrix_info: "MatrixInfo") -> "np.ndarray": + def _load_matrix(self, matrix_info: "MatrixInfo") -> "np.ndarray": """ Register static matrix. :return: None """ + if matrix_info.mat_name in self.registered_matrices: + return self.registered_matrices[matrix_info.mat_name] + assert matrix_info.file_name is not None _, ext = os.path.splitext(matrix_info.file_name) file_path_abs = os.path.join( @@ -255,7 +315,7 @@ def load_matrix(self, matrix_info: "MatrixInfo") -> "np.ndarray": def dataframe_generator( self, - df_info: DataFrameInfo, + df_info: Union[str, DataFrameInfo], rows_in_scenario: Union[int, Callable[[Scenario], int]], ) -> DataFrameGenerator: """ @@ -304,9 +364,18 @@ def generate_scenarios(self, manager_type: str) -> List["Scenario"]: :param manager_type: The type of scenario manager, a ``str`` in "simulator", "trainer" or "calibrator". :return: A list of scenarios. """ - if manager_type not in {"simulator", "trainer", "calibrator"}: - MelodieExceptions.Program.Variable.VariableNotInSet( + if manager_type not in {"Simulator", "Trainer", "Calibrator"}: + raise MelodieExceptions.Program.Variable.VariableNotInSet( "manager_type", manager_type, { - "simulator", "trainer", "calibrator"} + "Simulator", "Trainer", "Calibrator"} ) - return self.generate_scenarios_from_dataframe(f"{manager_type}_scenarios") + + df_name = f"{manager_type}Scenarios" + + if df_name in self.registered_dataframes: + return self.generate_scenarios_from_dataframe(df_name) + elif underline_to_camel(df_name) in self.registered_dataframes: + return self.generate_scenarios_from_dataframe(underline_to_camel(df_name)) + else: + raise NotImplementedError( + f"{manager_type}/{underline_to_camel(df_name)} is not supported!") diff --git a/Melodie/scenario_manager.py b/Melodie/scenario_manager.py index e4c076ee..ffe0501f 100644 --- a/Melodie/scenario_manager.py +++ b/Melodie/scenario_manager.py @@ -72,15 +72,10 @@ def load_dataframe(self, df_info: Union[str, "DataFrameInfo"]): :df_info: The file name of that containing the data frame, or pass a `DataFrameInfo` """ - from .data_loader import DataFrameInfo assert self.manager is not None, MelodieExceptions.MLD_INTL_EXC assert self.manager.data_loader is not None, MelodieExceptions.MLD_INTL_EXC - if isinstance(df_info, str): - info = DataFrameInfo(os.path.basename(df_info), {}, df_info) - return self.manager.data_loader.load_dataframe(info) - else: - return self.manager.data_loader.load_dataframe(df_info) + return self.manager.data_loader.load_dataframe(df_info) def load_matrix(self, mat_info: Union[str, "MatrixInfo"]) -> np.ndarray: """ @@ -88,15 +83,10 @@ def load_matrix(self, mat_info: Union[str, "MatrixInfo"]) -> np.ndarray: :mat_info: The file name of that containing the matrix, or pass a `DataFrameInfo` """ - from .data_loader import MatrixInfo assert self.manager is not None, MelodieExceptions.MLD_INTL_EXC assert self.manager.data_loader is not None, MelodieExceptions.MLD_INTL_EXC - if isinstance(mat_info, str): - info = MatrixInfo(os.path.basename(mat_info), None, mat_info) - return self.manager.data_loader.load_matrix(info) - else: - return self.manager.data_loader.load_matrix(mat_info) + return self.manager.data_loader.load_matrix(mat_info) def to_dict(self): """ diff --git a/Melodie/simulator.py b/Melodie/simulator.py index 277911c0..6561f9a7 100644 --- a/Melodie/simulator.py +++ b/Melodie/simulator.py @@ -19,9 +19,10 @@ MelodieExceptions, show_prettified_warning, MelodieGlobalConfig, + ) -from .data_loader import DataLoader +from .data_loader import DataLoader, DataFrameInfo from .model import Model from .scenario_manager import Scenario from .visualizer import BaseVisualizer, MelodieModelReset @@ -100,7 +101,8 @@ def clear_output_tables(self): Clear all output tables """ output_path = self.config.output_tables_path() - shutil.rmtree(output_path) + if os.path.exists(output_path): + shutil.rmtree(output_path) os.mkdir(output_path) def pre_run(self, clear_output_data=True): @@ -209,7 +211,8 @@ def generate_scenarios(self) -> List["Scenario"]: """ if self.data_loader is None: raise MelodieExceptions.Data.NoDataframeLoaderDefined() - return self.data_loader.generate_scenarios("simulator") + + return self.data_loader.generate_scenarios("Simulator") def run_model( self, config, scenario, id_run, model_class: Type["Model"], visualizer=None @@ -222,6 +225,8 @@ def run_model( ) t0 = time.time() + scenario.id_run = id_run + model: Model = model_class( config, scenario, run_id_in_scenario=id_run, visualizer=visualizer ) @@ -234,20 +239,6 @@ def run_model( self.visualizer.params_manager._initialized = True else: self.visualizer.params_manager.modify_scenario(scenario) - # with open('test.json', 'w') as f: - # json.dump({ - # 'model': self.visualizer.params_manager.to_json(), - # 'params-values': self.visualizer.params_manager.to_value_json() - # }, f, indent=4) - # with open('123.json') as f: - # self.visualizer.params_manager.from_json(json.load(f)) - # with open('out.json', 'w') as f: - # json.dump({ - # 'model': self.visualizer.params_manager.to_json(), - # 'params-values': self.visualizer.params_manager.to_value_json() - # }, f, indent=4) - # self.visualizer.params_manager.modify_scenario(scenario) - # print('aaa', self.visualizer.params_manager.to_value_json()[0]) visualizer.start() else: model._setup() @@ -302,7 +293,6 @@ def run(self): assert self.scenarios is not None, MelodieExceptions.MLD_INTL_EXC for scenario_index, scenario in enumerate(self.scenarios): for id_run in range(scenario.run_num): - # TODO: scenario.id_run = id_run might cause wrong result when parallel executing. self.run_model( self.config, scenario, id_run, self.model_cls, visualizer=None ) diff --git a/MelodieInfra/config/config.py b/MelodieInfra/config/config.py index ea2b4dea..3b34bb66 100644 --- a/MelodieInfra/config/config.py +++ b/MelodieInfra/config/config.py @@ -98,5 +98,4 @@ def output_tables_path(self): If output to database and using sqlite, the output directory will be `data/output` """ - return os.path.join(self.output_folder, - self.project_name) + return os.path.join(self.output_folder) diff --git a/MelodieInfra/db/db.py b/MelodieInfra/db/db.py index 0c910993..89c8a9b9 100644 --- a/MelodieInfra/db/db.py +++ b/MelodieInfra/db/db.py @@ -25,8 +25,7 @@ class DBConn: table_dtypes: Dict[str, TABLE_DTYPES] = {} existing_connections: Dict[str, "DBConn"] = {} - SCENARIO_TABLE = "simulator_scenarios" - ENVIRONMENT_RESULT_TABLE = "environment_result" + ENVIRONMENT_RESULT_TABLE = "Result_Environment" def __init__( self, diff --git a/MelodieInfra/exceptions/exceptions.py b/MelodieInfra/exceptions/exceptions.py index 0bd76b9f..5144b27a 100644 --- a/MelodieInfra/exceptions/exceptions.py +++ b/MelodieInfra/exceptions/exceptions.py @@ -121,7 +121,7 @@ def VariableNotInSet(var_desc: str, var_value: Any, allowed_set: Set[Any]): """ return MelodieException( 1012, - f"Variable {var_desc} is {var_value}, not in allowed set {allowed_set} ", + f"Variable {var_desc} is {repr(var_value)}, not in allowed set {allowed_set} ", ) class Function: diff --git a/MelodieInfra/parallel/parallel_manager.py b/MelodieInfra/parallel/parallel_manager.py index 242f1e54..f282f1a2 100644 --- a/MelodieInfra/parallel/parallel_manager.py +++ b/MelodieInfra/parallel/parallel_manager.py @@ -96,10 +96,14 @@ def run(self, role: str): self.th_server.setDaemon(True) self.th_server.start() for core_id in range(self.cores): + python_path = os.environ.get('PYTHONPATH', "") + # paths = paths + print("python_path", python_path, ":".join(sys.path)) p = subprocess.Popen( [ sys.executable, - os.path.join(os.path.dirname(__file__), "parallel_worker.py"), + os.path.join(os.path.dirname(__file__), + "parallel_worker.py"), "--core_id", str(core_id), "--workdirs", diff --git a/MelodieInfra/table/table_general.py b/MelodieInfra/table/table_general.py index b81d9b14..2c661688 100644 --- a/MelodieInfra/table/table_general.py +++ b/MelodieInfra/table/table_general.py @@ -13,20 +13,27 @@ def vectorize_template(obj): """ -RowType = Dict[str, TypeEngine] +RowType = Dict[str, Optional[TypeEngine]] class GeneralTable(TableBase): data: List[dict] - def __init__(self, row_type: RowType) -> None: + def __init__(self, row_type: RowType, columns: Optional[List[str]] = None) -> None: super().__init__() self._row_type = row_type self._db_model_cls: Type = None self.row_types: Dict[str, Column] = {} + if row_type is not None: - for prop_name, prop_value in row_type.items(): - self.row_types[prop_name] = Column(prop_value) + for prop_name, prop_type in row_type.items(): + self.row_types[prop_name] = Column(prop_type) + + if columns is None: + self._columns = list(row_type.keys()) + else: + assert set(columns) == set(row_type.keys()) + self._columns = columns def create_empty(self): return GeneralTable(self._row_type) @@ -84,7 +91,8 @@ def from_dicts(row_type: RowType, dicts: List[dict], copy=True): assert ( len(dicts) > 0 ), "Initial data must have at least one row for Melodie to detect data type." - row_type = {k: py_types_to_sa_types[type(v)] for k, v in dicts[0].items()} + row_type = {k: py_types_to_sa_types[type( + v)] for k, v in dicts[0].items()} table = GeneralTable(row_type) if not copy: for dic in dicts: diff --git a/MelodieInfra/table/table_objects.py b/MelodieInfra/table/table_objects.py index 4167c20f..20dd1e2c 100644 --- a/MelodieInfra/table/table_objects.py +++ b/MelodieInfra/table/table_objects.py @@ -137,7 +137,7 @@ class TR(TableRow): class Table(TableBase, Generic[TableRowGeneric]): data: List[TableRowGeneric] - def __init__(self, row_type: Type[TableRowGeneric], columns_order: Optional[List[str]] = None) -> None: + def __init__(self, row_type: Type[TableRowGeneric], columns: Optional[List[str]] = None) -> None: super().__init__() self.row_cls: Type[TableRow] = row_type self._db_model_cls: Type = None @@ -153,16 +153,16 @@ def __init__(self, row_type: Type[TableRowGeneric], columns_order: Optional[List raise NotImplementedError( f"Cannot recognize table row type {type(row_type)}" ) - if columns_order is not None: - self.columns_order = columns_order + if columns is not None: + self._columns = columns assert set([row for row in self.row_types.keys()]) == set( - self.columns_order), f"columns_order {self.columns_order} should contain the same names as row_types: {self.row_types.keys()}" + self._columns), f"columns_order {self._columns} should contain the same names as row_types: {self.row_types.keys()}" else: - self.columns_order = [row for row in self.row_types.keys()] + self._columns = [row for row in self.row_types.keys()] @property def columns(self): - return self.columns_order + return self._columns def clear(self): self.data = [] diff --git a/tests/infra/resources/.gitignore b/tests/infra/resources/.gitignore new file mode 100644 index 00000000..4029495d --- /dev/null +++ b/tests/infra/resources/.gitignore @@ -0,0 +1 @@ +/output \ No newline at end of file diff --git a/tests/infra/resources/output/.gitignore b/tests/infra/resources/output/.gitignore deleted file mode 100644 index 3bb50f71..00000000 --- a/tests/infra/resources/output/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -*.csv -/temp_* \ No newline at end of file diff --git a/tests/infra/resources/output/placeholder.txt b/tests/infra/resources/output/placeholder.txt deleted file mode 100644 index 106a1928..00000000 --- a/tests/infra/resources/output/placeholder.txt +++ /dev/null @@ -1 +0,0 @@ -this file acts as a place holder for git to contain the parent folder \ No newline at end of file diff --git a/tests/infra/test_data_collector.py b/tests/infra/test_data_collector.py index 0bb3c6e2..df73f0db 100644 --- a/tests/infra/test_data_collector.py +++ b/tests/infra/test_data_collector.py @@ -1,6 +1,7 @@ # -*- coding:utf-8 -*- +import os import random from typing import List @@ -14,9 +15,26 @@ Simulator, Model, DataLoader, + Config +) +cfg_for_temp = Config( + "temp_db_created", + os.path.dirname(__file__), + input_folder=os.path.join(os.path.dirname( + __file__), "resources", "excels"), + output_folder=os.path.join(os.path.dirname( + __file__), "resources", "output"), + data_output_type="sqlite" +) +cfg_for_calibrator = Config( + "temp_db_calibrator", + os.path.dirname(__file__), + input_folder=os.path.join(os.path.dirname( + __file__), "resources", "excels"), + output_folder=os.path.join(os.path.dirname( + __file__), "resources", "output"), + data_output_type="sqlite" ) -from tests.infra.config import cfg_for_temp, cfg_for_calibrator - AGENT_NUM_1 = 10 AGENT_NUM_2 = 20 @@ -54,9 +72,11 @@ def setup(self): # params_df_3 = pd.DataFrame( # [{"a": 1.0, "b": 1, "productivity": 0} for i in range(20)] # ) - self.agent_list1 = self.create_agent_container(TestAgent, 10, params_df_1) + self.agent_list1 = self.create_agent_container( + TestAgent, 10, params_df_1) self.agent_list1.setup_agents(10, params_df_1) - self.agent_list2 = self.create_agent_container(TestAgent, 20, params_df_2) + self.agent_list2 = self.create_agent_container( + TestAgent, 20, params_df_2) self.agent_list2.setup_agents(20, params_df_2) self.environment = self.create_environment(TestEnv) self.data_collector = self.create_data_collector(DataCollector1) diff --git a/tests/infra/test_data_collector_csv.py b/tests/infra/test_data_collector_csv.py index 2e3dd43c..33004b90 100644 --- a/tests/infra/test_data_collector_csv.py +++ b/tests/infra/test_data_collector_csv.py @@ -70,7 +70,8 @@ def setup(self): TestAgent, 20, params_df_2) self.agent_list2.setup_agents(20, params_df_2) self.environment = self.create_environment(TestEnv) - self.data_collector = self.create_data_collector(DataCollector1) + self.data_collector: DataCollector1[DCTestModel] = self.create_data_collector( + DataCollector1) class Simulator4Test(Simulator): @@ -100,6 +101,16 @@ def setup(self): self.add_agent_property("agent_list1", "a") self.add_agent_property("agent_list2", "b") + def get_data(model: DCTestModel): + return {"agent1_num": len(model.agent_list1), "agent2_num": len(model.agent_list2)} + + def get_data2(model: DCTestModel): + return [{"agent1_num": len(model.agent_list1), "agent2_num": len(model.agent_list2)} for i in range(2)] + self.add_custom_collector("my_collector", get_data, [ + 'agent1_num', "agent2_num"]) + self.add_custom_collector("my_collector2", get_data2, [ + 'agent1_num', "agent2_num"]) + def test_model_run(): global data_collector @@ -118,4 +129,7 @@ def test_model_run(): dc.collect(1) assert len(dc.agent_properties_dict["agent_list1"]) == AGENT_NUM_1 * 2 assert len(dc.agent_properties_dict["agent_list2"]) == AGENT_NUM_2 * 2 + + assert len(dc._custom_collected_data['my_collector'].data) == 2 + assert len(dc._custom_collected_data['my_collector2'].data) == 4 dc.save() diff --git a/tests/infra/test_typing.py b/tests/infra/test_typing.py new file mode 100644 index 00000000..d376a1b8 --- /dev/null +++ b/tests/infra/test_typing.py @@ -0,0 +1,17 @@ +from typing import Callable, Dict, Any + +class Model: + pass + +class DCTestModel(Model): + pass + +def add_custom_collector(row_collector: Callable[[Model], Dict[str, Any]]) -> None: + # 函数实现 + pass + +def my_collector(model: DCTestModel) -> Dict[str, int]: + # 自定义收集器实现 + pass + +add_custom_collector(my_collector)