Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 42 additions & 11 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1614,9 +1614,9 @@ struct LLMEmbedder : public Conditioner {
bool enable_vision = false)
: version(version) {
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
if (sd_version_is_flux2(version)) {
if (version == VERSION_FLUX2) {
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE) {
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) {
arch = LLM::LLMArch::QWEN3;
}
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
Expand Down Expand Up @@ -1708,6 +1708,9 @@ struct LLMEmbedder : public Conditioner {
int prompt_template_encode_start_idx = 34;
int max_length = 0;
std::set<int> out_layers;
std::vector<int> tokens;
std::vector<float> weights;
std::vector<float> mask;
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
LOG_INFO("QwenImageEditPlusPipeline");
prompt_template_encode_start_idx = 64;
Expand Down Expand Up @@ -1771,7 +1774,7 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = static_cast<int>(prompt.size());

prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else if (sd_version_is_flux2(version)) {
} else if (version == VERSION_FLUX2) {
prompt_template_encode_start_idx = 0;
out_layers = {10, 20, 30};

Expand All @@ -1793,17 +1796,28 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = static_cast<int>(prompt.size());

prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else if (sd_version_is_flux2(version)) {
} else if (version == VERSION_FLUX2_KLEIN) {
prompt_template_encode_start_idx = 0;
out_layers = {10, 20, 30};
max_length = 512;
out_layers = {9, 18, 27};

prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
prompt = "<|im_start|>user\n";

prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());

prompt += "[/INST]";
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";

auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false);
tokens = std::get<0>(tokens_and_weights);
weights = std::get<1>(tokens_and_weights);

mask.insert(mask.end(), tokens.size(), 1.f);
if (tokens.size() < max_length) {
mask.insert(mask.end(), max_length - tokens.size(), 0.f);
tokenizer->pad_tokens(tokens, weights, max_length, true);
}
} else if (version == VERSION_OVIS_IMAGE) {
prompt_template_encode_start_idx = 28;
max_length = prompt_template_encode_start_idx + 256;
Expand All @@ -1827,17 +1841,34 @@ struct LLMEmbedder : public Conditioner {
prompt += "<|im_end|>\n<|im_start|>assistant\n";
}

auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0);
auto& tokens = std::get<0>(tokens_and_weights);
auto& weights = std::get<1>(tokens_and_weights);
if (tokens.empty()) {
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0);
tokens = std::get<0>(tokens_and_weights);
weights = std::get<1>(tokens_and_weights);
}

int64_t t0 = ggml_time_ms();
struct ggml_tensor* hidden_states = nullptr; // [N, n_token, 3584]

auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);

ggml_tensor* attention_mask = nullptr;
if (!mask.empty()) {
attention_mask = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, mask.size(), mask.size());
ggml_ext_tensor_iter(attention_mask, [&](ggml_tensor* attention_mask, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = 0.f;
if (mask[i0] == 0.f) {
value = -INFINITY;
} else if (i0 > i1) {
value = -INFINITY;
}
ggml_ext_tensor_set_f32(attention_mask, value, i0, i1, i2, i3);
});
}

llm->compute(n_threads,
input_ids,
attention_mask,
image_embeds,
out_layers,
&hidden_states,
Expand All @@ -1861,7 +1892,7 @@ struct LLMEmbedder : public Conditioner {
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);

int64_t min_length = 0;
if (sd_version_is_flux2(version)) {
if (version == VERSION_FLUX2) {
min_length = 512;
}

Expand Down
29 changes: 21 additions & 8 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1288,13 +1288,9 @@ namespace Flux {
} else if (version == VERSION_OVIS_IMAGE) {
flux_params.semantic_txt_norm = true;
flux_params.use_yak_mlp = true;
flux_params.context_in_dim = 2048;
flux_params.vec_in_dim = 0;
} else if (sd_version_is_flux2(version)) {
flux_params.context_in_dim = 15360;
flux_params.in_channels = 128;
flux_params.hidden_size = 6144;
flux_params.num_heads = 48;
flux_params.patch_size = 1;
flux_params.out_channels = 128;
flux_params.mlp_ratio = 3.f;
Expand All @@ -1307,12 +1303,12 @@ namespace Flux {
flux_params.ref_index_scale = 10.f;
flux_params.use_mlp_silu_act = true;
}
int64_t head_dim = 0;
for (auto pair : tensor_storage_map) {
std::string tensor_name = pair.first;
if (!starts_with(tensor_name, prefix))
continue;
if (tensor_name.find("guidance_in.in_layer.weight") != std::string::npos) {
// not schnell
flux_params.guidance_embed = true;
}
if (tensor_name.find("__x0__") != std::string::npos) {
Expand Down Expand Up @@ -1344,13 +1340,30 @@ namespace Flux {
flux_params.depth_single_blocks = block_depth + 1;
}
}
if (ends_with(tensor_name, "txt_in.weight")) {
flux_params.context_in_dim = pair.second.ne[0];
flux_params.hidden_size = pair.second.ne[1];
}
if (ends_with(tensor_name, "single_blocks.0.norm.key_norm.scale")) {
head_dim = pair.second.ne[0];
}
if (ends_with(tensor_name, "double_blocks.0.txt_attn.norm.key_norm.scale")) {
head_dim = pair.second.ne[0];
}
}

LOG_INFO("Flux blocks: %d double, %d single", flux_params.depth, flux_params.depth_single_blocks);
flux_params.num_heads = static_cast<int>(flux_params.hidden_size / head_dim);

LOG_INFO("flux: depth = %d, depth_single_blocks = %d, guidance_embed = %s, context_in_dim = %" PRId64
", hidden_size = %" PRId64 ", num_heads = %d",
flux_params.depth,
flux_params.depth_single_blocks,
flux_params.guidance_embed ? "true" : "false",
flux_params.context_in_dim,
flux_params.hidden_size,
flux_params.num_heads);
if (flux_params.is_chroma) {
LOG_INFO("Using pruned modulation (Chroma)");
} else if (!flux_params.guidance_embed) {
LOG_INFO("Flux guidance is disabled (Schnell mode)");
}

flux = Flux(flux_params);
Expand Down
3 changes: 2 additions & 1 deletion ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1348,7 +1348,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_kv_head * N); // [N * n_kv_head, d_head, L_k]

auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k]
kq = ggml_scale_inplace(ctx, kq, scale);
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
kq = ggml_scale_inplace(ctx, kq, scale);
if (mask) {
kq = ggml_add_inplace(ctx, kq, mask);
}
Expand Down
51 changes: 38 additions & 13 deletions llm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,8 @@ namespace LLM {

struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* input_pos) {
struct ggml_tensor* input_pos,
struct ggml_tensor* attention_mask = nullptr) {
// x: [N, n_token, hidden_size]
int64_t n_token = x->ne[1];
int64_t N = x->ne[2];
Expand Down Expand Up @@ -880,7 +881,7 @@ namespace LLM {
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim]
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim]

x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, true, true, false); // [N, n_token, hidden_size]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, false, true, false); // [N, n_token, hidden_size]

x = out_proj->forward(ctx, x); // [N, n_token, hidden_size]
return x;
Expand All @@ -898,7 +899,8 @@ namespace LLM {

struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* x,
struct ggml_tensor* input_pos) {
struct ggml_tensor* input_pos,
struct ggml_tensor* attention_mask = nullptr) {
// x: [N, n_token, hidden_size]
auto self_attn = std::dynamic_pointer_cast<Attention>(blocks["self_attn"]);
auto mlp = std::dynamic_pointer_cast<MLP>(blocks["mlp"]);
Expand All @@ -907,7 +909,7 @@ namespace LLM {

auto residual = x;
x = input_layernorm->forward(ctx, x);
x = self_attn->forward(ctx, x, input_pos);
x = self_attn->forward(ctx, x, input_pos, attention_mask);
x = ggml_add_inplace(ctx->ggml_ctx, x, residual);

residual = x;
Expand Down Expand Up @@ -936,6 +938,7 @@ namespace LLM {
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids,
struct ggml_tensor* input_pos,
struct ggml_tensor* attention_mask,
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
std::set<int> out_layers) {
// input_ids: [N, n_token]
Expand Down Expand Up @@ -990,7 +993,7 @@ namespace LLM {
for (int i = 0; i < num_layers; i++) {
auto block = std::dynamic_pointer_cast<TransformerBlock>(blocks["layers." + std::to_string(i)]);

x = block->forward(ctx, x, input_pos);
x = block->forward(ctx, x, input_pos, attention_mask);
if (out_layers.find(i + 1) != out_layers.end()) {
intermediate_outputs.push_back(x);
}
Expand Down Expand Up @@ -1036,12 +1039,13 @@ namespace LLM {
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids,
struct ggml_tensor* input_pos,
struct ggml_tensor* attention_mask,
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
std::set<int> out_layers) {
// input_ids: [N, n_token]
auto model = std::dynamic_pointer_cast<TextModel>(blocks["model"]);

auto x = model->forward(ctx, input_ids, input_pos, image_embeds, out_layers);
auto x = model->forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers);
return x;
}

Expand All @@ -1063,6 +1067,7 @@ namespace LLM {
LLM model;

std::vector<int> input_pos_vec;
std::vector<float> attention_mask_vec;
std::vector<float> window_mask_vec;
std::vector<int> window_index_vec;
std::vector<int> window_inverse_index_vec;
Expand Down Expand Up @@ -1157,9 +1162,10 @@ namespace LLM {
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
struct ggml_tensor* input_ids,
struct ggml_tensor* input_pos,
struct ggml_tensor* attention_mask,
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
std::set<int> out_layers) {
auto hidden_states = model.forward(ctx, input_ids, input_pos, image_embeds, out_layers); // [N, n_token, hidden_size]
auto hidden_states = model.forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); // [N, n_token, hidden_size]
return hidden_states;
}

Expand All @@ -1174,6 +1180,7 @@ namespace LLM {
}

struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
struct ggml_tensor* attention_mask,
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
std::set<int> out_layers) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
Expand Down Expand Up @@ -1205,9 +1212,26 @@ namespace LLM {
input_pos_vec.size());
set_backend_tensor_data(input_pos, input_pos_vec.data());

if (attention_mask != nullptr) {
attention_mask = to_backend(attention_mask);
} else {
attention_mask_vec.resize(n_tokens * n_tokens);
for (int i0 = 0; i0 < n_tokens; i0++) {
for (int i1 = 0; i1 < n_tokens; i1++) {
float value = 0.f;
if (i0 > i1) {
value = -INFINITY;
}
attention_mask_vec[i1 * n_tokens + i0] = value;
}
}
attention_mask = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, n_tokens, n_tokens);
set_backend_tensor_data(attention_mask, attention_mask_vec.data());
}

auto runner_ctx = get_context();

struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, image_embeds, out_layers);
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers);

ggml_build_forward_expand(gf, hidden_states);

Expand All @@ -1216,12 +1240,13 @@ namespace LLM {

bool compute(const int n_threads,
struct ggml_tensor* input_ids,
struct ggml_tensor* attention_mask,
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
std::set<int> out_layers,
ggml_tensor** output,
ggml_context* output_ctx = nullptr) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(input_ids, image_embeds, out_layers);
return build_graph(input_ids, attention_mask, image_embeds, out_layers);
};
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
}
Expand Down Expand Up @@ -1525,7 +1550,7 @@ namespace LLM {
struct ggml_tensor* out = nullptr;

int64_t t0 = ggml_time_ms();
model.compute(8, input_ids, image_embeds, {}, &out, work_ctx);
model.compute(8, input_ids, nullptr, image_embeds, {}, &out, work_ctx);
int64_t t1 = ggml_time_ms();

print_ggml_tensor(out);
Expand Down Expand Up @@ -1565,7 +1590,7 @@ namespace LLM {
struct ggml_tensor* out = nullptr;

int64_t t0 = ggml_time_ms();
model.compute(8, input_ids, {}, {10, 20, 30}, &out, work_ctx);
model.compute(8, input_ids, nullptr, {}, {10, 20, 30}, &out, work_ctx);
int64_t t1 = ggml_time_ms();

print_ggml_tensor(out);
Expand All @@ -1588,7 +1613,7 @@ namespace LLM {
struct ggml_tensor* out = nullptr;

int64_t t0 = ggml_time_ms();
model.compute(8, input_ids, {}, {35}, &out, work_ctx);
model.compute(8, input_ids, nullptr, {}, {35}, &out, work_ctx);
int64_t t1 = ggml_time_ms();

print_ggml_tensor(out);
Expand All @@ -1611,7 +1636,7 @@ namespace LLM {
struct ggml_tensor* out = nullptr;

int64_t t0 = ggml_time_ms();
model.compute(8, input_ids, {}, {}, &out, work_ctx);
model.compute(8, input_ids, nullptr, {}, {}, &out, work_ctx);
int64_t t1 = ggml_time_ms();

print_ggml_tensor(out);
Expand Down
Loading