Source code for src.terrain.flow_accumulation

"""
Flow accumulation module for hydrological analysis.

Implements D8 flow routing algorithm with precipitation weighting.
Based on flow-spec.md requirements.

Supports two backends:
- "pysheds": Uses pysheds library for core hydrology (recommended)
- "custom": Uses custom numba-accelerated implementation
"""

from pathlib import Path
from typing import Dict, Tuple, Union, Optional, Literal
import datetime
import json
import os
import shutil
import numpy as np
import rasterio
from rasterio import Affine
from scipy import ndimage
from scipy.ndimage import grey_dilation, grey_erosion

# Try to import pysheds
try:
    from pysheds.grid import Grid as PyshedsGrid
    PYSHEDS_AVAILABLE = True
except ImportError:
    PYSHEDS_AVAILABLE = False
    PyshedsGrid = None

# Try to import numba for performance optimizations
try:
    from numba import jit, prange
    NUMBA_AVAILABLE = True
except ImportError:
    NUMBA_AVAILABLE = False
    # Create no-op decorator if numba not available
    def jit(*args, **kwargs):
        def decorator(func):
            return func
        return decorator
    # Mock prange as regular range
    prange = range


# ==============================================================================
# D8 FLOW DIRECTION ENCODING (ESRI ArcGIS Convention)
# ==============================================================================
#
# This module uses ESRI's standard D8 power-of-2 encoding for compatibility
# with GIS workflows. Each direction is assigned a unique power of 2, allowing
# bitwise operations and multi-directional flow calculations.
#
# D8 Neighbor Geometry:
#   8  4  2
#  16  x  1
#  32 64 128
#
# Direction codes and their meanings:
#   1 = East (→)     : (0, +1), distance = 1
#   2 = Northeast ↗ : (-1, +1), distance = sqrt(2)
#   4 = North (↑)    : (-1, 0), distance = 1
#   8 = Northwest ↖ : (-1, -1), distance = sqrt(2)
#  16 = West (←)    : (0, -1), distance = 1
#  32 = Southwest ↙ : (+1, -1), distance = sqrt(2)
#  64 = South (↓)   : (+1, 0), distance = 1
# 128 = Southeast ↘ : (+1, +1), distance = sqrt(2)
#
# Reference: ArcGIS D8 Direction (https://pro.arcgis.com/...)
#           flow-spec.md, Section 1 (Data Structures)

# D8 flow direction encoding: (row_offset, col_offset) -> direction_code
# Directions follow ESRI's power-of-2 encoding (standard in ArcGIS/GRASS)
D8_DIRECTIONS = {
    (0, 1): 1,      # East
    (-1, 1): 2,     # Northeast (diagonal, distance = sqrt(2))
    (-1, 0): 4,     # North
    (-1, -1): 8,    # Northwest (diagonal, distance = sqrt(2))
    (0, -1): 16,    # West
    (1, -1): 32,    # Southwest (diagonal, distance = sqrt(2))
    (1, 0): 64,     # South
    (1, 1): 128,    # Southeast (diagonal, distance = sqrt(2))
}

# Reverse mapping for flow routing: direction_code -> (row_offset, col_offset)
D8_OFFSETS = {v: k for k, v in D8_DIRECTIONS.items()}


# ============================================================================
# Flow Computation Caching
# ============================================================================


def _get_cache_key_params(
    dem_path: str,
    backend: str,
    max_cells: Optional[int],
    target_vertices: Optional[int],
    fill_method: str,
    mask_ocean: bool,
    ocean_elevation_threshold: float,
    coastal_elev_threshold: float,
    edge_mode: str,
    max_breach_depth: float,
    max_breach_length: int,
    epsilon: float,
) -> Dict:
    """
    Build dictionary of parameters that affect cache validity.

    Returns
    -------
    dict
        Parameters that form the cache key
    """
    return {
        "dem_path": str(dem_path),
        "backend": backend,
        "max_cells": max_cells,
        "target_vertices": target_vertices,
        "fill_method": fill_method,
        "mask_ocean": mask_ocean,
        "ocean_elevation_threshold": ocean_elevation_threshold,
        "coastal_elev_threshold": coastal_elev_threshold,
        "edge_mode": edge_mode,
        "max_breach_depth": max_breach_depth,
        "max_breach_length": max_breach_length,
        "epsilon": epsilon,
    }


def _get_dem_mtime(dem_path: Path) -> float:
    """Get modification time of DEM file."""
    return os.path.getmtime(dem_path)


def _validate_cache(
    cache_dir: Path,
    cache_params: Dict,
    dem_mtime: float,
) -> bool:
    """
    Check if valid cache exists.

    Parameters
    ----------
    cache_dir : Path
        Directory containing cached files
    cache_params : dict
        Current computation parameters
    dem_mtime : float
        Modification time of DEM file

    Returns
    -------
    bool
        True if cache is valid, False otherwise
    """
    metadata_file = cache_dir / "flow_cache_metadata.json"
    if not metadata_file.exists():
        return False

    # Check all required output files exist
    required_files = [
        "flow_direction.tif",
        "flow_accumulation_area.tif",
        "flow_accumulation_rainfall.tif",
        "dem_conditioned.tif",
    ]
    for filename in required_files:
        if not (cache_dir / filename).exists():
            return False

    # Load and validate metadata
    try:
        with open(metadata_file) as f:
            cached_metadata = json.load(f)
    except (json.JSONDecodeError, IOError):
        return False

    # Check DEM modification time
    cached_dem_mtime = cached_metadata.get("dem_mtime")
    if cached_dem_mtime is None or cached_dem_mtime != dem_mtime:
        return False

    # Check all cache key parameters match
    cached_params = cached_metadata.get("cache_params", {})
    for key, value in cache_params.items():
        if cached_params.get(key) != value:
            return False

    return True


def _load_from_cache(cache_dir: Path) -> Dict:
    """
    Load flow computation results from cache.

    Parameters
    ----------
    cache_dir : Path
        Directory containing cached files

    Returns
    -------
    dict
        Dictionary with flow_direction, drainage_area, upstream_rainfall,
        conditioned_dem, metadata, and files
    """
    # Load metadata
    metadata_file = cache_dir / "flow_cache_metadata.json"
    with open(metadata_file) as f:
        full_metadata = json.load(f)

    # Extract just the computation metadata (not cache params)
    metadata = full_metadata.get("computation_metadata", {})
    metadata["cache_hit"] = True

    # Load rasters
    files = {
        "flow_direction": str(cache_dir / "flow_direction.tif"),
        "drainage_area": str(cache_dir / "flow_accumulation_area.tif"),
        "upstream_rainfall": str(cache_dir / "flow_accumulation_rainfall.tif"),
        "conditioned_dem": str(cache_dir / "dem_conditioned.tif"),
    }

    with rasterio.open(files["flow_direction"]) as src:
        flow_direction = src.read(1)
    with rasterio.open(files["drainage_area"]) as src:
        drainage_area = src.read(1)
    with rasterio.open(files["upstream_rainfall"]) as src:
        upstream_rainfall = src.read(1)
    with rasterio.open(files["conditioned_dem"]) as src:
        conditioned_dem = src.read(1)

    return {
        "flow_direction": flow_direction,
        "drainage_area": drainage_area,
        "upstream_rainfall": upstream_rainfall,
        "conditioned_dem": conditioned_dem,
        "breached_dem": None,  # Not saved to cache, available only on first run
        "metadata": metadata,
        "files": files,
    }


def _save_to_cache(
    cache_dir: Path,
    cache_params: Dict,
    dem_mtime: float,
    result: Dict,
) -> None:
    """
    Save flow computation results to cache.

    Parameters
    ----------
    cache_dir : Path
        Directory for cached files
    cache_params : dict
        Parameters that form the cache key
    dem_mtime : float
        Modification time of DEM file
    result : dict
        Flow computation results
    """
    cache_dir.mkdir(parents=True, exist_ok=True)

    # Save metadata
    metadata_file = cache_dir / "flow_cache_metadata.json"
    full_metadata = {
        # Top-level fields for easy access
        "dem_path": cache_params["dem_path"],
        "backend": cache_params["backend"],
        "timestamp": datetime.datetime.now().isoformat(),
        # Full cache params for validation
        "cache_params": cache_params,
        "dem_mtime": dem_mtime,
        "computation_metadata": result["metadata"],
    }
    with open(metadata_file, "w") as f:
        json.dump(full_metadata, f, indent=2)


def compute_flow_with_basins(
    dem: np.ndarray,
    dem_transform: Affine,
    precipitation: Optional[np.ndarray] = None,
    precip_transform: Optional[Affine] = None,
    lake_mask: Optional[np.ndarray] = None,
    lake_outlets: Optional[np.ndarray] = None,
    detect_basins: bool = True,
    min_basin_size: int = 5000,
    min_basin_depth: float = 1.0,
    backend: str = "spec",
    coastal_elev_threshold: float = 0.0,
    edge_mode: str = "all",
    max_breach_depth: float = 25.0,
    max_breach_length: int = 150,
    epsilon: float = 1e-4,
    ocean_threshold: float = 0.0,
    ocean_border_only: bool = True,
    upscale_precip: bool = False,
    upscale_factor: int = 4,
    upscale_method: str = "auto",
    verbose: bool = True,
) -> Dict[str, any]:
    """Compute flow using validated basin preservation pattern (array-based API).

    This is the array-based entry point for flow computation with basin-aware
    lake handling. Unlike ``flow_accumulation()`` which takes file paths,
    this accepts numpy arrays directly.

    Steps:
    1. Detect ocean mask
    2. Detect endorheic basins (optional)
    3. Create conditioning mask (ocean + basins + selective lakes)
    4. Condition DEM with combined mask
    5. Compute flow direction with DEM-based spillway lake routing
    6. Compute drainage area
    7. Compute upstream rainfall (if precipitation provided)

    Parameters
    ----------
    dem : np.ndarray
        Digital elevation model (2D array)
    dem_transform : Affine
        Geographic transform for DEM
    precipitation : np.ndarray, optional
        Precipitation data (same shape as DEM)
    lake_mask : np.ndarray, optional
        Labeled mask of water bodies (0 = no lake, >0 = lake ID)
    lake_outlets : np.ndarray, optional
        Boolean mask of lake outlet cells
    detect_basins : bool, default=True
        Whether to detect and preserve endorheic basins
    verbose : bool, default=True
        Print progress messages

    Returns
    -------
    dict
        Keys: 'flow_direction', 'drainage_area', 'dem_conditioned',
        'breached_dem', 'ocean_mask', 'basin_mask', 'lake_inlets',
        'upstream_rainfall', 'conditioning_mask'
    """
    from src.terrain.water_bodies import (
        identify_lake_inlets,
        create_lake_flow_routing,
        compute_outlet_downstream_directions,
        find_lake_spillways,
    )
    from src.terrain.transforms import upscale_scores

    if verbose:
        print("=" * 60)
        print("FLOW PIPELINE WITH BASIN PRESERVATION")
        print("=" * 60)

    # Step 1: Detect ocean
    if verbose:
        print("\n1. Detecting ocean...")
    ocean_mask = detect_ocean_mask(
        dem, threshold=ocean_threshold, border_only=ocean_border_only
    )
    if verbose:
        ocean_pct = 100 * np.sum(ocean_mask) / dem.size
        print(f"   Ocean cells: {np.sum(ocean_mask):,} ({ocean_pct:.1f}%)")

    # Step 2: Detect endorheic basins (optional)
    basin_mask = None

    if detect_basins:
        if verbose:
            print("\n2. Detecting endorheic basins...")
        total_cells = dem.size
        adaptive_min_size = int(1e-3 * total_cells)
        effective_min_size = adaptive_min_size if min_basin_size == 5000 else min_basin_size

        if verbose and effective_min_size != min_basin_size:
            print(f"   Adaptive basin size: {effective_min_size:,} cells "
                  f"({100*effective_min_size/total_cells:.4f}% of domain)")

        basin_mask, endorheic_basins = detect_endorheic_basins(
            dem, min_size=effective_min_size, exclude_mask=ocean_mask,
            min_depth=min_basin_depth,
        )

        if basin_mask is not None and np.any(basin_mask):
            num_basins = len(endorheic_basins)
            if verbose:
                print(f"   Found {num_basins} endorheic basin(s)")
        else:
            if verbose:
                print("   No significant endorheic basins detected")
            basin_mask = None

    # Step 3: Create conditioning mask (basin-aware lake pre-masking)
    if verbose:
        print("\n3. Creating DEM conditioning mask...")
    conditioning_mask = ocean_mask.copy()

    if lake_mask is not None and basin_mask is not None and np.any(basin_mask):
        lakes_in_basins = (lake_mask > 0) & basin_mask
        if np.any(lakes_in_basins):
            if verbose:
                print(f"   Pre-masking {np.sum(lakes_in_basins):,} lake cells "
                      "inside basins (drainage sinks)")
            conditioning_mask = conditioning_mask | lakes_in_basins
        lakes_outside = (lake_mask > 0) & ~basin_mask
        if np.any(lakes_outside) and verbose:
            print(f"   NOT masking {np.sum(lakes_outside):,} lake cells "
                  "outside basins (river connectors)")
    elif lake_mask is not None and np.any(lake_mask > 0):
        if verbose:
            print(f"   NOT masking {np.sum(lake_mask > 0):,} lake cells "
                  "(no basins detected, all are connectors)")

    if basin_mask is not None and np.any(basin_mask):
        if verbose:
            print(f"   Pre-masking {np.sum(basin_mask):,} basin cells "
                  "to preserve topography")
        conditioning_mask = conditioning_mask | basin_mask

    # Step 4: Condition DEM
    if verbose:
        print(f"\n4. Conditioning DEM (backend={backend})...")
    if backend == "spec":
        dem_conditioned, outlets, breached_dem = condition_dem_spec(
            dem, nodata_mask=conditioning_mask,
            coastal_elev_threshold=coastal_elev_threshold,
            edge_mode=edge_mode, max_breach_depth=max_breach_depth,
            max_breach_length=max_breach_length, epsilon=epsilon,
        )
    else:
        dem_conditioned = condition_dem(
            dem, method="breach", ocean_mask=conditioning_mask,
            min_basin_size=min_basin_size, min_basin_depth=min_basin_depth,
        )
        breached_dem = None

    # Step 5: Identify lake inlets
    lake_inlets = None
    if lake_mask is not None and np.any(lake_mask > 0):
        if verbose:
            print("\n5. Identifying lake inlets...")
        outlet_mask_for_inlets = lake_outlets if lake_outlets is not None else None
        inlets_dict = identify_lake_inlets(
            lake_mask, dem_conditioned, outlet_mask=outlet_mask_for_inlets
        )
        if inlets_dict:
            lake_inlets = np.zeros_like(lake_mask, dtype=bool)
            for lake_id, inlet_cells in inlets_dict.items():
                for row, col in inlet_cells:
                    if 0 <= row < lake_inlets.shape[0] and 0 <= col < lake_inlets.shape[1]:
                        lake_inlets[row, col] = True
            if verbose:
                print(f"   Inlet cells: {np.sum(lake_inlets)}")

    # Step 6: Compute flow direction with DEM-based spillway lake routing
    if verbose:
        print("\n6. Computing flow direction...")
    flow_dir_base = compute_flow_direction(dem_conditioned, mask=ocean_mask)
    flow_dir = flow_dir_base.copy()

    if lake_mask is not None and lake_outlets is not None and np.any(lake_mask > 0):
        if verbose:
            print("   Applying lake flow routing (DEM-based spillways)...")

        labeled_lakes = lake_mask.copy()
        if basin_mask is not None and np.any(basin_mask):
            labeled_lakes[basin_mask] = 0
            lakes_outside = (lake_mask > 0) & ~basin_mask
        else:
            lakes_outside = lake_mask > 0

        if np.any(lakes_outside):
            spillways = find_lake_spillways(labeled_lakes, dem_conditioned)
            spillway_outlets = np.zeros(lake_mask.shape, dtype=bool)
            for lake_id, (sr, sc, _sdir) in spillways.items():
                spillway_outlets[sr, sc] = True

            if verbose:
                print(f"   DEM spillway detection: {len(spillways)} spillways")

            lake_flow = create_lake_flow_routing(
                labeled_lakes, spillway_outlets, dem_conditioned
            )
            flow_dir = np.where(lakes_outside, lake_flow, flow_dir_base)

            if np.any(spillway_outlets):
                flow_dir = compute_outlet_downstream_directions(
                    flow_dir, labeled_lakes, spillway_outlets,
                    dem_conditioned, basin_mask=basin_mask, spillways=spillways,
                )

            if verbose:
                print(f"   Applied routing to {np.sum(lakes_outside):,} cells "
                      f"with {len(spillways)} spillway outlets")

    # Step 7: Compute drainage area
    if verbose:
        print("\n7. Computing drainage area...")
    drainage_area = compute_drainage_area(flow_dir)

    # Step 8: Compute upstream rainfall (optional)
    upstream_rainfall = None
    if precipitation is not None:
        if verbose:
            print("\n8. Computing upstream rainfall...")
        precip_for_accumulation = precipitation.copy()

        if upscale_precip and precipitation.shape != dem.shape:
            scale_y = dem.shape[0] / precipitation.shape[0]
            scale_x = dem.shape[1] / precipitation.shape[1]
            if abs(scale_y - scale_x) < 0.01 and abs(scale_y - round(scale_y)) < 0.01:
                scale_int = int(round(scale_y))
                precip_for_accumulation = upscale_scores(
                    precipitation, scale=scale_int,
                    method=upscale_method, nodata_value=0.0
                )
            else:
                from scipy.ndimage import zoom
                precip_for_accumulation = zoom(
                    precipitation, (scale_y, scale_x), order=3, mode='reflect'
                )
            if verbose:
                print(f"   Upscaled precipitation: {precipitation.shape} → "
                      f"{precip_for_accumulation.shape}")

        precip_masked = precip_for_accumulation.copy()
        precip_masked[ocean_mask] = 0
        upstream_rainfall = compute_upstream_rainfall(flow_dir, precip_masked)

    if verbose:
        print("\nFlow pipeline complete.")
        print("=" * 60)

    return {
        "flow_direction": flow_dir,
        "drainage_area": drainage_area,
        "dem_conditioned": dem_conditioned,
        "breached_dem": breached_dem,
        "ocean_mask": ocean_mask,
        "basin_mask": basin_mask,
        "lake_inlets": lake_inlets,
        "upstream_rainfall": upstream_rainfall,
        "conditioning_mask": conditioning_mask,
    }


