Skip to content

Commit 7db1640

Browse files
committed
stencils: Add GPU support
1 parent a0b1502 commit 7db1640

File tree

4 files changed

+48
-54
lines changed

4 files changed

+48
-54
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
88
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
99
DistributedNext = "fab6aee4-877b-4bac-a744-3eca44acbb6f"
10-
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1110
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1211
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1312
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
@@ -52,7 +51,6 @@ DataFrames = "1"
5251
DataStructures = "0.18"
5352
DistributedNext = "1.0.0"
5453
Distributions = "0.25"
55-
FillArrays = "1.11.0"
5654
GraphViz = "0.2"
5755
Graphs = "1"
5856
JSON3 = "1"

src/Dagger.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ import TimespanLogging: timespan_start, timespan_finish
3232

3333
import Adapt
3434

35-
import FillArrays: Fill
36-
3735
# Preferences
3836
import Preferences: @load_preference, @set_preferences!
3937

src/stencil.jl

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@ function load_neighbor_edge(arr, dim, dir, neigh_dist)
1111
start_idx = CartesianIndex(ntuple(i -> i == dim ? firstindex(arr, i) : firstindex(arr, i), ndims(arr)))
1212
stop_idx = CartesianIndex(ntuple(i -> i == dim ? (firstindex(arr, i) + neigh_dist - 1) : lastindex(arr, i), ndims(arr)))
1313
end
14-
return collect(@view arr[start_idx:stop_idx])
14+
# FIXME: Don't collect
15+
return move(thunk_processor(), collect(@view arr[start_idx:stop_idx]))
1516
end
1617
function load_neighbor_corner(arr, corner_side, neigh_dist)
1718
start_idx = CartesianIndex(ntuple(i -> corner_side[i] == 0 ? (lastindex(arr, i) - neigh_dist + 1) : firstindex(arr, i), ndims(arr)))
1819
stop_idx = CartesianIndex(ntuple(i -> corner_side[i] == 0 ? lastindex(arr, i) : (firstindex(arr, i) + neigh_dist - 1), ndims(arr)))
19-
return collect(@view arr[start_idx:stop_idx])
20+
return move(thunk_processor(), collect(@view arr[start_idx:stop_idx]))
2021
end
2122
function select_neighborhood_chunks(chunks, idx, neigh_dist, boundary)
2223
@assert neigh_dist isa Integer && neigh_dist > 0 "Neighborhood distance must be an Integer greater than 0"
@@ -71,19 +72,30 @@ function select_neighborhood_chunks(chunks, idx, neigh_dist, boundary)
7172
@assert length(accesses) == 1+2*ndims(chunks)+2^ndims(chunks) "Accesses mismatch: $(length(accesses))"
7273
return accesses
7374
end
74-
function build_halo(neigh_dist, boundary, center::Array{T,N}, all_neighbors...) where {T,N}
75-
# FIXME: Don't collect views
76-
edges = collect.(all_neighbors[1:(2*N)])
77-
corners = collect.(all_neighbors[((2^N)+1):end])
75+
function build_halo(neigh_dist, boundary, center, all_neighbors...)
76+
N = ndims(center)
77+
edges = all_neighbors[1:(2*N)]
78+
corners = all_neighbors[((2^N)+1):end]
7879
@assert length(edges) == 2*N && length(corners) == 2^N "Halo mismatch: edges=$(length(edges)) corners=$(length(corners))"
79-
arr = HaloArray(center, (edges...,), (corners...,), ntuple(_->neigh_dist, N))
80-
return arr
80+
return HaloArray(center, (edges...,), (corners...,), ntuple(_->neigh_dist, N))
8181
end
82-
function load_neighborhood(arr::HaloArray{T,N}, idx, neigh_dist) where {T,N}
82+
function load_neighborhood(arr::HaloArray{T,N}, idx) where {T,N}
83+
@assert all(arr.halo_width .== arr.halo_width[1])
84+
neigh_dist = arr.halo_width[1]
8385
start_idx = idx - CartesianIndex(ntuple(_->neigh_dist, ndims(arr)))
8486
stop_idx = idx + CartesianIndex(ntuple(_->neigh_dist, ndims(arr)))
85-
# FIXME: Don't collect HaloArray view
86-
return collect(@view arr[start_idx:stop_idx])
87+
return @view arr[start_idx:stop_idx]
88+
end
89+
function inner_stencil!(f, output, read_vars)
90+
processor = thunk_processor()
91+
inner_stencil_proc!(processor, f, output, read_vars)
92+
end
93+
# Non-KA (for CPUs)
94+
function inner_stencil_proc!(::ThreadProc, f, output, read_vars)
95+
for idx in CartesianIndices(output)
96+
f(idx, output, read_vars)
97+
end
98+
return
8799
end
88100

