Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config)
SplitMode = config?.SplitMode ?? LLama.Native.GPUSplitMode.Layer,
BatchSize = 512,
UBatchSize = 512,
FlashAttention = true,
UseMemorymap = true,
PoolingType = LLamaPoolingType.Mean,
};
Expand Down Expand Up @@ -68,7 +67,6 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config, LLamaWeights we
SplitMode = config?.SplitMode ?? LLama.Native.GPUSplitMode.Layer,
BatchSize = 512,
UBatchSize = 512,
FlashAttention = true,
UseMemorymap = true,
PoolingType = LLamaPoolingType.Mean,
};
Expand Down
2 changes: 0 additions & 2 deletions LLama.KernelMemory/LlamaSharpTextGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ public LlamaSharpTextGenerator(LLamaSharpConfig config)
SplitMode = config?.SplitMode ?? LLama.Native.GPUSplitMode.Layer,
BatchSize = 512,
UBatchSize = 512,
FlashAttention = true,
UseMemorymap = true
};
_weights = LLamaWeights.LoadFromFile(@params);
Expand Down Expand Up @@ -66,7 +65,6 @@ public LlamaSharpTextGenerator(LLamaWeights weights, LLamaSharpConfig config, St
SplitMode = config?.SplitMode ?? LLama.Native.GPUSplitMode.Layer,
BatchSize = 512,
UBatchSize = 512,
FlashAttention = true,
UseMemorymap = true
};
_executor = executor ?? new StatelessExecutor(_weights, @params);
Expand Down
4 changes: 2 additions & 2 deletions LLama.Unittest/LLamaContextTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public LLamaContextTests()
{
var @params = new ModelParams(Constants.GenerativeModelPath2)
{
ContextSize = 128,
ContextSize = 512,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason for the changes in tests?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@martindevans fixed that test here

BatchSize = 8,
UBatchSize = 8,
SeqMax = 1,
Expand All @@ -33,7 +33,7 @@ public void Dispose()
[Fact]
public void CheckProperties()
{
Assert.Equal(128u, _context.ContextSize);
Assert.Equal(512u, _context.ContextSize);
Assert.Equal(960, _context.EmbeddingSize);
Assert.Equal(49152, _context.Vocab.Count);
}
Expand Down
4 changes: 2 additions & 2 deletions LLama.Unittest/LLamaContextWithCustomLoggerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public LLamaContextWithCustomLoggerTests()
{
var @params = new ModelParams(Constants.GenerativeModelPath2)
{
ContextSize = 128,
ContextSize = 512,
GpuLayerCount = Constants.CIGpuLayerCount,
};

Expand All @@ -55,7 +55,7 @@ public void Dispose()
[Fact]
public void CheckProperties()
{
Assert.Equal(128u, _context.ContextSize);
Assert.Equal(512u, _context.ContextSize);
Assert.Equal(960, _context.EmbeddingSize);
Assert.Equal(49152, _context.Vocab.Count);
}
Expand Down
2 changes: 1 addition & 1 deletion LLama.Unittest/LLamaRerankerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ public LLamaRerankerTests(ITestOutputHelper testOutputHelper)
var @params = new ModelParams(Constants.RerankingModelPath)
{
ContextSize = 0,
SeqMax = 1,
PoolingType = LLamaPoolingType.Rank,
GpuLayerCount = Constants.CIGpuLayerCount,

};
using var weights = LLamaWeights.LoadFromFile(@params);
_reranker = new LLamaReranker(weights, @params);
Expand Down
3 changes: 2 additions & 1 deletion LLama.Unittest/SamplingTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public SamplingTests(ITestOutputHelper testOutputHelper)
_params = new ModelParams(Constants.GenerativeModelPath2) {
ContextSize = 200,
BatchSize = 200,
SeqMax = 4,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think specifiying SeqMax is necessary now that kv_unified is enabled.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is due to the regression in llama cpp without this the batched sampling and reranker tests would fail
look at this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try without? I think it isn't needed

GpuLayerCount = Constants.CIGpuLayerCount,
};
_model = LLamaWeights.LoadFromFile(_params);
Expand Down Expand Up @@ -104,7 +105,7 @@ public void BatchedSampling()
}
}

// Add " repeat" and test whether next tokens will be "this phrase forever.".
// Add " repeat" and test whether next tokens will be "this phrase forever."
for (int i = 0; i < 4; i++)
{
for (int b = 0; b < batch_count; b++)
Expand Down
2 changes: 1 addition & 1 deletion LLama.Web/Common/ModelOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public class ModelOptions
public bool NoKqvOffload { get; set; }

/// <inheritdoc />
public bool FlashAttention { get; set; }
public bool? FlashAttention { get; set; }

/// <inheritdoc />
public Encoding Encoding { get; set; } = Encoding.UTF8;
Expand Down
4 changes: 2 additions & 2 deletions LLama/Abstractions/IContextParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ public interface IContextParams
/// <summary>
/// Whether to use flash attention
/// </summary>
bool FlashAttention { get; }

bool? FlashAttention { get; }
/// <summary>
/// defragment the KV cache if holes/size &gt; defrag_threshold, Set to &lt;= 0 to disable (default)
/// </summary>
Expand Down
7 changes: 4 additions & 3 deletions LLama/Common/ModelParams.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using LLama.Abstractions;
using System.Text;
using System.Text.Json.Serialization;
Expand Down Expand Up @@ -95,12 +96,12 @@ public record ModelParams

/// <inheritdoc />
public bool NoKqvOffload { get; set; }

/// <inheritdoc />

public bool FlashAttention { get; set; }
public bool? FlashAttention { get; set; }

/// <inheritdoc />
[Obsolete]
public float? DefragThreshold { get; set; }

/// <inheritdoc />
Expand Down
8 changes: 7 additions & 1 deletion LLama/Extensions/IContextParamsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,15 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
result.type_k = @params.TypeK ?? GGMLType.GGML_TYPE_F16;
result.type_v = @params.TypeV ?? GGMLType.GGML_TYPE_F16;
result.offload_kqv = [email protected];
result.flash_attention = @params.FlashAttention;
Copy link
Contributor

@Lyrcaxis Lyrcaxis Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of completely removing the option to use flash attention, can you pass to llama_flash_attn_type?
I would suggest keeping the previous FlashAttention bool as it was -- but turn it to nullable, so null == Auto.

result.llama_flash_attn_type = @params.FlashAttention switch
{
    true => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_ENABLED,
    false => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_DISABLED,
    null => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_AUTO
}
result.kv_unified = true; // if we wanna hardcode it here instead of in `Default()`.

result.llama_pooling_type = @params.PoolingType;
result.attention_type = @params.AttentionType;
result.llama_flash_attn_type = @params.FlashAttention switch
{
true => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_ENABLED,
false => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_DISABLED,
null => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_AUTO
};
result.kv_unified = true;

result.n_threads = Threads(@params.Threads);
result.n_threads_batch = Threads(@params.BatchThreads);
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaSharp.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
</ItemGroup>

<PropertyGroup>
<BinaryReleaseId>11dd5a44eb180e</BinaryReleaseId>
<BinaryReleaseId>86587da</BinaryReleaseId>
</PropertyGroup>

<PropertyGroup>
Expand Down
5 changes: 5 additions & 0 deletions LLama/Native/LLamaContextParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ public struct LLamaContextParams
/// Attention type to use for embeddings
/// </summary>
public LLamaAttentionType attention_type;

/// <summary>
/// when to enable Flash Attention
/// </summary>
public LLamaFlashAttentionType llama_flash_attn_type;

/// <summary>
/// RoPE base frequency, 0 = from model
Expand Down
19 changes: 19 additions & 0 deletions LLama/Native/LLamaFlashAttentionType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
namespace LLama.Native;
/// <summary>
/// flash_attn_type
/// </summary>
public enum LLamaFlashAttentionType
{
/// <summary>
/// attention type auto
/// </summary>
LLAMA_FLASH_ATTENTION_TYPE_AUTO = -1,
/// <summary>
/// attention disabled
/// </summary>
LLAMA_FLASH_ATTENTION_TYPE_DISABLED = 0,
/// <summary>
/// attention enabled
/// </summary>
LLAMA_FLASH_ATTENTION_TYPE_ENABLED = 1,
}
7 changes: 6 additions & 1 deletion LLama/Native/LLamaFtype.cs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,12 @@ public enum LLamaFtype
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37,


/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38,

/// <summary>
/// File type was not specified
/// </summary>
Expand Down
11 changes: 10 additions & 1 deletion LLama/Native/LLamaModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,16 @@ public bool check_tensors
set => _check_tensors = Convert.ToSByte(value);
}
private sbyte _check_tensors;


/// <summary>
/// use extra buffer types (used for weight repacking)
/// </summary>
public bool use_extra_bufts
{
readonly get => Convert.ToBoolean(_use_extra_bufts);
set => _use_extra_bufts = Convert.ToSByte(value);
}
private sbyte _use_extra_bufts;
/// <summary>
/// Create a LLamaModelParams with default values
/// </summary>
Expand Down
74 changes: 61 additions & 13 deletions LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,15 @@ public static void llama_empty_call()
/// <param name="buf">A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)</param>
/// <param name="length">The size of the allocated buffer</param>
/// <returns>The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.</returns>
public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg, [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length)
public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg,
[MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length)
{
return internal_llama_chat_apply_template(tmpl, chat, n_msg, add_ass, buf, length);

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_chat_apply_template")]
static extern int internal_llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg, [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length);
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl,
EntryPoint = "llama_chat_apply_template")]
static extern int internal_llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg,
[MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length);
}

