-
Notifications
You must be signed in to change notification settings - Fork 47
Description
Describe the bug
First off, thanks for the amazing codebase! I'm a big fan of the pure Jax approach.
I think most places in the code that use bootstrapping aren't handling episode truncation properly. e.g. in ff_mz.py we have:
r_t = sequence.reward[:, :-1]
d_t = 1.0 - sequence.done.astype(jnp.float32)
d_t = (d_t * config.system.gamma).astype(jnp.float32)
d_t = d_t[:, :-1]
search_values = sequence.search_value[:, 1:]
value_targets = batch_n_step_bootstrapped_returns(
r_t, d_t, search_values, config.system.n_steps
)Note that sequence.done is set to timestep.last() which in turn checks if the step type is TRUNCATED or TERMINATED.
Let's consider what happens for a sequence of three timesteps (for a single env) corresponding to StepTypes (MID, TRUNCATED, FIRST). The arguments to batch_n_step_bootstrapped_returns might look like the following:
r_t = jnp.asarray([[1.0, 1.0, 0.0]])
discount_t = d_t = jnp.asarray([[0.99, 0.0, 0.99]])
v_t = search_values = jnp.asarray([[6.0, 5.0, 10.0]])Let's say n_steps = 1. Then the value of the "MID" timestep should actually be 1.0 + 5.0 = 6.0. But here since we pass a discount of zero, the value estimate isn't captured, and the target value ends up being just 1.0.
To Reproduce
It's pretty deep in the code, so I haven't gotten to write a unit test. But if you plug the above values into batch_n_step_bootstrapped_returns I indeed get Array([[6.94, 1. , 9.9 ]], dtype=float32).
Expected behavior
In the above test case, the second element of the response (the value target for the MID state) should be 6.0 as described above.
Context (Environment)
I'm on commit fe9de0a. Running on MacOS 15.6.1. I don't think the full context is relevant for this issue.
Additional context
None
Possible Solution
In the ExItTransition, we should include the original discount. Then we can pass d_t as the lambda_t parameter.
r_t = sequence.reward[:, :-1]
d_t = 1.0 - sequence.done[:, :-1].astype(jnp.float32)
discount_t = sequence.discount[:, :-1] * config.system.gamma
search_values = sequence.search_value[:, 1:]
value_targets = batch_n_step_bootstrapped_returns(
r_t, discount_t, search_values, config.system.n_steps, d_t
)