Source code for nwm_region_mgr.utils.config_utils

"""Utility functions for loading and validating configuration files using Pydantic and YAML.

config_utils.py

Classes/Functions:
    - LoggingConfig: Pydantic model for logging configuration.
    - BaseGeneralConfig: Pydantic model for general settings of the application.
    - BaseOutputConfig: Pydantic model for output settings of the application.
    - BaseConfig: Pydantic model for the base configuration of the application.
    - BaseConfigProcessor: base class for processing and validating configurations
    - _deep_merge_configs: Recursively merge two dictionaries, with values from dict #2 overwriting those in dict #1.
    - _load_and_validate_config: Load a YAML file, validate its structure using Pydantic
    - _substitute_placeholders: Substitute placeholders in the config with actual values.
    - load_and_process_config: Load, validate, process, and save the configuration files.
    - _validate_paths: Validate that all file and directory paths exist.
    - _assemble_file_paths: Assemble file paths from the configuration.
    - _required_columns_calval_stats: Return a set of required columns for calibration/validation statistics.
    - _file_required_column_map: Return a dictionary mapping files to required columns.
    - PydanticDictLike: Stand-in for dictionary-like behavior when you want specificity of a pydantic model.
    - FieldCrosswalk: Mapping of column names for unique identifiers in all require files for regionalization.
    - LayerCrosswalk: Dictionary mapping layer names for hydrofabric files.

"""

import logging
import os
import re
from contextlib import contextmanager
from functools import lru_cache, reduce
from pathlib import Path
from time import time
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union

import geopandas as gpd
import pandas as pd
import yaml
from pydantic import BaseModel, Field, ValidationError, model_validator

from nwm_region_mgr.utils.dict_utils import flatten_dict
from nwm_region_mgr.utils.io_utils import read_table, save_data
from nwm_region_mgr.utils.logging_utils import setup_logging
from nwm_region_mgr.utils.plot_utils import plot_histogram, plot_spatial_map
from nwm_region_mgr.utils.string_utils import recursive_substitute
from nwm_region_mgr.utils.validation_utils import (
    check_columns_dataframe,
    check_columns_hydrofabric,
    check_options,
)

logger = logging.getLogger(__name__)


