-
Notifications
You must be signed in to change notification settings - Fork 270
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Feature Summary
Allow passing foward_mode_differentiation argument to autoguide (either at initialization or call time) to control how it calls initialize_model.
Why is this needed?
I am using a jax routine in my model (linalg.eigh) which has poor robustness under reverse-mode differentiation (it often produces nan values in this mode, for some reason). Therefore I need to use forward-mode differentiation, However, if I try to use such such a model with autoguide (in my case AutoNormal, but I think this applies to all) it tries to initialize_model without passing the forward_mode_differentiation argument, falling back to the default reverse mode differentiation, which fails to produce valid results.
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request