diff --git a/lib/nnc/_ccv_nnc_graph.h b/lib/nnc/_ccv_nnc_graph.h index 399ddf19e..7ca00721c 100644 --- a/lib/nnc/_ccv_nnc_graph.h +++ b/lib/nnc/_ccv_nnc_graph.h @@ -110,6 +110,12 @@ typedef struct { ccv_nnc_graph_t* graph; } ccv_nnc_graph_tensor_wraps_ref_t; +enum { + CCV_NNC_GRAPH_STATE_IDLE = 0, + CCV_NNC_GRAPH_STATE_RUNNING = 1, + CCV_NNC_GRAPH_STATE_CANCEL = 2, +}; + struct ccv_nnc_graph_s { int p_idx; // Reference to the index in its parent graph's sub-graph array, Starts at 1. int exec_idx; // Reference to the index in its parent graph's exec (the graph exec), Starts at 1. @@ -119,6 +125,7 @@ struct ccv_nnc_graph_s { int stream_size; int signal_size; int buffer_size; + int run_state; ccv_array_t* exec_info; // deferred exec info // I think that I can be more explicit about which are sources and which are destinations. // These are int types. diff --git a/lib/nnc/ccv_cnnp_model.c b/lib/nnc/ccv_cnnp_model.c index 3c1262b2a..4eb70b78b 100644 --- a/lib/nnc/ccv_cnnp_model.c +++ b/lib/nnc/ccv_cnnp_model.c @@ -3013,3 +3013,14 @@ void ccv_cnnp_model_free(ccv_cnnp_model_t* const model) ccfree(model->name); ccfree(model); } + +void ccv_cnnp_model_cancel(ccv_cnnp_model_t* const model) +{ + ccv_cnnp_compiled_data_t* const compiled_data = model->compiled_data; + if (!compiled_data) + return; + if (compiled_data->graph) + ccv_nnc_graph_cancel(compiled_data->graph); + if (compiled_data->apply_gradients.graph) + ccv_nnc_graph_cancel(compiled_data->apply_gradients.graph); +} diff --git a/lib/nnc/ccv_nnc.h b/lib/nnc/ccv_nnc.h index 0d28e06a8..345339abd 100644 --- a/lib/nnc/ccv_nnc.h +++ b/lib/nnc/ccv_nnc.h @@ -1346,6 +1346,14 @@ int ccv_nnc_graph_run(ccv_nnc_graph_t* const graph, const int flags, const ccv_n * @return CCV_NNC_EXEC_SUCCESS if succeed. */ int ccv_nnc_graph_run_with_schedule(ccv_nnc_graph_t* const graph, const int flags, const ccv_nnc_graph_static_schedule_t* const schedule, ccv_nnc_tensor_tape_t* const tensor_tape, ccv_nnc_stream_context_t* const stream_context); +/** + * Cancel execution of a graph. You need to handle synchronization yourself when calling this method to make + * sure the graph is currently executing when cancelling. This method will set a flag internally and the + * graph execution will check that flag when push compute on the computation device and abort if it is cancelled. + * When you call ccv_nnc_graph_run again, this cancellation won't in effect and you need to call cancel again. + * @param graph The concrete graph. + */ +void ccv_nnc_graph_cancel(ccv_nnc_graph_t* const graph); /** @} */ @@ -3804,6 +3812,15 @@ void ccv_cnnp_model_backward(ccv_cnnp_model_t* const model, ccv_nnc_tensor_t* co * @param stream_context The stream where the gradient computation can be executed upon. */ void ccv_cnnp_model_apply_gradients(ccv_cnnp_model_t* const model, ccv_nnc_stream_context_t* const stream_context); +/** + * Cancel execution of a model, whether it is forward / backward or gradient application pass. You need to make + * sure the model is currently executing when cancelling. This method will set a flag internally and the + * execution will check that flag when push compute on the computation device and abort if it is cancelled. + * When you call other model execution method again, this cancellation won't in effect and you need to call + * cancel again. + * @param model The composed model. + */ +void ccv_cnnp_model_cancel(ccv_cnnp_model_t* const model); enum { /** * This is the default flag, if the model is not initialized, will attempt to read from the disk. diff --git a/lib/nnc/ccv_nnc_graph_run.c b/lib/nnc/ccv_nnc_graph_run.c index 4e23e686c..2a23e7c24 100644 --- a/lib/nnc/ccv_nnc_graph_run.c +++ b/lib/nnc/ccv_nnc_graph_run.c @@ -597,6 +597,8 @@ static co_decl_task(_ccv_nnc_graph_exec_run_loop, (ccv_nnc_graph_t* const graph, CO_V(pending_node_size)[1] = 0; for (CO_V(i) = CO_P(start_index); CO_V(i) < CO_P(exec_info_size); CO_V(i)++) { + if (__atomic_load_n(&CO_P(graph)->run_state, __ATOMIC_ACQUIRE) == CCV_NNC_GRAPH_STATE_CANCEL) + break; CO_V(idx) = CO_P(psort) ? CO_P(psort)[CO_V(i)] : CO_V(i); CO_V(node) = CO_P(exec_info) + CO_V(idx); CO_V(schd) = CO_P(schd_info) + CO_V(idx); @@ -628,6 +630,8 @@ static co_decl_task(_ccv_nnc_graph_exec_run_loop, (ccv_nnc_graph_t* const graph, } if (CO_V(sub_task_size)) co_apply(_ccv_nnc_graph_wait_any_sub_tasks, (CO_P(graph), CO_V(sub_tasks), CO_V(sub_task_size), CO_P(schd_info), CO_V(pending_nodes)[0], CO_V(pending_node_size)[0])); + if (__atomic_load_n(&CO_P(graph)->run_state, __ATOMIC_ACQUIRE) == CCV_NNC_GRAPH_STATE_CANCEL) + co_return(); CO_V(p) = 0; CO_V(q) = 1; while (CO_V(pending_node_size)[CO_V(p)] > 0) @@ -746,6 +750,8 @@ co_task(_ccv_nnc_graph_topsorted_run_coro, (ccv_nnc_graph_t* const graph, const if (CO_V(count) > 0) _ccv_nnc_graph_transit_move_to(CO_P(graph)); co_apply(_ccv_nnc_graph_exec_run_loop, (CO_P(graph), CO_V(exec_info), CO_V(schd_info), 0, 0, CO_V(graph_breakpoint_size), CO_P(tensor_tape), CO_P(flags))); + if (__atomic_load_n(&CO_P(graph)->run_state, __ATOMIC_ACQUIRE) == CCV_NNC_GRAPH_STATE_CANCEL) + break; // Reached breakpoints, now check the breakpoint, if not met, break out. // Wait until everything on the stream is executed. for (CO_V(i) = CO_P(graph)->breakpoint_offset; CO_V(i) < CO_V(graph_breakpoint_size); CO_V(i)++) @@ -759,6 +765,7 @@ co_task(_ccv_nnc_graph_topsorted_run_coro, (ccv_nnc_graph_t* const graph, const break; } co_apply(_ccv_nnc_graph_exec_run_loop, (CO_P(graph), CO_V(exec_info), CO_V(schd_info), 0, CO_V(graph_breakpoint_size), CO_P(graph)->exec_info->rnum, CO_P(tensor_tape), CO_P(flags))); + // If it is cancelled here, we don't need to breakout yet, we can breakout on earlier place. The most important thing is to avoid stream wait if there is a cancel. _ccv_nnc_graph_from_move_transit(CO_P(graph)); _ccv_nnc_graph_rewrap(CO_P(graph)); } @@ -774,6 +781,7 @@ co_task(_ccv_nnc_graph_topsorted_run_coro, (ccv_nnc_graph_t* const graph, const }); _ccv_nnc_graph_unwrap(CO_P(graph), CO_V(count), CO_V(reverse_count)); co_apply(_ccv_nnc_graph_exec_run_loop, (CO_P(graph), CO_V(exec_info), CO_V(schd_info), 0, CO_P(graph)->breakpoint_offset, CO_P(graph)->exec_info->rnum, CO_P(tensor_tape), CO_P(flags))); + // If it is cancelled here, we don't need to breakout yet, we can breakout later. _ccv_nnc_graph_from_move_transit(CO_P(graph)); _ccv_nnc_graph_rewrap(CO_P(graph)); for (CO_V(count) = 1; CO_V(reverse_count) > 0; ++CO_V(count)) @@ -782,10 +790,19 @@ co_task(_ccv_nnc_graph_topsorted_run_coro, (ccv_nnc_graph_t* const graph, const _ccv_nnc_graph_unwrap(CO_P(graph), CO_V(count), CO_V(reverse_count)); _ccv_nnc_graph_transit_move_to(CO_P(graph)); co_apply(_ccv_nnc_graph_exec_run_loop, (CO_P(graph), CO_V(exec_info), CO_V(schd_info), 0, 0, CO_P(graph)->exec_info->rnum, CO_P(tensor_tape), CO_P(flags))); + if (__atomic_load_n(&CO_P(graph)->run_state, __ATOMIC_ACQUIRE) == CCV_NNC_GRAPH_STATE_CANCEL) + break; _ccv_nnc_graph_from_move_transit(CO_P(graph)); _ccv_nnc_graph_rewrap(CO_P(graph)); } } + if (__atomic_load_n(&CO_P(graph)->run_state, __ATOMIC_ACQUIRE) == CCV_NNC_GRAPH_STATE_CANCEL) + { + // The most important thing is to reset main and then return, we don't need to wait for any streaming event. + if (CO_P(exec_idx) == -1 && CO_P(stream_context)->main == co_self()) + CO_P(stream_context)->main = 0; + co_return(); + } assert(CO_V(stream_0) == 0); int i; for (i = 0; i < CO_P(schedule)->wait_size; i++) @@ -793,6 +810,13 @@ co_task(_ccv_nnc_graph_topsorted_run_coro, (ccv_nnc_graph_t* const graph, const } else { CO_P(graph)->while_count = 0; co_apply(_ccv_nnc_graph_exec_run_loop, (CO_P(graph), CO_V(exec_info), CO_V(schd_info), CO_P(schedule)->psort, 0, CO_P(schedule)->psort ? CO_P(schedule)->psort_size : CO_P(schedule)->exec_info_size, CO_P(tensor_tape), CO_P(flags))); + if (__atomic_load_n(&CO_P(graph)->run_state, __ATOMIC_ACQUIRE) == CCV_NNC_GRAPH_STATE_CANCEL) + { + // The most important thing is to reset main and then return, we don't need to wait for any streaming event. + if (CO_P(exec_idx) == -1 && CO_P(stream_context)->main == co_self()) + CO_P(stream_context)->main = 0; + co_return(); + } PRINT(CCV_CLI_INFO, "Graph Stream %d End", CO_V(stream_0)); int i, flag = 0; for (i = 0; i < CO_P(schedule)->wait_size; i++) @@ -1067,6 +1091,7 @@ static int _ccv_nnc_graph_run(ccv_nnc_graph_t* const graph, const int exec_idx, int ccv_nnc_graph_run(ccv_nnc_graph_t* const graph, const int flags, const ccv_nnc_graph_exec_t* const sources, const int source_size, const ccv_nnc_graph_exec_t* const destinations, const int destination_size, ccv_nnc_tensor_tape_t* const tensor_tape, ccv_nnc_stream_context_t* const stream_context) { + __atomic_store_n(&graph->run_state, CCV_NNC_GRAPH_STATE_RUNNING, __ATOMIC_RELEASE); if (stream_context && graph->topsorted && graph->stream_size > 0 && graph->default_schedule && source_size == 0 && destination_size == 0) { co_scheduler_t* const scheduler = ccv_nnc_stream_context_get_scheduler(stream_context); @@ -1083,6 +1108,7 @@ int ccv_nnc_graph_run_with_schedule(ccv_nnc_graph_t* const graph, const int flag assert(graph->topsorted); if (graph->exec_info->rnum == 0) return CCV_NNC_EXEC_SUCCESS; + __atomic_store_n(&graph->run_state, CCV_NNC_GRAPH_STATE_RUNNING, __ATOMIC_RELEASE); assert(graph->stream_size > 0); const ccv_nnc_graph_static_schedule_t* const schedule = _schedule ? _schedule : graph->default_schedule; assert(schedule); @@ -1096,3 +1122,8 @@ int ccv_nnc_graph_run_with_schedule(ccv_nnc_graph_t* const graph, const int flag ccv_nnc_stream_context_wait(stream_context); return CCV_NNC_EXEC_SUCCESS; } + +void ccv_nnc_graph_cancel(ccv_nnc_graph_t* const graph) +{ + __atomic_store_n(&graph->run_state, CCV_NNC_GRAPH_STATE_CANCEL, __ATOMIC_RELEASE); +}