Skip to content

Commit d0390b5

Browse files
committed
cleanup
1 parent e49028b commit d0390b5

File tree

8 files changed

+189
-188
lines changed

8 files changed

+189
-188
lines changed

CMakeLists.txt

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -742,42 +742,33 @@ if(ESPRESSO_BUILD_WITH_METATENSOR)
742742
# expression from `metatensor_torch`
743743
find_package(Torch REQUIRED)
744744

745-
# # cmake-format: off
746-
# set(METATENSOR_URL_BASE "https://github.com/lab-cosmo/metatensor/releases/download")
747-
# set(METATENSOR_CORE_VERSION "0.1.8")
748-
# set(METATENSOR_TORCH_VERSION "0.5.3")
749-
#
750-
# include(FetchContent)
751-
# set(BUILD_SHARED_LIBS on CACHE BOOL "")
752-
# FetchContent_Declare(
753-
# metatensor
754-
# URL "${METATENSOR_URL_BASE}/metatensor-core-v${METATENSOR_CORE_VERSION}/metatensor-core-cxx-${METATENSOR_CORE_VERSION}.tar.gz"
755-
# URL_HASH SHA1=3ed389770e5ec6dbb8cbc9ed88f84d6809b552ef
756-
# )
757-
# set(BUILD_SHARED_LIBS on CACHE BOOL "")
758-
#
759-
# # workaround for https://gitlab.kitware.com/cmake/cmake/-/issues/21146
760-
# if(NOT DEFINED metatensor_SOURCE_DIR OR NOT EXISTS "${metatensor_SOURCE_DIR}")
761-
# message(STATUS "Fetching metatensor v${METATENSOR_CORE_VERSION} from github")
762-
# FetchContent_Populate(metatensor)
763-
# endif()
764-
# set(BUILD_SHARED_LIBS on CACHE BOOL "")
765-
#
766-
# FetchContent_Declare(
767-
# metatensor_torch
768-
# URL "${METATENSOR_URL_BASE}/metatensor-torch-v${METATENSOR_TORCH_VERSION}/metatensor-torch-cxx-${METATENSOR_TORCH_VERSION}.tar.gz"
769-
# )
770-
# set(BUILD_SHARED_LIBS on CACHE BOOL "")
771-
# if(NOT DEFINED metatensor_torch_SOURCE_DIR OR NOT EXISTS "${metatensor_torch_SOURCE_DIR}")
772-
# message(STATUS "Fetching metatensor torch v${METATENSOR_CORE_VERSION} from github")
773-
# FetchContent_Populate(metatensor_torch)
774-
# endif()
775-
# # cmake-format: on
776-
# set(BUILD_SHARED_LIBS on CACHE BOOL "")
777-
#
778-
# set(METATENSOR_INSTALL_BOTH_STATIC_SHARED on CACHE BOOL "")
779-
# add_subdirectory("${metatensor_SOURCE_DIR}")
780-
# add_subdirectory("${metatensor_torch_SOURCE_DIR}")
745+
# # cmake-format: off set(METATENSOR_URL_BASE
746+
# "https://github.com/lab-cosmo/metatensor/releases/download")
747+
# set(METATENSOR_CORE_VERSION "0.1.8") set(METATENSOR_TORCH_VERSION "0.5.3")
748+
#
749+
# include(FetchContent) set(BUILD_SHARED_LIBS on CACHE BOOL "")
750+
# FetchContent_Declare( metatensor URL
751+
# "${METATENSOR_URL_BASE}/metatensor-core-v${METATENSOR_CORE_VERSION}/metatensor-core-cxx-${METATENSOR_CORE_VERSION}.tar.gz"
752+
# URL_HASH SHA1=3ed389770e5ec6dbb8cbc9ed88f84d6809b552ef )
753+
# set(BUILD_SHARED_LIBS on CACHE BOOL "")
754+
#
755+
# # workaround for https://gitlab.kitware.com/cmake/cmake/-/issues/21146
756+
# if(NOT DEFINED metatensor_SOURCE_DIR OR NOT EXISTS
757+
# "${metatensor_SOURCE_DIR}") message(STATUS "Fetching metatensor
758+
# v${METATENSOR_CORE_VERSION} from github") FetchContent_Populate(metatensor)
759+
# endif() set(BUILD_SHARED_LIBS on CACHE BOOL "")
760+
#
761+
# FetchContent_Declare( metatensor_torch URL
762+
# "${METATENSOR_URL_BASE}/metatensor-torch-v${METATENSOR_TORCH_VERSION}/metatensor-torch-cxx-${METATENSOR_TORCH_VERSION}.tar.gz"
763+
# ) set(BUILD_SHARED_LIBS on CACHE BOOL "") if(NOT DEFINED
764+
# metatensor_torch_SOURCE_DIR OR NOT EXISTS "${metatensor_torch_SOURCE_DIR}")
765+
# message(STATUS "Fetching metatensor torch v${METATENSOR_CORE_VERSION} from
766+
# github") FetchContent_Populate(metatensor_torch) endif() # cmake-format: on
767+
# set(BUILD_SHARED_LIBS on CACHE BOOL "")
768+
#
769+
# set(METATENSOR_INSTALL_BOTH_STATIC_SHARED on CACHE BOOL "")
770+
# add_subdirectory("${metatensor_SOURCE_DIR}")
771+
# add_subdirectory("${metatensor_torch_SOURCE_DIR}")
781772
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
782773
find_package(metatensor)
783774
find_package(metatensor_torch)