[docs] class LoggingConfig(BaseModel): """Logging configuration for the application.""" level: Literal["debug", "info", "warning", "error", "critical"] = Field( description="Logging level.", examples="debug", default="info" ) log_to_file: bool = Field( description=( "Whether to log to a file. If set to True, logging messages will be written to " "the specified log file, in addition to the console." ), default=False, examples=False, ) file: str | None = Field( description="Path to the log file. If not provided, logging will be written to console only.", examples="logfile.log", default=None, ) @model_validator(mode="after") def check_log_level(self) -> "LoggingConfig": """Ensure that the log level is valid.""" valid_levels = [ "DEBUG", "INFO", "WARNING", "SEVERE", "FATAL", ] # use standard Python log levels check_options(self.level.upper(), valid_levels, "log level") return self @model_validator(mode="after") def check_log_file(self) -> "LoggingConfig": """Ensure that the log file path is valid if logging to a file.""" if self.log_to_file and not self.file: logger.error("If 'log_to_file' is True, 'file' must be specified.") return self
[docs] class PydanticDictLike(BaseModel): """Stand-in for dictionary-like behavior when you want specificity of a pydantic model.""" @property def as_dict(self) -> dict[str, Any]: """Return the model as a dictionary.""" return self.model_dump() def get(self, key: str, default=None): """Get a value from the model by key, with an optional default.""" return getattr(self, key, default) def items(self) -> Iterator[Tuple[str, Any]]: """Return an iterator over the model's (key, value) pairs.""" return self.model_dump().items() def keys(self) -> Iterator[str]: """Return an iterator over the model's keys.""" return self.model_dump().keys() def values(self) -> Iterator[Any]: """Return an iterator over the model's values.""" return self.model_dump().values() def lower_case(self) -> "PydanticDictLike": """Return a new instance with all string fields lowercased.""" data = { k: (v.lower() if isinstance(v, str) else v) for k, v in self.model_dump().items() } return self.__class__(**data)
[docs] class FieldCrosswalk(PydanticDictLike): """Mapping of column names for unique identifiers in all require files for regionalization.""" divide: str = Field( description="Column name for divide (catchment) ID.", default="divide_id" ) gage: str = Field(description="Column name for gage (basin) ID.", default="gage_id") huc12: str = Field(description="Column name for HUC12 ID.", default="huc_12") vpu: str = Field(description="Column name for VPU ID.", default="vpuid") drainage_area: str = Field( description="Column name for drainage area.", default="areasqkm" )
[docs] class LayerCrosswalk(PydanticDictLike): """Dictionary mapping layer names for hydrofabric files.""" huc12: str = Field( description="Layer name for HUC12 hydrofabric file.", default="WBDSnapshot_National", ) ngen: str = Field( description="Layer name for NextGen hydrofabric file.", default="divides" )
[docs] class BaseGeneralConfig(BaseModel): """Base general settings for the formulation regionalization application.""" run_name: str = Field( description="Name of the run, used to create output folders and files.", examples="test", default="test", ) domain: Literal["conus", "ak", "hi", "prvi"] = Field( description="Which National Water Model Domain this run uses.", examples="conus", default="conus", ) vpu_list: Union[List[str], str] = Field( description=( "List of vector processing units (VPUs) to be processed within the domain. " "Set to 'all' to process all VPUs in the domain (not recommended for conus since there are many VPUs)." ), examples=["03S"], default=["03S"], ) n_procs: int = Field( description="Number of processors to use for parallel processing. Set to -1 to use all available processors.", examples=2, default=-1, ) base_dir: str = Field( description="Path to base directory for input/output files.", examples="/root/nwm-region-mgr/data/", default="./data/", ) static_data_dir: str = Field( description="Path to static data directory containing hydrofabric and other static input files.", examples="/ngencerf-app/nwm-region-mgr/inputs/static_data/", default="/ngencerf-app/nwm-region-mgr/inputs/static_data/", ) ngen_hydrofabric_file: Path | str | Dict[str, Path] | Dict[str, str] = Field( description=( "Path to NextGen hydrofabric file. Can be: 1) a single file path (Path or str), e.g., 'vpu_01.gpkg' " "or 2) a dictionary mapping VPU strings to file paths, e.g., {'09': 'vpu_09.gpkg'}." "If providing a string with placeholders like {vpu_list}, they will be substituted accordingly and " "expanded to a dictionary mapping each VPU to its corresponding file." " This file must include columns 'divide_id', 'vpuid' and 'geometry'." ), examples="{base_dir}/inputs/hydrofabric/vpu_09.gpkg", default="vpu_03S.gpkg", ) gage_divide_cwt_file: Path | str = Field( description="Path to CSV or parquet file with gage divide CWTs, with columns 'divide_id' and 'gage_id'.", examples="{base_dir}/inputs/calib_gage_divide_{domain}.parquet", default="calib_gage_divide_{domain}.parquet", ) donor_gage_file: Path | str = Field( description="Path to CSV file with donor gage information, including 'gage_id', 'longitude', and 'latitude'.", examples="{base_dir}/inputs/gages_nwm4_calib_all.csv", default="gages_nwm4_calib_all.csv", ) calval_stats_file: Path | str = Field( description=( "Path to CSV or parquet file with calibration/validation statistics for all calibration gages " "and formulations, e.g., 'stat_calval_all_conus.parquet', 'stat_calval_all_conus.csv'. " "Must include columns for 'gage_id', 'formulation', and relevant metrics to be used for formulation " "and parameter regionalization." ), examples=["stat_calval_all_{domain}.csv", "stat_calval_all_{domain}.parquet"], default="stat_calval_all_{domain}.parquet", ) calib_param_file: Path | str = Field( description=( "Path to CSV or parquet file containing calibrated parameters for all calibration gages " "and formulations in the domain. Must include columns for 'gage_id', 'formulation', and " "calibrated parameters." ), examples=["calib_params_{domain}.csv", "calib_params_{domain}.parquet"], default="calib_params_{domain}.csv", ) approach_calib_basins: Literal["regionalization", "summary_score"] = Field( description=( "Strategy for assigning formulations to calibrated basins. Valid options are 'regionalization' " "(assign the formulation chosen for the region) or 'summary_score' (assign based on formulation " "summary scores for the calibrated basin)." ), examples=["regionalization", "summary_score"], default="summary_score", ) id_col: FieldCrosswalk = Field( description="Dictionary mapping column names for unique identifiers in all applicable files.", examples={ "divide": "divide_id", "gage": "gage_id", "huc12": "huc_12", "vpu": "vpuid", "drainage_area": "areasqkm", }, default_factory=FieldCrosswalk, ) layer_name: LayerCrosswalk = Field( description=( "Dictionary mapping layer names for hydrofabric files. " "Identifies the layer in each hydrofabric file to be used during regionalization." ), examples={ "huc12": "WBDSnapshot_National", "ngen": "divides", }, default_factory=LayerCrosswalk, ) logging: LoggingConfig = Field( description="Logging configuration for the application.", examples={ "level": "info", "log_to_file": True, "file": "logs/{run_name}.log", }, default_factory=LoggingConfig, ) @model_validator(mode="after") def lower_case_ids(self) -> "BaseGeneralConfig": """Ensure that all ID columns are in lower case.""" if self.id_col: self.id_col = self.id_col.lower_case() if self.layer_name: self.layer_name = self.layer_name.lower_case() return self @model_validator(mode="after") def check_vpu_list(self) -> "BaseGeneralConfig": """Ensure that the VPU list is valid.""" valid_vpus = { "conus": [ "01", "02", "03N", "03S", "03W", "04", "05", "06", "07", "08", "09", "10L", "10U", "11", "12", "13", "14", "15", "16", "17", "18", ], "ak": ["ak"], "hi": ["hi"], "prvi": ["prvi"], } if isinstance(self.vpu_list, str) and self.vpu_list.lower() == "all": self.vpu_list = valid_vpus[self.domain] elif isinstance(self.vpu_list, list): for vpu in self.vpu_list: if vpu not in valid_vpus[self.domain]: msg = f"Invalid VPU '{vpu}' for domain '{self.domain}'. Valid options are: {valid_vpus[self.domain]}" logger.error(msg) raise ValueError(msg) else: msg = f"'vpu_list' must be a list of VPUs or 'all'. Got: {self.vpu_list}" logger.error(msg) raise ValueError(msg) return self
[docs] class BaseOutputConfig(BaseModel): """Base Output Manager.""" save: bool = Field( description="Whether to save output files", default=True, examples=True, ) path: Path | str = Field( description="Path to save output file or files. If a directory, the 'stem' and 'format' must be specified.", examples=None, default=None, ) stem: Optional[str | Dict[str, str]] = Field( description="File stem for output files, used to create unique file names based on the path.", default=None, examples=None, ) stem_suffix: Optional[str] = Field( description="Suffix for the file stem, used to create unique file names based on the path for specific needs.", default=None, examples=None, ) format: Optional[str] = Field( description="File format for output files, e.g., 'parquet', 'csv', 'yaml'. If not specified, the path must be a file.", default=None, examples=None, ) plots: Optional[Dict[str, Any]] = Field( description="Configuration for output plots, if applicable.", default=None, examples=None, ) plot_path: Optional[str] = Field( description=( "Path to save output plots, if applicable. If not specified, plots will be saved " "in a subfolder 'plots' in the defined output path." ), default=None, examples=None, ) @model_validator(mode="after") def check_plot_path(cls, values): """Check if plot path is valid.""" if not values.plot_path: values.plot_path = f"{values.path}/plots" logger.debug(f"Plot path not specified, using default: {values.plot_path}") return values @model_validator(mode="after") def check_plot_config(cls, values): """Check if plot configuration is valid.""" if values.plots is not None: if not isinstance(values.plots, dict): msg = f"'plots' must be a dictionary, got {type(values.plots)}" logger.error(msg) raise ValueError(msg) # only "histogram" and "spatial_map" are supported, currently check_options( list(values.plots.keys()), ["histogram", "spatial_map", "columns_to_plot"], "plot keys", ) return values def get_file_path( self, vpu: str = None, algorithm: str = None, plot_type: str = None, use_stem_suffix: bool = False, ) -> Path: """Get the file path for saving the output.""" file_path = Path(self.path) if plot_type is None else Path(self.plot_path) # If the path is a directory, construct the file name using 'stem' and 'format' if not file_path.suffix: if not self.stem or not self.format: msg = f"File 'stem' and 'format' must be specified if 'path' is a directory: {file_path}" logger.error(msg) raise ValueError(msg) if isinstance(self.stem, dict): # If stem is a dict (for different VPUs), find the stem for current VPU if vpu and not algorithm: file_stem = self.stem.get(f"{vpu}") elif vpu and algorithm: file_stem = self.stem.get(f"{vpu}_{algorithm}") elif algorithm and not vpu: file_stem = self.stem.get(f"{algorithm}") else: file_stem = re.sub( r"_vpu.*$", "", next(iter(self.stem.values())) ) # remove VPU part from stem elif isinstance(self.stem, str): file_stem = self.stem else: msg = f"Invalid 'stem' type: {type(self.stem)}. Must be str or dict." logger.error(msg) raise ValueError(msg) if not file_stem: if isinstance(self.stem, dict): file_stem = next(iter(self.stem.values())).replace( f"{next(iter(self.stem))}", f"{vpu}" ) if not file_stem: msg = f"File stem not found for VPU {vpu}: {self.stem}" logger.error(msg) raise ValueError(msg) if use_stem_suffix: if not self.stem_suffix: msg = f"File stem suffix not specified: {self.stem_suffix}" logger.error(msg) raise ValueError(msg) else: file_stem += self.stem_suffix # make sure plot_type is supported if plot_type not in [None, "map", "hist"]: msg = f"Unsupported plot type: {plot_type}. Supported types are None, 'map', 'hist'." logger.error(msg) raise ValueError(msg) file_format = self.format if plot_type is None else "png" file_prefix = "" if plot_type is None else f"{plot_type}_" # Construct the full file path file_path = file_path / f"{file_prefix}{file_stem}.{file_format}" # create the directory if it does not exist if not file_path.parent.exists(): file_path.parent.mkdir(parents=True, exist_ok=True) return file_path def save_to_file( self, data: Any, vpu: str = None, algorithm: str = None, data_str: str = None, use_stem_suffix: bool = False, ) -> None: """Save output data to the specified path and format. Args: data: Data to save, can be a DataFrame or Pydantic model. vpu: VPU identifier for the output file name. algorithm: Algorithm identifier for the output file name. data_str: String representation of the data being saved. use_stem_suffix: Whether to use the stem suffix for the file name. Raises: ValueError: If the output path is a directory and no file name is provided. """ if not self.save: return # get the file path to save the output filepath = self.get_file_path( vpu, algorithm=algorithm, use_stem_suffix=use_stem_suffix ) # save the output data save_data(data, filepath) if data_str is None: logger.info(f"Saved output to {filepath}") else: logger.info(f"Saved {data_str} output to {filepath}") def read_from_file( self, vpu: str = None, algorithm: str = None, use_stem_suffix: bool = False, data_str: str = None, data_type: dict[str, Any] = None, ) -> pd.DataFrame | gpd.GeoDataFrame: """Read output data from the specified path and format. Args: vpu: VPU identifier for the output file name. algorithm: Algorithm identifier for the output file name. use_stem_suffix: Whether to use the stem suffix for the file name. data_str: String representation of the data being read. data_type: Optional dictionary specifying the data types for specific columns. Returns: DataFrame or GeoDataFrame containing the loaded data. Raises: FileNotFoundError: If the file does not exist. """ # get the file path to read the output filepath = self.get_file_path( vpu, algorithm=algorithm, use_stem_suffix=use_stem_suffix ) # read the output data data = read_table(filepath, dtype=data_type) if data_str is None: logger.info(f"Read output from {filepath}") else: logger.info(f"Read {data_str} output from {filepath}") return data def plot_data( self, data: pd.DataFrame | gpd.GeoDataFrame, plot_dict: Dict[str, Any], ) -> None: """Plot the data and save png to the specified path. Args: data: DataFrame or GeoDataFrame containing the data to plot. plot_dict: Dictionary containing plot configuration, including: var_str: String representation of the variable being plotted. columns: list of columns in the data to plot. vpu: VPU identifier for the output file name. plot_type: Type of plot being saved (e.g., 'map', 'hist'). """ # if columns_to_plot is specified by the user, use it if self.plots and self.plots.get("columns_to_plot", None) is not None: plot_dict["columns"] = self.plots["columns_to_plot"] if self.plots and self.plots.get("histogram", False): path1 = self.get_file_path( plot_dict.get("vpu"), algorithm=plot_dict.get("algorithm"), plot_type="hist", ) plot_dict1 = plot_dict.copy() plot_dict1["outfile"] = path1 plot_dict1["ncols"] = min(2, plot_dict1.get("ncols", 2)) # remove non-numeric columns from data from histogram plotting numeric_columns = data.select_dtypes(include=["number"]).columns.tolist() plot_dict1["columns"] = [ col for col in plot_dict1.get("columns", []) if col in numeric_columns ] plot_histogram(data, plot_dict1) if self.plots and self.plots.get("spatial_map", False): path2 = self.get_file_path( plot_dict.get("vpu"), algorithm=plot_dict.get("algorithm"), plot_type="map", ) plot_dict2 = plot_dict.copy() plot_dict2["outfile"] = path2 plot_dict2["ncols"] = min(3, plot_dict2.get("ncols", 3)) plot_spatial_map(data, plot_dict2)
[docs] class BaseConfig(BaseModel): """Base configuration for the application.""" general: BaseGeneralConfig """Base general settings for the regionalization application.""" output: dict[str, BaseOutputConfig] """Base output settings for the regionalization application."""
[docs] class BaseConfigProcessor: """Base configuration processor.""" def __init__( self, config_file: str | Path | list[str] | list[Path], config_schema: BaseModel = Field(...), sample_size: int = None, ): """Initialize regionalization processor.""" if isinstance(config_file, (str, Path)): self.config_file = [config_file] else: self.config_file = config_file self.config_schema = config_schema self.config = self.load_and_process_config self.sample_size = sample_size self._expand_user_file_paths(self.config) self.set_logging() self.validate_files() def _deep_merge_configs(self, a: dict, b: dict) -> dict: """Recursively merge two configuration dictionaries, with values from `b` overwriting those in `a`. Args: a: The base dictionary. b: The dictionary whose values will overwrite those in `a`. Returns: A new dictionary that is the result of merging `a` and `b`. """ result = a.copy() for key, value in b.items(): if ( key in result and isinstance(result[key], dict) and isinstance(value, dict) ): result[key] = self._deep_merge_configs(result[key], value) else: result[key] = value return result def _load_and_validate_config( self, config_paths: list[str], config_schema: BaseModel = Field(...) ) -> BaseModel: """Load a YAML file, validate its structure using Pydantic, and substitute placeholders in the config. Args: config_paths: list of paths to the config files config_schema: Pydantic model to validate the config structure Returns: Config object with validated structure """ try: configs = [] for p in config_paths: with open(p, "r") as f: configs.append(yaml.safe_load(f)) merged_config = reduce(self._deep_merge_configs, configs) config = config_schema(**merged_config) return config except ValidationError as e: logger.exception(f"Validation Error: {e}") raise except Exception as e: logger.exception(f"Error loading YAML file: {e}") raise def _recursive_substitute_until_fixed(self, config, context, max_iter=5): for _ in range(max_iter): new_config = recursive_substitute(config, context) if new_config == config: break config = new_config return config def _substitute_placeholders(self, config: BaseModel) -> BaseModel: """Substitute placeholders in the config with actual values. Args: config: Config object with placeholders Returns: Config object with placeholders substituted """ # Create a context dictionary with general config parameters context = { "domain": config.general.domain if hasattr(config.general, "domain") else None, "run_name": config.general.run_name if hasattr(config.general, "run_name") else None, "static_data_dir": config.general.static_data_dir if hasattr(config.general, "static_data_dir") else None, "base_dir": config.general.base_dir if hasattr(config.general, "base_dir") else None, "vpu_list": config.general.vpu_list if hasattr(config.general, "vpu_list") else None, "algorithm_list": config.general.algorithm_list if hasattr(config.general, "algorithm_list") else None, } # remove items with None values from context context = {k: v for k, v in context.items() if v is not None} # substitute placeholders in the config config = self._recursive_substitute_until_fixed(config, context) return config def _expand_user(self, val: str | Path) -> Path: """Expand user home directory and environment variables in a file path.""" s = str(val) user = os.environ.get("LOGNAME") or os.environ.get("USER") if user: user = user.split("@", 1)[0] s = re.sub(r"\$USER\b", user, s) s = os.path.expandvars(s) s = os.path.expanduser(s) return str(s) def _expand_user_file_paths(self, obj) -> None: if isinstance(obj, BaseModel): for name in obj.__class__.model_fields: val = getattr(obj, name) new_val = self._expand_user_file_paths(val) if new_val is not val: setattr(obj, name, new_val) elif isinstance(obj, dict): for k, v in obj.items(): new_v = self._expand_user_file_paths(v) if new_v is not v: obj[k] = new_v elif isinstance(obj, list): for i, v in enumerate(obj): new_v = self._expand_user_file_paths(v) if new_v is not v: obj[i] = new_v elif isinstance(obj, tuple): return tuple(self._expand_user_file_paths(v) for v in obj) elif isinstance(obj, (str, Path)): s = str(obj) if any(x in s for x in ("~", "$")): return self._expand_user(s) return obj def _required_columns_calval_stats(self, config) -> set[str]: """Return a set of required columns for calibration/validation statistics.""" required_fields = {self.gage_id_name, "formulation"} # required fields for summary score configuration (for formulation regionalization) sc = getattr(config, "summary_score", None) if sc: mp = getattr(sc, "metric_eval_period", None) eval_col = getattr(mp, "col_name", None) if mp else None if eval_col: required_fields.add(eval_col) metrics = getattr(sc, "metrics", None) if isinstance(metrics, dict): required_fields.update(metrics) # required fields for parameter regionalization donor = getattr(config, "donor", None) if donor: eval_pd = getattr(donor, "metric_eval_period", None) if eval_pd: eval_col = getattr(eval_pd, "col_name", None) if eval_col: required_fields.add(eval_col) metrics = getattr(donor, "metric_threshold", None) if isinstance(metrics, dict): required_fields.update(metrics) return required_fields def _file_required_column_map(self, config) -> Dict[str, str]: """Return a dictionary mapping files to required columns.""" file_dict = { "gage_divide_cwt_file": {self.divide_id_name, self.gage_id_name}, "donor_gage_file": {self.gage_id_name, "longitude", "latitude"}, "calval_stats_file": self._required_columns_calval_stats(config), "calib_param_file": {self.gage_id_name, "formulation"}, "ngen_hydrofabric_file": { self.divide_id_name, self.vpu_id_name, "geometry", }, "huc12_hydrofabric_file": {self.huc12_id_name, "geometry"}, "divide_huc12_cwt_file": {self.divide_id_name, self.huc12_id_name}, "formulation_file": {self.divide_id_name, "formulation"}, } # add snow cover file if it exists in the config from nwm_region_mgr.parreg import config_schema as pcs if isinstance(config, pcs.Config) and hasattr(config, "snow_cover"): if config.snow_cover.consider_snowness: snow_cover_file = getattr(config.snow_cover, "snow_cover_file", None) snow_frac_col = getattr(config.snow_cover, "column", None) if snow_cover_file and snow_frac_col: file_dict["snow_cover_file"] = {self.divide_id_name, snow_frac_col} return file_dict def _assemble_file_paths( self, config: BaseModel, exclude: set[str] = None, include: set[str] = None ) -> dict[str, Path]: """Assemble file paths from the configuration. Args: config: The configuration object. exclude: Optional set of fields to exclude from the list of paths. include: Optional set of fields to include in the list of paths. Returns: Dictionary mapping path names to their Path objects. """ # all input file paths to be checked in the config path_fields = self._file_required_column_map(config) if exclude: # exclude specified fields from the list path_fields = [field for field in path_fields if field not in exclude] if include: # include specified fields in the list path_fields = [field for field in path_fields if field in include] # create a dictionary to hold the paths paths = {} for path1 in path_fields: val = getattr(config.general, path1, None) or getattr( getattr(config, "snow_cover", None), path1, None ) if val is not None: if isinstance(val, (str, Path)): paths[path1] = Path(val) elif isinstance(val, dict): paths[path1] = { k: Path(v).expanduser() for k, v in val.items() if isinstance(v, (str, Path)) } else: logger.warning(f"Unsupported type for {path1}: {type(val)}") # Remove any None values and return the paths dictionary return {k: v for k, v in paths.items() if v is not None} def _validate_paths( self, paths: str | Path | list[str | Path] | dict[str, str | Path] ) -> None: """Validate that all file and directory paths exist. Args: paths: A string, Path, list, or dict of strings/Paths to validate. Raises: FileNotFoundError: If any path does not exist. """ if isinstance(paths, (str, Path)): paths = [paths] elif isinstance(paths, dict): paths = list(flatten_dict(paths).values()) missing_paths = [p for p in paths if not Path(p).exists()] if missing_paths: msg = f"Missing paths: {missing_paths}" logger.error(msg) raise FileNotFoundError(msg) def _check_file_columns( self, config: BaseModel, dict_path: Dict[str, str | Path | dict[str, str | Path]] = None, ) -> None: """Check if the required columns are present in the files in the configuration. Args: config: The configuration object to check. dict_path: Optional dictionary of file paths to check. Raises: ValueError: If any required columns are missing in the files. """ if dict_path is None: msg = "No file paths provided for checking columns." logger.error(msg) raise ValueError(msg) dict_cols = self._file_required_column_map(config) logger.debug("Required columns for files: %s", dict_cols) # loop through the required files and check their columns for file_key, file_path in dict_path.items(): if file_key not in dict_cols: logger.warning( f"No required columns defined for {file_key}. Skipping column check." ) continue if isinstance(file_path, (str, Path)): file_path = [ Path(file_path) ] # Ensure file_path is a list of Path objects elif isinstance(file_path, dict): file_path = [Path(v) for v in file_path.values()] else: msg = f"Invalid type for file path '{file_key}': {type(file_path)}. Must be str, Path, or dict." logger.error(msg) raise ValueError(msg) required_columns = dict_cols[file_key] if not required_columns: logger.warning( f"No required columns defined for {file_key}. Skipping column check." ) continue for file in file_path: if "geometry" in required_columns: layer = ( getattr(config.general.layer_name, "ngen", None) if "ngen" in file_key else getattr(config.general.layer_name, "huc12", None) ) check_columns_hydrofabric(file, required_columns, layer_name=layer) else: check_columns_dataframe(file, required_columns) @property @lru_cache def load_and_process_config(self) -> BaseModel: """Load, validate, and process the configuration files. Returns: Config object with validated structure and substituted placeholders. """ # Load and validate the configuration config = self._load_and_validate_config(self.config_file, self.config_schema) # Substitute placeholders in the configuration config = self._substitute_placeholders(config) return config def set_logging(self): """Set up logging based on the configuration.""" log_level = self.config.general.logging.level.upper() log_file = Path(self.config.general.logging.file) setup_logging( level=log_level, log_file=log_file, file_level=log_level, ) from nwm_region_mgr.formreg import config_schema as fcs # avoid circular import config_str = ( "Formulation Regionalization" if isinstance(self.config, fcs.Config) else "Parameter Regionalization" ) logger.info( "%s - Config files: %s", config_str, [str(f) for f in self.config_file] ) logger.info("Set up logging with level: %s", log_level) logger.info("Log files: %s", log_file) def validate_files(self): """Validate that all file paths exist and required columns are present in the files.""" # Assemble file paths from the configuration paths = self._assemble_file_paths(self.config, exclude={"formulation_file"}) logger.debug("Input file paths from configuration: %s", paths) # Validate that all file paths exist self._validate_paths(paths) # Check if the required columns are present in the files self._check_file_columns(self.config, paths) logger.info("Successfully validated and processed the configurations.") # Save the final configuration cc = getattr(self.config.output, "config_final", None) if cc is not None: cc.save_to_file(self.config, data_str="Final Configuration") def set_vpu(self, vpu: str): """Set the vpu.""" self.vpu = vpu if hasattr(self, "vpu_gdf"): delattr(self, "vpu_gdf") def set_vpu_gdf(self) -> gpd.GeoDataFrame: """Set the GeoDataFrame for the current vpu.""" layer_name = getattr(self.config.general.layer_name, "ngen", "divides") gdf = gpd.read_file( Path(self.config.general.ngen_hydrofabric_file[self.vpu]), layer=layer_name, ) gdf = gdf[[self.divide_id_name.lower(), "geometry"]] self.vpu_gdf = gdf.copy() def get_vpu_gdf(self) -> gpd.GeoDataFrame: """Get the GeoDataFrame for the current vpu.""" if not hasattr(self, "vpu_gdf"): self.set_vpu_gdf() return self.vpu_gdf @contextmanager def timing_block(self, step_str: str): """Context manager for timing code execution. Args: step_str: Description of the step being timed. """ start = time() yield end = time() logger.info(f" Execution time for {step_str}: {end - start} seconds") @property def divide_id_name(self): """Id_name for divide from the config.""" return getattr(self.config.general.id_col, "divide", "divide_id") @property def gage_id_name(self): """Id_name for gage from the config.""" return getattr(self.config.general.id_col, "gage", "gage_id") @property def drainage_area_name(self): """Id_name for drainage_area from the config.""" return getattr(self.config.general.id_col, "drainage_area", "areasqkm") @property def huc12_id_name(self): """Id_name for huc12 from the config.""" return getattr(self.config.general.id_col, "huc12", "huc_12") @property def vpu_id_name(self): """Id_name for vpu from the config.""" return getattr(self.config.general.id_col, "vpu", "vpuid") @property def donor_id_name(self): """Id_name for donor gage from the config.""" return getattr(self.config.general.id_col, "donor", "donor") @property def gage_crosswalk_file(self): """Get the gage crosswalk file path from the configuration.""" return self.config.general.gage_divide_cwt_file @property def gage_crosswalk(self): """Get the gage crosswalk DataFrame.""" return read_table(self.gage_crosswalk_file, dtype={self.gage_id_name: str}) @property def donor_gage_file(self): """Get the donor gage file path from the configuration.""" return self.config.general.donor_gage_file @property def donor_gages(self): """Get the donor gage DataFrame.""" return read_table(self.donor_gage_file, dtype={self.gage_id_name: str}) def expand_config_for_vpu(self, vpu: str): """Expand the config to include VPUs (e.g., needed for all donors). Some donors may come from nearby VPUs, so we need to make sure that the configuration includes all donor VPUs. For example, if the current VPU '03S' uses donor gages from '03W','03N', and '06', then this function will expand the config to include those VPUs as well for the ngen_hydrofabric_file and output files. --- Before expansion--- ngen_hydrofabric_file: '03S': '/path/to/ngen_hydro_03S.gpkg' --- After expansion--- ngen_hydrofabric_file: '03S': '/path/to/ngen_hydro_03S.gpkg' '03W': '/path/to/ngen_hydro_03W.gpkg' '03N': '/path/to/ngen_hydro_03N.gpkg' '06': '/path/to/ngen_hydro_06.gpkg' """ def add_entry_if_missing(d: dict, vpu_key: str) -> dict: if vpu_key not in d: first_key = next(iter(d)) d[vpu_key] = d[first_key].replace(first_key, vpu_key) add_entry_if_missing(self.config.general.ngen_hydrofabric_file, vpu) co = getattr(self.config.output, "summary_score", None) if co is not None: add_entry_if_missing(co.stem, vpu) co = getattr(self.config.output, "formulation", None) if co is not None: add_entry_if_missing(co.stem, vpu) # save the expanded configuration if hasattr(self.config.output, "config_final"): getattr(self.config.output, "config_final").save_to_file( self.config, data_str="Expanded final configuration" ) def get_output_file_path( self, output_section: str, vpu: str = None, algorithm: str = None, use_stem_suffix: bool = False, ): """Get the output file name from the output configuration.""" output_config = getattr(self.config.output, output_section, None) if output_config is None: msg = f"Output section '{output_section}' not found in configuration for {self.__class__.__name__}." logger.error(msg) raise ValueError(msg) return output_config.get_file_path( vpu, algorithm=algorithm, use_stem_suffix=use_stem_suffix )