diff --git a/tornado-assembly/src/bin/tornado-test b/tornado-assembly/src/bin/tornado-test index 56eeac654f..4e2657e43b 100755 --- a/tornado-assembly/src/bin/tornado-test +++ b/tornado-assembly/src/bin/tornado-test @@ -226,6 +226,8 @@ __TEST_THE_WORLD__ = [ TestEntry("uk.ac.manchester.tornado.unittests.pointers.TestCopyDevicePointers"), TestEntry("uk.ac.manchester.tornado.unittests.tensors.TestTensorAPIWithOnnx"), TestEntry("uk.ac.manchester.tornado.unittests.memory.MemoryConsumptionTest"), + TestEntry("uk.ac.manchester.tornado.unittests.gpullama.TestTransformerKernelsUnit"), + TestEntry("uk.ac.manchester.tornado.unittests.gpullama.TestTransformerKernelsFused"), ## Test for function calls - We force not to inline methods TestEntry(testName="uk.ac.manchester.tornado.unittests.tasks.TestMultipleFunctions", diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/gpullama/GPULlama3Kernels.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/gpullama/GPULlama3Kernels.java new file mode 100644 index 0000000000..2563aaa2e5 --- /dev/null +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/gpullama/GPULlama3Kernels.java @@ -0,0 +1,1115 @@ +/* + * Copyright (c) 2025, APT Group, Department of Computer Science, + * The University of Manchester. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +package uk.ac.manchester.tornado.unittests.gpullama; + +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.annotations.Parallel; +import uk.ac.manchester.tornado.api.math.TornadoMath; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.Int8Array; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; + +public class GPULlama3Kernels { + + /** + * Performs RMS (Root Mean Square) normalization using parallel reduction. This is the first phase of RMS normalization that computes the variance and scaling factor across all work groups. + * + * Algorithm: 1. Each thread computes square of its input element 2. Work group performs parallel reduction of squares 3. Partial sums stored per work group 4. First thread combines all partial + * sums and computes normalization factor + * + * @param context + * Kernel execution context + * @param output + * Array to store partial sums and final normalization factor + * @param x + * Input array to normalize + * @param size + * Number of elements to process + * @param ermsNorm + * Epsilon value squared for numerical stability + * @param localMemSize + * Size of local memory allocation (must match work group size) + */ + public static void reductionOneBlockWithLayer(KernelContext context, FloatArray output, FloatArray x, int size, float ermsNorm, int localMemSize) { + int gid = context.globalIdx; + int lid = context.localIdx; + int groupId = context.groupIdx; + int groupSize = context.localGroupSizeX; + + // Allocate local memory with the provided size + float[] localX = context.allocateFloatLocalArray(localMemSize); + + // Load input value and compute square + if (gid < size) { + localX[lid] = x.get(gid); + localX[lid] = localX[lid] * localX[lid]; + } else { + localX[lid] = 0.0f; + } + + // Perform parallel reduction within the work group + for (int stride = (groupSize / 2); stride > 0; stride /= 2) { + context.localBarrier(); + if (lid < stride) { + localX[lid] += localX[lid + stride]; + } + } + + // Each workgroup stores its partial sum in a different location + if (lid == 0) { + // Store the partial sum from each workgroup + output.set(groupId + 1, localX[0]); + } + + // Only the first thread in the first workgroup computes the final normalization factor + if (gid == 0) { + // Combine partial sums from all workgroups + float ss = 0.0f; + for (int i = 1; i <= (size / localMemSize); i++) { // Assuming 8 workgroups + ss += output.get(i); + } + + ss /= size; + ss += ermsNorm; + ss = 1.0f / TornadoMath.sqrt(ss); + output.set(0, ss); // Store the final scale factor + } + } + + /** + * Applies the computed normalization factor to input and weight elements. This is the second phase of RMS normalization. + * + * Formula: output[i] = weight[i] * (normalizationFactor * x[i]) + * + * @param context + * Kernel execution context + * @param output + * Array for normalized output + * @param x + * Input values to normalize + * @param weights + * Weight values for each element + * @param temp + * Temporary array containing normalization factor at index 0 + */ + public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray output, FloatArray x, FloatArray weights, FloatArray temp) { + int gid = context.globalIdx; + + float ss = temp.get(0); + output.set(gid, weights.get(gid) * (ss * x.get(gid))); + } + + /** + * Copies keys and values into the key-value cache for attention computation. Enables efficient access to past key-value pairs during autoregressive generation. + * + * Cache layout: [layer][position][dimension] - Each layer has its own key and value cache - Each position in sequence has a key and value vector + * + * @param destKeyCache + * Destination array for key cache + * @param srcKey + * Source keys to copy + * @param destValueCache + * Destination array for value cache + * @param srcValue + * Source values to copy + * @param positioNlayer + * Array containing current position + * @param kvDim + * Dimension of key/value vectors + * @param layer + * Current transformer layer index + * @param contextLength + * Maximum sequence length + */ + public static void copyToCache(FloatArray destKeyCache, FloatArray srcKey, FloatArray destValueCache, FloatArray srcValue, IntArray positioNlayer, int kvDim, int layer, int contextLength) { + + int position = positioNlayer.get(0); + int loff = layer * contextLength * kvDim; + int destOffset = loff + position * kvDim; + + for (@Parallel int i = 0; i < srcValue.getSize(); i++) { + destKeyCache.set(destOffset + i, srcKey.get(i)); + destValueCache.set(destOffset + i, srcValue.get(i)); + } + } + + public static void splitQKV(FloatArray qkv, FloatArray q, FloatArray k, FloatArray v, int dimQ, int dimKV) { + int totalSize = dimQ + 2 * dimKV; + + for (@Parallel int i = 0; i < totalSize; i++) { + if (i < dimQ) { + // Copy to Q + q.set(i, qkv.get(i)); + } else if (i < dimQ + dimKV) { + // Copy to K + int kIndex = i - dimQ; + k.set(kIndex, qkv.get(i)); + } else { + // Copy to V + int vIndex = i - dimQ - dimKV; + v.set(vIndex, qkv.get(i)); + } + } + } + + /** + * Applies Rotary Position Encoding (RoPE) to query and key vectors. RoPE rotates pairs of dimensions based on their position in the sequence, enabling the model to learn relative positional + * information. + * + * For each pair of dimensions (2*i, 2*i+1): - Compute rotation angle based on position and frequency - Apply 2D rotation to the pair + * + * @param context + * Kernel execution context + * @param positionHolder + * Array containing current position + * @param sq + * Query vectors to rotate + * @param sk + * Key vectors to rotate + * @param kv_dim + * Dimension of key/value vectors + * @param head_size + * Dimension of each attention head + */ + public static void ropeRotation(KernelContext context, IntArray positionHolder, FloatArray sq, FloatArray sk, int kv_dim, int head_size) { + int i = context.globalIdx * 2; + + int head_dim = i % head_size; + // 50000.0f vs 10000.0f + float freq = 1.0f / TornadoMath.pow(50000.0f, head_dim / (float) head_size); + float val = positionHolder.get(0) * freq; + float fcr = TornadoMath.cos(val); + float fci = TornadoMath.sin(val); + + int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only + + // Rotate query vector + float v0q = sq.get(i); + float v1q = sq.get(i + 1); + sq.set(i, v0q * fcr - v1q * fci); + sq.set(i + 1, v0q * fci + v1q * fcr); + + // Rotate key vector if needed + if (rotn > 1 && i < sk.getSize()) { + float v0k = sk.get(i); + float v1k = sk.get(i + 1); + sk.set(i, v0k * fcr - v1k * fci); + sk.set(i + 1, v0k * fci + v1k * fcr); + } + + } + + public static void ropeRotationPhi3(KernelContext context, IntArray positionHolder, FloatArray sq, FloatArray sk, int kv_dim, int head_size) { + int idx = context.globalIdx; + + // For Phi3, we process pairs with offset of head_size/2 + int dimHalf = head_size / 2; + + // Each thread processes one dimension pair + if (idx >= dimHalf) { + return; + } + + int position = positionHolder.get(0); + + // Calculate frequency for this dimension + float freq = 1.0f / TornadoMath.pow(10000.0f, (float) (idx * 2) / (float) head_size); + float val = position * freq; + float fcr = TornadoMath.cos(val); + float fci = TornadoMath.sin(val); + + // Process all heads + int totalDim = sq.getSize(); + for (int base = 0; base < totalDim; base += head_size) { + // Skip if we're beyond the bounds + if (base + idx >= totalDim || base + idx + dimHalf >= totalDim) { + break; + } + + // Rotate query + float v0 = sq.get(base + idx); + float v1 = sq.get(base + idx + dimHalf); + sq.set(base + idx, v0 * fcr - v1 * fci); + sq.set(base + idx + dimHalf, v0 * fci + v1 * fcr); + + // Rotate key if within kv_dim + if (base < kv_dim && base + idx < sk.getSize() && base + idx + dimHalf < sk.getSize()) { + float k0 = sk.get(base + idx); + float k1 = sk.get(base + idx + dimHalf); + sk.set(base + idx, k0 * fcr - k1 * fci); + sk.set(base + idx + dimHalf, k0 * fci + k1 * fcr); + } + } + } + + /** + * Computes attention for a single head. Implements scaled dot-product attention with softmax normalization. + * + * Steps: 1. Compute attention scores: Q·K / sqrt(head_size) 2. Apply softmax (with max subtraction for numerical stability) 3. Compute weighted sum of values + * + * @param allQ + * All query vectors + * @param key_cache + * Cached keys + * @param value_cache + * Cached values + * @param allXb + * Output buffer + * @param h + * Head index to process + * @param headSize + * Dimension per head + * @param kvDim + * Key/value dimension + * @param kvMul + * Key multiplier for grouped attention + * @param loff + * Layer offset in cache + * @param pos + * Current position + * @param wrapAtt + * Attention weights buffer + */ + private static void processHeadTornado(FloatArray allQ, FloatArray key_cache, FloatArray value_cache, FloatArray allXb, int h, int headSize, int kvDim, int kvMul, long loff, int pos, + FloatArray wrapAtt) { + + // Base index for this head's attention weights + int headOffset = h * (pos + 1); + + // STEP 1: Calculate attention scores for all timesteps + for (int t = 0; t <= pos; t++) { + int kvHeadIdx = h / kvMul; + int keyOffset = (int) (loff + t * kvDim + kvHeadIdx * headSize); + + float score = 0.0f; + for (int i = 0; i < headSize; i++) { + score += allQ.get(h * headSize + i) * key_cache.get(keyOffset + i); + } + score = score / TornadoMath.sqrt(headSize); + + // Store in attention buffer + wrapAtt.set(headOffset + t, score); + } + + // STEP 2: Find max score for softmax stability + float maxScore = wrapAtt.get(headOffset); + for (int t = 1; t <= pos; t++) { + float val = wrapAtt.get(headOffset + t); + if (val > maxScore) { + maxScore = val; + } + } + + // STEP 3: Compute exponentials and sum + float sum = 0.0f; + for (int t = 0; t <= pos; t++) { + int idx = headOffset + t; + float expScore = TornadoMath.exp(wrapAtt.get(idx) - maxScore); + wrapAtt.set(idx, expScore); + sum += expScore; + } + + // STEP 4: Normalize + float normFactor = (sum > 0.0f) ? (1.0f / sum) : (1.0f / (pos + 1)); + for (int t = 0; t <= pos; t++) { + int idx = headOffset + t; + wrapAtt.set(idx, wrapAtt.get(idx) * normFactor); + } + + // STEP 5: Compute weighted sum of values for each dimension + for (int i = 0; i < headSize; i++) { + float weightedSum = 0.0f; + for (int t = 0; t <= pos; t++) { + int kvHeadIdx = h / kvMul; + int valueOffset = (int) (loff + t * kvDim + kvHeadIdx * headSize); + weightedSum += wrapAtt.get(headOffset + t) * value_cache.get(valueOffset + i); + } + allXb.set(h * headSize + i, weightedSum); + } + } + + public static void processHeadsFlashAttention(KernelContext context, FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, + IntArray positionHolder, int layer, int contextLength) { + + // Thread and workgroup information + int tid = context.localIdx; + int h = context.groupIdx; // Each workgroup processes one head + int localSize = context.localGroupSizeX; + + // Early exit if this workgroup is beyond our head count + // This relies on the kernel being launched with nHeads workgroups. + if (h >= nHeads) { + return; + } + + int pos = positionHolder.get(0); + int loff = layer * contextLength * kvDim; + int kvHeadIdx = h / kvMul; + int BLOCK_SIZE_C = 8; + + // Allocate shared memory for tiled computation + float[] q_shared = context.allocateFloatLocalArray(headSize); + float[] k_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize); + float[] v_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize); + float[] s_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C); + float[] shared_tile_max_holder = context.allocateFloatLocalArray(1); // FIX: For broadcasting tile max + + // Thread-local accumulators for online softmax + float maxScore = Float.NEGATIVE_INFINITY; + float sumExp = 0.0f; + + // Thread-local output accumulation + float[] output = new float[headSize]; + for (int i = 0; i < headSize; i++) { + output[i] = 0.0f; + } + + // Load query vector into shared memory + for (int i = tid; i < headSize; i += localSize) { + q_shared[i] = q.get(h * headSize + i); + } + + context.localBarrier(); + + // Process sequence in tiles + for (int tileC = 0; tileC <= pos; tileC += BLOCK_SIZE_C) { + int tileEnd = Math.min(tileC + BLOCK_SIZE_C - 1, pos); + + // Load key and value vectors for this tile + // Each thread loads a portion of the K and V vectors for the tile + for (int tIdxInSeq = tileC + tid; tIdxInSeq <= tileEnd; tIdxInSeq += localSize) { + int k_v_idx_in_tile = tIdxInSeq - tileC; // 0, 1, 2, or 3 for this tile + int tileMemOffset = k_v_idx_in_tile * headSize; + for (int d = 0; d < headSize; d++) { + int kvCacheAbsolutePos = tIdxInSeq; + int kvOffset = loff + kvCacheAbsolutePos * kvDim + kvHeadIdx * headSize + d; + k_tile[tileMemOffset + d] = key_cache.get(kvOffset); + v_tile[tileMemOffset + d] = value_cache.get(kvOffset); + } + } + + context.localBarrier(); + + // Compute attention scores for this tile + // Each thread computes one score for the tile + for (int tIdxInSeq = tileC + tid; tIdxInSeq <= tileEnd; tIdxInSeq += localSize) { + int score_idx_in_tile = tIdxInSeq - tileC; // 0, 1, 2, or 3 for this tile + + float score = 0.0f; + for (int d = 0; d < headSize; d++) { + score += q_shared[d] * k_tile[score_idx_in_tile * headSize + d]; + } + score /= TornadoMath.sqrt(headSize); + s_tile[score_idx_in_tile] = score; + } + + context.localBarrier(); + + // Find max score in this tile (all threads compute it redundantly over the small s_tile) + float tileLocalMax = Float.NEGATIVE_INFINITY; + for (int i = 0; i <= tileEnd - tileC; i++) { // Iterate over valid scores in s_tile + if (s_tile[i] > tileLocalMax) { + tileLocalMax = s_tile[i]; + } + } + + // Broadcast max to all threads via shared memory + if (tid == 0) { + shared_tile_max_holder[0] = tileLocalMax; // FIX: Use dedicated holder + } + context.localBarrier(); + float currentTileMax = shared_tile_max_holder[0]; // FIX: Read from dedicated holder + + // Determine if we need to rescale previous results + float newMax = Math.max(maxScore, currentTileMax); + if (newMax != maxScore && maxScore != Float.NEGATIVE_INFINITY) { + float scale = TornadoMath.exp(maxScore - newMax); + sumExp *= scale; + for (int d = 0; d < headSize; d++) { + output[d] *= scale; + } + } + maxScore = newMax; + + // Process each key-value pair using original scores from s_tile + // All threads iterate over all scores in the current tile + for (int t_idx_in_s_tile = 0; t_idx_in_s_tile <= tileEnd - tileC; t_idx_in_s_tile++) { + // s_tile[t_idx_in_s_tile] now correctly refers to the original score + float expScore = TornadoMath.exp(s_tile[t_idx_in_s_tile] - maxScore); + sumExp += expScore; + + for (int d = 0; d < headSize; d++) { + output[d] += expScore * v_tile[t_idx_in_s_tile * headSize + d]; + } + } + context.localBarrier(); // Ensure all threads finish with s_tile, k_tile, v_tile before next tile load + } + + // Normalize and write final results + float normFactor = (sumExp > 0.0f) ? (1.0f / sumExp) : 0.0f; // Avoid division by zero, return 0 if sumExp is 0 + for (int d = tid; d < headSize; d += localSize) { + xb.set(h * headSize + d, output[d] * normFactor); + } + } + + /** + * Same as processHeadsFlashAttention but with some optimizations that seem to lower attention's execution time, especially in larger models. + */ + public static void processHeadsFlashAttentionOpt(KernelContext context, FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, + IntArray positionHolder, int layer, int contextLength) { + + // Thread and workgroup information + int tid = context.localIdx; + int h = context.groupIdx; // Each workgroup processes one head + int localSize = context.localGroupSizeX; + + // Early exit if this workgroup is beyond our head count + // This relies on the kernel being launched with nHeads workgroups. + if (h >= nHeads) { + return; + } + + int pos = positionHolder.get(0); + int loff = layer * contextLength * kvDim; + int kvHeadIdx = h / kvMul; + int BLOCK_SIZE_C = 32; + + // Allocate shared memory for tiled computation + float[] q_shared = context.allocateFloatLocalArray(headSize); + float[] k_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize); + float[] v_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize); + float[] s_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C); + float[] shared_tile_max_holder = context.allocateFloatLocalArray(1); // FIX: For broadcasting tile max + + // Thread-local accumulators for online softmax + float maxScore = Float.NEGATIVE_INFINITY; + float sumExp = 0.0f; + + // Thread-local output accumulation + float[] output = new float[headSize]; + for (int i = 0; i < headSize; i++) { + output[i] = 0.0f; + } + + // Load query vector into shared memory + for (int i = tid; i < headSize; i += localSize) { + q_shared[i] = q.get(h * headSize + i); + } + + context.localBarrier(); + + // Process sequence in tiles + for (int tileC = 0; tileC <= pos; tileC += BLOCK_SIZE_C) { + int tileEnd = Math.min(tileC + BLOCK_SIZE_C - 1, pos); + + // Load key and value vectors for this tile + // Each thread loads a contiguous block of elements + int totalElements = (tileEnd - tileC + 1) * headSize; + int elementsPerThread = (totalElements + localSize - 1) / localSize; + int startElem = tid * elementsPerThread; + int endElem = Math.min(startElem + elementsPerThread, totalElements); + + for (int globalElemIdx = startElem; globalElemIdx < endElem; globalElemIdx++) { + // Convert flat index to (sequence_pos, dimension) + int seqIdx = globalElemIdx / headSize; + int dimIdx = globalElemIdx % headSize; + + int tIdxInSeq = tileC + seqIdx; + int tileMemOffset = seqIdx * headSize + dimIdx; + + int kvCacheAbsolutePos = tIdxInSeq; + int kvOffset = loff + kvCacheAbsolutePos * kvDim + kvHeadIdx * headSize + dimIdx; + + k_tile[tileMemOffset] = key_cache.get(kvOffset); + v_tile[tileMemOffset] = value_cache.get(kvOffset); + } + + context.localBarrier(); + + // Compute attention scores for this tile + // Each thread computes one score for the tile + for (int tIdxInSeq = tileC + tid; tIdxInSeq <= tileEnd; tIdxInSeq += localSize) { + int score_idx_in_tile = tIdxInSeq - tileC; // 0, 1, 2, or 3 for this tile + + float score = 0.0f; + for (int d = 0; d < headSize; d++) { + score += q_shared[d] * k_tile[score_idx_in_tile * headSize + d]; + } + score /= TornadoMath.sqrt(headSize); + s_tile[score_idx_in_tile] = score; + } + + context.localBarrier(); + + // Allocate shared memory for reduction (needs to be power of 2) + int reductionSize = 1024; // Should be >= BLOCK_SIZE_C and power of 2 + float[] reduction_shared = context.allocateFloatLocalArray(reductionSize); + + // Step 1: Each thread finds max of its assigned subset + int itemsPerThread = (BLOCK_SIZE_C + localSize - 1) / localSize; + int startIdx = tid * itemsPerThread; + int endIdx = Math.min(startIdx + itemsPerThread, tileEnd - tileC + 1); + + float threadLocalMax = Float.NEGATIVE_INFINITY; + for (int i = startIdx; i < endIdx; i++) { + if (s_tile[i] > threadLocalMax) { + threadLocalMax = s_tile[i]; + } + } + + // Step 2: Store each thread's local max in shared memory + reduction_shared[tid] = threadLocalMax; + context.localBarrier(); + + // Step 3: Parallel reduction tree + for (int stride = localSize / 2; stride > 0; stride /= 2) { + if (tid < stride && tid + stride < localSize) { + reduction_shared[tid] = Math.max(reduction_shared[tid], reduction_shared[tid + stride]); + } + context.localBarrier(); + } + + // Step 4: Thread 0 now has the final max + float currentTileMax = reduction_shared[0]; + + // Determine if we need to rescale previous results + float newMax = Math.max(maxScore, currentTileMax); + if (newMax != maxScore && maxScore != Float.NEGATIVE_INFINITY) { + float scale = TornadoMath.exp(maxScore - newMax); + sumExp *= scale; + for (int d = 0; d < headSize; d++) { + output[d] *= scale; + } + } + maxScore = newMax; + + // Process each key-value pair using original scores from s_tile + // All threads iterate over all scores in the current tile + for (int t_idx_in_s_tile = 0; t_idx_in_s_tile <= tileEnd - tileC; t_idx_in_s_tile++) { + // s_tile[t_idx_in_s_tile] now correctly refers to the original score + float expScore = TornadoMath.exp(s_tile[t_idx_in_s_tile] - maxScore); + sumExp += expScore; + + for (int d = 0; d < headSize; d++) { + output[d] += expScore * v_tile[t_idx_in_s_tile * headSize + d]; + } + } + context.localBarrier(); // Ensure all threads finish with s_tile, k_tile, v_tile before next tile load + } + + float normFactor = (sumExp > 0.0f) ? (1.0f / sumExp) : 0.0f; + + int dimsPerThread = (headSize + localSize - 1) / localSize; + int startDim = tid * dimsPerThread; + int endDim = Math.min(startDim + dimsPerThread, headSize); + int baseOffset = h * headSize + startDim; + + // Process 4 elements at a time when possible + int vectorEnd = startDim + ((endDim - startDim) & ~3); // Round down to multiple of 4 + + // Unrolled loop for better instruction-level parallelism + for (int d = startDim; d < vectorEnd; d += 4) { + int offset = d - startDim; + xb.set(baseOffset + offset, output[d] * normFactor); + xb.set(baseOffset + offset + 1, output[d + 1] * normFactor); + xb.set(baseOffset + offset + 2, output[d + 2] * normFactor); + xb.set(baseOffset + offset + 3, output[d + 3] * normFactor); + } + + // Handle remaining elements (0-3 elements) + for (int d = vectorEnd; d < endDim; d++) { + xb.set(h * headSize + d, output[d] * normFactor); + } + } + + /** + * Performs optimized matrix-vector multiplication where each work group processes one row of the matrix. + * + * Algorithm: 1. Each work group handles one output dimension 2. Threads in work group compute partial dot products 3. Parallel reduction yields final row result + * + * @param context + * Kernel execution context + * @param x + * Input vector + * @param hb + * Output vector + * @param w + * Weight matrix (row-major) + * @param n + * Input dimension + * @param d + * Output dimension + * @param localWorkGroupSize + * Number of threads per work group + */ + public static void matrixVectorGeneric(KernelContext context, FloatArray x, FloatArray hb, FloatArray w, int n, int d, int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= d) { + return; + } + float sum = matrixVectorRowMajorOptimized(context, localSize, x, w, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + hb.set(rowId, sum); + } + } + + // @formatter:off + public static void matrixVectorGeneric( + KernelContext context, + FloatArray x, + FloatArray hb, // output + HalfFloatArray w, + int dim1, // inner loop + int dim0, // outer loop + int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= dim0) { + return; + } + float sum = matrixVectorRowMajorOptimized(context, localSize, x, w, dim1); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + hb.set(rowId, sum); + } + } + // @formatter:on + + /** + * Matrix-vector multiplication with residual connection. Combines regular matrix multiplication with addition of existing values. + * + * Formula: hb[i] = hb[i] + w[i]·x + * + * @param context + * Kernel execution context + * @param x + * Input vector + * @param hb + * Input/output vector (contains residual, receives result) + * @param w + * Weight matrix + * @param n + * Input dimension + * @param d + * Output dimension + * @param localWorkGroupSize + * Work group size + */ + public static void matrixVectorGenericWithResidual(KernelContext context, FloatArray x, FloatArray hb, HalfFloatArray w, int n, int d, int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= d) { + return; + } + + float sum = matrixVectorRowMajorOptimized(context, localSize, x, w, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + float result = hb.get(rowId) + sum; + hb.set(rowId, result); + } + } + + /** + * Fused feed-forward network with SiLU activation and GLU gating. Implements the SwiGLU variant used in LLaMA-style models. + * + * Formula: FFN(x) = SiLU(x·W1) ⊙ (x·W3) where ⊙ denotes element-wise multiplication + * + * @param context + * Kernel execution context + * @param x + * Input vector + * @param hb + * Output buffer + * @param w1 + * First feed-forward weight matrix + * @param w3 + * Third feed-forward weight matrix (gate) + * @param n + * Input dimension + * @param d + * Hidden dimension + * @param localWorkGroupSize + * Work group size + */ + public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, FloatArray x, FloatArray hb, HalfFloatArray w1, HalfFloatArray w3, int n, int d, int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + + if (rowId >= d) { + return; + } + + float sum1 = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, w1, n); + float sum3 = matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, w3, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + float silu = siluActivation(sum1); // Using the new SiLU method + float result = silu * sum3; + hb.set(rowId, result); + } + } + + /** + * Gaussian Error Linear Unit (GELU) activation function. Approximation formula: GELU(x) ≈ 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³))) + * + * @param x + * Input value + * @return Activated value + */ + public static float geluActivation(float x) { + float x3 = x * x * x; + return 0.5f * x * (1.0f + TornadoMath.tanh((0.797885f * (x + 0.044715f * x3)))); + } + + /** + * Sigmoid-weighted Linear Unit (SiLU) activation function. Also known as Swish activation. + * + * Formula: SiLU(x) = x * σ(x) = x / (1 + e^(-x)) + * + * @param x + * Input value + * @return Activated value + */ + public static float siluActivation(float x) { + return x * (1.0f / (1.0f + TornadoMath.exp(-x))); + } + + /** + * Optimized row-major matrix-vector multiplication for a single row. Uses parallel reduction within a work group to compute one dot product. + * + * Algorithm: 1. Each thread computes partial dot product 2. Partial results stored in local memory 3. Tree-based reduction combines partial results 4. Returns final dot product for the row + * + * @param context + * Kernel execution context + * @param localSize + * Work group size + * @param x + * Input vector + * @param w + * Weight matrix row + * @param n + * Input dimension + * @return Dot product result for this row + */ + public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, FloatArray x, FloatArray w, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Allocate local memory for reduction + float[] localSum = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + + // Each thread calculates partial dot product + float partialSum = 0.0f; + for (int j = localId; j < n; j += localSize) { + int matrixIdx = rowOffset + j; + partialSum += w.get(matrixIdx) * x.get(j); + } + + // Store partial sum in local memory + localSum[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction within workgroup + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + return localSum[0]; + } + + public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, FloatArray x, HalfFloatArray w, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Allocate local memory for reduction + float[] localSum = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + + // Each thread calculates partial dot product + float partialSum = 0.0f; + for (int j = localId; j < n; j += localSize) { + int matrixIdx = rowOffset + j; + partialSum += w.get(matrixIdx).getFloat32() * x.get(j); + } + + // Store partial sum in local memory + localSum[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction within workgroup + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + + return localSum[0]; + } + + // Second kernel - Combines partial sums and computes final normalization + public static void reductionFinalNormalization(KernelContext context, FloatArray output, int size, float ermsNorm) { + int gid = context.globalIdx; + + // Only one thread needs to perform this calculation + if (gid == 0) { + // Combine partial sums from all workgroups + float ss = 0.0f; + for (int i = 1; i < output.getSize(); i++) { // Fixed bounds to avoid out of bounds + ss += output.get(i); + } + + ss /= size; + ss += ermsNorm; + ss = 1.0f / TornadoMath.sqrt(ss); + output.set(0, ss); // Store the final scale factor + } + } + + public static void splitGateUpAndSiLU(FloatArray hb, FloatArray hbG, FloatArray hbU, int hiddenDim) { + // Copy and apply SiLU to gate in one pass + for (@Parallel int i = 0; i < hiddenDim; i++) { + float gateVal = hb.get(i); + float upVal = hb.get(hiddenDim + i); + + // Apply SiLU to gate + float siluGate = gateVal / (1.0f + TornadoMath.exp(-gateVal)); + + // Store activated gate and multiply with up + hbG.set(i, siluGate); + hbU.set(i, siluGate * upVal); + } + } + + public static void addInPlace(FloatArray arrayA, FloatArray arrayB, int size) { + // Element-wise addition: arrayA[i] = arrayA[i] + arrayB[i] + for (@Parallel int i = 0; i < size; i++) { + float result = arrayA.get(i) + arrayB.get(i); + arrayA.set(i, result); + } + } + + /** + * Matrix-vector multiplication for Q8_0 quantized weights. + * + * @param context + * Kernel context + * @param x + * Input activations (FloatArray) + * @param output + * Output array (FloatArray) + * @param weightsQ + * Quantized weights (Int8Array) - from Q8_0QuantizedTensor.getQuants() + * @param weightScales + * Scale factors (HalfFloatArray) - from Q8_0QuantizedTensor.getScales() + * @param dim1 + * Input dimension (n - number of columns) + * @param dim0 + * Output dimension (d - number of rows) + * @param localWorkGroupSize + * Local workgroup size + */ + public static void matrixVectorGeneric(KernelContext context, FloatArray x, FloatArray output, Int8Array weightsQ, HalfFloatArray weightScales, int dim1, int dim0, int localWorkGroupSize) { + + // One row per workgroup + int rowId = context.groupIdx; + int localId = context.localIdx; + + // Early exit if this workgroup is beyond output dimension + if (rowId >= dim0) { + return; + } + + float sum = matrixVectorRowMajorOptimizedQ8_0(context, localWorkGroupSize, x, weightsQ, weightScales, dim1); + + // Thread 0 writes the result + if (localId == 0) { + output.set(rowId, sum); + } + } + + /** + * Helper method to compute dot product for a single row with Q8_0 quantized weights. Uses 4-way unrolling for better performance. + */ + public static float matrixVectorRowMajorOptimizedQ8_0(KernelContext context, int localSize, FloatArray x, Int8Array weightsQ, HalfFloatArray weightScales, int n) { + int rowId = context.groupIdx; + int localId = context.localIdx; + int blockSize = 32; + + // Allocate local memory for reduction + float[] localSums = context.allocateFloatLocalArray(localSize); + + int rowOffset = rowId * n; + int scalesRowOffset = rowId * (n / blockSize); + + // 4-way unrolling + float partialSum1 = 0.0f; + float partialSum2 = 0.0f; + float partialSum3 = 0.0f; + float partialSum4 = 0.0f; + + // Main loop - process 4 elements at a time + for (int j = localId * 4; j < n - 3; j += localSize * 4) { + int blockIdx = j / blockSize; + float scale = weightScales.get(scalesRowOffset + blockIdx).getFloat32(); + + // Dequantize and multiply + partialSum1 += ((float) weightsQ.get(rowOffset + j) * scale) * x.get(j); + partialSum2 += ((float) weightsQ.get(rowOffset + j + 1) * scale) * x.get(j + 1); + partialSum3 += ((float) weightsQ.get(rowOffset + j + 2) * scale) * x.get(j + 2); + partialSum4 += ((float) weightsQ.get(rowOffset + j + 3) * scale) * x.get(j + 3); + } + + float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4; + + // Handle remaining elements + for (int j = ((n / 4) * 4) + localId; j < n; j += localSize) { + int blockIdx = j / blockSize; + float scale = weightScales.get(scalesRowOffset + blockIdx).getFloat32(); + partialSum += ((float) weightsQ.get(rowOffset + j) * scale) * x.get(j); + } + + // Store partial sum + localSums[localId] = partialSum; + context.localBarrier(); + + // Parallel reduction + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSums[localId] += localSums[localId + stride]; + } + context.localBarrier(); + } + + return localSums[0]; + } + + public static void matrixVectorGenericWithResidual(KernelContext context, FloatArray x, FloatArray hb, Int8Array w_quants, HalfFloatArray w_scales, int n, int d, int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= d) { + return; + } + + float sum = matrixVectorRowMajorOptimizedQ8_0(context, localSize, x, w_quants, w_scales, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + float result = hb.get(rowId) + sum; + hb.set(rowId, result); + } + } + + public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, FloatArray x, FloatArray hb, Int8Array w1_quants, HalfFloatArray w1_scales, Int8Array w3_quants, + HalfFloatArray w3_scales, int n, int d, int localWorkGroupSize) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + + if (rowId >= d) { + return; + } + + float sum1 = matrixVectorRowMajorOptimizedQ8_0(context, localWorkGroupSize, x, w1_quants, w1_scales, n); + float sum3 = matrixVectorRowMajorOptimizedQ8_0(context, localWorkGroupSize, x, w3_quants, w3_scales, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + float silu = siluActivation(sum1); // Using the new SiLU method + float result = silu * sum3; + hb.set(rowId, result); + } + } + + /** + * Orchestrates parallel multi-head attention computation across all heads. Each head processes attention independently in parallel. + * + * Attention computation: 1. Compute attention scores (Q·K) 2. Apply softmax for attention weights 3. Compute weighted sum of values (attention·V) + * + * @param q + * Query vectors for all heads + * @param key_cache + * Cached key vectors + * @param value_cache + * Cached value vectors + * @param xb + * Output buffer for attention results + * @param nHeads + * Number of attention heads + * @param headSize + * Dimension of each head + * @param kvDim + * Total key/value dimension + * @param kvMul + * Key/value head multiplier for grouped-query attention + * @param seqLen + * Current sequence length + * @param positionHolder + * Array containing position and layer info + * @param wrapAtt + * Buffer for attention weights + * @param layer + * Current transformer layer + * @param contextLength + * Maximum context length + */ + public static void processHeadsParallel(FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, int seqLen, + IntArray positionHolder, FloatArray wrapAtt, int layer, int contextLength) { + + int pos = positionHolder.get(0); + int loff = layer * contextLength * kvDim; + + // Parallelize computation across attention heads + for (@Parallel int h = 0; h < nHeads; h++) { + // Process each head in parallel + processHeadTornado(q, key_cache, value_cache, xb, h, headSize, kvDim, kvMul, loff, pos, wrapAtt); + } + } + +} diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/gpullama/TestTransformerKernelsFused.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/gpullama/TestTransformerKernelsFused.java new file mode 100644 index 0000000000..80a765a61c --- /dev/null +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/gpullama/TestTransformerKernelsFused.java @@ -0,0 +1,809 @@ +/* + * Copyright (c) 2025, APT Group, Department of Computer Science, + * The University of Manchester. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ +package uk.ac.manchester.tornado.unittests.gpullama; + +import static org.junit.Assert.assertEquals; + +import java.util.Random; + +import org.junit.Before; +import org.junit.Test; + +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.TornadoExecutionPlan; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.WorkerGrid1D; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; +import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException; +import uk.ac.manchester.tornado.api.types.HalfFloat; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; +import uk.ac.manchester.tornado.unittests.common.TornadoTestBase; + +/** + * Fused pipeline tests with full numerical verification. + * Tests progressive combinations of kernels matching the LlamaFP16FFNLayers pipeline. + * + *
How to run:
+ *
+ * tornado-test -V org.beehive.gpullama3.tornadovm.kernels.TestTransformerKernelsFused
+ *
+ */
+public class TestTransformerKernelsFused extends TornadoTestBase {
+
+ // Model configuration
+ private static final int DIM = 256;
+ private static final int KV_DIM = 64;
+ private static final int HIDDEN_DIM = 512;
+ private static final int N_HEADS = 4;
+ private static final int HEAD_SIZE = DIM / N_HEADS;
+ private static final int KV_MUL = N_HEADS / (KV_DIM / HEAD_SIZE);
+ private static final int CONTEXT_LENGTH = 128;
+ private static final float RMS_NORM_EPS = 1e-5f;
+ private static final int LOCAL_SIZE = 64;
+ private static final int LOCAL_SIZE_RMS = 256;
+
+ // Tolerances
+ private static final float EPSILON_FP32 = 1e-4f;
+ private static final float EPSILON_FP16 = 0.05f;
+ private static final float EPSILON_ACCUMULATED = 0.15f; // For multi-stage pipelines
+
+ private Random random;
+ private KernelContext context;
+
+ // State arrays
+ private FloatArray x, xb, xb2, q, k, v;
+ private FloatArray keyCache, valueCache, att, hb;
+ private FloatArray temp, tempFFN;
+ private IntArray positionHolder;
+
+ // Weights
+ private FloatArray rmsAttWeight, rmsFfnWeight;
+ private HalfFloatArray wq, wk, wv, wo, w1, w2, w3;
+
+ @Before
+ public void setUp() {
+ random = new Random(42);
+ context = new KernelContext();
+ initializeArrays();
+ initializeWeights();
+ }
+
+ private void initializeArrays() {
+ x = new FloatArray(DIM);
+ xb = new FloatArray(DIM);
+ xb2 = new FloatArray(DIM);
+ q = new FloatArray(DIM);
+ k = new FloatArray(KV_DIM);
+ v = new FloatArray(KV_DIM);
+ keyCache = new FloatArray(CONTEXT_LENGTH * KV_DIM);
+ valueCache = new FloatArray(CONTEXT_LENGTH * KV_DIM);
+ att = new FloatArray(N_HEADS * CONTEXT_LENGTH);
+ hb = new FloatArray(HIDDEN_DIM);
+
+ int numGroups = (DIM + LOCAL_SIZE_RMS - 1) / LOCAL_SIZE_RMS;
+ temp = new FloatArray(numGroups + 1);
+ tempFFN = new FloatArray(numGroups + 1);
+ positionHolder = new IntArray(1);
+
+ fillRandom(x, -1.0f, 1.0f);
+ temp.init(0.0f);
+ tempFFN.init(0.0f);
+ positionHolder.set(0, 5);
+ }
+
+ private void initializeWeights() {
+ rmsAttWeight = new FloatArray(DIM);
+ rmsFfnWeight = new FloatArray(DIM);
+ fillRandom(rmsAttWeight, 0.8f, 1.2f);
+ fillRandom(rmsFfnWeight, 0.8f, 1.2f);
+
+ wq = createRandomHalfFloatArray(DIM * DIM, -0.1f, 0.1f);
+ wk = createRandomHalfFloatArray(KV_DIM * DIM, -0.1f, 0.1f);
+ wv = createRandomHalfFloatArray(KV_DIM * DIM, -0.1f, 0.1f);
+ wo = createRandomHalfFloatArray(DIM * DIM, -0.1f, 0.1f);
+ w1 = createRandomHalfFloatArray(HIDDEN_DIM * DIM, -0.1f, 0.1f);
+ w2 = createRandomHalfFloatArray(DIM * HIDDEN_DIM, -0.1f, 0.1f);
+ w3 = createRandomHalfFloatArray(HIDDEN_DIM * DIM, -0.1f, 0.1f);
+ }
+
+ // ==================== Sequential Reference Implementations ====================
+
+ private float computeRmsNormFactorSequential(FloatArray x, float eps) {
+ float ss = 0.0f;
+ for (int i = 0; i < x.getSize(); i++) {
+ ss += x.get(i) * x.get(i);
+ }
+ ss /= x.getSize();
+ ss += eps;
+ return 1.0f / (float) Math.sqrt(ss);
+ }
+
+ private void applyRmsNormSequential(FloatArray output, FloatArray x, FloatArray weights, float normFactor) {
+ for (int i = 0; i < x.getSize(); i++) {
+ output.set(i, weights.get(i) * (normFactor * x.get(i)));
+ }
+ }
+
+ private void matrixVectorSequentialFP16(FloatArray output, HalfFloatArray weights, FloatArray x, int rows, int cols) {
+ for (int i = 0; i < rows; i++) {
+ float sum = 0.0f;
+ for (int j = 0; j < cols; j++) {
+ sum += weights.get(i * cols + j).getFloat32() * x.get(j);
+ }
+ output.set(i, sum);
+ }
+ }
+
+ private void matrixVectorWithResidualSequential(FloatArray output, HalfFloatArray weights, FloatArray x, int rows, int cols) {
+ for (int i = 0; i < rows; i++) {
+ float sum = 0.0f;
+ for (int j = 0; j < cols; j++) {
+ sum += weights.get(i * cols + j).getFloat32() * x.get(j);
+ }
+ output.set(i, output.get(i) + sum);
+ }
+ }
+
+ private void ropeRotationSequential(FloatArray sq, FloatArray sk, int pos, int kvDim, int headSize) {
+ int numPairs = sq.getSize() / 2;
+ for (int i = 0; i < numPairs; i++) {
+ int idx = i * 2;
+ int headDim = idx % headSize;
+ float freq = 1.0f / (float) Math.pow(50000.0f, headDim / (float) headSize);
+ float val = pos * freq;
+ float fcr = (float) Math.cos(val);
+ float fci = (float) Math.sin(val);
+
+ float v0q = sq.get(idx);
+ float v1q = sq.get(idx + 1);
+ sq.set(idx, v0q * fcr - v1q * fci);
+ sq.set(idx + 1, v0q * fci + v1q * fcr);
+
+ if (idx < kvDim && idx + 1 < sk.getSize()) {
+ float v0k = sk.get(idx);
+ float v1k = sk.get(idx + 1);
+ sk.set(idx, v0k * fcr - v1k * fci);
+ sk.set(idx + 1, v0k * fci + v1k * fcr);
+ }
+ }
+ }
+
+ private void copyToCacheSequential(FloatArray keyCache, FloatArray key, FloatArray valueCache, FloatArray value, int position, int kvDim, int layer, int contextLength) {
+ int loff = layer * contextLength * kvDim;
+ int destOffset = loff + position * kvDim;
+ for (int i = 0; i < key.getSize(); i++) {
+ keyCache.set(destOffset + i, key.get(i));
+ valueCache.set(destOffset + i, value.get(i));
+ }
+ }
+
+ private void processHeadSequential(FloatArray q, FloatArray keyCache, FloatArray valueCache, FloatArray xb, int h, int headSize, int kvDim, int kvMul, int loff, int pos) {
+ int kvHeadIdx = h / kvMul;
+
+ float[] attScores = new float[pos + 1];
+ for (int t = 0; t <= pos; t++) {
+ int keyOffset = loff + t * kvDim + kvHeadIdx * headSize;
+ float score = 0.0f;
+ for (int i = 0; i < headSize; i++) {
+ score += q.get(h * headSize + i) * keyCache.get(keyOffset + i);
+ }
+ attScores[t] = score / (float) Math.sqrt(headSize);
+ }
+
+ float maxScore = attScores[0];
+ for (int t = 1; t <= pos; t++) {
+ if (attScores[t] > maxScore)
+ maxScore = attScores[t];
+ }
+
+ float sumExp = 0.0f;
+ for (int t = 0; t <= pos; t++) {
+ attScores[t] = (float) Math.exp(attScores[t] - maxScore);
+ sumExp += attScores[t];
+ }
+
+ for (int t = 0; t <= pos; t++) {
+ attScores[t] /= sumExp;
+ }
+
+ for (int i = 0; i < headSize; i++) {
+ float weightedSum = 0.0f;
+ for (int t = 0; t <= pos; t++) {
+ int valueOffset = loff + t * kvDim + kvHeadIdx * headSize;
+ weightedSum += attScores[t] * valueCache.get(valueOffset + i);
+ }
+ xb.set(h * headSize + i, weightedSum);
+ }
+ }
+
+ private void processHeadsSequential(FloatArray q, FloatArray keyCache, FloatArray valueCache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, int pos, int layer,
+ int contextLength) {
+ int loff = layer * contextLength * kvDim;
+ for (int h = 0; h < nHeads; h++) {
+ processHeadSequential(q, keyCache, valueCache, xb, h, headSize, kvDim, kvMul, loff, pos);
+ }
+ }
+
+ private float siluSequential(float x) {
+ return x * (1.0f / (1.0f + (float) Math.exp(-x)));
+ }
+
+ private void fusedFFNSequential(FloatArray output, FloatArray x, HalfFloatArray w1, HalfFloatArray w3, int inputDim, int hiddenDim) {
+ for (int i = 0; i < hiddenDim; i++) {
+ float sum1 = 0.0f;
+ float sum3 = 0.0f;
+ for (int j = 0; j < inputDim; j++) {
+ sum1 += w1.get(i * inputDim + j).getFloat32() * x.get(j);
+ sum3 += w3.get(i * inputDim + j).getFloat32() * x.get(j);
+ }
+ float silu = siluSequential(sum1);
+ output.set(i, silu * sum3);
+ }
+ }
+
+ // ==================== Helper Methods ====================
+
+ private void fillRandom(FloatArray array, float min, float max) {
+ float range = max - min;
+ for (int i = 0; i < array.getSize(); i++) {
+ array.set(i, min + random.nextFloat() * range);
+ }
+ }
+
+ private FloatArray copyArray(FloatArray src) {
+ FloatArray dst = new FloatArray(src.getSize());
+ for (int i = 0; i < src.getSize(); i++) {
+ dst.set(i, src.get(i));
+ }
+ return dst;
+ }
+
+ private HalfFloatArray createRandomHalfFloatArray(int size, float min, float max) {
+ HalfFloatArray array = new HalfFloatArray(size);
+ float range = max - min;
+ for (int i = 0; i < size; i++) {
+ array.set(i, new HalfFloat(min + random.nextFloat() * range));
+ }
+ return array;
+ }
+
+ private void assertArrayEquals(String message, FloatArray expected, FloatArray actual, float tolerance) {
+ assertEquals(message + " - size mismatch", expected.getSize(), actual.getSize());
+ for (int i = 0; i < expected.getSize(); i++) {
+ assertEquals(message + " at index " + i, expected.get(i), actual.get(i), tolerance);
+ }
+ }
+
+ private GridScheduler createScheduler() {
+ WorkerGrid rmsWorker = new WorkerGrid1D(DIM);
+ rmsWorker.setLocalWork(LOCAL_SIZE_RMS, 1, 1);
+
+ WorkerGrid qMatmulWorker = new WorkerGrid1D(DIM * LOCAL_SIZE);
+ qMatmulWorker.setLocalWork(LOCAL_SIZE, 1, 1);
+
+ WorkerGrid kvMatmulWorker = new WorkerGrid1D(KV_DIM * LOCAL_SIZE);
+ kvMatmulWorker.setLocalWork(LOCAL_SIZE, 1, 1);
+
+ WorkerGrid ropeWorker = new WorkerGrid1D(DIM / 2);
+ ropeWorker.setLocalWork(Math.min(128, DIM / 2), 1, 1);
+
+ WorkerGrid ffnWorker = new WorkerGrid1D(HIDDEN_DIM * LOCAL_SIZE);
+ ffnWorker.setLocalWork(LOCAL_SIZE, 1, 1);
+
+ GridScheduler scheduler = new GridScheduler();
+ scheduler.addWorkerGrid("s0.rmsReduce", rmsWorker);
+ scheduler.addWorkerGrid("s0.rmsApply", rmsWorker);
+ scheduler.addWorkerGrid("s0.qmatmul", qMatmulWorker);
+ scheduler.addWorkerGrid("s0.kmatmul", kvMatmulWorker);
+ scheduler.addWorkerGrid("s0.vmatmul", kvMatmulWorker);
+ scheduler.addWorkerGrid("s0.rope", ropeWorker);
+ scheduler.addWorkerGrid("s0.outputProj", qMatmulWorker);
+ scheduler.addWorkerGrid("s0.rmsReduceFFN", rmsWorker);
+ scheduler.addWorkerGrid("s0.rmsApplyFFN", rmsWorker);
+ scheduler.addWorkerGrid("s0.fusedFFN", ffnWorker);
+ scheduler.addWorkerGrid("s0.ffnProj", qMatmulWorker);
+
+ return scheduler;
+ }
+
+ // ==================== Stage 1: RMS Normalization ====================
+
+ @Test
+ public void testFusedStage1_RMSNorm() throws TornadoExecutionPlanException {
+ // Sequential reference
+ FloatArray expectedXb = new FloatArray(DIM);
+ float normFactor = computeRmsNormFactorSequential(x, RMS_NORM_EPS);
+ applyRmsNormSequential(expectedXb, x, rmsAttWeight, normFactor);
+
+ GridScheduler scheduler = createScheduler();
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, rmsAttWeight, temp)
+ .task("rmsReduce", GPULlama3Kernels::reductionOneBlockWithLayer, context, temp, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS)
+ .task("rmsApply", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsAttWeight, temp)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, xb, temp);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.withGridScheduler(scheduler).execute();
+ }
+
+ assertEquals("RMS norm factor", normFactor, temp.get(0), 0.01f);
+ assertArrayEquals("Stage 1: RMS Norm output", expectedXb, xb, EPSILON_FP32);
+ }
+
+ // ==================== Stage 2: RMS Norm + Q Projection ====================
+
+ @Test
+ public void testFusedStage2_RMSNorm_QMatmul() throws TornadoExecutionPlanException {
+ // Sequential reference
+ FloatArray expectedXb = new FloatArray(DIM);
+ FloatArray expectedQ = new FloatArray(DIM);
+
+ float normFactor = computeRmsNormFactorSequential(x, RMS_NORM_EPS);
+ applyRmsNormSequential(expectedXb, x, rmsAttWeight, normFactor);
+ matrixVectorSequentialFP16(expectedQ, wq, expectedXb, DIM, DIM);
+
+ GridScheduler scheduler = createScheduler();
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, rmsAttWeight, temp, wq)
+ .task("rmsReduce", GPULlama3Kernels::reductionOneBlockWithLayer, context, temp, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS)
+ .task("rmsApply", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsAttWeight, temp)
+ .task("qmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, q, wq, DIM, DIM, LOCAL_SIZE)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, q);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.withGridScheduler(scheduler).execute();
+ }
+
+ assertArrayEquals("Stage 2: Q projection output", expectedQ, q, EPSILON_FP16);
+ }
+
+ // ==================== Stage 3: RMS Norm + QKV Projections ====================
+
+ @Test
+ public void testFusedStage3_RMSNorm_QKVMatmul() throws TornadoExecutionPlanException {
+ // Sequential reference
+ FloatArray expectedXb = new FloatArray(DIM);
+ FloatArray expectedQ = new FloatArray(DIM);
+ FloatArray expectedK = new FloatArray(KV_DIM);
+ FloatArray expectedV = new FloatArray(KV_DIM);
+
+ float normFactor = computeRmsNormFactorSequential(x, RMS_NORM_EPS);
+ applyRmsNormSequential(expectedXb, x, rmsAttWeight, normFactor);
+ matrixVectorSequentialFP16(expectedQ, wq, expectedXb, DIM, DIM);
+ matrixVectorSequentialFP16(expectedK, wk, expectedXb, KV_DIM, DIM);
+ matrixVectorSequentialFP16(expectedV, wv, expectedXb, KV_DIM, DIM);
+
+ GridScheduler scheduler = createScheduler();
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, rmsAttWeight, temp, wq, wk, wv)
+ .task("rmsReduce", GPULlama3Kernels::reductionOneBlockWithLayer, context, temp, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS)
+ .task("rmsApply", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsAttWeight, temp)
+ .task("qmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, q, wq, DIM, DIM, LOCAL_SIZE)
+ .task("kmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, k, wk, DIM, KV_DIM, LOCAL_SIZE)
+ .task("vmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, v, wv, DIM, KV_DIM, LOCAL_SIZE)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, q, k, v);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.withGridScheduler(scheduler).execute();
+ }
+
+ assertArrayEquals("Stage 3: Q output", expectedQ, q, EPSILON_FP16);
+ assertArrayEquals("Stage 3: K output", expectedK, k, EPSILON_FP16);
+ assertArrayEquals("Stage 3: V output", expectedV, v, EPSILON_FP16);
+ }
+
+ // ==================== Stage 4: QKV + RoPE ====================
+
+ @Test
+ public void testFusedStage4_QKV_RoPE() throws TornadoExecutionPlanException {
+ int position = positionHolder.get(0);
+
+ // Sequential reference
+ FloatArray expectedXb = new FloatArray(DIM);
+ FloatArray expectedQ = new FloatArray(DIM);
+ FloatArray expectedK = new FloatArray(KV_DIM);
+
+ float normFactor = computeRmsNormFactorSequential(x, RMS_NORM_EPS);
+ applyRmsNormSequential(expectedXb, x, rmsAttWeight, normFactor);
+ matrixVectorSequentialFP16(expectedQ, wq, expectedXb, DIM, DIM);
+ matrixVectorSequentialFP16(expectedK, wk, expectedXb, KV_DIM, DIM);
+ ropeRotationSequential(expectedQ, expectedK, position, KV_DIM, HEAD_SIZE);
+
+ GridScheduler scheduler = createScheduler();
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(DataTransferMode.EVERY_EXECUTION, x, rmsAttWeight, temp, wq, wk, wv, positionHolder)
+ .task("rmsReduce", GPULlama3Kernels::reductionOneBlockWithLayer, context, temp, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS)
+ .task("rmsApply", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsAttWeight, temp)
+ .task("qmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, q, wq, DIM, DIM, LOCAL_SIZE)
+ .task("kmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, k, wk, DIM, KV_DIM, LOCAL_SIZE)
+ .task("vmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, v, wv, DIM, KV_DIM, LOCAL_SIZE)
+ .task("rope", GPULlama3Kernels::ropeRotation, context, positionHolder, q, k, KV_DIM, HEAD_SIZE)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, q, k);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.withGridScheduler(scheduler).execute();
+ }
+
+ assertArrayEquals("Stage 4: Q after RoPE", expectedQ, q, EPSILON_FP16);
+ assertArrayEquals("Stage 4: K after RoPE", expectedK, k, EPSILON_FP16);
+ }
+
+ // ==================== Stage 5: QKV + RoPE + Cache Update ====================
+
+ @Test
+ public void testFusedStage5_QKV_RoPE_Cache() throws TornadoExecutionPlanException {
+ final int layer = 0;
+ int position = positionHolder.get(0);
+
+ // Sequential reference
+ FloatArray expectedXb = new FloatArray(DIM);
+ FloatArray expectedQ = new FloatArray(DIM);
+ FloatArray expectedK = new FloatArray(KV_DIM);
+ FloatArray expectedV = new FloatArray(KV_DIM);
+ FloatArray expectedKeyCache = new FloatArray(CONTEXT_LENGTH * KV_DIM);
+ FloatArray expectedValueCache = new FloatArray(CONTEXT_LENGTH * KV_DIM);
+
+ float normFactor = computeRmsNormFactorSequential(x, RMS_NORM_EPS);
+ applyRmsNormSequential(expectedXb, x, rmsAttWeight, normFactor);
+ matrixVectorSequentialFP16(expectedQ, wq, expectedXb, DIM, DIM);
+ matrixVectorSequentialFP16(expectedK, wk, expectedXb, KV_DIM, DIM);
+ matrixVectorSequentialFP16(expectedV, wv, expectedXb, KV_DIM, DIM);
+ ropeRotationSequential(expectedQ, expectedK, position, KV_DIM, HEAD_SIZE);
+ copyToCacheSequential(expectedKeyCache, expectedK, expectedValueCache, expectedV, position, KV_DIM, layer, CONTEXT_LENGTH);
+
+ GridScheduler scheduler = createScheduler();
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, rmsAttWeight, temp, wq, wk, wv, positionHolder, keyCache, valueCache)
+ .task("rmsReduce", GPULlama3Kernels::reductionOneBlockWithLayer, context, temp, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS)
+ .task("rmsApply", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsAttWeight, temp)
+ .task("qmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, q, wq, DIM, DIM, LOCAL_SIZE)
+ .task("kmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, k, wk, DIM, KV_DIM, LOCAL_SIZE)
+ .task("vmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, v, wv, DIM, KV_DIM, LOCAL_SIZE)
+ .task("rope", GPULlama3Kernels::ropeRotation, context, positionHolder, q, k, KV_DIM, HEAD_SIZE)
+ .task("copyToCache", GPULlama3Kernels::copyToCache, keyCache, k, valueCache, v, positionHolder, KV_DIM, layer, CONTEXT_LENGTH)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, keyCache, valueCache);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.withGridScheduler(scheduler).execute();
+ }
+
+ // Verify cache at the specific position
+ int destOffset = position * KV_DIM;
+ for (int i = 0; i < KV_DIM; i++) {
+ assertEquals("Stage 5: Key cache at " + i, expectedKeyCache.get(destOffset + i), keyCache.get(destOffset + i), EPSILON_FP16);
+ assertEquals("Stage 5: Value cache at " + i, expectedValueCache.get(destOffset + i), valueCache.get(destOffset + i), EPSILON_FP16);
+ }
+ }
+
+ // ==================== Stage 6: Full Attention Block ====================
+
+ @Test
+ public void testFusedStage6_FullAttentionBlock() throws TornadoExecutionPlanException {
+ final int layer = 0;
+ int position = positionHolder.get(0);
+
+ // Sequential reference
+ FloatArray seqXb = new FloatArray(DIM);
+ FloatArray seqQ = new FloatArray(DIM);
+ FloatArray seqK = new FloatArray(KV_DIM);
+ FloatArray seqV = new FloatArray(KV_DIM);
+ FloatArray seqKeyCache = new FloatArray(CONTEXT_LENGTH * KV_DIM);
+ FloatArray seqValueCache = new FloatArray(CONTEXT_LENGTH * KV_DIM);
+ FloatArray seqAttOut = new FloatArray(DIM);
+ FloatArray expectedX = copyArray(x);
+
+ float normFactor = computeRmsNormFactorSequential(expectedX, RMS_NORM_EPS);
+ applyRmsNormSequential(seqXb, expectedX, rmsAttWeight, normFactor);
+ matrixVectorSequentialFP16(seqQ, wq, seqXb, DIM, DIM);
+ matrixVectorSequentialFP16(seqK, wk, seqXb, KV_DIM, DIM);
+ matrixVectorSequentialFP16(seqV, wv, seqXb, KV_DIM, DIM);
+ ropeRotationSequential(seqQ, seqK, position, KV_DIM, HEAD_SIZE);
+ copyToCacheSequential(seqKeyCache, seqK, seqValueCache, seqV, position, KV_DIM, layer, CONTEXT_LENGTH);
+ processHeadsSequential(seqQ, seqKeyCache, seqValueCache, seqAttOut, N_HEADS, HEAD_SIZE, KV_DIM, KV_MUL, position, layer, CONTEXT_LENGTH);
+ // Copy seqAttOut to seqXb for output projection
+ for (int i = 0; i < DIM; i++)
+ seqXb.set(i, seqAttOut.get(i));
+ matrixVectorWithResidualSequential(expectedX, wo, seqXb, DIM, DIM);
+
+ GridScheduler scheduler = createScheduler();
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, rmsAttWeight, temp, wq, wk, wv, wo, positionHolder, keyCache, valueCache, att)
+ .task("rmsReduce", GPULlama3Kernels::reductionOneBlockWithLayer, context, temp, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS)
+ .task("rmsApply", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsAttWeight, temp)
+ .task("qmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, q, wq, DIM, DIM, LOCAL_SIZE)
+ .task("kmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, k, wk, DIM, KV_DIM, LOCAL_SIZE)
+ .task("vmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, v, wv, DIM, KV_DIM, LOCAL_SIZE)
+ .task("rope", GPULlama3Kernels::ropeRotation, context, positionHolder, q, k, KV_DIM, HEAD_SIZE)
+ .task("copyToCache", GPULlama3Kernels::copyToCache, keyCache, k, valueCache, v, positionHolder, KV_DIM, layer, CONTEXT_LENGTH)
+ .task("attention", GPULlama3Kernels::processHeadsParallel, q, keyCache, valueCache, xb, N_HEADS, HEAD_SIZE, KV_DIM, KV_MUL, CONTEXT_LENGTH, positionHolder, att, layer, CONTEXT_LENGTH)
+ .task("outputProj", GPULlama3Kernels::matrixVectorGenericWithResidual, context, xb, x, wo, DIM, DIM, LOCAL_SIZE)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, x);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.withGridScheduler(scheduler).execute();
+ }
+
+ assertArrayEquals("Stage 6: Full attention block output", expectedX, x, EPSILON_ACCUMULATED);
+ }
+
+ // ==================== Stage 7: Attention + FFN RMS Norm ====================
+
+ @Test
+ public void testFusedStage7_AttentionBlock_FFNRMSNorm() throws TornadoExecutionPlanException {
+ final int layer = 0;
+ int position = positionHolder.get(0);
+
+ // Sequential reference (build on Stage 6)
+ FloatArray seqXb = new FloatArray(DIM);
+ FloatArray seqQ = new FloatArray(DIM);
+ FloatArray seqK = new FloatArray(KV_DIM);
+ FloatArray seqV = new FloatArray(KV_DIM);
+ FloatArray seqKeyCache = new FloatArray(CONTEXT_LENGTH * KV_DIM);
+ FloatArray seqValueCache = new FloatArray(CONTEXT_LENGTH * KV_DIM);
+ FloatArray seqAttOut = new FloatArray(DIM);
+ FloatArray seqX = copyArray(x);
+ FloatArray expectedXb = new FloatArray(DIM);
+
+ // Attention block
+ float normFactor = computeRmsNormFactorSequential(seqX, RMS_NORM_EPS);
+ applyRmsNormSequential(seqXb, seqX, rmsAttWeight, normFactor);
+ matrixVectorSequentialFP16(seqQ, wq, seqXb, DIM, DIM);
+ matrixVectorSequentialFP16(seqK, wk, seqXb, KV_DIM, DIM);
+ matrixVectorSequentialFP16(seqV, wv, seqXb, KV_DIM, DIM);
+ ropeRotationSequential(seqQ, seqK, position, KV_DIM, HEAD_SIZE);
+ copyToCacheSequential(seqKeyCache, seqK, seqValueCache, seqV, position, KV_DIM, layer, CONTEXT_LENGTH);
+ processHeadsSequential(seqQ, seqKeyCache, seqValueCache, seqAttOut, N_HEADS, HEAD_SIZE, KV_DIM, KV_MUL, position, layer, CONTEXT_LENGTH);
+ for (int i = 0; i < DIM; i++)
+ seqXb.set(i, seqAttOut.get(i));
+ matrixVectorWithResidualSequential(seqX, wo, seqXb, DIM, DIM);
+
+ // FFN RMS norm
+ float ffnNormFactor = computeRmsNormFactorSequential(seqX, RMS_NORM_EPS);
+ applyRmsNormSequential(expectedXb, seqX, rmsFfnWeight, ffnNormFactor);
+
+ GridScheduler scheduler = createScheduler();
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, rmsAttWeight, rmsFfnWeight, temp, tempFFN, wq, wk, wv, wo, positionHolder, keyCache, valueCache, att)
+ .task("rmsReduce", GPULlama3Kernels::reductionOneBlockWithLayer, context, temp, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS)
+ .task("rmsApply", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsAttWeight, temp)
+ .task("qmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, q, wq, DIM, DIM, LOCAL_SIZE)
+ .task("kmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, k, wk, DIM, KV_DIM, LOCAL_SIZE)
+ .task("vmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, v, wv, DIM, KV_DIM, LOCAL_SIZE)
+ .task("rope", GPULlama3Kernels::ropeRotation, context, positionHolder, q, k, KV_DIM, HEAD_SIZE)
+ .task("copyToCache", GPULlama3Kernels::copyToCache, keyCache, k, valueCache, v, positionHolder, KV_DIM, layer, CONTEXT_LENGTH)
+ .task("attention", GPULlama3Kernels::processHeadsParallel, q, keyCache, valueCache, xb, N_HEADS, HEAD_SIZE, KV_DIM, KV_MUL, CONTEXT_LENGTH, positionHolder, att, layer, CONTEXT_LENGTH)
+ .task("outputProj", GPULlama3Kernels::matrixVectorGenericWithResidual, context, xb, x, wo, DIM, DIM, LOCAL_SIZE)
+ .task("rmsReduceFFN", GPULlama3Kernels::reductionOneBlockWithLayer, context, tempFFN, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS)
+ .task("rmsApplyFFN", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsFfnWeight, tempFFN)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, xb);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.withGridScheduler(scheduler).execute();
+ }
+
+ assertArrayEquals("Stage 7: FFN input (after RMS norm)", expectedXb, xb, EPSILON_ACCUMULATED);
+ }
+
+ // ==================== Stage 8: Complete Transformer Layer ====================
+
+ @Test
+ public void testFusedStage8_CompleteTransformerLayer() throws TornadoExecutionPlanException {
+ final int layer = 0;
+ int position = positionHolder.get(0);
+
+ // Full sequential reference
+ FloatArray seqXb = new FloatArray(DIM);
+ FloatArray seqQ = new FloatArray(DIM);
+ FloatArray seqK = new FloatArray(KV_DIM);
+ FloatArray seqV = new FloatArray(KV_DIM);
+ FloatArray seqKeyCache = new FloatArray(CONTEXT_LENGTH * KV_DIM);
+ FloatArray seqValueCache = new FloatArray(CONTEXT_LENGTH * KV_DIM);
+ FloatArray seqAttOut = new FloatArray(DIM);
+ FloatArray seqHb = new FloatArray(HIDDEN_DIM);
+ FloatArray expectedX = copyArray(x);
+
+ // Attention block
+ float normFactor = computeRmsNormFactorSequential(expectedX, RMS_NORM_EPS);
+ applyRmsNormSequential(seqXb, expectedX, rmsAttWeight, normFactor);
+ matrixVectorSequentialFP16(seqQ, wq, seqXb, DIM, DIM);
+ matrixVectorSequentialFP16(seqK, wk, seqXb, KV_DIM, DIM);
+ matrixVectorSequentialFP16(seqV, wv, seqXb, KV_DIM, DIM);
+ ropeRotationSequential(seqQ, seqK, position, KV_DIM, HEAD_SIZE);
+ copyToCacheSequential(seqKeyCache, seqK, seqValueCache, seqV, position, KV_DIM, layer, CONTEXT_LENGTH);
+ processHeadsSequential(seqQ, seqKeyCache, seqValueCache, seqAttOut, N_HEADS, HEAD_SIZE, KV_DIM, KV_MUL, position, layer, CONTEXT_LENGTH);
+ for (int i = 0; i < DIM; i++)
+ seqXb.set(i, seqAttOut.get(i));
+ matrixVectorWithResidualSequential(expectedX, wo, seqXb, DIM, DIM);
+
+ // FFN block
+ float ffnNormFactor = computeRmsNormFactorSequential(expectedX, RMS_NORM_EPS);
+ applyRmsNormSequential(seqXb, expectedX, rmsFfnWeight, ffnNormFactor);
+ fusedFFNSequential(seqHb, seqXb, w1, w3, DIM, HIDDEN_DIM);
+ matrixVectorWithResidualSequential(expectedX, w2, seqHb, DIM, HIDDEN_DIM);
+
+ GridScheduler scheduler = createScheduler();
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, rmsAttWeight, rmsFfnWeight, temp, tempFFN, wq, wk, wv, wo, w1, w2, w3, positionHolder, keyCache, valueCache, att)
+ .task("rmsReduce", GPULlama3Kernels::reductionOneBlockWithLayer, context, temp, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS)
+ .task("rmsApply", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsAttWeight, temp)
+ .task("qmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, q, wq, DIM, DIM, LOCAL_SIZE)
+ .task("kmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, k, wk, DIM, KV_DIM, LOCAL_SIZE)
+ .task("vmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, v, wv, DIM, KV_DIM, LOCAL_SIZE)
+ .task("rope", GPULlama3Kernels::ropeRotation, context, positionHolder, q, k, KV_DIM, HEAD_SIZE)
+ .task("copyToCache", GPULlama3Kernels::copyToCache, keyCache, k, valueCache, v, positionHolder, KV_DIM, layer, CONTEXT_LENGTH)
+ .task("attention", GPULlama3Kernels::processHeadsParallel, q, keyCache, valueCache, xb, N_HEADS, HEAD_SIZE, KV_DIM, KV_MUL, CONTEXT_LENGTH, positionHolder, att, layer, CONTEXT_LENGTH)
+ .task("outputProj", GPULlama3Kernels::matrixVectorGenericWithResidual, context, xb, x, wo, DIM, DIM, LOCAL_SIZE)
+ .task("rmsReduceFFN", GPULlama3Kernels::reductionOneBlockWithLayer, context, tempFFN, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS)
+ .task("rmsApplyFFN", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsFfnWeight, tempFFN)
+ .task("fusedFFN", GPULlama3Kernels::fusedFeedForwardWithSiLUAndGLUActivation, context, xb, hb, w1, w3, DIM, HIDDEN_DIM, LOCAL_SIZE)
+ .task("ffnProj", GPULlama3Kernels::matrixVectorGenericWithResidual, context, hb, x, w2, HIDDEN_DIM, DIM, LOCAL_SIZE)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, x);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.withGridScheduler(scheduler).execute();
+ }
+
+ assertArrayEquals("Stage 8: Complete transformer layer output", expectedX, x, EPSILON_ACCUMULATED);
+
+ }
+
+ // ==================== Multi-Iteration Test ====================
+
+ @Test
+ public void testFusedMultipleIterations() throws TornadoExecutionPlanException {
+ final int layer = 0;
+ final int numIterations = 3;
+
+ GridScheduler scheduler = createScheduler();
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, rmsAttWeight, rmsFfnWeight, temp, tempFFN, wq, wk, wv, wo, w1, w2, w3, positionHolder, keyCache, valueCache, att)
+ .task("rmsReduce", GPULlama3Kernels::reductionOneBlockWithLayer, context, temp, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS)
+ .task("rmsApply", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsAttWeight, temp)
+ .task("qmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, q, wq, DIM, DIM, LOCAL_SIZE)
+ .task("kmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, k, wk, DIM, KV_DIM, LOCAL_SIZE)
+ .task("vmatmul", GPULlama3Kernels::matrixVectorGeneric, context, xb, v, wv, DIM, KV_DIM, LOCAL_SIZE)
+ .task("rope", GPULlama3Kernels::ropeRotation, context, positionHolder, q, k, KV_DIM, HEAD_SIZE)
+ .task("copyToCache", GPULlama3Kernels::copyToCache, keyCache, k, valueCache, v, positionHolder, KV_DIM, layer, CONTEXT_LENGTH)
+ .task("attention", GPULlama3Kernels::processHeadsParallel, q, keyCache, valueCache, xb, N_HEADS, HEAD_SIZE, KV_DIM, KV_MUL, CONTEXT_LENGTH, positionHolder, att, layer, CONTEXT_LENGTH)
+ .task("outputProj", GPULlama3Kernels::matrixVectorGenericWithResidual, context, xb, x, wo, DIM, DIM, LOCAL_SIZE)
+ .task("rmsReduceFFN", GPULlama3Kernels::reductionOneBlockWithLayer, context, tempFFN, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS)
+ .task("rmsApplyFFN", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsFfnWeight, tempFFN)
+ .task("fusedFFN", GPULlama3Kernels::fusedFeedForwardWithSiLUAndGLUActivation, context, xb, hb, w1, w3, DIM, HIDDEN_DIM, LOCAL_SIZE)
+ .task("ffnProj", GPULlama3Kernels::matrixVectorGenericWithResidual, context, hb, x, w2, HIDDEN_DIM, DIM, LOCAL_SIZE)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, x);
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+
+ // Track sequential state
+ FloatArray seqX = copyArray(x);
+ FloatArray seqXb = new FloatArray(DIM);
+ FloatArray seqQ = new FloatArray(DIM);
+ FloatArray seqK = new FloatArray(KV_DIM);
+ FloatArray seqV = new FloatArray(KV_DIM);
+ FloatArray seqKeyCache = new FloatArray(CONTEXT_LENGTH * KV_DIM);
+ FloatArray seqValueCache = new FloatArray(CONTEXT_LENGTH * KV_DIM);
+ FloatArray seqAttOut = new FloatArray(DIM);
+ FloatArray seqHb = new FloatArray(HIDDEN_DIM);
+
+ for (int iter = 0; iter < numIterations; iter++) {
+ int position = iter;
+ positionHolder.set(0, position);
+ temp.init(0.0f);
+ tempFFN.init(0.0f);
+
+ // Sequential computation for this iteration
+ float normFactor = computeRmsNormFactorSequential(seqX, RMS_NORM_EPS);
+ applyRmsNormSequential(seqXb, seqX, rmsAttWeight, normFactor);
+ matrixVectorSequentialFP16(seqQ, wq, seqXb, DIM, DIM);
+ matrixVectorSequentialFP16(seqK, wk, seqXb, KV_DIM, DIM);
+ matrixVectorSequentialFP16(seqV, wv, seqXb, KV_DIM, DIM);
+ ropeRotationSequential(seqQ, seqK, position, KV_DIM, HEAD_SIZE);
+ copyToCacheSequential(seqKeyCache, seqK, seqValueCache, seqV, position, KV_DIM, layer, CONTEXT_LENGTH);
+ processHeadsSequential(seqQ, seqKeyCache, seqValueCache, seqAttOut, N_HEADS, HEAD_SIZE, KV_DIM, KV_MUL, position, layer, CONTEXT_LENGTH);
+ for (int i = 0; i < DIM; i++)
+ seqXb.set(i, seqAttOut.get(i));
+ matrixVectorWithResidualSequential(seqX, wo, seqXb, DIM, DIM);
+ float ffnNormFactor = computeRmsNormFactorSequential(seqX, RMS_NORM_EPS);
+ applyRmsNormSequential(seqXb, seqX, rmsFfnWeight, ffnNormFactor);
+ fusedFFNSequential(seqHb, seqXb, w1, w3, DIM, HIDDEN_DIM);
+ matrixVectorWithResidualSequential(seqX, w2, seqHb, DIM, HIDDEN_DIM);
+
+ // TornadoVM execution
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.withGridScheduler(scheduler).execute();
+ }
+
+ // Verify
+ assertArrayEquals("Iteration " + iter + ": output mismatch", seqX, x, EPSILON_ACCUMULATED);
+
+ System.out.printf("Iteration %d: x[0] expected=%.6f, actual=%.6f%n", iter, seqX.get(0), x.get(0));
+ }
+ }
+
+ // ==================== FFN Block Only Test ====================
+
+ @Test
+ public void testFusedFFNBlockOnly() throws TornadoExecutionPlanException {
+ // Sequential reference
+ FloatArray seqXb = new FloatArray(DIM);
+ FloatArray seqHb = new FloatArray(HIDDEN_DIM);
+ FloatArray expectedX = copyArray(x);
+
+ float ffnNormFactor = computeRmsNormFactorSequential(expectedX, RMS_NORM_EPS);
+ applyRmsNormSequential(seqXb, expectedX, rmsFfnWeight, ffnNormFactor);
+ fusedFFNSequential(seqHb, seqXb, w1, w3, DIM, HIDDEN_DIM);
+ matrixVectorWithResidualSequential(expectedX, w2, seqHb, DIM, HIDDEN_DIM);
+
+ GridScheduler scheduler = createScheduler();
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, rmsFfnWeight, tempFFN, w1, w2, w3)
+ .task("rmsReduceFFN", GPULlama3Kernels::reductionOneBlockWithLayer, context, tempFFN, x, DIM, RMS_NORM_EPS, LOCAL_SIZE_RMS)
+ .task("rmsApplyFFN", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, rmsFfnWeight, tempFFN)
+ .task("fusedFFN", GPULlama3Kernels::fusedFeedForwardWithSiLUAndGLUActivation, context, xb, hb, w1, w3, DIM, HIDDEN_DIM, LOCAL_SIZE)
+ .task("ffnProj", GPULlama3Kernels::matrixVectorGenericWithResidual, context, hb, x, w2, HIDDEN_DIM, DIM, LOCAL_SIZE)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, x, hb);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.withGridScheduler(scheduler).execute();
+ }
+
+ assertArrayEquals("FFN block: hidden state", seqHb, hb, EPSILON_ACCUMULATED);
+ assertArrayEquals("FFN block: output", expectedX, x, EPSILON_ACCUMULATED);
+ }
+}
\ No newline at end of file
diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/gpullama/TestTransformerKernelsUnit.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/gpullama/TestTransformerKernelsUnit.java
new file mode 100644
index 0000000000..e09a135a37
--- /dev/null
+++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/gpullama/TestTransformerKernelsUnit.java
@@ -0,0 +1,951 @@
+/*
+ * Copyright (c) 2025, APT Group, Department of Computer Science,
+ * The University of Manchester.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+package uk.ac.manchester.tornado.unittests.gpullama;
+
+import static org.junit.Assert.assertEquals;
+
+import java.util.Random;
+
+import org.junit.Before;
+import org.junit.Test;
+
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+import uk.ac.manchester.tornado.api.KernelContext;
+import uk.ac.manchester.tornado.api.TaskGraph;
+import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
+import uk.ac.manchester.tornado.api.WorkerGrid;
+import uk.ac.manchester.tornado.api.WorkerGrid1D;
+import uk.ac.manchester.tornado.api.enums.DataTransferMode;
+import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
+import uk.ac.manchester.tornado.api.types.HalfFloat;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.IntArray;
+import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;
+
+/**
+ * Unit tests for individual transformer kernels with full numerical verification.
+ *
+ * How to run:
+ *
+ * tornado-test -V org.beehive.gpullama3.tornadovm.kernels.TestTransformerKernelsUnit
+ *
+ */
+public class TestTransformerKernelsUnit extends TornadoTestBase {
+
+ private static final float EPSILON_FP32 = 1e-4f;
+ private static final float EPSILON_FP16 = 0.05f; // FP16 has lower precision
+ private static final float EPSILON_ACCUMULATED = 0.1f; // For operations with accumulated error
+ private static final float RMS_NORM_EPS = 1e-5f;
+ private static final int LOCAL_SIZE = 64;
+ private static final int LOCAL_SIZE_RMS = 256;
+
+ private Random random;
+
+ @Before
+ public void setUp() {
+ random = new Random(42);
+ }
+
+ // ==================== Sequential Reference Implementations ====================
+
+ /**
+ * Sequential RMS normalization - Phase 1: compute normalization factor
+ */
+ private float computeRmsNormFactorSequential(FloatArray x, float eps) {
+ float ss = 0.0f;
+ for (int i = 0; i < x.getSize(); i++) {
+ ss += x.get(i) * x.get(i);
+ }
+ ss /= x.getSize();
+ ss += eps;
+ return 1.0f / (float) Math.sqrt(ss);
+ }
+
+ /**
+ * Sequential RMS normalization - Phase 2: apply normalization
+ */
+ private void applyRmsNormSequential(FloatArray output, FloatArray x, FloatArray weights, float normFactor) {
+ for (int i = 0; i < x.getSize(); i++) {
+ output.set(i, weights.get(i) * (normFactor * x.get(i)));
+ }
+ }
+
+ /**
+ * Sequential matrix-vector multiplication (FP32 weights)
+ */
+ private void matrixVectorSequentialFP32(FloatArray output, FloatArray weights, FloatArray x, int rows, int cols) {
+ for (int i = 0; i < rows; i++) {
+ float sum = 0.0f;
+ for (int j = 0; j < cols; j++) {
+ sum += weights.get(i * cols + j) * x.get(j);
+ }
+ output.set(i, sum);
+ }
+ }
+
+ /**
+ * Sequential matrix-vector multiplication (FP16 weights)
+ */
+ private void matrixVectorSequentialFP16(FloatArray output, HalfFloatArray weights, FloatArray x, int rows, int cols) {
+ for (int i = 0; i < rows; i++) {
+ float sum = 0.0f;
+ for (int j = 0; j < cols; j++) {
+ sum += weights.get(i * cols + j).getFloat32() * x.get(j);
+ }
+ output.set(i, sum);
+ }
+ }
+
+ /**
+ * Sequential matrix-vector with residual addition (FP16 weights)
+ */
+ private void matrixVectorWithResidualSequential(FloatArray output, HalfFloatArray weights, FloatArray x, int rows, int cols) {
+ for (int i = 0; i < rows; i++) {
+ float sum = 0.0f;
+ for (int j = 0; j < cols; j++) {
+ sum += weights.get(i * cols + j).getFloat32() * x.get(j);
+ }
+ output.set(i, output.get(i) + sum);
+ }
+ }
+
+ /**
+ * Sequential RoPE rotation
+ */
+ private void ropeRotationSequential(FloatArray sq, FloatArray sk, int pos, int kvDim, int headSize) {
+ int numPairs = sq.getSize() / 2;
+ for (int i = 0; i < numPairs; i++) {
+ int idx = i * 2;
+ int headDim = idx % headSize;
+ float freq = 1.0f / (float) Math.pow(50000.0f, headDim / (float) headSize);
+ float val = pos * freq;
+ float fcr = (float) Math.cos(val);
+ float fci = (float) Math.sin(val);
+
+ // Rotate query
+ float v0q = sq.get(idx);
+ float v1q = sq.get(idx + 1);
+ sq.set(idx, v0q * fcr - v1q * fci);
+ sq.set(idx + 1, v0q * fci + v1q * fcr);
+
+ // Rotate key if within kvDim
+ if (idx < kvDim && idx + 1 < sk.getSize()) {
+ float v0k = sk.get(idx);
+ float v1k = sk.get(idx + 1);
+ sk.set(idx, v0k * fcr - v1k * fci);
+ sk.set(idx + 1, v0k * fci + v1k * fcr);
+ }
+ }
+ }
+
+ /**
+ * Sequential SiLU activation
+ */
+ private float siluSequential(float x) {
+ return x * (1.0f / (1.0f + (float) Math.exp(-x)));
+ }
+
+ /**
+ * Sequential fused FFN with SiLU and GLU
+ */
+ private void fusedFFNSequential(FloatArray output, FloatArray x, HalfFloatArray w1, HalfFloatArray w3, int inputDim, int hiddenDim) {
+ for (int i = 0; i < hiddenDim; i++) {
+ float sum1 = 0.0f;
+ float sum3 = 0.0f;
+ for (int j = 0; j < inputDim; j++) {
+ sum1 += w1.get(i * inputDim + j).getFloat32() * x.get(j);
+ sum3 += w3.get(i * inputDim + j).getFloat32() * x.get(j);
+ }
+ float silu = siluSequential(sum1);
+ output.set(i, silu * sum3);
+ }
+ }
+
+ /**
+ * Sequential copy to KV cache
+ */
+ private void copyToCacheSequential(FloatArray keyCache, FloatArray key, FloatArray valueCache, FloatArray value, int position, int kvDim, int layer, int contextLength) {
+ int loff = layer * contextLength * kvDim;
+ int destOffset = loff + position * kvDim;
+ for (int i = 0; i < key.getSize(); i++) {
+ keyCache.set(destOffset + i, key.get(i));
+ valueCache.set(destOffset + i, value.get(i));
+ }
+ }
+
+ /**
+ * Sequential element-wise addition
+ */
+ private void addInPlaceSequential(FloatArray a, FloatArray b) {
+ for (int i = 0; i < a.getSize(); i++) {
+ a.set(i, a.get(i) + b.get(i));
+ }
+ }
+
+ /**
+ * Sequential split gate/up with SiLU
+ */
+ private void splitGateUpAndSiLUSequential(FloatArray hb, FloatArray hbG, FloatArray hbU, int hiddenDim) {
+ for (int i = 0; i < hiddenDim; i++) {
+ float gateVal = hb.get(i);
+ float upVal = hb.get(hiddenDim + i);
+ float siluGate = siluSequential(gateVal);
+ hbG.set(i, siluGate);
+ hbU.set(i, siluGate * upVal);
+ }
+ }
+
+ /**
+ * Sequential attention for a single head
+ */
+ private void processHeadSequential(FloatArray q, FloatArray keyCache, FloatArray valueCache, FloatArray xb, int h, int headSize, int kvDim, int kvMul, int loff, int pos) {
+ int kvHeadIdx = h / kvMul;
+
+ // Compute attention scores
+ float[] attScores = new float[pos + 1];
+ for (int t = 0; t <= pos; t++) {
+ int keyOffset = loff + t * kvDim + kvHeadIdx * headSize;
+ float score = 0.0f;
+ for (int i = 0; i < headSize; i++) {
+ score += q.get(h * headSize + i) * keyCache.get(keyOffset + i);
+ }
+ attScores[t] = score / (float) Math.sqrt(headSize);
+ }
+
+ // Softmax
+ float maxScore = attScores[0];
+ for (int t = 1; t <= pos; t++) {
+ if (attScores[t] > maxScore)
+ maxScore = attScores[t];
+ }
+
+ float sumExp = 0.0f;
+ for (int t = 0; t <= pos; t++) {
+ attScores[t] = (float) Math.exp(attScores[t] - maxScore);
+ sumExp += attScores[t];
+ }
+
+ for (int t = 0; t <= pos; t++) {
+ attScores[t] /= sumExp;
+ }
+
+ // Weighted sum of values
+ for (int i = 0; i < headSize; i++) {
+ float weightedSum = 0.0f;
+ for (int t = 0; t <= pos; t++) {
+ int valueOffset = loff + t * kvDim + kvHeadIdx * headSize;
+ weightedSum += attScores[t] * valueCache.get(valueOffset + i);
+ }
+ xb.set(h * headSize + i, weightedSum);
+ }
+ }
+
+ /**
+ * Sequential multi-head attention
+ */
+ private void processHeadsSequential(FloatArray q, FloatArray keyCache, FloatArray valueCache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, int pos, int layer,
+ int contextLength) {
+ int loff = layer * contextLength * kvDim;
+ for (int h = 0; h < nHeads; h++) {
+ processHeadSequential(q, keyCache, valueCache, xb, h, headSize, kvDim, kvMul, loff, pos);
+ }
+ }
+
+ // ==================== Helper Methods ====================
+
+ private void fillRandom(FloatArray array, float min, float max) {
+ float range = max - min;
+ for (int i = 0; i < array.getSize(); i++) {
+ array.set(i, min + random.nextFloat() * range);
+ }
+ }
+
+ private FloatArray copyArray(FloatArray src) {
+ FloatArray dst = new FloatArray(src.getSize());
+ for (int i = 0; i < src.getSize(); i++) {
+ dst.set(i, src.get(i));
+ }
+ return dst;
+ }
+
+ private HalfFloatArray createRandomHalfFloatArray(int size, float min, float max) {
+ HalfFloatArray array = new HalfFloatArray(size);
+ float range = max - min;
+ for (int i = 0; i < size; i++) {
+ array.set(i, new HalfFloat(min + random.nextFloat() * range));
+ }
+ return array;
+ }
+
+ private void assertArrayEquals(String message, FloatArray expected, FloatArray actual, float tolerance) {
+ assertEquals(message + " - size mismatch", expected.getSize(), actual.getSize());
+ for (int i = 0; i < expected.getSize(); i++) {
+ assertEquals(message + " at index " + i, expected.get(i), actual.get(i), tolerance);
+ }
+ }
+
+ // ==================== Unit Tests ====================
+
+ @Test
+ public void testReductionOneBlockWithLayer() throws TornadoExecutionPlanException {
+ final int dim = 512;
+ final int localSize = 256;
+ final int numGroups = (dim + localSize - 1) / localSize;
+
+ FloatArray x = new FloatArray(dim);
+ FloatArray output = new FloatArray(numGroups + 1);
+ fillRandom(x, -1.0f, 1.0f);
+
+ // Sequential reference
+ float expectedNormFactor = computeRmsNormFactorSequential(x, RMS_NORM_EPS);
+
+ WorkerGrid worker = new WorkerGrid1D(dim);
+ worker.setLocalWork(localSize, 1, 1);
+ GridScheduler gridScheduler = new GridScheduler("s0.t0", worker);
+ KernelContext context = new KernelContext();
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, output)
+ .task("t0", GPULlama3Kernels::reductionOneBlockWithLayer, context, output, x, dim, RMS_NORM_EPS, localSize)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, output);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.withGridScheduler(gridScheduler).execute();
+ }
+
+ assertEquals("RMS norm factor", expectedNormFactor, output.get(0), 0.01f);
+ }
+
+ @Test
+ public void testReductionOneBlock2WithLayer() throws TornadoExecutionPlanException {
+ final int dim = 512;
+ final int localSize = 256;
+ final int numGroups = (dim + localSize - 1) / localSize;
+
+ FloatArray x = new FloatArray(dim);
+ FloatArray xb = new FloatArray(dim);
+ FloatArray weights = new FloatArray(dim);
+ FloatArray temp = new FloatArray(numGroups + 1);
+ FloatArray expectedOutput = new FloatArray(dim);
+
+ fillRandom(x, -1.0f, 1.0f);
+ fillRandom(weights, 0.5f, 1.5f);
+
+ // Sequential reference
+ float normFactor = computeRmsNormFactorSequential(x, RMS_NORM_EPS);
+ temp.set(0, normFactor);
+ applyRmsNormSequential(expectedOutput, x, weights, normFactor);
+
+ WorkerGrid worker = new WorkerGrid1D(dim);
+ worker.setLocalWork(localSize, 1, 1);
+ GridScheduler gridScheduler = new GridScheduler("s0.t0", worker);
+ KernelContext context = new KernelContext();
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, weights, temp)
+ .task("t0", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, weights, temp)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, xb);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.withGridScheduler(gridScheduler).execute();
+ }
+
+ assertArrayEquals("RMS norm application", expectedOutput, xb, EPSILON_FP32);
+ }
+
+ @Test
+ public void testFullRMSNormalization() throws TornadoExecutionPlanException {
+ final int dim = 512;
+ final int localSize = 256;
+ final int numGroups = (dim + localSize - 1) / localSize;
+
+ FloatArray x = new FloatArray(dim);
+ FloatArray xb = new FloatArray(dim);
+ FloatArray weights = new FloatArray(dim);
+ FloatArray temp = new FloatArray(numGroups + 1);
+ FloatArray expectedOutput = new FloatArray(dim);
+
+ fillRandom(x, -1.0f, 1.0f);
+ fillRandom(weights, 0.5f, 1.5f);
+ temp.init(0.0f);
+
+ // Sequential reference: full RMS norm
+ float normFactor = computeRmsNormFactorSequential(x, RMS_NORM_EPS);
+ applyRmsNormSequential(expectedOutput, x, weights, normFactor);
+
+ WorkerGrid worker = new WorkerGrid1D(dim);
+ worker.setLocalWork(localSize, 1, 1);
+ GridScheduler gridScheduler = new GridScheduler();
+ gridScheduler.addWorkerGrid("s0.reduce", worker);
+ gridScheduler.addWorkerGrid("s0.apply", worker);
+ KernelContext context = new KernelContext();
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(DataTransferMode.EVERY_EXECUTION, x, weights, temp)
+ .task("reduce", GPULlama3Kernels::reductionOneBlockWithLayer, context, temp, x, dim, RMS_NORM_EPS, localSize)
+ .task("apply", GPULlama3Kernels::reductionOneBlock2WithLayer, context, xb, x, weights, temp)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, xb, temp);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.withGridScheduler(gridScheduler).execute();
+ }
+
+ assertEquals("RMS norm factor", normFactor, temp.get(0), 0.01f);
+ assertArrayEquals("Full RMS normalization", expectedOutput, xb, EPSILON_FP32);
+ }
+
+ @Test
+ public void testMatrixVectorGenericFP32() throws TornadoExecutionPlanException {
+ final int inputDim = 256;
+ final int outputDim = 128;
+ final int localSize = 64;
+
+ FloatArray x = new FloatArray(inputDim);
+ FloatArray weights = new FloatArray(outputDim * inputDim);
+ FloatArray output = new FloatArray(outputDim);
+ FloatArray expectedOutput = new FloatArray(outputDim);
+
+ fillRandom(x, -1.0f, 1.0f);
+ fillRandom(weights, -0.1f, 0.1f);
+
+ // Sequential reference
+ matrixVectorSequentialFP32(expectedOutput, weights, x, outputDim, inputDim);
+
+ WorkerGrid worker = new WorkerGrid1D(outputDim * localSize);
+ worker.setLocalWork(localSize, 1, 1);
+ GridScheduler gridScheduler = new GridScheduler("s0.t0", worker);
+ KernelContext context = new KernelContext();
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, weights)
+ .task("t0", GPULlama3Kernels::matrixVectorGeneric, context, x, output, weights, inputDim, outputDim, localSize)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, output);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.withGridScheduler(gridScheduler).execute();
+ }
+
+ assertArrayEquals("MatVec FP32", expectedOutput, output, EPSILON_FP32);
+ }
+
+ @Test
+ public void testMatrixVectorGenericFP16() throws TornadoExecutionPlanException {
+ final int inputDim = 256;
+ final int outputDim = 128;
+ final int localSize = 64;
+
+ FloatArray x = new FloatArray(inputDim);
+ HalfFloatArray weights = createRandomHalfFloatArray(outputDim * inputDim, -0.1f, 0.1f);
+ FloatArray output = new FloatArray(outputDim);
+ FloatArray expectedOutput = new FloatArray(outputDim);
+
+ fillRandom(x, -1.0f, 1.0f);
+
+ // Sequential reference
+ matrixVectorSequentialFP16(expectedOutput, weights, x, outputDim, inputDim);
+
+ WorkerGrid worker = new WorkerGrid1D(outputDim * localSize);
+ worker.setLocalWork(localSize, 1, 1);
+ GridScheduler gridScheduler = new GridScheduler("s0.t0", worker);
+ KernelContext context = new KernelContext();
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, weights)
+ .task("t0", GPULlama3Kernels::matrixVectorGeneric, context, x, output, weights, inputDim, outputDim, localSize)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, output);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.withGridScheduler(gridScheduler).execute();
+ }
+
+ assertArrayEquals("MatVec FP16", expectedOutput, output, EPSILON_FP16);
+ }
+
+ @Test
+ public void testMatrixVectorGenericWithResidual() throws TornadoExecutionPlanException {
+ final int inputDim = 256;
+ final int outputDim = 128;
+ final int localSize = 64;
+
+ FloatArray x = new FloatArray(inputDim);
+ FloatArray residual = new FloatArray(outputDim);
+ FloatArray expectedResidual = new FloatArray(outputDim);
+ HalfFloatArray weights = createRandomHalfFloatArray(outputDim * inputDim, -0.1f, 0.1f);
+
+ fillRandom(x, -1.0f, 1.0f);
+ fillRandom(residual, -0.5f, 0.5f);
+
+ // Copy residual for sequential computation
+ for (int i = 0; i < outputDim; i++) {
+ expectedResidual.set(i, residual.get(i));
+ }
+ matrixVectorWithResidualSequential(expectedResidual, weights, x, outputDim, inputDim);
+
+ WorkerGrid worker = new WorkerGrid1D(outputDim * localSize);
+ worker.setLocalWork(localSize, 1, 1);
+ GridScheduler gridScheduler = new GridScheduler("s0.t0", worker);
+ KernelContext context = new KernelContext();
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, residual, weights)
+ .task("t0", GPULlama3Kernels::matrixVectorGenericWithResidual, context, x, residual, weights, inputDim, outputDim, localSize)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, residual);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.withGridScheduler(gridScheduler).execute();
+ }
+
+ assertArrayEquals("MatVec with residual", expectedResidual, residual, EPSILON_FP16);
+ }
+
+ @Test
+ public void testRopeRotation() throws TornadoExecutionPlanException {
+ final int dim = 128;
+ final int kvDim = 64;
+ final int headSize = 32;
+ final int position = 5;
+
+ FloatArray sq = new FloatArray(dim);
+ FloatArray sk = new FloatArray(kvDim);
+ FloatArray expectedSq = new FloatArray(dim);
+ FloatArray expectedSk = new FloatArray(kvDim);
+ IntArray positionHolder = new IntArray(1);
+
+ fillRandom(sq, -1.0f, 1.0f);
+ fillRandom(sk, -1.0f, 1.0f);
+ positionHolder.set(0, position);
+
+ // Copy for sequential computation
+ for (int i = 0; i < dim; i++)
+ expectedSq.set(i, sq.get(i));
+ for (int i = 0; i < kvDim; i++)
+ expectedSk.set(i, sk.get(i));
+ ropeRotationSequential(expectedSq, expectedSk, position, kvDim, headSize);
+
+ WorkerGrid worker = new WorkerGrid1D(dim / 2);
+ worker.setLocalWork(Math.min(128, dim / 2), 1, 1);
+ GridScheduler gridScheduler = new GridScheduler("s0.t0", worker);
+ KernelContext context = new KernelContext();
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, positionHolder, sq, sk)
+ .task("t0", GPULlama3Kernels::ropeRotation, context, positionHolder, sq, sk, kvDim, headSize)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, sq, sk);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.withGridScheduler(gridScheduler).execute();
+ }
+
+ assertArrayEquals("RoPE query rotation", expectedSq, sq, EPSILON_FP32);
+ assertArrayEquals("RoPE key rotation", expectedSk, sk, EPSILON_FP32);
+ }
+
+ @Test
+ public void testRopeRotationMultiplePositions() throws TornadoExecutionPlanException {
+ final int dim = 128;
+ final int kvDim = 64;
+ final int headSize = 32;
+
+ for (int position : new int[] { 0, 1, 10, 50, 100 }) {
+ FloatArray sq = new FloatArray(dim);
+ FloatArray sk = new FloatArray(kvDim);
+ FloatArray expectedSq = new FloatArray(dim);
+ FloatArray expectedSk = new FloatArray(kvDim);
+ IntArray positionHolder = new IntArray(1);
+
+ fillRandom(sq, -1.0f, 1.0f);
+ fillRandom(sk, -1.0f, 1.0f);
+ positionHolder.set(0, position);
+
+ for (int i = 0; i < dim; i++)
+ expectedSq.set(i, sq.get(i));
+ for (int i = 0; i < kvDim; i++)
+ expectedSk.set(i, sk.get(i));
+ ropeRotationSequential(expectedSq, expectedSk, position, kvDim, headSize);
+
+ WorkerGrid worker = new WorkerGrid1D(dim / 2);
+ worker.setLocalWork(Math.min(128, dim / 2), 1, 1);
+ GridScheduler gridScheduler = new GridScheduler("s0.t0", worker);
+ KernelContext context = new KernelContext();
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, positionHolder, sq, sk)
+ .task("t0", GPULlama3Kernels::ropeRotation, context, positionHolder, sq, sk, kvDim, headSize)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, sq, sk);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.withGridScheduler(gridScheduler).execute();
+ }
+
+ assertArrayEquals("RoPE Q at position " + position, expectedSq, sq, EPSILON_FP32);
+ assertArrayEquals("RoPE K at position " + position, expectedSk, sk, EPSILON_FP32);
+ }
+ }
+
+ @Test
+ public void testCopyToCache() throws TornadoExecutionPlanException {
+ final int kvDim = 64;
+ final int contextLength = 128;
+ final int numLayers = 4;
+ final int layer = 2;
+ final int position = 10;
+
+ FloatArray key = new FloatArray(kvDim);
+ FloatArray value = new FloatArray(kvDim);
+ FloatArray keyCache = new FloatArray(numLayers * contextLength * kvDim);
+ FloatArray valueCache = new FloatArray(numLayers * contextLength * kvDim);
+ FloatArray expectedKeyCache = new FloatArray(numLayers * contextLength * kvDim);
+ FloatArray expectedValueCache = new FloatArray(numLayers * contextLength * kvDim);
+ IntArray positionHolder = new IntArray(1);
+
+ fillRandom(key, -1.0f, 1.0f);
+ fillRandom(value, -1.0f, 1.0f);
+ positionHolder.set(0, position);
+
+ // Sequential reference
+ copyToCacheSequential(expectedKeyCache, key, expectedValueCache, value, position, kvDim, layer, contextLength);
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, key, value, keyCache, valueCache, positionHolder)
+ .task("t0", GPULlama3Kernels::copyToCache, keyCache, key, valueCache, value, positionHolder, kvDim, layer, contextLength)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, keyCache, valueCache);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.execute();
+ }
+
+ // Verify the entire cache matches (including zeros)
+ int loff = layer * contextLength * kvDim;
+ int destOffset = loff + position * kvDim;
+ for (int i = 0; i < kvDim; i++) {
+ assertEquals("Key cache at " + i, key.get(i), keyCache.get(destOffset + i), EPSILON_FP32);
+ assertEquals("Value cache at " + i, value.get(i), valueCache.get(destOffset + i), EPSILON_FP32);
+ }
+ }
+
+ @Test
+ public void testFusedFeedForwardWithSiLUAndGLUActivation() throws TornadoExecutionPlanException {
+ final int inputDim = 128;
+ final int hiddenDim = 64;
+ final int localSize = 32;
+
+ FloatArray x = new FloatArray(inputDim);
+ HalfFloatArray w1 = createRandomHalfFloatArray(hiddenDim * inputDim, -0.1f, 0.1f);
+ HalfFloatArray w3 = createRandomHalfFloatArray(hiddenDim * inputDim, -0.1f, 0.1f);
+ FloatArray output = new FloatArray(hiddenDim);
+ FloatArray expectedOutput = new FloatArray(hiddenDim);
+
+ fillRandom(x, -1.0f, 1.0f);
+
+ // Sequential reference
+ fusedFFNSequential(expectedOutput, x, w1, w3, inputDim, hiddenDim);
+
+ WorkerGrid worker = new WorkerGrid1D(hiddenDim * localSize);
+ worker.setLocalWork(localSize, 1, 1);
+ GridScheduler gridScheduler = new GridScheduler("s0.t0", worker);
+ KernelContext context = new KernelContext();
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, x, w1, w3)
+ .task("t0", GPULlama3Kernels::fusedFeedForwardWithSiLUAndGLUActivation, context, x, output, w1, w3, inputDim, hiddenDim, localSize)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, output);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.withGridScheduler(gridScheduler).execute();
+ }
+
+ assertArrayEquals("Fused FFN", expectedOutput, output, EPSILON_ACCUMULATED);
+ }
+
+ @Test
+ public void testAddInPlace() throws TornadoExecutionPlanException {
+ final int size = 512;
+
+ FloatArray a = new FloatArray(size);
+ FloatArray b = new FloatArray(size);
+ FloatArray expectedA = new FloatArray(size);
+
+ fillRandom(a, -1.0f, 1.0f);
+ fillRandom(b, -1.0f, 1.0f);
+
+ // Copy for sequential
+ for (int i = 0; i < size; i++) {
+ expectedA.set(i, a.get(i));
+ }
+ addInPlaceSequential(expectedA, b);
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, a, b)
+ .task("t0", GPULlama3Kernels::addInPlace, a, b, size)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, a);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.execute();
+ }
+
+ assertArrayEquals("Add in place", expectedA, a, EPSILON_FP32);
+ }
+
+ @Test
+ public void testSplitQKV() throws TornadoExecutionPlanException {
+ final int dimQ = 256;
+ final int dimKV = 64;
+ final int totalSize = dimQ + 2 * dimKV;
+
+ FloatArray qkv = new FloatArray(totalSize);
+ FloatArray q = new FloatArray(dimQ);
+ FloatArray k = new FloatArray(dimKV);
+ FloatArray v = new FloatArray(dimKV);
+
+ fillRandom(qkv, -1.0f, 1.0f);
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, qkv)
+ .task("t0", GPULlama3Kernels::splitQKV, qkv, q, k, v, dimQ, dimKV)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, q, k, v);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.execute();
+ }
+
+ // Verify exact split
+ for (int i = 0; i < dimQ; i++) {
+ assertEquals("Q[" + i + "]", qkv.get(i), q.get(i), EPSILON_FP32);
+ }
+ for (int i = 0; i < dimKV; i++) {
+ assertEquals("K[" + i + "]", qkv.get(dimQ + i), k.get(i), EPSILON_FP32);
+ assertEquals("V[" + i + "]", qkv.get(dimQ + dimKV + i), v.get(i), EPSILON_FP32);
+ }
+ }
+
+ @Test
+ public void testSplitGateUpAndSiLU() throws TornadoExecutionPlanException {
+ final int hiddenDim = 256;
+
+ FloatArray hb = new FloatArray(hiddenDim * 2);
+ FloatArray hbG = new FloatArray(hiddenDim);
+ FloatArray hbU = new FloatArray(hiddenDim);
+ FloatArray expectedHbG = new FloatArray(hiddenDim);
+ FloatArray expectedHbU = new FloatArray(hiddenDim);
+
+ fillRandom(hb, -2.0f, 2.0f);
+
+ // Sequential reference
+ splitGateUpAndSiLUSequential(hb, expectedHbG, expectedHbU, hiddenDim);
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, hb)
+ .task("t0", GPULlama3Kernels::splitGateUpAndSiLU, hb, hbG, hbU, hiddenDim)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, hbG, hbU);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.execute();
+ }
+
+ assertArrayEquals("Gate with SiLU", expectedHbG, hbG, EPSILON_FP32);
+ assertArrayEquals("Up * Gate", expectedHbU, hbU, EPSILON_FP32);
+ }
+
+ @Test
+ public void testProcessHeadsParallel() throws TornadoExecutionPlanException {
+ final int nHeads = 4;
+ final int headSize = 32;
+ final int kvDim = nHeads * headSize;
+ final int dim = nHeads * headSize;
+ final int contextLength = 64;
+ final int position = 3;
+ final int layer = 0;
+ final int kvMul = 1;
+
+ FloatArray q = new FloatArray(dim);
+ FloatArray keyCache = new FloatArray(contextLength * kvDim);
+ FloatArray valueCache = new FloatArray(contextLength * kvDim);
+ FloatArray xb = new FloatArray(dim);
+ FloatArray expectedXb = new FloatArray(dim);
+ FloatArray att = new FloatArray(nHeads * contextLength);
+ IntArray positionHolder = new IntArray(1);
+
+ fillRandom(q, -1.0f, 1.0f);
+ fillRandom(keyCache, -1.0f, 1.0f);
+ fillRandom(valueCache, -1.0f, 1.0f);
+ positionHolder.set(0, position);
+
+ // Sequential reference
+ processHeadsSequential(q, keyCache, valueCache, expectedXb, nHeads, headSize, kvDim, kvMul, position, layer, contextLength);
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, q, keyCache, valueCache, positionHolder, att)
+ .task("t0", GPULlama3Kernels::processHeadsParallel, q, keyCache, valueCache, xb, nHeads, headSize, kvDim, kvMul, contextLength, positionHolder, att, layer, contextLength)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, xb);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.execute();
+ }
+
+ assertArrayEquals("Parallel attention", expectedXb, xb, EPSILON_FP32);
+ }
+
+ @Test
+ public void testProcessHeadsFlashAttention() throws TornadoExecutionPlanException {
+ final int nHeads = 4;
+ final int headSize = 32;
+ final int kvDim = nHeads * headSize;
+ final int dim = nHeads * headSize;
+ final int contextLength = 64;
+ final int position = 3;
+ final int layer = 0;
+ final int kvMul = 1;
+
+ FloatArray q = new FloatArray(dim);
+ FloatArray keyCache = new FloatArray(contextLength * kvDim);
+ FloatArray valueCache = new FloatArray(contextLength * kvDim);
+ FloatArray xb = new FloatArray(dim);
+ FloatArray expectedXb = new FloatArray(dim);
+ IntArray positionHolder = new IntArray(1);
+
+ fillRandom(q, -1.0f, 1.0f);
+ fillRandom(keyCache, -1.0f, 1.0f);
+ fillRandom(valueCache, -1.0f, 1.0f);
+ positionHolder.set(0, position);
+
+ // Sequential reference (same as processHeadsParallel)
+ processHeadsSequential(q, keyCache, valueCache, expectedXb, nHeads, headSize, kvDim, kvMul, position, layer, contextLength);
+
+ WorkerGrid worker = new WorkerGrid1D(nHeads * headSize);
+ worker.setLocalWork(headSize, 1, 1);
+ GridScheduler gridScheduler = new GridScheduler("s0.t0", worker);
+ KernelContext context = new KernelContext();
+
+ // @formatter:off
+ TaskGraph taskGraph = new TaskGraph("s0")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, q, keyCache, valueCache, positionHolder)
+ .task("t0", GPULlama3Kernels::processHeadsFlashAttention, context, q, keyCache, valueCache, xb, nHeads, headSize, kvDim, kvMul, positionHolder, layer, contextLength)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, xb);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {
+ executionPlan.withGridScheduler(gridScheduler).execute();
+ }
+
+ assertArrayEquals("Flash attention", expectedXb, xb, EPSILON_FP32);
+ }
+
+ @Test
+ public void testAttentionConsistency() throws TornadoExecutionPlanException {
+ // Test that both attention implementations produce the same result
+ final int nHeads = 4;
+ final int headSize = 32;
+ final int kvDim = nHeads * headSize;
+ final int dim = nHeads * headSize;
+ final int contextLength = 64;
+ final int position = 5;
+ final int layer = 0;
+ final int kvMul = 1;
+
+ FloatArray q = new FloatArray(dim);
+ FloatArray keyCache = new FloatArray(contextLength * kvDim);
+ FloatArray valueCache = new FloatArray(contextLength * kvDim);
+ FloatArray xbParallel = new FloatArray(dim);
+ FloatArray xbFlash = new FloatArray(dim);
+ FloatArray att = new FloatArray(nHeads * contextLength);
+ IntArray positionHolder = new IntArray(1);
+
+ fillRandom(q, -1.0f, 1.0f);
+ fillRandom(keyCache, -1.0f, 1.0f);
+ fillRandom(valueCache, -1.0f, 1.0f);
+ positionHolder.set(0, position);
+
+ // Test parallel attention
+ // @formatter:off
+ TaskGraph taskGraphParallel = new TaskGraph("s_parallel")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, q, keyCache, valueCache, positionHolder, att)
+ .task("t0", GPULlama3Kernels::processHeadsParallel, q, keyCache, valueCache, xbParallel, nHeads, headSize, kvDim, kvMul, contextLength, positionHolder, att, layer, contextLength)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, xbParallel);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraphParallel = taskGraphParallel.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraphParallel)) {
+ executionPlan.execute();
+ }
+
+ // Test flash attention
+ WorkerGrid worker = new WorkerGrid1D(nHeads * headSize);
+ worker.setLocalWork(headSize, 1, 1);
+ GridScheduler gridScheduler = new GridScheduler("s_flash.t0", worker);
+ KernelContext context = new KernelContext();
+
+ // @formatter:off
+ TaskGraph taskGraphFlash = new TaskGraph("s_flash")
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, q, keyCache, valueCache, positionHolder)
+ .task("t0", GPULlama3Kernels::processHeadsFlashAttention, context, q, keyCache, valueCache, xbFlash, nHeads, headSize, kvDim, kvMul, positionHolder, layer, contextLength)
+ .transferToHost(DataTransferMode.EVERY_EXECUTION, xbFlash);
+ // @formatter:on
+
+ ImmutableTaskGraph immutableTaskGraphFlash = taskGraphFlash.snapshot();
+ try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraphFlash)) {
+ executionPlan.withGridScheduler(gridScheduler).execute();
+ }
+
+ // Both implementations should match
+ assertArrayEquals("Parallel vs Flash attention consistency", xbParallel, xbFlash, EPSILON_FP32);
+ }
+}
\ No newline at end of file