89101
is_past_boundary(size, idx) = any(ntuple(i -> idx[i] < 1 || idx[i] > size[i], length(size)))
@@ -108,17 +120,19 @@ function load_boundary_edge(pad::Pad, arr, dim, dir, neigh_dist)
108120
stop_idx = CartesianIndex(ntuple(i -> i == dim ? (firstindex(arr, i) + neigh_dist - 1) : lastindex(arr, i), ndims(arr)))
109121
end
110122
edge_size = ntuple(i -> length(start_idx[i]:stop_idx[i]), ndims(arr))
111-
return Fill(pad.padval, edge_size)
123+
# FIXME: return Fill(pad.padval, edge_size)
124+
return move(thunk_processor(), fill(pad.padval, edge_size))
112125
end
113126
function load_boundary_corner(pad::Pad, arr, corner_side, neigh_dist)
114127
start_idx = CartesianIndex(ntuple(i -> corner_side[i] == 0 ? (lastindex(arr, i) - neigh_dist + 1) : firstindex(arr, i), ndims(arr)))
115128
stop_idx = CartesianIndex(ntuple(i -> corner_side[i] == 0 ? lastindex(arr, i) : (firstindex(arr, i) + neigh_dist - 1), ndims(arr)))
116129
corner_size = ntuple(i -> length(start_idx[i]:stop_idx[i]), ndims(arr))
117-
return Fill(pad.padval, corner_size)
130+
# FIXME: return Fill(pad.padval, corner_size)
131+
return move(thunk_processor(), fill(pad.padval, corner_size))
118132
end
119133

