Skip to content

Commit

Permalink
improved error messages
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuaspear committed May 30, 2024
1 parent 7d04bfa commit 568b6c2
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 50 deletions.
48 changes: 0 additions & 48 deletions .github/python-dev.yml

This file was deleted.

18 changes: 17 additions & 1 deletion src/offline_rl_ope/components/ImportanceSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,23 @@ def update(
states=states, actions=actions, eval_policy=eval_policy)
check_array_dim(_is_weights,2)
check_array_dim(_weight_msk,2)
assert len(eval_policy.policy_actions) == len(actions)
len_act = len(actions)
len_eval_act = len(eval_policy.policy_actions)
_msg = f"""
Actions have length: {len_act}.
Evalutaion policy predicted actions have length: {len_eval_act}
"""
try:
assert len_act == len_eval_act, _msg
except Exception as e:
if len_eval_act == 0:
logger.info(
"""
No actions assoicated with evaluation policy.
Has collect_act been set to true in the Policy class?
"""
)
raise e
self.policy_actions = eval_policy.policy_actions
self.is_weights = _is_weights
self.weight_msk = _weight_msk
Expand Down
2 changes: 1 addition & 1 deletion src/offline_rl_ope/components/Policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __call__(
self,
state:StateTensor,
action:ActionTensor
)->ActionTensor:
)->Float[torch.Tensor, "traj_length 1"]:
"""Defines the probability of the given actions under the given states
according to the policy defined by policy_func
Expand Down

0 comments on commit 568b6c2

Please sign in to comment.