Skip to content

Commit

Permalink
Add cancel function for concrete graph.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Oct 13, 2024
1 parent ef7bd53 commit e7c7639
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 0 deletions.
7 changes: 7 additions & 0 deletions lib/nnc/_ccv_nnc_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ typedef struct {
ccv_nnc_graph_t* graph;
} ccv_nnc_graph_tensor_wraps_ref_t;

enum {
CCV_NNC_GRAPH_STATE_IDLE = 0,
CCV_NNC_GRAPH_STATE_RUNNING = 1,
CCV_NNC_GRAPH_STATE_CANCEL = 2,
};

struct ccv_nnc_graph_s {
int p_idx; // Reference to the index in its parent graph's sub-graph array, Starts at 1.
int exec_idx; // Reference to the index in its parent graph's exec (the graph exec), Starts at 1.
Expand All @@ -119,6 +125,7 @@ struct ccv_nnc_graph_s {
int stream_size;
int signal_size;
int buffer_size;
int run_state;
ccv_array_t* exec_info; // deferred exec info
// I think that I can be more explicit about which are sources and which are destinations.
// These are int types.
Expand Down
11 changes: 11 additions & 0 deletions lib/nnc/ccv_cnnp_model.c
Original file line number Diff line number Diff line change
Expand Up @@ -3013,3 +3013,14 @@ void ccv_cnnp_model_free(ccv_cnnp_model_t* const model)
ccfree(model->name);
ccfree(model);
}

void ccv_cnnp_model_cancel(ccv_cnnp_model_t* const model)
{
ccv_cnnp_compiled_data_t* const compiled_data = model->compiled_data;
if (!compiled_data)
return;
if (compiled_data->graph)
ccv_nnc_graph_cancel(compiled_data->graph);
if (compiled_data->apply_gradients.graph)
ccv_nnc_graph_cancel(compiled_data->apply_gradients.graph);
}
17 changes: 17 additions & 0 deletions lib/nnc/ccv_nnc.h
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,14 @@ int ccv_nnc_graph_run(ccv_nnc_graph_t* const graph, const int flags, const ccv_n
* @return CCV_NNC_EXEC_SUCCESS if succeed.
*/
int ccv_nnc_graph_run_with_schedule(ccv_nnc_graph_t* const graph, const int flags, const ccv_nnc_graph_static_schedule_t* const schedule, ccv_nnc_tensor_tape_t* const tensor_tape, ccv_nnc_stream_context_t* const stream_context);
/**
* Cancel execution of a graph. You need to handle synchronization yourself when calling this method to make
* sure the graph is currently executing when cancelling. This method will set a flag internally and the
* graph execution will check that flag when push compute on the computation device and abort if it is cancelled.
* When you call ccv_nnc_graph_run again, this cancellation won't in effect and you need to call cancel again.
* @param graph The concrete graph.
*/
void ccv_nnc_graph_cancel(ccv_nnc_graph_t* const graph);

/** @} */

Expand Down Expand Up @@ -3804,6 +3812,15 @@ void ccv_cnnp_model_backward(ccv_cnnp_model_t* const model, ccv_nnc_tensor_t* co
* @param stream_context The stream where the gradient computation can be executed upon.
*/
void ccv_cnnp_model_apply_gradients(ccv_cnnp_model_t* const model, ccv_nnc_stream_context_t* const stream_context);
/**
* Cancel execution of a model, whether it is forward / backward or gradient application pass. You need to make
* sure the model is currently executing when cancelling. This method will set a flag internally and the
* execution will check that flag when push compute on the computation device and abort if it is cancelled.
* When you call other model execution method again, this cancellation won't in effect and you need to call
* cancel again.
* @param model The composed model.
*/
void ccv_cnnp_model_cancel(ccv_cnnp_model_t* const model);
enum {
/**
* This is the default flag, if the model is not initialized, will attempt to read from the disk.
Expand Down
31 changes: 31 additions & 0 deletions lib/nnc/ccv_nnc_graph_run.c
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,8 @@ static co_decl_task(_ccv_nnc_graph_exec_run_loop, (ccv_nnc_graph_t* const graph,
CO_V(pending_node_size)[1] = 0;
for (CO_V(i) = CO_P(start_index); CO_V(i) < CO_P(exec_info_size); CO_V(i)++)
{
if (__atomic_load_n(&CO_P(graph)->run_state, __ATOMIC_ACQUIRE) == CCV_NNC_GRAPH_STATE_CANCEL)
break;
CO_V(idx) = CO_P(psort) ? CO_P(psort)[CO_V(i)] : CO_V(i);
CO_V(node) = CO_P(exec_info) + CO_V(idx);
CO_V(schd) = CO_P(schd_info) + CO_V(idx);
Expand Down Expand Up @@ -628,6 +630,8 @@ static co_decl_task(_ccv_nnc_graph_exec_run_loop, (ccv_nnc_graph_t* const graph,
}
if (CO_V(sub_task_size))
co_apply(_ccv_nnc_graph_wait_any_sub_tasks, (CO_P(graph), CO_V(sub_tasks), CO_V(sub_task_size), CO_P(schd_info), CO_V(pending_nodes)[0], CO_V(pending_node_size)[0]));
if (__atomic_load_n(&CO_P(graph)->run_state, __ATOMIC_ACQUIRE) == CCV_NNC_GRAPH_STATE_CANCEL)
co_return();
CO_V(p) = 0;
CO_V(q) = 1;
while (CO_V(pending_node_size)[CO_V(p)] > 0)
Expand Down Expand Up @@ -746,6 +750,8 @@ co_task(_ccv_nnc_graph_topsorted_run_coro, (ccv_nnc_graph_t* const graph, const
if (CO_V(count) > 0)
_ccv_nnc_graph_transit_move_to(CO_P(graph));
co_apply(_ccv_nnc_graph_exec_run_loop, (CO_P(graph), CO_V(exec_info), CO_V(schd_info), 0, 0, CO_V(graph_breakpoint_size), CO_P(tensor_tape), CO_P(flags)));
if (__atomic_load_n(&CO_P(graph)->run_state, __ATOMIC_ACQUIRE) == CCV_NNC_GRAPH_STATE_CANCEL)
break;
// Reached breakpoints, now check the breakpoint, if not met, break out.
// Wait until everything on the stream is executed.
for (CO_V(i) = CO_P(graph)->breakpoint_offset; CO_V(i) < CO_V(graph_breakpoint_size); CO_V(i)++)
Expand All @@ -759,6 +765,7 @@ co_task(_ccv_nnc_graph_topsorted_run_coro, (ccv_nnc_graph_t* const graph, const
break;
}
co_apply(_ccv_nnc_graph_exec_run_loop, (CO_P(graph), CO_V(exec_info), CO_V(schd_info), 0, CO_V(graph_breakpoint_size), CO_P(graph)->exec_info->rnum, CO_P(tensor_tape), CO_P(flags)));
// If it is cancelled here, we don't need to breakout yet, we can breakout on earlier place. The most important thing is to avoid stream wait if there is a cancel.
_ccv_nnc_graph_from_move_transit(CO_P(graph));
_ccv_nnc_graph_rewrap(CO_P(graph));
}
Expand All @@ -774,6 +781,7 @@ co_task(_ccv_nnc_graph_topsorted_run_coro, (ccv_nnc_graph_t* const graph, const
});
_ccv_nnc_graph_unwrap(CO_P(graph), CO_V(count), CO_V(reverse_count));
co_apply(_ccv_nnc_graph_exec_run_loop, (CO_P(graph), CO_V(exec_info), CO_V(schd_info), 0, CO_P(graph)->breakpoint_offset, CO_P(graph)->exec_info->rnum, CO_P(tensor_tape), CO_P(flags)));
// If it is cancelled here, we don't need to breakout yet, we can breakout later.
_ccv_nnc_graph_from_move_transit(CO_P(graph));
_ccv_nnc_graph_rewrap(CO_P(graph));
for (CO_V(count) = 1; CO_V(reverse_count) > 0; ++CO_V(count))
Expand All @@ -782,17 +790,33 @@ co_task(_ccv_nnc_graph_topsorted_run_coro, (ccv_nnc_graph_t* const graph, const
_ccv_nnc_graph_unwrap(CO_P(graph), CO_V(count), CO_V(reverse_count));
_ccv_nnc_graph_transit_move_to(CO_P(graph));
co_apply(_ccv_nnc_graph_exec_run_loop, (CO_P(graph), CO_V(exec_info), CO_V(schd_info), 0, 0, CO_P(graph)->exec_info->rnum, CO_P(tensor_tape), CO_P(flags)));
if (__atomic_load_n(&CO_P(graph)->run_state, __ATOMIC_ACQUIRE) == CCV_NNC_GRAPH_STATE_CANCEL)
break;
_ccv_nnc_graph_from_move_transit(CO_P(graph));
_ccv_nnc_graph_rewrap(CO_P(graph));
}
}
if (__atomic_load_n(&CO_P(graph)->run_state, __ATOMIC_ACQUIRE) == CCV_NNC_GRAPH_STATE_CANCEL)
{
// The most important thing is to reset main and then return, we don't need to wait for any streaming event.
if (CO_P(exec_idx) == -1 && CO_P(stream_context)->main == co_self())
CO_P(stream_context)->main = 0;
co_return();
}
assert(CO_V(stream_0) == 0);
int i;
for (i = 0; i < CO_P(schedule)->wait_size; i++)
ccv_nnc_stream_context_wait_signal(CO_P(graph)->streams[0], CO_P(graph)->signals[CO_P(schedule)->waits[i]]);
} else {
CO_P(graph)->while_count = 0;
co_apply(_ccv_nnc_graph_exec_run_loop, (CO_P(graph), CO_V(exec_info), CO_V(schd_info), CO_P(schedule)->psort, 0, CO_P(schedule)->psort ? CO_P(schedule)->psort_size : CO_P(schedule)->exec_info_size, CO_P(tensor_tape), CO_P(flags)));
if (__atomic_load_n(&CO_P(graph)->run_state, __ATOMIC_ACQUIRE) == CCV_NNC_GRAPH_STATE_CANCEL)
{
// The most important thing is to reset main and then return, we don't need to wait for any streaming event.
if (CO_P(exec_idx) == -1 && CO_P(stream_context)->main == co_self())
CO_P(stream_context)->main = 0;
co_return();
}
PRINT(CCV_CLI_INFO, "Graph Stream %d End", CO_V(stream_0));
int i, flag = 0;
for (i = 0; i < CO_P(schedule)->wait_size; i++)
Expand Down Expand Up @@ -1067,6 +1091,7 @@ static int _ccv_nnc_graph_run(ccv_nnc_graph_t* const graph, const int exec_idx,

int ccv_nnc_graph_run(ccv_nnc_graph_t* const graph, const int flags, const ccv_nnc_graph_exec_t* const sources, const int source_size, const ccv_nnc_graph_exec_t* const destinations, const int destination_size, ccv_nnc_tensor_tape_t* const tensor_tape, ccv_nnc_stream_context_t* const stream_context)
{
__atomic_store_n(&graph->run_state, CCV_NNC_GRAPH_STATE_RUNNING, __ATOMIC_RELEASE);
if (stream_context && graph->topsorted && graph->stream_size > 0 && graph->default_schedule && source_size == 0 && destination_size == 0)
{
co_scheduler_t* const scheduler = ccv_nnc_stream_context_get_scheduler(stream_context);
Expand All @@ -1083,6 +1108,7 @@ int ccv_nnc_graph_run_with_schedule(ccv_nnc_graph_t* const graph, const int flag
assert(graph->topsorted);
if (graph->exec_info->rnum == 0)
return CCV_NNC_EXEC_SUCCESS;
__atomic_store_n(&graph->run_state, CCV_NNC_GRAPH_STATE_RUNNING, __ATOMIC_RELEASE);
assert(graph->stream_size > 0);
const ccv_nnc_graph_static_schedule_t* const schedule = _schedule ? _schedule : graph->default_schedule;
assert(schedule);
Expand All @@ -1096,3 +1122,8 @@ int ccv_nnc_graph_run_with_schedule(ccv_nnc_graph_t* const graph, const int flag
ccv_nnc_stream_context_wait(stream_context);
return CCV_NNC_EXEC_SUCCESS;
}

void ccv_nnc_graph_cancel(ccv_nnc_graph_t* const graph)
{
__atomic_store_n(&graph->run_state, CCV_NNC_GRAPH_STATE_CANCEL, __ATOMIC_RELEASE);
}

0 comments on commit e7c7639

Please sign in to comment.