Skip to content

Commit

Permalink
fix quantize with scale type
Browse files Browse the repository at this point in the history
  • Loading branch information
wenscarl committed Sep 24, 2024
1 parent c9b0fc5 commit 9d0f844
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions flax/linen/fp8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,16 @@ def quantize_and_update(
new_scale = compute_scale(amax_from_history, scale, dtype_max)
new_history = compute_amax_history(x, amax_history)

if not use_direct_quant:
qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype)

# convert f32->fmax32 so the autodiff system accumulates fp8 meta correctly
if is_fmax32:
new_history = lax.convert_element_type(new_history, fp32_max_grad)
new_scale = lax.convert_element_type(new_scale, fp32_max_grad)

# Quantize the input
if not use_direct_quant:
qx = quantize_dequantize(x, q_dtype, new_scale, compute_dtype, quantize_only=quantize_only)
return qx, new_scale, new_history

return new_scale, new_history


Expand Down

0 comments on commit 9d0f844

Please sign in to comment.