Skip to content

Commit

Permalink
added csv output mode
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mole committed Jan 19, 2024
1 parent d5c7a94 commit 1f3b2e2
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 27 deletions.
11 changes: 6 additions & 5 deletions Melodie/data_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
from typing import Callable, List, TYPE_CHECKING, Dict, Tuple, Any, Optional, Type
import pandas as pd

import collections
import sqlalchemy

from MelodieInfra import (
Expand Down Expand Up @@ -205,7 +205,8 @@ def append_agent_properties_by_records(
props.update({k: agent_attrs_dict[k] for k in prop_names})

row_cls = TableRow.subcls_from_dict(props)
self.agent_properties_dict[container_name] = Table(row_cls)
self.agent_properties_dict[container_name] = Table(
row_cls, ['id_scenario', 'id_run', 'period', 'id'] + prop_names)
self._agent_properties_collectors[
container_name
] = objs_to_table_row_vectorizer(row_cls, prop_names)
Expand All @@ -231,7 +232,8 @@ def append_environment_properties(self, period: int):
self.env_property_names()))
if self.environment_properties_list is None:
row_cls = TableRow.subcls_from_dict(env_dic)
self.environment_properties_list = Table(row_cls)
self.environment_properties_list = Table(
row_cls, ['id_scenario', 'id_run', 'period']+self.env_property_names())
self.environment_properties_list.append_from_dicts([env_dic])

@property
Expand Down Expand Up @@ -309,8 +311,7 @@ def _write_list_to_table(self, engine, table_name: str, data: Table):
:return:
"""
if self.model.config.data_output_type == "csv":
base_path = os.path.join(self.model.config.output_folder,
self.model.config.project_name)
base_path = self.model.config.output_tables_path()
if not os.path.exists(base_path):
os.makedirs(base_path)
path = os.path.join(base_path, table_name+".csv")
Expand Down
2 changes: 1 addition & 1 deletion Melodie/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def generate_scenarios_from_dataframe(self, df_name: str) -> List["Scenario"]:
for i, row in enumerate(scenarios_dataframe.iter_dicts()):
scenario = self.scenario_cls()
scenario.manager = self.manager

scenario._setup()
for col_name in cols:
value = row[col_name]
scenario.__dict__[col_name] = value
Expand Down
6 changes: 6 additions & 0 deletions Melodie/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ def create(self):
"""
pass

def load(self):
"""
An initialization method for injecting data.
"""

def setup(self):
"""
General method for model setup, which is called after ``Model.create()``
Expand Down Expand Up @@ -300,6 +305,7 @@ def _setup(self):
:return:
"""
self.load()
self.create()
self.setup()
for component_to_init in self.initialization_queue:
Expand Down
22 changes: 15 additions & 7 deletions Melodie/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import copy
import logging
import os
import shutil
import threading
import time
from multiprocessing import Pool
Expand Down Expand Up @@ -40,7 +41,7 @@ def __init__(
model_cls: Type["Model"],
data_loader_cls: Optional[Type[DataLoader]] = None,
):
self.config: Optional[Config] = config
self.config: Config = config
self.scenario_cls = scenario_cls
self.model_cls = model_cls

Expand Down Expand Up @@ -94,7 +95,15 @@ def subworker_prerun(self):

self.scenarios = self.generate_scenarios()

def pre_run(self, clear_db=True):
def clear_output_tables(self):
"""
Clear all output tables
"""
output_path = self.config.output_tables_path()
shutil.rmtree(output_path)
os.mkdir(output_path)

def pre_run(self, clear_output_data=True):
"""
`pre_run` means this function should be executed before ``run``, to initialize the scenario
parameters.
Expand All @@ -103,8 +112,10 @@ def pre_run(self, clear_db=True):
:return:
"""
assert self.config is not None, MelodieExceptions.MLD_INTL_EXC
if clear_db:
if clear_output_data:
create_db_conn(self.config).clear_database()
self.clear_output_tables()

