-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Message passing operation #59
base: master
Are you sure you want to change the base?
Changes from all commits
1a44054
8fd009d
b9de467
5d9994a
4070236
7f47b19
6f22079
b8e0089
71a8469
de52137
982e5b5
0463d65
c0c6142
5b6507e
8826219
9bf61c6
77cb418
47c2495
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import pytest | ||
import torch as pt | ||
from NNPOps.messages import passMessages | ||
|
||
|
||
@pytest.mark.parametrize('device', ['cpu', 'cuda']) | ||
@pytest.mark.parametrize('dtype', [pt.float32, pt.float64]) | ||
@pytest.mark.parametrize('num_pairs', [1, 2, 3, 4, 5, 10, 100]) | ||
@pytest.mark.parametrize('num_atoms', [1, 2, 3, 4, 5, 10, 100]) | ||
@pytest.mark.parametrize('num_states', [32, 64, 1024]) | ||
def testPassMessageValues(device, dtype, num_pairs, num_atoms, num_states): | ||
|
||
if not pt.cuda.is_available() and device == 'cuda': | ||
pytest.skip('No GPU') | ||
|
||
# Generate random neighbors | ||
neighbors = pt.randint(0, num_atoms, (2, num_pairs), dtype=pt.int32, device=device) | ||
neighbors[:, pt.rand(num_pairs) > 0.5] = -1 | ||
|
||
# Generate random messages and states | ||
messages = pt.randn((num_pairs, num_states), dtype=dtype, device=device) | ||
states = pt.randn((num_atoms, num_states), dtype=dtype, device=device) | ||
|
||
# Compute reference | ||
mask = pt.logical_and(neighbors[0] > -1, neighbors[1] > -1) | ||
masked_neighbors = neighbors[:, mask].to(pt.long) | ||
masked_messages = messages[mask, :] | ||
ref_new_states = states.index_add(0, masked_neighbors[0], masked_messages)\ | ||
.index_add(0, masked_neighbors[1], masked_messages) | ||
|
||
# Compute results | ||
new_states = passMessages(neighbors, messages, states) | ||
|
||
# Check data type and device | ||
assert new_states.device == neighbors.device | ||
assert new_states.dtype == dtype | ||
|
||
# Check values | ||
if dtype == pt.float32 and num_pairs > 10 and num_atoms < 10: | ||
assert pt.allclose(ref_new_states, new_states, atol=1e-5, rtol=1e-3) | ||
elif dtype == pt.float32: | ||
assert pt.allclose(ref_new_states, new_states, atol=1e-6, rtol=1e-4) | ||
else: | ||
assert pt.allclose(ref_new_states, new_states, atol=1e-12, rtol=1e-8) | ||
|
||
@pytest.mark.parametrize('dtype', [pt.float32, pt.float64]) | ||
@pytest.mark.parametrize('num_pairs', [1, 2, 3, 4, 5, 10, 100]) | ||
@pytest.mark.parametrize('num_atoms', [1, 2, 3, 4, 5, 10, 100]) | ||
@pytest.mark.parametrize('num_states', [32, 64, 1024]) | ||
def testPassMessagesGrads(dtype, num_pairs, num_atoms, num_states): | ||
|
||
if not pt.cuda.is_available(): | ||
pytest.skip('No GPU') | ||
|
||
# Generate random neighbors | ||
neighbors = pt.randint(0, num_atoms, (2, num_pairs), dtype=pt.int32) | ||
neighbors[:, pt.rand(num_pairs) > 0.5] = -1 | ||
|
||
# Generate random messages and states | ||
messages = pt.randn((num_pairs, num_states), dtype=dtype) | ||
states = pt.randn((num_atoms, num_states), dtype=dtype) | ||
|
||
# Compute CPU gradients | ||
neighbors_cpu = neighbors.detach().cpu() | ||
messages_cpu = messages.detach().cpu() | ||
states_cpu = states.detach().cpu() | ||
messages_cpu.requires_grad_() | ||
states_cpu.requires_grad_() | ||
passMessages(neighbors_cpu, messages_cpu, states_cpu).norm().backward() | ||
|
||
# Compute CUDA gradients | ||
neighbors_cuda = neighbors.detach().cuda() | ||
messages_cuda = messages.detach().cuda() | ||
states_cuda = states.detach().cuda() | ||
messages_cuda.requires_grad_() | ||
states_cuda.requires_grad_() | ||
passMessages(neighbors_cuda, messages_cuda, states_cuda).norm().backward() | ||
|
||
# Check type and device | ||
assert messages_cuda.grad.dtype == dtype | ||
assert states_cuda.grad.dtype == dtype | ||
assert messages_cuda.grad.device == neighbors_cuda.device | ||
assert states_cuda.grad.device == neighbors_cuda.device | ||
|
||
# Check gradients | ||
if dtype == pt.float32: | ||
assert pt.allclose(messages_cpu.grad, messages_cuda.grad.cpu(), atol=1e-6, rtol=1e-4) | ||
assert pt.allclose(states_cpu.grad, states_cuda.grad.cpu(), atol=1e-6, rtol=1e-4) | ||
else: | ||
assert pt.allclose(messages_cpu.grad, messages_cuda.grad.cpu(), atol=1e-12, rtol=1e-8) | ||
assert pt.allclose(states_cpu.grad, states_cuda.grad.cpu(), atol=1e-12, rtol=1e-8) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
''' | ||
Message passing operations | ||
''' | ||
|
||
from NNPOps.messages.passMessages import passMessages |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#include <torch/extension.h> | ||
|
||
TORCH_LIBRARY(messages, m) { | ||
m.def("passMessages(Tensor neighbors, Tensor messages, Tensor states) -> (Tensor states)"); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from torch import ops, Tensor | ||
|
||
|
||
def passMessages(neighbors: Tensor, messages: Tensor, states: Tensor) -> Tensor: | ||
''' | ||
Pass messages between the neighbor atoms. | ||
|
||
Given a set of `num_atoms` atoms (each atom has a state with `num_features` | ||
features) and a set of `num_neighbors` neighbor atom pairs (each pair has a | ||
message with `num_features` features), the messages of the pairs are added | ||
to the corresponding atom states. | ||
|
||
Parameters | ||
---------- | ||
neighbors: `torch.Tensor` | ||
Atom pair indices. The shape of the tensor is `(2, num_pairs)`. | ||
The indices can be `[0, num_atoms)` or `-1` (ignored pairs). | ||
See for the documentation of `NNPOps.neighbors.getNeighborPairs` for | ||
details. | ||
messages: `torch.Tensor` | ||
Atom pair messages. The shape of the tensor is `(num_pairs, num_features)`. | ||
For efficient, `num_features` has to be a multiple of 32 and <= 1024. | ||
states: `torch.Tensor` | ||
Atom states. The shape of the tensor is `(num_atoms, num_features)`. | ||
|
||
Returns | ||
------- | ||
new_states: `torch.Tensor` | ||
Update atom states. The shape of the tensor is `(num_atoms, num_features)`. | ||
|
||
Note | ||
---- | ||
The operation is compatible with CUDA Grahps, i.e. the shapes of the output | ||
tensors are independed of the values of input tensors. | ||
|
||
Examples | ||
-------- | ||
>>> import torch as pt | ||
>>> from NNPOps.messages import passMessages | ||
|
||
>>> num_atoms = 4 | ||
>>> num_neigbors = 3 | ||
>>> num_features = 32 | ||
|
||
>>> neighbors = pt.tensor([[0, -1, 1], [0, -1, 3]], dtype=pt.int32) | ||
|
||
>>> messages = pt.ones((num_neigbors, 32)); messages[1] = 5 | ||
>>> messages[:, 0] | ||
tensor([1., 5., 1.]) | ||
|
||
>>> states = pt.zeros((num_atoms, num_features)); states[1] = 3 | ||
>>> states[:, 0] | ||
tensor([0., 3., 0., 0.]) | ||
|
||
>>> new_states = passMessages(neighbors, messages, states) | ||
>>> new_states[:, 0] | ||
tensor([2., 4., 0., 1.]) | ||
''' | ||
|
||
return ops.messages.passMessages(neighbors, messages, states) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#include <torch/extension.h> | ||
|
||
using torch::kInt32; | ||
using torch::logical_and; | ||
using torch::Tensor; | ||
|
||
static Tensor forward(const Tensor& neighbors, const Tensor& messages, const Tensor& states) { | ||
|
||
TORCH_CHECK(neighbors.dim() == 2, "Expected \"neighbors\" to have two dimensions"); | ||
TORCH_CHECK(neighbors.size(0) == 2, "Expected the 2nd dimension size of \"neighbors\" to be 2"); | ||
TORCH_CHECK(neighbors.scalar_type() == kInt32, "Expected \"neighbors\" to have data type of int32"); | ||
TORCH_CHECK(neighbors.is_contiguous(), "Expected \"neighbors\" to be contiguous"); | ||
|
||
TORCH_CHECK(messages.dim() == 2, "Expected \"messages\" to have two dimensions"); | ||
TORCH_CHECK(messages.size(1) % 32 == 0, "Expected the 2nd dimension size of \"messages\" to be a multiple of 32"); | ||
TORCH_CHECK(messages.size(1) <= 1024, "Expected the 2nd dimension size of \"messages\" to be less than 1024"); | ||
TORCH_CHECK(messages.is_contiguous(), "Expected \"messages\" to be contiguous"); | ||
|
||
TORCH_CHECK(states.dim() == 2, "Expected \"states\" to have two dimensions"); | ||
TORCH_CHECK(states.size(1) == messages.size(1), "Expected the 2nd dimension size of \"messages\" and \"states\" to be the same"); | ||
TORCH_CHECK(states.scalar_type() == messages.scalar_type(), "Expected the data type of \"messages\" and \"states\" to be the same"); | ||
TORCH_CHECK(states.is_contiguous(), "Expected \"messages\" to be contiguous"); | ||
|
||
const Tensor rows = neighbors[0]; | ||
const Tensor columns = neighbors[1]; | ||
|
||
const int num_features = messages.size(1); | ||
|
||
const Tensor mask = logical_and(rows > -1, columns > -1); | ||
const Tensor masked_rows = rows.masked_select(mask).to(torch::kLong); | ||
const Tensor masked_columns = columns.masked_select(mask).to(torch::kLong); | ||
const Tensor masked_messages = messages.masked_select(mask.unsqueeze(1)).reshape({-1, num_features}); | ||
|
||
Tensor new_states = states.clone(); | ||
new_states.index_add_(0, masked_rows, masked_messages); | ||
new_states.index_add_(0, masked_columns, masked_messages); | ||
|
||
return new_states; | ||
} | ||
|
||
TORCH_LIBRARY_IMPL(messages, CPU, m) { | ||
m.impl("passMessages", &forward); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
#include <c10/cuda/CUDAGuard.h> | ||
#include <c10/cuda/CUDAStream.h> | ||
#include <torch/extension.h> | ||
|
||
#include "common/accessor.cuh" | ||
#include "common/atomicAdd.cuh" | ||
|
||
using c10::cuda::CUDAStreamGuard; | ||
using c10::cuda::getCurrentCUDAStream; | ||
using torch::autograd::AutogradContext; | ||
using torch::autograd::Function; | ||
using torch::autograd::tensor_list; | ||
using torch::kInt32; | ||
using torch::Tensor; | ||
using torch::TensorOptions; | ||
|
||
template <typename scalar_t> __global__ void kernel_forward( | ||
const Accessor<int32_t, 2> neighbors, | ||
const Accessor<scalar_t, 2> messages, | ||
Accessor<scalar_t, 2> new_states | ||
) { | ||
const int32_t i_neig = blockIdx.x; | ||
const int32_t i_dir = blockIdx.y; | ||
const int32_t i_atom = neighbors[i_dir][i_neig]; | ||
if (i_atom < 0) return; | ||
|
||
const int32_t i_feat = threadIdx.x; | ||
atomicAdd(&new_states[i_atom][i_feat], messages[i_neig][i_feat]); | ||
Comment on lines
+27
to
+28
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can eliminate the limitations on number of features by just rewriting this as a loop. for (int32_t i_feat = threadIdx.x; i_feat < num_features; i_feat += blockDim.x)
atomicAdd(&new_states[i_atom][i_feat], messages[i_neig][i_feat]); There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Apart from solving a non-existing problem, this would make the memory access not coalesced and reduce speed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would have no effect on speed at all. If the number of features happens to satisfy your current requirement, the behavior would be identical to what it currently does. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will reduce the number of thread by the number of features. The reduced parallelism would result into reduced speed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not suggesting any change to the number of threads. The only thing I'm suggesting is wrapping the |
||
} | ||
|
||
template <typename scalar_t> __global__ void kernel_backward( | ||
const Accessor<int32_t, 2> neighbors, | ||
const Accessor<scalar_t, 2> grad_new_state, | ||
Accessor<scalar_t, 2> grad_messages | ||
) { | ||
const int32_t i_neig = blockIdx.x; | ||
const int32_t i_dir = blockIdx.y; | ||
const int32_t i_atom = neighbors[i_dir][i_neig]; | ||
if (i_atom < 0) return; | ||
|
||
const int32_t i_feat = threadIdx.x; | ||
atomicAdd(&grad_messages[i_neig][i_feat], grad_new_state[i_atom][i_feat]); | ||
} | ||
|
||
class Autograd : public Function<Autograd> { | ||
public: | ||
static tensor_list forward(AutogradContext* ctx, | ||
const Tensor& neighbors, | ||
const Tensor& messages, | ||
const Tensor& states) { | ||
|
||
TORCH_CHECK(neighbors.dim() == 2, "Expected \"neighbors\" to have two dimensions"); | ||
TORCH_CHECK(neighbors.size(0) == 2, "Expected the 2nd dimension size of \"neighbors\" to be 2"); | ||
TORCH_CHECK(neighbors.scalar_type() == kInt32, "Expected \"neighbors\" to have data type of int32"); | ||
TORCH_CHECK(neighbors.is_contiguous(), "Expected \"neighbors\" to be contiguous"); | ||
|
||
TORCH_CHECK(messages.dim() == 2, "Expected \"messages\" to have two dimensions"); | ||
TORCH_CHECK(messages.size(1) % 32 == 0, "Expected the 2nd dimension size of \"messages\" to be a multiple of 32"); | ||
TORCH_CHECK(messages.size(1) <= 1024, "Expected the 2nd dimension size of \"messages\" to be less than 1024"); | ||
TORCH_CHECK(messages.is_contiguous(), "Expected \"messages\" to be contiguous"); | ||
|
||
TORCH_CHECK(states.dim() == 2, "Expected \"states\" to have two dimensions"); | ||
TORCH_CHECK(states.size(1) == messages.size(1), "Expected the 2nd dimension size of \"messages\" and \"states\" to be the same"); | ||
TORCH_CHECK(states.scalar_type() == messages.scalar_type(), "Expected the data type of \"messages\" and \"states\" to be the same"); | ||
TORCH_CHECK(states.is_contiguous(), "Expected \"messages\" to be contiguous"); | ||
|
||
const int num_neighbors = neighbors.size(1); | ||
const int num_features = messages.size(1); | ||
|
||
const dim3 blocks(num_neighbors, 2); | ||
const dim3 threads(num_features); | ||
const auto stream = getCurrentCUDAStream(neighbors.get_device()); | ||
|
||
Tensor new_states = states.clone(); | ||
|
||
AT_DISPATCH_FLOATING_TYPES(messages.scalar_type(), "passMessages::forward", [&]() { | ||
const CUDAStreamGuard guard(stream); | ||
kernel_forward<<<blocks, threads, 0, stream>>>( | ||
get_accessor<int32_t, 2>(neighbors), | ||
get_accessor<scalar_t, 2>(messages), | ||
get_accessor<scalar_t, 2>(new_states)); | ||
}); | ||
|
||
ctx->save_for_backward({neighbors}); | ||
|
||
return {new_states}; | ||
} | ||
|
||
static tensor_list backward(AutogradContext* ctx, tensor_list grad_inputs) { | ||
|
||
const Tensor neighbors = ctx->get_saved_variables()[0]; | ||
const Tensor grad_new_state = grad_inputs[0]; | ||
|
||
const int num_neighbors = neighbors.size(1); | ||
const int num_features = grad_new_state.size(1); | ||
|
||
const dim3 blocks(num_neighbors, 2); | ||
const dim3 threads(num_features); | ||
const auto stream = getCurrentCUDAStream(neighbors.get_device()); | ||
|
||
Tensor grad_messages = torch::zeros({num_neighbors, num_features}, grad_new_state.options()); | ||
|
||
AT_DISPATCH_FLOATING_TYPES(grad_new_state.scalar_type(), "passMessages::backward", [&]() { | ||
const CUDAStreamGuard guard(stream); | ||
kernel_backward<<<blocks, threads, 0, stream>>>( | ||
get_accessor<int32_t, 2>(neighbors), | ||
get_accessor<scalar_t, 2>(grad_new_state), | ||
get_accessor<scalar_t, 2>(grad_messages)); | ||
}); | ||
|
||
return {Tensor(), // grad_neighbors | ||
grad_messages, | ||
grad_new_state.clone()}; // grad_state | ||
} | ||
}; | ||
|
||
TORCH_LIBRARY_IMPL(messages, AutogradCUDA, m) { | ||
m.impl("passMessages", [](const Tensor& neighbors, | ||
const Tensor& messages, | ||
const Tensor& states) { | ||
return Autograd::apply(neighbors, messages, states)[0]; | ||
}); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are those limitations really necessary? It's very common for the number of features not to be a multiple of 32.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In contrary, the number of internal features is always factor of 32 (e.g. in TorchMD-NET, I have seen usage of 64, 128, 256). GPU computes in warps of 32 threads, so it is the best to match that patter for computational efficiency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I frequently create models that don't satisfy those requirements, including in TorchMD-Net. For example, I've trained models with 48 or 80 features per layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand your question. Why not? The number of features is a hyperparameter. It's one of many hyperarameters you tune to balance training accuracy, overfitting, speed, and memory use. Why place arbitrary limits on it when there's no need to?