Skip to content

Commit

Permalink
Parallel attempt berkeleydeeprlcourse#2, still throws CUDA errors
Browse files Browse the repository at this point in the history
  • Loading branch information
frank-lsf committed Sep 27, 2020
1 parent 3109a76 commit 49efc21
Showing 1 changed file with 34 additions and 32 deletions.
66 changes: 34 additions & 32 deletions hw2/cs285/infrastructure/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,22 @@ def sample_trajectories_sequential(
return paths, timesteps_this_batch


import concurrent.futures as cf
import copy
import os
import torch.multiprocessing as mp


def mp_worker(result_queue, env, policy, max_path_length, render, render_mode):
# if policy.logits_na:
# policy.logits_na = copy.deepcopy(policy.logits_na)
# else:
# raise False

while True:
result = sample_trajectory(env, copy.deepcopy(policy), max_path_length, render, render_mode)
result_queue.put(result)


def sample_trajectories_parallel(
env,
policy,
Expand All @@ -250,40 +261,31 @@ def sample_trajectories_parallel(
"""Collect rollouts until we have collected min_timesteps_per_batch steps."""
# Number of tasks to submit to the executor. This should be larger than
# the number of workers (i.e. CPU count).
n_tasks = os.cpu_count() * 2

timesteps_this_batch = 0
paths = []

mp.set_start_method("spawn")
with cf.ProcessPoolExecutor(mp_context=mp) as executor:
task_args = (
sample_trajectory,
env,
policy,
max_path_length,
render,
render_mode,
ctx = mp.get_context("spawn")

def launch_worker():
proc = ctx.Process(
target=mp_worker,
args=(result_queue, env, policy, max_path_length, render, render_mode),
)
tasks = set(executor.submit(*task_args) for _ in range(n_tasks))
while True:
done_set, rest_set = cf.wait(tasks, return_when=cf.FIRST_COMPLETED)
(done,) = done_set
if not done.done():
raise done.exception()
path = done.result()
paths.append(path)
timesteps_this_batch += get_pathlength(path)

if timesteps_this_batch >= min_timesteps_per_batch:
# We have collected enough. Cancel the rest.
for task in rest_set:
task.cancel()
break
else:
# Submit a new sample_trajectory task
new_task = executor.submit(*task_args)
rest_set.add(new_task)
tasks = list(rest_set)
proc.start()
return proc

# mp.set_start_method("spawn", force=True)
result_queue = ctx.Queue(1)
# processes = [launch_worker() for _ in range(os.cpu_count())]
processes = [launch_worker() for _ in range(1)]

while True:
path = result_queue.get()
paths.append(path)
timesteps_this_batch += get_pathlength(path)
if timesteps_this_batch >= min_timesteps_per_batch:
# We have collected enough. Kill the workers.
for proc in processes:
proc.kill()

return paths, timesteps_this_batch

0 comments on commit 49efc21

Please sign in to comment.