@@ -96,40 +96,61 @@ struct DAGSpec
96
96
Dict {Int, Type} (),
97
97
Dict {Int, Vector{DatadepsArgSpec}} ())
98
98
end
99
- function Base. push! (dspec:: DAGSpec , tspec:: DTaskSpec , task:: DTask )
99
+ function dag_add_task! (dspec:: DAGSpec , tspec:: DTaskSpec , task:: DTask )
100
+ # Check if this task depends on any other tasks within the DAG,
101
+ # which we are not yet ready to handle
102
+ for (idx, (kwpos, arg)) in enumerate (tspec. args)
103
+ arg, deps = unwrap_inout (arg)
104
+ pos = kwpos isa Symbol ? kwpos : idx
105
+ for (dep_mod, readdep, writedep) in deps
106
+ if arg isa DTask
107
+ if arg. uid in keys (dspec. uid_to_id)
108
+ # Within-DAG dependency, bail out
109
+ return false
110
+ end
111
+ end
112
+ end
113
+ end
114
+
100
115
add_vertex! (dspec. g)
101
116
id = nv (dspec. g)
102
117
118
+ # Record function signature
103
119
dspec. id_to_functype[id] = typeof (tspec. f)
104
-
105
- dspec. id_to_argtypes[id] = DatadepsArgSpec[]
120
+ argtypes = DatadepsArgSpec[]
106
121
for (idx, (kwpos, arg)) in enumerate (tspec. args)
107
122
arg, deps = unwrap_inout (arg)
108
123
pos = kwpos isa Symbol ? kwpos : idx
109
124
for (dep_mod, readdep, writedep) in deps
110
125
if arg isa DTask
126
+ #= TODO : Re-enable this when we can handle within-DAG dependencies
111
127
if arg.uid in keys(dspec.uid_to_id)
112
128
# Within-DAG dependency
113
129
arg_id = dspec.uid_to_id[arg.uid]
114
130
push!(dspec.id_to_argtypes[arg_id], DatadepsArgSpec(pos, DTaskDAGID{arg_id}, dep_mod, UnknownAliasing()))
115
131
add_edge!(dspec.g, arg_id, id)
116
132
continue
117
133
end
134
+ =#
118
135
119
136
# External DTask, so fetch this and track it as a raw value
120
137
arg = fetch (arg; raw= true )
121
138
end
122
139
ainfo = aliasing (arg, dep_mod)
123
- push! (dspec . id_to_argtypes[id] , DatadepsArgSpec (pos, typeof (arg), dep_mod, ainfo))
140
+ push! (argtypes , DatadepsArgSpec (pos, typeof (arg), dep_mod, ainfo))
124
141
end
125
142
end
143
+ dspec. id_to_argtypes[id] = argtypes
126
144
127
145
# FIXME : Also record some portion of options
128
146
# FIXME : Record syncdeps
129
147
dspec. id_to_uid[id] = task. uid
130
148
dspec. uid_to_id[task. uid] = id
131
149
132
- return
150
+ return true
151
+ end
152
+ function dag_has_task (dspec:: DAGSpec , task:: DTask )
153
+ return task. uid in keys (dspec. uid_to_id)
133
154
end
134
155
function Base.:(== )(dspec1:: DAGSpec , dspec2:: DAGSpec )
135
156
# Are the graphs the same size?
@@ -621,7 +642,6 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
621
642
@warn " Datadeps support for multi-GPU, multi-worker is currently broken\n Please be prepared for incorrect results or errors" maxlog= 1
622
643
end
623
644
624
- # Round-robin assign tasks to processors
625
645
upper_queue = get_options (:task_queue )
626
646
627
647
state = DataDepsState (queue. aliasing, all_procs)
@@ -632,7 +652,10 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
632
652
if DATADEPS_SCHEDULE_REUSABLE[]
633
653
# Compute DAG spec
634
654
for (spec, task) in queue. seen_tasks
635
- push! (state. dag_spec, spec, task)
655
+ if ! dag_add_task! (state. dag_spec, spec, task)
656
+ # This task needs to be deferred
657
+ break
658
+ end
636
659
end
637
660
638
661
# Find any matching DAG specs and reuse their schedule
@@ -654,13 +677,20 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
654
677
655
678
# Populate all task dependencies
656
679
write_num = 1
680
+ task_num = 0
657
681
for (spec, task) in queue. seen_tasks
682
+ if ! dag_has_task (state. dag_spec, task)
683
+ # This task needs to be deferred
684
+ break
685
+ end
658
686
write_num = populate_task_info! (state, spec, task, write_num)
687
+ task_num += 1
659
688
end
689
+ @assert task_num > 0
660
690
661
691
if isempty (schedule)
662
692
# Run AOT scheduling
663
- schedule = datadeps_create_schedule (queue. scheduler, state, queue. seen_tasks):: Dict{DTask, Processor}
693
+ schedule = datadeps_create_schedule (queue. scheduler, state, queue. seen_tasks[ 1 : task_num] ):: Dict{DTask, Processor}
664
694
665
695
if DATADEPS_SCHEDULE_REUSABLE[]
666
696
# Cache the schedule
@@ -680,6 +710,11 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
680
710
# Launch tasks and necessary copies
681
711
write_num = 1
682
712
for (spec, task) in queue. seen_tasks
713
+ if ! dag_has_task (state. dag_spec, task)
714
+ # This task needs to be deferred
715
+ break
716
+ end
717
+
683
718
our_proc = schedule[task]
684
719
@assert our_proc in all_procs
685
720
our_space = only (memory_spaces (our_proc))
@@ -829,6 +864,9 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
829
864
write_num += 1
830
865
end
831
866
867
+ # Remove processed tasks
868
+ deleteat! (queue. seen_tasks, 1 : task_num)
869
+
832
870
# Copy args from remote to local
833
871
if queue. aliasing
834
872
# We need to replay the writes from all tasks in-order (skipping any
@@ -961,18 +999,25 @@ function spawn_datadeps(f::Base.Callable; static::Bool=true,
961
999
wait_all (; check_errors= true ) do
962
1000
scheduler = something (scheduler, DATADEPS_SCHEDULER[], RoundRobinScheduler ())
963
1001
launch_wait = something (launch_wait, DATADEPS_LAUNCH_WAIT[], false ):: Bool
1002
+ local result
964
1003
if launch_wait
965
- result = spawn_bulk () do
1004
+ spawn_bulk () do
966
1005
queue = DataDepsTaskQueue (get_options (:task_queue );
967
1006
scheduler, aliasing)
968
- with_options (f; task_queue= queue)
969
- distribute_tasks! (queue)
1007
+ result = with_options (f; task_queue= queue)
1008
+ while ! isempty (queue. seen_tasks)
1009
+ @dagdebug nothing :spawn_datadeps " Entering Datadeps region"
1010
+ distribute_tasks! (queue)
1011
+ end
970
1012
end
971
1013
else
972
1014
queue = DataDepsTaskQueue (get_options (:task_queue );
973
1015
scheduler, aliasing)
974
1016
result = with_options (f; task_queue= queue)
975
- distribute_tasks! (queue)
1017
+ while ! isempty (queue. seen_tasks)
1018
+ @dagdebug nothing :spawn_datadeps " Entering Datadeps region"
1019
+ distribute_tasks! (queue)
1020
+ end
976
1021
end
977
1022
return result
978
1023
end
0 commit comments