From 9b52e5d57b00f31341660e2c6e1994a152776c80 Mon Sep 17 00:00:00 2001 From: crapromer Date: Mon, 22 Dec 2025 20:13:30 +0800 Subject: [PATCH] Basic loop fusion added --- src/ninetoothed/generation.py | 81 ++++++++++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/src/ninetoothed/generation.py b/src/ninetoothed/generation.py index f650c14..e6bc890 100644 --- a/src/ninetoothed/generation.py +++ b/src/ninetoothed/generation.py @@ -61,7 +61,9 @@ def _get_tree(func): inliner.visit(func_def) func_def = ast.parse(ast.unparse(func_def)) - + name_mapping = type(self)._generate_name_mapping_from_tensors(self._args) + loop_fuser = _LoopFuser(self._context, name_mapping) + loop_fuser.visit(func_def) module = ast.Module(body=[func_def], type_ignores=[]) if inliner.libdevice_used: @@ -1264,3 +1266,80 @@ def visit_FunctionDef(self, node): self.result = node self.generic_visit(node) + + +class _LoopFuser(ast.NodeVisitor): + def __init__(self, context, name_mapping): + self._context = context + self._name_mapping = name_mapping + self.result = None + + def _same_loop(self, f1, f2): + return ast.dump(f1.target) == ast.dump(f2.target) and ast.dump( + f1.iter + ) == ast.dump(f2.iter) + + # === 新增:变量分析 === + class _VarRWAnalyzer(ast.NodeVisitor): + def __init__(self): + self.read = set() + self.write = set() + + def visit_Name(self, node): + if isinstance(node.ctx, ast.Load): + self.read.add(node.id) + elif isinstance(node.ctx, ast.Store): + self.write.add(node.id) + + def _loop_carried_vars(self, loop): + analyzer = self._VarRWAnalyzer() + for stmt in loop.body: + analyzer.visit(stmt) + return analyzer.read & analyzer.write + + def _loop_reads_vars(self, loop, vars): + analyzer = self._VarRWAnalyzer() + for stmt in loop.body: + analyzer.visit(stmt) + return bool(analyzer.read & vars) + + # === 新增:是否允许融合 === + def _can_fuse(self, loop1, loop2): + carried = self._loop_carried_vars(loop1) + if not carried: + return True + if self._loop_reads_vars(loop2, carried): + return False + return True + + def visit_FunctionDef(self, node): + self.generic_visit(node) + + fused_body = [] + body = node.body + i = 0 + + while i < len(body): + if not isinstance(body[i], ast.For): + fused_body.append(body[i]) + i += 1 + continue + + fused_for = body[i] + j = i + 1 + + while ( + j < len(body) + and isinstance(body[j], ast.For) + and self._same_loop(fused_for, body[j]) + and self._can_fuse(fused_for, body[j]) # 👈 关键 + ): + fused_for.body.extend(body[j].body) + j += 1 + + fused_body.append(fused_for) + i = j + + node.body = fused_body + self.result = node + return node