diff --git a/pandera/api/base/checks.py b/pandera/api/base/checks.py index 0f4790699..089e2b133 100644 --- a/pandera/api/base/checks.py +++ b/pandera/api/base/checks.py @@ -123,9 +123,15 @@ def from_builtin_check_name( init_kwargs, error: Union[str, Callable], statistics: Optional[dict[str, Any]] = None, + defaults: Optional[dict[str, Any]] = None, **check_kwargs, ): """Create a Check object from a built-in check's name.""" + # Apply defaults to init_kwargs if provided + if defaults: + for key, value in defaults.items(): + init_kwargs.setdefault(key, value) + kws = {**init_kwargs, **check_kwargs} if "error" not in kws: kws["error"] = error diff --git a/pandera/api/checks.py b/pandera/api/checks.py index d651afc80..9fe7599a7 100644 --- a/pandera/api/checks.py +++ b/pandera/api/checks.py @@ -34,6 +34,7 @@ def __init__( description: Optional[str] = None, statistics: Optional[dict[str, Any]] = None, strategy: Optional[Any] = None, + determined_by_unique: bool = False, **check_kwargs, ) -> None: """Apply a validation function to a data object. @@ -98,6 +99,12 @@ def __init__( :param strategy: A hypothesis strategy, used for implementing data synthesis strategies for this check. See the :ref:`User Guide ` for more details. + :param determined_by_unique: If True, indicates that this check's + result is fully determined by the unique values in the data, meaning + duplicate values don't affect the outcome. This enables significant + performance optimizations for MultiIndex validation when dealing with + large datasets. If True, the check function must produce the same result + whether applied to unique values or full values. :param check_kwargs: key-word arguments to pass into ``check_fn`` :example: @@ -177,6 +184,7 @@ def __init__( self.n_failure_cases = n_failure_cases self.title = title self.description = description + self.determined_by_unique = determined_by_unique if groupby is None and groups is not None: raise ValueError( @@ -240,6 +248,7 @@ def equal_to(cls, value: Any, **kwargs) -> "Check": "equal_to", kwargs, error=f"equal_to({value})", + defaults={"determined_by_unique": True}, value=value, ) @@ -253,6 +262,7 @@ def not_equal_to(cls, value: Any, **kwargs) -> "Check": "not_equal_to", kwargs, error=f"not_equal_to({value})", + defaults={"determined_by_unique": True}, value=value, ) @@ -272,6 +282,7 @@ def greater_than(cls, min_value: Any, **kwargs) -> "Check": "greater_than", kwargs, error=f"greater_than({min_value})", + defaults={"determined_by_unique": True}, min_value=min_value, ) @@ -289,6 +300,7 @@ def greater_than_or_equal_to(cls, min_value: Any, **kwargs) -> "Check": "greater_than_or_equal_to", kwargs, error=f"greater_than_or_equal_to({min_value})", + defaults={"determined_by_unique": True}, min_value=min_value, ) @@ -306,6 +318,7 @@ def less_than(cls, max_value: Any, **kwargs) -> "Check": "less_than", kwargs, error=f"less_than({max_value})", + defaults={"determined_by_unique": True}, max_value=max_value, ) @@ -323,6 +336,7 @@ def less_than_or_equal_to(cls, max_value: Any, **kwargs) -> "Check": "less_than_or_equal_to", kwargs, error=f"less_than_or_equal_to({max_value})", + defaults={"determined_by_unique": True}, max_value=max_value, ) @@ -365,6 +379,7 @@ def in_range( "in_range", kwargs, error=f"in_range({min_value}, {max_value})", + defaults={"determined_by_unique": True}, min_value=min_value, max_value=max_value, include_min=include_min, @@ -395,6 +410,7 @@ def isin(cls, allowed_values: Iterable, **kwargs) -> "Check": "isin", kwargs, error=f"isin({allowed_values})", + defaults={"determined_by_unique": True}, statistics={"allowed_values": allowed_values}, allowed_values=allowed_values_mod, ) @@ -424,6 +440,7 @@ def notin(cls, forbidden_values: Iterable, **kwargs) -> "Check": "notin", kwargs, error=f"notin({forbidden_values})", + defaults={"determined_by_unique": True}, statistics={"forbidden_values": forbidden_values}, forbidden_values=forbidden_values_mod, ) @@ -445,6 +462,7 @@ def str_matches(cls, pattern: Union[str, re.Pattern], **kwargs) -> "Check": "str_matches", kwargs, error=f"str_matches('{pattern}')", + defaults={"determined_by_unique": True}, statistics={"pattern": pattern}, pattern=pattern, ) @@ -468,6 +486,7 @@ def str_contains( "str_contains", kwargs, error=f"str_contains('{pattern}')", + defaults={"determined_by_unique": True}, statistics={"pattern": pattern}, pattern=pattern, ) @@ -484,6 +503,7 @@ def str_startswith(cls, string: str, **kwargs) -> "Check": "str_startswith", kwargs, error=f"str_startswith('{string}')", + defaults={"determined_by_unique": True}, string=string, ) @@ -498,6 +518,7 @@ def str_endswith(cls, string: str, **kwargs) -> "Check": "str_endswith", kwargs, error=f"str_endswith('{string}')", + defaults={"determined_by_unique": True}, string=string, ) @@ -522,6 +543,7 @@ def str_length( "str_length", kwargs, error=f"str_length({min_value}, {max_value})", + defaults={"determined_by_unique": True}, min_value=min_value, max_value=max_value, ) @@ -546,6 +568,7 @@ def unique_values_eq(cls, values: Iterable, **kwargs) -> "Check": "unique_values_eq", kwargs, error=f"unique_values_eq({values})", + defaults={"determined_by_unique": True}, statistics={"values": values_mod}, values=values_mod, ) diff --git a/pandera/backends/pandas/components.py b/pandera/backends/pandas/components.py index 1b83a1b72..a547c033e 100644 --- a/pandera/backends/pandas/components.py +++ b/pandera/backends/pandas/components.py @@ -460,6 +460,8 @@ def validate( if not inplace: check_obj = check_obj.copy() + validate_full_df = not (head or tail or sample) + # Ensure the object has a MultiIndex if not is_multiindex(check_obj.index): # Allow an exception for a *single-level* Index when the schema also @@ -529,24 +531,40 @@ def validate( # Iterate over the expected index levels and validate each level with its # corresponding ``Index`` schema component. for level_pos, index_schema in level_mapping: - stub_df = pd.DataFrame( - index=check_obj.index.get_level_values(level_pos) - ) # We've already taken care of coercion, so we can disable it now. index_schema = deepcopy(index_schema) index_schema.coerce = False + # Check if we can optimize validation for this level. We skip optimization + # if we're validating only a subset of the data because subsetting the data + # doesn't commute with taking unique values, which can lead to inconsistent + # results. For instance, the check may fail on the first n unique values but + # pass on the first n values. + can_optimize = validate_full_df and self._can_optimize_level( + index_schema + ) + try: - # Validate using the schema for this level - index_schema.validate( - stub_df, - head=head, - tail=tail, - sample=sample, - random_state=random_state, - lazy=lazy, - inplace=True, - ) + if can_optimize: + # Use optimized validation with unique values only + self._validate_level_optimized( + check_obj.index, + level_pos, + index_schema, + lazy=lazy, + ) + else: + # Fall back to validating all of the values. + self._validate_level_with_full_materialization( + check_obj.index, + level_pos, + index_schema, + head=head, + tail=tail, + sample=sample, + random_state=random_state, + lazy=lazy, + ) except (SchemaError, SchemaErrors) as exc: self._collect_or_raise(error_handler, exc, schema) @@ -564,6 +582,102 @@ def validate( return check_obj + def _can_optimize_level(self, index_schema) -> bool: + """Check if we can optimize validation for this level. + + :param index_schema: The schema for this level + :returns: True if optimization can be applied to this level + """ + # Check whether all checks are determined by unique values + # Note that if there are no checks all([]) returns True + return all( + self._check_determined_by_unique(check) + for check in index_schema.checks + ) + + def _check_determined_by_unique(self, check) -> bool: + """Determine if a check is determined by unique values only. + + :param check: The check to analyze + :returns: True if the check result is determined by unique values + """ + # Check if the check result is determined by unique values + # All built-in checks that are determined by unique values have this property set + return getattr(check, "determined_by_unique", False) + + def _validate_level_optimized( + self, + multiindex: pd.MultiIndex, + level_pos: int, + index_schema, + lazy: bool = False, + ) -> None: + """Validate a level using unique values optimization. + + :param multiindex: The MultiIndex being validated + :param level_pos: Position of this level in the MultiIndex + :param index_schema: The schema for this level + :param lazy: if True, collect errors instead of raising immediately + """ + try: + # Use unique values. Use the MultiIndex.unique method rather than + # multiindex.levels[level_pos] which can have extra values that + # don't appear in the full data. Additionally, multiindex.unique + # will include nan if present, whereas multiindex.levels[level_pos] + # will not. + unique_values = multiindex.unique(level=level_pos) + unique_stub_df = pd.DataFrame(index=unique_values) + + # Run validation on unique values only, using lazy=False to cut to + # full validation as soon as we hit a failure + index_schema.validate( + unique_stub_df, + lazy=False, + inplace=True, + ) + except (SchemaError, SchemaErrors): + # Validation failed on unique values, need to materialize full values + # for proper error reporting with correct indices + self._validate_level_with_full_materialization( + multiindex, + level_pos, + index_schema, + lazy=lazy, + ) + + def _validate_level_with_full_materialization( + self, + multiindex: pd.MultiIndex, + level_pos: int, + index_schema, + head: Optional[int] = None, + tail: Optional[int] = None, + sample: Optional[int] = None, + random_state: Optional[int] = None, + lazy: bool = False, + ) -> None: + """Validate a level using full materialization. + + This materializes all values (including duplicates) for validation. + Used both as a fallback when optimization isn't possible and when + errors are identified in optimized validation + in order to provide proper error reporting with correct indices. + """ + # Materialize the full level values + full_values = multiindex.get_level_values(level_pos) + full_stub_df = pd.DataFrame(index=full_values) + + # Run validation on full materialized values + index_schema.validate( + full_stub_df, + head=head, + tail=tail, + sample=sample, + random_state=random_state, + lazy=lazy, + inplace=True, + ) + def _check_strict( self, check_obj: pd.MultiIndex, diff --git a/tests/pandas/test_schema_components.py b/tests/pandas/test_schema_components.py index b4d65b773..fb5594706 100644 --- a/tests/pandas/test_schema_components.py +++ b/tests/pandas/test_schema_components.py @@ -2,6 +2,7 @@ import copy from typing import Any, Optional +from unittest.mock import patch, MagicMock import numpy as np import pandas as pd @@ -890,6 +891,207 @@ def test_multiindex_incorrect_input(indexes) -> None: MultiIndex(indexes) +@pytest.mark.parametrize( + "schema,expected_optimized_calls,expected_full_calls,expected_optimized_levels,expected_full_levels", + [ + # All optimizable checks -> optimized path for both levels + ( + DataFrameSchema( + columns={"value": Column(int)}, + index=MultiIndex( + [ + Index( + String, + checks=[ + Check.str_matches( + r"^(cat|dog)$" + ), # Optimizable + Check.isin(["cat", "dog"]), # Optimizable + ], + name="animal", + ), + Index( + Int, + checks=[ + Check.greater_than_or_equal_to( + 0 + ), # Optimizable + Check.less_than(1000), # Optimizable + ], + name="id", + ), + ] + ), + ), + 2, + 0, + [0, 1], + [], + ), + # Mixed checks -> full materialization for level with non-optimizable, optimized for others + ( + DataFrameSchema( + columns={"value": Column(int)}, + index=MultiIndex( + [ + Index( + String, + checks=[ + Check.str_matches( + r"^(cat|dog)$" + ), # Optimizable + Check( + lambda s: len(s) > 50, + determined_by_unique=False, + ), # NOT optimizable + ], + name="animal", + ), + Index( + Int, + checks=[ + Check.greater_than_or_equal_to( + 0 + ), # Optimizable + ], + name="id", + ), + ] + ), + ), + 1, + 1, + [1], + [0], + ), + ], +) +def test_multiindex_optimization_path_selection( + schema: DataFrameSchema, + expected_optimized_calls: int, + expected_full_calls: int, + expected_optimized_levels: list[int], + expected_full_levels: list[int], +) -> None: + """Test that MultiIndex validation chooses the correct optimization path.""" + # Create test MultiIndex with duplicates for optimization benefit + mi = pd.MultiIndex.from_arrays( + [ + ["cat", "dog", "cat", "dog"] * 100, # Lots of duplicates + list(range(400)), + ], + names=["animal", "id"], + ) + df = pd.DataFrame({"value": range(400)}, index=mi) + + # Mock the backend methods to track which path is taken + with ( + patch( + "pandera.backends.pandas.components.MultiIndexBackend._validate_level_optimized" + ) as mock_optimized, + patch( + "pandera.backends.pandas.components.MultiIndexBackend._validate_level_with_full_materialization" + ) as mock_full, + ): + + schema.validate(df) + + # Verify correct number of calls + assert ( + mock_optimized.call_count == expected_optimized_calls + ), f"Expected {expected_optimized_calls} calls to optimized path, got {mock_optimized.call_count}" + assert ( + mock_full.call_count == expected_full_calls + ), f"Expected {expected_full_calls} calls to full materialization, got {mock_full.call_count}" + + # Verify correct levels were called with correct methods + if expected_optimized_calls > 0: + optimized_calls = [ + call[0][1] for call in mock_optimized.call_args_list + ] # Extract level_pos argument + assert sorted(optimized_calls) == sorted( + expected_optimized_levels + ), f"Expected optimized calls for levels {expected_optimized_levels}, got {optimized_calls}" + + if expected_full_calls > 0: + full_calls = [call[0][1] for call in mock_full.call_args_list] + assert sorted(full_calls) == sorted( + expected_full_levels + ), f"Expected full calls for levels {expected_full_levels}, got {full_calls}" + + +@pytest.mark.parametrize( + "checks,expected_can_optimize", + [ + # Schema with all optimizable checks + ([Check.str_matches(r"^test$"), Check.isin(["test"])], True), + # Schema with mixed checks (includes non-optimizable) + ( + [ + Check.str_matches(r"^test$"), + Check(lambda s: len(s) > 100, determined_by_unique=False), + ], + False, + ), + # Schema with no checks + ([], True), + # Schema with only non-optimizable checks + ( + [ + Check( + lambda s: s.nunique() > 10, + determined_by_unique=False, + ) + ], + False, + ), + ], +) +def test_multiindex_can_optimize_level( + checks: list, expected_can_optimize: bool +) -> None: + """Test the _can_optimize_level decision logic.""" + from pandera.backends.pandas.components import MultiIndexBackend + + backend = MultiIndexBackend() + schema = Index(String, checks=checks) + + result = backend._can_optimize_level(schema) + assert result is expected_can_optimize + + +@pytest.mark.parametrize( + "check,expected_supports_optimization", + [ + # Built-in optimizable check + (Check.greater_than(5), True), + # Explicitly non-optimizable check + ( + Check(lambda s: s.nunique() > 10, determined_by_unique=False), + False, + ), + # Custom check marked as optimizable + ( + Check(lambda s: s.str.len() > 2, determined_by_unique=True), + True, + ), + # Built-in optimizable check - isin + (Check.isin(["test"]), True), + # Built-in optimizable check - str_matches + (Check.str_matches(r"^test$"), True), + ], +) +def test_check_determined_by_unique( + check, expected_supports_optimization: bool +) -> None: + """Test individual check support detection for unique optimization.""" + from pandera.backends.pandas.components import MultiIndexBackend + + backend = MultiIndexBackend() + result = backend._check_determined_by_unique(check) + assert result is expected_supports_optimization + + def test_index_validation_pandas_string_dtype(): """Test that pandas string type is correctly validated."""