diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 86bc1abfd..80cc311a0 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -115,7 +115,7 @@ struct SDCliParams { if (mode_found == -1) { LOG_ERROR("error: invalid mode %s, must be one of [%s]\n", mode_c_str, SD_ALL_MODES_STR); - exit(1); + return -1; } mode = (SDMode)mode_found; } @@ -209,19 +209,23 @@ void print_usage(int argc, const char* argv[], const std::vector& op options_list[2].print(); } -void parse_args(int argc, const char** argv, SDCliParams& cli_params, SDContextParams& ctx_params, SDGenerationParams& gen_params) { +int parse_args(int argc, const char** argv, SDCliParams& cli_params, SDContextParams& ctx_params, SDGenerationParams& gen_params, bool& proceed) { std::vector options_vec = {cli_params.get_options(), ctx_params.get_options(), gen_params.get_options()}; + proceed = true; + if (!parse_options(argc, argv, options_vec)) { print_usage(argc, argv, options_vec); - exit(cli_params.normal_exit ? 0 : 1); + proceed = false; + return cli_params.normal_exit ? 0 : 1; } if (!cli_params.process_and_check() || !ctx_params.process_and_check(cli_params.mode) || !gen_params.process_and_check(cli_params.mode, ctx_params.lora_model_dir)) { print_usage(argc, argv, options_vec); - exit(1); + proceed = false; + return 1; } } @@ -469,7 +473,9 @@ int main(int argc, const char* argv[]) { SDContextParams ctx_params; SDGenerationParams gen_params; - parse_args(argc, argv, cli_params, ctx_params, gen_params); + bool proceed; + int parsing_exit_code = parse_args(argc, argv, cli_params, ctx_params, gen_params, proceed); + if (proceed) { if (gen_params.video_frames > 4) { size_t last_dot_pos = cli_params.preview_path.find_last_of("."); std::string base_path = cli_params.preview_path; @@ -813,6 +819,6 @@ int main(int argc, const char* argv[]) { free(results); release_all_resources(); - - return 0; + } + return parsing_exit_code; } diff --git a/examples/common/common.hpp b/examples/common/common.hpp index 7ea95ed14..a39dd7be5 100644 --- a/examples/common/common.hpp +++ b/examples/common/common.hpp @@ -413,6 +413,10 @@ static bool parse_options(int argc, const char** argv, const std::vector