@@ -34,13 +34,14 @@ def _add_names_to_parents_idx_series(parents):
3434
3535class QDForest (object ):
3636 """A quasi-deterministic forest returned by `qd_screen`"""
37- __slots__ = ('_adjmat' , # a square numpy array or pandas DataFrame containing the adjacency matrix (parent->child)
38- '_parents' , # a 1d np array or a pandas Series relating each child to its parent index or -1 if a root
39- 'is_nparray' , # a boolean indicating if this was built from numpy array (and not pandas dataframe)
40- '_roots_mask' , # a 1d np array or pd Series containing a boolean mask for root variables
41- '_roots_wc_mask' , # a 1d np array or pd Series containing a boolean mask for root with children
42- 'stats' # an optional `Entropies` object stored for debug
43- )
37+ __slots__ = (
38+ '_adjmat' , # a square np array or pd DataFrame containing the adjacency matrix (parent->child)
39+ '_parents' , # a 1d np array or a pandas Series relating each child to its parent index or -1 if a root
40+ 'is_nparray' , # a boolean indicating if this was built from numpy array (and not pandas dataframe)
41+ '_roots_mask' , # a 1d np array or pd Series containing a boolean mask for root variables
42+ '_roots_wc_mask' , # a 1d np array or pd Series containing a boolean mask for root with children
43+ 'stats' # an optional `Entropies` object stored for debug
44+ )
4445
4546 def __init__ (self ,
4647 adjmat = None , # type: Union[np.ndarray, pd.DataFrame]
@@ -129,13 +130,13 @@ def mask_to_indices(self, mask):
129130 @property
130131 def adjmat_ar (self ):
131132 """The adjacency matrix as a 2D numpy array"""
132- return self .adjmat if self .is_nparray else self .adjmat .values
133+ return self .adjmat if self .is_nparray else self .adjmat .values
133134
134135 @property
135136 def adjmat (self ):
136137 """The adjacency matrix as a pandas DataFrame or a 2D numpy array"""
137138 if self ._adjmat is None :
138- # compute adjmat from parents.
139+ # compute adjmat from parents and cache it
139140 n = self .nb_vars
140141 adjmat = np .zeros ((n , n ), dtype = bool )
141142 # from https://stackoverflow.com/a/46018613/7262247
@@ -543,10 +544,30 @@ def plot_increasing_entropies(self):
543544 self .stats .plot_increasing_entropies ()
544545
545546
546- def qd_screen (X , # type: Union[pd.DataFrame, np.ndarray]
547+ def assert_df_or_2D_array (df_or_array # type: Union[pd.DataFrame, np.ndarray]
548+ ):
549+ """
550+ Raises a ValueError if `df_or_array` is
551+
552+ :param df_or_array:
553+ :return:
554+ """
555+ if isinstance (df_or_array , pd .DataFrame ):
556+ pass
557+ elif isinstance (df_or_array , np .ndarray ):
558+ # see https://numpy.org/doc/stable/user/basics.rec.html#manipulating-and-displaying-structured-datatypes
559+ if len (df_or_array .shape ) != 2 :
560+ raise ValueError ("Provided data is not a 2D array, the number of dimensions is %s" % len (df_or_array .shape ))
561+ else :
562+ # Raise error
563+ raise TypeError ("Provided data is neither a `pd.DataFrame` nor a `np.ndarray`" )
564+
565+
566+ def qd_screen (X , # type: Union[pd.DataFrame, np.ndarray]
547567 absolute_eps = None , # type: float
548568 relative_eps = None , # type: float
549- keep_stats = False # type: bool
569+ keep_stats = False , # type: bool
570+ non_categorical_mode = 'strict' ,
550571 ):
551572 # type: (...) -> QDForest
552573 """
@@ -574,12 +595,18 @@ def qd_screen(X, # type: Union[pd.DataFrame, np.ndarray]
574595 memory in the resulting forest object (`<QDForest>.stats`), for further analysis. By default this is `False`.
575596 :return:
576597 """
577- # only work on the categorical features
578- X = get_categorical_features (X )
598+ # Make sure this is a 2D table
599+ assert_df_or_2D_array (X )
579600
580- # sanity check
601+ # Sanity check: are there rows in here ?
581602 if len (X ) == 0 :
582- raise ValueError ("Empty dataset provided" )
603+ raise ValueError ("Provided dataset does not contain any row" )
604+
605+ # Only work on the categorical features
606+ X = get_categorical_features (X , non_categorical_mode = non_categorical_mode )
607+
608+ # Sanity check concerning the number of columns
609+ assert X .shape [1 ] > 0 , "Internal error: no columns remain in dataset after preprocessing."
583610
584611 # parameters check and defaults
585612 if absolute_eps is None :
@@ -1143,28 +1170,49 @@ def get_arcs_from_adjmat(A, # type: Union[np.ndarray, pd.DataFra
11431170 return ((cols [i ], cols [j ]) for i , j in zip (* res_ar ))
11441171
11451172
1146- def get_categorical_features (df_or_array # type: Union[np.ndarray, pd.DataFrame]
1173+ def get_categorical_features (df_or_array , # type: Union[np.ndarray, pd.DataFrame]
1174+ non_categorical_mode = "strict" # type: str
11471175 ):
11481176 # type: (...) -> Union[np.ndarray, pd.DataFrame]
11491177 """
11501178
11511179 :param df_or_array:
1180+ :param non_categorical_mode:
11521181 :return: a dataframe or array with the categorical features
11531182 """
1183+ assert_df_or_2D_array (df_or_array )
1184+
1185+ if non_categorical_mode == "strict" :
1186+ strict_mode = True
1187+ elif non_categorical_mode == "remove" :
1188+ strict_mode = False
1189+ else :
1190+ raise ValueError ("Unsupported value for `non_categorical_mode`: %r" % non_categorical_mode )
1191+
11541192 if isinstance (df_or_array , pd .DataFrame ):
11551193 is_categorical_dtype = df_or_array .dtypes .astype (str ).isin (["object" , "categorical" ])
1156- if not is_categorical_dtype .any ():
1157- raise TypeError ("Provided dataframe columns do not contain any categorical datatype (dtype in 'object' or "
1194+ if strict_mode and not is_categorical_dtype .all ():
1195+ raise ValueError ("Provided dataframe columns contains non-categorical datatypes (dtype in 'object' or "
1196+ "'categorical'): found dtypes %r. This is not supported when `non_categorical_mode` is set to "
1197+ "`'strict'`" % df_or_array .dtypes [~ is_categorical_dtype ].to_dict ())
1198+ elif not is_categorical_dtype .any ():
1199+ raise ValueError ("Provided dataframe columns do not contain any categorical datatype (dtype in 'object' or "
11581200 "'categorical'): found dtypes %r" % df_or_array .dtypes [~ is_categorical_dtype ].to_dict ())
11591201 return df_or_array .loc [:, is_categorical_dtype ]
1202+
11601203 elif isinstance (df_or_array , np .ndarray ):
11611204 # see https://numpy.org/doc/stable/user/basics.rec.html#manipulating-and-displaying-structured-datatypes
11621205 if df_or_array .dtype .names is not None :
11631206 # structured array
11641207 is_categorical_dtype = np .array ([str (df_or_array .dtype .fields [n ][0 ]) == "object"
11651208 for n in df_or_array .dtype .names ])
1166- if not is_categorical_dtype .any ():
1167- raise TypeError (
1209+ if strict_mode and not is_categorical_dtype .all ():
1210+ invalid_dtypes = df_or_array .dtype [~ is_categorical_dtype ].asdict ()
1211+ raise ValueError ("Provided numpy array columns contains non-categorical datatypes ('object' dtype): "
1212+ "found dtypes %r. This is not supported when `non_categorical_mode` is set to "
1213+ "`'strict'`" % invalid_dtypes )
1214+ elif not is_categorical_dtype .any ():
1215+ raise ValueError (
11681216 "Provided dataframe columns do not contain any categorical datatype (dtype in 'object' or "
11691217 "'categorical'): found dtypes %r" % df_or_array .dtype .fields )
11701218 categorical_names = np .array (df_or_array .dtype .names )[is_categorical_dtype ]
@@ -1176,6 +1224,7 @@ def get_categorical_features(df_or_array # type: Union[np.ndarray, pd.DataFrame
11761224 % df_or_array .dtype )
11771225 return df_or_array
11781226 else :
1227+ # Should not happen since `assert_df_or_2D_array` is called upfront now.
11791228 raise TypeError ("Provided data is neither a pd.DataFrame nor a np.ndarray" )
11801229
11811230
0 commit comments