11from __future__ import annotations
22
3+ from concurrent .futures import ThreadPoolExecutor , as_completed
4+ from dataclasses import dataclass
5+ import os
36from pathlib import Path
47
58from ...common .cache import compute_cache_key , is_cache_hit
69from ...common .manifest import KernelAbi , KernelFamilyManifest , KernelVariantManifest
7- from ...common .spec import DispatchField , KernelCompileSpec , TilelangKernel
10+ from ...common .spec import DispatchField , KernelCompileSpec , KernelSpec , TilelangKernel
811from ...common .toolchain import repo_root , run_checked
912from . import abi_entry , kernel_registry , toolchain
1013from .kernel_registry import RegisteredKernelFamily
1316from .toolchain import AscendBuildContext , TILELANG_BISHENG_COMMON_FLAGS
1417
1518
19+ @dataclass (frozen = True )
20+ class _VariantBuildPlan :
21+ compile_spec : KernelCompileSpec
22+ kernel_spec : KernelSpec
23+ generated_source : Path
24+ compiled_binary : Path
25+ cache_key : str
26+
27+
28+ @dataclass (frozen = True )
29+ class _VariantBuildResult :
30+ manifest : KernelVariantManifest
31+ kernel_abi : KernelAbi
32+
33+
1634def _variant_entry_symbol (spec : KernelCompileSpec ) -> str :
1735 kernel_entry_name = spec .entry_name or spec .kernel_name
1836 return f"{ kernel_entry_name } __{ spec .variant_key } _call"
@@ -116,7 +134,8 @@ def build_kernel_family(
116134 existing_manifest = _read_family_manifest (manifest_path )
117135 dependency_files = _build_dependency_files (family )
118136
119- variant_manifests : list [KernelVariantManifest ] = []
137+ variant_manifest_by_key : dict [str , KernelVariantManifest ] = {}
138+ uncached_plans : list [_VariantBuildPlan ] = []
120139 family_kernel_abi : KernelAbi | None = None
121140
122141 for compile_spec , kernel_spec in family .spec_pairs :
@@ -167,24 +186,37 @@ def build_kernel_family(
167186 "All variants in a TileLang kernel must share the same exported "
168187 f"C ABI. Mismatch found in variant { compile_spec .variant_key !r} ."
169188 )
170- variant_manifests .append (
171- KernelVariantManifest (
172- variant_key = compile_spec .variant_key ,
173- specialization = dict (compile_spec .specialization ),
174- dispatch_values = dict (compile_spec .dispatch_values ),
175- generated_source = cached_variant .generated_source ,
176- compiled_binary = cached_variant .compiled_binary ,
177- entry_symbol = cached_variant .entry_symbol ,
178- cache_key = cached_variant .cache_key ,
179- toolchain_options = dict (context .toolchain_options ),
180- fingerprint = dict (context .fingerprint ),
181- compile_definitions = kernel_spec .render_compile_definitions (
182- entry_symbol = cached_variant .entry_symbol
183- ),
184- )
189+ variant_manifest_by_key [compile_spec .variant_key ] = KernelVariantManifest (
190+ variant_key = compile_spec .variant_key ,
191+ specialization = dict (compile_spec .specialization ),
192+ dispatch_values = dict (compile_spec .dispatch_values ),
193+ generated_source = cached_variant .generated_source ,
194+ compiled_binary = cached_variant .compiled_binary ,
195+ entry_symbol = cached_variant .entry_symbol ,
196+ cache_key = cached_variant .cache_key ,
197+ toolchain_options = dict (context .toolchain_options ),
198+ fingerprint = dict (context .fingerprint ),
199+ compile_definitions = kernel_spec .render_compile_definitions (
200+ entry_symbol = cached_variant .entry_symbol
201+ ),
185202 )
186203 continue
187204
205+ uncached_plans .append (
206+ _VariantBuildPlan (
207+ compile_spec = compile_spec ,
208+ kernel_spec = kernel_spec ,
209+ generated_source = generated_source ,
210+ compiled_binary = compiled_binary ,
211+ cache_key = cache_key ,
212+ )
213+ )
214+
215+ compile_cwd = repo_root ()
216+
217+ def _run_variant_job (plan : _VariantBuildPlan ) -> _VariantBuildResult :
218+ compile_spec = plan .compile_spec
219+ kernel_spec = plan .kernel_spec
188220 source = family .kernel_cls .generate_source (** compile_spec .specialization )
189221 entry_symbol = _variant_entry_symbol (compile_spec )
190222 rendered_source = abi_entry .rename_variant_internal_symbols (
@@ -194,44 +226,79 @@ def build_kernel_family(
194226 compile_spec .variant_key ,
195227 )
196228 kernel_abi = abi_entry .parse_kernel_abi (rendered_source , entry_symbol )
197- if family_kernel_abi is None :
198- family_kernel_abi = kernel_abi
199- elif kernel_abi != family_kernel_abi :
200- raise ValueError (
201- "All variants in a TileLang kernel must share the same exported "
202- f"C ABI. Mismatch found in variant { compile_spec .variant_key !r} ."
203- )
204- generated_source .write_text (rendered_source , encoding = "utf-8" )
229+ plan .generated_source .write_text (rendered_source , encoding = "utf-8" )
205230
206231 compile_cmd = [
207232 context .bisheng_executable ,
208233 f"--cce-aicore-arch={ context .bisheng_arch } " ,
209234 * TILELANG_BISHENG_COMMON_FLAGS ,
210235 f"-Dg_tilingKey=g_tilingKey__{ compile_spec .variant_key } " ,
211236 * [f"-I{ include_dir } " for include_dir in context .include_dirs ],
212- str (generated_source ),
237+ str (plan . generated_source ),
213238 "-c" ,
214239 "-o" ,
215- str (compiled_binary ),
240+ str (plan . compiled_binary ),
216241 ]
217- run_checked (compile_cmd , cwd = repo_root ())
218-
219- variant_manifests .append (
220- KernelVariantManifest (
221- variant_key = compile_spec .variant_key ,
222- specialization = compile_spec .specialization ,
223- dispatch_values = compile_spec .dispatch_values ,
224- generated_source = str (generated_source ),
225- compiled_binary = str (compiled_binary ),
226- entry_symbol = entry_symbol ,
227- cache_key = cache_key ,
228- toolchain_options = dict (context .toolchain_options ),
229- fingerprint = dict (context .fingerprint ),
230- compile_definitions = kernel_spec .render_compile_definitions (
231- entry_symbol = entry_symbol
232- ),
233- )
242+ run_checked (compile_cmd , cwd = compile_cwd )
243+ manifest = KernelVariantManifest (
244+ variant_key = compile_spec .variant_key ,
245+ specialization = dict (compile_spec .specialization ),
246+ dispatch_values = dict (compile_spec .dispatch_values ),
247+ generated_source = str (plan .generated_source ),
248+ compiled_binary = str (plan .compiled_binary ),
249+ entry_symbol = entry_symbol ,
250+ cache_key = plan .cache_key ,
251+ toolchain_options = dict (context .toolchain_options ),
252+ fingerprint = dict (context .fingerprint ),
253+ compile_definitions = kernel_spec .render_compile_definitions (
254+ entry_symbol = entry_symbol
255+ ),
234256 )
257+ return _VariantBuildResult (manifest = manifest , kernel_abi = kernel_abi )
258+
259+ if uncached_plans :
260+ max_workers = max (1 , os .cpu_count () or 1 )
261+ if max_workers == 1 :
262+ for plan in uncached_plans :
263+ result = _run_variant_job (plan )
264+ if family_kernel_abi is None :
265+ family_kernel_abi = result .kernel_abi
266+ elif result .kernel_abi != family_kernel_abi :
267+ raise ValueError (
268+ "All variants in a TileLang kernel must share the same exported "
269+ "C ABI. Mismatch found in variant "
270+ f"{ result .manifest .variant_key !r} ."
271+ )
272+ variant_manifest_by_key [result .manifest .variant_key ] = result .manifest
273+ else :
274+ with ThreadPoolExecutor (max_workers = max_workers ) as executor :
275+ future_to_plan = {
276+ executor .submit (_run_variant_job , plan ): plan
277+ for plan in uncached_plans
278+ }
279+ for future in as_completed (future_to_plan ):
280+ plan = future_to_plan [future ]
281+ try :
282+ result = future .result ()
283+ except Exception as exc :
284+ raise RuntimeError (
285+ "Ascend variant build failed for variant "
286+ f"{ plan .compile_spec .variant_key !r} "
287+ ) from exc
288+ if family_kernel_abi is None :
289+ family_kernel_abi = result .kernel_abi
290+ elif result .kernel_abi != family_kernel_abi :
291+ raise ValueError (
292+ "All variants in a TileLang kernel must share the same exported "
293+ "C ABI. Mismatch found in variant "
294+ f"{ result .manifest .variant_key !r} ."
295+ )
296+ variant_manifest_by_key [result .manifest .variant_key ] = result .manifest
297+
298+ variant_manifests : list [KernelVariantManifest ] = [
299+ variant_manifest_by_key [compile_spec .variant_key ]
300+ for compile_spec , _ in family .spec_pairs
301+ ]
235302
236303 if family_kernel_abi is None :
237304 raise ValueError (
0 commit comments