Skip to content

Commit

Permalink
changed the dataloader parameter of trainer/calibrator to optional
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mole committed Mar 2, 2024
1 parent bfe4828 commit b7bcc38
Show file tree
Hide file tree
Showing 14 changed files with 77 additions and 39 deletions.
9 changes: 3 additions & 6 deletions Melodie/calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def __init__(
config: "Config",
scenario_cls: "Optional[Type[Scenario]]",
model_cls: "Optional[Type[Model]]",
data_loader_cls: Type["DataLoader"],
data_loader_cls: Type["DataLoader"] = None,
processors=1,
):
"""
Expand Down Expand Up @@ -433,7 +433,6 @@ def __init__(
self.model: Optional[Model] = None

self.current_algorithm_meta = GACalibratorAlgorithmMeta()
self.df_loader_cls = data_loader_cls

def setup(self):
"""
Expand All @@ -454,9 +453,7 @@ def generate_scenarios(self) -> List["Scenario"]:
:return: A list of generated scenarios.
"""
return self.data_loader.generate_scenarios_from_dataframe(
"calibrator_scenarios"
)
return self.data_loader.generate_scenarios("Calibrator")

def get_params_scenarios(self) -> List:
"""
Expand All @@ -465,7 +462,7 @@ def get_params_scenarios(self) -> List:
:return: A list of dict, and each dict contains parameters.
"""

calibrator_scenarios_table = self.get_dataframe("calibrator_params_scenarios")
calibrator_scenarios_table = self.get_dataframe("CalibratorParamsScenarios")
assert isinstance(
calibrator_scenarios_table, pd.DataFrame
), "No learning scenarios table specified!"
Expand Down
6 changes: 5 additions & 1 deletion Melodie/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,11 @@ def load_dataframe(self, df_info: Union[str, "DataFrameInfo"], df_name=""):
assert self.manager is not None, MelodieExceptions.MLD_INTL_EXC
assert self.manager.data_loader is not None, MelodieExceptions.MLD_INTL_EXC
if isinstance(df_info, str):
df_name = df_name if df_name != "" else os.path.basename(df_info)
df_name = (
df_name
if df_name != ""
else os.path.splitext(os.path.basename(df_info))[0]
)
info = DataFrameInfo(df_name, {}, df_info)
return self._load_dataframe(info)
else:
Expand Down
4 changes: 2 additions & 2 deletions Melodie/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def set_params(self, params: Dict[str, Any], asserts_key_exist=True):
if paramName not in self.__dict__.keys():
if asserts_key_exist:
raise ValueError(msg)
else:
warnings.warn(msg)
# else:
# warnings.warn(msg)

setattr(self, paramName, paramValue)
24 changes: 18 additions & 6 deletions Melodie/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Optional, Union, Type, List
from typing import Optional, Union, Type, List, TypeVar

from MelodieInfra import (
create_db_conn,
Expand Down Expand Up @@ -31,6 +31,12 @@

logger = logging.getLogger(__name__)

EnvironmentType = TypeVar("EnvironmentType", bound=Environment)
AgentType = TypeVar("AgentType", bound=Agent)
GridType = TypeVar("GridType", bound=Grid)
SpotType = TypeVar("SpotType", bound=Spot)
NetworkType = TypeVar("NetworkType", bound=Network)


class ModelRunRoutine:
"""
Expand Down Expand Up @@ -139,8 +145,8 @@ def create_db_conn(self) -> "DBConn":

def create_agent_list(
self,
agent_class: Type["Agent"],
):
agent_class: Type[AgentType],
) -> AgentList[AgentType]:
"""
Create an agent list object. A model could contain multiple ``AgentList``s.
Expand All @@ -149,7 +155,7 @@ def create_agent_list(
"""
return AgentList(agent_class, model=self)

