diff --git a/tornado-assembly/src/bin/tornado-test b/tornado-assembly/src/bin/tornado-test index 56eeac654f..4e2657e43b 100755 --- a/tornado-assembly/src/bin/tornado-test +++ b/tornado-assembly/src/bin/tornado-test @@ -226,6 +226,8 @@ __TEST_THE_WORLD__ = [ TestEntry("uk.ac.manchester.tornado.unittests.pointers.TestCopyDevicePointers"), TestEntry("uk.ac.manchester.tornado.unittests.tensors.TestTensorAPIWithOnnx"), TestEntry("uk.ac.manchester.tornado.unittests.memory.MemoryConsumptionTest"), + TestEntry("uk.ac.manchester.tornado.unittests.gpullama.TestTransformerKernelsUnit"), + TestEntry("uk.ac.manchester.tornado.unittests.gpullama.TestTransformerKernelsFused"), ## Test for function calls - We force not to inline methods TestEntry(testName="uk.ac.manchester.tornado.unittests.tasks.TestMultipleFunctions", diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/gpullama/GPULlama3Kernels.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/gpullama/GPULlama3Kernels.java new file mode 100644 index 0000000000..2563aaa2e5 --- /dev/null +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/gpullama/GPULlama3Kernels.java @@ -0,0 +1,1115 @@ +/* + * Copyright (c) 2025, APT Group, Department of Computer Science, + * The University of Manchester. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +package uk.ac.manchester.tornado.unittests.gpullama; + +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.annotations.Parallel; +import uk.ac.manchester.tornado.api.math.TornadoMath; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.Int8Array; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; + +public class GPULlama3Kernels { + + /** + * Performs RMS (Root Mean Square) normalization using parallel reduction. This is the first phase of RMS normalization that computes the variance and scaling factor across all work groups. + * + * Algorithm: 1. Each thread computes square of its input element 2. Work group performs parallel reduction of squares 3. Partial sums stored per work group 4. First thread combines all partial + * sums and computes normalization factor + * + * @param context + * Kernel execution context + * @param output + * Array to store partial sums and final normalization factor + * @param x + * Input array to normalize + * @param size + * Number of elements to process + * @param ermsNorm + * Epsilon value squared for numerical stability + * @param localMemSize + * Size of local memory allocation (must match work group size) + */ + public static void reductionOneBlockWithLayer(KernelContext context, FloatArray output, FloatArray x, int size, float ermsNorm, int localMemSize) { + int gid = context.globalIdx; + int lid = context.localIdx; + int groupId = context.groupIdx; + int groupSize = context.localGroupSizeX; + + // Allocate local memory with the provided size + float[] localX = context.allocateFloatLocalArray(localMemSize); + + // Load input value and compute square + if (gid < size) { + localX[lid] = x.get(gid); + localX[lid] = localX[lid] * localX[lid]; + } else { + localX[lid] = 0.0f; + } + + // Perform parallel reduction within the work group + for (int stride = (groupSize / 2); stride > 0; stride /= 2) { + context.localBarrier(); + if (lid < stride) { + localX[lid] += localX[lid + stride]; + } + } + + // Each workgroup stores its partial sum in a different location + if (lid == 0) { + // Store the partial sum from each workgroup + output.set(groupId + 1, localX[0]); + } + + // Only the first thread in the first workgroup computes the final normalization factor + if (gid == 0) { + // Combine partial sums from all workgroups + float ss = 0.0f; + for (int i = 1; i <= (size / localMemSize); i++) { // Assuming 8 workgroups + ss += output.get(i); + } + + ss /= size; + ss += ermsNorm; + ss = 1.0f / TornadoMath.sqrt(ss); + output.set(0, ss); // Store the final scale factor + } + } + + /** + * Applies the computed normalization factor to input and weight elements. This is the second phase of RMS normalization. + * + * Formula: output[i] = weight[i] * (normalizationFactor * x[i]) + * + * @param context + * Kernel execution context + * @param output + * Array for normalized output + * @param x + * Input values to normalize + * @param weights + * Weight values for each element + * @param temp + * Temporary array containing normalization factor at index 0 + */ + public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray output, FloatArray x, FloatArray weights, FloatArray temp) { + int gid = context.globalIdx; + + float ss = temp.get(0); + output.set(gid, weights.get(gid) * (ss * x.get(gid))); + } + + /** + * Copies keys and values into the key-value cache for attention computation. Enables efficient access to past key-value pairs during autoregressive generation. + * + * Cache layout: [layer][position][dimension] - Each layer has its own key and value cache - Each position in sequence has a key and value vector + * + * @param destKeyCache + * Destination array for key cache + * @param srcKey + * Source keys to copy + * @param destValueCache + * Destination array for value cache + * @param srcValue + * Source values to copy + * @param positioNlayer + * Array containing current position + * @param kvDim + * Dimension of key/value vectors + * @param layer + * Current transformer layer index + * @param contextLength + * Maximum sequence length + */ + public static void copyToCache(FloatArray destKeyCache, FloatArray srcKey, FloatArray destValueCache, FloatArray srcValue, IntArray positioNlayer, int kvDim, int layer, int contextLength) { + + int position = positioNlayer.get(0); + int loff = layer * contextLength * kvDim; + int destOffset = loff + position * kvDim; + + for (@Parallel int i = 0; i < srcValue.getSize(); i++) { + destKeyCache.set(destOffset + i, srcKey.get(i)); + destValueCache.set(destOffset + i, srcValue.get(i)); + } + } + + public static void splitQKV(FloatArray qkv, FloatArray q, FloatArray k, FloatArray v, int dimQ, int dimKV) { + int totalSize = dimQ + 2 * dimKV; + + for (@Parallel int i = 0; i < totalSize; i++) { + if (i < dimQ) { + // Copy to Q + q.set(i, qkv.get(i)); + } else if (i < dimQ + dimKV) { + // Copy to K + int kIndex = i - dimQ; + k.set(kIndex, qkv.get(i)); + } else { + // Copy to V + int vIndex = i - dimQ - dimKV; + v.set(vIndex, qkv.get(i)); + } + } + } + + /** + * Applies Rotary Position Encoding (RoPE) to query and key vectors. RoPE rotates pairs of dimensions based on their position in the sequence, enabling the model to learn relative positional + * information. + * + * For each pair of dimensions (2*i, 2*i+1): - Compute rotation angle based on position and frequency - Apply 2D rotation to the pair + * + * @param context + * Kernel execution context + * @param positionHolder + * Array containing current position + * @param sq + * Query vectors to rotate + * @param sk + * Key vectors to rotate + * @param kv_dim + * Dimension of key/value vectors + * @param head_size + * Dimension of each attention head + */ + public static void ropeRotation(KernelContext context, IntArray positionHolder, FloatArray sq, FloatArray sk, int kv_dim, int head_size) { + int i = context.globalIdx * 2; + + int head_dim = i % head_size; + // 50000.0f vs 10000.0f + float freq = 1.0f / TornadoMath.pow(50000.0f, head_dim / (float) head_size); + float val = positionHolder.get(0) * freq; + float fcr = TornadoMath.cos(val); + float fci = TornadoMath.sin(val); + + int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only + + // Rotate query vector + float v0q = sq.get(i); + float v1q = sq.get(i + 1); + sq.set(i, v0q * fcr - v1q * fci); + sq.set(i + 1, v0q * fci + v1q * fcr); + + // Rotate key vector if needed + if (rotn > 1 && i < sk.getSize()) { + float v0k = sk.get(i); + float v1k = sk.get(i + 1); + sk.set(i, v0k * fcr - v1k * fci); + sk.set(i + 1, v0k * fci + v1k * fcr); + } + + } + + public static void ropeRotationPhi3(KernelContext context, IntArray positionHolder, FloatArray sq, FloatArray sk, int kv_dim, int head_size) { + int idx = context.globalIdx; + + // For Phi3, we process pairs with offset of head_size/2 + int dimHalf = head_size / 2; + + // Each thread processes one dimension pair + if (idx >= dimHalf) { + return; + } + + int position = positionHolder.get(0); + + // Calculate frequency for this dimension + float freq = 1.0f / TornadoMath.pow(10000.0f, (float) (idx * 2) / (float) head_size); + float val = position * freq; + float fcr = TornadoMath.cos(val); + float fci = TornadoMath.sin(val); + + // Process all heads + int totalDim = sq.getSize(); + for (int base = 0; base < totalDim; base += head_size) { + // Skip if we're beyond the bounds + if (base + idx >= totalDim || base + idx + dimHalf >= totalDim) { + break; + } + + // Rotate query + float v0 = sq.get(base + idx); + float v1 = sq.get(base + idx + dimHalf); + sq.set(base + idx, v0 * fcr - v1 * fci); + sq.set(base + idx + dimHalf, v0 * fci + v1 * fcr); + + // Rotate key if within kv_dim + if (base < kv_dim && base + idx < sk.getSize() && base + idx + dimHalf < sk.getSize()) { + float k0 = sk.get(base + idx); + float k1 = sk.get(base + idx + dimHalf); + sk.set(base + idx, k0 * fcr - k1 * fci); + sk.set(base + idx + dimHalf, k0 * fci + k1 * fcr); + } + } + } + + /** + * Computes attention for a single head. Implements scaled dot-product attention with softmax normalization. + * + * Steps: 1. Compute attention scores: Q·K / sqrt(head_size) 2. Apply softmax (with max subtraction for numerical stability) 3. Compute weighted sum of values + * + * @param allQ + * All query vectors + * @param key_cache + * Cached keys + * @param value_cache + * Cached values + * @param allXb + * Output buffer + * @param h + * Head index to process + * @param headSize + * Dimension per head + * @param kvDim + * Key/value dimension + * @param kvMul + * Key multiplier for grouped attention + * @param loff + * Layer offset in cache + * @param pos + * Current position + * @param wrapAtt + * Attention weights buffer + */ + private static void processHeadTornado(FloatArray allQ, FloatArray key_cache, FloatArray value_cache, FloatArray allXb, int h, int headSize, int kvDim, int kvMul, long loff, int pos, + FloatArray wrapAtt) { + + // Base index for this head's attention weights + int headOffset = h * (pos + 1); + + // STEP 1: Calculate attention scores for all timesteps + for (int t = 0; t <= pos; t++) { + int kvHeadIdx = h / kvMul; + int keyOffset = (int) (loff + t * kvDim + kvHeadIdx * headSize); + + float score = 0.0f; + for (int i = 0; i < headSize; i++) { + score += allQ.get(h * headSize + i) * key_cache.get(keyOffset + i); + } + score = score / TornadoMath.sqrt(headSize); + + // Store in attention buffer + wrapAtt.set(headOffset + t, score); + } + + // STEP 2: Find max score for softmax stability + float maxScore = wrapAtt.get(headOffset); + for (int t = 1; t <= pos; t++) { + float val = wrapAtt.get(headOffset + t); + if (val > maxScore) { + maxScore = val; + } + } + + // STEP 3: Compute exponentials and sum + float sum = 0.0f; + for (int t = 0; t <= pos; t++) { + int idx = headOffset + t; + float expScore = TornadoMath.exp(wrapAtt.get(idx) - maxScore); + wrapAtt.set(idx, expScore); + sum += expScore; + } + + // STEP 4: Normalize + float normFactor = (sum > 0.0f) ? (1.0f / sum) : (1.0f / (pos + 1)); + for (int t = 0; t <= pos; t++) { + int idx = headOffset + t; + wrapAtt.set(idx, wrapAtt.get(idx) * normFactor); + } + + // STEP 5: Compute weighted sum of values for each dimension + for (int i = 0; i < headSize; i++) { + float weightedSum = 0.0f; + for (int t = 0; t <= pos; t++) { + int kvHeadIdx = h / kvMul; + int valueOffset = (int) (loff + t * kvDim + kvHeadIdx * headSize); + weightedSum += wrapAtt.get(headOffset + t) * value_cache.get(valueOffset + i); + } + allXb.set(h * headSize + i, weightedSum); + } + } + + public static void processHeadsFlashAttention(KernelContext context, FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, + IntArray positionHolder, int layer, int contextLength) { + + // Thread and workgroup information + int tid = context.localIdx; + int h = context.groupIdx; // Each workgroup processes one head + int localSize = context.localGroupSizeX; + + // Early exit if this workgroup is beyond our head count + // This relies on the kernel being launched with nHeads workgroups. + if (h >= nHeads) { + return; + } + + int pos = positionHolder.get(0); + int loff = layer * contextLength * kvDim; + int kvHeadIdx = h / kvMul; + int BLOCK_SIZE_C = 8; + + // Allocate shared memory for tiled computation + float[] q_shared = context.allocateFloatLocalArray(headSize); + float[] k_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize); + float[] v_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize); + float[] s_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C); + float[] shared_tile_max_holder = context.allocateFloatLocalArray(1); // FIX: For broadcasting tile max + + // Thread-local accumulators for online softmax + float maxScore = Float.NEGATIVE_INFINITY; + float sumExp = 0.0f; + + // Thread-local output accumulation + float[] output = new float[headSize]; + for (int i = 0; i < headSize; i++) { + output[i] = 0.0f; + } + + // Load query vector into shared memory + for (int i = tid; i < headSize; i += localSize) { + q_shared[i] = q.get(h * headSize + i); + } + + context.localBarrier(); + + // Process sequence in tiles + for (int tileC = 0; tileC <= pos; tileC += BLOCK_SIZE_C) { + int tileEnd = Math.min(tileC + BLOCK_SIZE_C - 1, pos); + + // Load key and value vectors for this tile + // Each thread loads a portion of the K and V vectors for the tile + for (int tIdxInSeq = tileC + tid; tIdxInSeq <= tileEnd; tIdxInSeq += localSize) { + int k_v_idx_in_tile = tIdxInSeq - tileC; // 0, 1, 2, or 3 for this tile + int tileMemOffset = k_v_idx_in_tile * headSize; + for (int d = 0; d < headSize; d++) { + int kvCacheAbsolutePos = tIdxInSeq; + int kvOffset = loff + kvCacheAbsolutePos * kvDim + kvHeadIdx * headSize + d; + k_tile[tileMemOffset + d] = key_cache.get(kvOffset); + v_tile[tileMemOffset + d] = value_cache.get(kvOffset); + } + } + + context.localBarrier(); + + // Compute attention scores for this tile + // Each thread computes one score for the tile + for (int tIdxInSeq = tileC + tid; tIdxInSeq <= tileEnd; tIdxInSeq += localSize) { + int score_idx_in_tile = tIdxInSeq - tileC; // 0, 1, 2, or 3 for this tile + + float score = 0.0f; + for (int d = 0; d < headSize; d++) { + score += q_shared[d] * k_tile[score_idx_in_tile * headSize + d]; + } + score /= TornadoMath.sqrt(headSize); + s_tile[score_idx_in_tile] = score; + } + + context.localBarrier(); + + // Find max score in this tile (all threads compute it redundantly over the small s_tile) + float tileLocalMax = Float.NEGATIVE_INFINITY; + for (int i = 0; i <= tileEnd - tileC; i++) { // Iterate over valid scores in s_tile + if (s_tile[i] > tileLocalMax) { + tileLocalMax = s_tile[i]; + } + } + + // Broadcast max to all threads via shared memory + if (tid == 0) { + shared_tile_max_holder[0] = tileLocalMax; // FIX: Use dedicated holder + } + context.localBarrier(); + float currentTileMax = shared_tile_max_holder[0]; // FIX: Read from dedicated holder + + // Determine if we need to rescale previous results + float newMax = Math.max(maxScore, currentTileMax); + if (newMax != maxScore && maxScore != Float.NEGATIVE_INFINITY) { + float scale = TornadoMath.exp(maxScore - newMax); + sumExp *= scale; + for (int d = 0; d < headSize; d++) { + output[d] *= scale; + } + } + maxScore = newMax; + + // Process each key-value pair using original scores from s_tile + // All threads iterate over all scores in the current tile + for (int t_idx_in_s_tile = 0; t_idx_in_s_tile <= tileEnd - tileC; t_idx_in_s_tile++) { + // s_tile[t_idx_in_s_tile] now correctly refers to the original score + float expScore = TornadoMath.exp(s_tile[t_idx_in_s_tile] - maxScore); + sumExp += expScore; + + for (int d = 0; d < headSize; d++) { + output[d] += expScore * v_tile[t_idx_in_s_tile * headSize + d]; + } + } + context.localBarrier(); // Ensure all threads finish with s_tile, k_tile, v_tile before next tile load + } + + // Normalize and write final results + float normFactor = (sumExp > 0.0f) ? (1.0f / sumExp) : 0.0f; // Avoid division by zero, return 0 if sumExp is 0 + for (int d = tid; d < headSize; d += localSize) { + xb.set(h * headSize + d, output[d] * normFactor); + } + } + + /** + * Same as processHeadsFlashAttention but with some optimizations that seem to lower attention's execution time, especially in larger models. + */ + public static void processHeadsFlashAttentionOpt(KernelContext context, FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, + IntArray positionHolder, int layer, int contextLength) { + + // Thread and workgroup information + int tid = context.localIdx; + int h = context.groupIdx; // Each workgroup processes one head + int localSize = context.localGroupSizeX; + + // Early exit if this workgroup is beyond our head count + // This relies on the kernel being launched with nHeads workgroups. + if (h >= nHeads) { + return; + } + + int pos = positionHolder.get(0); + int loff = layer * contextLength * kvDim; + int kvHeadIdx = h / kvMul; + int BLOCK_SIZE_C = 32; + + // Allocate shared memory for tiled computation + float[] q_shared = context.allocateFloatLocalArray(headSize); + float[] k_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize); + float[] v_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize); + float[] s_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C); + float[] shared_tile_max_holder = context.allocateFloatLocalArray(1); // FIX: For broadcasting tile max + + // Thread-local accumulators for online softmax + float maxScore = Float.NEGATIVE_INFINITY; + float sumExp = 0.0f; + + // Thread-local output accumulation + float[] output = new float[headSize]; + for (int i = 0; i < headSize; i++) { + output[i] = 0.0f; + } + + // Load query vector into shared memory + for (int i = tid; i < headSize; i += localSize) { + q_shared[i] = q.get(h * headSize + i); + } + + context.localBarrier(); + + // Process sequence in tiles + for (int tileC = 0; tileC <= pos; tileC += BLOCK_SIZE_C) { + int tileEnd = Math.min(tileC + BLOCK_SIZE_C - 1, pos); + + // Load key and value vectors for this tile + // Each thread loads a contiguous block of elements + int totalElements = (tileEnd - tileC + 1) * headSize; + int elementsPerThread = (totalElements + localSize - 1) / localSize; + int startElem = tid * elementsPerThread; + int endElem = Math.min(startElem + elementsPerThread, totalElements); + + for (int globalElemIdx = startElem; globalElemIdx < endElem; globalElemIdx++) { + // Convert flat index to (sequence_pos, dimension) + int seqIdx = globalElemIdx / headSize; + int dimIdx = globalElemIdx % headSize; + + int tIdxInSeq = tileC + seqIdx; + int tileMemOffset = seqIdx * headSize + dimIdx; + + int kvCacheAbsolutePos = tIdxInSeq; + int kvOffset = loff + kvCacheAbsolutePos * kvDim + kvHeadIdx * headSize + dimIdx; + + k_tile[tileMemOffset] = key_cache.get(kvOffset); + v_tile[tileMemOffset] = value_cache.get(kvOffset); + } + + context.localBarrier(); + + // Compute attention scores for this tile + // Each thread computes one score for the tile + for (int tIdxInSeq = tileC + tid; tIdxInSeq <= tileEnd; tIdxInSeq += localSize) { + int score_idx_in_tile = tIdxInSeq - tileC; // 0, 1, 2, or 3 for this tile + + float score = 0.0f; + for (int d = 0; d < headSize; d++) { + score += q_shared[d] * k_tile[score_idx_in_tile * headSize + d]; + } + score /= TornadoMath.sqrt(headSize); + s_tile[score_idx_in_tile] = score; + } + + context.localBarrier(); + + // Allocate shared memory for reduction (needs to be power of 2) + int reductionSize = 1024; // Should be >= BLOCK_SIZE_C and power of 2 + float[] reduction_shared = context.allocateFloatLocalArray(reductionSize); + + // Step 1: Each thread finds max of its assigned subset + int itemsPerThread = (BLOCK_SIZE_C + localSize - 1) / localSize; + int startIdx = tid * itemsPerThread; + int endIdx = Math.min(startIdx + itemsPerThread, tileEnd - tileC + 1); + + float threadLocalMax = Float.NEGATIVE_INFINITY; + for (int i = startIdx; i < endIdx; i++) { + if (s_tile[i] > threadLocalMax) { + threadLocalMax = s_tile[i]; + } + } + + // Step 2: Store each thread's local max in shared memory + reduction_shared[tid] = threadLocalMax; + context.localBarrier(); + + // Step 3: Parallel reduction tree + for (int stride = localSize / 2; stride > 0; stride /= 2) { + if (tid < stride && tid + stride < localSize) { + reduction_shared[tid] = Math.max(reduction_shared[tid], reduction_shared[tid + stride]); + } + context.localBarrier(); + } + + // Step 4: Thread 0 now has the final max + float currentTileMax = reduction_shared[0]; + + // Determine if we need to rescale previous results + float newMax = Math.max(maxScore, currentTileMax); + if (newMax != maxScore && maxScore != Float.NEGATIVE_INFINITY) { + float scale = TornadoMath.exp(maxScore - newMax); + sumExp *= scale; + for (int d = 0; d < headSize; d++) { + output[d] *= scale; + } + } + maxScore = newMax; + + // Process each key-value pair using original scores from s_tile + // All threads iterate over all scores in the current tile + for (int t_idx_in_s_tile = 0; t_idx_in_s_tile <= tileEnd - tileC; t_idx_in_s_tile++) { + // s_tile[t_idx_in_s_tile] now correctly refers to the original score + float expScore = TornadoMath.exp(s_tile[t_idx_in_s_tile] - maxScore); + sumExp += expScore; + + for (int d = 0; d < headSize; d++) { + output[d] += expScore * v_tile[t_idx_in_s_tile * headSize + d]; + } + } + context.localBarrier(); // Ensure all threads finish with s_tile, k_tile, v_tile before next tile load + } + + float normFactor = (sumExp > 0.0f) ? (1.0f / sumExp) : 0.0f; + + int dimsPerThread = (headSize + localSize - 1) / localSize; + int startDim = tid * dimsPerThread; + int endDim = Math.min(startDim + dimsPerThread, headSize); + int baseOffset = h * headSize + startDim; + + // Process 4 elements at a time when possible + int vectorEnd = startDim + ((endDim - startDim) & ~3); // Round down to multiple of 4 + + // Unrolled loop for better instruction-level parallelism + for (int d = startDim; d < vectorEnd; d += 4) { + int offset = d - startDim; + xb.set(baseOffset + offset, output[d] * normFactor); + xb.set(baseOffset + offset + 1, output[d + 1] * normFactor); + xb.set(baseOffset + offset + 2, output[d + 2] * normFactor); + xb.set(baseOffset + offset + 3, output[d + 3] * normFactor); + } + + // Handle remaining elements (0-3 elements) + for (int d = vectorEnd; d < endDim; d++) { + xb.set(h * headSize + d, output[d] * normFactor); + } + } + + /** + * Performs optimized matrix-vector multiplication where each work group processes one row of the matrix. + * + * Algorithm: 1. Each work group handles one output dimension 2. Threads in work group compute partial dot products 3. Parallel reduction yields final row result + * + * @param context + * Kernel execution context + * @param x + * Input vector + * @param hb + * Output vector + * @param w + * Weight matrix (row-major) + * @param n + * Input dimension + * @param d + * Output dimension + * @param localWorkGroupSize + * Number of threads per work group + */ + public static void matrixVectorGeneric(KernelContext context, FloatArray x, FloatArray hb, FloatArray w, int n, int d, int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= d) { + return; + } + float sum = matrixVectorRowMajorOptimized(context, localSize, x, w, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + hb.set(rowId, sum); + } + } + + // @formatter:off + public static void matrixVectorGeneric( + KernelContext context, + FloatArray x, + FloatArray hb, // output + HalfFloatArray w, + int dim1, // inner loop + int dim0, // outer loop + int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= dim0) { + return; + } + float sum = matrixVectorRowMajorOptimized(context, localSize, x, w, dim1); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + hb.set(rowId, sum); + } + } + // @formatter:on + + /** + * Matrix-vector multiplication with residual connection. Combines regular matrix multiplication with addition of existing values. + * + * Formula: hb[i] = hb[i] + w[i]·x + * + * @param context + * Kernel execution context + * @param x + * Input vector + * @param hb + * Input/output vector (contains residual, receives result) + * @param w + * Weight matrix + * @param n + * Input dimension + * @param d + * Output dimension + * @param localWorkGroupSize + * Work group size + */ + public static void matrixVectorGenericWithResidual(KernelContext context, FloatArray x, FloatArray hb, HalfFloatArray w, int n, int d, int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= d) { + return; + } + + float sum = matrixVectorRowMajorOptimized(context, localSize, x, w, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + float result = hb.get(rowId) + sum; + hb.set(rowId, result); + } + } + + /** + * Fused feed-forward network with SiLU activation and GLU gating. Implements the SwiGLU variant used in LLaMA-style models. + * + * Formula: FFN(x) = SiLU(x·W1) ⊙ (x·W3) where ⊙ denotes element-wise multiplication + * + * @param context + * Kernel execution context + * @param x + * Input vector + * @param hb + * Output buffer + * @param w1 + * First feed-forward weight matrix + * @param w3 + * Third feed-forward weight matrix (gate) + * @param n + * Input dimension + * @param d + * Hidden dimension + * @param localWorkGroupSize + * Work group size + */ + public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, FloatArray x, FloatArray hb, HalfFloatArray w1, HalfFloatArray w3, int n, int d, int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + + if (rowId >= d) { + return; + } + + float sum1 = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, w1, n); + float sum3 = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, w3, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + float silu = siluActivation(sum1); // Using the new SiLU method + float result = silu * sum3; + hb.set(rowId, result); + } + } + + /** + * Gaussian Error Linear Unit (GELU) activation function. Approximation formula: GELU(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³))) + * + * @param x + * Input value + * @return Activated value + */ + public static float geluActivation(float x) { + float x3 = x * x * x; + return 0.5f * x * (1.0f + TornadoMath.tanh((0.797885f * (x + 0.044715f * x3)))); + } + + /** + * Sigmoid-weighted Linear Unit (SiLU) activation function. Also known as Swish activation. + * + * Formula: SiLU(x) = x * σ(x) = x / (1 + e^(-x)) + * + * @param x + * Input value + * @return Activated value + */ + public static float siluActivation(float x) { + return x * (1.0f / (1.0f + TornadoMath.exp(-x))); + } + + /** + * Optimized row-major matrix-vector multiplication for a single row. Uses parallel reduction within a work group to compute one dot product. + * + * Algorithm: 1. Each thread computes partial dot product 2. Partial results stored in local memory 3. Tree-based reduction combines partial results 4. Returns final dot product for the row + * + * @param context + * Kernel execution context + * @param localSize + * Work group size + * @param x + * Input vector + * @param w + * Weight matrix row + * @param n + * Input dimension + * @return Dot product result for this row + */ + public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, FloatArray x, FloatArray w, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Allocate local memory for reduction + float[] localSum = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + + // Each thread calculates partial dot product + float partialSum = 0.0f; + for (int j = localId; j < n; j += localSize) { + int matrixIdx = rowOffset + j; + partialSum += w.get(matrixIdx) * x.get(j); + } + + // Store partial sum in local memory + localSum[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction within workgroup + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + return localSum[0]; + } + + public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, FloatArray x, HalfFloatArray w, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Allocate local memory for reduction + float[] localSum = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + + // Each thread calculates partial dot product + float partialSum = 0.0f; + for (int j = localId; j < n; j += localSize) { + int matrixIdx = rowOffset + j; + partialSum += w.get(matrixIdx).getFloat32() * x.get(j); + } + + // Store partial sum in local memory + localSum[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction within workgroup + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + return localSum[0]; + } + + // Second kernel - Combines partial sums and computes final normalization + public static void reductionFinalNormalization(KernelContext context, FloatArray output, int size, float ermsNorm) { + int gid = context.globalIdx; + + // Only one thread needs to perform this calculation + if (gid == 0) { + // Combine partial sums from all workgroups + float ss = 0.0f; + for (int i = 1; i < output.getSize(); i++) { // Fixed bounds to avoid out of bounds + ss += output.get(i); + } + + ss /= size; + ss += ermsNorm; + ss = 1.0f / TornadoMath.sqrt(ss); + output.set(0, ss); // Store the final scale factor + } + } + + public static void splitGateUpAndSiLU(FloatArray hb, FloatArray hbG, FloatArray hbU, int hiddenDim) { + // Copy and apply SiLU to gate in one pass + for (@Parallel int i = 0; i < hiddenDim; i++) { + float gateVal = hb.get(i); + float upVal = hb.get(hiddenDim + i); + + // Apply SiLU to gate + float siluGate = gateVal / (1.0f + TornadoMath.exp(-gateVal)); + + // Store activated gate and multiply with up + hbG.set(i, siluGate); + hbU.set(i, siluGate * upVal); + } + } + + public static void addInPlace(FloatArray arrayA, FloatArray arrayB, int size) { + // Element-wise addition: arrayA[i] = arrayA[i] + arrayB[i] + for (@Parallel int i = 0; i < size; i++) { + float result = arrayA.get(i) + arrayB.get(i); + arrayA.set(i, result); + } + } + + /** + * Matrix-vector multiplication for Q8_0 quantized weights. + * + * @param context + * Kernel context + * @param x + * Input activations (FloatArray) + * @param output + * Output array (FloatArray) + * @param weightsQ + * Quantized weights (Int8Array) - from Q8_0QuantizedTensor.getQuants() + * @param weightScales + * Scale factors (HalfFloatArray) - from Q8_0QuantizedTensor.getScales() + * @param dim1 + * Input dimension (n - number of columns) + * @param dim0 + * Output dimension (d - number of rows) + * @param localWorkGroupSize + * Local workgroup size + */ + public static void matrixVectorGeneric(KernelContext context, FloatArray x, FloatArray output, Int8Array weightsQ, HalfFloatArray weightScales, int dim1, int dim0, int localWorkGroupSize) { + + // One row per workgroup + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Early exit if this workgroup is beyond output dimension + if (rowId >= dim0) { + return; + } + + float sum = matrixVectorRowMajorOptimizedQ8_0(context, localWorkGroupSize, x, weightsQ, weightScales, dim1); + + // Thread 0 writes the result + if (localId == 0) { + output.set(rowId, sum); + } + } + + /** + * Helper method to compute dot product for a single row with Q8_0 quantized weights. Uses 4-way unrolling for better performance. + */ + public static float matrixVectorRowMajorOptimizedQ8_0(KernelContext context, int localSize, FloatArray x, Int8Array weightsQ, HalfFloatArray weightScales, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + int blockSize = 32; + + // Allocate local memory for reduction + float[] localSums = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + int scalesRowOffset = rowId * (n / blockSize); + + // 4-way unrolling + float partialSum1 = 0.0f; + float partialSum2 = 0.0f; + float partialSum3 = 0.0f; + float partialSum4 = 0.0f; + + // Main loop - process 4 elements at a time + for (int j = localId * 4; j < n - 3; j += localSize * 4) { + int blockIdx = j / blockSize; + float scale = weightScales.get(scalesRowOffset + blockIdx).getFloat32(); + + // Dequantize and multiply + partialSum1 += ((float) weightsQ.get(rowOffset + j) * scale) * x.get(j); + partialSum2 += ((float) weightsQ.get(rowOffset + j + 1) * scale) * x.get(j + 1); + partialSum3 += ((float) weightsQ.get(rowOffset + j + 2) * scale) * x.get(j + 2); + partialSum4 += ((float) weightsQ.get(rowOffset + j + 3) * scale) * x.get(j + 3); + } + + float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4; + + // Handle remaining elements + for (int j = ((n / 4) * 4) + localId; j < n; j += localSize) { + int blockIdx = j / blockSize; + float scale = weightScales.get(scalesRowOffset + blockIdx).getFloat32(); + partialSum += ((float) weightsQ.get(rowOffset + j) * scale) * x.get(j); + } + + // Store partial sum + localSums[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSums[localId] += localSums[localId + stride]; + } + context.localBarrier(); + } + + return localSums[0]; + } + + public static void matrixVectorGenericWithResidual(KernelContext context, FloatArray x, FloatArray hb, Int8Array w_quants, HalfFloatArray w_scales, int n, int d, int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= d) { + return; + } + + float sum = matrixVectorRowMajorOptimizedQ8_0(context, localSize, x, w_quants, w_scales, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + float result = hb.get(rowId) + sum; + hb.set(rowId, result); + } + } + + public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, FloatArray x, FloatArray hb, Int8Array w1_quants, HalfFloatArray w1_scales, Int8Array w3_quants, + HalfFloatArray w3_scales, int n, int d, int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + + if (rowId >= d) { + return; + } + + float sum1 = matrixVectorRowMajorOptimizedQ8_0(context, localWorkGroupSize, x, w1_quants, w1_scales, n); + float sum3 = matrixVectorRowMajorOptimizedQ8_0(context, localWorkGroupSize, x, w3_quants, w3_scales, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + float silu = siluActivation(sum1); // Using the new SiLU method + float result = silu * sum3; + hb.set(rowId, result); + } + } + + /** + * Orchestrates parallel multi-head attention computation across all heads. Each head processes attention independently in parallel. + * + * Attention computation: 1. Compute attention scores (Q·K) 2. Apply softmax for attention weights 3. Compute weighted sum of values (attention·V) + * + * @param q + * Query vectors for all heads + * @param key_cache + * Cached key vectors + * @param value_cache + * Cached value vectors + * @param xb + * Output buffer for attention results + * @param nHeads + * Number of attention heads + * @param headSize + * Dimension of each head + * @param kvDim + * Total key/value dimension + * @param kvMul + * Key/value head multiplier for grouped-query attention + * @param seqLen + * Current sequence length + * @param positionHolder + * Array containing position and layer info + * @param wrapAtt + * Buffer for attention weights + * @param layer + * Current transformer layer + * @param contextLength + * Maximum context length + */ + public static void processHeadsParallel(FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, int seqLen, + IntArray positionHolder, FloatArray wrapAtt, int layer, int contextLength) { + + int pos = positionHolder.get(0); + int loff = layer * contextLength * kvDim; + + // Parallelize computation across attention heads + for (@Parallel int h = 0; h < nHeads; h++) { + // Process each head in parallel + processHeadTornado(q, key_cache, value_cache, xb, h, headSize, kvDim, kvMul, loff, pos, wrapAtt); + } + } + +} diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/gpullama/TestTransformerKernelsFused.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/gpullama/TestTransformerKernelsFused.java new file mode 100644 index 0000000000..80a765a61c --- /dev/null +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/gpullama/TestTransformerKernelsFused.java @@ -0,0 +1,809 @@ +/* + * Copyright (c) 2025, APT Group, Department of Computer Science, + * The University of Manchester. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +package uk.ac.manchester.tornado.unittests.gpullama; + +import static org.junit.Assert.assertEquals; + +import java.util.Random; + +import org.junit.Before; +import org.junit.Test; + +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.WorkerGrid1D; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; +import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException; +import uk.ac.manchester.tornado.api.types.HalfFloat; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; +import uk.ac.manchester.tornado.unittests.common.TornadoTestBase; + +/** + * Fused pipeline tests with full numerical verification. + * Tests progressive combinations of kernels matching the LlamaFP16FFNLayers pipeline. + * + *

How to run:

+ * + * tornado-test -V org.beehive.gpullama3.tornadovm.kernels.TestTransformerKernelsFused + * + */ +public class TestTransformerKernelsFused extends TornadoTestBase { + + // Model configuration + private static final int DIM = 256; + private static final int KV_DIM = 64; + private static final int HIDDEN_DIM = 512; + private static final int N_HEADS = 4; + private static final int HEAD_SIZE = DIM / N_HEADS; + private static final int KV_MUL = N_HEADS / (KV_DIM / HEAD_SIZE); + private static final int CONTEXT_LENGTH = 128; + private static final float RMS_NORM_EPS = 1e-5f; + private static final int LOCAL_SIZE = 64; + private static final int LOCAL_SIZE_RMS = 256; + + // Tolerances + private static final float EPSILON_FP32 = 1e-4f; + private static final float EPSILON_FP16 = 0.05f; + private static final float EPSILON_ACCUMULATED = 0.15f; // For multi-stage pipelines + + private Random random; + private KernelContext context; + + // State arrays + private FloatArray x, xb, xb2, q, k, v; + private FloatArray keyCache, valueCache, att, hb; + private FloatArray temp, tempFFN; + private IntArray positionHolder; + + // Weights + private FloatArray rmsAttWeight, rmsFfnWeight; + private HalfFloatArray wq, wk, wv, wo, w1, w2, w3; + + @Before + public void setUp() { + random = new Random(42); + context = new KernelContext(); + initializeArrays(); + initializeWeights(); + } + + private void initializeArrays() { + x = new FloatArray(DIM); + xb = new FloatArray(DIM); + xb2 = new FloatArray(DIM); + q = new FloatArray(DIM); + k = new FloatArray(KV_DIM); + v = new FloatArray(KV_DIM); + keyCache = new FloatArray(CONTEXT_LENGTH * KV_DIM); + valueCache = new FloatArray(CONTEXT_LENGTH * KV_DIM); + att = new FloatArray(N_HEADS * CONTEXT_LENGTH); + hb = new FloatArray(HIDDEN_DIM); + + int numGroups = (DIM + LOCAL_SIZE_RMS - 1) / LOCAL_SIZE_RMS; + temp = new FloatArray(numGroups + 1); + tempFFN = new FloatArray(numGroups + 1); + positionHolder = new IntArray(1); + + fillRandom(x, -1.0f, 1.0f); + temp.init(0.0f); + tempFFN.init(0.0f); + positionHolder.set(0, 5); + } + + private void initializeWeights() { + rmsAttWeight = new FloatArray(DIM); + rmsFfnWeight = new FloatArray(DIM); + fillRandom(rmsAttWeight, 0.8f, 1.2f); + fillRandom(rmsFfnWeight, 0.8f, 1.2f); + + wq = createRandomHalfFloatArray(DIM * DIM, -0.1f, 0.1f); + wk = createRandomHalfFloatArray(KV_DIM * DIM, -0.1f, 0.1f); + wv = createRandomHalfFloatArray(KV_DIM * DIM, -0.1f, 0.1f); + wo = createRandomHalfFloatArray(DIM * DIM, -0.1f, 0.1f); + w1 = createRandomHalfFloatArray(HIDDEN_DIM * DIM, -0.1f, 0.1f); + w2 = createRandomHalfFloatArray(DIM * HIDDEN_DIM, -0.1f, 0.1f); + w3 = createRandomHalfFloatArray(HIDDEN_DIM * DIM, -0.1f, 0.1f); + } + + // ==================== Sequential Reference Implementations ==================== + + private float computeRmsNormFactorSequential(FloatArray x, float eps) { + float ss = 0.0f; + for (int i = 0; i < x.getSize(); i++) { + ss += x.get(i) * x.get(i); + } + ss /= x.getSize(); + ss += eps; + return 1.0f / (float) Math.sqrt(ss); + } + + private void applyRmsNormSequential(FloatArray output, FloatArray x, FloatArray weights, float normFactor) { + for (int i = 0; i < x.getSize(); i++) { + output.set(i, weights.get(i) * (normFactor * x.get(i))); + } + } + + private void matrixVectorSequentialFP16(FloatArray output, HalfFloatArray weights, FloatArray x, int rows, int cols) { + for (int i = 0; i < rows; i++) { + float sum = 0.0f; + for (int j = 0; j < cols; j++) { + sum += weights.get(i * cols + j).getFloat32() * x.get(j); + } + output.set(i, sum); + } + } + + private void matrixVectorWithResidualSequential(FloatArray output, HalfFloatArray weights, FloatArray x, int rows, int cols) { + for (int i = 0; i < rows; i++) { + float sum = 0.0f; + for (int j = 0; j < cols; j++) { + sum += weights.get(i * cols + j).getFloat32() * x.get(j); + } + output.set(i, output.get(i) + sum); + } + } + + private void ropeRotationSequential(FloatArray sq, FloatArray sk, int pos, int kvDim, int headSize) { + int numPairs = sq.getSize() / 2; + for (int i = 0; i < numPairs; i++) { + int idx = i * 2; + int headDim = idx % headSize; + float freq = 1.0f / (float) Math.pow(50000.0f, headDim / (float) headSize); + float val = pos * freq; + float fcr = (float) Math.cos(val); + float fci = (float) Math.sin(val); + + float v0q = sq.get(idx); + float v1q = sq.get(idx + 1); + sq.set(idx, v0q * fcr - v1q * fci); + sq.set(idx + 1, v0q * fci + v1q * fcr); + + if (idx < kvDim && idx + 1 < sk.getSize()) { + float v0k = sk.get(idx); + float v1k = sk.get(idx + 1); + sk.set(idx, v0k * fcr - v1k * fci); + sk.set(idx + 1, v0k * fci + v1k * fcr); + } + } + } + + private void copyToCacheSequential(FloatArray keyCache, FloatArray key, FloatArray valueCache, FloatArray value, int position, int kvDim, int layer, int contextLength) { + int loff = layer * contextLength * kvDim; + int destOffset = loff + position * kvDim; + for (int i = 0; i < key.getSize(); i++) { + keyCache.set(destOffset + i, key.get(i)); + valueCache.set(destOffset + i, value.get(i)); + } + } + + private void processHeadSequential(FloatArray q, FloatArray keyCache, FloatArray valueCache, FloatArray xb, int h, int headSize, int kvDim, int kvMul, int loff, int pos) { + int kvHeadIdx = h / kvMul; + + float[] attScores = new float[pos + 1]; + for (int t = 0; t <= pos; t++) { + int keyOffset = loff + t * kvDim + kvHeadIdx * headSize; + float score = 0.0f; + for (int i = 0; i < headSize; i++) { + score += q.get(h * headSize + i) * keyCache.get(keyOffset + i); + } + attScores[t] = score / (float) Math.sqrt(headSize); + } + + float maxScore = attScores[0]; + for (int t = 1; t <= pos; t++) { + if (attScores[t] > maxScore) + maxScore = attScores[t]; + } + + float sumExp = 0.0f; + for (int t = 0; t <= pos; t++) { + attScores[t] = (float) Math.exp(attScores[t] - maxScore); + sumExp += attScores[t]; + } + + for (int t = 0; t <= pos; t++) { + attScores[t] /= sumExp; + } + + for (int i = 0; i < headSize; i++) { + float weightedSum = 0.0f; + for (int t = 0; t <= pos; t++) { + int valueOffset = loff + t * kvDim + kvHeadIdx * headSize; + weightedSum += attScores[t] * valueCache.get(valueOffset + i); + } + xb.set(h * headSize + i, weightedSum); + } + } + + private void processHeadsSequential(FloatArray q, FloatArray keyCache, FloatArray valueCache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, int pos, int layer, + int contextLength) { + int loff = layer * contextLength * kvDim; + for (int h = 0; h < nHeads; h++) { + processHeadSequential(q, keyCache, valueCache, xb, h, headSize, kvDim, kvMul, loff, pos); + } + } + + private float siluSequential(float x) { + return x * (1.0f / (1.0f + (float) Math.exp(-x))); + } + + private void fusedFFNSequential(FloatArray output, FloatArray x, HalfFloatArray w1, HalfFloatArray w3, int inputDim, int hiddenDim) { + for (int i = 0; i < hiddenDim; i++) { + float sum1 = 0.0f; + float sum3 = 0.0f; + for (int j = 0; j < inputDim; j++) { + sum1 += w1.get(i * inputDim + j).getFloat32() * x.get(j); + sum3 += w3.get(i * inputDim + j).getFloat32() * x.get(j); + } + float silu = siluSequential(sum1); + output.set(i, silu * sum3); + } + } + + // ==================== Helper Methods ==================== + + private void fillRandom(FloatArray array, float min, float max) { + float range = max - min; + for (int i = 0; i < array.getSize(); i++) { + array.set(i, min + random.nextFloat() * range); + } + } + + private FloatArray copyArray(FloatArray src) { + FloatArray dst = new FloatArray(src.getSize()); + for (int i = 0; i < src.getSize(); i++) { + dst.set(i, src.get(i)); + } + return dst; + } + + private HalfFloatArray createRandomHalfFloatArray(int size, float min, float max) { + HalfFloatArray array = new HalfFloatArray(size); + float range = max - min; + for (int i = 0; i < size; i++) { + array.set(i, new HalfFloat(min + random.nextFloat() * range)); + } + return array; + } + + private void assertArrayEquals(String message, FloatArray expected, FloatArray actual, float tolerance) { + assertEquals(message + " - size mismatch", expected.getSize(), actual.getSize()); + for (int i = 0; i < expected.getSize(); i++) { + assertEquals(message + " at index " + i, expected.get(i), actual.get(i), tolerance); + } + } + + private GridScheduler createScheduler() { + WorkerGrid rmsWorker = new WorkerGrid1D(DIM); + rmsWorker.setLocalWork(LOCAL_SIZE_RMS, 1, 1); + + WorkerGrid qMatmulWorker = new WorkerGrid1D(DIM * LOCAL_SIZE); + qMatmulWorker.setLocalWork(LOCAL_SIZE, 1, 1); + + WorkerGrid kvMatmulWorker = new WorkerGrid1D(KV_DIM * LOCAL_SIZE); + kvMatmulWorker.setLocalWork(LOCAL_SIZE, 1, 1); + + WorkerGrid ropeWorker = new WorkerGrid1D(DIM / 2); + ropeWorker.setLocalWork(Math.min(128, DIM / 2), 1, 1); + + WorkerGrid ffnWorker = new WorkerGrid1D(HIDDEN_DIM * LOCAL_SIZE); + ffnWorker.setLocalWork(LOCAL_SIZE, 1, 1); + + GridScheduler scheduler = new GridScheduler(); + scheduler.addWorkerGrid("s0.rmsReduce", rmsWorker); + scheduler.addWorkerGrid("s0.rmsApply", rmsWorker); + scheduler.addWorkerGrid("s0.qmatmul", qMatmulWorker); + scheduler.addWorkerGrid("s0.kmatmul", kvMatmulWorker); + scheduler.addWorkerGrid("s0.vmatmul", kvMatmulWorker); + scheduler.addWorkerGrid("s0.rope", ropeWorker); + scheduler.addWorkerGrid("s0.outputProj", qMatmulWorker); + scheduler.addWorkerGrid("s0.rmsReduceFFN", rmsWorker); + scheduler.addWorkerGrid("s0.rmsApplyFFN", rmsWorker); + scheduler.addWorkerGrid("s0.fusedFFN", ffnWorker); + scheduler.addWorkerGrid("s0.ffnProj", qMatmulWorker); + + return scheduler; + } + + // ==================== Stage 1: RMS Normalization ==================== + + @Test + public void testFusedStage1_RMSNorm() throws TornadoExecutionPlanException { + // Sequential reference + FloatArray expectedXb = new FloatArray(DIM); + float normFactor = computeRmsNormFactorSequential(x, RMS_NORM_EPS); + applyRmsNormSequential(expectedXb, x, rmsAttWeight, normFactor); + + GridScheduler scheduler = createScheduler(); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, rmsAttWeight, temp) + .task("rmsReduce", GPULlama3Kernels::reductionOneBlockWithLayer, context, temp, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS) + .task("rmsApply", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsAttWeight, temp) + .transferToHost(DataTransferMode.EVERY_EXECUTION, xb, temp); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.withGridScheduler(scheduler).execute(); + } + + assertEquals("RMS norm factor", normFactor, temp.get(0), 0.01f); + assertArrayEquals("Stage 1: RMS Norm output", expectedXb, xb, EPSILON_FP32); + } + + // ==================== Stage 2: RMS Norm + Q Projection ==================== + + @Test + public void testFusedStage2_RMSNorm_QMatmul() throws TornadoExecutionPlanException { + // Sequential reference + FloatArray expectedXb = new FloatArray(DIM); + FloatArray expectedQ = new FloatArray(DIM); + + float normFactor = computeRmsNormFactorSequential(x, RMS_NORM_EPS); + applyRmsNormSequential(expectedXb, x, rmsAttWeight, normFactor); + matrixVectorSequentialFP16(expectedQ, wq, expectedXb, DIM, DIM); + + GridScheduler scheduler = createScheduler(); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, rmsAttWeight, temp, wq) + .task("rmsReduce", GPULlama3Kernels::reductionOneBlockWithLayer, context, temp, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS) + .task("rmsApply", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsAttWeight, temp) + .task("qmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, q, wq, DIM, DIM, LOCAL_SIZE) + .transferToHost(DataTransferMode.EVERY_EXECUTION, q); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.withGridScheduler(scheduler).execute(); + } + + assertArrayEquals("Stage 2: Q projection output", expectedQ, q, EPSILON_FP16); + } + + // ==================== Stage 3: RMS Norm + QKV Projections ==================== + + @Test + public void testFusedStage3_RMSNorm_QKVMatmul() throws TornadoExecutionPlanException { + // Sequential reference + FloatArray expectedXb = new FloatArray(DIM); + FloatArray expectedQ = new FloatArray(DIM); + FloatArray expectedK = new FloatArray(KV_DIM); + FloatArray expectedV = new FloatArray(KV_DIM); + + float normFactor = computeRmsNormFactorSequential(x, RMS_NORM_EPS); + applyRmsNormSequential(expectedXb, x, rmsAttWeight, normFactor); + matrixVectorSequentialFP16(expectedQ, wq, expectedXb, DIM, DIM); + matrixVectorSequentialFP16(expectedK, wk, expectedXb, KV_DIM, DIM); + matrixVectorSequentialFP16(expectedV, wv, expectedXb, KV_DIM, DIM); + + GridScheduler scheduler = createScheduler(); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, rmsAttWeight, temp, wq, wk, wv) + .task("rmsReduce", GPULlama3Kernels::reductionOneBlockWithLayer, context, temp, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS) + .task("rmsApply", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsAttWeight, temp) + .task("qmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, q, wq, DIM, DIM, LOCAL_SIZE) + .task("kmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, k, wk, DIM, KV_DIM, LOCAL_SIZE) + .task("vmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, v, wv, DIM, KV_DIM, LOCAL_SIZE) + .transferToHost(DataTransferMode.EVERY_EXECUTION, q, k, v); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.withGridScheduler(scheduler).execute(); + } + + assertArrayEquals("Stage 3: Q output", expectedQ, q, EPSILON_FP16); + assertArrayEquals("Stage 3: K output", expectedK, k, EPSILON_FP16); + assertArrayEquals("Stage 3: V output", expectedV, v, EPSILON_FP16); + } + + // ==================== Stage 4: QKV + RoPE ==================== + + @Test + public void testFusedStage4_QKV_RoPE() throws TornadoExecutionPlanException { + int position = positionHolder.get(0); + + // Sequential reference + FloatArray expectedXb = new FloatArray(DIM); + FloatArray expectedQ = new FloatArray(DIM); + FloatArray expectedK = new FloatArray(KV_DIM); + + float normFactor = computeRmsNormFactorSequential(x, RMS_NORM_EPS); + applyRmsNormSequential(expectedXb, x, rmsAttWeight, normFactor); + matrixVectorSequentialFP16(expectedQ, wq, expectedXb, DIM, DIM); + matrixVectorSequentialFP16(expectedK, wk, expectedXb, KV_DIM, DIM); + ropeRotationSequential(expectedQ, expectedK, position, KV_DIM, HEAD_SIZE); + + GridScheduler scheduler = createScheduler(); + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(DataTransferMode.EVERY_EXECUTION, x, rmsAttWeight, temp, wq, wk, wv, positionHolder) + .task("rmsReduce", GPULlama3Kernels::reductionOneBlockWithLayer, context, temp, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS) + .task("rmsApply", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsAttWeight, temp) + .task("qmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, q, wq, DIM, DIM, LOCAL_SIZE) + .task("kmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, k, wk, DIM, KV_DIM, LOCAL_SIZE) + .task("vmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, v, wv, DIM, KV_DIM, LOCAL_SIZE) + .task("rope", GPULlama3Kernels::ropeRotation, context, positionHolder, q, k, KV_DIM, HEAD_SIZE) + .transferToHost(DataTransferMode.EVERY_EXECUTION, q, k); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.withGridScheduler(scheduler).execute(); + } + + assertArrayEquals("Stage 4: Q after RoPE", expectedQ, q, EPSILON_FP16); + assertArrayEquals("Stage 4: K after RoPE", expectedK, k, EPSILON_FP16); + } + + // ==================== Stage 5: QKV + RoPE + Cache Update ==================== + + @Test + public void testFusedStage5_QKV_RoPE_Cache() throws TornadoExecutionPlanException { + final int layer = 0; + int position = positionHolder.get(0); + + // Sequential reference + FloatArray expectedXb = new FloatArray(DIM); + FloatArray expectedQ = new FloatArray(DIM); + FloatArray expectedK = new FloatArray(KV_DIM); + FloatArray expectedV = new FloatArray(KV_DIM); + FloatArray expectedKeyCache = new FloatArray(CONTEXT_LENGTH * KV_DIM); + FloatArray expectedValueCache = new FloatArray(CONTEXT_LENGTH * KV_DIM); + + float normFactor = computeRmsNormFactorSequential(x, RMS_NORM_EPS); + applyRmsNormSequential(expectedXb, x, rmsAttWeight, normFactor); + matrixVectorSequentialFP16(expectedQ, wq, expectedXb, DIM, DIM); + matrixVectorSequentialFP16(expectedK, wk, expectedXb, KV_DIM, DIM); + matrixVectorSequentialFP16(expectedV, wv, expectedXb, KV_DIM, DIM); + ropeRotationSequential(expectedQ, expectedK, position, KV_DIM, HEAD_SIZE); + copyToCacheSequential(expectedKeyCache, expectedK, expectedValueCache, expectedV, position, KV_DIM, layer, CONTEXT_LENGTH); + + GridScheduler scheduler = createScheduler(); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, rmsAttWeight, temp, wq, wk, wv, positionHolder, keyCache, valueCache) + .task("rmsReduce", GPULlama3Kernels::reductionOneBlockWithLayer, context, temp, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS) + .task("rmsApply", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsAttWeight, temp) + .task("qmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, q, wq, DIM, DIM, LOCAL_SIZE) + .task("kmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, k, wk, DIM, KV_DIM, LOCAL_SIZE) + .task("vmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, v, wv, DIM, KV_DIM, LOCAL_SIZE) + .task("rope", GPULlama3Kernels::ropeRotation, context, positionHolder, q, k, KV_DIM, HEAD_SIZE) + .task("copyToCache", GPULlama3Kernels::copyToCache, keyCache, k, valueCache, v, positionHolder, KV_DIM, layer, CONTEXT_LENGTH) + .transferToHost(DataTransferMode.EVERY_EXECUTION, keyCache, valueCache); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.withGridScheduler(scheduler).execute(); + } + + // Verify cache at the specific position + int destOffset = position * KV_DIM; + for (int i = 0; i < KV_DIM; i++) { + assertEquals("Stage 5: Key cache at " + i, expectedKeyCache.get(destOffset + i), keyCache.get(destOffset + i), EPSILON_FP16); + assertEquals("Stage 5: Value cache at " + i, expectedValueCache.get(destOffset + i), valueCache.get(destOffset + i), EPSILON_FP16); + } + } + + // ==================== Stage 6: Full Attention Block ==================== + + @Test + public void testFusedStage6_FullAttentionBlock() throws TornadoExecutionPlanException { + final int layer = 0; + int position = positionHolder.get(0); + + // Sequential reference + FloatArray seqXb = new FloatArray(DIM); + FloatArray seqQ = new FloatArray(DIM); + FloatArray seqK = new FloatArray(KV_DIM); + FloatArray seqV = new FloatArray(KV_DIM); + FloatArray seqKeyCache = new FloatArray(CONTEXT_LENGTH * KV_DIM); + FloatArray seqValueCache = new FloatArray(CONTEXT_LENGTH * KV_DIM); + FloatArray seqAttOut = new FloatArray(DIM); + FloatArray expectedX = copyArray(x); + + float normFactor = computeRmsNormFactorSequential(expectedX, RMS_NORM_EPS); + applyRmsNormSequential(seqXb, expectedX, rmsAttWeight, normFactor); + matrixVectorSequentialFP16(seqQ, wq, seqXb, DIM, DIM); + matrixVectorSequentialFP16(seqK, wk, seqXb, KV_DIM, DIM); + matrixVectorSequentialFP16(seqV, wv, seqXb, KV_DIM, DIM); + ropeRotationSequential(seqQ, seqK, position, KV_DIM, HEAD_SIZE); + copyToCacheSequential(seqKeyCache, seqK, seqValueCache, seqV, position, KV_DIM, layer, CONTEXT_LENGTH); + processHeadsSequential(seqQ, seqKeyCache, seqValueCache, seqAttOut, N_HEADS, HEAD_SIZE, KV_DIM, KV_MUL, position, layer, CONTEXT_LENGTH); + // Copy seqAttOut to seqXb for output projection + for (int i = 0; i < DIM; i++) + seqXb.set(i, seqAttOut.get(i)); + matrixVectorWithResidualSequential(expectedX, wo, seqXb, DIM, DIM); + + GridScheduler scheduler = createScheduler(); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, rmsAttWeight, temp, wq, wk, wv, wo, positionHolder, keyCache, valueCache, att) + .task("rmsReduce", GPULlama3Kernels::reductionOneBlockWithLayer, context, temp, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS) + .task("rmsApply", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsAttWeight, temp) + .task("qmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, q, wq, DIM, DIM, LOCAL_SIZE) + .task("kmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, k, wk, DIM, KV_DIM, LOCAL_SIZE) + .task("vmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, v, wv, DIM, KV_DIM, LOCAL_SIZE) + .task("rope", GPULlama3Kernels::ropeRotation, context, positionHolder, q, k, KV_DIM, HEAD_SIZE) + .task("copyToCache", GPULlama3Kernels::copyToCache, keyCache, k, valueCache, v, positionHolder, KV_DIM, layer, CONTEXT_LENGTH) + .task("attention", GPULlama3Kernels::processHeadsParallel, q, keyCache, valueCache, xb, N_HEADS, HEAD_SIZE, KV_DIM, KV_MUL, CONTEXT_LENGTH, positionHolder, att, layer, CONTEXT_LENGTH) + .task("outputProj", GPULlama3Kernels::matrixVectorGenericWithResidual, context, xb, x, wo, DIM, DIM, LOCAL_SIZE) + .transferToHost(DataTransferMode.EVERY_EXECUTION, x); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.withGridScheduler(scheduler).execute(); + } + + assertArrayEquals("Stage 6: Full attention block output", expectedX, x, EPSILON_ACCUMULATED); + } + + // ==================== Stage 7: Attention + FFN RMS Norm ==================== + + @Test + public void testFusedStage7_AttentionBlock_FFNRMSNorm() throws TornadoExecutionPlanException { + final int layer = 0; + int position = positionHolder.get(0); + + // Sequential reference (build on Stage 6) + FloatArray seqXb = new FloatArray(DIM); + FloatArray seqQ = new FloatArray(DIM); + FloatArray seqK = new FloatArray(KV_DIM); + FloatArray seqV = new FloatArray(KV_DIM); + FloatArray seqKeyCache = new FloatArray(CONTEXT_LENGTH * KV_DIM); + FloatArray seqValueCache = new FloatArray(CONTEXT_LENGTH * KV_DIM); + FloatArray seqAttOut = new FloatArray(DIM); + FloatArray seqX = copyArray(x); + FloatArray expectedXb = new FloatArray(DIM); + + // Attention block + float normFactor = computeRmsNormFactorSequential(seqX, RMS_NORM_EPS); + applyRmsNormSequential(seqXb, seqX, rmsAttWeight, normFactor); + matrixVectorSequentialFP16(seqQ, wq, seqXb, DIM, DIM); + matrixVectorSequentialFP16(seqK, wk, seqXb, KV_DIM, DIM); + matrixVectorSequentialFP16(seqV, wv, seqXb, KV_DIM, DIM); + ropeRotationSequential(seqQ, seqK, position, KV_DIM, HEAD_SIZE); + copyToCacheSequential(seqKeyCache, seqK, seqValueCache, seqV, position, KV_DIM, layer, CONTEXT_LENGTH); + processHeadsSequential(seqQ, seqKeyCache, seqValueCache, seqAttOut, N_HEADS, HEAD_SIZE, KV_DIM, KV_MUL, position, layer, CONTEXT_LENGTH); + for (int i = 0; i < DIM; i++) + seqXb.set(i, seqAttOut.get(i)); + matrixVectorWithResidualSequential(seqX, wo, seqXb, DIM, DIM); + + // FFN RMS norm + float ffnNormFactor = computeRmsNormFactorSequential(seqX, RMS_NORM_EPS); + applyRmsNormSequential(expectedXb, seqX, rmsFfnWeight, ffnNormFactor); + + GridScheduler scheduler = createScheduler(); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, rmsAttWeight, rmsFfnWeight, temp, tempFFN, wq, wk, wv, wo, positionHolder, keyCache, valueCache, att) + .task("rmsReduce", GPULlama3Kernels::reductionOneBlockWithLayer, context, temp, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS) + .task("rmsApply", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsAttWeight, temp) + .task("qmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, q, wq, DIM, DIM, LOCAL_SIZE) + .task("kmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, k, wk, DIM, KV_DIM, LOCAL_SIZE) + .task("vmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, v, wv, DIM, KV_DIM, LOCAL_SIZE) + .task("rope", GPULlama3Kernels::ropeRotation, context, positionHolder, q, k, KV_DIM, HEAD_SIZE) + .task("copyToCache", GPULlama3Kernels::copyToCache, keyCache, k, valueCache, v, positionHolder, KV_DIM, layer, CONTEXT_LENGTH) + .task("attention", GPULlama3Kernels::processHeadsParallel, q, keyCache, valueCache, xb, N_HEADS, HEAD_SIZE, KV_DIM, KV_MUL, CONTEXT_LENGTH, positionHolder, att, layer, CONTEXT_LENGTH) + .task("outputProj", GPULlama3Kernels::matrixVectorGenericWithResidual, context, xb, x, wo, DIM, DIM, LOCAL_SIZE) + .task("rmsReduceFFN", GPULlama3Kernels::reductionOneBlockWithLayer, context, tempFFN, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS) + .task("rmsApplyFFN", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsFfnWeight, tempFFN) + .transferToHost(DataTransferMode.EVERY_EXECUTION, xb); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.withGridScheduler(scheduler).execute(); + } + + assertArrayEquals("Stage 7: FFN input (after RMS norm)", expectedXb, xb, EPSILON_ACCUMULATED); + } + + // ==================== Stage 8: Complete Transformer Layer ==================== + + @Test + public void testFusedStage8_CompleteTransformerLayer() throws TornadoExecutionPlanException { + final int layer = 0; + int position = positionHolder.get(0); + + // Full sequential reference + FloatArray seqXb = new FloatArray(DIM); + FloatArray seqQ = new FloatArray(DIM); + FloatArray seqK = new FloatArray(KV_DIM); + FloatArray seqV = new FloatArray(KV_DIM); + FloatArray seqKeyCache = new FloatArray(CONTEXT_LENGTH * KV_DIM); + FloatArray seqValueCache = new FloatArray(CONTEXT_LENGTH * KV_DIM); + FloatArray seqAttOut = new FloatArray(DIM); + FloatArray seqHb = new FloatArray(HIDDEN_DIM); + FloatArray expectedX = copyArray(x); + + // Attention block + float normFactor = computeRmsNormFactorSequential(expectedX, RMS_NORM_EPS); + applyRmsNormSequential(seqXb, expectedX, rmsAttWeight, normFactor); + matrixVectorSequentialFP16(seqQ, wq, seqXb, DIM, DIM); + matrixVectorSequentialFP16(seqK, wk, seqXb, KV_DIM, DIM); + matrixVectorSequentialFP16(seqV, wv, seqXb, KV_DIM, DIM); + ropeRotationSequential(seqQ, seqK, position, KV_DIM, HEAD_SIZE); + copyToCacheSequential(seqKeyCache, seqK, seqValueCache, seqV, position, KV_DIM, layer, CONTEXT_LENGTH); + processHeadsSequential(seqQ, seqKeyCache, seqValueCache, seqAttOut, N_HEADS, HEAD_SIZE, KV_DIM, KV_MUL, position, layer, CONTEXT_LENGTH); + for (int i = 0; i < DIM; i++) + seqXb.set(i, seqAttOut.get(i)); + matrixVectorWithResidualSequential(expectedX, wo, seqXb, DIM, DIM); + + // FFN block + float ffnNormFactor = computeRmsNormFactorSequential(expectedX, RMS_NORM_EPS); + applyRmsNormSequential(seqXb, expectedX, rmsFfnWeight, ffnNormFactor); + fusedFFNSequential(seqHb, seqXb, w1, w3, DIM, HIDDEN_DIM); + matrixVectorWithResidualSequential(expectedX, w2, seqHb, DIM, HIDDEN_DIM); + + GridScheduler scheduler = createScheduler(); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, rmsAttWeight, rmsFfnWeight, temp, tempFFN, wq, wk, wv, wo, w1, w2, w3, positionHolder, keyCache, valueCache, att) + .task("rmsReduce", GPULlama3Kernels::reductionOneBlockWithLayer, context, temp, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS) + .task("rmsApply", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsAttWeight, temp) + .task("qmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, q, wq, DIM, DIM, LOCAL_SIZE) + .task("kmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, k, wk, DIM, KV_DIM, LOCAL_SIZE) + .task("vmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, v, wv, DIM, KV_DIM, LOCAL_SIZE) + .task("rope", GPULlama3Kernels::ropeRotation, context, positionHolder, q, k, KV_DIM, HEAD_SIZE) + .task("copyToCache", GPULlama3Kernels::copyToCache, keyCache, k, valueCache, v, positionHolder, KV_DIM, layer, CONTEXT_LENGTH) + .task("attention", GPULlama3Kernels::processHeadsParallel, q, keyCache, valueCache, xb, N_HEADS, HEAD_SIZE, KV_DIM, KV_MUL, CONTEXT_LENGTH, positionHolder, att, layer, CONTEXT_LENGTH) + .task("outputProj", GPULlama3Kernels::matrixVectorGenericWithResidual, context, xb, x, wo, DIM, DIM, LOCAL_SIZE) + .task("rmsReduceFFN", GPULlama3Kernels::reductionOneBlockWithLayer, context, tempFFN, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS) + .task("rmsApplyFFN", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsFfnWeight, tempFFN) + .task("fusedFFN", GPULlama3Kernels::fusedFeedForwardWithSiLUAndGLUActivation, context, xb, hb, w1, w3, DIM, HIDDEN_DIM, LOCAL_SIZE) + .task("ffnProj", GPULlama3Kernels::matrixVectorGenericWithResidual, context, hb, x, w2, HIDDEN_DIM, DIM, LOCAL_SIZE) + .transferToHost(DataTransferMode.EVERY_EXECUTION, x); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.withGridScheduler(scheduler).execute(); + } + + assertArrayEquals("Stage 8: Complete transformer layer output", expectedX, x, EPSILON_ACCUMULATED); + + } + + // ==================== Multi-Iteration Test ==================== + + @Test + public void testFusedMultipleIterations() throws TornadoExecutionPlanException { + final int layer = 0; + final int numIterations = 3; + + GridScheduler scheduler = createScheduler(); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, rmsAttWeight, rmsFfnWeight, temp, tempFFN, wq, wk, wv, wo, w1, w2, w3, positionHolder, keyCache, valueCache, att) + .task("rmsReduce", GPULlama3Kernels::reductionOneBlockWithLayer, context, temp, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS) + .task("rmsApply", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsAttWeight, temp) + .task("qmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, q, wq, DIM, DIM, LOCAL_SIZE) + .task("kmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, k, wk, DIM, KV_DIM, LOCAL_SIZE) + .task("vmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, v, wv, DIM, KV_DIM, LOCAL_SIZE) + .task("rope", GPULlama3Kernels::ropeRotation, context, positionHolder, q, k, KV_DIM, HEAD_SIZE) + .task("copyToCache", GPULlama3Kernels::copyToCache, keyCache, k, valueCache, v, positionHolder, KV_DIM, layer, CONTEXT_LENGTH) + .task("attention", GPULlama3Kernels::processHeadsParallel, q, keyCache, valueCache, xb, N_HEADS, HEAD_SIZE, KV_DIM, KV_MUL, CONTEXT_LENGTH, positionHolder, att, layer, CONTEXT_LENGTH) + .task("outputProj", GPULlama3Kernels::matrixVectorGenericWithResidual, context, xb, x, wo, DIM, DIM, LOCAL_SIZE) + .task("rmsReduceFFN", GPULlama3Kernels::reductionOneBlockWithLayer, context, tempFFN, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS) + .task("rmsApplyFFN", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsFfnWeight, tempFFN) + .task("fusedFFN", GPULlama3Kernels::fusedFeedForwardWithSiLUAndGLUActivation, context, xb, hb, w1, w3, DIM, HIDDEN_DIM, LOCAL_SIZE) + .task("ffnProj", GPULlama3Kernels::matrixVectorGenericWithResidual, context, hb, x, w2, HIDDEN_DIM, DIM, LOCAL_SIZE) + .transferToHost(DataTransferMode.EVERY_EXECUTION, x); + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + + // Track sequential state + FloatArray seqX = copyArray(x); + FloatArray seqXb = new FloatArray(DIM); + FloatArray seqQ = new FloatArray(DIM); + FloatArray seqK = new FloatArray(KV_DIM); + FloatArray seqV = new FloatArray(KV_DIM); + FloatArray seqKeyCache = new FloatArray(CONTEXT_LENGTH * KV_DIM); + FloatArray seqValueCache = new FloatArray(CONTEXT_LENGTH * KV_DIM); + FloatArray seqAttOut = new FloatArray(DIM); + FloatArray seqHb = new FloatArray(HIDDEN_DIM); + + for (int iter = 0; iter < numIterations; iter++) { + int position = iter; + positionHolder.set(0, position); + temp.init(0.0f); + tempFFN.init(0.0f); + + // Sequential computation for this iteration + float normFactor = computeRmsNormFactorSequential(seqX, RMS_NORM_EPS); + applyRmsNormSequential(seqXb, seqX, rmsAttWeight, normFactor); + matrixVectorSequentialFP16(seqQ, wq, seqXb, DIM, DIM); + matrixVectorSequentialFP16(seqK, wk, seqXb, KV_DIM, DIM); + matrixVectorSequentialFP16(seqV, wv, seqXb, KV_DIM, DIM); + ropeRotationSequential(seqQ, seqK, position, KV_DIM, HEAD_SIZE); + copyToCacheSequential(seqKeyCache, seqK, seqValueCache, seqV, position, KV_DIM, layer, CONTEXT_LENGTH); + processHeadsSequential(seqQ, seqKeyCache, seqValueCache, seqAttOut, N_HEADS, HEAD_SIZE, KV_DIM, KV_MUL, position, layer, CONTEXT_LENGTH); + for (int i = 0; i < DIM; i++) + seqXb.set(i, seqAttOut.get(i)); + matrixVectorWithResidualSequential(seqX, wo, seqXb, DIM, DIM); + float ffnNormFactor = computeRmsNormFactorSequential(seqX, RMS_NORM_EPS); + applyRmsNormSequential(seqXb, seqX, rmsFfnWeight, ffnNormFactor); + fusedFFNSequential(seqHb, seqXb, w1, w3, DIM, HIDDEN_DIM); + matrixVectorWithResidualSequential(seqX, w2, seqHb, DIM, HIDDEN_DIM); + + // TornadoVM execution + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.withGridScheduler(scheduler).execute(); + } + + // Verify + assertArrayEquals("Iteration " + iter + ": output mismatch", seqX, x, EPSILON_ACCUMULATED); + + System.out.printf("Iteration %d: x[0] expected=%.6f, actual=%.6f%n", iter, seqX.get(0), x.get(0)); + } + } + + // ==================== FFN Block Only Test ==================== + + @Test + public void testFusedFFNBlockOnly() throws TornadoExecutionPlanException { + // Sequential reference + FloatArray seqXb = new FloatArray(DIM); + FloatArray seqHb = new FloatArray(HIDDEN_DIM); + FloatArray expectedX = copyArray(x); + + float ffnNormFactor = computeRmsNormFactorSequential(expectedX, RMS_NORM_EPS); + applyRmsNormSequential(seqXb, expectedX, rmsFfnWeight, ffnNormFactor); + fusedFFNSequential(seqHb, seqXb, w1, w3, DIM, HIDDEN_DIM); + matrixVectorWithResidualSequential(expectedX, w2, seqHb, DIM, HIDDEN_DIM); + + GridScheduler scheduler = createScheduler(); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, rmsFfnWeight, tempFFN, w1, w2, w3) + .task("rmsReduceFFN", GPULlama3Kernels::reductionOneBlockWithLayer, context, tempFFN, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS) + .task("rmsApplyFFN", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsFfnWeight, tempFFN) + .task("fusedFFN", GPULlama3Kernels::fusedFeedForwardWithSiLUAndGLUActivation, context, xb, hb, w1, w3, DIM, HIDDEN_DIM, LOCAL_SIZE) + .task("ffnProj", GPULlama3Kernels::matrixVectorGenericWithResidual, context, hb, x, w2, HIDDEN_DIM, DIM, LOCAL_SIZE) + .transferToHost(DataTransferMode.EVERY_EXECUTION, x, hb); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.withGridScheduler(scheduler).execute(); + } + + assertArrayEquals("FFN block: hidden state", seqHb, hb, EPSILON_ACCUMULATED); + assertArrayEquals("FFN block: output", expectedX, x, EPSILON_ACCUMULATED); + } +} \ No newline at end of file diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/gpullama/TestTransformerKernelsUnit.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/gpullama/TestTransformerKernelsUnit.java new file mode 100644 index 0000000000..e09a135a37 --- /dev/null +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/gpullama/TestTransformerKernelsUnit.java @@ -0,0 +1,951 @@ +/* + * Copyright (c) 2025, APT Group, Department of Computer Science, + * The University of Manchester. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +package uk.ac.manchester.tornado.unittests.gpullama; + +import static org.junit.Assert.assertEquals; + +import java.util.Random; + +import org.junit.Before; +import org.junit.Test; + +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.WorkerGrid1D; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; +import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException; +import uk.ac.manchester.tornado.api.types.HalfFloat; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; +import uk.ac.manchester.tornado.unittests.common.TornadoTestBase; + +/** + * Unit tests for individual transformer kernels with full numerical verification. + * + *

How to run:

+ * + * tornado-test -V org.beehive.gpullama3.tornadovm.kernels.TestTransformerKernelsUnit + * + */ +public class TestTransformerKernelsUnit extends TornadoTestBase { + + private static final float EPSILON_FP32 = 1e-4f; + private static final float EPSILON_FP16 = 0.05f; // FP16 has lower precision + private static final float EPSILON_ACCUMULATED = 0.1f; // For operations with accumulated error + private static final float RMS_NORM_EPS = 1e-5f; + private static final int LOCAL_SIZE = 64; + private static final int LOCAL_SIZE_RMS = 256; + + private Random random; + + @Before + public void setUp() { + random = new Random(42); + } + + // ==================== Sequential Reference Implementations ==================== + + /** + * Sequential RMS normalization - Phase 1: compute normalization factor + */ + private float computeRmsNormFactorSequential(FloatArray x, float eps) { + float ss = 0.0f; + for (int i = 0; i < x.getSize(); i++) { + ss += x.get(i) * x.get(i); + } + ss /= x.getSize(); + ss += eps; + return 1.0f / (float) Math.sqrt(ss); + } + + /** + * Sequential RMS normalization - Phase 2: apply normalization + */ + private void applyRmsNormSequential(FloatArray output, FloatArray x, FloatArray weights, float normFactor) { + for (int i = 0; i < x.getSize(); i++) { + output.set(i, weights.get(i) * (normFactor * x.get(i))); + } + } + + /** + * Sequential matrix-vector multiplication (FP32 weights) + */ + private void matrixVectorSequentialFP32(FloatArray output, FloatArray weights, FloatArray x, int rows, int cols) { + for (int i = 0; i < rows; i++) { + float sum = 0.0f; + for (int j = 0; j < cols; j++) { + sum += weights.get(i * cols + j) * x.get(j); + } + output.set(i, sum); + } + } + + /** + * Sequential matrix-vector multiplication (FP16 weights) + */ + private void matrixVectorSequentialFP16(FloatArray output, HalfFloatArray weights, FloatArray x, int rows, int cols) { + for (int i = 0; i < rows; i++) { + float sum = 0.0f; + for (int j = 0; j < cols; j++) { + sum += weights.get(i * cols + j).getFloat32() * x.get(j); + } + output.set(i, sum); + } + } + + /** + * Sequential matrix-vector with residual addition (FP16 weights) + */ + private void matrixVectorWithResidualSequential(FloatArray output, HalfFloatArray weights, FloatArray x, int rows, int cols) { + for (int i = 0; i < rows; i++) { + float sum = 0.0f; + for (int j = 0; j < cols; j++) { + sum += weights.get(i * cols + j).getFloat32() * x.get(j); + } + output.set(i, output.get(i) + sum); + } + } + + /** + * Sequential RoPE rotation + */ + private void ropeRotationSequential(FloatArray sq, FloatArray sk, int pos, int kvDim, int headSize) { + int numPairs = sq.getSize() / 2; + for (int i = 0; i < numPairs; i++) { + int idx = i * 2; + int headDim = idx % headSize; + float freq = 1.0f / (float) Math.pow(50000.0f, headDim / (float) headSize); + float val = pos * freq; + float fcr = (float) Math.cos(val); + float fci = (float) Math.sin(val); + + // Rotate query + float v0q = sq.get(idx); + float v1q = sq.get(idx + 1); + sq.set(idx, v0q * fcr - v1q * fci); + sq.set(idx + 1, v0q * fci + v1q * fcr); + + // Rotate key if within kvDim + if (idx < kvDim && idx + 1 < sk.getSize()) { + float v0k = sk.get(idx); + float v1k = sk.get(idx + 1); + sk.set(idx, v0k * fcr - v1k * fci); + sk.set(idx + 1, v0k * fci + v1k * fcr); + } + } + } + + /** + * Sequential SiLU activation + */ + private float siluSequential(float x) { + return x * (1.0f / (1.0f + (float) Math.exp(-x))); + } + + /** + * Sequential fused FFN with SiLU and GLU + */ + private void fusedFFNSequential(FloatArray output, FloatArray x, HalfFloatArray w1, HalfFloatArray w3, int inputDim, int hiddenDim) { + for (int i = 0; i < hiddenDim; i++) { + float sum1 = 0.0f; + float sum3 = 0.0f; + for (int j = 0; j < inputDim; j++) { + sum1 += w1.get(i * inputDim + j).getFloat32() * x.get(j); + sum3 += w3.get(i * inputDim + j).getFloat32() * x.get(j); + } + float silu = siluSequential(sum1); + output.set(i, silu * sum3); + } + } + + /** + * Sequential copy to KV cache + */ + private void copyToCacheSequential(FloatArray keyCache, FloatArray key, FloatArray valueCache, FloatArray value, int position, int kvDim, int layer, int contextLength) { + int loff = layer * contextLength * kvDim; + int destOffset = loff + position * kvDim; + for (int i = 0; i < key.getSize(); i++) { + keyCache.set(destOffset + i, key.get(i)); + valueCache.set(destOffset + i, value.get(i)); + } + } + + /** + * Sequential element-wise addition + */ + private void addInPlaceSequential(FloatArray a, FloatArray b) { + for (int i = 0; i < a.getSize(); i++) { + a.set(i, a.get(i) + b.get(i)); + } + } + + /** + * Sequential split gate/up with SiLU + */ + private void splitGateUpAndSiLUSequential(FloatArray hb, FloatArray hbG, FloatArray hbU, int hiddenDim) { + for (int i = 0; i < hiddenDim; i++) { + float gateVal = hb.get(i); + float upVal = hb.get(hiddenDim + i); + float siluGate = siluSequential(gateVal); + hbG.set(i, siluGate); + hbU.set(i, siluGate * upVal); + } + } + + /** + * Sequential attention for a single head + */ + private void processHeadSequential(FloatArray q, FloatArray keyCache, FloatArray valueCache, FloatArray xb, int h, int headSize, int kvDim, int kvMul, int loff, int pos) { + int kvHeadIdx = h / kvMul; + + // Compute attention scores + float[] attScores = new float[pos + 1]; + for (int t = 0; t <= pos; t++) { + int keyOffset = loff + t * kvDim + kvHeadIdx * headSize; + float score = 0.0f; + for (int i = 0; i < headSize; i++) { + score += q.get(h * headSize + i) * keyCache.get(keyOffset + i); + } + attScores[t] = score / (float) Math.sqrt(headSize); + } + + // Softmax + float maxScore = attScores[0]; + for (int t = 1; t <= pos; t++) { + if (attScores[t] > maxScore) + maxScore = attScores[t]; + } + + float sumExp = 0.0f; + for (int t = 0; t <= pos; t++) { + attScores[t] = (float) Math.exp(attScores[t] - maxScore); + sumExp += attScores[t]; + } + + for (int t = 0; t <= pos; t++) { + attScores[t] /= sumExp; + } + + // Weighted sum of values + for (int i = 0; i < headSize; i++) { + float weightedSum = 0.0f; + for (int t = 0; t <= pos; t++) { + int valueOffset = loff + t * kvDim + kvHeadIdx * headSize; + weightedSum += attScores[t] * valueCache.get(valueOffset + i); + } + xb.set(h * headSize + i, weightedSum); + } + } + + /** + * Sequential multi-head attention + */ + private void processHeadsSequential(FloatArray q, FloatArray keyCache, FloatArray valueCache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, int pos, int layer, + int contextLength) { + int loff = layer * contextLength * kvDim; + for (int h = 0; h < nHeads; h++) { + processHeadSequential(q, keyCache, valueCache, xb, h, headSize, kvDim, kvMul, loff, pos); + } + } + + // ==================== Helper Methods ==================== + + private void fillRandom(FloatArray array, float min, float max) { + float range = max - min; + for (int i = 0; i < array.getSize(); i++) { + array.set(i, min + random.nextFloat() * range); + } + } + + private FloatArray copyArray(FloatArray src) { + FloatArray dst = new FloatArray(src.getSize()); + for (int i = 0; i < src.getSize(); i++) { + dst.set(i, src.get(i)); + } + return dst; + } + + private HalfFloatArray createRandomHalfFloatArray(int size, float min, float max) { + HalfFloatArray array = new HalfFloatArray(size); + float range = max - min; + for (int i = 0; i < size; i++) { + array.set(i, new HalfFloat(min + random.nextFloat() * range)); + } + return array; + } + + private void assertArrayEquals(String message, FloatArray expected, FloatArray actual, float tolerance) { + assertEquals(message + " - size mismatch", expected.getSize(), actual.getSize()); + for (int i = 0; i < expected.getSize(); i++) { + assertEquals(message + " at index " + i, expected.get(i), actual.get(i), tolerance); + } + } + + // ==================== Unit Tests ==================== + + @Test + public void testReductionOneBlockWithLayer() throws TornadoExecutionPlanException { + final int dim = 512; + final int localSize = 256; + final int numGroups = (dim + localSize - 1) / localSize; + + FloatArray x = new FloatArray(dim); + FloatArray output = new FloatArray(numGroups + 1); + fillRandom(x, -1.0f, 1.0f); + + // Sequential reference + float expectedNormFactor = computeRmsNormFactorSequential(x, RMS_NORM_EPS); + + WorkerGrid worker = new WorkerGrid1D(dim); + worker.setLocalWork(localSize, 1, 1); + GridScheduler gridScheduler = new GridScheduler("s0.t0", worker); + KernelContext context = new KernelContext(); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, output) + .task("t0", GPULlama3Kernels::reductionOneBlockWithLayer, context, output, x, dim, RMS_NORM_EPS, localSize) + .transferToHost(DataTransferMode.EVERY_EXECUTION, output); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.withGridScheduler(gridScheduler).execute(); + } + + assertEquals("RMS norm factor", expectedNormFactor, output.get(0), 0.01f); + } + + @Test + public void testReductionOneBlock2WithLayer() throws TornadoExecutionPlanException { + final int dim = 512; + final int localSize = 256; + final int numGroups = (dim + localSize - 1) / localSize; + + FloatArray x = new FloatArray(dim); + FloatArray xb = new FloatArray(dim); + FloatArray weights = new FloatArray(dim); + FloatArray temp = new FloatArray(numGroups + 1); + FloatArray expectedOutput = new FloatArray(dim); + + fillRandom(x, -1.0f, 1.0f); + fillRandom(weights, 0.5f, 1.5f); + + // Sequential reference + float normFactor = computeRmsNormFactorSequential(x, RMS_NORM_EPS); + temp.set(0, normFactor); + applyRmsNormSequential(expectedOutput, x, weights, normFactor); + + WorkerGrid worker = new WorkerGrid1D(dim); + worker.setLocalWork(localSize, 1, 1); + GridScheduler gridScheduler = new GridScheduler("s0.t0", worker); + KernelContext context = new KernelContext(); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, weights, temp) + .task("t0", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, weights, temp) + .transferToHost(DataTransferMode.EVERY_EXECUTION, xb); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.withGridScheduler(gridScheduler).execute(); + } + + assertArrayEquals("RMS norm application", expectedOutput, xb, EPSILON_FP32); + } + + @Test + public void testFullRMSNormalization() throws TornadoExecutionPlanException { + final int dim = 512; + final int localSize = 256; + final int numGroups = (dim + localSize - 1) / localSize; + + FloatArray x = new FloatArray(dim); + FloatArray xb = new FloatArray(dim); + FloatArray weights = new FloatArray(dim); + FloatArray temp = new FloatArray(numGroups + 1); + FloatArray expectedOutput = new FloatArray(dim); + + fillRandom(x, -1.0f, 1.0f); + fillRandom(weights, 0.5f, 1.5f); + temp.init(0.0f); + + // Sequential reference: full RMS norm + float normFactor = computeRmsNormFactorSequential(x, RMS_NORM_EPS); + applyRmsNormSequential(expectedOutput, x, weights, normFactor); + + WorkerGrid worker = new WorkerGrid1D(dim); + worker.setLocalWork(localSize, 1, 1); + GridScheduler gridScheduler = new GridScheduler(); + gridScheduler.addWorkerGrid("s0.reduce", worker); + gridScheduler.addWorkerGrid("s0.apply", worker); + KernelContext context = new KernelContext(); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(DataTransferMode.EVERY_EXECUTION, x, weights, temp) + .task("reduce", GPULlama3Kernels::reductionOneBlockWithLayer, context, temp, x, dim, RMS_NORM_EPS, localSize) + .task("apply", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, weights, temp) + .transferToHost(DataTransferMode.EVERY_EXECUTION, xb, temp); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.withGridScheduler(gridScheduler).execute(); + } + + assertEquals("RMS norm factor", normFactor, temp.get(0), 0.01f); + assertArrayEquals("Full RMS normalization", expectedOutput, xb, EPSILON_FP32); + } + + @Test + public void testMatrixVectorGenericFP32() throws TornadoExecutionPlanException { + final int inputDim = 256; + final int outputDim = 128; + final int localSize = 64; + + FloatArray x = new FloatArray(inputDim); + FloatArray weights = new FloatArray(outputDim * inputDim); + FloatArray output = new FloatArray(outputDim); + FloatArray expectedOutput = new FloatArray(outputDim); + + fillRandom(x, -1.0f, 1.0f); + fillRandom(weights, -0.1f, 0.1f); + + // Sequential reference + matrixVectorSequentialFP32(expectedOutput, weights, x, outputDim, inputDim); + + WorkerGrid worker = new WorkerGrid1D(outputDim * localSize); + worker.setLocalWork(localSize, 1, 1); + GridScheduler gridScheduler = new GridScheduler("s0.t0", worker); + KernelContext context = new KernelContext(); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, weights) + .task("t0", GPULlama3Kernels::matrixVectorGeneric, context, x, output, weights, inputDim, outputDim, localSize) + .transferToHost(DataTransferMode.EVERY_EXECUTION, output); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.withGridScheduler(gridScheduler).execute(); + } + + assertArrayEquals("MatVec FP32", expectedOutput, output, EPSILON_FP32); + } + + @Test + public void testMatrixVectorGenericFP16() throws TornadoExecutionPlanException { + final int inputDim = 256; + final int outputDim = 128; + final int localSize = 64; + + FloatArray x = new FloatArray(inputDim); + HalfFloatArray weights = createRandomHalfFloatArray(outputDim * inputDim, -0.1f, 0.1f); + FloatArray output = new FloatArray(outputDim); + FloatArray expectedOutput = new FloatArray(outputDim); + + fillRandom(x, -1.0f, 1.0f); + + // Sequential reference + matrixVectorSequentialFP16(expectedOutput, weights, x, outputDim, inputDim); + + WorkerGrid worker = new WorkerGrid1D(outputDim * localSize); + worker.setLocalWork(localSize, 1, 1); + GridScheduler gridScheduler = new GridScheduler("s0.t0", worker); + KernelContext context = new KernelContext(); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, weights) + .task("t0", GPULlama3Kernels::matrixVectorGeneric, context, x, output, weights, inputDim, outputDim, localSize) + .transferToHost(DataTransferMode.EVERY_EXECUTION, output); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.withGridScheduler(gridScheduler).execute(); + } + + assertArrayEquals("MatVec FP16", expectedOutput, output, EPSILON_FP16); + } + + @Test + public void testMatrixVectorGenericWithResidual() throws TornadoExecutionPlanException { + final int inputDim = 256; + final int outputDim = 128; + final int localSize = 64; + + FloatArray x = new FloatArray(inputDim); + FloatArray residual = new FloatArray(outputDim); + FloatArray expectedResidual = new FloatArray(outputDim); + HalfFloatArray weights = createRandomHalfFloatArray(outputDim * inputDim, -0.1f, 0.1f); + + fillRandom(x, -1.0f, 1.0f); + fillRandom(residual, -0.5f, 0.5f); + + // Copy residual for sequential computation + for (int i = 0; i < outputDim; i++) { + expectedResidual.set(i, residual.get(i)); + } + matrixVectorWithResidualSequential(expectedResidual, weights, x, outputDim, inputDim); + + WorkerGrid worker = new WorkerGrid1D(outputDim * localSize); + worker.setLocalWork(localSize, 1, 1); + GridScheduler gridScheduler = new GridScheduler("s0.t0", worker); + KernelContext context = new KernelContext(); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, residual, weights) + .task("t0", GPULlama3Kernels::matrixVectorGenericWithResidual, context, x, residual, weights, inputDim, outputDim, localSize) + .transferToHost(DataTransferMode.EVERY_EXECUTION, residual); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.withGridScheduler(gridScheduler).execute(); + } + + assertArrayEquals("MatVec with residual", expectedResidual, residual, EPSILON_FP16); + } + + @Test + public void testRopeRotation() throws TornadoExecutionPlanException { + final int dim = 128; + final int kvDim = 64; + final int headSize = 32; + final int position = 5; + + FloatArray sq = new FloatArray(dim); + FloatArray sk = new FloatArray(kvDim); + FloatArray expectedSq = new FloatArray(dim); + FloatArray expectedSk = new FloatArray(kvDim); + IntArray positionHolder = new IntArray(1); + + fillRandom(sq, -1.0f, 1.0f); + fillRandom(sk, -1.0f, 1.0f); + positionHolder.set(0, position); + + // Copy for sequential computation + for (int i = 0; i < dim; i++) + expectedSq.set(i, sq.get(i)); + for (int i = 0; i < kvDim; i++) + expectedSk.set(i, sk.get(i)); + ropeRotationSequential(expectedSq, expectedSk, position, kvDim, headSize); + + WorkerGrid worker = new WorkerGrid1D(dim / 2); + worker.setLocalWork(Math.min(128, dim / 2), 1, 1); + GridScheduler gridScheduler = new GridScheduler("s0.t0", worker); + KernelContext context = new KernelContext(); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, positionHolder, sq, sk) + .task("t0", GPULlama3Kernels::ropeRotation, context, positionHolder, sq, sk, kvDim, headSize) + .transferToHost(DataTransferMode.EVERY_EXECUTION, sq, sk); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.withGridScheduler(gridScheduler).execute(); + } + + assertArrayEquals("RoPE query rotation", expectedSq, sq, EPSILON_FP32); + assertArrayEquals("RoPE key rotation", expectedSk, sk, EPSILON_FP32); + } + + @Test + public void testRopeRotationMultiplePositions() throws TornadoExecutionPlanException { + final int dim = 128; + final int kvDim = 64; + final int headSize = 32; + + for (int position : new int[] { 0, 1, 10, 50, 100 }) { + FloatArray sq = new FloatArray(dim); + FloatArray sk = new FloatArray(kvDim); + FloatArray expectedSq = new FloatArray(dim); + FloatArray expectedSk = new FloatArray(kvDim); + IntArray positionHolder = new IntArray(1); + + fillRandom(sq, -1.0f, 1.0f); + fillRandom(sk, -1.0f, 1.0f); + positionHolder.set(0, position); + + for (int i = 0; i < dim; i++) + expectedSq.set(i, sq.get(i)); + for (int i = 0; i < kvDim; i++) + expectedSk.set(i, sk.get(i)); + ropeRotationSequential(expectedSq, expectedSk, position, kvDim, headSize); + + WorkerGrid worker = new WorkerGrid1D(dim / 2); + worker.setLocalWork(Math.min(128, dim / 2), 1, 1); + GridScheduler gridScheduler = new GridScheduler("s0.t0", worker); + KernelContext context = new KernelContext(); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, positionHolder, sq, sk) + .task("t0", GPULlama3Kernels::ropeRotation, context, positionHolder, sq, sk, kvDim, headSize) + .transferToHost(DataTransferMode.EVERY_EXECUTION, sq, sk); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.withGridScheduler(gridScheduler).execute(); + } + + assertArrayEquals("RoPE Q at position " + position, expectedSq, sq, EPSILON_FP32); + assertArrayEquals("RoPE K at position " + position, expectedSk, sk, EPSILON_FP32); + } + } + + @Test + public void testCopyToCache() throws TornadoExecutionPlanException { + final int kvDim = 64; + final int contextLength = 128; + final int numLayers = 4; + final int layer = 2; + final int position = 10; + + FloatArray key = new FloatArray(kvDim); + FloatArray value = new FloatArray(kvDim); + FloatArray keyCache = new FloatArray(numLayers * contextLength * kvDim); + FloatArray valueCache = new FloatArray(numLayers * contextLength * kvDim); + FloatArray expectedKeyCache = new FloatArray(numLayers * contextLength * kvDim); + FloatArray expectedValueCache = new FloatArray(numLayers * contextLength * kvDim); + IntArray positionHolder = new IntArray(1); + + fillRandom(key, -1.0f, 1.0f); + fillRandom(value, -1.0f, 1.0f); + positionHolder.set(0, position); + + // Sequential reference + copyToCacheSequential(expectedKeyCache, key, expectedValueCache, value, position, kvDim, layer, contextLength); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, key, value, keyCache, valueCache, positionHolder) + .task("t0", GPULlama3Kernels::copyToCache, keyCache, key, valueCache, value, positionHolder, kvDim, layer, contextLength) + .transferToHost(DataTransferMode.EVERY_EXECUTION, keyCache, valueCache); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.execute(); + } + + // Verify the entire cache matches (including zeros) + int loff = layer * contextLength * kvDim; + int destOffset = loff + position * kvDim; + for (int i = 0; i < kvDim; i++) { + assertEquals("Key cache at " + i, key.get(i), keyCache.get(destOffset + i), EPSILON_FP32); + assertEquals("Value cache at " + i, value.get(i), valueCache.get(destOffset + i), EPSILON_FP32); + } + } + + @Test + public void testFusedFeedForwardWithSiLUAndGLUActivation() throws TornadoExecutionPlanException { + final int inputDim = 128; + final int hiddenDim = 64; + final int localSize = 32; + + FloatArray x = new FloatArray(inputDim); + HalfFloatArray w1 = createRandomHalfFloatArray(hiddenDim * inputDim, -0.1f, 0.1f); + HalfFloatArray w3 = createRandomHalfFloatArray(hiddenDim * inputDim, -0.1f, 0.1f); + FloatArray output = new FloatArray(hiddenDim); + FloatArray expectedOutput = new FloatArray(hiddenDim); + + fillRandom(x, -1.0f, 1.0f); + + // Sequential reference + fusedFFNSequential(expectedOutput, x, w1, w3, inputDim, hiddenDim); + + WorkerGrid worker = new WorkerGrid1D(hiddenDim * localSize); + worker.setLocalWork(localSize, 1, 1); + GridScheduler gridScheduler = new GridScheduler("s0.t0", worker); + KernelContext context = new KernelContext(); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, w1, w3) + .task("t0", GPULlama3Kernels::fusedFeedForwardWithSiLUAndGLUActivation, context, x, output, w1, w3, inputDim, hiddenDim, localSize) + .transferToHost(DataTransferMode.EVERY_EXECUTION, output); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.withGridScheduler(gridScheduler).execute(); + } + + assertArrayEquals("Fused FFN", expectedOutput, output, EPSILON_ACCUMULATED); + } + + @Test + public void testAddInPlace() throws TornadoExecutionPlanException { + final int size = 512; + + FloatArray a = new FloatArray(size); + FloatArray b = new FloatArray(size); + FloatArray expectedA = new FloatArray(size); + + fillRandom(a, -1.0f, 1.0f); + fillRandom(b, -1.0f, 1.0f); + + // Copy for sequential + for (int i = 0; i < size; i++) { + expectedA.set(i, a.get(i)); + } + addInPlaceSequential(expectedA, b); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, a, b) + .task("t0", GPULlama3Kernels::addInPlace, a, b, size) + .transferToHost(DataTransferMode.EVERY_EXECUTION, a); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.execute(); + } + + assertArrayEquals("Add in place", expectedA, a, EPSILON_FP32); + } + + @Test + public void testSplitQKV() throws TornadoExecutionPlanException { + final int dimQ = 256; + final int dimKV = 64; + final int totalSize = dimQ + 2 * dimKV; + + FloatArray qkv = new FloatArray(totalSize); + FloatArray q = new FloatArray(dimQ); + FloatArray k = new FloatArray(dimKV); + FloatArray v = new FloatArray(dimKV); + + fillRandom(qkv, -1.0f, 1.0f); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, qkv) + .task("t0", GPULlama3Kernels::splitQKV, qkv, q, k, v, dimQ, dimKV) + .transferToHost(DataTransferMode.EVERY_EXECUTION, q, k, v); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.execute(); + } + + // Verify exact split + for (int i = 0; i < dimQ; i++) { + assertEquals("Q[" + i + "]", qkv.get(i), q.get(i), EPSILON_FP32); + } + for (int i = 0; i < dimKV; i++) { + assertEquals("K[" + i + "]", qkv.get(dimQ + i), k.get(i), EPSILON_FP32); + assertEquals("V[" + i + "]", qkv.get(dimQ + dimKV + i), v.get(i), EPSILON_FP32); + } + } + + @Test + public void testSplitGateUpAndSiLU() throws TornadoExecutionPlanException { + final int hiddenDim = 256; + + FloatArray hb = new FloatArray(hiddenDim * 2); + FloatArray hbG = new FloatArray(hiddenDim); + FloatArray hbU = new FloatArray(hiddenDim); + FloatArray expectedHbG = new FloatArray(hiddenDim); + FloatArray expectedHbU = new FloatArray(hiddenDim); + + fillRandom(hb, -2.0f, 2.0f); + + // Sequential reference + splitGateUpAndSiLUSequential(hb, expectedHbG, expectedHbU, hiddenDim); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, hb) + .task("t0", GPULlama3Kernels::splitGateUpAndSiLU, hb, hbG, hbU, hiddenDim) + .transferToHost(DataTransferMode.EVERY_EXECUTION, hbG, hbU); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.execute(); + } + + assertArrayEquals("Gate with SiLU", expectedHbG, hbG, EPSILON_FP32); + assertArrayEquals("Up * Gate", expectedHbU, hbU, EPSILON_FP32); + } + + @Test + public void testProcessHeadsParallel() throws TornadoExecutionPlanException { + final int nHeads = 4; + final int headSize = 32; + final int kvDim = nHeads * headSize; + final int dim = nHeads * headSize; + final int contextLength = 64; + final int position = 3; + final int layer = 0; + final int kvMul = 1; + + FloatArray q = new FloatArray(dim); + FloatArray keyCache = new FloatArray(contextLength * kvDim); + FloatArray valueCache = new FloatArray(contextLength * kvDim); + FloatArray xb = new FloatArray(dim); + FloatArray expectedXb = new FloatArray(dim); + FloatArray att = new FloatArray(nHeads * contextLength); + IntArray positionHolder = new IntArray(1); + + fillRandom(q, -1.0f, 1.0f); + fillRandom(keyCache, -1.0f, 1.0f); + fillRandom(valueCache, -1.0f, 1.0f); + positionHolder.set(0, position); + + // Sequential reference + processHeadsSequential(q, keyCache, valueCache, expectedXb, nHeads, headSize, kvDim, kvMul, position, layer, contextLength); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, q, keyCache, valueCache, positionHolder, att) + .task("t0", GPULlama3Kernels::processHeadsParallel, q, keyCache, valueCache, xb, nHeads, headSize, kvDim, kvMul, contextLength, positionHolder, att, layer, contextLength) + .transferToHost(DataTransferMode.EVERY_EXECUTION, xb); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.execute(); + } + + assertArrayEquals("Parallel attention", expectedXb, xb, EPSILON_FP32); + } + + @Test + public void testProcessHeadsFlashAttention() throws TornadoExecutionPlanException { + final int nHeads = 4; + final int headSize = 32; + final int kvDim = nHeads * headSize; + final int dim = nHeads * headSize; + final int contextLength = 64; + final int position = 3; + final int layer = 0; + final int kvMul = 1; + + FloatArray q = new FloatArray(dim); + FloatArray keyCache = new FloatArray(contextLength * kvDim); + FloatArray valueCache = new FloatArray(contextLength * kvDim); + FloatArray xb = new FloatArray(dim); + FloatArray expectedXb = new FloatArray(dim); + IntArray positionHolder = new IntArray(1); + + fillRandom(q, -1.0f, 1.0f); + fillRandom(keyCache, -1.0f, 1.0f); + fillRandom(valueCache, -1.0f, 1.0f); + positionHolder.set(0, position); + + // Sequential reference (same as processHeadsParallel) + processHeadsSequential(q, keyCache, valueCache, expectedXb, nHeads, headSize, kvDim, kvMul, position, layer, contextLength); + + WorkerGrid worker = new WorkerGrid1D(nHeads * headSize); + worker.setLocalWork(headSize, 1, 1); + GridScheduler gridScheduler = new GridScheduler("s0.t0", worker); + KernelContext context = new KernelContext(); + + // @formatter:off + TaskGraph taskGraph = new TaskGraph("s0") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, q, keyCache, valueCache, positionHolder) + .task("t0", GPULlama3Kernels::processHeadsFlashAttention, context, q, keyCache, valueCache, xb, nHeads, headSize, kvDim, kvMul, positionHolder, layer, contextLength) + .transferToHost(DataTransferMode.EVERY_EXECUTION, xb); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) { + executionPlan.withGridScheduler(gridScheduler).execute(); + } + + assertArrayEquals("Flash attention", expectedXb, xb, EPSILON_FP32); + } + + @Test + public void testAttentionConsistency() throws TornadoExecutionPlanException { + // Test that both attention implementations produce the same result + final int nHeads = 4; + final int headSize = 32; + final int kvDim = nHeads * headSize; + final int dim = nHeads * headSize; + final int contextLength = 64; + final int position = 5; + final int layer = 0; + final int kvMul = 1; + + FloatArray q = new FloatArray(dim); + FloatArray keyCache = new FloatArray(contextLength * kvDim); + FloatArray valueCache = new FloatArray(contextLength * kvDim); + FloatArray xbParallel = new FloatArray(dim); + FloatArray xbFlash = new FloatArray(dim); + FloatArray att = new FloatArray(nHeads * contextLength); + IntArray positionHolder = new IntArray(1); + + fillRandom(q, -1.0f, 1.0f); + fillRandom(keyCache, -1.0f, 1.0f); + fillRandom(valueCache, -1.0f, 1.0f); + positionHolder.set(0, position); + + // Test parallel attention + // @formatter:off + TaskGraph taskGraphParallel = new TaskGraph("s_parallel") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, q, keyCache, valueCache, positionHolder, att) + .task("t0", GPULlama3Kernels::processHeadsParallel, q, keyCache, valueCache, xbParallel, nHeads, headSize, kvDim, kvMul, contextLength, positionHolder, att, layer, contextLength) + .transferToHost(DataTransferMode.EVERY_EXECUTION, xbParallel); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraphParallel = taskGraphParallel.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraphParallel)) { + executionPlan.execute(); + } + + // Test flash attention + WorkerGrid worker = new WorkerGrid1D(nHeads * headSize); + worker.setLocalWork(headSize, 1, 1); + GridScheduler gridScheduler = new GridScheduler("s_flash.t0", worker); + KernelContext context = new KernelContext(); + + // @formatter:off + TaskGraph taskGraphFlash = new TaskGraph("s_flash") + .transferToDevice(DataTransferMode.EVERY_EXECUTION, q, keyCache, valueCache, positionHolder) + .task("t0", GPULlama3Kernels::processHeadsFlashAttention, context, q, keyCache, valueCache, xbFlash, nHeads, headSize, kvDim, kvMul, positionHolder, layer, contextLength) + .transferToHost(DataTransferMode.EVERY_EXECUTION, xbFlash); + // @formatter:on + + ImmutableTaskGraph immutableTaskGraphFlash = taskGraphFlash.snapshot(); + try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraphFlash)) { + executionPlan.withGridScheduler(gridScheduler).execute(); + } + + // Both implementations should match + assertArrayEquals("Parallel vs Flash attention consistency", xbParallel, xbFlash, EPSILON_FP32); + } +} \ No newline at end of file