diff --git a/cmake/ONEMKL.cmake b/cmake/ONEMKL.cmake index 04518f7a6..4b1a0f8c4 100644 --- a/cmake/ONEMKL.cmake +++ b/cmake/ONEMKL.cmake @@ -21,3 +21,5 @@ set(TORCH_XPU_OPS_ONEMKL_LIBRARIES ${ONEMKL_LIBRARIES}) list(INSERT TORCH_XPU_OPS_ONEMKL_LIBRARIES 1 "-Wl,--start-group") list(APPEND TORCH_XPU_OPS_ONEMKL_LIBRARIES "-Wl,--end-group") +list(INSERT TORCH_XPU_OPS_ONEMKL_LIBRARIES 0 "-Wl,--no-as-needed") +list(INSERT TORCH_XPU_OPS_ONEMKL_LIBRARIES 2 "-Wl,--as-needed") diff --git a/src/ATen/CMakeLists.txt b/src/ATen/CMakeLists.txt index 22e060111..2f3cb2765 100644 --- a/src/ATen/CMakeLists.txt +++ b/src/ATen/CMakeLists.txt @@ -7,7 +7,9 @@ file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse file(GLOB xpu_sycl "native/xpu/sycl/*.cpp" "native/sparse/xpu/sycl/*.cpp" "native/nested/xpu/sycl/*.cpp" "native/transformers/sycl/*.cpp" "native/quantized/sycl/*.cpp") list(APPEND ATen_XPU_CPP_SRCS ${xpu_cpp}) -list(APPEND ATen_XPU_MKL_SRCS ${xpu_mkl}) +if(USE_ONEMKL) + list(APPEND ATen_XPU_MKL_SRCS ${xpu_mkl}) +endif() list(APPEND ATen_XPU_NATIVE_CPP_SRCS ${xpu_native_cpp}) list(APPEND ATen_XPU_SYCL_SRCS ${xpu_sycl}) diff --git a/src/ATen/native/xpu/BatchLinearAlgebra.cpp b/src/ATen/native/xpu/BatchLinearAlgebra.cpp new file mode 100644 index 000000000..ec55c8f43 --- /dev/null +++ b/src/ATen/native/xpu/BatchLinearAlgebra.cpp @@ -0,0 +1,61 @@ +#include +#include +#include +#include +#if defined(USE_ONEMKL) +#include +#endif // USE_ONEMKL + +namespace at::native { + +void svd_kernel_xpu( + const Tensor& A, + const bool full_matrices, + const bool compute_uv, + const c10::optional& driver, + const Tensor& U, + const Tensor& S, + const Tensor& Vh, + const Tensor& info) { +#if defined(USE_ONEMKL) + native::xpu::svd_mkl(A, full_matrices, compute_uv, driver, U, S, Vh, info); +#else + const auto A_cpu = A.to( + A.options().device(kCPU).memory_format(at::MemoryFormat::Contiguous)); + // U, S, Vh, info are the right size and strides, but these tensors are on GPU + // and need to be copied + const auto empty_like_cpu = [](const Tensor& t) { + return at::empty_like(t, t.options().device(kCPU)); + }; + + auto U_cpu = compute_uv ? empty_like_cpu(U) : Tensor{}; + auto S_cpu = empty_like_cpu(S); + auto Vh_cpu = compute_uv ? empty_like_cpu(Vh) : Tensor{}; + auto info_cpu = empty_like_cpu(info); + + svd_stub( + at::kCPU, + A_cpu, + full_matrices, + compute_uv, + driver, + U_cpu, + S_cpu, + Vh_cpu, + info_cpu); + + // Copy from CPU back to XPU + // We can do a non_blocking copy, as there is an unconditional check of the + // infos in the calling function + if (compute_uv) { + U.copy_(U_cpu); + Vh.copy_(Vh_cpu); + } + S.copy_(S_cpu); + info.copy_(info_cpu); +#endif // USE_ONEMKL +} + +REGISTER_XPU_DISPATCH(svd_stub, &svd_kernel_xpu); + +} // namespace at::native diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index dded45ad3..86e02c34c 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -211,7 +211,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "_linalg_slogdet.sign", "_linalg_solve_ex.result", "linalg_solve_triangular", - "_linalg_svd.U", "lu_unpack.out", "ormqr", "_scaled_mm", diff --git a/src/ATen/native/xpu/mkl/BatchLinearAlgebra.cpp b/src/ATen/native/xpu/mkl/BatchLinearAlgebra.cpp new file mode 100644 index 000000000..695d73f72 --- /dev/null +++ b/src/ATen/native/xpu/mkl/BatchLinearAlgebra.cpp @@ -0,0 +1,349 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace at::native::xpu { + +namespace impl { + +#define SYCL_ONEMKL_SUBMIT(q, routine, ...) \ + { \ + auto e = (routine(__VA_ARGS__)); \ + (q).throw_asynchronous(); \ + } + +static inline std::tuple _create_U_S_VT( + const Tensor& input, + bool some, + bool compute_uv) { + auto sizes = input.sizes().vec(); + int64_t m = input.size(-2), n = input.size(-1); + auto k = std::min(m, n); + + sizes[input.dim() - 1] = (compute_uv && some) ? k : m; + auto U_strides = + at::native::batched_matrix_contiguous_strides(sizes, /*f-contig*=*/true); + // U should be a column-major or a batch of column-major matrices + // ... x m x ucol will have strides: ...., ucol, 1 + // We require: ...., 1, m + + Tensor U_empty; + U_empty = at::empty_strided(sizes, U_strides, input.options()); + + sizes[input.dim() - 2] = some ? k : n; + sizes[input.dim() - 1] = n; + auto Vh_strides = + at::native::batched_matrix_contiguous_strides(sizes, /*f-contig*=*/true); + Tensor VT_empty; + VT_empty = at::empty_strided(sizes, Vh_strides, input.options()); + + sizes.pop_back(); + sizes[input.dim() - 2] = std::min(m, n); + Tensor S_empty; + ScalarType dtype = toRealValueType(typeMetaToScalarType(input.dtype())); + S_empty = at::empty(sizes, input.options().dtype(dtype)); + return std::tuple(U_empty, S_empty, VT_empty); +} + +template +static void apply_svd( + sycl::queue& queue, + scalar_t* self_data, + int64_t lda, + int64_t self_stride, + int64_t batchsize, + int64_t m, + int64_t n, + TensorOptions self_opt, + scalar_t* U_data, + int64_t ldu, + int64_t U_stride, + value_t* S_data, + int64_t S_stride, + scalar_t* VT_data, + int64_t ldvt, + int64_t VT_stride, + char jobz) { + oneapi::mkl::jobsvd jobu, jobvt; + if (jobz == 'N') { + jobu = oneapi::mkl::jobsvd::N; + jobvt = oneapi::mkl::jobsvd::N; + } else if (jobz == 'S') { + jobu = oneapi::mkl::jobsvd::S; + jobvt = oneapi::mkl::jobsvd::S; + } else { + jobu = oneapi::mkl::jobsvd::A; + jobvt = oneapi::mkl::jobsvd::A; + } + + std::int64_t scratchpadsize = + oneapi::mkl::lapack::gesvd_scratchpad_size( + queue, jobu, jobvt, m, n, lda, ldu, ldvt); + Tensor scratchpad_at = at::empty({scratchpadsize}, self_opt); + + for (int64_t i = 0; i < batchsize; i++) { + scalar_t* self_working_ptr = &self_data[i * self_stride]; + scalar_t* U_working_ptr = &U_data[i * U_stride]; + value_t* S_working_ptr = &S_data[i * S_stride]; + scalar_t* VT_working_ptr = &VT_data[i * VT_stride]; + + SYCL_ONEMKL_SUBMIT( + queue, + oneapi::mkl::lapack::gesvd, + queue, + jobu, + jobvt, + m, + n, + self_working_ptr, + lda, + S_working_ptr, + U_working_ptr, + ldu, + VT_working_ptr, + ldvt, + scratchpad_at.data_ptr(), + scratchpadsize); + } +} // namespace impl + +template <> +void apply_svd, double>( + sycl::queue& queue, + c10::complex* self_data, + int64_t lda, + int64_t self_stride, + int64_t batchsize, + int64_t m, + int64_t n, + TensorOptions self_opt, + c10::complex* U_data, + int64_t ldu, + int64_t U_stride, + double* S_data, + int64_t S_stride, + c10::complex* VT_data, + int64_t ldvt, + int64_t VT_stride, + char jobz) { + oneapi::mkl::jobsvd jobu, jobvt; + if (jobz == 'N') { + jobu = oneapi::mkl::jobsvd::N; + jobvt = oneapi::mkl::jobsvd::N; + } else if (jobz == 'S') { + jobu = oneapi::mkl::jobsvd::S; + jobvt = oneapi::mkl::jobsvd::S; + } else { + jobu = oneapi::mkl::jobsvd::A; + jobvt = oneapi::mkl::jobsvd::A; + } + + std::int64_t scratchpadsize = + oneapi::mkl::lapack::gesvd_scratchpad_size>( + queue, jobu, jobvt, m, n, lda, ldu, ldvt); + Tensor scratchpad_at = at::empty({scratchpadsize}, self_opt); + + for (int64_t i = 0; i < batchsize; i++) { + c10::complex* self_working_ptr = &self_data[i * self_stride]; + c10::complex* U_working_ptr = &U_data[i * U_stride]; + double* S_working_ptr = &S_data[i * S_stride]; + c10::complex* VT_working_ptr = &VT_data[i * VT_stride]; + + SYCL_ONEMKL_SUBMIT( + queue, + oneapi::mkl::lapack::gesvd, + queue, + jobu, + jobvt, + m, + n, + reinterpret_cast*>(self_working_ptr), + lda, + S_working_ptr, + reinterpret_cast*>(U_working_ptr), + ldu, + reinterpret_cast*>(VT_working_ptr), + ldvt, + reinterpret_cast*>(scratchpad_at.data_ptr()), + scratchpadsize); + } +} + +template <> +void apply_svd, float>( + sycl::queue& queue, + c10::complex* self_data, + int64_t lda, + int64_t self_stride, + int64_t batchsize, + int64_t m, + int64_t n, + TensorOptions self_opt, + c10::complex* U_data, + int64_t ldu, + int64_t U_stride, + float* S_data, + int64_t S_stride, + c10::complex* VT_data, + int64_t ldvt, + int64_t VT_stride, + char jobz) { + oneapi::mkl::jobsvd jobu, jobvt; + if (jobz == 'N') { + jobu = oneapi::mkl::jobsvd::N; + jobvt = oneapi::mkl::jobsvd::N; + } else if (jobz == 'S') { + jobu = oneapi::mkl::jobsvd::S; + jobvt = oneapi::mkl::jobsvd::S; + } else { + jobu = oneapi::mkl::jobsvd::A; + jobvt = oneapi::mkl::jobsvd::A; + } + + std::int64_t scratchpadsize = + oneapi::mkl::lapack::gesvd_scratchpad_size>( + queue, jobu, jobvt, m, n, lda, ldu, ldvt); + Tensor scratchpad_at = at::empty({scratchpadsize}, self_opt); + + for (int64_t i = 0; i < batchsize; i++) { + c10::complex* self_working_ptr = &self_data[i * self_stride]; + c10::complex* U_working_ptr = &U_data[i * U_stride]; + float* S_working_ptr = &S_data[i * S_stride]; + c10::complex* VT_working_ptr = &VT_data[i * VT_stride]; + + SYCL_ONEMKL_SUBMIT( + queue, + oneapi::mkl::lapack::gesvd, + queue, + jobu, + jobvt, + m, + n, + reinterpret_cast*>(self_working_ptr), + lda, + S_working_ptr, + reinterpret_cast*>(U_working_ptr), + ldu, + reinterpret_cast*>(VT_working_ptr), + ldvt, + reinterpret_cast*>(scratchpad_at.data_ptr()), + scratchpadsize); + } +} + +} // namespace impl + +std::tuple _svd_helper( + const Tensor& self, + bool some, + bool compute_uv) { + auto infos_tensor = at::zeros( + native::batchCount(self), + self.options().dtype(kInt).device(DeviceType::CPU)); + std::vector infos(native::batchCount(self), 0); + + char jobz = compute_uv ? (some ? 'S' : 'A') : 'N'; + + Tensor U_working_copy, S_working_copy, VT_working_copy; + std::tie(U_working_copy, S_working_copy, VT_working_copy) = + impl::_create_U_S_VT(self, some, compute_uv); + + if (self.numel() > 0) { + auto self_working_copy = native::cloneBatchedColumnMajor(self); + auto& queue = at::xpu::getCurrentSYCLQueue(); + auto self_stride = at::native::matrixStride(self_working_copy); + auto U_stride = compute_uv ? at::native::matrixStride(U_working_copy) : 1; + auto S_stride = S_working_copy.size(-1); + auto VT_stride = compute_uv ? at::native::matrixStride(VT_working_copy) : 1; + auto batchsize = at::native::batchCount(self_working_copy); + + auto m = self_working_copy.size(-2); + auto n = self_working_copy.size(-1); + int64_t lda = self_working_copy.stride(-1); + int64_t ldu = compute_uv ? U_working_copy.stride(-1) : 1; + int64_t ldvt = compute_uv ? VT_working_copy.stride(-1) : 1; + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "svd_xpu", [&] { + using value_t = typename c10::scalar_value_type::type; + impl::apply_svd( + queue, + self_working_copy.data_ptr(), + lda, + self_stride, + batchsize, + m, + n, + self.options(), + compute_uv ? U_working_copy.data_ptr() : nullptr, + ldu, + U_stride, + S_working_copy.data_ptr(), + S_stride, + compute_uv ? VT_working_copy.data_ptr() : nullptr, + ldvt, + VT_stride, + jobz); + }); + + std::copy( + infos.begin(), infos.end(), infos_tensor.template data_ptr()); + at::_linalg_check_errors(infos_tensor, "svd_xpu", self.dim() == 2); + + if (!compute_uv) { + VT_working_copy.zero_(); + U_working_copy.zero_(); + } + } else { + U_working_copy.zero_(); + VT_working_copy.zero_(); + } + + return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); +} + +static void svd_resize_and_copy( + const char* name, + const Tensor& src, + const Tensor& dst) { + TORCH_CHECK( + src.device() == dst.device(), + "svd output tensor ", + name, + " is on the wrong device: expected ", + src.device(), + " got ", + dst.device()); + at::native::resize_output(dst, src.sizes()); + dst.copy_(src); +} + +void svd_mkl( + const Tensor& A, + const bool full_matrices, + const bool compute_uv, + const c10::optional& driver, + const Tensor& U, + const Tensor& S, + const Tensor& Vh, + const Tensor& info) { + Tensor U_tmp, S_tmp, Vh_tmp; + bool some = !full_matrices; + std::tie(U_tmp, S_tmp, Vh_tmp) = _svd_helper(A, some, compute_uv); + + // TODO: Remove copy + if (compute_uv) { + svd_resize_and_copy("U", U_tmp, U); + svd_resize_and_copy("Vh", Vh_tmp, Vh); + } + svd_resize_and_copy("S", S_tmp, S); +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/mkl/BatchLinearAlgebra.h b/src/ATen/native/xpu/mkl/BatchLinearAlgebra.h new file mode 100644 index 000000000..a808a31ba --- /dev/null +++ b/src/ATen/native/xpu/mkl/BatchLinearAlgebra.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +namespace at::native::xpu { + +TORCH_XPU_API void svd_mkl( + const Tensor& A, + const bool full_matrices, + const bool compute_uv, + const c10::optional& driver, + const Tensor& U, + const Tensor& S, + const Tensor& Vh, + const Tensor& info); + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/mkl/SpectralOps.cpp b/src/ATen/native/xpu/mkl/SpectralOps.cpp index 4f1e028b4..69e25621d 100644 --- a/src/ATen/native/xpu/mkl/SpectralOps.cpp +++ b/src/ATen/native/xpu/mkl/SpectralOps.cpp @@ -1,4 +1,3 @@ -#if defined(USE_ONEMKL) #include #include #include @@ -591,4 +590,3 @@ Tensor& _fft_r2c_mkl_out( } } // namespace at::native::xpu -#endif // USE_ONEMKL diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 98245cf49..431fe6b48 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -317,18 +317,15 @@ "test_out_warning_nn_functional_linear_xpu", "test_python_ref__refs_linalg_svd_xpu_complex128", "test_python_ref__refs_linalg_svd_xpu_complex64", - "test_python_ref__refs_linalg_svd_xpu_float64", "test_python_ref_executor__refs_linalg_svd_executor_aten_xpu_complex128", "test_python_ref_executor__refs_linalg_svd_executor_aten_xpu_complex64", "test_python_ref_executor__refs_linalg_svd_executor_aten_xpu_float64", "test_python_ref_executor__refs_nn_functional_pdist_executor_aten_xpu_float64", "test_python_ref_meta__refs_linalg_svd_xpu_complex128", "test_python_ref_meta__refs_linalg_svd_xpu_complex64", - "test_python_ref_meta__refs_linalg_svd_xpu_float64", "test_python_ref_meta__refs_nn_functional_pdist_xpu_float64", "test_python_ref_torch_fallback__refs_linalg_svd_xpu_complex128", "test_python_ref_torch_fallback__refs_linalg_svd_xpu_complex64", - "test_python_ref_torch_fallback__refs_linalg_svd_xpu_float64", "test_python_ref_torch_fallback__refs_nn_functional_pdist_xpu_float64", "test_variant_consistency_eager___rmatmul___xpu_complex64", "test_variant_consistency_eager_addmm_decomposed_xpu_complex64", @@ -530,7 +527,6 @@ "test_neg_conj_view_tensordot_xpu_complex128", "test_neg_conj_view_triangular_solve_xpu_complex128", "test_neg_view___rmatmul___xpu_float64", - "test_neg_view__refs_linalg_svd_xpu_float64", "test_neg_view__refs_nn_functional_pdist_xpu_float64", "test_neg_view_addbmm_xpu_float64", "test_neg_view_addmm_decomposed_xpu_float64", @@ -581,7 +577,6 @@ "test_neg_view_linalg_solve_ex_xpu_float64", "test_neg_view_linalg_solve_triangular_xpu_float64", "test_neg_view_linalg_solve_xpu_float64", - "test_neg_view_linalg_svd_xpu_float64", "test_neg_view_linalg_svdvals_xpu_float64", "test_neg_view_linalg_tensorinv_xpu_float64", "test_neg_view_linalg_tensorsolve_xpu_float64", @@ -601,7 +596,6 @@ "test_neg_view_pinverse_xpu_float64", "test_neg_view_qr_xpu_float64", "test_neg_view_svd_lowrank_xpu_float64", - "test_neg_view_svd_xpu_float64", "test_neg_view_tensordot_xpu_float64", "test_neg_view_triangular_solve_xpu_float64", "test_noncontiguous_samples_pca_lowrank_xpu_complex64", @@ -1395,7 +1389,6 @@ "test_svd_lowrank_xpu_float64", "test_svd_xpu_complex128", "test_svd_xpu_complex64", - "test_svd_xpu_float64", "test_triangular_solve_batched_broadcasting_xpu_complex128", "test_triangular_solve_batched_broadcasting_xpu_complex64", "test_triangular_solve_batched_broadcasting_xpu_float64", @@ -1576,7 +1569,6 @@ "test_fn_fwgrad_bwgrad_linalg_solve_xpu_complex128", "test_fn_fwgrad_bwgrad_linalg_solve_xpu_float64", "test_fn_fwgrad_bwgrad_linalg_svd_xpu_complex128", - "test_fn_fwgrad_bwgrad_linalg_svd_xpu_float64", "test_fn_fwgrad_bwgrad_linalg_svdvals_xpu_complex128", "test_fn_fwgrad_bwgrad_linalg_svdvals_xpu_float64", "test_fn_fwgrad_bwgrad_linalg_tensorinv_xpu_complex128", @@ -1611,7 +1603,6 @@ "test_fn_fwgrad_bwgrad_qr_xpu_float64", "test_fn_fwgrad_bwgrad_svd_lowrank_xpu_float64", "test_fn_fwgrad_bwgrad_svd_xpu_complex128", - "test_fn_fwgrad_bwgrad_svd_xpu_float64", "test_fn_fwgrad_bwgrad_tensordot_xpu_complex128", "test_fn_fwgrad_bwgrad_tensordot_xpu_float64", "test_forward_mode_AD___rmatmul___xpu_complex128", @@ -1699,7 +1690,6 @@ "test_forward_mode_AD_linalg_solve_xpu_complex128", "test_forward_mode_AD_linalg_solve_xpu_float64", "test_forward_mode_AD_linalg_svd_xpu_complex128", - "test_forward_mode_AD_linalg_svd_xpu_float64", "test_forward_mode_AD_linalg_svdvals_xpu_complex128", "test_forward_mode_AD_linalg_svdvals_xpu_float64", "test_forward_mode_AD_linalg_tensorinv_xpu_complex128", @@ -1730,7 +1720,6 @@ "test_forward_mode_AD_qr_xpu_float64", "test_forward_mode_AD_svd_lowrank_xpu_float64", "test_forward_mode_AD_svd_xpu_complex128", - "test_forward_mode_AD_svd_xpu_float64", "test_forward_mode_AD_tensordot_xpu_complex128", "test_forward_mode_AD_tensordot_xpu_float64", "test_forward_mode_AD_triangular_solve_xpu_complex128", @@ -1935,7 +1924,6 @@ "test_fn_grad_linalg_solve_xpu_complex128", "test_fn_grad_linalg_solve_xpu_float64", "test_fn_grad_linalg_svd_xpu_complex128", - "test_fn_grad_linalg_svd_xpu_float64", "test_fn_grad_linalg_svdvals_xpu_complex128", "test_fn_grad_linalg_svdvals_xpu_float64", "test_fn_grad_linalg_tensorinv_xpu_complex128", @@ -1970,7 +1958,6 @@ "test_fn_grad_qr_xpu_float64", "test_fn_grad_svd_lowrank_xpu_float64", "test_fn_grad_svd_xpu_complex128", - "test_fn_grad_svd_xpu_float64", "test_fn_grad_tensordot_xpu_complex128", "test_fn_grad_tensordot_xpu_float64", "test_fn_grad_triangular_solve_xpu_complex128", @@ -2056,7 +2043,6 @@ "test_fn_gradgrad_linalg_solve_xpu_complex128", "test_fn_gradgrad_linalg_solve_xpu_float64", "test_fn_gradgrad_linalg_svd_xpu_complex128", - "test_fn_gradgrad_linalg_svd_xpu_float64", "test_fn_gradgrad_linalg_svdvals_xpu_complex128", "test_fn_gradgrad_linalg_svdvals_xpu_float64", "test_fn_gradgrad_linalg_tensorinv_xpu_complex128", @@ -2091,7 +2077,6 @@ "test_fn_gradgrad_qr_xpu_float64", "test_fn_gradgrad_svd_lowrank_xpu_float64", "test_fn_gradgrad_svd_xpu_complex128", - "test_fn_gradgrad_svd_xpu_float64", "test_fn_gradgrad_tensordot_xpu_complex128", "test_fn_gradgrad_tensordot_xpu_float64", "test_fn_gradgrad_triangular_solve_xpu_complex128", @@ -2429,7 +2414,6 @@ "test_dispatch_meta_outplace_linalg_solve_xpu_complex", "test_dispatch_meta_outplace_linalg_solve_xpu_float64", "test_dispatch_meta_outplace_linalg_svd_xpu_complex", - "test_dispatch_meta_outplace_linalg_svd_xpu_float64", "test_dispatch_meta_outplace_linalg_tensorinv_xpu_complex", "test_dispatch_meta_outplace_linalg_tensorinv_xpu_float64", "test_dispatch_meta_outplace_logdet_xpu_complex", @@ -2458,7 +2442,6 @@ "test_dispatch_meta_outplace_svd_lowrank_xpu_complex", "test_dispatch_meta_outplace_svd_lowrank_xpu_float64", "test_dispatch_meta_outplace_svd_xpu_complex", - "test_dispatch_meta_outplace_svd_xpu_float64", "test_dispatch_meta_outplace_tensordot_xpu_complex", "test_dispatch_meta_outplace_tensordot_xpu_float64", "test_dispatch_meta_outplace_triangular_solve_xpu_complex", @@ -2558,7 +2541,6 @@ "test_dispatch_symbolic_meta_outplace_linalg_solve_xpu_complex", "test_dispatch_symbolic_meta_outplace_linalg_solve_xpu_float64", "test_dispatch_symbolic_meta_outplace_linalg_svd_xpu_complex", - "test_dispatch_symbolic_meta_outplace_linalg_svd_xpu_float64", "test_dispatch_symbolic_meta_outplace_linalg_tensorinv_xpu_complex", "test_dispatch_symbolic_meta_outplace_linalg_tensorinv_xpu_float64", "test_dispatch_symbolic_meta_outplace_logdet_xpu_complex", @@ -2587,7 +2569,6 @@ "test_dispatch_symbolic_meta_outplace_svd_lowrank_xpu_complex", "test_dispatch_symbolic_meta_outplace_svd_lowrank_xpu_float64", "test_dispatch_symbolic_meta_outplace_svd_xpu_complex", - "test_dispatch_symbolic_meta_outplace_svd_xpu_float64", "test_dispatch_symbolic_meta_outplace_tensordot_xpu_complex", "test_dispatch_symbolic_meta_outplace_tensordot_xpu_float64", "test_dispatch_symbolic_meta_outplace_triangular_solve_xpu_complex", @@ -2687,7 +2668,6 @@ "test_meta_outplace_linalg_solve_xpu_complex", "test_meta_outplace_linalg_solve_xpu_float64", "test_meta_outplace_linalg_svd_xpu_complex", - "test_meta_outplace_linalg_svd_xpu_float64", "test_meta_outplace_linalg_tensorinv_xpu_complex", "test_meta_outplace_linalg_tensorinv_xpu_float64", "test_meta_outplace_logdet_xpu_complex", @@ -2716,7 +2696,6 @@ "test_meta_outplace_svd_lowrank_xpu_complex", "test_meta_outplace_svd_lowrank_xpu_float64", "test_meta_outplace_svd_xpu_complex", - "test_meta_outplace_svd_xpu_float64", "test_meta_outplace_tensordot_xpu_complex", "test_meta_outplace_tensordot_xpu_float64", "test_meta_outplace_triangular_solve_xpu_complex", diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index e206fb1d5..be242de40 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -9350,3 +9350,22 @@ variants: function dispatch: XPU: _fft_r2c_xpu_out + +# This function exposes the `compute_uv` flag, which is then used to implement `linalg.svd` and +# `linalg.svdvals` as composite functions that call this one +- func: _linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh) + variants: function + structured_delegate: _linalg_svd.U + +- func: _linalg_svd.U(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) + structured: True + dispatch: + XPU: _linalg_svd_out + +- func: linalg_svd(Tensor A, bool full_matrices=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh) + python_module: linalg + variants: function + +- func: linalg_svd.U(Tensor A, bool full_matrices=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) + python_module: linalg + variants: function