"""
Generic gridded data loader with pipeline caching for terrain visualization.
Handles loading external gridded datasets (SNODAS, temperature, precipitation, etc.),
processing through user-defined pipelines, and caching each step independently.
Features:
- Transparent automatic tiling for large datasets
- Memory monitoring with failsafe to prevent OOM/thrashing
- Per-step and merged result caching
- Smart aggregation (concatenation for spatial data, averaging for statistics)
"""
import gc
import hashlib
import inspect
import json
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Dict, Any, Tuple, List, Optional
import numpy as np
logger = logging.getLogger(__name__)
[docs]
class MemoryLimitExceeded(Exception):
"""Raised when memory usage exceeds configured limits."""
pass
[docs]
@dataclass
class TiledDataConfig:
"""Configuration for automatic tiling in GriddedDataLoader."""
max_output_pixels: int = 4096 * 4096
"""Maximum output pixels before triggering tiling (default: ~16M = ~64MB for float32)."""
target_tile_outputs: int = 2000
"""Target output pixels per tile dimension (default: 2000x2000)."""
halo: int = 0
"""Halo size for operations needing boundary overlap (default: 0 for gridded data)."""
enable_tile_cache: bool = True
"""Cache individual tiles (vs only final merged result)."""
aggregation_strategy: str = "auto"
"""How to merge tiles: 'concatenate', 'mean', 'weighted_mean', 'auto' (default)."""
max_memory_percent: float = 85.0
"""Maximum RAM usage percent before aborting (default: 85%)."""
max_swap_percent: float = 50.0
"""Maximum swap usage percent before aborting (default: 50%)."""
memory_check_interval: float = 5.0
"""Seconds between memory checks (default: 5s)."""
enable_memory_monitoring: bool = True
"""Enable memory monitoring failsafe (default: True)."""
[docs]
@dataclass
class TileSpecGridded:
"""Tile specification with geographic extent for gridded data."""
src_slice: Tuple[slice, slice]
"""Slice into source DEM (with halo)."""
out_slice: Tuple[slice, slice]
"""Slice into output arrays."""
extent: Tuple[float, float, float, float]
"""Geographic extent (minx, miny, maxx, maxy)."""
target_shape: Tuple[int, int]
"""Target output shape for this tile (height, width)."""
[docs]
class MemoryMonitor:
"""Monitor system memory and abort processing if limits exceeded."""
[docs]
def __init__(self, config: TiledDataConfig):
"""
Initialize memory monitor.
Args:
config: TiledDataConfig with memory thresholds
"""
try:
import psutil # pylint: disable=import-outside-toplevel
self.psutil = psutil
self.available = True
except ImportError:
logger.warning("psutil not installed - memory monitoring disabled")
self.available = False
self.max_memory_percent = config.max_memory_percent
self.max_swap_percent = config.max_swap_percent
self.check_interval = config.memory_check_interval
self.last_check = 0
self.enabled = config.enable_memory_monitoring and self.available
[docs]
def check_memory(self, force: bool = False) -> None:
"""
Check memory usage and raise MemoryLimitExceeded if over threshold.
Args:
force: Force check even if check_interval hasn't elapsed
Raises:
MemoryLimitExceeded: If memory or swap usage exceeds limits
"""
if not self.enabled:
return
current_time = time.time()
if not force and (current_time - self.last_check) < self.check_interval:
return
self.last_check = current_time
# Check RAM usage
memory = self.psutil.virtual_memory()
if memory.percent > self.max_memory_percent:
raise MemoryLimitExceeded(
f"Memory usage {memory.percent:.1f}% exceeds limit "
f"{self.max_memory_percent}%"
)
# Check swap usage
swap = self.psutil.swap_memory()
if swap.percent > self.max_swap_percent:
raise MemoryLimitExceeded(
f"Swap usage {swap.percent:.1f}% exceeds limit "
f"{self.max_swap_percent}%"
)
logger.debug(f"Memory: {memory.percent:.1f}% RAM, {swap.percent:.1f}% swap")
[docs]
class GriddedDataLoader:
"""
Load and cache external gridded data with pipeline processing.
This class provides a general framework for:
- Loading gridded data from arbitrary formats
- Processing data through multi-step pipelines
- Caching each pipeline step independently
- Smart cache invalidation based on step dependencies
Pipeline format: List of (name, function, kwargs) tuples
Example:
>>> def load_data(source, extent, target_shape):
... # Load and crop data
... return {"raw": data_array}
>>>
>>> def compute_stats(input_data):
... # Compute statistics from previous step
... raw = input_data["raw"]
... return {"mean": raw.mean(), "std": raw.std()}
>>>
>>> pipeline = [
... ("load", load_data, {}),
... ("stats", compute_stats, {}),
... ]
>>>
>>> loader = GriddedDataLoader(terrain, cache_dir=Path(".cache"))
>>> result = loader.run_pipeline(
... data_source="/path/to/data",
... pipeline=pipeline,
... cache_name="my_analysis"
... )
"""
[docs]
def __init__(
self,
terrain,
cache_dir: Path = None,
auto_tile: bool = True,
tile_config: Optional[TiledDataConfig] = None,
):
"""
Initialize gridded data loader.
Args:
terrain: Terrain object (provides extent and resolution)
cache_dir: Directory for caching (default: .gridded_data_cache)
auto_tile: Enable automatic tiling when outputs exceed memory threshold
(default: True)
tile_config: TiledDataConfig for tiling behavior (uses defaults if None)
"""
self.terrain = terrain
self.cache_dir = cache_dir or Path(".gridded_data_cache")
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.auto_tile = auto_tile
self.tile_config = tile_config or TiledDataConfig()
logger.debug(f"GriddedDataLoader initialized at: {self.cache_dir}")
logger.debug(f"Auto-tiling: {self.auto_tile}")
[docs]
def run_pipeline(
self,
data_source: Any,
pipeline: List[Tuple[str, Callable, Dict]],
cache_name: str,
force_reprocess: bool = False,
) -> Any:
"""
Execute a processing pipeline with caching at each step.
Features:
- Transparent automatic tiling for large outputs
- Memory monitoring with failsafe
- Per-step and merged result caching
Args:
data_source: Data source (directory, file list, URL, etc.)
pipeline: List of (step_name, function, kwargs) tuples
Each function receives previous step's output as first arg
cache_name: Base name for cache files
force_reprocess: Force reprocessing all steps even if cached
Returns:
Output of final pipeline step
Raises:
MemoryLimitExceeded: If memory limits exceeded during tiling
"""
logger.info(f"Running pipeline '{cache_name}' with {len(pipeline)} steps")
# Check if tiling is needed
if self.auto_tile and self._should_tile(data_source, pipeline, cache_name):
# Get terrain extent
extent = self.terrain.dem_bounds
# Determine target shape from first step output (cached from _should_tile)
try:
step_name, step_func, step_kwargs = pipeline[0]
step_output, _ = self._execute_step(
step_name=step_name,
func=step_func,
input_data=data_source,
kwargs=step_kwargs,
upstream_cache_key=self._compute_source_cache_key(data_source),
cache_name=cache_name,
force_reprocess=False, # Use cached result
)
target_shape = self._get_output_shape(step_output)
if not target_shape or len(target_shape) < 2:
logger.warning("Could not determine output shape, skipping tiling")
return self._run_pipeline_non_tiled(
data_source, pipeline, cache_name, force_reprocess
)
target_shape = tuple(target_shape[:2]) # Use first 2 dimensions
except Exception as e:
logger.warning(f"Error determining target shape for tiling: {e}, skipping tiling")
return self._run_pipeline_non_tiled(
data_source, pipeline, cache_name, force_reprocess
)
# Create tile specifications
tile_specs = self._create_tile_specs(target_shape, extent)
# Execute tiled pipeline
return self._execute_tiled_pipeline(
data_source, pipeline, cache_name, tile_specs, force_reprocess, target_shape, extent
)
else:
# No tiling needed, use standard pipeline execution
return self._run_pipeline_non_tiled(
data_source, pipeline, cache_name, force_reprocess
)
def _run_pipeline_non_tiled(
self,
data_source: Any,
pipeline: List[Tuple[str, Callable, Dict]],
cache_name: str,
force_reprocess: bool,
) -> Any:
"""
Execute pipeline without tiling (original logic).
Args:
data_source: Data source
pipeline: Pipeline to execute
cache_name: Cache name
force_reprocess: Force reprocessing flag
Returns:
Output of final pipeline step
"""
current_data = data_source
upstream_cache_key = self._compute_source_cache_key(data_source)
for i, (step_name, step_func, step_kwargs) in enumerate(pipeline):
logger.info(f" Step {i+1}/{len(pipeline)}: {step_name}")
# Execute step with caching
current_data, step_cache_key = self._execute_step(
step_name=step_name,
func=step_func,
input_data=current_data,
kwargs=step_kwargs,
upstream_cache_key=upstream_cache_key,
cache_name=cache_name,
force_reprocess=force_reprocess,
)
# Update upstream key for next step
upstream_cache_key = step_cache_key
logger.info(f"Pipeline '{cache_name}' completed")
return current_data
def _should_tile(
self,
data_source: Any,
pipeline: List[Tuple[str, Callable, Dict]],
cache_name: str,
) -> bool:
"""
Detect if pipeline output would exceed memory threshold via dry-run.
Executes first pipeline step to determine output shape and size.
Leverages caching to avoid redundant computation.
Args:
data_source: Data source
pipeline: Pipeline to analyze
cache_name: Cache name for steps
Returns:
True if output size exceeds max_output_pixels threshold
"""
if not self.auto_tile or not pipeline:
return False
try:
step_name, step_func, step_kwargs = pipeline[0]
# Execute first step (may hit cache)
step_output, _ = self._execute_step(
step_name=step_name,
func=step_func,
input_data=data_source,
kwargs=step_kwargs,
upstream_cache_key=self._compute_source_cache_key(data_source),
cache_name=cache_name,
force_reprocess=False,
)
# Determine output size
total_pixels = self._get_array_pixel_count(step_output)
if total_pixels > self.tile_config.max_output_pixels:
logger.info(
f"Auto-tiling triggered: output size {total_pixels:,} pixels "
f"exceeds threshold {self.tile_config.max_output_pixels:,}"
)
return True
logger.debug(
f"Auto-tiling not needed: output size {total_pixels:,} pixels "
f"within threshold {self.tile_config.max_output_pixels:,}"
)
return False
except Exception as e:
logger.debug(f"Error detecting tiling need: {e}, defaulting to False")
return False
def _get_array_pixel_count(self, data: Any) -> int:
"""
Get total pixel count from array or dict of arrays.
Args:
data: Array, dict of arrays, or other data
Returns:
Total number of pixels, or 0 if not array data
"""
if isinstance(data, np.ndarray):
return data.size
if isinstance(data, dict):
total = 0
for v in data.values():
if isinstance(v, np.ndarray):
total += v.size
return total
return 0
def _create_tile_specs(
self,
target_shape: Tuple[int, int],
extent: Tuple[float, float, float, float],
) -> List[TileSpecGridded]:
"""
Create tile specifications for gridded data processing.
Divides target shape into regular tiles based on target_tile_outputs.
Each tile gets geographic extent and output slice information.
Args:
target_shape: Target output shape (height, width)
extent: Geographic extent (minx, miny, maxx, maxy)
Returns:
List of TileSpecGridded with tile layout
"""
target_h, target_w = target_shape
minx, miny, maxx, maxy = extent
tile_size = self.tile_config.target_tile_outputs
tiles = []
# Calculate number of tiles
tiles_h = (target_h + tile_size - 1) // tile_size
tiles_w = (target_w + tile_size - 1) // tile_size
logger.debug(f"Creating tile grid: {tiles_h}x{tiles_w} tiles")
for tile_row in range(tiles_h):
for tile_col in range(tiles_w):
# Compute tile bounds in output space
row_start = tile_row * tile_size
row_end = min(row_start + tile_size, target_h)
col_start = tile_col * tile_size
col_end = min(col_start + tile_size, target_w)
# Compute geographic extent for this tile
pixel_size_y = (maxy - miny) / target_h
pixel_size_x = (maxx - minx) / target_w
tile_minx = minx + col_start * pixel_size_x
tile_miny = miny + row_start * pixel_size_y
tile_maxx = minx + col_end * pixel_size_x
tile_maxy = miny + row_end * pixel_size_y
tile_extent = (tile_minx, tile_miny, tile_maxx, tile_maxy)
tile_target_shape = (row_end - row_start, col_end - col_start)
tile_spec = TileSpecGridded(
src_slice=(slice(row_start, row_end), slice(col_start, col_end)),
out_slice=(slice(row_start, row_end), slice(col_start, col_end)),
extent=tile_extent,
target_shape=tile_target_shape,
)
tiles.append(tile_spec)
logger.debug(
f" Tile [{tile_row},{tile_col}]: "
f"output slice [{row_start}:{row_end}, {col_start}:{col_end}], "
f"shape {tile_target_shape}"
)
return tiles
def _create_tile_source(
self,
data_source: Any,
tile_spec: TileSpecGridded,
) -> Dict[str, Any]:
"""
Create tile-specific data source with extent and target_shape.
Extracts tile extent and target shape for this tile so that pipeline
functions can process just the tile (not full extent).
Args:
data_source: Original data source
tile_spec: Specification for this tile
Returns:
Dict with 'extent' and 'target_shape' for this tile, preserving
original data_source structure/path
"""
return {
"data_source": data_source,
"extent": tile_spec.extent,
"target_shape": tile_spec.target_shape,
}
def _execute_step(
self,
step_name: str,
func: Callable,
input_data: Any,
kwargs: Dict,
upstream_cache_key: str,
cache_name: str,
force_reprocess: bool,
) -> Tuple[Any, str]:
"""
Execute a single pipeline step with caching.
Args:
step_name: Name of this step
func: Function to execute
input_data: Output from previous step (or data_source for first step)
kwargs: Additional arguments for func
upstream_cache_key: Cache key from previous step
cache_name: Base cache name
force_reprocess: Force recomputation
Returns:
Tuple of (step_output, step_cache_key)
"""
# Compute cache key for this step
step_cache_key = self._compute_step_cache_key(
step_name, func, kwargs, upstream_cache_key
)
cache_file = self.cache_dir / f"{cache_name}_{step_name}_{step_cache_key[:16]}.npz"
# Try to load from cache
if not force_reprocess and cache_file.exists():
logger.debug(f" Cache hit: {cache_file.name}")
step_output = self._load_step_cache(cache_file)
if step_output is not None:
return step_output, step_cache_key
# Execute step
logger.debug(f" Executing: {func.__name__}")
# Inject terrain parameters if needed
if self._needs_terrain_params(func):
extent = self.terrain.dem_bounds
pixel_size = 1.0 / 120.0 # SNODAS native resolution
minx, miny, maxx, maxy = extent
target_width = int(round((maxx - minx) / pixel_size))
target_height = int(round((maxy - miny) / pixel_size))
# Enforce max_output_pixels limit to prevent OOM
total_pixels = target_height * target_width
max_pixels = self.tile_config.max_output_pixels
if total_pixels > max_pixels:
# Downsample to fit within limit while preserving aspect ratio
scale = (max_pixels / total_pixels) ** 0.5
target_height = int(target_height * scale)
target_width = int(target_width * scale)
logger.info(
f"Limiting output to {target_height}×{target_width} "
f"({target_height * target_width:,} pixels) to stay within "
f"max_output_pixels={max_pixels:,}"
)
kwargs = {
**kwargs,
"extent": extent,
"target_shape": (target_height, target_width),
}
step_output = func(input_data, **kwargs)
# Save to cache
self._save_step_cache(cache_file, step_output)
logger.debug(f" Cached: {cache_file.name}")
return step_output, step_cache_key
def _execute_tiled_pipeline(
self,
data_source: Any,
pipeline: List[Tuple[str, Callable, Dict]],
cache_name: str,
tile_specs: List[TileSpecGridded],
force_reprocess: bool,
target_shape: Tuple[int, int],
extent: Tuple[float, float, float, float],
) -> Any:
"""
Execute pipeline with tiling and memory monitoring.
Processes each tile independently through the full pipeline,
with memory checks before each tile to prevent OOM/thrashing.
Args:
data_source: Original data source
pipeline: Pipeline to execute
cache_name: Cache name for steps
tile_specs: List of TileSpecGridded specifications
force_reprocess: Force reprocessing all tiles
target_shape: Target output shape for final result
extent: Geographic extent for entire dataset
Returns:
Aggregated results from all tiles
Raises:
MemoryLimitExceeded: If memory limits exceeded before completion
"""
logger.info(f"Executing tiled pipeline with {len(tile_specs)} tiles")
# Initialize memory monitor
monitor = MemoryMonitor(self.tile_config)
# Check memory before starting
try:
monitor.check_memory(force=True)
except MemoryLimitExceeded as e:
logger.error(f"Memory limit exceeded before starting: {e}")
raise
tile_outputs = []
tile_shapes = {} # Track output shapes for aggregation
# Process each tile
for i, tile_spec in enumerate(tile_specs):
# Check memory before processing this tile
try:
monitor.check_memory()
except MemoryLimitExceeded as e:
logger.error(f"Memory limit exceeded before tile {i+1}/{len(tile_specs)}: {e}")
logger.error(
f"Processed {i}/{len(tile_specs)} tiles successfully. "
f"Consider reducing tile size or freeing memory."
)
raise
logger.info(f"Processing tile {i+1}/{len(tile_specs)}: shape {tile_spec.target_shape}")
# Create tile-specific data source
tile_source = self._create_tile_source(data_source, tile_spec)
# Execute full pipeline for this tile
try:
tile_output = self._execute_tile_pipeline(
tile_source,
pipeline,
f"{cache_name}_tile{i}",
force_reprocess,
)
except Exception as e:
logger.error(f"Error processing tile {i+1}: {e}")
raise
tile_outputs.append(tile_output)
# Track output shapes for later aggregation
tile_shapes[i] = self._get_output_shape(tile_output)
# Force garbage collection after each tile
gc.collect()
# Check memory before aggregation
try:
monitor.check_memory(force=True)
except MemoryLimitExceeded as e:
logger.error(f"Memory limit exceeded before aggregation: {e}")
raise
# Aggregate tiles into final result
merged = self._aggregate_tiles(tile_outputs, tile_specs, target_shape)
logger.info(f"Tiled pipeline completed, aggregated result shape: {self._get_output_shape(merged)}")
return merged
def _execute_tile_pipeline(
self,
tile_source: Dict[str, Any],
pipeline: List[Tuple[str, Callable, Dict]],
cache_name: str,
force_reprocess: bool,
) -> Any:
"""
Execute full pipeline for a single tile.
Similar to run_pipeline but with tile-specific caching.
Args:
tile_source: Tile-specific data source (includes extent/target_shape)
pipeline: Pipeline to execute
cache_name: Cache name for this tile's steps
force_reprocess: Force reprocessing
Returns:
Output of final pipeline step for this tile
"""
current_data = tile_source
upstream_cache_key = hashlib.sha256(
f"{tile_source['extent']}:{tile_source['target_shape']}".encode()
).hexdigest()
for step_name, step_func, step_kwargs in pipeline:
# Execute step with tile-specific cache
current_data, upstream_cache_key = self._execute_step(
step_name=step_name,
func=step_func,
input_data=current_data,
kwargs=step_kwargs,
upstream_cache_key=upstream_cache_key,
cache_name=cache_name,
force_reprocess=force_reprocess,
)
return current_data
def _get_output_shape(self, data: Any) -> Tuple[int, ...]:
"""Get shape of array or dict of arrays."""
if isinstance(data, np.ndarray):
return data.shape
if isinstance(data, dict):
for v in data.values():
if isinstance(v, np.ndarray):
return v.shape
return ()
def _aggregate_tiles(
self,
tile_outputs: List[Any],
tile_specs: List[TileSpecGridded],
target_shape: Tuple[int, int],
) -> Any:
"""
Aggregate tile outputs into final result.
Auto-detects aggregation strategy based on output type and shape.
Args:
tile_outputs: List of outputs from each tile
tile_specs: List of tile specifications
target_shape: Target output shape for final result
Returns:
Aggregated result matching original expected output shape
"""
if not tile_outputs:
return None
# Auto-detect aggregation strategy
if self.tile_config.aggregation_strategy == "auto":
strategy = self._determine_aggregation(tile_outputs)
else:
strategy = self.tile_config.aggregation_strategy
logger.debug(f"Aggregating {len(tile_outputs)} tiles using strategy: {strategy}")
if strategy == "concatenate":
return self._concatenate_spatial(tile_outputs, tile_specs, target_shape)
elif strategy == "mean":
return self._average_statistics(tile_outputs)
elif strategy == "weighted_mean":
return self._weighted_average_statistics(tile_outputs, tile_specs)
else:
logger.warning(f"Unknown aggregation strategy '{strategy}', using first tile")
return tile_outputs[0]
def _determine_aggregation(self, tile_outputs: List[Any]) -> str:
"""
Auto-detect aggregation strategy from tile outputs.
Args:
tile_outputs: List of outputs from tiles
Returns:
Strategy name: 'concatenate' for 2D+ spatial data, 'mean' for scalars/1D
"""
if not tile_outputs:
return "first"
first_output = tile_outputs[0]
if isinstance(first_output, dict):
# Check first array in dict
for v in first_output.values():
if isinstance(v, np.ndarray):
return "concatenate" if v.ndim >= 2 else "mean"
return "first"
elif isinstance(first_output, np.ndarray):
return "concatenate" if first_output.ndim >= 2 else "mean"
else:
return "first" # Non-array data
def _concatenate_spatial(
self,
tile_outputs: List[Any],
tile_specs: List[TileSpecGridded],
target_shape: Tuple[int, int],
) -> Any:
"""
Concatenate spatial arrays using tile specifications.
Args:
tile_outputs: List of tile outputs
tile_specs: List of tile specifications
target_shape: Target output shape
Returns:
Assembled spatial arrays with target_shape
"""
first_output = tile_outputs[0]
if isinstance(first_output, dict):
# Merge each key separately
result = {}
for key in first_output.keys():
arrays = [output[key] for output in tile_outputs if isinstance(output.get(key), np.ndarray)]
if arrays:
result[key] = self._assemble_grid(arrays, tile_specs, target_shape)
return result
else:
return self._assemble_grid(tile_outputs, tile_specs, target_shape)
def _average_statistics(self, tile_outputs: List[Any]) -> Any:
"""
Average statistical outputs across tiles.
Args:
tile_outputs: List of tile outputs
Returns:
Averaged result
"""
if not tile_outputs:
return None
first_output = tile_outputs[0]
if isinstance(first_output, dict):
result = {}
for key in first_output.keys():
values = [output[key] for output in tile_outputs if key in output]
if values:
if isinstance(values[0], np.ndarray):
result[key] = np.mean(values, axis=0)
else:
result[key] = np.mean(values)
return result
elif isinstance(first_output, np.ndarray):
return np.mean(tile_outputs, axis=0)
else:
return np.mean(tile_outputs)
def _weighted_average_statistics(
self,
tile_outputs: List[Any],
tile_specs: List[TileSpecGridded],
) -> Any:
"""
Weighted average across tiles based on tile sizes.
Larger tiles get higher weights in the average.
Args:
tile_outputs: List of tile outputs
tile_specs: List of tile specifications
Returns:
Weighted average result
"""
if not tile_outputs:
return None
# Compute weights based on tile size
tile_sizes = np.array([np.prod(spec.target_shape) for spec in tile_specs])
weights = tile_sizes / tile_sizes.sum()
first_output = tile_outputs[0]
if isinstance(first_output, dict):
result = {}
for key in first_output.keys():
values = [output[key] for output in tile_outputs if key in output]
if values:
if isinstance(values[0], np.ndarray):
# Weighted average for each array
weighted = sum(v * w for v, w in zip(values, weights))
result[key] = weighted
else:
result[key] = sum(v * w for v, w in zip(values, weights))
return result
elif isinstance(first_output, np.ndarray):
return sum(v * w for v, w in zip(tile_outputs, weights))
else:
return sum(v * w for v, w in zip(tile_outputs, weights))
def _assemble_grid(
self,
arrays: List[np.ndarray],
tile_specs: List[TileSpecGridded],
target_shape: Tuple[int, int],
) -> np.ndarray:
"""
Assemble tiled arrays into final grid using tile specifications.
Args:
arrays: List of tile arrays
tile_specs: List of tile specifications
target_shape: Target output shape
Returns:
Final assembled array with target_shape
"""
if not arrays or not tile_specs:
return None
# Initialize output array with first array's dtype
dtype = arrays[0].dtype
output = np.zeros(target_shape, dtype=dtype)
# Place each tile into its position
for array, spec in zip(arrays, tile_specs):
output[spec.out_slice] = array
return output
def _compute_source_cache_key(self, data_source: Any) -> str:
"""Compute cache key for data source."""
source_str = str(data_source)
extent = self.terrain.dem_bounds
cache_str = f"{source_str}:{extent}"
return hashlib.sha256(cache_str.encode()).hexdigest()
def _compute_step_cache_key(
self, step_name: str, func: Callable, kwargs: Dict, upstream_key: str
) -> str:
"""
Compute cache key for a pipeline step.
Includes:
- Step name
- Function source code hash
- Step kwargs
- Upstream cache key (dependency tracking)
"""
try:
# Hash function source code
func_source = inspect.getsource(func)
func_hash = hashlib.sha256(func_source.encode()).hexdigest()
except (OSError, TypeError):
# Fallback for built-in functions or functions without source
func_hash = hashlib.sha256(func.__name__.encode()).hexdigest()
# Hash kwargs (convert to sorted json for stability)
kwargs_str = json.dumps(kwargs, sort_keys=True, default=str)
# Combine all components
cache_str = f"{step_name}:{func_hash}:{kwargs_str}:{upstream_key}"
return hashlib.sha256(cache_str.encode()).hexdigest()
def _needs_terrain_params(self, func: Callable) -> bool:
"""Check if function signature expects extent/target_shape parameters."""
try:
sig = inspect.signature(func)
params = sig.parameters.keys()
return "extent" in params or "target_shape" in params
except (ValueError, TypeError):
return False
def _save_step_cache(self, cache_file: Path, data: Any):
"""Save step output to cache."""
try:
if isinstance(data, dict):
# Save dict of arrays
np.savez_compressed(cache_file, **data)
elif isinstance(data, np.ndarray):
# Save single array
np.savez_compressed(cache_file, data=data)
else:
# Pickle for other types
np.savez_compressed(cache_file, data=np.array(data, dtype=object))
except Exception as e:
logger.warning(f"Failed to save cache: {e}")
def _load_step_cache(self, cache_file: Path) -> Any:
"""Load step output from cache."""
try:
with np.load(cache_file, allow_pickle=True) as npz:
if len(npz.files) == 1 and "data" in npz.files:
# Single array or pickled object
data = npz["data"]
return data.item() if data.dtype == object else data
else:
# Dict of arrays - need to extract object arrays back to Python types
result = {}
for k in npz.files:
arr = npz[k]
# Object arrays with ndim=0 are pickled Python objects
if arr.dtype == object and arr.ndim == 0:
result[k] = arr.item()
else:
result[k] = arr
return result
except Exception as e:
logger.warning(f"Failed to load cache: {e}")
return None
# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================
[docs]
def downsample_for_viz(
arr: np.ndarray, max_dim: int = 2000
) -> Tuple[np.ndarray, int]:
"""
Downsample array using stride slicing for visualization.
Args:
arr: Input array to downsample
max_dim: Maximum dimension size for output
Returns:
Tuple of (downsampled_array, stride_used)
"""
max_shape = max(arr.shape)
if max_shape <= max_dim:
return arr, 1
stride = max(1, max_shape // max_dim)
downsampled = arr[::stride, ::stride]
return downsampled, stride
[docs]
def create_mock_snow_data(shape: Tuple[int, int]) -> Dict[str, np.ndarray]:
"""
Create mock snow data for testing.
Generates realistic-looking mock snow statistics using statistical
distributions that mimic real SNODAS patterns.
Args:
shape: Shape of the snow data arrays (height, width)
Returns:
Dictionary with mock snow statistics:
- median_max_depth: Snow depth in mm (gamma distribution)
- mean_snow_day_ratio: Fraction of days with snow (beta distribution)
- interseason_cv: Year-to-year variability (beta distribution)
- mean_intraseason_cv: Within-winter variability (beta distribution)
"""
logger.info(f"Creating mock snow data at shape {shape}...")
np.random.seed(42)
# Snow depth (0-300mm, concentrated in certain areas)
median_max_depth = np.random.gamma(2, 30, shape).astype(np.float32)
median_max_depth = np.clip(median_max_depth, 0, 300)
# Snow coverage ratio (0-1, mostly high in winter)
mean_snow_day_ratio = np.random.beta(8, 2, shape).astype(np.float32)
# Variability (0-1, lower is more consistent)
interseason_cv = np.random.beta(2, 8, shape).astype(np.float32) * 0.5
mean_intraseason_cv = np.random.beta(2, 8, shape).astype(np.float32) * 0.3
return {
"median_max_depth": median_max_depth,
"mean_snow_day_ratio": mean_snow_day_ratio,
"interseason_cv": interseason_cv,
"mean_intraseason_cv": mean_intraseason_cv,
}