if self.df_loader_cls is not None:
self.data_loader: DataLoader = self.df_loader_cls(
self, self.config, self.scenario_cls
Expand All @@ -116,11 +127,8 @@ def pre_run(self, clear_db=True):
raise MelodieExceptions.Scenario.NoValidScenarioGenerated(
self.scenarios)

for scenario in self.scenarios:
scenario._setup()

@abc.abstractmethod
def generate_scenarios(self):
def generate_scenarios(self) -> List[Scenario]:
"""
Abstract method for generation of scenarios.
"""
Expand Down
11 changes: 10 additions & 1 deletion MelodieInfra/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
self.visualizer_port = kwargs.get("visualizer_port", 8765)
self.parallel_port = kwargs.get("parallel_port", 12233)
self.data_output_type = data_output_type

if database_config is None:
self.database_config = SQLiteDBConfig(
os.path.join(self.output_folder,
Expand Down Expand Up @@ -91,3 +91,12 @@ def from_dict(d: Dict[str, Any]):
database_config=db_conf,
)
return c

def output_tables_path(self):
"""
Get the path to store the tables output from the model. It is `data/output/{project_name}`
If output to database and using sqlite, the output directory will be `data/output`
"""
return os.path.join(self.output_folder,
self.project_name)
44 changes: 31 additions & 13 deletions MelodieInfra/table/table_objects.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -42,8 +43,8 @@ def subcls_from_dict(cls, dic: Dict[str, Union[int, float, str, bool]]):
"_TMP_ROW",
(TableRow,),
{
k: ColumnMeta(k, Column(py_types_to_sa_types[type(v)]()))
for k, v in dic.items()
k: ColumnMeta(k, Column(py_types_to_sa_types[type(dic[k])]()))
for k in dic.keys()
},
)

Expand All @@ -60,7 +61,8 @@ def get_annotations(cls) -> Dict[str, Union[Type[int], Type[str], Type[float]]]:

@classmethod
def get_aliases(cls):
attr_names = list(cls.__dict__.keys()) + list(cls.get_annotations().keys())
attr_names = list(cls.__dict__.keys()) + \
list(cls.get_annotations().keys())
aliases = {}
for attr_name in attr_names:
if hasattr(cls, attr_name) and isinstance(
Expand All @@ -77,7 +79,8 @@ def get_datatypes(cls):
"""
Get the datatype represented in database.
"""
attr_names = list(cls.__dict__.keys()) + list(cls.get_annotations().keys())
attr_names = list(cls.__dict__.keys()) + \
list(cls.get_annotations().keys())
attr_names = [
attr_name
for attr_name in list(set(attr_names))
Expand All @@ -89,15 +92,17 @@ def get_datatypes(cls):
# attr_name in cls.get_annotations()
# ), f'Attribute "{attr_name}" of class {cls.__name__} must be annotated!'
if not hasattr(cls, attr_name):
dtype_ = py_types_to_sa_types[cls.get_annotations()[attr_name]]()
dtype_ = py_types_to_sa_types[cls.get_annotations()[
attr_name]]()
dtype = Column(dtype_)
else:
meta = getattr(cls, attr_name)
assert isinstance(meta, ColumnMeta)
if meta.dtype is not None:
dtype = meta.dtype
else:
dtype_ = py_types_to_sa_types[cls.get_annotations()[attr_name]]()
dtype_ = py_types_to_sa_types[cls.get_annotations()[
attr_name]]()
dtype = Column(dtype_)
assert isinstance(dtype, Column)
col_dtypes[attr_name] = dtype
Expand Down Expand Up @@ -132,24 +137,32 @@ class TR(TableRow):
class Table(TableBase, Generic[TableRowGeneric]):
data: List[TableRowGeneric]

def __init__(self, row_type: Type[TableRowGeneric]) -> None:
def __init__(self, row_type: Type[TableRowGeneric], columns_order: Optional[List[str]] = None) -> None:
super().__init__()
self.row_cls: Type[TableRow] = row_type
self._db_model_cls: Type = None
self.row_types: Dict[str, Column] = {}

column_names = []
if callable(row_type) and issubclass(row_type, TableRow):
self.row_types = row_type.get_datatypes()
for v in self.row_types.values():
for k, v in self.row_types.items():
assert isinstance(v, Column), v
column_names.append(k)
else:
raise NotImplementedError(
f"Cannot recognize table row type {type(row_type)}"
)
if columns_order is not None:
self.columns_order = columns_order
assert set([row for row in self.row_types.keys()]) == set(
self.columns_order), f"columns_order {self.columns_order} should contain the same names as row_types: {self.row_types.keys()}"
else:
self.columns_order = [row for row in self.row_types.keys()]

@property
def columns(self):
return [row for row in self.row_types.keys()]
return self.columns_order

def clear(self):
self.data = []
Expand Down Expand Up @@ -182,22 +195,27 @@ def from_file(file_name: str, row_types: Type[TableRowGeneric], encoding="utf-8"
aliases = table.row_cls.get_aliases()
for row_data in rows_iter:
table_row_obj: TableRow = table.row_cls.from_dict(
table, {col: row_data[i] for i, col in enumerate(columns)}, aliases
table, {col: row_data[i]
for i, col in enumerate(columns)}, aliases
)
table.data.append(table_row_obj)
return table

def to_file(self, file_name: str, encoding="utf-8"):
writer = TableWriter(file_name, text_encoding=encoding).write()
is_new_file = True if not os.path.exists(file_name) else False
writer = TableWriter(file_name, text_encoding=encoding,
append=not is_new_file).write()
headers = self.columns
writer.send(headers)
if is_new_file:
writer.send(headers)
for row_data in self.data:
writer.send([row_data.__dict__[k] for k in headers])
writer.close()

def to_database(self, engine, table_name: str):
conn = DatabaseConnector(engine)
conn.write_table(table_name, self.row_types, [d.__dict__ for d in self.data])
conn.write_table(table_name, self.row_types, [
d.__dict__ for d in self.data])

def to_file_with_codegen(self, file_name: str, encoding="utf-8"):
writer = TableWriter(file_name, text_encoding=encoding).write()
Expand Down

0 comments on commit 1f3b2e2

Please sign in to comment.