Skip to content

Commit c326cb7

Browse files
feat: add fused_gdn_gating kernel in tilelang-ascend. (#1267)
1 parent 6b311e0 commit c326cb7

File tree

13 files changed

+1060
-73
lines changed

13 files changed

+1060
-73
lines changed

third_party/tilelang-ascend

Submodule tilelang-ascend updated from 46d8c6b to 289e1ae

xllm/compiler/tilelang/cli/compile_kernels.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,26 @@ def main(argv: list[str] | None = None) -> None:
4646

4747
if args.target == "ascend":
4848
from ..bootstrap import prepare_ascend
49+
4950
prepare_ascend()
5051
from ..targets.ascend.build import build_kernels
52+
53+
manifests = build_kernels(
54+
output_root=output_root,
55+
kernel_names=args.kernels,
56+
force=args.force,
57+
device=args.device,
58+
)
5159
elif args.target == "cuda":
5260
from ..targets.cuda.build import build_kernels
61+
62+
manifests = build_kernels(
63+
output_root=output_root,
64+
kernel_names=args.kernels,
65+
force=args.force,
66+
)
5367
else:
5468
raise ValueError(f"Unsupported target: {args.target}")
55-
56-
manifests = build_kernels(
57-
output_root=output_root,
58-
kernel_names=args.kernels,
59-
force=args.force,
60-
device=args.device,
61-
)
6269
for manifest in manifests:
6370
print(f"[INFO] built {manifest.target}:{manifest.kernel_name}")
6471
print(f"[INFO] manifest: {Path(manifest.output_dir) / 'manifest.json'}")

xllm/compiler/tilelang/common/toolchain.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def require_env(name: str) -> str:
3434
def prepend_pythonpath(env: dict[str, str], path: str) -> None:
3535
current = env.get("PYTHONPATH", "")
3636
items = [item for item in current.split(os.pathsep) if item]
37-
if path not in items:
38-
items.insert(0, path)
39-
env["PYTHONPATH"] = os.pathsep.join(items)
37+
items = [item for item in items if item != path]
38+
items.insert(0, path)
39+
env["PYTHONPATH"] = os.pathsep.join(items)
4040

4141

4242
def prepare_tilelang_import(tilelang_root: str | Path | None = None) -> Path:
@@ -45,8 +45,11 @@ def prepare_tilelang_import(tilelang_root: str | Path | None = None) -> Path:
4545
)
4646
os.environ["TL_ROOT"] = str(tl_root)
4747
prepend_pythonpath(os.environ, str(tl_root))
48-
if str(tl_root) not in sys.path:
49-
sys.path.insert(0, str(tl_root))
48+
tl_root_str = str(tl_root)
49+
# Keep TL_ROOT at sys.path front to avoid resolving the sibling
50+
# package xllm/compiler/tilelang as top-level `tilelang`.
51+
sys.path = [p for p in sys.path if p != tl_root_str]
52+
sys.path.insert(0, tl_root_str)
5053
os.environ.setdefault("ACL_OP_INIT_MODE", "1")
5154
return tl_root
5255

xllm/compiler/tilelang/targets/ascend/build.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
from __future__ import annotations
2-
1+
import importlib
2+
import os
3+
import pkgutil
4+
import re
5+
from dataclasses import dataclass
36
from pathlib import Path
47

58
from ...common.manifest import KernelFamilyManifest

xllm/compiler/tilelang/targets/ascend/kernel_family_builder.py

Lines changed: 111 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from __future__ import annotations
22

3+
from concurrent.futures import ThreadPoolExecutor, as_completed
4+
from dataclasses import dataclass
5+
import os
36
from pathlib import Path
47

58
from ...common.cache import compute_cache_key, is_cache_hit
69
from ...common.manifest import KernelAbi, KernelFamilyManifest, KernelVariantManifest
7-
from ...common.spec import DispatchField, KernelCompileSpec, TilelangKernel
10+
from ...common.spec import DispatchField, KernelCompileSpec, KernelSpec, TilelangKernel
811
from ...common.toolchain import repo_root, run_checked
912
from . import abi_entry, kernel_registry, toolchain
1013
from .kernel_registry import RegisteredKernelFamily
@@ -13,6 +16,21 @@
1316
from .toolchain import AscendBuildContext, TILELANG_BISHENG_COMMON_FLAGS
1417

