diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index a2e8f767..ed139b4a 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -733,12 +733,17 @@ def _set_color_source_vec( table_name: str | None = None, table_layer: str | None = None, render_type: Literal["points"] | None = None, -) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]: +) -> tuple[pd.Categorical | None, ArrayLike, bool]: if value_to_plot is None and element is not None: color = np.full(len(element), na_color) - return color, color, False + return None, color, False + + # First check if value_to_plot is likely a color specification rather than a column name + if value_to_plot is not None and _is_color_like(value_to_plot) and element is not None: + # User passed a color, not a column name + color = np.full(len(element), value_to_plot) + return None, color, False - # Figure out where to get the color from origins = _locate_value( value_key=value_to_plot, sdata=sdata, @@ -760,27 +765,55 @@ def _set_color_source_vec( table_layer=table_layer, )[value_to_plot] - # numerical case, return early - # TODO temporary split until refactor is complete - if color_source_vector is not None and not isinstance(color_source_vector.dtype, pd.CategoricalDtype): - if ( - not isinstance(element, GeoDataFrame) - and isinstance(palette, list) - and palette[0] is not None - or isinstance(element, GeoDataFrame) - and isinstance(palette, list) - ): - logger.warning( - "Ignoring categorical palette which is given for a continuous variable. " - "Consider using `cmap` to pass a ColorMap." - ) - return None, color_source_vector, False - - color_source_vector = pd.Categorical(color_source_vector) # convert, e.g., `pd.Series` + # Convert to categorical if not already + if not isinstance(color_source_vector, pd.Categorical): + try: + color_source_vector = pd.Categorical(color_source_vector) + except (ValueError, TypeError) as e: + logger.warning(f"Could not convert '{value_to_plot}' to categorical: {e}") + # For numeric data, return None to indicate non-categorical + if pd.api.types.is_numeric_dtype(color_source_vector): + if ( + not isinstance(element, GeoDataFrame) + and isinstance(palette, list) + and palette[0] is not None + or isinstance(element, GeoDataFrame) + and isinstance(palette, list) + ): + logger.warning( + "Ignoring categorical palette which is given for a continuous variable. " + "Consider using `cmap` to pass a ColorMap." + ) + return None, color_source_vector, False + # For other types, try to use as is + return None, color_source_vector, False + + # At this point color_source_vector should be categorical + adata_with_colors = None + cluster_key = value_to_plot + + # First check if the table_name is specified + if table_name is not None and table_name in sdata.tables: + adata_with_colors = sdata.tables[table_name] + adata_with_colors.uns["spatialdata_key"] = table_name + + # If not, but the element is annotated by any table, use that + elif element_name is not None: + annotator_tables = get_element_annotators(sdata, element_name) + if len(annotator_tables) > 0: + # Use the first table that annotates this element + first_table = next(iter(annotator_tables)) + adata_with_colors = sdata.tables[first_table] + adata_with_colors.uns["spatialdata_key"] = first_table + + # If no specific table is found, try using the default table + elif sdata.table is not None: + adata_with_colors = sdata.table + adata_with_colors.uns["spatialdata_key"] = "default_table" color_mapping = _get_categorical_color_mapping( - adata=sdata.table, - cluster_key=value_to_plot, + adata=adata_with_colors, + cluster_key=cluster_key, color_source_vector=color_source_vector, cmap_params=cmap_params, alpha=alpha, @@ -790,18 +823,28 @@ def _set_color_source_vec( render_type=render_type, ) + # Set categories to match the mapping keys color_source_vector = color_source_vector.set_categories(color_mapping.keys()) if color_mapping is None: raise ValueError("Unable to create color palette.") - # do not rename categories, as colors need not be unique - color_vector = color_source_vector.map(color_mapping) + # Map categorical values to colors + try: + color_vector = color_source_vector.map(color_mapping) + except (KeyError, TypeError, ValueError) as e: + logger.warning(f"Error mapping colors: {e}. Attempting alternate approach.") + # Try mapping with string conversion + str_mapping = {str(k): v for k, v in color_mapping.items()} + color_vector = pd.Series( + [str_mapping.get(str(x), color_mapping.get("NaN", "#d3d3d3")) for x in color_source_vector], + index=color_source_vector.index, + ) return color_source_vector, color_vector, True - logger.warning(f"Color key '{value_to_plot}' for element '{element_name}' not been found, using default colors.") + logger.warning(f"Color key '{value_to_plot}' for element '{element_name}' not found, using default colors.") color = np.full(sdata[table_name].n_obs, to_hex(na_color)) - return color, color, False + return None, color, False def _map_color_seg( @@ -817,20 +860,34 @@ def _map_color_seg( ) -> ArrayLike: cell_id = np.array(cell_id) - if pd.api.types.is_categorical_dtype(color_vector.dtype): - # Case A: users wants to plot a categorical column + is_categorical = pd.api.types.is_categorical_dtype(getattr(color_vector, "dtype", None)) + is_numeric = pd.api.types.is_numeric_dtype(getattr(color_vector, "dtype", None)) + is_pandas_series = isinstance(color_vector, pd.Series) + + # Case A: categorical column + if is_categorical: if np.any(color_source_vector.isna()): cell_id[color_source_vector.isna()] = 0 val_im: ArrayLike = map_array(seg.copy(), cell_id, color_vector.codes + 1) cols = colors.to_rgba_array(color_vector.categories) - elif pd.api.types.is_numeric_dtype(color_vector.dtype): - # Case B: user wants to plot a continous column - if isinstance(color_vector, pd.Series): + + # Case B: continuous column + elif is_numeric: + if is_pandas_series: color_vector = color_vector.to_numpy() cols = cmap_params.cmap(cmap_params.norm(color_vector)) val_im = map_array(seg.copy(), cell_id, cell_id) + + # Case C & D: Other cases (could be strings, or hex colors) else: - # Case C: User didn't specify any colors + # Get the first color safely, regardless of index structure + first_color = None + if is_pandas_series and len(color_vector) > 0: + first_color = color_vector.iloc[0] + elif not is_pandas_series and len(color_vector) > 0: + first_color = color_vector[0] + + # Case C: Using default colors with random generation if color_source_vector is not None and ( set(color_vector) == set(color_source_vector) and len(set(color_vector)) == 1 @@ -840,14 +897,31 @@ def _map_color_seg( val_im = map_array(seg.copy(), cell_id, cell_id) RNG = default_rng(42) cols = RNG.random((len(color_vector), 3)) + + # Case D: User specified explicit colors or we're using defaults else: - # Case D: User didn't specify a column to color by, but modified the na_color val_im = map_array(seg.copy(), cell_id, cell_id) - if "#" in str(color_vector[0]): - # we have hex colors - assert all(_is_color_like(c) for c in color_vector), "Not all values are color-like." - cols = colors.to_rgba_array(color_vector) + + # Check if we're dealing with hex colors + if first_color is not None and isinstance(first_color, str) and "#" in first_color: + # We have hex colors + all_is_color = True + for c in color_vector: + if not _is_color_like(c): + all_is_color = False + break + + if all_is_color: + try: + cols = colors.to_rgba_array(color_vector) + except ValueError as e: + logger.warning(f"Error converting colors: {e}, falling back to default colormap") + cols = cmap_params.cmap(cmap_params.norm(np.arange(len(color_vector)))) + else: + # Fall back to colormap + cols = cmap_params.cmap(cmap_params.norm(color_vector)) else: + # Use the colormap cols = cmap_params.cmap(cmap_params.norm(color_vector)) if seg_erosionpx is not None: @@ -879,20 +953,93 @@ def _generate_base_categorial_color_mapping( na_color: ColorLike, cmap_params: CmapParams | None = None, ) -> Mapping[str, str]: - if adata is not None and cluster_key in adata.uns and f"{cluster_key}_colors" in adata.uns: - colors = adata.uns[f"{cluster_key}_colors"] - categories = color_source_vector.categories.tolist() + ["NaN"] - if "#" not in na_color: - # should be unreachable, but just for safety - raise ValueError("Expected `na_color` to be a hex color, but got a non-hex color.") - - colors = [to_hex(to_rgba(color)[:3]) for color in colors] - na_color = to_hex(to_rgba(na_color)[:3]) - - if na_color and len(categories) > len(colors): - return dict(zip(categories, colors + [na_color], strict=True)) + color_key = f"{cluster_key}_colors" + color_found_in_uns_msg_template = ( + "Using colors from '{cluster}_colors' in .uns slot of table '{table}' for plotting. " + "If this is unexpected, please delete the column from your AnnData object." + ) - return dict(zip(categories, colors, strict=True)) + if adata is not None and cluster_key is not None: + if cluster_key in adata.uns and isinstance(adata.uns[cluster_key], dict): + # We have a direct color mapping dictionary + color_dict = adata.uns[cluster_key] + table_name = getattr(adata, "uns", {}).get("spatialdata_key", "") + if table_name: + logger.info(color_found_in_uns_msg_template.format(cluster=cluster_key, table=table_name)) + + # Ensure all values are hex colors + for k, v in color_dict.items(): + if isinstance(v, str) and not v.startswith("#"): + color_dict[k] = to_hex(to_rgba(v)) + + categories = color_source_vector.categories.tolist() + na_color_hex = to_hex(to_rgba(na_color)[:3]) + + return {cat: color_dict.get(str(cat), color_dict.get(cat, na_color_hex)) for cat in categories} + + if color_key in adata.uns: + colors = adata.uns[color_key] + table_name = getattr(adata, "uns", {}).get("spatialdata_key", "") + if table_name: + logger.info(color_found_in_uns_msg_template.format(cluster=cluster_key, table=table_name)) + + if isinstance(colors, list): + colors = [to_hex(to_rgba(color)[:3]) for color in colors] + categories = color_source_vector.categories.tolist() + + na_color_hex = to_hex(to_rgba(na_color)[:3]) + if "NaN" not in categories: + categories.append("NaN") + + if len(colors) < len(categories) - 1: # -1 for NaN + logger.warning( + f"Not enough colors in {color_key} ({len(colors)}) for all categories ({len(categories) - 1}). " + "Some categories will use default colors." + ) + # Extend with default colors or duplicate the last color + colors.extend([na_color_hex] * (len(categories) - 1 - len(colors))) + + return dict(zip(categories, colors + [na_color_hex], strict=False)) + + if isinstance(colors, np.ndarray): + colors = [to_hex(to_rgba(color)[:3]) for color in colors] + categories = color_source_vector.categories.tolist() + + na_color_hex = to_hex(to_rgba(na_color)[:3]) + if "NaN" not in categories: + categories.append("NaN") + + if len(colors) < len(categories) - 1: # -1 for NaN + logger.warning( + f"Not enough colors in {color_key} ({len(colors)}) for all categories ({len(categories) - 1}). " + "Some categories will use default colors." + ) + colors.extend([na_color_hex] * (len(categories) - 1 - len(colors))) + + return dict(zip(categories, colors + [na_color_hex], strict=False)) + + if isinstance(colors, dict): + # Ensure all values are hex colors + for k, v in colors.items(): + if isinstance(v, str) and not v.startswith("#"): + colors[k] = to_hex(to_rgba(v)) + + categories = color_source_vector.categories.tolist() + na_color_hex = to_hex(to_rgba(na_color)[:3]) + + # Try to match color keys to categories, accounting for string/categorical differences + result = {} + for cat in categories: + # Try direct match first + if cat in colors: + result[cat] = colors[cat] + # Then try string conversion - handles int/string mismatches + elif str(cat) in colors: + result[cat] = colors[str(cat)] + else: + result[cat] = na_color_hex + + return result return _get_default_categorial_color_mapping(color_source_vector=color_source_vector, cmap_params=cmap_params) @@ -1007,13 +1154,23 @@ def _maybe_set_colors( try: if palette is not None: raise KeyError("Unable to copy the palette when there was other explicitly specified.") - target.uns[color_key] = source.uns[color_key] + + # First check if source has the colors + if color_key in source.uns: + logger.info(f"Copying color information for '{key}' from source to target AnnData") + target.uns[color_key] = source.uns[color_key] + # Then check if the base key has colors (direct dict mapping) + elif key in source.uns and isinstance(source.uns[key], dict): + logger.info(f"Copying direct color mappings for '{key}' from source to target AnnData") + target.uns[key] = source.uns[key] + else: + raise KeyError(f"No color information found for '{key}' in source AnnData") + except KeyError: if isinstance(palette, str): palette = ListedColormap([palette]) if isinstance(palette, ListedColormap): # `scanpy` requires it palette = cycler(color=palette.colors) - palette = None add_colors_for_categorical_sample_annotation(target, key=key, force_update_colors=True, palette=palette)