@@ -338,7 +338,7 @@ def _render_shapes(
338338 cax = None
339339 if aggregate_with_reduction is not None :
340340 vmin = aggregate_with_reduction [0 ].values if norm .vmin is None else norm .vmin
341- vmax = aggregate_with_reduction [1 ].values if norm .vmin is None else norm .vmax
341+ vmax = aggregate_with_reduction [1 ].values if norm .vmax is None else norm .vmax
342342 if (norm .vmin is not None or norm .vmax is not None ) and norm .vmin == norm .vmax :
343343 # value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and
344344 # under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1)
@@ -846,20 +846,22 @@ def _render_images(
846846 # 2) Image has any number of channels but 1
847847 else :
848848 layers = {}
849- for ch_index , c in enumerate (channels ):
850- layers [c ] = img .sel (c = c ).copy (deep = True ).squeeze ()
851-
852- if not isinstance (render_params .cmap_params , list ):
853- if render_params .cmap_params .norm is not None :
854- layers [c ] = render_params .cmap_params .norm (layers [c ])
849+ for ch_idx , ch in enumerate (channels ):
850+ layers [ch ] = img .sel (c = ch ).copy (deep = True ).squeeze ()
851+ if isinstance (render_params .cmap_params , list ):
852+ ch_norm = render_params .cmap_params [ch_idx ].norm
853+ ch_cmap_is_default = render_params .cmap_params [ch_idx ].cmap_is_default
855854 else :
856- if render_params .cmap_params [ch_index ].norm is not None :
857- layers [c ] = render_params .cmap_params [ch_index ].norm (layers [c ])
855+ ch_norm = render_params .cmap_params .norm
856+ ch_cmap_is_default = render_params .cmap_params .cmap_is_default
857+
858+ if not ch_cmap_is_default and ch_norm is not None :
859+ layers [ch_idx ] = ch_norm (layers [ch_idx ])
858860
859861 # 2A) Image has 3 channels, no palette info, and no/only one cmap was given
860862 if palette is None and n_channels == 3 and not isinstance (render_params .cmap_params , list ):
861863 if render_params .cmap_params .cmap_is_default : # -> use RGB
862- stacked = np .stack ([layers [c ] for c in channels ], axis = - 1 )
864+ stacked = np .stack ([layers [ch ] for ch in layers ], axis = - 1 )
863865 else : # -> use given cmap for each channel
864866 channel_cmaps = [render_params .cmap_params .cmap ] * n_channels
865867 stacked = (
@@ -892,12 +894,54 @@ def _render_images(
892894 # overwrite if n_channels == 2 for intuitive result
893895 if n_channels == 2 :
894896 seed_colors = ["#ff0000ff" , "#00ff00ff" ]
895- else :
897+ channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in seed_colors ]
898+ colored = np .stack (
899+ [channel_cmaps [ch_ind ](layers [ch ]) for ch_ind , ch in enumerate (channels )],
900+ 0 ,
901+ ).sum (0 )
902+ colored = colored [:, :, :3 ]
903+ elif n_channels == 3 :
896904 seed_colors = _get_colors_for_categorical_obs (list (range (n_channels )))
905+ channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in seed_colors ]
906+ colored = np .stack (
907+ [channel_cmaps [ind ](layers [ch ]) for ind , ch in enumerate (channels )],
908+ 0 ,
909+ ).sum (0 )
910+ colored = colored [:, :, :3 ]
911+ else :
912+ if isinstance (render_params .cmap_params , list ):
913+ cmap_is_default = render_params .cmap_params [0 ].cmap_is_default
914+ else :
915+ cmap_is_default = render_params .cmap_params .cmap_is_default
897916
898- channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in seed_colors ]
899- colored = np .stack ([channel_cmaps [ind ](layers [ch ]) for ind , ch in enumerate (channels )], 0 ).sum (0 )
900- colored = colored [:, :, :3 ]
917+ if cmap_is_default :
918+ seed_colors = _get_colors_for_categorical_obs (list (range (n_channels )))
919+ else :
920+ # Sample n_channels colors evenly from the colormap
921+ if isinstance (render_params .cmap_params , list ):
922+ seed_colors = [
923+ render_params .cmap_params [i ].cmap (i / (n_channels - 1 )) for i in range (n_channels )
924+ ]
925+ else :
926+ seed_colors = [render_params .cmap_params .cmap (i / (n_channels - 1 )) for i in range (n_channels )]
927+ channel_cmaps = [_get_linear_colormap ([c ], "k" )[0 ] for c in seed_colors ]
928+
929+ # Stack (n_channels, height, width) → (height*width, n_channels)
930+ H , W = next (iter (layers .values ())).shape
931+ comp_rgb = np .zeros ((H , W , 3 ), dtype = float )
932+
933+ # For each channel: map to RGBA, apply constant alpha, then add
934+ for ch_idx , ch in enumerate (channels ):
935+ layer_arr = layers [ch ]
936+ rgba = channel_cmaps [ch_idx ](layer_arr )
937+ rgba [..., 3 ] = render_params .alpha
938+ comp_rgb += rgba [..., :3 ] * rgba [..., 3 ][..., None ]
939+
940+ colored = np .clip (comp_rgb , 0 , 1 )
941+ logger .info (
942+ f"Your image has { n_channels } channels. Sampling categorical colors and using "
943+ f"multichannel strategy 'stack' to render."
944+ ) # TODO: update when pca is added as strategy
901945
902946 _ax_show_and_transform (
903947 colored ,
@@ -943,6 +987,7 @@ def _render_images(
943987 zorder = render_params .zorder ,
944988 )
945989
990+ # 2D) Image has n channels, no palette but cmap info
946991 elif palette is not None and got_multiple_cmaps :
947992 raise ValueError ("If 'palette' is provided, 'cmap' must be None." )
948993
0 commit comments