diff --git a/RELEASENOTES.md b/RELEASENOTES.md index 3e9c01a83..d9152664f 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -3,6 +3,10 @@ Releases, starting with 9/2/2021, are listed with the most recent release at the top. # NuGet Version 0.105.1 +__Breaking Changes__: + +`torch.nn.functional.scaled_dot_product_attention`'s function signature has been changed. The `is_casual` argument has been renamed to `is_causal`.
+ __Bug Fixes__: #1426 Sequential.eval() does not put model into eval mode
@@ -16,6 +20,7 @@ __API Changes__: `torch.optim.lr_scheduler.PolynomialLR` `power` type has been corrected, is now double.
Returning an input tensor has been corrected, is now `alias()`.
Add `torchvision.transforms.Resize` `interpolation` and `antialias`.
+Add optional `scale` and `enable_gqa` arguments to `torch.nn.functional.scaled_dot_product_attention`.
# NuGet Version 0.105.0 diff --git a/src/Native/LibTorchSharp/THSNN.cpp b/src/Native/LibTorchSharp/THSNN.cpp index 941399e62..0a56313d5 100644 --- a/src/Native/LibTorchSharp/THSNN.cpp +++ b/src/Native/LibTorchSharp/THSNN.cpp @@ -1064,9 +1064,10 @@ Tensor THSNN_unfold(const Tensor input, const int64_t kernel1, const int64_t ker CATCH_TENSOR(torch::nn::functional::unfold(*input, opts)); } -Tensor THSNN_scaled_dot_product_attention(const Tensor query, const Tensor key, const Tensor value, const Tensor attention_mask, double p, bool casual) +Tensor THSNN_scaled_dot_product_attention(const Tensor query, const Tensor key, const Tensor value, const Tensor attention_mask, double p, bool is_causal, double* scale, bool enable_gqa) { auto mask = attention_mask == nullptr ? c10::nullopt : c10::optional(*attention_mask); + auto scl = (scale == nullptr) ? c10::nullopt : c10::optional(*scale); - CATCH_TENSOR(torch::scaled_dot_product_attention(*query, *key, *value, mask, p, casual)); + CATCH_TENSOR(torch::scaled_dot_product_attention(*query, *key, *value, mask, p, is_causal, scl, enable_gqa)); } \ No newline at end of file diff --git a/src/Native/LibTorchSharp/THSNN.h b/src/Native/LibTorchSharp/THSNN.h index 3dab43f90..870a45237 100644 --- a/src/Native/LibTorchSharp/THSNN.h +++ b/src/Native/LibTorchSharp/THSNN.h @@ -228,7 +228,7 @@ EXPORT_API(Tensor) THSNN_cosine_similarity(const Tensor input1, const Tensor i EXPORT_API(Tensor) THSNN_pairwise_distance(const Tensor input1, const Tensor input2, double p, double eps, bool keepdim); -EXPORT_API(Tensor) THSNN_scaled_dot_product_attention(const Tensor query, const Tensor key, const Tensor value, const Tensor attention_mask, double p, bool casual); +EXPORT_API(Tensor) THSNN_scaled_dot_product_attention(const Tensor query, const Tensor key, const Tensor value, const Tensor attention_mask, double p, bool is_causal, double* scale, bool enable_gqa); // Initializers diff --git a/src/TorchSharp/NN/Transformer.cs b/src/TorchSharp/NN/Transformer.cs index d69ff96de..5fbf71074 100644 --- a/src/TorchSharp/NN/Transformer.cs +++ b/src/TorchSharp/NN/Transformer.cs @@ -107,15 +107,26 @@ public static partial class functional /// A float mask of the same type as query, key, value that is added to the attention score. /// /// Dropout probability - /// If true, assumes causal attention masking and errors if both attn_mask and is_causal are set. + /// If true, assumes causal attention masking and errors if both attn_mask and is_causal are set. + /// Scaling factor applied prior to softmax. If null, 1/sqrt(E) is used. + /// If true, enable Group Query Attention /// - public static Tensor scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask = null, double p = 0.0, [MarshalAs(UnmanagedType.U1)] bool is_casual = false) + public static Tensor scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask = null, double p = 0.0, [MarshalAs(UnmanagedType.U1)] bool is_causal = false, double? scale=null, bool enable_gqa=false) { if (p < 0) throw new ArgumentException("Dropout probability must be greater than or equal to zero."); - if (is_casual && attn_mask is not null) throw new ArgumentException("Casual attention masking cannot pass a mask."); - var res = THSNN_scaled_dot_product_attention(query.Handle, key.Handle, value.Handle, attn_mask is null ? IntPtr.Zero : attn_mask.Handle, p, is_casual); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); + if (is_causal && attn_mask is not null) throw new ArgumentException("Casual attention masking cannot pass a mask."); + if (query.dim() < 2 || key.dim() < 2 || value.dim() < 2) throw new ArgumentException("Query, key, and value must have at least 2 dimensions."); + if (!enable_gqa && (query.size(1) != key.size(1) || query.size(1) != value.size(1))) throw new InvalidOperationException("Query and key/value heads must be equal when Group Query Attention is not enabled."); + + var _scale = scale.HasValue ? new double[] { scale.Value } : null; + + unsafe { + fixed (double* scalePtr = _scale) { + var res = THSNN_scaled_dot_product_attention(query.Handle, key.Handle, value.Handle, attn_mask is null ? IntPtr.Zero : attn_mask.Handle, p, is_causal, (IntPtr)scalePtr, enable_gqa); + if (res == IntPtr.Zero) { torch.CheckForErrors(); } + return new Tensor(res); + } + } } } } diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs index 2b97b61eb..5f1b83c0c 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs @@ -552,7 +552,7 @@ internal static extern IntPtr THSNN_custom_module( internal static extern IntPtr THSNN_ConvTranspose2d_ctor_1(long inputChannel, long outputChannel, long kernelSizeX, long kernelSizeY, long strideX, long strideY, long paddingX, long paddingY, long outputPaddingX, long outputPaddingY, long dilationX, long dilationY, long paddingMode, long groups, [MarshalAs(UnmanagedType.U1)] bool bias, out IntPtr pBoxedModule); [DllImport("LibTorchSharp")] - internal static extern IntPtr THSNN_scaled_dot_product_attention(IntPtr query, IntPtr key, IntPtr value, IntPtr attention_mask, double p, [MarshalAs(UnmanagedType.U1)] bool casual); + internal static extern IntPtr THSNN_scaled_dot_product_attention(IntPtr query, IntPtr key, IntPtr value, IntPtr attention_mask, double p, [MarshalAs(UnmanagedType.U1)] bool is_causal, IntPtr scale, [MarshalAs(UnmanagedType.U1)] bool enable_gqa); [DllImport("LibTorchSharp")] internal static extern IntPtr THSNN_Softshrink_forward(torch.nn.Module.HType module, IntPtr tensor); diff --git a/test/TorchSharpTest/NN.cs b/test/TorchSharpTest/NN.cs index 86e339d7f..9cdd5ae6c 100644 --- a/test/TorchSharpTest/NN.cs +++ b/test/TorchSharpTest/NN.cs @@ -5307,7 +5307,43 @@ public void TestScaledDotProductWithMask() Assert.Equal(query.shape, x.shape); Assert.Equal(value, x); - Assert.Throws(() => torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask: mask, is_casual: true)); + Assert.Throws(() => torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask: mask, is_causal: true)); + } + [Fact] + public void TestScaledDotProductWithScale() + { + + var query = torch.rand(32, 8, 128, 64) * 0.25; + var key = torch.rand(32, 8, 128, 64) * 0.5; + var value = torch.rand(32, 8, 128, 64) * 0.125; + var customScale = 0.5; + + var defaultOutput = torch.nn.functional.scaled_dot_product_attention(query, key, value); + var withCustomScale = torch.nn.functional.scaled_dot_product_attention(query, key, value, scale: customScale); + + Assert.Equal(query.shape, withCustomScale.shape); + Assert.False(torch.allclose(defaultOutput, withCustomScale, rtol: 1e-5, atol: 1e-5)); + } + + [Fact] + public void TestScaledDotProductWithGQA() + { + var batchSize = 2; + var queryHeads = 8; + var kvHeads = 2; // Key/value heads should be less than query heads for GQA + var seqLen = 16; + var headDim = 64; + + var query = torch.ones(batchSize, queryHeads, seqLen, headDim) * 0.25; + var key = torch.ones(batchSize, kvHeads, seqLen, headDim) * 0.5; + var value = torch.ones(batchSize, kvHeads, seqLen, headDim) * 0.125; + + Assert.Throws(() => + torch.nn.functional.scaled_dot_product_attention(query, key, value, enable_gqa: false)); + + var output = torch.nn.functional.scaled_dot_product_attention(query, key, value, enable_gqa: true); + + Assert.Equal(query.shape, output.shape); } [Fact]