Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 69 additions & 26 deletions folder_paths.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import threading
import os
import time
import mimetypes
import logging
from typing import Set, List, Dict, Tuple, Literal
from collections.abc import Collection
from concurrent.futures import ThreadPoolExecutor

supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}

Expand Down Expand Up @@ -46,6 +48,8 @@

filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}

async_executor = ThreadPoolExecutor(32)

class CacheHelper:
"""
Helper class for managing file list cache data.
Expand Down Expand Up @@ -210,6 +214,22 @@ def get_folder_paths(folder_name: str) -> list[str]:
folder_name = map_legacy(folder_name)
return folder_names_and_paths[folder_name][0][:]


def prebuild_lists():
start_time = time.perf_counter()

with ThreadPoolExecutor(32) as executor:
calls = []
for folder_name in folder_names_and_paths:
calls.append(executor.submit(lambda: get_filename_list(folder_name)))

for call in calls:
call.result()

end_time = time.perf_counter()
logging.info(f"Scanned model lists in {end_time - start_time:.2f} seconds")


def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]:
if not os.path.isdir(directory):
return [], {}
Expand All @@ -225,33 +245,43 @@ def recursive_search(directory: str, excluded_dir_names: list[str] | None=None)
dirs[directory] = os.path.getmtime(directory)
except FileNotFoundError:
logging.warning(f"Warning: Unable to access {directory}. Skipping this path.")
return [], {}

logging.debug("recursive file list on directory {}".format(directory))
dirpath: str
subdirs: list[str]
filenames: list[str]

for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
for file_name in filenames:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)

for d in subdirs:
path: str = os.path.join(dirpath, d)
try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue

calls = []

def proc_subdir(path: str):
dirs[path] = os.path.getmtime(path)

def handle(file):
try:
if not os.path.isdir(file):
relative_path = os.path.relpath(file, directory)
result.append(relative_path)
return

calls.append(async_executor.submit(lambda f=file: proc_subdir(f)))

for subdir in os.listdir(file):
if subdir not in excluded_dir_names:
path = os.path.join(file, subdir)
calls.append(async_executor.submit(lambda p=path: handle(p)))
except Exception as e:
logging.error(f"recursive_search encountered error while handling '{file}': {e}")

calls.append(async_executor.submit(lambda: handle(directory)))
while len(calls) > 0:
calls.pop().result()

logging.debug("found {} files".format(len(result)))
return result, dirs


def filter_files_extensions(files: Collection[str], extensions: Collection[str]) -> list[str]:
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))



def get_full_path(folder_name: str, filename: str) -> str | None:
global folder_names_and_paths
folder_name = map_legacy(folder_name)
Expand Down Expand Up @@ -293,26 +323,39 @@ def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float]
strong_cache = cache_helper.get(folder_name)
if strong_cache is not None:
return strong_cache

global filename_list_cache
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in filename_list_cache:
return None
out = filename_list_cache[folder_name]
must_invalidate = threading.Event()
folders = folder_names_and_paths[folder_name]

for x in out[1]:
time_modified = out[1][x]
folder = x
def check_folder_mtime(folder: str, time_modified: float):
if os.path.getmtime(folder) != time_modified:
return None
must_invalidate.set()

folders = folder_names_and_paths[folder_name]
for x in folders[0]:
def check_new_dirs(x: str):
if os.path.isdir(x):
if x not in out[1]:
return None
must_invalidate.set()

calls = []

for x in out[1]:
time_modified = out[1][x]
calls.append(async_executor.submit(lambda f=x, t=time_modified: check_folder_mtime(f, t)))

for x in folders[0]:
calls.append(async_executor.submit(lambda f=x: check_new_dirs(f)))

for call in calls:
call.result()

if must_invalidate.is_set():
return None
return out

def get_filename_list(folder_name: str) -> list[str]:
Expand Down
2 changes: 2 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ def cleanup_temp():

nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)

folder_paths.prebuild_lists()

cuda_malloc_warning()

server.add_routes()
Expand Down
Loading