diff --git a/rocket_learn/rollout_generator/redis/redis_rollout_worker.py b/rocket_learn/rollout_generator/redis/redis_rollout_worker.py index a29bfd3..fd89f55 100644 --- a/rocket_learn/rollout_generator/redis/redis_rollout_worker.py +++ b/rocket_learn/rollout_generator/redis/redis_rollout_worker.py @@ -2,6 +2,7 @@ import itertools import os import time +import copy from threading import Thread from uuid import uuid4 @@ -43,6 +44,7 @@ class RedisRolloutWorker: :param auto_minimize: automatically minimize the launched rocket league instance :param local_cache_name: name of local database used for model caching. If None, caching is not used :param gamemode_weights: dict of dynamic gamemode choice weights. If None, default equal experience + :param gamemode_weight_ema_alpha: alpha for the exponential moving average of gamemode weighting """ def __init__(self, redis: Redis, name: str, match: Match, @@ -51,7 +53,9 @@ def __init__(self, redis: Redis, name: str, match: Match, send_obs=True, scoreboard=None, pretrained_agents=None, human_agent=None, force_paging=False, auto_minimize=True, local_cache_name=None, - gamemode_weights=None,): + gamemode_weights=None, + gamemode_weight_ema_alpha=0.02, + ): # TODO model or config+params so workers can recreate just from redis connection? self.redis = redis self.name = name @@ -80,8 +84,18 @@ def __init__(self, redis: Redis, name: str, match: Match, self.send_obs = send_obs self.dynamic_gm = dynamic_gm self.gamemode_weights = gamemode_weights - if self.gamemode_weights is not None: - assert sum(self.gamemode_weights.values()) == 1, "gamemode_weights must sum to 1" + if self.gamemode_weights is None: + self.gamemode_weights = {'1v1': 1/3, '2v2': 1/3, '3v3': 1/3} + assert sum(self.gamemode_weights.values()) == 1, "gamemode_weights must sum to 1" + self.target_weights = copy.copy(self.gamemode_weights) + # change weights from percentage of experience desired to percentage of gamemodes necessary (approx) + self.current_weights = copy.copy(self.gamemode_weights) + for k in self.current_weights.keys(): + b, o = k.split("v") + self.current_weights[k] /= int(b) + self.current_weights = {k: self.current_weights[k] / (sum(self.current_weights.values()) + 1e-8) for k in self.current_weights.keys()} + self.mean_exp_grant = {'1v1': 1000, '2v2': 2000, '3v3': 3000} + self.ema_alpha = gamemode_weight_ema_alpha self.local_cache_name = local_cache_name self.uuid = str(uuid4()) @@ -210,19 +224,15 @@ def _get_past_model(self, version): return model def select_gamemode(self): - mode_exp = {m.decode("utf-8"): int(v) for m, v in self.redis.hgetall(EXPERIENCE_PER_MODE).items()} - if self.gamemode_weights is None: - mode = min(mode_exp, key=mode_exp.get) - else: - total = sum(mode_exp.values()) + 1e-8 - mode_exp = {k: mode_exp[k] / total for k in mode_exp.keys()} - # find exp which is farthest below desired exp - diff = {k: self.gamemode_weights[k] - mode_exp[k] for k in mode_exp.keys()} - mode = max(diff, key=diff.get) + + emp_weight = {k: self.mean_exp_grant[k] / (sum(self.mean_exp_grant.values()) + 1e-8) + for k in self.mean_exp_grant.keys()} + cor_weight = {k: self.gamemode_weights[k] / emp_weight[k] for k in self.gamemode_weights.keys()} + self.current_weights = {k: cor_weight[k] / (sum(cor_weight.values()) + 1e-8) for k in cor_weight} + mode = np.random.choice(list(self.current_weights.keys()), p=list(self.current_weights.values())) b, o = mode.split("v") return int(b), int(o) - def run(self): # Mimics Thread """ begin processing in already launched match and push to redis @@ -310,7 +320,11 @@ def run(self): # Mimics Thread state = rollouts[0].infos[-2]["state"] goal_speed = np.linalg.norm(state.ball.linear_velocity) * 0.036 # kph str_result = ('+' if result > 0 else "") + str(result) - self.total_steps_generated += len(rollouts[0].observations) * len(rollouts) + episode_exp = len(rollouts[0].observations) * len(rollouts) + self.total_steps_generated += episode_exp + if self.dynamic_gm: + old_exp = self.mean_exp_grant[f"{blue}v{orange}"] + self.mean_exp_grant[f"{blue}v{orange}"] = ((episode_exp - old_exp) * self.ema_alpha) + old_exp post_stats = f"Rollout finished after {len(rollouts[0].observations)} steps ({self.total_steps_generated} total steps), result was {str_result}" if result != 0: post_stats += f", goal speed: {goal_speed:.2f} kph"