diff --git a/src/litdata/CHANGELOG.md b/src/litdata/CHANGELOG.md index fd392861..12013084 100644 --- a/src/litdata/CHANGELOG.md +++ b/src/litdata/CHANGELOG.md @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Introduced `CHANGELOG.md` to track changes across releases ([#733](https://github.com/lightning-ai/litdata/pull/733)) +- Add environment variable `LITDATA_DISABLE_VERSION_CHECK` to disable PyPI version check ([#737](https://github.com/Lightning-AI/litData/pull/737)) ### Changed diff --git a/src/litdata/constants.py b/src/litdata/constants.py index b7e7e46d..2820cc09 100644 --- a/src/litdata/constants.py +++ b/src/litdata/constants.py @@ -55,6 +55,7 @@ _MAX_WAIT_TIME = int(os.getenv("MAX_WAIT_TIME", "120")) _FORCE_DOWNLOAD_TIME = int(os.getenv("FORCE_DOWNLOAD_TIME", "30")) +_LITDATA_DISABLE_VERSION_CHECK = int(os.getenv("LITDATA_DISABLE_VERSION_CHECK", "0")) # DON'T CHANGE ORDER _TORCH_DTYPES_MAPPING = { diff --git a/src/litdata/helpers.py b/src/litdata/helpers.py index 5a343415..f9d79550 100644 --- a/src/litdata/helpers.py +++ b/src/litdata/helpers.py @@ -5,6 +5,8 @@ import requests from packaging import version as packaging_version +from litdata.constants import _LITDATA_DISABLE_VERSION_CHECK + class WarningCache(set): """Cache for warnings.""" @@ -28,7 +30,7 @@ def _get_newer_version(curr_version: str) -> Optional[str]: Returning the newest version if different from the current or ``None`` otherwise. """ - if packaging_version.parse(curr_version).is_prerelease: + if _LITDATA_DISABLE_VERSION_CHECK == 1 or packaging_version.parse(curr_version).is_prerelease: return None try: response = requests.get(f"https://pypi.org/pypi/{__package_name__}/json", timeout=30) diff --git a/tests/test_helper.py b/tests/test_helper.py new file mode 100644 index 00000000..a6b8ac84 --- /dev/null +++ b/tests/test_helper.py @@ -0,0 +1,47 @@ +import warnings +from unittest.mock import Mock, patch + +import pytest + +from litdata.helpers import _check_version_and_prompt_upgrade, _get_newer_version + + +@pytest.mark.parametrize("disable_version_check", [1, 0, None]) +def test_get_newer_version_respects_env_flag(monkeypatch, disable_version_check): + """Verify that _get_newer_version respects LITDATA_DISABLE_VERSION_CHECK and skips requests when disabled.""" + if disable_version_check is not None: + monkeypatch.setattr("litdata.helpers._LITDATA_DISABLE_VERSION_CHECK", disable_version_check) + + # Mock requests.get + mock_get = Mock() + mock_get.return_value.json.return_value = { + "releases": {"0.2.50": [], "2.51.0": []}, + "info": {"version": "2.51.0", "yanked": False}, + } + + monkeypatch.setattr("litdata.helpers.requests.get", mock_get) + + # Clear cached function results + _get_newer_version.cache_clear() + + result = _get_newer_version("0.2.50") + + if disable_version_check: + assert result is None + mock_get.assert_not_called() + else: + assert result == "2.51.0" + mock_get.assert_called_once_with("https://pypi.org/pypi/litdata/json", timeout=30) + + +@patch("litdata.helpers._get_newer_version") +def test_check_version_default_behavior_warning(mock_get_newer, monkeypatch): + """Test default behavior: calls _get_newer_version and warns if newer version exists.""" + mock_get_newer.return_value = "0.2.58" + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + _check_version_and_prompt_upgrade("0.2.50") + assert len(w) == 1 + assert f"A newer version of litdata is available ({mock_get_newer.return_value})" in str(w[0].message) + mock_get_newer.assert_called_once_with("0.2.50")