Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion src/ninetoothed/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def _get_tree(func):
inliner.visit(func_def)

func_def = ast.parse(ast.unparse(func_def))

name_mapping = type(self)._generate_name_mapping_from_tensors(self._args)
loop_fuser = _LoopFuser(self._context, name_mapping)
loop_fuser.visit(func_def)
module = ast.Module(body=[func_def], type_ignores=[])

if inliner.libdevice_used:
Expand Down Expand Up @@ -1264,3 +1266,80 @@ def visit_FunctionDef(self, node):
self.result = node

self.generic_visit(node)


class _LoopFuser(ast.NodeVisitor):
def __init__(self, context, name_mapping):
self._context = context
self._name_mapping = name_mapping
self.result = None

def _same_loop(self, f1, f2):
return ast.dump(f1.target) == ast.dump(f2.target) and ast.dump(
f1.iter
) == ast.dump(f2.iter)

# === 新增:变量分析 ===
class _VarRWAnalyzer(ast.NodeVisitor):
def __init__(self):
self.read = set()
self.write = set()

def visit_Name(self, node):
if isinstance(node.ctx, ast.Load):
self.read.add(node.id)
elif isinstance(node.ctx, ast.Store):
self.write.add(node.id)

def _loop_carried_vars(self, loop):
analyzer = self._VarRWAnalyzer()
for stmt in loop.body:
analyzer.visit(stmt)
return analyzer.read & analyzer.write

def _loop_reads_vars(self, loop, vars):
analyzer = self._VarRWAnalyzer()
for stmt in loop.body:
analyzer.visit(stmt)
return bool(analyzer.read & vars)

# === 新增:是否允许融合 ===
def _can_fuse(self, loop1, loop2):
carried = self._loop_carried_vars(loop1)
if not carried:
return True
if self._loop_reads_vars(loop2, carried):
return False
return True

def visit_FunctionDef(self, node):
self.generic_visit(node)

fused_body = []
body = node.body
i = 0

while i < len(body):
if not isinstance(body[i], ast.For):
fused_body.append(body[i])
i += 1
continue

fused_for = body[i]
j = i + 1

while (
j < len(body)
and isinstance(body[j], ast.For)
and self._same_loop(fused_for, body[j])
and self._can_fuse(fused_for, body[j]) # 👈 关键
):
fused_for.body.extend(body[j].body)
j += 1

fused_body.append(fused_for)
i = j

node.body = fused_body
self.result = node
return node