From 5b37bd8dfe6068899c11b9f3f7a3334f95ce8634 Mon Sep 17 00:00:00 2001 From: Nate Parsons <4307001+thehomebrewnerd@users.noreply.github.com> Date: Fri, 10 May 2024 09:59:53 -0500 Subject: [PATCH] Use filter arg in tarfile.extractall to prevent unsafe unarchival operations (#2722) * use filter in tarfile.extractall * update release notes * update release notes action * update docstring --- .github/workflows/release_notes_updated.yaml | 4 +++- docs/source/release_notes.rst | 1 + featuretools/entityset/deserialize.py | 10 +++++++++- .../tests/entityset_tests/test_serialization.py | 14 +++++++++++++- 4 files changed, 26 insertions(+), 3 deletions(-) diff --git a/.github/workflows/release_notes_updated.yaml b/.github/workflows/release_notes_updated.yaml index def24cbb32..1c88ebd111 100644 --- a/.github/workflows/release_notes_updated.yaml +++ b/.github/workflows/release_notes_updated.yaml @@ -10,6 +10,8 @@ jobs: - name: Check for development branch id: branch shell: python + env: + REF: ${{ github.event.pull_request.head.ref }} run: | from re import compile main = '^main$' @@ -19,7 +21,7 @@ jobs: min_dep_update = '^min-dep-update-[a-f0-9]{7}$' regex = main, release, backport, dep_update, min_dep_update patterns = list(map(compile, regex)) - ref = "${{ github.event.pull_request.head.ref }}" + ref = "$REF" is_dev = not any(pattern.match(ref) for pattern in patterns) print('::set-output name=is_dev::' + str(is_dev)) - if: ${{ steps.branch.outputs.is_dev == 'true' }} diff --git a/docs/source/release_notes.rst b/docs/source/release_notes.rst index 77eb5b071a..9df35edd74 100644 --- a/docs/source/release_notes.rst +++ b/docs/source/release_notes.rst @@ -13,6 +13,7 @@ Future Release * Temporarily restrict Dask version (:pr:`2694`) * Remove support for creating ``EntitySets`` from Dask or Pyspark dataframes (:pr:`2705`) * Bump minimum versions of ``tqdm`` and ``pip`` in requirements files (:pr:`2716`) + * Use ``filter`` arg in call to ``tarfile.extractall`` to safely deserialize EntitySets (:pr:`2722`) * Documentation Changes * Testing Changes * Fix serialization test to work with pytest 8.1.1 (:pr:`2694`) diff --git a/featuretools/entityset/deserialize.py b/featuretools/entityset/deserialize.py index fbcd6c6ef8..8a7d448b46 100644 --- a/featuretools/entityset/deserialize.py +++ b/featuretools/entityset/deserialize.py @@ -2,6 +2,7 @@ import os import tarfile import tempfile +from inspect import getfullargspec import pandas as pd import woodwork.type_sys.type_system as ww_type_system @@ -140,6 +141,8 @@ def read_data_description(path): def read_entityset(path, profile_name=None, **kwargs): """Read entityset from disk, S3 path, or URL. + NOTE: Never attempt to read an archived EntitySet from an untrusted source. + Args: path (str): Directory on disk, S3 path, or URL to read `data_description.json`. profile_name (str, bool): The AWS profile specified to write to S3. Will default to None and search for AWS credentials. @@ -159,7 +162,12 @@ def read_entityset(path, profile_name=None, **kwargs): use_smartopen_es(local_path, path, transport_params) with tarfile.open(str(local_path)) as tar: - tar.extractall(path=tmpdir) + if "filter" in getfullargspec(tar.extractall).kwonlyargs: + tar.extractall(path=tmpdir, filter="data") + else: + raise RuntimeError( + "Please upgrade your Python version to the latest patch release to allow for safe extraction of the EntitySet archive.", + ) data_description = read_data_description(tmpdir) return description_to_entityset(data_description, **kwargs) diff --git a/featuretools/tests/entityset_tests/test_serialization.py b/featuretools/tests/entityset_tests/test_serialization.py index deb9da2d77..6511725a5e 100644 --- a/featuretools/tests/entityset_tests/test_serialization.py +++ b/featuretools/tests/entityset_tests/test_serialization.py @@ -2,7 +2,7 @@ import logging import os import tempfile -from unittest.mock import patch +from unittest.mock import MagicMock, patch from urllib.request import urlretrieve import boto3 @@ -292,6 +292,18 @@ def test_deserialize_local_tar(es): assert es.__eq__(new_es, deep=True) +@patch("featuretools.entityset.deserialize.getfullargspec") +def test_deserialize_errors_if_python_version_unsafe(mock_inspect, es): + mock_response = MagicMock() + mock_response.kwonlyargs = [] + mock_inspect.return_value = mock_response + with tempfile.TemporaryDirectory() as tmp_path: + temp_tar_filepath = os.path.join(tmp_path, TEST_FILE) + urlretrieve(URL, filename=temp_tar_filepath) + with pytest.raises(RuntimeError, match=""): + deserialize.read_entityset(temp_tar_filepath) + + def test_deserialize_url_csv(es): new_es = deserialize.read_entityset(URL) assert es.__eq__(new_es, deep=True)