diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 6068533f..117ffbba 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -3194,6 +3194,69 @@ def _prepare_transformation( return trans, trans_data +def _apply_cmap_alpha_to_datashader_result( + result: Any, + agg: DataArray, + cmap: str | list[str] | Colormap, + span: list[float] | tuple[float, float] | None, +) -> Any: + """Apply the colormap's alpha channel to a datashader RGBA result. + + Datashader ignores the per-entry alpha channel of matplotlib colormaps, + so pixels that the cmap marks as transparent (alpha=0) are rendered + opaque. This function post-processes the shaded RGBA output to restore + the cmap's intended transparency. See :issue:`376`. + """ + if not isinstance(cmap, Colormap): + return result + + # Quick check: does this cmap have any transparent entries? + test_vals = np.linspace(0, 1, min(cmap.N, 256)) + cmap_alphas = cmap(test_vals)[:, 3] + if np.all(cmap_alphas >= 1.0): + return result + + # Get or ensure we have an (H, W, 4) uint8 array + if hasattr(result, "values"): + # datashader Image — uint32 packed, convert via to_numpy() + rgba = result.to_numpy().base + if rgba is None: + return result + else: + rgba = result + + if rgba.ndim != 3 or rgba.shape[2] != 4: + return result + + # Normalise aggregate values to [0, 1] using the same span datashader used + agg_vals = agg.values.astype(np.float64) + valid = np.isfinite(agg_vals) + if not valid.any(): + return result + + if span is not None: + lo, hi = float(span[0]), float(span[1]) + else: + lo = float(np.nanmin(agg_vals)) + hi = float(np.nanmax(agg_vals)) + + if hi <= lo or not np.isfinite(lo) or not np.isfinite(hi): + return result + + normed = np.clip((agg_vals - lo) / (hi - lo), 0.0, 1.0) + + # Look up cmap alpha for each pixel + desired_alpha = cmap(normed)[:, :, 3] + + # Zero out pixels where the cmap wants transparency + transparent = valid & (desired_alpha < 1.0) + if transparent.any(): + # Scale the existing alpha by the cmap's alpha + rgba[transparent, 3] = (rgba[transparent, 3].astype(np.float32) * desired_alpha[transparent]).astype(np.uint8) + + return result + + def _datashader_map_aggregate_to_color( agg: DataArray, cmap: str | list[str] | ListedColormap, @@ -3245,9 +3308,10 @@ def _datashader_map_aggregate_to_color( img_over = img_over.to_numpy().base if img_over is not None: stack[stack[:, :, 3] == 0] = img_over[stack[:, :, 3] == 0] - return stack - return ds.tf.shade( + return _apply_cmap_alpha_to_datashader_result(stack, agg, cmap, span) + + result = ds.tf.shade( agg, cmap=cmap, color_key=color_key, @@ -3255,6 +3319,7 @@ def _datashader_map_aggregate_to_color( span=span, how="linear", ) + return _apply_cmap_alpha_to_datashader_result(result, agg, cmap, span) def _hex_no_alpha(hex: str) -> str: diff --git a/tests/_images/Utils_transparent_cmap_shapes_clip_false.png b/tests/_images/Utils_transparent_cmap_shapes_clip_false.png new file mode 100644 index 00000000..ea15e0ad Binary files /dev/null and b/tests/_images/Utils_transparent_cmap_shapes_clip_false.png differ diff --git a/tests/_images/Utils_transparent_cmap_shapes_datashader.png b/tests/_images/Utils_transparent_cmap_shapes_datashader.png new file mode 100644 index 00000000..ea15e0ad Binary files /dev/null and b/tests/_images/Utils_transparent_cmap_shapes_datashader.png differ diff --git a/tests/_images/Utils_transparent_cmap_shapes_matplotlib.png b/tests/_images/Utils_transparent_cmap_shapes_matplotlib.png new file mode 100644 index 00000000..8af4cd35 Binary files /dev/null and b/tests/_images/Utils_transparent_cmap_shapes_matplotlib.png differ diff --git a/tests/pl/test_utils.py b/tests/pl/test_utils.py index a9296d2e..42165333 100644 --- a/tests/pl/test_utils.py +++ b/tests/pl/test_utils.py @@ -4,10 +4,16 @@ import pandas as pd import pytest import scanpy as sc +import xarray as xr from spatialdata import SpatialData import spatialdata_plot -from spatialdata_plot.pl.utils import _get_subplots +from spatialdata_plot.pl.utils import ( + _apply_cmap_alpha_to_datashader_result, + _datashader_map_aggregate_to_color, + _get_subplots, + set_zero_in_cmap_to_transparent, +) from tests.conftest import DPI, PlotTester, PlotTesterMeta sc.pl.set_rcParams_defaults() @@ -52,8 +58,6 @@ def test_plot_colnames_that_are_valid_matplotlib_greyscale_colors_are_not_evalua sdata_blobs.pl.render_shapes("blobs_polygons", color=colname).pl.show() def test_plot_can_set_zero_in_cmap_to_transparent(self, sdata_blobs: SpatialData): - from spatialdata_plot.pl.utils import set_zero_in_cmap_to_transparent - # set up figure and modify the data to add 0s _, axs = plt.subplots(nrows=1, ncols=2, layout="tight") sdata_blobs.tables["table"].obs["my_var"] = list(range(len(sdata_blobs.tables["table"].obs))) @@ -73,6 +77,49 @@ def test_plot_can_set_zero_in_cmap_to_transparent(self, sdata_blobs: SpatialData ax=axs[1], colorbar=False ) + def _render_transparent_cmap_shapes(self, sdata_blobs: SpatialData, method: str): + _, axs = plt.subplots(nrows=1, ncols=2, layout="tight") + new_cmap = set_zero_in_cmap_to_transparent(cmap="viridis") + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs) + sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons" + sdata_blobs.shapes["blobs_polygons"]["value"] = [0.0, 2.0, 3.0, 4.0, 5.0] + + # left: baseline with standard viridis + sdata_blobs.pl.render_images("blobs_image").pl.render_shapes( + "blobs_polygons", color="value", cmap="viridis", method=method + ).pl.show(ax=axs[0], colorbar=False) + + # right: transparent cmap — shape with value=0 should reveal the image + sdata_blobs.pl.render_images("blobs_image").pl.render_shapes( + "blobs_polygons", color="value", cmap=new_cmap, method=method + ).pl.show(ax=axs[1], colorbar=False) + + def test_plot_transparent_cmap_shapes_matplotlib(self, sdata_blobs: SpatialData): + self._render_transparent_cmap_shapes(sdata_blobs, method="matplotlib") + + def test_plot_transparent_cmap_shapes_datashader(self, sdata_blobs: SpatialData): + self._render_transparent_cmap_shapes(sdata_blobs, method="datashader") + + def test_plot_transparent_cmap_shapes_clip_false(self, sdata_blobs: SpatialData): + """Transparent cmap with clip=False norm (3-part shading path).""" + from matplotlib.colors import Normalize + + _, axs = plt.subplots(nrows=1, ncols=2, layout="tight") + new_cmap = set_zero_in_cmap_to_transparent(cmap="viridis") + norm = Normalize(vmin=0, vmax=5, clip=False) + + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs) + sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons" + sdata_blobs.shapes["blobs_polygons"]["value"] = [0.0, 2.0, 3.0, 4.0, 5.0] + + sdata_blobs.pl.render_images("blobs_image").pl.render_shapes( + "blobs_polygons", color="value", cmap="viridis", norm=norm, method="datashader" + ).pl.show(ax=axs[0], colorbar=False) + + sdata_blobs.pl.render_images("blobs_image").pl.render_shapes( + "blobs_polygons", color="value", cmap=new_cmap, norm=norm, method="datashader" + ).pl.show(ax=axs[1], colorbar=False) + @pytest.mark.parametrize( "color_result", @@ -90,6 +137,79 @@ def test_is_color_like(color_result: tuple[ColorLike, bool]): assert spatialdata_plot.pl.utils._is_color_like(color) == result +class TestCmapAlphaDatashader: + """Regression tests for #376: set_zero_in_cmap_to_transparent with datashader.""" + + def test_transparent_pixels_get_alpha_zero(self): + """Post-processing sets alpha=0 for pixels mapping to transparent cmap entries.""" + import datashader as ds + + cmap = set_zero_in_cmap_to_transparent("viridis") + data = np.array([[0.0, 5.0, 10.0]], dtype=np.float64) + agg = xr.DataArray(data, dims=["y", "x"]) + + shaded = ds.tf.shade(agg, cmap=cmap, min_alpha=254, how="linear") + result = _apply_cmap_alpha_to_datashader_result(shaded, agg, cmap, span=[0.0, 10.0]) + rgba = result.to_numpy().base if hasattr(result, "to_numpy") else result + + assert rgba[0, 0, 3] == 0, f"Expected alpha=0 at value=0.0, got {rgba[0, 0, 3]}" + assert rgba[0, 1, 3] > 0, "Expected non-zero alpha at value=5.0" + assert rgba[0, 2, 3] > 0, "Expected non-zero alpha at value=10.0" + + def test_opaque_cmap_unchanged(self): + """Post-processing is a no-op for fully opaque cmaps.""" + import datashader as ds + + cmap = plt.get_cmap("viridis") + data = np.array([[0.0, 5.0, 10.0]], dtype=np.float64) + agg = xr.DataArray(data, dims=["y", "x"]) + + shaded = ds.tf.shade(agg, cmap=cmap, min_alpha=254, how="linear") + rgba_before = shaded.to_numpy().base.copy() + result = _apply_cmap_alpha_to_datashader_result(shaded, agg, cmap, span=[0.0, 10.0]) + rgba_after = result.to_numpy().base if hasattr(result, "to_numpy") else result + np.testing.assert_array_equal(rgba_before, rgba_after) + + def test_string_cmap_passthrough(self): + """Post-processing is a no-op for string cmaps (early return).""" + dummy_rgba = np.zeros((2, 3, 4), dtype=np.uint8) + dummy_rgba[:, :, 3] = 200 + data = np.array([[0.0, 5.0, 10.0]], dtype=np.float64) + agg = xr.DataArray(data, dims=["y", "x"]) + + result = _apply_cmap_alpha_to_datashader_result(dummy_rgba, agg, "viridis", span=[0.0, 10.0]) + np.testing.assert_array_equal(result, dummy_rgba) + + def test_end_to_end_datashader_map(self): + """_datashader_map_aggregate_to_color produces alpha=0 for transparent cmap entries.""" + cmap = set_zero_in_cmap_to_transparent("viridis") + data = np.array([[0.0, 5.0, 10.0]], dtype=np.float64) + agg = xr.DataArray(data, dims=["y", "x"]) + + result = _datashader_map_aggregate_to_color(agg, cmap=cmap, min_alpha=254, span=[0.0, 10.0]) + img = result.to_numpy().base if hasattr(result, "to_numpy") else result + + assert img[0, 0, 3] == 0, f"Expected alpha=0 at value=0.0, got {img[0, 0, 3]}" + assert img[0, 1, 3] > 0, "Expected non-zero alpha at value=5.0" + + def test_span_none_preserves_colors(self): + """With span=None, non-transparent shapes keep their correct colors.""" + cmap = set_zero_in_cmap_to_transparent("viridis") + data = np.array([[0.0, 5.0, 10.0]], dtype=np.float64) + agg = xr.DataArray(data, dims=["y", "x"]) + + result = _datashader_map_aggregate_to_color(agg, cmap=cmap, min_alpha=254) + img = result.to_numpy().base if hasattr(result, "to_numpy") else result + + # value=0 should be transparent + assert img[0, 0, 3] == 0 + # value=5 and value=10 should be opaque with correct viridis colors (not white) + assert img[0, 1, 3] > 0 + assert img[0, 2, 3] > 0 + # The non-transparent pixels should NOT be white (R=255,G=255,B=255) + assert not (img[0, 1, 0] == 255 and img[0, 1, 1] == 255 and img[0, 1, 2] == 255) + + def test_extract_scalar_value(): """Test the new _extract_scalar_value function for robust numeric conversion."""