Skip to content

Empty second-order derivative (= Hessian) for the segment_* reductions #299

Open
@jeanfeydy

Description

@jeanfeydy

Hi @rusty1s,
Thanks again for your great work on this library!

I am currently experimenting with computing second-order derivatives that involve torch_scatter operations, and noticed that the segment_coo and segment_csr operators are not twice differentiable with the sum reduction. To reproduce this behavior, see e.g.:

import torch
import torch_scatter

# Values:
val = torch.FloatTensor([[0, 1, 2]])
# Groups:
gr_coo = torch.LongTensor([[0, 0, 1]])
gr_csr = torch.LongTensor([[0, 2, 3]])

val.requires_grad = True
B, D = val.shape


def group_reduce(*, values, groups, reduction, output_size, backend):

    if backend == "torch":
        # Compatibility switch for PyTorch.scatter_reduce:
        if reduction == "max":
            reduction = "amax"
        return torch.scatter_reduce(
            values, 1, groups, reduction, output_size=output_size
        )
    elif backend == "pyg":
        return torch_scatter.scatter(
            values, groups, dim=1, dim_size=output_size, reduce=reduction
        )
    elif backend == "coo":
        return torch_scatter.segment_coo(
            values, groups, dim_size=output_size, reduce=reduction
        )
    elif backend == "csr":
        return torch_scatter.segment_csr(values, groups, reduce=reduction)
    else:
        raise ValueError(
            f"Invalid value for the scatter backend ({backend}), "
            "should be one of 'torch', 'pyg', 'coo' or 'csr'."
        )


for backend in ["torch", "pyg", "coo", "csr"]:

    red = group_reduce(
        values=val,
        groups=gr_csr if backend == "csr" else gr_coo,
        reduction="sum",
        output_size=2,
        backend=backend,
    )

    # Compute an arbitrary scalar value out of our reduction:
    v = (red ** 2).sum(-1) + 0.0 * (val ** 2).sum(-1)
    # Gradient:
    g = torch.autograd.grad(v.sum(), [val], create_graph=True)[0]
    # Hessian:
    h = torch.zeros(B, D, D).type_as(val)
    for d in range(D):
        h[:, d, :] = torch.autograd.grad(g[:, d].sum(), [val], retain_graph=True)[0]

    print(backend, ":")
    print("Value:", v.detach().numpy())
    print("Grad :", g.detach().numpy())
    print("Hessian:")
    print(h.detach().numpy())
    print("--------------")

The output shows that torch_scatter.scatter and torch.scatter_reduce coincide on all derivatives, while the two segment_* implementations have a Null derivative at order 2:

torch :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[2. 2. 0.]
  [2. 2. 0.]
  [0. 0. 2.]]]
--------------
pyg :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[2. 2. 0.]
  [2. 2. 0.]
  [0. 0. 2.]]]
--------------
coo :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]]
--------------
csr :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]]
--------------

Is this expected behavior on your side?
Support for order-two derivatives would be especially useful to perform e.g. Newton optimization.

I'm sure that I could hack something with a torch.autograd.Function wrapper for my own use-case, but a proper fix would certainly be useful to other people. Unfortunately, I am not familiar enough with the PyTorch C++ API to fix e.g. segment_csr.cpp myself and write a Pull Request for this :-(

What do you think?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions