diff --git a/aspect/data.py b/aspect/data.py index b42cbc1..5379740 100644 --- a/aspect/data.py +++ b/aspect/data.py @@ -399,17 +399,17 @@ def _featurize( @staticmethod def _unsqueeze( - x: Mapping[str, ArrayLike] + x: Mapping[str, ArrayLike], + columns: Optional[Iterable[str]] = None ) -> Dict[str, np.ndarray]: - def _unsqueeze(x, columns=None): - columns = columns or x.keys() - for key in columns: - vals = x[key] - if not isinstance(vals, dict): - vals = np.asarray(x[key]) - if vals.ndim == 1 and np.issubdtype(vals.dtype, np.number): - x[key] = vals[:, None] - return x + columns = columns or x.keys() + for key in columns: + vals = x[key] + if not isinstance(vals, dict): + vals = np.asarray(x[key]) + if vals.ndim == 1 and np.issubdtype(vals.dtype, np.number): + x[key] = vals[:, None] + return x def __call__( self, @@ -477,6 +477,7 @@ def __call__( .select_columns(all_output_columns) .map( self._unsqueeze, + fn_kwargs={"columns": output_columns}, batched=True, batch_size=batch_size, desc="Unsqueezing", diff --git a/pyproject.toml b/pyproject.toml index e08f10b..707d39f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "aspect-data" -version = "0.0.3" +version = "0.0.3.post1" authors = [ { name="Eachan Johnson", email="eachan.johnson@crick.ac.uk" }, ]