From f1fd28d7a3343753000ac0e41756a5fcf7785dfc Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Thu, 19 Sep 2024 18:55:01 +0900 Subject: [PATCH 01/14] add internal /folder_paths route returns a json maps of folder paths --- api_server/routes/internal/internal_routes.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/api_server/routes/internal/internal_routes.py b/api_server/routes/internal/internal_routes.py index 8c46215f07ee..63704f13a6dc 100644 --- a/api_server/routes/internal/internal_routes.py +++ b/api_server/routes/internal/internal_routes.py @@ -1,6 +1,6 @@ from aiohttp import web from typing import Optional -from folder_paths import models_dir, user_directory, output_directory +from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths from api_server.services.file_service import FileService import app.logger @@ -36,6 +36,13 @@ async def list_files(request): async def get_logs(request): return web.json_response(app.logger.get_logs()) + @self.routes.get('/folder_paths') + async def get_folder_paths(request): + response = {} + for key in folder_names_and_paths: + response[key] = folder_names_and_paths[key][0] + return web.json_response(response) + def get_app(self): if self._app is None: self._app = web.Application() From cbaffac4a560b6e5db0379bef2a3380495bd5e54 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Thu, 19 Sep 2024 19:00:16 +0900 Subject: [PATCH 02/14] (minor) format download_models.py --- model_filemanager/download_models.py | 65 +++++++++++++++------------- 1 file changed, 35 insertions(+), 30 deletions(-) diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py index 712d59328f63..b24cbba59c1d 100644 --- a/model_filemanager/download_models.py +++ b/model_filemanager/download_models.py @@ -17,6 +17,7 @@ class DownloadStatusType(Enum): COMPLETED = "completed" ERROR = "error" + @dataclass class DownloadModelStatus(): status: str @@ -29,7 +30,7 @@ def __init__(self, status: DownloadStatusType, progress_percentage: float, messa self.progress_percentage = progress_percentage self.message = message self.already_existed = already_existed - + def to_dict(self) -> Dict[str, Any]: return { "status": self.status, @@ -38,9 +39,10 @@ def to_dict(self) -> Dict[str, Any]: "already_existed": self.already_existed } + async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], - model_name: str, - model_url: str, + model_name: str, + model_url: str, model_sub_directory: str, progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], progress_interval: float = 1.0) -> DownloadModelStatus: @@ -48,16 +50,16 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht Download a model file from a given URL into the models directory. Args: - model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]): + model_download_request (Callable[[str], Awaitable[aiohttp.ClientResponse]]): A function that makes an HTTP request. This makes it easier to mock in unit tests. - model_name (str): + model_name (str): The name of the model file to be downloaded. This will be the filename on disk. - model_url (str): + model_url (str): The URL from which to download the model. - model_sub_directory (str): - The subdirectory within the main models directory where the model + model_sub_directory (str): + The subdirectory within the main models directory where the model should be saved (e.g., 'checkpoints', 'loras', etc.). - progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]): + progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]): An asynchronous function to call with progress updates. Returns: @@ -65,17 +67,17 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht """ if not validate_model_subdirectory(model_sub_directory): return DownloadModelStatus( - DownloadStatusType.ERROR, + DownloadStatusType.ERROR, 0, - "Invalid model subdirectory", + "Invalid model subdirectory", False ) if not validate_filename(model_name): return DownloadModelStatus( - DownloadStatusType.ERROR, + DownloadStatusType.ERROR, 0, - "Invalid model name", + "Invalid model name", False ) @@ -101,7 +103,7 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht except Exception as e: logging.error(f"Error in downloading model: {e}") return await handle_download_error(e, model_name, progress_callback, relative_path) - + def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]: full_model_dir = os.path.join(models_base_dir, model_directory) @@ -114,13 +116,13 @@ def create_model_path(model_name: str, model_directory: str, models_base_dir: st if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir: raise Exception(f"Invalid model directory: {model_directory}/{model_name}") - relative_path = '/'.join([model_directory, model_name]) return file_path, relative_path -async def check_file_exists(file_path: str, - model_name: str, - progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], + +async def check_file_exists(file_path: str, + model_name: str, + progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], relative_path: str) -> Optional[DownloadModelStatus]: if os.path.exists(file_path): status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True) @@ -129,11 +131,11 @@ async def check_file_exists(file_path: str, return None -async def track_download_progress(response: aiohttp.ClientResponse, - file_path: str, - model_name: str, - progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], - relative_path: str, +async def track_download_progress(response: aiohttp.ClientResponse, + file_path: str, + model_name: str, + progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], + relative_path: str, interval: float = 1.0) -> DownloadModelStatus: try: total_size = int(response.headers.get('Content-Length', 0)) @@ -156,12 +158,12 @@ async def update_progress(): break f.write(chunk) downloaded += len(chunk) - + if time.time() - last_update_time >= interval: await update_progress() await update_progress() - + logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}") status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False) await progress_callback(relative_path, status) @@ -172,18 +174,20 @@ async def update_progress(): logging.error(traceback.format_exc()) return await handle_download_error(e, model_name, progress_callback, relative_path) -async def handle_download_error(e: Exception, - model_name: str, - progress_callback: Callable[[str, DownloadModelStatus], Any], + +async def handle_download_error(e: Exception, + model_name: str, + progress_callback: Callable[[str, DownloadModelStatus], Any], relative_path: str) -> DownloadModelStatus: error_message = f"Error downloading {model_name}: {str(e)}" status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) await progress_callback(relative_path, status) return status + def validate_model_subdirectory(model_subdirectory: str) -> bool: """ - Validate that the model subdirectory is safe to install into. + Validate that the model subdirectory is safe to install into. Must not contain relative paths, nested paths or special characters other than underscores and hyphens. @@ -204,10 +208,11 @@ def validate_model_subdirectory(model_subdirectory: str) -> bool: return True + def validate_filename(filename: str)-> bool: """ Validate a filename to ensure it's safe and doesn't contain any path traversal attempts. - + Args: filename (str): The filename to validate From 618f4d848c51c8ebe859d31cecd108a197c8137c Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Thu, 19 Sep 2024 19:15:11 +0900 Subject: [PATCH 03/14] initial folder path input on download api --- model_filemanager/download_models.py | 38 +++++++++++++++++++++++----- server.py | 3 ++- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py index b24cbba59c1d..124854727062 100644 --- a/model_filemanager/download_models.py +++ b/model_filemanager/download_models.py @@ -3,7 +3,7 @@ import os import traceback import logging -from folder_paths import models_dir +from folder_paths import models_dir, folder_names_and_paths, get_folder_paths import re from typing import Callable, Any, Optional, Awaitable, Dict from enum import Enum @@ -43,9 +43,10 @@ def to_dict(self) -> Dict[str, Any]: async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], model_name: str, model_url: str, - model_sub_directory: str, + model_directory: str, progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], - progress_interval: float = 1.0) -> DownloadModelStatus: + progress_interval: float = 1.0, + folder_path: str = None) -> DownloadModelStatus: """ Download a model file from a given URL into the models directory. @@ -56,16 +57,18 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht The name of the model file to be downloaded. This will be the filename on disk. model_url (str): The URL from which to download the model. - model_sub_directory (str): + model_directory (str): The subdirectory within the main models directory where the model should be saved (e.g., 'checkpoints', 'loras', etc.). progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]): An asynchronous function to call with progress updates. + folder_path (str); + Optional path to which model folder should be used as the root, or None to use default. Returns: DownloadModelStatus: The result of the download operation. """ - if not validate_model_subdirectory(model_sub_directory): + if not validate_model_subdirectory(model_directory): return DownloadModelStatus( DownloadStatusType.ERROR, 0, @@ -81,7 +84,28 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht False ) - file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir) + models_base_dir = models_dir + + if folder_path: + if not model_directory in folder_names_and_paths: + return DownloadModelStatus( + DownloadStatusType.ERROR, + 0, + "Invalid model directory, when using 'folder_path', model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.", + False + ) + + if not folder_path in get_folder_paths(model_directory): + return DownloadModelStatus( + DownloadStatusType.ERROR, + 0, + "Invalid folder path, does not match the list of known directories. If you're seeing this in the downloader UI, you may need to refresh the page.", + False + ) + models_base_dir = folder_path + model_directory = '' + + file_path, relative_path = create_model_path(model_name, model_directory, models_base_dir) existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path) if existing_file: return existing_file @@ -114,7 +138,7 @@ def create_model_path(model_name: str, model_directory: str, models_base_dir: st abs_file_path = os.path.abspath(file_path) abs_base_dir = os.path.abspath(str(models_base_dir)) if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir: - raise Exception(f"Invalid model directory: {model_directory}/{model_name}") + raise Exception(f"Invalid model directory: {models_base_dir}/{model_directory}/{model_name}") relative_path = '/'.join([model_directory, model_name]) return file_path, relative_path diff --git a/server.py b/server.py index 9321e4d088a1..7f184682af5e 100644 --- a/server.py +++ b/server.py @@ -684,6 +684,7 @@ async def report_progress(filename: str, status: DownloadModelStatus): data = await request.json() url = data.get('url') model_directory = data.get('model_directory') + folder_path = data.get('folder_path') model_filename = data.get('model_filename') progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress. @@ -695,7 +696,7 @@ async def report_progress(filename: str, status: DownloadModelStatus): logging.error("Client session is not initialized") return web.Response(status=500) - task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval)) + task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval, folder_path)) await task return web.json_response(task.result().to_dict()) From 80a916777b8de69da40e5ea04ee698f30c275b9c Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Thu, 19 Sep 2024 19:47:27 +0900 Subject: [PATCH 04/14] actually, require folder_path and clean up some code --- model_filemanager/__init__.py | 2 +- model_filemanager/download_models.py | 110 +++++++++------------------ server.py | 2 +- 3 files changed, 37 insertions(+), 77 deletions(-) diff --git a/model_filemanager/__init__.py b/model_filemanager/__init__.py index e318351c0512..b7ac16256ac1 100644 --- a/model_filemanager/__init__.py +++ b/model_filemanager/__init__.py @@ -1,2 +1,2 @@ # model_manager/__init__.py -from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename +from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py index 124854727062..dda69cfa1811 100644 --- a/model_filemanager/download_models.py +++ b/model_filemanager/download_models.py @@ -44,9 +44,9 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht model_name: str, model_url: str, model_directory: str, + folder_path: str, progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], - progress_interval: float = 1.0, - folder_path: str = None) -> DownloadModelStatus: + progress_interval: float = 1.0) -> DownloadModelStatus: """ Download a model file from a given URL into the models directory. @@ -68,89 +68,74 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht Returns: DownloadModelStatus: The result of the download operation. """ - if not validate_model_subdirectory(model_directory): + if not validate_filename(model_name): return DownloadModelStatus( DownloadStatusType.ERROR, 0, - "Invalid model subdirectory", + "Invalid model name", False ) - if not validate_filename(model_name): + if not model_directory in folder_names_and_paths: return DownloadModelStatus( DownloadStatusType.ERROR, 0, - "Invalid model name", + "Invalid or unrecognized model directory. model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.", False ) - models_base_dir = models_dir - - if folder_path: - if not model_directory in folder_names_and_paths: - return DownloadModelStatus( - DownloadStatusType.ERROR, - 0, - "Invalid model directory, when using 'folder_path', model_directory must be a known model type (eg 'checkpoints'). If you are seeing this error for a custom model type, ensure the relevant custom nodes are installed and working.", - False - ) - - if not folder_path in get_folder_paths(model_directory): - return DownloadModelStatus( - DownloadStatusType.ERROR, - 0, - "Invalid folder path, does not match the list of known directories. If you're seeing this in the downloader UI, you may need to refresh the page.", - False - ) - models_base_dir = folder_path - model_directory = '' - - file_path, relative_path = create_model_path(model_name, model_directory, models_base_dir) - existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path) + if not folder_path in get_folder_paths(model_directory): + return DownloadModelStatus( + DownloadStatusType.ERROR, + 0, + "Invalid folder path, does not match the list of known directories. If you're seeing this in the downloader UI, you may need to refresh the page.", + False + ) + + file_path = create_model_path(model_name, folder_path) + existing_file = await check_file_exists(file_path, model_name, progress_callback) if existing_file: return existing_file try: status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False) - await progress_callback(relative_path, status) + await progress_callback(model_name, status) response = await model_download_request(model_url) if response.status != 200: error_message = f"Failed to download {model_name}. Status code: {response.status}" logging.error(error_message) status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) - await progress_callback(relative_path, status) + await progress_callback(model_name, status) return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) - return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval) + return await track_download_progress(response, file_path, model_name, progress_callback, progress_interval) except Exception as e: logging.error(f"Error in downloading model: {e}") - return await handle_download_error(e, model_name, progress_callback, relative_path) + return await handle_download_error(e, model_name, progress_callback) -def create_model_path(model_name: str, model_directory: str, models_base_dir: str) -> tuple[str, str]: - full_model_dir = os.path.join(models_base_dir, model_directory) - os.makedirs(full_model_dir, exist_ok=True) - file_path = os.path.join(full_model_dir, model_name) +def create_model_path(model_name: str, folder_path: str) -> tuple[str, str]: + os.makedirs(folder_path, exist_ok=True) + file_path = os.path.join(folder_path, model_name) # Ensure the resulting path is still within the base directory abs_file_path = os.path.abspath(file_path) - abs_base_dir = os.path.abspath(str(models_base_dir)) + abs_base_dir = os.path.abspath(folder_path) if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir: - raise Exception(f"Invalid model directory: {models_base_dir}/{model_directory}/{model_name}") + raise Exception(f"Invalid model directory: {folder_path}/{model_name}") - relative_path = '/'.join([model_directory, model_name]) - return file_path, relative_path + return file_path async def check_file_exists(file_path: str, model_name: str, - progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], - relative_path: str) -> Optional[DownloadModelStatus]: + progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]] + ) -> Optional[DownloadModelStatus]: if os.path.exists(file_path): status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True) - await progress_callback(relative_path, status) + await progress_callback(model_name, status) return status return None @@ -159,7 +144,6 @@ async def track_download_progress(response: aiohttp.ClientResponse, file_path: str, model_name: str, progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], - relative_path: str, interval: float = 1.0) -> DownloadModelStatus: try: total_size = int(response.headers.get('Content-Length', 0)) @@ -170,7 +154,7 @@ async def update_progress(): nonlocal last_update_time progress = (downloaded / total_size) * 100 if total_size > 0 else 0 status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False) - await progress_callback(relative_path, status) + await progress_callback(model_name, status) last_update_time = time.time() with open(file_path, 'wb') as f: @@ -190,49 +174,25 @@ async def update_progress(): logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}") status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False) - await progress_callback(relative_path, status) + await progress_callback(model_name, status) return status except Exception as e: logging.error(f"Error in track_download_progress: {e}") logging.error(traceback.format_exc()) - return await handle_download_error(e, model_name, progress_callback, relative_path) + return await handle_download_error(e, model_name, progress_callback) async def handle_download_error(e: Exception, model_name: str, - progress_callback: Callable[[str, DownloadModelStatus], Any], - relative_path: str) -> DownloadModelStatus: + progress_callback: Callable[[str, DownloadModelStatus], Any] + ) -> DownloadModelStatus: error_message = f"Error downloading {model_name}: {str(e)}" status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) - await progress_callback(relative_path, status) + await progress_callback(model_name, status) return status -def validate_model_subdirectory(model_subdirectory: str) -> bool: - """ - Validate that the model subdirectory is safe to install into. - Must not contain relative paths, nested paths or special characters - other than underscores and hyphens. - - Args: - model_subdirectory (str): The subdirectory for the specific model type. - - Returns: - bool: True if the subdirectory is safe, False otherwise. - """ - if len(model_subdirectory) > 50: - return False - - if '..' in model_subdirectory or '/' in model_subdirectory: - return False - - if not re.match(r'^[a-zA-Z0-9_-]+$', model_subdirectory): - return False - - return True - - def validate_filename(filename: str)-> bool: """ Validate a filename to ensure it's safe and doesn't contain any path traversal attempts. diff --git a/server.py b/server.py index 7f184682af5e..713d4d1a9e2c 100644 --- a/server.py +++ b/server.py @@ -688,7 +688,7 @@ async def report_progress(filename: str, status: DownloadModelStatus): model_filename = data.get('model_filename') progress_interval = data.get('progress_interval', 1.0) # In seconds, how often to report download progress. - if not url or not model_directory or not model_filename: + if not url or not model_directory or not model_filename or not folder_path: return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400) session = self.client_session From 68f8d8c568d2e91c3e39fffd00fe346d47d3018b Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Thu, 19 Sep 2024 19:47:35 +0900 Subject: [PATCH 05/14] partial tests update --- .../download_models_test.py | 71 +++++++------------ 1 file changed, 26 insertions(+), 45 deletions(-) diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py index 66150a4682fd..e86ba2d763d9 100644 --- a/tests-unit/prompt_server_test/download_models_test.py +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -4,7 +4,7 @@ import itertools import os from unittest.mock import AsyncMock, patch, MagicMock -from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename +from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename class AsyncIteratorMock: """ @@ -59,7 +59,7 @@ async def test_download_model_success(): mock_open.return_value.__enter__.return_value = mock_file time_values = itertools.count(0, 0.1) - with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \ + with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \ patch('model_filemanager.check_file_exists', return_value=None), \ patch('builtins.open', mock_open), \ patch('time.time', side_effect=time_values): # Simulate time passing @@ -69,6 +69,7 @@ async def test_download_model_success(): 'model.sft', 'http://example.com/model.sft', 'checkpoints', + 'mock_directory', mock_progress_callback ) @@ -83,13 +84,13 @@ async def test_download_model_success(): # Check initial call mock_progress_callback.assert_any_call( - 'checkpoints/model.sft', + 'model.sft', DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False) ) # Check final call mock_progress_callback.assert_any_call( - 'checkpoints/model.sft', + 'model.sft', DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False) ) @@ -110,7 +111,7 @@ async def test_download_model_url_request_failure(): mock_progress_callback = AsyncMock() # Mock the create_model_path function - with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')): + with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'): # Mock the check_file_exists function to return None (file doesn't exist) with patch('model_filemanager.check_file_exists', return_value=None): # Call the function @@ -118,6 +119,7 @@ async def test_download_model_url_request_failure(): mock_get, 'model.safetensors', 'http://example.com/model.safetensors', + 'checkpoints', 'mock_directory', mock_progress_callback ) @@ -163,12 +165,13 @@ async def test_download_model_invalid_model_subdirectory(): 'model.sft', 'http://example.com/model.sft', '../bad_path', + '../bad_path', mock_progress_callback ) # Assert the result assert isinstance(result, DownloadModelStatus) - assert result.message == 'Invalid model subdirectory' + assert result.message.startswith('Invalid or unrecognized model directory') assert result.status == 'error' assert result.already_existed is False @@ -177,14 +180,13 @@ async def test_download_model_invalid_model_subdirectory(): def test_create_model_path(tmp_path, monkeypatch): mock_models_dir = tmp_path / "models" monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir)) - + model_name = "test_model.sft" model_directory = "test_dir" - - file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir) - + + file_path = create_model_path(model_name, model_directory, mock_models_dir) + assert file_path == str(mock_models_dir / model_directory / model_name) - assert relative_path == f"{model_directory}/{model_name}" assert os.path.exists(os.path.dirname(file_path)) @@ -192,29 +194,29 @@ def test_create_model_path(tmp_path, monkeypatch): async def test_check_file_exists_when_file_exists(tmp_path): file_path = tmp_path / "existing_model.sft" file_path.touch() # Create an empty file - + mock_callback = AsyncMock() - - result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft") - + + result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback) + assert result is not None assert result.status == "completed" assert result.message == "existing_model.sft already exists" assert result.already_existed is True - + mock_callback.assert_called_once_with( - "test/existing_model.sft", + "existing_model.sft", DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True) ) @pytest.mark.asyncio async def test_check_file_exists_when_file_does_not_exist(tmp_path): file_path = tmp_path / "non_existing_model.sft" - + mock_callback = AsyncMock() - - result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft") - + + result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback) + assert result is None mock_callback.assert_not_called() @@ -230,13 +232,13 @@ async def test_track_download_progress_no_content_length(): with patch('builtins.open', mock_open): result = await track_download_progress( mock_response, '/mock/path/model.sft', 'model.sft', - mock_callback, 'models/model.sft', interval=0.1 + mock_callback, interval=0.1 ) assert result.status == "completed" # Check that progress was reported even without knowing the total size mock_callback.assert_any_call( - 'models/model.sft', + 'model.sft', DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False) ) @@ -257,7 +259,7 @@ async def test_track_download_progress_interval(): patch('time.time', mock_time): await track_download_progress( mock_response, '/mock/path/model.sft', 'model.sft', - mock_callback, 'models/model.sft', interval=1.0 + mock_callback, interval=1.0 ) # Print out the actual call count and the arguments of each call for debugging @@ -279,27 +281,6 @@ async def test_track_download_progress_interval(): assert last_call[0][1].status == "completed" assert last_call[0][1].progress_percentage == 100 -def test_valid_subdirectory(): - assert validate_model_subdirectory("valid-model123") is True - -def test_subdirectory_too_long(): - assert validate_model_subdirectory("a" * 51) is False - -def test_subdirectory_with_double_dots(): - assert validate_model_subdirectory("model/../unsafe") is False - -def test_subdirectory_with_slash(): - assert validate_model_subdirectory("model/unsafe") is False - -def test_subdirectory_with_special_characters(): - assert validate_model_subdirectory("model@unsafe") is False - -def test_subdirectory_with_underscore_and_dash(): - assert validate_model_subdirectory("valid_model-name") is True - -def test_empty_subdirectory(): - assert validate_model_subdirectory("") is False - @pytest.mark.parametrize("filename, expected", [ ("valid_model.safetensors", True), ("valid_model.sft", True), From c539ee55f4ef30ed6718957e73890c467a60ceea Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Thu, 19 Sep 2024 20:05:14 +0900 Subject: [PATCH 06/14] fix & logging --- model_filemanager/download_models.py | 3 ++- server.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py index dda69cfa1811..23a52c441eb0 100644 --- a/model_filemanager/download_models.py +++ b/model_filemanager/download_models.py @@ -88,7 +88,7 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht return DownloadModelStatus( DownloadStatusType.ERROR, 0, - "Invalid folder path, does not match the list of known directories. If you're seeing this in the downloader UI, you may need to refresh the page.", + f"Invalid folder path '{folder_path}', does not match the list of known directories ({get_folder_paths(model_directory)}). If you're seeing this in the downloader UI, you may need to refresh the page.", False ) @@ -98,6 +98,7 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht return existing_file try: + logging.info(f"Downloading {model_name} from {model_url}") status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False) await progress_callback(model_name, status) diff --git a/server.py b/server.py index 713d4d1a9e2c..1f9135d6f37e 100644 --- a/server.py +++ b/server.py @@ -696,7 +696,7 @@ async def report_progress(filename: str, status: DownloadModelStatus): logging.error("Client session is not initialized") return web.Response(status=500) - task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, report_progress, progress_interval, folder_path)) + task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval)) await task return web.json_response(task.result().to_dict()) From cfb5d6a53227aa8a48c242eb1b0b0434f4a49009 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Thu, 19 Sep 2024 20:06:53 +0900 Subject: [PATCH 07/14] also download to a tmp file not the live file to avoid compounding errors from network failure --- model_filemanager/download_models.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py index 23a52c441eb0..c2d0c6f93000 100644 --- a/model_filemanager/download_models.py +++ b/model_filemanager/download_models.py @@ -158,7 +158,10 @@ async def update_progress(): await progress_callback(model_name, status) last_update_time = time.time() - with open(file_path, 'wb') as f: + if os.path.exists(file_path + '.tmp'): + os.remove(file_path + '.tmp') + + with open(file_path + '.tmp', 'wb') as f: chunk_iterator = response.content.iter_chunked(8192) while True: try: @@ -171,6 +174,8 @@ async def update_progress(): if time.time() - last_update_time >= interval: await update_progress() + os.rename(file_path + '.tmp', file_path) + await update_progress() logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}") From 3e8dbd51da5e78a9069370e1a53de9f51a12b63c Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Thu, 19 Sep 2024 20:10:40 +0900 Subject: [PATCH 08/14] update tests again --- .../download_models_test.py | 51 +++++++++++++------ 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py index e86ba2d763d9..a8acf3c89d24 100644 --- a/tests-unit/prompt_server_test/download_models_test.py +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -62,6 +62,7 @@ async def test_download_model_success(): with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \ patch('model_filemanager.check_file_exists', return_value=None), \ patch('builtins.open', mock_open), \ + patch('folder_paths.get_folder_paths', return_value=['mock_directory']), \ patch('time.time', side_effect=time_values): # Simulate time passing result = await download_model( @@ -111,18 +112,18 @@ async def test_download_model_url_request_failure(): mock_progress_callback = AsyncMock() # Mock the create_model_path function - with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'): - # Mock the check_file_exists function to return None (file doesn't exist) - with patch('model_filemanager.check_file_exists', return_value=None): - # Call the function - result = await download_model( - mock_get, - 'model.safetensors', - 'http://example.com/model.safetensors', - 'checkpoints', - 'mock_directory', - mock_progress_callback - ) + with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \ + patch('model_filemanager.check_file_exists', return_value=None), \ + patch('folder_paths.get_folder_paths', return_value=['mock_directory']): + # Call the function + result = await download_model( + mock_get, + 'model.safetensors', + 'http://example.com/model.safetensors', + 'checkpoints', + 'mock_directory', + mock_progress_callback + ) # Assert the expected behavior assert isinstance(result, DownloadModelStatus) @@ -132,7 +133,7 @@ async def test_download_model_url_request_failure(): # Check that progress_callback was called with the correct arguments mock_progress_callback.assert_any_call( - 'mock_directory/model.safetensors', + 'model.safetensors', DownloadModelStatus( status=DownloadStatusType.PENDING, progress_percentage=0, @@ -141,7 +142,7 @@ async def test_download_model_url_request_failure(): ) ) mock_progress_callback.assert_called_with( - 'mock_directory/model.safetensors', + 'model.safetensors', DownloadModelStatus( status=DownloadStatusType.ERROR, progress_percentage=0, @@ -155,11 +156,9 @@ async def test_download_model_url_request_failure(): @pytest.mark.asyncio async def test_download_model_invalid_model_subdirectory(): - mock_make_request = AsyncMock() mock_progress_callback = AsyncMock() - result = await download_model( mock_make_request, 'model.sft', @@ -175,6 +174,26 @@ async def test_download_model_invalid_model_subdirectory(): assert result.status == 'error' assert result.already_existed is False +@pytest.mark.asyncio +async def test_download_model_invalid_folder_path(): + mock_make_request = AsyncMock() + mock_progress_callback = AsyncMock() + + with patch('folder_paths.get_folder_paths', return_value=['valid_path']): + result = await download_model( + mock_make_request, + 'model.sft', + 'http://example.com/model.sft', + 'checkpoints', + 'invalid_path', + mock_progress_callback + ) + + # Assert the result + assert isinstance(result, DownloadModelStatus) + assert result.message.startswith("Invalid folder path") + assert result.status == 'error' + assert result.already_existed is False # For create_model_path function def test_create_model_path(tmp_path, monkeypatch): From 57a91a2c1c5e52eca2b7ffef1edec2435530c6c9 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Thu, 19 Sep 2024 20:45:27 +0900 Subject: [PATCH 09/14] test tweaks --- model_filemanager/download_models.py | 2 +- .../prompt_server_test/download_models_test.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py index c2d0c6f93000..5b0642e3666e 100644 --- a/model_filemanager/download_models.py +++ b/model_filemanager/download_models.py @@ -3,7 +3,7 @@ import os import traceback import logging -from folder_paths import models_dir, folder_names_and_paths, get_folder_paths +from folder_paths import folder_names_and_paths, get_folder_paths import re from typing import Callable, Any, Optional, Awaitable, Dict from enum import Enum diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py index a8acf3c89d24..0e6177c22df4 100644 --- a/tests-unit/prompt_server_test/download_models_test.py +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -1,4 +1,5 @@ import pytest +import tempfile import aiohttp from aiohttp import ClientResponse import itertools @@ -6,6 +7,11 @@ from unittest.mock import AsyncMock, patch, MagicMock from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename +@pytest.fixture +def temp_dir(): + with tempfile.TemporaryDirectory() as tmpdirname: + yield tmpdirname + class AsyncIteratorMock: """ A mock class that simulates an asynchronous iterator. @@ -42,7 +48,7 @@ def iter_chunked(self, chunk_size): return AsyncIteratorMock(self.chunks) @pytest.mark.asyncio -async def test_download_model_success(): +async def test_download_model_success(temp_dir): mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response.status = 200 mock_response.headers = {'Content-Length': '1000'} @@ -62,7 +68,7 @@ async def test_download_model_success(): with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \ patch('model_filemanager.check_file_exists', return_value=None), \ patch('builtins.open', mock_open), \ - patch('folder_paths.get_folder_paths', return_value=['mock_directory']), \ + patch('folder_paths.get_folder_paths', return_value=[temp_dir]), \ patch('time.time', side_effect=time_values): # Simulate time passing result = await download_model( @@ -70,7 +76,7 @@ async def test_download_model_success(): 'model.sft', 'http://example.com/model.sft', 'checkpoints', - 'mock_directory', + temp_dir, mock_progress_callback ) From 3ac444c35f5e69f71b853981f78fa443e690f68f Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Sun, 22 Sep 2024 16:38:38 +0900 Subject: [PATCH 10/14] workaround the first tests blocker --- tests-unit/prompt_server_test/download_models_test.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py index 0e6177c22df4..c495f344f3ec 100644 --- a/tests-unit/prompt_server_test/download_models_test.py +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -3,9 +3,10 @@ import aiohttp from aiohttp import ClientResponse import itertools -import os +import os from unittest.mock import AsyncMock, patch, MagicMock from model_filemanager import download_model, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename +import folder_paths @pytest.fixture def temp_dir(): @@ -65,9 +66,12 @@ async def test_download_model_success(temp_dir): mock_open.return_value.__enter__.return_value = mock_file time_values = itertools.count(0, 0.1) + fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)} + with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \ patch('model_filemanager.check_file_exists', return_value=None), \ patch('builtins.open', mock_open), \ + patch('folder_paths.folder_names_and_paths', fake_paths), \ patch('folder_paths.get_folder_paths', return_value=[temp_dir]), \ patch('time.time', side_effect=time_values): # Simulate time passing @@ -116,11 +120,14 @@ async def test_download_model_url_request_failure(): mock_response.status = 404 # Simulate a "Not Found" error mock_get = AsyncMock(return_value=mock_response) mock_progress_callback = AsyncMock() + + fake_paths = {'checkpoints': (['mock_directory'], folder_paths.supported_pt_extensions)} # Mock the create_model_path function with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \ patch('model_filemanager.check_file_exists', return_value=None), \ - patch('folder_paths.get_folder_paths', return_value=['mock_directory']): + patch('folder_paths.get_folder_paths', return_value=['mock_directory']), \ + patch('folder_paths.folder_names_and_paths', fake_paths): # Call the function result = await download_model( mock_get, From 6fbceea4a6213e2aa05849569b13c279c04245c5 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Sun, 22 Sep 2024 16:56:52 +0900 Subject: [PATCH 11/14] fix file handling in tests --- model_filemanager/download_models.py | 8 +- .../download_models_test.py | 78 ++++++++++--------- 2 files changed, 43 insertions(+), 43 deletions(-) diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py index 5b0642e3666e..ae3032ecb2f8 100644 --- a/model_filemanager/download_models.py +++ b/model_filemanager/download_models.py @@ -158,10 +158,8 @@ async def update_progress(): await progress_callback(model_name, status) last_update_time = time.time() - if os.path.exists(file_path + '.tmp'): - os.remove(file_path + '.tmp') - - with open(file_path + '.tmp', 'wb') as f: + temp_file_path = file_path + '.tmp' + with open(temp_file_path, 'wb') as f: chunk_iterator = response.content.iter_chunked(8192) while True: try: @@ -174,7 +172,7 @@ async def update_progress(): if time.time() - last_update_time >= interval: await update_progress() - os.rename(file_path + '.tmp', file_path) + os.rename(temp_file_path, file_path) await update_progress() diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py index c495f344f3ec..8f633f8cf095 100644 --- a/tests-unit/prompt_server_test/download_models_test.py +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -60,19 +60,13 @@ async def test_download_model_success(temp_dir): mock_make_request = AsyncMock(return_value=mock_response) mock_progress_callback = AsyncMock() - # Mock file operations - mock_open = MagicMock() - mock_file = MagicMock() - mock_open.return_value.__enter__.return_value = mock_file time_values = itertools.count(0, 0.1) fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)} with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \ patch('model_filemanager.check_file_exists', return_value=None), \ - patch('builtins.open', mock_open), \ patch('folder_paths.folder_names_and_paths', fake_paths), \ - patch('folder_paths.get_folder_paths', return_value=[temp_dir]), \ patch('time.time', side_effect=time_values): # Simulate time passing result = await download_model( @@ -105,10 +99,11 @@ async def test_download_model_success(temp_dir): DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False) ) - # Verify file writing - mock_file.write.assert_any_call(b'a' * 500) - mock_file.write.assert_any_call(b'b' * 300) - mock_file.write.assert_any_call(b'c' * 200) + mock_file_path = os.path.join(temp_dir, 'model.sft') + assert os.path.exists(mock_file_path) + with open(mock_file_path, 'rb') as mock_file: + assert mock_file.read() == b''.join(chunks) + os.remove(mock_file_path) # Verify request was made mock_make_request.assert_called_once_with('http://example.com/model.sft') @@ -192,15 +187,14 @@ async def test_download_model_invalid_folder_path(): mock_make_request = AsyncMock() mock_progress_callback = AsyncMock() - with patch('folder_paths.get_folder_paths', return_value=['valid_path']): - result = await download_model( - mock_make_request, - 'model.sft', - 'http://example.com/model.sft', - 'checkpoints', - 'invalid_path', - mock_progress_callback - ) + result = await download_model( + mock_make_request, + 'model.sft', + 'http://example.com/model.sft', + 'checkpoints', + 'invalid_path', + mock_progress_callback + ) # Assert the result assert isinstance(result, DownloadModelStatus) @@ -253,21 +247,28 @@ async def test_check_file_exists_when_file_does_not_exist(tmp_path): mock_callback.assert_not_called() @pytest.mark.asyncio -async def test_track_download_progress_no_content_length(): +async def test_track_download_progress_no_content_length(temp_dir): mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response.headers = {} # No Content-Length header - mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500]) + chunks = [b'a' * 500, b'b' * 500] + mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks) mock_callback = AsyncMock() - mock_open = MagicMock(return_value=MagicMock()) - with patch('builtins.open', mock_open): - result = await track_download_progress( - mock_response, '/mock/path/model.sft', 'model.sft', - mock_callback, interval=0.1 - ) + full_path = os.path.join(temp_dir, 'model.sft') + + result = await track_download_progress( + mock_response, full_path, 'model.sft', + mock_callback, interval=0.1 + ) assert result.status == "completed" + + assert os.path.exists(full_path) + with open(full_path, 'rb') as f: + assert f.read() == b''.join(chunks) + os.remove(full_path) + # Check that progress was reported even without knowing the total size mock_callback.assert_any_call( 'model.sft', @@ -275,10 +276,11 @@ async def test_track_download_progress_no_content_length(): ) @pytest.mark.asyncio -async def test_track_download_progress_interval(): +async def test_track_download_progress_interval(temp_dir): mock_response = AsyncMock(spec=aiohttp.ClientResponse) mock_response.headers = {'Content-Length': '1000'} - mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10) + chunks = [b'a' * 100] * 10 + mock_response.content.iter_chunked.return_value = AsyncIteratorMock(chunks) mock_callback = AsyncMock() mock_open = MagicMock(return_value=MagicMock()) @@ -287,18 +289,18 @@ async def test_track_download_progress_interval(): mock_time = MagicMock() mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks - with patch('builtins.open', mock_open), \ - patch('time.time', mock_time): + full_path = os.path.join(temp_dir, 'model.sft') + + with patch('time.time', mock_time): await track_download_progress( - mock_response, '/mock/path/model.sft', 'model.sft', + mock_response, full_path, 'model.sft', mock_callback, interval=1.0 ) - - # Print out the actual call count and the arguments of each call for debugging - print(f"mock_callback was called {mock_callback.call_count} times") - for i, call in enumerate(mock_callback.call_args_list): - args, kwargs = call - print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%") + + assert os.path.exists(full_path) + with open(full_path, 'rb') as f: + assert f.read() == b''.join(chunks) + os.remove(full_path) # Assert that progress was updated at least 3 times (start, at least one interval, and end) assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}" From 0df1e732207949c1e6bbf6f146059fc1d100fa40 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Sun, 22 Sep 2024 17:00:49 +0900 Subject: [PATCH 12/14] rewrite test for create_model_path --- .../prompt_server_test/download_models_test.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py index 8f633f8cf095..bd41745fa677 100644 --- a/tests-unit/prompt_server_test/download_models_test.py +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -202,19 +202,21 @@ async def test_download_model_invalid_folder_path(): assert result.status == 'error' assert result.already_existed is False -# For create_model_path function def test_create_model_path(tmp_path, monkeypatch): - mock_models_dir = tmp_path / "models" - monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir)) + model_name = "model.safetensors" + folder_path = os.path.join(tmp_path, "mock_dir") - model_name = "test_model.sft" - model_directory = "test_dir" + file_path = create_model_path(model_name, folder_path) - file_path = create_model_path(model_name, model_directory, mock_models_dir) - - assert file_path == str(mock_models_dir / model_directory / model_name) + assert file_path == os.path.join(folder_path, "model.safetensors") assert os.path.exists(os.path.dirname(file_path)) + with pytest.raises(Exception, match="Invalid model directory"): + create_model_path("../path_traversal.safetensors", folder_path) + + with pytest.raises(Exception, match="Invalid model directory"): + create_model_path("/etc/some_root_path", folder_path) + @pytest.mark.asyncio async def test_check_file_exists_when_file_exists(tmp_path): From ae931da0978d4cba102acb1f80695a99ffb8cb03 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Sun, 22 Sep 2024 17:07:43 +0900 Subject: [PATCH 13/14] minor doc fix --- model_filemanager/download_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py index ae3032ecb2f8..5ffec395e2d3 100644 --- a/model_filemanager/download_models.py +++ b/model_filemanager/download_models.py @@ -63,7 +63,7 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht progress_callback (Callable[[str, DownloadModelStatus], Awaitable[Any]]): An asynchronous function to call with progress updates. folder_path (str); - Optional path to which model folder should be used as the root, or None to use default. + Path to which model folder should be used as the root. Returns: DownloadModelStatus: The result of the download operation. From 89b3539dd6f96086f635716fb47f6a5fa3805010 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Mon, 23 Sep 2024 19:00:18 +0900 Subject: [PATCH 14/14] avoid 'mock_directory' use temp dir to avoid accidental fs pollution from tests --- tests-unit/prompt_server_test/download_models_test.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py index bd41745fa677..128dfeb9a11e 100644 --- a/tests-unit/prompt_server_test/download_models_test.py +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -109,19 +109,18 @@ async def test_download_model_success(temp_dir): mock_make_request.assert_called_once_with('http://example.com/model.sft') @pytest.mark.asyncio -async def test_download_model_url_request_failure(): +async def test_download_model_url_request_failure(temp_dir): # Mock dependencies mock_response = AsyncMock(spec=ClientResponse) mock_response.status = 404 # Simulate a "Not Found" error mock_get = AsyncMock(return_value=mock_response) mock_progress_callback = AsyncMock() - fake_paths = {'checkpoints': (['mock_directory'], folder_paths.supported_pt_extensions)} + fake_paths = {'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)} # Mock the create_model_path function with patch('model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \ patch('model_filemanager.check_file_exists', return_value=None), \ - patch('folder_paths.get_folder_paths', return_value=['mock_directory']), \ patch('folder_paths.folder_names_and_paths', fake_paths): # Call the function result = await download_model( @@ -129,7 +128,7 @@ async def test_download_model_url_request_failure(): 'model.safetensors', 'http://example.com/model.safetensors', 'checkpoints', - 'mock_directory', + temp_dir, mock_progress_callback )