Skip to content

Commit

Permalink
added shortcut to register dataframe in scenario
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mole committed Jan 21, 2024
1 parent 1f3b2e2 commit ff54e91
Show file tree
Hide file tree
Showing 16 changed files with 245 additions and 77 deletions.
84 changes: 72 additions & 12 deletions Melodie/data_collector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
import time
from typing import Callable, List, TYPE_CHECKING, Dict, Tuple, Any, Optional, Type
from typing import Callable, Generic, List, TYPE_CHECKING, Dict, NewType, Tuple, Any, Optional, Type, TypeVar, Union, cast
import pandas as pd
import collections
import sqlalchemy
Expand All @@ -11,6 +11,7 @@
MelodieExceptions,
is_pypy,
Table,
GeneralTable,
TableRow,
objs_to_table_row_vectorizer,
)
Expand All @@ -20,6 +21,7 @@

if TYPE_CHECKING:
from Melodie import Model, Scenario, BaseAgentContainer, AgentList
M = TypeVar('M', bound=Model)


class PropertyToCollect:
Expand All @@ -34,6 +36,10 @@ def vectorize_template(obj):
"""


def underline_to_camel(s: str):
return ''.join(word.capitalize() for word in s.split('_'))


def vectorizer(attrs):
code = VEC_TEMPLATE.format(exprs=",".join(
[f'obj["{attr}"]' for attr in attrs]))
Expand All @@ -45,7 +51,7 @@ def vectorizer(attrs):
vectorizers = {}


class DataCollector:
class DataCollector():
"""
Data Collector collects data in the model.
Expand All @@ -56,6 +62,7 @@ class DataCollector:
Before the model running finished, the DataCollector dumps data to dataframe, and save to
database.
"""
_CORE_PROPERTIES_ = ['id_scenario', 'id_run', 'period']

def __init__(self, target="sqlite"):
"""
Expand All @@ -75,6 +82,9 @@ def __init__(self, target="sqlite"):

self.agent_properties_dict: Dict[str, Table] = {}
self.environment_properties_list: Dict[str, Table] = None
self._custom_collectors: Dict[str,
Tuple[Callable[[Model], Dict[str, Any]], List[str]]] = {}
self._custom_collected_data: Dict[str, GeneralTable] = {}

self._time_elapsed = 0

Expand Down Expand Up @@ -132,6 +142,18 @@ def add_environment_property(self, property_name: str, as_type: Type = None):
PropertyToCollect(property_name, as_type)
)

def add_custom_collector(self, table_name: str, row_collector:
"Callable[[M], Union[Dict[str, Any], List[Dict[str, Any]]]]", columns: List[str]):
"""
Add a custom data collector to generate a standalone table.
:param table_name: The name of table storing the data collected.
:param row_collector: A callable function, returning a `dict` computed from `Model` forming one
row in the table.
"""
self._custom_collectors[table_name] = (
cast(Any, row_collector), columns)

