Skip to content

Support forward-mode differentiation with autoguide. #2068

@PhilReinhold

Description

@PhilReinhold

Feature Summary

Allow passing foward_mode_differentiation argument to autoguide (either at initialization or call time) to control how it calls initialize_model.

Why is this needed?

I am using a jax routine in my model (linalg.eigh) which has poor robustness under reverse-mode differentiation (it often produces nan values in this mode, for some reason). Therefore I need to use forward-mode differentiation, However, if I try to use such such a model with autoguide (in my case AutoNormal, but I think this applies to all) it tries to initialize_model without passing the forward_mode_differentiation argument, falling back to the default reverse mode differentiation, which fails to produce valid results.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions