Skip to content

Commit

Permalink
[microNPU][ETHOSU] Fix Softmax quantization parameters (apache#14774)
Browse files Browse the repository at this point in the history
Fix zero point and scale values for operations according to the values in Vela, the test is updated to check case with different input and output zero point.
  • Loading branch information
Aleksei-grovety authored May 15, 2023
1 parent b6d7ce6 commit 0274930
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
28 changes: 14 additions & 14 deletions python/tvm/relay/backend/contrib/ethosu/softmax_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def callback(
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=int(params.ifm.q_params.zero_point),
ofm_scale=0.0,
ofm_zero_point=int(params.ofm.q_params.zero_point),
ofm_zero_point=int(params.ifm.q_params.zero_point),
pool_shape=(1, depth),
ofm_channels=1,
ofm_dtype=ifm_dtype,
Expand Down Expand Up @@ -141,7 +141,7 @@ def callback(
ifm2_scale=0.0,
ifm2_zero_point=0,
ofm_scale=0.0,
ofm_zero_point=int(params.ofm.q_params.zero_point),
ofm_zero_point=int(params.ifm.q_params.zero_point),
ifm_channels=params.ifm.shape[-1],
ifm2_channels=1,
reversed_operands=False,
Expand All @@ -160,7 +160,7 @@ def callback(
ifm_scale=0.0,
ifm_zero_point=0,
ofm_scale=0.0,
ofm_zero_point=int(params.ofm.q_params.zero_point),
ofm_zero_point=int(params.ifm.q_params.zero_point),
pool_shape=(1, 1),
ofm_channels=1,
upscale="NONE",
Expand All @@ -175,7 +175,7 @@ def callback(
ifm_scale=0.0,
ifm_zero_point=0,
ofm_scale=0.0,
ofm_zero_point=int(params.ofm.q_params.zero_point),
ofm_zero_point=int(params.ifm.q_params.zero_point),
ofm_channels=1,
)

Expand All @@ -186,12 +186,12 @@ def callback(
ifm2=headroom_plus_one,
lut=lut,
operator_type="SUB",
ifm_scale=1.0,
ifm_scale=0.0,
ifm_zero_point=0,
ifm2_scale=0.0,
ifm2_zero_point=0,
ofm_scale=1.0,
ofm_zero_point=int(params.ofm.q_params.zero_point),
ofm_scale=0.0,
ofm_zero_point=int(params.ifm.q_params.zero_point),
ifm_channels=1,
ifm2_channels=1,
reversed_operands=False,
Expand All @@ -210,7 +210,7 @@ def callback(
ifm2_scale=0.0,
ifm2_zero_point=0,
ofm_scale=0.0,
ofm_zero_point=int(params.ofm.q_params.zero_point),
ofm_zero_point=int(params.ifm.q_params.zero_point),
ifm_channels=1,
ifm2_channels=1,
reversed_operands=False,
Expand All @@ -228,7 +228,7 @@ def callback(
ifm2_scale=0.0,
ifm2_zero_point=0,
ofm_scale=0.0,
ofm_zero_point=int(params.ofm.q_params.zero_point),
ofm_zero_point=int(params.ifm.q_params.zero_point),
ifm_channels=depth,
ifm2_channels=1,
reversed_operands=False,
Expand All @@ -250,7 +250,7 @@ def callback(
ifm2_scale=0.0,
ifm2_zero_point=0,
ofm_scale=0.0,
ofm_zero_point=int(params.ofm.q_params.zero_point),
ofm_zero_point=int(params.ifm.q_params.zero_point),
ifm_channels=1,
ifm2_channels=1,
reversed_operands=False,
Expand All @@ -268,7 +268,7 @@ def callback(
ifm2_scale=0.0,
ifm2_zero_point=0,
ofm_scale=0.0,
ofm_zero_point=int(params.ofm.q_params.zero_point),
ofm_zero_point=int(params.ifm.q_params.zero_point),
ifm_channels=1,
ifm2_channels=1,
reversed_operands=False,
Expand Down Expand Up @@ -378,9 +378,9 @@ def callback(
ifm2=half_denominator_times_x,
lut=lut,
operator_type="SUB",
ifm_scale=0.0,
ifm_scale=2.0,
ifm_zero_point=0,
ifm2_scale=2.0,
ifm2_scale=0.0,
ifm2_zero_point=0,
ofm_scale=1.0,
ofm_zero_point=0,
Expand Down Expand Up @@ -422,7 +422,7 @@ def callback(
ifm2_scale=0.0,
ifm2_zero_point=0,
ofm_scale=0.0,
ofm_zero_point=int(params.ofm.q_params.zero_point),
ofm_zero_point=int(params.ifm.q_params.zero_point),
ifm_channels=1,
ifm2_channels=1,
reversed_operands=False,
Expand Down
2 changes: 1 addition & 1 deletion tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def test_ethosu_softmax(
def softmax(x):
return tf.nn.softmax(x)

infra.compare_tvm_with_tflite(softmax, [ifm_shape], accel_type)
infra.compare_tvm_with_tflite(softmax, [ifm_shape], accel_type, ranges=[(-1, 1)])


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
Expand Down

0 comments on commit 0274930

Please sign in to comment.