diff --git a/WORKSPACE b/WORKSPACE index 65cd2988473..2597b8765c9 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -106,6 +106,14 @@ onedal_repo( root_env_var = "DALROOT", ) +http_archive( + name = "boost", + url = "https://archives.boost.io/release/1.86.0/source/boost_1_86_0.tar.gz", + sha256 = "2575e74ffc3ef1cd0babac2c1ee8bdb5782a0ee672b1912da40e5b4b591ca01f", + strip_prefix = "boost_1_86_0", + build_file = "@onedal//dev/bazel/deps:boost.tpl.BUILD", +) + http_archive( name = "catch2", url = "https://github.com/catchorg/Catch2/archive/v3.7.1.tar.gz", @@ -113,6 +121,14 @@ http_archive( strip_prefix = "Catch2-3.7.1", ) +http_archive( + name = "eigen", + url = "https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.tar.gz", + sha256 = "8586084f71f9bde545ee7fa6d00288b264a2b7ac3607b974e54d13e7162c1c72", + build_file = "@onedal//dev/bazel/deps:eigen.tpl.BUILD", + strip_prefix = "eigen-3.4.0", +) + http_archive( name = "fmt", url = "https://github.com/fmtlib/fmt/archive/11.0.2.tar.gz", diff --git a/cpp/oneapi/dal/backend/primitives/lapack/test/syevd_dpc.cpp b/cpp/oneapi/dal/backend/primitives/lapack/test/syevd_dpc.cpp index 56484014a81..bf3dc1aee8c 100644 --- a/cpp/oneapi/dal/backend/primitives/lapack/test/syevd_dpc.cpp +++ b/cpp/oneapi/dal/backend/primitives/lapack/test/syevd_dpc.cpp @@ -128,6 +128,46 @@ class syevd_test : public te::float_algo_fixture { } } + void check_eigvals_with_eigen(const la::matrix& s, + const la::matrix& eigvecs, + const la::matrix& eigvals) const { + INFO("convert results to float64"); + const auto s_f64 = la::astype(s); + const auto eigvals_f64 = la::astype(eigvals); + const auto eigvecs_f64 = la::astype(eigvecs); + std::int64_t row_count = s.get_row_count(); + std::int64_t column_count = s.get_column_count(); + const Float* data = s.get_data(); + + Eigen::Matrix eigen_matrix(row_count, column_count); + for (int i = 0; i < eigen_matrix.rows(); ++i) { + for (int j = 0; j < eigen_matrix.cols(); ++j) { + eigen_matrix(i, j) = data[i * column_count + j]; + } + } + + Eigen::SelfAdjointEigenSolver> es( + eigen_matrix); + + auto eigenvalues = es.eigenvalues().real(); + INFO("oneDAL eigvals vs Eigen eigvals"); + la::enumerate_linear(eigvals_f64, [&](std::int64_t i, Float x) { + REQUIRE(abs(eigvals_f64.get(i) - eigenvalues(i)) < 0.1); + }); + + INFO("oneDAL eigvectors vs Eigen eigvectors"); + auto eigenvectors = es.eigenvectors().real(); + + const double* eigenvec_ptr = eigvecs_f64.get_data(); + //TODO: investigate Eigen classes and align checking between oneDAL and Eigen classes. + for (int j = 0; j < eigvecs.get_column_count(); ++j) { + auto column_eigen = eigenvectors.col(j); + for (int i = 0; i < eigvecs.get_row_count(); ++i) { + REQUIRE((abs(eigenvec_ptr[j * row_count + i]) - abs(column_eigen(i))) < 0.1); + } + } + } + void check_eigvals_are_ascending(const la::matrix& eigvals) const { INFO("check eigenvalues order is ascending"); la::enumerate_linear(eigvals, [&](std::int64_t i, Float x) { @@ -158,6 +198,7 @@ TEMPLATE_LIST_TEST_M(syevd_test, "test syevd with pos def matrix", "[sym_eigvals this->check_eigvals_definition(s, eigenvectors, eigenvalues); this->check_eigvals_are_ascending(eigenvalues); + this->check_eigvals_with_eigen(s, eigenvectors, eigenvalues); } TEMPLATE_LIST_TEST_M(syevd_test, "test syevd with pos def matrix 2", "[sym_eigvals]", eigen_types) { diff --git a/cpp/oneapi/dal/test/engine/BUILD b/cpp/oneapi/dal/test/engine/BUILD index 6432a50021e..5732edd02a2 100644 --- a/cpp/oneapi/dal/test/engine/BUILD +++ b/cpp/oneapi/dal/test/engine/BUILD @@ -22,7 +22,9 @@ dal_test_module( "@onedal//cpp/oneapi/dal/test/engine/metrics", ], extra_deps = [ + "@boost//:boost", "@catch2//:catch2", + "@eigen//:eigen", "@fmt//:fmt", ], ) diff --git a/cpp/oneapi/dal/test/engine/common.hpp b/cpp/oneapi/dal/test/engine/common.hpp index 093e945b55b..765aadfc817 100644 --- a/cpp/oneapi/dal/test/engine/common.hpp +++ b/cpp/oneapi/dal/test/engine/common.hpp @@ -22,6 +22,11 @@ #include #include +//Necessary headers from boost +#include + +#include +#include #include "oneapi/dal/train.hpp" #include "oneapi/dal/infer.hpp" diff --git a/dev/bazel/deps/boost.tpl.BUILD b/dev/bazel/deps/boost.tpl.BUILD new file mode 100644 index 00000000000..e6d5c32872c --- /dev/null +++ b/dev/bazel/deps/boost.tpl.BUILD @@ -0,0 +1,18 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "boost", + srcs = glob([ + "libs/libboost*.a", + ]), + hdrs = glob([ + "boost/**/*.h", + "boost/**/*.hpp", + "boost/**/*.ipp", + ]), + includes = [ + ".", + ], + visibility = ["//visibility:public"], +) + diff --git a/dev/bazel/deps/eigen.tpl.BUILD b/dev/bazel/deps/eigen.tpl.BUILD new file mode 100644 index 00000000000..e4d892ecfcb --- /dev/null +++ b/dev/bazel/deps/eigen.tpl.BUILD @@ -0,0 +1,8 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "eigen", + hdrs = glob(["Eigen/**"]), + includes = [""], + visibility = ["//visibility:public"], +)