Skip to content

[BUG] Truncation not handled properly in n-step bootstrapped returns #181

@adzcai

Description

@adzcai

Describe the bug

First off, thanks for the amazing codebase! I'm a big fan of the pure Jax approach.

I think most places in the code that use bootstrapping aren't handling episode truncation properly. e.g. in ff_mz.py we have:

r_t = sequence.reward[:, :-1]
d_t = 1.0 - sequence.done.astype(jnp.float32)
d_t = (d_t * config.system.gamma).astype(jnp.float32)
d_t = d_t[:, :-1]
search_values = sequence.search_value[:, 1:]
value_targets = batch_n_step_bootstrapped_returns(
    r_t, d_t, search_values, config.system.n_steps
)

Note that sequence.done is set to timestep.last() which in turn checks if the step type is TRUNCATED or TERMINATED.

Let's consider what happens for a sequence of three timesteps (for a single env) corresponding to StepTypes (MID, TRUNCATED, FIRST). The arguments to batch_n_step_bootstrapped_returns might look like the following:

r_t = jnp.asarray([[1.0, 1.0, 0.0]])
discount_t = d_t = jnp.asarray([[0.99, 0.0, 0.99]])
v_t = search_values = jnp.asarray([[6.0, 5.0, 10.0]])

Let's say n_steps = 1. Then the value of the "MID" timestep should actually be 1.0 + 5.0 = 6.0. But here since we pass a discount of zero, the value estimate isn't captured, and the target value ends up being just 1.0.

To Reproduce

It's pretty deep in the code, so I haven't gotten to write a unit test. But if you plug the above values into batch_n_step_bootstrapped_returns I indeed get Array([[6.94, 1. , 9.9 ]], dtype=float32).

Expected behavior

In the above test case, the second element of the response (the value target for the MID state) should be 6.0 as described above.

Context (Environment)

I'm on commit fe9de0a. Running on MacOS 15.6.1. I don't think the full context is relevant for this issue.

Additional context

None

Possible Solution

In the ExItTransition, we should include the original discount. Then we can pass d_t as the lambda_t parameter.

r_t = sequence.reward[:, :-1]
d_t = 1.0 - sequence.done[:, :-1].astype(jnp.float32)
discount_t = sequence.discount[:, :-1] * config.system.gamma
search_values = sequence.search_value[:, 1:]
value_targets = batch_n_step_bootstrapped_returns(
    r_t, discount_t, search_values, config.system.n_steps, d_t
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions