diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index e38a6101f..e5d96428a 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -640,17 +640,25 @@ class StableDiffusionGGML { void apply_lora(const std::string& lora_name, float multiplier) { int64_t t0 = ggml_time_ms(); - std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors"); - std::string ckpt_file_path = path_join(lora_model_dir, lora_name + ".ckpt"); + std::vector extensions = {".safetensors", ".ckpt"}; + std::string st_file_path = path_join(lora_model_dir, lora_name + extensions[0]); + std::string ckpt_file_path = path_join(lora_model_dir, lora_name + extensions[1]); std::string file_path; if (file_exists(st_file_path)) { file_path = st_file_path; } else if (file_exists(ckpt_file_path)) { file_path = ckpt_file_path; } else { - LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str()); - return; + file_path = get_filepath_from_dir(lora_model_dir, lora_name, &extensions); + + if (file_path.empty()) { + LOG_WARN("can not find %s or %s for lora %s in main directory or subdirectories", + st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str()); + return; + } + LOG_INFO("found lora %s in subdirectory: %s", lora_name.c_str(), file_path.c_str()); } + LoraModel lora(backend, file_path); if (!lora.load_from_file()) { LOG_WARN("load lora tensors from %s failed", file_path.c_str()); diff --git a/util.cpp b/util.cpp index da11a14d6..46c404089 100644 --- a/util.cpp +++ b/util.cpp @@ -75,6 +75,60 @@ std::string format(const char* fmt, ...) { #ifdef _WIN32 // code for windows #include +std::string get_filepath_from_dir_recursive( + const std::string& dir_path, + const std::string& file_name, + const std::vector* extensions = nullptr) { + + if (extensions) { + // Search with provided extensions + for (const auto& ext : *extensions) { + std::string file_path = path_join(dir_path, file_name + ext); + if (file_exists(file_path)) { + return file_path; + } + } + } else { + // Search for exact filename without extensions + std::string file_path = path_join(dir_path, file_name); + if (file_exists(file_path)) { + return file_path; + } + } + + // Check subdirectories + WIN32_FIND_DATA findData; + HANDLE hFind; + std::string search_path = path_join(dir_path, "*"); + + hFind = FindFirstFile(search_path.c_str(), &findData); + if (hFind != INVALID_HANDLE_VALUE) { + do { + if ((findData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) && + strcmp(findData.cFileName, ".") != 0 && + strcmp(findData.cFileName, "..") != 0) { + + std::string subdir = path_join(dir_path, findData.cFileName); + std::string result = get_filepath_from_dir_recursive(subdir, file_name, extensions); + + if (!result.empty()) { + FindClose(hFind); + return result; + } + } + } while (FindNextFile(hFind, &findData)); + + FindClose(hFind); + } + + return ""; +} + +std::string get_filepath_from_dir(const std::string& dir, const std::string& filename, const std::vector* extensions = nullptr) { + return get_filepath_from_dir_recursive(dir, filename, extensions); +} + + bool file_exists(const std::string& filename) { DWORD attributes = GetFileAttributesA(filename.c_str()); return (attributes != INVALID_FILE_ATTRIBUTES && !(attributes & FILE_ATTRIBUTE_DIRECTORY)); @@ -153,6 +207,65 @@ std::vector get_files_from_dir(const std::string& dir) { #include #include +std::string get_filepath_from_dir_recursive( + const std::string& dir_path, + const std::string& file_name, + const std::vector* extensions = nullptr) { + + DIR* dir = opendir(dir_path.c_str()); + if (dir == nullptr) { + return ""; + } + + std::string result = ""; + + if (extensions) { + for (const auto& ext : *extensions) { + std::string file_path = path_join(dir_path, file_name + ext); + if (file_exists(file_path)) { + closedir(dir); + return file_path; + } + } + } else { + std::string file_path = path_join(dir_path, file_name); + if (file_exists(file_path)) { + closedir(dir); + return file_path; + } + } + + // Check all subdirectories + struct dirent* entry; + while ((entry = readdir(dir)) != nullptr) { + if (entry->d_type == DT_DIR && + strcmp(entry->d_name, ".") != 0 && + strcmp(entry->d_name, "..") != 0) { + + std::string subdir = path_join(dir_path, entry->d_name); + + result = get_filepath_from_dir_recursive(subdir, file_name, extensions); + + if (!result.empty()) { + closedir(dir); + return result; + } + } + } + + closedir(dir); + return ""; +} + +std::string get_filepath_from_dir( + const std::string& dir_path, + const std::string& file_name, + const std::vector* extensions = nullptr) { + + return get_filepath_from_dir_recursive(dir_path, file_name, extensions); +} + + bool file_exists(const std::string& filename) { struct stat buffer; return (stat(filename.c_str(), &buffer) == 0 && S_ISREG(buffer.st_mode)); diff --git a/util.h b/util.h index 14fa812e5..913dd029b 100644 --- a/util.h +++ b/util.h @@ -15,6 +15,8 @@ std::string format(const char* fmt, ...); void replace_all_chars(std::string& str, char target, char replacement); +std::string get_filepath_from_dir(const std::string& dir, const std::string& filename, const std::vector* extensions); + bool file_exists(const std::string& filename); bool is_directory(const std::string& path); std::string get_full_path(const std::string& dir, const std::string& filename);