diff --git a/nle/agent/README.md b/nle/agent/README.md new file mode 100644 index 000000000..7ce7aa438 --- /dev/null +++ b/nle/agent/README.md @@ -0,0 +1,115 @@ +# Neurips 2020 code release + +Here we release updated code to get results competitive with our NeurIPS 2020 paper. + +To be clear, this is not the exact code used for the paper: we made a number of performance improvements to NLE since the original results, dramatically increasing the speed of the environment (which was already one of the fastest-performing environments when the paper was published!). + +We also introduced some additional modeling options, including conditioning the model on the in-game messages (i.e. msg.model=lt_cnn) and introducing new ways of observing the environment through different glyph types (i.e. glyph_type=all_cat). These features are enabled by default for the model now, which outperforms the models in the paper. + +## Reproduced results + +After 1e9 training steps, the average mean_episode_return achieved by the agents in their last 10k episodes are listed below, averaged over three runs. + +We give 1 reward for winning the staircase, pet, and oracle tasks. This gives a lower reward per completion than the original paper. +We give 1 reward for every 1000 steps the agent stays alive on all tasks. This is absent in the original paper. + +**Staircase Task** +The average steps per episode for all models was about 400, so mean_episode_return - 0.4 gives the approximate percentage success rate of the task (so getting close to 1.4 means the agent is reaching the goal every episode). +These models perform better than the original paper result. +- mon-hum-neu-mal: baseline 1.37, RND 1.00. +- val-dwa-law-fem: baseline 1.17, RND 1.14 +- wiz-elf-cha-mal: baseline 1.01, RND 0.97 +- tou-hum-neu-fem: baseline 0.94, RND 1.18 + +**Pet Task** +The average steps per episode for all models was about 400, so mean_episode_return - 0.4 gives the approximate percentage success rate of the task (so getting close to 1.4 means the agent is reaching the goal every episode). +These models perform better than the original paper result. +- mon-hum-neu-mal: baseline 1.26, RND 1.31 +- val-dwa-law-fem: baseline 1.08, RND 1.08 +- wiz-elf-cha-mal: baseline 0.86, RND 0.85 +- tou-hum-neu-fem: baseline 0.75, RND 0.79 + +**Eat Task** +The valkyrie baseline performed worse than in the original paper on this task; otherwise, every other model and character performed much better. +- mon-hum-neu-mal: baseline 2282, RND 2193 +- val-dwa-law-fem: baseline 145, RND 1240 +- wiz-elf-cha-mal: baseline 993, RND 1066 +- tou-hum-neu-fem: baseline 1131, RND 1230 + +**Gold Task** +These models perform much better than the original paper result. +- mon-hum-neu-mal: baseline 159, RND 116 +- val-dwa-law-fem: baseline 60, RND 63 +- wiz-elf-cha-mal: baseline 37, RND 38 +- tou-hum-neu-fem: baseline 18, RND 19 + +**Score Task** +The RND models here perform slightly worse than in the paper. +The baseline models here perform significantly better than in the paper. +- mon-hum-neu-mal: baseline 971, RND 941 +- val-dwa-law-fem: baseline 653, RND 657 +- wiz-elf-cha-mal: baseline 343, RND 343 +- tou-hum-neu-fem: baseline 116, RND 202 + +**Scout Task** +These models perform significantly better than in the paper. +- mon-hum-neu-mal: baseline 2491, RND 2452 +- val-dwa-law-fem: baseline 1978, RND 1982 +- wiz-elf-cha-mal: baseline 1411, RND 1397 +- tou-hum-neu-fem: baseline 1112, RND 1114 + +**Oracle Task** +These models perform similarly to the paper. +- mon-hum-neu-mal: baseline -4.1, RND -4.8 +- val-dwa-law-fem: baseline -4.9, RND -6.0 +- wiz-elf-cha-mal: baseline -4.4, RND -4.3 +- tou-hum-neu-fem: baseline -3.4, RND -6.8 + + +## Changed params from the paper (better performing!) + +- change the reward_win parameter from 100 to 1. this only affects the staircase, pet, and oracle tasks. this is a more appropriate ratio between the reward and the step penalty and results in more consistent performance. you can compare with the plots in the paper but consider the scale to be leading towards 1 rather than 100 for the agent to be reaching the goal on every episode. +- decrease learning rate from 0.0002 to 0.0001 +- added reward_normalization parameter. set this reward_normalization=true, set reward_clipping=none (was "tim" before). +- increase hidden_size from 128 to 256 +- increase embedding size from 32 to 64 +- add a "message model" which conditions on the in-game message, providing a fourth input to the policy (in addition to the full dungeon screen, the crop of the dungeon right around the agent, and the status bar). set msg.model=lt_cnn. +- add different interpretations of the glyphs in the environment. see below for explanation. set glyph_type=all_cat. + +When msg.model=lt_cnn, and int.input=full (the default), we also add the message model to the target and predictor networks for RND. This should also improve RND network's performance, as seeing new messages in the game should be a particularly high-signal new experience to seek out (it could include taking new actions or seeing new monsters, new items, new environments or more) but we have not yet analyzed this in detail. + +## Glyph Types + +The default setting `full` from in the paper uses a unique identifier for every glyph (entity on the screen) that you encounter. There are about 6000 unique glyphs. + +However, this masks the relationships between entities that the agent might otherwise be able to use (and humans do!), for example that several entities that share certain traits might have a similar appearance such as different kinds of dogs using the same symbol but having different colours. + +The `group_id` setting is one such breakdown, splitting each entity into one of twelve groups and then an id within each group. + +An alternative breakdown `color_char`(_special) is by actual appearance, splitting each glyph into the colour used for the entity, the character used for it, and then finally a special identifier which plays can enable in the game and provides additional information about the entity at that tile. + +We then also provide the `all` encoding which uses the group, id, colour, character, and special traits for each item, embedding them each in an `edim` vector, concatenating the vectors, and then projecting back to a final `edim` vector using a Linear(5 * edim, edim) layer. + +Finally, we provide `all_cat` which partitions the `edim` final vector into sub-vectors based on how unique each component is so that the final vectors can be concatened (without projection) into an `edim` vector. That is, for a 64-dim embedding_dim, we would use a 24-dim embedding for the id, 8 dim for group, 8 dim for color, 16 dim for character, and 8 dim for special. This is a bit arbitrary but performs well and is the recommended setting. + +## Polyhydra Syntax + +``` +# install hydra +pip install hydra-core hydra_colorlog + +# single run +python -m hackrl.polyhydra model=baseline env=score num_actors=80 + +# to sweep on the cluster: add another -m (for multirun) and comma-separate values +python -m hackrl.polyhydra model=baseline,ride,rnd,dynamics env=score,gold + +# hydra supports nested arguments. +python -m hackrl.polyhydra msg.model=cnn fwd.forward_cost=0.01 ride.count_norm=false +``` + +In addition to specifiying arguments on the command line, you can edit config.yaml directly. + +## Reproducing our NeurIPS sweep + +Run neurips_sweep.sh to run a sweep covering the characters, tasks, and models we report in the paper. \ No newline at end of file diff --git a/nle/agent/__init__.py b/nle/agent/__init__.py index 9020c2df2..8daf2005d 100644 --- a/nle/agent/__init__.py +++ b/nle/agent/__init__.py @@ -1 +1,13 @@ # Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nle/agent/agent.py b/nle/agent/agent.py deleted file mode 100644 index 21f888840..000000000 --- a/nle/agent/agent.py +++ /dev/null @@ -1,956 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This is an example self-contained agent running NLE based on MonoBeast. - -import argparse -import logging -import os -import pprint -import threading -import time -import timeit -import traceback - -# Necessary for multithreading. -os.environ["OMP_NUM_THREADS"] = "1" - -try: - import torch - from torch import multiprocessing as mp - from torch import nn - from torch.nn import functional as F -except ImportError: - logging.exception( - "PyTorch not found. Please install the agent dependencies with " - '`pip install "nle[agent]"`' - ) - -import gym # noqa: E402 - -import nle # noqa: F401, E402 -from nle.agent import vtrace # noqa: E402 -from nle import nethack # noqa: E402 - - -# yapf: disable -parser = argparse.ArgumentParser(description="PyTorch Scalable Agent") - -parser.add_argument("--env", type=str, default="NetHackScore-v0", - help="Gym environment.") -parser.add_argument("--mode", default="train", - choices=["train", "test", "test_render"], - help="Training or test mode.") - -# Training settings. -parser.add_argument("--disable_checkpoint", action="store_true", - help="Disable saving checkpoint.") -parser.add_argument("--savedir", default="~/torchbeast/", - help="Root dir where experiment data will be saved.") -parser.add_argument("--num_actors", default=4, type=int, metavar="N", - help="Number of actors (default: 4).") -parser.add_argument("--total_steps", default=100000, type=int, metavar="T", - help="Total environment steps to train for.") -parser.add_argument("--batch_size", default=8, type=int, metavar="B", - help="Learner batch size.") -parser.add_argument("--unroll_length", default=80, type=int, metavar="T", - help="The unroll length (time dimension).") -parser.add_argument("--num_buffers", default=None, type=int, - metavar="N", help="Number of shared-memory buffers.") -parser.add_argument("--num_learner_threads", "--num_threads", default=2, type=int, - metavar="N", help="Number learner threads.") -parser.add_argument("--disable_cuda", action="store_true", - help="Disable CUDA.") -parser.add_argument("--use_lstm", action="store_true", - help="Use LSTM in agent model.") - -# Loss settings. -parser.add_argument("--entropy_cost", default=0.0006, - type=float, help="Entropy cost/multiplier.") -parser.add_argument("--baseline_cost", default=0.5, - type=float, help="Baseline cost/multiplier.") -parser.add_argument("--discounting", default=0.99, - type=float, help="Discounting factor.") -parser.add_argument("--reward_clipping", default="abs_one", - choices=["abs_one", "none"], - help="Reward clipping.") - -# Optimizer settings. -parser.add_argument("--learning_rate", default=0.00048, - type=float, metavar="LR", help="Learning rate.") -parser.add_argument("--alpha", default=0.99, type=float, - help="RMSProp smoothing constant.") -parser.add_argument("--momentum", default=0, type=float, - help="RMSProp momentum.") -parser.add_argument("--epsilon", default=0.01, type=float, - help="RMSProp epsilon.") -parser.add_argument("--grad_norm_clipping", default=40.0, type=float, - help="Global gradient norm clip.") -# yapf: enable - - -logging.basicConfig( - format=( - "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s" - ), - level=logging.INFO, -) - - -def nested_map(f, n): - if isinstance(n, tuple) or isinstance(n, list): - return n.__class__(nested_map(f, sn) for sn in n) - elif isinstance(n, dict): - return {k: nested_map(f, v) for k, v in n.items()} - else: - return f(n) - - -def compute_baseline_loss(advantages): - return 0.5 * torch.sum(advantages ** 2) - - -def compute_entropy_loss(logits): - """Return the entropy loss, i.e., the negative entropy of the policy.""" - policy = F.softmax(logits, dim=-1) - log_policy = F.log_softmax(logits, dim=-1) - return torch.sum(policy * log_policy) - - -def compute_policy_gradient_loss(logits, actions, advantages): - cross_entropy = F.nll_loss( - F.log_softmax(torch.flatten(logits, 0, 1), dim=-1), - target=torch.flatten(actions, 0, 1), - reduction="none", - ) - cross_entropy = cross_entropy.view_as(advantages) - return torch.sum(cross_entropy * advantages.detach()) - - -def create_env(name, *args, **kwargs): - return gym.make(name, observation_keys=("glyphs", "blstats"), *args, **kwargs) - - -def act( - flags, - actor_index: int, - free_queue: mp.SimpleQueue, - full_queue: mp.SimpleQueue, - model: torch.nn.Module, - buffers, - initial_agent_state_buffers, -): - try: - logging.info("Actor %i started.", actor_index) - - gym_env = create_env(flags.env, savedir=flags.rundir) - env = ResettingEnvironment(gym_env) - env_output = env.initial() - agent_state = model.initial_state(batch_size=1) - agent_output, unused_state = model(env_output, agent_state) - while True: - index = free_queue.get() - if index is None: - break - - # Write old rollout end. - for key in env_output: - buffers[key][index][0, ...] = env_output[key] - for key in agent_output: - buffers[key][index][0, ...] = agent_output[key] - for i, tensor in enumerate(agent_state): - initial_agent_state_buffers[index][i][...] = tensor - - # Do new rollout. - for t in range(flags.unroll_length): - with torch.no_grad(): - agent_output, agent_state = model(env_output, agent_state) - - env_output = env.step(agent_output["action"]) - - for key in env_output: - buffers[key][index][t + 1, ...] = env_output[key] - for key in agent_output: - buffers[key][index][t + 1, ...] = agent_output[key] - - full_queue.put(index) - - except KeyboardInterrupt: - pass # Return silently. - except Exception: - logging.error("Exception in worker process %i", actor_index) - traceback.print_exc() - print() - raise - - -def get_batch( - flags, - free_queue: mp.SimpleQueue, - full_queue: mp.SimpleQueue, - buffers, - initial_agent_state_buffers, - lock=threading.Lock(), -): - with lock: - indices = [full_queue.get() for _ in range(flags.batch_size)] - batch = { - key: torch.stack([buffers[key][m] for m in indices], dim=1) for key in buffers - } - initial_agent_state = ( - torch.cat(ts, dim=1) - for ts in zip(*[initial_agent_state_buffers[m] for m in indices]) - ) - for m in indices: - free_queue.put(m) - batch = {k: t.to(device=flags.device, non_blocking=True) for k, t in batch.items()} - initial_agent_state = tuple( - t.to(device=flags.device, non_blocking=True) for t in initial_agent_state - ) - return batch, initial_agent_state - - -def learn( - flags, - actor_model, - model, - batch, - initial_agent_state, - optimizer, - scheduler, - lock=threading.Lock(), # noqa: B008 -): - """Performs a learning (optimization) step.""" - with lock: - learner_outputs, unused_state = model(batch, initial_agent_state) - - # Take final value function slice for bootstrapping. - bootstrap_value = learner_outputs["baseline"][-1] - - # Move from obs[t] -> action[t] to action[t] -> obs[t]. - batch = {key: tensor[1:] for key, tensor in batch.items()} - learner_outputs = {key: tensor[:-1] for key, tensor in learner_outputs.items()} - - rewards = batch["reward"] - if flags.reward_clipping == "abs_one": - clipped_rewards = torch.clamp(rewards, -1, 1) - elif flags.reward_clipping == "none": - clipped_rewards = rewards - - discounts = (~batch["done"]).float() * flags.discounting - - vtrace_returns = vtrace.from_logits( - behavior_policy_logits=batch["policy_logits"], - target_policy_logits=learner_outputs["policy_logits"], - actions=batch["action"], - discounts=discounts, - rewards=clipped_rewards, - values=learner_outputs["baseline"], - bootstrap_value=bootstrap_value, - ) - - pg_loss = compute_policy_gradient_loss( - learner_outputs["policy_logits"], - batch["action"], - vtrace_returns.pg_advantages, - ) - baseline_loss = flags.baseline_cost * compute_baseline_loss( - vtrace_returns.vs - learner_outputs["baseline"] - ) - entropy_loss = flags.entropy_cost * compute_entropy_loss( - learner_outputs["policy_logits"] - ) - - total_loss = pg_loss + baseline_loss + entropy_loss - - episode_returns = batch["episode_return"][batch["done"]] - stats = { - "episode_returns": tuple(episode_returns.cpu().numpy()), - "mean_episode_return": torch.mean(episode_returns).item(), - "total_loss": total_loss.item(), - "pg_loss": pg_loss.item(), - "baseline_loss": baseline_loss.item(), - "entropy_loss": entropy_loss.item(), - } - - optimizer.zero_grad() - total_loss.backward() - nn.utils.clip_grad_norm_(model.parameters(), flags.grad_norm_clipping) - optimizer.step() - scheduler.step() - - actor_model.load_state_dict(model.state_dict()) - return stats - - -def create_buffers(flags, observation_space, num_actions, num_overlapping_steps=1): - size = (flags.unroll_length + num_overlapping_steps,) - - # Get specimens to infer shapes and dtypes. - samples = {k: torch.from_numpy(v) for k, v in observation_space.sample().items()} - - specs = { - key: dict(size=size + sample.shape, dtype=sample.dtype) - for key, sample in samples.items() - } - specs.update( - reward=dict(size=size, dtype=torch.float32), - done=dict(size=size, dtype=torch.bool), - episode_return=dict(size=size, dtype=torch.float32), - episode_step=dict(size=size, dtype=torch.int32), - policy_logits=dict(size=size + (num_actions,), dtype=torch.float32), - baseline=dict(size=size, dtype=torch.float32), - last_action=dict(size=size, dtype=torch.int64), - action=dict(size=size, dtype=torch.int64), - ) - buffers = {key: [] for key in specs} - for _ in range(flags.num_buffers): - for key in buffers: - buffers[key].append(torch.empty(**specs[key]).share_memory_()) - return buffers - - -def _format_observations(observation, keys=("glyphs", "blstats")): - observations = {} - for key in keys: - entry = observation[key] - entry = torch.from_numpy(entry) - entry = entry.view((1, 1) + entry.shape) # (...) -> (T,B,...). - observations[key] = entry - return observations - - -class ResettingEnvironment: - """Turns a Gym environment into something that can be step()ed indefinitely.""" - - def __init__(self, gym_env): - self.gym_env = gym_env - self.episode_return = None - self.episode_step = None - - def initial(self): - initial_reward = torch.zeros(1, 1) - # This supports only single-tensor actions ATM. - initial_last_action = torch.zeros(1, 1, dtype=torch.int64) - self.episode_return = torch.zeros(1, 1) - self.episode_step = torch.zeros(1, 1, dtype=torch.int32) - initial_done = torch.ones(1, 1, dtype=torch.uint8) - - result = _format_observations(self.gym_env.reset()) - result.update( - reward=initial_reward, - done=initial_done, - episode_return=self.episode_return, - episode_step=self.episode_step, - last_action=initial_last_action, - ) - return result - - def step(self, action): - observation, reward, done, unused_info = self.gym_env.step(action.item()) - self.episode_step += 1 - self.episode_return += reward - episode_step = self.episode_step - episode_return = self.episode_return - if done: - observation = self.gym_env.reset() - self.episode_return = torch.zeros(1, 1) - self.episode_step = torch.zeros(1, 1, dtype=torch.int32) - - result = _format_observations(observation) - - reward = torch.tensor(reward).view(1, 1) - done = torch.tensor(done).view(1, 1) - - result.update( - reward=reward, - done=done, - episode_return=episode_return, - episode_step=episode_step, - last_action=action, - ) - return result - - def close(self): - self.gym_env.close() - - -def train(flags): # pylint: disable=too-many-branches, too-many-statements - flags.savedir = os.path.expandvars(os.path.expanduser(flags.savedir)) - - rundir = os.path.join( - flags.savedir, "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S") - ) - - if not os.path.exists(rundir): - os.makedirs(rundir) - logging.info("Logging results to %s", rundir) - - symlink = os.path.join(flags.savedir, "latest") - try: - if os.path.islink(symlink): - os.remove(symlink) - if not os.path.exists(symlink): - os.symlink(rundir, symlink) - logging.info("Symlinked log directory: %s", symlink) - except OSError: - raise - - logfile = open(os.path.join(rundir, "logs.tsv"), "a", buffering=1) - checkpointpath = os.path.join(rundir, "model.tar") - - flags.rundir = rundir - - if flags.num_buffers is None: # Set sensible default for num_buffers. - flags.num_buffers = max(2 * flags.num_actors, flags.batch_size) - if flags.num_actors >= flags.num_buffers: - raise ValueError("num_buffers should be larger than num_actors") - if flags.num_buffers < flags.batch_size: - raise ValueError("num_buffers should be larger than batch_size") - - T = flags.unroll_length - B = flags.batch_size - - flags.device = None - if not flags.disable_cuda and torch.cuda.is_available(): - logging.info("Using CUDA.") - flags.device = torch.device("cuda") - else: - logging.info("Not using CUDA.") - flags.device = torch.device("cpu") - - env = create_env(flags.env, archivefile=None) - observation_space = env.observation_space - action_space = env.action_space - del env # End this before forking. - - model = Net(observation_space, action_space.n, flags.use_lstm) - buffers = create_buffers(flags, observation_space, model.num_actions) - - model.share_memory() - - # Add initial RNN state. - initial_agent_state_buffers = [] - for _ in range(flags.num_buffers): - state = model.initial_state(batch_size=1) - for t in state: - t.share_memory_() - initial_agent_state_buffers.append(state) - - actor_processes = [] - ctx = mp.get_context("fork") - free_queue = ctx.SimpleQueue() - full_queue = ctx.SimpleQueue() - - for i in range(flags.num_actors): - actor = ctx.Process( - target=act, - args=( - flags, - i, - free_queue, - full_queue, - model, - buffers, - initial_agent_state_buffers, - ), - name="Actor-%i" % i, - ) - actor.start() - actor_processes.append(actor) - - learner_model = Net(observation_space, action_space.n, flags.use_lstm).to( - device=flags.device - ) - learner_model.load_state_dict(model.state_dict()) - - optimizer = torch.optim.RMSprop( - learner_model.parameters(), - lr=flags.learning_rate, - momentum=flags.momentum, - eps=flags.epsilon, - alpha=flags.alpha, - ) - - def lr_lambda(epoch): - return 1 - min(epoch * T * B, flags.total_steps) / flags.total_steps - - scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) - - stat_keys = [ - "total_loss", - "mean_episode_return", - "pg_loss", - "baseline_loss", - "entropy_loss", - ] - logfile.write("# Step\t%s\n" % "\t".join(stat_keys)) - - step, stats = 0, {} - - def batch_and_learn(i, lock=threading.Lock()): - """Thread target for the learning process.""" - nonlocal step, stats - while step < flags.total_steps: - batch, agent_state = get_batch( - flags, free_queue, full_queue, buffers, initial_agent_state_buffers - ) - stats = learn( - flags, model, learner_model, batch, agent_state, optimizer, scheduler - ) - with lock: - logfile.write("%i\t" % step) - logfile.write("\t".join(str(stats[k]) for k in stat_keys)) - logfile.write("\n") - step += T * B - - for m in range(flags.num_buffers): - free_queue.put(m) - - threads = [] - for i in range(flags.num_learner_threads): - thread = threading.Thread( - target=batch_and_learn, - name="batch-and-learn-%d" % i, - args=(i,), - daemon=True, # To support KeyboardInterrupt below. - ) - thread.start() - threads.append(thread) - - def checkpoint(): - if flags.disable_checkpoint: - return - logging.info("Saving checkpoint to %s", checkpointpath) - torch.save( - { - "model_state_dict": model.state_dict(), - "optimizer_state_dict": optimizer.state_dict(), - "scheduler_state_dict": scheduler.state_dict(), - "flags": vars(flags), - }, - checkpointpath, - ) - - timer = timeit.default_timer - try: - last_checkpoint_time = timer() - while step < flags.total_steps: - start_step = step - start_time = timer() - time.sleep(5) - - if timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. - checkpoint() - last_checkpoint_time = timer() - - sps = (step - start_step) / (timer() - start_time) - if stats.get("episode_returns", None): - mean_return = ( - "Return per episode: %.1f. " % stats["mean_episode_return"] - ) - else: - mean_return = "" - total_loss = stats.get("total_loss", float("inf")) - logging.info( - "Steps %i @ %.1f SPS. Loss %f. %sStats:\n%s", - step, - sps, - total_loss, - mean_return, - pprint.pformat(stats), - ) - except KeyboardInterrupt: - logging.warning("Quitting.") - return # Try joining actors then quit. - else: - for thread in threads: - thread.join() - logging.info("Learning finished after %d steps.", step) - finally: - for _ in range(flags.num_actors): - free_queue.put(None) - for actor in actor_processes: - actor.join(timeout=1) - - checkpoint() - logfile.close() - - -def test(flags, num_episodes=10): - flags.savedir = os.path.expandvars(os.path.expanduser(flags.savedir)) - checkpointpath = os.path.join(flags.savedir, "latest", "model.tar") - - gym_env = create_env(flags.env, archivefile=None) - env = ResettingEnvironment(gym_env) - model = Net(gym_env.observation_space, gym_env.action_space.n, flags.use_lstm) - model.eval() - checkpoint = torch.load(checkpointpath, map_location="cpu") - model.load_state_dict(checkpoint["model_state_dict"]) - - observation = env.initial() - returns = [] - - agent_state = model.initial_state(batch_size=1) - - while len(returns) < num_episodes: - if flags.mode == "test_render": - env.gym_env.render() - policy_outputs, agent_state = model(observation, agent_state) - observation = env.step(policy_outputs["action"]) - if observation["done"].item(): - returns.append(observation["episode_return"].item()) - logging.info( - "Episode ended after %d steps. Return: %.1f", - observation["episode_step"].item(), - observation["episode_return"].item(), - ) - env.close() - logging.info( - "Average returns over %i steps: %.1f", num_episodes, sum(returns) / len(returns) - ) - - -class RandomNet(nn.Module): - def __init__(self, observation_shape, num_actions, use_lstm): - super(RandomNet, self).__init__() - del observation_shape, use_lstm - self.num_actions = num_actions - self.theta = torch.nn.Parameter(torch.zeros(self.num_actions)) - - def forward(self, inputs, core_state): - # print(inputs) - T, B, *_ = inputs["observation"].shape - zeros = self.theta * 0 - # set logits to 0 - policy_logits = zeros[None, :].expand(T * B, -1) - # set baseline to 0 - baseline = policy_logits.sum(dim=1).view(-1, B) - - # sample random action - action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1).view( - T, B - ) - policy_logits = policy_logits.view(T, B, self.num_actions) - return ( - dict(policy_logits=policy_logits, baseline=baseline, action=action), - core_state, - ) - - def initial_state(self, batch_size): - return () - - -def _step_to_range(delta, num_steps): - """Range of `num_steps` integers with distance `delta` centered around zero.""" - return delta * torch.arange(-num_steps // 2, num_steps // 2) - - -class Crop(nn.Module): - """Helper class for NetHackNet below.""" - - def __init__(self, height, width, height_target, width_target): - super(Crop, self).__init__() - self.width = width - self.height = height - self.width_target = width_target - self.height_target = height_target - width_grid = _step_to_range(2 / (self.width - 1), self.width_target)[ - None, : - ].expand(self.height_target, -1) - height_grid = _step_to_range(2 / (self.height - 1), height_target)[ - :, None - ].expand(-1, self.width_target) - - # "clone" necessary, https://github.com/pytorch/pytorch/issues/34880 - self.register_buffer("width_grid", width_grid.clone()) - self.register_buffer("height_grid", height_grid.clone()) - - def forward(self, inputs, coordinates): - """Calculates centered crop around given x,y coordinates. - Args: - inputs [B x H x W] - coordinates [B x 2] x,y coordinates - Returns: - [B x H' x W'] inputs cropped and centered around x,y coordinates. - """ - assert inputs.shape[1] == self.height - assert inputs.shape[2] == self.width - - inputs = inputs[:, None, :, :].float() - - x = coordinates[:, 0] - y = coordinates[:, 1] - - x_shift = 2 / (self.width - 1) * (x.float() - self.width // 2) - y_shift = 2 / (self.height - 1) * (y.float() - self.height // 2) - - grid = torch.stack( - [ - self.width_grid[None, :, :] + x_shift[:, None, None], - self.height_grid[None, :, :] + y_shift[:, None, None], - ], - dim=3, - ) - - # TODO: only cast to int if original tensor was int - return ( - torch.round(F.grid_sample(inputs, grid, align_corners=True)) - .squeeze(1) - .long() - ) - - -class NetHackNet(nn.Module): - def __init__( - self, - observation_shape, - num_actions, - use_lstm, - embedding_dim=32, - crop_dim=9, - num_layers=5, - ): - super(NetHackNet, self).__init__() - - self.glyph_shape = observation_shape["glyphs"].shape - self.blstats_size = observation_shape["blstats"].shape[0] - - self.num_actions = num_actions - self.use_lstm = use_lstm - - self.H = self.glyph_shape[0] - self.W = self.glyph_shape[1] - - self.k_dim = embedding_dim - self.h_dim = 512 - - self.crop_dim = crop_dim - - self.crop = Crop(self.H, self.W, self.crop_dim, self.crop_dim) - - self.embed = nn.Embedding(nethack.MAX_GLYPH, self.k_dim) - - K = embedding_dim # number of input filters - F = 3 # filter dimensions - S = 1 # stride - P = 1 # padding - M = 16 # number of intermediate filters - Y = 8 # number of output filters - L = num_layers # number of convnet layers - - in_channels = [K] + [M] * (L - 1) - out_channels = [M] * (L - 1) + [Y] - - def interleave(xs, ys): - return [val for pair in zip(xs, ys) for val in pair] - - conv_extract = [ - nn.Conv2d( - in_channels=in_channels[i], - out_channels=out_channels[i], - kernel_size=(F, F), - stride=S, - padding=P, - ) - for i in range(L) - ] - - self.extract_representation = nn.Sequential( - *interleave(conv_extract, [nn.ELU()] * len(conv_extract)) - ) - - # CNN crop model. - conv_extract_crop = [ - nn.Conv2d( - in_channels=in_channels[i], - out_channels=out_channels[i], - kernel_size=(F, F), - stride=S, - padding=P, - ) - for i in range(L) - ] - - self.extract_crop_representation = nn.Sequential( - *interleave(conv_extract_crop, [nn.ELU()] * len(conv_extract)) - ) - - out_dim = self.k_dim - # CNN over full glyph map - out_dim += self.H * self.W * Y - - # CNN crop model. - out_dim += self.crop_dim ** 2 * Y - - self.embed_blstats = nn.Sequential( - nn.Linear(self.blstats_size, self.k_dim), - nn.ReLU(), - nn.Linear(self.k_dim, self.k_dim), - nn.ReLU(), - ) - - self.fc = nn.Sequential( - nn.Linear(out_dim, self.h_dim), - nn.ReLU(), - nn.Linear(self.h_dim, self.h_dim), - nn.ReLU(), - ) - - if self.use_lstm: - self.core = nn.LSTM(self.h_dim, self.h_dim, num_layers=1) - - self.policy = nn.Linear(self.h_dim, self.num_actions) - self.baseline = nn.Linear(self.h_dim, 1) - - def initial_state(self, batch_size=1): - if not self.use_lstm: - return tuple() - return tuple( - torch.zeros(self.core.num_layers, batch_size, self.core.hidden_size) - for _ in range(2) - ) - - def _select(self, embed, x): - # Work around slow backward pass of nn.Embedding, see - # https://github.com/pytorch/pytorch/issues/24912 - out = embed.weight.index_select(0, x.reshape(-1)) - return out.reshape(x.shape + (-1,)) - - def forward(self, env_outputs, core_state): - # -- [T x B x H x W] - glyphs = env_outputs["glyphs"] - - # -- [T x B x F] - blstats = env_outputs["blstats"] - - T, B, *_ = glyphs.shape - - # -- [B' x H x W] - glyphs = torch.flatten(glyphs, 0, 1) # Merge time and batch. - - # -- [B' x F] - blstats = blstats.view(T * B, -1).float() - - # -- [B x H x W] - glyphs = glyphs.long() - # -- [B x 2] x,y coordinates - coordinates = blstats[:, :2] - # TODO ??? - # coordinates[:, 0].add_(-1) - - # -- [B x F] - # FIXME: hack to use compatible blstats to before - # blstats = blstats[:, [0, 1, 21, 10, 11]] - - blstats = blstats.view(T * B, -1).float() - # -- [B x K] - blstats_emb = self.embed_blstats(blstats) - - assert blstats_emb.shape[0] == T * B - - reps = [blstats_emb] - - # -- [B x H' x W'] - crop = self.crop(glyphs, coordinates) - - # print("crop", crop) - # print("at_xy", glyphs[:, coordinates[:, 1].long(), coordinates[:, 0].long()]) - - # -- [B x H' x W' x K] - crop_emb = self._select(self.embed, crop) - - # CNN crop model. - # -- [B x K x W' x H'] - crop_emb = crop_emb.transpose(1, 3) # -- TODO: slow? - # -- [B x W' x H' x K] - crop_rep = self.extract_crop_representation(crop_emb) - - # -- [B x K'] - crop_rep = crop_rep.view(T * B, -1) - assert crop_rep.shape[0] == T * B - - reps.append(crop_rep) - - # -- [B x H x W x K] - glyphs_emb = self._select(self.embed, glyphs) - # glyphs_emb = self.embed(glyphs) - # -- [B x K x W x H] - glyphs_emb = glyphs_emb.transpose(1, 3) # -- TODO: slow? - # -- [B x W x H x K] - glyphs_rep = self.extract_representation(glyphs_emb) - - # -- [B x K'] - glyphs_rep = glyphs_rep.view(T * B, -1) - - assert glyphs_rep.shape[0] == T * B - - # -- [B x K''] - reps.append(glyphs_rep) - - st = torch.cat(reps, dim=1) - - # -- [B x K] - st = self.fc(st) - - if self.use_lstm: - core_input = st.view(T, B, -1) - core_output_list = [] - notdone = (~env_outputs["done"]).float() - for input, nd in zip(core_input.unbind(), notdone.unbind()): - # Reset core state to zero whenever an episode ended. - # Make `done` broadcastable with (num_layers, B, hidden_size) - # states: - nd = nd.view(1, -1, 1) - core_state = tuple(nd * s for s in core_state) - output, core_state = self.core(input.unsqueeze(0), core_state) - core_output_list.append(output) - core_output = torch.flatten(torch.cat(core_output_list), 0, 1) - else: - core_output = st - - # -- [B x A] - policy_logits = self.policy(core_output) - # -- [B x A] - baseline = self.baseline(core_output) - - if self.training: - action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1) - else: - # Don't sample when testing. - action = torch.argmax(policy_logits, dim=1) - - policy_logits = policy_logits.view(T, B, self.num_actions) - baseline = baseline.view(T, B) - action = action.view(T, B) - - return ( - dict(policy_logits=policy_logits, baseline=baseline, action=action), - core_state, - ) - - -Net = NetHackNet - - -def main(flags): - if flags.mode == "train": - train(flags) - else: - test(flags) - - -if __name__ == "__main__": - flags = parser.parse_args() - main(flags) diff --git a/nle/agent/config.yaml b/nle/agent/config.yaml new file mode 100644 index 000000000..2e456f328 --- /dev/null +++ b/nle/agent/config.yaml @@ -0,0 +1,139 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +defaults: +- hydra/job_logging: colorlog +- hydra/hydra_logging: colorlog +- hydra/launcher: submitit_slurm + +# pip install hydra-core hydra_colorlog +# can set these on the commandline too, e.g. `hydra.launcher.partition=dev` +hydra: + launcher: + timeout_min: 4300 + cpus_per_task: 20 + gpus_per_node: 2 + tasks_per_node: 1 + mem_gb: 20 + nodes: 1 + partition: dev + comment: null + max_num_timeout: 5 # will requeue on timeout or preemption + + +name: null # can use this to have multiple runs with same params, eg name=1,2,3,4,5 + +## WandB settings +wandb: false # enable wandb logging +project: neurips2020 # specifies the project name to log to +entity: nethack # the team to log to +group: default # defines a group name for the experiment + +# Polybeast settings +mock: false +single_ttyrec: true +num_seeds: 0 + +write_profiler_trace: false +relative_reward: false +fn_penalty_step: constant +penalty_time: 0.0 +penalty_step: -0.001 +reward_lose: 0 +reward_win: 1 +character: mon-hum-neu-mal +## typical characters we use +# mon-hum-neu-mal +# val-dwa-law-fem +# wiz-elf-cha-mal +# tou-hum-neu-fem + +# Run settings. +mode: train +env: score +## env (task) names +# staircase +# pet +# eat +# gold +# score +# scout +# oracle + +# Training settings. +num_actors: 256 # should be at least batch_size +total_steps: 1e9 # 1e9 used in paper +batch_size: 32 # 32 is standard, can use 128 with small model variants +unroll_length: 80 # 80 is standard +num_learner_threads: 1 +num_inference_threads: 1 +disable_cuda: false +learner_device: cuda:1 +actor_device: cuda:0 + +# Model settings. +model: baseline # random, baseline, rnd, ride +use_lstm: true +hidden_dim: 256 # use at least 128, 256 is stronger +embedding_dim: 64 # use at least 32, 64 is stronger +glyph_type: all_cat # full, group_id, color_char, all, all_cat* (all_cat best, full fastest) +equalize_input_dim: false # project inputs to same dim (*false unless doing dynamics) +equalize_factor: 2 # multiplies hdim by this when equalize is enabled (2 > 1) +layers: 5 # number of cnn layers for crop/glyph model +crop_model: cnn +crop_dim: 9 # size of crop +use_index_select: true # use index select instead of normal embedding lookup + + +# Loss settings. +entropy_cost: 0.001 # 0.001 is better than 0.0001 +baseline_cost: 0.5 +discounting: 0.999 # probably a bit better at 0.999, esp with intrinsic reward +reward_clipping: none # use none with normalize_reward, else use tim +normalize_reward: true # true is reliable across tasks, but false & tim-clip is best on score + +# Optimizer settings. +learning_rate: 0.0002 +grad_norm_clipping: 40 +# rmsprop settings +alpha: 0.99 # 0.99 vs 0.9 vs 0.5 seems to make no difference +momentum: 0 # keep at 0 +epsilon: 0.000001 # do not use 0.01, 1e-6 seems same as 1e-8 + +# Experimental settings. +state_counter: none # none, coordinates +no_extrinsic: false # ignore extrinsic reward + +int: # intrinsic reward options + twoheaded: true # separate value heads for extrinsic & intrinsic, use True + input: full # what to model? full, crop_only, glyph_only (for RND, RIDE) + intrinsic_weight: 0.1 # this needs to be tuned per-model, each have different scale + discounting: 0.99 + baseline_cost: 0.5 + episodic: true + reward_clipping: none # none is best with normalize enabled + normalize_reward: true # whether to use reward normalization for intrinsic reward + +ride: # Rewarding Impact-Driven Exploration + count_norm: true # normalise reward by the number of visits to a state + forward_cost: 1 + inverse_cost: 0.1 + hidden_dim: 128 + +rnd: # Random Network Distillation + forward_cost: 0.01 # weight on modelling loss (ie convergence of predictor) + +msg: + model: lt_cnn # character model? none, lt_cnn*, cnn, gru, lstm + hidden_dim: 64 # recommend 256 + embedding_dim: 32 # recommend 64 \ No newline at end of file diff --git a/nle/agent/core/file_writer.py b/nle/agent/core/file_writer.py new file mode 100644 index 000000000..99b06fe4f --- /dev/null +++ b/nle/agent/core/file_writer.py @@ -0,0 +1,203 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import csv +import datetime +import json +import logging +import os +import time +import weakref + + +def _save_metadata(path, metadata): + metadata["date_save"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") + with open(path, "w") as f: + json.dump(metadata, f, indent=4, sort_keys=True) + + +def gather_metadata(): + metadata = dict( + date_start=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), + env=os.environ.copy(), + successful=False, + ) + + # Git metadata. + try: + import git + except ImportError: + logging.warning( + "Couldn't import gitpython module; install it with `pip install gitpython`." + ) + else: + try: + repo = git.Repo(search_parent_directories=True) + metadata["git"] = { + "commit": repo.commit().hexsha, + "is_dirty": repo.is_dirty(), + "path": repo.git_dir, + } + if not repo.head.is_detached: + metadata["git"]["branch"] = repo.active_branch.name + except git.InvalidGitRepositoryError: + pass + + if "git" not in metadata: + logging.warning("Couldn't determine git data.") + + # Slurm metadata. + if "SLURM_JOB_ID" in os.environ: + slurm_env_keys = [k for k in os.environ if k.startswith("SLURM")] + metadata["slurm"] = {} + for k in slurm_env_keys: + d_key = k.replace("SLURM_", "").replace("SLURMD_", "").lower() + metadata["slurm"][d_key] = os.environ[k] + + return metadata + + +class FileWriter: + def __init__(self, xp_args=None, rootdir="~/palaas"): + if rootdir == "~/palaas": + # make unique id in case someone uses the default rootdir + xpid = "{proc}_{unixtime}".format( + proc=os.getpid(), unixtime=int(time.time()) + ) + rootdir = os.path.join(rootdir, xpid) + self.basepath = os.path.expandvars(os.path.expanduser(rootdir)) + + self._tick = 0 + + # metadata gathering + if xp_args is None: + xp_args = {} + self.metadata = gather_metadata() + # we need to copy the args, otherwise when we close the file writer + # (and rewrite the args) we might have non-serializable objects (or + # other nasty stuff). + self.metadata["args"] = copy.deepcopy(xp_args) + + formatter = logging.Formatter("%(message)s") + self._logger = logging.getLogger("palaas/out") + + # to stdout handler + shandle = logging.StreamHandler() + shandle.setFormatter(formatter) + self._logger.addHandler(shandle) + self._logger.setLevel(logging.INFO) + + # to file handler + if not os.path.exists(self.basepath): + self._logger.info("Creating log directory: %s", self.basepath) + os.makedirs(self.basepath, exist_ok=True) + else: + self._logger.info("Found log directory: %s", self.basepath) + + self.paths = dict( + msg="{base}/out.log".format(base=self.basepath), + logs="{base}/logs.csv".format(base=self.basepath), + fields="{base}/fields.csv".format(base=self.basepath), + meta="{base}/meta.json".format(base=self.basepath), + ) + + self._logger.info("Saving arguments to %s", self.paths["meta"]) + if os.path.exists(self.paths["meta"]): + self._logger.warning( + "Path to meta file already exists. " "Not overriding meta." + ) + else: + self.save_metadata() + + self._logger.info("Saving messages to %s", self.paths["msg"]) + if os.path.exists(self.paths["msg"]): + self._logger.warning( + "Path to message file already exists. " "New data will be appended." + ) + + fhandle = logging.FileHandler(self.paths["msg"]) + fhandle.setFormatter(formatter) + self._logger.addHandler(fhandle) + + self._logger.info("Saving logs data to %s", self.paths["logs"]) + self._logger.info("Saving logs' fields to %s", self.paths["fields"]) + self.fieldnames = ["_tick", "_time"] + if os.path.exists(self.paths["logs"]): + self._logger.warning( + "Path to log file already exists. " "New data will be appended." + ) + # Override default fieldnames. + with open(self.paths["fields"], "r") as csvfile: + reader = csv.reader(csvfile) + lines = list(reader) + if len(lines) > 0: + self.fieldnames = lines[-1] + # Override default tick: use the last tick from the logs file plus 1. + with open(self.paths["logs"], "r") as csvfile: + reader = csv.reader(csvfile) + lines = list(reader) + # Need at least two lines in order to read the last tick: + # the first is the csv header and the second is the first line + # of data. + if len(lines) > 1: + self._tick = int(lines[-1][0]) + 1 + + self._fieldfile = open(self.paths["fields"], "a") + self._fieldwriter = csv.writer(self._fieldfile) + self._fieldfile.flush() + self._logfile = open(self.paths["logs"], "a") + self._logwriter = csv.DictWriter(self._logfile, fieldnames=self.fieldnames) + + # Auto-close (and save) on destruction. + weakref.finalize(self, _save_metadata, self.paths["meta"], self.metadata) + + def log(self, to_log, tick=None, verbose=False): + if tick is not None: + raise NotImplementedError + else: + to_log["_tick"] = self._tick + self._tick += 1 + to_log["_time"] = time.time() + + old_len = len(self.fieldnames) + for k in to_log: + if k not in self.fieldnames: + self.fieldnames.append(k) + if old_len != len(self.fieldnames): + self._fieldwriter.writerow(self.fieldnames) + self._fieldfile.flush() + self._logger.info("Updated log fields: %s", self.fieldnames) + + if to_log["_tick"] == 0: + self._logfile.write("# %s\n" % ",".join(self.fieldnames)) + + if verbose: + self._logger.info( + "LOG | %s", + ", ".join(["{}: {}".format(k, to_log[k]) for k in sorted(to_log)]), + ) + + self._logwriter.writerow(to_log) + self._logfile.flush() + + def close(self, successful=True): + self.metadata["successful"] = successful + self.save_metadata() + + for f in [self._logfile, self._fieldfile]: + f.close() + + def save_metadata(self): + _save_metadata(self.paths["meta"], self.metadata) diff --git a/nle/agent/core/prof.py b/nle/agent/core/prof.py new file mode 100644 index 000000000..3a031489a --- /dev/null +++ b/nle/agent/core/prof.py @@ -0,0 +1,82 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Naive profiling using timeit.""" + +import collections +import timeit + + +class Timings: + """Not thread-safe.""" + + def __init__(self): + self._means = collections.defaultdict(int) + self._vars = collections.defaultdict(int) + self._counts = collections.defaultdict(int) + self.reset() + + def reset(self): + self.last_time = timeit.default_timer() + + def time(self, name): + """Save an update for event `name`. + + Nerd alarm: We could just store a + collections.defaultdict(list) + and compute means and standard deviations at the end. But thanks to the + clever math in Sutton-Barto + (http://www.incompleteideas.net/book/first/ebook/node19.html) and + https://math.stackexchange.com/a/103025/5051 we can update both the + means and the stds online. O(1) FTW! + """ + now = timeit.default_timer() + x = now - self.last_time + self.last_time = now + + n = self._counts[name] + + mean = self._means[name] + (x - self._means[name]) / (n + 1) + var = ( + n * self._vars[name] + n * (self._means[name] - mean) ** 2 + (x - mean) ** 2 + ) / (n + 1) + + self._means[name] = mean + self._vars[name] = var + self._counts[name] += 1 + + def means(self): + return self._means + + def vars(self): + return self._vars + + def stds(self): + return {k: v ** 0.5 for k, v in self._vars.items()} + + def summary(self, prefix=""): + means = self.means() + stds = self.stds() + total = sum(means.values()) + + result = prefix + for k in sorted(means, key=means.get, reverse=True): + result += "\n %s: %.6fms +- %.6fms (%.2f%%) " % ( + k, + 1000 * means[k], + 1000 * stds[k], + 100 * means[k] / total, + ) + result += "\nTotal: %.6fms" % (1000 * total) + return result diff --git a/nle/agent/vtrace.py b/nle/agent/core/vtrace.py similarity index 93% rename from nle/agent/vtrace.py rename to nle/agent/core/vtrace.py index 34935286a..8d851a044 100644 --- a/nle/agent/vtrace.py +++ b/nle/agent/core/vtrace.py @@ -49,8 +49,8 @@ def action_log_probs(policy_logits, actions): return -F.nll_loss( - F.log_softmax(torch.flatten(policy_logits, 0, -2), dim=-1), - torch.flatten(actions), + F.log_softmax(torch.flatten(policy_logits, 0, 1), dim=-1), + torch.flatten(actions, 0, 1), reduction="none", ).view_as(actions) @@ -125,10 +125,7 @@ def from_importance_weights( vs = torch.add(vs_minus_v_xs, values) # Advantage for policy gradient. - broadcasted_bootstrap_values = torch.ones_like(vs[0]) * bootstrap_value - vs_t_plus_1 = torch.cat( - [vs[1:], broadcasted_bootstrap_values.unsqueeze(0)], dim=0 - ) + vs_t_plus_1 = torch.cat([vs[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0) if clip_pg_rho_threshold is not None: clipped_pg_rhos = torch.clamp(rhos, max=clip_pg_rho_threshold) else: diff --git a/nle/agent/core/vtrace_test.py b/nle/agent/core/vtrace_test.py new file mode 100644 index 000000000..97e815828 --- /dev/null +++ b/nle/agent/core/vtrace_test.py @@ -0,0 +1,263 @@ +# This file taken from +# https://github.com/deepmind/scalable_agent/blob/ +# d24bd74bd53d454b7222b7f0bea57a358e4ca33e/vtrace_test.py +# and modified. + +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for V-trace. + +For details and theory see: + +"IMPALA: Scalable Distributed Deep-RL with +Importance Weighted Actor-Learner Architectures" +by Espeholt, Soyer, Munos et al. +""" + +import unittest + +import numpy as np +import torch +import vtrace + + +def _shaped_arange(*shape): + """Runs np.arange, converts to float and reshapes.""" + return np.arange(np.prod(shape), dtype=np.float32).reshape(*shape) + + +def _softmax(logits): + """Applies softmax non-linearity on inputs.""" + return np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True) + + +def _ground_truth_calculation( + discounts, + log_rhos, + rewards, + values, + bootstrap_value, + clip_rho_threshold, + clip_pg_rho_threshold, +): + """Calculates the ground truth for V-trace in Python/Numpy.""" + vs = [] + seq_len = len(discounts) + rhos = np.exp(log_rhos) + cs = np.minimum(rhos, 1.0) + clipped_rhos = rhos + if clip_rho_threshold: + clipped_rhos = np.minimum(rhos, clip_rho_threshold) + clipped_pg_rhos = rhos + if clip_pg_rho_threshold: + clipped_pg_rhos = np.minimum(rhos, clip_pg_rho_threshold) + + # This is a very inefficient way to calculate the V-trace ground truth. + # We calculate it this way because it is close to the mathematical notation + # of V-trace. + # v_s = V(x_s) + # + \sum^{T-1}_{t=s} \gamma^{t-s} + # * \prod_{i=s}^{t-1} c_i + # * \rho_t (r_t + \gamma V(x_{t+1}) - V(x_t)) + # Note that when we take the product over c_i, we write `s:t` as the + # notation of the paper is inclusive of the `t-1`, but Python is exclusive. + # Also note that np.prod([]) == 1. + values_t_plus_1 = np.concatenate([values, bootstrap_value[None, :]], axis=0) + for s in range(seq_len): + v_s = np.copy(values[s]) # Very important copy. + for t in range(s, seq_len): + v_s += ( + np.prod(discounts[s:t], axis=0) + * np.prod(cs[s:t], axis=0) + * clipped_rhos[t] + * (rewards[t] + discounts[t] * values_t_plus_1[t + 1] - values[t]) + ) + vs.append(v_s) + vs = np.stack(vs, axis=0) + pg_advantages = clipped_pg_rhos * ( + rewards + + discounts * np.concatenate([vs[1:], bootstrap_value[None, :]], axis=0) + - values + ) + + return vtrace.VTraceReturns(vs=vs, pg_advantages=pg_advantages) + + +def assert_allclose(actual, desired): + return np.testing.assert_allclose(actual, desired, rtol=1e-06, atol=1e-05) + + +class ActionLogProbsTest(unittest.TestCase): + def test_action_log_probs(self, batch_size=2): + seq_len = 7 + num_actions = 3 + + policy_logits = _shaped_arange(seq_len, batch_size, num_actions) + 10 + actions = np.random.randint( + 0, num_actions, size=(seq_len, batch_size), dtype=np.int64 + ) + + action_log_probs_tensor = vtrace.action_log_probs( + torch.from_numpy(policy_logits), torch.from_numpy(actions) + ) + + # Ground Truth + # Using broadcasting to create a mask that indexes action logits + action_index_mask = actions[..., None] == np.arange(num_actions) + + def index_with_mask(array, mask): + return array[mask].reshape(*array.shape[:-1]) + + # Note: Normally log(softmax) is not a good idea because it's not + # numerically stable. However, in this test we have well-behaved values. + ground_truth_v = index_with_mask( + np.log(_softmax(policy_logits)), action_index_mask + ) + + assert_allclose(ground_truth_v, action_log_probs_tensor) + + def test_action_log_probs_batch_1(self): + self.test_action_log_probs(1) + + +class VtraceTest(unittest.TestCase): + def test_vtrace(self, batch_size=5): + """Tests V-trace against ground truth data calculated in python.""" + seq_len = 5 + + # Create log_rhos such that rho will span from near-zero to above the + # clipping thresholds. In particular, calculate log_rhos in [-2.5, 2.5), + # so that rho is in approx [0.08, 12.2). + log_rhos = _shaped_arange(seq_len, batch_size) / (batch_size * seq_len) + log_rhos = 5 * (log_rhos - 0.5) # [0.0, 1.0) -> [-2.5, 2.5). + values = { + "log_rhos": log_rhos, + # T, B where B_i: [0.9 / (i+1)] * T + "discounts": np.array( + [[0.9 / (b + 1) for b in range(batch_size)] for _ in range(seq_len)], + dtype=np.float32, + ), + "rewards": _shaped_arange(seq_len, batch_size), + "values": _shaped_arange(seq_len, batch_size) / batch_size, + "bootstrap_value": _shaped_arange(batch_size) + 1.0, + "clip_rho_threshold": 3.7, + "clip_pg_rho_threshold": 2.2, + } + + ground_truth = _ground_truth_calculation(**values) + + values = {key: torch.tensor(value) for key, value in values.items()} + output = vtrace.from_importance_weights(**values) + + for a, b in zip(ground_truth, output): + assert_allclose(a, b) + + def test_vtrace_batch_1(self): + self.test_vtrace(1) + + def test_vtrace_from_logits(self, batch_size=2): + """Tests V-trace calculated from logits.""" + seq_len = 5 + num_actions = 3 + clip_rho_threshold = None # No clipping. + clip_pg_rho_threshold = None # No clipping. + + values = { + "behavior_policy_logits": _shaped_arange(seq_len, batch_size, num_actions), + "target_policy_logits": _shaped_arange(seq_len, batch_size, num_actions), + "actions": np.random.randint( + 0, num_actions - 1, size=(seq_len, batch_size) + ), + "discounts": np.array( # T, B where B_i: [0.9 / (i+1)] * T + [[0.9 / (b + 1) for b in range(batch_size)] for _ in range(seq_len)], + dtype=np.float32, + ), + "rewards": _shaped_arange(seq_len, batch_size), + "values": _shaped_arange(seq_len, batch_size) / batch_size, + "bootstrap_value": _shaped_arange(batch_size) + 1.0, # B + } + values = {k: torch.from_numpy(v) for k, v in values.items()} + + from_logits_output = vtrace.from_logits( + clip_rho_threshold=clip_rho_threshold, + clip_pg_rho_threshold=clip_pg_rho_threshold, + **values, + ) + + target_log_probs = vtrace.action_log_probs( + values["target_policy_logits"], values["actions"] + ) + behavior_log_probs = vtrace.action_log_probs( + values["behavior_policy_logits"], values["actions"] + ) + log_rhos = target_log_probs - behavior_log_probs + + # Calculate V-trace using the ground truth logits. + from_iw = vtrace.from_importance_weights( + log_rhos=log_rhos, + discounts=values["discounts"], + rewards=values["rewards"], + values=values["values"], + bootstrap_value=values["bootstrap_value"], + clip_rho_threshold=clip_rho_threshold, + clip_pg_rho_threshold=clip_pg_rho_threshold, + ) + + assert_allclose(from_iw.vs, from_logits_output.vs) + assert_allclose(from_iw.pg_advantages, from_logits_output.pg_advantages) + assert_allclose( + behavior_log_probs, from_logits_output.behavior_action_log_probs + ) + assert_allclose(target_log_probs, from_logits_output.target_action_log_probs) + assert_allclose(log_rhos, from_logits_output.log_rhos) + + def test_vtrace_from_logits_batch_1(self): + self.test_vtrace_from_logits(1) + + def test_higher_rank_inputs_for_importance_weights(self): + """Checks support for additional dimensions in inputs.""" + T = 3 # pylint: disable=invalid-name + B = 2 # pylint: disable=invalid-name + values = { + "log_rhos": torch.zeros(T, B, 1), + "discounts": torch.zeros(T, B, 1), + "rewards": torch.zeros(T, B, 42), + "values": torch.zeros(T, B, 42), + "bootstrap_value": torch.zeros(B, 42), + } + output = vtrace.from_importance_weights(**values) + self.assertSequenceEqual(output.vs.shape, (T, B, 42)) + + def test_inconsistent_rank_inputs_for_importance_weights(self): + """Test one of many possible errors in shape of inputs.""" + T = 3 # pylint: disable=invalid-name + B = 2 # pylint: disable=invalid-name + + values = { + "log_rhos": torch.zeros(T, B, 1), + "discounts": torch.zeros(T, B, 1), + "rewards": torch.zeros(T, B, 42), + "values": torch.zeros(T, B, 42), + # Should be [B, 42]. + "bootstrap_value": torch.zeros(B), + } + + with self.assertRaisesRegex( + RuntimeError, "same number of dimensions: got 3 and 2" + ): + vtrace.from_importance_weights(**values) + + +if __name__ == "__main__": + unittest.main() diff --git a/nle/agent/envs/README.md b/nle/agent/envs/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/nle/agent/envs/__init__.py b/nle/agent/envs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/nle/agent/envs/tasks.py b/nle/agent/envs/tasks.py new file mode 100644 index 000000000..137ba63c4 --- /dev/null +++ b/nle/agent/envs/tasks.py @@ -0,0 +1,114 @@ +from collections import defaultdict +from nle.env import tasks +import numpy as np + + +class SharedPatch(object): + def __init__(self, *args, state_counter="none", **kwargs): + # intialize state counter + self.state_counter = state_counter + if self.state_counter != "none": + self.state_count_dict = defaultdict(int) + # this super() goes to the parent of the particular task, not to `object` + super().__init__(*args, **kwargs) + + def step(self, action): + # add state counting to step function if desired + step_return = super().step(action) + if self.state_counter == "none": + # do nothing + return step_return + + obs, reward, done, info = step_return + + if self.state_counter == "ones": + # treat every state as unique + state_visits = 1 + elif self.state_counter == "coordinates": + # use the location of the agent within the dungeon to accumulate visits + features = obs["blstats"] + x = features[0] + y = features[1] + # TODO: prefer to use dungeon level and dungeon number from Blstats + d = features[12] + coord = (d, x, y) + self.state_count_dict[coord] += 1 + state_visits = self.state_count_dict[coord] + else: + raise NotImplementedError("state_counter=%s" % self.state_counter) + + obs.update(state_visits=np.array([state_visits])) + + if done: + self.state_count_dict.clear() + + return step_return + + def reset(self, wizkit_items=None): + # reset state counter when env resets + obs = super().reset(wizkit_items=wizkit_items) + if self.state_counter != "none": + self.state_count_dict.clear() + # current state counts as one visit + obs.update(state_visits=np.array([1])) + return obs + + +class PatchedNetHackScore(SharedPatch, tasks.NetHackScore): + pass + + +class PatchedNetHackStaircase(SharedPatch, tasks.NetHackStaircase): + def __init__(self, *args, reward_win=1, reward_lose=-1, **kwargs): + super().__init__(*args, **kwargs) + self.reward_win = reward_win + self.reward_lose = reward_lose + + def _reward_fn(self, last_response, response, end_status): + if end_status == self.StepStatus.TASK_SUCCESSFUL: + reward = self.reward_win + elif end_status == self.StepStatus.RUNNING: + reward = 0 + else: # death or aborted + reward = self.reward_lose + return reward + self._get_time_penalty(last_response, response) + + +class PatchedNetHackStaircasePet(PatchedNetHackStaircase, tasks.NetHackStaircasePet): + pass # inherit from PatchedNetHackStaircase + + +class PatchedNetHackStaircaseOracle(PatchedNetHackStaircase, tasks.NetHackOracle): + pass # inherit from PatchedNetHackStaircase + + +class PatchedNetHackGold(SharedPatch, tasks.NetHackGold): + pass + + +class PatchedNetHackEat(SharedPatch, tasks.NetHackEat): + pass + + +class PatchedNetHackScout(SharedPatch, tasks.NetHackScout): + pass + + +NetHackScore = PatchedNetHackScore +NetHackStaircase = PatchedNetHackStaircase +NetHackStaircasePet = PatchedNetHackStaircasePet +NetHackOracle = PatchedNetHackStaircaseOracle +NetHackGold = PatchedNetHackGold +NetHackEat = PatchedNetHackEat +NetHackScout = PatchedNetHackScout + + +ENVS = dict( + staircase=NetHackStaircase, + score=NetHackScore, + pet=NetHackStaircasePet, + oracle=NetHackOracle, + gold=NetHackGold, + eat=NetHackEat, + scout=NetHackScout, +) diff --git a/nle/agent/models/__init__.py b/nle/agent/models/__init__.py new file mode 100644 index 000000000..f61c25b4a --- /dev/null +++ b/nle/agent/models/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nle.agent.envs import tasks +from nle.env.base import DUNGEON_SHAPE +from nle.agent.models.base import BaseNet, RandomNet +from nle.agent.models.intrinsic import RNDNet, RIDENet + + +def create_model(flags, device): + model_string = flags.model + if model_string == "random": + model_cls = RandomNet + elif model_string == "baseline": + model_cls = BaseNet + elif model_string == "rnd": + model_cls = RNDNet + elif model_string == "ride": + model_cls = RIDENet + elif model_string == "cnn" or model_string == "transformer": + raise RuntimeError( + "model=%s deprecated, use model=baseline crop_model=%s instead" + % (model_string, model_string) + ) + else: + raise NotImplementedError("model=%s" % model_string) + + num_actions = len(tasks.ENVS[flags.env](savedir=None, archivefile=None)._actions) + + model = model_cls(DUNGEON_SHAPE, num_actions, flags, device) + model.to(device=device) + return model diff --git a/nle/agent/models/base.py b/nle/agent/models/base.py new file mode 100644 index 000000000..9a6569001 --- /dev/null +++ b/nle/agent/models/base.py @@ -0,0 +1,541 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import torch +from torch import nn +from torch.nn import functional as F + +from nle import nethack + +from nle.agent.models.embed import GlyphEmbedding +from nle.agent.models.transformer import TransformerEncoder + +NUM_GLYPHS = nethack.MAX_GLYPH +NUM_FEATURES = nethack.BLSTATS_SHAPE[0] +PAD_CHAR = 0 +NUM_CHARS = 128 + + +class NetHackNet(nn.Module): + AgentOutput = collections.namedtuple("AgentOutput", "action policy_logits baseline") + + def __init__(self): + super(NetHackNet, self).__init__() + + self.register_buffer("reward_sum", torch.zeros(())) + self.register_buffer("reward_m2", torch.zeros(())) + self.register_buffer("reward_count", torch.zeros(()).fill_(1e-8)) + + def forward(self, inputs, core_state): + raise NotImplementedError + + def initial_state(self, batch_size=1): + return () + + def prepare_input(self, inputs): + # -- [T x B x H x W] + glyphs = inputs["glyphs"] + + # -- [T x B x F] + features = inputs["blstats"] + + T, B, *_ = glyphs.shape + + # -- [B' x H x W] + glyphs = torch.flatten(glyphs, 0, 1) # Merge time and batch. + + # -- [B' x F] + features = features.view(T * B, -1).float() + + return glyphs, features + + def embed_state(self, inputs): + raise NotImplementedError + + @torch.no_grad() + def update_running_moments(self, reward_batch): + """Maintains a running mean of reward.""" + new_count = len(reward_batch) + new_sum = torch.sum(reward_batch) + new_mean = new_sum / new_count + + curr_mean = self.reward_sum / self.reward_count + new_m2 = torch.sum((reward_batch - new_mean) ** 2) + ( + (self.reward_count * new_count) + / (self.reward_count + new_count) + * (new_mean - curr_mean) ** 2 + ) + + self.reward_count += new_count + self.reward_sum += new_sum + self.reward_m2 += new_m2 + + @torch.no_grad() + def get_running_std(self): + """Returns standard deviation of the running mean of the reward.""" + return torch.sqrt(self.reward_m2 / self.reward_count) + + +class RandomNet(NetHackNet): + def __init__(self, observation_shape, num_actions, flags, device=None): + super(RandomNet, self).__init__() + self.num_actions = num_actions + self.theta = torch.nn.Parameter(torch.zeros(self.num_actions)) + + def forward(self, inputs, core_state): + T, B, *_ = inputs["glyphs"].shape + zeros = self.theta * 0 + # set logits to 0 + policy_logits = zeros[None, :].expand(T * B, -1) + # set baseline to 0 + baseline = policy_logits.sum(dim=1).view(-1, B) + + # sample random action + action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1).view( + T, B + ) + policy_logits = policy_logits.view(T, B, self.num_actions) + return ( + dict(policy_logits=policy_logits, baseline=baseline, action=action), + core_state, + ) + + def embed_state(self, inputs): + raise NotImplementedError + + +class Crop(nn.Module): + def __init__(self, height, width, height_target, width_target, device=None): + super(Crop, self).__init__() + self.width = width + self.height = height + self.width_target = width_target + self.height_target = height_target + self.width_grid = self._step_to_range(2 / (self.width - 1), self.width_target)[ + None, : + ].expand(self.height_target, -1) + self.height_grid = self._step_to_range(2 / (self.height - 1), height_target)[ + :, None + ].expand(-1, self.width_target) + + if device is not None: + self.width_grid = self.width_grid.to(device) + self.height_grid = self.height_grid.to(device) + + def _step_to_range(self, step, num_steps): + return torch.tensor([step * (i - num_steps // 2) for i in range(num_steps)]) + + def forward(self, inputs, coordinates): + """Calculates centered crop around given x,y coordinates. + + Args: + inputs [B x H x W] or [B x H x W x C] + coordinates [B x 2] x,y coordinates + + Returns: + [B x H' x W'] inputs cropped and centered around x,y coordinates. + """ + assert inputs.shape[1] == self.height, "expected %d but found %d" % ( + self.height, + inputs.shape[1], + ) + assert inputs.shape[2] == self.width, "expected %d but found %d" % ( + self.width, + inputs.shape[2], + ) + + permute_results = False + if inputs.dim() == 3: + inputs = inputs.unsqueeze(1) + else: + permute_results = True + inputs = inputs.permute(0, 2, 3, 1) + inputs = inputs.float() + + x = coordinates[:, 0] + y = coordinates[:, 1] + + x_shift = 2 / (self.width - 1) * (x.float() - self.width // 2) + y_shift = 2 / (self.height - 1) * (y.float() - self.height // 2) + + grid = torch.stack( + [ + self.width_grid[None, :, :] + x_shift[:, None, None], + self.height_grid[None, :, :] + y_shift[:, None, None], + ], + dim=3, + ) + + crop = ( + torch.round(F.grid_sample(inputs, grid, align_corners=True)) + .squeeze(1) + .long() + ) + + if permute_results: + # [B x C x H x W] -> [B x H x W x C] + crop = crop.permute(0, 2, 3, 1) + + return crop + + +class Flatten(nn.Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +class BaseNet(NetHackNet): + def __init__(self, observation_shape, num_actions, flags, device): + super(BaseNet, self).__init__() + + self.flags = flags + + self.observation_shape = observation_shape + + self.H = observation_shape[0] + self.W = observation_shape[1] + + self.num_actions = num_actions + self.use_lstm = flags.use_lstm + + self.k_dim = flags.embedding_dim + self.h_dim = flags.hidden_dim + + self.crop_model = flags.crop_model + self.crop_dim = flags.crop_dim + + self.num_features = NUM_FEATURES + + self.crop = Crop(self.H, self.W, self.crop_dim, self.crop_dim, device) + + self.glyph_type = flags.glyph_type + self.glyph_embedding = GlyphEmbedding( + flags.glyph_type, flags.embedding_dim, device, flags.use_index_select + ) + + K = flags.embedding_dim # number of input filters + F = 3 # filter dimensions + S = 1 # stride + P = 1 # padding + M = 16 # number of intermediate filters + self.Y = 8 # number of output filters + L = flags.layers # number of convnet layers + + in_channels = [K] + [M] * (L - 1) + out_channels = [M] * (L - 1) + [self.Y] + + def interleave(xs, ys): + return [val for pair in zip(xs, ys) for val in pair] + + conv_extract = [ + nn.Conv2d( + in_channels=in_channels[i], + out_channels=out_channels[i], + kernel_size=(F, F), + stride=S, + padding=P, + ) + for i in range(L) + ] + + self.extract_representation = nn.Sequential( + *interleave(conv_extract, [nn.ELU()] * len(conv_extract)) + ) + + if self.crop_model == "transformer": + self.extract_crop_representation = TransformerEncoder( + K, + N=L, + heads=8, + height=self.crop_dim, + width=self.crop_dim, + device=device, + ) + elif self.crop_model == "cnn": + conv_extract_crop = [ + nn.Conv2d( + in_channels=in_channels[i], + out_channels=out_channels[i], + kernel_size=(F, F), + stride=S, + padding=P, + ) + for i in range(L) + ] + + self.extract_crop_representation = nn.Sequential( + *interleave(conv_extract_crop, [nn.ELU()] * len(conv_extract)) + ) + + # MESSAGING MODEL + if "msg" not in flags: + self.msg_model = "none" + else: + self.msg_model = flags.msg.model + self.msg_hdim = flags.msg.hidden_dim + self.msg_edim = flags.msg.embedding_dim + if self.msg_model in ("gru", "lstm", "lt_cnn"): + # character-based embeddings + self.char_lt = nn.Embedding(NUM_CHARS, self.msg_edim, padding_idx=PAD_CHAR) + else: + # forward will set up one-hot inputs for the cnn, no lt needed + pass + + if self.msg_model.endswith("cnn"): + # from Zhang et al, 2016 + # Character-level Convolutional Networks for Text Classification + # https://arxiv.org/abs/1509.01626 + if self.msg_model == "cnn": + # inputs will be one-hot vectors, as done in paper + self.conv1 = nn.Conv1d(NUM_CHARS, self.msg_hdim, kernel_size=7) + elif self.msg_model == "lt_cnn": + # replace one-hot inputs with learned embeddings + self.conv1 = nn.Conv1d(self.msg_edim, self.msg_hdim, kernel_size=7) + else: + raise NotImplementedError("msg.model == %s", flags.msg.model) + + # remaining convolutions, relus, pools, and a small FC network + self.conv2_6_fc = nn.Sequential( + nn.ReLU(), + nn.MaxPool1d(kernel_size=3, stride=3), + # conv2 + nn.Conv1d(self.msg_hdim, self.msg_hdim, kernel_size=7), + nn.ReLU(), + nn.MaxPool1d(kernel_size=3, stride=3), + # conv3 + nn.Conv1d(self.msg_hdim, self.msg_hdim, kernel_size=3), + nn.ReLU(), + # conv4 + nn.Conv1d(self.msg_hdim, self.msg_hdim, kernel_size=3), + nn.ReLU(), + # conv5 + nn.Conv1d(self.msg_hdim, self.msg_hdim, kernel_size=3), + nn.ReLU(), + # conv6 + nn.Conv1d(self.msg_hdim, self.msg_hdim, kernel_size=3), + nn.ReLU(), + nn.MaxPool1d(kernel_size=3, stride=3), + # fc receives -- [ B x h_dim x 5 ] + Flatten(), + nn.Linear(5 * self.msg_hdim, 2 * self.msg_hdim), + nn.ReLU(), + nn.Linear(2 * self.msg_hdim, self.msg_hdim), + ) # final output -- [ B x h_dim x 5 ] + elif self.msg_model in ("gru", "lstm"): + + def rnn(flag): + return nn.LSTM if flag == "lstm" else nn.GRU + + self.char_rnn = rnn(self.msg_model)( + self.msg_edim, self.msg_hdim // 2, batch_first=True, bidirectional=True + ) + elif self.msg_model != "none": + raise NotImplementedError("msg.model == %s", flags.msg.model) + + self.embed_features = nn.Sequential( + nn.Linear(self.num_features, self.k_dim), + nn.ReLU(), + nn.Linear(self.k_dim, self.k_dim), + nn.ReLU(), + ) + + self.equalize_input_dim = flags.equalize_input_dim + if not self.equalize_input_dim: + # just added up the output dimensions of the input featurizers + # feature / status dim + out_dim = self.k_dim + # CNN over full glyph map + out_dim += self.H * self.W * self.Y + if self.crop_model == "transformer": + out_dim += self.crop_dim ** 2 * K + elif self.crop_model == "cnn": + out_dim += self.crop_dim ** 2 * self.Y + # messaging model + if self.msg_model != "none": + out_dim += self.msg_hdim + else: + # otherwise, project them all to h_dim + NUM_INPUTS = 4 if self.msg_model != "none" else 3 + project_hdim = flags.equalize_factor * self.h_dim + out_dim = project_hdim * NUM_INPUTS + + # set up linear layers for projections + self.project_feature_dim = nn.Linear(self.k_dim, project_hdim) + self.project_glyph_dim = nn.Linear(self.H * self.W * self.Y, project_hdim) + c__2 = self.crop_dim ** 2 + if self.crop_model == "transformer": + self.project_crop_dim = nn.Linear(c__2 * K, project_hdim) + elif self.crop_model == "cnn": + self.project_crop_dim = nn.Linear(c__2 * self.Y, project_hdim) + if self.msg_model != "none": + self.project_msg_dim = nn.Linear(self.msg_hdim, project_hdim) + + self.fc = nn.Sequential( + nn.Linear(out_dim, self.h_dim), + nn.ReLU(), + nn.Linear(self.h_dim, self.h_dim), + nn.ReLU(), + ) + + if self.use_lstm: + self.core = nn.LSTM(self.h_dim, self.h_dim, num_layers=1) + + self.policy = nn.Linear(self.h_dim, self.num_actions) + self.baseline = nn.Linear(self.h_dim, 1) + + def initial_state(self, batch_size=1): + if not self.use_lstm: + return tuple() + return tuple( + torch.zeros(self.core.num_layers, batch_size, self.core.hidden_size) + for _ in range(2) + ) + + def prepare_input(self, inputs): + # -- [T x B x H x W] + T, B, H, W = inputs["glyphs"].shape + + # take our chosen glyphs and merge the time and batch + + glyphs = self.glyph_embedding.prepare_input(inputs) + + # -- [T x B x F] + features = inputs["blstats"] + # -- [B' x F] + features = features.view(T * B, -1).float() + + return glyphs, features + + def forward(self, inputs, core_state, learning=False): + T, B, *_ = inputs["glyphs"].shape + + glyphs, features = self.prepare_input(inputs) + + # -- [B x 2] x,y coordinates + coordinates = features[:, :2] + + features = features.view(T * B, -1).float() + # -- [B x K] + features_emb = self.embed_features(features) + if self.equalize_input_dim: + features_emb = self.project_feature_dim(features_emb) + + assert features_emb.shape[0] == T * B + + reps = [features_emb] + + # -- [B x H' x W'] + crop = self.glyph_embedding.GlyphTuple( + *[self.crop(g, coordinates) for g in glyphs] + ) + # -- [B x H' x W' x K] + crop_emb = self.glyph_embedding(crop) + + if self.crop_model == "transformer": + # -- [B x W' x H' x K] + crop_rep = self.extract_crop_representation(crop_emb, mask=None) + elif self.crop_model == "cnn": + # -- [B x K x W' x H'] + crop_emb = crop_emb.transpose(1, 3) # -- TODO: slow? + # -- [B x W' x H' x K] + crop_rep = self.extract_crop_representation(crop_emb) + # -- [B x K'] + + crop_rep = crop_rep.view(T * B, -1) + if self.equalize_input_dim: + crop_rep = self.project_crop_dim(crop_rep) + assert crop_rep.shape[0] == T * B + + reps.append(crop_rep) + + # -- [B x H x W x K] + glyphs_emb = self.glyph_embedding(glyphs) + # glyphs_emb = self.embed(glyphs) + # -- [B x K x W x H] + glyphs_emb = glyphs_emb.transpose(1, 3) # -- TODO: slow? + # -- [B x W x H x K] + glyphs_rep = self.extract_representation(glyphs_emb) + + # -- [B x K'] + glyphs_rep = glyphs_rep.view(T * B, -1) + if self.equalize_input_dim: + glyphs_rep = self.project_glyph_dim(glyphs_rep) + + assert glyphs_rep.shape[0] == T * B + + # -- [B x K''] + reps.append(glyphs_rep) + + # MESSAGING MODEL + if self.msg_model != "none": + # [T x B x 256] -> [T * B x 256] + messages = inputs["message"].long().view(T * B, -1) + if self.msg_model == "cnn": + # convert messages to one-hot, [T * B x 96 x 256] + one_hot = F.one_hot(messages, num_classes=NUM_CHARS).transpose(1, 2) + char_rep = self.conv2_6_fc(self.conv1(one_hot.float())) + elif self.msg_model == "lt_cnn": + # [ T * B x E x 256 ] + char_emb = self.char_lt(messages).transpose(1, 2) + char_rep = self.conv2_6_fc(self.conv1(char_emb)) + else: # lstm, gru + char_emb = self.char_lt(messages) + output = self.char_rnn(char_emb)[0] + fwd_rep = output[:, -1, : self.h_dim // 2] + bwd_rep = output[:, 0, self.h_dim // 2 :] + char_rep = torch.cat([fwd_rep, bwd_rep], dim=1) + + if self.equalize_input_dim: + char_rep = self.project_msg_dim(char_rep) + reps.append(char_rep) + + st = torch.cat(reps, dim=1) + + # -- [B x K] + st = self.fc(st) + + if self.use_lstm: + core_input = st.view(T, B, -1) + core_output_list = [] + notdone = (~inputs["done"]).float() + for input, nd in zip(core_input.unbind(), notdone.unbind()): + # Reset core state to zero whenever an episode ended. + # Make `done` broadcastable with (num_layers, B, hidden_size) + # states: + nd = nd.view(1, -1, 1) + core_state = tuple(nd * t for t in core_state) + output, core_state = self.core(input.unsqueeze(0), core_state) + core_output_list.append(output) + core_output = torch.flatten(torch.cat(core_output_list), 0, 1) + else: + core_output = st + + # -- [B x A] + policy_logits = self.policy(core_output) + # -- [B x A] + baseline = self.baseline(core_output) + + if self.training: + action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1) + else: + # Don't sample when testing. + action = torch.argmax(policy_logits, dim=1) + + policy_logits = policy_logits.view(T, B, self.num_actions) + baseline = baseline.view(T, B) + action = action.view(T, B) + + output = dict(policy_logits=policy_logits, baseline=baseline, action=action) + return (output, core_state) diff --git a/nle/agent/models/dynamics.py b/nle/agent/models/dynamics.py new file mode 100644 index 000000000..a7bf5ba02 --- /dev/null +++ b/nle/agent/models/dynamics.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn +from torch.nn import functional as F + + +class ForwardDynamicsNet(nn.Module): + def __init__(self, num_actions, hidden_dim, input_dim, output_dim): + super(ForwardDynamicsNet, self).__init__() + self.num_actions = num_actions + + # TODO: add more layers + total_input_dim = input_dim + self.num_actions + self.forward_dynamics = nn.Sequential( + nn.Linear(total_input_dim, hidden_dim), + nn.ELU(), + nn.Linear(hidden_dim, output_dim), + ) + + def forward(self, state_embedding, action): + action_one_hot = F.one_hot(action, num_classes=self.num_actions).float() + inputs = torch.cat((state_embedding, action_one_hot), dim=-1) + next_state_emb = self.forward_dynamics(inputs) + return next_state_emb + + +class InverseDynamicsNet(nn.Module): + def __init__(self, num_actions, hidden_dim, input_dim1, input_dim2): + super(InverseDynamicsNet, self).__init__() + self.num_actions = num_actions + + # TODO: add more layers + total_input_dim = input_dim1 + input_dim2 # concat the inputs + self.inverse_dynamics = nn.Sequential( + nn.Linear(total_input_dim, hidden_dim), + nn.ELU(), + nn.Linear(hidden_dim, self.num_actions), + ) + + def forward(self, state_embedding, next_state_embedding): + inputs = torch.cat((state_embedding, next_state_embedding), dim=-1) + action_logits = self.inverse_dynamics(inputs) + return action_logits diff --git a/nle/agent/models/embed.py b/nle/agent/models/embed.py new file mode 100644 index 000000000..38a573a84 --- /dev/null +++ b/nle/agent/models/embed.py @@ -0,0 +1,178 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch import nn +import torch +from nle import nethack as nh + +from typing import NamedTuple, Union +from collections import namedtuple +from nle.agent.util.id_pairs import id_pairs_table +import logging + +Ratio = Union[int, bool] + + +class Targets(NamedTuple): + """Class for configuring whch ids you want to embed into the single + GlyphEmbedding, and in what ratios. The ratio is only relevant if + do_linear_layer is false, and the embedding is pure concatenation. + """ + + glyphs: Ratio = 0 + groups: Ratio = 0 + subgroup_ids: Ratio = 0 + colors: Ratio = 0 + chars: Ratio = 0 + specials: Ratio = 0 + do_linear_layer: bool = True + + def count_matrices(self): + """Count of matrices required""" + return sum(self) - int(self.do_linear_layer) + + +GLYPH_TYPE_STRATEGIES = { + "full": Targets(glyphs=True), + "group_id": Targets(groups=True, subgroup_ids=True), + "color_char": Targets(colors=True, chars=True, specials=True), + "all": Targets( + groups=True, subgroup_ids=True, colors=True, chars=True, specials=True + ), + "all_cat": Targets( + groups=1, subgroup_ids=3, colors=1, chars=2, specials=1, do_linear_layer=False + ), +} + + +class GlyphEmbedding(nn.Module): + """Take the glyph information and return an embedding vector.""" + + def __init__(self, glyph_type, dimension, device=None, use_index_select=None): + super(GlyphEmbedding, self).__init__() + logging.debug("Emdedding on device: %s ", device) + self.glyph_type = glyph_type + self.use_index_select = use_index_select + self.device = device + self.dim = dimension + + if glyph_type not in GLYPH_TYPE_STRATEGIES: + raise RuntimeError("unexpected glyph_type=%s" % self.glyph_type) + strategy = GLYPH_TYPE_STRATEGIES[glyph_type] + self.strategy = strategy + + self._unit_dim = dimension // strategy.count_matrices() + self._remainder_dim = self.dim - self._unit_dim * strategy.count_matrices() + + self._id_pairs_table = None + if self.requires_id_pairs_table: + self._id_pairs_table = torch.from_numpy(id_pairs_table()) + + # Build our custom embedding matrices + embed = {} + if strategy.glyphs: + embed["glyphs"] = nn.Embedding(nh.MAX_GLYPH, self._dim(strategy.glyphs)) + if strategy.colors: + embed["colors"] = nn.Embedding(16, self._dim(strategy.colors)) + if strategy.chars: + embed["chars"] = nn.Embedding(256, self._dim(strategy.chars)) + if strategy.specials: + embed["specials"] = nn.Embedding(256, self._dim(strategy.specials)) + if strategy.groups: + num_groups = self.id_pairs_table.select(1, 1).max().item() + 1 + embed["groups"] = nn.Embedding(num_groups, self._dim(strategy.groups)) + if strategy.subgroup_ids: + num_subgroup_ids = self.id_pairs_table.select(1, 0).max().item() + 1 + embed["subgroup_ids"] = nn.Embedding( + num_subgroup_ids, self._dim(strategy.subgroup_ids) + ) + + if self.id_pairs_table is not None and device is not None: + self._id_pairs_table = self._id_pairs_table.to(device) + + self.embeddings = nn.ModuleDict(embed) + self.targets = list(embed.keys()) + self.GlyphTuple = namedtuple("GlyphTuple", self.targets) + + if strategy.do_linear_layer and strategy.count_matrices() > 1: + self.linear = nn.Linear(strategy.count_matrices() * self.dim, self.dim) + + self.to(device) + + def _dim(self, units): + """Decide the embedding size for a single matrix. If using a linear layer + at the end this is always the embedding dimension, otherwise it is a + fraction of the embedding dim""" + if self.strategy.do_linear_layer: + return self.dim + else: + dim = units * self._unit_dim + self._remainder_dim + self._remainder_dim = 0 + return dim + + @property + def requires_id_pairs_table(self): + return self.strategy.groups or self.strategy.subgroup_ids + + @property + def id_pairs_table(self): + return self._id_pairs_table + + def prepare_input(self, inputs): + """Take the inputs to the network as dictionary and return a namedtuple + of the input/index tensors to be embedded (GlyphTuple)""" + embeddable_data = {} + # Only flatten the data we want + for key, value in inputs.items(): + if key in self.embeddings: + # -- [ T x B x ...] -> [ B' x ... ] + embeddable_data[key] = torch.flatten(value, 0, 1).long() + + # add our group id and subgroup id if we want them + if self.requires_id_pairs_table: + ids, groups = self.glyphs_to_idgroup(inputs["glyphs"]) + embeddable_data["groups"] = groups + embeddable_data["subgroup_ids"] = ids + + # convert embeddable_data to a named tuple + return self.GlyphTuple(**embeddable_data) + + def forward(self, data_tuple): + """Output the embdedded tuple prepared in in prepare input. This will be + a GlyphTuple.""" + embs = [] + for field, data in zip(self.targets, data_tuple): + embs.append(self._select(self.embeddings[field], data)) + if len(embs) == 1: + return embs[0] + + embedded = torch.cat(embs, dim=-1) + if self.strategy.do_linear_layer: + embedded = self.linear(embedded) + return embedded + + def _select(self, embedding_layer, x): + if self.use_index_select: + out = embedding_layer.weight.index_select(0, x.view(-1)) + # handle reshaping x to 1-d and output back to N-d + return out.view(x.shape + (-1,)) + else: + return embedding_layer(x) + + def glyphs_to_idgroup(self, glyphs): + T, B, H, W = glyphs.shape + ids_groups = self.id_pairs_table.index_select(0, glyphs.view(-1).long()) + ids = ids_groups.select(1, 0).view(T * B, H, W).long() + groups = ids_groups.select(1, 1).view(T * B, H, W).long() + return (ids, groups) diff --git a/nle/agent/models/intrinsic.py b/nle/agent/models/intrinsic.py new file mode 100644 index 000000000..c90d490d5 --- /dev/null +++ b/nle/agent/models/intrinsic.py @@ -0,0 +1,584 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import math +import torch +from torch import nn +from torch.nn import functional as F + +from nle.agent.models.base import BaseNet, PAD_CHAR, NUM_CHARS +from nle.agent.models.embed import GlyphEmbedding +from nle.agent.models.dynamics import ForwardDynamicsNet, InverseDynamicsNet + + +class IntrinsicRewardNet(BaseNet): + def __init__(self, observation_shape, num_actions, flags, device): + super(IntrinsicRewardNet, self).__init__( + observation_shape, num_actions, flags, device + ) + self.register_buffer("intrinsic_sum", torch.zeros(())) + self.register_buffer("intrinsic_m2", torch.zeros(())) + self.register_buffer("intrinsic_count", torch.zeros(()).fill_(1e-8)) + + self.intrinsic_input = flags.int.input + + self.int_baseline = nn.Linear(self.h_dim, 1) + + def intrinsic_enabled(self): + return True + + @torch.no_grad() + def update_intrinsic_moments(self, reward_batch): + """Maintains a running mean of reward.""" + new_count = len(reward_batch) + new_sum = torch.sum(reward_batch) + new_mean = new_sum / new_count + + curr_mean = self.intrinsic_sum / self.intrinsic_count + new_m2 = torch.sum((reward_batch - new_mean) ** 2) + ( + (self.intrinsic_count * new_count) + / (self.intrinsic_count + new_count) + * (new_mean - curr_mean) ** 2 + ) + + self.intrinsic_count += new_count + self.intrinsic_sum += new_sum + self.intrinsic_m2 += new_m2 + + @torch.no_grad() + def get_intrinsic_std(self): + """Returns standard deviation of the running mean of the intrinsic reward.""" + return torch.sqrt(self.intrinsic_m2 / self.intrinsic_count) + + +class RNDNet(IntrinsicRewardNet): + def __init__(self, observation_shape, num_actions, flags, device): + super(RNDNet, self).__init__(observation_shape, num_actions, flags, device) + + if self.equalize_input_dim: + raise NotImplementedError("rnd model does not support equalize_input_dim") + + Y = 8 # number of output filters + + # IMPLEMENTED HERE: RND net using the default feature extractor + self.rndtgt_embed = GlyphEmbedding( + flags.glyph_type, flags.embedding_dim, device, flags.use_index_select + ).requires_grad_(False) + self.rndprd_embed = GlyphEmbedding( + flags.glyph_type, flags.embedding_dim, device, flags.use_index_select + ) + + if self.intrinsic_input not in ("crop_only", "glyph_only", "full"): + raise NotImplementedError("RND input type %s" % self.intrinsic_input) + + rnd_out_dim = 0 + if self.intrinsic_input in ("crop_only", "full"): + self.rndtgt_extract_crop_representation = copy.deepcopy( + self.extract_crop_representation + ).requires_grad_(False) + self.rndprd_extract_crop_representation = copy.deepcopy( + self.extract_crop_representation + ) + + rnd_out_dim += self.crop_dim ** 2 * Y # crop dim + + if self.intrinsic_input in ("full", "glyph_only"): + self.rndtgt_extract_representation = copy.deepcopy( + self.extract_representation + ).requires_grad_(False) + self.rndprd_extract_representation = copy.deepcopy( + self.extract_representation + ) + rnd_out_dim += self.H * self.W * Y # glyph dim + + if self.intrinsic_input == "full": + self.rndtgt_embed_features = nn.Sequential( + nn.Linear(self.num_features, self.k_dim), + nn.ELU(), + nn.Linear(self.k_dim, self.k_dim), + nn.ELU(), + ).requires_grad_(False) + self.rndprd_embed_features = nn.Sequential( + nn.Linear(self.num_features, self.k_dim), + nn.ELU(), + nn.Linear(self.k_dim, self.k_dim), + nn.ELU(), + ) + rnd_out_dim += self.k_dim # feature dim + + if self.intrinsic_input == "full" and self.msg_model != "none": + # we only implement the lt_cnn msg model for RND for simplicity & speed + if self.msg_model != "lt_cnn": + logging.warning( + "msg.model set to %s, but RND overriding to lt_cnn for its input--" + "so the policy and RND are using different models for the messages" + % self.msg_model + ) + + self.rndtgt_char_lt = nn.Embedding( + NUM_CHARS, self.msg_edim, padding_idx=PAD_CHAR + ).requires_grad_(False) + self.rndprd_char_lt = nn.Embedding( + NUM_CHARS, self.msg_edim, padding_idx=PAD_CHAR + ) + + # similar to Zhang et al, 2016 + # Character-level Convolutional Networks for Text Classification + # https://arxiv.org/abs/1509.01626 + # replace one-hot inputs with learned embeddings + self.rndtgt_conv1 = nn.Conv1d( + self.msg_edim, self.msg_hdim, kernel_size=7 + ).requires_grad_(False) + self.rndprd_conv1 = nn.Conv1d(self.msg_edim, self.msg_hdim, kernel_size=7) + + # remaining convolutions, relus, pools, and a small FC network + self.rndtgt_conv2_6_fc = copy.deepcopy(self.conv2_6_fc).requires_grad_( + False + ) + self.rndprd_conv2_6_fc = copy.deepcopy(self.conv2_6_fc) + rnd_out_dim += self.msg_hdim + + self.rndtgt_fc = nn.Sequential( # matching RND paper making this smaller + nn.Linear(rnd_out_dim, self.h_dim) + ).requires_grad_(False) + self.rndprd_fc = nn.Sequential( # matching RND paper making this bigger + nn.Linear(rnd_out_dim, self.h_dim), + nn.ELU(), + nn.Linear(self.h_dim, self.h_dim), + nn.ELU(), + nn.Linear(self.h_dim, self.h_dim), + ) + + modules_to_init = [ + self.rndtgt_embed, + self.rndprd_embed, + self.rndtgt_fc, + self.rndprd_fc, + ] + + SQRT_2 = math.sqrt(2) + + def init(p): + if isinstance(p, nn.Conv2d) or isinstance(p, nn.Linear): + # init method used in paper + nn.init.orthogonal_(p.weight, SQRT_2) + p.bias.data.zero_() + if isinstance(p, nn.Embedding): + nn.init.orthogonal_(p.weight, SQRT_2) + + # manually init all to orthogonal dist + + if self.intrinsic_input in ("full", "crop_only"): + modules_to_init.append(self.rndtgt_extract_crop_representation) + modules_to_init.append(self.rndprd_extract_crop_representation) + if self.intrinsic_input in ("full", "glyph_only"): + modules_to_init.append(self.rndtgt_extract_representation) + modules_to_init.append(self.rndprd_extract_representation) + if self.intrinsic_input == "full": + modules_to_init.append(self.rndtgt_embed_features) + modules_to_init.append(self.rndprd_embed_features) + if self.msg_model != "none": + modules_to_init.append(self.rndtgt_conv2_6_fc) + modules_to_init.append(self.rndprd_conv2_6_fc) + + for m in modules_to_init: + for p in m.modules(): + init(p) + + def forward(self, inputs, core_state, learning=False): + if not learning: + # no need to calculate RND outputs when not in learn step + return super(RNDNet, self).forward(inputs, core_state, learning) + T, B, *_ = inputs["glyphs"].shape + + glyphs, features = self.prepare_input(inputs) + + # -- [B x 2] x,y coordinates + coordinates = features[:, :2] + + features = features.view(T * B, -1).float() + # -- [B x K] + features_emb = self.embed_features(features) + + assert features_emb.shape[0] == T * B + + reps = [features_emb] + + # -- [B x H' x W'] + crop = self.glyph_embedding.GlyphTuple( + *[self.crop(g, coordinates) for g in glyphs] + ) + # -- [B x H' x W' x K] + crop_emb = self.glyph_embedding(crop) + + if self.crop_model == "transformer": + # -- [B x W' x H' x K] + crop_rep = self.extract_crop_representation(crop_emb, mask=None) + elif self.crop_model == "cnn": + # -- [B x K x W' x H'] + crop_emb = crop_emb.transpose(1, 3) # -- TODO: slow? + # -- [B x W' x H' x K] + crop_rep = self.extract_crop_representation(crop_emb) + # -- [B x K'] + + crop_rep = crop_rep.view(T * B, -1) + assert crop_rep.shape[0] == T * B + + reps.append(crop_rep) + + # -- [B x H x W x K] + glyphs_emb = self.glyph_embedding(glyphs) + # -- [B x K x W x H] + glyphs_emb = glyphs_emb.transpose(1, 3) # -- TODO: slow? + # -- [B x W x H x K] + glyphs_rep = self.extract_representation(glyphs_emb) + + # -- [B x K'] + glyphs_rep = glyphs_rep.view(T * B, -1) + if self.equalize_input_dim: + glyphs_rep = self.project_glyph_dim(glyphs_rep) + + assert glyphs_rep.shape[0] == T * B + + # -- [B x K''] + reps.append(glyphs_rep) + + # MESSAGING MODEL + if self.msg_model != "none": + # [T x B x 256] -> [T * B x 256] + messages = inputs["message"].long().view(T * B, -1) + if self.msg_model == "cnn": + # convert messages to one-hot, [T * B x 96 x 256] + one_hot = F.one_hot(messages, num_classes=NUM_CHARS).transpose(1, 2) + char_rep = self.conv2_6_fc(self.conv1(one_hot.float())) + elif self.msg_model == "lt_cnn": + # [ T * B x E x 256 ] + char_emb = self.char_lt(messages).transpose(1, 2) + char_rep = self.conv2_6_fc(self.conv1(char_emb)) + else: # lstm, gru + char_emb = self.char_lt(messages) + output = self.char_rnn(char_emb)[0] + fwd_rep = output[:, -1, : self.h_dim // 2] + bwd_rep = output[:, 0, self.h_dim // 2 :] + char_rep = torch.cat([fwd_rep, bwd_rep], dim=1) + + if self.equalize_input_dim: + char_rep = self.project_msg_dim(char_rep) + reps.append(char_rep) + + st = torch.cat(reps, dim=1) + + # -- [B x K] + st = self.fc(st) + + # TARGET NETWORK + with torch.no_grad(): + if self.intrinsic_input == "crop_only": + tgt_crop_emb = self.rndtgt_embed(crop).transpose(1, 3) + tgt_crop_rep = self.rndtgt_extract_crop_representation(tgt_crop_emb) + tgt_st = self.rndtgt_fc(tgt_crop_rep.view(T * B, -1)) + elif self.intrinsic_input == "glyph_only": + tgt_glyphs_emb = self.rndtgt_embed(glyphs).transpose(1, 3) + tgt_glyphs_rep = self.rndtgt_extract_representation(tgt_glyphs_emb) + tgt_st = self.rndtgt_fc(tgt_glyphs_rep.view(T * B, -1)) + else: # full + tgt_reps = [] + tgt_feats = self.rndtgt_embed_features(features) + tgt_reps.append(tgt_feats) + + tgt_crop_emb = self.rndtgt_embed(crop).transpose(1, 3) + tgt_crop_rep = self.rndtgt_extract_crop_representation(tgt_crop_emb) + tgt_reps.append(tgt_crop_rep.view(T * B, -1)) + + tgt_glyphs_emb = self.rndtgt_embed(glyphs).transpose(1, 3) + tgt_glyphs_rep = self.rndtgt_extract_representation(tgt_glyphs_emb) + tgt_reps.append(tgt_glyphs_rep.view(T * B, -1)) + + if self.msg_model != "none": + tgt_char_emb = self.rndtgt_char_lt(messages).transpose(1, 2) + tgt_char_rep = self.rndtgt_conv2_6_fc( + self.rndprd_conv1(tgt_char_emb) + ) + tgt_reps.append(tgt_char_rep) + + tgt_st = self.rndtgt_fc(torch.cat(tgt_reps, dim=1)) + + # PREDICTOR NETWORK + if self.intrinsic_input == "crop_only": + prd_crop_emb = self.rndprd_embed(crop).transpose(1, 3) + prd_crop_rep = self.rndprd_extract_crop_representation(prd_crop_emb) + prd_st = self.rndprd_fc(prd_crop_rep.view(T * B, -1)) + elif self.intrinsic_input == "glyph_only": + prd_glyphs_emb = self.rndprd_embed(glyphs).transpose(1, 3) + prd_glyphs_rep = self.rndprd_extract_representation(prd_glyphs_emb) + prd_st = self.rndprd_fc(prd_glyphs_rep.view(T * B, -1)) + else: # full + prd_reps = [] + prd_feats = self.rndprd_embed_features(features) + prd_reps.append(prd_feats) + + prd_crop_emb = self.rndprd_embed(crop).transpose(1, 3) + prd_crop_rep = self.rndprd_extract_crop_representation(prd_crop_emb) + prd_reps.append(prd_crop_rep.view(T * B, -1)) + + prd_glyphs_emb = self.rndprd_embed(glyphs).transpose(1, 3) + prd_glyphs_rep = self.rndprd_extract_representation(prd_glyphs_emb) + prd_reps.append(prd_glyphs_rep.view(T * B, -1)) + + if self.msg_model != "none": + prd_char_emb = self.rndprd_char_lt(messages).transpose(1, 2) + prd_char_rep = self.rndprd_conv2_6_fc(self.rndprd_conv1(prd_char_emb)) + prd_reps.append(prd_char_rep) + + prd_st = self.rndprd_fc(torch.cat(prd_reps, dim=1)) + + assert tgt_st.size() == prd_st.size() + + if self.use_lstm: + core_input = st.view(T, B, -1) + core_output_list = [] + notdone = (~inputs["done"]).float() + for input, nd in zip(core_input.unbind(), notdone.unbind()): + # Reset core state to zero whenever an episode ended. + # Make `done` broadcastable with (num_layers, B, hidden_size) + # states: + nd = nd.view(1, -1, 1) + core_state = tuple(nd * t for t in core_state) + output, core_state = self.core(input.unsqueeze(0), core_state) + core_output_list.append(output) + core_output = torch.flatten(torch.cat(core_output_list), 0, 1) + else: + core_output = st + + # -- [B x A] + policy_logits = self.policy(core_output) + # -- [B x A] + baseline = self.baseline(core_output) + + if self.training: + action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1) + else: + # Don't sample when testing. + action = torch.argmax(policy_logits, dim=1) + + policy_logits = policy_logits.view(T, B, self.num_actions) + baseline = baseline.view(T, B) + action = action.view(T, B) + + output = dict( + policy_logits=policy_logits, + baseline=baseline, + action=action, + target=tgt_st.view(T, B, -1), + predicted=prd_st.view(T, B, -1), + int_baseline=self.int_baseline(core_output).view(T, B), + ) + return (output, core_state) + + +class RIDENet(IntrinsicRewardNet): + def __init__(self, observation_shape, num_actions, flags, device): + super(RIDENet, self).__init__(observation_shape, num_actions, flags, device) + + if flags.msg.model != "none": + raise NotImplementedError( + "model=%s + msg.model=%s" % (flags.model, flags.msg.model) + ) + + self.forward_dynamics_model = ForwardDynamicsNet( + num_actions, flags.ride.hidden_dim, flags.hidden_dim, flags.hidden_dim + ) + self.inverse_dynamics_model = InverseDynamicsNet( + num_actions, flags.ride.hidden_dim, flags.hidden_dim, flags.hidden_dim + ) + + Y = 8 # number of output filters + + # IMPLEMENTED HERE: RIDE net using the default feature extractor + self.ride_embed = GlyphEmbedding( + flags.glyph_type, flags.embedding_dim, device, flags.use_index_select + ) + + if self.intrinsic_input not in ("crop_only", "glyph_only", "full"): + raise NotImplementedError("RIDE input type %s" % self.intrinsic_input) + + ride_out_dim = 0 + if self.intrinsic_input in ("crop_only", "full"): + self.ride_extract_crop_representation = copy.deepcopy( + self.extract_crop_representation + ) + ride_out_dim += self.crop_dim ** 2 * Y # crop dim + + if self.intrinsic_input in ("full", "glyph_only"): + self.ride_extract_representation = copy.deepcopy( + self.extract_representation + ) + ride_out_dim += self.H * self.W * Y # glyph dim + + if self.intrinsic_input == "full": + self.ride_embed_features = nn.Sequential( + nn.Linear(self.num_features, self.k_dim), + nn.ELU(), + nn.Linear(self.k_dim, self.k_dim), + nn.ELU(), + ) + ride_out_dim += self.k_dim # feature dim + + self.ride_fc = nn.Sequential( + nn.Linear(ride_out_dim, self.h_dim), + # nn.ELU(), + # nn.Linear(self.h_dim, self.h_dim), + # nn.ELU(), + # nn.Linear(self.h_dim, self.h_dim), + ) + + # reinitialize all deep-copied layers + modules_to_init = [] + if self.intrinsic_input in ("full", "crop_only"): + modules_to_init.append(self.ride_extract_crop_representation) + if self.intrinsic_input in ("full", "glyph_only"): + modules_to_init.append(self.ride_extract_representation) + + for m in modules_to_init: + for p in m.modules(): + if isinstance(p, nn.Conv2d): + p.reset_parameters() + + def forward(self, inputs, core_state, learning=False): + if not learning: + # no need to calculate RIDE outputs when not in learn step + return super(RIDENet, self).forward(inputs, core_state, learning) + + T, B, *_ = inputs["glyphs"].shape + + glyphs, features = self.prepare_input(inputs) + + # -- [B x 2] x,y coordinates + coordinates = features[:, :2] + + features = features.view(T * B, -1).float() + # -- [B x K] + features_emb = self.embed_features(features) + + assert features_emb.shape[0] == T * B + + reps = [features_emb] + + # -- [B x H' x W'] + crop = self.glyph_embedding.GlyphTuple( + *[self.crop(g, coordinates) for g in glyphs] + ) + # -- [B x H' x W' x K] + crop_emb = self.glyph_embedding(crop) + + if self.crop_model == "transformer": + # -- [B x W' x H' x K] + crop_rep = self.extract_crop_representation(crop_emb, mask=None) + elif self.crop_model == "cnn": + # -- [B x K x W' x H'] + crop_emb = crop_emb.transpose(1, 3) # -- TODO: slow? + # -- [B x W' x H' x K] + crop_rep = self.extract_crop_representation(crop_emb) + # -- [B x K'] + + crop_rep = crop_rep.view(T * B, -1) + assert crop_rep.shape[0] == T * B + + reps.append(crop_rep) + + # -- [B x H x W x K] + glyphs_emb = self.glyph_embedding(glyphs) + # glyphs_emb = self.embed(glyphs) + # -- [B x K x W x H] + glyphs_emb = glyphs_emb.transpose(1, 3) # -- TODO: slow? + # -- [B x W x H x K] + glyphs_rep = self.extract_representation(glyphs_emb) + + # -- [B x K'] + glyphs_rep = glyphs_rep.view(T * B, -1) + assert glyphs_rep.shape[0] == T * B + + # -- [B x K''] + reps.append(glyphs_rep) + + st = torch.cat(reps, dim=1) + + # -- [B x K] + st = self.fc(st) + + # PREDICTOR NETWORK + if self.intrinsic_input == "crop_only": + ride_crop_emb = self.ride_embed(crop).transpose(1, 3) + ride_crop_rep = self.ride_extract_crop_representation(ride_crop_emb) + ride_st = self.ride_fc(ride_crop_rep.view(T * B, -1)) + elif self.intrinsic_input == "glyph_only": + ride_glyphs_emb = self.ride_embed(glyphs).transpose(1, 3) + ride_glyphs_rep = self.ride_extract_representation(ride_glyphs_emb) + ride_st = self.ride_fc(ride_glyphs_rep.view(T * B, -1)) + else: # full + ride_reps = [] + ride_feats = self.ride_embed_features(features) + ride_reps.append(ride_feats) + + ride_crop_emb = self.ride_embed(crop).transpose(1, 3) + ride_crop_rep = self.ride_extract_crop_representation(ride_crop_emb) + ride_reps.append(ride_crop_rep.view(T * B, -1)) + + ride_glyphs_emb = self.ride_embed(glyphs).transpose(1, 3) + ride_glyphs_rep = self.ride_extract_representation(ride_glyphs_emb) + ride_reps.append(ride_glyphs_rep.view(T * B, -1)) + + ride_st = self.ride_fc(torch.cat(ride_reps, dim=1)) + + if self.use_lstm: + core_input = st.view(T, B, -1) + core_output_list = [] + notdone = (~inputs["done"]).float() + for input, nd in zip(core_input.unbind(), notdone.unbind()): + # Reset core state to zero whenever an episode ended. + # Make `done` broadcastable with (num_layers, B, hidden_size) + # states: + nd = nd.view(1, -1, 1) + core_state = tuple(nd * t for t in core_state) + output, core_state = self.core(input.unsqueeze(0), core_state) + core_output_list.append(output) + core_output = torch.flatten(torch.cat(core_output_list), 0, 1) + else: + core_output = st + + # -- [B x A] + policy_logits = self.policy(core_output) + # -- [B x A] + baseline = self.baseline(core_output) + + if self.training: + action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1) + else: + # Don't sample when testing. + action = torch.argmax(policy_logits, dim=1) + + policy_logits = policy_logits.view(T, B, self.num_actions) + baseline = baseline.view(T, B) + action = action.view(T, B) + + output = dict( + policy_logits=policy_logits, + baseline=baseline, + action=action, + state_embedding=ride_st.view(T, B, -1), + int_baseline=self.int_baseline(core_output).view(T, B), + ) + return (output, core_state) diff --git a/nle/agent/models/losses.py b/nle/agent/models/losses.py new file mode 100644 index 000000000..d5653d0e1 --- /dev/null +++ b/nle/agent/models/losses.py @@ -0,0 +1,63 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch.nn import functional as F + + +def compute_baseline_loss(advantages): + return 0.5 * torch.sum(advantages ** 2) + + +def compute_entropy_loss(logits): + policy = F.softmax(logits, dim=-1) + log_policy = F.log_softmax(logits, dim=-1) + entropy_per_timestep = torch.sum(-policy * log_policy, dim=-1) + return -torch.sum(entropy_per_timestep) + + +def compute_policy_gradient_loss(logits, actions, advantages): + cross_entropy = F.nll_loss( + F.log_softmax(torch.flatten(logits, 0, 1), dim=-1), + target=torch.flatten(actions, 0, 1), + reduction="none", + ) + cross_entropy = cross_entropy.view_as(advantages) + policy_gradient_loss_per_timestep = cross_entropy * advantages.detach() + return torch.sum(policy_gradient_loss_per_timestep) + + +def compute_forward_dynamics_loss(pred_next_emb, next_emb): + forward_dynamics_loss = torch.norm(pred_next_emb - next_emb, dim=2, p=2) + return torch.sum(torch.mean(forward_dynamics_loss, dim=1)) + + +def compute_forward_binary_loss(pred_next_binary, next_binary): + return F.binary_cross_entropy_with_logits(pred_next_binary, next_binary) + + +def compute_forward_class_loss(pred_next_glyphs, next_glyphs): + next_glyphs = torch.flatten(next_glyphs, 0, 2).long() + pred_next_glyphs = pred_next_glyphs.view(next_glyphs.size(0), -1) + return F.cross_entropy(pred_next_glyphs, next_glyphs) + + +def compute_inverse_dynamics_loss(pred_actions, true_actions): + inverse_dynamics_loss = F.cross_entropy( + torch.flatten(pred_actions, 0, 1), + torch.flatten(true_actions, 0, 1), + reduction="none", + ) + inverse_dynamics_loss = inverse_dynamics_loss.view_as(true_actions) + return torch.sum(torch.mean(inverse_dynamics_loss, dim=1)) diff --git a/nle/agent/models/transformer.py b/nle/agent/models/transformer.py new file mode 100644 index 000000000..2e687505f --- /dev/null +++ b/nle/agent/models/transformer.py @@ -0,0 +1,73 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn +from torch.nn.modules import transformer + + +class LearnedPositionalEncoder(nn.Module): + def __init__(self, k, height, width, device): + super().__init__() + + self.height = height + self.width = width + + self.enc = torch.randn(height, width, k) + + self.enc = self.enc.div( + torch.norm(self.enc, p=2, dim=2)[:, :, None].expand_as(self.enc) + ) + + self.mlp = nn.Sequential( + nn.Linear(2 * k, k), nn.ReLU(), nn.Linear(k, k), nn.ReLU() + ) + + self.enc = nn.Parameter(self.enc, requires_grad=True)[None, :, :, :] + + if device is not None: + self.enc = self.enc.to(device) + + def forward(self, x): + x = torch.cat([x, self.enc.expand_as(x)], dim=3) + x = self.mlp(x) + return x + + +class TransformerEncoder(nn.Module): + def __init__(self, d_model, N, heads, height, width, device): + super().__init__() + self.N = N + self.pe = LearnedPositionalEncoder(d_model, height, width, device) + self.layers = transformer._get_clones( + transformer.TransformerEncoderLayer( + d_model, heads, dim_feedforward=d_model + ), + N, + ) + + def forward(self, src, mask=None): + x = src + x = self.pe(x) + + bs, h, w, k = x.shape + + x = x.view(bs, h * w, k).transpose(1, 0) + + for i in range(self.N): + x = self.layers[i](x, mask) + + # FIXME: probably slow due to contiguous; we can adapt the rest of the base + # model to not assume the batch as first dimension + return x.transpose(1, 0).view(bs, h, w, k).contiguous() diff --git a/nle/agent/neurips_sweep.sh b/nle/agent/neurips_sweep.sh new file mode 100755 index 000000000..8547d5ba8 --- /dev/null +++ b/nle/agent/neurips_sweep.sh @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# try both extrinsic-only and Random Network Distillation on different tasks with 3 repeats + +python polyhydra.py --multirun model=baseline,rnd character=mon-hum-neu-mal,val-dwa-law-fem,wiz-elf-cha-mal,tou-hum-neu-fem env=score,staircase,pet,eat,gold,scout,oracle name=1,2,3 diff --git a/nle/agent/polybeast_env.py b/nle/agent/polybeast_env.py new file mode 100644 index 000000000..101a85264 --- /dev/null +++ b/nle/agent/polybeast_env.py @@ -0,0 +1,177 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import multiprocessing as mp +import logging +import os +import threading +import time + +import torch +from nle.agent.envs import tasks +import libtorchbeast + + +# yapf: disable +parser = argparse.ArgumentParser(description='Remote Environment Server') + +parser.add_argument('--env', default='staircase', type=str, metavar='E', + help='Name of Gym environment to create.') +parser.add_argument('--character', default='mon-hum-neu-mal', type=str, metavar='C', + help='Specification of the NetHack character.') +parser.add_argument("--pipes_basename", default="unix:/tmp/polybeast", + help="Basename for the pipes for inter-process communication. " + "Has to be of the type unix:/some/path.") +parser.add_argument('--num_servers', default=4, type=int, metavar='N', + help='Number of environment servers.') + +parser.add_argument('--mock', action="store_true", + help='Use mock environment instead of NetHack.') +parser.add_argument('--single_ttyrec', action="store_true", + help='Record ttyrec only for actor 0.') +parser.add_argument('--num_seeds', default=0, type=int, metavar='S', + help='If larger than 0, samples fixed number of environment seeds ' + 'to be used.') +parser.add_argument('--seedspath', default="", type=str, + help="Path to json file with seeds.") + +# Training settings. +parser.add_argument('--savedir', default='~/nethackruns', + help='Root dir where experiment data will be saved.') + +# Task-Specific settings. +parser.add_argument('--reward_win', default=1.0, type=float, + help='Reward for winning (finding the staircase).') +parser.add_argument('--reward_lose', default=-1.0, type=float, + help='Reward for losing (dying before finding the staircase).') + +parser.add_argument('--penalty_step', default=-0.0001, type=float, + help='Penalty per step in the episode.') +parser.add_argument('--penalty_time', default=-0.0001, type=float, + help='Penalty per time step in the episode.') +parser.add_argument('--fn_penalty_step', default="constant", type=str, + help='Function to accumulate penalty.') +parser.add_argument('--max_num_steps', default=1000, type=int, + help='Maximum number of steps in the game.') +parser.add_argument('--state_counter', default="none", choices=['none', 'coordinates'], + help='Method for counting state visits. Default none. ' + 'Coordinates concatenates dungeon level with player x,y.') +# yapf: enable + +logging.basicConfig( + format=( + "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s" + ), + level=0, +) + + +# Helper functions for NethackEnv. +def _format_observation(obs): + obs = torch.from_numpy(obs) + return obs.view((1, 1) + obs.shape) # (...) -> (T,B,...). + + +def create_folders(flags): + # Creates some of the folders that would be created by the filewriter. + logdir = os.path.join(flags.savedir, "archives") + if not os.path.exists(logdir): + logging.info("Creating archive directory: %s" % logdir) + os.makedirs(logdir, exist_ok=True) + else: + logging.info("Found archive directory: %s" % logdir) + + +def create_env(flags, env_id=0, lock=threading.Lock()): + # commenting out these options for now because they use too much disk space + # archivefile = "nethack.%i.%%(pid)i.%%(time)s.zip" % env_id + # if flags.single_ttyrec and env_id != 0: + # archivefile = None + + # logdir = os.path.join(flags.savedir, "archives") + + with lock: + env_class = tasks.ENVS[flags.env] + kwargs = dict( + savedir=None, + archivefile=None, + character=flags.character, + max_episode_steps=flags.max_num_steps, + observation_keys=( + "glyphs", + "chars", + "colors", + "specials", + "blstats", + "message", + ), + penalty_step=flags.penalty_step, + penalty_time=flags.penalty_time, + penalty_mode=flags.fn_penalty_step, + ) + if flags.env in ("staircase", "pet", "oracle"): + kwargs.update(reward_win=flags.reward_win, reward_lose=flags.reward_lose) + elif env_id == 0: # print warning once + print("Ignoring flags.reward_win and flags.reward_lose") + if flags.state_counter != "none": + kwargs.update(state_counter=flags.state_counter) + env = env_class(**kwargs) + if flags.seedspath is not None and len(flags.seedspath) > 0: + json # Unused. + raise NotImplementedError("seedspath > 0 not implemented yet.") + # with open(flags.seedspath) as f: + # seeds = json.load(f) + # assert flags.num_seeds == len(seeds) + # env = SeedingWrapper(env, seeds=seeds) + # elif flags.num_seeds > 0: + # env = SeedingWrapper(env, num_seeds=flags.num_seeds) + return env + + +def serve(flags, server_address, env_id): + env = lambda: create_env(flags, env_id) + server = libtorchbeast.Server(env, server_address=server_address) + server.run() + + +def main(flags): + if flags.num_seeds > 0: + raise NotImplementedError("num_seeds > 0 not currently implemented.") + + create_folders(flags) + + if not flags.pipes_basename.startswith("unix:"): + raise Exception("--pipes_basename has to be of the form unix:/some/path.") + + processes = [] + for i in range(flags.num_servers): + p = mp.Process( + target=serve, args=(flags, f"{flags.pipes_basename}.{i}", i), daemon=True + ) + p.start() + processes.append(p) + + try: + # We are only here to listen to the interrupt. + while True: + time.sleep(10) + except KeyboardInterrupt: + pass + + +if __name__ == "__main__": + flags = parser.parse_args() + main(flags) diff --git a/nle/agent/polybeast_learner.py b/nle/agent/polybeast_learner.py new file mode 100644 index 000000000..2246b66d0 --- /dev/null +++ b/nle/agent/polybeast_learner.py @@ -0,0 +1,765 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Run with OMP_NUM_THREADS=1. +# + +import argparse +import collections +import logging +import omegaconf +import os +import threading +import time +import timeit +import traceback + +import wandb + +import nest +import torch +from nle.agent.core import file_writer +from nle.agent.core import vtrace +from nle.agent.models import create_model, losses +from nle.agent.models.base import NetHackNet +from nle.agent.models.intrinsic import IntrinsicRewardNet +import libtorchbeast +from torch import nn +from torch.nn import functional as F + +# yapf: disable +parser = argparse.ArgumentParser(description="PyTorch Scalable Agent") + +parser.add_argument("--mode", default="train", + choices=["train", "test", "test_render"], + help="Training or test mode.") +parser.add_argument('--env', default='staircase', type=str, metavar='E', + help='Name of Gym environment to create.') +parser.add_argument("--wandb", action="store_true", + help="Log to wandb.") +parser.add_argument('--group', default='default', type=str, metavar='G', + help='Name of the experiment group (as being used by wandb).') +parser.add_argument('--project', default='nle', type=str, metavar='P', + help='Name of the project (as being used by wandb).') +parser.add_argument('--entity', default='nethack', type=str, metavar='P', + help='Which team to log to.') + +# Training settings. +parser.add_argument("--pipes_basename", default="unix:/tmp/polybeast", + help="Basename for the pipes for inter-process communication. " + "Has to be of the type unix:/some/path.") +parser.add_argument("--savedir", default="~/palaas/torchbeast", + help="Root dir where experiment data will be saved.") +parser.add_argument("--num_actors", default=4, type=int, metavar="N", + help="Number of actors") +parser.add_argument("--total_steps", default=1e6, type=float, metavar="T", + help="Total environment steps to train for. Will be cast to int.") +parser.add_argument("--batch_size", default=8, type=int, metavar="B", + help="Learner batch size") +parser.add_argument("--unroll_length", default=80, type=int, metavar="T", + help="The unroll length (time dimension)") +parser.add_argument("--num_learner_threads", default=2, type=int, + metavar="N", help="Number learner threads.") +parser.add_argument("--num_inference_threads", default=2, type=int, + metavar="N", help="Number learner threads.") +parser.add_argument("--learner_device", default="cuda:0", help="Set learner device") +parser.add_argument("--actor_device", default="cuda:1", help="Set actor device") +parser.add_argument("--disable_cuda", action="store_true", + help="Disable CUDA.") +parser.add_argument("--use_lstm", action="store_true", + help="Use LSTM in agent model.") +parser.add_argument("--use_index_select", action="store_true", + help="Whether to use index_select instead of embedding lookup.") +parser.add_argument("--max_learner_queue_size", default=None, type=int, metavar="N", + help="Optional maximum learner queue size. Defaults to batch_size.") + + +# Model settings. +parser.add_argument('--model', default="baseline", + help='Name of the model to run') +parser.add_argument('--crop_model', default="cnn", choices=["cnn", "transformer"], + help='Size of cropping window around the agent') +parser.add_argument('--crop_dim', type=int, default=9, + help='Size of cropping window around the agent') +parser.add_argument('--embedding_dim', type=int, default=32, + help='Size of glyph embeddings.') +parser.add_argument('--hidden_dim', type=int, default=128, + help='Size of hidden representations.') +parser.add_argument('--layers', type=int, default=5, + help='Number of ConvNet/Transformer layers.') +# Loss settings. +parser.add_argument("--entropy_cost", default=0.0006, type=float, + help="Entropy cost/multiplier.") +parser.add_argument("--baseline_cost", default=0.5, type=float, + help="Baseline cost/multiplier.") +parser.add_argument("--discounting", default=0.99, type=float, + help="Discounting factor.") +parser.add_argument("--reward_clipping", default="tim", + choices=["soft_asymmetric", "none", "tim"], + help="Reward clipping.") +parser.add_argument("--no_extrinsic", action="store_true", + help=("Disables extrinsic reward (no baseline/pg_loss).")) +parser.add_argument("--normalize_reward", action="store_true", + help=("Normalizes reward by dividing by running stdev from mean.")) + +# Optimizer settings. +parser.add_argument("--learning_rate", default=0.00048, type=float, + metavar="LR", help="Learning rate.") +parser.add_argument("--alpha", default=0.99, type=float, + help="RMSProp smoothing constant.") +parser.add_argument("--momentum", default=0, type=float, + help="RMSProp momentum.") +parser.add_argument("--epsilon", default=0.01, type=float, + help="RMSProp epsilon.") +parser.add_argument("--grad_norm_clipping", default=40.0, type=float, + help="Global gradient norm clip.") + +# Misc settings. +parser.add_argument("--write_profiler_trace", action="store_true", + help="Collect and write a profiler trace " + "for chrome://tracing/.") + +# yapf: enable + + +logging.basicConfig( + format=( + "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s" + ), + level=0, +) + + +def inference( + inference_batcher, model, flags, actor_device, lock=threading.Lock() +): # noqa: B008 + with torch.no_grad(): + for batch in inference_batcher: + batched_env_outputs, agent_state = batch.get_inputs() + observation, reward, done, *_ = batched_env_outputs + # Observation is a dict with keys 'features' and 'glyphs'. + observation["done"] = done + observation, agent_state = nest.map( + lambda t: t.to(actor_device, non_blocking=True), + (observation, agent_state), + ) + with lock: + outputs = model(observation, agent_state) + core_outputs, agent_state = nest.map(lambda t: t.cpu(), outputs) + # Restructuring the output in the way that is expected + # by the functions in actorpool. + outputs = ( + tuple( + ( + core_outputs["action"], + core_outputs["policy_logits"], + core_outputs["baseline"], + ) + ), + agent_state, + ) + batch.set_outputs(outputs) + + +# TODO(heiner): Given that our nest implementation doesn't support +# namedtuples, using them here doesn't seem like a good fit. We +# probably want to nestify the environment server and deal with +# dictionaries? +EnvOutput = collections.namedtuple( + "EnvOutput", "frame rewards done episode_step episode_return" +) +AgentOutput = NetHackNet.AgentOutput +Batch = collections.namedtuple("Batch", "env agent") + + +def clip(flags, rewards): + if flags.reward_clipping == "tim": + clipped_rewards = torch.tanh(rewards / 100.0) + elif flags.reward_clipping == "soft_asymmetric": + squeezed = torch.tanh(rewards / 5.0) + # Negative rewards are given less weight than positive rewards. + clipped_rewards = torch.where(rewards < 0, 0.3 * squeezed, squeezed) * 5.0 + elif flags.reward_clipping == "none": + clipped_rewards = rewards + else: + raise NotImplementedError("reward_clipping=%s" % flags.reward_clipping) + return clipped_rewards + + +def learn( + learner_queue, + model, + actor_model, + optimizer, + scheduler, + stats, + flags, + plogger, + learner_device, + lock=threading.Lock(), # noqa: B008 +): + for tensors in learner_queue: + tensors = nest.map(lambda t: t.to(learner_device), tensors) + + batch, initial_agent_state = tensors + env_outputs, actor_outputs = batch + observation, reward, done, *_ = env_outputs + observation["reward"] = reward + observation["done"] = done + + lock.acquire() # Only one thread learning at a time. + + output, _ = model(observation, initial_agent_state, learning=True) + + # Use last baseline value (from the value function) to bootstrap. + learner_outputs = AgentOutput._make( + (output["action"], output["policy_logits"], output["baseline"]) + ) + + # At this point, the environment outputs at time step `t` are the inputs + # that lead to the learner_outputs at time step `t`. After the following + # shifting, the actions in `batch` and `learner_outputs` at time + # step `t` is what leads to the environment outputs at time step `t`. + batch = nest.map(lambda t: t[1:], batch) + learner_outputs = nest.map(lambda t: t[:-1], learner_outputs) + + # Turn into namedtuples again. + env_outputs, actor_outputs = batch + # Note that the env_outputs.frame is now a dict with 'features' and 'glyphs' + # instead of actually being the frame itself. This is currently not a problem + # because we never use actor_outputs.frame in the rest of this function. + env_outputs = EnvOutput._make(env_outputs) + actor_outputs = AgentOutput._make(actor_outputs) + learner_outputs = AgentOutput._make(learner_outputs) + + rewards = env_outputs.rewards + if flags.normalize_reward: + model.update_running_moments(rewards) + rewards /= model.get_running_std() + + total_loss = 0 + + # INTRINSIC REWARDS + calculate_intrinsic = ( + isinstance(model, IntrinsicRewardNet) and model.intrinsic_enabled() + ) + if calculate_intrinsic: + # Compute intrinsic reward and loss + if "int_baseline" not in output: + raise RuntimeError("Expected intrinsic outputs but found none") + + # set intrinsic reward dimensions here so we don't make any mistakes later + intrinsic_reward = rewards.new_zeros(rewards.size()).float() + + if flags.model == "rnd": + # Random Network Distillation + target = output["target"][1:] + predicted = output["predicted"][1:] + # loss for prediction failures, not really "forward" model + forward_loss = flags.rnd.forward_cost * F.mse_loss( + target, predicted, reduction="mean" + ) + total_loss += forward_loss + + # reward based on unpredicted scenarios + intrinsic_reward += (target - predicted).pow(2).sum(2) * 0.5 + elif flags.model == "ride": + # Rewarding Impact-Driven Exploration + state_emb = output["state_embedding"][:-1] + next_state_emb = output["state_embedding"][1:] + actions = actor_outputs.action + + pred_next_state_emb = model.forward_dynamics_model(state_emb, actions) + pred_actions = model.inverse_dynamics_model(state_emb, next_state_emb) + + forward_loss = ( + flags.ride.forward_cost + * losses.compute_forward_dynamics_loss( + pred_next_state_emb, next_state_emb + ) + ) + inverse_loss = ( + flags.ride.inverse_cost + * losses.compute_inverse_dynamics_loss(pred_actions, actions) + ) + total_loss += forward_loss + inverse_loss + + intrinsic_reward += torch.norm(next_state_emb - state_emb, dim=2, p=2) + if flags.ride.count_norm: + if "state_visits" not in observation: + raise RuntimeError( + "ride.count_norm=true but state_counter=none" + ) + # -- [T x B ] + counts = observation["state_visits"][1:].squeeze(-1).float().sqrt() + intrinsic_reward /= counts + + if flags.int.normalize_reward: + model.update_intrinsic_moments(intrinsic_reward) + intrinsic_reward /= model.get_intrinsic_std() + intrinsic_reward *= flags.int.intrinsic_weight + + if not flags.int.twoheaded and not flags.no_extrinsic: + # add intrinsic rewards to extrinsic ones + rewards += intrinsic_reward + + # STANDARD EXTRINSIC LOSSES / REWARDS + if flags.entropy_cost > 0: + entropy_loss = flags.entropy_cost * losses.compute_entropy_loss( + learner_outputs.policy_logits + ) + total_loss += entropy_loss + + if not flags.no_extrinsic: + clipped_rewards = clip(flags, rewards) + + discounts = (~env_outputs.done).float() * flags.discounting + + # This could be in C++. In TF, this is actually slower on the GPU. + vtrace_returns = vtrace.from_logits( + behavior_policy_logits=actor_outputs.policy_logits, + target_policy_logits=learner_outputs.policy_logits, + actions=actor_outputs.action, + discounts=discounts, + rewards=clipped_rewards, + values=learner_outputs.baseline, + bootstrap_value=learner_outputs.baseline[-1], + ) + + # Compute loss as a weighted sum of the baseline loss, the policy + # gradient loss and an entropy regularization term. + pg_loss = losses.compute_policy_gradient_loss( + learner_outputs.policy_logits, + actor_outputs.action, + vtrace_returns.pg_advantages, + ) + baseline_loss = flags.baseline_cost * losses.compute_baseline_loss( + vtrace_returns.vs - learner_outputs.baseline + ) + total_loss += pg_loss + baseline_loss + + # TWO-HEADED INTRINSIC REWARDS / LOSSES + if calculate_intrinsic and (flags.int.twoheaded or flags.no_extrinsic): + # here we calculate RL loss on the intrinsic reward using its own value head + # 1) twoheaded always separates ext and int rewards to their own heads + # 2) no_extrinsic skips the ext value head and uses only the int one + int_clipped_rewards = clip(flags, intrinsic_reward) + + # use a separate discounting factor for intrinsic rewards + if flags.int.episodic: + int_discounts = (~env_outputs.done).float() * flags.int.discounting + else: + # can also do non-episodic intrinsic rewards + int_discounts = discounts.new_full( + discounts.size(), flags.int.discounting + ) + + int_vtrace_returns = vtrace.from_logits( + behavior_policy_logits=actor_outputs.policy_logits, + target_policy_logits=learner_outputs.policy_logits, + actions=actor_outputs.action, + discounts=int_discounts, # intrinsic discounts + rewards=int_clipped_rewards, # intrinsic reward + values=output["int_baseline"][1:], # intrinsic baseline + bootstrap_value=output["int_baseline"][-1], # intrinsic bootstrap + ) + + # intrinsic baseline loss + int_baseline_loss = flags.int.baseline_cost * losses.compute_baseline_loss( + int_vtrace_returns.vs - output["int_baseline"][1:] + ) + + # intrinsic policy gradient loss + int_pg_loss = losses.compute_policy_gradient_loss( + learner_outputs.policy_logits, + actor_outputs.action, + int_vtrace_returns.pg_advantages, + ) + + total_loss += int_pg_loss + int_baseline_loss + + # BACKWARD STEP + optimizer.zero_grad() + total_loss.backward() + if flags.grad_norm_clipping > 0: + nn.utils.clip_grad_norm_(model.parameters(), flags.grad_norm_clipping) + optimizer.step() + scheduler.step() + + actor_model.load_state_dict(model.state_dict()) + + # LOGGING + episode_returns = env_outputs.episode_return[env_outputs.done] + stats["step"] = stats.get("step", 0) + flags.unroll_length * flags.batch_size + stats["mean_episode_return"] = torch.mean(episode_returns).item() + stats["mean_episode_step"] = torch.mean(env_outputs.episode_step.float()).item() + stats["total_loss"] = total_loss.item() + if flags.entropy_cost > 0: + stats["entropy_loss"] = entropy_loss.item() + if not flags.no_extrinsic: + stats["pg_loss"] = pg_loss.item() + stats["baseline_loss"] = baseline_loss.item() + + stats["learner_queue_size"] = learner_queue.size() + + if calculate_intrinsic: + stats["intrinsic_reward"] = torch.mean(intrinsic_reward).item() + if flags.model == "rnd": + stats["forward_loss"] = forward_loss.item() + elif flags.model == "ride": + stats["forward_loss"] = forward_loss.item() + stats["inverse_loss"] = inverse_loss.item() + if flags.int.twoheaded: + stats["int_baseline_loss"] = int_baseline_loss.item() + stats["int_pg_loss"] = int_pg_loss.item() + + if "state_visits" in observation: + visits = observation["state_visits"][:-1] + metric = visits[env_outputs.done].float() + key1 = "mean_state_visits" + key2 = "max_state_visits" + if not len(episode_returns): + stats[key1] = None + stats[key2] = None + else: + stats[key1] = torch.mean(metric).item() + stats[key2] = torch.max(metric).item() + + DEBUG = False + + if DEBUG and env_outputs.done.sum() > 0: + print() + print("glyphs shape", env_outputs.frame["glyphs"].shape) + print("features shape", env_outputs.frame["features"].shape) + print( + "episode_step", + env_outputs.episode_step[:, 0], + env_outputs.episode_step.shape, + ) + print("rewards", env_outputs.rewards[:, 0], env_outputs.rewards.shape) + print( + "episode_return", + env_outputs.episode_return[:, 0], + env_outputs.episode_return.shape, + ) + print("done", env_outputs.done[:, 0], env_outputs.done.shape) + + if not len(episode_returns): + # Hide the mean-of-empty-tuple NaN as it scares people. + stats["mean_episode_return"] = None + + # Only logging if at least one episode was finished + if len(episode_returns): + # TODO: log also SPS + plogger.log(stats) + if flags.wandb: + wandb.log(stats, step=stats["step"]) + + lock.release() + + +def train(flags): + logging.info("Logging results to %s", flags.savedir) + if isinstance(flags, omegaconf.DictConfig): + flag_dict = omegaconf.OmegaConf.to_container(flags) + else: + flag_dict = vars(flags) + plogger = file_writer.FileWriter(xp_args=flag_dict, rootdir=flags.savedir) + + if not flags.disable_cuda and torch.cuda.is_available(): + logging.info("Using CUDA.") + learner_device = torch.device(flags.learner_device) + actor_device = torch.device(flags.actor_device) + else: + logging.info("Not using CUDA.") + learner_device = torch.device("cpu") + actor_device = torch.device("cpu") + + if flags.max_learner_queue_size is None: + flags.max_learner_queue_size = flags.batch_size + + # The queue the learner threads will get their data from. + # Setting `minimum_batch_size == maximum_batch_size` + # makes the batch size static. We could make it dynamic, but that + # requires a loss (and learning rate schedule) that's batch size + # independent. + learner_queue = libtorchbeast.BatchingQueue( + batch_dim=1, + minimum_batch_size=flags.batch_size, + maximum_batch_size=flags.batch_size, + check_inputs=True, + maximum_queue_size=flags.max_learner_queue_size, + ) + + # The "batcher", a queue for the inference call. Will yield + # "batch" objects with `get_inputs` and `set_outputs` methods. + # The batch size of the tensors will be dynamic. + inference_batcher = libtorchbeast.DynamicBatcher( + batch_dim=1, + minimum_batch_size=1, + maximum_batch_size=512, + timeout_ms=100, + check_outputs=True, + ) + + addresses = [] + connections_per_server = 1 + pipe_id = 0 + while len(addresses) < flags.num_actors: + for _ in range(connections_per_server): + addresses.append(f"{flags.pipes_basename}.{pipe_id}") + if len(addresses) == flags.num_actors: + break + pipe_id += 1 + + logging.info("Using model %s", flags.model) + + model = create_model(flags, learner_device) + + plogger.metadata["model_numel"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + + logging.info("Number of model parameters: %i", plogger.metadata["model_numel"]) + + actor_model = create_model(flags, actor_device) + + # The ActorPool that will run `flags.num_actors` many loops. + actors = libtorchbeast.ActorPool( + unroll_length=flags.unroll_length, + learner_queue=learner_queue, + inference_batcher=inference_batcher, + env_server_addresses=addresses, + initial_agent_state=model.initial_state(), + ) + + def run(): + try: + actors.run() + except Exception as e: + logging.error("Exception in actorpool thread!") + traceback.print_exc() + print() + raise e + + actorpool_thread = threading.Thread(target=run, name="actorpool-thread") + + optimizer = torch.optim.RMSprop( + model.parameters(), + lr=flags.learning_rate, + momentum=flags.momentum, + eps=flags.epsilon, + alpha=flags.alpha, + ) + + def lr_lambda(epoch): + return ( + 1 + - min(epoch * flags.unroll_length * flags.batch_size, flags.total_steps) + / flags.total_steps + ) + + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + + stats = {} + + if flags.checkpoint and os.path.exists(flags.checkpoint): + logging.info("Loading checkpoint: %s" % flags.checkpoint) + checkpoint_states = torch.load( + flags.checkpoint, map_location=flags.learner_device + ) + model.load_state_dict(checkpoint_states["model_state_dict"]) + optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"]) + scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"]) + stats = checkpoint_states["stats"] + logging.info(f"Resuming preempted job, current stats:\n{stats}") + + # Initialize actor model like learner model. + actor_model.load_state_dict(model.state_dict()) + + learner_threads = [ + threading.Thread( + target=learn, + name="learner-thread-%i" % i, + args=( + learner_queue, + model, + actor_model, + optimizer, + scheduler, + stats, + flags, + plogger, + learner_device, + ), + ) + for i in range(flags.num_learner_threads) + ] + inference_threads = [ + threading.Thread( + target=inference, + name="inference-thread-%i" % i, + args=(inference_batcher, actor_model, flags, actor_device), + ) + for i in range(flags.num_inference_threads) + ] + + actorpool_thread.start() + for t in learner_threads + inference_threads: + t.start() + + def checkpoint(checkpoint_path=None): + if flags.checkpoint: + if checkpoint_path is None: + checkpoint_path = flags.checkpoint + logging.info("Saving checkpoint to %s", checkpoint_path) + torch.save( + { + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + "stats": stats, + "flags": vars(flags), + }, + checkpoint_path, + ) + + # TODO: test this again then uncomment (from deleted polyhydra code) + # def receive_slurm_signal(signal_num=None, frame=None): + # logging.info("Received SIGTERM, checkpointing") + # make_checkpoint() + + # signal.signal(signal.SIGTERM, receive_slurm_signal) + + def format_value(x): + return f"{x:1.5}" if isinstance(x, float) else str(x) + + try: + train_start_time = timeit.default_timer() + train_time_offset = stats.get("train_seconds", 0) # used for resuming training + last_checkpoint_time = timeit.default_timer() + + dev_checkpoint_intervals = [0, 0.25, 0.5, 0.75] + + loop_start_time = timeit.default_timer() + loop_start_step = stats.get("step", 0) + while True: + if loop_start_step >= flags.total_steps: + break + time.sleep(5) + loop_end_time = timeit.default_timer() + loop_end_step = stats.get("step", 0) + + stats["train_seconds"] = round( + loop_end_time - train_start_time + train_time_offset, 1 + ) + + if loop_end_time - last_checkpoint_time > 10 * 60: + # Save every 10 min. + checkpoint() + last_checkpoint_time = loop_end_time + + if len(dev_checkpoint_intervals) > 0: + step_percentage = loop_end_step / flags.total_steps + i = dev_checkpoint_intervals[0] + if step_percentage > i: + checkpoint(flags.checkpoint[:-4] + "_" + str(i) + ".tar") + dev_checkpoint_intervals = dev_checkpoint_intervals[1:] + + logging.info( + "Step %i @ %.1f SPS. Inference batcher size: %i." + " Learner queue size: %i." + " Other stats: (%s)", + loop_end_step, + (loop_end_step - loop_start_step) / (loop_end_time - loop_start_time), + inference_batcher.size(), + learner_queue.size(), + ", ".join( + f"{key} = {format_value(value)}" for key, value in stats.items() + ), + ) + loop_start_time = loop_end_time + loop_start_step = loop_end_step + except KeyboardInterrupt: + pass # Close properly. + else: + logging.info("Learning finished after %i steps.", stats["step"]) + + checkpoint() + + # Done with learning. Let's stop all the ongoing work. + inference_batcher.close() + learner_queue.close() + + actorpool_thread.join() + + for t in learner_threads + inference_threads: + t.join() + + +def test(flags): + test_checkpoint = os.path.join(flags.savedir, "test_checkpoint.tar") + + if not os.path.exists(os.path.dirname(test_checkpoint)): + os.makedirs(os.path.dirname(test_checkpoint)) + + logging.info("Creating test copy of checkpoint '%s'", flags.checkpoint) + + checkpoint = torch.load(flags.checkpoint) + for d in checkpoint["optimizer_state_dict"]["param_groups"]: + d["lr"] = 0.0 + d["initial_lr"] = 0.0 + + checkpoint["scheduler_state_dict"]["last_epoch"] = 0 + checkpoint["scheduler_state_dict"]["_step_count"] = 0 + checkpoint["scheduler_state_dict"]["base_lrs"] = [0.0] + checkpoint["stats"]["step"] = 0 + checkpoint["stats"]["_tick"] = 0 + + flags.checkpoint = test_checkpoint + flags.learning_rate = 0.0 + + logging.info("Saving test checkpoint to %s", test_checkpoint) + torch.save(checkpoint, test_checkpoint) + + train(flags) + + +def main(flags): + if flags.wandb: + wandb.init( + project=flags.project, + config=vars(flags), + group=flags.group, + entity=flags.entity, + ) + if flags.mode == "train": + if flags.write_profiler_trace: + logging.info("Running with profiler.") + with torch.autograd.profiler.profile() as prof: + train(flags) + filename = "chrome-%s.trace" % time.strftime("%Y%m%d-%H%M%S") + logging.info("Writing profiler trace to '%s.gz'", filename) + prof.export_chrome_trace(filename) + os.system("gzip %s" % filename) + else: + train(flags) + elif flags.mode.startswith("test"): + test(flags) + + +if __name__ == "__main__": + flags = parser.parse_args() + flags.total_steps = int(flags.total_steps) # Allows e.g. 1e6. + main(flags) diff --git a/nle/agent/polyhydra.py b/nle/agent/polyhydra.py new file mode 100644 index 000000000..3ebe350d9 --- /dev/null +++ b/nle/agent/polyhydra.py @@ -0,0 +1,142 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Installation for hydra: +pip install hydra-core hydra_colorlog --upgrade + +Runs like polybeast but use = to set flags: +python -m polyhydra.py learning_rate=0.001 rnd.twoheaded=true + +Run sweep with another -m after the module: +python -m polyhydra.py -m learning_rate=0.01,0.001,0.0001,0.00001 momentum=0,0.5 +""" + +import hydra +from omegaconf import OmegaConf, DictConfig +from pathlib import Path + +import logging +import os + +import numpy as np +import multiprocessing as mp + +from nle.agent import polybeast_env, polybeast_learner + +import torch + +if torch.__version__.startswith("1.5") or torch.__version__.startswith("1.6"): + # pytorch 1.5.* needs this for some reason on the cluster + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" +os.environ["OMP_NUM_THREADS"] = "1" + +logging.basicConfig( + format=( + "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s" + ), + level=0, +) + + +def pipes_basename(): + logdir = Path(os.getcwd()) + name = ".".join([logdir.parents[1].name, logdir.parents[0].name, logdir.name]) + return "unix:/tmp/poly.%s" % name + + +def get_common_flags(flags): + flags = OmegaConf.to_container(flags) + flags["pipes_basename"] = pipes_basename() + flags["savedir"] = os.getcwd() + return OmegaConf.create(flags) + + +def get_learner_flags(flags): + lrn_flags = OmegaConf.to_container(flags) + lrn_flags["checkpoint"] = os.path.join(flags["savedir"], "checkpoint.tar") + lrn_flags["entropy_cost"] = float(lrn_flags["entropy_cost"]) + return OmegaConf.create(lrn_flags) + + +def run_learner(flags: DictConfig): + polybeast_learner.main(flags) + + +def get_environment_flags(flags): + env_flags = OmegaConf.to_container(flags) + env_flags["num_servers"] = flags.num_actors + max_num_steps = 1e6 + if flags.env in ("staircase", "pet"): + max_num_steps = 1000 + env_flags["max_num_steps"] = int(max_num_steps) + env_flags["seedspath"] = "" + return OmegaConf.create(env_flags) + + +def run_env(flags): + np.random.seed() # Get new random seed in forked process. + polybeast_env.main(flags) + + +def symlink_latest(savedir, symlink): + try: + if os.path.islink(symlink): + os.remove(symlink) + if not os.path.exists(symlink): + os.symlink(savedir, symlink) + logging.info("Symlinked log directory: %s" % symlink) + except OSError: + # os.remove() or os.symlink() raced. Don't do anything. + pass + + +@hydra.main(config_name="config") +def main(flags: DictConfig): + if os.path.exists("config.yaml"): + # this ignores the local config.yaml and replaces it completely with saved one + logging.info("loading existing configuration, we're continuing a previous run") + new_flags = OmegaConf.load("config.yaml") + cli_conf = OmegaConf.from_cli() + # however, you can override parameters from the cli still + # this is useful e.g. if you did total_steps=N before and want to increase it + flags = OmegaConf.merge(new_flags, cli_conf) + + logging.info(flags.pretty(resolve=True)) + OmegaConf.save(flags, "config.yaml") + + flags = get_common_flags(flags) + + # set flags for polybeast_env + env_flags = get_environment_flags(flags) + env_processes = [] + for _ in range(1): + p = mp.Process(target=run_env, args=(env_flags,)) + p.start() + env_processes.append(p) + + symlink_latest( + flags.savedir, os.path.join(hydra.utils.get_original_cwd(), "latest") + ) + + lrn_flags = get_learner_flags(flags) + run_learner(lrn_flags) + + for p in env_processes: + p.kill() + p.join() + + +if __name__ == "__main__": + main() diff --git a/nle/agent/scripts/hiplot.py b/nle/agent/scripts/hiplot.py new file mode 100644 index 000000000..a579202f3 --- /dev/null +++ b/nle/agent/scripts/hiplot.py @@ -0,0 +1,176 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +To use: +1) forward a port to 5005 + +2) start hiplot + +`python -m hackrl.scripts.hiplot` + +3) open hiplot in the browser at localhost:5000 and enter the {sweep_path} with globbing + +e.g. +`/home/user/outputs/2020-07-24/07-04-02/*/logs.csv` + +This collects logs via a function imported from the gnuplot plot script. + +""" + +import copy + +import hiplot as hip + +from omegaconf import OmegaConf, DictConfig +from pathlib import Path +from threading import Timer +from nle.agent.scripts.plot import collect_logs + + +# Default hiplot server. +HIPLOT_SERVER_URL = "http://127.0.0.1:5005/" + + +def flatten(cfg): + """Collapse configurations -- {"foo": {"bar": 0}} -> {"foo.bar": 0}""" + flat = False + while not flat: + flat = True + new_cfg = {} + for key, val in cfg.items(): + if isinstance(val, DictConfig) or isinstance(val, dict): + flat = False + for subkey, subval in val.items(): + newkey = key + "." + subkey + new_cfg[newkey] = subval + else: + new_cfg[key] = val + cfg = new_cfg + return new_cfg + + +def fetcher(uri): + """Prepare param sweep output for hiplot + Collects the sweep results and simplifies them for easy display using hiplot. + :param uri: root dir that containing all the param_sweeping results. + :returns: hiplot Experiment Object for display + """ + + print("got request for %s, collecting logs" % uri) + + exp = hip.Experiment() + exp.display_data(hip.Displays.XY).update( + {"axis_x": "step", "axis_y": "cumulative_reward"} + ) + + dfs = collect_logs(Path(uri)) # list of (name, log, df) triplets + cfg_variants = {} + cfgs = {} + for name, _dfs in dfs: + # first collect each config + print("loading config from %s" % name) + target = Path(name) + configpath = target / "config.yaml" + cfg = flatten(OmegaConf.load(str(configpath))) + cfgs[name] = cfg + for k, v in cfg.items(): + if k not in cfg_variants: + cfg_variants[k] = set() + cfg_variants[k].add(v) + + print("Read in %d logs successfully" % len(cfgs)) + + order = [] + order.append("mean_final_reward") + # cfg_variants are hyperparams with more than one value + for key, vals in cfg_variants.items(): + if len(vals) > 1: + order.append(key) + order.append("cumulative_reward") + print("headers found to plot: ", order) + exp.display_data(hip.Displays.PARALLEL_PLOT).update( + hide=["step", "uid", "from_uid"], order=order + ) + + # min_points = min(len(df["step"]) for _name, df in dfs) + # max_points = max(len(df["step"]) for _name, df in dfs) + ave_points = sum(len(df["step"]) for _name, df in dfs) // len(dfs) + step_size = ave_points // 100 + 1 # I want an average of 100 points per experiment + print("ave_points:", ave_points, "step_size:", step_size) + + for name, df in dfs: + # now go through each dataframe + cfg = cfgs[name] + + hyperparams = dict() + for key, val in cfg.items(): + if len(cfg_variants[key]) > 1: + try: + hyperparams[key] = float(val) + except ValueError: + hyperparams[key] = str(val) + + steps = df["step"] + prev_name = None + cum_sum = df["mean_episode_return"].cumsum() + + for idx in range(0, len(cum_sum), step_size): + step = int(steps[idx]) + cumulative_reward = cum_sum[idx] + curr_name = "{},step{}".format(name, step) + sp = hip.Datapoint( + uid=curr_name, + values=dict(step=step, cumulative_reward=cumulative_reward), + ) + if prev_name is not None: + sp.from_uid = prev_name + exp.datapoints.append(sp) + prev_name = curr_name + + mean_final_reward = float(df["mean_episode_return"][-10000:].mean()) + peak_performance = float( + df["mean_episode_return"].rolling(window=1000).mean().max() + ) + end_vals = copy.deepcopy(hyperparams) + end_vals.update( + step=int(steps.iloc[-1]), + cumulative_reward=cum_sum.iloc[-1], + mean_final_reward=mean_final_reward, + peak_performance=peak_performance, + ) + dp = hip.Datapoint(uid=name, from_uid=prev_name, values=end_vals) + exp.datapoints.append(dp) + + return exp + + +def open_browser(): + import webbrowser + + webbrowser.open(HIPLOT_SERVER_URL, new=2, autoraise=True) + + +def main(): + # By running the following command, a hiplot server will be rendered to display + # your experiment results using the udf fetcher passed to hiplot. + try: + Timer(1, open_browser).start() + except Exception as e: + print("Fail to open browser", e) + hip.server.run_server(fetchers=[fetcher]) + + +if __name__ == "__main__": + main() diff --git a/nle/agent/scripts/plot.py b/nle/agent/scripts/plot.py new file mode 100644 index 000000000..24df30cd0 --- /dev/null +++ b/nle/agent/scripts/plot.py @@ -0,0 +1,303 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Script for plotting results from an NLE agent's logs.csv file. + +Examples of using the plotting tool. + +Plot the most recent run (symlinked at ~/torchbeast/latest by default). +``` +python -m nle.scripts.plot +``` + +Plot a specific run using a rolling window of size 10 (if window > 1, shows error bars). +``` +python -m nle.scripts.plot path/to/run/logs.csv --window 10 +``` + +Plot a specific run to a specific window size. (PATH/logs.csv is found automatically) +``` +python -m nle.scripts.plot path/to/run -x 100 -y 50 +``` + +Plot all runs under a specific directory without a legend matching plots to runs. +``` +python -m nle.scripts.plot path/to/multi_runs --no_legend +``` + +Plot all runs matching a directory prefix, zooming in on a specific prefix. +Note that negative ranges need a little help on the command line. +``` +python -m nle.scripts.plot path/to/multi_runs/2020-05 --xrange 0,1e8 --yrange='-10,80' +``` +""" + +import argparse +import glob +import gnuplotlib as gp +import numpy as np +import pandas as pd +import random + +from pathlib import Path + + +def str_to_float_pair(s): + """ + Convert string to pair of floats. + """ + if s is None: + return None + split = s.split(",") + if len(split) != 2: + raise RuntimeError("range does not match pattern 'float,float'") + return (float(split[0]), float(split[1])) + + +parser = argparse.ArgumentParser("NetHack GnuPlotter", allow_abbrev=False) +parser.register("type", "pair", str_to_float_pair) +parser.add_argument( + "-f", + "--file", + type=str, + default="~/torchbeast/latest/logs.csv", + help="file to plot or directory to look for log files", +) +parser.add_argument( + "-k", "--key", type=str, default="mean_episode_return", help="logged value to plot" +) +parser.add_argument( + "-w", "--window", type=int, default=-1, help="override automatic window size." +) +parser.add_argument("-x", "--width", type=int, default=80, help="width of plot") +parser.add_argument("-y", "--height", type=int, default=30, help="height of plot") +parser.add_argument( + "--no_legend", + action="store_true", + help="skip printing legend when plotting multiple experiments", +) +parser.add_argument( + "--xrange", + type="pair", + default=None, + help="float,float. range of x values to plot. overrides automatic zoom for x axis.", +) +parser.add_argument( + "--yrange", + type="pair", + default=None, + help="float,float. range of y values to plot. overrides automatic zoom for y axis.", +) +parser.add_argument( + "--shuffle", + action="store_true", + help="shuffles the order of plotting if rendering multiple curves.", +) + + +def plot_single_ascii( + target, + width, + height, + key="mean_episode_return", + window=-1, + xrange=None, + yrange=None, +): + """ + Plot the target file using the specified width and height. + If window > 0, use it to specify the window size for rolling averages. + xrange and yrange are used to specify the zoom level of the plot. + """ + print("plotting %s" % str(target)) + df = pd.read_csv(target, sep=",", engine="python") + steps = np.array(df["step"]) + + if window < 0: + window = len(steps) // width + 1 + window = df[key].rolling(window=window, min_periods=0) + returns = np.array(window.mean()) + stderrs = np.array(window.std()) + + plot_options = {} + plot_options["with"] = "yerrorbars" + plot_options["terminal"] = "dumb %d %d ansi" % (width, height) + plot_options["tuplesize"] = 3 + plot_options["title"] = key + plot_options["xlabel"] = "steps" + + if xrange is not None: + plot_options["xrange"] = xrange + + if yrange is not None: + plot_options["yrange"] = yrange + + gp.plot(steps, returns, stderrs, **plot_options) + + +def collect_logs(target): + """ + Collect results from log files at the target directory. + Can be fully specified or a partial match, for example: + full: /checkpoint/me/outputs/2020-05-12/00-02-13/ + part: /checkpoint/me/outputs/2020-05-12/00 + """ + dfs = [] + for child in sorted(glob.iglob(str(target))): + child = Path(child) + try: + df = pd.read_csv(child, sep=",") + # TODO: remove rows with nan? maybe bad as will damage rolling window + # df[df[key] == df[key]].copy() + if len(df) > 0: + name = str(child.parent) + dfs.append((name, df)) + except pd.errors.EmptyDataError: + print("Found no data in %s" % str(child)) + except pd.errors.ParserError: + print("Error reading file %s" % str(child)) + + if len(dfs) == 0: + # didn't find any valid csv logs + raise FileNotFoundError("No logs found under %s" % target) + + return dfs + + +def plot_multiple_ascii( + target, + width, + height, + key="mean_episode_return", + window=-1, + xrange=None, + yrange=None, + no_legend=False, + shuffle=False, +): + """ + Plot files under the target path using the specified width and height. + If window > 0, use it to specify the window size for rolling averages. + xrange and yrange are used to specify the zoom level of the plot. + Set no_legend to true to save the visual space for the plot. + shuffle randomizes the order of the plot (does NOT preserve auto-assigned curve + labels), which can help to see a curve which otherwise is overwritten. + """ + dfs = collect_logs(target) + + if window < 0: + max_size = max(len(df["step"]) for _name, df in dfs) + window = 2 * max_size // width + 1 + + datasets = [] + for name, df in dfs: + steps = np.array(df["step"]) + if window > 1: + roll = df[key].rolling(window=window, min_periods=0) + try: + rewards = np.array(roll.mean()) + except pd.core.base.DataError: + print("Error reading file at %s" % name) + continue + else: + rewards = np.array(df[key]) + if no_legend: + datasets.append((steps, rewards)) + else: + datasets.append((steps, rewards, dict(legend=" " + name + ":"))) + + errs = len(dfs) - len(datasets) + if errs > 0: + print( + "Skipped %d runs (%f) due to errors reading data" % (errs, errs / len(dfs)) + ) + + if len(dfs) == 1: + print("Plotting only one found run: %s" % dfs[0][0]) + plot_single_ascii( + dfs[0][0] + "/logs.csv", width, height, key, window, xrange, yrange + ) + return + + print( + "Plotting %d runs with window_size %d from %s" % (len(datasets), window, target) + ) + + plot_options = {} + plot_options["terminal"] = "dumb %d %d ansi" % (width, height) + plot_options["tuplesize"] = 2 + plot_options["title"] = key + plot_options["xlabel"] = "steps" + plot_options["set"] = "key outside below" + + if xrange is not None: + plot_options["xrange"] = xrange + + if yrange is not None: + plot_options["yrange"] = yrange + + if shuffle: + random.shuffle(datasets) + gp.plot(*datasets, **plot_options) + + +def plot(flags): + target = Path(flags.file).expanduser() + + if target.is_file(): + # plot single torchbeast run, path/to/logs.csv + if target.suffix == ".csv": + plot_single_ascii( + target, + flags.width, + flags.height, + flags.key, + flags.window, + flags.xrange, + flags.yrange, + ) + else: + raise RuntimeError( + "Filetype not recognised (expected .csv): %s" % target.suffix + ) + elif (target / "logs.csv").is_file(): + # next check if this is actually a single run directory with file "logs.csv" + plot_single_ascii( + target / "logs.csv", + flags.width, + flags.height, + flags.key, + flags.window, + flags.xrange, + flags.yrange, + ) + else: + # look for runs underneath the specified directory + plot_multiple_ascii( + target, + flags.width, + flags.height, + flags.key, + flags.window, + flags.xrange, + flags.yrange, + flags.no_legend, + flags.shuffle, + ) + + +if __name__ == "__main__": + flags = parser.parse_args() + plot(flags) diff --git a/nle/agent/util/__init__.py b/nle/agent/util/__init__.py new file mode 100644 index 000000000..8daf2005d --- /dev/null +++ b/nle/agent/util/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nle/agent/util/id_pairs.py b/nle/agent/util/id_pairs.py new file mode 100644 index 000000000..30352401f --- /dev/null +++ b/nle/agent/util/id_pairs.py @@ -0,0 +1,142 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum + +import numpy as np + +from nle.nethack import * # noqa: F403 + +# flake8: noqa: F405 + +# TODO: import this from NLE again +NUM_OBJECTS = 453 +MAXEXPCHARS = 9 + + +class GlyphGroup(enum.IntEnum): + # See display.h in NetHack. + MON = 0 + PET = 1 + INVIS = 2 + DETECT = 3 + BODY = 4 + RIDDEN = 5 + OBJ = 6 + CMAP = 7 + EXPLODE = 8 + ZAP = 9 + SWALLOW = 10 + WARNING = 11 + STATUE = 12 + + +def id_pairs_table(): + """Returns a lookup table for glyph -> NLE id pairs.""" + table = np.zeros([MAX_GLYPH, 2], dtype=np.int16) + + num_nle_ids = 0 + + for glyph in range(GLYPH_MON_OFF, GLYPH_PET_OFF): + table[glyph] = (glyph, GlyphGroup.MON) + num_nle_ids += 1 + + for glyph in range(GLYPH_PET_OFF, GLYPH_INVIS_OFF): + table[glyph] = (glyph - GLYPH_PET_OFF, GlyphGroup.PET) + + for glyph in range(GLYPH_INVIS_OFF, GLYPH_DETECT_OFF): + table[glyph] = (num_nle_ids, GlyphGroup.INVIS) + num_nle_ids += 1 + + for glyph in range(GLYPH_DETECT_OFF, GLYPH_BODY_OFF): + table[glyph] = (glyph - GLYPH_DETECT_OFF, GlyphGroup.DETECT) + + for glyph in range(GLYPH_BODY_OFF, GLYPH_RIDDEN_OFF): + table[glyph] = (glyph - GLYPH_BODY_OFF, GlyphGroup.BODY) + + for glyph in range(GLYPH_RIDDEN_OFF, GLYPH_OBJ_OFF): + table[glyph] = (glyph - GLYPH_RIDDEN_OFF, GlyphGroup.RIDDEN) + + for glyph in range(GLYPH_OBJ_OFF, GLYPH_CMAP_OFF): + table[glyph] = (num_nle_ids, GlyphGroup.OBJ) + num_nle_ids += 1 + + for glyph in range(GLYPH_CMAP_OFF, GLYPH_EXPLODE_OFF): + table[glyph] = (num_nle_ids, GlyphGroup.CMAP) + num_nle_ids += 1 + + for glyph in range(GLYPH_EXPLODE_OFF, GLYPH_ZAP_OFF): + id_ = num_nle_ids + (glyph - GLYPH_EXPLODE_OFF) // MAXEXPCHARS + table[glyph] = (id_, GlyphGroup.EXPLODE) + + num_nle_ids += EXPL_MAX + + for glyph in range(GLYPH_ZAP_OFF, GLYPH_SWALLOW_OFF): + id_ = num_nle_ids + (glyph - GLYPH_ZAP_OFF) // 4 + table[glyph] = (id_, GlyphGroup.ZAP) + + num_nle_ids += NUM_ZAP + + for glyph in range(GLYPH_SWALLOW_OFF, GLYPH_WARNING_OFF): + table[glyph] = (num_nle_ids, GlyphGroup.SWALLOW) + num_nle_ids += 1 + + for glyph in range(GLYPH_WARNING_OFF, GLYPH_STATUE_OFF): + table[glyph] = (num_nle_ids, GlyphGroup.WARNING) + num_nle_ids += 1 + + for glyph in range(GLYPH_STATUE_OFF, MAX_GLYPH): + table[glyph] = (glyph - GLYPH_STATUE_OFF, GlyphGroup.STATUE) + + return table + + +def id_pairs_func(glyph): + result = glyph_to_mon(glyph) + if result != NO_GLYPH: + return result + if glyph_is_invisible(glyph): + return NUMMONS + if glyph_is_body(glyph): + return glyph - GLYPH_BODY_OFF + + offset = NUMMONS + 1 + + # CORPSE handled by glyph_is_body; STATUE handled by glyph_to_mon. + result = glyph_to_obj(glyph) + if result != NO_GLYPH: + return result + offset + offset += NUM_OBJECTS + + # I don't understand glyph_to_cmap and/or the GLYPH_EXPLODE_OFF definition + # with MAXPCHARS - MAXEXPCHARS. + if GLYPH_CMAP_OFF <= glyph < GLYPH_EXPLODE_OFF: + return glyph - GLYPH_CMAP_OFF + offset + offset += MAXPCHARS - MAXEXPCHARS + + if GLYPH_EXPLODE_OFF <= glyph < GLYPH_ZAP_OFF: + return (glyph - GLYPH_EXPLODE_OFF) // MAXEXPCHARS + offset + offset += EXPL_MAX + + if GLYPH_ZAP_OFF <= glyph < GLYPH_SWALLOW_OFF: + return ((glyph - GLYPH_ZAP_OFF) >> 2) + offset + offset += NUM_ZAP + + if GLYPH_SWALLOW_OFF <= glyph < GLYPH_WARNING_OFF: + return offset + offset += 1 + + result = glyph_to_warning(glyph) + if result != NO_GLYPH: + return result + offset