def create_environment(self, env_class: Type["Environment"]):
def create_environment(self, env_class: Type[EnvironmentType]) -> EnvironmentType:
"""
Create the environment of model. Notice that a model has only one environment.
Expand All @@ -162,7 +168,11 @@ def create_environment(self, env_class: Type["Environment"]):
self.initialization_queue.append(env)
return env

def create_grid(self, grid_cls: Type["Grid"] = None, spot_cls: Type["Spot"] = None):
def create_grid(
self,
grid_cls: Optional[Type[GridType]] = None,
spot_cls: Optional[Type[SpotType]] = None,
) -> GridType:
"""
Create a grid.
Expand All @@ -177,7 +187,9 @@ def create_grid(self, grid_cls: Type["Grid"] = None, spot_cls: Type["Spot"] = No
return grid

def create_network(
self, network_cls: Type["Network"] = None, edge_cls: Type["Edge"] = None
self,
network_cls: Optional[Type[NetworkType]] = None,
edge_cls: Type[Edge] = None,
):
"""
Create the network of model.
Expand Down
1 change: 0 additions & 1 deletion Melodie/scenario_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def _setup(self, data: dict = None):
if data is not None:
for col_name in data.keys():
setattr(self, col_name, data[col_name])
# self.load()
self.load_data()
self.setup_data()

Expand Down
4 changes: 2 additions & 2 deletions Melodie/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def __init__(
config: "Config",
scenario_cls: "Optional[Type[Scenario]]",
model_cls: "Optional[Type[Model]]",
data_loader_cls: "Optional[Type[DataLoader]]",
data_loader_cls: "Optional[Type[DataLoader]]" = None,
processors: int = 1,
):
"""
Expand Down Expand Up @@ -743,7 +743,7 @@ def generate_trainer_params_list(
:return: A list of trainer parameters.
"""

trainer_params_table = self.get_dataframe("trainer_params_scenarios")
trainer_params_table = self.get_dataframe("TrainerParamsScenarios")
assert isinstance(
trainer_params_table, pd.DataFrame
), "No learning scenarios table specified!"
Expand Down
3 changes: 1 addition & 2 deletions MelodieInfra/core/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import List
from .types import Optional, TYPE_CHECKING, Dict, Any
from typing import List, Optional, TYPE_CHECKING, Dict, Any


class Element:
Expand Down
18 changes: 11 additions & 7 deletions MelodieInfra/core/agent_list.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
# import random
import logging
import random
from typing import Any, Callable, TYPE_CHECKING

from .types import (
from typing import (
Any,
Callable,
TYPE_CHECKING,
Generic,
ClassVar,
List,
Dict,
Union,
Set,
TypeVar,
Type,
)

import pandas as pd
Expand Down Expand Up @@ -39,7 +42,7 @@ def __next__(self):
return next_item


class BaseAgentContainer:
class BaseAgentContainer(Generic[AgentGeneric]):
"""
The base class that contains agents
"""
Expand Down Expand Up @@ -68,16 +71,17 @@ def to_list(self, column_names: List[str] = None) -> List[Dict]:
:param column_names: property names
"""
raise NotImplementedError

def get_agent(self, agent_id: int) -> "AgentGeneric":
raise NotImplementedError


class AgentList(BaseAgentContainer):
def __init__(self, agent_class: "ClassVar[AgentGeneric]", model: "Model") -> None:
class AgentList(BaseAgentContainer, Generic[AgentGeneric]):
def __init__(self, agent_class: "Type[AgentGeneric]", model: "Model") -> None:
super().__init__()
self.scenario = model.scenario
self.agent_class: "ClassVar[AgentGeneric]" = agent_class
self.agent_class: "Type[AgentGeneric]" = agent_class
self.model = model
self.indices = {}
self.agents: List[AgentGeneric] = []
Expand Down
2 changes: 1 addition & 1 deletion MelodieInfra/core/environment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .types import List, Dict, Optional
from typing import List, Dict, Optional
from .agent import Element


Expand Down
2 changes: 1 addition & 1 deletion MelodieInfra/core/grid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from MelodieInfra.core.agent_list import AgentList
from .api import floor, randint, _random as random, iterable, lru_cache
from .types import ClassVar, Dict, List, Tuple
from typing import ClassVar, Dict, List, Tuple

from .agent import Agent

Expand Down
1 change: 0 additions & 1 deletion MelodieInfra/core/types.py

This file was deleted.

18 changes: 10 additions & 8 deletions MelodieInfra/parallel/parallel_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,11 @@ def sub_routine_trainer(
chrom, d, agent_params = worker.get_task()
logger.debug(f"processor {proc_id} got chrom {chrom}")
scenario = scenario_cls()
scenario._setup()
scenario.set_params(d)
model = model_cls(config, scenario)
scenario.manager = trainer
scenario._setup(d)
# scenario.set_params(d)
model = model_cls(config, scenario)

model.create()
model._setup()
# {category: [{id: 0, param1: 1, param2: 2, ...}]}
Expand Down Expand Up @@ -217,9 +218,10 @@ def sub_routine_calibrator(
logger.debug(f"processor {proc_id} got chrom {chrom}")
scenario: Scenario = scenario_cls()
scenario.manager = calibrator
scenario._setup()
scenario.set_params(d)
scenario.set_params(env_params)
# scenario.set_params(d, asserts_key_exist=False)
# scenario.set_params(env_params, asserts_key_exist=False)
scenario._setup(d)

model = model_cls(config, scenario)
model.create()
model._setup()
Expand Down Expand Up @@ -290,8 +292,8 @@ def sub_routine_simulator(
logger.debug(f"processor {proc_id} got id_run {id_run}")
scenario: Scenario = scenario_cls()
scenario.manager = simulator
scenario._setup()
scenario.set_params(d, asserts_key_exist=False)
# scenario.set_params(d, asserts_key_exist=False)
scenario._setup(d)
# scenario.set_params(env_params)
model = model_cls(config, scenario, run_id_in_scenario=id_run)

Expand Down
22 changes: 22 additions & 0 deletions temp.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
1. simulator/calibrator/trainer的输出表格名称,风格与dating market相同

- Result_<驼峰型>
- Result_Simulator_Environment
- Result_Simulator_Men
- Result_Simulator_Women

- Result_Calibrator_EnvironmentCov
- Result_Calibrator_Environment

- Result_Trainer_Players.csv (players为agentlist的变量名)
- Result_Trainer_PlayersCov.csv (players为agentlist的变量名)
- Result_Trainer_Environment.csv
- Result_Trainer_EnvironmentCov.csv

2. input下面的SimulatorScenarios.xlsx,
CalibratorScenarios, TrainerScenarios

3. save_dataframe函数,放在Melodie里面。

4. load_dataframe函数要对calibrator和trainer也生效
- 修改例子的内容,去掉dataloader
2 changes: 1 addition & 1 deletion tests/procedures/test_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ class DFLoader(DataLoader):
@pytest.mark.timeout(15)
def test_calibrator():
calibrator = CovidCalibrator(
cfg_for_calibrator, CovidScenario, CovidModel, DFLoader, processors=2
cfg_for_calibrator, CovidScenario, CovidModel, processors=2
)
calibrator.run()

0 comments on commit b7bcc38

Please sign in to comment.