1518

19+
@dataclass(frozen=True)
20+
class _VariantBuildPlan:
21+
compile_spec: KernelCompileSpec
22+
kernel_spec: KernelSpec
23+
generated_source: Path
24+
compiled_binary: Path
25+
cache_key: str
26+
27+
28+
@dataclass(frozen=True)
29+
class _VariantBuildResult:
30+
manifest: KernelVariantManifest
31+
kernel_abi: KernelAbi
32+
33+
1634
def _variant_entry_symbol(spec: KernelCompileSpec) -> str:
1735
kernel_entry_name = spec.entry_name or spec.kernel_name
1836
return f"{kernel_entry_name}__{spec.variant_key}_call"
@@ -116,7 +134,8 @@ def build_kernel_family(
116134
existing_manifest = _read_family_manifest(manifest_path)
117135
dependency_files = _build_dependency_files(family)
118136

119-
variant_manifests: list[KernelVariantManifest] = []
137+
variant_manifest_by_key: dict[str, KernelVariantManifest] = {}
138+
uncached_plans: list[_VariantBuildPlan] = []
120139
family_kernel_abi: KernelAbi | None = None
121140

122141
for compile_spec, kernel_spec in family.spec_pairs:
@@ -167,24 +186,37 @@ def build_kernel_family(
167186
"All variants in a TileLang kernel must share the same exported "
168187
f"C ABI. Mismatch found in variant {compile_spec.variant_key!r}."
169188
)
170-
variant_manifests.append(
171-
KernelVariantManifest(
172-
variant_key=compile_spec.variant_key,
173-
specialization=dict(compile_spec.specialization),
174-
dispatch_values=dict(compile_spec.dispatch_values),
175-
generated_source=cached_variant.generated_source,
176-
compiled_binary=cached_variant.compiled_binary,
177-
entry_symbol=cached_variant.entry_symbol,
178-
cache_key=cached_variant.cache_key,
179-
toolchain_options=dict(context.toolchain_options),
180-
fingerprint=dict(context.fingerprint),
181-
compile_definitions=kernel_spec.render_compile_definitions(
182-
entry_symbol=cached_variant.entry_symbol
183-
),
184-
)
189+
variant_manifest_by_key[compile_spec.variant_key] = KernelVariantManifest(
190+
variant_key=compile_spec.variant_key,
191+
specialization=dict(compile_spec.specialization),
192+
dispatch_values=dict(compile_spec.dispatch_values),
193+
generated_source=cached_variant.generated_source,
194+
compiled_binary=cached_variant.compiled_binary,
195+
entry_symbol=cached_variant.entry_symbol,
196+
cache_key=cached_variant.cache_key,
197+
toolchain_options=dict(context.toolchain_options),
198+
fingerprint=dict(context.fingerprint),
199+
compile_definitions=kernel_spec.render_compile_definitions(
200+
entry_symbol=cached_variant.entry_symbol
201+
),
185202
)
186203
continue
187204

205+
uncached_plans.append(
206+
_VariantBuildPlan(
207+
compile_spec=compile_spec,
208+
kernel_spec=kernel_spec,
209+
generated_source=generated_source,
210+
compiled_binary=compiled_binary,
211+
cache_key=cache_key,
212+
)
213+
)
214+
215+
compile_cwd = repo_root()
216+
217+
def _run_variant_job(plan: _VariantBuildPlan) -> _VariantBuildResult:
218+
compile_spec = plan.compile_spec
219+
kernel_spec = plan.kernel_spec
188220
source = family.kernel_cls.generate_source(**compile_spec.specialization)
189221
entry_symbol = _variant_entry_symbol(compile_spec)
190222
rendered_source = abi_entry.rename_variant_internal_symbols(
@@ -194,44 +226,79 @@ def build_kernel_family(
194226
compile_spec.variant_key,
195227
)
196228
kernel_abi = abi_entry.parse_kernel_abi(rendered_source, entry_symbol)
197-
if family_kernel_abi is None:
198-
family_kernel_abi = kernel_abi
199-
elif kernel_abi != family_kernel_abi:
200-
raise ValueError(
201-
"All variants in a TileLang kernel must share the same exported "
202-
f"C ABI. Mismatch found in variant {compile_spec.variant_key!r}."
203-
)
204-
generated_source.write_text(rendered_source, encoding="utf-8")
229+
plan.generated_source.write_text(rendered_source, encoding="utf-8")
205230

206231
compile_cmd = [
207232
context.bisheng_executable,
208233
f"--cce-aicore-arch={context.bisheng_arch}",
209234
*TILELANG_BISHENG_COMMON_FLAGS,
210235
f"-Dg_tilingKey=g_tilingKey__{compile_spec.variant_key}",
211236
*[f"-I{include_dir}" for include_dir in context.include_dirs],
212-
str(generated_source),
237+
str(plan.generated_source),
213238
"-c",
214239
"-o",
215-
str(compiled_binary),
240+
str(plan.compiled_binary),
216241
]
217-
run_checked(compile_cmd, cwd=repo_root())
218-
219-
variant_manifests.append(
220-
KernelVariantManifest(
221-
variant_key=compile_spec.variant_key,
222-
specialization=compile_spec.specialization,
223-
dispatch_values=compile_spec.dispatch_values,
224-
generated_source=str(generated_source),
225-
compiled_binary=str(compiled_binary),
226-
entry_symbol=entry_symbol,
227-
cache_key=cache_key,
228-
toolchain_options=dict(context.toolchain_options),
229-
fingerprint=dict(context.fingerprint),
230-
compile_definitions=kernel_spec.render_compile_definitions(
231-
entry_symbol=entry_symbol
232-
),
233-
)
242+
run_checked(compile_cmd, cwd=compile_cwd)
243+
manifest = KernelVariantManifest(
244+
variant_key=compile_spec.variant_key,
245+
specialization=dict(compile_spec.specialization),
246+
dispatch_values=dict(compile_spec.dispatch_values),
247+
generated_source=str(plan.generated_source),
248+
compiled_binary=str(plan.compiled_binary),
249+
entry_symbol=entry_symbol,
250+
cache_key=plan.cache_key,
251+
toolchain_options=dict(context.toolchain_options),
252+
fingerprint=dict(context.fingerprint),
253+
compile_definitions=kernel_spec.render_compile_definitions(
254+
entry_symbol=entry_symbol
255+
),
234256
)
257+
return _VariantBuildResult(manifest=manifest, kernel_abi=kernel_abi)
258+
259+
if uncached_plans:
260+
max_workers = max(1, os.cpu_count() or 1)
261+
if max_workers == 1:
262+
for plan in uncached_plans:
263+
result = _run_variant_job(plan)
264+
if family_kernel_abi is None:
265+
family_kernel_abi = result.kernel_abi
266+
elif result.kernel_abi != family_kernel_abi:
267+
raise ValueError(
268+
"All variants in a TileLang kernel must share the same exported "
269+
"C ABI. Mismatch found in variant "
270+
f"{result.manifest.variant_key!r}."
271+
)
272+
variant_manifest_by_key[result.manifest.variant_key] = result.manifest
273+
else:
274+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
275+
future_to_plan = {
276+
executor.submit(_run_variant_job, plan): plan
277+
for plan in uncached_plans
278+
}
279+
for future in as_completed(future_to_plan):
280+
plan = future_to_plan[future]
281+
try:
282+
result = future.result()
283+
except Exception as exc:
284+
raise RuntimeError(
285+
"Ascend variant build failed for variant "
286+
f"{plan.compile_spec.variant_key!r}"
287+
) from exc
288+
if family_kernel_abi is None:
289+
family_kernel_abi = result.kernel_abi
290+
elif result.kernel_abi != family_kernel_abi:
291+
raise ValueError(
292+
"All variants in a TileLang kernel must share the same exported "
293+
"C ABI. Mismatch found in variant "
294+
f"{result.manifest.variant_key!r}."
295+
)
296+
variant_manifest_by_key[result.manifest.variant_key] = result.manifest
297+
298+
variant_manifests: list[KernelVariantManifest] = [
299+
variant_manifest_by_key[compile_spec.variant_key]
300+
for compile_spec, _ in family.spec_pairs
301+
]
235302

236303
if family_kernel_abi is None:
237304
raise ValueError(

0 commit comments

Comments
 (0)