Skip to content

Commit 71cfdac

Browse files
stencils: Add tests and documentation (#619)
Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
1 parent 386dfd8 commit 71cfdac

File tree

5 files changed

+337
-43
lines changed

5 files changed

+337
-43
lines changed

docs/src/index.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,39 @@ DA = rand(Blocks(32, 32), 256, 128)
361361
collect(DA) # returns a `Matrix{Float64}`
362362
```
363363

364+
-----
365+
366+
## Quickstart: Stencil Operations
367+
368+
Dagger's `@stencil` macro allows for easy specification of stencil operations on `DArray`s, often used in simulations and image processing. These operations typically involve updating an element based on the values of its neighbors.
369+
370+
For more details: [Stencil Operations](@ref)
371+
372+
### Applying a Simple Stencil
373+
374+
Here's how to apply a stencil that averages each element with its immediate neighbors, using a `Wrap` boundary condition (where neighbor access at the array edges wrap around).
375+
376+
```julia
377+
using Dagger
378+
import Dagger: @stencil, Wrap
379+
380+
# Create a 5x5 DArray, partitioned into 2x2 blocks
381+
A = rand(Blocks(2, 2), 5, 5)
382+
B = zeros(Blocks(2,2), 5, 5)
383+
384+
Dagger.spawn_datadeps() do
385+
@stencil begin
386+
# For each element in A, calculate the sum of its 3x3 neighborhood
387+
# (including itself) and store the average in B.
388+
# Values outside the array bounds are determined by Wrap().
389+
B[idx] = sum(@neighbors(A[idx], 1, Wrap())) / 9.0
390+
end
391+
end
392+
393+
# B now contains the averaged values.
394+
```
395+
In this example, `idx` refers to the coordinates of each element being processed. `@neighbors(A[idx], 1, Wrap())` fetches the 3x3 neighborhood around `A[idx]`. The `1` indicates a neighborhood distance of 1 from the central element, and `Wrap()` specifies the boundary behavior.
396+
364397
## Quickstart: Datadeps
365398

366399
Datadeps is a feature in Dagger.jl that facilitates parallelism control within designated regions, allowing tasks to write to their arguments while ensuring dependencies are respected.

docs/src/stencils.jl

Lines changed: 0 additions & 43 deletions
This file was deleted.

docs/src/stencils.md

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Stencil Operations
2+
3+
The `@stencil` macro in Dagger.jl provides a convenient way to perform stencil computations on `DArray`s. It operates within a `Dagger.spawn_datadeps()` block and allows you to define operations that apply to each element of a `DArray`, potentially accessing values from each element's neighbors.
4+
5+
## Basic Usage
6+
7+
The fundamental structure of a `@stencil` block involves iterating over an implicit index, named `idx` in the following example , which represents the coordinates of an element in the processed `DArray`s.
8+
9+
```julia
10+
using Dagger
11+
import Dagger: @stencil, Wrap, Pad
12+
13+
# Initialize a DArray
14+
A = zeros(Blocks(2, 2), Int, 4, 4)
15+
16+
Dagger.spawn_datadeps() do
17+
@stencil begin
18+
A[idx] = 1 # Assign 1 to every element of A
19+
end
20+
end
21+
22+
@assert all(collect(A) .== 1)
23+
```
24+
25+
In this example, `A[idx] = 1` is executed for each chunk of `A`. The `idx` variable corresponds to the indices within each chunk.
26+
27+
## Neighborhood Access with `@neighbors`
28+
29+
The true power of stencils comes from accessing neighboring elements. The `@neighbors` macro facilitates this.
30+
31+
`@neighbors(array[idx], distance, boundary_condition)`
32+
33+
- `array[idx]`: The array and current index from which to find neighbors.
34+
- `distance`: An integer specifying the extent of the neighborhood (e.g., `1` for a 3x3 neighborhood in 2D).
35+
- `boundary_condition`: Defines how to handle accesses beyond the array boundaries. Available conditions are:
36+
- `Wrap()`: Wraps around to the other side of the array.
37+
- `Pad(value)`: Pads with a specified `value`.
38+
39+
### Example: Averaging Neighbors with `Wrap`
40+
41+
```julia
42+
import Dagger: Wrap
43+
44+
# Initialize a DArray
45+
A = ones(Blocks(1, 1), Int, 3, 3)
46+
A[2,2] = 10 # Central element has a different value
47+
B = zeros(Blocks(1, 1), Float64, 3, 3)
48+
49+
Dagger.spawn_datadeps() do
50+
@stencil begin
51+
# Calculate the average of the 3x3 neighborhood (including the center)
52+
B[idx] = sum(@neighbors(A[idx], 1, Wrap())) / 9.0
53+
end
54+
end
55+
56+
# Manually calculate expected B for verification
57+
expected_B = zeros(Float64, 3, 3)
58+
A_collected = collect(A)
59+
for r in 1:3, c in 1:3
60+
local_sum = 0.0
61+
for dr in -1:1, dc in -1:1
62+
nr, nc = mod1(r+dr, 3), mod1(c+dc, 3)
63+
local_sum += A_collected[nr, nc]
64+
end
65+
expected_B[r,c] = local_sum / 9.0
66+
end
67+
68+
@assert collect(B) expected_B
69+
```
70+
71+
### Example: Convolution with `Pad`
72+
73+
```julia
74+
import Pad
75+
76+
# Initialize a DArray
77+
A = ones(Blocks(2, 2), Int, 4, 4)
78+
B = zeros(Blocks(2, 2), Int, 4, 4)
79+
80+
Dagger.spawn_datadeps() do
81+
@stencil begin
82+
B[idx] = sum(@neighbors(A[idx], 1, Pad(0))) # Pad with 0
83+
end
84+
end
85+
86+
# Expected result for a 3x3 sum filter with zero padding
87+
expected_B_padded = [
88+
4 6 6 4;
89+
6 9 9 6;
90+
6 9 9 6;
91+
4 6 6 4
92+
]
93+
@assert collect(B) == expected_B_padded
94+
```
95+
96+
## Sequential Semantics
97+
98+
Expressions within a `@stencil` block are executed sequentially in terms of their effect on the data. This means that the result of one statement is visible to the subsequent statements, as if they were applied "all at once" across all indices before the next statement begins.
99+
100+
```julia
101+
A = zeros(Blocks(2, 2), Int, 4, 4)
102+
B = zeros(Blocks(2, 2), Int, 4, 4)
103+
104+
Dagger.spawn_datadeps() do
105+
@stencil begin
106+
A[idx] = 1 # First, A is initialized
107+
B[idx] = A[idx] * 2 # Then, B is computed using the new values of A
108+
end
109+
end
110+
111+
expected_A = [1 for r in 1:4, c in 1:4]
112+
expected_B_seq = expected_A .* 2
113+
114+
@assert collect(A) == expected_A
115+
@assert collect(B) == expected_B_seq
116+
```
117+
118+
## Operations on Multiple `DArray`s
119+
120+
You can read from and write to multiple `DArray`s within a single `@stencil` block, provided they have compatible chunk structures.
121+
122+
```julia
123+
A = ones(Blocks(1, 1), Int, 2, 2)
124+
B = DArray(fill(3, 2, 2), Blocks(1, 1))
125+
C = zeros(Blocks(1, 1), Int, 2, 2)
126+
127+
Dagger.spawn_datadeps() do
128+
@stencil begin
129+
C[idx] = A[idx] + B[idx]
130+
end
131+
end
132+
@assert all(collect(C) .== 4)
133+
```
134+
135+
## Example: Game of Life
136+
137+
The following demonstrates a more complex example: Conway's Game of Life.
138+
139+
```julia
140+
# Ensure Plots and other necessary packages are available for the example
141+
using Plots
142+
143+
N = 27 # Size of one dimension of a tile
144+
nt = 3 # Number of tiles in each dimension (results in nt x nt grid of tiles)
145+
niters = 10 # Number of iterations for the animation
146+
147+
tiles = zeros(Blocks(N, N), Bool, N*nt, N*nt)
148+
outputs = zeros(Blocks(N, N), Bool, N*nt, N*nt)
149+
150+
# Create a fun initial state (e.g., a glider and some random noise)
151+
tiles[13, 14] = true
152+
tiles[14, 14] = true
153+
tiles[15, 14] = true
154+
tiles[15, 15] = true
155+
tiles[14, 16] = true
156+
# Add some random noise in one of the tiles
157+
@view(tiles[(2N+1):3N, (2N+1):3N]) .= rand(Bool, N, N)
158+
159+
160+
161+
anim = @animate for _ in 1:niters
162+
Dagger.spawn_datadeps() do
163+
@stencil begin
164+
outputs[idx] = begin
165+
nhood = @neighbors(tiles[idx], 1, Wrap())
166+
neighs = sum(nhood) - tiles[idx] # Sum neighborhood, but subtract own value
167+
if tiles[idx] && neighs < 2
168+
0 # Dies of underpopulation
169+
elseif tiles[idx] && neighs > 3
170+
0 # Dies of overpopulation
171+
elseif !tiles[idx] && neighs == 3
172+
1 # Becomes alive by reproduction
173+
else
174+
tiles[idx] # Keeps its prior value
175+
end
176+
end
177+
tiles[idx] = outputs[idx] # Update tiles for the next iteration
178+
end
179+
end
180+
heatmap(Int.(collect(outputs))) # Generate a heatmap visualization
181+
end
182+
path = mp4(anim; fps=5, show_msg=true).filename # Create an animation of the heatmaps over time
183+
```

test/array/stencil.jl

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import Dagger: @stencil, Wrap, Pad
2+
3+
@testset "@stencil" begin
4+
@testset "Simple assignment" begin
5+
A = zeros(Blocks(2, 2), Int, 4, 4)
6+
Dagger.spawn_datadeps() do
7+
@stencil begin
8+
A[idx] = 1
9+
end
10+
end
11+
@test all(collect(A) .== 1)
12+
end
13+
14+
@testset "Wrap boundary" begin
15+
A = zeros(Blocks(2, 2), Int, 4, 4)
16+
A[1,1] = 10
17+
B = zeros(Blocks(2, 2), Int, 4, 4)
18+
Dagger.spawn_datadeps() do
19+
@stencil begin
20+
B[idx] = sum(@neighbors(A[idx], 1, Wrap()))
21+
end
22+
end
23+
# Expected result after convolution with wrap around
24+
# Corner element (1,1) will sum its 3 neighbors + itself (10) + 5 wrapped around neighbors
25+
# For A[1,1], neighbors are A[4,4], A[4,1], A[4,2], A[1,4], A[1,2], A[2,4], A[2,1], A[2,2]
26+
# Since only A[1,1] is 10 and others are 0, sum for B[1,1] will be 10 (A[1,1])
27+
# Sum for B[1,2] will be A[1,1] = 10
28+
# Sum for B[2,1] will be A[1,1] = 10
29+
# Sum for B[2,2] will be A[1,1] = 10
30+
# Sum for B[4,4] will be A[1,1] = 10
31+
# ... and so on for elements that wrap around to include A[1,1]
32+
expected_B_calc = zeros(Int, 4, 4)
33+
for i in 1:4, j in 1:4
34+
sum_val = 0
35+
for ni in -1:1, nj in -1:1
36+
# Apply wrap around logic for neighbors
37+
row = mod1(i+ni, 4)
38+
col = mod1(j+nj, 4)
39+
if row == 1 && col == 1 # Check if the wrapped neighbor is A[1,1]
40+
sum_val += 10
41+
end
42+
end
43+
expected_B_calc[i,j] = sum_val
44+
end
45+
@test collect(B) == expected_B_calc
46+
end
47+
48+
@testset "Pad boundary" begin
49+
A = DArray(ones(Int, 4, 4), Blocks(2, 2))
50+
B = DArray(zeros(Int, 4, 4), Blocks(2, 2))
51+
Dagger.spawn_datadeps() do
52+
@stencil begin
53+
B[idx] = sum(@neighbors(A[idx], 1, Pad(0)))
54+
end
55+
end
56+
# Expected result after convolution with zero padding
57+
# Inner elements (e.g., B[2,2]) will sum 9 (3x3 neighborhood of 1s)
58+
# Edge elements (e.g., B[1,2]) will sum 6 (2x3 neighborhood of 1s, 3 zeros from padding)
59+
# Corner elements (e.g., B[1,1]) will sum 4 (2x2 neighborhood of 1s, 5 zeros from padding)
60+
expected_B_pad = [
61+
4 6 6 4;
62+
6 9 9 6;
63+
6 9 9 6;
64+
4 6 6 4
65+
]
66+
@test collect(B) == expected_B_pad
67+
end
68+
69+
@testset "Multiple expressions" begin
70+
A = zeros(Blocks(2, 2), Int, 4, 4)
71+
B = zeros(Blocks(2, 2), Int, 4, 4)
72+
Dagger.spawn_datadeps() do
73+
@stencil begin
74+
A[idx] = 1
75+
B[idx] = A[idx] * 2
76+
end
77+
end
78+
expected_A_multi = [1 for r in 1:4, c in 1:4]
79+
expected_B_multi = expected_A_multi .* 2
80+
@test collect(A) == expected_A_multi
81+
@test collect(B) == expected_B_multi
82+
end
83+
84+
@testset "Multiple DArrays" begin
85+
A = ones(Blocks(2, 2), Int, 4, 4)
86+
B = DArray(fill(2, 4, 4), Blocks(2, 2))
87+
C = zeros(Blocks(2, 2), Int, 4, 4)
88+
Dagger.spawn_datadeps() do
89+
@stencil begin
90+
C[idx] = A[idx] + B[idx]
91+
end
92+
end
93+
@test all(collect(C) .== 3)
94+
end
95+
96+
@testset "Pad boundary with non-zero value" begin
97+
A = ones(Blocks(1, 1), Int, 2, 2) # Simpler 2x2 case
98+
B = zeros(Blocks(1, 1), Int, 2, 2)
99+
pad_value = 5
100+
Dagger.spawn_datadeps() do
101+
@stencil begin
102+
B[idx] = sum(@neighbors(A[idx], 1, Pad(pad_value)))
103+
end
104+
end
105+
# For A = [1 1; 1 1] and Pad(5)
106+
# B[1,1] neighbors considering a 3x3 neighborhood around A[1,1]:
107+
# P P P
108+
# P A11 A12
109+
# P A21 A22
110+
# Values:
111+
# 5 5 5
112+
# 5 1 1
113+
# 5 1 1
114+
# Sum = 5*5 (for the padded values) + 1*4 (for the actual values from A) = 25 + 4 = 29.
115+
# This logic applies to all elements in B because the array A is small (2x2) and the neighborhood is 1.
116+
# Every element's 3x3 neighborhood will include 5 padded values and the 4 values of A.
117+
expected_B_pad_val = fill(pad_value*5 + 1*4, 2, 2)
118+
@test collect(B) == expected_B_pad_val
119+
end
120+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ tests = [
2121
("Array - LinearAlgebra - Cholesky", "array/linalg/cholesky.jl"),
2222
("Array - LinearAlgebra - LU", "array/linalg/lu.jl"),
2323
("Array - Random", "array/random.jl"),
24+
("Array - Stencils", "array/stencil.jl"),
2425
("Caching", "cache.jl"),
2526
("Disk Caching", "diskcaching.jl"),
2627
("File IO", "file-io.jl"),

0 commit comments

Comments
 (0)