Source code for ordinal_xai.interpretation.ice

"""
Individual Conditional Expectation (ICE) Plot implementation for ordinal regression models.

This module implements ICE plots, a model-agnostic interpretation method that shows how
a model's prediction changes as a feature value changes, while keeping other features
constant. For ordinal regression, it shows how the predicted ordinal classes changes across 
individual observations with feature variations.
"""

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from .base_interpretation import BaseInterpretation
from ..utils import pdp_modified

[docs] class ICE(BaseInterpretation): """ Individual Conditional Expectation (ICE) Plot interpretation method. ICE plots show how a model's prediction changes as a feature value changes, while keeping other features constant. For ordinal regression, it shows how the probability distribution across ordinal classes changes with feature variations. Parameters ---------- model : object The trained ordinal regression model. Must implement predict_proba method. X : pd.DataFrame Dataset used for interpretation. Should contain the same features used during model training. y : pd.Series, optional Target labels. Not required for interpretation but useful for reference. Attributes ---------- model : object The trained ordinal regression model X : pd.DataFrame Dataset used for interpretation y : pd.Series Target labels (if provided) """
[docs] def __init__(self, model, X, y=None): """ Initialize the ICE Plot interpretation method. Parameters ---------- model : object The trained ordinal regression model X : pd.DataFrame Dataset used for interpretation y : pd.Series, optional Target labels """ super().__init__(model, X, y)
[docs] def explain(self, observation_idx=None, feature_subset=None, plot=False): """ Generate Individual Conditional Expectation Plots. This method computes and optionally visualizes how the model's predictions change as feature values change. For ordinal regression, it shows how the probability distribution across classes changes with feature variations. Parameters ---------- observation_idx : int, optional Index of specific instance to highlight in the plot. If provided, only this instance's ICE curves will be shown along with the average (PDP). feature_subset : list, optional List of feature names or indices to plot. If None, all features are used. plot : bool, default=False Whether to create visualizations of the ICE plots. Returns ------- dict Dictionary containing ICE results for each feature: - 'grid_values': Feature values used for prediction - 'average': Average predictions (PDP) for each class - 'individual': Individual predictions for each instance and class Notes ----- - For ordinal regression, the plots show probability changes for each class - The average curve (PDP) shows the overall effect of the feature - Individual curves show instance-specific effects - For categorical features, exact feature values are used - For numerical features, a grid of values is used """ if feature_subset is None: feature_subset = self.X.columns.tolist() else: feature_subset = [self.X.columns[i] if isinstance(i, int) else i for i in feature_subset] num_features = len(feature_subset) num_cols = min(num_features, 4) # Max 4 plots per row num_rows = int(np.ceil(num_features / num_cols)) # Compute required rows if not self.model.is_fitted_: self.model.fit(self.X, self.y) # Ensure model is fitted results = {} # Compute ICE curves for each feature for idx, feature in enumerate(feature_subset): feature_idx = [self.X.columns.get_loc(feature)] ice_result = pdp_modified.partial_dependence( self.model, self.X, features=feature_idx, kind="both" ) results[feature] = ice_result # Create visualizations if requested if plot: fig, axes = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(5 * num_cols, 4 * num_rows)) if num_features == 1: axes = np.array([[axes]]) elif num_features <= num_cols: axes = axes.reshape(1, -1) for idx, feature in enumerate(feature_subset): row, col = divmod(idx, num_cols) ax = axes[row, col] ice_result = results[feature] x_values = ice_result['grid_values'][0] averaged_predictions = ice_result['average'] # Shape: (n_classes, n_grid_points) individual_predictions = ice_result['individual'] # Shape: (n_classes, n_instances, n_grid_points) num_ranks = averaged_predictions.shape[0] # Plot curves based on whether observation_idx is specified if observation_idx is not None: # Only plot the specified instance and average for rank in range(num_ranks): # Plot the specified instance ax.plot(x_values, individual_predictions[rank, observation_idx, :], color=f'C{rank}', linewidth=2, label=f'Instance {observation_idx} Rank') # Plot the average ax.plot(x_values, averaged_predictions[rank], color=f'C{rank}', linestyle='--', linewidth=2, label=f'Average Rank (PDP)') # Add marker for original feature value original_value = self.X.iloc[observation_idx][feature] for rank in range(num_ranks): # Find the closest grid point to the original value if isinstance(original_value, (int, float)): closest_idx = np.argmin(np.abs(x_values - original_value)) else: # For categorical features, find the exact match closest_idx = np.where(x_values == original_value)[0][0] ax.scatter(original_value, individual_predictions[rank, observation_idx, closest_idx], color=f'C{rank}', s=100, zorder=5) # Add vertical line at original feature value ymin, ymax = ax.get_ylim() ax.vlines(x=original_value, ymin=ymin, ymax=ymax, colors='black', linestyles='dashed', linewidth=1.5, zorder=1) else: # Plot all instances and average for i in range(len(self.X)): for rank in range(num_ranks): ax.plot(x_values, individual_predictions[rank, i, :], color=f'C{rank}', alpha=0.1, linewidth=0.5) # Plot the average curves on top for rank in range(num_ranks): ax.plot(x_values, averaged_predictions[rank], color=f'C{rank}', linestyle='--', linewidth=2, label=f'Average Rank (PDP)') ax.set_xlabel(feature, fontsize=12, labelpad=6) ax.set_ylabel("Prediction", fontsize=12, labelpad=6) ax.set_title(f"ICE Plot for {feature}", fontsize=14, pad=15) ax.grid() ax.legend() # Hide empty subplots for idx in range(num_features, num_rows * num_cols): row, col = divmod(idx, num_cols) fig.delaxes(axes[row, col]) plt.tight_layout(pad=3.0) # Increase padding to avoid overlap plt.show() print(f"Generated ICE Plots for features: {feature_subset}") return results