diff --git a/src/ATen/native/xpu/BatchLinearAlgebra.cpp b/src/ATen/native/xpu/BatchLinearAlgebra.cpp new file mode 100644 index 000000000..46a6da83a --- /dev/null +++ b/src/ATen/native/xpu/BatchLinearAlgebra.cpp @@ -0,0 +1,52 @@ +#include +#include +#include +#if defined(USE_ONEMKL_XPU) +#include +#endif // USE_ONEMKL_XPU + +namespace at::native { + +void lu_solve_kernel_xpu( + const Tensor& LU, + const Tensor& pivots, + const Tensor& B, + TransposeType trans) { +#if defined(USE_ONEMKL_XPU) + native::xpu::lu_solve_mkl(LU, pivots, B, trans); +#else + const auto LU_cpu = LU.to(LU.options().device(kCPU)); + const auto pivots_cpu = pivots.to(pivots.options().device(kCPU)); + auto B_cpu = B.to(B.options().device(kCPU)); + + lu_solve_stub(at::kCPU, LU_cpu, pivots_cpu, B_cpu, trans); + + B.copy_(B_cpu); +#endif // USE_ONEMKL_XPU +} + +REGISTER_XPU_DISPATCH(lu_solve_stub, &lu_solve_kernel_xpu); + +void lu_factor_kernel_xpu( + const Tensor& input, + const Tensor& pivots, + const Tensor& infos, + bool compute_pivots) { +#if defined(USE_ONEMKL_XPU) + native::xpu::lu_factor_mkl(input, pivots, infos, compute_pivots); +#else + auto input_cpu = input.to(input.options().device(kCPU)); + auto pivots_cpu = pivots.to(pivots.options().device(kCPU)); + const auto infos_cpu = infos.to(infos.options().device(kCPU)); + + lu_factor_stub(at::kCPU, input_cpu, pivots_cpu, infos_cpu, compute_pivots); + + input.copy_(input_cpu); + pivots.copy_(pivots_cpu); + infos.copy_(infos_cpu); +#endif // USE_ONEMKL_XPU +} + +REGISTER_XPU_DISPATCH(lu_factor_stub, &lu_factor_kernel_xpu); + +} // namespace at::native diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 81bdd4564..aba93034b 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -193,7 +193,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "_flash_attention_forward", "geqrf", "linalg_cholesky_ex.L", - "_linalg_det.result", "linalg_eig", "_linalg_eigvals", "linalg_eigvals.out", @@ -203,13 +202,9 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "linalg_ldl_factor_ex.out", "linalg_ldl_solve.out", "linalg_lstsq.out", - "linalg_lu_factor_ex.out", "linalg_lu.out", - "linalg_lu_solve.out", "linalg_matrix_exp", "linalg_qr.out", - "_linalg_slogdet.sign", - "_linalg_solve_ex.result", "linalg_solve_triangular", "_linalg_svd.U", "lu_unpack.out", diff --git a/src/ATen/native/xpu/mkl/BatchLinearAlgebra.cpp b/src/ATen/native/xpu/mkl/BatchLinearAlgebra.cpp new file mode 100644 index 000000000..42533078e --- /dev/null +++ b/src/ATen/native/xpu/mkl/BatchLinearAlgebra.cpp @@ -0,0 +1,541 @@ +#if defined(USE_ONEMKL_XPU) +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace at::native::xpu { + +#define SYCL_ONEMKL_SUBMIT(q, routine, ...) \ + { \ + auto e = (routine(__VA_ARGS__)); \ + (q).throw_asynchronous(); \ + } + +// Transforms TransposeType into the BLAS / LAPACK format +static oneapi::mkl::transpose to_blas_(TransposeType trans) { + switch (trans) { + case TransposeType::Transpose: + return oneapi::mkl::transpose::trans; + case TransposeType::NoTranspose: + return oneapi::mkl::transpose::nontrans; + case TransposeType::ConjTranspose: + return oneapi::mkl::transpose::conjtrans; + } + TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); +} + +void error_handle(int32_t* infos, const oneapi::mkl::lapack::batch_error& be) { + auto errs = be.exceptions(); + auto ids = be.ids(); + for (auto& i : ids) { + try { + std::rethrow_exception(errs[i]); + } catch (const oneapi::mkl::lapack::exception& e) { + std::cout << "Cathed lapack exception:" + << "\nWhat: " << e.what() << "\nInfo: " << e.info() + << "\nDetail: " << e.detail() << std::endl; + infos[i] = e.info(); + } catch (const sycl::exception& e) { + std::cout << "Catched SYCL exception:" + << "\nWhat: " << e.what() << "\nInfo: -1" << std::endl; + infos[i] = -1; + } + } +} + +template +void mkl_getrf( + sycl::queue& queue, + int64_t m, + int64_t n, + scalar_t* a, + int64_t lda, + int64_t stride_a, + int64_t* ipiv, + int64_t stride_ipiv, + int64_t batch_size, + scalar_t* scratchpad, + int scratchpadsize) { + SYCL_ONEMKL_SUBMIT( + queue, + oneapi::mkl::lapack::getrf_batch, + queue, + m, + n, + a, + lda, + stride_a, + ipiv, + stride_ipiv, + batch_size, + scratchpad, + scratchpadsize); +} + +template <> +void mkl_getrf>( + sycl::queue& queue, + int64_t m, + int64_t n, + c10::complex* a, + int64_t lda, + int64_t stride_a, + int64_t* ipiv, + int64_t stride_ipiv, + int64_t batch_size, + c10::complex* scratchpad, + int scratchpadsize) { + SYCL_ONEMKL_SUBMIT( + queue, + oneapi::mkl::lapack::getrf_batch, + queue, + m, + n, + reinterpret_cast*>(a), + lda, + stride_a, + ipiv, + stride_ipiv, + batch_size, + reinterpret_cast*>(scratchpad), + scratchpadsize); +} + +template <> +void mkl_getrf>( + sycl::queue& queue, + int64_t m, + int64_t n, + c10::complex* a, + int64_t lda, + int64_t stride_a, + int64_t* ipiv, + int64_t stride_ipiv, + int64_t batch_size, + c10::complex* scratchpad, + int scratchpadsize) { + SYCL_ONEMKL_SUBMIT( + queue, + oneapi::mkl::lapack::getrf_batch, + queue, + m, + n, + reinterpret_cast*>(a), + lda, + stride_a, + ipiv, + stride_ipiv, + batch_size, + reinterpret_cast*>(scratchpad), + scratchpadsize); +} + +template +void mkl_getrs( + sycl::queue& queue, + oneapi::mkl::transpose trans, + int64_t n, + int64_t nrhs, + scalar_t* a, + int64_t lda, + int64_t stride_a, + int64_t* ipiv, + int64_t stride_ipiv, + scalar_t* b, + int64_t ldb, + int64_t stride_b, + int64_t batch_size, + scalar_t* scratchpad, + int64_t scratchpad_size) { + SYCL_ONEMKL_SUBMIT( + queue, + oneapi::mkl::lapack::getrs_batch, + queue, + trans, + n, + nrhs, + a, + lda, + stride_a, + ipiv, + stride_ipiv, + b, + ldb, + stride_b, + batch_size, + scratchpad, + scratchpad_size); +} + +template <> +void mkl_getrs>( + sycl::queue& queue, + oneapi::mkl::transpose trans, + int64_t n, + int64_t nrhs, + c10::complex* a, + int64_t lda, + int64_t stride_a, + int64_t* ipiv, + int64_t stride_ipiv, + c10::complex* b, + int64_t ldb, + int64_t stride_b, + int64_t batch_size, + c10::complex* scratchpad, + int64_t scratchpad_size) { + SYCL_ONEMKL_SUBMIT( + queue, + oneapi::mkl::lapack::getrs_batch, + queue, + trans, + n, + nrhs, + reinterpret_cast*>(a), + lda, + stride_a, + ipiv, + stride_ipiv, + reinterpret_cast*>(b), + ldb, + stride_b, + batch_size, + reinterpret_cast*>(scratchpad), + scratchpad_size); +} + +template <> +void mkl_getrs>( + sycl::queue& queue, + oneapi::mkl::transpose trans, + int64_t n, + int64_t nrhs, + c10::complex* a, + int64_t lda, + int64_t stride_a, + int64_t* ipiv, + int64_t stride_ipiv, + c10::complex* b, + int64_t ldb, + int64_t stride_b, + int64_t batch_size, + c10::complex* scratchpad, + int64_t scratchpad_size) { + SYCL_ONEMKL_SUBMIT( + queue, + oneapi::mkl::lapack::getrs_batch, + queue, + trans, + n, + nrhs, + reinterpret_cast*>(a), + lda, + stride_a, + ipiv, + stride_ipiv, + reinterpret_cast*>(b), + ldb, + stride_b, + batch_size, + reinterpret_cast*>(scratchpad), + scratchpad_size); +} + +template +int64_t mkl_getrf_scratchpad( + sycl::queue& queue, + int64_t m, + int64_t n, + int64_t lda, + int64_t stride_a, + int64_t stride_ipiv, + int64_t batch_size) { + return oneapi::mkl::lapack::getrf_batch_scratchpad_size( + queue, m, n, lda, stride_a, stride_ipiv, batch_size); +} + +template <> +int64_t mkl_getrf_scratchpad>( + sycl::queue& queue, + int64_t m, + int64_t n, + int64_t lda, + int64_t stride_a, + int64_t stride_ipiv, + int64_t batch_size) { + return oneapi::mkl::lapack::getrf_batch_scratchpad_size>( + queue, m, n, lda, stride_a, stride_ipiv, batch_size); +} + +template <> +int64_t mkl_getrf_scratchpad>( + sycl::queue& queue, + int64_t m, + int64_t n, + int64_t lda, + int64_t stride_a, + int64_t stride_ipiv, + int64_t batch_size) { + return oneapi::mkl::lapack::getrf_batch_scratchpad_size>( + queue, m, n, lda, stride_a, stride_ipiv, batch_size); +} + +template +int64_t mkl_getrs_scratchpad( + sycl::queue& queue, + oneapi::mkl::transpose trans, + int64_t n, + int64_t nrhs, + int64_t lda, + int64_t stride_a, + int64_t stride_ipiv, + int64_t ldb, + int64_t stride_b, + int64_t batch_size) { + return oneapi::mkl::lapack::getrs_batch_scratchpad_size( + queue, + trans, + n, + nrhs, + lda, + stride_a, + stride_ipiv, + ldb, + stride_b, + batch_size); +} + +template <> +int64_t mkl_getrs_scratchpad>( + sycl::queue& queue, + oneapi::mkl::transpose trans, + int64_t n, + int64_t nrhs, + int64_t lda, + int64_t stride_a, + int64_t stride_ipiv, + int64_t ldb, + int64_t stride_b, + int64_t batch_size) { + return oneapi::mkl::lapack::getrs_batch_scratchpad_size>( + queue, + trans, + n, + nrhs, + lda, + stride_a, + stride_ipiv, + ldb, + stride_b, + batch_size); +} + +template <> +int64_t mkl_getrs_scratchpad>( + sycl::queue& queue, + oneapi::mkl::transpose trans, + int64_t n, + int64_t nrhs, + int64_t lda, + int64_t stride_a, + int64_t stride_ipiv, + int64_t ldb, + int64_t stride_b, + int64_t batch_size) { + return oneapi::mkl::lapack::getrs_batch_scratchpad_size>( + queue, + trans, + n, + nrhs, + lda, + stride_a, + stride_ipiv, + ldb, + stride_b, + batch_size); +} + +template +static void apply_lu_xpu_( + const Tensor& self_, + Tensor& pivots_, + int32_t* infos_) { + // do nothing if empty input. + if (self_.numel() == 0) + return; + + auto& queue = at::xpu::getCurrentSYCLQueue(); + int64_t batch_size = native::batchCount(self_); + int64_t m = self_.size(-2); + int64_t n = self_.size(-1); + int64_t lda = m; + int64_t stride_a = lda * n; + int64_t stride_ipiv = (m < n) ? m : n; + scalar_t* a = (scalar_t*)(self_.data_ptr()); + int64_t* ipiv = (int64_t*)(pivots_.data_ptr()); + int64_t scratchpadsize = mkl_getrf_scratchpad( + queue, m, n, lda, stride_a, stride_ipiv, batch_size); + Tensor scratchpad_at = at::empty({scratchpadsize}, self_.options()); + try { + mkl_getrf( + queue, + m, + n, + a, + lda, + stride_a, + ipiv, + stride_ipiv, + batch_size, + (scalar_t*)(scratchpad_at.data_ptr()), + scratchpadsize); + } catch (const oneapi::mkl::lapack::batch_error& be) { + error_handle(infos_, be); + } +} + +template +static void apply_lu_solve_xpu_( + const Tensor& b_, + const Tensor& lu_, + const Tensor& pivots_, + TransposeType t) { + // do nothing if empty input + if (lu_.numel() == 0) + return; + + auto& queue = at::xpu::getCurrentSYCLQueue(); + int64_t batch_size = native::batchCount(b_); + + auto trans = to_blas_(t); + int64_t n = lu_.size(-2); + int64_t nrhs = b_.size(-1); + int64_t lda = lu_.size(-2); + int64_t stride_a = native::matrixStride(lu_); + int64_t stride_ipiv = pivots_.size(-1); + int64_t ldb = b_.size(-2); + int64_t stride_b = native::matrixStride(b_); + + scalar_t* a = lu_.data_ptr(); + Tensor pivots = pivots_; + if (pivots_.scalar_type() == at::ScalarType::Int) + pivots = pivots_.to(kLong); + int64_t* ipiv = pivots.data_ptr(); + scalar_t* b = b_.data_ptr(); + + auto execute_mkl_getrs = + [&](scalar_t* a, scalar_t* b, int64_t* ipiv, int64_t batch_size) { + int64_t scratchpad_size = mkl_getrs_scratchpad( + queue, + trans, + n, + nrhs, + lda, + stride_a, + stride_ipiv, + ldb, + stride_b, + batch_size); + Tensor scratchpad_at = at::empty({scratchpad_size}, b_.options()); + mkl_getrs( + queue, + trans, + n, + nrhs, + a, + lda, + stride_a, + ipiv, + stride_ipiv, + b, + ldb, + stride_b, + batch_size, + scratchpad_at.data_ptr(), + scratchpad_size); + }; + + bool is_broadcast = false; + IntArrayRef lu_batch_shape(lu_.sizes().data(), lu_.dim() - 2); + IntArrayRef b_batch_shape(b_.sizes().data(), b_.dim() - 2); + + { + auto infer_size_buffer = at::infer_size(lu_batch_shape, b_batch_shape); + IntArrayRef out_batch_shape(infer_size_buffer); + + is_broadcast = !(out_batch_shape.equals(lu_batch_shape)); + } + + if (!is_broadcast) { + execute_mkl_getrs(a, b, ipiv, batch_size); + return; + } + + BroadcastLinearIndices lu_index( + native::batchCount(lu_), lu_batch_shape, b_batch_shape); + + for (const auto i : c10::irange(batch_size)) { + int64_t lu_index_i = lu_index(i); + scalar_t* a_working_ptr = &a[lu_index_i * stride_a]; + scalar_t* b_working_ptr = &b[i * stride_b]; + int64_t* ipiv_working_ptr = &ipiv[lu_index_i * stride_ipiv]; + + execute_mkl_getrs(a_working_ptr, b_working_ptr, ipiv_working_ptr, 1); + } +} + +void lu_solve_mkl( + const Tensor& LU, + const Tensor& pivots, + const Tensor& B, + TransposeType trans) { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(LU.scalar_type(), "lu_solve_xpu", [&] { + apply_lu_solve_xpu_(B, LU, pivots, trans); + }); +} + +void lu_factor_mkl( + const Tensor& LU, + const Tensor& pivots, + const Tensor& info, + bool pivot) { + TORCH_CHECK( + LU.dim() >= 2, + "torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: ", + LU.sizes(), + " instead"); + TORCH_CHECK( + pivot, + "linalg.lu_factor: LU without pivoting is not implemented on the XPU"); + + // handle the info + info.zero_(); + int32_t* infos_data = info.data_ptr(); + + // oneMKL requires Long for pivots but PyTorch provides Int + Tensor pivots_ = at::empty(pivots.sizes(), pivots.options().dtype(kLong)); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(LU.scalar_type(), "lu_xpu", [&] { + apply_lu_xpu_(LU, pivots_, infos_data); + }); + + // Copy to original pivots tensor + pivots.copy_(pivots_); +} + +} // namespace at::native::xpu +#endif // USE_ONEMKL_XPU diff --git a/src/ATen/native/xpu/mkl/BatchLinearAlgebra.h b/src/ATen/native/xpu/mkl/BatchLinearAlgebra.h new file mode 100644 index 000000000..c1cc1da5c --- /dev/null +++ b/src/ATen/native/xpu/mkl/BatchLinearAlgebra.h @@ -0,0 +1,19 @@ +#pragma once + +#include + +namespace at::native::xpu { + +TORCH_XPU_API void lu_solve_mkl( + const Tensor& LU, + const Tensor& pivots, + const Tensor& B, + TransposeType trans); + +TORCH_XPU_API void lu_factor_mkl( + const Tensor& LU, + const Tensor& pivots, + const Tensor& info, + bool pivot); + +} // namespace at::native::xpu diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 98245cf49..fb5a0381a 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -38,10 +38,6 @@ "test_errors_dot_xpu", "test_errors_vdot_xpu", # Linalg OPs not supported - "test_noncontiguous_samples_linalg_det_xpu_float32", - "test_noncontiguous_samples_linalg_slogdet_xpu_float32", - "test_noncontiguous_samples_linalg_solve_ex_xpu_float32", - "test_noncontiguous_samples_linalg_solve_xpu_float32", "test_noncontiguous_samples_linalg_tensorsolve_xpu_float32", "test_noncontiguous_samples_logdet_xpu_float32", # Sparse CSR OPs not supported @@ -133,7 +129,6 @@ "test_dtypes_linalg_cholesky_ex_xpu", "test_dtypes_linalg_cholesky_xpu", "test_dtypes_linalg_cond_xpu", - "test_dtypes_linalg_det_singular_xpu", "test_dtypes_linalg_det_xpu", "test_dtypes_linalg_eig_xpu", "test_dtypes_linalg_eigh_xpu", @@ -281,7 +276,6 @@ "test_out_requires_grad_error_inner_xpu_complex64", "test_out_requires_grad_error_linalg_cholesky_ex_xpu_complex64", "test_out_requires_grad_error_linalg_cholesky_xpu_complex64", - "test_out_requires_grad_error_linalg_det_singular_xpu_complex64", "test_out_requires_grad_error_linalg_eig_xpu_complex64", "test_out_requires_grad_error_linalg_eigh_xpu_complex64", "test_out_requires_grad_error_linalg_eigvals_xpu_complex64", @@ -350,7 +344,6 @@ "test_variant_consistency_eager_linalg_cholesky_ex_xpu_complex64", "test_variant_consistency_eager_linalg_cholesky_xpu_complex64", "test_variant_consistency_eager_linalg_cond_xpu_complex64", - "test_variant_consistency_eager_linalg_det_singular_xpu_complex64", "test_variant_consistency_eager_linalg_det_xpu_complex64", "test_variant_consistency_eager_linalg_eig_xpu_complex64", "test_variant_consistency_eager_linalg_eigh_xpu_complex64", @@ -425,7 +418,6 @@ "test_conj_view_linalg_cholesky_ex_xpu_complex64", "test_conj_view_linalg_cholesky_xpu_complex64", "test_conj_view_linalg_cond_xpu_complex64", - "test_conj_view_linalg_det_singular_xpu_complex64", "test_conj_view_linalg_det_xpu_complex64", "test_conj_view_linalg_eig_xpu_complex64", "test_conj_view_linalg_eigh_xpu_complex64", @@ -490,7 +482,6 @@ "test_neg_conj_view_linalg_cholesky_ex_xpu_complex128", "test_neg_conj_view_linalg_cholesky_xpu_complex128", "test_neg_conj_view_linalg_cond_xpu_complex128", - "test_neg_conj_view_linalg_det_singular_xpu_complex128", "test_neg_conj_view_linalg_eig_xpu_complex128", "test_neg_conj_view_linalg_eigh_xpu_complex128", "test_neg_conj_view_linalg_eigvals_xpu_complex128", @@ -551,8 +542,6 @@ "test_neg_view_linalg_cholesky_ex_xpu_float64", "test_neg_view_linalg_cholesky_xpu_float64", "test_neg_view_linalg_cond_xpu_float64", - "test_neg_view_linalg_det_singular_xpu_float64", - "test_neg_view_linalg_det_xpu_float64", "test_neg_view_linalg_eig_xpu_float64", "test_neg_view_linalg_eigh_xpu_float64", "test_neg_view_linalg_eigvalsh_xpu_float64", @@ -564,8 +553,6 @@ "test_neg_view_linalg_ldl_solve_xpu_float64", "test_neg_view_linalg_lstsq_grad_oriented_xpu_float64", "test_neg_view_linalg_lstsq_xpu_float64", - "test_neg_view_linalg_lu_factor_xpu_float64", - "test_neg_view_linalg_lu_solve_xpu_float64", "test_neg_view_linalg_matrix_norm_xpu_float64", "test_neg_view_linalg_matrix_power_xpu_float64", "test_neg_view_linalg_matrix_rank_hermitian_xpu_float64", @@ -577,16 +564,12 @@ "test_neg_view_linalg_pinv_singular_xpu_float64", "test_neg_view_linalg_pinv_xpu_float64", "test_neg_view_linalg_qr_xpu_float64", - "test_neg_view_linalg_slogdet_xpu_float64", - "test_neg_view_linalg_solve_ex_xpu_float64", "test_neg_view_linalg_solve_triangular_xpu_float64", - "test_neg_view_linalg_solve_xpu_float64", "test_neg_view_linalg_svd_xpu_float64", "test_neg_view_linalg_svdvals_xpu_float64", "test_neg_view_linalg_tensorinv_xpu_float64", "test_neg_view_linalg_tensorsolve_xpu_float64", "test_neg_view_logdet_xpu_float64", - "test_neg_view_lu_solve_xpu_float64", "test_neg_view_lu_xpu_float64", "test_neg_view_matmul_xpu_float64", "test_neg_view_mm_xpu_float64", @@ -1278,8 +1261,6 @@ "test_corner_cases_of_cublasltmatmul_xpu_complex128", "test_corner_cases_of_cublasltmatmul_xpu_complex64", "test_corner_cases_of_cublasltmatmul_xpu_float64", - "test_det_logdet_slogdet_batched_xpu_float64", - "test_det_logdet_slogdet_xpu_float64", "test_eig_check_magma_xpu_float32", "test_einsum_random_xpu_complex128", "test_einsum_random_xpu_float64", @@ -1312,7 +1293,6 @@ "test_linalg_lu_family_xpu_float64", "test_linalg_lu_solve_xpu_complex128", "test_linalg_lu_solve_xpu_complex64", - "test_linalg_lu_solve_xpu_float64", "test_linalg_solve_triangular_broadcasting_xpu_complex128", "test_linalg_solve_triangular_broadcasting_xpu_complex64", "test_linalg_solve_triangular_broadcasting_xpu_float64", @@ -1326,19 +1306,14 @@ "test_lobpcg_ortho_xpu_float64", "test_lu_solve_batched_broadcasting_xpu_complex128", "test_lu_solve_batched_broadcasting_xpu_complex64", - "test_lu_solve_batched_broadcasting_xpu_float64", "test_lu_solve_batched_many_batches_xpu_complex128", "test_lu_solve_batched_many_batches_xpu_complex64", - "test_lu_solve_batched_many_batches_xpu_float64", "test_lu_solve_batched_xpu_complex128", "test_lu_solve_batched_xpu_complex64", - "test_lu_solve_batched_xpu_float64", "test_lu_solve_large_matrices_xpu_complex128", "test_lu_solve_large_matrices_xpu_complex64", - "test_lu_solve_large_matrices_xpu_float64", "test_lu_solve_xpu_complex128", "test_lu_solve_xpu_complex64", - "test_lu_solve_xpu_float64", "test_matmul_out_kernel_errors_with_autograd_xpu_complex64", "test_matmul_small_brute_force_1d_Nd_xpu_complex64", "test_matmul_small_brute_force_2d_Nd_xpu_complex64", @@ -1383,7 +1358,6 @@ "test_pinverse_xpu_float64", "test_slogdet_xpu_complex128", "test_slogdet_xpu_complex64", - "test_slogdet_xpu_float64", "test_solve_batched_broadcasting_xpu_complex128", "test_solve_batched_broadcasting_xpu_complex64", "test_solve_batched_broadcasting_xpu_float64", @@ -1528,7 +1502,6 @@ "test_fn_fwgrad_bwgrad_linalg_cond_xpu_complex128", "test_fn_fwgrad_bwgrad_linalg_cond_xpu_float64", "test_fn_fwgrad_bwgrad_linalg_det_xpu_complex128", - "test_fn_fwgrad_bwgrad_linalg_det_xpu_float64", "test_fn_fwgrad_bwgrad_linalg_eig_xpu_complex128", "test_fn_fwgrad_bwgrad_linalg_eig_xpu_float64", "test_fn_fwgrad_bwgrad_linalg_eigh_xpu_complex128", @@ -1546,11 +1519,8 @@ "test_fn_fwgrad_bwgrad_linalg_lstsq_grad_oriented_xpu_complex128", "test_fn_fwgrad_bwgrad_linalg_lstsq_grad_oriented_xpu_float64", "test_fn_fwgrad_bwgrad_linalg_lu_factor_ex_xpu_complex128", - "test_fn_fwgrad_bwgrad_linalg_lu_factor_ex_xpu_float64", "test_fn_fwgrad_bwgrad_linalg_lu_factor_xpu_complex128", - "test_fn_fwgrad_bwgrad_linalg_lu_factor_xpu_float64", "test_fn_fwgrad_bwgrad_linalg_lu_solve_xpu_complex128", - "test_fn_fwgrad_bwgrad_linalg_lu_solve_xpu_float64", "test_fn_fwgrad_bwgrad_linalg_lu_xpu_complex128", "test_fn_fwgrad_bwgrad_linalg_lu_xpu_float64", "test_fn_fwgrad_bwgrad_linalg_matrix_norm_xpu_complex128", @@ -1568,13 +1538,10 @@ "test_fn_fwgrad_bwgrad_linalg_qr_xpu_complex128", "test_fn_fwgrad_bwgrad_linalg_qr_xpu_float64", "test_fn_fwgrad_bwgrad_linalg_slogdet_xpu_complex128", - "test_fn_fwgrad_bwgrad_linalg_slogdet_xpu_float64", "test_fn_fwgrad_bwgrad_linalg_solve_ex_xpu_complex128", - "test_fn_fwgrad_bwgrad_linalg_solve_ex_xpu_float64", "test_fn_fwgrad_bwgrad_linalg_solve_triangular_xpu_complex128", "test_fn_fwgrad_bwgrad_linalg_solve_triangular_xpu_float64", "test_fn_fwgrad_bwgrad_linalg_solve_xpu_complex128", - "test_fn_fwgrad_bwgrad_linalg_solve_xpu_float64", "test_fn_fwgrad_bwgrad_linalg_svd_xpu_complex128", "test_fn_fwgrad_bwgrad_linalg_svd_xpu_float64", "test_fn_fwgrad_bwgrad_linalg_svdvals_xpu_complex128", @@ -1586,7 +1553,6 @@ "test_fn_fwgrad_bwgrad_logdet_xpu_complex128", "test_fn_fwgrad_bwgrad_logdet_xpu_float64", "test_fn_fwgrad_bwgrad_lu_solve_xpu_complex128", - "test_fn_fwgrad_bwgrad_lu_solve_xpu_float64", "test_fn_fwgrad_bwgrad_lu_xpu_complex128", "test_fn_fwgrad_bwgrad_lu_xpu_float64", "test_fn_fwgrad_bwgrad_matmul_xpu_complex128", @@ -1647,10 +1613,7 @@ "test_forward_mode_AD_linalg_cholesky_xpu_float64", "test_forward_mode_AD_linalg_cond_xpu_complex128", "test_forward_mode_AD_linalg_cond_xpu_float64", - "test_forward_mode_AD_linalg_det_singular_xpu_complex128", - "test_forward_mode_AD_linalg_det_singular_xpu_float64", "test_forward_mode_AD_linalg_det_xpu_complex128", - "test_forward_mode_AD_linalg_det_xpu_float64", "test_forward_mode_AD_linalg_eig_xpu_complex128", "test_forward_mode_AD_linalg_eig_xpu_float64", "test_forward_mode_AD_linalg_eigh_xpu_complex128", @@ -1668,11 +1631,8 @@ "test_forward_mode_AD_linalg_lstsq_grad_oriented_xpu_complex128", "test_forward_mode_AD_linalg_lstsq_grad_oriented_xpu_float64", "test_forward_mode_AD_linalg_lu_factor_ex_xpu_complex128", - "test_forward_mode_AD_linalg_lu_factor_ex_xpu_float64", "test_forward_mode_AD_linalg_lu_factor_xpu_complex128", - "test_forward_mode_AD_linalg_lu_factor_xpu_float64", "test_forward_mode_AD_linalg_lu_solve_xpu_complex128", - "test_forward_mode_AD_linalg_lu_solve_xpu_float64", "test_forward_mode_AD_linalg_lu_xpu_complex128", "test_forward_mode_AD_linalg_lu_xpu_float64", "test_forward_mode_AD_linalg_matrix_norm_xpu_complex128", @@ -1691,13 +1651,10 @@ "test_forward_mode_AD_linalg_qr_xpu_complex128", "test_forward_mode_AD_linalg_qr_xpu_float64", "test_forward_mode_AD_linalg_slogdet_xpu_complex128", - "test_forward_mode_AD_linalg_slogdet_xpu_float64", "test_forward_mode_AD_linalg_solve_ex_xpu_complex128", - "test_forward_mode_AD_linalg_solve_ex_xpu_float64", "test_forward_mode_AD_linalg_solve_triangular_xpu_complex128", "test_forward_mode_AD_linalg_solve_triangular_xpu_float64", "test_forward_mode_AD_linalg_solve_xpu_complex128", - "test_forward_mode_AD_linalg_solve_xpu_float64", "test_forward_mode_AD_linalg_svd_xpu_complex128", "test_forward_mode_AD_linalg_svd_xpu_float64", "test_forward_mode_AD_linalg_svdvals_xpu_complex128", @@ -1709,7 +1666,6 @@ "test_forward_mode_AD_logdet_xpu_complex128", "test_forward_mode_AD_logdet_xpu_float64", "test_forward_mode_AD_lu_solve_xpu_complex128", - "test_forward_mode_AD_lu_solve_xpu_float64", "test_forward_mode_AD_lu_xpu_complex128", "test_forward_mode_AD_lu_xpu_float64", "test_forward_mode_AD_matmul_xpu_complex128", @@ -1884,10 +1840,7 @@ "test_fn_grad_linalg_cholesky_xpu_float64", "test_fn_grad_linalg_cond_xpu_complex128", "test_fn_grad_linalg_cond_xpu_float64", - "test_fn_grad_linalg_det_singular_xpu_complex128", - "test_fn_grad_linalg_det_singular_xpu_float64", "test_fn_grad_linalg_det_xpu_complex128", - "test_fn_grad_linalg_det_xpu_float64", "test_fn_grad_linalg_eig_xpu_complex128", "test_fn_grad_linalg_eig_xpu_float64", "test_fn_grad_linalg_eigh_xpu_complex128", @@ -1904,11 +1857,8 @@ "test_fn_grad_linalg_lstsq_grad_oriented_xpu_complex128", "test_fn_grad_linalg_lstsq_grad_oriented_xpu_float64", "test_fn_grad_linalg_lu_factor_ex_xpu_complex128", - "test_fn_grad_linalg_lu_factor_ex_xpu_float64", "test_fn_grad_linalg_lu_factor_xpu_complex128", - "test_fn_grad_linalg_lu_factor_xpu_float64", "test_fn_grad_linalg_lu_solve_xpu_complex128", - "test_fn_grad_linalg_lu_solve_xpu_float64", "test_fn_grad_linalg_lu_xpu_complex128", "test_fn_grad_linalg_lu_xpu_float64", "test_fn_grad_linalg_matrix_norm_xpu_complex128", @@ -1927,13 +1877,10 @@ "test_fn_grad_linalg_qr_xpu_complex128", "test_fn_grad_linalg_qr_xpu_float64", "test_fn_grad_linalg_slogdet_xpu_complex128", - "test_fn_grad_linalg_slogdet_xpu_float64", "test_fn_grad_linalg_solve_ex_xpu_complex128", - "test_fn_grad_linalg_solve_ex_xpu_float64", "test_fn_grad_linalg_solve_triangular_xpu_complex128", "test_fn_grad_linalg_solve_triangular_xpu_float64", "test_fn_grad_linalg_solve_xpu_complex128", - "test_fn_grad_linalg_solve_xpu_float64", "test_fn_grad_linalg_svd_xpu_complex128", "test_fn_grad_linalg_svd_xpu_float64", "test_fn_grad_linalg_svdvals_xpu_complex128", @@ -1945,7 +1892,6 @@ "test_fn_grad_logdet_xpu_complex128", "test_fn_grad_logdet_xpu_float64", "test_fn_grad_lu_solve_xpu_complex128", - "test_fn_grad_lu_solve_xpu_float64", "test_fn_grad_lu_xpu_complex128", "test_fn_grad_lu_xpu_float64", "test_fn_grad_matmul_xpu_complex128", @@ -2009,7 +1955,6 @@ "test_fn_gradgrad_linalg_cond_xpu_complex128", "test_fn_gradgrad_linalg_cond_xpu_float64", "test_fn_gradgrad_linalg_det_xpu_complex128", - "test_fn_gradgrad_linalg_det_xpu_float64", "test_fn_gradgrad_linalg_eig_xpu_complex128", "test_fn_gradgrad_linalg_eig_xpu_float64", "test_fn_gradgrad_linalg_eigh_xpu_complex128", @@ -2027,11 +1972,8 @@ "test_fn_gradgrad_linalg_lstsq_grad_oriented_xpu_complex128", "test_fn_gradgrad_linalg_lstsq_grad_oriented_xpu_float64", "test_fn_gradgrad_linalg_lu_factor_ex_xpu_complex128", - "test_fn_gradgrad_linalg_lu_factor_ex_xpu_float64", "test_fn_gradgrad_linalg_lu_factor_xpu_complex128", - "test_fn_gradgrad_linalg_lu_factor_xpu_float64", "test_fn_gradgrad_linalg_lu_solve_xpu_complex128", - "test_fn_gradgrad_linalg_lu_solve_xpu_float64", "test_fn_gradgrad_linalg_lu_xpu_complex128", "test_fn_gradgrad_linalg_lu_xpu_float64", "test_fn_gradgrad_linalg_matrix_norm_xpu_complex128", @@ -2048,13 +1990,10 @@ "test_fn_gradgrad_linalg_qr_xpu_complex128", "test_fn_gradgrad_linalg_qr_xpu_float64", "test_fn_gradgrad_linalg_slogdet_xpu_complex128", - "test_fn_gradgrad_linalg_slogdet_xpu_float64", "test_fn_gradgrad_linalg_solve_ex_xpu_complex128", - "test_fn_gradgrad_linalg_solve_ex_xpu_float64", "test_fn_gradgrad_linalg_solve_triangular_xpu_complex128", "test_fn_gradgrad_linalg_solve_triangular_xpu_float64", "test_fn_gradgrad_linalg_solve_xpu_complex128", - "test_fn_gradgrad_linalg_solve_xpu_float64", "test_fn_gradgrad_linalg_svd_xpu_complex128", "test_fn_gradgrad_linalg_svd_xpu_float64", "test_fn_gradgrad_linalg_svdvals_xpu_complex128", @@ -2066,7 +2005,6 @@ "test_fn_gradgrad_logdet_xpu_complex128", "test_fn_gradgrad_logdet_xpu_float64", "test_fn_gradgrad_lu_solve_xpu_complex128", - "test_fn_gradgrad_lu_solve_xpu_float64", "test_fn_gradgrad_lu_xpu_complex128", "test_fn_gradgrad_lu_xpu_float64", "test_fn_gradgrad_matmul_xpu_complex128", @@ -2377,10 +2315,7 @@ "test_dispatch_meta_outplace_linalg_cholesky_ex_xpu_float64", "test_dispatch_meta_outplace_linalg_cholesky_xpu_complex", "test_dispatch_meta_outplace_linalg_cholesky_xpu_float64", - "test_dispatch_meta_outplace_linalg_det_singular_xpu_complex", - "test_dispatch_meta_outplace_linalg_det_singular_xpu_float64", "test_dispatch_meta_outplace_linalg_det_xpu_complex", - "test_dispatch_meta_outplace_linalg_det_xpu_float64", "test_dispatch_meta_outplace_linalg_eig_xpu_complex", "test_dispatch_meta_outplace_linalg_eig_xpu_float64", "test_dispatch_meta_outplace_linalg_eigh_xpu_complex", @@ -2403,9 +2338,7 @@ "test_dispatch_meta_outplace_linalg_lstsq_xpu_complex", "test_dispatch_meta_outplace_linalg_lstsq_xpu_float64", "test_dispatch_meta_outplace_linalg_lu_factor_xpu_complex", - "test_dispatch_meta_outplace_linalg_lu_factor_xpu_float64", "test_dispatch_meta_outplace_linalg_lu_solve_xpu_complex", - "test_dispatch_meta_outplace_linalg_lu_solve_xpu_float64", "test_dispatch_meta_outplace_linalg_matrix_power_xpu_complex", "test_dispatch_meta_outplace_linalg_matrix_power_xpu_float64", "test_dispatch_meta_outplace_linalg_matrix_rank_hermitian_xpu_complex", @@ -2423,11 +2356,8 @@ "test_dispatch_meta_outplace_linalg_qr_xpu_complex", "test_dispatch_meta_outplace_linalg_qr_xpu_float64", "test_dispatch_meta_outplace_linalg_slogdet_xpu_complex", - "test_dispatch_meta_outplace_linalg_slogdet_xpu_float64", "test_dispatch_meta_outplace_linalg_solve_ex_xpu_complex", - "test_dispatch_meta_outplace_linalg_solve_ex_xpu_float64", "test_dispatch_meta_outplace_linalg_solve_xpu_complex", - "test_dispatch_meta_outplace_linalg_solve_xpu_float64", "test_dispatch_meta_outplace_linalg_svd_xpu_complex", "test_dispatch_meta_outplace_linalg_svd_xpu_float64", "test_dispatch_meta_outplace_linalg_tensorinv_xpu_complex", @@ -2435,7 +2365,6 @@ "test_dispatch_meta_outplace_logdet_xpu_complex", "test_dispatch_meta_outplace_logdet_xpu_float64", "test_dispatch_meta_outplace_lu_solve_xpu_complex", - "test_dispatch_meta_outplace_lu_solve_xpu_float64", "test_dispatch_meta_outplace_lu_xpu_complex", "test_dispatch_meta_outplace_lu_xpu_float64", "test_dispatch_meta_outplace_matmul_xpu_complex", @@ -2506,10 +2435,7 @@ "test_dispatch_symbolic_meta_outplace_linalg_cholesky_ex_xpu_float64", "test_dispatch_symbolic_meta_outplace_linalg_cholesky_xpu_complex", "test_dispatch_symbolic_meta_outplace_linalg_cholesky_xpu_float64", - "test_dispatch_symbolic_meta_outplace_linalg_det_singular_xpu_complex", - "test_dispatch_symbolic_meta_outplace_linalg_det_singular_xpu_float64", "test_dispatch_symbolic_meta_outplace_linalg_det_xpu_complex", - "test_dispatch_symbolic_meta_outplace_linalg_det_xpu_float64", "test_dispatch_symbolic_meta_outplace_linalg_eig_xpu_complex", "test_dispatch_symbolic_meta_outplace_linalg_eig_xpu_float64", "test_dispatch_symbolic_meta_outplace_linalg_eigh_xpu_complex", @@ -2532,9 +2458,7 @@ "test_dispatch_symbolic_meta_outplace_linalg_lstsq_xpu_complex", "test_dispatch_symbolic_meta_outplace_linalg_lstsq_xpu_float64", "test_dispatch_symbolic_meta_outplace_linalg_lu_factor_xpu_complex", - "test_dispatch_symbolic_meta_outplace_linalg_lu_factor_xpu_float64", "test_dispatch_symbolic_meta_outplace_linalg_lu_solve_xpu_complex", - "test_dispatch_symbolic_meta_outplace_linalg_lu_solve_xpu_float64", "test_dispatch_symbolic_meta_outplace_linalg_matrix_power_xpu_complex", "test_dispatch_symbolic_meta_outplace_linalg_matrix_power_xpu_float64", "test_dispatch_symbolic_meta_outplace_linalg_matrix_rank_hermitian_xpu_complex", @@ -2552,11 +2476,8 @@ "test_dispatch_symbolic_meta_outplace_linalg_qr_xpu_complex", "test_dispatch_symbolic_meta_outplace_linalg_qr_xpu_float64", "test_dispatch_symbolic_meta_outplace_linalg_slogdet_xpu_complex", - "test_dispatch_symbolic_meta_outplace_linalg_slogdet_xpu_float64", "test_dispatch_symbolic_meta_outplace_linalg_solve_ex_xpu_complex", - "test_dispatch_symbolic_meta_outplace_linalg_solve_ex_xpu_float64", "test_dispatch_symbolic_meta_outplace_linalg_solve_xpu_complex", - "test_dispatch_symbolic_meta_outplace_linalg_solve_xpu_float64", "test_dispatch_symbolic_meta_outplace_linalg_svd_xpu_complex", "test_dispatch_symbolic_meta_outplace_linalg_svd_xpu_float64", "test_dispatch_symbolic_meta_outplace_linalg_tensorinv_xpu_complex", @@ -2564,7 +2485,6 @@ "test_dispatch_symbolic_meta_outplace_logdet_xpu_complex", "test_dispatch_symbolic_meta_outplace_logdet_xpu_float64", "test_dispatch_symbolic_meta_outplace_lu_solve_xpu_complex", - "test_dispatch_symbolic_meta_outplace_lu_solve_xpu_float64", "test_dispatch_symbolic_meta_outplace_lu_xpu_complex", "test_dispatch_symbolic_meta_outplace_lu_xpu_float64", "test_dispatch_symbolic_meta_outplace_matmul_xpu_complex", @@ -2635,10 +2555,7 @@ "test_meta_outplace_linalg_cholesky_ex_xpu_float64", "test_meta_outplace_linalg_cholesky_xpu_complex", "test_meta_outplace_linalg_cholesky_xpu_float64", - "test_meta_outplace_linalg_det_singular_xpu_complex", - "test_meta_outplace_linalg_det_singular_xpu_float64", "test_meta_outplace_linalg_det_xpu_complex", - "test_meta_outplace_linalg_det_xpu_float64", "test_meta_outplace_linalg_eig_xpu_complex", "test_meta_outplace_linalg_eig_xpu_float64", "test_meta_outplace_linalg_eigh_xpu_complex", @@ -2661,9 +2578,7 @@ "test_meta_outplace_linalg_lstsq_xpu_complex", "test_meta_outplace_linalg_lstsq_xpu_float64", "test_meta_outplace_linalg_lu_factor_xpu_complex", - "test_meta_outplace_linalg_lu_factor_xpu_float64", "test_meta_outplace_linalg_lu_solve_xpu_complex", - "test_meta_outplace_linalg_lu_solve_xpu_float64", "test_meta_outplace_linalg_matrix_power_xpu_complex", "test_meta_outplace_linalg_matrix_power_xpu_float64", "test_meta_outplace_linalg_matrix_rank_hermitian_xpu_complex", @@ -2681,11 +2596,8 @@ "test_meta_outplace_linalg_qr_xpu_complex", "test_meta_outplace_linalg_qr_xpu_float64", "test_meta_outplace_linalg_slogdet_xpu_complex", - "test_meta_outplace_linalg_slogdet_xpu_float64", "test_meta_outplace_linalg_solve_ex_xpu_complex", - "test_meta_outplace_linalg_solve_ex_xpu_float64", "test_meta_outplace_linalg_solve_xpu_complex", - "test_meta_outplace_linalg_solve_xpu_float64", "test_meta_outplace_linalg_svd_xpu_complex", "test_meta_outplace_linalg_svd_xpu_float64", "test_meta_outplace_linalg_tensorinv_xpu_complex", @@ -2693,7 +2605,6 @@ "test_meta_outplace_logdet_xpu_complex", "test_meta_outplace_logdet_xpu_float64", "test_meta_outplace_lu_solve_xpu_complex", - "test_meta_outplace_lu_solve_xpu_float64", "test_meta_outplace_lu_xpu_complex", "test_meta_outplace_lu_xpu_float64", "test_meta_outplace_matmul_xpu_complex", diff --git a/test/xpu/test_linalg_xpu.py b/test/xpu/test_linalg_xpu.py index a986e42f4..161abab14 100644 --- a/test/xpu/test_linalg_xpu.py +++ b/test/xpu/test_linalg_xpu.py @@ -13,10 +13,14 @@ instantiate_device_type_tests, precisionOverride, ) -from torch.testing._internal.common_dtype import floating_and_complex_types_and +from torch.testing._internal.common_dtype import ( + floating_and_complex_types, + floating_and_complex_types_and, +) from torch.testing._internal.common_mkldnn import bf32_on_and_off from torch.testing._internal.common_utils import ( IS_WINDOWS, + make_fullrank_matrices_with_distinct_singular_values, parametrize, run_tests, setBlasBackendsToDefaultFinally, @@ -392,6 +396,192 @@ def ck_blas_library(self): pass +@precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2}) +@dtypes(*floating_and_complex_types()) +def linalg_lu_family(self, device, dtype): + # Tests torch.lu + # torch.linalg.lu_factor + # torch.linalg.lu_factor_ex + # torch.lu_unpack + # torch.linalg.lu_solve + # torch.linalg.solve + make_arg_full = partial( + make_fullrank_matrices_with_distinct_singular_values, device=device, dtype=dtype + ) + make_arg = partial(make_tensor, device=device, dtype=dtype) + + def run_test(A, pivot, singular, fn): + k = min(A.shape[-2:]) + batch = A.shape[:-2] + check_errors = fn == torch.linalg.lu_factor + if singular and check_errors: + # It may or may not throw as the LU decomposition without pivoting + # may still succeed for singular matrices + try: + LU, pivots = fn(A, pivot=pivot) + except RuntimeError: + return + else: + LU, pivots = fn(A, pivot=pivot)[:2] + + self.assertEqual(LU.size(), A.shape) + self.assertEqual(pivots.size(), batch + (k,)) + + if not pivot: + self.assertEqual( + pivots, + torch.arange(1, 1 + k, device=device, dtype=torch.int32).expand( + batch + (k,) + ), + ) + + P, L, U = torch.lu_unpack(LU, pivots, unpack_pivots=pivot) + + self.assertEqual(P @ L @ U if pivot else L @ U, A) + + PLU = torch.linalg.lu(A, pivot=pivot) + self.assertEqual(P, PLU.P) + self.assertEqual(L, PLU.L) + self.assertEqual(U, PLU.U) + + if not singular and A.size(-2) == A.size(-1): + nrhs = ((), (1,), (3,)) + for left, rhs in itertools.product((True, False), nrhs): + # Vector case when left = False is not allowed + if not left and rhs == (): + continue + if left: + shape_B = A.shape[:-1] + rhs + else: + shape_B = A.shape[:-2] + rhs + A.shape[-1:] + B = make_arg(shape_B) + + # Test linalg.lu_solve. It does not support vectors as rhs + # See https://github.com/pytorch/pytorch/pull/74045#issuecomment-1112304913 + if rhs != (): + for adjoint in (True, False): + X = torch.linalg.lu_solve( + LU, pivots, B, left=left, adjoint=adjoint + ) + A_adj = A.mH if adjoint else A + if left: + self.assertEqual(B, A_adj @ X) + else: + self.assertEqual(B, X @ A_adj) + + # Test linalg.solve + X = torch.linalg.solve(A, B, left=left) + X_ = X.unsqueeze(-1) if rhs == () else X + B_ = B.unsqueeze(-1) if rhs == () else B + if left: + self.assertEqual(B_, A @ X_) + else: + self.assertEqual(B_, X_ @ A) + + sizes = ((3, 3), (5, 5), (4, 2), (3, 4), (0, 0), (0, 1), (1, 0)) + batches = ((0,), (), (1,), (2,), (3,), (1, 0), (3, 5)) + # Non pivoting just implemented for CUDA + pivots = (True, False) if self.device_type == "cuda" else (True,) + fns = ( + partial(torch.lu, get_infos=True), + torch.linalg.lu_factor, + torch.linalg.lu_factor_ex, + ) + for ms, batch, pivot, singular, fn in itertools.product( + sizes, batches, pivots, (True, False), fns + ): + shape = batch + ms + A = make_arg(shape) if singular else make_arg_full(*shape) + # Just do one of them on singular matrices + if A.numel() == 0 and not singular: + continue + run_test(A, pivot, singular, fn) + + # Reproducer of a magma bug, + # see https://bitbucket.org/icl/magma/issues/13/getrf_batched-kernel-produces-nans-on + # This is also a bug in cuSOLVER < 11.3 + if dtype == torch.double and singular: + A = torch.ones(batch + ms, dtype=dtype, device=device) + run_test(A, pivot, singular, fn) + + # Info should be positive for rank deficient matrices + A = torch.ones(5, 3, 3, device=device) + self.assertTrue((torch.linalg.lu_factor_ex(A, pivot=True).info >= 0).all()) + + if self.device_type == "cpu": + # Error checking, no pivoting variant on CPU + fns = [ + torch.lu, + torch.linalg.lu_factor, + torch.linalg.lu_factor_ex, + torch.linalg.lu, + ] + for f in fns: + with self.assertRaisesRegex( + RuntimeError, "LU without pivoting is not implemented on the CPU" + ): + f(torch.empty(1, 2, 2), pivot=False) + + +@precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2}) +@setLinalgBackendsToDefaultFinally +@dtypes(*floating_and_complex_types()) +def linalg_lu_solve(self, device, dtype): + make_arg = partial(make_tensor, dtype=dtype, device=device) + + backends = ["default"] + + if torch.device(device).type == "cuda": + if torch.cuda.has_magma: + backends.append("magma") + + def gen_matrices(): + rhs = 3 + ns = (5, 2, 0) + batches = ((), (0,), (1,), (2,), (2, 1), (0, 2)) + for batch, n in itertools.product(batches, ns): + yield make_arg(batch + (n, n)), make_arg(batch + (n, rhs)) + # Shapes to exercise all the paths + shapes = ((1, 64), (2, 128), (1025, 2)) + for b, n in shapes: + yield make_arg((b, n, n)), make_arg((b, n, rhs)) + + for A, B in gen_matrices(): + LU, pivots = torch.linalg.lu_factor(A) + for backend in backends: + torch.backends.cuda.preferred_linalg_library(backend) + + for left, adjoint in itertools.product((True, False), repeat=2): + B_left = B if left else B.mT + X = torch.linalg.lu_solve( + LU, pivots, B_left, left=left, adjoint=adjoint + ) + A_adj = A.mH if adjoint else A + if left: + self.assertEqual(B_left, A_adj @ X) + else: + self.assertEqual(B_left, X @ A_adj) + + +@dtypes(torch.double) +def lu_unpack_check_input(self, device, dtype): + x = torch.rand(5, 5, 5, device=device, dtype=dtype) + lu_data, lu_pivots = torch.linalg.lu_factor(x) + + with self.assertRaisesRegex(RuntimeError, "torch.int32 dtype"): + torch.lu_unpack(lu_data, lu_pivots.long()) + + # check that onces flags are unset, Nones are returned + p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_data=False) + self.assertTrue(l.numel() == 0 and u.numel() == 0) + p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_pivots=False) + self.assertTrue(p.numel() == 0) + p, l, u = torch.lu_unpack( + lu_data, lu_pivots, unpack_data=False, unpack_pivots=False + ) + self.assertTrue(p.numel() == 0 and l.numel() == 0 and u.numel() == 0) + + with XPUPatchForImport(False): from test_linalg import TestLinalg @@ -410,6 +600,10 @@ def ck_blas_library(self): TestLinalg.test_matmul_small_brute_force_2d_Nd = matmul_small_brute_force_2d_Nd TestLinalg.test_matmul_small_brute_force_3d_Nd = matmul_small_brute_force_3d_Nd TestLinalg.test_ck_blas_library = ck_blas_library +TestLinalg.test_linalg_lu_family = linalg_lu_family +TestLinalg.test_linalg_lu_solve = linalg_lu_solve +TestLinalg.test_lu_unpack_check_input = lu_unpack_check_input + TestLinalg._default_dtype_check_enabled = True instantiate_device_type_tests(TestLinalg, globals(), only_for=("xpu"), allow_xpu=True) diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index e206fb1d5..9735b2fac 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -9350,3 +9350,96 @@ variants: function dispatch: XPU: _fft_r2c_xpu_out + +- func: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info) + python_module: linalg + structured_delegate: linalg_lu_factor_ex.out + variants: function + +- func: linalg_lu_factor_ex.out(Tensor A, *, bool pivot=True, bool check_errors=False, Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) + python_module: linalg + variants: function + structured: True + dispatch: + XPU: linalg_lu_factor_ex_out + +# linalg.lu_solve +- func: linalg_lu_solve(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False) -> Tensor + python_module: linalg + structured_delegate: linalg_lu_solve.out + variants: function + +- func: linalg_lu_solve.out(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + structured: True + dispatch: + XPU: linalg_lu_solve_out + +- func: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U) + structured_delegate: lu_unpack.out + variants: function + +- func: lu_unpack.out(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True, *, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) + variants: function + structured: True + dispatch: + CPU, CUDA: lu_unpack_out + +# linalg.det +- func: _linalg_det(Tensor A) -> (Tensor result, Tensor LU, Tensor pivots) + structured_delegate: _linalg_det.result + +- func: _linalg_det.result(Tensor A, *, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) + structured: True + dispatch: + XPU: _linalg_det_out + +- func: linalg_det(Tensor A) -> Tensor + python_module: linalg + variants: function + +- func: linalg_det.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + +# torch.det, alias for torch.linalg.det +- func: det(Tensor self) -> Tensor + variants: function, method + +# linalg.slogdet +- func: _linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet, Tensor LU, Tensor pivots) + structured_delegate: _linalg_slogdet.sign + +- func: _linalg_slogdet.sign(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) -> (Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) + structured: True + dispatch: + XPU: _linalg_slogdet_out + +- func: linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet) + python_module: linalg + +- func: linalg_slogdet.out(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet) + python_module: linalg + +- func: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) + variants: function, method + +- func: slogdet.out(Tensor self, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet) + variants: function + +- func: _linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor LU, Tensor pivots, Tensor info) + structured_delegate: _linalg_solve_ex.result + +- func: _linalg_solve_ex.result(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) + structured: True + dispatch: + XPU: _linalg_solve_ex_out + +- func: linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor info) + python_module: linalg + +- func: linalg_solve_ex.out(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) info) -> (Tensor(a!) result, Tensor(b!) info) + python_module: linalg + +- func: linalg_solve(Tensor A, Tensor B, *, bool left=True) -> Tensor + python_module: linalg