Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,14 @@ def make_ttir(mod, metadata, opt):
return mod


def get_attrs_descriptor(self, params, args):
if self.driver.target == 'ascend':
from triton.backends.dicp_triton.npu import AscendAttrsDescriptor
return AscendAttrsDescriptor(params, args)
else:
raise RuntimeError(f"backend {self.driver.target} not supported for get_attrs_descriptor.")
# def get_attrs_descriptor(self, params, args):
# if self.driver.target == 'ascend':
# from triton.backends.dicp_triton.npu import AscendAttrsDescriptor
# return AscendAttrsDescriptor(params, args)
# else:
# raise RuntimeError(f"backend {self.driver.target} not supported for get_attrs_descriptor.")

def add_stages(self, stages, options):
def add_stages(self, stages, options, language):
if self.driver.target not in ['ascend', 'mlu']:
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
if self.driver.target == 'dicp':
Expand Down Expand Up @@ -250,7 +250,7 @@ def parse_options(self, options: dict) -> Any:
args.update({k: options[k] for k in DICPOptions.__dataclass_fields__.keys() if k in options})
return DICPOptions(**args)

def get_codegen_implementation(self):
def get_codegen_implementation(self, options):
codegen_fns = dict()
if self.target.backend == 'ascend':
from triton.backends.dicp_triton.npu import min_dot_size
Expand Down
50 changes: 49 additions & 1 deletion backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,24 @@
from triton.backends.compiler import GPUTarget
from triton.backends.dicp_triton.utils import get_current_backend

from triton.runtime.build import quiet
# from triton.runtime.build import quiet
import importlib
import shutil

import setuptools
import torch
import sys
import contextlib
import io

@contextlib.contextmanager
def quiet():
old_stdout, old_stderr = sys.stdout, sys.stderr
sys.stdout, sys.stderr = io.StringIO(), io.StringIO()
try:
yield
finally:
sys.stdout, sys.stderr = old_stdout, old_stderr


def build_for_backend(name, src, srcdir):
Expand Down Expand Up @@ -104,6 +116,31 @@ def __init__(self):
self.load_binary = mod.load_binary
self.get_device_properties = mod.get_device_properties

def _ty_to_cpp(ty):
if ty[0] == '*':
return "void*"
if ty == "constexpr":
return "PyObject*"
return {
"i1": "int32_t",
"i8": "int8_t",
"i16": "int16_t",
"i32": "int32_t",
"i64": "int64_t",
"u1": "uint32_t",
"u8": "uint8_t",
"u16": "uint16_t",
"u32": "uint32_t",
"u64": "uint64_t",
# Proper support for bfloat16 and float16 is not yet handled.
# https://github.com/microsoft/triton-shared/issues/348
# "fp16": "TODO",
# "bf16": "TODO",
"fp32": "float",
"f32": "float",
"fp64": "double",
}[ty]

class DICPDriver(DriverBase):
def __init__(self, target=None):
if(self.__initialized): return
Expand Down Expand Up @@ -294,3 +331,14 @@ def get_empty_cache_for_benchmark(self):
return torch.empty(int(cache_size // 4), dtype=torch.int, device='mlu')
else:
assert False, f"Not implemented for {self.target}"

def get_active_torch_device(self):
# todo: fix it.
import torch
return torch.device("cpu")



def map_python_to_cpp_type(self, ty: str) -> str:
return _ty_to_cpp(ty)

41 changes: 35 additions & 6 deletions backend/npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import hashlib
from triton.runtime.cache import get_cache_manager, get_dump_manager
from triton.backends.compiler import GPUTarget
from triton.backends.compiler import BaseBackend, GPUTarget, AttrsDescriptor, register_descriptor
from triton.backends.compiler import BaseBackend, GPUTarget
from triton._C.libtriton import ir, passes
from triton.runtime.cache import get_dump_manager
from dataclasses import dataclass
Expand Down Expand Up @@ -322,11 +322,16 @@ def ttir_to_ttsharedir(mod, metadata, opt, *, named_ops=False):
dst_ttshared_path = os.path.join(tmpdir, "kernel.ttshared.mlir")
Path(src_path).write_text(ttir_code)
triton_shared_opt_path = _get_triton_shared_opt_path()
triton_shared_opt_path = "/mnt/data01/zmz/workspace/04ttshared/modify/triton/build/cmake.linux-aarch64-cpython-3.10/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt"
shutil.copy(src_path, "./tt.kernel.ttir.mlir")

cmd_shared_list = [triton_shared_opt_path, src_path,
f'--triton-to-linalg',
# f'--triton-to-linalg',
f'--triton-to-linalg-experimental',
"-o", dst_ttshared_path]
print(f"zmz debug cmd ttir_to_ttsharedir: {cmd_shared_list}")
ret = subprocess.run(cmd_shared_list, capture_output=True, check=True)
shutil.copy(dst_ttshared_path, "./tt.kernel.ttshared.mlir")
return Path(dst_ttshared_path).read_text()


Expand All @@ -337,9 +342,11 @@ def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False):
dst_path = os.path.join(tmpdir, "kernel.linkedir.mlir")
Path(src_path).write_text(ttsharedir_code)
dicp_opt_path = _get_dicp_opt_path()
dicp_opt_path = "/mnt/data01/zmz/workspace/04ttshared/Triton/third_party/triton/build/cmake.linux-aarch64-cpython-3.10/third_party/dicp_triton/tools/dicp_triton_opt/dicp_opt"
dicp_cmd_list = [dicp_opt_path, src_path,
f'--linalg-to-linked=global-kernel=false named-ops=true',
"-o", dst_path]
print(f"zmz debug cmd ttsharedir_to_linkedir: {dicp_cmd_list}")
ret = subprocess.run(dicp_cmd_list, capture_output=True, check=True)
# TODO(zmz): 修改test_path 中内容,暂时在python中处理,bishengir-compile后续会支持,去掉这里逻辑。
with open(dst_path, 'r') as f:
Expand All @@ -350,6 +357,22 @@ def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False):
content = content.replace("*xbf", "?xbf")
with open(dst_path, 'w') as f:
f.write(content)