/// <summary>
Expand Down Expand Up @@ -215,7 +218,8 @@ public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage*
/// <param name="lstrip">User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix')</param>
/// <param name="special">If true, special tokens are rendered in the output</param>
/// <returns>The length written, or if the buffer is too small a negative that indicates the length required</returns>
public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LLamaToken llamaToken, Span<byte> buffer, int lstrip, bool special)
public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LLamaToken llamaToken,
Span<byte> buffer, int lstrip, bool special)
{
// Handle invalid tokens
if ((int)llamaToken < 0)
Expand All @@ -225,12 +229,14 @@ public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LL
{
fixed (byte* bufferPtr = buffer)
{
return llama_token_to_piece_native(vocab.VocabNative, llamaToken, bufferPtr, buffer.Length, lstrip, special);
return llama_token_to_piece_native(vocab.VocabNative, llamaToken, bufferPtr, buffer.Length, lstrip,
special);
}
}

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_token_to_piece")]
static extern unsafe int llama_token_to_piece_native(LLamaVocabNative* model, LLamaToken llamaToken, byte* buffer, int length, int lstrip, [MarshalAs(UnmanagedType.U1)] bool special);
static extern unsafe int llama_token_to_piece_native(LLamaVocabNative* model, LLamaToken llamaToken,
byte* buffer, int length, int lstrip, [MarshalAs(UnmanagedType.U1)] bool special);
}

