Open
Description
Dear authors,
I find that if we use torch_scatter in CUDA graph. The output memory should be pre-allocated (the param out
) in scatter_mean
, scatter_max
, scatter_add
, etc.
Meanwhile, since scatter_softmax
does not support out
argument, scatter_softmax
+ CUDA Graph always raises the same RuntimeError as scatter_mean
w/o given out
.
from torch_scatter import scatter_mean, scatter_max, scatter_add
import torch
num_iters = 1000
device = torch.device("cuda:0")
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]).to(device)
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]).to(device)
print("Param out is given")
out_mean = src.new_zeros((2, 6))
out_max = src.new_zeros((2, 6))
out_add = src.new_zeros((2, 6))
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for i in range(3):
out_mean = scatter_mean(src, index, out=out_mean)
out_max, _ = scatter_max(src, index, out=out_max)
out_add = scatter_add(src, index, out=out_add)
torch.cuda.current_stream().wait_stream(s)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
out_mean = scatter_mean(src, index, out=out_mean)
out_max, _ = scatter_max(src, index, out=out_max)
out_add = scatter_add(src, index, out=out_add)
for i in range(num_iters):
g.replay()
torch.cuda.synchronize(0)
print("out_mean: ", out_mean)
print("out_max: ", out_max)
print("out_add: ", out_add)
print("============")
print("Param out is not given")
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for i in range(3):
out_mean = scatter_mean(src, index)
out_max, _ = scatter_max(src, index)
out_add = scatter_add(src, index)
torch.cuda.current_stream().wait_stream(s)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
out_mean = scatter_mean(src, index)
out_max, _ = scatter_max(src, index)
out_add = scatter_add(src, index)
for i in range(num_iters):
g.replay()
torch.cuda.synchronize(0)
print("out_mean: ", out_mean)
print("out_max: ", out_max)
print("out_add: ", out_add)
Output:
Param out is given
out_mean: tensor([[0.0000e+00, 0.0000e+00, 4.0120e+03, 3.0090e+03, 3.0000e+00, 0.0000e+00],
[2.0000e+00, 4.0120e+03, 4.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],
device='cuda:0')
out_max: tensor([[0., 0., 4., 3., 2., 0.],
[2., 4., 3., 0., 0., 0.]], device='cuda:0')
out_add: tensor([[ 0., 0., 4012., 3009., 3009., 0.],
[2006., 4012., 4012., 0., 0., 0.]], device='cuda:0')
============
Param out is not given
Traceback (most recent call last):
File "test.py", line 50, in <module>
out_mean = scatter_mean(src, index)
File "~/anaconda3/lib/python3.8/site-packages/torch_scatter/scatter.py", line 41, in scatter_mean
out = scatter_sum(src, index, dim, out, dim_size)
File "~/anaconda3/lib/python3.8/site-packages/torch_scatter/scatter.py", line 19, in scatter_sum
size[dim] = int(index.max()) + 1
RuntimeError: CUDA error: operation not permitted when stream is capturing
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "test.py", line 52, in <module>
out_add = scatter_add(src, index)
File "~/anaconda3/lib/python3.8/site-packages/torch/cuda/graphs.py", line 149, in __exit__
self.cuda_graph.capture_end()
File "~/anaconda3/lib/python3.8/site-packages/torch/cuda/graphs.py", line 71, in capture_end
super(CUDAGraph, self).capture_end()
RuntimeError: CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1