@@ -17,7 +17,9 @@ limitations under the License.
1717
1818#include < torch/torch.h>
1919
20+ #include " common/global_flags.h"
2021#include " platform/device.h"
22+ #include " util/env_var.h"
2123#if defined(USE_NPU)
2224#ifdef TORCH_HIGHER_THAN_PTA6
2325// #include <torch_npu/csrc/core/npu/NPUFormat.h>
@@ -30,6 +32,21 @@ limitations under the License.
3032#endif
3133
3234namespace xllm {
35+
36+ namespace {
37+
38+ bool should_enable_async_tiling_copy_stream () {
39+ // ATB copy-stream teardown is not reversible for the same context on the
40+ // current CANN/PTA stack, so contexts that may enter graph capture must not
41+ // pre-create the helper stream.
42+ if (FLAGS_enable_graph) {
43+ return false ;
44+ }
45+ return util::get_bool_env (" ATB_USE_TILING_COPY_STREAM" , false );
46+ }
47+
48+ } // namespace
49+
3350ModelContext::ModelContext (const ParallelArgs& input_parallel_args,
3451 const ModelArgs& model_args,
3552 const QuantArgs& quant_args,
@@ -44,7 +61,9 @@ ModelContext::ModelContext(const ParallelArgs& input_parallel_args,
4461 atb::CreateContext (&context_);
4562 void * stream = c10_npu::getCurrentNPUStream (device_id).stream ();
4663 context_->SetExecuteStream (stream);
47- context_->SetAsyncTilingCopyStatus (true );
64+ if (should_enable_async_tiling_copy_stream ()) {
65+ context_->SetAsyncTilingCopyStatus (true );
66+ }
4867 atb_workspace_ = std::make_shared<AtbWorkspace>(tensor_options.device ());
4968#endif
5069 derive_optimization_config ();
0 commit comments