[docs] def flow_accumulation( dem_path: str, precipitation_path: str, output_dir: Optional[str] = None, flow_algorithm: str = "d8", # Legacy parameters (used when backend="legacy" or "custom") fill_method: str = "breach", min_basin_size: Optional[int] = 10000, max_fill_depth: Optional[float] = None, # Spec-compliant parameters (used when backend="spec") coastal_elev_threshold: float = 10.0, edge_mode: Literal["all", "local_minima", "outward_slope", "none"] = "all", max_breach_depth: float = 50.0, max_breach_length: int = 100, parallel_method: str = "checkerboard", epsilon: Optional[float] = None, masked_basin_outlets: Optional[np.ndarray] = None, # Basin detection parameters detect_basins: bool = False, min_basin_depth: float = 1.0, # Common parameters cell_size: Optional[float] = None, max_cells: Optional[int] = None, target_vertices: Optional[int] = None, mask_ocean: bool = True, ocean_elevation_threshold: float = 0.0, backend: Literal["legacy", "spec", "pysheds"] = "spec", lake_mask: Optional[np.ndarray] = None, lake_outlets: Optional[np.ndarray] = None, # Precipitation upscaling parameters upscale_precip: bool = False, upscale_factor: int = 4, upscale_method: str = "auto", # Caching parameters cache: bool = False, cache_dir: Optional[str] = None, ) -> Dict[str, Union[np.ndarray, Dict]]: """ Compute flow accumulation with precipitation weighting. Automatically downsamples DEM if it exceeds max_cells to improve performance. Parameters ---------- dem_path : str Path to DEM raster file (GeoTIFF) precipitation_path : str Path to annual precipitation raster (mm/year) output_dir : str, optional Directory for output files (default: same as DEM) flow_algorithm : str, default 'd8' Flow routing method ('d8' only for now) fill_method : str, default 'breach' Depression handling ('breach' or 'fill') cell_size : float, optional Override DEM resolution in meters max_cells : int, optional Maximum number of cells for flow computation. DEM will be downsampled if it exceeds this limit. Mutually exclusive with target_vertices. target_vertices : int, optional Target number of vertices for final rendering. Automatically sets max_cells = target_vertices * 3 for flow accuracy. Mutually exclusive with max_cells. mask_ocean : bool, default True If True, detect and exclude ocean/water bodies from flow computation. Ocean cells (elevation <= ocean_elevation_threshold, connected to border) will not be filled and will have flow_dir = 0. ocean_elevation_threshold : float, default 0.0 Elevation threshold (meters) for ocean detection. Cells at or below this elevation connected to the border are considered ocean. min_basin_size : int, optional, default 10000 Minimum basin size (cells) to preserve. Endorheic basins >= this size will not be filled (preserves large natural basins like Salton Sea). Set to None to disable basin preservation. max_fill_depth : float, optional Maximum fill depth (meters). Depressions requiring fill > this depth will be preserved. Set to None to allow unlimited fill depth. backend : {"legacy", "spec", "pysheds"}, default "spec" Which implementation to use for core hydrology algorithms: - "spec": Spec-compliant 4-stage pipeline (outlet ID + breaching + fill) - RECOMMENDED - "legacy": Morphological reconstruction + workarounds (deprecated) - "pysheds": PySheds library integration (experimental, may produce cycles) Legacy parameters (used when backend="legacy" or backend="pysheds"): fill_method : str, default "breach" Depression handling ("breach" or "fill") min_basin_size : int, default 10000 Minimum basin size (cells) to preserve max_fill_depth : float, optional Maximum fill depth (meters) for preserving deep basins Spec-compliant parameters (used when backend="spec"): coastal_elev_threshold : float, default 10.0 Maximum elevation for coastal outlets (meters above sea level) edge_mode : {"all", "local_minima", "outward_slope", "none"}, default "all" Boundary outlet strategy (see flow-spec.md for details) max_breach_depth : float, default 50.0 Maximum elevation drop at any cell during breaching (meters) max_breach_length : int, default 100 Maximum breach path length (cells) epsilon : float, optional Minimum gradient in filled areas (meters/cell). If None (default), automatically calculated as 1e-5 * cell_resolution per flow-spec.md guidelines (e.g., 1e-4 for 10m DEM). Pass explicit value to override. masked_basin_outlets : np.ndarray (bool), optional User-supplied outlet locations for known lakes/basins lake_mask : np.ndarray, optional Labeled mask of known water bodies (0 = no lake, N = lake ID). Lake interior cells will route flow toward their outlets. lake_outlets : np.ndarray (bool), optional Boolean mask of lake outlet cells. Required if lake_mask is provided. Outlets receive accumulated flow from all lake cells. detect_basins : bool, default False If True, automatically detect and preserve endorheic basins (closed drainage basins). Basins are masked during DEM conditioning to preserve their original topography. Works with spec backend only. min_basin_depth : float, default 1.0 Minimum basin depth (meters) to be considered endorheic. Only used when detect_basins=True. Basins shallower than this threshold are not preserved. upscale_precip : bool, default False If True, upscale precipitation data to match DEM resolution using ESRGAN/bilateral upscaling before computing upstream rainfall. This preserves fine-scale precipitation patterns and reduces coastal artifacts. Upscaling happens BEFORE ocean masking. upscale_factor : int, default 4 Target upscaling factor for precipitation (2, 4, or 8). Only used if upscale_precip=True. upscale_method : str, default "auto" Upscaling method: "auto" (try ESRGAN, fall back to bilateral), "esrgan" (Real-ESRGAN neural network), "bilateral" (bilateral filter), or "bicubic" (simple interpolation). Only used if upscale_precip=True. cache : bool, default False If True, cache computation results and load from cache if valid. Cache is invalidated if DEM file is modified or parameters change. cache_dir : str, optional Directory for cached files. Defaults to DEM's directory if not specified. Returns ------- dict Dictionary with keys: - flow_direction: np.ndarray (D8 encoded) - drainage_area: np.ndarray (cells draining to each pixel) - upstream_rainfall: np.ndarray (mm·m² total upstream) - conditioned_dem: np.ndarray (pit-filled DEM) - metadata: dict (processing info, including downsampling details) - files: dict (output file paths) Raises ------ FileNotFoundError If DEM or precipitation file doesn't exist ValueError If spatial alignment fails or both max_cells and target_vertices specified """ # Validate inputs print(" flow_accumulation: validating inputs...", flush=True) dem_path = Path(dem_path) precip_path = Path(precipitation_path) if not dem_path.exists(): raise FileNotFoundError(f"DEM file not found: {dem_path}") if not precip_path.exists(): raise FileNotFoundError(f"Precipitation file not found: {precip_path}") # Validate max_cells and target_vertices if max_cells is not None and target_vertices is not None: raise ValueError("Cannot specify both max_cells and target_vertices") # Calculate max_cells from target_vertices if specified original_target_vertices = target_vertices if target_vertices is not None: # Use 3x target_vertices for flow accuracy max_cells = target_vertices * 3 # === CACHE CHECK === if cache: # Determine cache directory if cache_dir is None: cache_path = dem_path.parent else: cache_path = Path(cache_dir) # Build cache key parameters cache_params = _get_cache_key_params( dem_path=str(dem_path), backend=backend, max_cells=max_cells, target_vertices=original_target_vertices, fill_method=fill_method, mask_ocean=mask_ocean, ocean_elevation_threshold=ocean_elevation_threshold, coastal_elev_threshold=coastal_elev_threshold, edge_mode=edge_mode, max_breach_depth=max_breach_depth, max_breach_length=max_breach_length, epsilon=epsilon, ) # Check cache validity dem_mtime = _get_dem_mtime(dem_path) if _validate_cache(cache_path, cache_params, dem_mtime): print(" flow_accumulation: loading from cache...", flush=True) return _load_from_cache(cache_path) print(" flow_accumulation: cache miss, computing...", flush=True) # Load DEM print(" flow_accumulation: loading DEM...", flush=True) with rasterio.open(dem_path) as src: dem_data = src.read(1).astype(np.float32) dem_transform = src.transform dem_crs = src.crs original_shape = dem_data.shape print(f" flow_accumulation: DEM loaded {original_shape}", flush=True) # Adaptive resolution: downsample if DEM exceeds max_cells downsampling_applied = False downsample_factor = 1.0 dem_shape = original_shape if max_cells is not None and (original_shape[0] * original_shape[1]) > max_cells: # Calculate downsample factor to achieve max_cells current_cells = original_shape[0] * original_shape[1] downsample_factor = np.sqrt(current_cells / max_cells) # Calculate new shape new_height = int(original_shape[0] / downsample_factor) new_width = int(original_shape[1] / downsample_factor) downsampled_shape = (new_height, new_width) print(f" Downsampling DEM from {original_shape} ({current_cells:,} cells) to " f"{downsampled_shape} ({new_height * new_width:,} cells) " f"[{downsample_factor:.2f}x factor]...", flush=True) # Downsample DEM using rasterio from rasterio.warp import reproject, Resampling dem_downsampled = np.empty(downsampled_shape, dtype=np.float32) # Calculate new transform (larger pixels) downsampled_transform = dem_transform * Affine.scale(downsample_factor) reproject( source=dem_data, destination=dem_downsampled, src_transform=dem_transform, src_crs=dem_crs, dst_transform=downsampled_transform, dst_crs=dem_crs, resampling=Resampling.bilinear ) dem_data = dem_downsampled dem_transform = downsampled_transform dem_shape = downsampled_shape downsampling_applied = True # Also downsample lake_mask and lake_outlets if provided if lake_mask is not None: from scipy.ndimage import zoom scale_y = downsampled_shape[0] / original_shape[0] scale_x = downsampled_shape[1] / original_shape[1] lake_mask = zoom(lake_mask, (scale_y, scale_x), order=0) print(f" ✓ Downsampled lake_mask to {lake_mask.shape}") if lake_outlets is not None: from scipy.ndimage import zoom scale_y = downsampled_shape[0] / original_shape[0] scale_x = downsampled_shape[1] / original_shape[1] lake_outlets = zoom(lake_outlets.astype(np.uint8), (scale_y, scale_x), order=0).astype(bool) print(f" ✓ Downsampled lake_outlets to {lake_outlets.shape}") print(f" ✓ Downsampled DEM to {dem_shape}", flush=True) elif max_cells is not None: print(f"DEM size ({original_shape[0] * original_shape[1]:,} cells) below max_cells ({max_cells:,}), " f"no downsampling needed") # Load precipitation (cropped to DEM bounds using library function) print(" flow_accumulation: loading precipitation...", flush=True) from src.terrain.data_loading import load_geotiff_cropped_to_dem precip_data, precip_transform, precip_crs = load_geotiff_cropped_to_dem( precip_path, dem_shape=dem_shape, dem_transform=dem_transform, dem_crs=dem_crs, use_windowed_read=True, ) print(f" flow_accumulation: precipitation loaded {precip_data.shape}", flush=True) # Fill missing values (nodata) using nearest neighbor interpolation # Common nodata values: -9999, -32768, 0, NaN, or any negative values (precipitation can't be negative) nodata_mask = ( np.isnan(precip_data) | (precip_data < -1000) | # Catch extreme negative nodata values like -9999, -32768 (precip_data < 0) # Any negative value is invalid for precipitation ) if np.any(nodata_mask): num_missing = np.sum(nodata_mask) total_pixels = precip_data.size pct_missing = 100.0 * num_missing / total_pixels print(f" Imputing {num_missing:,} missing values ({pct_missing:.1f}%) using nearest neighbor...", flush=True) from scipy.ndimage import distance_transform_edt # Find indices of nearest valid values # Returns shape (ndim, *input_shape) - for 2D: (2, H, W) indices = distance_transform_edt(nodata_mask, return_distances=False, return_indices=True) # Fill missing values with nearest valid neighbors # indices[0][nodata_mask] = row indices, indices[1][nodata_mask] = col indices precip_data[nodata_mask] = precip_data[indices[0][nodata_mask], indices[1][nodata_mask]] print(f" ✓ Imputation complete", flush=True) # Check spatial alignment and resample if needed if precip_data.shape != dem_shape: # Calculate required scale factor scale_y = dem_shape[0] / precip_data.shape[0] scale_x = dem_shape[1] / precip_data.shape[1] is_upscaling = scale_y > 1.0 and scale_x > 1.0 # Use ESRGAN upscaling if requested AND actually upscaling if upscale_precip and is_upscaling: print(f" Upscaling precipitation from {precip_data.shape} to {dem_shape} using {upscale_method}...", flush=True) # Use upscale_scores if scale is uniform and an integer if abs(scale_y - scale_x) < 0.01 and abs(scale_y - round(scale_y)) < 0.01: from src.terrain.transforms import upscale_scores scale_int = int(round(scale_y)) print(f" Running {upscale_method} {scale_int}x upscaling...", flush=True) precip_upscaled = upscale_scores( precip_data, scale=scale_int, method=upscale_method, nodata_value=0.0 ) precip_data = precip_upscaled print(f" ✓ Upscaled precipitation using {upscale_method}: {precip_data.shape}", flush=True) else: # Non-uniform scaling - Detroit-style approach for GPU acceleration # Step 1: Over-upscale to next power-of-2 with ESRGAN (GPU) # Step 2: Downsample to exact target with rasterio reproject import math avg_scale = (scale_y + scale_x) / 2 # Round UP to next power of 2 (e.g., 29.458 → 32) power_of_2_scale = 2 ** math.ceil(math.log2(avg_scale)) if power_of_2_scale >= 2 and upscale_method in ("auto", "esrgan"): # Use ESRGAN for over-upscaling, then downsample print(f" Detroit-style upscaling: ESRGAN {power_of_2_scale}x + downsample to exact shape...", flush=True) from src.terrain.transforms import upscale_scores # Step 1: ESRGAN over-upscaling to power-of-2 scale (GPU-accelerated) print(f" Running ESRGAN {power_of_2_scale}x upscaling (this may take 10-60s)...", flush=True) precip_esrgan = upscale_scores( precip_data, scale=power_of_2_scale, method=upscale_method, nodata_value=0.0 ) print(f" ✓ ESRGAN complete: {precip_data.shape}{precip_esrgan.shape}", flush=True) # Step 2: Downsample to exact target shape with rasterio reproject from rasterio.warp import reproject, Resampling print(f" Downsampling to exact target shape...", flush=True) precip_final = np.empty(dem_shape, dtype=np.float32) # Create transforms for intermediate and target shapes if precip_transform is not None and dem_transform is not None: # Calculate intermediate transform (after ESRGAN upscaling) esrgan_transform = precip_transform * Affine.scale(1.0 / power_of_2_scale) reproject( source=precip_esrgan, destination=precip_final, src_transform=esrgan_transform, src_crs=precip_crs, dst_transform=dem_transform, dst_crs=dem_crs, resampling=Resampling.bilinear, ) print(f" ✓ Downsampling complete: {precip_esrgan.shape}{precip_final.shape}", flush=True) else: # No transform available - use scipy zoom for final adjustment from scipy.ndimage import zoom scale_y_final = dem_shape[0] / precip_esrgan.shape[0] scale_x_final = dem_shape[1] / precip_esrgan.shape[1] precip_final = zoom(precip_esrgan, (scale_y_final, scale_x_final), order=1, mode='reflect') print(f" ✓ Downsampling complete: {precip_esrgan.shape}{precip_final.shape}", flush=True) precip_data = precip_final print(f" ✓ Detroit-style upscaling complete: {precip_final.shape}", flush=True) else: # Fall back to basic bicubic for small scales or non-ESRGAN methods from rasterio.warp import reproject, Resampling print(f" Non-uniform scaling ({scale_y:.2f}x, {scale_x:.2f}x), using rasterio reproject...", flush=True) precip_resampled = np.empty(dem_shape, dtype=np.float32) reproject( source=precip_data, destination=precip_resampled, src_transform=precip_transform, src_crs=precip_crs, dst_transform=dem_transform, dst_crs=dem_crs, resampling=Resampling.bilinear ) precip_data = precip_resampled print(f" ✓ Resampled precipitation to {precip_data.shape}", flush=True) elif upscale_precip and not is_upscaling: # User requested upscaling but data is being downscaled - inform and use standard resampling print(f" Precipitation is being downscaled ({scale_y:.2f}x, {scale_x:.2f}x), using bilinear resampling...", flush=True) from rasterio.warp import reproject, Resampling precip_resampled = np.empty(dem_shape, dtype=np.float32) reproject( source=precip_data, destination=precip_resampled, src_transform=precip_transform, src_crs=precip_crs, dst_transform=dem_transform, dst_crs=dem_crs, resampling=Resampling.bilinear ) precip_data = precip_resampled print(f" ✓ Downsampled precipitation to {precip_data.shape}") else: # Standard resampling (no upscaling requested) print(f" Resampling precipitation from {precip_data.shape} to match DEM {dem_shape}...", flush=True) from rasterio.warp import reproject, Resampling precip_resampled = np.empty(dem_shape, dtype=np.float32) reproject( source=precip_data, destination=precip_resampled, src_transform=precip_transform, src_crs=precip_crs, dst_transform=dem_transform, dst_crs=dem_crs, resampling=Resampling.bilinear ) precip_data = precip_resampled print(f" ✓ Resampled precipitation to {precip_data.shape}") # Determine cell size (convert from degrees to meters if geographic CRS) if cell_size is None: from rasterio.crs import CRS import math # Check if CRS is geographic (lat/lon in degrees) if dem_crs is not None and CRS.from_user_input(dem_crs).is_geographic: # Cell size is in degrees - convert to meters using Haversine approximation # Calculate at center latitude for best accuracy pixel_size_deg = abs(dem_transform.a) # Get bounds to find center latitude height, width = dem_shape left = dem_transform.c top = dem_transform.f bottom = top + (height * dem_transform.e) center_lat = (top + bottom) / 2 # Haversine approximation for degrees to meters # At center latitude lon_to_m = 111320 * math.cos(math.radians(center_lat)) lat_to_m = 110540 cell_width_m = pixel_size_deg * lon_to_m cell_height_m = abs(dem_transform.e) * lat_to_m # Use average of width and height for area calculations cell_size = math.sqrt(cell_width_m * cell_height_m) print(f" Geographic CRS detected (cell size: {pixel_size_deg:.8f}° ≈ {cell_size:.1f}m at lat {center_lat:.2f}°)") else: # Projected CRS - pixel size is already in CRS units (meters) cell_size = abs(dem_transform.a) # Auto-calculate epsilon if not provided (flow-spec.md guideline) # epsilon = 1e-5 * cell_resolution (e.g., 1e-4 for 10m DEM) if epsilon is None: epsilon = 1e-5 * cell_size print(f" Auto-calculated epsilon: {epsilon:.2e} m/cell (= 1e-5 × {cell_size:.1f}m cell size)") # Step 1: Detect ocean mask (if enabled) ocean_mask = None if mask_ocean: print(f"Detecting ocean (elevation <= {ocean_elevation_threshold}m, border-connected)...") ocean_mask = detect_ocean_mask( dem_data, threshold=ocean_elevation_threshold, border_only=True ) ocean_cells = np.sum(ocean_mask) ocean_pct = 100 * ocean_cells / ocean_mask.size print(f" Ocean detected: {ocean_cells:,} cells ({ocean_pct:.1f}%)") # Step 1b: Detect endorheic basins (if enabled and using spec backend) basin_mask = None if detect_basins and backend == "spec": print(f"Detecting endorheic basins (min_size={min_basin_size}, min_depth={min_basin_depth:.1f}m)...") basin_mask, endorheic_basins = detect_endorheic_basins( dem_data, min_size=min_basin_size, min_depth=min_basin_depth, exclude_mask=ocean_mask, ) if basin_mask is not None and np.any(basin_mask): num_basins = len(endorheic_basins) basin_coverage = 100 * np.sum(basin_mask) / dem_data.size print(f" Found {num_basins} endorheic basin(s) ({basin_coverage:.2f}% of domain)") print(f" Basins will be masked during conditioning to preserve topography") else: print(" No significant endorheic basins detected") basin_mask = None elif detect_basins and backend != "spec": import warnings warnings.warn( f"detect_basins=True is only supported with backend='spec'. " f"Use min_basin_size parameter with legacy backend instead.", UserWarning ) # Create combined conditioning mask for spec backend # Strategy: ocean + endorheic basins + selective lakes # (Lakes handled later in the pipeline based on basin location) conditioning_mask = ocean_mask.copy() if ocean_mask is not None else np.zeros(dem_data.shape, dtype=bool) # Basin-aware lake pre-masking: # Lakes INSIDE basins → masked (drainage sinks like Salton Sea) # Lakes OUTSIDE basins → NOT masked (river connectors) if lake_mask is not None and basin_mask is not None and np.any(basin_mask): lakes_in_basins = (lake_mask > 0) & basin_mask if np.any(lakes_in_basins): conditioning_mask = conditioning_mask | lakes_in_basins print(f" Pre-masking {np.sum(lakes_in_basins):,} lake cells inside basins (drainage sinks)") lakes_outside = (lake_mask > 0) & ~basin_mask if np.any(lakes_outside): print(f" NOT masking {np.sum(lakes_outside):,} lake cells outside basins (river connectors)") elif lake_mask is not None and np.any(lake_mask > 0): print(f" NOT masking {np.sum(lake_mask > 0):,} lake cells (no basins detected, all are connectors)") if basin_mask is not None and np.any(basin_mask): conditioning_mask = conditioning_mask | basin_mask print(f" Combined conditioning mask: {np.sum(conditioning_mask):,} cells " f"({100*np.sum(conditioning_mask)/conditioning_mask.size:.1f}%)") # Use ocean mask for flow computation (flow direction terminals) # Note: Basins are NOT masked from flow - flow is computed inside them flow_mask = ocean_mask if ocean_mask is not None else None # Initialize breached_dem (only set by spec backend, None for others) breached_dem = None # Branch based on backend if backend == "spec": # === SPEC-COMPLIANT BACKEND === # Use spec-compliant 4-stage pipeline (outlet ID + breaching + fill) print(f"Using spec-compliant backend (flow-spec.md)...") # Emit warnings if legacy parameters are specified if fill_method != "breach": import warnings warnings.warn( f"fill_method='{fill_method}' is ignored when backend='spec'. " "Use epsilon parameter instead (epsilon=0 for fill, epsilon>0 for breach-like).", DeprecationWarning ) if min_basin_size != 10000: import warnings warnings.warn( "min_basin_size is ignored when backend='spec'. " "Use max_breach_depth/max_breach_length to control basin preservation.", DeprecationWarning ) # Step 2: Condition DEM using spec-compliant pipeline # Use combined conditioning mask (ocean + basins) to preserve topography print("Conditioning DEM (outlets + breach + fill)...", flush=True) nodata_mask_for_spec = conditioning_mask if detect_basins else ocean_mask conditioned_dem, outlets, breached_dem = condition_dem_spec( dem_data, nodata_mask=nodata_mask_for_spec, coastal_elev_threshold=coastal_elev_threshold, edge_mode=edge_mode, max_breach_depth=max_breach_depth, max_breach_length=max_breach_length, epsilon=epsilon, masked_basin_outlets=masked_basin_outlets, parallel_method=parallel_method, ) # Step 3: Compute flow directions print("Computing flow directions...", flush=True) # CRITICAL: Combine ocean/nodata with all identified outlets (edge, coastal, masked basin) # so that ALL outlet types are properly used as flow direction terminals outlet_mask = ocean_mask | outlets flow_direction = compute_flow_direction(conditioned_dem, mask=outlet_mask) print(" ✓ Flow directions computed", flush=True) elif backend == "pysheds": if not PYSHEDS_AVAILABLE: raise ImportError("pysheds is not installed. Install with: pip install pysheds") # === PYSHEDS BACKEND === # Use pysheds for core hydrology (depression filling, flow direction, accumulation) # WARNING: PySheds may produce flow cycles in some cases. Use custom backend for # production work. import warnings warnings.warn( "PySheds backend is experimental and may produce flow cycles. " "Use backend='custom' for reliable results.", UserWarning ) print(f"Using pysheds backend for flow computation...") # Create pysheds grid from numpy array # PySheds expects nodata value - use a large negative number for ocean/masked areas dem_for_pysheds = dem_data.copy() nodata_value = -9999.0 if ocean_mask is not None and np.any(ocean_mask): dem_for_pysheds[ocean_mask] = nodata_value # Create ViewFinder and Raster objects for pysheds from pysheds.sview import Raster, ViewFinder # Create ViewFinder with proper metadata viewfinder = ViewFinder( shape=dem_for_pysheds.shape, affine=dem_transform, crs=dem_crs, nodata=nodata_value, ) # Wrap numpy array as Raster and create grid from viewfinder dem_raster = Raster(dem_for_pysheds, viewfinder=viewfinder) grid = PyshedsGrid(viewfinder=viewfinder) # Step 2: Condition DEM using pysheds print(" pysheds: Filling pits...") pit_filled = grid.fill_pits(dem_raster) print(" pysheds: Filling depressions...") flooded = grid.fill_depressions(pit_filled) print(" pysheds: Resolving flats...") inflated = grid.resolve_flats(flooded) conditioned_dem = np.array(inflated).astype(np.float32) # Restore original ocean values to conditioned DEM if ocean_mask is not None and np.any(ocean_mask): conditioned_dem[ocean_mask] = dem_data[ocean_mask] # Step 3: Compute flow direction using pysheds print(" pysheds: Computing flow direction...") fdir = grid.flowdir(inflated) # Convert pysheds flow direction to our D8 encoding # PySheds uses same D8 encoding by default (1,2,4,8,16,32,64,128) # PySheds returns negative values for outlets/boundaries (e.g., -2) # We need to convert these to 0 (outlet) before casting to uint8 fdir_arr = np.array(fdir) fdir_arr[fdir_arr < 0] = 0 # Convert negative values to outlet (0) flow_direction = fdir_arr.astype(np.uint8) # Apply ocean mask to flow direction (ocean cells = outlet) if ocean_mask is not None and np.any(ocean_mask): flow_direction[ocean_mask] = 0 # Fix coastal cells to flow toward ocean (pysheds doesn't do this automatically) _fix_coastal_flow_directions(flow_direction, ocean_mask) else: # === LEGACY BACKEND (morphological reconstruction + workarounds) === # Step 2: Condition DEM (fill pits/depressions with masking) # Scale min_basin_size if downsampling was applied (area scales with factor²) scaled_min_basin_size = min_basin_size if min_basin_size is not None and downsample_factor > 1.0: scaled_min_basin_size = max(100, int(min_basin_size / (downsample_factor ** 2))) print(f"Conditioning DEM (method={fill_method}, min_basin_size={min_basin_size}{scaled_min_basin_size} scaled)...") else: print(f"Conditioning DEM (method={fill_method}, min_basin_size={min_basin_size})...") conditioned_dem = condition_dem( dem_data, method=fill_method, ocean_mask=ocean_mask, min_basin_size=scaled_min_basin_size, max_fill_depth=max_fill_depth, ) # Step 3: Compute flow directions (with combined mask) print("Computing flow directions...") flow_direction = compute_flow_direction( conditioned_dem, mask=flow_mask if np.any(flow_mask) else None ) # Step 3.5: Apply lake routing if lake_mask provided # Uses DEM-based spillway detection (boundary cells) instead of HydroLAKES # pour points (which land in lake interiors and create terminal outlets). # Basin-aware: only routes lakes outside preserved endorheic basins. if lake_mask is not None and lake_outlets is not None and np.any(lake_mask > 0): from src.terrain.water_bodies import ( create_lake_flow_routing, find_lake_spillways, compute_outlet_downstream_directions, ) print("Applying lake flow routing (DEM-based spillways)...") # Basin-aware: only route lakes OUTSIDE preserved basins labeled_lakes = lake_mask.copy() if basin_mask is not None and np.any(basin_mask): labeled_lakes[basin_mask] = 0 lakes_outside = (lake_mask > 0) & ~basin_mask n_in = len(np.unique(lake_mask[(lake_mask > 0) & basin_mask])) n_out = len(np.unique(lake_mask[lakes_outside])) print(f" {n_in} lakes inside basins (natural flow), " f"{n_out} lakes outside basins (explicit routing)") else: lakes_outside = lake_mask > 0 if np.any(lakes_outside): # DEM-based spillways: find lowest boundary cell for each lake spillways = find_lake_spillways(labeled_lakes, conditioned_dem) spillway_outlets = np.zeros(lake_mask.shape, dtype=bool) for lake_id, (sr, sc, _sdir) in spillways.items(): spillway_outlets[sr, sc] = True print(f" DEM spillway detection: {len(spillways)} spillways " f"(replacing {int(np.sum(lake_outlets)):,} HydroLAKES pour points)") # BFS routing: all lake cells route toward DEM spillway lake_flow = create_lake_flow_routing( labeled_lakes, spillway_outlets, conditioned_dem ) flow_direction = np.where(lakes_outside, lake_flow, flow_direction) # Connect spillway outlets to downstream terrain (cycle-safe) if np.any(spillway_outlets): flow_direction = compute_outlet_downstream_directions( flow_direction, labeled_lakes, spillway_outlets, conditioned_dem, basin_mask=basin_mask, spillways=spillways, ) print(f" Applied routing to {np.sum(lakes_outside):,} cells " f"with {len(spillways)} spillway outlets") # Step 3.6: Identify lake inlets (after DEM conditioning + lake routing) lake_inlets = None if lake_mask is not None and np.any(lake_mask > 0): from src.terrain.water_bodies import identify_lake_inlets outlet_mask_for_inlets = lake_outlets if lake_outlets is not None else None inlets_dict = identify_lake_inlets( lake_mask, conditioned_dem, outlet_mask=outlet_mask_for_inlets ) if inlets_dict: lake_inlets = np.zeros_like(lake_mask, dtype=bool) for lake_id, inlet_cells in inlets_dict.items(): for row, col in inlet_cells: if 0 <= row < lake_inlets.shape[0] and 0 <= col < lake_inlets.shape[1]: lake_inlets[row, col] = True print(f" Lake inlets: {np.sum(lake_inlets):,} cells") # Step 4: Compute drainage area and upstream rainfall if backend == "pysheds": # Use pysheds accumulation print(" pysheds: Computing drainage area...") acc = grid.accumulation(fdir) drainage_area = np.array(acc).astype(np.float32) print(" pysheds: Computing upstream rainfall (weighted)...") # CRITICAL: Mask ocean in precipitation BEFORE accumulation # Otherwise ocean precip accumulates into coastal cells (coastline artifacts) precip_for_pysheds = precip_data.copy() if ocean_mask is not None and np.any(ocean_mask): precip_for_pysheds[ocean_mask] = 0 # Wrap precipitation as Raster for weighted accumulation precip_raster = Raster(precip_for_pysheds, viewfinder=viewfinder) weighted_acc = grid.accumulation(fdir, weights=precip_raster) upstream_rainfall = np.array(weighted_acc).astype(np.float32) # Apply ocean mask to drainage area output (already masked for upstream_rainfall) if ocean_mask is not None and np.any(ocean_mask): drainage_area[ocean_mask] = 0 else: # Use custom backend print("Computing drainage area...") drainage_area = compute_drainage_area(flow_direction) print("Computing upstream rainfall...") # CRITICAL: Mask ocean in precipitation BEFORE accumulation # Otherwise ocean precip accumulates into coastal cells (coastline artifacts) precip_masked = precip_data.copy() if ocean_mask is not None and np.any(ocean_mask): precip_masked[ocean_mask] = 0 upstream_rainfall = compute_upstream_rainfall(flow_direction, precip_masked) # Compute metadata total_area_km2 = (dem_shape[0] * dem_shape[1] * cell_size**2) / 1e6 max_drainage_cells = np.max(drainage_area) max_drainage_area_km2 = (max_drainage_cells * cell_size**2) / 1e6 max_upstream_m3 = np.max(upstream_rainfall) * cell_size**2 / 1000 metadata = { "cell_size_m": cell_size, "drainage_area_units": "cells", "total_area_km2": total_area_km2, "max_drainage_area_km2": max_drainage_area_km2, "max_upstream_rainfall_m3": max_upstream_m3, "algorithm": flow_algorithm, "fill_method": fill_method, "backend": backend, "downsampling_applied": downsampling_applied, "original_shape": original_shape, "downsampled_shape": dem_shape if downsampling_applied else original_shape, "downsample_factor": downsample_factor, # Store serializable versions of transform and crs "transform": tuple(dem_transform), # Affine as 6-tuple (a, b, c, d, e, f) "crs": str(dem_crs), # CRS as string (e.g., "EPSG:4326") } # Add target_vertices to metadata if specified if target_vertices is not None: metadata["target_vertices"] = target_vertices # Mark as fresh computation (not from cache) if cache: metadata["cache_hit"] = False # Save outputs if output_dir is None: output_dir = dem_path.parent else: output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) files = {} files["flow_direction"] = str(output_dir / "flow_direction.tif") files["drainage_area"] = str(output_dir / "flow_accumulation_area.tif") files["upstream_rainfall"] = str(output_dir / "flow_accumulation_rainfall.tif") files["conditioned_dem"] = str(output_dir / "dem_conditioned.tif") # Write output GeoTIFFs _write_geotiff(files["flow_direction"], flow_direction.astype(np.uint8), dem_transform, dem_crs) _write_geotiff(files["drainage_area"], drainage_area, dem_transform, dem_crs) _write_geotiff(files["upstream_rainfall"], upstream_rainfall, dem_transform, dem_crs) _write_geotiff(files["conditioned_dem"], conditioned_dem, dem_transform, dem_crs) result = { "flow_direction": flow_direction, "drainage_area": drainage_area, "upstream_rainfall": upstream_rainfall, "conditioned_dem": conditioned_dem, "breached_dem": breached_dem, "lake_mask": lake_mask, # Downsampled lake mask (or None if not provided) "lake_inlets": lake_inlets, "basin_mask": basin_mask, "ocean_mask": ocean_mask, "metadata": metadata, "files": files, } # === CACHE SAVE === if cache: # Determine cache directory (use output_dir if cache_dir not specified) if cache_dir is None: cache_save_path = dem_path.parent else: cache_save_path = Path(cache_dir) # Copy output files to cache location if different from output_dir if cache_save_path != output_dir: cache_save_path.mkdir(parents=True, exist_ok=True) for filename in ["flow_direction.tif", "flow_accumulation_area.tif", "flow_accumulation_rainfall.tif", "dem_conditioned.tif"]: src_file = output_dir / filename dst_file = cache_save_path / filename if src_file.exists(): shutil.copy2(src_file, dst_file) # Update files dict to point to cache location result["files"] = { "flow_direction": str(cache_save_path / "flow_direction.tif"), "drainage_area": str(cache_save_path / "flow_accumulation_area.tif"), "upstream_rainfall": str(cache_save_path / "flow_accumulation_rainfall.tif"), "conditioned_dem": str(cache_save_path / "dem_conditioned.tif"), } # Save cache metadata _save_to_cache(cache_save_path, cache_params, dem_mtime, result) print(f" flow_accumulation: cached to {cache_save_path}", flush=True) return result
@jit(nopython=True, cache=True) def _compute_flow_direction_jit(dem: np.ndarray, flow_dir: np.ndarray) -> None: """ JIT-compiled flow direction computation (numba accelerated). Modifies flow_dir in-place for maximum performance. Parameters ---------- dem : np.ndarray Digital elevation model flow_dir : np.ndarray Output array for flow directions (modified in-place) """ rows, cols = dem.shape # D8 direction offsets and codes (explicit for numba) offsets = np.array([ (0, 1), # East: 1 (-1, 1), # Northeast: 2 (-1, 0), # North: 4 (-1, -1), # Northwest: 8 (0, -1), # West: 16 (1, -1), # Southwest: 32 (1, 0), # South: 64 (1, 1), # Southeast: 128 ], dtype=np.int32) codes = np.array([1, 2, 4, 8, 16, 32, 64, 128], dtype=np.uint8) # Pre-compute distance factors for diagonal vs cardinal neighbors distances = np.array([1.0, 1.414, 1.0, 1.414, 1.0, 1.414, 1.0, 1.414], dtype=np.float32) # For each cell, find steepest downslope neighbor for i in range(rows): for j in range(cols): max_slope = 0.0 best_dir = 0 current_elev = dem[i, j] # Check all 8 neighbors # Priority order for tie-breaking: S, E, W, N, SE, SW, NE, NW # This ensures consistent flow direction when slopes are equal priority_order = np.array([6, 0, 4, 2, 7, 5, 1, 3], dtype=np.int32) for p in range(8): k = priority_order[p] di, dj = offsets[k] ni = i + di nj = j + dj # Check bounds if 0 <= ni < rows and 0 <= nj < cols: neighbor_elev = dem[ni, nj] slope = (current_elev - neighbor_elev) / distances[k] # Use >= for first priority direction, > for others # This ensures consistent tie-breaking if slope > max_slope: max_slope = slope best_dir = codes[k] # Handle pits and boundary outlets # A pit is an interior cell where all neighbors are higher # Pits become outlets (flow_dir = 0) to avoid cycles if best_dir == 0: # No downslope neighbor found - this is a pit or boundary # Mark as outlet (flow_dir = 0) best_dir = 0 flow_dir[i, j] = best_dir
[docs] def compute_flow_direction(dem: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray: """ Compute D8 flow direction from DEM (Stage 3 of flow-spec.md). Assigns each cell the direction of steepest descent, using ESRI's power-of-2 D8 encoding for compatibility with GIS workflows. Parameters ---------- dem : np.ndarray Digital elevation model (2D array, already conditioned by Stages 1-2) mask : np.ndarray (bool), optional Boolean mask indicating cells to exclude from flow computation. Masked cells (ocean, lakes, etc.) will have flow_dir = 0 (no flow). Returns ------- np.ndarray (uint8) Flow direction encoded as D8 using ESRI power-of-2 convention: - 0 = outlet or masked cell (no downstream flow) - 1,2,4,8,16,32,64,128 = ESRI D8 codes (see module docstring) Notes ----- This implements Stage 3 of flow-spec.md (lines 347-390). D8 Direction Encoding (ESRI ArcGIS): 8 4 2 16 x 1 32 64 128 Direction codes represent powers of 2 for bitwise compatibility: - 1=East, 2=NE, 4=North, 8=NW, 16=West, 32=SW, 64=South, 128=SE Distances: - Orthogonal (1,4,16,64): distance = 1 - Diagonal (2,8,32,128): distance = sqrt(2) Flow Direction Selection: Each cell flows toward the neighbor with maximum slope (elevation drop per unit distance). If no neighbor is lower, the cell is an outlet (flow_dir = 0). This should not occur after proper DEM conditioning (Stage 2) unless the cell is an outlet or masked. See Also -------- identify_outlets : Stage 1 (outlet identification) breach_depressions_constrained : Stage 2a (depression breaching) priority_flood_fill_epsilon : Stage 2b (depression filling) compute_drainage_area : Stage 4 (flow accumulation) """ rows, cols = dem.shape flow_dir = np.zeros((rows, cols), dtype=np.uint8) # Use JIT-compiled version if available (10-20x faster) if NUMBA_AVAILABLE: _compute_flow_direction_jit(dem, flow_dir) # Apply mask after computation if mask is not None: flow_dir[mask] = 0 # CRITICAL FIX: Force land cells adjacent to masked cells (ocean/sinks) # to flow toward the masked cell, even if there's no downslope. # This ensures watersheds properly drain to ocean/water bodies. if mask is not None: _fix_coastal_flow_directions(flow_dir, mask) return flow_dir # Fallback: Pure Python implementation for i in range(rows): for j in range(cols): max_slope = 0.0 best_dir = 0 current_elev = dem[i, j] # Check all 8 neighbors for (di, dj), direction_code in D8_DIRECTIONS.items(): ni, nj = i + di, j + dj # Check bounds if 0 <= ni < rows and 0 <= nj < cols: neighbor_elev = dem[ni, nj] slope = (current_elev - neighbor_elev) / np.sqrt(di**2 + dj**2) if slope > max_slope: max_slope = slope best_dir = direction_code # Handle pits and boundary outlets # A pit is an interior cell where all neighbors are higher # Pits become outlets (flow_dir = 0) to avoid cycles if best_dir == 0: # No downslope neighbor found - this is a pit or boundary # Mark as outlet (flow_dir = 0) best_dir = 0 flow_dir[i, j] = best_dir # Apply mask if mask is not None: flow_dir[mask] = 0 return flow_dir
@jit(nopython=True, cache=True) def _fix_coastal_flow_directions_jit(flow_dir: np.ndarray, mask: np.ndarray) -> int: """ JIT-compiled coastal flow direction fix. Only fixes pit cells (flow_dir == 0) adjacent to masked cells. Does not override valid flow directions. Parameters ---------- flow_dir : np.ndarray (uint8) Flow direction grid (modified in-place) mask : np.ndarray (bool) Mask of ocean/sink cells Returns ------- int Number of cells fixed """ rows, cols = flow_dir.shape # D8 offsets and codes (explicit arrays for numba) offsets = np.array([ (0, 1), (-1, 1), (-1, 0), (-1, -1), (0, -1), (1, -1), (1, 0), (1, 1) ], dtype=np.int32) codes = np.array([1, 2, 4, 8, 16, 32, 64, 128], dtype=np.uint8) fixed_count = 0 for i in range(rows): for j in range(cols): # Skip if already masked (ocean/sink) if mask[i, j]: continue # Only fix cells that are pits (flow_dir == 0) # Don't override valid flow directions if flow_dir[i, j] != 0: continue # Check if any neighbor is masked (ocean/sink) for k in range(8): di = offsets[k, 0] dj = offsets[k, 1] ni = i + di nj = j + dj if 0 <= ni < rows and 0 <= nj < cols and mask[ni, nj]: # Found adjacent ocean/sink - point flow toward it flow_dir[i, j] = codes[k] fixed_count += 1 break return fixed_count def _fix_coastal_flow_directions(flow_dir: np.ndarray, mask: np.ndarray) -> None: """ Fix pit cells adjacent to masked cells (ocean/sinks) to flow toward them. Only modifies cells with flow_dir == 0 (pits). This ensures coastal pits drain to ocean while preserving valid inland flow directions. This prevents coastal pits from fragmenting the drainage network. Parameters ---------- flow_dir : np.ndarray Flow direction grid (modified in-place) mask : np.ndarray (bool) Mask of ocean/sink cells """ if NUMBA_AVAILABLE: fixed_count = _fix_coastal_flow_directions_jit(flow_dir, mask) else: # Pure Python fallback rows, cols = flow_dir.shape D8_OFFSETS = [(0, 1), (-1, 1), (-1, 0), (-1, -1), (0, -1), (1, -1), (1, 0), (1, 1)] D8_CODES = [1, 2, 4, 8, 16, 32, 64, 128] fixed_count = 0 for i in range(rows): for j in range(cols): if mask[i, j]: continue # Only fix cells that are pits (flow_dir == 0) if flow_dir[i, j] != 0: continue for (di, dj), code in zip(D8_OFFSETS, D8_CODES): ni, nj = i + di, j + dj if 0 <= ni < rows and 0 <= nj < cols and mask[ni, nj]: flow_dir[i, j] = code fixed_count += 1 break print(f" Fixed {fixed_count:,} coastal cells to flow toward ocean/sinks") @jit(nopython=True, cache=True) def _compute_drainage_area_jit(flow_dir: np.ndarray, drainage_area: np.ndarray) -> np.bool_: """ JIT-compiled drainage area computation (numba accelerated). Uses array-based topological sorting (Kahn's algorithm) for massive speedup over dict/set approach. Implements Stage 4 of flow-spec.md with cycle detection. Parameters ---------- flow_dir : np.ndarray D8 flow direction grid (ESRI codes: 0,1,2,4,8,16,32,64,128) drainage_area : np.ndarray Output array (modified in-place), initialized to 1.0 Returns ------- bool True if a cycle was detected (should raise error in caller), False if topological sort completed successfully """ rows, cols = flow_dir.shape # D8 direction offsets (ESRI convention) # Maps direction codes 1,2,4,8,16,32,64,128 to (dy, dx) offsets offsets = np.array([ (0, 1), # 1: East (-1, 1), # 2: Northeast (-1, 0), # 4: North (-1, -1), # 8: Northwest (0, -1), # 16: West (1, -1), # 32: Southwest (1, 0), # 64: South (1, 1), # 128: Southeast ], dtype=np.int32) codes = np.array([1, 2, 4, 8, 16, 32, 64, 128], dtype=np.uint8) # Count how many cells flow INTO each cell (in_degree in topological sort) contributor_count = np.zeros((rows, cols), dtype=np.int32) for i in range(rows): for j in range(cols): direction = flow_dir[i, j] if direction > 0: # Find which offset corresponds to this direction for k in range(8): if codes[k] == direction: di, dj = offsets[k] ni, nj = i + di, j + dj if 0 <= ni < rows and 0 <= nj < cols: contributor_count[ni, nj] += 1 break # Create queue of cells to process (cells with 0 contributors = ridgelines/outlets) # Use flat indexing for efficiency queue_size = 0 queue = np.zeros(rows * cols, dtype=np.int32) for i in range(rows): for j in range(cols): if contributor_count[i, j] == 0: queue[queue_size] = i * cols + j queue_size += 1 # Track cells processed for cycle detection cells_processed = 0 initial_queue_size = queue_size # Process cells in topological order (Kahn's algorithm) queue_pos = 0 while queue_pos < queue_size: # Dequeue flat_idx = queue[queue_pos] queue_pos += 1 cells_processed += 1 i = flat_idx // cols j = flat_idx % cols # Get receiver direction = flow_dir[i, j] if direction > 0: # Find offset for this direction for k in range(8): if codes[k] == direction: di, dj = offsets[k] ni, nj = i + di, j + dj if 0 <= ni < rows and 0 <= nj < cols: # Accumulate drainage area drainage_area[ni, nj] += drainage_area[i, j] # Decrement contributor count contributor_count[ni, nj] -= 1 # If all contributors processed, add to queue if contributor_count[ni, nj] == 0: queue[queue_size] = ni * cols + nj queue_size += 1 break # Note: Cells with direction==0 (outlets/ocean) are processed but don't # send flow anywhere. They act as sinks that accumulate drainage from # upstream but have no outflow. # Cycle detection: count total non-masked cells that should have been processed # (all cells have valid flow_dir or are outlets with direction=0) total_cells_with_flow = 0 for i in range(rows): for j in range(cols): if flow_dir[i, j] >= 0: # Valid flow direction or outlet total_cells_with_flow += 1 # If not all cells were processed, there's a cycle # Note: cells_processed includes initial queue, so compare to total cycle_detected = cells_processed < total_cells_with_flow return cycle_detected
[docs] def compute_drainage_area(flow_dir: np.ndarray) -> np.ndarray: """ Compute drainage area (unweighted flow accumulation). Stage 4 of flow-spec.md pipeline. Uses topological sort (Kahn's algorithm) to traverse the flow network and accumulate drainage areas. Parameters ---------- flow_dir : np.ndarray D8 flow direction grid (ESRI codes: 0,1,2,4,8,16,32,64,128) where 0 indicates outlet/nodata (no downstream flow) Returns ------- np.ndarray Number of cells draining through each pixel (including itself) Raises ------ RuntimeError If a cycle is detected in the flow network, indicating a bug in DEM conditioning. Cycles should never occur after proper Stage 2 (DEM conditioning). Notes ----- Implementation matches flow-spec.md Stage 4 (lines 392-438). Uses topological sort with cycle detection to ensure each cell's contributions are computed before it contributes to its receiver. If cycle detection fires, check that: - DEM was properly conditioned (Stage 2a: breaching, 2b: filling) - All outlets are properly identified (Stage 1) - No invalid flow directions exist (Stage 3) """ rows, cols = flow_dir.shape drainage_area = np.ones((rows, cols), dtype=np.float32) # Use JIT-compiled version if available (50-100x faster) if NUMBA_AVAILABLE: cycle_detected = _compute_drainage_area_jit(flow_dir, drainage_area) if cycle_detected: raise RuntimeError( "Cycle detected in flow network! DEM conditioning failed. " "Check that outlets are properly identified (Stage 1) and " "DEM is properly conditioned (Stage 2a: breach, 2b: fill)." ) return drainage_area # Fallback: Pure Python implementation with dict/set # Build flow network: for each cell, track which cells flow INTO it contributors = {} # contributors[cell] = list of cells that flow to it receivers = {} # receivers[cell] = cell it flows to for i in range(rows): for j in range(cols): direction = flow_dir[i, j] if direction > 0 and direction in D8_OFFSETS: di, dj = D8_OFFSETS[direction] receiver = (i + di, j + dj) if 0 <= receiver[0] < rows and 0 <= receiver[1] < cols: receivers[(i, j)] = receiver # Track that (i,j) contributes to receiver if receiver not in contributors: contributors[receiver] = [] contributors[receiver].append((i, j)) # Process cells using topological sort (Kahn's algorithm) # Cells with no contributors (in_degree=0) are processed first processed = set() to_process = [] # Find all cells with no contributors (ridgelines/peaks/outlets) for i in range(rows): for j in range(cols): if (i, j) not in contributors: to_process.append((i, j)) # Process in topological order while to_process: cell = to_process.pop(0) if cell in processed: continue i, j = cell processed.add(cell) # Add this cell's drainage area to its receiver receiver = receivers.get(cell) if receiver is not None: ri, rj = receiver drainage_area[ri, rj] += drainage_area[i, j] # Check if all contributors to receiver are now processed receiver_contributors = contributors.get(receiver, []) if all(c in processed for c in receiver_contributors): to_process.append(receiver) # Cycle detection: ensure all land cells were processed total_land_cells = np.sum((flow_dir >= 0).astype(int)) # cells with valid flow or outlet if len(processed) < total_land_cells: unprocessed = total_land_cells - len(processed) raise RuntimeError( f"Cycle detected in flow network! {unprocessed} cells never reached in_degree 0. " "This indicates a bug in DEM conditioning (Stage 2). " "Check outlets (Stage 1) and breaching/filling (Stage 2a/2b)." ) return drainage_area
@jit(nopython=True, cache=True) def _compute_upstream_rainfall_jit( flow_dir: np.ndarray, precipitation: np.ndarray, upstream_rainfall: np.ndarray ) -> None: """ JIT-compiled upstream rainfall computation (numba accelerated). Uses array-based topological sorting for massive speedup over dict/set approach. Parameters ---------- flow_dir : np.ndarray D8 flow direction grid precipitation : np.ndarray Annual precipitation (mm/year) upstream_rainfall : np.ndarray Output array (modified in-place), initialized to precipitation values """ rows, cols = flow_dir.shape # D8 direction offsets offsets = np.array([ (0, 1), # 1: East (-1, 1), # 2: Northeast (-1, 0), # 4: North (-1, -1), # 8: Northwest (0, -1), # 16: West (1, -1), # 32: Southwest (1, 0), # 64: South (1, 1), # 128: Southeast ], dtype=np.int32) codes = np.array([1, 2, 4, 8, 16, 32, 64, 128], dtype=np.uint8) # Count contributors contributor_count = np.zeros((rows, cols), dtype=np.int32) for i in range(rows): for j in range(cols): direction = flow_dir[i, j] if direction > 0: for k in range(8): if codes[k] == direction: di, dj = offsets[k] ni, nj = i + di, j + dj if 0 <= ni < rows and 0 <= nj < cols: contributor_count[ni, nj] += 1 break # Initialize queue with ridgelines (cells with 0 contributors) queue_size = 0 queue = np.zeros(rows * cols, dtype=np.int32) for i in range(rows): for j in range(cols): if contributor_count[i, j] == 0: queue[queue_size] = i * cols + j queue_size += 1 # Process in topological order queue_pos = 0 while queue_pos < queue_size: flat_idx = queue[queue_pos] queue_pos += 1 i = flat_idx // cols j = flat_idx % cols # Get receiver direction = flow_dir[i, j] if direction > 0: for k in range(8): if codes[k] == direction: di, dj = offsets[k] ni, nj = i + di, j + dj if 0 <= ni < rows and 0 <= nj < cols: # Accumulate upstream rainfall upstream_rainfall[ni, nj] += upstream_rainfall[i, j] # Decrement contributor count contributor_count[ni, nj] -= 1 # If ready, add to queue if contributor_count[ni, nj] == 0: queue[queue_size] = ni * cols + nj queue_size += 1 break
[docs] def compute_upstream_rainfall(flow_dir: np.ndarray, precipitation: np.ndarray) -> np.ndarray: """ Compute precipitation-weighted flow accumulation. Parameters ---------- flow_dir : np.ndarray D8 flow direction grid precipitation : np.ndarray Annual precipitation (mm/year) Returns ------- np.ndarray Total upstream precipitation (mm·m²) at each pixel Represents cumulative rainfall from entire upstream area """ rows, cols = flow_dir.shape upstream_rainfall = precipitation.copy().astype(np.float32) # Use JIT-compiled version if available (50-100x faster) if NUMBA_AVAILABLE: _compute_upstream_rainfall_jit(flow_dir, precipitation, upstream_rainfall) return upstream_rainfall # Fallback: Pure Python implementation with dict/set # Build flow network contributors = {} receivers = {} for i in range(rows): for j in range(cols): direction = flow_dir[i, j] if direction > 0 and direction in D8_OFFSETS: di, dj = D8_OFFSETS[direction] receiver = (i + di, j + dj) if 0 <= receiver[0] < rows and 0 <= receiver[1] < cols: receivers[(i, j)] = receiver if receiver not in contributors: contributors[receiver] = [] contributors[receiver].append((i, j)) # Topological sort: process from ridgelines to outlets processed = set() to_process = [] # Find cells with no contributors (ridgelines) for i in range(rows): for j in range(cols): if (i, j) not in contributors: to_process.append((i, j)) # Process in topological order while to_process: cell = to_process.pop(0) if cell in processed: continue i, j = cell processed.add(cell) # Add this cell's upstream rainfall to its receiver receiver = receivers.get(cell) if receiver is not None: ri, rj = receiver upstream_rainfall[ri, rj] += upstream_rainfall[i, j] # Check if receiver is ready to process receiver_contributors = contributors.get(receiver, []) if all(c in processed for c in receiver_contributors): to_process.append(receiver) return upstream_rainfall
def compute_discharge_potential( drainage_area: np.ndarray, upstream_rainfall: np.ndarray, ) -> np.ndarray: """ Compute discharge potential combining drainage area and rainfall. Discharge potential represents where the largest water flows occur, combining topographic convergence (drainage area) with climate (rainfall). Higher values indicate locations where both large watersheds AND high precipitation combine to produce significant discharge. Parameters ---------- drainage_area : np.ndarray Drainage area in cells (from compute_drainage_area) upstream_rainfall : np.ndarray Upstream rainfall accumulation (from compute_upstream_rainfall) Returns ------- np.ndarray Discharge potential index: drainage_area × (upstream_rainfall / mean_rainfall) Units are dimensionless, scaled relative to mean rainfall Notes ----- The formula normalizes upstream rainfall by mean to produce a dimensionless multiplier. This means: - Discharge potential = drainage_area when rainfall is uniform - Cells with above-average rainfall have higher discharge potential - Cells with below-average rainfall have lower discharge potential This is useful for identifying where actual river discharge would be highest, accounting for both watershed size and precipitation patterns. Examples -------- >>> drainage_area = np.array([[1, 2], [4, 8]], dtype=np.float32) >>> upstream_rainfall = np.array([[100, 200], [400, 800]], dtype=np.float32) >>> discharge = compute_discharge_potential(drainage_area, upstream_rainfall) >>> discharge.shape (2, 2) """ # Handle edge cases upstream_valid = upstream_rainfall[upstream_rainfall > 0] if len(upstream_valid) == 0: # No valid rainfall data - return drainage area as-is return drainage_area.astype(np.float32) mean_rainfall = np.mean(upstream_valid) # Compute discharge potential # Formula: drainage × (rainfall / mean_rainfall) discharge = drainage_area.astype(np.float32) * (upstream_rainfall / mean_rainfall) # Ensure zeros stay zero (avoid NaN from 0/0) discharge[upstream_rainfall == 0] = 0 return discharge
[docs] def identify_outlets( dem: np.ndarray, nodata_mask: np.ndarray, coastal_elev_threshold: float = 10.0, edge_mode: Literal["all", "local_minima", "outward_slope", "none"] = "all", masked_basin_outlets: Optional[np.ndarray] = None, ) -> np.ndarray: """ Identify drainage outlet cells (Stage 1 of flow-spec.md). Classifies cells that act as drainage termini where water leaves the system. Implements three outlet types: coastal, edge, and masked basin outlets. Parameters ---------- dem : np.ndarray Digital elevation model nodata_mask : np.ndarray (bool) Mask of ocean/off-grid cells (True = NoData) coastal_elev_threshold : float, default 10.0 Maximum elevation for coastal outlets (meters above sea level). Prevents high cliffs adjacent to ocean from being spurious outlets. edge_mode : {"all", "local_minima", "outward_slope", "none"}, default "all" Boundary outlet strategy: - "all": All boundary cells are outlets (safest, prevents artificial basins) - "local_minima": Only boundary cells that are local minima - "outward_slope": Boundary cells with interior neighbors sloping toward them - "none": No edge outlets (for islands fully surrounded by coastline) masked_basin_outlets : np.ndarray (bool), optional User-supplied outlet locations for known lakes/basins Returns ------- np.ndarray (bool) Boolean mask where True = outlet cell Notes ----- This implements Stage 1 of flow-spec.md (lines 42-108). Edge mode "all" is the safest default - it ensures no artificial endorheic basins form at boundaries. The cost is some fragmentation of edge drainage networks, but this is usually preferable to missed outlets. Examples -------- >>> dem = np.array([[5, 5, 5], ... [5, 3, 5], ... [5, 5, 5]], dtype=np.float32) >>> nodata = np.array([[False, False, False], ... [True, False, False], ... [False, False, False]]) >>> outlets = identify_outlets(dem, nodata, coastal_elev_threshold=10.0) >>> outlets[1, 1] # Low coastal cell adjacent to ocean True """ rows, cols = dem.shape outlets = np.zeros((rows, cols), dtype=bool) # --- Coastal outlets --- # Land cells adjacent to NoData AND elevation <= threshold for i in range(rows): for j in range(cols): # Skip if already NoData if nodata_mask[i, j]: continue # Check if adjacent to NoData (8-connected) adjacent_to_nodata = False for di, dj in [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]: ni, nj = i + di, j + dj if 0 <= ni < rows and 0 <= nj < cols: if nodata_mask[ni, nj]: adjacent_to_nodata = True break # Mark as coastal outlet if low enough if adjacent_to_nodata and dem[i, j] <= coastal_elev_threshold: outlets[i, j] = True # --- Edge outlets --- if edge_mode == "all": # All boundary cells are outlets outlets[0, :] = True # Top edge outlets[-1, :] = True # Bottom edge outlets[:, 0] = True # Left edge outlets[:, -1] = True # Right edge # Don't override NoData cells outlets[nodata_mask] = False elif edge_mode == "local_minima": # Only boundary cells that are local minima among edge neighbors # Top edge for j in range(cols): if nodata_mask[0, j]: continue is_min = True # Check edge neighbors (left, right, below) for dj in [-1, 0, 1]: nj = j + dj if 0 <= nj < cols and nj != j: if not nodata_mask[0, nj] and dem[0, nj] < dem[0, j]: is_min = False break # Check interior neighbor (below) if rows > 1 and not nodata_mask[1, j]: if dem[1, j] < dem[0, j]: is_min = False if is_min: outlets[0, j] = True # Bottom edge for j in range(cols): if nodata_mask[-1, j]: continue is_min = True for dj in [-1, 0, 1]: nj = j + dj if 0 <= nj < cols and nj != j: if not nodata_mask[-1, nj] and dem[-1, nj] < dem[-1, j]: is_min = False break # Check interior neighbor (above) if rows > 1 and not nodata_mask[-2, j]: if dem[-2, j] < dem[-1, j]: is_min = False if is_min: outlets[-1, j] = True # Left edge for i in range(rows): if nodata_mask[i, 0]: continue is_min = True for di in [-1, 0, 1]: ni = i + di if 0 <= ni < rows and ni != i: if not nodata_mask[ni, 0] and dem[ni, 0] < dem[i, 0]: is_min = False break # Check interior neighbor (right) if cols > 1 and not nodata_mask[i, 1]: if dem[i, 1] < dem[i, 0]: is_min = False if is_min: outlets[i, 0] = True # Right edge for i in range(rows): if nodata_mask[i, -1]: continue is_min = True for di in [-1, 0, 1]: ni = i + di if 0 <= ni < rows and ni != i: if not nodata_mask[ni, -1] and dem[ni, -1] < dem[i, -1]: is_min = False break # Check interior neighbor (left) if cols > 1 and not nodata_mask[i, -2]: if dem[i, -2] < dem[i, -1]: is_min = False if is_min: outlets[i, -1] = True elif edge_mode == "outward_slope": # Boundary cells with interior neighbors sloping more steeply toward edge # than toward any other neighbor # Top edge for j in range(cols): if nodata_mask[0, j]: continue if rows > 1 and not nodata_mask[1, j]: # Interior neighbor below slope_to_edge = (dem[1, j] - dem[0, j]) / 1.0 # Check if this is steepest slope from interior cell max_slope_elsewhere = -np.inf for di, dj in [(-1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]: ni, nj = 1 + di, j + dj if 0 <= ni < rows and 0 <= nj < cols: if not nodata_mask[ni, nj]: dist = np.sqrt(di**2 + dj**2) slope = (dem[1, j] - dem[ni, nj]) / dist max_slope_elsewhere = max(max_slope_elsewhere, slope) if slope_to_edge > max_slope_elsewhere: outlets[0, j] = True # Similar logic for other edges (bottom, left, right) # Bottom edge for j in range(cols): if nodata_mask[-1, j]: continue if rows > 1 and not nodata_mask[-2, j]: slope_to_edge = (dem[-2, j] - dem[-1, j]) / 1.0 max_slope_elsewhere = -np.inf for di, dj in [(1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]: ni, nj = -2 + di, j + dj if 0 <= ni < rows and 0 <= nj < cols: if not nodata_mask[ni, nj]: dist = np.sqrt(di**2 + dj**2) slope = (dem[-2, j] - dem[ni, nj]) / dist max_slope_elsewhere = max(max_slope_elsewhere, slope) if slope_to_edge > max_slope_elsewhere: outlets[-1, j] = True # Left edge for i in range(rows): if nodata_mask[i, 0]: continue if cols > 1 and not nodata_mask[i, 1]: slope_to_edge = (dem[i, 1] - dem[i, 0]) / 1.0 max_slope_elsewhere = -np.inf for di, dj in [(0, -1), (-1, 0), (1, 0), (-1, -1), (-1, 1), (1, -1), (1, 1)]: ni, nj = i + di, 1 + dj if 0 <= ni < rows and 0 <= nj < cols: if not nodata_mask[ni, nj]: dist = np.sqrt(di**2 + dj**2) slope = (dem[i, 1] - dem[ni, nj]) / dist max_slope_elsewhere = max(max_slope_elsewhere, slope) if slope_to_edge > max_slope_elsewhere: outlets[i, 0] = True # Right edge for i in range(rows): if nodata_mask[i, -1]: continue if cols > 1 and not nodata_mask[i, -2]: slope_to_edge = (dem[i, -2] - dem[i, -1]) / 1.0 max_slope_elsewhere = -np.inf for di, dj in [(0, 1), (-1, 0), (1, 0), (-1, -1), (-1, 1), (1, -1), (1, 1)]: ni, nj = i + di, -2 + dj if 0 <= ni < rows and 0 <= nj < cols: if not nodata_mask[ni, nj]: dist = np.sqrt(di**2 + dj**2) slope = (dem[i, -2] - dem[ni, nj]) / dist max_slope_elsewhere = max(max_slope_elsewhere, slope) if slope_to_edge > max_slope_elsewhere: outlets[i, -1] = True elif edge_mode == "none": # No edge outlets (for islands) pass # --- Masked basin outlets --- if masked_basin_outlets is not None: outlets |= masked_basin_outlets return outlets
@jit(nopython=True, parallel=True, cache=True) def _identify_sinks_jit( dem: np.ndarray, outlets: np.ndarray, nodata_mask: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ JIT-compiled parallel sink identification (10-20x faster than pure Python). Uses numba parallel execution to check cells concurrently across multiple CPU cores. Returns ------- sink_rows, sink_cols, sink_elevs : np.ndarray Arrays of sink coordinates and elevations, sorted by elevation """ rows, cols = dem.shape # First pass: count sinks per row (parallel) row_sink_counts = np.zeros(rows, dtype=np.int32) for i in prange(rows): count = 0 for j in range(cols): # Skip outlets and nodata if outlets[i, j] or nodata_mask[i, j]: continue # Check if any neighbor is lower has_downslope = False for di in range(-1, 2): for dj in range(-1, 2): if di == 0 and dj == 0: continue ni, nj = i + di, j + dj if 0 <= ni < rows and 0 <= nj < cols: if not nodata_mask[ni, nj] and dem[ni, nj] < dem[i, j]: has_downslope = True break if has_downslope: break if not has_downslope: count += 1 row_sink_counts[i] = count # Calculate total sinks and row offsets total_sinks = np.sum(row_sink_counts) row_offsets = np.zeros(rows + 1, dtype=np.int32) for i in range(rows): row_offsets[i + 1] = row_offsets[i] + row_sink_counts[i] # Allocate output arrays sink_rows = np.empty(total_sinks, dtype=np.int32) sink_cols = np.empty(total_sinks, dtype=np.int32) sink_elevs = np.empty(total_sinks, dtype=np.float64) # Second pass: fill sink arrays (parallel) for i in prange(rows): offset = row_offsets[i] local_count = 0 for j in range(cols): # Skip outlets and nodata if outlets[i, j] or nodata_mask[i, j]: continue # Check if any neighbor is lower has_downslope = False for di in range(-1, 2): for dj in range(-1, 2): if di == 0 and dj == 0: continue ni, nj = i + di, j + dj if 0 <= ni < rows and 0 <= nj < cols: if not nodata_mask[ni, nj] and dem[ni, nj] < dem[i, j]: has_downslope = True break if has_downslope: break if not has_downslope: sink_rows[offset + local_count] = i sink_cols[offset + local_count] = j sink_elevs[offset + local_count] = dem[i, j] local_count += 1 # Sort by elevation (argsort) sort_indices = np.argsort(sink_elevs) return sink_rows[sort_indices], sink_cols[sort_indices], sink_elevs[sort_indices] def _identify_sinks( dem: np.ndarray, outlets: np.ndarray, nodata_mask: Optional[np.ndarray] = None, ) -> list: """ Identify all sink cells (cells with no downslope neighbor). Returns sinks sorted by elevation (lowest first) to reduce cascading breaches. Parameters ---------- dem : np.ndarray Digital elevation model outlets : np.ndarray (bool) Outlet mask from identify_outlets() nodata_mask : np.ndarray (bool), optional Cells to exclude Returns ------- list of tuples List of (row, col, elevation) for each sink, sorted by elevation ascending """ if nodata_mask is None: nodata_mask = np.zeros_like(dem, dtype=bool) # Use JIT-compiled version if available if NUMBA_AVAILABLE: sink_rows, sink_cols, sink_elevs = _identify_sinks_jit(dem, outlets, nodata_mask) return [(int(r), int(c), float(e)) for r, c, e in zip(sink_rows, sink_cols, sink_elevs)] # Fallback to pure Python rows, cols = dem.shape sinks = [] for i in range(rows): for j in range(cols): # Skip outlets and nodata if outlets[i, j] or nodata_mask[i, j]: continue # Check if any neighbor is lower has_downslope = False for di, dj in [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]: ni, nj = i + di, j + dj if 0 <= ni < rows and 0 <= nj < cols: if not nodata_mask[ni, nj] and dem[ni, nj] < dem[i, j]: has_downslope = True break if not has_downslope: sinks.append((i, j, dem[i, j])) # Sort by elevation (process lowest first to avoid cascading) sinks.sort(key=lambda x: x[2]) return sinks def _reconstruct_path(parent_map: dict, start_r: int, start_c: int, end_r: int, end_c: int) -> list: """ Reconstruct breach path from Dijkstra parent map. Parameters ---------- parent_map : dict Mapping from (row, col) to (parent_row, parent_col) start_r, start_c : int Starting cell (sink) end_r, end_c : int Ending cell (drain point) Returns ------- list of tuples Path from sink to drain as [(row, col), ...] in order """ path = [] r, c = end_r, end_c while (r, c) != (None, None): path.append((r, c)) if r == start_r and c == start_c: break r, c = parent_map.get((r, c), (None, None)) path.reverse() return path @jit(nopython=True, cache=True) def _find_breach_path_dijkstra_jit( dem: np.ndarray, start_row: int, start_col: int, outlets: np.ndarray, resolved: np.ndarray, max_depth: float, max_length: int, ) -> Tuple[bool, np.ndarray, np.ndarray]: """ JIT-compiled Dijkstra breach path finder (5-10x faster). Returns ------- found : bool Whether a valid path was found path_rows, path_cols : np.ndarray Path coordinates (empty if not found) """ start_elev = dem[start_row, start_col] rows, cols = dem.shape # Use arrays instead of dicts for JIT compatibility visited = np.zeros((rows, cols), dtype=np.bool_) parent_r = np.full((rows, cols), -1, dtype=np.int32) parent_c = np.full((rows, cols), -1, dtype=np.int32) # Simple priority queue (using lists, heapq not fully supported in nopython) # Store: (cost, length, row, col, parent_row, parent_col) # We'll use a simplified approach: expand lowest cost first max_queue_size = min(10000, rows * cols) # Reasonable limit queue_costs = np.full(max_queue_size, np.inf, dtype=np.float64) queue_lengths = np.zeros(max_queue_size, dtype=np.int32) queue_rows = np.zeros(max_queue_size, dtype=np.int32) queue_cols = np.zeros(max_queue_size, dtype=np.int32) queue_parent_r = np.zeros(max_queue_size, dtype=np.int32) queue_parent_c = np.zeros(max_queue_size, dtype=np.int32) queue_size = 1 queue_costs[0] = 0.0 queue_lengths[0] = 0 queue_rows[0] = start_row queue_cols[0] = start_col queue_parent_r[0] = -1 queue_parent_c[0] = -1 end_r, end_c = -1, -1 found = False while queue_size > 0: # Find minimum cost item (simple linear search for now) min_idx = 0 min_cost = queue_costs[0] for i in range(1, queue_size): if queue_costs[i] < min_cost: min_cost = queue_costs[i] min_idx = i # Pop item cost = queue_costs[min_idx] length = queue_lengths[min_idx] r = queue_rows[min_idx] c = queue_cols[min_idx] pr = queue_parent_r[min_idx] pc = queue_parent_c[min_idx] # Remove from queue (swap with last) queue_size -= 1 if min_idx < queue_size: queue_costs[min_idx] = queue_costs[queue_size] queue_lengths[min_idx] = queue_lengths[queue_size] queue_rows[min_idx] = queue_rows[queue_size] queue_cols[min_idx] = queue_cols[queue_size] queue_parent_r[min_idx] = queue_parent_r[queue_size] queue_parent_c[min_idx] = queue_parent_c[queue_size] # Skip if already visited if visited[r, c]: continue visited[r, c] = True parent_r[r, c] = pr parent_c[r, c] = pc # Check termination: reached outlet if outlets[r, c]: # Verify outlet is not deeper than max_depth below sink # (to prevent breaching sink by more than max_depth) sink_lowering = start_elev - dem[r, c] if sink_lowering <= max_depth: end_r, end_c = r, c found = True break # Check termination: reached resolved cell at or below start elevation if resolved[r, c] and dem[r, c] <= start_elev: end_r, end_c = r, c found = True break # Check length constraint if length >= max_length: continue # Explore neighbors for di in range(-1, 2): for dj in range(-1, 2): if di == 0 and dj == 0: continue ni, nj = r + di, c + dj # Bounds check if not (0 <= ni < rows and 0 <= nj < cols): continue if visited[ni, nj]: continue # Cost to breach through this neighbor breach_depth_here = max(0.0, dem[ni, nj] - start_elev) # Check depth constraint if breach_depth_here > max_depth: continue new_cost = cost + breach_depth_here new_length = length + 1 # Add to queue if space available if queue_size < max_queue_size: queue_costs[queue_size] = new_cost queue_lengths[queue_size] = new_length queue_rows[queue_size] = ni queue_cols[queue_size] = nj queue_parent_r[queue_size] = r queue_parent_c[queue_size] = c queue_size += 1 if not found: return False, np.empty(0, dtype=np.int32), np.empty(0, dtype=np.int32) # Reconstruct path path_list_r = [] path_list_c = [] curr_r, curr_c = end_r, end_c while curr_r != -1: path_list_r.append(curr_r) path_list_c.append(curr_c) if curr_r == start_row and curr_c == start_col: break next_r = parent_r[curr_r, curr_c] next_c = parent_c[curr_r, curr_c] curr_r, curr_c = next_r, next_c # Convert to arrays and reverse path_r = np.array(path_list_r[::-1], dtype=np.int32) path_c = np.array(path_list_c[::-1], dtype=np.int32) return True, path_r, path_c @jit(nopython=True, cache=True) def _dijkstra_single_sink( dem: np.ndarray, start_row: int, start_col: int, outlets: np.ndarray, max_depth: float, max_length: int, path_r_out: np.ndarray, path_c_out: np.ndarray, ) -> int: """ Single sink Dijkstra worker for parallel processing. Similar to _find_breach_path_dijkstra_jit but: 1. Ignores `resolved` array (for Phase 1 parallelism) 2. Writes path to pre-allocated output arrays 3. Returns path length (0 if not found) This allows calling from within prange loops. """ start_elev = dem[start_row, start_col] rows, cols = dem.shape # Use arrays for JIT compatibility visited = np.zeros((rows, cols), dtype=np.bool_) parent_r = np.full((rows, cols), -1, dtype=np.int32) parent_c = np.full((rows, cols), -1, dtype=np.int32) # Simple priority queue max_queue_size = min(10000, rows * cols) queue_costs = np.full(max_queue_size, np.inf, dtype=np.float64) queue_lengths = np.zeros(max_queue_size, dtype=np.int32) queue_rows = np.zeros(max_queue_size, dtype=np.int32) queue_cols = np.zeros(max_queue_size, dtype=np.int32) queue_parent_r = np.zeros(max_queue_size, dtype=np.int32) queue_parent_c = np.zeros(max_queue_size, dtype=np.int32) queue_size = 1 queue_costs[0] = 0.0 queue_lengths[0] = 0 queue_rows[0] = start_row queue_cols[0] = start_col queue_parent_r[0] = -1 queue_parent_c[0] = -1 end_r, end_c = -1, -1 found = False while queue_size > 0: # Find minimum cost item min_idx = 0 min_cost = queue_costs[0] for i in range(1, queue_size): if queue_costs[i] < min_cost: min_cost = queue_costs[i] min_idx = i # Pop item cost = queue_costs[min_idx] length = queue_lengths[min_idx] r = queue_rows[min_idx] c = queue_cols[min_idx] pr = queue_parent_r[min_idx] pc = queue_parent_c[min_idx] # Remove from queue (swap with last) queue_size -= 1 if min_idx < queue_size: queue_costs[min_idx] = queue_costs[queue_size] queue_lengths[min_idx] = queue_lengths[queue_size] queue_rows[min_idx] = queue_rows[queue_size] queue_cols[min_idx] = queue_cols[queue_size] queue_parent_r[min_idx] = queue_parent_r[queue_size] queue_parent_c[min_idx] = queue_parent_c[queue_size] if visited[r, c]: continue visited[r, c] = True parent_r[r, c] = pr parent_c[r, c] = pc # Termination: reached outlet if outlets[r, c]: # Verify outlet is not deeper than max_depth below sink # (to prevent breaching sink by more than max_depth) sink_lowering = start_elev - dem[r, c] if sink_lowering <= max_depth: end_r, end_c = r, c found = True break # Length constraint if length >= max_length: continue # Explore neighbors for di in range(-1, 2): for dj in range(-1, 2): if di == 0 and dj == 0: continue ni, nj = r + di, c + dj if not (0 <= ni < rows and 0 <= nj < cols): continue if visited[ni, nj]: continue breach_depth_here = max(0.0, dem[ni, nj] - start_elev) if breach_depth_here > max_depth: continue new_cost = cost + breach_depth_here new_length = length + 1 if queue_size < max_queue_size: queue_costs[queue_size] = new_cost queue_lengths[queue_size] = new_length queue_rows[queue_size] = ni queue_cols[queue_size] = nj queue_parent_r[queue_size] = r queue_parent_c[queue_size] = c queue_size += 1 if not found: return 0 # No path found # Reconstruct path into output arrays path_idx = 0 curr_r, curr_c = end_r, end_c max_path_len = len(path_r_out) while curr_r != -1 and path_idx < max_path_len: path_r_out[path_idx] = curr_r path_c_out[path_idx] = curr_c path_idx += 1 if curr_r == start_row and curr_c == start_col: break next_r = parent_r[curr_r, curr_c] next_c = parent_c[curr_r, curr_c] curr_r, curr_c = next_r, next_c # Reverse path in-place (sink -> outlet becomes sink -> outlet order) for i in range(path_idx // 2): j = path_idx - 1 - i path_r_out[i], path_r_out[j] = path_r_out[j], path_r_out[i] path_c_out[i], path_c_out[j] = path_c_out[j], path_c_out[i] return path_idx @jit(nopython=True, parallel=True, cache=True) def _breach_sinks_parallel_batch( dem: np.ndarray, sinks_r: np.ndarray, sinks_c: np.ndarray, outlets: np.ndarray, max_depth: float, max_length: int, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Find breach paths for a batch of sinks in parallel using prange. This is Phase 1 of two-phase parallel processing: - Uses outlets only (no resolved array) for termination - Each sink's Dijkstra is independent - Sinks should be spatially distant (checkerboard partitioning) Parameters ---------- dem : np.ndarray Digital elevation model sinks_r, sinks_c : np.ndarray Row and column coordinates of sinks to process outlets : np.ndarray (bool) Outlet mask max_depth : float Maximum breach depth per cell max_length : int Maximum path length Returns ------- found : np.ndarray (bool) Whether path was found for each sink paths_r, paths_c : np.ndarray (n_sinks, max_length+1) Path coordinates for each sink (-1 for unused slots) path_lengths : np.ndarray (int) Actual path length for each sink """ n_sinks = len(sinks_r) max_path_len = max_length + 1 # Pre-allocate outputs found = np.zeros(n_sinks, dtype=np.bool_) paths_r = np.full((n_sinks, max_path_len), -1, dtype=np.int32) paths_c = np.full((n_sinks, max_path_len), -1, dtype=np.int32) path_lengths = np.zeros(n_sinks, dtype=np.int32) # Process each sink in parallel for i in prange(n_sinks): sink_r = sinks_r[i] sink_c = sinks_c[i] # Run Dijkstra for this sink path_len = _dijkstra_single_sink( dem, sink_r, sink_c, outlets, max_depth, max_length, paths_r[i], paths_c[i] ) found[i] = path_len > 0 path_lengths[i] = path_len return found, paths_r, paths_c, path_lengths def _cluster_sinks_checkerboard( sinks: list, grid_size: int, dem_shape: Tuple[int, int], ) -> Tuple[list, list]: """ Cluster sinks into two batches using checkerboard pattern. Divides the DEM into grid cells of size `grid_size`. Sinks in "black" cells (even row+col) go in batch 1, "white" cells (odd) in batch 2. This ensures sinks in the same batch are at least `grid_size` cells apart diagonally. Parameters ---------- sinks : list List of (row, col, elev, depth) tuples grid_size : int Size of grid cells (should be >= max_breach_length) dem_shape : tuple Shape of DEM (rows, cols) Returns ------- batch_black, batch_white : list Two lists of sinks for parallel processing """ batch_black = [] # Even grid cells batch_white = [] # Odd grid cells for sink in sinks: r, c, elev, depth = sink grid_r = r // grid_size grid_c = c // grid_size if (grid_r + grid_c) % 2 == 0: batch_black.append(sink) else: batch_white.append(sink) return batch_black, batch_white def _find_breach_path_dijkstra( dem: np.ndarray, start_row: int, start_col: int, outlets: np.ndarray, resolved: np.ndarray, max_depth: float, max_length: int, ) -> Optional[list]: """ Find least-cost breach path from sink to draining cell using Dijkstra. Cost metric: Total elevation that must be removed along the path. Termination conditions: 1. Reached an outlet cell 2. Reached a resolved cell with elevation <= start_elev 3. No more cells to explore within constraints (breach failed) Parameters ---------- dem : np.ndarray Digital elevation model start_row, start_col : int Sink cell coordinates outlets : np.ndarray (bool) Outlet mask resolved : np.ndarray (bool) Tracks which cells already have drainage paths max_depth : float Maximum elevation drop allowed at any single cell (meters) max_length : int Maximum path length (cells) Returns ------- list of tuples or None Breach path [(row, col), ...] from sink to drain, or None if failed """ # Use JIT-compiled version if available if NUMBA_AVAILABLE: found, path_r, path_c = _find_breach_path_dijkstra_jit( dem, start_row, start_col, outlets, resolved, max_depth, max_length ) if found: return [(int(r), int(c)) for r, c in zip(path_r, path_c)] else: return None # Fallback to pure Python with heapq import heapq start_elev = dem[start_row, start_col] rows, cols = dem.shape # Priority queue: (cost, length, row, col, parent_row, parent_col) pq = [(0, 0, start_row, start_col, None, None)] visited = {} # {(row, col): cost} parent_map = {} # {(row, col): (parent_row, parent_col)} while pq: cost, length, r, c, pr, pc = heapq.heappop(pq) # Skip if already visited with lower cost if (r, c) in visited: continue visited[(r, c)] = cost parent_map[(r, c)] = (pr, pc) # Check termination: reached outlet if outlets[r, c]: # Verify outlet is not deeper than max_depth below sink # (to prevent breaching sink by more than max_depth) sink_lowering = start_elev - dem[r, c] if sink_lowering <= max_depth: return _reconstruct_path(parent_map, start_row, start_col, r, c) # Check termination: reached resolved cell at or below start elevation if resolved[r, c] and dem[r, c] <= start_elev: return _reconstruct_path(parent_map, start_row, start_col, r, c) # Check length constraint if length >= max_length: continue # Don't expand further from this cell # Explore neighbors for di, dj in [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]: ni, nj = r + di, c + dj # Bounds check if not (0 <= ni < rows and 0 <= nj < cols): continue if (ni, nj) in visited: continue # Cost to breach through this neighbor # Must carve down to start_elev (or below) breach_depth_here = max(0, dem[ni, nj] - start_elev) # Check depth constraint if breach_depth_here > max_depth: continue # Would exceed max breach depth at this cell new_cost = cost + breach_depth_here new_length = length + 1 heapq.heappush(pq, (new_cost, new_length, ni, nj, r, c)) # No viable path found within constraints return None def _apply_breach(dem: np.ndarray, path: list, epsilon: float) -> None: """ Carve breach path into DEM with monotonic gradient. Works backward from drain point to sink, ensuring each cell is epsilon lower than the previous cell in the path. Only lowers cells, never raises them. Parameters ---------- dem : np.ndarray Digital elevation model (modified in-place) path : list of tuples Breach path [(row, col), ...] from sink to drain epsilon : float Small gradient for breach paths (meters per cell) """ n = len(path) if n < 2: return # Start from drain point (end of path) drain_r, drain_c = path[-1] base_elev = dem[drain_r, drain_c] # Work backward toward sink for i in range(n - 2, -1, -1): r, c = path[i] # Each cell should be epsilon higher than the next cell in path required_elev = base_elev + epsilon * (n - 1 - i) # Only lower cells, never raise them if dem[r, c] > required_elev: dem[r, c] = required_elev
[docs] def breach_depressions_constrained( dem: np.ndarray, outlets: np.ndarray, max_breach_depth: float = 50.0, max_breach_length: int = 100, epsilon: float = 1e-4, nodata_mask: Optional[np.ndarray] = None, parallel_method: str = "checkerboard", ) -> np.ndarray: """ Remove depressions via constrained breaching (Stage 2a of flow-spec.md). Uses Lindsay (2016) constrained least-cost breaching algorithm. Implements two-pass approach: 1. Identify all sinks (cells with no downslope neighbor) 2. For each sink, attempt Dijkstra breach within constraints Parameters ---------- dem : np.ndarray Input digital elevation model outlets : np.ndarray (bool) Outlet mask from identify_outlets() max_breach_depth : float, default 50.0 Maximum elevation drop allowed at any single cell (meters) max_breach_length : int, default 100 Maximum breach path length (cells) epsilon : float, default 1e-4 Small gradient for breach paths (meters per cell) nodata_mask : np.ndarray (bool), optional Cells to exclude from breaching parallel_method : {"checkerboard", "iterative"}, default "checkerboard" Parallelization strategy for breach path finding: - "checkerboard": Two-phase parallel processing using checkerboard partitioning. Fast but can only terminate at outlets (misses chaining opportunities). - "iterative": Iterative refinement that runs multiple rounds, allowing each round's breaches to become termination targets for the next round. Finds more breaches via chaining but may be slower. Returns ------- np.ndarray DEM with depressions breached where possible Notes ----- This implements Stage 2a of flow-spec.md (lines 115-289). **Outlet Virtual Elevation:** Outlets (identified in Stage 1) act as guaranteed sinks in the breach algorithm. Breach paths MUST terminate at an outlet or at a cell that can drain naturally. Outlets are marked in the `resolved` set before breaching starts, ensuring they are always considered valid drainage targets. **Breach Path Termination:** Breach paths from sinks can terminate at: 1. An outlet cell (guaranteed sink) 2. A cell lower than the sink (natural downslope) 3. A cell that has already been resolved (has a path to an outlet) This ensures the final flow network has no artificial endorheic basins. **Residual Sinks:** Sinks that cannot be breached within constraints (max_breach_depth, max_breach_length) are left for Stage 2b (priority-flood fill) to handle. These are typically: - Large legitimate basins (lakes, endorheic systems) - Depressions too deep/long to breach efficiently References ---------- Lindsay, J.B. (2016). Efficient hybrid breaching-filling sink removal methods for flow path enforcement in digital elevation models. Hydrological Processes, 30, 846–857. Spec Reference: flow-spec.md lines 42-108 (outlet identification), 115-289 (constrained breach) """ breached = dem.copy().astype(np.float64) rows, cols = dem.shape if nodata_mask is None: nodata_mask = np.zeros_like(dem, dtype=bool) # Track which cells have drainage paths resolved = outlets.copy() print(" Stage 2a: Identifying sinks...") # Show threading info if NUMBA_AVAILABLE: try: from numba import get_num_threads num_threads = get_num_threads() print(f" Numba: {num_threads} threads available for parallel operations") except: pass sinks = _identify_sinks(breached, outlets, nodata_mask) print(f" Found {len(sinks):,} sink cells") if len(sinks) == 0: print(" No sinks to breach") return breached.astype(np.float32) # Precompute fill depths to identify shallow sinks (optimization) print(" Computing sink depths...") from skimage.morphology import reconstruction filled_preview = breached.copy().astype(np.float64) seed = filled_preview.copy() seed[1:-1, 1:-1] = filled_preview.max() + 1000 if nodata_mask is not None: seed[nodata_mask] = filled_preview[nodata_mask] filled_preview = reconstruction(seed, filled_preview, method='erosion') # Calculate depth for each sink sink_depths = [] for sink_r, sink_c, sink_elev in sinks: depth = filled_preview[sink_r, sink_c] - sink_elev sink_depths.append(depth) # Filter to significant sinks (> 1.0m depth) # Shallow sinks will be handled by Stage 2b (priority-flood fill) depth_threshold = 1.0 # meters significant_sinks = [ (sink_r, sink_c, sink_elev, depth) for (sink_r, sink_c, sink_elev), depth in zip(sinks, sink_depths) if depth > depth_threshold ] shallow_skipped = len(sinks) - len(significant_sinks) print(f" Skipping {shallow_skipped:,} shallow sinks (<{depth_threshold}m deep)") print(f" Processing {len(significant_sinks):,} significant sinks (>{depth_threshold}m deep)") print(f" Stage 2a: Attempting constrained breaching (max_depth={max_breach_depth}m, max_length={max_breach_length} cells)...") total_sinks = len(significant_sinks) breached_count = 0 failed_count = 0 already_resolved_count = 0 # Estimate if parallel is useful: only helps when sinks are near outlets # With large max_breach_length and large DEM, most sinks are interior # and parallel (outlets-only) will find ~0 breaches - wasted effort avg_dim = (rows + cols) / 2 parallel_useful = max_breach_length < avg_dim / 4 # Sinks likely near outlets if not parallel_useful and NUMBA_AVAILABLE and parallel_method != "iterative": print(f" Note: max_breach_length ({max_breach_length}) is large relative to DEM ({rows}x{cols})") print(f" Using serial JIT (better for interior sinks that chain together)") # Iterative refinement parallel processing if parallel_method == "iterative" and NUMBA_AVAILABLE and total_sinks > 0: # Iterative refinement: run parallel batches repeatedly, updating terminals each round # This allows chaining: sinks resolved in round N become terminals for round N+1 try: from numba import get_num_threads num_threads = get_num_threads() print(f" Using iterative refinement parallel breaching ({num_threads} CPU cores)") except: print(f" Using iterative refinement parallel breaching") grid_size = 2 * max_breach_length remaining_sinks = list(significant_sinks) iteration = 0 max_iterations = 100 # Safety limit while remaining_sinks and iteration < max_iterations: iteration += 1 iteration_breached = 0 # Current terminals = outlets + all resolved cells from previous iterations terminals = outlets | resolved # Cluster remaining sinks using checkerboard pattern batch_black, batch_white = _cluster_sinks_checkerboard( remaining_sinks, grid_size, (rows, cols) ) print(f" Iteration {iteration}: {len(remaining_sinks):,} sinks remaining...") newly_resolved_sinks = [] for batch, phase_name in [(batch_black, "black"), (batch_white, "white")]: if not batch: continue # Prepare arrays for batch batch_r = np.array([s[0] for s in batch], dtype=np.int32) batch_c = np.array([s[1] for s in batch], dtype=np.int32) # Run parallel Dijkstra with current terminals found, paths_r, paths_c, path_lengths = _breach_sinks_parallel_batch( breached, batch_r, batch_c, terminals, max_breach_depth, max_breach_length ) # Apply successful breaches for i in range(len(batch)): sink_r, sink_c = batch_r[i], batch_c[i] if resolved[sink_r, sink_c]: already_resolved_count += 1 newly_resolved_sinks.append((sink_r, sink_c)) continue if found[i]: path_len = path_lengths[i] path = [(int(paths_r[i, j]), int(paths_c[i, j])) for j in range(path_len)] _apply_breach(breached, path, epsilon) for r, c in path: resolved[r, c] = True breached_count += 1 iteration_breached += 1 newly_resolved_sinks.append((sink_r, sink_c)) # Remove resolved sinks from remaining list resolved_set = set(newly_resolved_sinks) remaining_sinks = [ s for s in remaining_sinks if (s[0], s[1]) not in resolved_set ] print(f" Breached {iteration_breached:,} sinks this iteration, {len(remaining_sinks):,} remaining") # If no progress, stop iterating if iteration_breached == 0: print(f" No new breaches in iteration {iteration}, stopping refinement") break # Count remaining as failed failed_count = len(remaining_sinks) if failed_count > 0: print(f" {failed_count:,} sinks could not be breached (will be filled in Stage 2b)") # Use two-phase parallel processing with numba prange (checkerboard method) elif parallel_method == "checkerboard" and NUMBA_AVAILABLE and total_sinks > 100 and parallel_useful: # Two-phase parallel processing: uses numba prange for true multi-core parallelism try: from numba import get_num_threads num_threads = get_num_threads() print(f" Using two-phase parallel breaching ({num_threads} CPU cores)") except: print(f" Using two-phase parallel breaching") # Cluster sinks using checkerboard pattern (grid_size = 2 * max_breach_length) # Sinks in same batch are guaranteed to be far enough apart that paths won't overlap grid_size = 2 * max_breach_length batch_black, batch_white = _cluster_sinks_checkerboard( significant_sinks, grid_size, (rows, cols) ) print(f" Checkerboard clustering: {len(batch_black):,} black cells, {len(batch_white):,} white cells") # Sub-batch size for progress reporting (process in chunks) # Use 1% intervals like serial version (~100 updates per phase) sub_batch_size = max(10, total_sinks // 100) def process_batch_with_progress(batch, phase_name, start_breached, start_failed, start_resolved): """Process a batch in sub-batches with progress reporting.""" nonlocal breached, resolved batch_breached = start_breached batch_failed = start_failed batch_resolved = start_resolved n = len(batch) for chunk_start in range(0, n, sub_batch_size): chunk_end = min(chunk_start + sub_batch_size, n) chunk = batch[chunk_start:chunk_end] # Prepare arrays for this chunk chunk_r = np.array([s[0] for s in chunk], dtype=np.int32) chunk_c = np.array([s[1] for s in chunk], dtype=np.int32) # Run parallel Dijkstra on chunk found, paths_r, paths_c, path_lengths = _breach_sinks_parallel_batch( breached, chunk_r, chunk_c, outlets, max_breach_depth, max_breach_length ) # Apply successful breaches for i in range(len(chunk)): sink_r, sink_c = chunk_r[i], chunk_c[i] if resolved[sink_r, sink_c]: batch_resolved += 1 continue if found[i]: path_len = path_lengths[i] path = [(int(paths_r[i, j]), int(paths_c[i, j])) for j in range(path_len)] _apply_breach(breached, path, epsilon) for r, c in path: resolved[r, c] = True batch_breached += 1 else: batch_failed += 1 # Progress report processed = chunk_end percent = 100.0 * processed / n print(f" {phase_name}: {percent:5.1f}% ({processed:,}/{n:,}) - {batch_breached:,} breached, {batch_failed:,} failed") return batch_breached, batch_failed, batch_resolved # Phase 1: Process "black" cells in parallel sub-batches if len(batch_black) > 0: print(f" Phase 1: Processing {len(batch_black):,} sinks...") breached_count, failed_count, already_resolved_count = process_batch_with_progress( batch_black, "Phase 1", breached_count, failed_count, already_resolved_count ) # Phase 2: Process "white" cells in parallel sub-batches if len(batch_white) > 0: print(f" Phase 2: Processing {len(batch_white):,} sinks...") breached_count, failed_count, already_resolved_count = process_batch_with_progress( batch_white, "Phase 2", breached_count, failed_count, already_resolved_count ) # Note: Failed sinks are left for Stage 2b (priority-flood) to handle # Priority-flood is O(n log n) and much faster than re-running Dijkstra else: # Serial processing for small sink counts or when numba unavailable if NUMBA_AVAILABLE: print(f" Using serial JIT-compiled Dijkstra (small sink count)") else: print(f" Using serial pure-Python Dijkstra (numba unavailable)") # Progress reporting if total_sinks > 0: progress_interval = max(1, total_sinks // 100) # Report every 1% print(f" Progress (showing every 1%):") for idx, (sink_r, sink_c, sink_elev, depth) in enumerate(significant_sinks): # Progress reporting if total_sinks > 0 and idx % progress_interval == 0: percent = 100.0 * idx / total_sinks print(f" {percent:5.1f}% ({idx:,} / {total_sinks:,} sinks, {breached_count:,} breached, {failed_count:,} failed)") # Skip if already resolved by previous breach if resolved[sink_r, sink_c]: already_resolved_count += 1 continue # Attempt to find breach path path = _find_breach_path_dijkstra( breached, sink_r, sink_c, outlets, resolved, max_breach_depth, max_breach_length ) if path is not None: # Apply breach _apply_breach(breached, path, epsilon) # Mark all cells in path as resolved for r, c in path: resolved[r, c] = True breached_count += 1 else: # Could not breach within constraints failed_count += 1 if total_sinks > 0: print(f" 100.0% ({total_sinks:,} / {total_sinks:,} sinks, {breached_count:,} breached, {failed_count:,} failed)") print(f" Results: {breached_count:,} breached, {already_resolved_count:,} already resolved, {failed_count:,} failed") print(f" Total sinks handled: {shallow_skipped:,} shallow (skipped) + {breached_count:,} breached + {already_resolved_count:,} resolved = {shallow_skipped + breached_count + already_resolved_count:,}") print(f" Remaining for Stage 2b: {failed_count + shallow_skipped:,} sinks") return breached.astype(np.float32)
[docs] def priority_flood_fill_epsilon( dem: np.ndarray, outlets: np.ndarray, epsilon: float = 1e-4, nodata_mask: Optional[np.ndarray] = None, ) -> np.ndarray: """ Fill residual depressions with epsilon gradient (Stage 2b of flow-spec.md). Uses Barnes et al. (2014) priority-flood algorithm with epsilon gradient applied DURING fill (not after). This creates flow-directing gradients as flats form, ensuring proper drainage without post-processing. Parameters ---------- dem : np.ndarray Input DEM (already breached by Stage 2a) outlets : np.ndarray (bool) Outlet mask from identify_outlets() epsilon : float, default 1e-4 Minimum elevation increment per cell (meters per cell). Creates gradients in filled areas to ensure drainage. nodata_mask : np.ndarray (bool), optional Cells to exclude from filling Returns ------- np.ndarray DEM with residual depressions filled Notes ----- This implements Stage 2b of flow-spec.md (lines 291-328). **Outlet Virtual Elevation:** Outlets are seeded into the priority queue with their original elevations. As the fill progresses outward from outlets, depression cells are raised with epsilon increments, creating micro-gradients that naturally point back toward the outlets. This ensures filled areas drain properly in Stage 3. **Epsilon Application (Flat Resolution Strategy):** Epsilon is applied DURING fill as cells are raised. This is the most robust flat resolution approach for several reasons: 1. **Implicit Gradient Creation**: As the fill progresses from outlets toward interior sinks, cells filled later get progressively higher elevations (by epsilon increments). This creates natural gradients pointing back toward outlets without explicit post-processing. 2. **Alternatives Mentioned in Literature:** - **Garbrecht & Martz (1997)**: Dual-gradient method assigns flow to flats based on proximity to higher terrain. More complex to implement. - **Postprocessing**: Some systems fill without gradients, then resolve flats afterward. Can be ambiguous for complex flat structures. 3. **Recommendation**: The epsilon-during-fill approach (used here) is: - Simpler to understand and implement - Produces consistent drainage patterns - Guaranteed to resolve all flats into valid flow networks - No need for separate flat-resolution algorithm **Epsilon Selection:** For epsilon tuning, see flow-spec.md lines 484-489: - **Auto-calculated** (default): epsilon = 1e-5 * cell_resolution - For 10m DEM: epsilon ≈ 1e-4 m/cell (0.1 mm per cell) - For 1m DEM: epsilon ≈ 1e-5 m/cell (0.01 mm per cell) - **For integer DEMs**: Use epsilon = 1 in native elevation units - If elevation in millimeters: epsilon = 1 mm/cell - If elevation in centimeters: epsilon = 1 cm/cell - **Too small epsilon**: Floating-point accumulation errors may create ties or reversals in flat areas. Rule of thumb: epsilon should exceed DEM measurement precision. - **Too large epsilon**: Creates obvious artificial "stair-stepping" in filled areas. Visual inspection usually reveals values > 0.1m. **Seed Cells:** The priority queue is seeded with: 1. All identified outlets (from Stage 1) - guaranteed sinks 2. All cells adjacent to NoData (domain boundary) - implicit outlets This ensures water drains both to identified outlets AND off the map edge. **Flat Area Behavior:** After filling, flat areas will have cells at different elevations (differing by epsilon). During Stage 3 (flow direction), steepest descent will cause water on flats to flow toward outlet-adjacent cells, creating coherent drainage patterns. This is more realistic than assuming multiple flow directions on wide flats. References ---------- Barnes, R., Lehman, C., & Mulla, D. (2014). Priority-flood: An optimal depression-filling and watershed-labeling algorithm for digital elevation models. Computers & Geosciences, 62, 117–127. Garbrecht, J., & Martz, L.W. (1997). The assignment of drainage direction over flat surfaces in raster digital elevation models. Journal of Hydrology, 193, 204–213. Spec Reference: flow-spec.md lines 291-328 (Stage 2b) flow-spec.md lines 484-494 (flat resolution discussion) """ import heapq rows, cols = dem.shape filled = dem.copy().astype(np.float64) if nodata_mask is None: nodata_mask = np.zeros_like(dem, dtype=bool) # Priority queue: (elevation, row, col) pq = [] in_queue = np.zeros((rows, cols), dtype=bool) # Seed priority queue with outlets for i in range(rows): for j in range(cols): if outlets[i, j] and not nodata_mask[i, j]: heapq.heappush(pq, (filled[i, j], i, j)) in_queue[i, j] = True # Also seed with cells adjacent to NoData (border of domain) for i in range(rows): for j in range(cols): if nodata_mask[i, j] or in_queue[i, j]: continue # Check if adjacent to NoData adjacent_to_nodata = False for di, dj in [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]: ni, nj = i + di, j + dj if 0 <= ni < rows and 0 <= nj < cols: if nodata_mask[ni, nj]: adjacent_to_nodata = True break if adjacent_to_nodata: heapq.heappush(pq, (filled[i, j], i, j)) in_queue[i, j] = True # Process cells in elevation order (priority-flood) filled_count = 0 while pq: elev, r, c = heapq.heappop(pq) # Check all 8 neighbors for di, dj in [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]: ni, nj = r + di, c + dj # Bounds and status checks if not (0 <= ni < rows and 0 <= nj < cols): continue if in_queue[ni, nj]: continue if nodata_mask[ni, nj]: continue # If neighbor is in depression (below current + epsilon), raise it # This is the KEY DIFFERENCE: epsilon applied DURING fill if filled[ni, nj] < elev + epsilon: filled[ni, nj] = elev + epsilon filled_count += 1 # Add neighbor to queue heapq.heappush(pq, (filled[ni, nj], ni, nj)) in_queue[ni, nj] = True if filled_count > 0: print(f" Priority-flood raised {filled_count:,} cells to resolve depressions") return filled.astype(np.float32)
[docs] def condition_dem_spec( dem: np.ndarray, nodata_mask: Optional[np.ndarray] = None, coastal_elev_threshold: float = 10.0, edge_mode: Literal["all", "local_minima", "outward_slope", "none"] = "all", max_breach_depth: float = 50.0, max_breach_length: int = 100, epsilon: Optional[float] = None, masked_basin_outlets: Optional[np.ndarray] = None, parallel_method: str = "checkerboard", ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Condition DEM using spec-compliant 4-stage pipeline. Orchestrates the complete flow-spec.md pipeline: 1. Stage 1: Identify outlets (coastal, edge, masked basins) 2. Stage 2a: Constrained breaching (Dijkstra least-cost) 3. Stage 2b: Priority-flood fill residuals (with epsilon) 4. (External) D8 flow direction → see compute_flow_direction() 5. (External) Flow accumulation → see compute_drainage_area() Parameters ---------- dem : np.ndarray Input digital elevation model nodata_mask : np.ndarray (bool), optional Mask of ocean/off-grid cells (True = NoData) coastal_elev_threshold : float, default 10.0 Maximum elevation for coastal outlets (meters above sea level) edge_mode : {"all", "local_minima", "outward_slope", "none"}, default "all" Boundary outlet strategy max_breach_depth : float, default 50.0 Maximum elevation drop at any single cell during breaching (meters) max_breach_length : int, default 100 Maximum breach path length (cells) epsilon : float, default 1e-4 Minimum elevation increment per cell in filled areas (meters) masked_basin_outlets : np.ndarray (bool), optional User-supplied outlet locations for known lakes/basins Returns ------- conditioned_dem : np.ndarray Breached and filled DEM outlets : np.ndarray (bool) Outlet mask (useful for diagnostics and visualization) breached_dem : np.ndarray DEM after breaching but before filling (for fill depth calculation) Notes ----- This implements the complete spec-compliant pipeline from flow-spec.md. **Implementation vs. Spec:** The implementation uses a two-pass breach approach rather than the priority-flood breach sketch in the spec (lines 125-172). Both approaches are equivalent and spec-compliant; two-pass is chosen for clarity: 1. Identify all sinks (cells with no downslope neighbor) 2. For each sink, attempt least-cost Dijkstra path to an outlet/drain-point 3. Apply monotonic gradient along successful paths The priority-flood approach integrates breach discovery into the same priority queue as flow processing, which is more complex but equivalent. **Key Advantages Over Legacy Approach:** - Explicit outlet classification prevents boundary artifacts - Selective breaching (max_depth, max_length constraints) preserves legitimate basins and endorheic systems - Epsilon gradient applied during fill (not after), eliminating need for post-hoc flat resolution - No workarounds needed (_fix_coastal_flow, fill_small_sinks) - Deterministic and reproducible (no randomness or ties) **Critical Correctness Requirements:** After conditioning, the DEM must satisfy: 1. Every land cell (not NoData) has a path to an outlet 2. Every land cell has at least one downslope neighbor (guaranteed by fill) 3. No cycles in flow directions (validated by compute_drainage_area) 4. Outlets are lowest points in their regions (set explicitly in Stage 1) If cycles are detected in compute_drainage_area, check: - Outlets are correctly identified (Stage 1) - Breach parameters (max_depth, max_length) aren't too restrictive - Epsilon isn't too large (causing reversals on filled areas) **Parameter Tuning:** - **max_breach_depth**: Controls minimum basin depth preserved - Smaller values (5-20m) preserve more detailed basins - Larger values (50-100m) allow breaching of deeper basins - Default 50m is suitable for most 30m DEMs - **max_breach_length**: Controls maximum breach path extent - Smaller values (10-30 cells) for detailed hydrology - Larger values (100-300 cells) for regional studies - Default 100 cells suitable for outlet-finding on large DEMs - **epsilon**: Controls micro-gradient in filled areas - Auto-calculated as 1e-5 × cell_resolution by default - Usually no manual tuning needed - If flats look "stair-stepped", epsilon is too large **Performance Notes:** - Stage 1 (outlets): O(n) with 8-neighbor checks - Stage 2a (breach): O(n log n) per sink via Dijkstra + priority queue - Stage 2b (fill): O(n log n) via priority-flood - Overall: O(n log n) where n = number of cells - Typical performance: 100M cells in ~30s (C/Rust), ~5m (Python/numba) Examples -------- >>> dem = np.array([[5, 5, 5], ... [5, 3, 5], ... [5, 5, 5]], dtype=np.float32) >>> nodata = np.zeros((3, 3), dtype=bool) >>> conditioned, outlets, breached = condition_dem_spec(dem, nodata) >>> # Outlets: [[False, False, False], >>> # [False, True, False], >>> # [False, False, False]] >>> # Conditioned: [[5, 5, 5], >>> # [5, 5, 5], >>> # [5, 5, 5]] (pit filled to neighbor level) >>> # Now use with Stage 3 and 4: >>> flow_dir = compute_flow_direction(conditioned) >>> drainage_area = compute_drainage_area(flow_dir) See Also -------- identify_outlets : Stage 1 breach_depressions_constrained : Stage 2a priority_flood_fill_epsilon : Stage 2b compute_flow_direction : Stage 3 compute_drainage_area : Stage 4 """ if nodata_mask is None: nodata_mask = np.zeros_like(dem, dtype=bool) print(" Stage 1: Identifying outlets...") outlets = identify_outlets( dem, nodata_mask, coastal_elev_threshold, edge_mode, masked_basin_outlets ) num_outlets = np.sum(outlets) print(f" Found {num_outlets:,} outlet cells") # Skip breaching if disabled (max_breach_depth <= 0 or max_breach_length <= 0) if max_breach_depth <= 0 or max_breach_length <= 0: print(" Stage 2a: Breaching SKIPPED (disabled via parameters)") breached = dem.copy() else: print(f" Stage 2a: Constrained breaching (max_depth={max_breach_depth}m, max_length={max_breach_length} cells)...") breached = breach_depressions_constrained( dem, outlets, max_breach_depth, max_breach_length, epsilon, nodata_mask, parallel_method=parallel_method ) print(" Stage 2b: Priority-flood fill residuals...") filled = priority_flood_fill_epsilon( breached, outlets, epsilon, nodata_mask ) # Ensure masked cells maintain original elevation filled[nodata_mask] = dem[nodata_mask] # Ensure we never lowered elevations below what breaching produced # (breaching can lower elevations to create flow paths) filled = np.maximum(filled, breached) print(" DEM conditioning complete (spec-compliant pipeline)") return filled, outlets, breached
[docs] def detect_ocean_mask( dem: np.ndarray, threshold: float = 0.0, border_only: bool = True ) -> np.ndarray: """ Detect ocean or water bodies in DEM. Identifies cells at or below elevation threshold that are connected to the border (assumed to be ocean/large water bodies). Uses connected component labeling for efficient O(n) detection. Parameters ---------- dem : np.ndarray Digital elevation model threshold : float, default 0.0 Elevation threshold (meters). Cells <= threshold are candidates. border_only : bool, default True If True, only return border-connected low-elevation areas (ocean). If False, return all areas below threshold (includes inland lakes). Returns ------- np.ndarray (bool) Boolean mask where True = ocean/water Examples -------- >>> dem = np.array([[0, 0, 5], [0, 1, 6], [5, 6, 7]]) >>> ocean = detect_ocean_mask(dem, threshold=0.0, border_only=True) >>> ocean array([[ True, True, False], [ True, False, False], [False, False, False]]) """ from scipy.ndimage import label # Find all cells at or below threshold low_elevation = dem <= threshold if not border_only: return low_elevation # No low-elevation cells? No ocean. if not np.any(low_elevation): return np.zeros_like(dem, dtype=bool) # Label connected regions (O(n) operation) structure = np.ones((3, 3), dtype=bool) # 8-connectivity labeled, num_features = label(low_elevation, structure=structure) # Find labels that touch any border border_labels = set() border_labels.update(labeled[0, :]) # Top border border_labels.update(labeled[-1, :]) # Bottom border border_labels.update(labeled[:, 0]) # Left border border_labels.update(labeled[:, -1]) # Right border border_labels.discard(0) # Remove background label # Create mask for all border-connected regions ocean_mask = np.isin(labeled, list(border_labels)) return ocean_mask
[docs] def detect_endorheic_basins( dem: np.ndarray, min_size: int = 10, exclude_mask: np.ndarray | None = None, min_depth: float = 0.5, ) -> tuple[np.ndarray, dict]: """ Detect endorheic (closed) basins in DEM. Identifies closed depressions (basins with no outlet) that exceed a minimum size threshold. Used to preserve large natural basins like the Salton Sea, Death Valley, etc. Parameters ---------- dem : np.ndarray Digital elevation model min_size : int, default 10 Minimum basin size in cells to be considered significant exclude_mask : np.ndarray (bool), optional Mask of areas to exclude from basin detection (e.g., ocean). This improves performance by only filling land areas. min_depth : float, default 0.5 Minimum depression depth in meters to be considered a basin. Higher values = only preserve truly deep basins, fill shallower ones. Returns ------- basin_mask : np.ndarray (bool) Boolean mask where True = part of endorheic basin basin_sizes : dict Dictionary mapping basin_id to size in cells Examples -------- >>> # Create closed basin surrounded by mountains >>> dem = np.array([[50, 50, 50], ... [50, 10, 50], ... [50, 50, 50]]) >>> mask, sizes = detect_endorheic_basins(dem, min_size=1) >>> mask[1, 1] # Center is basin True """ from scipy.ndimage import label # Fill depressions to find what WOULD be filled # Pass exclude_mask to avoid filling ocean (performance optimization) filled = _fill_depressions(dem, epsilon=0.0, mask=exclude_mask) fill_depth = filled - dem # Cells that would be filled are part of depressions # Higher min_depth = only preserve truly deep basins, fill shallower ones depressions = fill_depth > min_depth if not np.any(depressions): # No depressions found return np.zeros_like(dem, dtype=bool), {} # Label connected depression regions (O(n) operation) structure = np.ones((3, 3), dtype=bool) # 8-connectivity labeled, num_features = label(depressions, structure=structure) # Calculate size of each basin (vectorized for performance) unique_labels, label_counts = np.unique(labeled[labeled > 0], return_counts=True) basin_sizes = dict(zip(unique_labels.tolist(), label_counts.tolist())) # Create mask for large basins (vectorized operation) basin_mask = np.zeros_like(dem, dtype=bool) effective_min_size = min_size if min_size is not None else 10 large_basin_ids = [bid for bid, size in basin_sizes.items() if size >= effective_min_size] if large_basin_ids: basin_mask = np.isin(labeled, large_basin_ids) return basin_mask, basin_sizes
[docs] def condition_dem( dem: np.ndarray, method: str = "fill", ocean_mask: np.ndarray | None = None, min_basin_size: int | None = None, max_fill_depth: float | None = None, min_basin_depth: float = 0.5, fill_small_sinks: int | None = None, ) -> np.ndarray: """ Condition DEM by filling pits and resolving depressions. Uses morphological reconstruction (priority flood algorithm) to properly fill depressions and pits. This is the standard algorithm used by most hydrological analysis tools (e.g., GRASS, WhiteboxTools, ArcGIS). Supports masking ocean areas and preserving large endorheic basins. Parameters ---------- dem : np.ndarray Input DEM method : str, default 'fill' Depression handling method ('fill' or 'breach') - 'fill': Complete depression filling (raises elevations) - 'breach': Minimal filling to preserve terrain (uses epsilon) ocean_mask : np.ndarray (bool), optional Boolean mask indicating ocean/water cells to exclude from conditioning. Masked cells maintain original elevation. min_basin_size : int, optional Minimum basin size (cells) to preserve. Basins >= this size are not filled (preserves large endorheic basins like Salton Sea). max_fill_depth : float, optional Maximum fill depth (meters). Depressions requiring fill > this depth are preserved (protects deep natural basins). min_basin_depth : float, default 0.5 Minimum depression depth (meters) to be considered a preservable basin. Higher values = only preserve truly deep basins, fill shallower ones. Increase this for noisy high-resolution DEMs. fill_small_sinks : int, optional Maximum sink size (cells) to fill. After main filling, any remaining local minima (sinks) with contributing area < this size are filled. This removes small artifacts that create fragmented drainage. Typical values: 10-100 cells. Returns ------- np.ndarray Conditioned DEM with depressions resolved Examples -------- >>> # Basic usage >>> conditioned = condition_dem(dem) >>> # Mask ocean >>> ocean = detect_ocean_mask(dem, threshold=0.0) >>> conditioned = condition_dem(dem, ocean_mask=ocean) >>> # Preserve large basins >>> conditioned = condition_dem(dem, min_basin_size=10000) >>> # Fill small sinks (< 50 cells) to reduce fragmentation >>> conditioned = condition_dem(dem, fill_small_sinks=50) """ # ========== BASIN PRESERVATION LOGIC ========== # Large endorheic basins (e.g., Salton Sea, Death Valley) are preserved # using TWO independent mechanisms that can work together: # # 1. Size-based preservation (min_basin_size): # - Detects closed depressions larger than min_basin_size cells # - Preserves these naturally-occurring basins by excluding from filling # - Example: min_basin_size=10000 preserves basins >= 10,000 cells # # 2. Depth-based preservation (max_fill_depth): # - After filling, restores any cells that required > max_fill_depth meters # - Example: max_fill_depth=50 preserves basins requiring >50m fill # # These mechanisms preserve real geographic features while still filling # noise and local pits. # Create combined exclusion mask exclude_mask = np.zeros_like(dem, dtype=bool) if ocean_mask is not None: exclude_mask |= ocean_mask # === Size-based preservation: Preserve large closed basins === if min_basin_size is not None: basin_mask, basin_sizes = detect_endorheic_basins( dem, min_size=min_basin_size, exclude_mask=ocean_mask, min_depth=min_basin_depth ) total_depressions = len(basin_sizes) large_basins = sum(1 for size in basin_sizes.values() if size >= min_basin_size) num_cells_masked = np.sum(basin_mask) pct_masked = 100 * num_cells_masked / basin_mask.size print(f" Basin preservation: {total_depressions} depressions >{min_basin_depth}m deep, {large_basins} >= {min_basin_size} cells") print(f" Masked {num_cells_masked:,} cells ({pct_masked:.1f}% of DEM)") exclude_mask |= basin_mask # === Main depression filling === # Priority-flood algorithm properly fills depressions while respecting # exclusion masks (ocean, large basins, etc.) if method == "fill": filled = _fill_depressions(dem, epsilon=0.0, mask=exclude_mask) elif method == "breach": # Use 1e-4 (0.1mm) to avoid fragmentation from rounding (10mm precision) filled = _fill_depressions(dem, epsilon=1e-4, mask=exclude_mask) else: raise ValueError(f"Unknown fill method: {method}") # === Depth-based preservation: Preserve deep basins === # If a depression would require > max_fill_depth meters of fill, # preserve it at original elevation (it's likely a real natural feature) if max_fill_depth is not None: fill_depth = filled - dem deep_basins = fill_depth > max_fill_depth filled[deep_basins] = dem[deep_basins] # Ensure masked cells maintain original elevation filled[exclude_mask] = dem[exclude_mask] # Ensure we never lower original elevations filled = np.maximum(filled, dem) # Fill small sinks if requested if fill_small_sinks is not None and fill_small_sinks > 0: filled = _fill_small_sinks(filled, max_sink_size=fill_small_sinks, mask=exclude_mask) return filled
def _fill_small_sinks( dem: np.ndarray, max_sink_size: int = 50, mask: np.ndarray | None = None, ) -> np.ndarray: """ Fill small sinks (local minima) that would create fragmented drainage. Iteratively finds local minima and fills small ones to their spill point. This removes DEM noise artifacts that create many tiny outlets. Optimized implementation: O(m + n*k) where m=grid size, n=num_sinks, k=avg_sink_size. Uses single-pass index collection instead of per-sink grid scans. Parameters ---------- dem : np.ndarray Input DEM (already filled/breached) max_sink_size : int, default 50 Maximum size (cells) of sinks to fill. Sinks larger than this are preserved (they may be real features). mask : np.ndarray (bool), optional Cells to exclude from filling (e.g., ocean, large basins). Returns ------- np.ndarray DEM with small sinks filled """ from scipy.ndimage import label, minimum_filter, maximum_filter, find_objects filled = dem.copy().astype(np.float64) rows, cols = dem.shape # Find local minima: cells that are strictly lower than all 8 neighbors footprint = np.ones((3, 3), dtype=bool) local_min_filter = minimum_filter(filled, footprint=footprint, mode='constant', cval=np.inf) local_max_filter = maximum_filter(filled, footprint=footprint, mode='constant', cval=-np.inf) local_minima = (filled == local_min_filter) & (filled < local_max_filter) # Exclude masked cells and boundary if mask is not None: local_minima &= ~mask local_minima[0, :] = False local_minima[-1, :] = False local_minima[:, 0] = False local_minima[:, -1] = False num_minima = np.sum(local_minima) print(f" Small sink detection: found {num_minima:,} local minima cells") if num_minima == 0: return filled.astype(np.float32) # Label connected sink regions structure = np.ones((3, 3), dtype=bool) # 8-connectivity labeled, num_features = label(local_minima, structure=structure) print(f" Small sink detection: {num_features:,} connected sink regions") # Get bounding boxes for all regions in one pass slices = find_objects(labeled) # Precompute sink indices using vectorized numpy operations # Get all labeled cell coordinates at once labeled_rows, labeled_cols = np.where(labeled > 0) labeled_ids = labeled[labeled_rows, labeled_cols] # Sort by label ID to group cells belonging to same sink sort_order = np.argsort(labeled_ids) sorted_rows = labeled_rows[sort_order] sorted_cols = labeled_cols[sort_order] sorted_ids = labeled_ids[sort_order] # Find split points between different labels # np.diff finds where labels change, np.where finds those positions split_points = np.where(np.diff(sorted_ids) > 0)[0] + 1 split_points = np.concatenate([[0], split_points, [len(sorted_ids)]]) # Build dict mapping sink_id -> (row_indices, col_indices) sink_indices = {} for i in range(len(split_points) - 1): start, end = split_points[i], split_points[i + 1] if start < end: sink_id = sorted_ids[start] sink_indices[sink_id] = (sorted_rows[start:end], sorted_cols[start:end]) # Process each sink using precomputed indices sinks_filled = 0 cells_raised = 0 for sink_id in range(1, num_features + 1): if sink_id not in sink_indices: continue sink_rows, sink_cols = sink_indices[sink_id] sink_size = len(sink_rows) if sink_size > max_sink_size: continue # Skip large sinks obj_slice = slices[sink_id - 1] if obj_slice is None: continue # Expand slice by 1 for boundary detection r_start = max(0, obj_slice[0].start - 1) r_stop = min(rows, obj_slice[0].stop + 1) c_start = max(0, obj_slice[1].start - 1) c_stop = min(cols, obj_slice[1].stop + 1) # Work on cropped region local_labeled = labeled[r_start:r_stop, c_start:c_stop] local_filled = filled[r_start:r_stop, c_start:c_stop] local_mask = mask[r_start:r_stop, c_start:c_stop] if mask is not None else None # Create sink mask for this region only (small array) sink_mask_local = local_labeled == sink_id # Find boundary using dilation on small region from scipy.ndimage import binary_dilation expanded = binary_dilation(sink_mask_local, structure=structure) boundary = expanded & ~sink_mask_local if local_mask is not None: boundary &= ~local_mask if not np.any(boundary): continue sink_elevation = np.min(local_filled[sink_mask_local]) boundary_elevs = local_filled[boundary] higher_neighbors = boundary_elevs > sink_elevation # Determine new elevation if np.any(higher_neighbors): new_elev = np.min(boundary_elevs[higher_neighbors]) + 1e-6 else: lowest_boundary_elev = np.min(boundary_elevs) if lowest_boundary_elev <= sink_elevation: new_elev = lowest_boundary_elev + 1e-6 else: continue # Update using precomputed indices (vectorized, avoids full-grid scan) filled[sink_rows, sink_cols] = new_elev sinks_filled += 1 cells_raised += sink_size if sinks_filled > 0: print(f" Filled {sinks_filled} small sinks ({cells_raised:,} cells) with max_size={max_sink_size}") return filled.astype(np.float32) def _fill_depressions( dem: np.ndarray, epsilon: float = 0.0, mask: np.ndarray | None = None ) -> np.ndarray: """ Fill depressions in DEM using morphological reconstruction. This implements depression filling via morphological reconstruction. Algorithm: Reconstruct from seed (borders at DEM elevation, interior at +inf) downward, constrained by original DEM. Parameters ---------- dem : np.ndarray Input digital elevation model epsilon : float, default 0.0 Small gradient to add in flat areas (for 'breach' method) If > 0, creates minimal gradients instead of true flats mask : np.ndarray (bool), optional Boolean mask indicating cells to exclude from filling. Masked cells maintain original elevation. Returns ------- np.ndarray Depression-filled DEM """ from skimage.morphology import reconstruction dem = dem.astype(np.float64) # Use float64 for precision # Create seed: borders at DEM elevation, interior slightly higher # This allows reconstruction to fill depressions seed = dem.copy() seed[1:-1, 1:-1] = dem.max() + 1000 # Interior much higher # If mask provided, set masked cells as borders (won't be filled) if mask is not None: seed[mask] = dem[mask] # Morphological reconstruction by erosion # Erode seed downward, constrained by mask (original DEM) # This fills depressions to their spill point elevation filled = reconstruction(seed, dem, method='erosion') # Restore masked cells to original elevation if mask is not None: filled[mask] = dem[mask] # Add epsilon gradients if requested (for breach method) if epsilon > 0: filled = _resolve_flats(filled, epsilon) return filled.astype(np.float32) def _resolve_flats(dem: np.ndarray, epsilon: float = 1e-5) -> np.ndarray: """ Resolve flat areas using Garbrecht-Martz (1997) dual-gradient algorithm. Uses TWO gradients combined: 1. Gradient TOWARD lower terrain (pour points) - water flows to outlets 2. Gradient AWAY from higher terrain (high points) - water flows from ridges The combined gradient ensures natural flow convergence in flat regions. References: - Garbrecht & Martz (1997): "The assignment of drainage direction over flat surfaces in raster digital elevation models" J. Hydrol. 193: 204-213 - Barnes et al. (2014): "An Efficient Assignment of Drainage Direction Over Flat Surfaces" arXiv:1511.04433 Parameters ---------- dem : np.ndarray Input DEM (potentially with flat areas) epsilon : float Small value for gradient increment (default: 1e-5 m) Returns ------- np.ndarray DEM with flats resolved using dual gradients """ resolved = dem.copy().astype(np.float64) rows, cols = dem.shape # Round DEM to eliminate floating-point precision issues # Use 10mm precision (0.01m) to avoid fragmentation from tiny breach gradients dem_rounded = np.round(dem, 2) # Find flat cells: cells with at least one neighbor at the SAME elevation flat_cells = np.zeros((rows, cols), dtype=bool) # Create padded DEM for safe neighbor access padded = np.full((rows + 2, cols + 2), np.nan, dtype=dem_rounded.dtype) padded[1:-1, 1:-1] = dem_rounded # Check all 8 neighbors for equal elevation (vectorized) for di, dj in [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]: neighbor = padded[1 + di:rows + 1 + di, 1 + dj:cols + 1 + dj] flat_cells |= (dem_rounded == neighbor) flat_count = np.sum(flat_cells) if flat_count == 0: return resolved.astype(np.float32) print(f" Flat resolution: {flat_count:,} flat cells found") # Find pour points: flat cells adjacent to STRICTLY LOWER terrain pour_points = np.zeros((rows, cols), dtype=bool) for di, dj in [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]: padded_inf = np.full((rows + 2, cols + 2), np.inf, dtype=dem_rounded.dtype) padded_inf[1:-1, 1:-1] = dem_rounded neighbor = padded_inf[1 + di:rows + 1 + di, 1 + dj:cols + 1 + dj] pour_points |= flat_cells & (neighbor < dem_rounded) # Boundary flat cells are also pour points boundary_mask = np.zeros((rows, cols), dtype=bool) boundary_mask[0, :] = True boundary_mask[-1, :] = True boundary_mask[:, 0] = True boundary_mask[:, -1] = True pour_points |= flat_cells & boundary_mask # NEW: Find high points: flat cells adjacent to STRICTLY HIGHER terrain high_points = np.zeros((rows, cols), dtype=bool) for di, dj in [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]: padded_neg = np.full((rows + 2, cols + 2), -np.inf, dtype=dem_rounded.dtype) padded_neg[1:-1, 1:-1] = dem_rounded neighbor = padded_neg[1 + di:rows + 1 + di, 1 + dj:cols + 1 + dj] high_points |= flat_cells & (neighbor > dem_rounded) pour_count = np.sum(pour_points) high_count = np.sum(high_points) if pour_count == 0 and high_count == 0: return resolved.astype(np.float32) print(f" Flat resolution: {pour_count:,} pour points, {high_count:,} high points") # Compute DUAL gradients (Garbrecht-Martz algorithm) # Gradient 1: Distance from pour points (toward lower terrain) dist_to_low = _compute_flat_gradient_bfs(flat_cells, pour_points, dem_rounded) # Gradient 2: Distance from high points (away from higher terrain) dist_from_high = _compute_flat_gradient_bfs(flat_cells, high_points, dem_rounded) # Combine gradients: cells should be HIGHER if: # - farther from pour points (dist_to_low is larger) # - closer to high points (dist_from_high is smaller) # Formula: combined = dist_to_low + (max_dist - dist_from_high) # Simplified: combined = dist_to_low - dist_from_high + const # The constant doesn't matter since we're adding relative gradients # Find max distances for normalization max_dist_low = np.max(dist_to_low[flat_cells]) if pour_count > 0 else 0 max_dist_high = np.max(dist_from_high[flat_cells]) if high_count > 0 else 0 # Combine: cells farther from outlets AND closer to ridges get higher elevation # This creates natural flow convergence toward outlets and away from ridges combined_gradient = np.zeros_like(dist_to_low) if pour_count > 0 and high_count > 0: # Full dual-gradient: both components # dist_to_low: higher = farther from outlet = should be higher # dist_from_high: higher = farther from ridge = should be lower # Combined: dist_to_low adds elevation, dist_from_high subtracts combined_gradient[flat_cells] = ( dist_to_low[flat_cells] + (max_dist_high - dist_from_high[flat_cells]) ) elif pour_count > 0: # Only pour points: just use distance toward outlets combined_gradient[flat_cells] = dist_to_low[flat_cells] elif high_count > 0: # Only high points: just use distance from ridges (inverted) combined_gradient[flat_cells] = max_dist_high - dist_from_high[flat_cells] # Apply combined gradient resolved[flat_cells] += combined_gradient[flat_cells] * epsilon return resolved.astype(np.float32) @jit(nopython=True, cache=True) def _compute_flat_gradient_bfs_jit( flat_cells: np.ndarray, pour_points: np.ndarray, dem_rounded: np.ndarray, gradient: np.ndarray ) -> None: """ JIT-compiled multi-source BFS to compute gradient from pour points. Computes geodesic distance from pour points within each flat region, respecting elevation boundaries (cells at different elevations are barriers). Parameters ---------- flat_cells : np.ndarray (bool) Mask of flat cells pour_points : np.ndarray (bool) Mask of pour points (sources for BFS) dem_rounded : np.ndarray Rounded DEM for elevation comparison gradient : np.ndarray (float64) Output gradient array (modified in-place) """ rows, cols = flat_cells.shape # D8 offsets for neighbor checking offsets = np.array([ (-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1) ], dtype=np.int32) # Initialize queue with all pour points queue = np.zeros(rows * cols, dtype=np.int32) queue_start = 0 queue_end = 0 # Visited array (distance from nearest pour point at same elevation) visited = np.zeros((rows, cols), dtype=np.int32) visited[:, :] = -1 # -1 = not visited # Add all pour points to queue with distance 0 for i in range(rows): for j in range(cols): if pour_points[i, j]: queue[queue_end] = i * cols + j queue_end += 1 visited[i, j] = 0 # BFS from all pour points simultaneously while queue_start < queue_end: flat_idx = queue[queue_start] queue_start += 1 i = flat_idx // cols j = flat_idx % cols current_dist = visited[i, j] current_elev = dem_rounded[i, j] # Check all 8 neighbors for k in range(8): di = offsets[k, 0] dj = offsets[k, 1] ni = i + di nj = j + dj # Bounds check if 0 <= ni < rows and 0 <= nj < cols: # Only expand to flat cells at SAME elevation that haven't been visited if (flat_cells[ni, nj] and visited[ni, nj] == -1 and dem_rounded[ni, nj] == current_elev): visited[ni, nj] = current_dist + 1 queue[queue_end] = ni * cols + nj queue_end += 1 # Copy distances to gradient (pour points stay at 0) for i in range(rows): for j in range(cols): if visited[i, j] > 0: gradient[i, j] = visited[i, j] def _compute_flat_gradient_bfs( flat_cells: np.ndarray, pour_points: np.ndarray, dem_rounded: np.ndarray ) -> np.ndarray: """ Compute gradient from pour points using multi-source BFS. Parameters ---------- flat_cells : np.ndarray (bool) Mask of flat cells pour_points : np.ndarray (bool) Mask of pour points (sources for BFS) dem_rounded : np.ndarray Rounded DEM for elevation comparison Returns ------- np.ndarray Gradient values (distance from nearest pour point at same elevation) """ gradient = np.zeros(flat_cells.shape, dtype=np.float64) if NUMBA_AVAILABLE: _compute_flat_gradient_bfs_jit(flat_cells, pour_points, dem_rounded, gradient) else: # Pure Python fallback (slower but works) rows, cols = flat_cells.shape offsets = [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)] from collections import deque queue = deque() visited = np.full((rows, cols), -1, dtype=np.int32) # Initialize with pour points pour_coords = np.argwhere(pour_points) for i, j in pour_coords: queue.append((i, j, 0)) visited[i, j] = 0 # BFS while queue: i, j, dist = queue.popleft() current_elev = dem_rounded[i, j] for di, dj in offsets: ni, nj = i + di, j + dj if (0 <= ni < rows and 0 <= nj < cols and flat_cells[ni, nj] and visited[ni, nj] == -1 and dem_rounded[ni, nj] == current_elev): visited[ni, nj] = dist + 1 gradient[ni, nj] = dist + 1 queue.append((ni, nj, dist + 1)) return gradient def _write_geotiff( path: str, data: np.ndarray, transform: Affine, crs: rasterio.crs.CRS ) -> None: """ Write numpy array to GeoTIFF file. Parameters ---------- path : str Output file path data : np.ndarray Data array to write transform : Affine Affine transform crs : rasterio.crs.CRS Coordinate reference system """ height, width = data.shape with rasterio.open( path, "w", driver="GTiff", height=height, width=width, count=1, dtype=data.dtype, crs=crs, transform=transform, compress="lzw", ) as dst: dst.write(data, 1)