Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PScan gradient is not correct? #65

Open
andrewwarrington opened this issue Dec 23, 2024 · 1 comment
Open

PScan gradient is not correct? #65

andrewwarrington opened this issue Dec 23, 2024 · 1 comment

Comments

@andrewwarrington
Copy link

andrewwarrington commented Dec 23, 2024

Thanks for the awesome repo!

I am trying to use the PScan functionality. I wanted to verify (just for my own edification) that the forward and gradient results were the same. The forward test passes, but I don't think the gradients are correct? I have included my repro at the bottom below.

In short:

  1. Defines a (deterministic/seeded) data generating procedure.
  2. Define a sequential scan.
  3. Defines some parameters for generating data.
  4. Test the forward results are allclose.
  5. Test the gradients are allclose (using sum().abs() as the "loss" function to create a real scalar value to take the gradient of).

The forward test passes, but then it fails on the indicated line.

The pscan version I'm using a fresh clone imported directly from pscan.py.

Please advise! :)
Andy

import torch
import torch.nn.functional as F
from torch import Tensor
import numpy as np

from pscan import pscan


# 1. 
def gen_data(key=0):
    rng = torch.Generator()
    rng.manual_seed(key)
    rand = lambda shape: torch.rand(shape, generator=rng)
    A = torch.nn.Parameter(rand((B, L, D, N)) + 1j * rand((B, L, D, N)), requires_grad=True).to(DEVICE)
    X = torch.nn.Parameter(rand((B, L, D, N)) + 1j * rand((B, L, D, N)), requires_grad=True).to(DEVICE)
    return A, X


# 2. 
def sscan(A_in: Tensor, X_in: Tensor) -> Tensor:
    """
    Applies the sequential equivilant of the parallel scan operation, as defined above. 
    Returns a new tensor.  Supports forwards and backwards natively in pure PyTorch.

    Mainly for sanity checking and fallbacks for testing / in cases of limited memory.

    Args:
        A_in :  (B, L, D, N)
        X_in :  (B, L, D, N)

    Returns:
        H : (B, L, D, N)
    """

    H = torch.zeros_like(X_in)
    h = torch.zeros_like(X_in[:, 0])

    for l in range(X_in.size(1)):
        h = A_in[:, l] * h + X_in[:, l]
        H[:, l] = h

    return H


# 3.
(B, L, D, N) = (4, 64, 8, 12)
atol = 1e-6
rtol = 1e-6
DEVICE = "cpu"
data_key = 0


# 4. Test the forward pass.
A, X = gen_data(data_key)
results_par = pscan(A, X)
results_seq = sscan(A, X)
assert torch.allclose(results_par, results_seq, atol=atol, rtol=rtol), "Forward results are not equal!"


# 5. Now test the gradient.
A, X = gen_data(data_key)
pscan(A, X).sum().abs().backward()
print(A._grad.ravel()[-1])  # Quick check.
A_grad_par = A._grad.detach().clone().numpy()
X_grad_par = X._grad.detach().clone().numpy()

A, X = gen_data(data_key)
sscan(A, X).sum().abs().backward()
print(A._grad.ravel()[-1])  # Quick check.
A_grad_seq = A._grad.detach().clone().numpy()
X_grad_seq = X._grad.detach().clone().numpy()

assert np.allclose(A_grad_par, A_grad_seq, atol=atol, rtol=rtol), "A grads are not equal!"  # <FAILS HERE>
assert np.allclose(X_grad_par, X_grad_seq, atol=atol, rtol=rtol), "X grads are not equal!"
@andrewwarrington
Copy link
Author

Hello @alxndrTL , was wondering if there was any resolution to this? Thanks! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant