Skip to content

Commit

Permalink
Updated with a version of MFA that H_Hk_ratio is passed into it.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Jan 4, 2024
1 parent 099e757 commit 02156f7
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 3 deletions.
Binary file modified lib/nnc/mfa/3rdparty/libmfaios16-v1.0.2-a.metallib
Binary file not shown.
Binary file modified lib/nnc/mfa/3rdparty/libmfamacos13-v1.0.2-a.metallib
Binary file not shown.
3 changes: 2 additions & 1 deletion lib/nnc/mfa/ccv_nnc_mfa_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ mfa::attention::pipeline::pipeline(mfa::context* context, mfa::attention::hash h
constants->setConstantValue(&hash.C, MTL::DataTypeUInt, 1);
constants->setConstantValue(&hash.Hq, MTL::DataTypeUInt, 2);
constants->setConstantValue(&hash.D, MTL::DataTypeUInt, 3);
constants->setConstantValue(&hash.Hk, MTL::DataTypeUInt, 4);
constants->setConstantValue(&hash.Q_trans, MTL::DataTypeBool, 10);
constants->setConstantValue(&hash.K_trans, MTL::DataTypeBool, 11);
constants->setConstantValue(&hash.V_trans, MTL::DataTypeBool, 12);
Expand All @@ -248,6 +247,8 @@ mfa::attention::pipeline::pipeline(mfa::context* context, mfa::attention::hash h
bool backward = false;
bool generate_block_mask = false;
bool grouped_query = (hash.Hq != hash.Hk);
uint32_t H_Hk_ratio = hash.Hq / hash.Hk;
constants->setConstantValue(&H_Hk_ratio, MTL::DataTypeUInt, 4);
constants->setConstantValue(&block_sparse, MTL::DataTypeBool, 102);
constants->setConstantValue(&triangular, MTL::DataTypeBool, 103);
constants->setConstantValue(&forward, MTL::DataTypeBool, 110);
Expand Down
4 changes: 2 additions & 2 deletions lib/nnc/mfa/libmfa.inc

Large diffs are not rendered by default.

0 comments on commit 02156f7

Please sign in to comment.