120134
"""
121-
@stencil idx in range begin body end
135+
@stencil begin body end
122136
123137
Allows the specification of stencil operations within a `spawn_datadeps`
124138
region. The `idx` variable is used to iterate over `range`, which must be a
@@ -205,21 +219,25 @@ macro stencil(orig_ex)
205219
@gensym chunk_idx
206220

207221
# Generate function with transformed body
208-
@gensym inner_index_var
222+
@gensym inner_vars inner_index_var
209223
new_inner_ex_body = prewalk(inner_ex) do old_inner_ex
210224
if @capture(old_inner_ex, read_var_[read_idx_]) && read_idx == write_idx
211225
# Direct access
212-
return :($read_var[$inner_index_var])
226+
if read_var == write_var
227+
return :($write_var[$inner_index_var])
228+
else
229+
return :($inner_vars.$read_var[$inner_index_var])
230+
end
213231
elseif @capture(old_inner_ex, @neighbors(read_var_[read_idx_], neigh_dist_, boundary_))
214232
# Neighborhood access
215-
return :($load_neighborhood($read_var, $inner_index_var, $neigh_dist))
233+
return :($load_neighborhood($inner_vars.$read_var, $inner_index_var))
216234
end
217235
return old_inner_ex
218236
end
237+
new_inner_f = :(($inner_index_var, $write_var, $inner_vars)->$new_inner_ex_body)
219238
new_inner_ex = quote
220-
for $inner_index_var in CartesianIndices($write_var)
221-
$new_inner_ex_body
222-
end
239+
$inner_vars = (;$(read_vars...))
240+
$inner_stencil!($new_inner_f, $write_var, $inner_vars)
223241
end
224242
inner_fn = Expr(:->, Expr(:tuple, Expr(:parameters, write_var, read_vars...)), new_inner_ex)
225243

@@ -254,7 +272,6 @@ macro stencil(orig_ex)
254272
end)
255273
end
256274

257-
@show final_ex
258275

259276
return esc(final_ex)
260277
end

src/utils/haloarray.jl

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Define the HaloArray type with minimized halo storage
2-
struct HaloArray{T,N,E,C,A,EA,CA} <: AbstractArray{T,N}
2+
struct HaloArray{T,N,E,C,A,EAT<:Tuple,CAT<:Tuple} <: AbstractArray{T,N}
33
center::A
4-
edges::NTuple{E, EA}
5-
corners::NTuple{C, CA}
4+
edges::EAT
5+
corners::CAT
66
halo_width::NTuple{N,Int}
77
end
88

@@ -17,11 +17,11 @@ function HaloArray{T,N}(center_size::NTuple{N,Int}, halo_width::NTuple{N,Int}) w
1717
corners = ntuple(2^N) do i
1818
return Array{T,N}(undef, halo_width)
1919
end
20-
return HaloArray{T,N,2N,2^N}(center, edges, corners, halo_width)
20+
return HaloArray(center, edges, corners, halo_width)
2121
end
2222

23-
HaloArray(center::AT, edges::NTuple{E, EA}, corners::NTuple{C, CA}, halo_width::NTuple{N, Int}) where {T,N,AT<:AbstractArray{T,N},C,E,CA,EA} =
24-
HaloArray{T,N,E,C,AT,EA,CA}(center, edges, corners, halo_width)
23+
HaloArray(center::AT, edges::EAT, corners::CAT, halo_width::NTuple{N, Int}) where {T,N,AT<:AbstractArray{T,N},CAT<:Tuple,EAT<:Tuple} =
24+
HaloArray{T,N,length(edges),length(corners),AT,EAT,CAT}(center, edges, corners, halo_width)
2525

2626
Base.size(tile::HaloArray) = size(tile.center) .+ 2 .* tile.halo_width
2727
function Base.axes(tile::HaloArray{T,N,H}) where {T,N,H}
@@ -57,10 +57,10 @@ function Base.getindex(tile::HaloArray{T,N}, I::Vararg{Int,N}) where {T,N}
5757
else
5858
for d in 1:N
5959
if I[d] < 1
60-
halo_idx = (I[1:d-1]..., I[d] + tile.halo_width[d], I[d+1:end]...)
60+
halo_idx = ntuple(i->i == d ? I[i] + tile.halo_width[i] : I[i], N)
6161
return tile.edges[(2*(d-1))+1][halo_idx...]
6262
elseif I[d] > size(tile.center, d)
63-
halo_idx = (I[1:d-1]..., I[d] - size(tile.center, d), I[d+1:end]...)
63+
halo_idx = ntuple(i->i == d ? I[i] - size(tile.center, d) : I[i], N)
6464
return tile.edges[(2*(d-1))+2][halo_idx...]
6565
end
6666
end
@@ -84,32 +84,13 @@ function Base.setindex!(tile::HaloArray{T,N}, value, I::Vararg{Int,N}) where {T,
8484
# Edge
8585
for d in 1:N
8686
if I[d] < 1
87-
halo_idx = (I[1:d-1]..., I[d] + tile.halo_width[d], I[d+1:end]...)
87+
halo_idx = ntuple(i->i == d ? I[i] + tile.halo_width[i] : I[i], N)
8888
return tile.edges[(2*(d-1))+1][halo_idx...] = value
8989
elseif I[d] > size(tile.center, d)
90-
halo_idx = (I[1:d-1]..., I[d] - size(tile.center, d), I[d+1:end]...)
90+
halo_idx = ntuple(i->i == d ? I[i] - size(tile.center, d) : I[i], N)
9191
return tile.edges[(2*(d-1))+2][halo_idx...] = value
9292
end
9393
end
9494
end
9595
error("Index out of bounds")
9696
end
97-
98-
#=
99-
# Example usage
100-
center_size = (3, 5)
101-
halo_width = (1, 1)
102-
tile = HaloArray{Float64, 2}(center_size, halo_width)
103-
104-
# Set values in the center and halo
105-
tile[2, 2] = 1.0
106-
tile[0, 2] = 2.0 # This should be in an edge
107-
tile[0, 0] = 3.0 # This should be in a corner
108-
tile[4, 6] = 4.0 # This should be in a corner
109-
110-
# Get values from the center and halo
111-
println(tile[2, 2]) # 1.0
112-
println(tile[0, 2]) # 2.0
113-
println(tile[0, 0]) # 3.0
114-
println(tile[4, 6]) # 4.0
115-
=#

0 commit comments

Comments
 (0)