diff --git a/src/spatialdata_plot/pl/_color.py b/src/spatialdata_plot/pl/_color.py index 6181ac94..4db897ac 100644 --- a/src/spatialdata_plot/pl/_color.py +++ b/src/spatialdata_plot/pl/_color.py @@ -777,7 +777,16 @@ def to_rgba(self, cmap_params: CmapParams) -> np.ndarray: via norm+cmap with NaN/non-finite rows painted ``na_color``; an object vector mixes the two. """ if self.source_vector is not None: # categorical or none: color_vector holds per-row hex - return np.asarray(colors.to_rgba_array(list(self.color_vector))) + cv = self.color_vector + if isinstance(getattr(cv, "dtype", None), pd.CategoricalDtype): + # categories are hex (resolution fills NaN with an na_color category -> codes >= 0), + # so parse the few categories once and gather by code instead of parsing every row + lut = np.asarray(colors.to_rgba_array(cv.categories.to_numpy())) + return lut[cv.codes] + # object vector (align_to_length pad / uniform na): factorize the distinct colours and + # gather back (factorize keeps any NaN as a real code, so the gather stays in-bounds) + codes, uniq = pd.factorize(np.asarray(cv, dtype=object), sort=False, use_na_sentinel=False) + return np.asarray(colors.to_rgba_array(list(uniq)))[codes] arr = np.asarray(self.color_vector) if arr.ndim == 2 and arr.shape[1] in (3, 4) and np.issubdtype(arr.dtype, np.number): return np.asarray(colors.to_rgba_array(arr)) diff --git a/tests/pl/test_utils.py b/tests/pl/test_utils.py index 2960a0d2..def6d94c 100644 --- a/tests/pl/test_utils.py +++ b/tests/pl/test_utils.py @@ -1127,6 +1127,28 @@ def test_precomputed_rgba_passthrough(self): arr = np.array([[1.0, 0.0, 0.0, 1.0], [0.0, 0.0, 1.0, 1.0]]) np.testing.assert_allclose(ColorSpec("continuous", None, arr).to_rgba(self._params()), arr) + def test_codes_gather_matches_per_row_parse(self): + # the categorical codes-gather and the object factorize-gather must equal to_rgba_array(list(...)) + from matplotlib import colors + + from spatialdata_plot.pl._color import ColorSpec + + hexes = ["#e41a1cff", "#377eb8ff", "#4daf4aff", "#984ea3ff"] + na = "#cccccc00" + rng = np.random.default_rng(0) + clean = pd.Categorical.from_codes(rng.integers(0, len(hexes), 2000), categories=hexes) + with_na = pd.Categorical.from_codes(rng.integers(0, len(hexes) + 1, 2000), categories=[*hexes, na]) + variants = [ + clean, # categorical fast-path + with_na, # NaN replaced with an na_color category -> codes still >= 0 + clean[:800].remove_unused_categories(), # filtered -> remapped codes + np.full(500, na, dtype=object), # uniform na (object vector) + np.concatenate([np.asarray(list(clean[:300]), dtype=object), np.full(100, na, dtype=object)]), # padded + ] + for cv in variants: # cv doubles as the (only-checked-for-not-None) source_vector + spec = ColorSpec("categorical" if hasattr(cv, "codes") else "none", cv, cv) + np.testing.assert_array_equal(spec.to_rgba(self._params()), np.asarray(colors.to_rgba_array(list(cv)))) + class TestPercentileNormalize: """PercentileNormalize + _resolve_continuous_norm (issue #370: dim multichannel renders)."""