Skip to content

Possible bug in half-precision scatter #269

Open
@noahstier

Description

@noahstier

I am getting different results for manually computing the mean of a half-precision tensor using torch.mean vs. torch_scatter.scatter.

Is this expected behavior?

import torch
import torch_scatter

src_float = torch.randn(1_000_000).float()
src_half = src_float.half()
idx = torch.zeros(len(src_float), dtype=torch.long)

result_float = torch_scatter.scatter(src_float, idx, reduce='mean')
result_half = torch_scatter.scatter(src_half, idx, reduce='mean')

result_float_manual = torch.mean(src_float)
result_half_manual = torch.mean(src_half)

print(result_float)
print(result_float_manual)
print(result_half)
print(result_half_manual)

prints:

tensor([-0.0014])
tensor(-0.0014)
tensor([-0.7241], dtype=torch.float16)
tensor(-0., dtype=torch.float16)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions