Skip to content

Commit

Permalink
fixed multiprocessing problem in trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mole committed Jan 21, 2024
1 parent ff54e91 commit 565b175
Show file tree
Hide file tree
Showing 17 changed files with 12,519 additions and 149 deletions.
6 changes: 6 additions & 0 deletions Melodie/scenario_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def load(self):
def _setup(self):
self.load()
self.setup()

def initialize(self):
"""
Have same effect as calling `_setup`, and must be called when generating scenarios.
"""
self._setup()

def setup(self):
"""
Expand Down
22 changes: 17 additions & 5 deletions Melodie/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import copy
import logging
import time
import sys
from typing import (
Dict,
Tuple,
Expand All @@ -14,7 +16,7 @@
)

from MelodieInfra import Config, MelodieExceptions, create_db_conn

from .utils import run_profile
from .algorithms import AlgorithmParameters
from .algorithms.ga import MelodieGA
from MelodieInfra.core import AgentList, Agent
Expand Down Expand Up @@ -346,6 +348,7 @@ def record_agent_properties(
d.update(agent_container_data)
d.pop("target_function_value")
agent_records[container_name].append(d)

create_db_conn(self.manager.config).write_dataframe(
f"{container_name}_trainer_result",
pd.DataFrame(agent_records[container_name]),
Expand Down Expand Up @@ -444,7 +447,7 @@ def run(self, scenario: Scenario, meta: Union[GATrainerAlgorithmMeta]):
f"Scenario {scenario.id} Path {meta.id_path} Generation {i + 1}/{self.params.generation_num}"
f"======================="
)

t0 = time.time()
for id_chromosome in range(self.params.strategy_population):
params = self.get_agent_params(id_chromosome)
self.parallel_manager.put_task(
Expand All @@ -461,32 +464,41 @@ def run(self, scenario: Scenario, meta: Union[GATrainerAlgorithmMeta]):

for _id_chromosome in range(self.params.strategy_population):
# v = result_queue.get()

dt = 0

(
chrom,
agents_data,
env_data,
) = (
self.parallel_manager.get_result()
) # cloudpickle.loads(base64.b64decode(v))
t00 = time.time()
# def f():
print("got result!", _id_chromosome, time.time(), file=sys.stderr)
meta.id_chromosome = chrom
agent_records, env_record = self.record_agent_properties(
agents_data, env_data, meta
)
print(agent_records, env_record, agents_data, env_data)
for container_name, records in agent_records.items():
agent_records_collector[container_name] += records
env_records_list.append(env_record)
self.target_function_to_cache(agents_data, i, chrom)

# run_profile(f)
# f()
print("iter!", time.time() - t00, file=sys.stderr)
t1 = time.time()
print(t1 - t0)
self.calc_cov_df(
{k: pd.DataFrame(v) for k, v in agent_records_collector.items()},
pd.DataFrame(env_records_list),
meta,
)

for key, algorithm in self.algorithms_dict.items():
self._chromosome_counter = -1
algorithm.run(1)
# print("AAAAAAAAAAAAAAAA")


class RelatedAgentContainerModel:
Expand Down
15 changes: 11 additions & 4 deletions MelodieInfra/db/db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import time
from typing import Dict, TYPE_CHECKING, Optional, List, Tuple

import sqlalchemy
Expand Down Expand Up @@ -140,7 +141,8 @@ def clear_database(self):
Clear the database, deleting all tables.
"""
if database_exists(self.connection.url):
logger.info(f"Database contains tables: {self.connection.table_names()}.")
logger.info(
f"Database contains tables: {self.connection.table_names()}.")
table_names = list(self.connection.table_names())
for table_name in table_names:
self.connection.execute(f"drop table {table_name}")
Expand All @@ -164,20 +166,23 @@ def write_dataframe(
:param if_exists: A string in {'replace', 'fail', 'append'}.
:return:
"""

if isinstance(data_frame, TableBase):
data_frame.to_database(self.connection, table_name)
else:
t0 = time.time()
if data_types is None:
data_types = DBConn.get_table_dtypes(table_name)
logger.debug(f"datatype of table `{table_name}` is: {data_types}")
t2 = time.time()
data_frame.to_sql(
table_name,
self.connection,
index=False,
dtype=data_types,
if_exists=if_exists,
)
t1 = time.time()
print("t1-t0", t1-t0, t2-t0, data_frame.shape)

def read_dataframe(
self,
Expand Down Expand Up @@ -210,7 +215,8 @@ def read_dataframe(
where_condition_phrase = ""
condition_phrases = []
if conditions is not None:
condition_phrases.extend([item[0] + item[1] for item in conditions])
condition_phrases.extend([item[0] + item[1]
for item in conditions])
if id_scenario is not None:
condition_phrases.append(f"id_scenario={id_scenario}")
if id_run is not None:
Expand All @@ -230,7 +236,8 @@ def read_dataframe(
import traceback

traceback.print_exc()
raise MelodieExceptions.Data.AttemptingReadingFromUnexistedTable(table_name)
raise MelodieExceptions.Data.AttemptingReadingFromUnexistedTable(
table_name)

def drop_table(self, table_name: str):
"""
Expand Down
5 changes: 1 addition & 4 deletions MelodieInfra/parallel/parallel_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,6 @@ def run(self, role: str):
self.th_server.setDaemon(True)
self.th_server.start()
for core_id in range(self.cores):
python_path = os.environ.get('PYTHONPATH', "")
# paths = paths
print("python_path", python_path, ":".join(sys.path))
p = subprocess.Popen(
[
sys.executable,
Expand All @@ -111,7 +108,7 @@ def run(self, role: str):
"--role",
role,
],
env={"PYTHONPATH": ":".join(sys.path)},
# env={"PYTHONPATH": ";".join(sys.path)},
)
self.processes.append(p)

Expand Down
33 changes: 25 additions & 8 deletions MelodieInfra/parallel/parallel_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Dict, Tuple, Any, Type, Union, TYPE_CHECKING

import cloudpickle
from Melodie.utils.profiler import run_profile

from MelodieInfra.config.global_configs import MelodieGlobalConfig

Expand Down Expand Up @@ -49,7 +50,8 @@ def get_task(self):
return tuple(task)

def put_result(self, result):
return self.conn.root.put_result(result)
ret = self.conn.root.put_result(result)
return ret

def close(self):
self.conn.close()
Expand Down Expand Up @@ -108,6 +110,8 @@ def sub_routine_trainer(
:param config_raw:
:return:
"""
# TODO: Have to set this path!

from Melodie import Config, Trainer, Environment, AgentList, Agent
import logging

Expand All @@ -117,17 +121,23 @@ def sub_routine_trainer(
try:
config = Config.from_dict(config_raw)
trainer: Trainer
trainer, scenario_cls, model_cls = get_scenario_manager(config, modules)

trainer, scenario_cls, model_cls = get_scenario_manager(
config, modules)
except BaseException:
import traceback

traceback.print_exc()
dumped = cloudpickle.dumps(0)
worker.put_result(base64.b64encode(dumped))
return

while 1:
try:
t0 = time.time()

# chrom = -1
# def test():
# nonlocal chrom
chrom, d, agent_params = worker.get_task()
logger.debug(f"processor {proc_id} got chrom {chrom}")
scenario = scenario_cls()
Expand All @@ -151,17 +161,21 @@ def sub_routine_trainer(
agent_data[container.container_name] = df
for row in df:
agent = agent_container.get_agent(row["id"])
row["target_function_value"] = trainer.target_function(agent)
row["target_function_value"] = trainer.target_function(
agent)
row["utility"] = trainer.utility(agent)
row["agent_id"] = row.pop("id")
env: Environment = model.environment
env_data = env.to_dict(trainer.environment_properties)
dumped = cloudpickle.dumps((chrom, agent_data, env_data))

worker.put_result(base64.b64encode(dumped))
# run_profile(test)
# test()
t1 = time.time()
logger.info(
f"Processor {proc_id}, chromosome {chrom}, time: {MelodieGlobalConfig.Logger.round_elapsed_time(t1 - t0)}s"
)
worker.put_result(base64.b64encode(dumped))
except Exception:
import traceback

Expand Down Expand Up @@ -191,7 +205,8 @@ def sub_routine_calibrator(
try:
config = Config.from_dict(config_raw)
calibrator: "Calibrator"
calibrator, scenario_cls, model_cls = get_scenario_manager(config, modules)
calibrator, scenario_cls, model_cls = get_scenario_manager(
config, modules)
except BaseException:
import traceback

Expand Down Expand Up @@ -222,7 +237,8 @@ def sub_routine_calibrator(
env: Environment = model.environment
env_data = env.to_dict(calibrator.watched_env_properties)
env_data.update(
{prop: scenario.to_dict()[prop] for prop in calibrator.properties}
{prop: scenario.to_dict()[prop]
for prop in calibrator.properties}
)
env_data["target_function_value"] = env_data[
"distance"
Expand Down Expand Up @@ -262,7 +278,8 @@ def sub_routine_simulator(
try:
config = Config.from_dict(config_raw)
simulator: "Simulator"
simulator, scenario_cls, model_cls = get_scenario_manager(config, modules)
simulator, scenario_cls, model_cls = get_scenario_manager(
config, modules)
except BaseException:
import traceback

Expand Down
2 changes: 1 addition & 1 deletion tests/infra/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"temp_db_trainer",
os.path.dirname(__file__),
input_folder=os.path.join(os.path.dirname(__file__), "resources", "excels"),
output_folder=os.path.join(os.path.dirname(__file__), "resources", "output"),
output_folder=os.path.join(os.path.dirname(__file__), "resources", "output", "trainer"),
)

cfg_dataloader_with_cache = Config(
Expand Down
Loading

0 comments on commit 565b175

Please sign in to comment.