Source code for src.terrain.data_loading

"""
Data loading operations for terrain processing.

This module contains functions for loading and merging DEM (Digital Elevation Model)
files from various sources.
"""

import logging
import zipfile
from pathlib import Path
import numpy as np
import rasterio
from rasterio.merge import merge
from tqdm import tqdm

logger = logging.getLogger(__name__)


def _extract_dem_from_zips(directory: Path, pattern: str = "*.hgt") -> int:
    """
    Extract DEM files from ZIP archives if no matching files are found.

    NASADEM downloads come as ZIP files (e.g., NASADEM_HGT_N32W117.zip).
    This function automatically extracts matching files from ZIPs when needed.

    Args:
        directory: Directory containing ZIP files
        pattern: File pattern to extract (e.g., "*.hgt")

    Returns:
        Number of files extracted
    """
    # Check if matching files already exist
    glob_pattern = pattern
    existing_files = list(directory.glob(glob_pattern))
    if existing_files:
        return 0  # Files already exist, no extraction needed

    # Look for ZIP files
    zip_files = list(directory.glob("*.zip"))
    if not zip_files:
        return 0  # No ZIP files to extract

    logger.info(f"No {pattern} files found, extracting from {len(zip_files)} ZIP archives...")

    extracted_count = 0
    for zip_path in zip_files:
        try:
            with zipfile.ZipFile(zip_path, 'r') as zf:
                # Extract files matching the pattern
                # Convert glob pattern to simple extension check
                ext = pattern.lstrip("*")  # e.g., "*.hgt" -> ".hgt"

                matching_members = [
                    m for m in zf.namelist()
                    if m.lower().endswith(ext.lower())
                ]

                for member in matching_members:
                    zf.extract(member, directory)
                    extracted_count += 1
                    logger.debug(f"  Extracted {member} from {zip_path.name}")

        except zipfile.BadZipFile:
            logger.warning(f"Skipping invalid ZIP file: {zip_path.name}")
        except Exception as e:
            logger.warning(f"Failed to extract from {zip_path.name}: {e}")

    if extracted_count > 0:
        logger.info(f"Extracted {extracted_count} files from ZIP archives")

    return extracted_count


def load_dem_files(
    directory_path: str, pattern: str = "*.hgt", recursive: bool = False
) -> tuple[np.ndarray, rasterio.Affine]:
    """
    Load and merge DEM files from a directory into a single elevation dataset.
    Supports any raster format readable by rasterio (HGT, GeoTIFF, etc.).

    Automatically extracts files from ZIP archives if no matching files are found.
    This is useful for NASADEM downloads which come as ZIP files.

    Args:
        directory_path: Path to directory containing DEM files (or ZIP archives)
        pattern: File pattern to match (default: ``*.hgt``)
        recursive: Whether to search subdirectories recursively (default: False)

    Returns:
        tuple: (merged_dem, transform) where:
            - merged_dem: numpy array containing the merged elevation data
            - transform: affine transform mapping pixel to geographic coordinates

    Raises:
        ValueError: If no valid DEM files are found or directory doesn't exist
        OSError: If directory access fails or file reading fails
        rasterio.errors.RasterioIOError: If there are issues reading the DEM files
    """
    logger.info(f"Searching for DEM files matching '{pattern}' in: {directory_path}")

    try:
        directory = Path(directory_path)

        if not directory.exists():
            raise ValueError(f"Directory does not exist: {directory}")

        if not directory.is_dir():
            raise ValueError(f"Path is not a directory: {directory}")

        # Extract from ZIP archives if no matching files exist
        _extract_dem_from_zips(directory, pattern)

        # Find all matching files
        glob_func = directory.rglob if recursive else directory.glob
        dem_files = sorted(glob_func(pattern))

        if not dem_files:
            raise ValueError(f"No files matching '{pattern}' found in {directory}")

        # Validate and open files
        dem_datasets = []
        with tqdm(dem_files, desc="Opening DEM files") as pbar:
            for file in pbar:
                try:
                    ds = rasterio.open(file)

                    # Basic validation
                    if ds.count == 0:
                        logger.warning(f"No raster bands found in {file}")
                        ds.close()
                        continue

                    if ds.dtypes[0] not in ("int16", "int32", "float32", "float64"):
                        logger.warning(f"Unexpected data type in {file}: {ds.dtypes[0]}")
                        ds.close()
                        continue

                    dem_datasets.append(ds)
                    pbar.set_postfix({"opened": len(dem_datasets)})

                except rasterio.errors.RasterioIOError as e:
                    logger.warning(f"Failed to open {file}: {str(e)}")
                    continue
                except Exception as e:
                    logger.error(f"Unexpected error with {file}: {str(e)}")
                    continue

        if not dem_datasets:
            raise ValueError("No valid DEM files could be opened")

        logger.info(f"Successfully opened {len(dem_datasets)} DEM files")

        # Merge datasets
        try:
            with rasterio.Env():
                merged_dem, transform = merge(dem_datasets)

                # Extract first band - merge() returns 3D array (bands, height, width)
                merged_dem = merged_dem[0]

                logger.info(f"Successfully merged DEMs:")
                logger.info(f"  Output shape: {merged_dem.shape}")
                logger.info(
                    f"  Value range: {np.nanmin(merged_dem):.2f} to {np.nanmax(merged_dem):.2f}"
                )
                logger.info(f"  Transform: {transform}")

                return merged_dem, transform

        finally:
            # Clean up
            for ds in dem_datasets:
                ds.close()

    except Exception as e:
        logger.error(f"Error processing DEM files: {str(e)}")
        raise


