Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/litdata/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Introduced `CHANGELOG.md` to track changes across releases ([#733](https://github.com/lightning-ai/litdata/pull/733))
- Add environment variable `LITDATA_DISABLE_VERSION_CHECK` to disable PyPI version check ([#737](https://github.com/Lightning-AI/litData/pull/737))

### Changed

Expand Down
1 change: 1 addition & 0 deletions src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

_MAX_WAIT_TIME = int(os.getenv("MAX_WAIT_TIME", "120"))
_FORCE_DOWNLOAD_TIME = int(os.getenv("FORCE_DOWNLOAD_TIME", "30"))
_LITDATA_DISABLE_VERSION_CHECK = int(os.getenv("LITDATA_DISABLE_VERSION_CHECK", "0"))

# DON'T CHANGE ORDER
_TORCH_DTYPES_MAPPING = {
Expand Down
4 changes: 3 additions & 1 deletion src/litdata/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import requests
from packaging import version as packaging_version

from litdata.constants import _LITDATA_DISABLE_VERSION_CHECK


class WarningCache(set):
"""Cache for warnings."""
Expand All @@ -28,7 +30,7 @@ def _get_newer_version(curr_version: str) -> Optional[str]:
Returning the newest version if different from the current or ``None`` otherwise.

"""
if packaging_version.parse(curr_version).is_prerelease:
if _LITDATA_DISABLE_VERSION_CHECK == 1 or packaging_version.parse(curr_version).is_prerelease:
return None
try:
response = requests.get(f"https://pypi.org/pypi/{__package_name__}/json", timeout=30)
Expand Down
57 changes: 57 additions & 0 deletions tests/test_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import importlib
import sys
import warnings
from unittest.mock import Mock, patch

import pytest


@pytest.mark.parametrize("disable_version_check", ["1", "0", None])
def test_get_newer_version_respects_env_flag(monkeypatch, disable_version_check):
"""Verify that _get_newer_version respects LITDATA_DISABLE_VERSION_CHECK and skips requests when disabled."""
monkeypatch.delenv("LITDATA_DISABLE_VERSION_CHECK", raising=False)

if disable_version_check is not None:
monkeypatch.setenv("LITDATA_DISABLE_VERSION_CHECK", disable_version_check)

# Reload both modules so constants re-evaluate environment variables
sys.modules.pop("litdata.constants", None)
sys.modules.pop("litdata.helpers", None)
importlib.import_module("litdata.helpers")
from litdata import helpers

# Mock requests.get
mock_get = Mock()
mock_get.return_value.json.return_value = {
"releases": {"0.2.50": [], "2.51.0": []},
"info": {"version": "2.51.0", "yanked": False},
}

monkeypatch.setattr("litdata.helpers.requests.get", mock_get)

# Clear cached function results
helpers._get_newer_version.cache_clear()

result = helpers._get_newer_version("0.2.50")

if disable_version_check == "1":
assert result is None
mock_get.assert_not_called()
else:
assert result == "2.51.0"
mock_get.assert_called_once_with("https://pypi.org/pypi/litdata/json", timeout=30)


@patch("litdata.helpers._get_newer_version")
def test_check_version_default_behavior_warning(mock_get_newer, monkeypatch):
"""Test default behavior: calls _get_newer_version and warns if newer version exists."""
mock_get_newer.return_value = "0.2.58"

from litdata.helpers import _check_version_and_prompt_upgrade

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
_check_version_and_prompt_upgrade("0.2.50")
assert len(w) == 1
assert f"A newer version of litdata is available ({mock_get_newer.return_value})" in str(w[0].message)
mock_get_newer.assert_called_once_with("0.2.50")
Loading