Skip to content

Commit b80113a

Browse files
committed
fix DeepSeek-3.2 failures when ACL Graph is enabled.
1 parent 87d9e35 commit b80113a

File tree

3 files changed

+23
-4
lines changed

3 files changed

+23
-4
lines changed

third_party/xllm_atb_layers

Submodule xllm_atb_layers updated from 918c03d to d6aa214

xllm/core/framework/model_context.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ limitations under the License.
1717

1818
#include <torch/torch.h>
1919

20+
#include "common/global_flags.h"
2021
#include "platform/device.h"
22+
#include "util/env_var.h"
2123
#if defined(USE_NPU)
2224
#ifdef TORCH_HIGHER_THAN_PTA6
2325
// #include <torch_npu/csrc/core/npu/NPUFormat.h>
@@ -30,6 +32,21 @@ limitations under the License.
3032
#endif
3133

3234
namespace xllm {
35+
36+
namespace {
37+
38+
bool should_enable_async_tiling_copy_stream() {
39+
// ATB copy-stream teardown is not reversible for the same context on the
40+
// current CANN/PTA stack, so contexts that may enter graph capture must not
41+
// pre-create the helper stream.
42+
if (FLAGS_enable_graph) {
43+
return false;
44+
}
45+
return util::get_bool_env("ATB_USE_TILING_COPY_STREAM", false);
46+
}
47+
48+
} // namespace
49+
3350
ModelContext::ModelContext(const ParallelArgs& input_parallel_args,
3451
const ModelArgs& model_args,
3552
const QuantArgs& quant_args,
@@ -44,7 +61,9 @@ ModelContext::ModelContext(const ParallelArgs& input_parallel_args,
4461
atb::CreateContext(&context_);
4562
void* stream = c10_npu::getCurrentNPUStream(device_id).stream();
4663
context_->SetExecuteStream(stream);
47-
context_->SetAsyncTilingCopyStatus(true);
64+
if (should_enable_async_tiling_copy_stream()) {
65+
context_->SetAsyncTilingCopyStatus(true);
66+
}
4867
atb_workspace_ = std::make_shared<AtbWorkspace>(tensor_options.device());
4968
#endif
5069
derive_optimization_config();

xllm/core/framework/parallel_state/mapping_npu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ MappingNPU::MappingNPU(const std::string rank_table_file,
4242
num_nodes_ = get_num_nodes();
4343
world_size_ = world_size;
4444
local_world_size_ = world_size / num_nodes_;
45-
attn_o_proj_tp_.backend("lccl");
46-
attn_inner_sp_.backend("lccl");
45+
attn_o_proj_tp_.backend("hccl");
46+
attn_inner_sp_.backend("hccl");
4747
parse_parallel_info();
4848
validate();
4949
get_tp_group(word_embed_tp_);

0 commit comments

Comments
 (0)