Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion docs/distilled_sd.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ python convert_diffusers_to_original_stable_diffusion.py \
The file segmind_tiny-sd.ckpt will be generated and is now ready for use with sd.cpp. You can follow a similar process for the other models mentioned above.


### Another available .ckpt file:
##### Another available .ckpt file:

* https://huggingface.co/ClashSAN/small-sd/resolve/main/tinySDdistilled.ckpt

Expand All @@ -97,3 +97,31 @@ for key, value in ckpt['state_dict'].items():
ckpt['state_dict'][key] = value.contiguous()
torch.save(ckpt, "tinySDdistilled_fixed.ckpt")
```


### SDXS-512

Another very tiny and **incredibly fast** model is SDXS by IDKiro et al. The authors refer to it as *"Real-Time One-Step Latent Diffusion Models with Image Conditions"*. For details read the paper: https://arxiv.org/pdf/2403.16627 . Once again the authors removed some more blocks of U-Net part and unlike other SD1 models they use an adjusted _AutoEncoderTiny_ instead of default _AutoEncoderKL_ for the VAE part.

##### 1. Download the diffusers model from Hugging Face using Python:

```python
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("IDKiro/sdxs-512-dreamshaper")
pipe.save_pretrained(save_directory="sdxs")
```
##### 2. Create a safetensors file

```bash
python convert_diffusers_to_original_stable_diffusion.py \
--model_path sdxs --checkpoint_path sdxs.safetensors --half --use_safetensors
```

##### 3. Run the model as follows:

```bash
~/stable-diffusion.cpp/build/bin/sd-cli -m sdxs.safetensors -p "portrait of a lovely cat" \
--cfg-scale 1 --steps 1
```

Both options: ``` --cfg-scale 1 ``` and ``` --steps 1 ``` are mandatory here.
7 changes: 7 additions & 0 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,7 @@ SDVersion ModelLoader::get_sd_version() {
int64_t patch_embedding_channels = 0;
bool has_img_emb = false;
bool has_middle_block_1 = false;
bool has_output_block_71 = false;

for (auto& [name, tensor_storage] : tensor_storage_map) {
if (!(is_xl)) {
Expand Down Expand Up @@ -1094,6 +1095,9 @@ SDVersion ModelLoader::get_sd_version() {
tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) {
has_middle_block_1 = true;
}
if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos) {
has_output_block_71 = true;
}
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
Expand Down Expand Up @@ -1155,6 +1159,9 @@ SDVersion ModelLoader::get_sd_version() {
return VERSION_SD1_PIX2PIX;
}
if (!has_middle_block_1) {
if (!has_output_block_71) {
return VERSION_SDXS;
}
return VERSION_SD1_TINY_UNET;
}
return VERSION_SD1;
Expand Down
3 changes: 2 additions & 1 deletion model.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ enum SDVersion {
VERSION_SD2,
VERSION_SD2_INPAINT,
VERSION_SD2_TINY_UNET,
VERSION_SDXS,
VERSION_SDXL,
VERSION_SDXL_INPAINT,
VERSION_SDXL_PIX2PIX,
Expand All @@ -50,7 +51,7 @@ enum SDVersion {
};

static inline bool sd_version_is_sd1(SDVersion version) {
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET) {
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET || version == VERSION_SDXS) {
return true;
}
return false;
Expand Down
108 changes: 77 additions & 31 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ const char* model_version_to_str[] = {
"SD 2.x",
"SD 2.x Inpaint",
"SD 2.x Tiny UNet",
"SDXS",
"SDXL",
"SDXL Inpaint",
"SDXL Instruct-Pix2Pix",
Expand Down Expand Up @@ -114,7 +115,8 @@ class StableDiffusionGGML {
std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd or wan2.1 i2v
std::shared_ptr<DiffusionModel> diffusion_model;
std::shared_ptr<DiffusionModel> high_noise_diffusion_model;
std::shared_ptr<VAE> first_stage_model;
std::shared_ptr<VAE> first_stage_model = nullptr;
std::shared_ptr<TinyAutoEncoder> first_stage_model_tiny = nullptr;
std::shared_ptr<TinyAutoEncoder> tae_first_stage;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use tae_first_stage ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also even though it's recommended to use with TAESD by default, it should work just fine with sd1.x KL-F8 VAE, just slower.

std::shared_ptr<ControlNet> control_net;
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
Expand Down Expand Up @@ -605,28 +607,42 @@ class StableDiffusionGGML {
first_stage_model = std::make_shared<FakeVAE>(vae_backend,
offload_params_to_cpu);
} else {
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend,
offload_params_to_cpu,
tensor_storage_map,
"first_stage_model",
vae_decode_only,
false,
version);
if (sd_ctx_params->vae_conv_direct) {
LOG_INFO("Using Conv2d direct in the vae model");
first_stage_model->set_conv2d_direct_enabled(true);
}
if (version == VERSION_SDXL &&
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) {
float vae_conv_2d_scale = 1.f / 32.f;
LOG_WARN(
"No VAE specified with --vae or --force-sdxl-vae-conv-scale flag set, "
"using Conv2D scale %.3f",
vae_conv_2d_scale);
first_stage_model->set_conv2d_scale(vae_conv_2d_scale);
if (version == VERSION_SDXS) {
first_stage_model_tiny = std::make_shared<TinyImageAutoEncoder>(vae_backend,
offload_params_to_cpu,
tensor_storage_map,
"decoder.layers",
vae_decode_only,
version);
first_stage_model_tiny->alloc_params_buffer();
first_stage_model_tiny->get_param_tensors(tensors,"first_stage_model");
if (sd_ctx_params->vae_conv_direct) {
first_stage_model_tiny->set_conv2d_direct_enabled(true);
}
} else {
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend,
offload_params_to_cpu,
tensor_storage_map,
"first_stage_model",
vae_decode_only,
false,
version);
if (sd_ctx_params->vae_conv_direct) {
LOG_INFO("Using Conv2d direct in the vae model");
first_stage_model->set_conv2d_direct_enabled(true);
}
if (version == VERSION_SDXL &&
(strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) {
float vae_conv_2d_scale = 1.f / 32.f;
LOG_WARN(
"No VAE specified with --vae or --force-sdxl-vae-conv-scale flag set, "
"using Conv2D scale %.3f",
vae_conv_2d_scale);
first_stage_model->set_conv2d_scale(vae_conv_2d_scale);
}
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
}
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
}
}

Expand Down Expand Up @@ -722,6 +738,9 @@ class StableDiffusionGGML {
if (first_stage_model) {
first_stage_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
}
if (first_stage_model_tiny) {
first_stage_model_tiny->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
}
if (tae_first_stage) {
tae_first_stage->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
}
Expand Down Expand Up @@ -783,7 +802,11 @@ class StableDiffusionGGML {
}
size_t vae_params_mem_size = 0;
if (!use_tiny_autoencoder || sd_ctx_params->tae_preview_only) {
vae_params_mem_size = first_stage_model->get_params_buffer_size();
if (first_stage_model_tiny != nullptr) {
vae_params_mem_size = first_stage_model_tiny->get_params_buffer_size();
} else {
vae_params_mem_size = first_stage_model->get_params_buffer_size();
}
}
if (use_tiny_autoencoder) {
if (!tae_first_stage->load_from_file(taesd_path, n_threads)) {
Expand Down Expand Up @@ -2517,9 +2540,17 @@ class StableDiffusionGGML {
};
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling);
} else {
first_stage_model->compute(n_threads, x, false, &result, work_ctx);
if (version == VERSION_SDXS) {
first_stage_model_tiny->compute(n_threads, x, false, &result, work_ctx);
} else {
first_stage_model->compute(n_threads, x, false, &result, work_ctx);
}
}
if (version == VERSION_SDXS) {
first_stage_model_tiny->free_compute_buffer();
} else {
first_stage_model->free_compute_buffer();
}
first_stage_model->free_compute_buffer();
} else {
if (vae_tiling_params.enabled && !encode_video) {
// split latent in 32x32 tiles and compute in several steps
Expand Down Expand Up @@ -2573,6 +2604,7 @@ class StableDiffusionGGML {
sd_version_is_qwen_image(version) ||
sd_version_is_wan(version) ||
sd_version_is_flux2(version) ||
version == VERSION_SDXS ||
version == VERSION_CHROMA_RADIANCE) {
latent = vae_output;
} else if (version == VERSION_SD1_PIX2PIX) {
Expand Down Expand Up @@ -2631,7 +2663,9 @@ class StableDiffusionGGML {
if (sd_version_is_qwen_image(version)) {
x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]);
}
process_latent_out(x);
if (first_stage_model_tiny == nullptr) {
process_latent_out(x);
}
// x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
if (vae_tiling_params.enabled && !decode_video) {
float tile_overlap;
Expand All @@ -2642,14 +2676,22 @@ class StableDiffusionGGML {

// split latent in 32x32 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
first_stage_model->compute(n_threads, in, true, &out, nullptr);
first_stage_model_tiny != nullptr ? first_stage_model_tiny->compute(n_threads, in, true, &out, nullptr) : first_stage_model->compute(n_threads, in, true, &out, nullptr);
};
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling);
} else {
first_stage_model->compute(n_threads, x, true, &result, work_ctx);
if (first_stage_model_tiny != nullptr) {
first_stage_model_tiny->compute(n_threads, x, true, &result, work_ctx);
} else {
first_stage_model->compute(n_threads, x, true, &result, work_ctx);
}
}
if (first_stage_model_tiny != nullptr) {
first_stage_model_tiny->free_compute_buffer();
} else {
first_stage_model->free_compute_buffer();
process_vae_output_tensor(result);
}
first_stage_model->free_compute_buffer();
process_vae_output_tensor(result);
} else {
if (vae_tiling_params.enabled && !decode_video) {
// split latent in 64x64 tiles and compute in several steps
Expand Down Expand Up @@ -3411,7 +3453,11 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
int64_t t4 = ggml_time_ms();
LOG_INFO("decode_first_stage completed, taking %.2fs", (t4 - t3) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) {
sd_ctx->sd->first_stage_model->free_params_buffer();
if (sd_ctx->sd->first_stage_model_tiny != nullptr) {
sd_ctx->sd->first_stage_model_tiny->free_params_buffer();
} else {
sd_ctx->sd->first_stage_model->free_params_buffer();
}
}

sd_ctx->sd->lora_stat();
Expand Down
9 changes: 9 additions & 0 deletions tae.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ struct TinyAutoEncoder : public GGMLRunner {
struct ggml_context* output_ctx = nullptr) = 0;

virtual bool load_from_file(const std::string& file_path, int n_threads) = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
};

struct TinyImageAutoEncoder : public TinyAutoEncoder {
Expand Down Expand Up @@ -555,6 +556,10 @@ struct TinyImageAutoEncoder : public TinyAutoEncoder {
return success;
}

void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
taesd.get_param_tensors(tensors,prefix);
}

struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
z = to_backend(z);
Expand Down Expand Up @@ -624,6 +629,10 @@ struct TinyVideoAutoEncoder : public TinyAutoEncoder {
return success;
}

void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
taehv.get_param_tensors(tensors,prefix);
}

struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
z = to_backend(z);
Expand Down
5 changes: 4 additions & 1 deletion unet.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,13 @@ class UnetModelBlock : public GGMLBlock {
} else if (sd_version_is_unet_edit(version)) {
in_channels = 8;
}
if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET) {
if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET || version == VERSION_SDXS) {
num_res_blocks = 1;
channel_mult = {1, 2, 4};
tiny_unet = true;
if (version == VERSION_SDXS) {
attention_resolutions = {4, 2}; // here just like SDXL
}
}

// dims is always 2
Expand Down
Loading