Skip to content

Commit 1c5325f

Browse files
tangzhiyi11pdx1989
andauthored
[DICP][ascend] bugfix for llama finetune (#631)
Co-authored-by: Pan Daoxin <[email protected]>
1 parent 9c231e7 commit 1c5325f

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

dicp/dicp/vendor/AscendGraph/compile_job.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,11 @@ def build_graph(self, output_path, graph_path):
7777
def get_compile_result(self):
7878
if (not os.path.exists(self._model_path[0]) and not os.path.exists(self._model_path[1])):
7979
self.build_graph(self._output_graph_path, self._input_path)
80+
origin_graph_path = self._output_graph_path
8081
if not os.path.exists(self._output_graph_path + '.om'):
81-
self._output_graph_path += '_linux_x86_64'
82+
self._output_graph_path = origin_graph_path + '_linux_x86_64'
83+
if not os.path.exists(self._output_graph_path + '.om'):
84+
self._output_graph_path = origin_graph_path + '_linux_aarch64'
8285
assert (os.path.exists(self._output_graph_path + '.om'))
8386
from dicp.vendor.AscendGraph.codegen.load_and_run import AscendModel
8487
return AscendModel(self._local_rank, self._output_graph_path + '.om')

dicp/dicp/vendor/AscendGraph/conversion.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,12 @@ def full(self, dims, value, dtype=torch.float32, layout=torch.strided,
577577
if isinstance(value, torch.fx.proxy.Proxy) and hasattr(value.node, 'meta'):
578578
value = value.node.meta['val']
579579
dims = self.get_shape_proxy(dims)
580-
value = self.get_proxy(ascend_op.Const, ([value], torch_dtype, []))
580+
581+
# temporarily split the path for dynamic/static shape cases
582+
if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0:
583+
value = self.get_proxy(ascend_op.Const, ([value], torch_dtype, []))
584+
else:
585+
value = self.common_process_scalar(value, torch_dtype)
581586
return self.get_proxy(ascend_op.Fill, (dims, value))
582587

583588
@register_conversion(torch.ops.aten.fill.Scalar)

0 commit comments

Comments
 (0)