# 匹配形如 "memref<...> to tensor<...>" 的模式
pattern = r'(memref\<.*?\>)\s+to\s+(tensor\<.*?\>)'

with open(dst_path, 'r') as f:
lines = f.readlines()

modified = []
for line in lines:
# 使用正则替换,保留memref和tensor类型,中间插入注释
new_line = re.sub(pattern, r'\1 // to \2', line)
modified.append(new_line)

with open(dst_path, 'w') as f:
f.writelines(modified)
shutil.copy(dst_path, "./tt.kernel.linkedir.mlir")
return Path(dst_path).read_text()


Expand Down Expand Up @@ -597,11 +620,11 @@ def hash(self):
key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()])
return hashlib.md5(key.encode("utf-8")).hexdigest()

@register_descriptor
class AscendAttrsDescriptor(AttrsDescriptor):
# @register_descriptor
# class AscendAttrsDescriptor(AttrsDescriptor):

def _add_backend_properties(self, params=None, values=None):
pass
# def _add_backend_properties(self, params=None, values=None):
# pass

class NPUUtils(object):
def __new__(cls):
Expand Down Expand Up @@ -631,7 +654,9 @@ def __init__(self):
self.npu_utils_mod = mod

def load_binary(self, name, kernel, shared, device):
import sys
fnname, mix_mode = name.split()
# return (self.npu_utils_mod.load_kernel_binary(fnname, kernel, shared, device, mix_mode), sys.maxsize)
return self.npu_utils_mod.load_kernel_binary(fnname, kernel, shared, device, mix_mode)

