diff --git a/examples/common/common.hpp b/examples/common/common.hpp index 82328bccb..d299da50c 100644 --- a/examples/common/common.hpp +++ b/examples/common/common.hpp @@ -1594,10 +1594,30 @@ struct SDGenerationParams { load_if_exists("skip_layers", skip_layers); load_if_exists("high_noise_skip_layers", high_noise_skip_layers); + load_if_exists("steps", sample_params.sample_steps); + load_if_exists("high_noise_steps", high_noise_sample_params.sample_steps); load_if_exists("cfg_scale", sample_params.guidance.txt_cfg); load_if_exists("img_cfg_scale", sample_params.guidance.img_cfg); load_if_exists("guidance", sample_params.guidance.distilled_guidance); + auto load_sampler_if_exists = [&](const char* key, enum sample_method_t& out) { + if (j.contains(key) && j[key].is_string()) { + enum sample_method_t tmp = str_to_sample_method(j[key].get().c_str()); + if (tmp != SAMPLE_METHOD_COUNT) { + out = tmp; + } + } + }; + load_sampler_if_exists("sample_method", sample_params.sample_method); + load_sampler_if_exists("high_noise_sample_method", high_noise_sample_params.sample_method); + + if (j.contains("scheduler") && j["scheduler"].is_string()) { + enum scheduler_t tmp = str_to_scheduler(j["scheduler"].get().c_str()); + if (tmp != SCHEDULER_COUNT) { + sample_params.scheduler = tmp; + } + } + return true; } diff --git a/examples/server/main.cpp b/examples/server/main.cpp index b15daca1a..b0ac7eef9 100644 --- a/examples/server/main.cpp +++ b/examples/server/main.cpp @@ -420,6 +420,9 @@ int main(int argc, const char** argv) { return; } + if (gen_params.sample_params.sample_steps > 100) + gen_params.sample_params.sample_steps = 100; + if (!gen_params.process_and_check(IMG_GEN, "")) { res.status = 400; res.set_content(R"({"error":"invalid params"})", "application/json"); @@ -598,6 +601,9 @@ int main(int argc, const char** argv) { return; } + if (gen_params.sample_params.sample_steps > 100) + gen_params.sample_params.sample_steps = 100; + if (!gen_params.process_and_check(IMG_GEN, "")) { res.status = 400; res.set_content(R"({"error":"invalid params"})", "application/json");