Skip to content

Commit

Permalink
reformatted code
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mole committed Mar 2, 2024
1 parent 1d02fd2 commit b9013aa
Show file tree
Hide file tree
Showing 63 changed files with 5,715 additions and 6,773 deletions.
3 changes: 1 addition & 2 deletions Melodie/algorithms/ga.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import numpy as np

try:
from sko.GA import GA
Expand All @@ -16,8 +17,6 @@ def __init__(self, *args, **kwargs):

class MelodieGA(GA):
def run(self, max_iter=None):
import numpy as np

self.max_iter = max_iter or self.max_iter
best = []
for i in range(self.max_iter):
Expand Down
43 changes: 22 additions & 21 deletions Melodie/calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
Type,
Iterator,
cast,
Literal
Literal,
)
import pandas as pd

Expand Down Expand Up @@ -196,8 +196,7 @@ def get_params(self, id_chromosome: int) -> Dict[str, Any]:
:param id_chromosome:
:return:
"""
chromosome_value = self.algorithm.chrom2x(
self.algorithm.Chrom)[id_chromosome]
chromosome_value = self.algorithm.chrom2x(self.algorithm.Chrom)[id_chromosome]
env_parameters_dict = {}
for i, param_name in enumerate(self.env_param_names):
env_parameters_dict[param_name] = chromosome_value[i]
Expand All @@ -214,8 +213,7 @@ def target_function_to_cache(
:return:
"""
self.cache[(generation, id_chromosome)
] = env_data["target_function_value"]
self.cache[(generation, id_chromosome)] = env_data["target_function_value"]

def generate_target_function(self) -> Callable[[], float]:
"""
Expand All @@ -226,8 +224,7 @@ def generate_target_function(self) -> Callable[[], float]:

def f(*args):
self._chromosome_counter += 1
value = self.cache[(self._current_generation,
self._chromosome_counter)]
value = self.cache[(self._current_generation, self._chromosome_counter)]
return value

return f
Expand Down Expand Up @@ -260,14 +257,18 @@ def record_agent_properties(
d.update(agent_container_data)

agent_records[container_name].append(d)
self.manager._write_to_table("csv", f"{container_name}_calibrator_result", pd.DataFrame(
agent_records[container_name]))
self.manager._write_to_table(
"csv",
f"{container_name}_calibrator_result",
pd.DataFrame(agent_records[container_name]),
)
environment_record.update(meta_dict)
environment_record.update(env_data)
environment_record.pop("target_function_value")

self.manager._write_to_table(
"csv", "environment_calibrator_result", pd.DataFrame([environment_record]))
"csv", "environment_calibrator_result", pd.DataFrame([environment_record])
)
return agent_records, environment_record

def calc_cov_df(
Expand Down Expand Up @@ -309,20 +310,23 @@ def calc_cov_df(
)
container_agent_record_list.append(cov_records)

self.manager._write_to_table("csv", f"{container_name}_calibrator_result_cov",
pd.DataFrame(container_agent_record_list))
self.manager._write_to_table(
"csv",
f"{container_name}_calibrator_result_cov",
pd.DataFrame(container_agent_record_list),
)
env_record = {}
env_record.update(meta_dict)
for prop_name in (
self.env_param_names + self.recorded_env_properties + ["distance"]
):
mean = env_df[prop_name].mean()
cov = env_df[prop_name].std() / env_df[prop_name].mean()
env_record.update(
{prop_name + "_mean": mean, prop_name + "_cov": cov})
env_record.update({prop_name + "_mean": mean, prop_name + "_cov": cov})

self.manager._write_to_table("csv", "environment_calibrator_result_cov",
pd.DataFrame([env_record]))
self.manager._write_to_table(
"csv", "environment_calibrator_result_cov", pd.DataFrame([env_record])
)

def pre_check(self, meta):
"""
Expand All @@ -338,7 +342,6 @@ def pre_check(self, meta):
)

def run(self, scenario: Scenario, meta: Union[GACalibratorAlgorithmMeta]):

self.pre_check(meta)

for i in range(self.params.generation_num):
Expand Down Expand Up @@ -378,8 +381,7 @@ def run(self, scenario: Scenario, meta: Union[GACalibratorAlgorithmMeta]):
self.target_function_to_cache(env_data, i, chrom)

