Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/spatialdata_plot/pl/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,16 @@ def to_rgba(self, cmap_params: CmapParams) -> np.ndarray:
via norm+cmap with NaN/non-finite rows painted ``na_color``; an object vector mixes the two.
"""
if self.source_vector is not None: # categorical or none: color_vector holds per-row hex
return np.asarray(colors.to_rgba_array(list(self.color_vector)))
cv = self.color_vector
if isinstance(getattr(cv, "dtype", None), pd.CategoricalDtype):
# categories are hex (resolution fills NaN with an na_color category -> codes >= 0),
# so parse the few categories once and gather by code instead of parsing every row
lut = np.asarray(colors.to_rgba_array(cv.categories.to_numpy()))
return lut[cv.codes]
# object vector (align_to_length pad / uniform na): factorize the distinct colours and
# gather back (factorize keeps any NaN as a real code, so the gather stays in-bounds)
codes, uniq = pd.factorize(np.asarray(cv, dtype=object), sort=False, use_na_sentinel=False)
return np.asarray(colors.to_rgba_array(list(uniq)))[codes]
arr = np.asarray(self.color_vector)
if arr.ndim == 2 and arr.shape[1] in (3, 4) and np.issubdtype(arr.dtype, np.number):
return np.asarray(colors.to_rgba_array(arr))
Expand Down
22 changes: 22 additions & 0 deletions tests/pl/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,28 @@ def test_precomputed_rgba_passthrough(self):
arr = np.array([[1.0, 0.0, 0.0, 1.0], [0.0, 0.0, 1.0, 1.0]])
np.testing.assert_allclose(ColorSpec("continuous", None, arr).to_rgba(self._params()), arr)

def test_codes_gather_matches_per_row_parse(self):
# the categorical codes-gather and the object factorize-gather must equal to_rgba_array(list(...))
from matplotlib import colors

from spatialdata_plot.pl._color import ColorSpec

hexes = ["#e41a1cff", "#377eb8ff", "#4daf4aff", "#984ea3ff"]
na = "#cccccc00"
rng = np.random.default_rng(0)
clean = pd.Categorical.from_codes(rng.integers(0, len(hexes), 2000), categories=hexes)
with_na = pd.Categorical.from_codes(rng.integers(0, len(hexes) + 1, 2000), categories=[*hexes, na])
variants = [
clean, # categorical fast-path
with_na, # NaN replaced with an na_color category -> codes still >= 0
clean[:800].remove_unused_categories(), # filtered -> remapped codes
np.full(500, na, dtype=object), # uniform na (object vector)
np.concatenate([np.asarray(list(clean[:300]), dtype=object), np.full(100, na, dtype=object)]), # padded
]
for cv in variants: # cv doubles as the (only-checked-for-not-None) source_vector
spec = ColorSpec("categorical" if hasattr(cv, "codes") else "none", cv, cv)
np.testing.assert_array_equal(spec.to_rgba(self._params()), np.asarray(colors.to_rgba_array(list(cv))))


class TestPercentileNormalize:
"""PercentileNormalize + _resolve_continuous_norm (issue #370: dim multichannel renders)."""
Expand Down
Loading