"""Defines classes for validating the configuration for parameter regionalization.
Classes:
- GeneralConfig: General configuration settings specific to parameter regionalization.
- DonorConfig: Configuration for donor selection.
- AttrDatasetConfig: Configuration for attribute datasets used in the regionalization process.
- AvailableAttrsConfig: Configuration for available attribute datasets for use in the regionalization process.
- SnowCoverConfig: Configuration for snow cover data.
- AlgorithmConfig: Algorithm configuration class.
- ParameterOutputConfig: Configuration for parameter regionalization output.
- Config: Top-level configuration class for parameter regionalization.
"""
import logging
from pathlib import Path
from typing import Dict, List, Literal, get_args
import pandas as pd
import pyarrow.parquet as pq
from pydantic import BaseModel, Field, model_validator
from nwm_region_mgr.utils import (
BaseConfig,
BaseGeneralConfig,
BaseOutputConfig,
read_table,
)
logger = logging.getLogger(__name__)
[docs]
class GeneralConfig(BaseGeneralConfig):
"""General configuration settings specific to parameter regionalization."""
attr_dataset_list: List[Literal["ngen", "hlr", "streamcat"]] = Field(
description="List of attribute dataset names to use. Valid options include 'ngen', 'hlr', 'streamcat'.",
examples=["ngen", "streamcat"],
default=["ngen"],
)
algorithm_list: List[
Literal["gower", "urf", "kmeans", "kmedoids", "hdbscan", "birch", "proximity"]
] = Field(
description="Algorithms to use. Valid options ('gower', 'urf', 'kmeans', 'kmedoids', 'hdbscan', 'birch', 'proximity').",
examples=["gower", "kmeans"],
default=["gower"],
)
manual_pairings_file: Path | str | Dict[str, Path] | Dict[str, str] | None = Field(
description=(
"Path to the manual pairings file. If provided, this file will be used to specify "
"manual donor-receiver pairings, overriding the algorithmic selections."
),
examples="{static_data_dir}/region/manual_pairings/manual_pairs_{vpu_list}.csv",
default=None,
)
[docs]
class MetricEvalPeriod(BaseModel):
"""Configuration for the evaluation period of metrics to be used for screening donors."""
col_name: str | None = Field(
description=(
"Name of the column in the donor stats file that contains the evaluation period. "
"No filtering by evaluation period if None."
),
examples="evalPeriod",
default=None,
)
value: str | None = Field(
description=(
"Value of the evaluation period to filter donor stats. No filtering by evaluation period if None."
),
examples="full",
default=None,
)
[docs]
class MetricThreshold(BaseModel):
"""Configuration for the thresholds of metrics to be used for screening donors."""
min: float | None = Field(
description="Minimum threshold for the metric. If None, no minimum threshold is applied.",
examples=None,
default=None,
)
max: float | None = Field(
description="Maximum threshold for the metric. If None, no maximum threshold is applied.",
examples=None,
default=None,
)
absolute: bool | None = Field(
description="If True, apply the absolute value of the metric before applying the thresholds.",
examples=None,
default=False,
)
@model_validator(mode="after")
def validate_either_field_exists(self):
"""Validate that at least one of 'min' or 'max' is provided."""
if not self.min and not self.max:
raise ValueError("At least one of 'min' or 'max' must be provided.")
return self
[docs]
class DonorConfig(BaseModel):
"""Configuration for donor selection."""
buffer_km: float | None = Field(
description="Size of buffer (in km) around current VPU to identify qualified donors.",
examples=100.0,
default=0.0,
)
metric_eval_period: MetricEvalPeriod | None = Field(
description="Evaluation period of metrics to be used for screening donors.",
examples={"col_name": "eval_period", "value": "full"},
default=None,
)
metric_threshold: Dict[str, MetricThreshold] = Field(
description=(
"Dictionary of metric thresholds to be used for screening donors. Each key is a metric name, "
"and the value is a MetricThreshold object specifying the min, max, and absolute settings. "
"Refer to schema of MetricThreshold for details."
),
examples={
"cor": {"min": 0.4, "max": None, "absolute": False},
"kge": {"min": 0.2, "max": None, "absolute": False},
},
default=None,
)
def get_qualified_donors(
self,
config: BaseModel,
donors0: list = None,
init_donor_df: pd.DataFrame = None,
) -> list:
"""Screen donors based on the metric thresholds and evaluation period."""
divide_id_name = getattr(config.general.id_col, "divide", "divide_id")
gage_id_name = getattr(config.general.id_col, "gage", "gage_id")
# initial donors
donors = [
d for d in init_donor_df[gage_id_name].unique().tolist() if d in donors0
]
donor_cats = (
init_donor_df.loc[init_donor_df[gage_id_name].isin(donors), divide_id_name]
.unique()
.tolist()
)
# read the donor stats file
stats_file = config.general.calval_stats_file
df = read_table(stats_file, dtype={gage_id_name: str})
# filter based on initial donors
df = df[df[gage_id_name].isin(donors)]
if df.empty:
logger.warning(
f"No matching gages found in {stats_file} for the initial donors. Returning the initial list."
)
return {gage_id_name: donors, divide_id_name: donor_cats}
else:
gages_stat = df[gage_id_name].unique().tolist()
gages_missing = [g for g in donors if g not in gages_stat]
if gages_missing:
logger.warning(
f"Some gages in the initial list are not found in the stats file: {stats_file} "
)
logger.debug(f"Missing gages: {gages_missing}")
# filter based on the evaluation period
if self.metric_eval_period:
periods = df[self.metric_eval_period.col_name].unique()
if self.metric_eval_period.value not in periods:
raise ValueError(
f"Column {self.metric_eval_period.value} not found in {stats_file}."
)
else:
# Filter the DataFrame based on the evaluation period
df = df[
df[self.metric_eval_period.col_name]
== self.metric_eval_period.value
]
else:
logger.warning(
f"No evaluation period provided. Using all periods in {stats_file}."
)
# filter based on metric thresholds
for col, threshold in self.metric_threshold.items():
if threshold.absolute:
df[col] = df[col].abs()
if threshold.min is not None:
df = df[df[col] >= threshold.min]
if threshold.max is not None:
df = df[df[col] <= threshold.max]
donors = df[gage_id_name].unique().tolist()
donor_cats = (
init_donor_df[init_donor_df[gage_id_name].isin(donors)][divide_id_name]
.unique()
.tolist()
)
logger.info(
f"Number of donors after filtering: {len(donors)} gages, {len(donor_cats)} catchments"
)
# check if any donors are left after filtering
if not donors:
logger.info(
"No donors left after filtering. Check the metric thresholds and evaluation period."
)
return {gage_id_name: donors, divide_id_name: donor_cats}
[docs]
class AttrDatasetConfig(BaseModel):
"""Configuration for attribute datasets used in the regionalization process."""
attr_list: list | None = Field(
description=(
"List of attributes to use from this dataset. If not provided, attributes will be determined "
"from attr_select_file. Either this field or attr_select_file must be provided."
"If both are provided, attr_list takes priority."
),
examples=None,
default=None,
)
attr_select_file: Path | str | None = Field(
description="Path to file where selection of attributes to use during regionalization may be found.",
examples=["attr_selection_ngen.csv"],
default=None,
)
attr_data_file: Path | str | None = Field(
description="Path to file where attribute data may be found.",
examples=["attr_ngen_{domain}.parquet"],
default=None,
)
base_attr_list: list | None = Field(
description=(
"Small list of basic attributes during a 2nd round of pairing if no donor is found using "
"the full set of selected attributes during the first round."
),
examples=["elevation", "slope", "aspect"],
default=None,
)
@model_validator(mode="after")
def validate_either_field_exists(self):
"""Validate that at least one of 'attr_list' or 'attr_select_file' is provided."""
if not self.attr_list and not self.attr_select_file:
raise ValueError(
"At least one of 'attr_list' or 'attr_select_file' must be provided."
" If both are provided, 'attr_list' takes priority."
)
return self
def _get_selected_attrs(self):
"""Private method to load the list of selected attributes."""
# determine list of attributes to use from either attr_list or attr_select_file
# if both are provided, attr_list takes priority
if self.attr_list:
# make sure attr_list is valid
attrs1 = [
x
for x in self.attr_list
if x not in pq.ParquetFile(self.attr_data_file).schema.names
]
if attrs1:
msg = f"These attributes {attrs1} are not found in {self.attr_data_file}. Please check the configuration."
logger.error(msg)
raise ValueError(msg)
else:
attr_select_path = Path(self.attr_select_file)
if not attr_select_path.exists():
raise FileNotFoundError(
f"Select file not found: {self.attr_select_file}"
)
df_attrs = pd.read_csv(attr_select_path)
if "select" not in df_attrs.columns or "attr_name" not in df_attrs.columns:
raise ValueError(
f"Missing required columns 'select' and/or 'attr_name' in file: {self.attr_select_file}"
)
self.attr_list = df_attrs[df_attrs["select"] == 1]["attr_name"].to_list()
def get_attr_data(self, id_name: str = "divide_id") -> pd.DataFrame:
"""Load attribute data filtered by selected attributes."""
self._get_selected_attrs()
if not self.attr_data_file:
raise FileNotFoundError("No attr_data_file provided.")
attr_data_path = Path(self.attr_data_file)
if not attr_data_path.exists():
raise FileNotFoundError(
f"Attribute data file not found: {self.attr_data_file}"
)
suffix = attr_data_path.suffix.lower()
if suffix == ".csv":
df_data = pd.read_csv(attr_data_path)
elif suffix == ".parquet":
df_data = pd.read_parquet(attr_data_path)
else:
raise ValueError(f"Unsupported file format: {suffix}")
missing_cols = set(self.attr_list) - set(df_data.columns)
if missing_cols:
raise ValueError(f"Missing attributes in data file: {missing_cols}")
return df_data[[id_name] + self.attr_list]
[docs]
class AvailableAttrsConfig(BaseModel):
"""Configuration for available attribute datasets for use in the regionalization process."""
ngen: AttrDatasetConfig = Field(
description=(
"Configuration for NGEN attribute dataset."
"(https://lynker-spatial.s3-us-west-2.amazonaws.com/hydrofabric/v2.2/hfv2.2-data_model.html)."
),
examples={
"attr_list": None,
"attr_select_file": "{base_dir}/inputs/attr_config/attr_selection_ngen.csv",
"attr_data_file": "{base_dir}/inputs/attr_datasets/ngen/attr_ngen_{domain}.parquet",
"base_attr_list": ["elevation", "slope", "aspect"],
},
)
hlr: AttrDatasetConfig = Field(
description=(
"Configuration for Hydrologic Landscape Regions (HLR) attribute dataset "
"(https://www.usgs.gov/publications/hydrologic-landscape-regions-united-states)."
),
examples={
"attr_list": None,
"attr_select_file": "{base_dir}/inputs/attr_config/attr_selection_hlr.csv",
"attr_data_file": "{base_dir}/inputs/attr_datasets/hlr/attr_hlr_{domain}.parquet",
"base_attr_list": ["PPT", "SAND"],
},
)
streamcat: AttrDatasetConfig = Field(
description=(
"Configuration for StreamCat attribute dataset "
"(https://www.epa.gov/national-aquatic-resource-surveys/streamcat-dataset)."
),
examples={
"attr_list": None,
"attr_select_file": "{base_dir}/inputs/attr_config/attr_selection_streamcat.csv",
"attr_data_file": "{base_dir}/inputs/attr_datasets/streamcat/attr_streamcat_{domain}.parquet",
"base_attr_list": ["Precip_Minus_EVT", "Elev", "BFI"],
},
)
[docs]
class SnowCoverConfig(BaseModel):
"""Configuration for snow cover."""
consider_snowness: bool | None = Field(
description=(
"Whether to consider snow driven and non-snow driven catchments separately "
"in the regionalization process. If True, snow-driven receivers will only consider "
"snow-driven donors and non-snow-driven receivers will only consider non-snow-driven donors."
),
examples=True,
default=True,
)
snow_cover_file: Path | str | dict[str, Path | str] | None = Field(
description="Path to the snow cover data file, or a dictionary with VPU as keys and file paths as values.",
examples="vpu{vpu_list}_snow_frac.parquet",
default=None,
)
column: str | None = Field(
description="Column name in the snow cover data file that contains the snow cover percentage.",
examples="snow_pc_hydroatlas",
default="snow_pc_hydroatlas",
)
threshold: float | None = Field(
description="Threshold value for snow cover percentage to determine if a catchment is considered snow-driven.",
examples="20",
default=None,
)
@model_validator(mode="after")
def validate_snow_cover_config(self):
"""Validate the snow cover configuration."""
if self.consider_snowness:
if not self.snow_cover_file or not self.column or not self.threshold:
raise ValueError(
"When 'consider_snowness' is True, 'snow_cover_file', 'column', and 'threshold' must be provided."
)
return self
[docs]
class AlgoGeneral(BaseModel):
"""General configurations shared by all regionalization algorithms."""
max_spa_dist: float | None = Field(
default=1000.0,
description="Maximum spatial distance (km) to consider a donor suitable",
examples=1500.0,
)
n_donor_max: int | None = Field(
default=3,
description="Maximum number of donors to keep that satisfy all criteria",
examples=3,
)
min_var_pca: float | None = Field(
default=0.9,
description="Minimum total variance explained by chosen PCA components",
examples=0.8,
)
[docs]
class Gower(AlgoGeneral):
"""Configurations for the distance-based algorithm Gower."""
min_attr_dist: float | None = Field(
default=0.1,
description=(
"Minimum attribute distance. If one or more donors have a distance to receiver "
"smaller than this threshold, stop searching."
),
examples=0.1,
)
max_attr_dist: float | None = Field(
default=0.20,
description=(
"Maximum attribute distance. Donors with distance to receiver larger than this value are discarded, "
"unless no donor smaller than this threshold is available."
),
examples=0.25,
)
min_spa_dist: float | None = Field(
default=100.0,
description="Starting distance (km) to iteratively search for donors in the neighborhood",
examples=200.0,
)
zero_spa_dist: float | None = Field(
default=1.0,
description=(
"Distance threshold (in km) where receiver adopts a donor directly "
"(i.e., donor/receiver are considered overlapping each other)"
),
examples=1.0,
)
[docs]
class URF(AlgoGeneral):
"""Unsupervised Random Forest (URF) algorithm class."""
pca: bool | None = Field(
default=False,
description=(
"Whether to perform PCA on the attribute data before building the forest. "
"Preliminary testing indicates limited difference in results with/without PCA."
),
examples=False,
)
n_trees: int | None = Field(
default=500,
description="Number of trees in the random forest.",
examples=500,
)
max_depth: int | None = Field(
default=3,
description=(
"Maximum depth of each tree. If None, nodes are expanded until all leaves are pure."
),
examples=3,
)
min_attr_dist: float | None = Field(
default=None,
description=(
"Minimum attribute distance. If one or more donors have a distance to "
"receiver smaller than this threshold, stop searching."
),
examples=0.1,
)
max_attr_dist: float | None = Field(
default=None,
description=(
"Maximum attribute distance. Donors with distance to receiver larger than this value are discarded, "
"unless no donor smaller than this threshold is available."
),
examples=0.25,
)
min_spa_dist: float | None = Field(
default=None,
description="Starting distance (km) to iteratively search for donors in the neighborhood",
examples=200.0,
)
zero_spa_dist: float | None = Field(
default=None,
description=(
"Distance threshold (in km) where receiver adopts a donor directly "
"(i.e., donor/receiver are considered overlapping each other)"
),
examples=1.0,
)
[docs]
class KMeans(AlgoGeneral):
"""K-means algorithm class."""
n_iter_max: int | None = Field(
default=100,
description="Maximum number of iterations for the algorithm.",
examples=100,
)
init: Literal["k-means++", "random"] | None = Field(
default="k-means++",
description="Method for initialization.",
examples="k-means++",
)
n_init: int | None = Field(
default=None,
description="Number of times the k-means algorithm will be run with different centroid seeds.",
examples=3,
)
[docs]
class KMedoids(AlgoGeneral):
"""K-medoids algorithm class."""
n_iter_max: int | None = Field(
default=None,
description="Maximum number of iterations for the algorithm.",
examples=100,
)
init: Literal["random", "heuristic", "k-medoids++", "build"] | None = Field(
default="heuristic",
description="Method for initialization.",
examples="heuristic",
)
[docs]
class HDBSCAN(AlgoGeneral):
"""Hierarchical Density Based Spatial Clustering of Applications with Noise (HDBSCAN) algorithm class."""
n_donor_max: int | None = Field(
default=20,
description="Maximum number of donors to keep that satisfy all criteria.",
examples=20,
)
min_cluster_size: int | None = Field(
default=3,
description="Minimum size of clusters (to avoid being considered noise)",
examples=3,
)
[docs]
class Birch(AlgoGeneral):
"""Balanced Iterative Reducing and Clustering using Hierarchies (BIRCH) algorithm class."""
branching_factor: int | None = Field(
default=50,
description="Branching factor for the BIRCH algorithm.",
examples=50,
)
min_thresh: float | None = Field(
default=1.5,
description=(
"Minimum threshold for the BIRCH algorithm. The algorithm will iterate through thresholds "
"between min_thresh and max_thresh to identify a suitable threshold."
),
examples=1.5,
)
max_thresh: float | None = Field(
default=4.0,
description=(
"Maximum threshold for the BIRCH algorithm. The algorithm will iterate through thresholds "
"between min_thresh and max_thresh to identify a suitable threshold."
),
examples=4.0,
)
max_resample: int | None = Field(
default=20,
description="Maximum number of resamples.",
examples=20,
)
[docs]
class AlgorithmConfig(BaseModel):
"""Algorithm configuration class."""
algo_general: AlgoGeneral = Field(
description="General configurations shared by all regionalization algorithms.",
default_factory=AlgoGeneral,
)
gower: Gower = Field(
description="Configurations for the distance-based algorithm Gower.",
default_factory=Gower,
)
urf: URF = Field(
description="Configurations for the distance-based algorithm Unsupervised Random Forest (URF)",
default_factory=URF,
)
kmeans: KMeans = Field(
description="Configurations for the clustering algorithm K-means",
default_factory=KMeans,
)
kmedoids: KMedoids = Field(
description="Configurations for the clustering algorithm K-medoids",
default_factory=KMedoids,
)
hdbscan: HDBSCAN = Field(
description=(
"Configurations for the clustering algorithm Hierarchical Density Based Spatial Clustering "
"of Applications with Noise (HDBSCAN)",
),
default_factory=HDBSCAN,
)
birch: Birch = Field(
description=(
"Configurations for the clustering algorithm Balanced Iterative Reducing and "
"Clustering using Hierarchies (BIRCH)",
),
default_factory=Birch,
)
@model_validator(mode="before")
@classmethod
def merge_algo_general(cls, values: dict) -> dict:
"""Merge general algorithm parameters with specific algorithm parameters."""
general = values.get("algo_general")
if not general:
return values
# dynamically find all optional algorithm fields (excluding 'algo_general')
algo_fields = [
field
for field, annotation in cls.__annotations__.items()
if field != "algo_general"
and (
# either directly the class or Optional[class]
isinstance(annotation, type)
or (get_args(annotation) and isinstance(get_args(annotation)[0], type))
)
]
# merge general parameters into each algorithm's parameters
for key in algo_fields:
if key in values and values[key] is not None:
values[key] = {**general, **values[key]}
return values
[docs]
class ParameterOutputConfig(BaseModel):
"""Configuration for parameter regionalization output."""
pairs: BaseOutputConfig = Field(
description="Configuration for saving donor-receiver pairs.",
default_factory=BaseOutputConfig,
examples={
"save": True,
"path": "{base_dir}/outputs/{run_name}/pairs",
"stem": "pairs_{algorithm_list}_{domain}_vpu{vpu_list}",
"stem_suffix": "_mswm", # suffix for the pairs file to be used by MSWM
"format": "parquet",
"plots": {
"spatial_map": True,
"histogram": True,
"columns_to_plot": ["distSpatial", "distAttr"], # columns to plot
},
"plot_path": "{base_dir}/outputs/{run_name}/pairs/plots",
},
)
params: BaseOutputConfig = Field(
description="Configuration for saving regionalized parameters.",
default_factory=BaseOutputConfig,
examples={
"save": True,
"path": "{base_dir}/outputs/{run_name}/params",
"stem": "formulation_params_{algorithm_list}_{domain}_vpu{vpu_list}",
"format": "csv",
"plots": {
"spatial_map": True,
"columns_to_plot": ["MP", "MFSNO", "uztwm", "uzfwm", "pxtemp", "plwhc"],
},
"plot_path": "{base_dir}/outputs/{run_name}/params/plots",
},
)
attr_data_final: BaseOutputConfig = Field(
description=(
"Configuration for saving and plotting final attribute data used in regionalization. "
"Note only selected attributes are saved, and attribute names are prefixed with "
"the name of the corresponding attribute source (e.g., 'Elev' in StreamCat becomes 'streamcat_Elev').",
),
default_factory=BaseOutputConfig,
examples={
"save": True,
"path": "{base_dir}/outputs/{run_name}/attr_data_final",
"stem": "attr_{domain}_vpu{vpu_list}",
"format": "parquet",
"plots": {
"spatial_map": True,
"histogram": True,
"columns_to_plot": [
"streamcat_Elev",
"streamcat_BFI",
"streamcat_Precip_Minus_EVT",
"hlr_PMPE",
"hlr_SAND",
"hlr_TAVE",
], # attributes to plot
},
"plot_path": "{base_dir}/outputs/{run_name}/attr_data_final/plots",
},
)
config_final: BaseOutputConfig = Field(
description="Configuration for saving final configuration file used in regionalization.",
default_factory=BaseOutputConfig,
examples={
"save": True,
"path": "{base_dir}/outputs/{run_name}/config_parreg_final.yaml",
},
)
spatial_distance: BaseOutputConfig = Field(
description="Configuration for saving spatial distance data.",
default_factory=BaseOutputConfig,
examples={
"save": True,
"path": "{base_dir}/outputs/{run_name}/spatial_distance",
"format": "parquet",
},
)
[docs]
class Config(BaseConfig):
"""Configuration class."""
general: GeneralConfig = Field(
description="General configuration settings specific to parameter regionalization.",
default_factory=GeneralConfig,
)
donor: DonorConfig = Field(
description="Configuration for donor selection.",
default_factory=DonorConfig,
)
attr_datasets: AvailableAttrsConfig = Field(
description="Configuration for attribute datasets available for use in regionalization.",
default_factory=AvailableAttrsConfig,
)
snow_cover: SnowCoverConfig = Field(
description="Configuration for snow cover data to be used in determining whether catchments are snow-driven.",
default_factory=SnowCoverConfig,
)
algorithms: AlgorithmConfig = Field(
description="Algorithm configuration class. See specific algorithms for additional arguments.",
default_factory=AlgorithmConfig,
)
output: ParameterOutputConfig = Field(
description="Configuration for parameter regionalization output.",
default_factory=ParameterOutputConfig,
)
@model_validator(mode="after")
def check_required_algorithms_present(self):
"""Check if required algorithms are present."""
required = set(self.general.algorithm_list)
defined = set(self.algorithms.model_dump(exclude_unset=True).keys())
missing = required - defined
missing = {
m for m in missing if m != "proximity"
} # "proximity" appproch does not need algorithm parameters
if missing:
raise ValueError(
f"The following algorithms are listed in 'general.algorithm_list' "
f"but missing from the 'algorithms' section: {missing}"
)
return self
@model_validator(mode="after")
def check_required_attr_datasets_present(self):
"""Check that the required attributes are present."""
required = set(self.general.attr_dataset_list)
defined = set(self.attr_datasets.model_dump(exclude_unset=True).keys())
missing = required - defined
if missing:
raise ValueError(
f"The following attribute datasets are listed in 'general.attr_dataset_list' "
f"but missing from the 'attr_datasets' section: {missing}"
)
return self