self.calc_cov_df(
{k: pd.DataFrame(v)
for k, v in agent_records_collector.items()},
{k: pd.DataFrame(v) for k, v in agent_records_collector.items()},
pd.DataFrame(env_records_list),
meta,
)
Expand Down Expand Up @@ -463,8 +465,7 @@ def get_params_scenarios(self) -> List:
:return: A list of dict, and each dict contains parameters.
"""

calibrator_scenarios_table = self.get_dataframe(
"calibrator_params_scenarios")
calibrator_scenarios_table = self.get_dataframe("calibrator_params_scenarios")
assert isinstance(
calibrator_scenarios_table, pd.DataFrame
), "No learning scenarios table specified!"
Expand Down
87 changes: 53 additions & 34 deletions Melodie/data_collector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
import logging
import os
import time
from typing import Callable, Generic, List, TYPE_CHECKING, Dict, NewType, Tuple, Any, Optional, Type, TypeVar, Union, cast
from typing import (
Callable,
Generic,
List,
TYPE_CHECKING,
Dict,
NewType,
Tuple,
Any,
Optional,
Type,
TypeVar,
Union,
cast,
)
import pandas as pd
import collections
import sqlalchemy
Expand All @@ -22,7 +36,8 @@

if TYPE_CHECKING:
from Melodie import Model, Scenario, BaseAgentContainer, AgentList
M = TypeVar('M', bound=Model)

M = TypeVar("M", bound=Model)


class PropertyToCollect:
Expand All @@ -38,12 +53,11 @@ def vectorize_template(obj):


def underline_to_camel(s: str):
return ''.join(word.capitalize() for word in s.split('_'))
return "".join(word.capitalize() for word in s.split("_"))


def vectorizer(attrs):
code = VEC_TEMPLATE.format(exprs=",".join(
[f'obj["{attr}"]' for attr in attrs]))
code = VEC_TEMPLATE.format(exprs=",".join([f'obj["{attr}"]' for attr in attrs]))
d = {}
exec(code, None, d)
return d["vectorize_template"]
Expand All @@ -52,7 +66,7 @@ def vectorizer(attrs):
vectorizers = {}


class DataCollector():
class DataCollector:
"""
Data Collector collects data in the model.
Expand All @@ -63,7 +77,8 @@ class DataCollector():
Before the model running finished, the DataCollector dumps data to dataframe, and save to
database.
"""
_CORE_PROPERTIES_ = ['id_scenario', 'id_run', 'period']

_CORE_PROPERTIES_ = ["id_scenario", "id_run", "period"]

def __init__(self, target="sqlite"):
"""
Expand All @@ -76,16 +91,15 @@ def __init__(self, target="sqlite"):
self.config: Optional[Config] = None
self.model: Optional[Model] = None
self.scenario: Optional["Scenario"] = None
self._agent_properties_to_collect: Dict[str,
List[PropertyToCollect]] = {}
self._agent_properties_collectors: Dict[str, Callable[[
object], object]] = {}
self._agent_properties_to_collect: Dict[str, List[PropertyToCollect]] = {}
self._agent_properties_collectors: Dict[str, Callable[[object], object]] = {}
self._environment_properties_to_collect: List[PropertyToCollect] = []

self.agent_properties_dict: Dict[str, Table] = {}
self.environment_properties_list: Dict[str, Table] = None
self._custom_collectors: Dict[str,
Tuple[Callable[[Model], Dict[str, Any]], List[str]]] = {}
self._custom_collectors: Dict[
str, Tuple[Callable[[Model], Dict[str, Any]], List[str]]
] = {}
self._custom_collected_data: Dict[str, GeneralTable] = {}