/// <summary>
Expand All @@ -247,7 +253,9 @@ public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LL
/// Returns a negative number on failure - the number of tokens that would have been returned. Returns INT32_MIN on overflow (e.g., tokenization result size exceeds int32_t limit)
/// </returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
internal static extern unsafe int llama_tokenize(LLamaVocabNative* model, byte* text, int text_len, LLamaToken* tokens, int n_max_tokens, [MarshalAs(UnmanagedType.U1)] bool add_special, [MarshalAs(UnmanagedType.U1)] bool parse_special);
internal static extern unsafe int llama_tokenize(LLamaVocabNative* model, byte* text, int text_len,
LLamaToken* tokens, int n_max_tokens, [MarshalAs(UnmanagedType.U1)] bool add_special,
[MarshalAs(UnmanagedType.U1)] bool parse_special);

/// <summary>
/// Convert the provided tokens into text (inverse of llama_tokenize()).
Expand All @@ -261,7 +269,8 @@ public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LL
/// <param name="unparseSpecial">unparse_special If true, special tokens are rendered in the output.</param>
/// <returns>Returns the number of chars/bytes on success, no more than textLengthMax. Returns a negative number on failure - the number of chars/bytes that would have been returned.</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
internal static extern unsafe int llama_detokenize(LLamaVocabNative* model, LLamaToken* tokens, int nTokens, byte* textOut, int textLengthMax, bool removeSpecial, bool unparseSpecial);
internal static extern unsafe int llama_detokenize(LLamaVocabNative* model, LLamaToken* tokens, int nTokens,
byte* textOut, int textLengthMax, bool removeSpecial, bool unparseSpecial);

