"""
SNODAS data processing utilities.
Functions for loading, processing, and computing statistics from
NOAA SNODAS (Snow Data Assimilation System) snow depth data.
Pipeline steps:
- batch_process_snodas_data: Load and reproject SNODAS files
- calculate_snow_statistics: Compute aggregated snow statistics
High-level orchestration:
- load_snodas_stats: Complete pipeline with tiling, caching, and mock fallback
"""
import gzip
import logging
import shutil
from concurrent.futures import ProcessPoolExecutor, as_completed
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional, Tuple, Any
import numpy as np
import rasterio
import rasterio.crs as rcrs
from affine import Affine
from rasterio.warp import reproject, Resampling
from tqdm.auto import tqdm
logger = logging.getLogger(__name__)
# =============================================================================
# SNODAS File I/O Helpers
# =============================================================================
def _gunzip_snodas_file(gz_path, keep_original: bool = True) -> Path:
"""Decompress .gz to the same folder, remove if keep_original=False."""
gz_path = Path(gz_path)
out_path = gz_path.with_suffix("")
logger.debug(f"Gunzip {gz_path} -> {out_path}")
try:
with gzip.open(gz_path, "rb") as fin:
with open(out_path, "wb") as fout:
shutil.copyfileobj(fin, fout)
if not keep_original:
gz_path.unlink()
return out_path
except Exception as e:
logger.error(f"Failed to gunzip {gz_path}: {e}")
if out_path.exists():
out_path.unlink()
raise
def _read_snodas_header(header_path) -> Dict[str, Any]:
"""Parse SNODAS header file to get transform, dimensions, etc."""
logger.debug(f"Reading SNODAS header: {header_path}")
field_types = {
"data_units": str,
"data_intercept": float,
"data_slope": float,
"no_data_value": float,
"number_of_columns": int,
"number_of_rows": int,
"data_bytes_per_pixel": int,
"minimum_data_value": float,
"maximum_data_value": float,
"horizontal_datum": str,
"horizontal_precision": float,
"projected": str,
"geographically_corrected": str,
"benchmark_x-axis_coordinate": float,
"benchmark_y-axis_coordinate": float,
"x-axis_resolution": float,
"y-axis_resolution": float,
"x-axis_offset": float,
"y-axis_offset": float,
"minimum_x-axis_coordinate": float,
"maximum_x-axis_coordinate": float,
"minimum_y-axis_coordinate": float,
"maximum_y-axis_coordinate": float,
"benchmark_column": int,
"benchmark_row": int,
"start_year": int,
"start_month": int,
"start_day": int,
"start_hour": int,
"start_minute": int,
"start_second": int,
}
meta = {}
try:
with open(header_path, "r") as f:
for line in f:
parts = line.split(":", 1)
if len(parts) != 2:
continue
key, val = parts
key = key.strip().lower().replace(" ", "_")
val = val.strip()
if val == "Not applicable":
continue
if key in field_types:
try:
meta[key] = field_types[key](val)
except ValueError:
meta[key] = val
else:
meta[key] = val
meta["width"] = meta["number_of_columns"]
meta["height"] = meta["number_of_rows"]
# Build transform only if coordinate fields are present
required_coords = [
"minimum_x-axis_coordinate",
"minimum_y-axis_coordinate",
"maximum_x-axis_coordinate",
"maximum_y-axis_coordinate",
]
if all(coord in meta for coord in required_coords):
meta["transform"] = rasterio.transform.from_bounds(
meta["minimum_x-axis_coordinate"],
meta["minimum_y-axis_coordinate"],
meta["maximum_x-axis_coordinate"],
meta["maximum_y-axis_coordinate"],
meta["width"],
meta["height"],
)
meta["crs"] = "EPSG:4326"
return meta
except Exception as e:
logger.error(f"Error parsing header {header_path}: {e}")
raise
def _read_snodas_binary(binary_path, meta: Dict[str, Any]) -> np.ma.MaskedArray:
"""Read big-endian 16-bit .dat, applying slope/intercept, masking no_data."""
logger.debug(f"Reading SNODAS binary: {binary_path}")
try:
arr = np.fromfile(binary_path, dtype=">i2")
arr = arr.reshape((meta["height"], meta["width"]))
if "data_slope" in meta and "data_intercept" in meta:
arr = arr * meta["data_slope"] + meta["data_intercept"]
arr = np.ma.masked_equal(arr, meta["no_data_value"])
return arr
except Exception as e:
logger.error(f"Error reading binary {binary_path}: {e}")
raise
def _load_snodas_data(binary_path, header_path) -> Tuple[np.ma.MaskedArray, Dict]:
"""
Load raw SNODAS .dat + .txt into a masked array + metadata.
"""
logger.debug(f"Loading SNODAS from {binary_path}")
binary_path = Path(binary_path)
header_path = Path(header_path)
# Handle binary file decompression
loaded_uncompressed = False
if binary_path.suffix == ".gz":
uncompressed_binary_path = binary_path.with_suffix("")
if not uncompressed_binary_path.exists():
uncompressed_binary_path = _gunzip_snodas_file(binary_path)
else:
loaded_uncompressed = True
else:
# Already uncompressed
uncompressed_binary_path = binary_path
# Handle header file decompression
if header_path.suffix == ".gz":
uncompressed_header_path = header_path.with_suffix("")
if not uncompressed_header_path.exists():
uncompressed_header_path = _gunzip_snodas_file(header_path)
else:
loaded_uncompressed = True
else:
# Already uncompressed
uncompressed_header_path = header_path
if loaded_uncompressed:
logger.debug("Decompressed file exists, did not decompress.")
meta = _read_snodas_header(uncompressed_header_path)
data = _read_snodas_binary(uncompressed_binary_path, meta)
return data, meta
def _process_single_file(args) -> Optional[Tuple[datetime, Path]]:
"""Helper for parallel .dat.gz -> processed .npz conversion."""
(
dat_file, # Path to .dat.gz
extent, # (minx, miny, maxx, maxy)
processed_dir, # Path to output directory
target_dims, # (tgt_h, tgt_w, tgt_transform)
resample_to_extent,
) = args
try:
date_str = dat_file.stem.split("_")[2].split(".")[0]
header_file = dat_file.parent / f"snow_depth_metadata_{date_str}.txt.gz"
if not header_file.exists():
logger.warning(f"Missing header for {dat_file}")
return None
if resample_to_extent:
proc_tag = "processed"
else:
proc_tag = "raw"
date = datetime.strptime(date_str, "%Y%m%d")
out_file = processed_dir / f"snodas_{proc_tag}_{date_str}.npz"
target_height, target_width, target_transform = target_dims
# If out_file exists with correct shape, skip
if out_file.exists() and resample_to_extent:
logger.debug(f"Checking existing file {out_file}")
try:
with np.load(out_file) as npz:
d = npz["data"]
if d.shape == (target_height, target_width):
return (date, out_file)
except Exception as e:
logger.error(f"Failed to read {out_file} for shape check: {e}", exc_info=True)
pass
if out_file.exists() and not resample_to_extent:
logger.debug(f"Existing raw file exists {out_file}")
return (date, out_file)
# Load raw SNODAS
data, meta = _load_snodas_data(dat_file, header_file)
logger.debug(f"Loaded {dat_file} with shape {data.shape} and metadata: {meta}")
save_kwargs = {}
if resample_to_extent:
# 1) Allocate a buffer the size of your DEM grid
cropped = np.zeros((target_height, target_width), dtype=np.float32)
# 2) Reproject from SNODAS (geographic WGS84) -> target (also WGS84 here)
reproject(
data,
cropped,
src_transform=meta["transform"],
dst_transform=target_transform,
src_crs="EPSG:4326",
dst_crs="EPSG:4326",
resampling=Resampling.bilinear,
)
# 3) Build a NumPy array of the six Affine coefficients of the *target* grid
transform_arr = np.array(list(target_transform), dtype=np.float64)
# 4) SNODAS is always geographic WGS84 -> so we hard-code the CRS string
crs_str = "EPSG:4326"
# 5) Fill in everything to save
save_kwargs = {
"data": cropped,
"transform": transform_arr,
"crs": np.string_(crs_str),
"height": np.int32(target_height),
"width": np.int32(target_width),
"crop_extent": np.array(extent, dtype=np.float64),
"no_data_value": np.array(meta["no_data_value"], dtype=np.float32),
}
else:
# RAW branch: keep the original SNODAS data & transform
data # 2D array at native SNODAS resolution
# 1) Convert the native Affine -> 1x6 array
native_transform = meta["transform"] # rasterio.Affine
transform_arr = np.array(list(native_transform), dtype=np.float64)
# 2) Hard-code CRS = WGS84
crs_str = "EPSG:4326"
# 3) Build kwargs to save the raw grid
save_kwargs = {
"data": data,
"transform": transform_arr,
"crs": np.string_(crs_str),
"height": np.int32(data.shape[0]),
"width": np.int32(data.shape[1]),
"no_data_value": np.array(meta["no_data_value"], dtype=np.float32),
}
# Finally, write out the .npz with only numpy-friendly types
np.savez_compressed(out_file, **save_kwargs)
return (date, out_file)
except Exception as e:
logger.error(f"Error processing {dat_file}: {e}")
return None
def _load_processed_snodas(processed_file) -> Tuple[np.ma.MaskedArray, Dict]:
"""Load processed .npz SNODAS and reconstruct spatial metadata."""
with np.load(processed_file) as npz:
# 1) Pull out the data (masked if necessary)
data = npz["data"]
if "no_data_value" in npz.files:
data = np.ma.masked_equal(data, npz["no_data_value"])
else:
data = np.ma.masked_invalid(data)
meta = {}
# 2) Transform -> Affine
if "transform" in npz.files:
t_arr = npz["transform"] # shape (6,)
meta["transform"] = Affine(*t_arr.tolist())
# 3) CRS -> rasterio.crs.CRS
if "crs" in npz.files:
crs_val = npz["crs"].item() # byte-string or unicode
if isinstance(crs_val, bytes):
crs_str = crs_val.decode()
else:
crs_str = str(crs_val)
meta["crs"] = rcrs.CRS.from_user_input(crs_str)
# 4) Height/Width -> int
if "height" in npz.files:
meta["height"] = int(npz["height"])
if "width" in npz.files:
meta["width"] = int(npz["width"])
# 5) Crop extent -> tuple of floats
if "crop_extent" in npz.files:
ext = npz["crop_extent"] # array([minx, miny, maxx, maxy])
meta["crop_extent"] = tuple(ext.tolist())
# 6) Copy any remaining fields verbatim (e.g. no_data_value)
for key in npz.files:
if key in ("data", "transform", "crs", "height", "width", "crop_extent"):
continue
meta[key] = npz[key]
return data, meta
# =============================================================================
# Pipeline Step Functions for GriddedDataLoader
# =============================================================================
[docs]
def batch_process_snodas_data(
snodas_dir,
extent: Tuple[float, float, float, float],
target_shape: Tuple[int, int],
processed_dir: str = "processed_snodas",
max_workers: int = 14,
) -> Dict[str, Any]:
"""
Step 1: Load and process SNODAS files.
Process all SNODAS .dat.gz files in snodas_dir,
cropping to 'extent' & reprojecting to ~1/120 deg (EPSG:4326).
Args:
snodas_dir: Directory containing SNODAS .dat.gz files
extent: (minx, miny, maxx, maxy) bounding box
target_shape: (height, width) for output arrays
processed_dir: Directory for processed .npz files
max_workers: Number of parallel workers
Returns:
Dict with processed_files mapping dates to file paths
"""
snodas_dir = Path(snodas_dir)
processed_dir = Path(processed_dir)
processed_dir.mkdir(parents=True, exist_ok=True)
if not snodas_dir.exists():
logger.error(f"No SNODAS directory at {snodas_dir}")
return {"processed_files": {}, "num_files": 0}
dat_files = list(snodas_dir.glob("**/snow_depth_*.dat.gz"))
if not dat_files:
logger.error(f"No SNODAS .dat.gz files found in {snodas_dir}")
return {"processed_files": {}, "num_files": 0}
logger.info(f"Found {len(dat_files)} SNODAS data files to process")
# Snodas standard pixel size in geographic degrees
pixel_size = 1.0 / 120.0
target_height, target_width = target_shape
minx, miny, maxx, maxy = extent
target_transform = rasterio.transform.from_bounds(
minx, miny, maxx, maxy, target_width, target_height
)
args_list = [
(
f,
extent,
processed_dir,
(target_height, target_width, target_transform),
True, # resample_to_extent
)
for f in dat_files
]
processed_files = {}
with ProcessPoolExecutor(max_workers=max_workers) as executor:
future_map = {executor.submit(_process_single_file, arg): arg[0] for arg in args_list}
with tqdm(total=len(dat_files), desc="Processing SNODAS files") as pbar:
for future in as_completed(future_map):
dat_file = future_map[future]
try:
result = future.result()
if result is not None:
date, out_file = result
processed_files[date] = out_file
except Exception as e:
logger.error(f"Error processing {dat_file}: {e}")
finally:
pbar.update(1)
logger.info(f"Successfully processed {len(processed_files)} SNODAS files")
return {"processed_files": processed_files, "num_files": len(processed_files)}
[docs]
def calculate_snow_statistics(
input_data: Dict[str, Any],
snow_season_start_month: int = 11,
snow_season_end_month: int = 4,
) -> Dict[str, np.ndarray]:
"""
Step 2: Compute snow statistics from loaded files.
Group processed SNODAS by winter season and compute aggregated stats.
Args:
input_data: Dict with "processed_files" from batch_process_snodas_data
snow_season_start_month: First month of snow season (default: November)
snow_season_end_month: Last month of snow season (default: April)
Returns:
Dict with final statistics:
- median_max_depth: Median of seasonal max depths
- mean_snow_day_ratio: Average fraction of days with snow
- interseason_cv: Year-to-year variability
- mean_intraseason_cv: Within-winter variability
- metadata: Processing metadata
- failed_files: List of files that failed to process
"""
processed_files = input_data["processed_files"]
failed_files = []
# Group by winter season
seasons = {}
for date, filepath in processed_files.items():
if (
snow_season_start_month <= date.month <= 12
or 1 <= date.month <= snow_season_end_month
):
season_year = date.year if date.month >= snow_season_start_month else date.year - 1
key = f"{season_year}-{season_year+1}"
seasons.setdefault(key, []).append(filepath)
logger.info(
f"Computing stats for {len(seasons)} winter seasons, months {snow_season_start_month} to {snow_season_end_month}"
)
seasonal_stats = []
expected_shape = None
metadata = None
for season, filelist in tqdm(seasons.items(), desc="Processing winter seasons"):
# Initialize running accumulators to None
running_max = None
running_sum = None
running_sum_sq = None
running_snow_sum = None # sum of depth on snow-days only
running_snow_sum_sq = None # sum of squared depth on snow-days only
snow_days = None
N_successful = 0
for fpath in tqdm(filelist, desc=f"Loading {season} data", leave=False):
try:
data, meta = _load_processed_snodas(fpath)
if expected_shape is None:
expected_shape = data.shape
logger.info(f"Setting expected shape: {expected_shape}")
elif data.shape != expected_shape:
logger.error(f"Shape mismatch {fpath}")
failed_files.append((fpath, "Shape mismatch"))
continue
if running_max is None:
# first day -> bootstrap all accumulators
running_max = data.copy()
filled = data.filled(0.0).astype(np.float64)
running_sum = filled.copy()
running_sum_sq = filled**2
snow_mask = filled > 0
snow_days = snow_mask.astype(np.int32)
running_snow_sum = np.where(snow_mask, filled, 0.0)
running_snow_sum_sq = np.where(snow_mask, filled**2, 0.0)
else:
# update maximum
running_max = np.maximum(running_max, data)
# update sum & sum of squares (all days, for interseason CV)
filled = data.filled(0.0).astype(np.float64)
running_sum += filled
running_sum_sq += filled**2
# update snow-day accumulators (snow-days only, for intraseason CV)
snow_mask = filled > 0
snow_days += snow_mask.astype(np.int32)
running_snow_sum += np.where(snow_mask, filled, 0.0)
running_snow_sum_sq += np.where(snow_mask, filled**2, 0.0)
N_successful += 1
if metadata is None:
metadata = meta
except Exception as e:
logger.error(f"Error loading {fpath}: {e}")
failed_files.append((fpath, str(e)))
# After the loop, compute final summaries:
if N_successful > 0:
mean_depth = running_sum / N_successful
std_depth = np.sqrt(running_sum_sq / N_successful - mean_depth**2)
else:
mean_depth = np.NaN
std_depth = np.NaN
stats_dict = {
"max_depth": running_max,
"mean_depth": mean_depth,
"std_depth": std_depth,
"snow_days": snow_days,
"total_days": N_successful,
}
logger.debug(
f"Computed stats for season {season}: shapes -> "
f"max {running_max.shape}, mean {mean_depth.shape}, std {std_depth.shape}"
)
# Intraseason CV: computed on snow-days only to avoid contamination from
# zero-depth shoulder-season days, which inflate CV by pulling the mean
# down while std captures the full seasonal arc (0 -> peak -> 0).
# Result: CV measures within-season depth stability when snow is present.
logger.debug(f"Computing intraseason cv_depth for season {season}")
with np.errstate(divide="ignore", invalid="ignore"):
snow_day_count = snow_days.astype(np.float64)
mean_snow_depth = np.where(
snow_day_count > 0, running_snow_sum / snow_day_count, 0.0
)
# Clamp variance to >=0 to guard against floating-point underflow
var_snow_depth = np.where(
snow_day_count > 1,
np.maximum(running_snow_sum_sq / snow_day_count - mean_snow_depth**2, 0.0),
0.0,
)
std_snow_depth = np.sqrt(var_snow_depth)
cv_depth = np.where(
mean_snow_depth > 0, std_snow_depth / mean_snow_depth, 0.0
)
cv_depth = np.where(np.isfinite(cv_depth), cv_depth, 0.0)
stats_dict["cv_depth"] = cv_depth
seasonal_stats.append(stats_dict)
logger.info(
f"Season {season}: "
f"MaxDepth={float(np.ma.max(stats_dict['max_depth'])):.2f}, "
f"MeanSnowDays={float(np.ma.mean(stats_dict['snow_days'])):.1f} / {stats_dict['total_days']}"
)
if not seasonal_stats:
raise ValueError("No valid seasons found after processing SNODAS data.")
# Aggregation across multiple seasons
final_stats = {}
final_stats["median_max_depth"] = np.ma.median(
[s["max_depth"] for s in seasonal_stats], axis=0
)
# Snow day ratio
ratios = [s["snow_days"] / s["total_days"] for s in seasonal_stats]
final_stats["mean_snow_day_ratio"] = np.ma.mean(ratios, axis=0)
# Inter-season variability (year-to-year variation in snow)
seasonal_means = np.ma.stack([s["mean_depth"] for s in seasonal_stats])
with np.errstate(divide="ignore", invalid="ignore"):
interseason_cv = np.ma.std(seasonal_means, axis=0) / np.ma.mean(
seasonal_means, axis=0
)
# Replace NaN/Inf with 0 for areas with no snow across all seasons
interseason_cv = np.where(np.isfinite(interseason_cv), interseason_cv, 0.0)
final_stats["interseason_cv"] = interseason_cv
# Intra-season average CV (within-winter variability, averaged across years)
mean_intraseason_cv = np.ma.mean(
[s["cv_depth"] for s in seasonal_stats], axis=0
)
# Ensure no NaN values remain (should already be handled above, but be safe)
mean_intraseason_cv = np.where(
np.isfinite(mean_intraseason_cv), mean_intraseason_cv, 0.0
)
final_stats["mean_intraseason_cv"] = mean_intraseason_cv
logger.info(
f"Final stats computed: "
f"median_max_depth range={float(np.nanmin(final_stats['median_max_depth'])):.1f}-"
f"{float(np.nanmax(final_stats['median_max_depth'])):.1f}, "
f"mean_intraseason_cv range={float(np.nanmin(final_stats['mean_intraseason_cv'])):.2f}-"
f"{float(np.nanmax(final_stats['mean_intraseason_cv'])):.2f}"
)
# Add metadata to result
final_stats["metadata"] = metadata
final_stats["failed_files"] = failed_files
return final_stats
[docs]
def load_snodas_stats(
terrain=None,
snodas_dir: Optional[Path] = None,
*,
cache_dir: Path = Path("snodas_cache"),
cache_name: str = "snodas",
tile_config=None,
mock_data: bool = False,
mock_shape: tuple = (500, 500),
) -> Dict[str, np.ndarray]:
"""Load SNODAS snow statistics with automatic tiling, caching, and mock fallback.
Orchestrates the standard SNODAS pipeline:
1. batch_process_snodas_data (load and reproject)
2. calculate_snow_statistics (aggregate seasonal stats)
Uses GriddedDataLoader for memory-safe tiled processing. Falls back to
mock data when real data is unavailable or loading fails.
Args:
terrain: Terrain object providing extent and resolution context.
Required for real data; can be None with mock_data=True.
snodas_dir: Directory containing SNODAS .dat.gz files.
cache_dir: Directory for pipeline caching.
cache_name: Base name for cache files.
tile_config: Optional TiledDataConfig for memory-safe processing.
If None, uses GriddedDataLoader defaults.
mock_data: If True, return mock data without attempting real loading.
mock_shape: Shape for mock data arrays.
Returns:
Dict with snow statistics arrays:
- median_max_depth: Median of seasonal max depths (mm)
- mean_snow_day_ratio: Average fraction of days with snow
- interseason_cv: Year-to-year variability
- mean_intraseason_cv: Within-winter variability
"""
from src.terrain.gridded_data import (
GriddedDataLoader,
create_mock_snow_data,
)
if mock_data:
logger.info("Using mock snow data")
return create_mock_snow_data(mock_shape)
# Validate prerequisites for real data
can_load = True
if not snodas_dir:
logger.warning("No SNODAS directory specified")
can_load = False
elif not Path(snodas_dir).exists():
logger.warning("SNODAS directory not found: %s", snodas_dir)
can_load = False
if not terrain:
logger.warning("Terrain object not available for SNODAS processing")
can_load = False
if not can_load:
logger.info("Falling back to mock data")
return create_mock_snow_data(mock_shape)
try:
logger.info("Loading real SNODAS data from: %s", snodas_dir)
pipeline = [
("load_snodas", batch_process_snodas_data, {}),
("compute_stats", calculate_snow_statistics, {}),
]
loader_kwargs: Dict[str, Any] = {
"terrain": terrain,
"cache_dir": cache_dir,
}
if tile_config is not None:
loader_kwargs["tile_config"] = tile_config
loader = GriddedDataLoader(**loader_kwargs)
result = loader.run_pipeline(
data_source=snodas_dir,
pipeline=pipeline,
cache_name=cache_name,
)
snow_stats = {
k: v for k, v in result.items()
if k not in ("metadata", "failed_files")
}
failed = result.get("failed_files", [])
logger.info("Loaded SNODAS data (%d files failed)", len(failed))
return snow_stats
except Exception as e:
logger.warning("Failed to load SNODAS data: %s", e)
logger.info("Falling back to mock data")
return create_mock_snow_data(mock_shape)