-
I saw this pattern in the imagenet example - can anyone help me understand why are we able to use I really appreciate your help! YiYi @jax.pmap
def train_step(state,...):
if state.dynamic_scale:
....
train_step(state,...) https://github.com/google/flax/blob/main/examples/imagenet/train.py#L128 |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
You can use Python control logic inside a In this case, the if statement is about the value of flax/examples/imagenet/train.py Lines 222 to 244 in 6fd2446 |
Beta Was this translation helpful? Give feedback.
-
ohhh thank you! i got it now. it is not actually a |
Beta Was this translation helpful? Give feedback.
You can use Python control logic inside a
jit
/pmap
, as long as that control flow is not value dependent.In this case, the if statement is about the value of
state.dynamic_scale
which can beNone
, if no dynamic scaling is used, or am instance ofdynamic_scale_lib.DynamicScale
, if dynamic scaling is used. In either case, the value will be set from the config (see lines 227-230 in code snippet below), and is not changed during the training. In other words,jit
/pmap
will produce a single compiled version of the code.flax/examples/imagenet/train.py
Lines 222 to 244 in 6fd2446