diff --git a/pandera/backends/pyspark/builtin_checks.py b/pandera/backends/pyspark/builtin_checks.py index 4d30eb6f3..104c7cba9 100644 --- a/pandera/backends/pyspark/builtin_checks.py +++ b/pandera/backends/pyspark/builtin_checks.py @@ -1,15 +1,22 @@ """PySpark implementation of built-in checks""" -from typing import Any, Iterable, TypeVar +import re +from typing import Any, Iterable, TypeVar, Union +import pyspark.sql as ps import pyspark.sql.types as pst -from pyspark.sql.functions import col +from pyspark.sql.functions import col, when import pandera.strategies as st from pandera.api.extensions import register_builtin_check from pandera.api.pyspark.types import PysparkDataframeColumnObject -from pandera.backends.pyspark.decorators import register_input_datatypes -from pandera.backends.pyspark.utils import convert_to_list +from pandera.backends.pyspark.decorators import ( + register_input_datatypes, +) +from pandera.backends.pyspark.utils import ( + convert_to_list, + get_full_table_validation, +) T = TypeVar("T") ALL_NUMERIC_TYPE = [ @@ -22,6 +29,7 @@ pst.FloatType, ] ALL_DATE_TYPE = [pst.DateType, pst.TimestampType] +# TODO: Fix the boolean typo in a new PR or in a different commit if that is acceptable BOLEAN_TYPE = pst.BooleanType BINARY_TYPE = pst.BinaryType STRING_TYPE = pst.StringType @@ -36,7 +44,10 @@ ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE, BOLEAN_TYPE ) ) -def equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool: +def equal_to( + data: PysparkDataframeColumnObject, + value: Any, +) -> Union[bool, ps.Column]: """Ensure all elements of a data container equal a certain value. :param data: PysparkDataframeColumnObject column object which is a contains dataframe and column name to do the check @@ -44,6 +55,9 @@ def equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool: equal to this value. """ cond = col(data.column_name) == value + should_validate_full_table = get_full_table_validation() + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -57,7 +71,10 @@ def equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool: ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE, BOLEAN_TYPE ) ) -def not_equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool: +def not_equal_to( + data: PysparkDataframeColumnObject, + value: Any, +) -> Union[bool, ps.Column]: """Ensure no elements of a data container equals a certain value. :param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The keys @@ -65,6 +82,9 @@ def not_equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool: :param value: This value must not occur in the checked """ cond = col(data.column_name) != value + should_validate_full_table = get_full_table_validation() + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -75,7 +95,10 @@ def not_equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool: @register_input_datatypes( acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE) ) -def greater_than(data: PysparkDataframeColumnObject, min_value: Any) -> bool: +def greater_than( + data: PysparkDataframeColumnObject, + min_value: Any, +) -> Union[bool, ps.Column]: """ Ensure values of a data container are strictly greater than a minimum value. @@ -84,6 +107,9 @@ def greater_than(data: PysparkDataframeColumnObject, min_value: Any) -> bool: :param min_value: Lower bound to be exceeded. """ cond = col(data.column_name) > min_value + should_validate_full_table = get_full_table_validation() + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -96,8 +122,9 @@ def greater_than(data: PysparkDataframeColumnObject, min_value: Any) -> bool: acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE) ) def greater_than_or_equal_to( - data: PysparkDataframeColumnObject, min_value: Any -) -> bool: + data: PysparkDataframeColumnObject, + min_value: Any, +) -> Union[bool, ps.Column]: """Ensure all values are greater or equal a certain value. :param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The keys to access the dataframe is "dataframe" and column name using "column_name". @@ -105,6 +132,9 @@ def greater_than_or_equal_to( a type comparable to the dtype of the column datatype of pyspark """ cond = col(data.column_name) >= min_value + should_validate_full_table = get_full_table_validation() + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -116,7 +146,10 @@ def greater_than_or_equal_to( @register_input_datatypes( acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE) ) -def less_than(data: PysparkDataframeColumnObject, max_value: Any) -> bool: +def less_than( + data: PysparkDataframeColumnObject, + max_value: Any, +) -> Union[bool, ps.Column]: """Ensure values of a series are strictly below a maximum value. :param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The keys @@ -128,6 +161,9 @@ def less_than(data: PysparkDataframeColumnObject, max_value: Any) -> bool: if max_value is None: # pragma: no cover raise ValueError("max_value must not be None") cond = col(data.column_name) < max_value + should_validate_full_table = get_full_table_validation() + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -140,8 +176,9 @@ def less_than(data: PysparkDataframeColumnObject, max_value: Any) -> bool: acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE) ) def less_than_or_equal_to( - data: PysparkDataframeColumnObject, max_value: Any -) -> bool: + data: PysparkDataframeColumnObject, + max_value: Any, +) -> Union[bool, ps.Column]: """Ensure values of a series are strictly below a maximum value. :param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The keys @@ -153,6 +190,9 @@ def less_than_or_equal_to( if max_value is None: # pragma: no cover raise ValueError("max_value must not be None") cond = col(data.column_name) <= max_value + should_validate_full_table = get_full_table_validation() + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -170,7 +210,7 @@ def in_range( max_value: T, include_min: bool = True, include_max: bool = True, -): +) -> Union[bool, ps.Column]: """Ensure all values of a column are within an interval. Both endpoints must be a type comparable to the dtype of the @@ -200,7 +240,11 @@ def in_range( if include_max else col(data.column_name) < max_value ) - return data.dataframe.filter(~(cond_right & cond_left)).limit(1).count() == 0 # type: ignore + cond = cond_right & cond_left + should_validate_full_table = get_full_table_validation() + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) + return data.dataframe.filter(~cond).limit(1).count() == 0 # type: ignore @register_builtin_check( @@ -212,7 +256,10 @@ def in_range( ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE ) ) -def isin(data: PysparkDataframeColumnObject, allowed_values: Iterable) -> bool: +def isin( + data: PysparkDataframeColumnObject, + allowed_values: Iterable, +) -> Union[bool, ps.Column]: """Ensure only allowed values occur within a series. Remember it can be a compute intensive check on large dataset. So, use it with caution. @@ -228,14 +275,11 @@ def isin(data: PysparkDataframeColumnObject, allowed_values: Iterable) -> bool: to access the dataframe is "dataframe" and column name using "column_name". :param allowed_values: The set of allowed values. May be any iterable. """ - return ( - data.dataframe.filter( - ~col(data.column_name).isin(list(allowed_values)) - ) - .limit(1) - .count() - == 0 - ) + cond = col(data.column_name).isin(list(allowed_values)) + should_validate_full_table = get_full_table_validation() + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) + return data.dataframe.filter(~cond).limit(1).count() == 0 @register_builtin_check( @@ -248,8 +292,9 @@ def isin(data: PysparkDataframeColumnObject, allowed_values: Iterable) -> bool: ) ) def notin( - data: PysparkDataframeColumnObject, forbidden_values: Iterable -) -> bool: + data: PysparkDataframeColumnObject, + forbidden_values: Iterable, +) -> Union[bool, ps.Column]: """Ensure some defined values don't occur within a series. Remember it can be a compute intensive check on large dataset. So, use it with caution. @@ -264,14 +309,11 @@ def notin( :param forbidden_values: The set of values which should not occur. May be any iterable. """ - return ( - data.dataframe.filter( - col(data.column_name).isin(list(forbidden_values)) - ) - .limit(1) - .count() - == 0 - ) + cond = col(data.column_name).isin(list(forbidden_values)) + should_validate_full_table = get_full_table_validation() + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) + return data.dataframe.filter(cond).limit(1).count() == 0 @register_builtin_check( @@ -279,7 +321,10 @@ def notin( error="str_contains('{pattern}')", ) @register_input_datatypes(acceptable_datatypes=convert_to_list(STRING_TYPE)) -def str_contains(data: PysparkDataframeColumnObject, pattern: str) -> bool: +def str_contains( + data: PysparkDataframeColumnObject, + pattern: re.Pattern, +) -> Union[bool, ps.Column]: """Ensure that a pattern can be found within each row. Remember it can be a compute intensive check on large dataset. So, use it with caution. @@ -288,6 +333,10 @@ def str_contains(data: PysparkDataframeColumnObject, pattern: str) -> bool: to access the dataframe is "dataframe" and column name using "column_name". :param pattern: Regular expression pattern to use for searching """ + cond = col(data.column_name).rlike(pattern.pattern) + should_validate_full_table = get_full_table_validation() + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) return ( data.dataframe.filter(~col(data.column_name).rlike(pattern)) .limit(1) @@ -300,7 +349,10 @@ def str_contains(data: PysparkDataframeColumnObject, pattern: str) -> bool: error="str_startswith('{string}')", ) @register_input_datatypes(acceptable_datatypes=convert_to_list(STRING_TYPE)) -def str_startswith(data: PysparkDataframeColumnObject, string: str) -> bool: +def str_startswith( + data: PysparkDataframeColumnObject, + string: str, +) -> bool: """Ensure that all values start with a certain string. Remember it can be a compute intensive check on large dataset. So, use it with caution. @@ -310,6 +362,9 @@ def str_startswith(data: PysparkDataframeColumnObject, string: str) -> bool: :param string: String all values should start with """ cond = col(data.column_name).startswith(string) + should_validate_full_table = get_full_table_validation() + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -317,7 +372,10 @@ def str_startswith(data: PysparkDataframeColumnObject, string: str) -> bool: strategy=st.str_endswith_strategy, error="str_endswith('{string}')" ) @register_input_datatypes(acceptable_datatypes=convert_to_list(STRING_TYPE)) -def str_endswith(data: PysparkDataframeColumnObject, string: str) -> bool: +def str_endswith( + data: PysparkDataframeColumnObject, + string: str, +) -> bool: """Ensure that all values end with a certain string. Remember it can be a compute intensive check on large dataset. So, use it with caution. @@ -327,4 +385,7 @@ def str_endswith(data: PysparkDataframeColumnObject, string: str) -> bool: :param string: String all values should end with """ cond = col(data.column_name).endswith(string) + should_validate_full_table = get_full_table_validation() + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 diff --git a/pandera/backends/pyspark/utils.py b/pandera/backends/pyspark/utils.py index 4ba71884f..debd49f26 100644 --- a/pandera/backends/pyspark/utils.py +++ b/pandera/backends/pyspark/utils.py @@ -1,5 +1,7 @@ """pyspark backend utilities.""" +from pandera.config import get_config_context, get_config_global + def convert_to_list(*args): """Converts arguments to a list""" @@ -11,3 +13,23 @@ def convert_to_list(*args): converted_list.append(arg) return converted_list + + +def get_full_table_validation(): + """ + Get the full table validation configuration. + - By default, full table validation is disabled for pyspark dataframes for performance reasons. + """ + config_global = get_config_global() + config_ctx = get_config_context(full_table_validation_default=None) + + if config_ctx.full_table_validation is not None: + # use context configuration if specified + return config_ctx.full_table_validation + + if config_global.full_table_validation is not None: + # use global configuration if specified + return config_global.full_table_validation + + # full table validation is disabled by default for pyspark dataframes + return False diff --git a/pandera/config.py b/pandera/config.py index 6042a8f63..300cd3592 100644 --- a/pandera/config.py +++ b/pandera/config.py @@ -32,6 +32,7 @@ class PanderaConfig(BaseModel): export PANDERA_VALIDATION_DEPTH=DATA_ONLY export PANDERA_CACHE_DATAFRAME=True export PANDERA_KEEP_CACHED_DATAFRAME=True + export PANDERA_FULL_TABLE_VALIDATION=True """ validation_enabled: bool = True @@ -42,6 +43,8 @@ class PanderaConfig(BaseModel): validation_depth: Optional[ValidationDepth] = None cache_dataframe: bool = False keep_cached_dataframe: bool = False + # flag used to validate the complete dataframe or not, used to filter invalid rows + full_table_validation: Optional[bool] = None # this config variable should be accessible globally @@ -62,6 +65,10 @@ class PanderaConfig(BaseModel): "PANDERA_KEEP_CACHED_DATAFRAME", False, ), + full_table_validation=os.environ.get( + "PANDERA_FULL_TABLE_VALIDATION", + None, + ), ) _CONTEXT_CONFIG = deepcopy(CONFIG) @@ -73,6 +80,7 @@ def config_context( validation_depth: Optional[ValidationDepth] = None, cache_dataframe: Optional[bool] = None, keep_cached_dataframe: Optional[bool] = None, + full_table_validation: Optional[bool] = None, ): """Temporarily set pandera config options to custom settings.""" _outer_config_ctx = get_config_context(validation_depth_default=None) @@ -86,7 +94,8 @@ def config_context( _CONTEXT_CONFIG.cache_dataframe = cache_dataframe if keep_cached_dataframe is not None: _CONTEXT_CONFIG.keep_cached_dataframe = keep_cached_dataframe - + if full_table_validation is not None: + _CONTEXT_CONFIG.full_table_validation = full_table_validation yield finally: reset_config_context(_outer_config_ctx) @@ -108,6 +117,7 @@ def get_config_context( validation_depth_default: Optional[ ValidationDepth ] = ValidationDepth.SCHEMA_AND_DATA, + full_table_validation_default: Optional[bool] = None, ) -> PanderaConfig: """Gets the configuration context.""" config = deepcopy(_CONTEXT_CONFIG) @@ -115,4 +125,7 @@ def get_config_context( if config.validation_depth is None and validation_depth_default: config.validation_depth = validation_depth_default + if config.full_table_validation is None and full_table_validation_default: + config.full_table_validation = full_table_validation_default + return config diff --git a/tests/core/test_pandas_config.py b/tests/core/test_pandas_config.py index 7868a186d..e7c311685 100644 --- a/tests/core/test_pandas_config.py +++ b/tests/core/test_pandas_config.py @@ -49,12 +49,47 @@ class TestSchema(DataFrameModel): "keep_cached_dataframe": False, "validation_enabled": False, "validation_depth": ValidationDepth.SCHEMA_AND_DATA, + "full_table_validation": None, } assert get_config_context().dict() == expected assert pandera_schema.validate(self.sample_data) is self.sample_data assert TestSchema.validate(self.sample_data) is self.sample_data + def test_full_table_validation_true(self): + """ + Validates the full table validation is true for default pandas backend + #TODO: Need to check if there is anything else that needs to be tested here. + """ + with config_context(full_table_validation=True): + expected = { + "cache_dataframe": False, + "keep_cached_dataframe": False, + "validation_enabled": False, + "validation_depth": ValidationDepth.SCHEMA_AND_DATA, + "full_table_validation": True, + } + pandera_schema = DataFrameSchema( + { + "product": pa.Column( + str, pa.Check(lambda s: s.str.startswith("B")) + ), + "price_val": pa.Column(int), + } + ) + + class TestSchema(DataFrameModel): + """Test Schema class""" + + product: str = pa.Field(str_startswith="B") + price_val: int = pa.Field() + + assert get_config_context().dict() == expected + assert ( + pandera_schema.validate(self.sample_data) is self.sample_data + ) + assert TestSchema.validate(self.sample_data) is self.sample_data + class TestPandasSeriesConfig: """Class to test all the different configs types""" @@ -69,6 +104,7 @@ def test_disable_validation(self): "keep_cached_dataframe": False, "validation_enabled": False, "validation_depth": ValidationDepth.SCHEMA_AND_DATA, + "full_table_validation": None, } pandera_schema = SeriesSchema( int, pa.Check(lambda s: s.value_counts() == 2, element_wise=False) diff --git a/tests/pyspark/test_pyspark_config.py b/tests/pyspark/test_pyspark_config.py index ffb7a2119..274ee51e2 100644 --- a/tests/pyspark/test_pyspark_config.py +++ b/tests/pyspark/test_pyspark_config.py @@ -49,6 +49,7 @@ class TestSchema(DataFrameModel): "validation_depth": ValidationDepth.SCHEMA_AND_DATA, "cache_dataframe": False, "keep_cached_dataframe": False, + "full_table_validation": None, } with config_context(validation_enabled=False): @@ -72,6 +73,7 @@ def test_schema_only(self, spark_session, sample_spark_schema, request): "validation_depth": ValidationDepth.SCHEMA_ONLY, "cache_dataframe": False, "keep_cached_dataframe": False, + "full_table_validation": None, } input_df = spark_df(spark, self.sample_data, sample_spark_schema) @@ -159,6 +161,7 @@ def test_data_only(self, spark_session, sample_spark_schema, request): "validation_depth": ValidationDepth.DATA_ONLY, "cache_dataframe": False, "keep_cached_dataframe": False, + "full_table_validation": None, } input_df = spark_df(spark, self.sample_data, sample_spark_schema) @@ -255,6 +258,7 @@ def test_schema_and_data( "validation_depth": ValidationDepth.SCHEMA_AND_DATA, "cache_dataframe": False, "keep_cached_dataframe": False, + "full_table_validation": None, } input_df = spark_df(spark, self.sample_data, sample_spark_schema) @@ -383,9 +387,23 @@ def test_cache_dataframe_settings( "validation_depth": ValidationDepth.SCHEMA_AND_DATA, "cache_dataframe": cache_dataframe, "keep_cached_dataframe": keep_cached_dataframe, + "full_table_validation": None, } with config_context( cache_dataframe=cache_dataframe, keep_cached_dataframe=keep_cached_dataframe, ): assert get_config_context().dict() == expected + + @pytest.mark.parametrize("full_table_validation", [True, False]) + def test_full_table_validation_settings(self, full_table_validation): + """This function validates that the full table validation is set correctly.""" + expected = { + "validation_enabled": True, + "validation_depth": ValidationDepth.SCHEMA_AND_DATA, + "cache_dataframe": False, + "keep_cached_dataframe": False, + "full_table_validation": full_table_validation, + } + with config_context(full_table_validation=full_table_validation): + assert get_config_context().dict() == expected