Skip to content

Commit 180aa12

Browse files
committed
Fix lda on blas
1 parent bdb3107 commit 180aa12

File tree

2 files changed

+3
-14
lines changed

2 files changed

+3
-14
lines changed

enzyme/Enzyme/BlasDerivatives.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -801,8 +801,8 @@ def potrf: CallBlasPattern<(Op $layout, $uplo, $n, $A, $lda, $info),
801801
(CopyLowerToUpper<""> $layout, $uplo, (Shadow $A), $n),
802802

803803
// Actual Math
804-
(BlasCall<"trsm"> $layout, Char<"L">, $uplo, (uplo_to_normal $uplo), Char<"N">, $n, $n, Constant<"1.0">, $A, $lda, (Shadow $A)),
805-
(BlasCall<"trsm"> $layout, Char<"R">, $uplo, (uplo_to_trans $uplo), Char<"N">, $n, $n, Constant<"1.0">, $A, $lda, (Shadow $A)),
804+
(BlasCall<"trsm"> $layout, Char<"L">, $uplo, (uplo_to_normal $uplo), Char<"N">, $n, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $n, $n), (Shadow $A)),
805+
(BlasCall<"trsm"> $layout, Char<"R">, $uplo, (uplo_to_trans $uplo), Char<"N">, $n, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $n, $n), (Shadow $A)),
806806
(BlasCall<"scal"> $n, Constant<"0.5">, (First (Shadow $A)), (Add $lda, ConstantInt<1>)),
807807

808808
(Seq<["tmp", "vector", "n"], [], 1>
@@ -814,7 +814,7 @@ def potrf: CallBlasPattern<(Op $layout, $uplo, $n, $A, $lda, $info),
814814
(BlasCall<"copy"> $n, use<"tmp">, ConstantInt<1>, (First (Shadow $A)), (Add $lda, ConstantInt<1>)),
815815

816816
// More math
817-
(BlasCall<"trmm"> $layout, (uplo_to_side $uplo), $uplo, Char<"N">, Char<"N">, $n, $n, Constant<"1.0">, $A, $lda, (Shadow $A)),
817+
(BlasCall<"trmm"> $layout, (uplo_to_side $uplo), $uplo, Char<"N">, Char<"N">, $n, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $n, $n), (Shadow $A)),
818818
// Restore the flipped side, but saving our computed diagonal
819819
(BlasCall<"copy"> $n, (First (Shadow $A)), (Add $lda, ConstantInt<1>), use<"tmp">, ConstantInt<1>),
820820
(BlasCall<"lacpy"> $layout, (flip_uplo $uplo), $n, $n, use<"tri">, $n, (Shadow $A)),

enzyme/Enzyme/Utils.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3985,17 +3985,6 @@ arePointersGuaranteedNoAlias(TargetLibraryInfo &TLI, llvm::AAResults &AA,
39853985
if (lhs == rhs) {
39863986
return false;
39873987
}
3988-
if (auto i1 = dyn_cast<Instruction>(op1))
3989-
if (isa<ConstantPointerNull>(op0) &&
3990-
hasMetadata(i1, LLVMContext::MD_nonnull)) {
3991-
return true;
3992-
}
3993-
if (auto i0 = dyn_cast<Instruction>(op0))
3994-
if (isa<ConstantPointerNull>(op1) &&
3995-
hasMetadata(i0, LLVMContext::MD_nonnull)) {
3996-
return true;
3997-
}
3998-
39993988
if (!lhs->getType()->isPointerTy() && !rhs->getType()->isPointerTy())
40003989
return {};
40013990

0 commit comments

Comments
 (0)