Skip to content

Commit c6973bd

Browse files
Donut-MultiPolygons are now correctly rendered again (#334)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4e34da3 commit c6973bd

8 files changed

+122
-237
lines changed

.github/workflows/test.yaml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,11 @@ jobs:
5757
DISPLAY: :42
5858
run: |
5959
pytest -v --cov --color=yes --cov-report=xml
60-
# - name: Generate GH action "groundtruth" figures as artifacts, uncomment if needed
61-
# if: always()
62-
# uses: actions/upload-artifact@v3
63-
# with:
64-
# name: groundtruth-figures
65-
# path: /home/runner/work/spatialdata-plot/spatialdata-plot/tests/_images/*
6660
- name: Archive figures generated during testing
6761
if: always()
6862
uses: actions/upload-artifact@v3
6963
with:
70-
name: plotting-results
64+
name: visual_test_results_${{ matrix.os }}-python${{ matrix.python }}
7165
path: /home/runner/work/spatialdata-plot/spatialdata-plot/tests/figures/*
7266
- name: Upload coverage to Codecov
7367
uses: codecov/codecov-action@v4

src/spatialdata_plot/pl/utils.py

Lines changed: 76 additions & 186 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212

1313
import matplotlib
1414
import matplotlib.patches as mpatches
15-
import matplotlib.patches as mplp
1615
import matplotlib.path as mpath
1716
import matplotlib.pyplot as plt
18-
import multiscale_spatial_image as msi
1917
import numpy as np
2018
import pandas as pd
2119
import shapely
@@ -49,7 +47,6 @@
4947
from scanpy.plotting._tools.scatterplots import _add_categorical_legend
5048
from scanpy.plotting._utils import add_colors_for_categorical_sample_annotation
5149
from scanpy.plotting.palettes import default_20, default_28, default_102
52-
from shapely.geometry import LineString, Polygon
5350
from skimage.color import label2rgb
5451
from skimage.morphology import erosion, square
5552
from skimage.segmentation import find_boundaries
@@ -283,6 +280,30 @@ def _sanitise_na_color(na_color: ColorLike | None) -> tuple[str, bool]:
283280
raise ValueError(f"Invalid na_color value: {na_color}")
284281

285282

283+
def _get_centroid_of_pathpatch(pathpatch: mpatches.PathPatch) -> tuple[float, float]:
284+
# Extract the vertices from the PathPatch
285+
path = pathpatch.get_path()
286+
vertices = path.vertices
287+
x = vertices[:, 0]
288+
y = vertices[:, 1]
289+
290+
area = 0.5 * np.sum(x[:-1] * y[1:] - x[1:] * y[:-1])
291+
292+
# Calculate the centroid coordinates
293+
centroid_x = np.sum((x[:-1] + x[1:]) * (x[:-1] * y[1:] - x[1:] * y[:-1])) / (6 * area)
294+
centroid_y = np.sum((y[:-1] + y[1:]) * (x[:-1] * y[1:] - x[1:] * y[:-1])) / (6 * area)
295+
296+
return centroid_x, centroid_y
297+
298+
299+
def _scale_pathpatch_around_centroid(pathpatch: mpatches.PathPatch, scale_factor: float) -> None:
300+
301+
centroid = _get_centroid_of_pathpatch(pathpatch)
302+
vertices = pathpatch.get_path().vertices
303+
scaled_vertices = np.array([centroid + (vertex - centroid) * scale_factor for vertex in vertices])
304+
pathpatch.get_path().vertices = scaled_vertices
305+
306+
286307
def _get_collection_shape(
287308
shapes: list[GeoDataFrame],
288309
c: Any,
@@ -352,63 +373,64 @@ def _get_collection_shape(
352373
outline_c = outline_c * fill_c.shape[0]
353374

354375
shapes_df = pd.DataFrame(shapes, copy=True)
355-
356-
# remove empty points/polygons
357376
shapes_df = shapes_df[shapes_df["geometry"].apply(lambda geom: not geom.is_empty)]
358-
359-
# reset index of shapes_df for case of spatial query
360377
shapes_df = shapes_df.reset_index(drop=True)
361378

362-
rows = []
363-
364-
def assign_fill_and_outline_to_row(
365-
shapes: list[GeoDataFrame], fill_c: list[Any], outline_c: list[Any], row: pd.Series, idx: int
379+
def _assign_fill_and_outline_to_row(
380+
fill_c: list[Any], outline_c: list[Any], row: dict[str, Any], idx: int, is_multiple_shapes: bool
366381
) -> None:
367382
try:
368-
if len(shapes) > 1 and len(fill_c) == 1:
369-
row["fill_c"] = fill_c
370-
row["outline_c"] = outline_c
383+
if is_multiple_shapes and len(fill_c) == 1:
384+
row["fill_c"] = fill_c[0]
385+
row["outline_c"] = outline_c[0]
371386
else:
372387
row["fill_c"] = fill_c[idx]
373388
row["outline_c"] = outline_c[idx]
374389
except IndexError as e:
375-
raise IndexError("Could not assign fill and outline colors due to a mismatch in row-numbers.") from e
376-
377-
# Match colors to the geometry, potentially expanding the row in case of
378-
# multipolygons
379-
for idx, row in shapes_df.iterrows():
380-
geom = row["geometry"]
381-
if geom.geom_type == "Polygon":
382-
row = row.to_dict()
383-
coords = np.array(geom.exterior.coords)
384-
centroid = np.mean(coords, axis=0)
385-
scaled_coords = [(centroid + (np.array(coord) - centroid) * s).tolist() for coord in geom.exterior.coords]
386-
row["geometry"] = mplp.Polygon(scaled_coords, closed=True)
387-
assign_fill_and_outline_to_row(shapes, fill_c, outline_c, row, idx)
388-
rows.append(row)
389-
390-
elif geom.geom_type == "MultiPolygon":
391-
# mp = _make_patch_from_multipolygon(geom)
392-
for polygon in geom.geoms:
393-
mp_copy = row.to_dict()
394-
coords = np.array(polygon.exterior.coords)
395-
centroid = np.mean(coords, axis=0)
396-
scaled_coords = [(centroid + (coord - centroid) * s).tolist() for coord in coords]
397-
mp_copy["geometry"] = mplp.Polygon(scaled_coords, closed=True)
398-
assign_fill_and_outline_to_row(shapes, fill_c, outline_c, mp_copy, idx)
399-
rows.append(mp_copy)
400-
401-
elif geom.geom_type == "Point":
402-
row = row.to_dict()
403-
scaled_radius = row["radius"] * s
404-
row["geometry"] = mplp.Circle(
405-
(geom.x, geom.y), radius=scaled_radius
406-
) # Circle is always scaled from its center
407-
assign_fill_and_outline_to_row(shapes, fill_c, outline_c, row, idx)
408-
rows.append(row)
409-
410-
patches = pd.DataFrame(rows)
411-
390+
raise IndexError("Could not assign fill and outline colors due to a mismatch in row numbers.") from e
391+
392+
def _process_polygon(row: pd.Series, s: float) -> dict[str, Any]:
393+
coords = np.array(row["geometry"].exterior.coords)
394+
centroid = np.mean(coords, axis=0)
395+
scaled_coords = (centroid + (coords - centroid) * s).tolist()
396+
return {**row.to_dict(), "geometry": mpatches.Polygon(scaled_coords, closed=True)}
397+
398+
def _process_multipolygon(row: pd.Series, s: float) -> list[dict[str, Any]]:
399+
mp = _make_patch_from_multipolygon(row["geometry"])
400+
row_dict = row.to_dict()
401+
for m in mp:
402+
_scale_pathpatch_around_centroid(m, s)
403+
404+
return [{**row_dict, "geometry": m} for m in mp]
405+
406+
def _process_point(row: pd.Series, s: float) -> dict[str, Any]:
407+
return {
408+
**row.to_dict(),
409+
"geometry": mpatches.Circle((row["geometry"].x, row["geometry"].y), radius=row["radius"] * s),
410+
}
411+
412+
def _create_patches(shapes_df: GeoDataFrame, fill_c: list[Any], outline_c: list[Any], s: float) -> pd.DataFrame:
413+
rows = []
414+
is_multiple_shapes = len(shapes_df) > 1
415+
416+
for idx, row in shapes_df.iterrows():
417+
geom_type = row["geometry"].geom_type
418+
processed_rows = []
419+
420+
if geom_type == "Polygon":
421+
processed_rows.append(_process_polygon(row, s))
422+
elif geom_type == "MultiPolygon":
423+
processed_rows.extend(_process_multipolygon(row, s))
424+
elif geom_type == "Point":
425+
processed_rows.append(_process_point(row, s))
426+
427+
for processed_row in processed_rows:
428+
_assign_fill_and_outline_to_row(fill_c, outline_c, processed_row, idx, is_multiple_shapes)
429+
rows.append(processed_row)
430+
431+
return pd.DataFrame(rows)
432+
433+
patches = _create_patches(shapes_df, fill_c, outline_c, s)
412434
return PatchCollection(
413435
patches["geometry"].values.tolist(),
414436
snap=False,
@@ -788,7 +810,7 @@ def _map_color_seg(
788810
cell_id = np.array(cell_id)
789811
if color_vector is not None and isinstance(color_vector.dtype, pd.CategoricalDtype):
790812
# users wants to plot a categorical column
791-
if isinstance(na_color, tuple) and len(na_color) == 4 and np.any(color_source_vector.isna()):
813+
if np.any(color_source_vector.isna()):
792814
cell_id[color_source_vector.isna()] = 0
793815
val_im: ArrayLike = map_array(seg, cell_id, color_vector.codes + 1)
794816
cols = colors.to_rgba_array(color_vector.categories)
@@ -873,9 +895,9 @@ def _modify_categorical_color_mapping(
873895
modified_mapping = {key: mapping[key] for key in mapping if key in groups or key == "NaN"}
874896
elif len(palette) == len(groups) and isinstance(groups, list) and isinstance(palette, list):
875897
modified_mapping = dict(zip(groups, palette))
876-
877898
else:
878899
raise ValueError(f"Expected palette to be of length `{len(groups)}`, found `{len(palette)}`.")
900+
879901
return modified_mapping
880902

881903

@@ -891,7 +913,7 @@ def _get_default_categorial_color_mapping(
891913
palette = default_102
892914
else:
893915
palette = ["grey" for _ in range(len_cat)]
894-
logger.info("input has more than 103 categories. Uniform " "'grey' color will be used for all categories.")
916+
logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.")
895917

896918
return {cat: to_hex(to_rgba(col)[:3]) for cat, col in zip(color_source_vector.categories, palette[:len_cat])}
897919

@@ -922,54 +944,6 @@ def _get_categorical_color_mapping(
922944
return _modify_categorical_color_mapping(base_mapping, groups, palette)
923945

924946

925-
def _get_palette(
926-
categories: Sequence[Any],
927-
adata: AnnData | None = None,
928-
cluster_key: None | str = None,
929-
palette: ListedColormap | str | list[str] | None = None,
930-
alpha: float = 1.0,
931-
) -> Mapping[str, str] | None:
932-
palette = None if isinstance(palette, list) and palette[0] is None else palette
933-
if adata is not None and palette is None:
934-
try:
935-
palette = adata.uns[f"{cluster_key}_colors"] # type: ignore[arg-type]
936-
if len(palette) != len(categories):
937-
raise ValueError(
938-
f"Expected palette to be of length `{len(categories)}`, found `{len(palette)}`. "
939-
+ f"Removing the colors in `adata.uns` with `adata.uns.pop('{cluster_key}_colors')` may help."
940-
)
941-
return {cat: to_hex(to_rgba(col)[:3]) for cat, col in zip(categories, palette)}
942-
except KeyError as e:
943-
logger.warning(e)
944-
return None
945-
946-
len_cat = len(categories)
947-
948-
if palette is None:
949-
if len_cat <= 20:
950-
palette = default_20
951-
elif len_cat <= 28:
952-
palette = default_28
953-
elif len_cat <= len(default_102): # 103 colors
954-
palette = default_102
955-
else:
956-
palette = ["grey" for _ in range(len_cat)]
957-
logger.info("input has more than 103 categories. Uniform " "'grey' color will be used for all categories.")
958-
return {cat: to_hex(to_rgba(col)[:3]) for cat, col in zip(categories, palette[:len_cat])}
959-
960-
if isinstance(palette, str):
961-
cmap = ListedColormap([palette])
962-
elif isinstance(palette, list):
963-
cmap = ListedColormap(palette)
964-
elif isinstance(palette, ListedColormap):
965-
cmap = palette
966-
else:
967-
raise TypeError(f"Palette is {type(palette)} but should be string or list.")
968-
palette = [to_hex(np.round(x, 5)) for x in cmap(np.linspace(0, 1, len_cat), alpha=alpha)]
969-
970-
return dict(zip(categories, palette))
971-
972-
973947
def _maybe_set_colors(
974948
source: AnnData, target: AnnData, key: str, palette: str | ListedColormap | Cycler | Sequence[Any] | None = None
975949
) -> None:
@@ -1137,34 +1111,6 @@ def save_fig(fig: Figure, path: str | Path, make_dir: bool = True, ext: str = "p
11371111
fig.savefig(path, **kwargs)
11381112

11391113

1140-
def _get_cs_element_map(
1141-
element: str | Sequence[str] | None,
1142-
element_map: Mapping[str, Any],
1143-
) -> Mapping[str, str]:
1144-
"""Get the mapping between the coordinate system and the class."""
1145-
# from spatialdata.models import Image2DModel, Image3DModel, Labels2DModel, Labels3DModel, PointsModel, ShapesModel
1146-
element = list(element_map.keys())[0] if element is None else element
1147-
element = [element] if isinstance(element, str) else element
1148-
d = {}
1149-
for e in element:
1150-
cs = list(element_map[e].attrs["transform"].keys())[0]
1151-
d[cs] = e
1152-
# model = get_model(element_map["blobs_labels"])
1153-
# if model in [Image2DModel, Image3DModel, Labels2DModel, Labels3DModel]
1154-
return d
1155-
1156-
1157-
def _multiscale_to_image(sdata: sd.SpatialData) -> sd.SpatialData:
1158-
if sdata.images is None:
1159-
raise ValueError("No images found in the SpatialData object.")
1160-
1161-
for k, v in sdata.images.items():
1162-
if isinstance(v, msi.multiscale_spatial_image.DataTree):
1163-
sdata.images[k] = Image2DModel.parse(v["scale0"].ds.to_array().squeeze(axis=0))
1164-
1165-
return sdata
1166-
1167-
11681114
def _get_linear_colormap(colors: list[str], background: str) -> list[LinearSegmentedColormap]:
11691115
return [LinearSegmentedColormap.from_list(c, [background, c], N=256) for c in colors]
11701116

@@ -1176,62 +1122,6 @@ def _get_listed_colormap(color_dict: dict[str, str]) -> ListedColormap:
11761122
return ListedColormap(["black"] + colors, N=len(colors) + 1)
11771123

11781124

1179-
def _translate_image(
1180-
image: DataArray,
1181-
translation: sd.transformations.transformations.Translation,
1182-
) -> DataArray:
1183-
shifts: dict[str, int] = {axis: int(translation.translation[idx]) for idx, axis in enumerate(translation.axes)}
1184-
img = image.values.copy()
1185-
# for yx images (important for rasterized MultiscaleImages as labels)
1186-
expanded_dims = False
1187-
if len(img.shape) == 2:
1188-
img = np.expand_dims(img, axis=0)
1189-
expanded_dims = True
1190-
1191-
shifted_channels = []
1192-
1193-
# split channels, shift axes individually, them recombine
1194-
if len(img.shape) == 3:
1195-
for c in range(img.shape[0]):
1196-
channel = img[c, :, :]
1197-
1198-
# iterates over [x, y]
1199-
for axis, shift in shifts.items():
1200-
pad_x, pad_y = (0, 0), (0, 0)
1201-
if axis == "x" and shift > 0:
1202-
pad_x = (abs(shift), 0)
1203-
elif axis == "x" and shift < 0:
1204-
pad_x = (0, abs(shift))
1205-
1206-
if axis == "y" and shift > 0:
1207-
pad_y = (abs(shift), 0)
1208-
elif axis == "y" and shift < 0:
1209-
pad_y = (0, abs(shift))
1210-
1211-
channel = np.pad(channel, (pad_y, pad_x), mode="constant")
1212-
1213-
shifted_channels.append(channel)
1214-
1215-
if expanded_dims:
1216-
return Labels2DModel.parse(
1217-
np.array(shifted_channels[0]),
1218-
dims=["y", "x"],
1219-
transformations=image.attrs["transform"],
1220-
)
1221-
return Image2DModel.parse(
1222-
np.array(shifted_channels),
1223-
dims=["c", "y", "x"],
1224-
transformations=image.attrs["transform"],
1225-
)
1226-
1227-
1228-
def _convert_polygon_to_linestrings(polygon: Polygon) -> list[LineString]:
1229-
b = polygon.boundary.coords
1230-
linestrings = [LineString(b[k : k + 2]) for k in range(len(b) - 1)]
1231-
1232-
return [list(ls.coords) for ls in linestrings]
1233-
1234-
12351125
def _split_multipolygon_into_outer_and_inner(mp: shapely.MultiPolygon): # type: ignore
12361126
# https://stackoverflow.com/a/21922058
12371127

Loading
-36.9 KB
Binary file not shown.
-936 Bytes
Loading

0 commit comments

Comments
 (0)