src/core/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ if(ESPRESSO_BUILD_WITH_CUDA)
7474
endif()
7575
if(ESPRESSO_BUILD_WITH_METATENSOR)
7676
target_link_libraries(espresso_core PUBLIC "${TORCH_LIBRARIES}")
77-
target_link_libraries(espresso_core PUBLIC metatensor::shared)
78-
target_link_libraries(espresso_core PUBLIC metatensor_torch)
77+
target_link_libraries(espresso_core PUBLIC metatensor::shared)
78+
target_link_libraries(espresso_core PUBLIC metatensor_torch)
7979
endif()
8080

8181
install(TARGETS espresso_core

src/core/ml_metatensor/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@
1818
#
1919

2020
target_sources(espresso_core PRIVATE stub.cpp)
21-
target_sources(espresso_core PRIVATE load_model.cpp)
21+
# target_sources(espresso_core PRIVATE load_model.cpp)
Lines changed: 57 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,71 @@
1+
#include "metatensor/torch/atomistic/system.hpp"
2+
#include "utils/Vector.hpp"
3+
#include <variant>
4+
15
struct PairInfo {
2-
int part_id_1,
3-
int part_id_2,
6+
int part_id_1;
7+
int part_id_2;
48
Utils::Vector3d distance;
5-
}
6-
7-
using Sample = std::array<int_32_t,5>;
8-
using Distances =
9-
std::variant<std::vector<std::array<double,3>>, std::vector<std::array<float,3>>>;
9+
};
1010

11+
using Sample = std::array<int32_t, 5>;
12+
using Distances = std::variant<std::vector<std::array<double, 3>>,
13+
std::vector<std::array<float, 3>>>;
1114

