From 4c4db5f92e0b0a85eece6c04eca2addcc508105f Mon Sep 17 00:00:00 2001 From: Carl Buchholz <32228189+aGuyLearning@users.noreply.github.com> Date: Thu, 27 Feb 2025 16:55:35 +0100 Subject: [PATCH] Datatype Support in Quality Control and Impute (#865) * Enhancement: Add Dask support for explicit imputation * Enhancement: Add Dask support for quality control metrics and imputation tests * Fix test for imputation to handle Dask arrays without raising errors * Refactor quality control metrics functions to streamline computation and improve readability * added expected error * Remove unused Dask import from quality control module * simplify missing value computation * Rename parameter 'arr' to 'mtx' in _compute_obs_metrics no longer creates copy * daskify qc_metrics * Add fixture for array types and update imputation tests for dask arrays * Refactor _compute_var_metrics to prevent modification of the original data matrix and add a test for encoding mode integrity * Add parameterized tests for array types in miceforest imputation * Update missing values handling to include array type in error message and refine parameterized tests for miceforest imputation * Fix array type handling in missing values computation and update test for miceforest imputation * Implement array type handling in load_dataframe function and update tests for miceforest imputation * Remove parameterization for array types in miceforest numerical data imputation test * Update tests/preprocessing/test_quality_control.py Co-authored-by: Eljas Roellin <65244425+eroell@users.noreply.github.com> * Update tests/preprocessing/test_quality_control.py Co-authored-by: Eljas Roellin <65244425+eroell@users.noreply.github.com> * revert deepcopy changes * Fix test to ensure original matrix is not modified after encoding * Remove unused parameters from observation and variable metrics computation functions * Add sparse.csr_matrix to explicit impute array types test case * Parameterize quality control metrics tests to support multiple array types * Remove unused imports from test_quality_control.py * encode blocks dask function * Add pytest fixtures for observation and variable data in tests * Update tests/preprocessing/test_quality_control.py Co-authored-by: Eljas Roellin <65244425+eroell@users.noreply.github.com> * Update tests/preprocessing/test_quality_control.py Co-authored-by: Eljas Roellin <65244425+eroell@users.noreply.github.com> * support dask explicit impute all object types --------- Co-authored-by: eroell Co-authored-by: Lukas Heumos Co-authored-by: Eljas Roellin <65244425+eroell@users.noreply.github.com> --- ehrapy/preprocessing/_imputation.py | 48 +++++++- ehrapy/preprocessing/_quality_control.py | 97 +++++++++------ tests/conftest.py | 53 +++++++- tests/preprocessing/test_imputation.py | 47 ++++++- tests/preprocessing/test_quality_control.py | 129 ++++++++++++-------- 5 files changed, 277 insertions(+), 97 deletions(-) diff --git a/ehrapy/preprocessing/_imputation.py b/ehrapy/preprocessing/_imputation.py index ab68ef5c..0590893b 100644 --- a/ehrapy/preprocessing/_imputation.py +++ b/ehrapy/preprocessing/_imputation.py @@ -2,6 +2,7 @@ import warnings from collections.abc import Iterable +from functools import singledispatch from typing import TYPE_CHECKING, Literal import numpy as np @@ -11,7 +12,7 @@ from sklearn.impute import SimpleImputer from ehrapy import settings -from ehrapy._compat import _check_module_importable +from ehrapy._compat import _check_module_importable, _raise_array_type_not_implemented from ehrapy._progress import spinner from ehrapy.anndata import check_feature_types from ehrapy.anndata.anndata_ext import ( @@ -23,6 +24,13 @@ if TYPE_CHECKING: from anndata import AnnData +try: + import dask.array as da + + DASK_AVAILABLE = True +except ImportError: + DASK_AVAILABLE = False + @spinner("Performing explicit impute") def explicit_impute( @@ -76,7 +84,9 @@ def explicit_impute( imputation_value = _extract_impute_value(replacement, column_name) # only replace if an explicit value got passed or could be extracted from replacement if imputation_value: - _replace_explicit(adata.X[:, idx : idx + 1], imputation_value, impute_empty_strings) + adata.X[:, idx : idx + 1] = _replace_explicit( + adata.X[:, idx : idx + 1], imputation_value, impute_empty_strings + ) else: logger.warning(f"No replace value passed and found for var [not bold green]{column_name}.") else: @@ -87,13 +97,33 @@ def explicit_impute( return adata if copy else None -def _replace_explicit(arr: np.ndarray, replacement: str | int, impute_empty_strings: bool) -> None: +@singledispatch +def _replace_explicit(arr, replacement: str | int, impute_empty_strings: bool) -> None: + _raise_array_type_not_implemented(_replace_explicit, type(arr)) + + +@_replace_explicit.register +def _(arr: np.ndarray, replacement: str | int, impute_empty_strings: bool) -> np.ndarray: """Replace one column or whole X with a value where missing values are stored.""" if not impute_empty_strings: # pragma: no cover impute_conditions = pd.isnull(arr) else: impute_conditions = np.logical_or(pd.isnull(arr), arr == "") arr[impute_conditions] = replacement + return arr + + +if DASK_AVAILABLE: + + @_replace_explicit.register(da.Array) + def _(arr: da.Array, replacement: str | int, impute_empty_strings: bool) -> da.Array: + """Replace one column or whole X with a value where missing values are stored.""" + if not impute_empty_strings: # pragma: no cover + impute_conditions = da.isnull(arr) + else: + impute_conditions = da.logical_or(da.isnull(arr), arr == "") + arr[impute_conditions] = replacement + return arr def _extract_impute_value(replacement: dict[str, str | int], column_name: str) -> str | int | None: @@ -469,12 +499,22 @@ def mice_forest_impute( return adata if copy else None +@singledispatch +def load_dataframe(arr, columns, index): + _raise_array_type_not_implemented(load_dataframe, type(arr)) + + +@load_dataframe.register +def _(arr: np.ndarray, columns, index): + return pd.DataFrame(arr, columns=columns, index=index) + + def _miceforest_impute( adata, var_names, save_all_iterations_data, random_state, inplace, iterations, variable_parameters, verbose ) -> None: import miceforest as mf - data_df = pd.DataFrame(adata.X, columns=adata.var_names, index=adata.obs_names) + data_df = load_dataframe(adata.X, columns=adata.var_names, index=adata.obs_names) data_df = data_df.apply(pd.to_numeric, errors="coerce") if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names): diff --git a/ehrapy/preprocessing/_quality_control.py b/ehrapy/preprocessing/_quality_control.py index 7a018e42..4953422d 100644 --- a/ehrapy/preprocessing/_quality_control.py +++ b/ehrapy/preprocessing/_quality_control.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +from functools import singledispatch from pathlib import Path from typing import TYPE_CHECKING, Literal @@ -9,6 +10,7 @@ from lamin_utils import logger from thefuzz import process +from ehrapy._compat import _raise_array_type_not_implemented from ehrapy.anndata import anndata_to_df from ehrapy.preprocessing._encoding import _get_encoded_features @@ -17,6 +19,13 @@ from anndata import AnnData +try: + import dask.array as da + + DASK_AVAILABLE = True +except ImportError: + DASK_AVAILABLE = False + def qc_metrics( adata: AnnData, qc_vars: Collection[str] = (), layer: str = None @@ -55,55 +64,57 @@ def qc_metrics( >>> obs_qc, var_qc = ep.pp.qc_metrics(adata) >>> obs_qc["missing_values_pct"].plot(kind="hist", bins=20) """ - obs_metrics = _obs_qc_metrics(adata, layer, qc_vars) - var_metrics = _var_qc_metrics(adata, layer) - adata.obs[obs_metrics.columns] = obs_metrics + mtx = adata.X if layer is None else adata.layers[layer] + var_metrics = _compute_var_metrics(mtx, adata) + obs_metrics = _compute_obs_metrics(mtx, adata, qc_vars=qc_vars, log1p=True) + adata.var[var_metrics.columns] = var_metrics + adata.obs[obs_metrics.columns] = obs_metrics return obs_metrics, var_metrics -def _missing_values( - arr: np.ndarray, mode: Literal["abs", "pct"] = "abs", df_type: Literal["obs", "var"] = "obs" -) -> np.ndarray: - """Calculates the absolute or relative amount of missing values. +@singledispatch +def _compute_missing_values(mtx, axis): + _raise_array_type_not_implemented(_compute_missing_values, type(mtx)) - Args: - arr: Numpy array containing a data row which is a subset of X (mtx). - mode: Whether to calculate absolute or percentage of missing values. - df_type: Whether to calculate the proportions for obs or var. One of 'obs' or 'var'. - Returns: - Absolute or relative amount of missing values. - """ - num_missing = pd.isnull(arr).sum() - if mode == "abs": - return num_missing - elif mode == "pct": - total_elements = arr.shape[0] if df_type == "obs" else len(arr) - return (num_missing / total_elements) * 100 +@_compute_missing_values.register +def _(mtx: np.ndarray, axis) -> np.ndarray: + return pd.isnull(mtx).sum(axis) + +if DASK_AVAILABLE: -def _obs_qc_metrics( - adata: AnnData, layer: str = None, qc_vars: Collection[str] = (), log1p: bool = True -) -> pd.DataFrame: + @_compute_missing_values.register + def _(mtx: da.Array, axis) -> np.ndarray: + return da.isnull(mtx).sum(axis).compute() + + +def _compute_obs_metrics( + mtx, + adata: AnnData, + *, + qc_vars: Collection[str] = (), + log1p: bool = True, +): """Calculates quality control metrics for observations. See :func:`~ehrapy.preprocessing._quality_control.calculate_qc_metrics` for a list of calculated metrics. Args: + mtx: Data array. adata: Annotated data matrix. - layer: Layer containing the actual data matrix. qc_vars: A list of previously calculated QC metrics to calculate summary statistics for. log1p: Whether to apply log1p normalization for the QC metrics. Only used with parameter 'qc_vars'. Returns: A Pandas DataFrame with the calculated metrics. """ + obs_metrics = pd.DataFrame(index=adata.obs_names) var_metrics = pd.DataFrame(index=adata.var_names) - mtx = adata.X if layer is None else adata.layers[layer] if "encoding_mode" in adata.var: for original_values_categorical in _get_encoded_features(adata): @@ -120,8 +131,8 @@ def _obs_qc_metrics( ) ) - obs_metrics["missing_values_abs"] = np.apply_along_axis(_missing_values, 1, mtx, mode="abs") - obs_metrics["missing_values_pct"] = np.apply_along_axis(_missing_values, 1, mtx, mode="pct", df_type="obs") + obs_metrics["missing_values_abs"] = _compute_missing_values(mtx, axis=1) + obs_metrics["missing_values_pct"] = (obs_metrics["missing_values_abs"] / mtx.shape[1]) * 100 # Specific QC metrics for qc_var in qc_vars: @@ -136,10 +147,19 @@ def _obs_qc_metrics( return obs_metrics -def _var_qc_metrics(adata: AnnData, layer: str | None = None) -> pd.DataFrame: - var_metrics = pd.DataFrame(index=adata.var_names) - mtx = adata.X if layer is None else adata.layers[layer] +def _compute_var_metrics( + mtx, + adata: AnnData, +): + """Compute variable metrics for quality control. + + Args: + mtx: Data array. + adata: Annotated data matrix. + """ + categorical_indices = np.ndarray([0], dtype=int) + var_metrics = pd.DataFrame(index=adata.var_names) if "encoding_mode" in adata.var.keys(): for original_values_categorical in _get_encoded_features(adata): @@ -157,32 +177,35 @@ def _var_qc_metrics(adata: AnnData, layer: str | None = None) -> pd.DataFrame: mtx[:, index].shape[1], ) categorical_indices = np.concatenate([categorical_indices, index]) + non_categorical_indices = np.ones(mtx.shape[1], dtype=bool) non_categorical_indices[categorical_indices] = False - var_metrics["missing_values_abs"] = np.apply_along_axis(_missing_values, 0, mtx, mode="abs") - var_metrics["missing_values_pct"] = np.apply_along_axis(_missing_values, 0, mtx, mode="pct", df_type="var") + + var_metrics["missing_values_abs"] = _compute_missing_values(mtx, axis=0) + var_metrics["missing_values_pct"] = (var_metrics["missing_values_abs"] / mtx.shape[0]) * 100 var_metrics["mean"] = np.nan var_metrics["median"] = np.nan var_metrics["standard_deviation"] = np.nan var_metrics["min"] = np.nan var_metrics["max"] = np.nan + var_metrics["iqr_outliers"] = np.nan try: var_metrics.loc[non_categorical_indices, "mean"] = np.nanmean( - np.array(mtx[:, non_categorical_indices], dtype=np.float64), axis=0 + mtx[:, non_categorical_indices].astype(np.float64), axis=0 ) var_metrics.loc[non_categorical_indices, "median"] = np.nanmedian( - np.array(mtx[:, non_categorical_indices], dtype=np.float64), axis=0 + mtx[:, non_categorical_indices].astype(np.float64), axis=0 ) var_metrics.loc[non_categorical_indices, "standard_deviation"] = np.nanstd( - np.array(mtx[:, non_categorical_indices], dtype=np.float64), axis=0 + mtx[:, non_categorical_indices].astype(np.float64), axis=0 ) var_metrics.loc[non_categorical_indices, "min"] = np.nanmin( - np.array(mtx[:, non_categorical_indices], dtype=np.float64), axis=0 + mtx[:, non_categorical_indices].astype(np.float64), axis=0 ) var_metrics.loc[non_categorical_indices, "max"] = np.nanmax( - np.array(mtx[:, non_categorical_indices], dtype=np.float64), axis=0 + mtx[:, non_categorical_indices].astype(np.float64), axis=0 ) # Calculate IQR and define IQR outliers diff --git a/tests/conftest.py b/tests/conftest.py index 6c42f8a7..f996ef82 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING import numpy as np +import pandas as pd import pytest from anndata import AnnData from matplotlib.testing.compare import compare_images @@ -29,6 +30,54 @@ def rng(): return np.random.default_rng(seed=42) +@pytest.fixture +def obs_data(): + return { + "disease": ["cancer", "tumor"], + "country": ["Germany", "switzerland"], + "sex": ["male", "female"], + } + + +@pytest.fixture +def var_data(): + return { + "alive": ["yes", "no", "maybe"], + "hospital": ["hospital 1", "hospital 2", "hospital 1"], + "crazy": ["yes", "yes", "yes"], + } + + +@pytest.fixture +def missing_values_adata(obs_data, var_data): + return AnnData( + X=np.array([[0.21, np.nan, 41.42], [np.nan, np.nan, 7.234]], dtype=np.float32), + obs=pd.DataFrame(data=obs_data), + var=pd.DataFrame(data=var_data, index=["Acetaminophen", "hospital", "crazy"]), + ) + + +@pytest.fixture +def lab_measurements_simple_adata(obs_data, var_data): + X = np.array([[73, 0.02, 1.00], [148, 0.25, 3.55]], dtype=np.float32) + return AnnData( + X=X, + obs=pd.DataFrame(data=obs_data), + var=pd.DataFrame(data=var_data, index=["Acetaminophen", "Acetoacetic acid", "Beryllium, toxic"]), + ) + + +@pytest.fixture +def lab_measurements_layer_adata(obs_data, var_data): + X = np.array([[73, 0.02, 1.00], [148, 0.25, 3.55]], dtype=np.float32) + return AnnData( + X=X, + obs=pd.DataFrame(data=obs_data), + var=pd.DataFrame(data=var_data, index=["Acetaminophen", "Acetoacetic acid", "Beryllium, toxic"]), + layers={"layer_copy": X}, + ) + + @pytest.fixture def mimic_2(): adata = ep.dt.mimic_2() @@ -152,10 +201,10 @@ def asarray(a): return np.asarray(a) -def as_dense_dask_array(a): +def as_dense_dask_array(a, chunk_size=1000): import dask.array as da - return da.asarray(a) + return da.from_array(a, chunks=chunk_size) ARRAY_TYPES = (asarray, as_dense_dask_array) diff --git a/tests/preprocessing/test_imputation.py b/tests/preprocessing/test_imputation.py index 21379ef0..adf2fc17 100644 --- a/tests/preprocessing/test_imputation.py +++ b/tests/preprocessing/test_imputation.py @@ -3,9 +3,11 @@ from collections.abc import Iterable from pathlib import Path +import dask.array as da import numpy as np import pytest from anndata import AnnData +from scipy import sparse from sklearn.exceptions import ConvergenceWarning from ehrapy.anndata.anndata_ext import _are_ndarrays_equal, _is_val_missing, _to_dense_matrix @@ -17,7 +19,7 @@ miss_forest_impute, simple_impute, ) -from tests.conftest import TEST_DATA_PATH +from tests.conftest import ARRAY_TYPES, TEST_DATA_PATH CURRENT_DIR = Path(__file__).parent _TEST_PATH = f"{TEST_DATA_PATH}/imputation" @@ -46,6 +48,11 @@ def _base_check_imputation( Raises: AssertionError: If any of the checks fail. """ + # Convert dask arrays to numpy arrays + if isinstance(adata_before_imputation.X, da.Array): + adata_before_imputation.X = adata_before_imputation.X.compute() + if isinstance(adata_after_imputation.X, da.Array): + adata_after_imputation.X = adata_after_imputation.X.compute() layer_before = _to_dense_matrix(adata_before_imputation, before_imputation_layer) layer_after = _to_dense_matrix(adata_after_imputation, after_imputation_layer) @@ -266,6 +273,21 @@ def test_missforest_impute_subset(impute_num_adata): _base_check_imputation(impute_num_adata, adata_imputed, imputed_var_names=var_names) +@pytest.mark.parametrize( + "array_type,expected_error", + [ + (np.array, None), + (da.from_array, NotImplementedError), + (sparse.csr_matrix, NotImplementedError), + ], +) +def test_miceforest_array_types(impute_num_adata, array_type, expected_error): + impute_num_adata.X = array_type(impute_num_adata.X) + if expected_error: + with pytest.raises(expected_error): + mice_forest_impute(impute_num_adata, copy=True) + + @pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.") def test_miceforest_impute_no_copy(impute_iris_adata): adata_not_imputed = impute_iris_adata.copy() @@ -296,7 +318,24 @@ def test_miceforest_impute_numerical_data(impute_iris_adata): _base_check_imputation(adata_not_imputed, impute_iris_adata) -def test_explicit_impute_all(impute_num_adata): +@pytest.mark.parametrize( + "array_type,expected_error", + [ + (np.array, None), + (da.from_array, None), + (sparse.csr_matrix, NotImplementedError), + ], +) +def test_explicit_impute_array_types(impute_num_adata, array_type, expected_error): + impute_num_adata.X = array_type(impute_num_adata.X) + if expected_error: + with pytest.raises(expected_error): + explicit_impute(impute_num_adata, replacement=1011, copy=True) + + +@pytest.mark.parametrize("array_type", ARRAY_TYPES) +def test_explicit_impute_all(array_type, impute_num_adata): + impute_num_adata.X = array_type(impute_num_adata.X) warnings.filterwarnings("ignore", category=FutureWarning) adata_imputed = explicit_impute(impute_num_adata, replacement=1011, copy=True) @@ -304,7 +343,9 @@ def test_explicit_impute_all(impute_num_adata): assert np.sum([adata_imputed.X == 1011]) == 3 -def test_explicit_impute_subset(impute_adata): +@pytest.mark.parametrize("array_type", ARRAY_TYPES) +def test_explicit_impute_subset(impute_adata, array_type): + impute_adata.X = array_type(impute_adata.X) adata_imputed = explicit_impute(impute_adata, replacement={"strcol": "REPLACED", "intcol": 1011}, copy=True) _base_check_imputation(impute_adata, adata_imputed, imputed_var_names=("strcol", "intcol")) diff --git a/tests/preprocessing/test_quality_control.py b/tests/preprocessing/test_quality_control.py index dee27b3c..1ddaf084 100644 --- a/tests/preprocessing/test_quality_control.py +++ b/tests/preprocessing/test_quality_control.py @@ -8,70 +8,53 @@ import ehrapy as ep from ehrapy.io._read import read_csv from ehrapy.preprocessing._encoding import encode -from ehrapy.preprocessing._quality_control import _obs_qc_metrics, _var_qc_metrics, mcar_test -from tests.conftest import TEST_DATA_PATH +from ehrapy.preprocessing._quality_control import _compute_obs_metrics, _compute_var_metrics, mcar_test +from tests.conftest import ARRAY_TYPES, TEST_DATA_PATH, as_dense_dask_array CURRENT_DIR = Path(__file__).parent _TEST_PATH_ENCODE = f"{TEST_DATA_PATH}/encode" -@pytest.fixture -def obs_data(): - return { - "disease": ["cancer", "tumor"], - "country": ["Germany", "switzerland"], - "sex": ["male", "female"], - } +@pytest.mark.parametrize("array_type", ARRAY_TYPES) +def test_qc_metrics_vanilla(array_type, missing_values_adata): + adata = missing_values_adata + adata.X = array_type(adata.X) + modification_copy = adata.copy() + obs_metrics, var_metrics = ep.pp.qc_metrics(adata) + assert np.array_equal(obs_metrics["missing_values_abs"].values, np.array([1, 2])) + assert np.allclose(obs_metrics["missing_values_pct"].values, np.array([33.3333, 66.6667])) -@pytest.fixture -def var_data(): - return { - "alive": ["yes", "no", "maybe"], - "hospital": ["hospital 1", "hospital 2", "hospital 1"], - "crazy": ["yes", "yes", "yes"], - } - - -@pytest.fixture -def missing_values_adata(obs_data, var_data): - return AnnData( - X=np.array([[0.21, np.nan, 41.42], [np.nan, np.nan, 7.234]], dtype=np.float32), - obs=pd.DataFrame(data=obs_data), - var=pd.DataFrame(data=var_data, index=["Acetaminophen", "hospital", "crazy"]), - ) - - -@pytest.fixture -def lab_measurements_simple_adata(obs_data, var_data): - X = np.array([[73, 0.02, 1.00], [148, 0.25, 3.55]], dtype=np.float32) - return AnnData( - X=X, - obs=pd.DataFrame(data=obs_data), - var=pd.DataFrame(data=var_data, index=["Acetaminophen", "Acetoacetic acid", "Beryllium, toxic"]), - ) - + assert np.array_equal(var_metrics["missing_values_abs"].values, np.array([1, 2, 0])) + assert np.allclose(var_metrics["missing_values_pct"].values, np.array([50.0, 100.0, 0.0])) + assert np.allclose(var_metrics["mean"].values, np.array([0.21, np.nan, 24.327]), equal_nan=True) + assert np.allclose(var_metrics["median"].values, np.array([0.21, np.nan, 24.327]), equal_nan=True) + assert np.allclose(var_metrics["min"].values, np.array([0.21, np.nan, 7.234]), equal_nan=True) + assert np.allclose(var_metrics["max"].values, np.array([0.21, np.nan, 41.419998]), equal_nan=True) + assert (~var_metrics["iqr_outliers"]).all() -@pytest.fixture -def lab_measurements_layer_adata(obs_data, var_data): - X = np.array([[73, 0.02, 1.00], [148, 0.25, 3.55]], dtype=np.float32) - return AnnData( - X=X, - obs=pd.DataFrame(data=obs_data), - var=pd.DataFrame(data=var_data, index=["Acetaminophen", "Acetoacetic acid", "Beryllium, toxic"]), - layers={"layer_copy": X}, - ) + # check that none of the columns were modified + for key in modification_copy.obs.keys(): + assert np.array_equal(modification_copy.obs[key], adata.obs[key]) + for key in modification_copy.var.keys(): + assert np.array_equal(modification_copy.var[key], adata.var[key]) -def test_obs_qc_metrics(missing_values_adata): - obs_metrics = _obs_qc_metrics(missing_values_adata) +@pytest.mark.parametrize("array_type", ARRAY_TYPES) +def test_obs_qc_metrics(array_type, missing_values_adata): + missing_values_adata.X = array_type(missing_values_adata.X) + mtx = missing_values_adata.X + obs_metrics = _compute_obs_metrics(mtx, missing_values_adata) assert np.array_equal(obs_metrics["missing_values_abs"].values, np.array([1, 2])) assert np.allclose(obs_metrics["missing_values_pct"].values, np.array([33.3333, 66.6667])) -def test_var_qc_metrics(missing_values_adata): - var_metrics = _var_qc_metrics(missing_values_adata) +@pytest.mark.parametrize("array_type", ARRAY_TYPES) +def test_var_qc_metrics(array_type, missing_values_adata): + missing_values_adata.X = array_type(missing_values_adata.X) + mtx = missing_values_adata.X + var_metrics = _compute_var_metrics(mtx, missing_values_adata) assert np.array_equal(var_metrics["missing_values_abs"].values, np.array([1, 2, 0])) assert np.allclose(var_metrics["missing_values_pct"].values, np.array([50.0, 100.0, 0.0])) @@ -82,19 +65,63 @@ def test_var_qc_metrics(missing_values_adata): assert (~var_metrics["iqr_outliers"]).all() +@pytest.mark.parametrize( + "array_type, expected_error", + [ + (np.array, None), + (as_dense_dask_array, None), + # can't test sparse matrices because they don't support string values + ], +) +def test_obs_qc_metrics_array_types(array_type, expected_error): + adata = read_csv(dataset_path=f"{_TEST_PATH_ENCODE}/dataset1.csv") + adata.X = array_type(adata.X) + mtx = adata.X + if expected_error: + with pytest.raises(expected_error): + _compute_obs_metrics(mtx, adata) + + def test_obs_nan_qc_metrics(): adata = read_csv(dataset_path=f"{_TEST_PATH_ENCODE}/dataset1.csv") adata.X[0][4] = np.nan adata2 = encode(adata, encodings={"one-hot": ["clinic_day"]}) - obs_metrics = _obs_qc_metrics(adata2) + mtx = adata2.X + obs_metrics = _compute_obs_metrics(mtx, adata2) assert obs_metrics.iloc[0].iloc[0] == 1 +@pytest.mark.parametrize( + "array_type, expected_error", + [ + (np.array, None), + (as_dense_dask_array, None), + # can't test sparse matrices because they don't support string values + ], +) +def test_var_qc_metrics_array_types(array_type, expected_error): + adata = read_csv(dataset_path=f"{_TEST_PATH_ENCODE}/dataset1.csv") + adata.X = array_type(adata.X) + mtx = adata.X + if expected_error: + with pytest.raises(expected_error): + _compute_var_metrics(mtx, adata) + + +def test_var_encoding_mode_does_not_modify_original_matrix(): + adata = read_csv(dataset_path=f"{_TEST_PATH_ENCODE}/dataset1.csv") + adata2 = encode(adata, encodings={"one-hot": ["clinic_day"]}) + mtx_copy = adata2.X.copy() + _compute_var_metrics(adata2.X, adata2) + assert np.array_equal(mtx_copy, adata2.X) + + def test_var_nan_qc_metrics(): adata = read_csv(dataset_path=f"{_TEST_PATH_ENCODE}/dataset1.csv") adata.X[0][4] = np.nan adata2 = encode(adata, encodings={"one-hot": ["clinic_day"]}) - var_metrics = _var_qc_metrics(adata2) + mtx = adata2.X + var_metrics = _compute_var_metrics(mtx, adata2) assert var_metrics.iloc[0].iloc[0] == 1 assert var_metrics.iloc[1].iloc[0] == 1 assert var_metrics.iloc[2].iloc[0] == 1