diff --git a/README.md b/README.md index 2a61cb2f..50c60297 100644 --- a/README.md +++ b/README.md @@ -486,6 +486,7 @@ Same-CRS tiles skip reprojection entirely and are placed by direct coordinate al | [Gradient](xrspatial/morphology.py) | Dilation minus erosion (edge detection) | Standard (morphology) | ✅️ | ✅️ | ✅️ | ✅️ | | [White Top-hat](xrspatial/morphology.py) | Original minus opening (isolate bright features) | Standard (morphology) | ✅️ | ✅️ | ✅️ | ✅️ | | [Black Top-hat](xrspatial/morphology.py) | Closing minus original (isolate dark features) | Standard (morphology) | ✅️ | ✅️ | ✅️ | ✅️ | +| [Sieve](xrspatial/sieve.py) | Remove small connected clumps from classified rasters | GDAL sieve | ✅️ | ✅️ | 🔄 | 🔄 | ------- diff --git a/docs/source/reference/zonal.rst b/docs/source/reference/zonal.rst index c151f5ec..7d20706d 100644 --- a/docs/source/reference/zonal.rst +++ b/docs/source/reference/zonal.rst @@ -30,6 +30,13 @@ Regions xrspatial.zonal.regions +Sieve +===== +.. autosummary:: + :toctree: _autosummary + + xrspatial.sieve.sieve + Trim ==== .. autosummary:: diff --git a/examples/user_guide/48_Sieve_Filter.ipynb b/examples/user_guide/48_Sieve_Filter.ipynb new file mode 100644 index 00000000..d651cda6 --- /dev/null +++ b/examples/user_guide/48_Sieve_Filter.ipynb @@ -0,0 +1,261 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Sieve Filter: Removing Small Raster Clumps\n", + "\n", + "Classification outputs often contain salt-and-pepper noise: tiny clumps of 1-3 pixels that don't represent real features. The `sieve` function removes these by replacing connected regions smaller than a given threshold with the value of their largest spatial neighbor.\n", + "\n", + "This is the xarray-spatial equivalent of GDAL's `gdal_sieve.py`, and it pairs naturally with classification functions like `natural_breaks()` or `reclassify()` and with `polygonize()` for cleaning results before vectorization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import xarray as xr\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.colors import ListedColormap\n", + "\n", + "from xrspatial.sieve import sieve\n", + "from xrspatial.classify import natural_breaks" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generate a Noisy Classified Raster\n", + "\n", + "We'll create a synthetic classified raster with three land-cover classes and scatter some salt-and-pepper noise across it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(42)\n", + "rows, cols = 80, 100\n", + "\n", + "# Build a base classification with three broad zones\n", + "base = np.ones((rows, cols), dtype=np.float64)\n", + "base[:, 40:70] = 2.0\n", + "base[30:60, :] = 3.0\n", + "base[30:60, 40:70] = 2.0\n", + "\n", + "# Add salt-and-pepper noise: randomly flip ~8% of pixels\n", + "noise_mask = np.random.random((rows, cols)) < 0.08\n", + "noise_vals = np.random.choice([1.0, 2.0, 3.0], size=(rows, cols))\n", + "noisy = base.copy()\n", + "noisy[noise_mask] = noise_vals[noise_mask]\n", + "\n", + "# Sprinkle some NaN (nodata) pixels\n", + "noisy[0:3, 0:3] = np.nan\n", + "noisy[77:, 97:] = np.nan\n", + "\n", + "raster = xr.DataArray(noisy, dims=['y', 'x'], name='landcover')\n", + "print(f'Raster shape: {raster.shape}')\n", + "print(f'Unique values (excl. NaN): {np.unique(raster.values[~np.isnan(raster.values)])}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cmap = ListedColormap(['#2ecc71', '#3498db', '#e74c3c'])\n", + "\n", + "fig, ax = plt.subplots(figsize=(8, 5))\n", + "im = ax.imshow(raster.values, cmap=cmap, vmin=0.5, vmax=3.5, interpolation='nearest')\n", + "ax.set_title('Noisy classified raster')\n", + "cbar = fig.colorbar(im, ax=ax, ticks=[1, 2, 3])\n", + "cbar.ax.set_yticklabels(['Class 1', 'Class 2', 'Class 3'])\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic Sieve: Remove Single-Pixel Noise\n", + "\n", + "The simplest use case: set a threshold so isolated pixels are absorbed by their surroundings." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sieved = sieve(raster, threshold=4)\n", + "\n", + "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", + "for ax, data, title in zip(axes, [raster, sieved], ['Before sieve', 'After sieve (threshold=4)']):\n", + " im = ax.imshow(data.values, cmap=cmap, vmin=0.5, vmax=3.5, interpolation='nearest')\n", + " ax.set_title(title)\n", + "fig.colorbar(im, ax=axes, ticks=[1, 2, 3], shrink=0.8)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Connectivity: 4 vs 8\n", + "\n", + "With 4-connectivity (rook), only pixels sharing an edge are considered connected. With 8-connectivity (queen), diagonally adjacent pixels also form part of the same region. This affects which clumps are identified as \"small.\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sieved_4 = sieve(raster, threshold=6, neighborhood=4)\n", + "sieved_8 = sieve(raster, threshold=6, neighborhood=8)\n", + "\n", + "fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n", + "for ax, data, title in zip(\n", + " axes,\n", + " [raster, sieved_4, sieved_8],\n", + " ['Original', '4-connectivity (threshold=6)', '8-connectivity (threshold=6)'],\n", + "):\n", + " im = ax.imshow(data.values, cmap=cmap, vmin=0.5, vmax=3.5, interpolation='nearest')\n", + " ax.set_title(title)\n", + "fig.colorbar(im, ax=axes, ticks=[1, 2, 3], shrink=0.8)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Selective Sieving with `skip_values`\n", + "\n", + "Sometimes certain class values should never be removed, even if their regions are small. Use `skip_values` to protect specific categories from merging while still allowing other small regions to be cleaned up." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Protect class 3 from sieving\n", + "sieved_skip = sieve(raster, threshold=10, skip_values=[3.0])\n", + "sieved_noskip = sieve(raster, threshold=10)\n", + "\n", + "fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n", + "for ax, data, title in zip(\n", + " axes,\n", + " [raster, sieved_noskip, sieved_skip],\n", + " ['Original', 'threshold=10 (no skip)', 'threshold=10 (skip class 3)'],\n", + "):\n", + " im = ax.imshow(data.values, cmap=cmap, vmin=0.5, vmax=3.5, interpolation='nearest')\n", + " ax.set_title(title)\n", + "fig.colorbar(im, ax=axes, ticks=[1, 2, 3], shrink=0.8)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Practical Example: Clean Up a Classification\n", + "\n", + "Generate a continuous surface, classify it with `natural_breaks`, and then sieve the result to remove small artifacts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a smooth surface with some high-frequency variation\n", + "y = np.linspace(0, 4 * np.pi, rows)\n", + "x = np.linspace(0, 4 * np.pi, cols)\n", + "Y, X = np.meshgrid(y, x, indexing='ij')\n", + "surface = np.sin(Y) * np.cos(X) + 0.4 * np.random.randn(rows, cols)\n", + "\n", + "surface_da = xr.DataArray(surface, dims=['y', 'x'])\n", + "classified = natural_breaks(surface_da, k=5)\n", + "\n", + "# Sieve the classification\n", + "cleaned = sieve(classified, threshold=8)\n", + "\n", + "fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n", + "axes[0].imshow(surface, cmap='terrain', interpolation='nearest')\n", + "axes[0].set_title('Continuous surface')\n", + "axes[1].imshow(classified.values, cmap='tab10', interpolation='nearest')\n", + "axes[1].set_title('natural_breaks (k=5)')\n", + "axes[2].imshow(cleaned.values, cmap='tab10', interpolation='nearest')\n", + "axes[2].set_title('After sieve (threshold=8)')\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Threshold Selection\n", + "\n", + "The right threshold depends on pixel resolution and the minimum feature size you care about. Here's a comparison across threshold values." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "thresholds = [2, 5, 15, 50]\n", + "fig, axes = plt.subplots(1, len(thresholds), figsize=(5 * len(thresholds), 5))\n", + "\n", + "for ax, t in zip(axes, thresholds):\n", + " result = sieve(classified, threshold=t)\n", + " ax.imshow(result.values, cmap='tab10', interpolation='nearest')\n", + " ax.set_title(f'threshold={t}')\n", + "\n", + "plt.suptitle('Effect of sieve threshold on classified raster', y=1.02)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbformat_minor": 4, + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/xrspatial/__init__.py b/xrspatial/__init__.py index 8bd732a6..fd2715c2 100644 --- a/xrspatial/__init__.py +++ b/xrspatial/__init__.py @@ -104,6 +104,7 @@ from xrspatial.hydro import stream_link_d8, stream_link_dinf, stream_link_mfd # noqa from xrspatial.hydro import stream_order # noqa: unified wrapper from xrspatial.hydro import stream_order_d8, stream_order_dinf, stream_order_mfd # noqa +from xrspatial.sieve import sieve # noqa from xrspatial.sky_view_factor import sky_view_factor # noqa from xrspatial.slope import slope # noqa from xrspatial.surface_distance import surface_allocation # noqa diff --git a/xrspatial/sieve.py b/xrspatial/sieve.py new file mode 100644 index 00000000..b959312d --- /dev/null +++ b/xrspatial/sieve.py @@ -0,0 +1,373 @@ +"""Sieve filter for removing small raster clumps. + +Given a categorical raster and a pixel-count threshold, replaces +connected regions smaller than the threshold with the value of +their largest spatial neighbor. Pairs with classification functions +(``natural_breaks``, ``reclassify``, etc.) and ``polygonize`` for +cleaning results before vectorization. + +Supports all four backends: numpy, cupy, dask+numpy, dask+cupy. +""" + +from __future__ import annotations + +from collections import defaultdict +from typing import Sequence + +import numpy as np +import xarray as xr +from xarray import DataArray + +try: + import cupy +except ImportError: + + class cupy: + ndarray = False + + +try: + import dask.array as da +except ImportError: + da = None + +from xrspatial.utils import ( + _validate_raster, + has_cuda_and_cupy, + is_cupy_array, + is_dask_cupy, +) + + +# --------------------------------------------------------------------------- +# Adjacency helpers +# --------------------------------------------------------------------------- + + +def _build_adjacency(region_map, neighborhood): + """Build a region adjacency dict from a labeled map using vectorized shifts. + + Returns ``{region_id: set_of_neighbor_ids}``. + """ + adjacency: dict[int, set[int]] = defaultdict(set) + + def _add_pairs(a, b): + mask = (a > 0) & (b > 0) & (a != b) + if not mask.any(): + return + pairs = np.unique( + np.column_stack([a[mask].ravel(), b[mask].ravel()]), axis=0 + ) + for x, y in pairs: + adjacency[int(x)].add(int(y)) + adjacency[int(y)].add(int(x)) + + # 4-connected directions (rook) + _add_pairs(region_map[:-1, :], region_map[1:, :]) # vertical + _add_pairs(region_map[:, :-1], region_map[:, 1:]) # horizontal + + # 8-connected adds diagonals (queen) + if neighborhood == 8: + _add_pairs(region_map[:-1, :-1], region_map[1:, 1:]) # SE + _add_pairs(region_map[:-1, 1:], region_map[1:, :-1]) # SW + + return adjacency + + +# --------------------------------------------------------------------------- +# numpy backend +# --------------------------------------------------------------------------- + + +def _label_all_regions(result, valid, structure): + """Label connected components per unique value. + + Returns + ------- + region_map : ndarray of int32 + Each pixel mapped to its region id (0 = nodata). + region_val : ndarray of float64 + Original raster value for each region id. + n_total : int + Total number of regions + 1 (length of *region_val*). + """ + from scipy.ndimage import label + + unique_vals = np.unique(result[valid]) + region_map = np.zeros(result.shape, dtype=np.int32) + region_val_list: list[float] = [np.nan] # id 0 = nodata + uid = 1 + + for v in unique_vals: + mask = (result == v) & valid + labeled, n_features = label(mask, structure=structure) + if n_features > 0: + nonzero = labeled > 0 + region_map[nonzero] = labeled[nonzero] + (uid - 1) + region_val_list.extend([float(v)] * n_features) + uid += n_features + + region_val = np.array(region_val_list, dtype=np.float64) + return region_map, region_val, uid + + +def _sieve_numpy(data, threshold, neighborhood, skip_values): + """Replace connected regions smaller than *threshold* pixels.""" + structure = ( + np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]) + if neighborhood == 4 + else np.ones((3, 3), dtype=int) + ) + + result = data.astype(np.float64, copy=True) + is_float = np.issubdtype(data.dtype, np.floating) + valid = ~np.isnan(result) if is_float else np.ones(result.shape, dtype=bool) + skip_set = set(skip_values) if skip_values is not None else set() + + for _ in range(50): # convergence limit + region_map, region_val, uid = _label_all_regions( + result, valid, structure + ) + region_size = np.bincount( + region_map.ravel(), minlength=uid + ).astype(np.int64) + + # Identify small regions eligible for merging + small_ids = [ + rid + for rid in range(1, uid) + if region_size[rid] < threshold + and region_val[rid] not in skip_set + ] + if not small_ids: + break + + adjacency = _build_adjacency(region_map, neighborhood) + + # Process smallest regions first so they merge into larger neighbors + small_ids.sort(key=lambda r: region_size[r]) + + merged_any = False + for rid in small_ids: + if region_size[rid] == 0 or region_size[rid] >= threshold: + continue + + neighbors = adjacency.get(rid) + if not neighbors: + continue # surrounded by nodata only + + largest_nid = max(neighbors, key=lambda n: region_size[n]) + mask = region_map == rid + result[mask] = region_val[largest_nid] + + # Update tracking in place + region_map[mask] = largest_nid + region_size[largest_nid] += region_size[rid] + region_size[rid] = 0 + + for n in neighbors: + if n != largest_nid: + adjacency[n].discard(rid) + adjacency[n].add(largest_nid) + adjacency.setdefault(largest_nid, set()).add(n) + if largest_nid in adjacency: + adjacency[largest_nid].discard(rid) + del adjacency[rid] + merged_any = True + + if not merged_any: + break + + return result + + +# --------------------------------------------------------------------------- +# cupy backend (CPU fallback – merge logic is serial) +# --------------------------------------------------------------------------- + + +def _sieve_cupy(data, threshold, neighborhood, skip_values): + """CuPy backend: transfer to CPU, sieve, transfer back.""" + import cupy as cp + + np_result = _sieve_numpy(data.get(), threshold, neighborhood, skip_values) + return cp.asarray(np_result) + + +# --------------------------------------------------------------------------- +# dask backends +# --------------------------------------------------------------------------- + + +def _available_memory_bytes(): + """Best-effort estimate of available memory in bytes.""" + try: + with open("/proc/meminfo", "r") as f: + for line in f: + if line.startswith("MemAvailable:"): + return int(line.split()[1]) * 1024 + except (OSError, ValueError, IndexError): + pass + try: + import psutil + + return psutil.virtual_memory().available + except (ImportError, AttributeError): + pass + return 2 * 1024**3 + + +def _sieve_dask(data, threshold, neighborhood, skip_values): + """Dask+numpy backend: compute to numpy, sieve, wrap back.""" + avail = _available_memory_bytes() + estimated_bytes = np.prod(data.shape) * data.dtype.itemsize + if estimated_bytes * 5 > 0.5 * avail: + raise MemoryError( + f"sieve() needs the full array in memory " + f"(~{estimated_bytes * 5 / 1e9:.1f} GB) but only " + f"~{avail / 1e9:.1f} GB is available. Connected-component " + f"labeling is a global operation that cannot be chunked. " + f"Consider downsampling or tiling the input manually." + ) + + np_data = data.compute() + result = _sieve_numpy(np_data, threshold, neighborhood, skip_values) + return da.from_array(result, chunks=data.chunks) + + +def _sieve_dask_cupy(data, threshold, neighborhood, skip_values): + """Dask+CuPy backend: compute to cupy, sieve via CPU fallback, wrap back.""" + estimated_bytes = np.prod(data.shape) * data.dtype.itemsize + try: + import cupy as cp + + free_gpu, _total = cp.cuda.Device().mem_info + if estimated_bytes * 5 > 0.5 * free_gpu: + raise MemoryError( + f"sieve() needs the full array on GPU " + f"(~{estimated_bytes * 5 / 1e9:.1f} GB) but only " + f"~{free_gpu / 1e9:.1f} GB free. Connected-component " + f"labeling is a global operation that cannot be chunked. " + f"Consider downsampling or tiling the input manually." + ) + except (ImportError, AttributeError): + pass + + cp_data = data.compute() + result = _sieve_cupy(cp_data, threshold, neighborhood, skip_values) + return da.from_array(result, chunks=data.chunks) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def sieve( + raster: xr.DataArray, + threshold: int = 10, + neighborhood: int = 4, + skip_values: Sequence[float] | None = None, + name: str = "sieve", +) -> xr.DataArray: + """Remove small connected regions from a classified raster. + + Identifies connected components of same-value pixels and replaces + regions smaller than *threshold* pixels with the value of their + largest spatial neighbor. NaN pixels are always preserved. + + Parameters + ---------- + raster : xr.DataArray + 2D classified or categorical raster. + threshold : int, default=10 + Minimum region size in pixels. Regions with fewer pixels + are replaced by their largest neighbor's value. + neighborhood : int, default=4 + Pixel connectivity: 4 (rook) or 8 (queen). + skip_values : sequence of float, optional + Category values whose regions are never replaced, regardless + of size. These regions can still serve as merge targets for + neighboring small regions. + name : str, default='sieve' + Output DataArray name. + + Returns + ------- + xr.DataArray + Sieved raster with the same shape, dims, coords, and attrs. + + Examples + -------- + .. sourcecode:: python + + >>> import numpy as np + >>> import xarray as xr + >>> from xrspatial.sieve import sieve + + >>> # Classified raster with salt-and-pepper noise + >>> arr = np.array([[1, 1, 1, 2, 2], + ... [1, 3, 1, 2, 2], + ... [1, 1, 1, 2, 2], + ... [2, 2, 2, 2, 2], + ... [2, 2, 2, 2, 2]], dtype=np.float64) + >>> raster = xr.DataArray(arr, dims=['y', 'x']) + + >>> # Remove regions smaller than 2 pixels + >>> result = sieve(raster, threshold=2) + >>> print(result.values) + [[1. 1. 1. 2. 2.] + [1. 1. 1. 2. 2.] + [1. 1. 1. 2. 2.] + [2. 2. 2. 2. 2.] + [2. 2. 2. 2. 2.]] + + Notes + ----- + This is a global operation: for dask-backed arrays the entire raster + is computed into memory before sieving. Connected-component labeling + cannot be performed on individual chunks because regions may span + chunk boundaries. + + The CuPy backends use a CPU fallback for the merge step, which is + inherently serial. + + See Also + -------- + xrspatial.zonal.regions : Connected-component labeling. + xrspatial.classify.natural_breaks : Classification that may produce + noisy output suitable for sieving. + """ + _validate_raster(raster, func_name="sieve", name="raster", ndim=2) + + if neighborhood not in (4, 8): + raise ValueError("`neighborhood` must be 4 or 8") + + if not isinstance(threshold, (int, np.integer)) or threshold < 1: + raise ValueError("`threshold` must be a positive integer") + + data = raster.data + + if isinstance(data, np.ndarray): + out = _sieve_numpy(data, threshold, neighborhood, skip_values) + elif has_cuda_and_cupy() and is_cupy_array(data): + out = _sieve_cupy(data, threshold, neighborhood, skip_values) + elif da is not None and isinstance(data, da.Array): + if is_dask_cupy(raster): + out = _sieve_dask_cupy( + data, threshold, neighborhood, skip_values + ) + else: + out = _sieve_dask(data, threshold, neighborhood, skip_values) + else: + raise TypeError( + f"Unsupported array type {type(data).__name__} for sieve()" + ) + + return DataArray( + out, + name=name, + dims=raster.dims, + coords=raster.coords, + attrs=raster.attrs, + ) diff --git a/xrspatial/tests/test_sieve.py b/xrspatial/tests/test_sieve.py new file mode 100644 index 00000000..4e576fee --- /dev/null +++ b/xrspatial/tests/test_sieve.py @@ -0,0 +1,424 @@ +"""Tests for xrspatial.sieve.""" + +from __future__ import annotations + +import numpy as np +import pytest +import xarray as xr + +try: + import dask.array as da +except ImportError: + da = None + +from xrspatial.sieve import sieve +from xrspatial.tests.general_checks import create_test_raster, general_output_checks + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_raster(arr, backend): + return create_test_raster(arr, backend) + + +def _to_numpy(result): + data = result.data + if da is not None and isinstance(data, da.Array): + data = data.compute() + return np.asarray(data) + + +# --------------------------------------------------------------------------- +# Basic correctness +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("backend", ["numpy", "dask+numpy"]) +def test_sieve_removes_single_pixel_noise(backend): + """A single pixel of class 3 surrounded by class 1 should be replaced.""" + arr = np.array( + [ + [1, 1, 1, 1, 1], + [1, 3, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + ], + dtype=np.float64, + ) + expected = np.ones((5, 5), dtype=np.float64) + raster = _make_raster(arr, backend) + result = sieve(raster, threshold=2) + general_output_checks(raster, result, expected) + + +@pytest.mark.parametrize("backend", ["numpy", "dask+numpy"]) +def test_sieve_preserves_large_regions(backend): + """Regions at or above threshold stay unchanged.""" + arr = np.array( + [ + [1, 1, 2, 2], + [1, 1, 2, 2], + [3, 3, 4, 4], + [3, 3, 4, 4], + ], + dtype=np.float64, + ) + raster = _make_raster(arr, backend) + result = sieve(raster, threshold=4) + general_output_checks(raster, result, arr) + + +@pytest.mark.parametrize("backend", ["numpy", "dask+numpy"]) +def test_sieve_merges_into_largest_neighbor(backend): + """A small region should merge into its largest neighbor, not the nearest.""" + # Class 2 occupies 2 pixels, surrounded by class 1 (large) and class 3 (medium). + # The 2 pixels are adjacent to both 1 and 3, but 1 is larger. + arr = np.array( + [ + [1, 1, 1, 1], + [1, 2, 2, 3], + [1, 1, 1, 3], + [1, 1, 1, 3], + ], + dtype=np.float64, + ) + raster = _make_raster(arr, backend) + result = sieve(raster, threshold=3) + data = _to_numpy(result) + # The 2-pixel region of class 2 should become class 1 (9 pixels > 3 pixels of class 3) + assert data[1, 1] == 1.0 + assert data[1, 2] == 1.0 + + +# --------------------------------------------------------------------------- +# Connectivity +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("backend", ["numpy", "dask+numpy"]) +def test_sieve_four_connectivity(backend): + """With 4-connectivity, diagonal pixels are separate regions.""" + arr = np.array( + [ + [1, 2, 1], + [2, 1, 2], + [1, 2, 1], + ], + dtype=np.float64, + ) + raster = _make_raster(arr, backend) + # With 4-connectivity: each 1 and 2 forms its own 1-pixel region + # except center which is 1 pixel. All regions are size 1. + # threshold=2 should merge them all. + result = sieve(raster, threshold=2, neighborhood=4) + data = _to_numpy(result) + # All pixels should end up the same value (merged into one) + assert len(np.unique(data)) == 1 + + +@pytest.mark.parametrize("backend", ["numpy", "dask+numpy"]) +def test_sieve_eight_connectivity(backend): + """With 8-connectivity, diagonal pixels form connected regions.""" + arr = np.array( + [ + [1, 2, 1], + [2, 1, 2], + [1, 2, 1], + ], + dtype=np.float64, + ) + raster = _make_raster(arr, backend) + # With 8-connectivity: all 1s are diagonally connected (5 pixels), + # all 2s are diagonally connected (4 pixels). + # threshold=5 means only the 2-region (4 pixels) gets sieved. + result = sieve(raster, threshold=5, neighborhood=8) + data = _to_numpy(result) + assert np.all(data == 1.0) + + +# --------------------------------------------------------------------------- +# skip_values +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("backend", ["numpy", "dask+numpy"]) +def test_sieve_skip_values(backend): + """Regions of skipped values are never replaced, even if small.""" + arr = np.array( + [ + [1, 1, 1, 1], + [1, 2, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + ], + dtype=np.float64, + ) + raster = _make_raster(arr, backend) + # Class 2 is 1 pixel, below threshold=5, but skip_values=[2] + result = sieve(raster, threshold=5, skip_values=[2.0]) + data = _to_numpy(result) + assert data[1, 1] == 2.0 # Preserved because of skip_values + + +@pytest.mark.parametrize("backend", ["numpy", "dask+numpy"]) +def test_sieve_skip_values_still_serve_as_merge_target(backend): + """A skip-value region can absorb a neighboring small region.""" + arr = np.array( + [ + [2, 2, 2, 2], + [2, 3, 2, 2], + [2, 2, 2, 2], + [2, 2, 2, 2], + ], + dtype=np.float64, + ) + raster = _make_raster(arr, backend) + # Class 2 is skipped, class 3 is 1 pixel. 3 should merge into 2. + result = sieve(raster, threshold=2, skip_values=[2.0]) + data = _to_numpy(result) + assert data[1, 1] == 2.0 + + +# --------------------------------------------------------------------------- +# NaN handling +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("backend", ["numpy", "dask+numpy"]) +def test_sieve_preserves_nan(backend): + """NaN pixels remain NaN after sieving.""" + arr = np.array( + [ + [1, 1, np.nan], + [1, 1, np.nan], + [1, 1, np.nan], + ], + dtype=np.float64, + ) + raster = _make_raster(arr, backend) + result = sieve(raster, threshold=2) + data = _to_numpy(result) + assert np.all(np.isnan(data[:, 2])) + assert np.all(data[:, :2] == 1.0) + + +@pytest.mark.parametrize("backend", ["numpy", "dask+numpy"]) +def test_sieve_all_nan(backend): + """All-NaN input returns all-NaN output.""" + arr = np.full((3, 3), np.nan, dtype=np.float64) + raster = _make_raster(arr, backend) + result = sieve(raster, threshold=2) + data = _to_numpy(result) + assert np.all(np.isnan(data)) + + +@pytest.mark.parametrize("backend", ["numpy", "dask+numpy"]) +def test_sieve_small_region_adjacent_only_to_nan(backend): + """A small region with no valid neighbors stays unchanged.""" + arr = np.array( + [ + [np.nan, np.nan, np.nan], + [np.nan, 5.0, np.nan], + [np.nan, np.nan, np.nan], + ], + dtype=np.float64, + ) + raster = _make_raster(arr, backend) + result = sieve(raster, threshold=10) + data = _to_numpy(result) + assert data[1, 1] == 5.0 + assert np.isnan(data[0, 0]) + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("backend", ["numpy", "dask+numpy"]) +def test_sieve_no_small_regions(backend): + """When all regions are above threshold, output equals input.""" + arr = np.array( + [ + [1, 1, 2, 2], + [1, 1, 2, 2], + ], + dtype=np.float64, + ) + raster = _make_raster(arr, backend) + result = sieve(raster, threshold=2) + general_output_checks(raster, result, arr) + + +@pytest.mark.parametrize("backend", ["numpy", "dask+numpy"]) +def test_sieve_threshold_one(backend): + """threshold=1 means no region is below threshold; nothing changes.""" + arr = np.array( + [ + [1, 2, 1], + [2, 1, 2], + [1, 2, 1], + ], + dtype=np.float64, + ) + raster = _make_raster(arr, backend) + result = sieve(raster, threshold=1) + general_output_checks(raster, result, arr) + + +@pytest.mark.parametrize("backend", ["numpy", "dask+numpy"]) +def test_sieve_single_value_raster(backend): + """Raster with one value everywhere should be unchanged.""" + arr = np.full((4, 4), 7.0, dtype=np.float64) + raster = _make_raster(arr, backend) + result = sieve(raster, threshold=100) + general_output_checks(raster, result, arr) + + +@pytest.mark.parametrize("backend", ["numpy", "dask+numpy"]) +def test_sieve_integer_input(backend): + """Integer-typed rasters should work (output is float64).""" + arr = np.array( + [ + [1, 1, 1], + [1, 2, 1], + [1, 1, 1], + ], + dtype=np.int64, + ) + raster = _make_raster(arr, backend) + result = sieve(raster, threshold=2) + data = _to_numpy(result) + # The single pixel of 2 should be merged into 1 + assert data[1, 1] == 1.0 + + +@pytest.mark.parametrize("backend", ["numpy", "dask+numpy"]) +def test_sieve_cascading_merge(backend): + """Small regions that chain-merge across multiple sizes.""" + # Region A (1 pixel, val=3) merges into B (2 pixels, val=2), + # then B (now 3 pixels) merges into C (large, val=1) with threshold=4. + arr = np.array( + [ + [1, 1, 1, 1], + [1, 2, 2, 1], + [1, 3, 1, 1], + [1, 1, 1, 1], + ], + dtype=np.float64, + ) + raster = _make_raster(arr, backend) + result = sieve(raster, threshold=4) + data = _to_numpy(result) + # All small regions should ultimately merge into class 1 + expected = np.ones((4, 4), dtype=np.float64) + np.testing.assert_array_equal(data, expected) + + +# --------------------------------------------------------------------------- +# Output properties +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("backend", ["numpy", "dask+numpy"]) +def test_sieve_preserves_attrs(backend): + """Output should preserve input dims, coords, and attrs.""" + arr = np.ones((4, 4), dtype=np.float64) + raster = _make_raster(arr, backend) + result = sieve(raster, threshold=2) + general_output_checks(raster, result, verify_attrs=True) + + +@pytest.mark.parametrize("backend", ["numpy", "dask+numpy"]) +def test_sieve_custom_name(backend): + arr = np.ones((3, 3), dtype=np.float64) + raster = _make_raster(arr, backend) + result = sieve(raster, threshold=2, name="filtered") + assert result.name == "filtered" + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + + +def test_sieve_rejects_3d(): + arr = np.ones((3, 3, 3), dtype=np.float64) + raster = xr.DataArray(arr, dims=["z", "y", "x"]) + with pytest.raises(ValueError, match="2D"): + sieve(raster, threshold=2) + + +def test_sieve_rejects_bad_neighborhood(): + arr = np.ones((3, 3), dtype=np.float64) + raster = xr.DataArray(arr, dims=["y", "x"]) + with pytest.raises(ValueError, match="neighborhood"): + sieve(raster, threshold=2, neighborhood=6) + + +def test_sieve_rejects_bad_threshold(): + arr = np.ones((3, 3), dtype=np.float64) + raster = xr.DataArray(arr, dims=["y", "x"]) + with pytest.raises(ValueError, match="threshold"): + sieve(raster, threshold=0) + with pytest.raises(ValueError, match="threshold"): + sieve(raster, threshold=-1) + + +def test_sieve_rejects_ndarray(): + with pytest.raises(TypeError, match="xarray.DataArray"): + sieve(np.zeros((5, 5)), threshold=2) + + +# --------------------------------------------------------------------------- +# Dask memory guard +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(da is None, reason="dask not installed") +def test_sieve_dask_memory_guard(): + """Should raise MemoryError before .compute() on huge arrays.""" + from unittest.mock import patch + + from xrspatial.sieve import _sieve_dask + + huge = da.zeros( + (100_000, 100_000), chunks=(1000, 1000), dtype=np.float64 + ) + + with patch( + "xrspatial.sieve._available_memory_bytes", return_value=1 * 1024**3 + ): + with pytest.raises(MemoryError, match="global operation"): + _sieve_dask(huge, 10, 4, None) + + +# --------------------------------------------------------------------------- +# Numpy / dask consistency +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(da is None, reason="dask not installed") +def test_sieve_numpy_dask_match(): + """Numpy and dask backends should produce identical results.""" + arr = np.array( + [ + [1, 1, 2, 3, 3], + [1, 4, 2, 3, 5], + [1, 1, 2, 3, 3], + [6, 6, 2, 2, 2], + [6, 6, 6, 6, 6], + ], + dtype=np.float64, + ) + np_raster = _make_raster(arr, "numpy") + dk_raster = _make_raster(arr, "dask+numpy") + + np_result = _to_numpy(sieve(np_raster, threshold=3)) + dk_result = _to_numpy(sieve(dk_raster, threshold=3)) + + np.testing.assert_array_equal(np_result, dk_result)