Skip to content

Commit 7147a64

Browse files
Refactor of labels logic (#336)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c6973bd commit 7147a64

File tree

1 file changed

+27
-94
lines changed

1 file changed

+27
-94
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 27 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -650,11 +650,9 @@ def _render_images(
650650
stacked = np.stack([layers[c] for c in channels], axis=-1)
651651
else: # -> use given cmap for each channel
652652
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
653-
# Apply cmaps to each channel, add up and normalize to [0, 1]
654653
stacked = (
655654
np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0) / n_channels
656655
)
657-
# Remove alpha channel so we can overwrite it from render_params.alpha
658656
stacked = stacked[:, :, :3]
659657
logger.warning(
660658
"One cmap was given for multiple channels and is now used for each channel. "
@@ -676,11 +674,7 @@ def _render_images(
676674
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))
677675

678676
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
679-
680-
# Apply cmaps to each channel and add up
681677
colored = np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0)
682-
683-
# Remove alpha channel so we can overwrite it from render_params.alpha
684678
colored = colored[:, :, :3]
685679

686680
_ax_show_and_transform(colored, trans_data, ax, render_params.alpha, zorder=render_params.zorder)
@@ -691,24 +685,16 @@ def _render_images(
691685
raise ValueError("If 'palette' is provided, its length must match the number of channels.")
692686

693687
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in palette if isinstance(c, str)]
694-
695-
# Apply cmaps to each channel and add up
696688
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)
697-
698-
# Remove alpha channel so we can overwrite it from render_params.alpha
699689
colored = colored[:, :, :3]
700690

701691
_ax_show_and_transform(colored, trans_data, ax, render_params.alpha, zorder=render_params.zorder)
702692

703693
elif palette is None and got_multiple_cmaps:
704694
channel_cmaps = [cp.cmap for cp in render_params.cmap_params] # type: ignore[union-attr]
705-
706-
# Apply cmaps to each channel, add up and normalize to [0, 1]
707695
colored = (
708696
np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0) / n_channels
709697
)
710-
711-
# Remove alpha channel so we can overwrite it from render_params.alpha
712698
colored = colored[:, :, :3]
713699

714700
_ax_show_and_transform(colored, trans_data, ax, render_params.alpha, zorder=render_params.zorder)
@@ -794,119 +780,66 @@ def _render_labels(
794780
table_name=table_name,
795781
)
796782

797-
# default case: no contour, just fill
798-
# if fill_alpha and outline_alpha are the same, we're technically also at a no-outline situation
799-
if render_params.outline_alpha == 0.0 or render_params.outline_alpha == render_params.fill_alpha:
800-
labels_infill = _map_color_seg(
783+
def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) -> matplotlib.image.AxesImage:
784+
labels = _map_color_seg(
801785
seg=label.values,
802786
cell_id=instance_id,
803787
color_vector=color_vector,
804788
color_source_vector=color_source_vector,
805789
cmap_params=render_params.cmap_params,
806-
seg_erosionpx=None,
807-
seg_boundaries=False,
790+
seg_erosionpx=seg_erosionpx,
791+
seg_boundaries=seg_boundaries,
808792
na_color=render_params.cmap_params.na_color,
809793
na_color_modified_by_user=render_params.cmap_params.na_color_modified_by_user,
810794
)
795+
811796
_cax = ax.imshow(
812-
labels_infill,
797+
labels,
813798
rasterized=True,
814799
cmap=None if categorical else render_params.cmap_params.cmap,
815800
norm=None if categorical else render_params.cmap_params.norm,
816-
alpha=render_params.fill_alpha,
801+
alpha=alpha,
817802
origin="lower",
818803
zorder=render_params.zorder,
819804
)
820805
_cax.set_transform(trans_data)
821806
cax = ax.add_image(_cax)
807+
return cax # noqa: RET504
808+
809+
# default case: no contour, just fill
810+
# since contour_px is passed to skimage.morphology.erosion to create the contour,
811+
# any border thickness is only within the label, not outside. Therefore, the case
812+
# of fill_alpha == outline_alpha is equivalent to fill-only
813+
if (render_params.fill_alpha > 0.0 and render_params.outline_alpha == 0.0) or (
814+
render_params.fill_alpha == render_params.outline_alpha
815+
):
816+
cax = _draw_labels(seg_erosionpx=None, seg_boundaries=False, alpha=render_params.fill_alpha)
822817
alpha_to_decorate_ax = render_params.fill_alpha
823818