def load_score_grid(
    file_path: Path,
    data_keys: list[str] = None,
) -> tuple[np.ndarray, rasterio.Affine | None]:
    """
    Load georeferenced raster data from an NPZ file.

    Works with any NPZ file containing a 2D array and optional Affine transform.
    Common use cases: score grids, classification maps, derived terrain products.

    The function searches for data arrays using the provided keys, falling back
    to common key names and finally to the first available array.

    Args:
        file_path: Path to .npz file
        data_keys: Keys to try for data array. If None, tries ["data", "score", "values"]
            then falls back to first array in file.

    Returns:
        Tuple of (data_array, transform) where:
            - data_array: 2D numpy array with the raster data
            - transform: Affine transform or None if not present in file

    Raises:
        FileNotFoundError: If file doesn't exist
        ValueError: If file contains no arrays

    Example:
        >>> scores, transform = load_score_grid("path/to/scores.npz")
        >>> if transform:
        ...     terrain.add_data_layer("scores", scores, transform, "EPSG:4326")
    """
    file_path = Path(file_path)

    if not file_path.exists():
        raise FileNotFoundError(f"Score file not found: {file_path}")

    logger.info(f"Loading score grid from {file_path}")

    # Load NPZ file
    data = np.load(file_path)

    # Determine keys to search for data
    if data_keys is None:
        data_keys = ["data", "score", "values"]

    # Find the data array
    data_array = None
    for key in data_keys:
        if key in data:
            data_array = data[key]
            logger.debug(f"  Found data under key '{key}'")
            break

    # Fallback to first available array
    if data_array is None:
        available_keys = list(data.files)
        # Exclude 'transform' and other metadata keys
        array_keys = [k for k in available_keys if k not in ("transform", "crs")]
        if array_keys:
            first_key = array_keys[0]
            data_array = data[first_key]
            logger.debug(f"  Using first available array under key '{first_key}'")
        else:
            raise ValueError(f"No data arrays found in {file_path}")

    # Extract transform if present
    transform = None
    if "transform" in data:
        t = data["transform"]
        transform = rasterio.Affine(t[0], t[1], t[2], t[3], t[4], t[5])
        logger.debug(f"  Loaded transform: origin=({t[2]:.4f}, {t[5]:.4f})")

    logger.info(f"  Shape: {data_array.shape}, dtype: {data_array.dtype}")
    if transform:
        logger.info(f"  Transform: pixel size=({transform.a:.6f}, {transform.e:.6f})")
    else:
        logger.info("  No transform metadata (will need same_extent_as for alignment)")

    return data_array, transform


