Source code for ordinal_xai.interpretation.pdp_prob

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from .base_interpretation import BaseInterpretation
from ..utils import pdp_modified
from typing import Optional, List, Union, Dict, Any

[docs] class PDPProb(BaseInterpretation): """ Partial Dependence Plot (PDP) interpretation method for ordinal regression models. This version visualizes class probabilities as stacked area plots. """
[docs] def __init__(self, model: Any, X: pd.DataFrame, y: Optional[pd.Series] = None): """ Initialize the PDPProb interpretation method. Parameters ---------- model : BaseOrdinalModel The trained ordinal regression model. X : pd.DataFrame DataFrame containing the dataset used for interpretation. y : Optional[pd.Series], default=None Series containing target labels. """ super().__init__(model, X, y)
[docs] def explain( self, observation_idx: Optional[int] = None, feature_subset: Optional[Union[List[str], List[int]]] = None, plot: bool = False, max_features_per_figure: int = 12 ) -> Dict[str, Dict[str, np.ndarray]]: """ Generate Partial Dependence Plots as stacked area plots for class probabilities. Parameters ---------- observation_idx : Optional[int], default=None Ignored (PDP is a global method). feature_subset : Optional[Union[List[str], List[int]]], default=None List of feature names or indices to plot. If None, all features are used. plot : bool, default=False Whether to create visualizations. max_features_per_figure : int, default=12 Maximum number of features to display per figure (for large datasets). Returns ------- Dict[str, Dict[str, np.ndarray]] Dictionary containing PDP results for each feature: - 'grid_values': Feature values used for plotting - 'average': Average predicted probabilities at each feature value """ if feature_subset is None: feature_subset = self.X.columns.tolist() else: feature_subset = [self.X.columns[i] if isinstance(i, int) else i for i in feature_subset] num_features = len(feature_subset) results = {} if not self.model.is_fitted_: self.model.fit(self.X, self.y) # Ensure model is fitted for feature in feature_subset: feature_idx = [self.X.columns.get_loc(feature)] pdp_result = pdp_modified.partial_dependence( self.model, self.X, features=feature_idx, response_method="predict_proba" ) results[feature] = pdp_result if plot: self._plot_pdps_prob(results, feature_subset, max_features_per_figure) print(f"Generated PDPs for features: {feature_subset}") return results
[docs] def _plot_pdps_prob( self, results: Dict[str, Dict[str, np.ndarray]], feature_subset: List[str], max_features_per_figure: int = 12 ) -> None: """ Create visualization of PDP probability plots with pagination for large feature sets. Parameters ---------- results : Dict[str, Dict[str, np.ndarray]] PDP results for each feature feature_subset : List[str] List of features to plot max_features_per_figure : int, default=12 Maximum number of features to display per figure """ num_features = len(feature_subset) pages = int(np.ceil(num_features / max_features_per_figure)) for page in range(pages): start = page * max_features_per_figure end = min((page + 1) * max_features_per_figure, num_features) features_page = feature_subset[start:end] n = len(features_page) num_cols = min(n, 4) num_rows = int(np.ceil(n / num_cols)) fig, axes = plt.subplots( nrows=num_rows, ncols=num_cols, figsize=(5 * num_cols, 4 * num_rows) ) axes = np.array(axes).reshape(num_rows, num_cols) legend_handles = None for idx, feature in enumerate(features_page): row, col = divmod(idx, num_cols) ax = axes[row, col] pdp_result = results[feature] num_ranks = pdp_result['average'].shape[0] x_values = pdp_result['grid_values'][0] probabilities = pdp_result['average'] stack = ax.stackplot( x_values, probabilities, labels=[f"Rank {rank}" for rank in range(num_ranks)], alpha=0.7 ) if legend_handles is None: legend_handles = stack ax.set_xlabel(feature, fontsize=10, labelpad=6) ax.set_ylabel("Partial Dependence Probabilities", fontsize=10, labelpad=6) ax.set_title(f"PDP for {feature}", fontsize=11, pad=10) ax.grid(True) for label in ax.get_xticklabels(): label.set_rotation(30) label.set_fontsize(8) for label in ax.get_yticklabels(): label.set_fontsize(8) # Shared legend for the figure if legend_handles is not None: fig.legend( handles=legend_handles, loc='lower right', fontsize=10, ncol=3, handletextpad=0.5, columnspacing=0.5, frameon=True, borderpad=0.5, labelspacing=0.5, handlelength=1.5, borderaxespad=0.5, fancybox=True ) # Hide empty subplots for idx in range(n, num_rows * num_cols): row, col = divmod(idx, num_cols) fig.delaxes(axes[row, col]) plt.tight_layout(pad=4) plt.show()