Skip to content

Commit

Permalink
Datatype Support in Quality Control and Impute (#865)
Browse files Browse the repository at this point in the history
* Enhancement: Add Dask support for explicit imputation

* Enhancement: Add Dask support for quality control metrics and imputation tests

* Fix test for imputation to handle Dask arrays without raising errors

* Refactor quality control metrics functions to streamline computation and improve readability

* added expected error

* Remove unused Dask import from quality control module

* simplify missing value computation

* Rename parameter 'arr' to 'mtx' in _compute_obs_metrics no longer creates copy

* daskify qc_metrics

* Add fixture for array types and update imputation tests for dask arrays

* Refactor _compute_var_metrics to prevent modification of the original data matrix and add a test for encoding mode integrity

* Add parameterized tests for array types in miceforest imputation

* Update missing values handling to include array type in error message and refine parameterized tests for miceforest imputation

* Fix array type handling in missing values computation and update test for miceforest imputation

* Implement array type handling in load_dataframe function and update tests for miceforest imputation

* Remove parameterization for array types in miceforest numerical data imputation test

* Update tests/preprocessing/test_quality_control.py

Co-authored-by: Eljas Roellin <[email protected]>

* Update tests/preprocessing/test_quality_control.py

Co-authored-by: Eljas Roellin <[email protected]>

* revert deepcopy changes

* Fix test to ensure original matrix is not modified after encoding

* Remove unused parameters from observation and variable metrics computation functions

* Add sparse.csr_matrix to explicit impute array types test case

* Parameterize quality control metrics tests to support multiple array types

* Remove unused imports from test_quality_control.py

* encode blocks dask function

* Add pytest fixtures for observation and variable data in tests

* Update tests/preprocessing/test_quality_control.py

Co-authored-by: Eljas Roellin <[email protected]>

* Update tests/preprocessing/test_quality_control.py

Co-authored-by: Eljas Roellin <[email protected]>

* support dask explicit impute all object types

---------

Co-authored-by: eroell <[email protected]>
Co-authored-by: Lukas Heumos <[email protected]>
Co-authored-by: Eljas Roellin <[email protected]>
  • Loading branch information
4 people authored Feb 27, 2025
1 parent 324a978 commit 4c4db5f
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 97 deletions.
48 changes: 44 additions & 4 deletions ehrapy/preprocessing/_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import warnings
from collections.abc import Iterable
from functools import singledispatch
from typing import TYPE_CHECKING, Literal

import numpy as np
Expand All @@ -11,7 +12,7 @@
from sklearn.impute import SimpleImputer

from ehrapy import settings
from ehrapy._compat import _check_module_importable
from ehrapy._compat import _check_module_importable, _raise_array_type_not_implemented
from ehrapy._progress import spinner
from ehrapy.anndata import check_feature_types
from ehrapy.anndata.anndata_ext import (
Expand All @@ -23,6 +24,13 @@
if TYPE_CHECKING:
from anndata import AnnData

try:
import dask.array as da

DASK_AVAILABLE = True
except ImportError:
DASK_AVAILABLE = False


@spinner("Performing explicit impute")
def explicit_impute(
Expand Down Expand Up @@ -76,7 +84,9 @@ def explicit_impute(
imputation_value = _extract_impute_value(replacement, column_name)
# only replace if an explicit value got passed or could be extracted from replacement
if imputation_value:
_replace_explicit(adata.X[:, idx : idx + 1], imputation_value, impute_empty_strings)
adata.X[:, idx : idx + 1] = _replace_explicit(
adata.X[:, idx : idx + 1], imputation_value, impute_empty_strings
)
else:
logger.warning(f"No replace value passed and found for var [not bold green]{column_name}.")
else:
Expand All @@ -87,13 +97,33 @@ def explicit_impute(
return adata if copy else None


def _replace_explicit(arr: np.ndarray, replacement: str | int, impute_empty_strings: bool) -> None:
@singledispatch
def _replace_explicit(arr, replacement: str | int, impute_empty_strings: bool) -> None:
_raise_array_type_not_implemented(_replace_explicit, type(arr))


@_replace_explicit.register
def _(arr: np.ndarray, replacement: str | int, impute_empty_strings: bool) -> np.ndarray:
"""Replace one column or whole X with a value where missing values are stored."""
if not impute_empty_strings: # pragma: no cover
impute_conditions = pd.isnull(arr)
else:
impute_conditions = np.logical_or(pd.isnull(arr), arr == "")
arr[impute_conditions] = replacement
return arr


if DASK_AVAILABLE:

@_replace_explicit.register(da.Array)
def _(arr: da.Array, replacement: str | int, impute_empty_strings: bool) -> da.Array:
"""Replace one column or whole X with a value where missing values are stored."""
if not impute_empty_strings: # pragma: no cover
impute_conditions = da.isnull(arr)
else:
impute_conditions = da.logical_or(da.isnull(arr), arr == "")
arr[impute_conditions] = replacement
return arr


def _extract_impute_value(replacement: dict[str, str | int], column_name: str) -> str | int | None:
Expand Down Expand Up @@ -469,12 +499,22 @@ def mice_forest_impute(
return adata if copy else None


@singledispatch
def load_dataframe(arr, columns, index):
_raise_array_type_not_implemented(load_dataframe, type(arr))


@load_dataframe.register
def _(arr: np.ndarray, columns, index):
return pd.DataFrame(arr, columns=columns, index=index)


def _miceforest_impute(
adata, var_names, save_all_iterations_data, random_state, inplace, iterations, variable_parameters, verbose
) -> None:
import miceforest as mf

data_df = pd.DataFrame(adata.X, columns=adata.var_names, index=adata.obs_names)
data_df = load_dataframe(adata.X, columns=adata.var_names, index=adata.obs_names)
data_df = data_df.apply(pd.to_numeric, errors="coerce")

if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names):
Expand Down
97 changes: 60 additions & 37 deletions ehrapy/preprocessing/_quality_control.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
from functools import singledispatch
from pathlib import Path
from typing import TYPE_CHECKING, Literal

Expand All @@ -9,6 +10,7 @@
from lamin_utils import logger
from thefuzz import process

from ehrapy._compat import _raise_array_type_not_implemented
from ehrapy.anndata import anndata_to_df
from ehrapy.preprocessing._encoding import _get_encoded_features

Expand All @@ -17,6 +19,13 @@

from anndata import AnnData

try:
import dask.array as da

DASK_AVAILABLE = True
except ImportError:
DASK_AVAILABLE = False


def qc_metrics(
adata: AnnData, qc_vars: Collection[str] = (), layer: str = None
Expand Down Expand Up @@ -55,55 +64,57 @@ def qc_metrics(
>>> obs_qc, var_qc = ep.pp.qc_metrics(adata)
>>> obs_qc["missing_values_pct"].plot(kind="hist", bins=20)
"""
obs_metrics = _obs_qc_metrics(adata, layer, qc_vars)
var_metrics = _var_qc_metrics(adata, layer)

adata.obs[obs_metrics.columns] = obs_metrics
mtx = adata.X if layer is None else adata.layers[layer]
var_metrics = _compute_var_metrics(mtx, adata)
obs_metrics = _compute_obs_metrics(mtx, adata, qc_vars=qc_vars, log1p=True)

adata.var[var_metrics.columns] = var_metrics
adata.obs[obs_metrics.columns] = obs_metrics

return obs_metrics, var_metrics


def _missing_values(
arr: np.ndarray, mode: Literal["abs", "pct"] = "abs", df_type: Literal["obs", "var"] = "obs"
) -> np.ndarray:
"""Calculates the absolute or relative amount of missing values.
@singledispatch
def _compute_missing_values(mtx, axis):
_raise_array_type_not_implemented(_compute_missing_values, type(mtx))

Args:
arr: Numpy array containing a data row which is a subset of X (mtx).
mode: Whether to calculate absolute or percentage of missing values.
df_type: Whether to calculate the proportions for obs or var. One of 'obs' or 'var'.

Returns:
Absolute or relative amount of missing values.
"""
num_missing = pd.isnull(arr).sum()
if mode == "abs":
return num_missing
elif mode == "pct":
total_elements = arr.shape[0] if df_type == "obs" else len(arr)
return (num_missing / total_elements) * 100
@_compute_missing_values.register
def _(mtx: np.ndarray, axis) -> np.ndarray:
return pd.isnull(mtx).sum(axis)


if DASK_AVAILABLE:

def _obs_qc_metrics(
adata: AnnData, layer: str = None, qc_vars: Collection[str] = (), log1p: bool = True
) -> pd.DataFrame:
@_compute_missing_values.register
def _(mtx: da.Array, axis) -> np.ndarray:
return da.isnull(mtx).sum(axis).compute()


def _compute_obs_metrics(
mtx,
adata: AnnData,
*,
qc_vars: Collection[str] = (),
log1p: bool = True,
):
"""Calculates quality control metrics for observations.
See :func:`~ehrapy.preprocessing._quality_control.calculate_qc_metrics` for a list of calculated metrics.
Args:
mtx: Data array.
adata: Annotated data matrix.
layer: Layer containing the actual data matrix.
qc_vars: A list of previously calculated QC metrics to calculate summary statistics for.
log1p: Whether to apply log1p normalization for the QC metrics. Only used with parameter 'qc_vars'.
Returns:
A Pandas DataFrame with the calculated metrics.
"""

obs_metrics = pd.DataFrame(index=adata.obs_names)
var_metrics = pd.DataFrame(index=adata.var_names)
mtx = adata.X if layer is None else adata.layers[layer]

if "encoding_mode" in adata.var:
for original_values_categorical in _get_encoded_features(adata):
Expand All @@ -120,8 +131,8 @@ def _obs_qc_metrics(
)
)

obs_metrics["missing_values_abs"] = np.apply_along_axis(_missing_values, 1, mtx, mode="abs")
obs_metrics["missing_values_pct"] = np.apply_along_axis(_missing_values, 1, mtx, mode="pct", df_type="obs")
obs_metrics["missing_values_abs"] = _compute_missing_values(mtx, axis=1)
obs_metrics["missing_values_pct"] = (obs_metrics["missing_values_abs"] / mtx.shape[1]) * 100

# Specific QC metrics
for qc_var in qc_vars:
Expand All @@ -136,10 +147,19 @@ def _obs_qc_metrics(
return obs_metrics


def _var_qc_metrics(adata: AnnData, layer: str | None = None) -> pd.DataFrame:
var_metrics = pd.DataFrame(index=adata.var_names)
mtx = adata.X if layer is None else adata.layers[layer]
def _compute_var_metrics(
mtx,
adata: AnnData,
):
"""Compute variable metrics for quality control.
Args:
mtx: Data array.
adata: Annotated data matrix.
"""

categorical_indices = np.ndarray([0], dtype=int)
var_metrics = pd.DataFrame(index=adata.var_names)

if "encoding_mode" in adata.var.keys():
for original_values_categorical in _get_encoded_features(adata):
Expand All @@ -157,32 +177,35 @@ def _var_qc_metrics(adata: AnnData, layer: str | None = None) -> pd.DataFrame:
mtx[:, index].shape[1],
)
categorical_indices = np.concatenate([categorical_indices, index])

non_categorical_indices = np.ones(mtx.shape[1], dtype=bool)
non_categorical_indices[categorical_indices] = False
var_metrics["missing_values_abs"] = np.apply_along_axis(_missing_values, 0, mtx, mode="abs")
var_metrics["missing_values_pct"] = np.apply_along_axis(_missing_values, 0, mtx, mode="pct", df_type="var")

var_metrics["missing_values_abs"] = _compute_missing_values(mtx, axis=0)
var_metrics["missing_values_pct"] = (var_metrics["missing_values_abs"] / mtx.shape[0]) * 100

var_metrics["mean"] = np.nan
var_metrics["median"] = np.nan
var_metrics["standard_deviation"] = np.nan
var_metrics["min"] = np.nan
var_metrics["max"] = np.nan
var_metrics["iqr_outliers"] = np.nan

try:
var_metrics.loc[non_categorical_indices, "mean"] = np.nanmean(
np.array(mtx[:, non_categorical_indices], dtype=np.float64), axis=0
mtx[:, non_categorical_indices].astype(np.float64), axis=0
)
var_metrics.loc[non_categorical_indices, "median"] = np.nanmedian(
np.array(mtx[:, non_categorical_indices], dtype=np.float64), axis=0
mtx[:, non_categorical_indices].astype(np.float64), axis=0
)
var_metrics.loc[non_categorical_indices, "standard_deviation"] = np.nanstd(
np.array(mtx[:, non_categorical_indices], dtype=np.float64), axis=0
mtx[:, non_categorical_indices].astype(np.float64), axis=0
)
var_metrics.loc[non_categorical_indices, "min"] = np.nanmin(
np.array(mtx[:, non_categorical_indices], dtype=np.float64), axis=0
mtx[:, non_categorical_indices].astype(np.float64), axis=0
)
var_metrics.loc[non_categorical_indices, "max"] = np.nanmax(
np.array(mtx[:, non_categorical_indices], dtype=np.float64), axis=0
mtx[:, non_categorical_indices].astype(np.float64), axis=0
)

# Calculate IQR and define IQR outliers
Expand Down
53 changes: 51 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
import pytest
from anndata import AnnData
from matplotlib.testing.compare import compare_images
Expand All @@ -29,6 +30,54 @@ def rng():
return np.random.default_rng(seed=42)


@pytest.fixture
def obs_data():
return {
"disease": ["cancer", "tumor"],
"country": ["Germany", "switzerland"],
"sex": ["male", "female"],
}


@pytest.fixture
def var_data():
return {
"alive": ["yes", "no", "maybe"],
"hospital": ["hospital 1", "hospital 2", "hospital 1"],
"crazy": ["yes", "yes", "yes"],
}


@pytest.fixture
def missing_values_adata(obs_data, var_data):
return AnnData(
X=np.array([[0.21, np.nan, 41.42], [np.nan, np.nan, 7.234]], dtype=np.float32),
obs=pd.DataFrame(data=obs_data),
var=pd.DataFrame(data=var_data, index=["Acetaminophen", "hospital", "crazy"]),
)


@pytest.fixture
def lab_measurements_simple_adata(obs_data, var_data):
X = np.array([[73, 0.02, 1.00], [148, 0.25, 3.55]], dtype=np.float32)
return AnnData(
X=X,
obs=pd.DataFrame(data=obs_data),
var=pd.DataFrame(data=var_data, index=["Acetaminophen", "Acetoacetic acid", "Beryllium, toxic"]),
)


@pytest.fixture
def lab_measurements_layer_adata(obs_data, var_data):
X = np.array([[73, 0.02, 1.00], [148, 0.25, 3.55]], dtype=np.float32)
return AnnData(
X=X,
obs=pd.DataFrame(data=obs_data),
var=pd.DataFrame(data=var_data, index=["Acetaminophen", "Acetoacetic acid", "Beryllium, toxic"]),
layers={"layer_copy": X},
)


@pytest.fixture
def mimic_2():
adata = ep.dt.mimic_2()
Expand Down Expand Up @@ -152,10 +201,10 @@ def asarray(a):
return np.asarray(a)


def as_dense_dask_array(a):
def as_dense_dask_array(a, chunk_size=1000):
import dask.array as da

return da.asarray(a)
return da.from_array(a, chunks=chunk_size)


ARRAY_TYPES = (asarray, as_dense_dask_array)
Loading

0 comments on commit 4c4db5f

Please sign in to comment.