Skip to content

Commit

Permalink
all in decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
wenscarl committed Sep 24, 2024
1 parent de8d158 commit 2bf7694
Showing 1 changed file with 33 additions and 3 deletions.
36 changes: 33 additions & 3 deletions flax/linen/fp8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,11 +390,40 @@ def q_dot_dq(
)
return y # type: ignore

# This decorator wraps a function to perform quantized dot product and dequantization.
# It prepares the arguments for q_dot_dq.
def q_dot_dq_config(
lhs_scale, rhs_scale, out_grad_scale,
lhs_amax_history, rhs_amax_history, out_grad_amax_history,
compute_dtype
):
def decorator(func):
def wrapper(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None):
return q_dot_dq(
lhs,
rhs,
lhs_scale=lhs_scale,
rhs_scale=rhs_scale,
out_grad_scale=out_grad_scale,
lhs_amax_history=lhs_amax_history,
rhs_amax_history=rhs_amax_history,
out_grad_amax_history=out_grad_amax_history,
compute_dtype=compute_dtype,
dimension_numbers=dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
)
return wrapper
return decorator

# This decorator wraps a function to perform one-sided quantized dot product and dequantization.
# It prepares the arguments for fp8_ops.one_sided_q_dot_dq, including the pre-quantized input,
# It prepares the arguments for one_sided_q_dot_dq, including the pre-quantized input,
# scales, and amax histories. This allows for efficient FP8 matrix multiplication while
# managing quantization parameters.
def one_sided_q_dot_dq_config(comp_dtype, q_x, input_scale, kernel_scale, out_grad_scale, kernel_amax_history, out_grad_amax_history):
def one_sided_q_dot_dq_config(
compute_dtype, q_x, input_scale, kernel_scale, out_grad_scale,
kernel_amax_history, out_grad_amax_history
):
def decorator(func):
def wrapper(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None):
return one_sided_q_dot_dq(
Expand All @@ -406,14 +435,15 @@ def wrapper(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=
out_grad_scale=out_grad_scale,
rhs_amax_history=kernel_amax_history,
out_grad_amax_history=out_grad_amax_history,
compute_dtype=comp_dtype,
compute_dtype=compute_dtype,
dimension_numbers=dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type
)
return wrapper
return decorator


class Fp8DotGeneralBase(module.Module):
amax_history_length: int = 1024
e4m3_dtype: DType = jnp.float8_e4m3fn
Expand Down

0 comments on commit 2bf7694

Please sign in to comment.