diff --git a/onnxruntime/python/tools/transformers/fusion_transpose.py b/onnxruntime/python/tools/transformers/fusion_transpose.py index 286adf0fce42c..6602d168309f0 100644 --- a/onnxruntime/python/tools/transformers/fusion_transpose.py +++ b/onnxruntime/python/tools/transformers/fusion_transpose.py @@ -55,6 +55,10 @@ def fuse( cast_children = self.model.get_children(cast_node, input_name_to_nodes) if cast_children and len(cast_children) > 1: return + + if cast_node.input[0] not in output_name_to_node: + return + transpose_a = output_name_to_node[cast_node.input[0]] if transpose_a.op_type != "Transpose": diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index 503cead70a12d..19d205fb356bb 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -82,7 +82,7 @@ pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/w pip install -r requirements-cuda.txt ``` -ONNX Runtime requires CUDA and [cuDNN](https://developer.nvidia.com/rdp/cudnn-download) for GPU inference. See https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html for compatible versions. +ONNX Runtime requires CUDA and [cuDNN](https://developer.nvidia.com/rdp/cudnn-download) for GPU inference. CUDA 11.7 and cuDNN 8.5 are used in our tests. #### Install Nightly (Optional) @@ -179,7 +179,10 @@ Before running benchmark on PyTorch, you need to be logged in via `huggingface-c Example to benchmark the optimized pipeline of stable diffusion 1.5 with batch size 1 on CUDA EP: ``` python benchmark.py -p ./sd_v1_5/fp16 -b 1 -v 1.5 +python benchmark.py -b 1 -v 1.5 ``` +For the first command, '-p' specifies a directory of optimized ONNX pipeline as generated by optimize_pipeline.py. +For the second command without '-p', we will use OnnxruntimeCudaStableDiffusionPipeline to export and optimize ONNX models for clip, unet and vae decoder. On ROCm EP, use the following command instead: ``` @@ -220,6 +223,21 @@ Sometime, it complains ptxas not found when there are multiple CUDA versions ins Note that torch.compile is not supported in Windows: we encountered error `Windows not yet supported for torch.compile`. So it is excluded from RTX 3060 results of Windows. + +### Run Benchmark with TensorRT and TensorRT execution provider + +For TensorRT installation, follow https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html. + +``` +pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 +pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com +pip install -r requirements-tensorrt.txt +export CUDA_MODULE_LOADING=LAZY +python benchmark.py -e tensorrt -b 1 -v 1.5 +python benchmark.py -e onnxruntime -r tensorrt -b 1 -v 1.5 +python benchmark.py -e onnxruntime -r tensorrt -b 1 -v 1.5 --enable_cuda_graph +``` + ### Example Benchmark output Common settings for below test results: @@ -232,13 +250,13 @@ Common settings for below test results: | engine | version | provider | batch size | average latency | first run memory MB | second run memory MB | | ----------- | ----------------------- | --------------------- | ---------- | --------------- | ------------------- | -------------------- | -| onnxruntime | 1.14.1 | CUDAExecutionProvider | 1 | 4.8 | 4,117 | 4,625 | +| onnxruntime | 1.14.1 | CUDA | 1 | 4.8 | 4,117 | 4,625 | | torch | 2.0.0+cu117 | default | 1 | 5.6 | 4,325 | 4,047 | | torch | 1.13.1+cu117 | xformers | 1 | 6.0 | 9,124 | 9,130 | -| onnxruntime | 1.14.1 | CUDAExecutionProvider | 4 | 17.7 | 6,659 | 6,659 | +| onnxruntime | 1.14.1 | CUDA | 4 | 17.7 | 6,659 | 6,659 | | torch | 2.0.0+cu117 | default | 4 | 20.1 | 6,421 | 6,907 | | torch | 1.13.1+cu117 | xformers | 4 | 21.6 | 10,407 | 10,409 | -| onnxruntime | 1.14.1 | CUDAExecutionProvider | 8 | 33.5 | 6,663 | 6,663 | +| onnxruntime | 1.14.1 | CUDA | 8 | 33.5 | 6,663 | 6,663 | | torch | 2.0.0+cu117 | default | 8 | 39.5 | 10,767 | 10,813 | | torch | 1.13.1+cu117 | xformers | 8 | 41.1 | 10,825 | 9,255 | @@ -246,16 +264,16 @@ Common settings for below test results: #### Results of A100-SXM4-40GB (Ubuntu 20.04) | engine | version | provider | batch size | average latency | first run memory MB | second run memory MB | | ----------- | ----------------------- | --------------------- | ---------- | --------------- | ------------------- | -------------------- | -| onnxruntime | 1.14.1 | CUDAExecutionProvider | 1 | 1.1 | 6,883 | 7,395 | +| onnxruntime | 1.14.1 | CUDA | 1 | 1.1 | 6,883 | 7,395 | | torch | 2.0.0+cu117 | default | 1 | 1.5 | 13,828 | 4,400 | | torch | 2.0.0+cu117 | compile | 1 | 1.8 | 13,892 | 4,386 | -| onnxruntime | 1.14.1 | CUDAExecutionProvider | 4 | 3.7 | 7,381 | 7,381 | +| onnxruntime | 1.14.1 | CUDA | 4 | 3.7 | 7,381 | 7,381 | | torch | 2.0.0+cu117 | default | 4 | 3.9 | 31,278 | 6,870 | | torch | 2.0.0+cu117 | compile | 4 | 3.4 | 31,364 | 6,880 | -| onnxruntime | 1.14.1 | CUDAExecutionProvider | 8 | 6.9 | 7,411 | 7,411 | +| onnxruntime | 1.14.1 | CUDA | 8 | 6.9 | 7,411 | 7,411 | | torch | 2.0.0+cu117 | default | 8 | 7.6 | 31,660 | 10,122 | | torch | 2.0.0+cu117 | compile | 8 | 6.5 | 31,800 | 10,308 | -| onnxruntime | 1.14.1 | CUDAExecutionProvider | 16 | 13.6 | 11,479 | 11,479 | +| onnxruntime | 1.14.1 | CUDA | 16 | 13.6 | 11,479 | 11,479 | | torch | 2.0.0+cu117 | default | 16 | 14.8 | 32,306 | 16,520 | | torch | 2.0.0+cu117 | compile | 16 | 12.6 | 32,636 | 16,898 | @@ -265,15 +283,15 @@ Results from Standard_NC6s_v3 Azure virtual machine: | engine | version | provider | batch size | average latency | first run memory MB | second run memory MB | | ----------- | ----------------------- | --------------------- | ---------- | --------------- | ------------------- | -------------------- | -| onnxruntime | 1.14.1 | CUDAExecutionProvider | 1 | 2.7 | 12,646 | 7,152 | +| onnxruntime | 1.14.1 | CUDA | 1 | 2.7 | 12,646 | 7,152 | | torch | 2.0.0+cu117 | compile | 1 | 3.2 | 13,317 | 3,909 | | torch | 2.0.0+cu117 | default | 1 | 2.7 | 13,343 | 3,921 | | torch | 1.13.1+cu117 | xformers | 1 | 3.5 | 14,979 | 10,449 | -| onnxruntime | 1.14.1 | CUDAExecutionProvider | 4 | 8.4 | 7,114 | 7,114 | +| onnxruntime | 1.14.1 | CUDA | 4 | 8.4 | 7,114 | 7,114 | | torch | 2.0.0+cu117 | compile | 4 | 8.0 | 13,897 | 6,821 | | torch | 2.0.0+cu117 | default | 4 | 8.7 | 13,873 | 6,607 | | torch | 1.13.1+cu117 | xformers | 4 | 9.1 | 12,969 | 8,421 | -| onnxruntime | 1.14.1 | CUDAExecutionProvider | 8 | 15.9 | 7,120 | 7,120 | +| onnxruntime | 1.14.1 | CUDA | 8 | 15.9 | 7,120 | 7,120 | | torch | 2.0.0+cu117 | compile | 8 | 15.5 | 14,669 | 10,355 | | torch | 2.0.0+cu117 | default | 8 | 17.0 | 14,469 | 9,657 | | torch | 1.13.1+cu117 | xformers | 8 | 17.4 | 15,593 | 9,133 | @@ -287,15 +305,22 @@ Results are from Standard_NC4as_T4_v3 Azure virtual machine: | engine | version | provider | batch size | average latency | first run memory MB | second run memory MB | | ----------- | ----------------------- | --------------------- | ---------- | --------------- | ------------------- | -------------------- | -| onnxruntime | 1.14.1 | CUDAExecutionProvider | 1 | 5.6 | 4,925 | 4,925 | +| onnxruntime | 1.14.1 | CUDA | 1 | 5.6 | 4,925 | 4,925 | +| onnxruntime | 1.15.1 | CUDA | 1 | 5.5 | 3,738 | 4,250 | +| onnxruntime | 1.15.1 (tensorrt 8.6.1) | Tensorrt | 1 | 4.8 | 10,710 | 10,710 | +| onnxruntime | 1.16.0 nightly | Tensorrt (cuda graph) | 1 | 4.7 | 11,746 | 10,746 | +| tensorrt | 8.6.1 | default | 1 | 5.0 | 8,530 | 8,530 | | torch | 1.13.1+cu117 | xformers | 1 | 6.9 | 14,845 | 10,317 | | torch | 2.0.0+cu117 | compile | 1 | 6.0 | 12,989 | 3,841 | | torch | 2.0.0+cu117 | default | 1 | 6.4 | 12,987 | 3,841 | -| onnxruntime | 1.14.1 | CUDAExecutionProvider | 4 | 23.0 | 6,977 | 6,977 | +| onnxruntime | 1.14.1 | CUDA | 4 | 23.0 | 6,977 | 6,977 | +| onnxruntime | 1.15.1 | CUDA | 4 | 22.6 | 6,298 | 6,298 | +| onnxruntime | 1.15.1 (tensorrt 8.6.1) | Tensorrt | 4 | 21.8 | 10,746 | 10,746 | +| tensorrt | 8.6.1 | default | 4 | 22.2 | 8,542 | 8,542 | | torch | 1.13.1+cu117 | xformers | 4 | 25.8 | 12,819 | 8,269 | | torch | 2.0.0+cu117 | compile | 4 | 22.2 | 14,637 | 6,583 | | torch | 2.0.0+cu117 | default | 4 | 25.2 | 14,409 | 6,355 | -| onnxruntime | 1.14.1 | CUDAExecutionProvider | 8 | 46.4 | 6,779 | 6,779 | +| onnxruntime | 1.14.1 | CUDA | 8 | 46.4 | 6,779 | 6,779 | | torch | 1.13.1+cu117 | xformers | 8 | 51.4 | 14,827 | 9,001 | | torch | 2.0.0+cu117 | compile | 8 | 46.5 | 12,595 | 10,171 | | torch | 2.0.0+cu117 | default | 8 | 50.7 | 11,955 | 9,531 | @@ -304,15 +329,15 @@ Results are from Standard_NC4as_T4_v3 Azure virtual machine: | engine | version | provider | batch size | average latency | first run memory MB | second run memory MB | | ----------- | ----------------------- | --------------------- | ---------- | --------------- | ------------------- | -------------------- | -| onnxruntime | 1.15.0+rocm5.4.2 | ROCMExecutionProvider | 1 | 2.2 | 5,548 | 4,908 | +| onnxruntime | 1.15.0+rocm5.4.2 | ROCM | 1 | 2.2 | 5,548 | 4,908 | | torch | 1.12.1+rocm5.4 | - | 1 | 3.4 | 6,653 | 4,613 | | torch | 2.0.0+rocm5.4.2 | default | 1 | 3.2 | 5,977 | 4,368 | | torch | 2.0.0+rocm5.4.2 | compile | 1 | 3.0 | 5,869 | 4,266 | -| onnxruntime | 1.15.0+rocm5.4.2 | ROCMExecutionProvider | 4 | 6.6 | 5,546 | 4,906 | +| onnxruntime | 1.15.0+rocm5.4.2 | ROCM | 4 | 6.6 | 5,546 | 4,906 | | torch | 1.12.1+rocm5.4 | - | 4 | 10.1 | 19,477 | 11,325 | | torch | 2.0.0+rocm5.4.2 | default | 4 | 10.5 | 13,051 | 7,300 | | torch | 2.0.0+rocm5.4.2 | compile | 4 | 9.2 | 12,879 | 7,190 | -| onnxruntime | 1.15.0+rocm5.4.2 | ROCMExecutionProvider | 8 | 12.5 | 9,778 | 9,006 | +| onnxruntime | 1.15.0+rocm5.4.2 | ROCM | 8 | 12.5 | 9,778 | 9,006 | | torch | 1.12.1+rocm5.4 | - | 8 | 19.3 | 55,851 | 20,014 | | torch | 2.0.0+rocm5.4.2 | default | 8 | 20.3 | 23,551 | 11,930 | | torch | 2.0.0+rocm5.4.2 | compile | 8 | 17.8 | 23,303 | 11,800 | @@ -321,15 +346,15 @@ Results are from Standard_NC4as_T4_v3 Azure virtual machine: | engine | version | provider | batch size | average latency | first run memory MB | second run memory MB | | ----------- | ----------------------- | --------------------- | ---------- | --------------- | ------------------- | -------------------- | -| onnxruntime | 1.15.0+rocm5.4.2 | ROCMExecutionProvider | 1 | 2.4 | 5,254 | 4,614 | +| onnxruntime | 1.15.0+rocm5.4.2 | ROCM | 1 | 2.4 | 5,254 | 4,614 | | torch | 1.12.1+rocm5.4 | - | 1 | 3.5 | 5,771 | 4,672 | | torch | 2.0.0+rocm5.4.2 | default | 1 | 3.5 | 5,811 | 4,206 | | torch | 2.0.0+rocm5.4.2 | compile | 1 | 3.1 | 5,774 | 4,168 | -| onnxruntime | 1.15.0+rocm5.4.2 | ROCMExecutionProvider | 4 | 7.5 | 7,290 | 6,646 | +| onnxruntime | 1.15.0+rocm5.4.2 | ROCM | 4 | 7.5 | 7,290 | 6,646 | | torch | 1.12.1+rocm5.4 | - | 4 | 10.7 | 19,334 | 11,181 | | torch | 2.0.0+rocm5.4.2 | default | 4 | 11.5 | 12,881 | 7,151 | | torch | 2.0.0+rocm5.4.2 | compile | 4 | 10.0 | 12,740 | 7,073 | -| onnxruntime | 1.15.0+rocm5.4.2 | ROCMExecutionProvider | 8 | 14.4 | 7,320 | 6,676 | +| onnxruntime | 1.15.0+rocm5.4.2 | ROCM | 8 | 14.4 | 7,320 | 6,676 | | torch | 1.12.1+rocm5.4 | - | 8 | 20.2 | 31,820 | 19,908 | | torch | 2.0.0+rocm5.4.2 | default | 8 | 22.2 | 23,415 | 11,815 | | torch | 2.0.0+rocm5.4.2 | compile | 8 | 19.3 | 23,154 | 11,667 | @@ -346,13 +371,9 @@ Some kernels are enabled by MIOpen. We hereby thank for the AMD developers' coll ### Future Works There are other optimizations might improve the performance or reduce memory footprint: - -* Use IO Binding in the pipeline. Currently the input and output of each model is in CPU, and extra data copy between GPU and CPU slows down the pipeline. -* Use CUDA graph to speed up inference. * Export the whole pipeline into a single ONNX model. Currently, there are multiple ONNX models (CLIP, VAE and U-Net etc). Each model uses separated thread pool and memory allocator. Combine them into one model could share thread pool and memory allocator. The end result is more efficient and less memory footprint. -* For Stable Diffusion 2.1, we force Attention in fp32 to avoid black image. That slows down the inference significantly. We could potentially change attention kernel (like fp32 accumulation) to avoid the issue. +* For Stable Diffusion 2.1, we disable TensorRT flash attention kernel and use only memory efficient attention. It is possible to add flash attention using Triton compiler to improve performance. * Reduce GPU memory footprint by actively deleting buffers for intermediate results. -* Reduce GPU memory footprint by providing options for CPU RAM Offloading. * Attention fusion in CLIP * Safety Checker Optimization * Leverage FP8 in latest GPU diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index 4e00ded9e33be..13126f648d290 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -10,6 +10,11 @@ import sys import time +import coloredlogs + +# import torch before onnxruntime so that onnxruntime uses the cuDNN in the torch package. +import torch + SD_MODELS = { "1.5": "runwayml/stable-diffusion-v1-5", "2.0": "stabilityai/stable-diffusion-2", @@ -20,6 +25,7 @@ "cuda": "CUDAExecutionProvider", "rocm": "ROCMExecutionProvider", "migraphx": "MIGraphXExecutionProvider", + "tensorrt": "TensorrtExecutionProvider", } @@ -173,7 +179,7 @@ def measure_gpu_memory(monitor_type, func, start_memory=None): def get_ort_pipeline(model_name: str, directory: str, provider, disable_safety_checker: bool): - from diffusers import DPMSolverMultistepScheduler, OnnxStableDiffusionPipeline + from diffusers import DDIMScheduler, OnnxStableDiffusionPipeline import onnxruntime @@ -192,7 +198,7 @@ def get_ort_pipeline(model_name: str, directory: str, provider, disable_safety_c provider=provider, use_auth_token=True, ) - pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) pipe.set_progress_bar_config(disable=True) if disable_safety_checker: @@ -203,7 +209,7 @@ def get_ort_pipeline(model_name: str, directory: str, provider, disable_safety_c def get_torch_pipeline(model_name: str, disable_safety_checker: bool, enable_torch_compile: bool, use_xformers: bool): - from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline + from diffusers import DDIMScheduler, StableDiffusionPipeline from torch import channels_last, float16 pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=float16).to("cuda") @@ -214,14 +220,12 @@ def get_torch_pipeline(model_name: str, disable_safety_checker: bool, enable_tor pipe.enable_xformers_memory_efficient_attention() if enable_torch_compile: - import torch - pipe.unet = torch.compile(pipe.unet) pipe.vae = torch.compile(pipe.vae) pipe.text_encoder = torch.compile(pipe.text_encoder) print("Torch compiled unet, vae and text_encoder") - pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) pipe.set_progress_bar_config(disable=True) if disable_safety_checker: @@ -262,8 +266,7 @@ def warmup(): first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) - if memory_monitor_type is None: - warmup() + warmup() latency_list = [] for i, prompt in enumerate(prompts): @@ -317,8 +320,6 @@ def run_torch_pipeline( start_memory, memory_monitor_type, ): - import torch - prompts = example_prompts() # total 2 runs of warm up, and measure GPU memory for CUDA EP @@ -329,8 +330,7 @@ def warmup(): first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) - if memory_monitor_type is None: - warmup() + warmup() torch.set_grad_enabled(False) @@ -382,14 +382,14 @@ def run_ort( provider: str, batch_size: int, disable_safety_checker: bool, - height, - width, - steps, - num_prompts, - batch_count, + height: int, + width: int, + steps: int, + num_prompts: int, + batch_count: int, start_memory, memory_monitor_type, - tuning, + tuning: bool, ): provider_and_options = provider if tuning and provider in ["CUDAExecutionProvider", "ROCMExecutionProvider"]: @@ -418,29 +418,294 @@ def run_ort( { "model_name": model_name, "directory": directory, - "provider": provider, + "provider": provider.replace("ExecutionProvider", ""), "disable_safety_checker": disable_safety_checker, + "enable_cuda_graph": False, } ) return result +def export_and_run_ort( + model_name: str, + provider: str, + batch_size: int, + disable_safety_checker: bool, + height: int, + width: int, + steps: int, + num_prompts: int, + batch_count: int, + start_memory, + memory_monitor_type, + enable_cuda_graph: bool, +): + assert provider == "CUDAExecutionProvider" + + from diffusers import DDIMScheduler + from onnxruntime_cuda_txt2img import OnnxruntimeCudaStableDiffusionPipeline + + scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") + + pipe = OnnxruntimeCudaStableDiffusionPipeline.from_pretrained( + model_name, + scheduler=scheduler, + requires_safety_checker=not disable_safety_checker, + enable_cuda_graph=enable_cuda_graph, + ) + + # re-use cached folder to save ONNX models + pipe.set_cached_folder(model_name) + + pipe = pipe.to("cuda", torch_dtype=torch.float16) + + def warmup(): + pipe(["warm up"] * batch_size, image_height=height, image_width=width, num_inference_steps=steps) + + # Run warm up, and measure GPU memory of two runs + # The first run has algo search so it might need more memory + first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + + # An extra warm up run is needed for cuda graph + warmup() + + image_filename_prefix = get_image_filename_prefix("ort_cuda", model_name, batch_size, disable_safety_checker) + + latency_list = [] + prompts = example_prompts() + for i, prompt in enumerate(prompts): + if i >= num_prompts: + break + for j in range(batch_count): + inference_start = time.time() + images = pipe( + [prompt] * batch_size, + num_inference_steps=steps, + ).images + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"Inference took {latency:.3f} seconds") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") + + from onnxruntime import __version__ as ort_version + + return { + "model_name": model_name, + "engine": "onnxruntime", + "version": ort_version, + "provider": provider.replace("ExecutionProvider", ""), + "directory": pipe.engine_dir, + "height": height, + "width": width, + "steps": steps, + "batch_size": batch_size, + "batch_count": batch_count, + "num_prompts": num_prompts, + "average_latency": sum(latency_list) / len(latency_list), + "median_latency": statistics.median(latency_list), + "first_run_memory_MB": first_run_memory, + "second_run_memory_MB": second_run_memory, + "disable_safety_checker": disable_safety_checker, + "enable_cuda_graph": enable_cuda_graph, + } + + +def run_ort_trt( + model_name: str, + batch_size: int, + disable_safety_checker: bool, + height: int, + width: int, + steps: int, + num_prompts: int, + batch_count: int, + start_memory, + memory_monitor_type, + max_batch_size: int, + enable_cuda_graph: bool, +): + from diffusers import DDIMScheduler + from onnxruntime_tensorrt_txt2img import OnnxruntimeTensorRTStableDiffusionPipeline + + assert batch_size <= max_batch_size + + scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") + pipe = OnnxruntimeTensorRTStableDiffusionPipeline.from_pretrained( + model_name, + revision="fp16", + torch_dtype=torch.float16, + scheduler=scheduler, + requires_safety_checker=not disable_safety_checker, + image_height=height, + image_width=width, + max_batch_size=max_batch_size, + onnx_opset=17, + enable_cuda_graph=enable_cuda_graph, + ) + + # re-use cached folder to save ONNX models and TensorRT Engines + pipe.set_cached_folder(model_name, revision="fp16") + + pipe = pipe.to("cuda") + + def warmup(): + pipe(["warm up"] * batch_size, num_inference_steps=steps) + + # Run warm up, and measure GPU memory of two runs + # The first run has algo search so it might need more memory + first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + + warmup() + + image_filename_prefix = get_image_filename_prefix("ort_trt", model_name, batch_size, disable_safety_checker) + + latency_list = [] + prompts = example_prompts() + for i, prompt in enumerate(prompts): + if i >= num_prompts: + break + for j in range(batch_count): + inference_start = time.time() + images = pipe( + [prompt] * batch_size, + num_inference_steps=steps, + ).images + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"Inference took {latency:.3f} seconds") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") + + from tensorrt import __version__ as trt_version + + from onnxruntime import __version__ as ort_version + + return { + "model_name": model_name, + "engine": "onnxruntime", + "version": ort_version, + "provider": f"tensorrt{trt_version})", + "directory": pipe.engine_dir, + "height": height, + "width": width, + "steps": steps, + "batch_size": batch_size, + "batch_count": batch_count, + "num_prompts": num_prompts, + "average_latency": sum(latency_list) / len(latency_list), + "median_latency": statistics.median(latency_list), + "first_run_memory_MB": first_run_memory, + "second_run_memory_MB": second_run_memory, + "disable_safety_checker": disable_safety_checker, + "enable_cuda_graph": enable_cuda_graph, + } + + +def run_tensorrt( + model_name: str, + batch_size: int, + disable_safety_checker: bool, + height: int, + width: int, + steps: int, + num_prompts: int, + batch_count: int, + start_memory, + memory_monitor_type, + max_batch_size: int, +): + from diffusers import DDIMScheduler + from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline + + assert batch_size <= max_batch_size + + scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") + pipe = StableDiffusionPipeline.from_pretrained( + model_name, + custom_pipeline="stable_diffusion_tensorrt_txt2img", + revision="fp16", + torch_dtype=torch.float16, + scheduler=scheduler, + requires_safety_checker=not disable_safety_checker, + image_height=height, + image_width=width, + max_batch_size=max_batch_size, + ) + + # re-use cached folder to save ONNX models and TensorRT Engines + pipe.set_cached_folder(model_name, revision="fp16") + + pipe = pipe.to("cuda") + + def warmup(): + pipe(["warm up"] * batch_size, num_inference_steps=steps) + + # Run warm up, and measure GPU memory of two runs + # The first run has algo search so it might need more memory + first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + + warmup() + + image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, disable_safety_checker) + + latency_list = [] + prompts = example_prompts() + for i, prompt in enumerate(prompts): + if i >= num_prompts: + break + for j in range(batch_count): + inference_start = time.time() + images = pipe( + [prompt] * batch_size, + num_inference_steps=steps, + ).images + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"Inference took {latency:.3f} seconds") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") + + from tensorrt import __version__ as trt_version + + return { + "engine": "tensorrt", + "version": trt_version, + "provider": "default", + "height": height, + "width": width, + "steps": steps, + "batch_size": batch_size, + "batch_count": batch_count, + "num_prompts": num_prompts, + "average_latency": sum(latency_list) / len(latency_list), + "median_latency": statistics.median(latency_list), + "first_run_memory_MB": first_run_memory, + "second_run_memory_MB": second_run_memory, + "enable_cuda_graph": False, + } + + def run_torch( model_name: str, batch_size: int, disable_safety_checker: bool, enable_torch_compile: bool, use_xformers: bool, - height, - width, - steps, - num_prompts, - batch_count, + height: int, + width: int, + steps: int, + num_prompts: int, + batch_count: int, start_memory, memory_monitor_type, ): - import torch - torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True @@ -487,6 +752,7 @@ def run_torch( "directory": None, "provider": "compile" if enable_torch_compile else "xformers" if use_xformers else "default", "disable_safety_checker": disable_safety_checker, + "enable_cuda_graph": False, } ) return result @@ -501,7 +767,7 @@ def parse_arguments(): required=False, type=str, default="onnxruntime", - choices=["onnxruntime", "torch"], + choices=["onnxruntime", "torch", "tensorrt"], help="Engines to benchmark. Default is onnxruntime.", ) @@ -539,7 +805,7 @@ def parse_arguments(): required=False, type=str, default=None, - help="Directory of saved onnx pipeline. It could be output directory of optimize_pipeline.py.", + help="Directory of saved onnx pipeline. It could be the output directory of optimize_pipeline.py.", ) parser.add_argument( @@ -619,16 +885,62 @@ def parse_arguments(): help="Number of batches to test. Default is 5.", ) + parser.add_argument( + "-m", + "--max_trt_batch_size", + required=False, + type=int, + choices=range(1, 16), + default=4, + help="Maximum batch size for TensorRT. Change the value may trigger TensorRT engine rebuild. Default is 4.", + ) + + parser.add_argument( + "-g", + "--enable_cuda_graph", + required=False, + action="store_true", + help="Enable Cuda Graph. Requires onnxruntime >= 1.16", + ) + parser.set_defaults(enable_cuda_graph=False) + args = parser.parse_args() + return args +def print_loaded_libraries(cuda_related_only=True): + import psutil + + p = psutil.Process(os.getpid()) + for lib in p.memory_maps(): + if (not cuda_related_only) or any(x in lib.path for x in ("libcu", "libnv", "tensorrt")): + print(lib.path) + + def main(): args = parse_arguments() print(args) + if args.enable_cuda_graph: + if not (args.engine == "onnxruntime" and args.provider in ["cuda", "tensorrt"] and args.pipeline is None): + raise ValueError("The stable diffusion pipeline does not support CUDA graph.") + + from packaging import version + + from onnxruntime import __version__ as ort_version + + if version.parse(ort_version) < version.parse("1.16"): + raise ValueError( + "CUDA graph requires ONNX Runtime 1.16. You can install nightly like the following:\n" + " pip uninstall onnxruntime-gpu\n" + " pip install ort-nightly-gpu -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/" + ) + + coloredlogs.install(fmt="%(funcName)20s: %(message)s") + memory_monitor_type = None - if args.provider == "cuda": + if args.provider in ["cuda", "tensorrt"]: memory_monitor_type = CudaMemoryMonitor elif args.provider == "rocm": memory_monitor_type = RocmMemoryMonitor @@ -638,8 +950,41 @@ def main(): sd_model = SD_MODELS[args.version] provider = PROVIDERS[args.provider] - if args.engine == "onnxruntime": - assert args.pipeline, "--pipeline should be specified for onnxruntime engine" + if args.engine == "onnxruntime" and args.provider == "tensorrt": + result = run_ort_trt( + sd_model, + args.batch_size, + not args.enable_safety_checker, + args.height, + args.width, + args.steps, + args.num_prompts, + args.batch_count, + start_memory, + memory_monitor_type, + args.max_trt_batch_size, + args.enable_cuda_graph, + ) + elif args.engine == "onnxruntime" and provider == "CUDAExecutionProvider" and args.pipeline is None: + print("Pipeline is not specified. Trying export and optimize onnx models...") + result = export_and_run_ort( + sd_model, + provider, + args.batch_size, + not args.enable_safety_checker, + args.height, + args.width, + args.steps, + args.num_prompts, + args.batch_count, + start_memory, + memory_monitor_type, + args.enable_cuda_graph, + ) + elif args.engine == "onnxruntime": + assert args.pipeline and os.path.isdir( + args.pipeline + ), "--pipeline should be specified for the directory of ONNX models" if args.version in ["2.1"]: # Set a flag to avoid overflow in attention, which causes black image output in SD 2.1 model @@ -661,6 +1006,20 @@ def main(): memory_monitor_type, args.tuning, ) + elif args.engine == "tensorrt": + result = run_tensorrt( + sd_model, + args.batch_size, + not args.enable_safety_checker, + args.height, + args.width, + args.steps, + args.num_prompts, + args.batch_count, + start_memory, + memory_monitor_type, + args.max_trt_batch_size, + ) else: result = run_torch( sd_model, @@ -697,11 +1056,16 @@ def main(): "median_latency", "first_run_memory_MB", "second_run_memory_MB", + "enable_cuda_graph", ] csv_writer = csv.DictWriter(csv_file, fieldnames=column_names) csv_writer.writeheader() csv_writer.writerow(result) + # Show loaded DLLs when steps == 1 for debugging purpose. + if args.steps == 1: + print_loaded_libraries(args.provider in ["cuda", "tensorrt"]) + if __name__ == "__main__": try: diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py new file mode 100644 index 0000000000000..cecfc976b351e --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py @@ -0,0 +1,758 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# +# Copyright 2023 The HuggingFace Inc. team. +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Stable diffusion text to image pipeline using ONNX Runtime CUDA execution provider. +Based on https://github.com/huggingface/diffusers/blob/v0.17.1/examples/community/stable_diffusion_tensorrt_txt2img.py +Modifications: (1) Create ONNX Runtime session (2) Use I/O Binding of ONNX Runtime for inference + +Installation instructions +pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 +pip install --upgrade transformers diffusers>=0.16.0 +pip install --upgrade tensorrt>=8.6.1 +pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com +pip install onnxruntime-gpu +""" + +import gc +import os +import shutil +from typing import List, Optional, Union + +import torch +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion import ( + StableDiffusionPipeline, + StableDiffusionPipelineOutput, + StableDiffusionSafetyChecker, +) +from diffusers.schedulers import DDIMScheduler +from diffusers.utils import DIFFUSERS_CACHE, logging +from huggingface_hub import snapshot_download +from ort_utils import OrtCudaSession +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +import onnxruntime as ort +from onnxruntime.transformers.fusion_options import FusionOptions +from onnxruntime.transformers.onnx_model_clip import ClipOnnxModel +from onnxruntime.transformers.onnx_model_unet import UnetOnnxModel +from onnxruntime.transformers.onnx_model_vae import VaeOnnxModel +from onnxruntime.transformers.optimizer import optimize_by_onnxruntime, optimize_model + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Engine(OrtCudaSession): + def __init__(self, engine_path, provider, device_id: int = 0, enable_cuda_graph=False): + self.engine_path = engine_path + self.provider = provider + self.provider_options = self.get_cuda_provider_options(device_id, enable_cuda_graph) + + device = torch.device("cuda", device_id) + ort_session = ort.InferenceSession( + self.engine_path, + providers=[ + (provider, self.provider_options), + "CPUExecutionProvider", + ], + ) + + super().__init__(ort_session, device, enable_cuda_graph) + + def get_cuda_provider_options(self, device_id: int, enable_cuda_graph: bool): + return { + "device_id": device_id, + "arena_extend_strategy": "kSameAsRequested", + "enable_cuda_graph": enable_cuda_graph, + } + + +class OrtStableDiffusionOptimizer: + def __init__(self, model_type: str): + assert model_type in ["vae", "unet", "clip"] + self.model_type = model_type + self.model_type_class_mapping = { + "unet": UnetOnnxModel, + "vae": VaeOnnxModel, + "clip": ClipOnnxModel, + } + + def optimize_by_ort(self, onnx_model): + import tempfile + from pathlib import Path + + import onnx + + # Use this step to see the final graph that executed by Onnx Runtime. + with tempfile.TemporaryDirectory() as tmp_dir: + # Save to a temporary file so that we can load it with Onnx Runtime. + logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...") + tmp_model_path = Path(tmp_dir) / "model.onnx" + onnx_model.save_model_to_file(str(tmp_model_path)) + ort_optimized_model_path = tmp_model_path + optimize_by_onnxruntime( + str(tmp_model_path), use_gpu=True, optimized_model_path=str(ort_optimized_model_path) + ) + model = onnx.load(str(ort_optimized_model_path), load_external_data=True) + return self.model_type_class_mapping[self.model_type](model) + + def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True): + """Optimize onnx model using ONNX Runtime transformers optimizer""" + logger.info(f"Optimize {input_fp32_onnx_path}...") + fusion_options = FusionOptions(self.model_type) + if self.model_type in ["unet"] and not float16: + fusion_options.enable_packed_kv = False + fusion_options.enable_packed_qkv = False + + m = optimize_model( + input_fp32_onnx_path, + model_type=self.model_type, + num_heads=0, # will be deduced from graph + hidden_size=0, # will be deduced from graph + opt_level=0, + optimization_options=fusion_options, + use_gpu=True, + ) + + if self.model_type == "clip": + m.prune_graph(outputs=["text_embeddings"]) # remove the pooler_output, and only keep the first output. + + if float16: + logger.info("Convert to float16 ...") + m.convert_float_to_float16( + keep_io_types=False, + op_block_list=["RandomNormalLike"], + ) + + # Note that ORT 1.15 could not save model larger than 2GB. This only works for float16 + if float16 or (self.model_type != "unet"): + m = self.optimize_by_ort(m) + + m.get_operator_statistics() + m.get_fused_operator_statistics() + m.save_model_to_file(optimized_onnx_path, use_external_data_format=(self.model_type == "unet") and not float16) + logger.info("%s is optimized: %s", self.model_type, optimized_onnx_path) + + +class BaseModel: + def __init__(self, model, name, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77): + self.model = model + self.name = name + self.device = device + + self.min_batch = 1 + self.max_batch = max_batch_size + self.min_image_shape = 256 # min image resolution: 256x256 + self.max_image_shape = 1024 # max image resolution: 1024x1024 + self.min_latent_shape = self.min_image_shape // 8 + self.max_latent_shape = self.max_image_shape // 8 + + self.embedding_dim = embedding_dim + self.text_maxlen = text_maxlen + + self.model_type = name.lower() if name in ["CLIP", "UNet"] else "vae" + self.optimizer = OrtStableDiffusionOptimizer(self.model_type) + + def get_model(self): + return self.model + + def get_input_names(self): + pass + + def get_output_names(self): + pass + + def get_dynamic_axes(self): + return None + + def get_sample_input(self, batch_size, image_height, image_width): + pass + + def get_shape_dict(self, batch_size, image_height, image_width): + return None + + def optimize(self, input_fp32_onnx_path, optimized_onnx_path, fp16): + self.optimizer.optimize(input_fp32_onnx_path, optimized_onnx_path, fp16) + + def check_dims(self, batch_size, image_height, image_width): + assert batch_size >= self.min_batch and batch_size <= self.max_batch + assert image_height % 8 == 0 or image_width % 8 == 0 + latent_height = image_height // 8 + latent_width = image_width // 8 + assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape + assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape + return (latent_height, latent_width) + + def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape): + min_batch = batch_size if static_batch else self.min_batch + max_batch = batch_size if static_batch else self.max_batch + latent_height = image_height // 8 + latent_width = image_width // 8 + min_image_height = image_height if static_image_shape else self.min_image_shape + max_image_height = image_height if static_image_shape else self.max_image_shape + min_image_width = image_width if static_image_shape else self.min_image_shape + max_image_width = image_width if static_image_shape else self.max_image_shape + min_latent_height = latent_height if static_image_shape else self.min_latent_shape + max_latent_height = latent_height if static_image_shape else self.max_latent_shape + min_latent_width = latent_width if static_image_shape else self.min_latent_shape + max_latent_width = latent_width if static_image_shape else self.max_latent_shape + return ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) + + +def get_onnx_path(model_name, onnx_dir): + return os.path.join(onnx_dir, model_name + ".onnx") + + +def get_engine_path(engine_dir, model_name, profile_id): + return os.path.join(engine_dir, model_name + profile_id + ".onnx") + + +def build_engines( + models, + engine_dir, + onnx_dir, + onnx_opset, + force_engine_rebuild: bool = False, + fp16: bool = True, + provider: str = "CUDAExecutionProvider", + device_id: int = 0, + enable_cuda_graph: bool = False, +): + profile_id = "_fp16" if fp16 else "_fp32" + + if force_engine_rebuild: + if os.path.isdir(onnx_dir): + logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir) + shutil.rmtree(onnx_dir) + if os.path.isdir(engine_dir): + logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir) + shutil.rmtree(engine_dir) + + if not os.path.isdir(engine_dir): + os.makedirs(engine_dir) + + if not os.path.isdir(onnx_dir): + os.makedirs(onnx_dir) + + # Export models to ONNX + for model_name, model_obj in models.items(): + onnx_path = get_onnx_path(model_name, onnx_dir) + onnx_opt_path = get_engine_path(engine_dir, model_name, profile_id) + if os.path.exists(onnx_opt_path): + logger.info("Found cached optimized model: %s", onnx_opt_path) + else: + if os.path.exists(onnx_path): + logger.info("Found cached model: %s", onnx_path) + else: + logger.info("Exporting model: %s", onnx_path) + model = model_obj.get_model().to(model_obj.device) + with torch.inference_mode(): + inputs = model_obj.get_sample_input(1, 512, 512) + torch.onnx.export( + model, + inputs, + onnx_path, + export_params=True, + opset_version=onnx_opset, + do_constant_folding=True, + input_names=model_obj.get_input_names(), + output_names=model_obj.get_output_names(), + dynamic_axes=model_obj.get_dynamic_axes(), + ) + del model + torch.cuda.empty_cache() + gc.collect() + + # Optimize onnx + logger.info("Generating optimized model: %s", onnx_opt_path) + model_obj.optimize(onnx_path, onnx_opt_path, fp16) + + built_engines = {} + for model_name in models: + engine_path = get_engine_path(engine_dir, model_name, profile_id) + engine = Engine(engine_path, provider, device_id=device_id, enable_cuda_graph=enable_cuda_graph) + logger.info("%s options for %s: %s", provider, model_name, engine.provider_options) + built_engines[model_name] = engine + + return built_engines + + +def run_engine(engine, feed_dict): + return engine.infer(feed_dict) + + +class CLIP(BaseModel): + def __init__(self, model, device, max_batch_size, embedding_dim): + super().__init__( + model=model, + name="CLIP", + device=device, + max_batch_size=max_batch_size, + embedding_dim=embedding_dim, + ) + + def get_input_names(self): + return ["input_ids"] + + def get_output_names(self): + return ["text_embeddings", "pooler_output"] + + def get_dynamic_axes(self): + return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return { + "input_ids": (batch_size, self.text_maxlen), + "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), + # "pooler_output": (batch_size, self.embedding_dim) + } + + def get_sample_input(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device) + + +class UNet(BaseModel): + def __init__( + self, + model, + device="cuda", + max_batch_size=16, + embedding_dim=768, + text_maxlen=77, + unet_dim=4, + ): + super().__init__( + model=model, + name="UNet", + device=device, + max_batch_size=max_batch_size, + embedding_dim=embedding_dim, + text_maxlen=text_maxlen, + ) + self.unet_dim = unet_dim + + def get_input_names(self): + return ["sample", "timestep", "encoder_hidden_states"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return { + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), + "timestep": [1], + "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), + "latent": (2 * batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return ( + torch.randn( + 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + ), + torch.tensor([1.0], dtype=torch.float32, device=self.device), + torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=torch.float32, device=self.device), + ) + + +class VAE(BaseModel): + def __init__(self, model, device, max_batch_size, embedding_dim): + super().__init__( + model=model, + name="VAE Decoder", + device=device, + max_batch_size=max_batch_size, + embedding_dim=embedding_dim, + ) + + def get_input_names(self): + return ["latent"] + + def get_output_names(self): + return ["images"] + + def get_dynamic_axes(self): + return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}} + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "latent": (batch_size, 4, latent_height, latent_width), + "images": (batch_size, 3, image_height, image_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device) + + +class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline): + r""" + Pipeline for text-to-image generation using CUDA provider in ONNX Runtime. + This pipeline inherits from [`StableDiffusionPipeline`]. Check the documentation in super class for most parameters. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + # ONNX export parameters + onnx_opset: int = 14, + onnx_dir: str = "raw_onnx", + # Onnxruntime execution provider parameters + engine_dir: str = "onnxruntime_optimized_onnx", + force_engine_rebuild: bool = False, + enable_cuda_graph: bool = False, + ): + super().__init__( + vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker + ) + + self.vae.forward = self.vae.decode + self.unet_in_channels = unet.config.in_channels + + self.inpaint = False + self.onnx_opset = onnx_opset + self.onnx_dir = onnx_dir + self.engine_dir = engine_dir + self.force_engine_rebuild = force_engine_rebuild + self.enable_cuda_graph = enable_cuda_graph + + self.max_batch_size = 16 + + self.models = {} # loaded in __load_models() + self.engines = {} # loaded in build_engines() + + self.provider = "CUDAExecutionProvider" + self.fp16 = False + + def __load_models(self): + self.embedding_dim = self.text_encoder.config.hidden_size + + self.models["clip"] = CLIP( + self.text_encoder, + device=self.torch_device, + max_batch_size=self.max_batch_size, + embedding_dim=self.embedding_dim, + ) + + self.models["unet"] = UNet( + self.unet, + device=self.torch_device, + max_batch_size=self.max_batch_size, + embedding_dim=self.embedding_dim, + unet_dim=(9 if self.inpaint else 4), + ) + + self.models["vae"] = VAE( + self.vae, device=self.torch_device, max_batch_size=self.max_batch_size, embedding_dim=self.embedding_dim + ) + + @classmethod + def set_cached_folder(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + + cls.cached_folder = ( + pretrained_model_name_or_path + if os.path.isdir(pretrained_model_name_or_path) + else snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + ) + + def to( + self, + torch_device: Union[str, torch.device], + torch_dtype: Optional[torch.dtype] = None, + silence_dtype_warnings: bool = False, + ): + self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir) + self.engine_dir = os.path.join(self.cached_folder, self.engine_dir) + + # set device + self.torch_device = torch.device(torch_device) + + # load models + self.__load_models() + + # build engines + self.fp16 = torch_dtype == torch.float16 + self.engines = build_engines( + self.models, + self.engine_dir, + self.onnx_dir, + self.onnx_opset, + force_engine_rebuild=self.force_engine_rebuild, + fp16=self.fp16, + provider=self.provider, + device_id=self.torch_device.index or torch.cuda.current_device(), + enable_cuda_graph=self.enable_cuda_graph, + ) + + # Load the remaining modules to GPU. + self.text_encoder = None + self.vae = None + self.unet = None + super().to(torch_device, torch_dtype, silence_dtype_warnings=silence_dtype_warnings) + + self.torch_device = self._execution_device + logger.info(f"Running inference on device: {self.torch_device}") + + return self + + def __encode_prompt(self, prompt, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + """ + # Tokenize prompt + text_input_ids = ( + self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.torch_device) + ) + + # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt + text_embeddings = run_engine(self.engines["clip"], {"input_ids": text_input_ids})["text_embeddings"].clone() + + # Tokenize negative prompt + uncond_input_ids = ( + self.tokenizer( + negative_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.torch_device) + ) + + uncond_embeddings = run_engine(self.engines["clip"], {"input_ids": uncond_input_ids})["text_embeddings"] + + # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) + + return text_embeddings + + def __denoise_latent(self, latents, text_embeddings, timesteps=None, mask=None, masked_image_latents=None): + if not isinstance(timesteps, torch.Tensor): + timesteps = self.scheduler.timesteps + + for _step_index, timestep in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) + if isinstance(mask, torch.Tensor): + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + timestep_float = timestep.to(torch.float16) if self.fp16 else timestep.to(torch.float32) + + # Predict the noise residual + noise_pred = run_engine( + self.engines["unet"], + {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings}, + )["latent"] + + # Perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample + + latents = 1.0 / 0.18215 * latents + return latents + + def __decode_latent(self, latents): + images = run_engine(self.engines["vae"], {"latent": latents})["images"] + images = (images / 2 + 0.5).clamp(0, 1) + return images.cpu().permute(0, 2, 3, 1).float().numpy() + + def __allocate_buffers(self, image_height, image_width, batch_size): + # Allocate output tensors for I/O bindings + for model_name, obj in self.models.items(): + self.engines[model_name].allocate_buffers(obj.get_shape_dict(batch_size, image_height, image_width)) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + image_height: int = 512, + image_width: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + """ + self.generator = generator + self.denoising_steps = num_inference_steps + self.guidance_scale = guidance_scale + + # Pre-compute latent input scales and linear multistep coefficients + self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device) + + # Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"Expected prompt to be of type list or str but got {type(prompt)}") + + if negative_prompt is None: + negative_prompt = [""] * batch_size + + if negative_prompt is not None and isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + assert len(prompt) == len(negative_prompt) + + if batch_size > self.max_batch_size: + raise ValueError( + f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4" + ) + + self.__allocate_buffers(image_height, image_width, batch_size) + + with torch.inference_mode(), torch.autocast("cuda"): + # CLIP text encoder + text_embeddings = self.__encode_prompt(prompt, negative_prompt) + + # Pre-initialize latents + num_channels_latents = self.unet_in_channels + latents = self.prepare_latents( + batch_size, + num_channels_latents, + image_height, + image_width, + torch.float16 if self.fp16 else torch.float32, + self.torch_device, + generator, + ) + + # UNet denoiser + latents = self.__denoise_latent(latents, text_embeddings) + + # VAE decode latent + images = self.__decode_latent(latents) + + images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype) + images = self.numpy_to_pil(images) + return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + + +if __name__ == "__main__": + import torch + from diffusers import DDIMScheduler + + model_name_or_path = "runwayml/stable-diffusion-v1-5" + scheduler = DDIMScheduler.from_pretrained(model_name_or_path, subfolder="scheduler") + + pipe = OnnxruntimeCudaStableDiffusionPipeline.from_pretrained( + model_name_or_path, + scheduler=scheduler, + ) + + # re-use cached folder to save ONNX models + pipe.set_cached_folder(model_name_or_path) + + pipe = pipe.to("cuda", torch_dtype=torch.float16) + + prompt = "photorealistic new zealand hills" + image = pipe(prompt).images[0] + image.save("ort_trt_txt2img_new_zealand_hills.png") diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py new file mode 100644 index 0000000000000..7c29fd4af16bd --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py @@ -0,0 +1,912 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# +# Copyright 2023 The HuggingFace Inc. team. +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Stable diffusion text to image pipeline using ONNX Runtime TensorRT execution provider. +Based on https://github.com/huggingface/diffusers/blob/v0.17.1/examples/community/stable_diffusion_tensorrt_txt2img.py +Modifications: (1) Create ONNX Runtime session (2) Use I/O Binding of ONNX Runtime for inference + +Installation instructions +pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 +pip install --upgrade transformers diffusers>=0.16.0 +pip install --upgrade tensorrt>=8.6.1 +pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com +pip install onnxruntime-gpu +""" + +import gc +import os +import shutil +from typing import List, Optional, Union + +import onnx +import onnx_graphsurgeon as gs +import torch +from cuda import cudart +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion import ( + StableDiffusionPipeline, + StableDiffusionPipelineOutput, + StableDiffusionSafetyChecker, +) +from diffusers.schedulers import DDIMScheduler +from diffusers.utils import DIFFUSERS_CACHE, logging +from huggingface_hub import snapshot_download +from onnx import shape_inference +from ort_utils import OrtCudaSession +from polygraphy.backend.onnx.loader import fold_constants +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +import onnxruntime as ort + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Engine(OrtCudaSession): + def __init__(self, engine_path, device_id, onnx_path, fp16, input_profile, workspace_size, enable_cuda_graph): + self.engine_path = engine_path + self.ort_trt_provider_options = self.get_tensorrt_provider_options( + input_profile, + workspace_size, + fp16, + device_id, + enable_cuda_graph, + ) + + sess_options = ort.SessionOptions() + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL + ort_session = ort.InferenceSession( + onnx_path, + sess_options, + providers=[ + ("TensorrtExecutionProvider", self.ort_trt_provider_options), + ], + ) + + device = torch.device("cuda", device_id) + super().__init__(ort_session, device, enable_cuda_graph) + + def get_tensorrt_provider_options(self, input_profile, workspace_size, fp16, device_id, enable_cuda_graph): + trt_ep_options = { + "device_id": device_id, + "trt_fp16_enable": fp16, + "trt_engine_cache_enable": True, + "trt_timing_cache_enable": True, + "trt_detailed_build_log": True, + "trt_engine_cache_path": self.engine_path, + } + + if enable_cuda_graph: + trt_ep_options["trt_cuda_graph_enable"] = True + + if workspace_size > 0: + trt_ep_options["trt_max_workspace_size"] = workspace_size + + if input_profile: + min_shapes = [] + max_shapes = [] + opt_shapes = [] + for name, profile in input_profile.items(): + assert isinstance(profile, list) and len(profile) == 3 + min_shape = profile[0] + opt_shape = profile[1] + max_shape = profile[2] + assert len(min_shape) == len(opt_shape) and len(opt_shape) == len(max_shape) + + min_shapes.append(f"{name}:" + "x".join([str(x) for x in min_shape])) + opt_shapes.append(f"{name}:" + "x".join([str(x) for x in opt_shape])) + max_shapes.append(f"{name}:" + "x".join([str(x) for x in max_shape])) + + trt_ep_options["trt_profile_min_shapes"] = ",".join(min_shapes) + trt_ep_options["trt_profile_max_shapes"] = ",".join(max_shapes) + trt_ep_options["trt_profile_opt_shapes"] = ",".join(opt_shapes) + + logger.info("trt_ep_options=%s", trt_ep_options) + + return trt_ep_options + + +class Optimizer: + def __init__(self, onnx_graph): + self.graph = gs.import_onnx(onnx_graph) + + def cleanup(self): + self.graph.cleanup().toposort() + + def get_optimized_onnx_graph(self): + return gs.export_onnx(self.graph) + + def select_outputs(self, keep, names=None): + self.graph.outputs = [self.graph.outputs[o] for o in keep] + if names: + for i, name in enumerate(names): + self.graph.outputs[i].name = name + + def fold_constants(self): + onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) + self.graph = gs.import_onnx(onnx_graph) + + def infer_shapes(self): + onnx_graph = gs.export_onnx(self.graph) + if onnx_graph.ByteSize() > 2147483648: + raise TypeError("ERROR: model size exceeds supported 2GB limit") + else: + onnx_graph = shape_inference.infer_shapes(onnx_graph) + + self.graph = gs.import_onnx(onnx_graph) + + +class BaseModel: + def __init__(self, model, name, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77): + self.model = model + self.name = name + self.fp16 = fp16 + self.device = device + + self.min_batch = 1 + self.max_batch = max_batch_size + self.min_image_shape = 256 # min image resolution: 256x256 + self.max_image_shape = 1024 # max image resolution: 1024x1024 + self.min_latent_shape = self.min_image_shape // 8 + self.max_latent_shape = self.max_image_shape // 8 + + self.embedding_dim = embedding_dim + self.text_maxlen = text_maxlen + + def get_model(self): + return self.model + + def get_input_names(self): + pass + + def get_output_names(self): + pass + + def get_dynamic_axes(self): + return None + + def get_sample_input(self, batch_size, image_height, image_width): + pass + + def get_profile_id(self, batch_size, image_height, image_width, static_batch, static_image_shape): + ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + _, + _, + _, + _, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + + profile_id = f"_b_{batch_size}" if static_batch else f"_b_{min_batch}_{max_batch}" + + if self.name != "CLIP": + if static_image_shape: + profile_id += f"_h_{image_height}_w_{image_width}" + else: + profile_id += f"_h_{min_image_height}_{max_image_height}_w_{min_image_width}_{max_image_width}" + + return profile_id + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + return None + + def get_shape_dict(self, batch_size, image_height, image_width): + return None + + def optimize(self, onnx_graph): + opt = Optimizer(onnx_graph) + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + opt.cleanup() + return opt.get_optimized_onnx_graph() + + def check_dims(self, batch_size, image_height, image_width): + assert batch_size >= self.min_batch and batch_size <= self.max_batch + assert image_height % 8 == 0 or image_width % 8 == 0 + latent_height = image_height // 8 + latent_width = image_width // 8 + assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape + assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape + return (latent_height, latent_width) + + def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape): + min_batch = batch_size if static_batch else self.min_batch + max_batch = batch_size if static_batch else self.max_batch + latent_height = image_height // 8 + latent_width = image_width // 8 + min_image_height = image_height if static_image_shape else self.min_image_shape + max_image_height = image_height if static_image_shape else self.max_image_shape + min_image_width = image_width if static_image_shape else self.min_image_shape + max_image_width = image_width if static_image_shape else self.max_image_shape + min_latent_height = latent_height if static_image_shape else self.min_latent_shape + max_latent_height = latent_height if static_image_shape else self.max_latent_shape + min_latent_width = latent_width if static_image_shape else self.min_latent_shape + max_latent_width = latent_width if static_image_shape else self.max_latent_shape + return ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) + + +def get_onnx_path(model_name, onnx_dir, opt=True): + return os.path.join(onnx_dir, model_name + (".opt" if opt else "") + ".onnx") + + +def get_engine_path(engine_dir, model_name, profile_id): + return os.path.join(engine_dir, model_name + profile_id) + + +def has_engine_file(engine_path): + if os.path.isdir(engine_path): + children = os.scandir(engine_path) + for entry in children: + if entry.is_file() and entry.name.endswith(".engine"): + return True + return False + + +def get_work_space_size(model_name, max_workspace_size): + gibibyte = 2**30 + workspace_size = 4 * gibibyte if model_name == "clip" else max_workspace_size + if workspace_size == 0: + _, free_mem, _ = cudart.cudaMemGetInfo() + # The following logic are adopted from TensorRT demo diffusion. + if free_mem > 6 * gibibyte: + workspace_size = free_mem - 4 * gibibyte + return workspace_size + + +def build_engines( + models, + engine_dir, + onnx_dir, + onnx_opset, + opt_image_height, + opt_image_width, + opt_batch_size=1, + force_engine_rebuild=False, + static_batch=False, + static_image_shape=True, + max_workspace_size=0, + device_id=0, + enable_cuda_graph=False, +): + if force_engine_rebuild: + if os.path.isdir(onnx_dir): + logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir) + shutil.rmtree(onnx_dir) + if os.path.isdir(engine_dir): + logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir) + shutil.rmtree(engine_dir) + + if not os.path.isdir(engine_dir): + os.makedirs(engine_dir) + + if not os.path.isdir(onnx_dir): + os.makedirs(onnx_dir) + + # Export models to ONNX + for model_name, model_obj in models.items(): + profile_id = model_obj.get_profile_id( + opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape + ) + engine_path = get_engine_path(engine_dir, model_name, profile_id) + if not has_engine_file(engine_path): + onnx_path = get_onnx_path(model_name, onnx_dir, opt=False) + onnx_opt_path = get_onnx_path(model_name, onnx_dir) + if not os.path.exists(onnx_opt_path): + if not os.path.exists(onnx_path): + logger.info(f"Exporting model: {onnx_path}") + model = model_obj.get_model() + with torch.inference_mode(), torch.autocast("cuda"): + inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) + torch.onnx.export( + model, + inputs, + onnx_path, + export_params=True, + opset_version=onnx_opset, + do_constant_folding=True, + input_names=model_obj.get_input_names(), + output_names=model_obj.get_output_names(), + dynamic_axes=model_obj.get_dynamic_axes(), + ) + del model + torch.cuda.empty_cache() + gc.collect() + else: + logger.info("Found cached model: %s", onnx_path) + + # Optimize onnx + if not os.path.exists(onnx_opt_path): + logger.info("Generating optimizing model: %s", onnx_opt_path) + onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path)) + onnx.save(onnx_opt_graph, onnx_opt_path) + else: + logger.info("Found cached optimized model: %s", onnx_opt_path) + + built_engines = {} + for model_name, model_obj in models.items(): + profile_id = model_obj.get_profile_id( + opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape + ) + + engine_path = get_engine_path(engine_dir, model_name, profile_id) + onnx_opt_path = get_onnx_path(model_name, onnx_dir) + + if not has_engine_file(engine_path): + logger.info( + "Building TensorRT engine for %s from %s to %s. It can take a while to complete...", + model_name, + onnx_opt_path, + engine_path, + ) + else: + logger.info("Reuse cached TensorRT engine in directory %s", engine_path) + + input_profile = model_obj.get_input_profile( + opt_batch_size, + opt_image_height, + opt_image_width, + static_batch=static_batch, + static_image_shape=static_image_shape, + ) + + engine = Engine( + engine_path, + device_id, + onnx_opt_path, + fp16=True, + input_profile=input_profile, + workspace_size=get_work_space_size(model_name, max_workspace_size), + enable_cuda_graph=enable_cuda_graph, + ) + + built_engines[model_name] = engine + + return built_engines + + +def run_engine(engine, feed_dict): + return engine.infer(feed_dict) + + +class CLIP(BaseModel): + def __init__(self, model, device, max_batch_size, embedding_dim): + super().__init__( + model=model, name="CLIP", device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim + ) + + def get_input_names(self): + return ["input_ids"] + + def get_output_names(self): + return ["text_embeddings", "pooler_output"] + + def get_dynamic_axes(self): + return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + self.check_dims(batch_size, image_height, image_width) + min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_image_shape + ) + return { + "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return { + "input_ids": (batch_size, self.text_maxlen), + "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), + } + + def get_sample_input(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device) + + def optimize(self, onnx_graph): + opt = Optimizer(onnx_graph) + opt.select_outputs([0]) # delete graph output#1 + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + opt.select_outputs([0], names=["text_embeddings"]) # rename network output + opt.cleanup() + return opt.get_optimized_onnx_graph() + + +class UNet(BaseModel): + def __init__( + self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77, unet_dim=4 + ): + super().__init__( + model=model, + name="UNet", + fp16=fp16, + device=device, + max_batch_size=max_batch_size, + embedding_dim=embedding_dim, + text_maxlen=text_maxlen, + ) + self.unet_dim = unet_dim + + def get_input_names(self): + return ["sample", "timestep", "encoder_hidden_states"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return { + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + return { + "sample": [ + (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), + (2 * batch_size, self.unet_dim, latent_height, latent_width), + (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), + ], + "encoder_hidden_states": [ + (2 * min_batch, self.text_maxlen, self.embedding_dim), + (2 * batch_size, self.text_maxlen, self.embedding_dim), + (2 * max_batch, self.text_maxlen, self.embedding_dim), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), + "timestep": [1], + "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), + "latent": (2 * batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + dtype = torch.float16 if self.fp16 else torch.float32 + return ( + torch.randn( + 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + ), + torch.tensor([1.0], dtype=torch.float32, device=self.device), + torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + ) + + +class VAE(BaseModel): + def __init__(self, model, device, max_batch_size, embedding_dim): + super().__init__( + model=model, name="VAE decoder", device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim + ) + + def get_input_names(self): + return ["latent"] + + def get_output_names(self): + return ["images"] + + def get_dynamic_axes(self): + return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + return { + "latent": [ + (min_batch, 4, min_latent_height, min_latent_width), + (batch_size, 4, latent_height, latent_width), + (max_batch, 4, max_latent_height, max_latent_width), + ] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "latent": (batch_size, 4, latent_height, latent_width), + "images": (batch_size, 3, image_height, image_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device) + + +class OnnxruntimeTensorRTStableDiffusionPipeline(StableDiffusionPipeline): + r""" + Pipeline for text-to-image generation using TensorRT execution provider in ONNX Runtime. + + This pipeline inherits from [`StableDiffusionPipeline`]. Check the documentation in super class for most parameters. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + image_height: int = 768, + image_width: int = 768, + max_batch_size: int = 16, + # ONNX export parameters + onnx_opset: int = 17, + onnx_dir: str = "onnx", + # TensorRT engine build parameters + engine_dir: str = "onnxruntime_tensorrt_engine", + force_engine_rebuild: bool = False, + enable_cuda_graph: bool = False, + ): + super().__init__( + vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker + ) + + self.vae.forward = self.vae.decode + + self.image_height = image_height + self.image_width = image_width + self.inpaint = False + self.onnx_opset = onnx_opset + self.onnx_dir = onnx_dir + self.engine_dir = engine_dir + self.force_engine_rebuild = force_engine_rebuild + self.enable_cuda_graph = enable_cuda_graph + + # Although cuda graph requires static input shape, engine built with dyamic batch gets better performance in T4. + # Use static batch could reduce GPU memory footprint. + self.build_static_batch = False + + # TODO: support dynamic image shape. + self.build_dynamic_shape = False + + self.max_batch_size = max_batch_size + # Restrict batch size to 4 for larger image dimensions as a walkaround for TensorRT limitation. + if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512: + self.max_batch_size = 4 + + self.models = {} # loaded in __load_models() + self.engines = {} # loaded in build_engines() + + def __load_models(self): + self.embedding_dim = self.text_encoder.config.hidden_size + + self.models["clip"] = CLIP( + self.text_encoder, + device=self.torch_device, + max_batch_size=self.max_batch_size, + embedding_dim=self.embedding_dim, + ) + + self.models["unet"] = UNet( + self.unet, + device=self.torch_device, + max_batch_size=self.max_batch_size, + embedding_dim=self.embedding_dim, + unet_dim=(9 if self.inpaint else 4), + ) + + self.models["vae"] = VAE( + self.vae, device=self.torch_device, max_batch_size=self.max_batch_size, embedding_dim=self.embedding_dim + ) + + @classmethod + def set_cached_folder(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + + cls.cached_folder = ( + pretrained_model_name_or_path + if os.path.isdir(pretrained_model_name_or_path) + else snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + ) + + def to( + self, + torch_device: Optional[Union[str, torch.device]] = None, + silence_dtype_warnings: bool = False, + ): + super().to(torch_device, silence_dtype_warnings=silence_dtype_warnings) + + self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir) + self.engine_dir = os.path.join(self.cached_folder, self.engine_dir) + + # set device + self.torch_device = self._execution_device + logger.info(f"Running inference on device: {self.torch_device}") + + self.__load_models() + + self.engines = build_engines( + self.models, + self.engine_dir, + self.onnx_dir, + self.onnx_opset, + opt_image_height=self.image_height, + opt_image_width=self.image_width, + force_engine_rebuild=self.force_engine_rebuild, + static_batch=self.build_static_batch, + static_image_shape=not self.build_dynamic_shape, + device_id=self.torch_device.index, + enable_cuda_graph=self.enable_cuda_graph, + ) + + return self + + def __encode_prompt(self, prompt, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + """ + # Tokenize prompt + text_input_ids = ( + self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.torch_device) + ) + + # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt + text_embeddings = run_engine(self.engines["clip"], {"input_ids": text_input_ids})["text_embeddings"].clone() + + # Tokenize negative prompt + uncond_input_ids = ( + self.tokenizer( + negative_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.torch_device) + ) + + uncond_embeddings = run_engine(self.engines["clip"], {"input_ids": uncond_input_ids})["text_embeddings"] + + # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) + + return text_embeddings + + def __denoise_latent(self, latents, text_embeddings, timesteps=None, mask=None, masked_image_latents=None): + if not isinstance(timesteps, torch.Tensor): + timesteps = self.scheduler.timesteps + for _step_index, timestep in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) + if isinstance(mask, torch.Tensor): + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # Predict the noise residual + timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep + + noise_pred = run_engine( + self.engines["unet"], + {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings}, + )["latent"] + + # Perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample + + latents = 1.0 / 0.18215 * latents + return latents + + def __decode_latent(self, latents): + images = run_engine(self.engines["vae"], {"latent": latents})["images"] + images = (images / 2 + 0.5).clamp(0, 1) + return images.cpu().permute(0, 2, 3, 1).float().numpy() + + def __allocate_buffers(self, image_height, image_width, batch_size): + # Allocate output tensors for I/O bindings + for model_name, obj in self.models.items(): + self.engines[model_name].allocate_buffers(obj.get_shape_dict(batch_size, image_height, image_width)) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + + """ + self.generator = generator + self.denoising_steps = num_inference_steps + self.guidance_scale = guidance_scale + + # Pre-compute latent input scales and linear multistep coefficients + self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device) + + # Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"Expected prompt to be of type list or str but got {type(prompt)}") + + if negative_prompt is None: + negative_prompt = [""] * batch_size + + if negative_prompt is not None and isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + assert len(prompt) == len(negative_prompt) + + if batch_size > self.max_batch_size: + raise ValueError( + f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4" + ) + + self.__allocate_buffers(self.image_height, self.image_width, batch_size) + + with torch.inference_mode(), torch.autocast("cuda"): + # CLIP text encoder + text_embeddings = self.__encode_prompt(prompt, negative_prompt) + + # Pre-initialize latents + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size, + num_channels_latents, + self.image_height, + self.image_width, + torch.float32, + self.torch_device, + generator, + ) + + # UNet denoiser + latents = self.__denoise_latent(latents, text_embeddings) + + # VAE decode latent + images = self.__decode_latent(latents) + + images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype) + images = self.numpy_to_pil(images) + return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + + +if __name__ == "__main__": + import torch + from diffusers import DDIMScheduler + + scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="scheduler") + + pipe = OnnxruntimeTensorRTStableDiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1", + revision="fp16", + torch_dtype=torch.float16, + scheduler=scheduler, + image_height=512, + image_width=512, + max_batch_size=1, + ) + + # re-use cached folder to save ONNX models and TensorRT Engines + pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", revision="fp16") + + pipe = pipe.to("cuda") + + prompt = "photorealistic new zealand hills" + image = pipe(prompt).images[0] + image.save("ort_trt_txt2img_new_zealand_hills.png") diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py new file mode 100644 index 0000000000000..5ad43b6e39893 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py @@ -0,0 +1,112 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from collections import OrderedDict +from typing import Dict + +import torch + +import onnxruntime as ort +from onnxruntime.transformers.io_binding_helper import TypeHelper + + +class OrtCudaSession: + """ONNX Runtime Session for CUDA or TensorRT provider""" + + def __init__(self, ort_session: ort.InferenceSession, device: torch.device, enable_cuda_graph=False): + self.ort_session = ort_session + self.input_names = [input.name for input in self.ort_session.get_inputs()] + self.output_names = [output.name for output in self.ort_session.get_outputs()] + self.io_name_to_numpy_type = TypeHelper.get_io_numpy_type_map(self.ort_session) + self.io_binding = self.ort_session.io_binding() + self.enable_cuda_graph = enable_cuda_graph + + self.input_tensors = OrderedDict() + self.output_tensors = OrderedDict() + self.device = device + + def __del__(self): + del self.input_tensors + del self.output_tensors + del self.io_binding + del self.ort_session + + def allocate_buffers(self, shape_dict: Dict[str, tuple]): + """Allocate tensors for I/O Binding""" + if self.enable_cuda_graph: + for name, shape in shape_dict.items(): + if name in self.input_names: + # Reuse allocated buffer when the shape is same + if name in self.input_tensors: + if tuple(self.input_tensors[name].shape) == tuple(shape): + continue + raise RuntimeError("Expect static input shape for cuda graph") + + numpy_dtype = self.io_name_to_numpy_type[name] + tensor = torch.empty(tuple(shape), dtype=TypeHelper.numpy_type_to_torch_type(numpy_dtype)).to( + device=self.device + ) + self.input_tensors[name] = tensor + + self.io_binding.bind_input( + name, + tensor.device.type, + tensor.device.index, + numpy_dtype, + list(tensor.size()), + tensor.data_ptr(), + ) + + for name, shape in shape_dict.items(): + if name in self.output_names: + # Reuse allocated buffer when the shape is same + if name in self.output_tensors and tuple(self.output_tensors[name].shape) == tuple(shape): + continue + + numpy_dtype = self.io_name_to_numpy_type[name] + tensor = torch.empty(tuple(shape), dtype=TypeHelper.numpy_type_to_torch_type(numpy_dtype)).to( + device=self.device + ) + self.output_tensors[name] = tensor + + self.io_binding.bind_output( + name, + tensor.device.type, + tensor.device.index, + numpy_dtype, + list(tensor.size()), + tensor.data_ptr(), + ) + + def infer(self, feed_dict): + """Bind input tensors and run inference""" + for name, tensor in feed_dict.items(): + assert isinstance(tensor, torch.Tensor) and tensor.is_contiguous() + if name in self.input_names: + if self.enable_cuda_graph: + assert self.input_tensors[name].nelement() == tensor.nelement() + assert tensor.device.type == "cuda" + # Update input tensor inplace since cuda graph requires input and output has fixed memory address. + from cuda import cudart + + cudart.cudaMemcpy( + self.input_tensors[name].data_ptr(), + tensor.data_ptr(), + tensor.element_size() * tensor.nelement(), + cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice, + ) + else: + self.io_binding.bind_input( + name, + tensor.device.type, + tensor.device.index, + TypeHelper.torch_type_to_numpy_type(tensor.dtype), + [1] if len(tensor.shape) == 0 else list(tensor.shape), + tensor.data_ptr(), + ) + + self.ort_session.run_with_iobinding(self.io_binding) + + return self.output_tensors diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt index 18852f515a20d..b942749f8dcd2 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt @@ -1,6 +1,8 @@ -r requirements.txt onnxruntime-gpu>=1.14 -py3nvml==0.2.7 +py3nvml>=0.2.7 +# cuda-python is needed for cuda graph. It shall be compatible with CUDA version of torch and onnxruntime-gpu. +cuda-python==11.7.0 #To export onnx of stable diffusion, please install PyTorch 1.13.1+cu117 #--extra-index-url https://download.pytorch.org/whl/cu117 #torch==1.13.1+cu117 diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt new file mode 100644 index 0000000000000..e95fa6691fd47 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt @@ -0,0 +1,19 @@ +diffusers>=0.16.0 +transformers>=4.26.0 +numpy>=1.24.1 +accelerate +onnx>=1.13.0 +coloredlogs +packaging +protobuf +psutil +sympy +tensorrt>=8.6.1 +onnxruntime-gpu>=1.15.1 +py3nvml +wget +# cuda-python version shall be compatible with CUDA version of torch and onnxruntime-gpu +cuda-python==11.7.0 +#To export onnx of stable diffusion, please install PyTorch 1.13.1+cu117 +#--extra-index-url https://download.pytorch.org/whl/cu117 +#torch==1.13.1+cu117 diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt index 06d9c18468121..68947a1618a76 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements.txt @@ -1,10 +1,10 @@ -diffusers==0.15.1 -transformers==4.26.0 -numpy==1.24.1 -accelerate==0.15.0 -onnx==1.13.0 +diffusers>=0.15.1 +transformers>=4.26.0 +numpy>=1.24.1 +accelerate +onnx>=1.13.0 coloredlogs -packaging==23.0 +packaging protobuf==3.20.3 -psutil==5.9.4 -sympy==1.11.1 +psutil +sympy