diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 3bfbc01086cd3..b50d11a0aff6d 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2074,6 +2074,17 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): if is_windows(): cwd = os.path.join(cwd, config) + if args.enable_transformers_tool_test and not args.disable_contrib_ops and not args.use_rocm: + # PyTorch is required for transformers tests, and optional for some python tests. + # Install cpu only version of torch when cuda is not enabled in Linux. + extra = [] if args.use_cuda and is_linux() else ["--index-url", "https://download.pytorch.org/whl/cpu"] + run_subprocess( + [sys.executable, "-m", "pip", "install", "torch", *extra], + cwd=cwd, + dll_path=dll_path, + python_path=python_path, + ) + run_subprocess( [sys.executable, "onnxruntime_test_python.py"], cwd=cwd, dll_path=dll_path, python_path=python_path ) @@ -2128,6 +2139,7 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): dll_path=dll_path, python_path=python_path, ) + if not args.disable_contrib_ops: run_subprocess( [sys.executable, "-m", "unittest", "discover", "-s", "quantization"], cwd=cwd, dll_path=dll_path diff --git a/tools/ci_build/requirements/transformers-test/requirements.txt b/tools/ci_build/requirements/transformers-test/requirements.txt index cb93043e09b63..14aeff3df9c62 100644 --- a/tools/ci_build/requirements/transformers-test/requirements.txt +++ b/tools/ci_build/requirements/transformers-test/requirements.txt @@ -1,8 +1,9 @@ -# packages used by transformers python unittest (only enabled in Linux CPU CI Pipeline) +# packages used by transformers python unittest packaging -protobuf==3.20.2 -numpy==1.24.0 ; python_version < '3.12' -numpy==1.26.0 ; python_version >= '3.12' +# protobuf and numpy is same as tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt +protobuf==4.21.12 +numpy==1.21.6 ; python_version < '3.9' +numpy==2.0.0 ; python_version >= '3.9' torch coloredlogs==15.0 transformers==4.46.3