Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Message passing operation #59

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
10 changes: 8 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ set(SRC_FILES src/ani/CpuANISymmetryFunctions.cpp
src/pytorch/CFConv.cpp
src/pytorch/CFConvNeighbors.cpp
src/pytorch/SymmetryFunctions.cpp
src/pytorch/messages/messages.cpp
src/pytorch/messages/passMessagesCPU.cpp
src/pytorch/messages/passMessagesCUDA.cu
src/pytorch/neighbors/getNeighborPairsCPU.cpp
src/pytorch/neighbors/getNeighborPairsCUDA.cu
src/pytorch/neighbors/neighbors.cpp
Expand Down Expand Up @@ -65,7 +68,9 @@ add_test(TestEnergyShifter pytest -v ${CMAKE_SOURCE_DIR}/src/pytorch/TestEne
add_test(TestOptimizedTorchANI pytest -v ${CMAKE_SOURCE_DIR}/src/pytorch/TestOptimizedTorchANI.py)
add_test(TestSpeciesConverter pytest -v ${CMAKE_SOURCE_DIR}/src/pytorch/TestSpeciesConverter.py)
add_test(TestSymmetryFunctions pytest -v ${CMAKE_SOURCE_DIR}/src/pytorch/TestSymmetryFunctions.py)
add_test(TestMessages pytest -v ${CMAKE_SOURCE_DIR}/src/pytorch/messages/TestMessages.py)
add_test(TestNeighbors pytest -v ${CMAKE_SOURCE_DIR}/src/pytorch/neighbors/TestNeighbors.py)
add_test(TestPassMessages pytest -v --doctest-modules ${CMAKE_SOURCE_DIR}/src/pytorch/messages/passMessages.py)
add_test(TestGetNeighborPairs pytest -v --doctest-modules ${CMAKE_SOURCE_DIR}/src/pytorch/neighbors/getNeighborPairs.py)

# Installation
Expand All @@ -78,9 +83,10 @@ install(FILES src/pytorch/__init__.py
src/pytorch/OptimizedTorchANI.py
src/pytorch/SpeciesConverter.py
src/pytorch/SymmetryFunctions.py
src/pytorch/neighbors/__init__.py
src/pytorch/neighbors/getNeighborPairs.py
DESTINATION ${Python_SITEARCH}/${NAME})
install(FILES src/pytorch/messages/__init__.py
src/pytorch/messages/passMessages.py
DESTINATION ${Python_SITEARCH}/${NAME}/messages)
install(FILES src/pytorch/neighbors/__init__.py
src/pytorch/neighbors/getNeighborPairs.py
DESTINATION ${Python_SITEARCH}/${NAME}/neighbors)
91 changes: 91 additions & 0 deletions src/pytorch/messages/TestMessages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import pytest
import torch as pt
from NNPOps.messages import passMessages


@pytest.mark.parametrize('device', ['cpu', 'cuda'])
@pytest.mark.parametrize('dtype', [pt.float32, pt.float64])
@pytest.mark.parametrize('num_pairs', [1, 2, 3, 4, 5, 10, 100])
@pytest.mark.parametrize('num_atoms', [1, 2, 3, 4, 5, 10, 100])
@pytest.mark.parametrize('num_states', [32, 64, 1024])
def testPassMessageValues(device, dtype, num_pairs, num_atoms, num_states):

if not pt.cuda.is_available() and device == 'cuda':
pytest.skip('No GPU')

# Generate random neighbors
neighbors = pt.randint(0, num_atoms, (2, num_pairs), dtype=pt.int32, device=device)
neighbors[:, pt.rand(num_pairs) > 0.5] = -1

# Generate random messages and states
messages = pt.randn((num_pairs, num_states), dtype=dtype, device=device)
states = pt.randn((num_atoms, num_states), dtype=dtype, device=device)

# Compute reference
mask = pt.logical_and(neighbors[0] > -1, neighbors[1] > -1)
masked_neighbors = neighbors[:, mask].to(pt.long)
masked_messages = messages[mask, :]
ref_new_states = states.index_add(0, masked_neighbors[0], masked_messages)\
.index_add(0, masked_neighbors[1], masked_messages)

# Compute results
new_states = passMessages(neighbors, messages, states)

# Check data type and device
assert new_states.device == neighbors.device
assert new_states.dtype == dtype

# Check values
if dtype == pt.float32 and num_pairs > 10 and num_atoms < 10:
assert pt.allclose(ref_new_states, new_states, atol=1e-5, rtol=1e-3)
elif dtype == pt.float32:
assert pt.allclose(ref_new_states, new_states, atol=1e-6, rtol=1e-4)
else:
assert pt.allclose(ref_new_states, new_states, atol=1e-12, rtol=1e-8)

@pytest.mark.parametrize('dtype', [pt.float32, pt.float64])
@pytest.mark.parametrize('num_pairs', [1, 2, 3, 4, 5, 10, 100])
@pytest.mark.parametrize('num_atoms', [1, 2, 3, 4, 5, 10, 100])
@pytest.mark.parametrize('num_states', [32, 64, 1024])
def testPassMessagesGrads(dtype, num_pairs, num_atoms, num_states):

if not pt.cuda.is_available():
pytest.skip('No GPU')

# Generate random neighbors
neighbors = pt.randint(0, num_atoms, (2, num_pairs), dtype=pt.int32)
neighbors[:, pt.rand(num_pairs) > 0.5] = -1

# Generate random messages and states
messages = pt.randn((num_pairs, num_states), dtype=dtype)
states = pt.randn((num_atoms, num_states), dtype=dtype)

# Compute CPU gradients
neighbors_cpu = neighbors.detach().cpu()
messages_cpu = messages.detach().cpu()
states_cpu = states.detach().cpu()
messages_cpu.requires_grad_()
states_cpu.requires_grad_()
passMessages(neighbors_cpu, messages_cpu, states_cpu).norm().backward()

# Compute CUDA gradients
neighbors_cuda = neighbors.detach().cuda()
messages_cuda = messages.detach().cuda()
states_cuda = states.detach().cuda()
messages_cuda.requires_grad_()
states_cuda.requires_grad_()
passMessages(neighbors_cuda, messages_cuda, states_cuda).norm().backward()

# Check type and device
assert messages_cuda.grad.dtype == dtype
assert states_cuda.grad.dtype == dtype
assert messages_cuda.grad.device == neighbors_cuda.device
assert states_cuda.grad.device == neighbors_cuda.device

# Check gradients
if dtype == pt.float32:
assert pt.allclose(messages_cpu.grad, messages_cuda.grad.cpu(), atol=1e-6, rtol=1e-4)
assert pt.allclose(states_cpu.grad, states_cuda.grad.cpu(), atol=1e-6, rtol=1e-4)
else:
assert pt.allclose(messages_cpu.grad, messages_cuda.grad.cpu(), atol=1e-12, rtol=1e-8)
assert pt.allclose(states_cpu.grad, states_cuda.grad.cpu(), atol=1e-12, rtol=1e-8)
5 changes: 5 additions & 0 deletions src/pytorch/messages/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
'''
Message passing operations
'''

from NNPOps.messages.passMessages import passMessages
5 changes: 5 additions & 0 deletions src/pytorch/messages/messages.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <torch/extension.h>

TORCH_LIBRARY(messages, m) {
m.def("passMessages(Tensor neighbors, Tensor messages, Tensor states) -> (Tensor states)");
}
60 changes: 60 additions & 0 deletions src/pytorch/messages/passMessages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from torch import ops, Tensor


def passMessages(neighbors: Tensor, messages: Tensor, states: Tensor) -> Tensor:
'''
Pass messages between the neighbor atoms.

Given a set of `num_atoms` atoms (each atom has a state with `num_features`
features) and a set of `num_neighbors` neighbor atom pairs (each pair has a
message with `num_features` features), the messages of the pairs are added
to the corresponding atom states.

Parameters
----------
neighbors: `torch.Tensor`
Atom pair indices. The shape of the tensor is `(2, num_pairs)`.
The indices can be `[0, num_atoms)` or `-1` (ignored pairs).
See for the documentation of `NNPOps.neighbors.getNeighborPairs` for
details.
messages: `torch.Tensor`
Atom pair messages. The shape of the tensor is `(num_pairs, num_features)`.
For efficient, `num_features` has to be a multiple of 32 and <= 1024.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are those limitations really necessary? It's very common for the number of features not to be a multiple of 32.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In contrary, the number of internal features is always factor of 32 (e.g. in TorchMD-NET, I have seen usage of 64, 128, 256). GPU computes in warps of 32 threads, so it is the best to match that patter for computational efficiency.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I frequently create models that don't satisfy those requirements, including in TorchMD-Net. For example, I've trained models with 48 or 80 features per layer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand your question. Why not? The number of features is a hyperparameter. It's one of many hyperarameters you tune to balance training accuracy, overfitting, speed, and memory use. Why place arbitrary limits on it when there's no need to?

states: `torch.Tensor`
Atom states. The shape of the tensor is `(num_atoms, num_features)`.

Returns
-------
new_states: `torch.Tensor`
Update atom states. The shape of the tensor is `(num_atoms, num_features)`.

Note
----
The operation is compatible with CUDA Grahps, i.e. the shapes of the output
tensors are independed of the values of input tensors.

Examples
--------
>>> import torch as pt
>>> from NNPOps.messages import passMessages

>>> num_atoms = 4
>>> num_neigbors = 3
>>> num_features = 32

>>> neighbors = pt.tensor([[0, -1, 1], [0, -1, 3]], dtype=pt.int32)

>>> messages = pt.ones((num_neigbors, 32)); messages[1] = 5
>>> messages[:, 0]
tensor([1., 5., 1.])

>>> states = pt.zeros((num_atoms, num_features)); states[1] = 3
>>> states[:, 0]
tensor([0., 3., 0., 0.])

>>> new_states = passMessages(neighbors, messages, states)
>>> new_states[:, 0]
tensor([2., 4., 0., 1.])
'''

return ops.messages.passMessages(neighbors, messages, states)
43 changes: 43 additions & 0 deletions src/pytorch/messages/passMessagesCPU.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include <torch/extension.h>

using torch::kInt32;
using torch::logical_and;
using torch::Tensor;

static Tensor forward(const Tensor& neighbors, const Tensor& messages, const Tensor& states) {

TORCH_CHECK(neighbors.dim() == 2, "Expected \"neighbors\" to have two dimensions");
TORCH_CHECK(neighbors.size(0) == 2, "Expected the 2nd dimension size of \"neighbors\" to be 2");
TORCH_CHECK(neighbors.scalar_type() == kInt32, "Expected \"neighbors\" to have data type of int32");
TORCH_CHECK(neighbors.is_contiguous(), "Expected \"neighbors\" to be contiguous");

TORCH_CHECK(messages.dim() == 2, "Expected \"messages\" to have two dimensions");
TORCH_CHECK(messages.size(1) % 32 == 0, "Expected the 2nd dimension size of \"messages\" to be a multiple of 32");
TORCH_CHECK(messages.size(1) <= 1024, "Expected the 2nd dimension size of \"messages\" to be less than 1024");
TORCH_CHECK(messages.is_contiguous(), "Expected \"messages\" to be contiguous");

TORCH_CHECK(states.dim() == 2, "Expected \"states\" to have two dimensions");
TORCH_CHECK(states.size(1) == messages.size(1), "Expected the 2nd dimension size of \"messages\" and \"states\" to be the same");
TORCH_CHECK(states.scalar_type() == messages.scalar_type(), "Expected the data type of \"messages\" and \"states\" to be the same");
TORCH_CHECK(states.is_contiguous(), "Expected \"messages\" to be contiguous");

const Tensor rows = neighbors[0];
const Tensor columns = neighbors[1];

const int num_features = messages.size(1);

const Tensor mask = logical_and(rows > -1, columns > -1);
const Tensor masked_rows = rows.masked_select(mask).to(torch::kLong);
const Tensor masked_columns = columns.masked_select(mask).to(torch::kLong);
const Tensor masked_messages = messages.masked_select(mask.unsqueeze(1)).reshape({-1, num_features});

Tensor new_states = states.clone();
new_states.index_add_(0, masked_rows, masked_messages);
new_states.index_add_(0, masked_columns, masked_messages);

return new_states;
}

TORCH_LIBRARY_IMPL(messages, CPU, m) {
m.impl("passMessages", &forward);
}
123 changes: 123 additions & 0 deletions src/pytorch/messages/passMessagesCUDA.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>

#include "common/accessor.cuh"
#include "common/atomicAdd.cuh"

using c10::cuda::CUDAStreamGuard;
using c10::cuda::getCurrentCUDAStream;
using torch::autograd::AutogradContext;
using torch::autograd::Function;
using torch::autograd::tensor_list;
using torch::kInt32;
using torch::Tensor;
using torch::TensorOptions;

template <typename scalar_t> __global__ void kernel_forward(
const Accessor<int32_t, 2> neighbors,
const Accessor<scalar_t, 2> messages,
Accessor<scalar_t, 2> new_states
) {
const int32_t i_neig = blockIdx.x;
const int32_t i_dir = blockIdx.y;
const int32_t i_atom = neighbors[i_dir][i_neig];
if (i_atom < 0) return;

const int32_t i_feat = threadIdx.x;
atomicAdd(&new_states[i_atom][i_feat], messages[i_neig][i_feat]);
Comment on lines +27 to +28
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can eliminate the limitations on number of features by just rewriting this as a loop.

for (int32_t i_feat = threadIdx.x; i_feat < num_features; i_feat += blockDim.x)
    atomicAdd(&new_states[i_atom][i_feat], messages[i_neig][i_feat]);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apart from solving a non-existing problem, this would make the memory access not coalesced and reduce speed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would have no effect on speed at all. If the number of features happens to satisfy your current requirement, the behavior would be identical to what it currently does. The atomicAdd() would be executed once by every thread with i_feat equal to threadIdx.x. The only change would be if the number doesn't satisfy your current requirements, either because it's not a multiple of 32 or it's more than 1024. In that case it would produce correct behavior, unlike the current code. So there's no downside at all.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will reduce the number of thread by the number of features. The reduced parallelism would result into reduced speed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not suggesting any change to the number of threads. The only thing I'm suggesting is wrapping the atomicAdd() in a loop as shown above. If num_features happens to match your current restrictions, nothing will change. Every thread will still call it exactly once.

}

template <typename scalar_t> __global__ void kernel_backward(
const Accessor<int32_t, 2> neighbors,
const Accessor<scalar_t, 2> grad_new_state,
Accessor<scalar_t, 2> grad_messages
) {
const int32_t i_neig = blockIdx.x;
const int32_t i_dir = blockIdx.y;
const int32_t i_atom = neighbors[i_dir][i_neig];
if (i_atom < 0) return;

const int32_t i_feat = threadIdx.x;
atomicAdd(&grad_messages[i_neig][i_feat], grad_new_state[i_atom][i_feat]);
}

class Autograd : public Function<Autograd> {
public:
static tensor_list forward(AutogradContext* ctx,
const Tensor& neighbors,
const Tensor& messages,
const Tensor& states) {

TORCH_CHECK(neighbors.dim() == 2, "Expected \"neighbors\" to have two dimensions");
TORCH_CHECK(neighbors.size(0) == 2, "Expected the 2nd dimension size of \"neighbors\" to be 2");
TORCH_CHECK(neighbors.scalar_type() == kInt32, "Expected \"neighbors\" to have data type of int32");
TORCH_CHECK(neighbors.is_contiguous(), "Expected \"neighbors\" to be contiguous");

TORCH_CHECK(messages.dim() == 2, "Expected \"messages\" to have two dimensions");
TORCH_CHECK(messages.size(1) % 32 == 0, "Expected the 2nd dimension size of \"messages\" to be a multiple of 32");
TORCH_CHECK(messages.size(1) <= 1024, "Expected the 2nd dimension size of \"messages\" to be less than 1024");
TORCH_CHECK(messages.is_contiguous(), "Expected \"messages\" to be contiguous");

TORCH_CHECK(states.dim() == 2, "Expected \"states\" to have two dimensions");
TORCH_CHECK(states.size(1) == messages.size(1), "Expected the 2nd dimension size of \"messages\" and \"states\" to be the same");
TORCH_CHECK(states.scalar_type() == messages.scalar_type(), "Expected the data type of \"messages\" and \"states\" to be the same");
TORCH_CHECK(states.is_contiguous(), "Expected \"messages\" to be contiguous");

const int num_neighbors = neighbors.size(1);
const int num_features = messages.size(1);

const dim3 blocks(num_neighbors, 2);
const dim3 threads(num_features);
const auto stream = getCurrentCUDAStream(neighbors.get_device());

Tensor new_states = states.clone();

AT_DISPATCH_FLOATING_TYPES(messages.scalar_type(), "passMessages::forward", [&]() {
const CUDAStreamGuard guard(stream);
kernel_forward<<<blocks, threads, 0, stream>>>(
get_accessor<int32_t, 2>(neighbors),
get_accessor<scalar_t, 2>(messages),
get_accessor<scalar_t, 2>(new_states));
});

ctx->save_for_backward({neighbors});

return {new_states};
}

static tensor_list backward(AutogradContext* ctx, tensor_list grad_inputs) {

const Tensor neighbors = ctx->get_saved_variables()[0];
const Tensor grad_new_state = grad_inputs[0];

const int num_neighbors = neighbors.size(1);
const int num_features = grad_new_state.size(1);

const dim3 blocks(num_neighbors, 2);
const dim3 threads(num_features);
const auto stream = getCurrentCUDAStream(neighbors.get_device());

Tensor grad_messages = torch::zeros({num_neighbors, num_features}, grad_new_state.options());

AT_DISPATCH_FLOATING_TYPES(grad_new_state.scalar_type(), "passMessages::backward", [&]() {
const CUDAStreamGuard guard(stream);
kernel_backward<<<blocks, threads, 0, stream>>>(
get_accessor<int32_t, 2>(neighbors),
get_accessor<scalar_t, 2>(grad_new_state),
get_accessor<scalar_t, 2>(grad_messages));
});

return {Tensor(), // grad_neighbors
grad_messages,
grad_new_state.clone()}; // grad_state
}
};

TORCH_LIBRARY_IMPL(messages, AutogradCUDA, m) {
m.impl("passMessages", [](const Tensor& neighbors,
const Tensor& messages,
const Tensor& states) {
return Autograd::apply(neighbors, messages, states)[0];
});
}