Skip to content

Commit

Permalink
Add translated AttentionKernel. Need to do validate on the files.
Browse files Browse the repository at this point in the history
liuliu committed Sep 12, 2024
1 parent ca70311 commit 6ac4534
Showing 11 changed files with 3,617 additions and 6 deletions.
2 changes: 1 addition & 1 deletion lib/nnc/mfa/makefile
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@ include ../../config.mk

CFLAGS := -std=c++17 -O3 -Wall -I"../../" $(CFLAGS)

SRCS := Metal.cpp ccv_nnc_mfa.cpp ccv_nnc_mfa_attention.cpp ccv_nnc_mfa_error.cpp ccv_nnc_mfa_gemm.cpp ccv_nnc_mfa_normalization.cpp ccv_nnc_mfa_depalettize.cpp ccv_nnc_mfa_adam.cpp ccv_nnc_mfa_cmul.cpp ccv_nnc_mfa_gemv.cpp ccv_nnc_mfa_cast.cpp ccv_nnc_mfa_add.cpp 3rdparty/metal-cpp/Dispatch.cpp v2/CodeWriter.cpp v2/GEMMDescriptor.cpp v2/GEMMKernelDescriptor.cpp v2/GEMMHeaders.cpp v2/GEMMKernel.cpp
SRCS := Metal.cpp ccv_nnc_mfa.cpp ccv_nnc_mfa_attention.cpp ccv_nnc_mfa_error.cpp ccv_nnc_mfa_gemm.cpp ccv_nnc_mfa_normalization.cpp ccv_nnc_mfa_depalettize.cpp ccv_nnc_mfa_adam.cpp ccv_nnc_mfa_cmul.cpp ccv_nnc_mfa_gemv.cpp ccv_nnc_mfa_cast.cpp ccv_nnc_mfa_add.cpp 3rdparty/metal-cpp/Dispatch.cpp v2/CodeWriter.cpp v2/GEMMDescriptor.cpp v2/GEMMKernelDescriptor.cpp v2/GEMMHeaders.cpp v2/GEMMKernel.cpp v2/AttentionDescriptor.cpp v2/AttentionKernelDescriptor.cpp v2/AttentionKernel.cpp

SRC_OBJS := $(patsubst %.c,%.o,$(patsubst %.cpp,%.o,$(SRCS)))

29 changes: 29 additions & 0 deletions lib/nnc/mfa/v2/AttentionDescriptor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "AttentionDescriptor.hpp"
#include "AttentionKernelDescriptor.hpp"
// #include "AttentionKernel.hpp"
#include "../ccv_nnc_mfa_hash.hpp"
#include "../ccv_nnc_mfa_error.hpp"

bool AttentionDescriptor::operator==(const AttentionDescriptor& rhs) const {
return
(lowPrecisionInputs == rhs.lowPrecisionInputs) &&
(lowPrecisionIntermediates == rhs.lowPrecisionIntermediates) &&
simd_all(matrixDimensions == rhs.matrixDimensions) &&
simd_all(transposeState == rhs.transposeState);
}

std::size_t std::hash<AttentionDescriptor>::operator()(const AttentionDescriptor& hash) const noexcept {
std::size_t seed = 0;
using namespace ccv::nnc::mfa::hash;
combine_32(seed, hash.matrixDimensions[0]);
combine_32(seed, hash.matrixDimensions[1]);
combine_32(seed, hash.matrixDimensions[2]);
combine_32(seed, pack_32(simd::uchar4 { hash.transposeState[0], hash.transposeState[1], hash.transposeState[2], hash.transposeState[3] }));
combine_32(seed, pack_32(simd::uchar4 { hash.lowPrecisionInputs, hash.lowPrecisionIntermediates, 0, 0 }));
return seed;
}

/*
std::pair<AttentionKernelDescriptor, PipelineValue<AttentionKernel> *> AttentionDescriptor::findKernel(MTL::Device *const device, const DeviceProperties &dprops, std::unordered_map<AttentionKernelDescriptor, std::unique_ptr<AttentionKernel>> *const libraryCache) const noexcept {
}
*/
40 changes: 40 additions & 0 deletions lib/nnc/mfa/v2/AttentionDescriptor.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#ifndef MFA_ATTENTIONDESCRIPTOR_HPP_
#define MFA_ATTENTIONDESCRIPTOR_HPP_

#include <simd/simd.h>
#include <utility>
#include "PipelineValue.hpp"
#include "DeviceProperties.hpp"
#include "GEMMOperandPrecision.hpp"

struct AttentionKernelDescriptor;
struct AttentionKernel;

struct AttentionDescriptor {
/// Q, K, V, dO
bool lowPrecisionInputs;

/// S, P, L, D, dP, dS
bool lowPrecisionIntermediates;

/// row: Output sequence length; rows of the attention matrix.
/// column: Input sequence length; columns of the attention matrix.
/// head: Head dimension, typically 32 - 256.
simd::uint3 matrixDimensions;

/// Q, K, V, O
simd::uchar4 transposeState;

bool operator==(const AttentionDescriptor& rhs) const;

// std::pair<AttentionKernelDescriptor, PipelineValue<AttentionKernel> *> findKernel(MTL::Device* const device, const DeviceProperties &dprops, std::unordered_map<AttentionKernelDescriptor, std::unique_ptr<AttentionKernel>> *const libraryCache) const noexcept;
};

template<>
struct std::hash<AttentionDescriptor>
{
std::size_t operator()(const AttentionDescriptor& hash) const noexcept;
};

#endif

Loading

0 comments on commit 6ac4534

Please sign in to comment.