824819
# outline-only case
825-
if render_params.fill_alpha == 0.0 and render_params.outline_alpha != 0.0:
826-
labels_contour = _map_color_seg(
827-
seg=label.values,
828-
cell_id=instance_id,
829-
color_vector=color_vector,
830-
color_source_vector=color_source_vector,
831-
cmap_params=render_params.cmap_params,
832-
seg_erosionpx=render_params.contour_px,
833-
seg_boundaries=True,
834-
na_color=render_params.cmap_params.na_color,
835-
na_color_modified_by_user=render_params.cmap_params.na_color_modified_by_user,
820+
elif render_params.fill_alpha == 0.0 and render_params.outline_alpha > 0.0:
821+
cax = _draw_labels(
822+
seg_erosionpx=render_params.contour_px, seg_boundaries=True, alpha=render_params.outline_alpha
836823
)
837-
_cax = ax.imshow(
838-
labels_contour,
839-
rasterized=True,
840-
cmap=None if categorical else render_params.cmap_params.cmap,
841-
norm=None if categorical else render_params.cmap_params.norm,
842-
alpha=render_params.outline_alpha,
843-
origin="lower",
844-
zorder=render_params.zorder,
845-
)
846-
_cax.set_transform(trans_data)
847-
cax = ax.add_image(_cax)
848824
alpha_to_decorate_ax = render_params.outline_alpha
849825

850826
# pretty case: both outline and infill
851-
if (
852-
render_params.fill_alpha > 0.0
853-
and render_params.outline_alpha > 0.0
854-
and render_params.fill_alpha != render_params.outline_alpha
855-
):
827+
elif render_params.fill_alpha > 0.0 and render_params.outline_alpha > 0.0:
856828
# first plot the infill ...
857-
label_infill = _map_color_seg(
858-
seg=label.values,
859-
cell_id=instance_id,
860-
color_vector=color_vector,
861-
color_source_vector=color_source_vector,
862-
cmap_params=render_params.cmap_params,
863-
seg_erosionpx=None,
864-
seg_boundaries=False,
865-
na_color=render_params.cmap_params.na_color,
866-
na_color_modified_by_user=render_params.cmap_params.na_color_modified_by_user,
867-
)
868-
869-
_cax_infill = ax.imshow(
870-
label_infill,
871-
rasterized=True,
872-
cmap=None if categorical else render_params.cmap_params.cmap,
873-
norm=None if categorical else render_params.cmap_params.norm,
874-
alpha=render_params.fill_alpha,
875-
origin="lower",
876-
zorder=render_params.zorder,
877-
)
878-
_cax_infill.set_transform(trans_data)
879-
cax_infill = ax.add_image(_cax_infill)
829+
cax_infill = _draw_labels(seg_erosionpx=None, seg_boundaries=False, alpha=render_params.fill_alpha)
880830

881831
# ... then overlay the contour
882-
label_contour = _map_color_seg(
883-
seg=label.values,
884-
cell_id=instance_id,
885-
color_vector=color_vector,
886-
color_source_vector=color_source_vector,
887-
cmap_params=render_params.cmap_params,
888-
seg_erosionpx=render_params.contour_px,
889-
seg_boundaries=True,
890-
na_color=render_params.cmap_params.na_color,
891-
na_color_modified_by_user=render_params.cmap_params.na_color_modified_by_user,
832+
cax_contour = _draw_labels(
833+
seg_erosionpx=render_params.contour_px, seg_boundaries=True, alpha=render_params.outline_alpha
892834
)
893835

894-
_cax_contour = ax.imshow(
895-
label_contour,
896-
rasterized=True,
897-
cmap=None if categorical else render_params.cmap_params.cmap,
898-
norm=None if categorical else render_params.cmap_params.norm,
899-
alpha=render_params.outline_alpha,
900-
origin="lower",
901-
zorder=render_params.zorder,
902-
)
903-
_cax_contour.set_transform(trans_data)
904-
cax_contour = ax.add_image(_cax_contour)
905-
906836
# pass the less-transparent _cax for the legend
907837
cax = cax_infill if render_params.fill_alpha > render_params.outline_alpha else cax_contour
908838
alpha_to_decorate_ax = max(render_params.fill_alpha, render_params.outline_alpha)
909839

840+
else:
841+
raise ValueError("Parameters 'fill_alpha' and 'outline_alpha' cannot both be 0.")
842+
910843
_ = _decorate_axs(
911844
ax=ax,
912845
cax=cax,

0 commit comments

Comments
 (0)