/// <summary>
/// Register a callback to receive llama log messages
Expand All @@ -272,7 +281,7 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
{
NativeLogConfig.llama_log_set(logCallback);
}

/// <summary>
/// Allocates a batch of tokens on the heap
/// Each token can be assigned up to n_seq_max sequence ids
Expand Down Expand Up @@ -311,7 +320,8 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
/// <param name="il_end"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe int llama_apply_adapter_cvec(SafeLLamaContextHandle ctx, float* data, nuint len, int n_embd, int il_start, int il_end);
public static extern unsafe int llama_apply_adapter_cvec(SafeLLamaContextHandle ctx, float* data, nuint len,
int n_embd, int il_start, int il_end);

/// <summary>
/// Build a split GGUF final path for this chunk.
Expand All @@ -324,7 +334,8 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
/// <param name="split_count"></param>
/// <returns>Returns the split_path length.</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_split_path(string split_path, nuint maxlen, string path_prefix, int split_no, int split_count);
public static extern int llama_split_path(string split_path, nuint maxlen, string path_prefix, int split_no,
int split_count);

/// <summary>
/// Extract the path prefix from the split_path if and only if the split_no and split_count match.
Expand All @@ -337,7 +348,8 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
/// <param name="split_count"></param>
/// <returns>Returns the split_prefix length.</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_split_prefix(string split_prefix, nuint maxlen, string split_path, int split_no, int split_count);
public static extern int llama_split_prefix(string split_prefix, nuint maxlen, string split_path, int split_no,
int split_count);

//[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
//todo: public static void llama_attach_threadpool(SafeLLamaContextHandle ctx, ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch);
Expand Down Expand Up @@ -380,5 +392,41 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
/// <returns>Name of the buffer type</returns>
[DllImport(ggmlBaseLibraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr ggml_backend_buft_name(IntPtr buft);

/// <summary>
///
/// </summary>
/// <param name="ctx"></param>
/// <param name="seq_id"></param>
/// <param name="flags"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern UIntPtr llama_state_seq_get_size_ext(IntPtr ctx, int seq_id, uint flags);

/// <summary>
///
/// </summary>
/// <param name="ctx"></param>
/// <param name="dst"></param>
/// <param name="size"></param>
/// <param name="seq_id"></param>
/// <param name="flags"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern UIntPtr llama_state_seq_get_data_ext(IntPtr ctx, [Out] byte[] dst, UIntPtr size,
int seq_id, uint flags);

/// <summary>
///
/// </summary>
/// <param name="ctx"></param>
/// <param name="src"></param>
/// <param name="size"></param>
/// <param name="dest_seq_id"></param>
/// <param name="flags"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern UIntPtr llama_state_seq_set_data_ext(IntPtr ctx, byte[] src, UIntPtr size, int dest_seq_id,
uint flags);
}
}
}
Loading