Skip to content

Commit 386dfd8

Browse files
committed
scopes/DArray: Prevent GPU running setindex!
1 parent 7db1640 commit 386dfd8

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

src/array/indexing.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ function Base.setindex!(A::DArray{T,N}, value, idx::NTuple{N,Int}) where {T,N}
127127
# Set the value
128128
part = A.chunks[part_idx...]
129129
space = memory_space(part)
130-
scope = Dagger.scope(worker=root_worker_id(space))
130+
# FIXME: Do this correctly w.r.t memory space of part
131+
scope = Dagger.scope(worker=root_worker_id(space), threads=:)
131132
return fetch(Dagger.@spawn scope=scope setindex!(part, value, offset_idx...))
132133
end
133134
Base.setindex!(A::DArray, value, idx::Integer...) =

src/scopes.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -325,13 +325,20 @@ function to_scope(sc::NamedTuple)
325325
else
326326
nothing
327327
end
328+
all_threads = false
328329
threads = if haskey(sc, :thread)
329330
Int[sc.thread]
330331
elseif haskey(sc, :threads)
331-
Int[sc.threads...]
332+
if sc.threads == Colon()
333+
all_threads = true
334+
nothing
335+
else
336+
Int[sc.threads...]
337+
end
332338
else
333339
nothing
334340
end
341+
want_threads = all_threads || threads !== nothing
335342

336343
# Simple cases
337344
if workers !== nothing && threads !== nothing
@@ -341,18 +348,22 @@ function to_scope(sc::NamedTuple)
341348
end
342349
return simplified_union_scope(subscopes)
343350
elseif workers !== nothing && threads === nothing
344-
subscopes = AbstractScope[ProcessScope(w) for w in workers]
345-
return simplified_union_scope(subscopes)
351+
subscopes = simplified_union_scope(AbstractScope[ProcessScope(w) for w in workers])
352+
if all_threads
353+
return constrain(subscopes, ProcessorTypeScope(ThreadProc))
354+
else
355+
return subscopes
356+
end
346357
end
347358

348359
# More complex cases that require querying the cluster
349360
# FIXME: Use per-field scope taint
350361
if workers === nothing
351-
workers = procs()
362+
workers = map(p->p.pid, filter(p->p isa OSProc, procs(Dagger.Sch.eager_context())))
352363
end
353364
subscopes = AbstractScope[]
354365
for w in workers
355-
if threads === nothing
366+
if threads === nothing && want_threads
356367
threads = map(c->c.tid,
357368
filter(c->c isa ThreadProc,
358369
collect(children(OSProc(w)))))

0 commit comments

Comments
 (0)