@@ -801,8 +801,8 @@ def potrf: CallBlasPattern<(Op $layout, $uplo, $n, $A, $lda, $info),
801801 (CopyLowerToUpper<""> $layout, $uplo, (Shadow $A), $n),
802802
803803 // Actual Math
804- (BlasCall<"trsm"> $layout, Char<"L">, $uplo, (uplo_to_normal $uplo), Char<"N">, $n, $n, Constant<"1.0">, $A, $ lda, (Shadow $A)),
805- (BlasCall<"trsm"> $layout, Char<"R">, $uplo, (uplo_to_trans $uplo), Char<"N">, $n, $n, Constant<"1.0">, $A, $ lda, (Shadow $A)),
804+ (BlasCall<"trsm"> $layout, Char<"L">, $uplo, (uplo_to_normal $uplo), Char<"N">, $n, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $ lda, $n, $n) , (Shadow $A)),
805+ (BlasCall<"trsm"> $layout, Char<"R">, $uplo, (uplo_to_trans $uplo), Char<"N">, $n, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $ lda, $n, $n) , (Shadow $A)),
806806 (BlasCall<"scal"> $n, Constant<"0.5">, (First (Shadow $A)), (Add $lda, ConstantInt<1>)),
807807
808808 (Seq<["tmp", "vector", "n"], [], 1>
@@ -814,7 +814,7 @@ def potrf: CallBlasPattern<(Op $layout, $uplo, $n, $A, $lda, $info),
814814 (BlasCall<"copy"> $n, use<"tmp">, ConstantInt<1>, (First (Shadow $A)), (Add $lda, ConstantInt<1>)),
815815
816816 // More math
817- (BlasCall<"trmm"> $layout, (uplo_to_side $uplo), $uplo, Char<"N">, Char<"N">, $n, $n, Constant<"1.0">, $A, $ lda, (Shadow $A)),
817+ (BlasCall<"trmm"> $layout, (uplo_to_side $uplo), $uplo, Char<"N">, Char<"N">, $n, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $ lda, $n, $n) , (Shadow $A)),
818818 // Restore the flipped side, but saving our computed diagonal
819819 (BlasCall<"copy"> $n, (First (Shadow $A)), (Add $lda, ConstantInt<1>), use<"tmp">, ConstantInt<1>),
820820 (BlasCall<"lacpy"> $layout, (flip_uplo $uplo), $n, $n, use<"tri">, $n, (Shadow $A)),
0 commit comments