Source code for nwm_region_mgr.utils.plot_utils

"""Utility functions for plotting spatial maps and histograms etc.

plot_utils.py

Functions:
    - _plot_columns_by_dtype: Plot multiple columns of a GeoDataFrame based on their data types.
    - plot_spatial_map: Generate a spatial map plot for the given data.
    - plot_histogram: Generate histogram plot for the spatial or attribute distance between donors and receivers.
    - plot_point_map: Plot the spatial distribution of locations as points on a base layer map.

"""

import logging
import math
from itertools import cycle, islice

import geopandas as gpd
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable

from nwm_region_mgr.utils.hydrofabric_utils import dissolve_polygons

logger = logging.getLogger(__name__)


def _plot_columns_by_dtype(
    gdf: gpd.GeoDataFrame,
    columns: list[str],
    fillna_value=None,
    num_bins: int = None,
    cmap_numeric: str = "viridis",
    cmap_categorical: str = "Set3",
    figsize=(10, 6),
    antialiased: bool = False,
):
    """Plot multiple GeoDataFrame columns in subplots based on data type.

    Args:
        gdf: GeoDataFrame to plot
        columns: list of column names to plot
        fillna_value: value to fill NaNs
        num_bins: number of bins for numeric columns (optional)
        cmap_numeric: colormap for numeric values
        cmap_categorical: colormap for categorical values
        figsize: figure size
        antialiased: whether to disable antialiasing for the plots

    """
    n = len(columns)
    ncols = min(4, n)
    nrows = math.ceil(n / ncols)

    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
    axes = np.atleast_1d(axes).flatten()

    # outer boundary
    combined_polygon = dissolve_polygons(gdf, remove_holes=True)
    boundary = combined_polygon.boundary

    for i, column in enumerate(columns):
        ax = axes[i]
        dtype = gdf[column].dtype

        gdf_plot = gdf.copy()
        if fillna_value is not None:
            gdf_plot[column] = gdf_plot[column].fillna(fillna_value)

        gdf_plot = gdf_plot[~gdf_plot[column].isna()]
        if gdf_plot.empty:
            ax.axis("off")
            continue

        # numeric columns
        if pd.api.types.is_numeric_dtype(dtype):
            values = gdf_plot[column]

            if num_bins is not None:  # Treat as categorical (binned)
                values = pd.cut(values, bins=num_bins)
                gdf_plot["__binned__"] = values
                gdf_plot.plot(
                    ax=ax,
                    column="__binned__",
                    cmap=cmap_numeric,
                    edgecolor="none",
                    linewidth=0,
                )

                legend = ax.get_legend()
                if legend:
                    labels = [t.get_text() for t in legend.get_texts()]
                    ncol_legend = min(3, max(1, len(labels) // 2))

                    legend.set_bbox_to_anchor((0.5, -0.18))
                    legend.set_loc("upper center")
                    legend.set_ncol(ncol_legend)
                    legend.set_frame_on(True)

            else:
                vmin, vmax = values.min(), values.max()
                norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
                cmap = plt.get_cmap(cmap_numeric)

                gdf_plot.plot(
                    ax=ax,
                    column=column,
                    cmap=cmap,
                    edgecolor="none",
                    linewidth=0,
                )

                divider = make_axes_locatable(ax)
                cax = divider.append_axes("bottom", size="5%", pad=0.4)

                sm = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
                sm.set_array([])

                fig.colorbar(sm, cax=cax, orientation="horizontal")

        # categorical columns
        elif pd.api.types.is_object_dtype(dtype) or pd.api.types.is_categorical_dtype(
            dtype
        ):
            gdf_plot[column] = gdf_plot[column].astype("category")

            gdf_plot.plot(
                ax=ax,
                column=column,
                cmap=cmap_categorical,
                edgecolor="none",
                linewidth=0,
                legend=True,
            )

            legend = ax.get_legend()
            if legend:
                labels = [t.get_text() for t in legend.get_texts()]

                # dynamic column layout
                if len(labels) <= 3:
                    ncol_legend = len(labels)
                elif len(labels) <= 10:
                    ncol_legend = 4
                else:
                    ncol_legend = 5

                legend.set_bbox_to_anchor((0.5, -0.18))
                legend.set_loc("upper center")
                legend.set_ncols(ncol_legend)
                legend.set_frame_on(True)

                # shrink font if too many categories
                if len(labels) > 10:
                    for text in legend.get_texts():
                        text.set_fontsize(8)

        else:
            ax.set_title(f"Unsupported dtype: {column}")
            ax.axis("off")
            continue

        # antialiasing fix
        for coll in ax.collections:
            coll.set_antialiased(antialiased)

        # boundary overlay
        gpd.GeoSeries(boundary).plot(ax=ax, color="black", linewidth=0.5)

        ax.set_title(column)
        ax.set_axis_off()

    # turn off unused axes
    for j in range(i + 1, len(axes)):
        axes[j].axis("off")

    # leave space for bottom elements
    plt.tight_layout(rect=[0, 0.12, 1, 0.95])

    return fig, axes[: i + 1]


[docs] def plot_spatial_map(gdf: gpd.GeoDataFrame, d1: dict) -> None: """Generate a spatial map plot for the given data. Args: gdf : gpd.GeoDataFrame GeoDataFrame containing the data to be plotted. d1: dict Information needed for creating the plot """ _, ax = plt.subplots(figsize=(8, 6)) # create spatial map fig, axes = _plot_columns_by_dtype( gdf, columns=d1["columns"], fillna_value=d1.get("fillna_value", None), num_bins=d1.get("num_bins", None), cmap_numeric=d1.get("cmap_numeric", "viridis"), cmap_categorical=d1.get("cmap_categorical", "Set3"), ) d1["title"] = d1.get("title", f"Spatial Map of {d1['var_str']}: VPU {d1['vpu']}") if algorithm := d1.get("algorithm", None): d1["title"] += f" (Algorithm: {algorithm})" fig.suptitle(d1["title"], fontsize=16, fontweight="bold") plt.tight_layout( rect=[0, 0.03, 1, 0.95] ) # Adjust layout to make room for the title # save the figure if d1.get("outfile") is not None: plt.savefig(d1["outfile"], bbox_inches="tight") plt.close(fig) logger.info( f"Spatial map of {d1['var_str']} for VPU {d1['vpu']} saved to {d1['outfile']}" ) else: logger.warning("No output file specified for spatial map plot. Skipping save.")
[docs] def plot_histogram(data: pd.DataFrame, d1: dict) -> None: """Generate histogram plots for multiple columns in dataframe. Args: data : pd.DataFrame DataFrame containing the donor-receiver pairing results. d1: dict information needed for creating the plot, including: - columns: list of column names to plot - title: title of the plot - var_str: variable string for the plot - vpu: VPU identifier - algorithm: algorithm used for pairing (optional) """ columns = d1.get("columns", []) if not columns: logger.warning("No valid columns specified to plot for histogram.") return n = len(columns) ncols = min(2, n) # cap at 2 columns, but don’t exceed n nrows = math.ceil(n / ncols) fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(6 * ncols, 4 * nrows)) axes = np.atleast_1d(axes).flatten() # always normalize to an array for i, col in enumerate(columns): sns.histplot( data[col], ax=axes[i], kde=True, bins=30, color="lightblue", edgecolor="grey", stat="density", ) sns.kdeplot(data[col], ax=axes[i], color="darkblue", linewidth=1.5) axes[i].set_title(f"{col}") # Hide any unused subplots for j in range(len(columns), len(axes)): axes[j].set_visible(False) title = d1.get("title", f"Histograms of {d1['var_str']}: VPU {d1['vpu']}") if algorithm := d1.get("algorithm", None): title += f" (Algorithm: {algorithm})" fig.suptitle(title, fontsize=16, fontweight="bold") plt.xlabel(d1.get("xlabel", "Value")) plt.ylabel(d1.get("ylabel", "Density")) plt.tight_layout(rect=[0, 0, 1, 0.95]) # leave space for main title # save the figure if d1.get("outfile") is not None: plt.savefig(d1["outfile"], bbox_inches="tight") plt.close(fig) logger.info( f"Histogram of {d1['var_str']} for VPU {d1['vpu']} saved to {d1['outfile']}" ) else: logger.warning("No output file specified for histogram plot. Skipping save.")
# NOT USED YET (keep for future use)
[docs] def plot_point_map( data: pd.DataFrame, d1: dict, base_layers: list[gpd.GeoDataFrame] = None, ) -> None: """Plot the spatial distribution of donors within a VPU and its buffer zone. Args: data : pd.DataFrame DataFrame containing the locations to be plotted. d1: dict Information needed for creating the plot, including: base_layers: list of GeoDataFrames Optional base layers to plot under the points (e.g., VPU boundaries, buffer zones). """ if data.empty: logger.warning("No data provided for point map. Skipping plot.") return # visualize the donors selected fig, ax = plt.subplots(figsize=(8, 5)) # Plot base layers colors_base = ["lightgray", "lightpink", "lightgreen", "lightblue"] colors_base = list(islice(cycle(colors_base), len(base_layers))) colors_point = ["red", "blue", "green", "orange"] colors_point = list(islice(cycle(colors_point), len(data))) for base_layer, clr_base, clr_point in zip(base_layers, colors_base, colors_point): base_layer.plot(ax=ax, color=clr_base, edgecolor="black", alpha=0) data.plot(ax=ax, color=clr_point, markersize=10)
# TODO : Add legend for base layers and points and save the figure