You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to use the PScan functionality. I wanted to verify (just for my own edification) that the forward and gradient results were the same. The forward test passes, but I don't think the gradients are correct? I have included my repro at the bottom below.
In short:
Defines a (deterministic/seeded) data generating procedure.
Define a sequential scan.
Defines some parameters for generating data.
Test the forward results are allclose.
Test the gradients are allclose (using sum().abs() as the "loss" function to create a real scalar value to take the gradient of).
The forward test passes, but then it fails on the indicated line.
The pscan version I'm using a fresh clone imported directly from pscan.py.
Please advise! :)
Andy
import torch
import torch.nn.functional as F
from torch import Tensor
import numpy as np
from pscan import pscan
# 1.
def gen_data(key=0):
rng = torch.Generator()
rng.manual_seed(key)
rand = lambda shape: torch.rand(shape, generator=rng)
A = torch.nn.Parameter(rand((B, L, D, N)) + 1j * rand((B, L, D, N)), requires_grad=True).to(DEVICE)
X = torch.nn.Parameter(rand((B, L, D, N)) + 1j * rand((B, L, D, N)), requires_grad=True).to(DEVICE)
return A, X
# 2.
def sscan(A_in: Tensor, X_in: Tensor) -> Tensor:
"""
Applies the sequential equivilant of the parallel scan operation, as defined above.
Returns a new tensor. Supports forwards and backwards natively in pure PyTorch.
Mainly for sanity checking and fallbacks for testing / in cases of limited memory.
Args:
A_in : (B, L, D, N)
X_in : (B, L, D, N)
Returns:
H : (B, L, D, N)
"""
H = torch.zeros_like(X_in)
h = torch.zeros_like(X_in[:, 0])
for l in range(X_in.size(1)):
h = A_in[:, l] * h + X_in[:, l]
H[:, l] = h
return H
# 3.
(B, L, D, N) = (4, 64, 8, 12)
atol = 1e-6
rtol = 1e-6
DEVICE = "cpu"
data_key = 0
# 4. Test the forward pass.
A, X = gen_data(data_key)
results_par = pscan(A, X)
results_seq = sscan(A, X)
assert torch.allclose(results_par, results_seq, atol=atol, rtol=rtol), "Forward results are not equal!"
# 5. Now test the gradient.
A, X = gen_data(data_key)
pscan(A, X).sum().abs().backward()
print(A._grad.ravel()[-1]) # Quick check.
A_grad_par = A._grad.detach().clone().numpy()
X_grad_par = X._grad.detach().clone().numpy()
A, X = gen_data(data_key)
sscan(A, X).sum().abs().backward()
print(A._grad.ravel()[-1]) # Quick check.
A_grad_seq = A._grad.detach().clone().numpy()
X_grad_seq = X._grad.detach().clone().numpy()
assert np.allclose(A_grad_par, A_grad_seq, atol=atol, rtol=rtol), "A grads are not equal!" # <FAILS HERE>
assert np.allclose(X_grad_par, X_grad_seq, atol=atol, rtol=rtol), "X grads are not equal!"
The text was updated successfully, but these errors were encountered:
Thanks for the awesome repo!
I am trying to use the PScan functionality. I wanted to verify (just for my own edification) that the forward and gradient results were the same. The forward test passes, but I don't think the gradients are correct? I have included my repro at the bottom below.
In short:
allclose
.allclose
(usingsum().abs()
as the "loss" function to create a real scalar value to take the gradient of).The forward test passes, but then it fails on the indicated line.
The pscan version I'm using a fresh clone imported directly from
pscan.py
.Please advise! :)
Andy
The text was updated successfully, but these errors were encountered: