Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions rocket_learn/rollout_generator/redis/redis_rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,16 @@ class RedisRolloutWorker:
:param force_paging: Should paging be forced
:param auto_minimize: automatically minimize the launched rocket league instance
:param local_cache_name: name of local database used for model caching. If None, caching is not used
:param force_old_deterministic: force all old models to be deterministic only
"""

def __init__(self, redis: Redis, name: str, match: Match,
past_version_prob=.2, evaluation_prob=0.01, sigma_target=2,
dynamic_gm=True, streamer_mode=False, send_gamestates=True,
send_obs=True, scoreboard=None, pretrained_agents=None,
human_agent=None, force_paging=False, auto_minimize=True,
local_cache_name=None):
local_cache_name=None,
force_old_deterministic=False):
# TODO model or config+params so workers can recreate just from redis connection?
self.redis = redis
self.name = name
Expand All @@ -61,6 +63,7 @@ def __init__(self, redis: Redis, name: str, match: Match,
self.pretrained_total_prob = sum([self.pretrained_agents[key] for key in self.pretrained_agents])

self.human_agent = human_agent
self.force_old_deterministic = force_old_deterministic

if human_agent and pretrained_agents:
print("** WARNING - Human Player and Pretrained Agents are in conflict. **")
Expand Down Expand Up @@ -351,13 +354,17 @@ def _generate_matchup(self, n_agents, latest_version, pretrained_choice):
n_new = n_agents - n_old
versions, ratings = self._get_opponent_ids(n_new, n_old, pretrained_choice)
agents = []
for version in versions:
for i, version in enumerate(versions):
if version == -1:
agents.append(self.current_agent)
elif pretrained_choice is not None and version == 'na':
agents.append(pretrained_choice)
else:
selected_agent = self._get_past_model("-".join(version.split("-")[:-1]))
if self.force_old_deterministic and n_new != 0:
versions[i] = versions[i].replace('stochastic', 'deterministic')
version = version.replace('stochastic', 'deterministic')

if version.endswith("deterministic"):
selected_agent.deterministic = True
elif version.endswith("stochastic"):
Expand Down