Skip to content

chumy6/Attention

Repository files navigation

基于 Verilog 的 Attention 实现

工具:vivado 2018.3

1. 功能描述

设计一个模块进行 Self-Attention 计算,输入 Q、K、V 矩阵,得到输出结果,并进行量化。模块参数设为:Token 数量为 8,Token 特征维度为 4,数据采用 Fix_8_8 即 8 位整数、8 位小数的定点无符号数格式存储和量化。用 python 设计 golden model 用以验证,用 verilog 完成硬件实现。

2. 设计细节

2.1 Softmax

硬件实现中简化为 $softmax(A[i,j]) = (A[i,j]-minA[:,j])^2$

2.2 量化

采用 Fix_8_8 格式的 16 位定点数来存储数据,即 8 位整数、8 位小数。但在 Attention 计算过程中,矩阵乘法和 Softmax 计算会导致数据位宽增大,因此在计算过程中需要将数据量化。

在矩阵乘法中,经过乘累加运算,输出的数据为 Fix_18_16,需将其量化为 Fix_8_8;在 Softmax 计算中,经过平方运算,输出的数据为 Fix_16_16,需将其量化为 Fix_8_8。

量化过程包括两种情况。在小数位数不足以表达该数值的小数部分时,采用舍入的方法,即根据超出部分的大小,若小于最小精度的一半,则舍去,否则进一位。在整数部分超出位数,发生溢出时,采用饱和的方法,即近似到最大的正值。

2.3 Golden model

用 python 设计该模块的 golden model,模拟计算中的矩阵乘法、softmax 和量化等过程,以验证硬件实现的正确性。为了编写方便,矩阵操作采用 pytorch 实现。同时设计 python 测试脚本,随机生成模块的输入数据,进行仿真,并将仿真结果与 golden model 计算结果比较。

3. 硬件模块设计

3.1 实现原理

Attention 计算可以分为三个步骤,分别是:矩阵乘法 Q·K^T、Softmax 计算、矩阵乘法 Score·V。

本设计采用流水线结构实现,即每一个步骤作为流水线的一级,形成 3 级流水,3 个周期即可完成全部计算。

顶层模块 Attention_top 输入 Q、K、V 矩阵,和 clk、rst 信号,输出计算结果。在顶层模块中,每一级流水分别作为一级子模块。在第一级子模块中,包括矩阵转置模块和矩阵乘法模块;在第二级子模块中,包括矩阵 Softmax 模块;在第三级子模块中,包括矩阵乘法模块。在输入和每一级流水结构中,均有寄存器,由 Dff 模块例化生成。

3.2 模块设计

在模块设计中,各个子模块相关接口均采用参数化设计,便于模块复用与扩展。下面介绍各模块及子模块的设计。

3.2.1 量化模块

量化模块实现对输入数据的量化输出,在矩阵乘法模块和 Softmax 模块中均有使用。该模块首先判断输入输出是否溢出,即高位超出量化位数的部分是否不位 0,若溢出则饱和,否则进行舍入;舍入的判断标准是超出最小精度的部分是否达到最小精度的一半,具体到实现中,通过判断低位超出量化位数部分的最高位是否为 1,若为 0 则舍去超出部分输出,否则进位输出,若进位后发生溢出则饱和输出。

3.2.2 矩阵乘法模块

矩阵乘法实现两个输入矩阵的乘法运算和结果的量化。该模块通过例化一系列带有量化的 MAC 模块实现,每个 MAC 模块的输入是两个输入矩阵的一行与一列,输出对应乘积矩阵的一个元素。

3.2.3 MAC 模块

MAC 模块实现两个输入向量的乘累加运算,将两个输入向量每个元素相乘后累加,并将结果通过量化模块输出。为了减小 MAC 的关键路径,累加操作通过树状加法器实现。

3.2.4 转置模块

转置模块对输入矩阵进行转置,用于 Q·K^T 的计算中。该模块通过对输入矩阵的元素地址进行变换来实现,即将按行顺序存储的输入矩阵按列读取得到输出转置矩阵。

3.2.5 Softmax 模块

Softmax 模块实现对输入矩阵的 Softmax 计算。该模块通过例化一系列带有量化的行 Softmax 模块,得到每一个元素按行 Softmax 的后的矩阵输出。

3.2.6 行 Softmax 模块

行 Softmax 模块实现对输入一行向量的 Softmax 结果。该模块具有子模块 Min,通过树状的逐个对比的方式得到输入向量的最小值,然后对每一个元素计算与最小值之差的平方,并通过量化模块,得到输出的 Softmax 结果。

3.2.7 寄存器模块

异步低电平有效复位的 Dff 模块,用于流水线中各级的寄存器。

3.3 Testbench

底层模块的 Testbench,通过读取 golden model 生成的测试数据作为输入,根据测试时给定的测试次数确定仿真周期数,生成对应的时钟,每个周期将将输出结果分别以二进制和十进制浮点数的形式输出到文件中,以供后续对比验证。

4. 结果与讨论

4.1 关键路径

结合模块设计和综合结果:

第一级流水包括转置和一个 8×4 矩阵和 4×8 矩阵的乘法,数据的最长路径为转置模块、一个4 输入的 MAC 模块和量化模块。MAC 路径包括 1 个乘法器,加法器树中的 2 个加法器。

第二级流水包括一个 8×8 矩阵的 Softmax 计算,数据的最长路径为一个最小值模块,一个加法(减法)器,一个乘法器和一个量化模块。最小值模块的比较器树包括3个比较选择器。

第三级流水包括一个 8×8 矩阵和 8×4 矩阵的乘法,数据的最长路径位一个 8 输入的 MAC 模块和量化模块。MAC 路径包括 1 个乘法器,加法器树中的 3 个加法器。

量化模块在各级流水线中相同,带来的延迟也相同,因此不影响关键路径。此处转置模块是简单的地址变换,连线直接相连,没有逻辑门带来的延迟。对比 MAC 部分的逻辑门数量,第三级的延迟要大于第一级。对比第二级和第三级,抛去量化模块和两者都有的一个乘法器,第二级有一个加法器和三个比较选择器,第三级有三个加法器,且比较选择器包含一个 MUX 和一个比较器,第二级的逻辑门延迟应该更大。因此关键路径为第二级流水即 Softmax 操作部分。若考虑面积开销,则是第一级最大,第三级次之,第二级最小。

4.2 延迟与吞吐率

本设计为三级流水线结构,数据输入后经过三个时间周期得到对应的结果,因此计算延迟为三个时钟周期。流水线启动后,每周期输入一组 QKV 输出进行处理,吞吐率为 3×8×4×16bit = 1536bit 每周期。考虑到启动延迟,假设执行 N 次注意力计算操作,吞吐率 𝑁/(𝑁+3)。

About

A tiny attention module

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published