Skip to content

PyTorch1.10+cu113 CUDA Graph for torch_scatter raises RuntimeError w/o argument out #244

Open
@liulixinkerry

Description

@liulixinkerry

Dear authors,
I find that if we use torch_scatter in CUDA graph. The output memory should be pre-allocated (the param out) in scatter_mean, scatter_max, scatter_add, etc.
Meanwhile, since scatter_softmax does not support out argument, scatter_softmax + CUDA Graph always raises the same RuntimeError as scatter_mean w/o given out.

from torch_scatter import scatter_mean, scatter_max, scatter_add
import torch
num_iters = 1000

device = torch.device("cuda:0")

src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]).to(device)
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]).to(device)

print("Param out is given")
out_mean = src.new_zeros((2, 6))
out_max = src.new_zeros((2, 6))
out_add = src.new_zeros((2, 6))
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for i in range(3):
        out_mean = scatter_mean(src, index, out=out_mean)
        out_max, _ = scatter_max(src, index, out=out_max)
        out_add = scatter_add(src, index, out=out_add)
torch.cuda.current_stream().wait_stream(s)

g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
    out_mean = scatter_mean(src, index, out=out_mean)
    out_max, _ = scatter_max(src, index, out=out_max)
    out_add = scatter_add(src, index, out=out_add)

for i in range(num_iters):
    g.replay()
torch.cuda.synchronize(0)
print("out_mean: ", out_mean)
print("out_max: ", out_max)
print("out_add: ", out_add)

print("============")
print("Param out is not given")
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for i in range(3):
        out_mean = scatter_mean(src, index)
        out_max, _ = scatter_max(src, index)
        out_add = scatter_add(src, index)
torch.cuda.current_stream().wait_stream(s)

g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
    out_mean = scatter_mean(src, index)
    out_max, _ = scatter_max(src, index)
    out_add = scatter_add(src, index)

for i in range(num_iters):
    g.replay()
torch.cuda.synchronize(0)
print("out_mean: ", out_mean)
print("out_max: ", out_max)
print("out_add: ", out_add)

Output:

Param out is given
out_mean:  tensor([[0.0000e+00, 0.0000e+00, 4.0120e+03, 3.0090e+03, 3.0000e+00, 0.0000e+00],
        [2.0000e+00, 4.0120e+03, 4.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],
       device='cuda:0')
out_max:  tensor([[0., 0., 4., 3., 2., 0.],
        [2., 4., 3., 0., 0., 0.]], device='cuda:0')
out_add:  tensor([[   0.,    0., 4012., 3009., 3009.,    0.],
        [2006., 4012., 4012.,    0.,    0.,    0.]], device='cuda:0')
============
Param out is not given
Traceback (most recent call last):
  File "test.py", line 50, in <module>
    out_mean = scatter_mean(src, index)
  File "~/anaconda3/lib/python3.8/site-packages/torch_scatter/scatter.py", line 41, in scatter_mean
    out = scatter_sum(src, index, dim, out, dim_size)
  File "~/anaconda3/lib/python3.8/site-packages/torch_scatter/scatter.py", line 19, in scatter_sum
    size[dim] = int(index.max()) + 1
RuntimeError: CUDA error: operation not permitted when stream is capturing
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "test.py", line 52, in <module>
    out_add = scatter_add(src, index)
  File "~/anaconda3/lib/python3.8/site-packages/torch/cuda/graphs.py", line 149, in __exit__
    self.cuda_graph.capture_end()
  File "~/anaconda3/lib/python3.8/site-packages/torch/cuda/graphs.py", line 71, in capture_end
    super(CUDAGraph, self).capture_end()
RuntimeError: CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1

Metadata

Metadata

Assignees

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions