diff --git a/RELEASENOTES.md b/RELEASENOTES.md
index 3e9c01a83..d9152664f 100644
--- a/RELEASENOTES.md
+++ b/RELEASENOTES.md
@@ -3,6 +3,10 @@
Releases, starting with 9/2/2021, are listed with the most recent release at the top.
# NuGet Version 0.105.1
+__Breaking Changes__:
+
+`torch.nn.functional.scaled_dot_product_attention`'s function signature has been changed. The `is_casual` argument has been renamed to `is_causal`.
+
__Bug Fixes__:
#1426 Sequential.eval() does not put model into eval mode
@@ -16,6 +20,7 @@ __API Changes__:
`torch.optim.lr_scheduler.PolynomialLR` `power` type has been corrected, is now double.
Returning an input tensor has been corrected, is now `alias()`.
Add `torchvision.transforms.Resize` `interpolation` and `antialias`.
+Add optional `scale` and `enable_gqa` arguments to `torch.nn.functional.scaled_dot_product_attention`.
# NuGet Version 0.105.0
diff --git a/src/Native/LibTorchSharp/THSNN.cpp b/src/Native/LibTorchSharp/THSNN.cpp
index 941399e62..0a56313d5 100644
--- a/src/Native/LibTorchSharp/THSNN.cpp
+++ b/src/Native/LibTorchSharp/THSNN.cpp
@@ -1064,9 +1064,10 @@ Tensor THSNN_unfold(const Tensor input, const int64_t kernel1, const int64_t ker
CATCH_TENSOR(torch::nn::functional::unfold(*input, opts));
}
-Tensor THSNN_scaled_dot_product_attention(const Tensor query, const Tensor key, const Tensor value, const Tensor attention_mask, double p, bool casual)
+Tensor THSNN_scaled_dot_product_attention(const Tensor query, const Tensor key, const Tensor value, const Tensor attention_mask, double p, bool is_causal, double* scale, bool enable_gqa)
{
auto mask = attention_mask == nullptr ? c10::nullopt : c10::optional(*attention_mask);
+ auto scl = (scale == nullptr) ? c10::nullopt : c10::optional(*scale);
- CATCH_TENSOR(torch::scaled_dot_product_attention(*query, *key, *value, mask, p, casual));
+ CATCH_TENSOR(torch::scaled_dot_product_attention(*query, *key, *value, mask, p, is_causal, scl, enable_gqa));
}
\ No newline at end of file
diff --git a/src/Native/LibTorchSharp/THSNN.h b/src/Native/LibTorchSharp/THSNN.h
index 3dab43f90..870a45237 100644
--- a/src/Native/LibTorchSharp/THSNN.h
+++ b/src/Native/LibTorchSharp/THSNN.h
@@ -228,7 +228,7 @@ EXPORT_API(Tensor) THSNN_cosine_similarity(const Tensor input1, const Tensor i
EXPORT_API(Tensor) THSNN_pairwise_distance(const Tensor input1, const Tensor input2, double p, double eps, bool keepdim);
-EXPORT_API(Tensor) THSNN_scaled_dot_product_attention(const Tensor query, const Tensor key, const Tensor value, const Tensor attention_mask, double p, bool casual);
+EXPORT_API(Tensor) THSNN_scaled_dot_product_attention(const Tensor query, const Tensor key, const Tensor value, const Tensor attention_mask, double p, bool is_causal, double* scale, bool enable_gqa);
// Initializers
diff --git a/src/TorchSharp/NN/Transformer.cs b/src/TorchSharp/NN/Transformer.cs
index d69ff96de..5fbf71074 100644
--- a/src/TorchSharp/NN/Transformer.cs
+++ b/src/TorchSharp/NN/Transformer.cs
@@ -107,15 +107,26 @@ public static partial class functional
/// A float mask of the same type as query, key, value that is added to the attention score.
///
/// Dropout probability
- /// If true, assumes causal attention masking and errors if both attn_mask and is_causal are set.
+ /// If true, assumes causal attention masking and errors if both attn_mask and is_causal are set.
+ /// Scaling factor applied prior to softmax. If null, 1/sqrt(E) is used.
+ /// If true, enable Group Query Attention
///
- public static Tensor scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask = null, double p = 0.0, [MarshalAs(UnmanagedType.U1)] bool is_casual = false)
+ public static Tensor scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask = null, double p = 0.0, [MarshalAs(UnmanagedType.U1)] bool is_causal = false, double? scale=null, bool enable_gqa=false)
{
if (p < 0) throw new ArgumentException("Dropout probability must be greater than or equal to zero.");
- if (is_casual && attn_mask is not null) throw new ArgumentException("Casual attention masking cannot pass a mask.");
- var res = THSNN_scaled_dot_product_attention(query.Handle, key.Handle, value.Handle, attn_mask is null ? IntPtr.Zero : attn_mask.Handle, p, is_casual);
- if (res == IntPtr.Zero) { torch.CheckForErrors(); }
- return new Tensor(res);
+ if (is_causal && attn_mask is not null) throw new ArgumentException("Casual attention masking cannot pass a mask.");
+ if (query.dim() < 2 || key.dim() < 2 || value.dim() < 2) throw new ArgumentException("Query, key, and value must have at least 2 dimensions.");
+ if (!enable_gqa && (query.size(1) != key.size(1) || query.size(1) != value.size(1))) throw new InvalidOperationException("Query and key/value heads must be equal when Group Query Attention is not enabled.");
+
+ var _scale = scale.HasValue ? new double[] { scale.Value } : null;
+
+ unsafe {
+ fixed (double* scalePtr = _scale) {
+ var res = THSNN_scaled_dot_product_attention(query.Handle, key.Handle, value.Handle, attn_mask is null ? IntPtr.Zero : attn_mask.Handle, p, is_causal, (IntPtr)scalePtr, enable_gqa);
+ if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ return new Tensor(res);
+ }
+ }
}
}
}
diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs
index 2b97b61eb..5f1b83c0c 100644
--- a/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs
+++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs
@@ -552,7 +552,7 @@ internal static extern IntPtr THSNN_custom_module(
internal static extern IntPtr THSNN_ConvTranspose2d_ctor_1(long inputChannel, long outputChannel, long kernelSizeX, long kernelSizeY, long strideX, long strideY, long paddingX, long paddingY, long outputPaddingX, long outputPaddingY, long dilationX, long dilationY, long paddingMode, long groups, [MarshalAs(UnmanagedType.U1)] bool bias, out IntPtr pBoxedModule);
[DllImport("LibTorchSharp")]
- internal static extern IntPtr THSNN_scaled_dot_product_attention(IntPtr query, IntPtr key, IntPtr value, IntPtr attention_mask, double p, [MarshalAs(UnmanagedType.U1)] bool casual);
+ internal static extern IntPtr THSNN_scaled_dot_product_attention(IntPtr query, IntPtr key, IntPtr value, IntPtr attention_mask, double p, [MarshalAs(UnmanagedType.U1)] bool is_causal, IntPtr scale, [MarshalAs(UnmanagedType.U1)] bool enable_gqa);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_Softshrink_forward(torch.nn.Module.HType module, IntPtr tensor);
diff --git a/test/TorchSharpTest/NN.cs b/test/TorchSharpTest/NN.cs
index 86e339d7f..9cdd5ae6c 100644
--- a/test/TorchSharpTest/NN.cs
+++ b/test/TorchSharpTest/NN.cs
@@ -5307,7 +5307,43 @@ public void TestScaledDotProductWithMask()
Assert.Equal(query.shape, x.shape);
Assert.Equal(value, x);
- Assert.Throws(() => torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask: mask, is_casual: true));
+ Assert.Throws(() => torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask: mask, is_causal: true));
+ }
+ [Fact]
+ public void TestScaledDotProductWithScale()
+ {
+
+ var query = torch.rand(32, 8, 128, 64) * 0.25;
+ var key = torch.rand(32, 8, 128, 64) * 0.5;
+ var value = torch.rand(32, 8, 128, 64) * 0.125;
+ var customScale = 0.5;
+
+ var defaultOutput = torch.nn.functional.scaled_dot_product_attention(query, key, value);
+ var withCustomScale = torch.nn.functional.scaled_dot_product_attention(query, key, value, scale: customScale);
+
+ Assert.Equal(query.shape, withCustomScale.shape);
+ Assert.False(torch.allclose(defaultOutput, withCustomScale, rtol: 1e-5, atol: 1e-5));
+ }
+
+ [Fact]
+ public void TestScaledDotProductWithGQA()
+ {
+ var batchSize = 2;
+ var queryHeads = 8;
+ var kvHeads = 2; // Key/value heads should be less than query heads for GQA
+ var seqLen = 16;
+ var headDim = 64;
+
+ var query = torch.ones(batchSize, queryHeads, seqLen, headDim) * 0.25;
+ var key = torch.ones(batchSize, kvHeads, seqLen, headDim) * 0.5;
+ var value = torch.ones(batchSize, kvHeads, seqLen, headDim) * 0.125;
+
+ Assert.Throws(() =>
+ torch.nn.functional.scaled_dot_product_attention(query, key, value, enable_gqa: false));
+
+ var output = torch.nn.functional.scaled_dot_product_attention(query, key, value, enable_gqa: true);
+
+ Assert.Equal(query.shape, output.shape);
}
[Fact]