def env_property_names(self) -> List[str]:
"""
Get the environment property names to collect
Expand Down Expand Up @@ -177,6 +199,33 @@ def collect_agent_properties(self, period: int):
container_name, agent_property_names[container_name], container, period
)

def collect_custom_properties(self, period: int):
"""
Collect custom properties by calling custom callbacks.
:param period: Current simulation step
:return: None
"""
for collector_name in self._custom_collectors.keys():
self.collect_single_custom_property(collector_name, period)

def collect_single_custom_property(self, collector_name: str, period: int):
collector_func, column_names = self._custom_collectors[collector_name]
if collector_name not in self._custom_collected_data:
self._custom_collected_data[collector_name] = GeneralTable(
{k: None for k in column_names}, column_names)
assert self.model is not None
data = collector_func(self.model)
if isinstance(data, list):
for item in data:
assert isinstance(item, dict)
self._custom_collected_data[collector_name].data.append(item)
elif isinstance(data, dict):
self._custom_collected_data[collector_name].data.append(data)
else:
raise NotImplementedError(
"Data collector function should return a list or dict")

def append_agent_properties_by_records(
self,
container_name: str,
Expand All @@ -190,6 +239,7 @@ def append_agent_properties_by_records(
:return: None
"""
assert self.model is not None
id_run, id_scenario = self.model.run_id_in_scenario, self.model.scenario.id
if container_name not in self.agent_properties_dict:
if len(container) == 0:
Expand All @@ -206,7 +256,7 @@ def append_agent_properties_by_records(

row_cls = TableRow.subcls_from_dict(props)
self.agent_properties_dict[container_name] = Table(
row_cls, ['id_scenario', 'id_run', 'period', 'id'] + prop_names)
row_cls, self._CORE_PROPERTIES_+['id'] + prop_names)
self._agent_properties_collectors[
container_name
] = objs_to_table_row_vectorizer(row_cls, prop_names)
Expand All @@ -223,6 +273,9 @@ def append_agent_properties_by_records(
props_list.append(row)

def append_environment_properties(self, period: int):
assert self.model is not None
assert self.model.environment is not None
assert self.model.scenario is not None
env_dic = {
"id_scenario": self.model.scenario.id,
"id_run": self.model.run_id_in_scenario,
Expand All @@ -233,7 +286,7 @@ def append_environment_properties(self, period: int):
if self.environment_properties_list is None:
row_cls = TableRow.subcls_from_dict(env_dic)
self.environment_properties_list = Table(
row_cls, ['id_scenario', 'id_run', 'period']+self.env_property_names())
row_cls, self._CORE_PROPERTIES_+self.env_property_names())
self.environment_properties_list.append_from_dicts([env_dic])

@property
Expand Down Expand Up @@ -263,6 +316,7 @@ def collect(self, period: int) -> None:
t0 = time.time()
self.append_environment_properties(period)
self.collect_agent_properties(period)
self.collect_custom_properties(period)
t1 = time.time()

self._time_elapsed += t1 - t0
Expand Down Expand Up @@ -304,7 +358,7 @@ def get_single_agent_data(self, agent_container_name: str, agent_id: int):
container_data = self.agent_properties_dict[agent_container_name]
return list(filter(lambda item: item["id"] == agent_id, container_data))

def _write_list_to_table(self, engine, table_name: str, data: Table):
def _write_list_to_table(self, engine, table_name: str, data: Union[Table, GeneralTable]):
"""
Write a list of dict into database.
Expand Down Expand Up @@ -338,22 +392,28 @@ def save(self):
DBConn.ENVIRONMENT_RESULT_TABLE,
self.environment_properties_list,
)
# self._write_list_to_table(
# connection.get_engine(),
# DBConn.ENVIRONMENT_RESULT_TABLE,
# self.environment_properties_list,
# )
self.environment_properties_list = None
write_db_time += time.time() - _t

for container_name in self.agent_properties_dict.keys():
_t = time.time()
self._write_list_to_table(
connection.get_engine(),
container_name + "_result",
# "agent_list" -> "AgentList"
# "agent_list" -> "Agent_list"
# "Agent"
"Result_"+underline_to_camel(container_name),
self.agent_properties_dict[container_name],
)
# print("wrote agent properties!", container_name+"_result")
write_db_time += time.time() - _t

for custom_table_name in self._custom_collected_data.keys():
_t = time.time()
self._write_list_to_table(
connection.get_engine(),
custom_table_name,
self._custom_collected_data[custom_table_name]
)
write_db_time += time.time() - _t
self.agent_properties_dict = {}

Expand Down
89 changes: 79 additions & 10 deletions Melodie/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import shutil
from typing import Optional, Dict, List, Union, Callable, Type, TYPE_CHECKING
import numpy as np

import sqlalchemy
import cloudpickle
Expand All @@ -23,6 +24,19 @@
logger = logging.getLogger(__name__)


def first_char_upper(s: str) -> str:
if len(s) >= 1:
return s[0].upper()+s[1:]
else:
return s

# TODO: move this function to utils


def underline_to_camel(s):
return ''.join(first_char_upper(word) for word in s.split('_'))


class DataFrameInfo:
"""
DataFrameInfo provides standard format for input tables as parameters.
Expand Down Expand Up @@ -134,8 +148,51 @@ def __init__(
self.registered_matrices: Dict[str, "np.ndarray"] = {}
self.manager = manager
self.manager.data_loader = self
self.load_scenarios()
self.setup()

def load_scenarios(self):
for file_name in os.listdir(self.config.input_folder):
camel_case = underline_to_camel(os.path.splitext(file_name)[0])

if camel_case in ("SimulatorScenarios", "TrainerScenarios", "CalibratorScenarios"):
self.load_dataframe(file_name, camel_case)

def load_dataframe(self, df_info: Union[str, "DataFrameInfo"], df_name=""):
"""
Load a data frame from table file.
:df_info: The file name of that containing the data frame, or pass a `DataFrameInfo`
"""
from .data_loader import DataFrameInfo

assert self.manager is not None, MelodieExceptions.MLD_INTL_EXC
assert self.manager.data_loader is not None, MelodieExceptions.MLD_INTL_EXC
if isinstance(df_info, str):
df_name = df_name if df_name != "" else os.path.basename(df_info)
info = DataFrameInfo(df_name, {}, df_info)
return self._load_dataframe(info)
else:
return self._load_dataframe(df_info)

def load_matrix(self, mat_info: Union[str, "MatrixInfo"], mat_name="") -> np.ndarray:
"""
Load a matrix from table file.
:mat_info: The file name of that containing the matrix, or pass a `DataFrameInfo`
"""
from .data_loader import MatrixInfo

assert self.manager is not None, MelodieExceptions.MLD_INTL_EXC
assert self.manager.data_loader is not None, MelodieExceptions.MLD_INTL_EXC
if isinstance(mat_info, str):
mat_name = mat_name if mat_name != "" else os.path.basename(
mat_info)
info = MatrixInfo(mat_name, None, mat_info)
return self.manager.data_loader._load_matrix(info)
else:
return self.manager.data_loader._load_matrix(mat_info)

def setup(self):
pass

Expand Down Expand Up @@ -209,7 +266,7 @@ def _load_dataframe_cached(self, file_path_abs: str) -> "pd.DataFrame":
cloudpickle.dump(table, f)
return table

def load_dataframe(self, df_info: "DataFrameInfo") -> "pd.DataFrame":
def _load_dataframe(self, df_info: "DataFrameInfo") -> "pd.DataFrame":
"""
Register static table. The static table will be copied into database.
Expand All @@ -219,24 +276,27 @@ def load_dataframe(self, df_info: "DataFrameInfo") -> "pd.DataFrame":
"""

table: Optional["pd.DataFrame"]

MelodieExceptions.Data.TableNameInvalid(df_info.df_name)
if df_info.df_name in self.registered_dataframes:
return self.registered_dataframes[df_info.df_name]
assert df_info.file_name is not None
file_path_abs = os.path.join(
self.config.input_folder, df_info.file_name)
# if df_info.engine == "pandas":

table = self._load_dataframe_cached(file_path_abs)

self.registered_dataframes[df_info.df_name] = table

return table

def load_matrix(self, matrix_info: "MatrixInfo") -> "np.ndarray":
def _load_matrix(self, matrix_info: "MatrixInfo") -> "np.ndarray":
"""
Register static matrix.
:return: None
"""
if matrix_info.mat_name in self.registered_matrices:
return self.registered_matrices[matrix_info.mat_name]

assert matrix_info.file_name is not None
_, ext = os.path.splitext(matrix_info.file_name)
file_path_abs = os.path.join(
Expand All @@ -255,7 +315,7 @@ def load_matrix(self, matrix_info: "MatrixInfo") -> "np.ndarray":

def dataframe_generator(
self,
df_info: DataFrameInfo,
df_info: Union[str, DataFrameInfo],
rows_in_scenario: Union[int, Callable[[Scenario], int]],
) -> DataFrameGenerator:
"""
Expand Down Expand Up @@ -304,9 +364,18 @@ def generate_scenarios(self, manager_type: str) -> List["Scenario"]:
:param manager_type: The type of scenario manager, a ``str`` in "simulator", "trainer" or "calibrator".
:return: A list of scenarios.
"""
if manager_type not in {"simulator", "trainer", "calibrator"}:
MelodieExceptions.Program.Variable.VariableNotInSet(
if manager_type not in {"Simulator", "Trainer", "Calibrator"}:
raise MelodieExceptions.Program.Variable.VariableNotInSet(
"manager_type", manager_type, {
"simulator", "trainer", "calibrator"}
"Simulator", "Trainer", "Calibrator"}
)
return self.generate_scenarios_from_dataframe(f"{manager_type}_scenarios")

df_name = f"{manager_type}Scenarios"

if df_name in self.registered_dataframes:
return self.generate_scenarios_from_dataframe(df_name)
elif underline_to_camel(df_name) in self.registered_dataframes:
return self.generate_scenarios_from_dataframe(underline_to_camel(df_name))
else:
raise NotImplementedError(
f"{manager_type}/{underline_to_camel(df_name)} is not supported!")
14 changes: 2 additions & 12 deletions Melodie/scenario_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,31 +72,21 @@ def load_dataframe(self, df_info: Union[str, "DataFrameInfo"]):
:df_info: The file name of that containing the data frame, or pass a `DataFrameInfo`
"""
from .data_loader import DataFrameInfo

assert self.manager is not None, MelodieExceptions.MLD_INTL_EXC
assert self.manager.data_loader is not None, MelodieExceptions.MLD_INTL_EXC
if isinstance(df_info, str):
info = DataFrameInfo(os.path.basename(df_info), {}, df_info)
return self.manager.data_loader.load_dataframe(info)
else:
return self.manager.data_loader.load_dataframe(df_info)
return self.manager.data_loader.load_dataframe(df_info)

def load_matrix(self, mat_info: Union[str, "MatrixInfo"]) -> np.ndarray:
"""
Load a matrix from table file.
:mat_info: The file name of that containing the matrix, or pass a `DataFrameInfo`
"""
from .data_loader import MatrixInfo

assert self.manager is not None, MelodieExceptions.MLD_INTL_EXC
assert self.manager.data_loader is not None, MelodieExceptions.MLD_INTL_EXC
if isinstance(mat_info, str):
info = MatrixInfo(os.path.basename(mat_info), None, mat_info)
return self.manager.data_loader.load_matrix(info)
else:
return self.manager.data_loader.load_matrix(mat_info)
return self.manager.data_loader.load_matrix(mat_info)

def to_dict(self):
"""
Expand Down
Loading

0 comments on commit ff54e91

Please sign in to comment.