self._time_elapsed = 0
Expand Down Expand Up @@ -124,8 +138,7 @@ def add_agent_property(
:return:
"""
if not hasattr(self.model, container_name):
raise AttributeError(
f"Model has no agent container '{container_name}'")
raise AttributeError(f"Model has no agent container '{container_name}'")
if container_name not in self._agent_properties_to_collect.keys():
self._agent_properties_to_collect[container_name] = []
self._agent_properties_to_collect[container_name].append(
Expand All @@ -144,17 +157,20 @@ def add_environment_property(self, property_name: str, as_type: Type = None):
PropertyToCollect(property_name, as_type)
)

def add_custom_collector(self, table_name: str, row_collector:
"Callable[[M], Union[Dict[str, Any], List[Dict[str, Any]]]]", columns: List[str]):
def add_custom_collector(
self,
table_name: str,
row_collector: "Callable[[M], Union[Dict[str, Any], List[Dict[str, Any]]]]",
columns: List[str],
):
"""
Add a custom data collector to generate a standalone table.
:param table_name: The name of table storing the data collected.
:param row_collector: A callable function, returning a `dict` computed from `Model` forming one
:param row_collector: A callable function, returning a `dict` computed from `Model` forming one
row in the table.
"""
self._custom_collectors[table_name] = (
cast(Any, row_collector), columns)
self._custom_collectors[table_name] = (cast(Any, row_collector), columns)

def env_property_names(self) -> List[str]:
"""
Expand Down Expand Up @@ -183,8 +199,7 @@ def agent_containers(self) -> List[Tuple[str, "BaseAgentContainer"]]:
"""
containers = []
for container_name in self._agent_properties_to_collect.keys():
containers.append(
(container_name, getattr(self.model, container_name)))
containers.append((container_name, getattr(self.model, container_name)))
return containers

def collect_agent_properties(self, period: int):
Expand Down Expand Up @@ -215,7 +230,8 @@ def collect_single_custom_property(self, collector_name: str, period: int):
collector_func, column_names = self._custom_collectors[collector_name]
if collector_name not in self._custom_collected_data:
self._custom_collected_data[collector_name] = GeneralTable(
{k: None for k in column_names}, column_names)
{k: None for k in column_names}, column_names
)
assert self.model is not None
data = collector_func(self.model)
if isinstance(data, list):
Expand All @@ -226,7 +242,8 @@ def collect_single_custom_property(self, collector_name: str, period: int):
self._custom_collected_data[collector_name].data.append(data)
else:
raise NotImplementedError(
"Data collector function should return a list or dict")
"Data collector function should return a list or dict"
)

def append_agent_properties_by_records(
self,
Expand All @@ -245,8 +262,7 @@ def append_agent_properties_by_records(
id_run, id_scenario = self.model.run_id_in_scenario, self.model.scenario.id
if container_name not in self.agent_properties_dict:
if len(container) == 0:
raise ValueError(
f"No property collected for container {container}!")
raise ValueError(f"No property collected for container {container}!")
agent_attrs_dict = container.random_sample(1)[0].__dict__
props = {
"id_scenario": 0,
Expand All @@ -258,7 +274,8 @@ def append_agent_properties_by_records(

row_cls = TableRow.subcls_from_dict(props)
self.agent_properties_dict[container_name] = Table(
row_cls, self._CORE_PROPERTIES_+['id'] + prop_names)
row_cls, self._CORE_PROPERTIES_ + ["id"] + prop_names
)
self._agent_properties_collectors[
container_name
] = objs_to_table_row_vectorizer(row_cls, prop_names)
Expand All @@ -283,12 +300,12 @@ def append_environment_properties(self, period: int):
"id_run": self.model.run_id_in_scenario,
"period": period,
}
env_dic.update(self.model.environment.to_dict(
self.env_property_names()))
env_dic.update(self.model.environment.to_dict(self.env_property_names()))
if self.environment_properties_list is None:
row_cls = TableRow.subcls_from_dict(env_dic)
self.environment_properties_list = Table(
row_cls, self._CORE_PROPERTIES_+self.env_property_names())
row_cls, self._CORE_PROPERTIES_ + self.env_property_names()
)
self.environment_properties_list.append_from_dicts([env_dic])

@property
Expand Down Expand Up @@ -360,7 +377,9 @@ def get_single_agent_data(self, agent_container_name: str, agent_id: int):
container_data = self.agent_properties_dict[agent_container_name]
return list(filter(lambda item: item["id"] == agent_id, container_data))

def _write_list_to_table(self, engine, table_name: str, data: Union[Table, GeneralTable]):
def _write_list_to_table(
self, engine, table_name: str, data: Union[Table, GeneralTable]
):
"""
Write a list of dict into database.
Expand All @@ -370,7 +389,7 @@ def _write_list_to_table(self, engine, table_name: str, data: Union[Table, Gener
base_path = self.model.config.output_tables_path()
if not os.path.exists(base_path):
os.makedirs(base_path)
path = os.path.join(base_path, table_name+".csv")
path = os.path.join(base_path, table_name + ".csv")
data.to_file(path)
else:
data.to_database(engine, table_name)
Expand Down Expand Up @@ -404,7 +423,7 @@ def save(self):
# "agent_list" -> "AgentList"
# "agent_list" -> "Agent_list"
# "Agent"
"Result_"+underline_to_camel(container_name),
"Result_" + underline_to_camel(container_name),
self.agent_properties_dict[container_name],
)
write_db_time += time.time() - _t
Expand All @@ -414,7 +433,7 @@ def save(self):
self._write_list_to_table(
connection.get_engine(),
custom_table_name,
self._custom_collected_data[custom_table_name]
self._custom_collected_data[custom_table_name],
)
write_db_time += time.time() - _t
self.agent_properties_dict = {}
Expand Down
Loading

0 comments on commit b9013aa

Please sign in to comment.