"""Utility functions for plotting spatial maps and histograms etc.
plot_utils.py
Functions:
- _plot_columns_by_dtype: Plot multiple columns of a GeoDataFrame based on their data types.
- plot_spatial_map: Generate a spatial map plot for the given data.
- plot_histogram: Generate histogram plot for the spatial or attribute distance between donors and receivers.
- plot_point_map: Plot the spatial distribution of locations as points on a base layer map.
"""
import logging
import math
from itertools import cycle, islice
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
# from matplotlib.lines import Line2D
from shapely.ops import unary_union
logger = logging.getLogger(__name__)
def _plot_columns_by_dtype(
gdf: gpd.GeoDataFrame,
columns: list[str],
fillna_value=None,
num_bins: int = None,
cmap_numeric: str = "viridis",
cmap_categorical: str = "Set3",
figsize=(10, 6),
ncols: int = 3,
antialiased: bool = False,
):
"""Plot multiple GeoDataFrame columns in subplots based on data type.
Args:
gdf: GeoDataFrame to plot
columns: list of column names to plot
fillna_value: value to fill NaNs
num_bins: number of bins for numeric columns (optional)
cmap_numeric: colormap for numeric values
cmap_categorical: colormap for categorical values
figsize: figure size
ncols: number of columns in the subplot grid (default is 3)
antialiased: whether to disable antialiasing for the plots
"""
n = len(columns)
nrows = math.ceil(n / ncols)
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
axes = np.array(axes).reshape(-1) # Flatten in case of 2D grid
# get the outer boundary of the GeoDataFrame
combined_polygon = unary_union(gdf.geometry)
boundary = combined_polygon.boundary
for i, column in enumerate(columns):
ax = axes[i]
dtype = gdf[column].dtype
# Copy gdf to avoid modifying original
gdf_plot = gdf.copy()
if fillna_value is not None:
gdf_plot[column] = gdf_plot[column].fillna(fillna_value)
if pd.api.types.is_numeric_dtype(dtype):
if num_bins is not None:
gdf_plot["__binned__"] = pd.cut(gdf_plot[column], bins=num_bins)
gdf_plot.plot(
ax=ax,
column="__binned__",
cmap=cmap_numeric,
legend=True,
edgecolor="none",
linewidth=0,
)
else:
gdf_plot.plot(
ax=ax,
column=column,
cmap=cmap_numeric,
legend=True,
edgecolor="none",
linewidth=0,
)
elif pd.api.types.is_categorical_dtype(dtype) or pd.api.types.is_object_dtype(
dtype
):
gdf_plot[column] = gdf_plot[column].astype("category")
gdf_plot.plot(
ax=ax,
column=column,
cmap=cmap_categorical,
legend=True,
edgecolor="none",
linewidth=0,
)
else:
ax.set_title(f"Unsupported dtype: {column}")
ax.axis("off")
continue
# Even when edgecolor="none" and linewidth=0, antialiasing can cause matplotlib to blend edges
# between adjacent polygons, resulting in faint grey or black lines. Disable antialiasing if needed.
for coll in ax.collections:
coll.set_antialiased(antialiased)
# plot outer boundary
gpd.GeoSeries(boundary).plot(ax=ax, color="black", linewidth=0.5)
ax.set_title(column)
ax.set_axis_off()
# make sure legend is outside the plot if it exists
legend = ax.get_legend()
if legend is not None:
# legend.set_bbox_to_anchor((1.05, 1))
legend.set_bbox_to_anchor((0.5, -0.20)) # bottom center
legend.set_loc("lower center")
legend.set_frame_on(True)
# Turn off any unused subplots
for j in range(i + 1, len(axes)):
axes[j].axis("off")
return fig, axes[: i + 1] # Return only used axes
[docs]
def plot_spatial_map(gdf: gpd.GeoDataFrame, d1: dict) -> None:
"""Generate a spatial map plot for the given data.
Args:
gdf : gpd.GeoDataFrame
GeoDataFrame containing the data to be plotted.
d1: dict
Information needed for creating the plot
"""
# check if the required column exists, allow case insensitivity
cols_exist = [c for c in d1["columns"] if c.lower() in gdf.columns.str.lower()]
cols_missing = set(d1["columns"]) - set(gdf.columns)
if not cols_exist:
logger.warning(
f"Columns {cols_missing} not found in gdf. Cannot create spatial map plot."
)
return
if cols_missing:
logger.warning(
f"Excluding missing columns {cols_missing} from spatial map plot."
)
_, ax = plt.subplots(figsize=(8, 6))
# create spatial map
fig, axes = _plot_columns_by_dtype(
gdf,
columns=cols_exist,
fillna_value=d1.get("fillna_value", None),
num_bins=d1.get("num_bins", None),
cmap_numeric=d1.get("cmap_numeric", "viridis"),
cmap_categorical=d1.get("cmap_categorical", "Set3"),
ncols=d1.get("ncols", 3),
)
# gdf.plot(ax=ax, column=d1["column"], cmap="viridis", legend=True, edgecolor=None)
d1["title"] = d1.get("title", f"Spatial Map of {d1['var_str']}: VPU {d1['vpu']}")
if algorithm := d1.get("algorithm", None):
d1["title"] += f" (Algorithm: {algorithm})"
fig.suptitle(d1["title"], fontsize=16, fontweight="bold")
plt.tight_layout(
rect=[0, 0.03, 1, 0.95]
) # Adjust layout to make room for the title
# save the figure
if d1.get("outfile") is not None:
plt.savefig(d1["outfile"], bbox_inches="tight")
plt.close()
logger.info(
f"Spatial map of {d1['var_str']} for VPU {d1['vpu']} saved to {d1['outfile']}"
)
else:
logger.warning("No output file specified for spatial map plot. Skipping save.")
[docs]
def plot_histogram(data: pd.DataFrame, d1: dict) -> None:
"""Generate histogram plots for multiple columns in dataframe.
Args:
data : pd.DataFrame
DataFrame containing the donor-receiver pairing results.
d1: dict
information needed for creating the plot, including:
- columns: list of column names to plot
- ncols: number of columns in the subplot grid (default is 3)
- title: title of the plot
- var_str: variable string for the plot
- vpu: VPU identifier
- algorithm: algorithm used for pairing (optional)
"""
columns = d1.get("columns", [])
if not columns:
logger.warning("No valid columns specified to plot for histogram.")
return
# filter to numeric columns only
numeric_columns = data.select_dtypes(include=[np.number]).columns.tolist()
if not numeric_columns:
logger.warning("No numeric columns found in the dataframe to plot histograms.")
return
if not set(columns).issubset(set(numeric_columns)):
logger.warning(
f"Some specified columns {set(columns) - set(numeric_columns)} are not numeric. "
"Skipping these columns for histogram plots."
)
columns = [col for col in columns if col in numeric_columns]
# Filter to columns that exist in data, allow case insensitivity
valid_columns = [col for col in columns if col.lower() in data.columns.str.lower()]
if not valid_columns:
raise ValueError("None of the specified columns exist in the dataframe.")
missing_columns = set(columns) - set(valid_columns)
if missing_columns:
logger.warning(
f"Columns {missing_columns} not found in data. Only plotting valid columns: {valid_columns}"
)
n = len(valid_columns)
ncols = d1.get("ncols", 3)
nrows = math.ceil(n / ncols)
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(6 * ncols, 4 * nrows))
axes = axes.flatten() # Flatten in case of single row
for i, col in enumerate(valid_columns):
sns.histplot(
data[col],
ax=axes[i],
kde=True,
bins=30,
color="lightblue",
edgecolor="grey",
stat="density",
)
sns.kdeplot(data[col], ax=axes[i], color="darkblue", linewidth=1.5)
axes[i].set_title(f"{col}")
# Hide any unused subplots
for j in range(len(valid_columns), len(axes)):
axes[j].set_visible(False)
title = d1.get("title", f"Histograms of {d1['var_str']}: VPU {d1['vpu']}")
if algorithm := d1.get("algorithm", None):
title += f" (Algorithm: {algorithm})"
fig.suptitle(title, fontsize=16, fontweight="bold")
plt.xlabel(d1.get("xlabel", "Value"))
plt.ylabel(d1.get("ylabel", "Density"))
plt.tight_layout(rect=[0, 0, 1, 0.95]) # leave space for main title
# save the figure
if d1.get("outfile") is not None:
plt.savefig(d1["outfile"], bbox_inches="tight")
plt.close()
logger.info(
f"Histogram of {d1['var_str']} for VPU {d1['vpu']} saved to {d1['outfile']}"
)
else:
logger.warning("No output file specified for histogram plot. Skipping save.")
# NOT USED YET (for future use)
[docs]
def plot_point_map(
data: pd.DataFrame,
d1: dict,
base_layers: list[gpd.GeoDataFrame] = None,
) -> None:
"""Plot the spatial distribution of donors within a VPU and its buffer zone.
Args:
data : pd.DataFrame
DataFrame containing the locations to be plotted.
d1: dict
Information needed for creating the plot, including:
base_layers: list of GeoDataFrames
Optional base layers to plot under the points (e.g., VPU boundaries, buffer zones).
"""
if data.empty:
logger.warning("No data provided for point map. Skipping plot.")
return
# visualize the donors selected
fig, ax = plt.subplots(figsize=(8, 5))
# Plot base layers
colors_base = ["lightgray", "lightpink", "lightgreen", "lightblue"]
colors_base = list(islice(cycle(colors_base), len(base_layers)))
colors_point = ["red", "blue", "green", "orange"]
colors_point = list(islice(cycle(colors_point), len(data)))
for base_layer, clr_base, clr_point in zip(base_layers, colors_base, colors_point):
base_layer.plot(ax=ax, color=clr_base, edgecolor="black", alpha=0)
data.plot(ax=ax, color=clr_point, markersize=10)
# TODO : Add legend for base layers and points and save the figure