@functools.lru_cache()
Expand Down Expand Up @@ -796,6 +821,8 @@ def generate_npu_wrapper_src(constants, signature, workspace_size, mix_mode, loc
def _ty_to_cpp(ty):
if ty[0] == '*':
return "void*"
if ty == "constexpr":
return "PyObject*"
return {
"i1": "int32_t",
"i8": "int8_t",
Expand All @@ -814,6 +841,8 @@ def _ty_to_cpp(ty):
def _extracted_ty(ty):
if ty[0] == '*':
return "PyObject*"
if ty == "constexpr":
return "PyObject*"
return {
'i1': 'int32_t',
'i32': 'int32_t',
Expand Down
11 changes: 8 additions & 3 deletions compile_shared.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export LC_ALL="zh_CN.UTF-8"
# export LLVM_BUILD_DIR=/path/to/your/llvm-project/build
# export LLVM_TGZ_PATH=/path/to/your/llvm-86b69c31-ubuntu-arm64.tar.gz # 可选,用于指定LLVM的tgz包路径
export TRITON_PLUGIN_DIRS=$(pwd)
echo $TRITON_PLUGIN_DIRS
apply_patch=false

# 解析命令行参数
Expand Down Expand Up @@ -59,12 +60,16 @@ fi

pip uninstall triton -y

cd $TRITON_PLUGIN_DIRS/third_party/triton/python/
# cd $TRITON_PLUGIN_DIRS/third_party/triton/python/
cd $TRITON_PLUGIN_DIRS/third_party/triton/
rm -rf build/

if [ -z "$LLVM_BUILD_DIR" ]; then
# TRITON_BUILD_WITH_CLANG_LLD=true TRITON_BUILD_WITH_CCACHE=true \
# python3 -m pip install --no-build-isolation -vvv .[tests] -i https://mirrors.huaweicloud.com/repository/pypi/simple
echo "LLVM_BUILD_DIR is not set, will use LLVM_TGZ_PATH: $LLVM_TGZ_PATH"
TRITON_BUILD_WITH_CLANG_LLD=true TRITON_BUILD_WITH_CCACHE=true \
python3 -m pip install --no-build-isolation -vvv .[tests] -i https://mirrors.huaweicloud.com/repository/pypi/simple
python3 -m pip install --no-build-isolation -vvv '.[tests]' --trusted-host mirrors.huaweicloud.com -i https://mirrors.huaweicloud.com/repository/pypi/simple
else
echo "LLVM_BUILD_DIR is set to $LLVM_BUILD_DIR"
LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include \
Expand All @@ -78,4 +83,4 @@ if [[ $apply_patch == true ]]; then
echo "编译前先清空了third_party源码改动, 然后执行了apply patch/*.patch, 请检查正确性!"
else
echo "编译前没有执行apply patch/*.patch, 请检查正确性!"
fi
fi
27 changes: 19 additions & 8 deletions compiler/lib/Conversion/LinalgToLinked/LinalgToLinkedPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ using namespace linked;

#define GEN_PASS_CLASSES
#include "dicp/Conversion/LinalgToLinked/Passes.h.inc"

// #include "bishengir/Dialect/Annotation/IR/AnnotationOps.h"
namespace {

const std::string globalKernelAttr = "global_kernel";
Expand Down Expand Up @@ -194,7 +194,8 @@ class LinalgToLinkedPass : public LinalgToLinkedBase<LinalgToLinkedPass> {
return WalkResult::interrupt();
});
this->populateLinalgToLinkedConversionPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) {
if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) {
// if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) {
moduleOp.emitError("Pattern application failed");
signalPassFailure();
}
Expand Down Expand Up @@ -235,9 +236,12 @@ class LinalgToLinkedPass : public LinalgToLinkedBase<LinalgToLinkedPass> {
MemRefType syncBlockLockArgType =
MemRefType::get(SmallVector<int64_t>(1, ShapedType::kDynamic),
IntegerType::get(context, 8));
func.insertArgument(syncBlockLockArgIdx, // argIndex
syncBlockLockArgType, // argType
nullptr, func->getLoc()); // dicAttr
if (failed(func.insertArgument(syncBlockLockArgIdx, // argIndex
syncBlockLockArgType, // argType
nullptr, func->getLoc()))) { // 添加错误检查
signalPassFailure();
return;
}
func->setAttr("SyncBlockLockArgIdx",
IntegerAttr::get(IntegerType::get(&getContext(), 64), 0)); // 64: 64位整型

Expand All @@ -248,9 +252,16 @@ class LinalgToLinkedPass : public LinalgToLinkedBase<LinalgToLinkedPass> {
NamedAttribute workspaceArgAttr(StringAttr::get(context, "workspace"),
UnitAttr::get(context));

func.insertArgument(/*argIndex*/ workspaceArgIdx,
/*argType*/ workspaceArgType,
/*dicAttr*/ nullptr, func->getLoc());
// func.insertArgument(/*argIndex*/ workspaceArgIdx,
// /*argType*/ workspaceArgType,
// /*dicAttr*/ nullptr, func->getLoc());
if (failed(func.insertArgument(/*argIndex*/ workspaceArgIdx,
/*argType*/ workspaceArgType,
/*dicAttr*/ nullptr, func->getLoc()))) { // 添加错误检查
signalPassFailure();
return;
}

func->setAttr("WorkspaceArgIdx",
IntegerAttr::get(IntegerType::get(&getContext(), 64), 1));
}
Expand Down
7 changes: 6 additions & 1 deletion language/deeplink/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from triton.language.core import (
_tensor_member_fn,
_shape_check_impl,
_constexpr_to_value,
# _constexpr_to_value,
_unwrap_if_constexpr,
builtin,
constexpr,
Expand All @@ -14,6 +14,11 @@
from . import semantic as dl_semantic


def _constexpr_to_value(v):
if isinstance(v, constexpr):
return v.value
return v

class layout:
ASCEND = ['ND', 'NZ']

Expand Down
Loading