Skip to content

Commit c85eed4

Browse files
committed
datadeps: Support within-DAG deps for AOT scheduler
1 parent 2888fa4 commit c85eed4

File tree

1 file changed

+57
-12
lines changed

1 file changed

+57
-12
lines changed

src/datadeps.jl

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -96,40 +96,61 @@ struct DAGSpec
9696
Dict{Int, Type}(),
9797
Dict{Int, Vector{DatadepsArgSpec}}())
9898
end
99-
function Base.push!(dspec::DAGSpec, tspec::DTaskSpec, task::DTask)
99+
function dag_add_task!(dspec::DAGSpec, tspec::DTaskSpec, task::DTask)
100+
# Check if this task depends on any other tasks within the DAG,
101+
# which we are not yet ready to handle
102+
for (idx, (kwpos, arg)) in enumerate(tspec.args)
103+
arg, deps = unwrap_inout(arg)
104+
pos = kwpos isa Symbol ? kwpos : idx
105+
for (dep_mod, readdep, writedep) in deps
106+
if arg isa DTask
107+
if arg.uid in keys(dspec.uid_to_id)
108+
# Within-DAG dependency, bail out
109+
return false
110+
end
111+
end
112+
end
113+
end
114+
100115
add_vertex!(dspec.g)
101116
id = nv(dspec.g)
102117

118+
# Record function signature
103119
dspec.id_to_functype[id] = typeof(tspec.f)
104-
105-
dspec.id_to_argtypes[id] = DatadepsArgSpec[]
120+
argtypes = DatadepsArgSpec[]
106121
for (idx, (kwpos, arg)) in enumerate(tspec.args)
107122
arg, deps = unwrap_inout(arg)
108123
pos = kwpos isa Symbol ? kwpos : idx
109124
for (dep_mod, readdep, writedep) in deps
110125
if arg isa DTask
126+
#= TODO: Re-enable this when we can handle within-DAG dependencies
111127
if arg.uid in keys(dspec.uid_to_id)
112128
# Within-DAG dependency
113129
arg_id = dspec.uid_to_id[arg.uid]
114130
push!(dspec.id_to_argtypes[arg_id], DatadepsArgSpec(pos, DTaskDAGID{arg_id}, dep_mod, UnknownAliasing()))
115131
add_edge!(dspec.g, arg_id, id)
116132
continue
117133
end
134+
=#
118135

119136
# External DTask, so fetch this and track it as a raw value
120137
arg = fetch(arg; raw=true)
121138
end
122139
ainfo = aliasing(arg, dep_mod)
123-
push!(dspec.id_to_argtypes[id], DatadepsArgSpec(pos, typeof(arg), dep_mod, ainfo))
140+
push!(argtypes, DatadepsArgSpec(pos, typeof(arg), dep_mod, ainfo))
124141
end
125142
end
143+
dspec.id_to_argtypes[id] = argtypes
126144

127145
# FIXME: Also record some portion of options
128146
# FIXME: Record syncdeps
129147
dspec.id_to_uid[id] = task.uid
130148
dspec.uid_to_id[task.uid] = id
131149

132-
return
150+
return true
151+
end
152+
function dag_has_task(dspec::DAGSpec, task::DTask)
153+
return task.uid in keys(dspec.uid_to_id)
133154
end
134155
function Base.:(==)(dspec1::DAGSpec, dspec2::DAGSpec)
135156
# Are the graphs the same size?
@@ -621,7 +642,6 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
621642
@warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1
622643
end
623644

624-
# Round-robin assign tasks to processors
625645
upper_queue = get_options(:task_queue)
626646

627647
state = DataDepsState(queue.aliasing, all_procs)
@@ -632,7 +652,10 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
632652
if DATADEPS_SCHEDULE_REUSABLE[]
633653
# Compute DAG spec
634654
for (spec, task) in queue.seen_tasks
635-
push!(state.dag_spec, spec, task)
655+
if !dag_add_task!(state.dag_spec, spec, task)
656+
# This task needs to be deferred
657+
break
658+
end
636659
end
637660

638661
# Find any matching DAG specs and reuse their schedule
@@ -654,13 +677,20 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
654677

655678
# Populate all task dependencies
656679
write_num = 1
680+
task_num = 0
657681
for (spec, task) in queue.seen_tasks
682+
if !dag_has_task(state.dag_spec, task)
683+
# This task needs to be deferred
684+
break
685+
end
658686
write_num = populate_task_info!(state, spec, task, write_num)
687+
task_num += 1
659688
end
689+
@assert task_num > 0
660690

661691
if isempty(schedule)
662692
# Run AOT scheduling
663-
schedule = datadeps_create_schedule(queue.scheduler, state, queue.seen_tasks)::Dict{DTask, Processor}
693+
schedule = datadeps_create_schedule(queue.scheduler, state, queue.seen_tasks[1:task_num])::Dict{DTask, Processor}
664694

665695
if DATADEPS_SCHEDULE_REUSABLE[]
666696
# Cache the schedule
@@ -680,6 +710,11 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
680710
# Launch tasks and necessary copies
681711
write_num = 1
682712
for (spec, task) in queue.seen_tasks
713+
if !dag_has_task(state.dag_spec, task)
714+
# This task needs to be deferred
715+
break
716+
end
717+
683718
our_proc = schedule[task]
684719
@assert our_proc in all_procs
685720
our_space = only(memory_spaces(our_proc))
@@ -829,6 +864,9 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
829864
write_num += 1
830865
end
831866

867+
# Remove processed tasks
868+
deleteat!(queue.seen_tasks, 1:task_num)
869+
832870
# Copy args from remote to local
833871
if queue.aliasing
834872
# We need to replay the writes from all tasks in-order (skipping any
@@ -961,18 +999,25 @@ function spawn_datadeps(f::Base.Callable; static::Bool=true,
961999
wait_all(; check_errors=true) do
9621000
scheduler = something(scheduler, DATADEPS_SCHEDULER[], RoundRobinScheduler())
9631001
launch_wait = something(launch_wait, DATADEPS_LAUNCH_WAIT[], false)::Bool
1002+
local result
9641003
if launch_wait
965-
result = spawn_bulk() do
1004+
spawn_bulk() do
9661005
queue = DataDepsTaskQueue(get_options(:task_queue);
9671006
scheduler, aliasing)
968-
with_options(f; task_queue=queue)
969-
distribute_tasks!(queue)
1007+
result = with_options(f; task_queue=queue)
1008+
while !isempty(queue.seen_tasks)
1009+
@dagdebug nothing :spawn_datadeps "Entering Datadeps region"
1010+
distribute_tasks!(queue)
1011+
end
9701012
end
9711013
else
9721014
queue = DataDepsTaskQueue(get_options(:task_queue);
9731015
scheduler, aliasing)
9741016
result = with_options(f; task_queue=queue)
975-
distribute_tasks!(queue)
1017+
while !isempty(queue.seen_tasks)
1018+
@dagdebug nothing :spawn_datadeps "Entering Datadeps region"
1019+
distribute_tasks!(queue)
1020+
end
9761021
end
9771022
return result
9781023
end

0 commit comments

Comments
 (0)