diff --git a/src/doom/actions.py b/src/doom/actions.py index 835b4dd..63146e7 100644 --- a/src/doom/actions.py +++ b/src/doom/actions.py @@ -88,8 +88,9 @@ def get_action(self, action): for k in self.available_buttons] return doom_action else: - assert type(action) is int - return self.doom_actions[action] + a = action if type(action) == int else action.item() + assert type(a) is int + return self.doom_actions[a] action_categories_discrete = { diff --git a/src/doom/scenarios/deathmatch.py b/src/doom/scenarios/deathmatch.py index e78fcb5..db3d48a 100644 --- a/src/doom/scenarios/deathmatch.py +++ b/src/doom/scenarios/deathmatch.py @@ -173,7 +173,7 @@ def evaluate_deathmatch(game, network, params, n_train_iter=None): # observe the game state / select the next action game.observe_state(params, last_states) - action = network.next_action(last_states) + action = network.next_action(last_states).tolist() pred_features = network.pred_features # game features diff --git a/src/model/bucketed_embedding.py b/src/model/bucketed_embedding.py index f88a4d8..057894f 100644 --- a/src/model/bucketed_embedding.py +++ b/src/model/bucketed_embedding.py @@ -1,7 +1,7 @@ -import torch.nn as nn +import torch -class BucketedEmbedding(nn.Embedding): +class BucketedEmbedding(torch.nn.Embedding): def __init__(self, bucket_size, num_embeddings, *args, **kwargs): self.bucket_size = bucket_size @@ -9,4 +9,4 @@ def __init__(self, bucket_size, num_embeddings, *args, **kwargs): super(BucketedEmbedding, self).__init__(real_num_embeddings, *args, **kwargs) def forward(self, indices): - return super(BucketedEmbedding, self).forward(indices.div(self.bucket_size)) + return super(BucketedEmbedding, self).forward(indices.div(self.bucket_size).type(torch.LongTensor)) diff --git a/src/model/dqn/base.py b/src/model/dqn/base.py index 6e1c61c..504d30d 100644 --- a/src/model/dqn/base.py +++ b/src/model/dqn/base.py @@ -78,7 +78,12 @@ def base_forward(self, x_screens, x_variables): # create state input if self.n_variables: - output = torch.cat([conv_output] + embeddings, 1) + if(len(embeddings[0].shape) != 3): + embeddings[0] = embeddings[0].unsqueeze(0) + embeddings[1] = embeddings[1].unsqueeze(0) + output = torch.cat([conv_output.unsqueeze(0)] + embeddings, dim=2) + else: + output = torch.cat([conv_output.unsqueeze(0)] + embeddings, dim=2) else: output = conv_output @@ -185,13 +190,14 @@ def prepare_f_train_args(self, screens, variables, features, return screens, variables, features, actions, rewards, isfinal def register_loss(self, loss_history, loss_sc, loss_gf): - loss_history['dqn_loss'].append(loss_sc.data[0]) - loss_history['gf_loss'].append(loss_gf.data[0] + loss_history['dqn_loss'].append(loss_sc.data) + loss_history['gf_loss'].append(loss_gf.data if self.n_features else 0) def next_action(self, last_states, save_graph=False): scores, pred_features = self.f_eval(last_states) if self.params.network_type == 'dqn_ff': + scores = scores.squeeze(0) assert scores.size() == (1, self.module.n_actions) scores = scores[0] if pred_features is not None: @@ -205,7 +211,7 @@ def next_action(self, last_states, save_graph=False): if pred_features is not None: assert pred_features.size() == (1, seq_len, self.module.n_features) pred_features = pred_features[0, -1] - action_id = scores.data.max(0)[1][0] + action_id = scores.data.max(0)[1] self.pred_features = pred_features return action_id diff --git a/src/model/dqn/feedforward.py b/src/model/dqn/feedforward.py index e2af5a6..e9e7acc 100644 --- a/src/model/dqn/feedforward.py +++ b/src/model/dqn/feedforward.py @@ -21,10 +21,14 @@ def forward(self, x_screens, x_variables): """ batch_size = x_screens.size(0) + + for x in x_variables: + x.unsqueeze_(0) + assert x_screens.ndimension() == 4 assert len(x_variables) == self.n_variables - assert all(x.ndimension() == 1 and x.size(0) == batch_size - for x in x_variables) + #assert all(x.ndimension() == 1 and x.size(0) == batch_size + # for x in x_variables) # state input (screen / depth / labels buffer + variables) state_input, output_gf = self.base_forward(x_screens, x_variables) @@ -75,6 +79,9 @@ def f_train(self, screens, variables, features, actions, rewards, isfinal, # compute scores mask = torch.ByteTensor(output_sc1.size()).fill_(0) + mask = mask.squeeze(0) + output_sc2 = output_sc2.squeeze(0) + for i in range(batch_size): mask[i, int(actions[i, -1])] = 1 scores1 = output_sc1.masked_select(self.get_var(mask))