@@ -11,12 +11,13 @@ function load_neighbor_edge(arr, dim, dir, neigh_dist)
11
11
start_idx = CartesianIndex (ntuple (i -> i == dim ? firstindex (arr, i) : firstindex (arr, i), ndims (arr)))
12
12
stop_idx = CartesianIndex (ntuple (i -> i == dim ? (firstindex (arr, i) + neigh_dist - 1 ) : lastindex (arr, i), ndims (arr)))
13
13
end
14
- return collect (@view arr[start_idx: stop_idx])
14
+ # FIXME : Don't collect
15
+ return move (thunk_processor (), collect (@view arr[start_idx: stop_idx]))
15
16
end
16
17
function load_neighbor_corner (arr, corner_side, neigh_dist)
17
18
start_idx = CartesianIndex (ntuple (i -> corner_side[i] == 0 ? (lastindex (arr, i) - neigh_dist + 1 ) : firstindex (arr, i), ndims (arr)))
18
19
stop_idx = CartesianIndex (ntuple (i -> corner_side[i] == 0 ? lastindex (arr, i) : (firstindex (arr, i) + neigh_dist - 1 ), ndims (arr)))
19
- return collect (@view arr[start_idx: stop_idx])
20
+ return move ( thunk_processor (), collect (@view arr[start_idx: stop_idx]) )
20
21
end
21
22
function select_neighborhood_chunks (chunks, idx, neigh_dist, boundary)
22
23
@assert neigh_dist isa Integer && neigh_dist > 0 " Neighborhood distance must be an Integer greater than 0"
@@ -71,19 +72,30 @@ function select_neighborhood_chunks(chunks, idx, neigh_dist, boundary)
71
72
@assert length (accesses) == 1 + 2 * ndims (chunks)+ 2 ^ ndims (chunks) " Accesses mismatch: $(length (accesses)) "
72
73
return accesses
73
74
end
74
- function build_halo (neigh_dist, boundary, center:: Array{T,N} , all_neighbors... ) where {T,N}
75
- # FIXME : Don't collect views
76
- edges = collect .( all_neighbors[1 : (2 * N)])
77
- corners = collect .( all_neighbors[((2 ^ N)+ 1 ): end ])
75
+ function build_halo (neigh_dist, boundary, center, all_neighbors... )
76
+ N = ndims (center)
77
+ edges = all_neighbors[1 : (2 * N)]
78
+ corners = all_neighbors[((2 ^ N)+ 1 ): end ]
78
79
@assert length (edges) == 2 * N && length (corners) == 2 ^ N " Halo mismatch: edges=$(length (edges)) corners=$(length (corners)) "
79
- arr = HaloArray (center, (edges... ,), (corners... ,), ntuple (_-> neigh_dist, N))
80
- return arr
80
+ return HaloArray (center, (edges... ,), (corners... ,), ntuple (_-> neigh_dist, N))
81
81
end
82
- function load_neighborhood (arr:: HaloArray{T,N} , idx, neigh_dist) where {T,N}
82
+ function load_neighborhood (arr:: HaloArray{T,N} , idx) where {T,N}
83
+ @assert all (arr. halo_width .== arr. halo_width[1 ])
84
+ neigh_dist = arr. halo_width[1 ]
83
85
start_idx = idx - CartesianIndex (ntuple (_-> neigh_dist, ndims (arr)))
84
86
stop_idx = idx + CartesianIndex (ntuple (_-> neigh_dist, ndims (arr)))
85
- # FIXME : Don't collect HaloArray view
86
- return collect (@view arr[start_idx: stop_idx])
87
+ return @view arr[start_idx: stop_idx]
88
+ end
89
+ function inner_stencil! (f, output, read_vars)
90
+ processor = thunk_processor ()
91
+ inner_stencil_proc! (processor, f, output, read_vars)
92
+ end
93
+ # Non-KA (for CPUs)
94
+ function inner_stencil_proc! (:: ThreadProc , f, output, read_vars)
95
+ for idx in CartesianIndices (output)
96
+ f (idx, output, read_vars)
97
+ end
98
+ return
87
99
end
88
100
89
101
is_past_boundary (size, idx) = any (ntuple (i -> idx[i] < 1 || idx[i] > size[i], length (size)))
@@ -108,17 +120,19 @@ function load_boundary_edge(pad::Pad, arr, dim, dir, neigh_dist)
108
120
stop_idx = CartesianIndex (ntuple (i -> i == dim ? (firstindex (arr, i) + neigh_dist - 1 ) : lastindex (arr, i), ndims (arr)))
109
121
end
110
122
edge_size = ntuple (i -> length (start_idx[i]: stop_idx[i]), ndims (arr))
111
- return Fill (pad. padval, edge_size)
123
+ # FIXME : return Fill(pad.padval, edge_size)
124
+ return move (thunk_processor (), fill (pad. padval, edge_size))
112
125
end
113
126
function load_boundary_corner (pad:: Pad , arr, corner_side, neigh_dist)
114
127
start_idx = CartesianIndex (ntuple (i -> corner_side[i] == 0 ? (lastindex (arr, i) - neigh_dist + 1 ) : firstindex (arr, i), ndims (arr)))
115
128
stop_idx = CartesianIndex (ntuple (i -> corner_side[i] == 0 ? lastindex (arr, i) : (firstindex (arr, i) + neigh_dist - 1 ), ndims (arr)))
116
129
corner_size = ntuple (i -> length (start_idx[i]: stop_idx[i]), ndims (arr))
117
- return Fill (pad. padval, corner_size)
130
+ # FIXME : return Fill(pad.padval, corner_size)
131
+ return move (thunk_processor (), fill (pad. padval, corner_size))
118
132
end
119
133
120
134
"""
121
- @stencil idx in range begin body end
135
+ @stencil begin body end
122
136
123
137
Allows the specification of stencil operations within a `spawn_datadeps`
124
138
region. The `idx` variable is used to iterate over `range`, which must be a
@@ -205,21 +219,25 @@ macro stencil(orig_ex)
205
219
@gensym chunk_idx
206
220
207
221
# Generate function with transformed body
208
- @gensym inner_index_var
222
+ @gensym inner_vars inner_index_var
209
223
new_inner_ex_body = prewalk (inner_ex) do old_inner_ex
210
224
if @capture (old_inner_ex, read_var_[read_idx_]) && read_idx == write_idx
211
225
# Direct access
212
- return :($ read_var[$ inner_index_var])
226
+ if read_var == write_var
227
+ return :($ write_var[$ inner_index_var])
228
+ else
229
+ return :($ inner_vars.$ read_var[$ inner_index_var])
230
+ end
213
231
elseif @capture (old_inner_ex, @neighbors (read_var_[read_idx_], neigh_dist_, boundary_))
214
232
# Neighborhood access
215
- return :($ load_neighborhood ($ read_var, $ inner_index_var, $ neigh_dist ))
233
+ return :($ load_neighborhood ($ inner_vars. $ read_var, $ inner_index_var))
216
234
end
217
235
return old_inner_ex
218
236
end
237
+ new_inner_f = :(($ inner_index_var, $ write_var, $ inner_vars)-> $ new_inner_ex_body)
219
238
new_inner_ex = quote
220
- for $ inner_index_var in CartesianIndices ($ write_var)
221
- $ new_inner_ex_body
222
- end
239
+ $ inner_vars = (;$ (read_vars... ))
240
+ $ inner_stencil! ($ new_inner_f, $ write_var, $ inner_vars)
223
241
end
224
242
inner_fn = Expr (:-> , Expr (:tuple , Expr (:parameters , write_var, read_vars... )), new_inner_ex)
225
243
@@ -254,7 +272,6 @@ macro stencil(orig_ex)
254
272
end )
255
273
end
256
274
257
- @show final_ex
258
275
259
276
return esc (final_ex)
260
277
end
0 commit comments