-
Couldn't load subscription status.
- Fork 475
Add support for gemma 3n #1248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add support for gemma 3n #1248
Changes from all commits
d80d038
da01789
53c8c56
20bcf74
48f109a
424a736
ff6ea95
0990be3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ public SamplingTests(ITestOutputHelper testOutputHelper) | |
| _params = new ModelParams(Constants.GenerativeModelPath2) { | ||
| ContextSize = 200, | ||
| BatchSize = 200, | ||
| SeqMax = 4, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think specifiying SeqMax is necessary now that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is due to the regression in llama cpp without this the batched sampling and reranker tests would fail There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you try without? I think it isn't needed |
||
| GpuLayerCount = Constants.CIGpuLayerCount, | ||
| }; | ||
| _model = LLamaWeights.LoadFromFile(_params); | ||
|
|
@@ -104,7 +105,7 @@ public void BatchedSampling() | |
| } | ||
| } | ||
|
|
||
| // Add " repeat" and test whether next tokens will be "this phrase forever.". | ||
| // Add " repeat" and test whether next tokens will be "this phrase forever." | ||
| for (int i = 0; i < 4; i++) | ||
| { | ||
| for (int b = 0; b < batch_count; b++) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,9 +49,15 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo | |
| result.type_k = @params.TypeK ?? GGMLType.GGML_TYPE_F16; | ||
| result.type_v = @params.TypeV ?? GGMLType.GGML_TYPE_F16; | ||
| result.offload_kqv = [email protected]; | ||
| result.flash_attention = @params.FlashAttention; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of completely removing the option to use flash attention, can you pass to result.llama_flash_attn_type = @params.FlashAttention switch
{
true => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_ENABLED,
false => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_DISABLED,
null => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_AUTO
}
result.kv_unified = true; // if we wanna hardcode it here instead of in `Default()`. |
||
| result.llama_pooling_type = @params.PoolingType; | ||
| result.attention_type = @params.AttentionType; | ||
| result.llama_flash_attn_type = @params.FlashAttention switch | ||
| { | ||
| true => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_ENABLED, | ||
| false => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_DISABLED, | ||
| null => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_AUTO | ||
| }; | ||
| result.kv_unified = true; | ||
|
|
||
| result.n_threads = Threads(@params.Threads); | ||
| result.n_threads_batch = Threads(@params.BatchThreads); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| namespace LLama.Native; | ||
| /// <summary> | ||
| /// flash_attn_type | ||
| /// </summary> | ||
| public enum LLamaFlashAttentionType | ||
| { | ||
| /// <summary> | ||
| /// attention type auto | ||
| /// </summary> | ||
| LLAMA_FLASH_ATTENTION_TYPE_AUTO = -1, | ||
| /// <summary> | ||
| /// attention disabled | ||
| /// </summary> | ||
| LLAMA_FLASH_ATTENTION_TYPE_DISABLED = 0, | ||
| /// <summary> | ||
| /// attention enabled | ||
| /// </summary> | ||
| LLAMA_FLASH_ATTENTION_TYPE_ENABLED = 1, | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason for the changes in tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@martindevans fixed that test here