diff --git a/src/ninetoothed/jit.py b/src/ninetoothed/jit.py index 8e2113e..4c470b9 100644 --- a/src/ninetoothed/jit.py +++ b/src/ninetoothed/jit.py @@ -1,5 +1,7 @@ import importlib import sys +import torch +import os from ninetoothed.generation import CodeGenerator from ninetoothed.utils import calculate_default_configs @@ -97,13 +99,10 @@ def __call__(self): self._max_num_configs, self._prettify, ) - module = import_from_path(source_file, source_file) - module_vars = vars(module) - + handle = _Handle( - module_vars[self._kernel_name], - module_vars[code_generator.launch_func_name], source_file, + code_generator, ) return handle @@ -118,11 +117,57 @@ def import_from_path(module_name, file_path): return module +def get_target_device(*args, **kwargs): + target_device = None + for arg in args: + if isinstance(arg, torch.Tensor): + target_device = arg.device + break + + if target_device is None: + for val in kwargs.values(): + if isinstance(val, torch.Tensor): + target_device = val.device + break + + return target_device + + +def convert_to_cpu(source_file_path): + if not os.path.exists(source_file_path): + raise FileNotFoundError(f"源文件不存在: {source_file_path}") + + dir_name = os.path.dirname(source_file_path) + base_name = os.path.basename(source_file_path) + name, ext = os.path.splitext(base_name) + + new_file_name = f"{name}_cpu{ext}" + new_file_path = os.path.join(dir_name, new_file_name) + + with open(source_file_path, 'r', encoding='utf-8') as f: + content = f.read() + + new_content = content.replace("triton", "triton_cpu") + with open(new_file_path, 'w', encoding='utf-8') as f: + f.write(new_content) + + return new_file_path + + class _Handle: - def __init__(self, kernel, launch, source): - self._kernel = kernel - self._launch = launch + def __init__(self, source, code_generator): self._source = source + self._code_generator = code_generator def __call__(self, *args, **kwargs): - return self._launch(*args, **kwargs) + target_device = get_target_device(*args, **kwargs) + + if target_device is not None and str(target_device) == "cpu": + cpu_path=convert_to_cpu(self._source) + self._source = cpu_path + + module = import_from_path(self._source, self._source) + module_vars = vars(module) + self._launch = module_vars[self._code_generator.launch_func_name] + + return self._launch(*args, **kwargs) \ No newline at end of file