Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 22 additions & 32 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,13 +618,10 @@ def _render_shapes(
)

table_name = render_params.table_name
if table_name is None:
table = None
else:
if table_name is not None:
# No join/copy: _set_color_source_vec resolves each shape's color from the table (region-masked
# and reindexed to the element), so unannotated shapes keep their place and render with na_color.
_check_instance_ids_overlap(sdata_filt, table_name, element, sdata_filt[element].index)
table = sdata_filt[table_name]

shapes = sdata_filt[element]

Expand Down Expand Up @@ -1709,6 +1706,22 @@ def _draw_channel_legend(
)


def _composite_channels(
channel_cmaps: list[Colormap],
layers: dict[Any, np.ndarray],
channels: list[Any],
) -> np.ndarray:
"""Sum per-channel RGB into one ``(H, W, 3)`` buffer.

Holds O(1) full-resolution buffers instead of the full ``(n_channels, H, W, 4)`` cube;
byte-identical to stacking then summing because the reduction is sequential.
"""
acc = channel_cmaps[0](layers[channels[0]])[:, :, :3].copy()
for cmap, ch in zip(channel_cmaps[1:], channels[1:], strict=True):
acc += cmap(layers[ch])[:, :, :3]
return acc


def _render_images(
sdata: sd.SpatialData,
render_params: ImageRenderParams,
Expand Down Expand Up @@ -2008,14 +2021,7 @@ def _render_images(
legend_colors = ["red", "green", "blue"]
else: # -> use given cmap for each channel
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
stacked = (
np.stack(
[channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)],
0,
).sum(0)
/ n_channels
)
stacked = stacked[:, :, :3]
stacked = _composite_channels(channel_cmaps, layers, channels) / n_channels
logger.warning(
"One cmap was given for multiple channels and is now used for each channel. "
+ _MULTI_CMAP_BLENDING_WARNING
Expand All @@ -2036,19 +2042,11 @@ def _render_images(
if n_channels == 2:
seed_colors = ["#ff0000ff", "#00ff00ff"]
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
colored = np.stack(
[channel_cmaps[ch_ind](layers[ch]) for ch_ind, ch in enumerate(channels)],
0,
).sum(0)
colored = np.clip(colored[:, :, :3], 0, 1)
colored = np.clip(_composite_channels(channel_cmaps, layers, channels), 0, 1)
elif n_channels == 3:
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
colored = np.stack(
[channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)],
0,
).sum(0)
colored = np.clip(colored[:, :, :3], 0, 1)
colored = np.clip(_composite_channels(channel_cmaps, layers, channels), 0, 1)
else:
if isinstance(render_params.cmap_params, list):
cmap_is_default = render_params.cmap_params[0].cmap_is_default
Expand Down Expand Up @@ -2102,8 +2100,7 @@ def _render_images(
raise ValueError("If 'palette' is provided, its length must match the number of channels.")

channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in palette if isinstance(c, str)]
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)
colored = np.clip(colored[:, :, :3], 0, 1)
colored = np.clip(_composite_channels(channel_cmaps, layers, channels), 0, 1)

legend_colors = list(palette)

Expand All @@ -2118,14 +2115,7 @@ def _render_images(

elif palette is None and got_multiple_cmaps:
channel_cmaps = [cp.cmap for cp in render_params.cmap_params] # type: ignore[union-attr]
colored = (
np.stack(
[channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)],
0,
).sum(0)
/ n_channels
)
colored = colored[:, :, :3]
colored = _composite_channels(channel_cmaps, layers, channels) / n_channels

legend_colors = [matplotlib.colors.to_hex(cm(0.75)) for cm in channel_cmaps]

Expand Down
34 changes: 32 additions & 2 deletions tests/pl/test_render_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import numpy as np
import pytest
import scanpy as sc
from matplotlib.colors import LogNorm, Normalize
from matplotlib.colors import LinearSegmentedColormap, LogNorm, Normalize
from spatial_image import to_spatial_image
from spatialdata import SpatialData
from spatialdata.models import Image2DModel, Image3DModel

import spatialdata_plot # noqa: F401
from spatialdata_plot import PercentileNormalize
from spatialdata_plot._logging import logger, logger_no_warns, logger_warns
from spatialdata_plot.pl.render import _is_rgb_image
from spatialdata_plot.pl.render import _composite_channels, _is_rgb_image
from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over

sc.pl.set_rcParams_defaults()
Expand Down Expand Up @@ -937,3 +937,33 @@ def _render_and_grab(**kwargs):
plt.close(fig)

np.testing.assert_array_equal(_render_and_grab(), _render_and_grab(method="matplotlib"))


class TestCompositeChannels:
"""`_composite_channels` must be byte-identical to the materialized stack-and-sum it replaces."""

@staticmethod
def _stack_sum(channel_cmaps, layers, channels):
return np.stack([channel_cmaps[i](layers[ch]) for i, ch in enumerate(channels)], 0).sum(0)[:, :, :3]

@pytest.mark.parametrize("n", [2, 3, 129, 200]) # 129/200 cross numpy's 128-element pairwise threshold
def test_streaming_matches_stack_sum(self, n: int):
rng = np.random.default_rng(0)
channels = list(range(n))
layers = {ch: rng.random((32, 24)) for ch in channels}
cmaps = [LinearSegmentedColormap.from_list("x", ["k", rng.random(3)], N=256) for _ in channels]

result = _composite_channels(cmaps, layers, channels)

np.testing.assert_array_equal(result, self._stack_sum(cmaps, layers, channels))

def test_returns_owned_rgb_buffer(self):
rng = np.random.default_rng(1)
channels = [0, 1]
layers = {ch: rng.random((8, 8)) for ch in channels}
cmaps = [LinearSegmentedColormap.from_list("x", ["k", "r"], N=256) for _ in channels]

result = _composite_channels(cmaps, layers, channels)
assert result.shape == (8, 8, 3)
assert result.dtype == np.float64
assert result.flags["C_CONTIGUOUS"] # owns its buffer -> the (H,W,4) cmap temps are freed each step
Loading