Skip to content

Loras can now be found in subdirectories of lora model folder. #634

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -640,17 +640,25 @@ class StableDiffusionGGML {

void apply_lora(const std::string& lora_name, float multiplier) {
int64_t t0 = ggml_time_ms();
std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors");
std::string ckpt_file_path = path_join(lora_model_dir, lora_name + ".ckpt");
std::vector<std::string> extensions = {".safetensors", ".ckpt"};
std::string st_file_path = path_join(lora_model_dir, lora_name + extensions[0]);
std::string ckpt_file_path = path_join(lora_model_dir, lora_name + extensions[1]);
std::string file_path;
if (file_exists(st_file_path)) {
file_path = st_file_path;
} else if (file_exists(ckpt_file_path)) {
file_path = ckpt_file_path;
} else {
LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str());
return;
file_path = get_filepath_from_dir(lora_model_dir, lora_name, &extensions);

if (file_path.empty()) {
LOG_WARN("can not find %s or %s for lora %s in main directory or subdirectories",
st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str());
return;
}
LOG_INFO("found lora %s in subdirectory: %s", lora_name.c_str(), file_path.c_str());
}

LoraModel lora(backend, file_path);
if (!lora.load_from_file()) {
LOG_WARN("load lora tensors from %s failed", file_path.c_str());
Expand Down
113 changes: 113 additions & 0 deletions util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,60 @@ std::string format(const char* fmt, ...) {
#ifdef _WIN32 // code for windows
#include <windows.h>

std::string get_filepath_from_dir_recursive(
const std::string& dir_path,
const std::string& file_name,
const std::vector<std::string>* extensions = nullptr) {

if (extensions) {
// Search with provided extensions
for (const auto& ext : *extensions) {
std::string file_path = path_join(dir_path, file_name + ext);
if (file_exists(file_path)) {
return file_path;
}
}
} else {
// Search for exact filename without extensions
std::string file_path = path_join(dir_path, file_name);
if (file_exists(file_path)) {
return file_path;
}
}

// Check subdirectories
WIN32_FIND_DATA findData;
HANDLE hFind;
std::string search_path = path_join(dir_path, "*");

hFind = FindFirstFile(search_path.c_str(), &findData);
if (hFind != INVALID_HANDLE_VALUE) {
do {
if ((findData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) &&
strcmp(findData.cFileName, ".") != 0 &&
strcmp(findData.cFileName, "..") != 0) {

std::string subdir = path_join(dir_path, findData.cFileName);
std::string result = get_filepath_from_dir_recursive(subdir, file_name, extensions);

if (!result.empty()) {
FindClose(hFind);
return result;
}
}
} while (FindNextFile(hFind, &findData));

FindClose(hFind);
}

return "";
}

std::string get_filepath_from_dir(const std::string& dir, const std::string& filename, const std::vector<std::string>* extensions = nullptr) {
return get_filepath_from_dir_recursive(dir, filename, extensions);
}


bool file_exists(const std::string& filename) {
DWORD attributes = GetFileAttributesA(filename.c_str());
return (attributes != INVALID_FILE_ATTRIBUTES && !(attributes & FILE_ATTRIBUTE_DIRECTORY));
Expand Down Expand Up @@ -153,6 +207,65 @@ std::vector<std::string> get_files_from_dir(const std::string& dir) {
#include <dirent.h>
#include <sys/stat.h>

std::string get_filepath_from_dir_recursive(
const std::string& dir_path,
const std::string& file_name,
const std::vector<std::string>* extensions = nullptr) {

DIR* dir = opendir(dir_path.c_str());
if (dir == nullptr) {
return "";
}

std::string result = "";

if (extensions) {
for (const auto& ext : *extensions) {
std::string file_path = path_join(dir_path, file_name + ext);
if (file_exists(file_path)) {
closedir(dir);
return file_path;
}
}
} else {
std::string file_path = path_join(dir_path, file_name);
if (file_exists(file_path)) {
closedir(dir);
return file_path;
}
}

// Check all subdirectories
struct dirent* entry;
while ((entry = readdir(dir)) != nullptr) {
if (entry->d_type == DT_DIR &&
strcmp(entry->d_name, ".") != 0 &&
strcmp(entry->d_name, "..") != 0) {

std::string subdir = path_join(dir_path, entry->d_name);

result = get_filepath_from_dir_recursive(subdir, file_name, extensions);

if (!result.empty()) {
closedir(dir);
return result;
}
}
}

closedir(dir);
return "";
}

std::string get_filepath_from_dir(
const std::string& dir_path,
const std::string& file_name,
const std::vector<std::string>* extensions = nullptr) {

return get_filepath_from_dir_recursive(dir_path, file_name, extensions);
}


bool file_exists(const std::string& filename) {
struct stat buffer;
return (stat(filename.c_str(), &buffer) == 0 && S_ISREG(buffer.st_mode));
Expand Down
2 changes: 2 additions & 0 deletions util.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ std::string format(const char* fmt, ...);

void replace_all_chars(std::string& str, char target, char replacement);

std::string get_filepath_from_dir(const std::string& dir, const std::string& filename, const std::vector<std::string>* extensions);

bool file_exists(const std::string& filename);
bool is_directory(const std::string& path);
std::string get_full_path(const std::string& dir, const std::string& filename);
Expand Down