Skip to content

Update functional.scaled_dot_product_attention signature/arguments #1473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions RELEASENOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
Releases, starting with 9/2/2021, are listed with the most recent release at the top.
# NuGet Version 0.105.1

__Breaking Changes__:

`torch.nn.functional.scaled_dot_product_attention`'s function signature has been changed. The `is_casual` argument has been renamed to `is_causal`.<br/>

__Bug Fixes__:

#1426 Sequential.eval() does not put model into eval mode<br/>
Expand All @@ -16,6 +20,7 @@ __API Changes__:
`torch.optim.lr_scheduler.PolynomialLR` `power` type has been corrected, is now double.<br/>
Returning an input tensor has been corrected, is now `alias()`.<br/>
Add `torchvision.transforms.Resize` `interpolation` and `antialias`.<br />
Add optional `scale` and `enable_gqa` arguments to `torch.nn.functional.scaled_dot_product_attention`.<br/>

# NuGet Version 0.105.0

Expand Down
5 changes: 3 additions & 2 deletions src/Native/LibTorchSharp/THSNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1064,9 +1064,10 @@ Tensor THSNN_unfold(const Tensor input, const int64_t kernel1, const int64_t ker
CATCH_TENSOR(torch::nn::functional::unfold(*input, opts));
}

Tensor THSNN_scaled_dot_product_attention(const Tensor query, const Tensor key, const Tensor value, const Tensor attention_mask, double p, bool casual)
Tensor THSNN_scaled_dot_product_attention(const Tensor query, const Tensor key, const Tensor value, const Tensor attention_mask, double p, bool is_causal, double* scale, bool enable_gqa)
{
auto mask = attention_mask == nullptr ? c10::nullopt : c10::optional<at::Tensor>(*attention_mask);
auto scl = (scale == nullptr) ? c10::nullopt : c10::optional<double>(*scale);

CATCH_TENSOR(torch::scaled_dot_product_attention(*query, *key, *value, mask, p, casual));
CATCH_TENSOR(torch::scaled_dot_product_attention(*query, *key, *value, mask, p, is_causal, scl, enable_gqa));
}
2 changes: 1 addition & 1 deletion src/Native/LibTorchSharp/THSNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ EXPORT_API(Tensor) THSNN_cosine_similarity(const Tensor input1, const Tensor i

EXPORT_API(Tensor) THSNN_pairwise_distance(const Tensor input1, const Tensor input2, double p, double eps, bool keepdim);

EXPORT_API(Tensor) THSNN_scaled_dot_product_attention(const Tensor query, const Tensor key, const Tensor value, const Tensor attention_mask, double p, bool casual);
EXPORT_API(Tensor) THSNN_scaled_dot_product_attention(const Tensor query, const Tensor key, const Tensor value, const Tensor attention_mask, double p, bool is_causal, double* scale, bool enable_gqa);

// Initializers

Expand Down
23 changes: 17 additions & 6 deletions src/TorchSharp/NN/Transformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,26 @@ public static partial class functional
/// A float mask of the same type as query, key, value that is added to the attention score.
/// </param>
/// <param name="p">Dropout probability</param>
/// <param name="is_casual">If true, assumes causal attention masking and errors if both attn_mask and is_causal are set.</param>
/// <param name="is_causal">If true, assumes causal attention masking and errors if both attn_mask and is_causal are set.</param>
/// <param name="scale">Scaling factor applied prior to softmax. If null, 1/sqrt(E) is used.</param>
/// <param name="enable_gqa">If true, enable Group Query Attention</param>
/// <returns></returns>
public static Tensor scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask = null, double p = 0.0, [MarshalAs(UnmanagedType.U1)] bool is_casual = false)
public static Tensor scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask = null, double p = 0.0, [MarshalAs(UnmanagedType.U1)] bool is_causal = false, double? scale=null, bool enable_gqa=false)
{
if (p < 0) throw new ArgumentException("Dropout probability must be greater than or equal to zero.");
if (is_casual && attn_mask is not null) throw new ArgumentException("Casual attention masking cannot pass a mask.");
var res = THSNN_scaled_dot_product_attention(query.Handle, key.Handle, value.Handle, attn_mask is null ? IntPtr.Zero : attn_mask.Handle, p, is_casual);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new Tensor(res);
if (is_causal && attn_mask is not null) throw new ArgumentException("Casual attention masking cannot pass a mask.");
if (query.dim() < 2 || key.dim() < 2 || value.dim() < 2) throw new ArgumentException("Query, key, and value must have at least 2 dimensions.");
if (!enable_gqa && (query.size(1) != key.size(1) || query.size(1) != value.size(1))) throw new InvalidOperationException("Query and key/value heads must be equal when Group Query Attention is not enabled.");

var _scale = scale.HasValue ? new double[] { scale.Value } : null;

unsafe {
fixed (double* scalePtr = _scale) {
var res = THSNN_scaled_dot_product_attention(query.Handle, key.Handle, value.Handle, attn_mask is null ? IntPtr.Zero : attn_mask.Handle, p, is_causal, (IntPtr)scalePtr, enable_gqa);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new Tensor(res);
}
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ internal static extern IntPtr THSNN_custom_module(
internal static extern IntPtr THSNN_ConvTranspose2d_ctor_1(long inputChannel, long outputChannel, long kernelSizeX, long kernelSizeY, long strideX, long strideY, long paddingX, long paddingY, long outputPaddingX, long outputPaddingY, long dilationX, long dilationY, long paddingMode, long groups, [MarshalAs(UnmanagedType.U1)] bool bias, out IntPtr pBoxedModule);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_scaled_dot_product_attention(IntPtr query, IntPtr key, IntPtr value, IntPtr attention_mask, double p, [MarshalAs(UnmanagedType.U1)] bool casual);
internal static extern IntPtr THSNN_scaled_dot_product_attention(IntPtr query, IntPtr key, IntPtr value, IntPtr attention_mask, double p, [MarshalAs(UnmanagedType.U1)] bool is_causal, IntPtr scale, [MarshalAs(UnmanagedType.U1)] bool enable_gqa);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_Softshrink_forward(torch.nn.Module.HType module, IntPtr tensor);
Expand Down
38 changes: 37 additions & 1 deletion test/TorchSharpTest/NN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5307,7 +5307,43 @@ public void TestScaledDotProductWithMask()
Assert.Equal(query.shape, x.shape);
Assert.Equal(value, x);

Assert.Throws<ArgumentException>(() => torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask: mask, is_casual: true));
Assert.Throws<ArgumentException>(() => torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask: mask, is_causal: true));
}
[Fact]
public void TestScaledDotProductWithScale()
{

var query = torch.rand(32, 8, 128, 64) * 0.25;
var key = torch.rand(32, 8, 128, 64) * 0.5;
var value = torch.rand(32, 8, 128, 64) * 0.125;
var customScale = 0.5;

var defaultOutput = torch.nn.functional.scaled_dot_product_attention(query, key, value);
var withCustomScale = torch.nn.functional.scaled_dot_product_attention(query, key, value, scale: customScale);

Assert.Equal(query.shape, withCustomScale.shape);
Assert.False(torch.allclose(defaultOutput, withCustomScale, rtol: 1e-5, atol: 1e-5));
}

[Fact]
public void TestScaledDotProductWithGQA()
{
var batchSize = 2;
var queryHeads = 8;
var kvHeads = 2; // Key/value heads should be less than query heads for GQA
var seqLen = 16;
var headDim = 64;

var query = torch.ones(batchSize, queryHeads, seqLen, headDim) * 0.25;
var key = torch.ones(batchSize, kvHeads, seqLen, headDim) * 0.5;
var value = torch.ones(batchSize, kvHeads, seqLen, headDim) * 0.125;

Assert.Throws<InvalidOperationException>(() =>
torch.nn.functional.scaled_dot_product_attention(query, key, value, enable_gqa: false));

var output = torch.nn.functional.scaled_dot_product_attention(query, key, value, enable_gqa: true);

Assert.Equal(query.shape, output.shape);
}

[Fact]
Expand Down