Description
Hi @rusty1s,
Thanks again for your great work on this library!
I am currently experimenting with computing second-order derivatives that involve torch_scatter operations, and noticed that the segment_coo
and segment_csr
operators are not twice differentiable with the sum
reduction. To reproduce this behavior, see e.g.:
import torch
import torch_scatter
# Values:
val = torch.FloatTensor([[0, 1, 2]])
# Groups:
gr_coo = torch.LongTensor([[0, 0, 1]])
gr_csr = torch.LongTensor([[0, 2, 3]])
val.requires_grad = True
B, D = val.shape
def group_reduce(*, values, groups, reduction, output_size, backend):
if backend == "torch":
# Compatibility switch for PyTorch.scatter_reduce:
if reduction == "max":
reduction = "amax"
return torch.scatter_reduce(
values, 1, groups, reduction, output_size=output_size
)
elif backend == "pyg":
return torch_scatter.scatter(
values, groups, dim=1, dim_size=output_size, reduce=reduction
)
elif backend == "coo":
return torch_scatter.segment_coo(
values, groups, dim_size=output_size, reduce=reduction
)
elif backend == "csr":
return torch_scatter.segment_csr(values, groups, reduce=reduction)
else:
raise ValueError(
f"Invalid value for the scatter backend ({backend}), "
"should be one of 'torch', 'pyg', 'coo' or 'csr'."
)
for backend in ["torch", "pyg", "coo", "csr"]:
red = group_reduce(
values=val,
groups=gr_csr if backend == "csr" else gr_coo,
reduction="sum",
output_size=2,
backend=backend,
)
# Compute an arbitrary scalar value out of our reduction:
v = (red ** 2).sum(-1) + 0.0 * (val ** 2).sum(-1)
# Gradient:
g = torch.autograd.grad(v.sum(), [val], create_graph=True)[0]
# Hessian:
h = torch.zeros(B, D, D).type_as(val)
for d in range(D):
h[:, d, :] = torch.autograd.grad(g[:, d].sum(), [val], retain_graph=True)[0]
print(backend, ":")
print("Value:", v.detach().numpy())
print("Grad :", g.detach().numpy())
print("Hessian:")
print(h.detach().numpy())
print("--------------")
The output shows that torch_scatter.scatter
and torch.scatter_reduce
coincide on all derivatives, while the two segment_*
implementations have a Null derivative at order 2:
torch :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[2. 2. 0.]
[2. 2. 0.]
[0. 0. 2.]]]
--------------
pyg :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[2. 2. 0.]
[2. 2. 0.]
[0. 0. 2.]]]
--------------
coo :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]]
--------------
csr :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]]
--------------
Is this expected behavior on your side?
Support for order-two derivatives would be especially useful to perform e.g. Newton optimization.
I'm sure that I could hack something with a torch.autograd.Function
wrapper for my own use-case, but a proper fix would certainly be useful to other people. Unfortunately, I am not familiar enough with the PyTorch C++ API to fix e.g. segment_csr.cpp myself and write a Pull Request for this :-(
What do you think?