From 6dc8c6837000f8c6e0a9b3fa3e459f4bdb876491 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 29 Apr 2024 08:25:55 +0100 Subject: [PATCH] init --- .github/scripts/m1_script.sh | 2 +- .github/workflows/wheels.yml | 4 ++-- setup.py | 14 +++++++++----- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/.github/scripts/m1_script.sh b/.github/scripts/m1_script.sh index 4548b1430..fd374bf79 100644 --- a/.github/scripts/m1_script.sh +++ b/.github/scripts/m1_script.sh @@ -1,3 +1,3 @@ #!/bin/bash -export BUILD_VERSION=0.4.0 +export TENSORDICT_BUILD_VERSION=0.4.0 diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index accfc5994..bc8ca9167 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -32,7 +32,7 @@ jobs: run: | export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" python3 -mpip install wheel - BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel + TENSORDICT_BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel # NB: wheels have the linux_x86_64 tag so we rename to manylinux1 # find . -name 'dist/*whl' -exec bash -c ' mv $0 ${0/linux/manylinux1}' {} \; # pytorch/pytorch binaries are also manylinux_2_17 compliant but they @@ -72,7 +72,7 @@ jobs: shell: bash run: | python3 -mpip install wheel - BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel + TENSORDICT_BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel - name: Upload wheel for the test-wheel job uses: actions/upload-artifact@v2 with: diff --git a/setup.py b/setup.py index c26fc30b3..c7edb4baa 100644 --- a/setup.py +++ b/setup.py @@ -44,8 +44,8 @@ def parse_args(argv: List[str]) -> argparse.Namespace: def get_version(): version = (ROOT_DIR / "version.txt").read_text().strip() - if os.getenv("BUILD_VERSION"): - version = os.getenv("BUILD_VERSION") + if os.getenv("TENSORDICT_BUILD_VERSION"): + version = os.getenv("TENSORDICT_BUILD_VERSION") elif sha != "Unknown": version += "+" + sha[:7] return version @@ -62,11 +62,13 @@ def write_version_file(version): f.write(f"git_version = {repr(sha)}\n") -def _get_pytorch_version(is_nightly): +def _get_pytorch_version(is_nightly, is_local): # if "PYTORCH_VERSION" in os.environ: # return f"torch=={os.environ['PYTORCH_VERSION']}" if is_nightly: return "torch>=2.4.0.dev" + if is_local: + return "torch" return "torch>=2.3.0" @@ -153,9 +155,11 @@ def _main(argv): write_version_file(version) logging.info(f"Building wheel {package_name}-{version}") - logging.info(f"BUILD_VERSION is {os.getenv('BUILD_VERSION')}") + BUILD_VERSION = os.getenv("TENSORDICT_BUILD_VERSION") + logging.info(f"TENSORDICT_BUILD_VERSION is {BUILD_VERSION}") + local_build = BUILD_VERSION is None - pytorch_package_dep = _get_pytorch_version(is_nightly) + pytorch_package_dep = _get_pytorch_version(is_nightly, local_build) logging.info("-- PyTorch dependency:", pytorch_package_dep) long_description = (ROOT_DIR / "README.md").read_text()