From ba02957f4429f70b25791dd231ef3dc2eead9efa Mon Sep 17 00:00:00 2001 From: Zaheer Abbas Date: Mon, 13 May 2024 01:46:38 +0530 Subject: [PATCH 1/5] Add `PANDERA_FULL_TABLE_VALIDATION` config for full table validation - Add full table validation support for pyspark backend Signed-off-by: Zaheer Abbas --- pandera/backends/pyspark/utils.py | 13 +++++++++++++ pandera/config.py | 11 ++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/pandera/backends/pyspark/utils.py b/pandera/backends/pyspark/utils.py index 4ba71884f..64b58f951 100644 --- a/pandera/backends/pyspark/utils.py +++ b/pandera/backends/pyspark/utils.py @@ -1,5 +1,7 @@ """pyspark backend utilities.""" +from pandera.config import get_config_context + def convert_to_list(*args): """Converts arguments to a list""" @@ -11,3 +13,14 @@ def convert_to_list(*args): converted_list.append(arg) return converted_list + + +def get_full_table_validation(): + """ + Get the full table validation configuration. + - By default, full table validation is disabled for pyspark dataframes for performance reasons. + """ + config = get_config_context() + if config.full_table_validation is not None: + return config.full_table_validation + return False diff --git a/pandera/config.py b/pandera/config.py index 6042a8f63..a7f60fbf5 100644 --- a/pandera/config.py +++ b/pandera/config.py @@ -32,6 +32,7 @@ class PanderaConfig(BaseModel): export PANDERA_VALIDATION_DEPTH=DATA_ONLY export PANDERA_CACHE_DATAFRAME=True export PANDERA_KEEP_CACHED_DATAFRAME=True + export PANDERA_FULL_TABLE_VALIDATION=True """ validation_enabled: bool = True @@ -42,6 +43,8 @@ class PanderaConfig(BaseModel): validation_depth: Optional[ValidationDepth] = None cache_dataframe: bool = False keep_cached_dataframe: bool = False + # flag used to validate the complete dataframe or not, used to filter invalid rows + full_table_validation: Optional[bool] = None # this config variable should be accessible globally @@ -62,6 +65,10 @@ class PanderaConfig(BaseModel): "PANDERA_KEEP_CACHED_DATAFRAME", False, ), + full_table_validation=os.environ.get( + "PANDERA_FULL_TABLE_VALIDATION", + None, + ), ) _CONTEXT_CONFIG = deepcopy(CONFIG) @@ -73,6 +80,7 @@ def config_context( validation_depth: Optional[ValidationDepth] = None, cache_dataframe: Optional[bool] = None, keep_cached_dataframe: Optional[bool] = None, + full_table_validation: Optional[bool] = None, ): """Temporarily set pandera config options to custom settings.""" _outer_config_ctx = get_config_context(validation_depth_default=None) @@ -86,7 +94,8 @@ def config_context( _CONTEXT_CONFIG.cache_dataframe = cache_dataframe if keep_cached_dataframe is not None: _CONTEXT_CONFIG.keep_cached_dataframe = keep_cached_dataframe - + if full_table_validation is not None: + _CONTEXT_CONFIG.full_table_validation = full_table_validation yield finally: reset_config_context(_outer_config_ctx) From eed319c29dc5cb44255fa65ca91e9fa154e5336a Mon Sep 17 00:00:00 2001 From: Zaheer Abbas Date: Mon, 13 May 2024 01:49:30 +0530 Subject: [PATCH 2/5] Modify all builtin_checks for pyspark backend to support full table validation Signed-off-by: Zaheer Abbas --- pandera/backends/pyspark/builtin_checks.py | 152 +++++++++++++++------ pandera/backends/pyspark/decorators.py | 22 +++ 2 files changed, 129 insertions(+), 45 deletions(-) diff --git a/pandera/backends/pyspark/builtin_checks.py b/pandera/backends/pyspark/builtin_checks.py index 0725fbafd..5d1cdd7b2 100644 --- a/pandera/backends/pyspark/builtin_checks.py +++ b/pandera/backends/pyspark/builtin_checks.py @@ -1,16 +1,21 @@ """PySpark implementation of built-in checks""" import re -from typing import Any, Iterable, TypeVar - +from typing import Any, Iterable, TypeVar, Union +import pyspark.sql as ps import pyspark.sql.types as pst -from pyspark.sql.functions import col +from pyspark.sql.functions import col, when import pandera.strategies as st from pandera.api.extensions import register_builtin_check from pandera.api.pyspark.types import PysparkDataframeColumnObject -from pandera.backends.pyspark.decorators import register_input_datatypes -from pandera.backends.pyspark.utils import convert_to_list +from pandera.backends.pyspark.decorators import ( + builtin_check_validation_mode, + register_input_datatypes, +) +from pandera.backends.pyspark.utils import ( + convert_to_list, +) T = TypeVar("T") ALL_NUMERIC_TYPE = [ @@ -37,7 +42,12 @@ ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE, BOLEAN_TYPE ) ) -def equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool: +@builtin_check_validation_mode() +def equal_to( + data: PysparkDataframeColumnObject, + value: Any, + should_validate_full_table: bool, +) -> Union[bool, ps.Column]: """Ensure all elements of a data container equal a certain value. :param data: PysparkDataframeColumnObject column object which is a contains dataframe and column name to do the check @@ -45,6 +55,8 @@ def equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool: equal to this value. """ cond = col(data.column_name) == value + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -58,7 +70,12 @@ def equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool: ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE, BOLEAN_TYPE ) ) -def not_equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool: +@builtin_check_validation_mode() +def not_equal_to( + data: PysparkDataframeColumnObject, + value: Any, + should_validate_full_table: bool, +) -> Union[bool, ps.Column]: """Ensure no elements of a data container equals a certain value. :param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The keys @@ -66,6 +83,8 @@ def not_equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool: :param value: This value must not occur in the checked """ cond = col(data.column_name) != value + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -76,7 +95,12 @@ def not_equal_to(data: PysparkDataframeColumnObject, value: Any) -> bool: @register_input_datatypes( acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE) ) -def greater_than(data: PysparkDataframeColumnObject, min_value: Any) -> bool: +@builtin_check_validation_mode() +def greater_than( + data: PysparkDataframeColumnObject, + min_value: Any, + should_validate_full_table: bool, +) -> Union[bool, ps.Column]: """ Ensure values of a data container are strictly greater than a minimum value. @@ -85,6 +109,8 @@ def greater_than(data: PysparkDataframeColumnObject, min_value: Any) -> bool: :param min_value: Lower bound to be exceeded. """ cond = col(data.column_name) > min_value + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -96,9 +122,12 @@ def greater_than(data: PysparkDataframeColumnObject, min_value: Any) -> bool: @register_input_datatypes( acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE) ) +@builtin_check_validation_mode() def greater_than_or_equal_to( - data: PysparkDataframeColumnObject, min_value: Any -) -> bool: + data: PysparkDataframeColumnObject, + min_value: Any, + should_validate_full_table: bool, +) -> Union[bool, ps.Column]: """Ensure all values are greater or equal a certain value. :param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The keys to access the dataframe is "dataframe" and column name using "column_name". @@ -106,6 +135,8 @@ def greater_than_or_equal_to( a type comparable to the dtype of the column datatype of pyspark """ cond = col(data.column_name) >= min_value + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -117,7 +148,12 @@ def greater_than_or_equal_to( @register_input_datatypes( acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE) ) -def less_than(data: PysparkDataframeColumnObject, max_value: Any) -> bool: +@builtin_check_validation_mode() +def less_than( + data: PysparkDataframeColumnObject, + max_value: Any, + should_validate_full_table: bool, +) -> Union[bool, ps.Column]: """Ensure values of a series are strictly below a maximum value. :param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The keys @@ -129,6 +165,8 @@ def less_than(data: PysparkDataframeColumnObject, max_value: Any) -> bool: if max_value is None: # pragma: no cover raise ValueError("max_value must not be None") cond = col(data.column_name) < max_value + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -140,9 +178,12 @@ def less_than(data: PysparkDataframeColumnObject, max_value: Any) -> bool: @register_input_datatypes( acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE) ) +@builtin_check_validation_mode() def less_than_or_equal_to( - data: PysparkDataframeColumnObject, max_value: Any -) -> bool: + data: PysparkDataframeColumnObject, + max_value: Any, + should_validate_full_table: bool, +) -> Union[bool, ps.Column]: """Ensure values of a series are strictly below a maximum value. :param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The keys @@ -154,6 +195,8 @@ def less_than_or_equal_to( if max_value is None: # pragma: no cover raise ValueError("max_value must not be None") cond = col(data.column_name) <= max_value + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -165,13 +208,15 @@ def less_than_or_equal_to( @register_input_datatypes( acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE) ) +@builtin_check_validation_mode() def in_range( data: PysparkDataframeColumnObject, min_value: T, max_value: T, include_min: bool = True, include_max: bool = True, -): + should_validate_full_table: bool = False, +) -> Union[bool, ps.Column]: """Ensure all values of a column are within an interval. Both endpoints must be a type comparable to the dtype of the @@ -201,7 +246,10 @@ def in_range( if include_max else col(data.column_name) < max_value ) - return data.dataframe.filter(~(cond_right & cond_left)).limit(1).count() == 0 # type: ignore + cond = cond_right & cond_left + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) + return data.dataframe.filter(~cond).limit(1).count() == 0 # type: ignore @register_builtin_check( @@ -213,7 +261,12 @@ def in_range( ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE ) ) -def isin(data: PysparkDataframeColumnObject, allowed_values: Iterable) -> bool: +@builtin_check_validation_mode() +def isin( + data: PysparkDataframeColumnObject, + allowed_values: Iterable, + should_validate_full_table: bool, +) -> Union[bool, ps.Column]: """Ensure only allowed values occur within a series. Remember it can be a compute intensive check on large dataset. So, use it with caution. @@ -229,14 +282,10 @@ def isin(data: PysparkDataframeColumnObject, allowed_values: Iterable) -> bool: to access the dataframe is "dataframe" and column name using "column_name". :param allowed_values: The set of allowed values. May be any iterable. """ - return ( - data.dataframe.filter( - ~col(data.column_name).isin(list(allowed_values)) - ) - .limit(1) - .count() - == 0 - ) + cond = col(data.column_name).isin(list(allowed_values)) + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) + return data.dataframe.filter(~cond).limit(1).count() == 0 @register_builtin_check( @@ -248,9 +297,12 @@ def isin(data: PysparkDataframeColumnObject, allowed_values: Iterable) -> bool: ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE ) ) +@builtin_check_validation_mode() def notin( - data: PysparkDataframeColumnObject, forbidden_values: Iterable -) -> bool: + data: PysparkDataframeColumnObject, + forbidden_values: Iterable, + should_validate_full_table: bool, +) -> Union[bool, ps.Column]: """Ensure some defined values don't occur within a series. Remember it can be a compute intensive check on large dataset. So, use it with caution. @@ -265,14 +317,10 @@ def notin( :param forbidden_values: The set of values which should not occur. May be any iterable. """ - return ( - data.dataframe.filter( - col(data.column_name).isin(list(forbidden_values)) - ) - .limit(1) - .count() - == 0 - ) + cond = col(data.column_name).isin(list(forbidden_values)) + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) + return data.dataframe.filter(cond).limit(1).count() == 0 @register_builtin_check( @@ -280,9 +328,12 @@ def notin( error="str_contains('{pattern}')", ) @register_input_datatypes(acceptable_datatypes=convert_to_list(STRING_TYPE)) +@builtin_check_validation_mode() def str_contains( - data: PysparkDataframeColumnObject, pattern: re.Pattern -) -> bool: + data: PysparkDataframeColumnObject, + pattern: re.Pattern, + should_validate_full_table: bool, +) -> Union[bool, ps.Column]: """Ensure that a pattern can be found within each row. Remember it can be a compute intensive check on large dataset. So, use it with caution. @@ -291,20 +342,22 @@ def str_contains( to access the dataframe is "dataframe" and column name using "column_name". :param pattern: Regular expression pattern to use for searching """ - - return ( - data.dataframe.filter(~col(data.column_name).rlike(pattern.pattern)) - .limit(1) - .count() - == 0 - ) + cond = col(data.column_name).rlike(pattern.pattern) + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) + return data.dataframe.filter(~cond).limit(1).count() == 0 @register_builtin_check( error="str_startswith('{string}')", ) @register_input_datatypes(acceptable_datatypes=convert_to_list(STRING_TYPE)) -def str_startswith(data: PysparkDataframeColumnObject, string: str) -> bool: +@builtin_check_validation_mode() +def str_startswith( + data: PysparkDataframeColumnObject, + string: str, + should_validate_full_table: bool, +) -> bool: """Ensure that all values start with a certain string. Remember it can be a compute intensive check on large dataset. So, use it with caution. @@ -314,6 +367,8 @@ def str_startswith(data: PysparkDataframeColumnObject, string: str) -> bool: :param string: String all values should start with """ cond = col(data.column_name).startswith(string) + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -321,7 +376,12 @@ def str_startswith(data: PysparkDataframeColumnObject, string: str) -> bool: strategy=st.str_endswith_strategy, error="str_endswith('{string}')" ) @register_input_datatypes(acceptable_datatypes=convert_to_list(STRING_TYPE)) -def str_endswith(data: PysparkDataframeColumnObject, string: str) -> bool: +@builtin_check_validation_mode() +def str_endswith( + data: PysparkDataframeColumnObject, + string: str, + should_validate_full_table: bool, +) -> bool: """Ensure that all values end with a certain string. Remember it can be a compute intensive check on large dataset. So, use it with caution. @@ -331,4 +391,6 @@ def str_endswith(data: PysparkDataframeColumnObject, string: str) -> bool: :param string: String all values should end with """ cond = col(data.column_name).endswith(string) + if should_validate_full_table: + return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 diff --git a/pandera/backends/pyspark/decorators.py b/pandera/backends/pyspark/decorators.py index c7bc9b928..fb847a5ea 100644 --- a/pandera/backends/pyspark/decorators.py +++ b/pandera/backends/pyspark/decorators.py @@ -9,6 +9,7 @@ from pyspark.sql import DataFrame from pandera.api.pyspark.types import PysparkDefaultTypes +from pandera.backends.pyspark.utils import get_full_table_validation from pandera.config import ValidationDepth, get_config_context from pandera.errors import SchemaError from pandera.validation_depth import ValidationScope @@ -192,3 +193,24 @@ def cached_check_obj(): return wrapper return _wrapper + + +def builtin_check_validation_mode(): + """ + Evaluates whether the full table validation is enabled or not for a builtin check and passes it to the function. + """ + + def _wrapper(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + # Skip if not enabled + should_validate_full_table = get_full_table_validation() + return func( + *args, + **kwargs, + should_validate_full_table=should_validate_full_table, + ) + + return wrapper + + return _wrapper From cd2f76ce5a24739422668c9e595533be776f39b6 Mon Sep 17 00:00:00 2001 From: Zaheer Abbas Date: Sat, 18 May 2024 23:36:19 +0530 Subject: [PATCH 3/5] Use the same function signature for builtin_checks across backends - Remove unused decorators Signed-off-by: Zaheer Abbas --- pandera/backends/pyspark/builtin_checks.py | 39 ++++++++-------------- pandera/backends/pyspark/decorators.py | 22 ------------ 2 files changed, 14 insertions(+), 47 deletions(-) diff --git a/pandera/backends/pyspark/builtin_checks.py b/pandera/backends/pyspark/builtin_checks.py index 5d1cdd7b2..41cd12ee7 100644 --- a/pandera/backends/pyspark/builtin_checks.py +++ b/pandera/backends/pyspark/builtin_checks.py @@ -10,11 +10,11 @@ from pandera.api.extensions import register_builtin_check from pandera.api.pyspark.types import PysparkDataframeColumnObject from pandera.backends.pyspark.decorators import ( - builtin_check_validation_mode, register_input_datatypes, ) from pandera.backends.pyspark.utils import ( convert_to_list, + get_full_table_validation, ) T = TypeVar("T") @@ -28,6 +28,7 @@ pst.FloatType, ] ALL_DATE_TYPE = [pst.DateType, pst.TimestampType] +# TODO: Fix the boolean typo in a new PR or in a different commit if that is acceptable BOLEAN_TYPE = pst.BooleanType BINARY_TYPE = pst.BinaryType STRING_TYPE = pst.StringType @@ -42,11 +43,9 @@ ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE, BOLEAN_TYPE ) ) -@builtin_check_validation_mode() def equal_to( data: PysparkDataframeColumnObject, value: Any, - should_validate_full_table: bool, ) -> Union[bool, ps.Column]: """Ensure all elements of a data container equal a certain value. @@ -55,6 +54,7 @@ def equal_to( equal to this value. """ cond = col(data.column_name) == value + should_validate_full_table = get_full_table_validation() if should_validate_full_table: return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -70,11 +70,9 @@ def equal_to( ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE, BOLEAN_TYPE ) ) -@builtin_check_validation_mode() def not_equal_to( data: PysparkDataframeColumnObject, value: Any, - should_validate_full_table: bool, ) -> Union[bool, ps.Column]: """Ensure no elements of a data container equals a certain value. @@ -83,6 +81,7 @@ def not_equal_to( :param value: This value must not occur in the checked """ cond = col(data.column_name) != value + should_validate_full_table = get_full_table_validation() if should_validate_full_table: return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -95,11 +94,9 @@ def not_equal_to( @register_input_datatypes( acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE) ) -@builtin_check_validation_mode() def greater_than( data: PysparkDataframeColumnObject, min_value: Any, - should_validate_full_table: bool, ) -> Union[bool, ps.Column]: """ Ensure values of a data container are strictly greater than a minimum @@ -109,6 +106,7 @@ def greater_than( :param min_value: Lower bound to be exceeded. """ cond = col(data.column_name) > min_value + should_validate_full_table = get_full_table_validation() if should_validate_full_table: return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -122,11 +120,9 @@ def greater_than( @register_input_datatypes( acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE) ) -@builtin_check_validation_mode() def greater_than_or_equal_to( data: PysparkDataframeColumnObject, min_value: Any, - should_validate_full_table: bool, ) -> Union[bool, ps.Column]: """Ensure all values are greater or equal a certain value. :param data: NamedTuple PysparkDataframeColumnObject contains the dataframe and column name for the check. The keys @@ -135,6 +131,7 @@ def greater_than_or_equal_to( a type comparable to the dtype of the column datatype of pyspark """ cond = col(data.column_name) >= min_value + should_validate_full_table = get_full_table_validation() if should_validate_full_table: return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -148,11 +145,9 @@ def greater_than_or_equal_to( @register_input_datatypes( acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE) ) -@builtin_check_validation_mode() def less_than( data: PysparkDataframeColumnObject, max_value: Any, - should_validate_full_table: bool, ) -> Union[bool, ps.Column]: """Ensure values of a series are strictly below a maximum value. @@ -165,6 +160,7 @@ def less_than( if max_value is None: # pragma: no cover raise ValueError("max_value must not be None") cond = col(data.column_name) < max_value + should_validate_full_table = get_full_table_validation() if should_validate_full_table: return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -178,11 +174,9 @@ def less_than( @register_input_datatypes( acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE) ) -@builtin_check_validation_mode() def less_than_or_equal_to( data: PysparkDataframeColumnObject, max_value: Any, - should_validate_full_table: bool, ) -> Union[bool, ps.Column]: """Ensure values of a series are strictly below a maximum value. @@ -195,6 +189,7 @@ def less_than_or_equal_to( if max_value is None: # pragma: no cover raise ValueError("max_value must not be None") cond = col(data.column_name) <= max_value + should_validate_full_table = get_full_table_validation() if should_validate_full_table: return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -208,14 +203,12 @@ def less_than_or_equal_to( @register_input_datatypes( acceptable_datatypes=convert_to_list(ALL_NUMERIC_TYPE, ALL_DATE_TYPE) ) -@builtin_check_validation_mode() def in_range( data: PysparkDataframeColumnObject, min_value: T, max_value: T, include_min: bool = True, include_max: bool = True, - should_validate_full_table: bool = False, ) -> Union[bool, ps.Column]: """Ensure all values of a column are within an interval. @@ -247,6 +240,7 @@ def in_range( else col(data.column_name) < max_value ) cond = cond_right & cond_left + should_validate_full_table = get_full_table_validation() if should_validate_full_table: return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 # type: ignore @@ -261,11 +255,9 @@ def in_range( ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE ) ) -@builtin_check_validation_mode() def isin( data: PysparkDataframeColumnObject, allowed_values: Iterable, - should_validate_full_table: bool, ) -> Union[bool, ps.Column]: """Ensure only allowed values occur within a series. @@ -283,6 +275,7 @@ def isin( :param allowed_values: The set of allowed values. May be any iterable. """ cond = col(data.column_name).isin(list(allowed_values)) + should_validate_full_table = get_full_table_validation() if should_validate_full_table: return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -297,11 +290,9 @@ def isin( ALL_NUMERIC_TYPE, ALL_DATE_TYPE, STRING_TYPE, BINARY_TYPE ) ) -@builtin_check_validation_mode() def notin( data: PysparkDataframeColumnObject, forbidden_values: Iterable, - should_validate_full_table: bool, ) -> Union[bool, ps.Column]: """Ensure some defined values don't occur within a series. @@ -318,6 +309,7 @@ def notin( be any iterable. """ cond = col(data.column_name).isin(list(forbidden_values)) + should_validate_full_table = get_full_table_validation() if should_validate_full_table: return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(cond).limit(1).count() == 0 @@ -328,11 +320,9 @@ def notin( error="str_contains('{pattern}')", ) @register_input_datatypes(acceptable_datatypes=convert_to_list(STRING_TYPE)) -@builtin_check_validation_mode() def str_contains( data: PysparkDataframeColumnObject, pattern: re.Pattern, - should_validate_full_table: bool, ) -> Union[bool, ps.Column]: """Ensure that a pattern can be found within each row. @@ -343,6 +333,7 @@ def str_contains( :param pattern: Regular expression pattern to use for searching """ cond = col(data.column_name).rlike(pattern.pattern) + should_validate_full_table = get_full_table_validation() if should_validate_full_table: return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -352,11 +343,9 @@ def str_contains( error="str_startswith('{string}')", ) @register_input_datatypes(acceptable_datatypes=convert_to_list(STRING_TYPE)) -@builtin_check_validation_mode() def str_startswith( data: PysparkDataframeColumnObject, string: str, - should_validate_full_table: bool, ) -> bool: """Ensure that all values start with a certain string. @@ -367,6 +356,7 @@ def str_startswith( :param string: String all values should start with """ cond = col(data.column_name).startswith(string) + should_validate_full_table = get_full_table_validation() if should_validate_full_table: return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 @@ -376,11 +366,9 @@ def str_startswith( strategy=st.str_endswith_strategy, error="str_endswith('{string}')" ) @register_input_datatypes(acceptable_datatypes=convert_to_list(STRING_TYPE)) -@builtin_check_validation_mode() def str_endswith( data: PysparkDataframeColumnObject, string: str, - should_validate_full_table: bool, ) -> bool: """Ensure that all values end with a certain string. @@ -391,6 +379,7 @@ def str_endswith( :param string: String all values should end with """ cond = col(data.column_name).endswith(string) + should_validate_full_table = get_full_table_validation() if should_validate_full_table: return data.dataframe.select(when(cond, True).otherwise(False)) return data.dataframe.filter(~cond).limit(1).count() == 0 diff --git a/pandera/backends/pyspark/decorators.py b/pandera/backends/pyspark/decorators.py index fb847a5ea..c7bc9b928 100644 --- a/pandera/backends/pyspark/decorators.py +++ b/pandera/backends/pyspark/decorators.py @@ -9,7 +9,6 @@ from pyspark.sql import DataFrame from pandera.api.pyspark.types import PysparkDefaultTypes -from pandera.backends.pyspark.utils import get_full_table_validation from pandera.config import ValidationDepth, get_config_context from pandera.errors import SchemaError from pandera.validation_depth import ValidationScope @@ -193,24 +192,3 @@ def cached_check_obj(): return wrapper return _wrapper - - -def builtin_check_validation_mode(): - """ - Evaluates whether the full table validation is enabled or not for a builtin check and passes it to the function. - """ - - def _wrapper(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - # Skip if not enabled - should_validate_full_table = get_full_table_validation() - return func( - *args, - **kwargs, - should_validate_full_table=should_validate_full_table, - ) - - return wrapper - - return _wrapper From 1f952d1fe609d7204a1c848ad41c38f4c81b0b22 Mon Sep 17 00:00:00 2001 From: Zaheer Abbas Date: Sun, 19 May 2024 15:32:24 +0530 Subject: [PATCH 4/5] Add full_table_validation as a flag to config context func - Will help to use the flag in backend validate functions Signed-off-by: Zaheer Abbas --- pandera/backends/pyspark/utils.py | 17 +++++++++++++---- pandera/config.py | 4 ++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/pandera/backends/pyspark/utils.py b/pandera/backends/pyspark/utils.py index 64b58f951..debd49f26 100644 --- a/pandera/backends/pyspark/utils.py +++ b/pandera/backends/pyspark/utils.py @@ -1,6 +1,6 @@ """pyspark backend utilities.""" -from pandera.config import get_config_context +from pandera.config import get_config_context, get_config_global def convert_to_list(*args): @@ -20,7 +20,16 @@ def get_full_table_validation(): Get the full table validation configuration. - By default, full table validation is disabled for pyspark dataframes for performance reasons. """ - config = get_config_context() - if config.full_table_validation is not None: - return config.full_table_validation + config_global = get_config_global() + config_ctx = get_config_context(full_table_validation_default=None) + + if config_ctx.full_table_validation is not None: + # use context configuration if specified + return config_ctx.full_table_validation + + if config_global.full_table_validation is not None: + # use global configuration if specified + return config_global.full_table_validation + + # full table validation is disabled by default for pyspark dataframes return False diff --git a/pandera/config.py b/pandera/config.py index a7f60fbf5..300cd3592 100644 --- a/pandera/config.py +++ b/pandera/config.py @@ -117,6 +117,7 @@ def get_config_context( validation_depth_default: Optional[ ValidationDepth ] = ValidationDepth.SCHEMA_AND_DATA, + full_table_validation_default: Optional[bool] = None, ) -> PanderaConfig: """Gets the configuration context.""" config = deepcopy(_CONTEXT_CONFIG) @@ -124,4 +125,7 @@ def get_config_context( if config.validation_depth is None and validation_depth_default: config.validation_depth = validation_depth_default + if config.full_table_validation is None and full_table_validation_default: + config.full_table_validation = full_table_validation_default + return config From 14513a6961c77fe6cee57e6423f8a7b2e8f7ccea Mon Sep 17 00:00:00 2001 From: Zaheer Abbas Date: Sun, 19 May 2024 15:33:33 +0530 Subject: [PATCH 5/5] Fix broken tests and add new tests for full_table_validation config - More tests to come for full_table_validation config for built_in_checks after adding support in pyspark backend Signed-off-by: Zaheer Abbas --- tests/core/test_pandas_config.py | 36 ++++++++++++++++++++++++++++ tests/pyspark/test_pyspark_config.py | 18 ++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/tests/core/test_pandas_config.py b/tests/core/test_pandas_config.py index 7868a186d..e7c311685 100644 --- a/tests/core/test_pandas_config.py +++ b/tests/core/test_pandas_config.py @@ -49,12 +49,47 @@ class TestSchema(DataFrameModel): "keep_cached_dataframe": False, "validation_enabled": False, "validation_depth": ValidationDepth.SCHEMA_AND_DATA, + "full_table_validation": None, } assert get_config_context().dict() == expected assert pandera_schema.validate(self.sample_data) is self.sample_data assert TestSchema.validate(self.sample_data) is self.sample_data + def test_full_table_validation_true(self): + """ + Validates the full table validation is true for default pandas backend + #TODO: Need to check if there is anything else that needs to be tested here. + """ + with config_context(full_table_validation=True): + expected = { + "cache_dataframe": False, + "keep_cached_dataframe": False, + "validation_enabled": False, + "validation_depth": ValidationDepth.SCHEMA_AND_DATA, + "full_table_validation": True, + } + pandera_schema = DataFrameSchema( + { + "product": pa.Column( + str, pa.Check(lambda s: s.str.startswith("B")) + ), + "price_val": pa.Column(int), + } + ) + + class TestSchema(DataFrameModel): + """Test Schema class""" + + product: str = pa.Field(str_startswith="B") + price_val: int = pa.Field() + + assert get_config_context().dict() == expected + assert ( + pandera_schema.validate(self.sample_data) is self.sample_data + ) + assert TestSchema.validate(self.sample_data) is self.sample_data + class TestPandasSeriesConfig: """Class to test all the different configs types""" @@ -69,6 +104,7 @@ def test_disable_validation(self): "keep_cached_dataframe": False, "validation_enabled": False, "validation_depth": ValidationDepth.SCHEMA_AND_DATA, + "full_table_validation": None, } pandera_schema = SeriesSchema( int, pa.Check(lambda s: s.value_counts() == 2, element_wise=False) diff --git a/tests/pyspark/test_pyspark_config.py b/tests/pyspark/test_pyspark_config.py index d4a4de023..2d93eabe5 100644 --- a/tests/pyspark/test_pyspark_config.py +++ b/tests/pyspark/test_pyspark_config.py @@ -42,6 +42,7 @@ class TestSchema(DataFrameModel): "validation_depth": ValidationDepth.SCHEMA_AND_DATA, "cache_dataframe": False, "keep_cached_dataframe": False, + "full_table_validation": None, } with config_context(validation_enabled=False): @@ -65,6 +66,7 @@ def test_schema_only(self, spark, sample_spark_schema): "validation_depth": ValidationDepth.SCHEMA_ONLY, "cache_dataframe": False, "keep_cached_dataframe": False, + "full_table_validation": None, } input_df = spark_df(spark, self.sample_data, sample_spark_schema) @@ -152,6 +154,7 @@ def test_data_only(self, spark, sample_spark_schema): "validation_depth": ValidationDepth.DATA_ONLY, "cache_dataframe": False, "keep_cached_dataframe": False, + "full_table_validation": None, } input_df = spark_df(spark, self.sample_data, sample_spark_schema) @@ -245,6 +248,7 @@ def test_schema_and_data(self, spark, sample_spark_schema): "validation_depth": ValidationDepth.SCHEMA_AND_DATA, "cache_dataframe": False, "keep_cached_dataframe": False, + "full_table_validation": None, } input_df = spark_df(spark, self.sample_data, sample_spark_schema) @@ -372,9 +376,23 @@ def test_cache_dataframe_settings( "validation_depth": ValidationDepth.SCHEMA_AND_DATA, "cache_dataframe": cache_dataframe, "keep_cached_dataframe": keep_cached_dataframe, + "full_table_validation": None, } with config_context( cache_dataframe=cache_dataframe, keep_cached_dataframe=keep_cached_dataframe, ): assert get_config_context().dict() == expected + + @pytest.mark.parametrize("full_table_validation", [True, False]) + def test_full_table_validation_settings(self, full_table_validation): + """This function validates that the full table validation is set correctly.""" + expected = { + "validation_enabled": True, + "validation_depth": ValidationDepth.SCHEMA_AND_DATA, + "cache_dataframe": False, + "keep_cached_dataframe": False, + "full_table_validation": full_table_validation, + } + with config_context(full_table_validation=full_table_validation): + assert get_config_context().dict() == expected