[docs] def save_score_grid( file_path: Path, data: np.ndarray, transform: "rasterio.Affine | None" = None, data_key: str = "data", **metadata, ) -> Path: """ Save georeferenced raster data to an NPZ file. Creates an NPZ file compatible with load_score_grid(). The transform is stored as a 6-element array that can be reconstructed as an Affine. Args: file_path: Output path for .npz file data: 2D numpy array with raster data transform: Optional Affine transform for georeferencing data_key: Key name for the data array (default: "data") **metadata: Additional key=value pairs to store in the file Returns: Path to the saved file Example: >>> from rasterio import Affine >>> scores = compute_sledding_scores(dem) >>> transform = Affine.translation(-83.5, 42.5) * Affine.scale(0.01, -0.01) >>> save_score_grid("scores.npz", scores, transform, crs="EPSG:4326") >>> # Load it back >>> loaded_scores, loaded_transform = load_score_grid("scores.npz") """ file_path = Path(file_path) file_path.parent.mkdir(parents=True, exist_ok=True) # Build save dict save_dict = {data_key: data} # Add transform as array if transform is not None: save_dict["transform"] = [ transform.a, transform.b, transform.c, transform.d, transform.e, transform.f, ] # Add any additional metadata save_dict.update(metadata) # Save np.savez(file_path, **save_dict) logger.info(f"Saved score grid to {file_path}") logger.info(f" Shape: {data.shape}, dtype: {data.dtype}") if transform: logger.info(f" Transform: origin=({transform.c:.4f}, {transform.f:.4f})") return file_path
[docs] def find_score_file( name: str, search_dirs: list[Path] = None, subdirs: list[str] = None, ) -> Path | None: """ Search for a score file in common locations. Useful for finding pre-computed score files that may be in various locations depending on how the pipeline was run. Args: name: Base filename to search for (e.g., "sledding_scores.npz") search_dirs: List of directories to search. Defaults to common locations. subdirs: Subdirectories to check within each search_dir (e.g., ["sledding", "xc_skiing"]) Returns: Path to found file, or None if not found Example: >>> path = find_score_file("sledding_scores.npz", ... search_dirs=[Path("docs/images"), Path("output")], ... subdirs=["sledding", ""]) >>> if path: ... scores, transform = load_score_grid(path) """ if search_dirs is None: search_dirs = [ Path("docs/images"), Path("examples/output"), Path("output"), Path("."), ] if subdirs is None: subdirs = [""] # Just search in the directory itself for search_dir in search_dirs: for subdir in subdirs: if subdir: check_path = search_dir / subdir / name else: check_path = search_dir / name if check_path.exists(): logger.debug(f"Found score file: {check_path}") return check_path logger.debug(f"Score file not found: {name}") return None
# ============================================================================= # HGT FILE UTILITIES # ============================================================================= def parse_hgt_filename(filename: str | Path) -> tuple[int | None, int | None]: """ Parse SRTM HGT filename to extract latitude and longitude. Works globally with standard SRTM naming convention: - N42W083.hgt (Northern/Western hemisphere) -> lat=42, lon=-83 - S15E028.hgt (Southern/Eastern hemisphere) -> lat=-15, lon=28 Args: filename: HGT filename or Path (e.g., "N42W083.hgt" or Path("/path/to/N42W083.hgt")) Returns: Tuple of (latitude, longitude) as signed integers, or (None, None) if invalid """ import re # Extract just the filename if it's a path name = Path(filename).stem.upper() # Match pattern: N/S followed by 2 digits, then E/W followed by 3 digits pattern = r'^([NS])(\d{2})([EW])(\d{3})$' match = re.match(pattern, name) if not match: return None, None ns, lat_str, ew, lon_str = match.groups() lat = int(lat_str) lon = int(lon_str) # Apply sign based on hemisphere if ns == 'S': lat = -lat if ew == 'W': lon = -lon return lat, lon
[docs] def load_filtered_hgt_files( dem_dir: Path | str, min_latitude: int = None, max_latitude: int = None, min_longitude: int = None, max_longitude: int = None, bbox: tuple = None, pattern: str = "*.hgt", ) -> tuple[np.ndarray, rasterio.Affine]: """ Load SRTM HGT files filtered by latitude/longitude range. Works globally with standard SRTM naming convention. Filters files before loading to reduce memory usage for large DEM directories. Args: dem_dir: Directory containing HGT files min_latitude: Southern bound (e.g., -45 for S45, 42 for N42) max_latitude: Northern bound (e.g., 60 for N60) min_longitude: Western bound (e.g., -120 for W120) max_longitude: Eastern bound (e.g., 30 for E30) bbox: Bounding box as (west, south, east, north) tuple. If provided, overrides individual min/max parameters. Uses standard GIS convention: (min_lon, min_lat, max_lon, max_lat). pattern: File pattern to match (default: ``*.hgt``) Returns: Tuple of (merged_dem, transform) Raises: ValueError: If no matching files found after filtering Example: >>> # Load only tiles in Michigan area using individual params >>> dem, transform = load_filtered_hgt_files( ... "/path/to/srtm", ... min_latitude=41, max_latitude=47, ... min_longitude=-90, max_longitude=-82 ... ) >>> # Same area using bbox (west, south, east, north) >>> dem, transform = load_filtered_hgt_files( ... "/path/to/srtm", ... bbox=(-90, 41, -82, 47) ... ) >>> # Alps region (Switzerland/Austria) >>> dem, transform = load_filtered_hgt_files( ... "/path/to/srtm", ... bbox=(5, 45, 15, 48) ... ) """ # If bbox provided, unpack into individual params if bbox is not None: min_longitude, min_latitude, max_longitude, max_latitude = bbox dem_dir = Path(dem_dir) logger.info(f"Loading HGT files from {dem_dir}") if min_latitude is not None or max_latitude is not None: logger.info(f" Latitude filter: [{min_latitude}, {max_latitude}]") if min_longitude is not None or max_longitude is not None: logger.info(f" Longitude filter: [{min_longitude}, {max_longitude}]") # Find all matching files all_files = list(dem_dir.glob(pattern)) if not all_files: raise ValueError(f"No files matching '{pattern}' found in {dem_dir}") # Filter by lat/lon if specified filtered_files = [] for f in all_files: lat, lon = parse_hgt_filename(f) # Skip files that couldn't be parsed if lat is None or lon is None: continue # Apply filters if min_latitude is not None and lat < min_latitude: continue if max_latitude is not None and lat > max_latitude: continue if min_longitude is not None and lon < min_longitude: continue if max_longitude is not None and lon > max_longitude: continue filtered_files.append(f) if not filtered_files: raise ValueError( f"No HGT files found matching lat/lon filters in {dem_dir}" ) logger.info(f" Found {len(filtered_files)} files after filtering (from {len(all_files)} total)") # Load the filtered files directly using rasterio dem_datasets = [] for file in filtered_files: try: ds = rasterio.open(file) if ds.count > 0: dem_datasets.append(ds) except rasterio.errors.RasterioIOError as e: logger.warning(f"Failed to open {file}: {str(e)}") continue if not dem_datasets: raise ValueError("No valid HGT files could be opened after filtering") try: with rasterio.Env(): merged_dem, transform = merge(dem_datasets) merged_dem = merged_dem[0] # Extract first band logger.info(f" Merged {len(dem_datasets)} HGT files:") logger.info(f" Shape: {merged_dem.shape}") logger.info(f" Value range: {np.nanmin(merged_dem):.2f} to {np.nanmax(merged_dem):.2f}") return merged_dem, transform finally: for ds in dem_datasets: ds.close()
def load_geotiff_cropped_to_dem( geotiff_path: Path, dem_shape: tuple, dem_transform, dem_crs: str, use_windowed_read: bool = True, ) -> tuple: """ Load a GeoTIFF file cropped to DEM's geographic bounds. This is a common pattern for loading auxiliary data (precipitation, land cover, etc.) that needs to be aligned with a DEM. The function: 1. Crops to DEM's geographic bounds (via windowed reading if same CRS) 2. Returns data with its transform for further processing Args: geotiff_path: Path to GeoTIFF file (e.g., precipitation, land cover) dem_shape: DEM shape (height, width) as tuple dem_transform: DEM's affine transform dem_crs: DEM's coordinate reference system (e.g., "EPSG:4326") use_windowed_read: If True and CRS match, use windowed reading for efficiency Returns: tuple: (data, transform, crs) where: - data: np.ndarray cropped to DEM bounds - transform: Affine transform for the cropped data - crs: Coordinate reference system Example: >>> precip_data, precip_transform, precip_crs = load_geotiff_cropped_to_dem( ... precip_path, dem.shape, dem_transform, "EPSG:4326" ... ) >>> # precip_data is now cropped to DEM's geographic bounds Notes: - If CRS match and windowed reading is enabled, only loads overlapping region (memory efficient) - If CRS differ, loads full file (caller should use rasterio.reproject) - Falls back to full read if windowed read fails """ logger.info(f"Loading {geotiff_path.name} cropped to DEM bounds...") with rasterio.open(geotiff_path) as src: src_crs = src.crs # If CRS differ, can't use windowed reading - return full file if src_crs != dem_crs: logger.info(f" CRS differ ({src_crs} vs {dem_crs}), loading full file for reprojection") data = src.read(1).astype(np.float32) return data, src.transform, src_crs # Same CRS - use windowed reading if enabled if use_windowed_read: try: from rasterio.windows import from_bounds from rasterio.transform import array_bounds # Get DEM bounds dem_bounds = array_bounds(dem_shape[0], dem_shape[1], dem_transform) # Calculate window that overlaps with DEM window = from_bounds(*dem_bounds, transform=src.transform) # Read windowed data data = src.read(1, window=window).astype(np.float32) transform = src.window_transform(window) logger.info(f" ✓ Cropped {src.shape}{data.shape} using windowed read") return data, transform, src_crs except Exception as e: logger.warning(f" Windowed read failed ({e}), falling back to full read") # Fallback: load full file logger.info(f" Loading full file ({src.shape})") data = src.read(1).astype(np.float32) return data, src.transform, src_crs