1215
template <typename PairIterable>
13-
TorchTensorBlock neighbor_list_from_pairs(const metatensor_torch::System& system, const PairIterable& pairs) {
14-
auto dtype = system->positions().scalar_type();
15-
auto device = system->positions().device();
16-
std::vector<Sample> samples;
17-
Distances distances;
18-
if (dtype == torch::kFloat64) {
19-
distances = {std::vector<std::array<double,3>>()};
20-
}
21-
else if (dtype == torch::kFloat32) {
22-
distances = {std::vector<std::array<float,3>>()};
23-
}
24-
else {
25-
throw std::runtime_error("Unsupported floating poitn data type");
26-
}
16+
metatensor_torch::TorchTensorBlock
17+
neighbor_list_from_pairs(const metatensor_torch::System &system,
18+
const PairIterable &pairs) {
19+
auto dtype = system->positions().scalar_type();
20+
auto device = system->positions().device();
21+
std::vector<Sample> samples;
22+
Distances distances;
2723

28-
for (auto const& pair: pairs) {
29-
auto sample = Sample{
30-
pair.particle_id_1, pair.particle_id_2, 0, 0, 0};
31-
samples.push_back(sample);
32-
(*distances).push_back(pair.distance);
33-
}
24+
if (dtype == torch::kFloat64) {
25+
distances = {std::vector<std::array<double, 3>>()};
26+
} else if (dtype == torch::kFloat32) {
27+
distances = {std::vector<std::array<float, 3>>()};
28+
} else {
29+
throw std::runtime_error("Unsupported floating point data type");
30+
}
3431

32+
for (auto const &pair : pairs) {
33+
samples.emplace_back(pair.particle_id_1, pair.particle_id_2, 0, 0, 0);
34+
std::visit([&pair](auto &vec) { vec.push_back(pair.distance); }, distances);
35+
}
3536

36-
int64_t n_pairs = samples.size();
37-
auto samples_tensor = torch::from_blob(
38-
reinterpret_cast<int32_t*>(samples.data()),
39-
{n_pairs, 5},
40-
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU)
41-
);
37+
auto n_pairs = static_cast<int64_t>(samples.size());
4238

43-
auto samples = torch::make_intrusive<metatensor_torch::LabelsHolder>(
44-
std::vector<std::string>{"first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"},
45-
samples_values
46-
);
39+
auto samples_tensor = torch::from_blob(
40+
samples.data(), {n_pairs, 5},
41+
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU));
4742

48-
distances_vectors = torch::from_blob(
49-
(*distances).data(),
50-
{n_pairs, 3, 1},
51-
torch::TensorOptions().dtype(dtype).device(torch::kCPU)
52-
);
53-
return neighbors = torch::make_intrusive<metatensor_torch::TensorBlockHolder>(
54-
distances_vectors.to(dtype).to(device),
55-
samples->to(device),
56-
std::vector<metatensor_torch::TorchLabels>{
57-
metatensor_torch::LabelsHolder::create({"xyz"}, {{0}, {1}, {2}})->to(device),
58-
},
59-
metatensor_torch::LabelsHolder::create({"distance"}, {{0}})->to(device)
60-
);
43+
auto samples_ptr = torch::make_intrusive<metatensor_torch::LabelsHolder>(
44+
std::vector<std::string>{"first_atom", "second_atom", "cell_shift_a",
45+
"cell_shift_b", "cell_shift_c"},
46+
samples);
6147

62-
}
48+
auto distances_vectors = torch::from_blob(
49+
std::visit([](auto &vec) { return vec.data(); }, distances),
50+
{n_pairs, 3, 1}, torch::TensorOptions().dtype(dtype).device(torch::kCPU));
6351

64-
void add_neighbor_list_to_system(MetatensorTorch::system& system,
65-
const TorchTensorBlock& neighbors,
66-
const NeighborListOptions& options) {
67-
metatensor_torch::register_autograd_neighbors(system, neighbors, options_.check_consistency);
68-
system->add_neighbor_list(options, neighbors);
69-
}
52+
auto neighbors = torch::make_intrusive<metatensor_torch::TensorBlockHolder>(
53+
distances_vectors.to(dtype).to(device), samples_ptr->to(device),
54+
std::vector<metatensor_torch::TorchLabels>{
55+
metatensor_torch::LabelsHolder::create({"xyz"}, {{0}, {1}, {2}})
56+
->to(device),
57+
},
58+
metatensor_torch::LabelsHolder::create({"distance"}, {{0}})->to(device));
7059

60+
return neighbors;
61+
}
7162

63+
void add_neighbor_list_to_system(
64+
metatensor_torch::System &system,
65+
const metatensor_torch::TorchTensorBlock &neighbors,
66+
const metatensor_torch::NeighborListOptions &options,
67+
bool check_consistency) {
68+
metatensor_torch::register_autograd_neighbors(system, neighbors,
69+
check_consistency);
70+
system->add_neighbor_list(options, neighbors);
71+
}

