Skip to content
Merged
69 changes: 67 additions & 2 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3194,6 +3194,69 @@ def _prepare_transformation(
return trans, trans_data


def _apply_cmap_alpha_to_datashader_result(
result: Any,
agg: DataArray,
cmap: str | list[str] | Colormap,
span: list[float] | tuple[float, float] | None,
) -> Any:
"""Apply the colormap's alpha channel to a datashader RGBA result.

Datashader ignores the per-entry alpha channel of matplotlib colormaps,
so pixels that the cmap marks as transparent (alpha=0) are rendered
opaque. This function post-processes the shaded RGBA output to restore
the cmap's intended transparency. See :issue:`376`.
"""
if not isinstance(cmap, Colormap):
return result

# Quick check: does this cmap have any transparent entries?
test_vals = np.linspace(0, 1, min(cmap.N, 256))
cmap_alphas = cmap(test_vals)[:, 3]
if np.all(cmap_alphas >= 1.0):
return result

# Get or ensure we have an (H, W, 4) uint8 array
if hasattr(result, "values"):
# datashader Image — uint32 packed, convert via to_numpy()
rgba = result.to_numpy().base
if rgba is None:
return result
else:
rgba = result

if rgba.ndim != 3 or rgba.shape[2] != 4:
return result

# Normalise aggregate values to [0, 1] using the same span datashader used
agg_vals = agg.values.astype(np.float64)
valid = np.isfinite(agg_vals)
if not valid.any():
return result

if span is not None:
lo, hi = float(span[0]), float(span[1])
else:
lo = float(np.nanmin(agg_vals))
hi = float(np.nanmax(agg_vals))

if hi <= lo or not np.isfinite(lo) or not np.isfinite(hi):
return result

normed = np.clip((agg_vals - lo) / (hi - lo), 0.0, 1.0)

# Look up cmap alpha for each pixel
desired_alpha = cmap(normed)[:, :, 3]

# Zero out pixels where the cmap wants transparency
transparent = valid & (desired_alpha < 1.0)
if transparent.any():
# Scale the existing alpha by the cmap's alpha
rgba[transparent, 3] = (rgba[transparent, 3].astype(np.float32) * desired_alpha[transparent]).astype(np.uint8)

return result


def _datashader_map_aggregate_to_color(
agg: DataArray,
cmap: str | list[str] | ListedColormap,
Expand Down Expand Up @@ -3245,16 +3308,18 @@ def _datashader_map_aggregate_to_color(
img_over = img_over.to_numpy().base
if img_over is not None:
stack[stack[:, :, 3] == 0] = img_over[stack[:, :, 3] == 0]
return stack

return ds.tf.shade(
return _apply_cmap_alpha_to_datashader_result(stack, agg, cmap, span)

result = ds.tf.shade(
agg,
cmap=cmap,
color_key=color_key,
min_alpha=min_alpha,
span=span,
how="linear",
)
return _apply_cmap_alpha_to_datashader_result(result, agg, cmap, span)


def _hex_no_alpha(hex: str) -> str:
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
126 changes: 123 additions & 3 deletions tests/pl/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@
import pandas as pd
import pytest
import scanpy as sc
import xarray as xr
from spatialdata import SpatialData

import spatialdata_plot
from spatialdata_plot.pl.utils import _get_subplots
from spatialdata_plot.pl.utils import (
_apply_cmap_alpha_to_datashader_result,
_datashader_map_aggregate_to_color,
_get_subplots,
set_zero_in_cmap_to_transparent,
)
from tests.conftest import DPI, PlotTester, PlotTesterMeta

sc.pl.set_rcParams_defaults()
Expand Down Expand Up @@ -52,8 +58,6 @@ def test_plot_colnames_that_are_valid_matplotlib_greyscale_colors_are_not_evalua
sdata_blobs.pl.render_shapes("blobs_polygons", color=colname).pl.show()

def test_plot_can_set_zero_in_cmap_to_transparent(self, sdata_blobs: SpatialData):
from spatialdata_plot.pl.utils import set_zero_in_cmap_to_transparent

# set up figure and modify the data to add 0s
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")
sdata_blobs.tables["table"].obs["my_var"] = list(range(len(sdata_blobs.tables["table"].obs)))
Expand All @@ -73,6 +77,49 @@ def test_plot_can_set_zero_in_cmap_to_transparent(self, sdata_blobs: SpatialData
ax=axs[1], colorbar=False
)

def _render_transparent_cmap_shapes(self, sdata_blobs: SpatialData, method: str):
_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")
new_cmap = set_zero_in_cmap_to_transparent(cmap="viridis")
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
sdata_blobs.shapes["blobs_polygons"]["value"] = [0.0, 2.0, 3.0, 4.0, 5.0]

# left: baseline with standard viridis
sdata_blobs.pl.render_images("blobs_image").pl.render_shapes(
"blobs_polygons", color="value", cmap="viridis", method=method
).pl.show(ax=axs[0], colorbar=False)

# right: transparent cmap — shape with value=0 should reveal the image
sdata_blobs.pl.render_images("blobs_image").pl.render_shapes(
"blobs_polygons", color="value", cmap=new_cmap, method=method
).pl.show(ax=axs[1], colorbar=False)

def test_plot_transparent_cmap_shapes_matplotlib(self, sdata_blobs: SpatialData):
self._render_transparent_cmap_shapes(sdata_blobs, method="matplotlib")

def test_plot_transparent_cmap_shapes_datashader(self, sdata_blobs: SpatialData):
self._render_transparent_cmap_shapes(sdata_blobs, method="datashader")

def test_plot_transparent_cmap_shapes_clip_false(self, sdata_blobs: SpatialData):
"""Transparent cmap with clip=False norm (3-part shading path)."""
from matplotlib.colors import Normalize

_, axs = plt.subplots(nrows=1, ncols=2, layout="tight")
new_cmap = set_zero_in_cmap_to_transparent(cmap="viridis")
norm = Normalize(vmin=0, vmax=5, clip=False)

sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs)
sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons"
sdata_blobs.shapes["blobs_polygons"]["value"] = [0.0, 2.0, 3.0, 4.0, 5.0]

sdata_blobs.pl.render_images("blobs_image").pl.render_shapes(
"blobs_polygons", color="value", cmap="viridis", norm=norm, method="datashader"
).pl.show(ax=axs[0], colorbar=False)

sdata_blobs.pl.render_images("blobs_image").pl.render_shapes(
"blobs_polygons", color="value", cmap=new_cmap, norm=norm, method="datashader"
).pl.show(ax=axs[1], colorbar=False)


@pytest.mark.parametrize(
"color_result",
Expand All @@ -90,6 +137,79 @@ def test_is_color_like(color_result: tuple[ColorLike, bool]):
assert spatialdata_plot.pl.utils._is_color_like(color) == result


class TestCmapAlphaDatashader:
"""Regression tests for #376: set_zero_in_cmap_to_transparent with datashader."""

def test_transparent_pixels_get_alpha_zero(self):
"""Post-processing sets alpha=0 for pixels mapping to transparent cmap entries."""
import datashader as ds

cmap = set_zero_in_cmap_to_transparent("viridis")
data = np.array([[0.0, 5.0, 10.0]], dtype=np.float64)
agg = xr.DataArray(data, dims=["y", "x"])

shaded = ds.tf.shade(agg, cmap=cmap, min_alpha=254, how="linear")
result = _apply_cmap_alpha_to_datashader_result(shaded, agg, cmap, span=[0.0, 10.0])
rgba = result.to_numpy().base if hasattr(result, "to_numpy") else result

assert rgba[0, 0, 3] == 0, f"Expected alpha=0 at value=0.0, got {rgba[0, 0, 3]}"
assert rgba[0, 1, 3] > 0, "Expected non-zero alpha at value=5.0"
assert rgba[0, 2, 3] > 0, "Expected non-zero alpha at value=10.0"

def test_opaque_cmap_unchanged(self):
"""Post-processing is a no-op for fully opaque cmaps."""
import datashader as ds

cmap = plt.get_cmap("viridis")
data = np.array([[0.0, 5.0, 10.0]], dtype=np.float64)
agg = xr.DataArray(data, dims=["y", "x"])

shaded = ds.tf.shade(agg, cmap=cmap, min_alpha=254, how="linear")
rgba_before = shaded.to_numpy().base.copy()
result = _apply_cmap_alpha_to_datashader_result(shaded, agg, cmap, span=[0.0, 10.0])
rgba_after = result.to_numpy().base if hasattr(result, "to_numpy") else result
np.testing.assert_array_equal(rgba_before, rgba_after)

def test_string_cmap_passthrough(self):
"""Post-processing is a no-op for string cmaps (early return)."""
dummy_rgba = np.zeros((2, 3, 4), dtype=np.uint8)
dummy_rgba[:, :, 3] = 200
data = np.array([[0.0, 5.0, 10.0]], dtype=np.float64)
agg = xr.DataArray(data, dims=["y", "x"])

result = _apply_cmap_alpha_to_datashader_result(dummy_rgba, agg, "viridis", span=[0.0, 10.0])
np.testing.assert_array_equal(result, dummy_rgba)

def test_end_to_end_datashader_map(self):
"""_datashader_map_aggregate_to_color produces alpha=0 for transparent cmap entries."""
cmap = set_zero_in_cmap_to_transparent("viridis")
data = np.array([[0.0, 5.0, 10.0]], dtype=np.float64)
agg = xr.DataArray(data, dims=["y", "x"])

result = _datashader_map_aggregate_to_color(agg, cmap=cmap, min_alpha=254, span=[0.0, 10.0])
img = result.to_numpy().base if hasattr(result, "to_numpy") else result

assert img[0, 0, 3] == 0, f"Expected alpha=0 at value=0.0, got {img[0, 0, 3]}"
assert img[0, 1, 3] > 0, "Expected non-zero alpha at value=5.0"

def test_span_none_preserves_colors(self):
"""With span=None, non-transparent shapes keep their correct colors."""
cmap = set_zero_in_cmap_to_transparent("viridis")
data = np.array([[0.0, 5.0, 10.0]], dtype=np.float64)
agg = xr.DataArray(data, dims=["y", "x"])

result = _datashader_map_aggregate_to_color(agg, cmap=cmap, min_alpha=254)
img = result.to_numpy().base if hasattr(result, "to_numpy") else result

# value=0 should be transparent
assert img[0, 0, 3] == 0
# value=5 and value=10 should be opaque with correct viridis colors (not white)
assert img[0, 1, 3] > 0
assert img[0, 2, 3] > 0
# The non-transparent pixels should NOT be white (R=255,G=255,B=255)
assert not (img[0, 1, 0] == 255 and img[0, 1, 1] == 255 and img[0, 1, 2] == 255)


def test_extract_scalar_value():
"""Test the new _extract_scalar_value function for robust numeric conversion."""

Expand Down
Loading