|
| 1 | +# Stencil Operations |
| 2 | + |
| 3 | +The `@stencil` macro in Dagger.jl provides a convenient way to perform stencil computations on `DArray`s. It operates within a `Dagger.spawn_datadeps()` block and allows you to define operations that apply to each element of a `DArray`, potentially accessing values from each element's neighbors. |
| 4 | + |
| 5 | +## Basic Usage |
| 6 | + |
| 7 | +The fundamental structure of a `@stencil` block involves iterating over an implicit index, named `idx` in the following example , which represents the coordinates of an element in the processed `DArray`s. |
| 8 | + |
| 9 | +```julia |
| 10 | +using Dagger |
| 11 | +import Dagger: @stencil, Wrap, Pad |
| 12 | + |
| 13 | +# Initialize a DArray |
| 14 | +A = zeros(Blocks(2, 2), Int, 4, 4) |
| 15 | + |
| 16 | +Dagger.spawn_datadeps() do |
| 17 | + @stencil begin |
| 18 | + A[idx] = 1 # Assign 1 to every element of A |
| 19 | + end |
| 20 | +end |
| 21 | + |
| 22 | +@assert all(collect(A) .== 1) |
| 23 | +``` |
| 24 | + |
| 25 | +In this example, `A[idx] = 1` is executed for each chunk of `A`. The `idx` variable corresponds to the indices within each chunk. |
| 26 | + |
| 27 | +## Neighborhood Access with `@neighbors` |
| 28 | + |
| 29 | +The true power of stencils comes from accessing neighboring elements. The `@neighbors` macro facilitates this. |
| 30 | + |
| 31 | +`@neighbors(array[idx], distance, boundary_condition)` |
| 32 | + |
| 33 | +- `array[idx]`: The array and current index from which to find neighbors. |
| 34 | +- `distance`: An integer specifying the extent of the neighborhood (e.g., `1` for a 3x3 neighborhood in 2D). |
| 35 | +- `boundary_condition`: Defines how to handle accesses beyond the array boundaries. Available conditions are: |
| 36 | + - `Wrap()`: Wraps around to the other side of the array. |
| 37 | + - `Pad(value)`: Pads with a specified `value`. |
| 38 | + |
| 39 | +### Example: Averaging Neighbors with `Wrap` |
| 40 | + |
| 41 | +```julia |
| 42 | +import Dagger: Wrap |
| 43 | + |
| 44 | +# Initialize a DArray |
| 45 | +A = ones(Blocks(1, 1), Int, 3, 3) |
| 46 | +A[2,2] = 10 # Central element has a different value |
| 47 | +B = zeros(Blocks(1, 1), Float64, 3, 3) |
| 48 | + |
| 49 | +Dagger.spawn_datadeps() do |
| 50 | + @stencil begin |
| 51 | + # Calculate the average of the 3x3 neighborhood (including the center) |
| 52 | + B[idx] = sum(@neighbors(A[idx], 1, Wrap())) / 9.0 |
| 53 | + end |
| 54 | +end |
| 55 | + |
| 56 | +# Manually calculate expected B for verification |
| 57 | +expected_B = zeros(Float64, 3, 3) |
| 58 | +A_collected = collect(A) |
| 59 | +for r in 1:3, c in 1:3 |
| 60 | + local_sum = 0.0 |
| 61 | + for dr in -1:1, dc in -1:1 |
| 62 | + nr, nc = mod1(r+dr, 3), mod1(c+dc, 3) |
| 63 | + local_sum += A_collected[nr, nc] |
| 64 | + end |
| 65 | + expected_B[r,c] = local_sum / 9.0 |
| 66 | +end |
| 67 | + |
| 68 | +@assert collect(B) ≈ expected_B |
| 69 | +``` |
| 70 | + |
| 71 | +### Example: Convolution with `Pad` |
| 72 | + |
| 73 | +```julia |
| 74 | +import Pad |
| 75 | + |
| 76 | +# Initialize a DArray |
| 77 | +A = ones(Blocks(2, 2), Int, 4, 4) |
| 78 | +B = zeros(Blocks(2, 2), Int, 4, 4) |
| 79 | + |
| 80 | +Dagger.spawn_datadeps() do |
| 81 | + @stencil begin |
| 82 | + B[idx] = sum(@neighbors(A[idx], 1, Pad(0))) # Pad with 0 |
| 83 | + end |
| 84 | +end |
| 85 | + |
| 86 | +# Expected result for a 3x3 sum filter with zero padding |
| 87 | +expected_B_padded = [ |
| 88 | + 4 6 6 4; |
| 89 | + 6 9 9 6; |
| 90 | + 6 9 9 6; |
| 91 | + 4 6 6 4 |
| 92 | +] |
| 93 | +@assert collect(B) == expected_B_padded |
| 94 | +``` |
| 95 | + |
| 96 | +## Sequential Semantics |
| 97 | + |
| 98 | +Expressions within a `@stencil` block are executed sequentially in terms of their effect on the data. This means that the result of one statement is visible to the subsequent statements, as if they were applied "all at once" across all indices before the next statement begins. |
| 99 | + |
| 100 | +```julia |
| 101 | +A = zeros(Blocks(2, 2), Int, 4, 4) |
| 102 | +B = zeros(Blocks(2, 2), Int, 4, 4) |
| 103 | + |
| 104 | +Dagger.spawn_datadeps() do |
| 105 | + @stencil begin |
| 106 | + A[idx] = 1 # First, A is initialized |
| 107 | + B[idx] = A[idx] * 2 # Then, B is computed using the new values of A |
| 108 | + end |
| 109 | +end |
| 110 | + |
| 111 | +expected_A = [1 for r in 1:4, c in 1:4] |
| 112 | +expected_B_seq = expected_A .* 2 |
| 113 | + |
| 114 | +@assert collect(A) == expected_A |
| 115 | +@assert collect(B) == expected_B_seq |
| 116 | +``` |
| 117 | + |
| 118 | +## Operations on Multiple `DArray`s |
| 119 | + |
| 120 | +You can read from and write to multiple `DArray`s within a single `@stencil` block, provided they have compatible chunk structures. |
| 121 | + |
| 122 | +```julia |
| 123 | +A = ones(Blocks(1, 1), Int, 2, 2) |
| 124 | +B = DArray(fill(3, 2, 2), Blocks(1, 1)) |
| 125 | +C = zeros(Blocks(1, 1), Int, 2, 2) |
| 126 | + |
| 127 | +Dagger.spawn_datadeps() do |
| 128 | + @stencil begin |
| 129 | + C[idx] = A[idx] + B[idx] |
| 130 | + end |
| 131 | +end |
| 132 | +@assert all(collect(C) .== 4) |
| 133 | +``` |
| 134 | + |
| 135 | +## Example: Game of Life |
| 136 | + |
| 137 | +The following demonstrates a more complex example: Conway's Game of Life. |
| 138 | + |
| 139 | +```julia |
| 140 | +# Ensure Plots and other necessary packages are available for the example |
| 141 | +using Plots |
| 142 | + |
| 143 | +N = 27 # Size of one dimension of a tile |
| 144 | +nt = 3 # Number of tiles in each dimension (results in nt x nt grid of tiles) |
| 145 | +niters = 10 # Number of iterations for the animation |
| 146 | + |
| 147 | +tiles = zeros(Blocks(N, N), Bool, N*nt, N*nt) |
| 148 | +outputs = zeros(Blocks(N, N), Bool, N*nt, N*nt) |
| 149 | + |
| 150 | +# Create a fun initial state (e.g., a glider and some random noise) |
| 151 | +tiles[13, 14] = true |
| 152 | +tiles[14, 14] = true |
| 153 | +tiles[15, 14] = true |
| 154 | +tiles[15, 15] = true |
| 155 | +tiles[14, 16] = true |
| 156 | +# Add some random noise in one of the tiles |
| 157 | +@view(tiles[(2N+1):3N, (2N+1):3N]) .= rand(Bool, N, N) |
| 158 | + |
| 159 | + |
| 160 | + |
| 161 | +anim = @animate for _ in 1:niters |
| 162 | + Dagger.spawn_datadeps() do |
| 163 | + @stencil begin |
| 164 | + outputs[idx] = begin |
| 165 | + nhood = @neighbors(tiles[idx], 1, Wrap()) |
| 166 | + neighs = sum(nhood) - tiles[idx] # Sum neighborhood, but subtract own value |
| 167 | + if tiles[idx] && neighs < 2 |
| 168 | + 0 # Dies of underpopulation |
| 169 | + elseif tiles[idx] && neighs > 3 |
| 170 | + 0 # Dies of overpopulation |
| 171 | + elseif !tiles[idx] && neighs == 3 |
| 172 | + 1 # Becomes alive by reproduction |
| 173 | + else |
| 174 | + tiles[idx] # Keeps its prior value |
| 175 | + end |
| 176 | + end |
| 177 | + tiles[idx] = outputs[idx] # Update tiles for the next iteration |
| 178 | + end |
| 179 | + end |
| 180 | + heatmap(Int.(collect(outputs))) # Generate a heatmap visualization |
| 181 | +end |
| 182 | +path = mp4(anim; fps=5, show_msg=true).filename # Create an animation of the heatmaps over time |
| 183 | +``` |
0 commit comments