src/core/ml_metatensor/compute.hpp

Lines changed: 72 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,82 @@
1-
2-
torch_metatensor::TensorMapHolder run_model(metatensor_toch::System& system,
3-
const metatensor_torch::ModelEvaluationOptions evaluation_options,
4-
torch::dtype dtypem
5-
torch::Device device) {
6-
7-
8-
// only run the calculation for atoms actually in the current domain
9-
auto options = torch::TensorOptions().dtype(torch::kInt32);
10-
this->selected_atoms_values = torch::zeros({n_particles, 2}, options);
11-
for (int i=0; i<n_atoms; i++) {
12-
selected_atoms_values[i][0] = 0;
13-
selected_atoms_values[i][1] = i;
14-
}
15-
auto selected_atoms = torch::make_intrusive<metatensor_torch::LabelsHolder>(
16-
std::vector<std::string>{"system", "atom"}, mts_data->selected_atoms_values
17-
);
18-
evaluation_options->set_selected_atoms(selected_atoms->to(device));
19-
20-
torch::IValue result_ivalue;
21-
model->forward({
22-
std::vector<metatensor_torch::System>{system},
23-
evaluation_options,
24-
check_consistency
25-
});
26-
27-
auto result = result_ivalue.toGenericDict();
28-
return result.at("energy").toCustomClass<metatensor_torch::TensorMapHolder>();
1+
#include "metatensor/torch/atomistic/system.hpp"
2+
#include <metatensor/torch/atomistic/model.hpp>
3+
#include <metatensor/torch/tensor.hpp>
4+
5+
metatensor_torch::TensorMapHolder
6+
run_model(metatensor_torch::System &system, int64_t n_particles,
7+
const metatensor_torch::ModelEvaluationOptions evaluation_options,
8+
torch::Dtype dtype, torch::Device device, bool check_consistency) {
9+
10+
// only run the calculation for atoms actually in the current domain
11+
auto options = torch::TensorOptions().dtype(torch::kInt32);
12+
auto selected_atoms_values = torch::zeros({n_particles, 2}, options);
13+
14+
for (int i = 0; i < n_particles; i++) {
15+
selected_atoms_values[i][0] = 0;
16+
selected_atoms_values[i][1] = i;
17+
}
18+
auto selected_atoms = torch::make_intrusive<metatensor_torch::LabelsHolder>(
19+
std::vector<std::string>{"system", "atom"}, selected_atoms_values);
20+
evaluation_options->set_selected_atoms(selected_atoms->to(device));
21+
22+
torch::IValue result_ivalue;
23+
model->forward({std::vector<metatensor_torch::System>{system},
24+
evaluation_options, check_consistency});
25+
26+
auto result = result_ivalue.toGenericDict();
27+
auto energy =
28+
result.at("energy").toCustomClass<metatensor_torch::TensorMapHolder>();
29+
auto energy_tensor =
30+
metatensor_torch::TensorMapHolder::block_by_id(energy, 0);
2931
}
3032

31-
double get_energy(torch_metatensor::TensorMapHolder& energy, bool energy_is_per_atom) {
32-
auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(energy, 0);
33-
auto energy_tensor = energy_block->values();
34-
auto energy_detached = energy_tensor.detach().to(torch::kCPU).to(torch::kFloat64);
35-
auto energy_samples = energy_block->samples();
36-
37-
// store the energy returned by the model
38-
torch::Tensor global_energy;
39-
if (energy_is_per_atom) {
40-
assert(energy_samples->size() == 2);
41-
assert(energy_samples->names()[0] == "system");
42-
assert(energy_samples->names()[1] == "atom");
43-
44-
auto samples_values = energy_samples->values().to(torch::kCPU);
45-
auto samples = samples_values.accessor<int32_t, 2>();
46-
47-
// int n_atoms = selected_atoms_values.sizes();
48-
// assert(samples_values.sizes() == selected_atoms_values.sizes());
49-
50-
auto energies = energy_detached.accessor<double, 2>();
51-
global_energy = energy_detached.sum(0);
52-
assert(energy_detached.sizes() == std::vector<int64_t>({1}));
53-
} else {
54-
assert(energy_samples->size() == 1);
55-
assert(energy_samples->names()[0] == "system");
56-
57-
assert(energy_detached.sizes() == std::vector<int64_t>({1, 1}));
58-
global_energy = energy_detached.reshape({1});
59-
}
33+
double get_energy(metatensor_torch::TensorMapHolder &energy,
34+
bool energy_is_per_atom) {
35+
auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(energy, 0);
36+
auto energy_tensor = energy_block->values();
37+
auto energy_detached =
38+
energy_tensor.detach().to(torch::kCPU).to(torch::kFloat64);
39+
auto energy_samples = energy_block->samples();
40+
41+
// store the energy returned by the model
42+
torch::Tensor global_energy;
43+
if (energy_is_per_atom) {
44+
assert(energy_samples->size() == 2);
45+
assert(energy_samples->names()[0] == "system");
46+
assert(energy_samples->names()[1] == "atom");
47+
48+
auto samples_values = energy_samples->values().to(torch::kCPU);
49+
auto samples = samples_values.accessor<int32_t, 2>();
50+
51+
// int n_atoms = selected_atoms_values.sizes();
52+
// assert(samples_values.sizes() == selected_atoms_values.sizes());
53+
54+
auto energies = energy_detached.accessor<double, 2>();
55+
global_energy = energy_detached.sum(0);
56+
assert(energy_detached.sizes() == std::vector<int64_t>({1}));
57+
} else {
58+
assert(energy_samples->size() == 1);
59+
assert(energy_samples->names()[0] == "system");
60+
61+
assert(energy_detached.sizes() == std::vector<int64_t>({1, 1}));
62+
global_energy = energy_detached.reshape({1});
63+
}
6064

6165
return global_energy.item<double>();
6266
}
6367

68+
torch::Tensor get_forces(metatensor::TensorMap &energy,
69+
metatensor_torch::System &system) {
70+
// reset gradients to zero before calling backward
71+
system->positions().mutable_grad() = torch::Tensor();
6472

65-
torch::Tensor get_forces(torch_metatensor::TensorMap& energy, torch_metatensor::System& system) {
66-
// reset gradients to zero before calling backward
67-
system->positions.mutable_grad() = torch::Tensor();
68-
69-
auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(energy, 0);
70-
auto energy_tensor = energy_block->values();
73+
auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(energy, 0);
74+
auto energy_tensor = energy_block->values();
7175

72-
// compute forces/virial with backward propagation
73-
energy_tensor.backward(-torch::ones_like(energy_tensor));
74-
auto forces_tensor = sytem->positions.grad();
75-
assert(forces_tensor.is_cpu() && forces_tensor.scalar_type() == torch::kFloat64);
76+
// compute forces/virial with backward propagation
77+
energy_tensor.backward(-torch::ones_like(energy_tensor));
78+
auto forces_tensor = system->positions().grad();
79+
assert(forces_tensor.is_cpu() &&
80+
forces_tensor.scalar_type() == torch::kFloat64);
7681
return forces_tensor;
7782
}
78-
79-
80-

src/core/ml_metatensor/stub.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
#include "config/config.hpp"
1+
#include "config/config.hpp"
22

33
#ifdef METATENSOR
44
#undef CUDA
5-
#include <torch/version.h>
6-
#include <torch/script.h>
75
#include <torch/cuda.h>
6+
#include <torch/script.h>
7+
#include <torch/version.h>
88

99
#if TORCH_VERSION_MAJOR >= 2
10-
#include <torch/mps.h>
10+
#include <torch/mps.h>
1111
#endif
1212

1313
#include <metatensor/torch.hpp>

src/core/ml_metatensor/stub.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#ifdef METATENSOR
44
#undef CUDA
55

6-
#include <torch/version.h>
7-
#include <torch/script.h>
86
#include <torch/cuda.h>
7+
#include <torch/script.h>
8+
#include <torch/version.h>
99
#endif

0 commit comments

Comments
 (0)