Skip to content

Commit

Permalink
bug fix with Policy class when using GPU. Action input to policy_func…
Browse files Browse the repository at this point in the history
… not being pushed to GPU.
  • Loading branch information
joshuaspear committed Sep 6, 2024
1 parent 1b1c6a2 commit a281468
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions src/offline_rl_ope/components/Policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __call__(
# assert isinstance(action,torch.Tensor)
# check_array_dim(action,2)
state = self.preproc_tens(state)
action = self.preproc_tens(action)
p_return = self.policy_func(state, action)
actions = self.postproc_tens(p_return.actions)
action_prs = self.postproc_tens(p_return.action_prs)
Expand Down

0 comments on commit a281468

Please sign in to comment.