From 942be7bdf18f17e6d08f7d631d098e612d7ced29 Mon Sep 17 00:00:00 2001 From: anon Date: Mon, 22 Jun 2026 02:17:50 +0200 Subject: [PATCH 1/2] perf(render_images): stream-accumulate multi-channel composite Five multi-channel composite branches built the full (n_channels, H, W, 4) float64 cube via np.stack([...], 0).sum(0) before reducing (peak ~64*n*H*W bytes; GB-scale on large multiplex images). Replace with a streaming helper _composite_channels that accumulates per-channel RGB into one (H, W, 3) buffer (O(1) full buffers), mirroring the existing n>3 branch. Byte-identical: numpy sum over the stacked outer axis reduces sequentially, matching acc += rgb. Verified end-to-end main-vs-branch (8 cases, max|delta|=0) and via the regression test across channel counts incl. >128. Peak memory 3357MB -> 26MB (129x) at 200 channels @512x512. The n>3 premultiplied-alpha branch and the RGB-direct branch keep their own math and are left untouched. --- src/spatialdata_plot/pl/render.py | 49 +++++++++++++------------------ tests/pl/test_render_images.py | 34 +++++++++++++++++++-- 2 files changed, 53 insertions(+), 30 deletions(-) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index b3a243e4..4c4f318d 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -1709,6 +1709,22 @@ def _draw_channel_legend( ) +def _composite_channels( + channel_cmaps: list[Colormap], + layers: dict[Any, np.ndarray], + channels: list[Any], +) -> np.ndarray: + """Sum per-channel RGB into one ``(H, W, 3)`` buffer. + + Holds O(1) full-resolution buffers instead of the full ``(n_channels, H, W, 4)`` cube; + byte-identical to stacking then summing because the reduction is sequential. + """ + acc = channel_cmaps[0](layers[channels[0]])[:, :, :3].copy() + for cmap, ch in zip(channel_cmaps[1:], channels[1:], strict=True): + acc += cmap(layers[ch])[:, :, :3] + return acc + + def _render_images( sdata: sd.SpatialData, render_params: ImageRenderParams, @@ -2008,14 +2024,7 @@ def _render_images( legend_colors = ["red", "green", "blue"] else: # -> use given cmap for each channel channel_cmaps = [render_params.cmap_params.cmap] * n_channels - stacked = ( - np.stack( - [channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], - 0, - ).sum(0) - / n_channels - ) - stacked = stacked[:, :, :3] + stacked = _composite_channels(channel_cmaps, layers, channels) / n_channels logger.warning( "One cmap was given for multiple channels and is now used for each channel. " + _MULTI_CMAP_BLENDING_WARNING @@ -2036,19 +2045,11 @@ def _render_images( if n_channels == 2: seed_colors = ["#ff0000ff", "#00ff00ff"] channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors] - colored = np.stack( - [channel_cmaps[ch_ind](layers[ch]) for ch_ind, ch in enumerate(channels)], - 0, - ).sum(0) - colored = np.clip(colored[:, :, :3], 0, 1) + colored = np.clip(_composite_channels(channel_cmaps, layers, channels), 0, 1) elif n_channels == 3: seed_colors = _get_colors_for_categorical_obs(list(range(n_channels))) channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors] - colored = np.stack( - [channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], - 0, - ).sum(0) - colored = np.clip(colored[:, :, :3], 0, 1) + colored = np.clip(_composite_channels(channel_cmaps, layers, channels), 0, 1) else: if isinstance(render_params.cmap_params, list): cmap_is_default = render_params.cmap_params[0].cmap_is_default @@ -2102,8 +2103,7 @@ def _render_images( raise ValueError("If 'palette' is provided, its length must match the number of channels.") channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in palette if isinstance(c, str)] - colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0) - colored = np.clip(colored[:, :, :3], 0, 1) + colored = np.clip(_composite_channels(channel_cmaps, layers, channels), 0, 1) legend_colors = list(palette) @@ -2118,14 +2118,7 @@ def _render_images( elif palette is None and got_multiple_cmaps: channel_cmaps = [cp.cmap for cp in render_params.cmap_params] # type: ignore[union-attr] - colored = ( - np.stack( - [channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], - 0, - ).sum(0) - / n_channels - ) - colored = colored[:, :, :3] + colored = _composite_channels(channel_cmaps, layers, channels) / n_channels legend_colors = [matplotlib.colors.to_hex(cm(0.75)) for cm in channel_cmaps] diff --git a/tests/pl/test_render_images.py b/tests/pl/test_render_images.py index 643e8575..5cb53dbc 100644 --- a/tests/pl/test_render_images.py +++ b/tests/pl/test_render_images.py @@ -4,7 +4,7 @@ import numpy as np import pytest import scanpy as sc -from matplotlib.colors import LogNorm, Normalize +from matplotlib.colors import LinearSegmentedColormap, LogNorm, Normalize from spatial_image import to_spatial_image from spatialdata import SpatialData from spatialdata.models import Image2DModel, Image3DModel @@ -12,7 +12,7 @@ import spatialdata_plot # noqa: F401 from spatialdata_plot import PercentileNormalize from spatialdata_plot._logging import logger, logger_no_warns, logger_warns -from spatialdata_plot.pl.render import _is_rgb_image +from spatialdata_plot.pl.render import _composite_channels, _is_rgb_image from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over sc.pl.set_rcParams_defaults() @@ -937,3 +937,33 @@ def _render_and_grab(**kwargs): plt.close(fig) np.testing.assert_array_equal(_render_and_grab(), _render_and_grab(method="matplotlib")) + + +class TestCompositeChannels: + """`_composite_channels` must be byte-identical to the materialized stack-and-sum it replaces.""" + + @staticmethod + def _stack_sum(channel_cmaps, layers, channels): + return np.stack([channel_cmaps[i](layers[ch]) for i, ch in enumerate(channels)], 0).sum(0)[:, :, :3] + + @pytest.mark.parametrize("n", [2, 3, 129, 200]) # 129/200 cross numpy's 128-element pairwise threshold + def test_streaming_matches_stack_sum(self, n: int): + rng = np.random.default_rng(0) + channels = list(range(n)) + layers = {ch: rng.random((32, 24)) for ch in channels} + cmaps = [LinearSegmentedColormap.from_list("x", ["k", rng.random(3)], N=256) for _ in channels] + + result = _composite_channels(cmaps, layers, channels) + + np.testing.assert_array_equal(result, self._stack_sum(cmaps, layers, channels)) + + def test_returns_owned_rgb_buffer(self): + rng = np.random.default_rng(1) + channels = [0, 1] + layers = {ch: rng.random((8, 8)) for ch in channels} + cmaps = [LinearSegmentedColormap.from_list("x", ["k", "r"], N=256) for _ in channels] + + result = _composite_channels(cmaps, layers, channels) + assert result.shape == (8, 8, 3) + assert result.dtype == np.float64 + assert result.flags["C_CONTIGUOUS"] # owns its buffer -> the (H,W,4) cmap temps are freed each step From 15c6c901fb945a5592adbf2bce0497dba9dbad5f Mon Sep 17 00:00:00 2001 From: anon Date: Mon, 22 Jun 2026 03:11:21 +0200 Subject: [PATCH 2/2] chore(render_shapes): remove dead `table` local MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `table` was assigned (None or sdata_filt[table_name]) but never read — color resolution goes through sdata_filt + table_name, not this local. Drop both assignments, keep the _check_instance_ids_overlap validation, and invert the guard. Clears a pre-existing ruff F841. --- src/spatialdata_plot/pl/render.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 4c4f318d..8aa46ca0 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -618,13 +618,10 @@ def _render_shapes( ) table_name = render_params.table_name - if table_name is None: - table = None - else: + if table_name is not None: # No join/copy: _set_color_source_vec resolves each shape's color from the table (region-masked # and reindexed to the element), so unannotated shapes keep their place and render with na_color. _check_instance_ids_overlap(sdata_filt, table_name, element, sdata_filt[element].index) - table = sdata_filt[table_name] shapes = sdata_filt[element]