|
1 | | - |
2 | | -torch_metatensor::TensorMapHolder run_model(metatensor_toch::System& system, |
3 | | - const metatensor_torch::ModelEvaluationOptions evaluation_options, |
4 | | - torch::dtype dtypem |
5 | | - torch::Device device) { |
6 | | - |
7 | | - |
8 | | - // only run the calculation for atoms actually in the current domain |
9 | | - auto options = torch::TensorOptions().dtype(torch::kInt32); |
10 | | - this->selected_atoms_values = torch::zeros({n_particles, 2}, options); |
11 | | - for (int i=0; i<n_atoms; i++) { |
12 | | - selected_atoms_values[i][0] = 0; |
13 | | - selected_atoms_values[i][1] = i; |
14 | | - } |
15 | | - auto selected_atoms = torch::make_intrusive<metatensor_torch::LabelsHolder>( |
16 | | - std::vector<std::string>{"system", "atom"}, mts_data->selected_atoms_values |
17 | | - ); |
18 | | - evaluation_options->set_selected_atoms(selected_atoms->to(device)); |
19 | | - |
20 | | - torch::IValue result_ivalue; |
21 | | - model->forward({ |
22 | | - std::vector<metatensor_torch::System>{system}, |
23 | | - evaluation_options, |
24 | | - check_consistency |
25 | | - }); |
26 | | - |
27 | | - auto result = result_ivalue.toGenericDict(); |
28 | | - return result.at("energy").toCustomClass<metatensor_torch::TensorMapHolder>(); |
| 1 | +#include "metatensor/torch/atomistic/system.hpp" |
| 2 | +#include <metatensor/torch/atomistic/model.hpp> |
| 3 | +#include <metatensor/torch/tensor.hpp> |
| 4 | + |
| 5 | +metatensor_torch::TensorMapHolder |
| 6 | +run_model(metatensor_torch::System &system, int64_t n_particles, |
| 7 | + const metatensor_torch::ModelEvaluationOptions evaluation_options, |
| 8 | + torch::Dtype dtype, torch::Device device, bool check_consistency) { |
| 9 | + |
| 10 | + // only run the calculation for atoms actually in the current domain |
| 11 | + auto options = torch::TensorOptions().dtype(torch::kInt32); |
| 12 | + auto selected_atoms_values = torch::zeros({n_particles, 2}, options); |
| 13 | + |
| 14 | + for (int i = 0; i < n_particles; i++) { |
| 15 | + selected_atoms_values[i][0] = 0; |
| 16 | + selected_atoms_values[i][1] = i; |
| 17 | + } |
| 18 | + auto selected_atoms = torch::make_intrusive<metatensor_torch::LabelsHolder>( |
| 19 | + std::vector<std::string>{"system", "atom"}, selected_atoms_values); |
| 20 | + evaluation_options->set_selected_atoms(selected_atoms->to(device)); |
| 21 | + |
| 22 | + torch::IValue result_ivalue; |
| 23 | + model->forward({std::vector<metatensor_torch::System>{system}, |
| 24 | + evaluation_options, check_consistency}); |
| 25 | + |
| 26 | + auto result = result_ivalue.toGenericDict(); |
| 27 | + auto energy = |
| 28 | + result.at("energy").toCustomClass<metatensor_torch::TensorMapHolder>(); |
| 29 | + auto energy_tensor = |
| 30 | + metatensor_torch::TensorMapHolder::block_by_id(energy, 0); |
29 | 31 | } |
30 | 32 |
|
31 | | -double get_energy(torch_metatensor::TensorMapHolder& energy, bool energy_is_per_atom) { |
32 | | - auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(energy, 0); |
33 | | - auto energy_tensor = energy_block->values(); |
34 | | - auto energy_detached = energy_tensor.detach().to(torch::kCPU).to(torch::kFloat64); |
35 | | - auto energy_samples = energy_block->samples(); |
36 | | - |
37 | | - // store the energy returned by the model |
38 | | - torch::Tensor global_energy; |
39 | | - if (energy_is_per_atom) { |
40 | | - assert(energy_samples->size() == 2); |
41 | | - assert(energy_samples->names()[0] == "system"); |
42 | | - assert(energy_samples->names()[1] == "atom"); |
43 | | - |
44 | | - auto samples_values = energy_samples->values().to(torch::kCPU); |
45 | | - auto samples = samples_values.accessor<int32_t, 2>(); |
46 | | - |
47 | | -// int n_atoms = selected_atoms_values.sizes(); |
48 | | -// assert(samples_values.sizes() == selected_atoms_values.sizes()); |
49 | | - |
50 | | - auto energies = energy_detached.accessor<double, 2>(); |
51 | | - global_energy = energy_detached.sum(0); |
52 | | - assert(energy_detached.sizes() == std::vector<int64_t>({1})); |
53 | | - } else { |
54 | | - assert(energy_samples->size() == 1); |
55 | | - assert(energy_samples->names()[0] == "system"); |
56 | | - |
57 | | - assert(energy_detached.sizes() == std::vector<int64_t>({1, 1})); |
58 | | - global_energy = energy_detached.reshape({1}); |
59 | | - } |
| 33 | +double get_energy(metatensor_torch::TensorMapHolder &energy, |
| 34 | + bool energy_is_per_atom) { |
| 35 | + auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(energy, 0); |
| 36 | + auto energy_tensor = energy_block->values(); |
| 37 | + auto energy_detached = |
| 38 | + energy_tensor.detach().to(torch::kCPU).to(torch::kFloat64); |
| 39 | + auto energy_samples = energy_block->samples(); |
| 40 | + |
| 41 | + // store the energy returned by the model |
| 42 | + torch::Tensor global_energy; |
| 43 | + if (energy_is_per_atom) { |
| 44 | + assert(energy_samples->size() == 2); |
| 45 | + assert(energy_samples->names()[0] == "system"); |
| 46 | + assert(energy_samples->names()[1] == "atom"); |
| 47 | + |
| 48 | + auto samples_values = energy_samples->values().to(torch::kCPU); |
| 49 | + auto samples = samples_values.accessor<int32_t, 2>(); |
| 50 | + |
| 51 | + // int n_atoms = selected_atoms_values.sizes(); |
| 52 | + // assert(samples_values.sizes() == selected_atoms_values.sizes()); |
| 53 | + |
| 54 | + auto energies = energy_detached.accessor<double, 2>(); |
| 55 | + global_energy = energy_detached.sum(0); |
| 56 | + assert(energy_detached.sizes() == std::vector<int64_t>({1})); |
| 57 | + } else { |
| 58 | + assert(energy_samples->size() == 1); |
| 59 | + assert(energy_samples->names()[0] == "system"); |
| 60 | + |
| 61 | + assert(energy_detached.sizes() == std::vector<int64_t>({1, 1})); |
| 62 | + global_energy = energy_detached.reshape({1}); |
| 63 | + } |
60 | 64 |
|
61 | 65 | return global_energy.item<double>(); |
62 | 66 | } |
63 | 67 |
|
| 68 | +torch::Tensor get_forces(metatensor::TensorMap &energy, |
| 69 | + metatensor_torch::System &system) { |
| 70 | + // reset gradients to zero before calling backward |
| 71 | + system->positions().mutable_grad() = torch::Tensor(); |
64 | 72 |
|
65 | | -torch::Tensor get_forces(torch_metatensor::TensorMap& energy, torch_metatensor::System& system) { |
66 | | - // reset gradients to zero before calling backward |
67 | | - system->positions.mutable_grad() = torch::Tensor(); |
68 | | - |
69 | | - auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(energy, 0); |
70 | | - auto energy_tensor = energy_block->values(); |
| 73 | + auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(energy, 0); |
| 74 | + auto energy_tensor = energy_block->values(); |
71 | 75 |
|
72 | | - // compute forces/virial with backward propagation |
73 | | - energy_tensor.backward(-torch::ones_like(energy_tensor)); |
74 | | - auto forces_tensor = sytem->positions.grad(); |
75 | | - assert(forces_tensor.is_cpu() && forces_tensor.scalar_type() == torch::kFloat64); |
| 76 | + // compute forces/virial with backward propagation |
| 77 | + energy_tensor.backward(-torch::ones_like(energy_tensor)); |
| 78 | + auto forces_tensor = system->positions().grad(); |
| 79 | + assert(forces_tensor.is_cpu() && |
| 80 | + forces_tensor.scalar_type() == torch::kFloat64); |
76 | 81 | return forces_tensor; |
77 | 82 | } |
78 | | - |
79 | | - |
80 | | - |
|
0 commit comments