"""
Permutation Feature Importance (PFI) interpretation method for ordinal regression models.
This module implements the Permutation Feature Importance method for analyzing feature
importance in ordinal regression models. PFI works by measuring how model performance
changes when feature values are randomly permuted, breaking the relationship between
the feature and the target variable.
Key Features:
- Global feature importance analysis through feature permutation
- Support for multiple evaluation metrics
- Robust importance estimation through multiple permutations
- Comprehensive visualization of feature importance
- Support for both classification and probability-based metrics
The implementation is particularly useful for:
- Understanding feature importance in ordinal regression models
- Identifying key features that influence model predictions
- Analyzing feature importance across different evaluation metrics
- Comparing feature importance between different models
"""
from typing import Optional, List, Dict, Union, Callable, Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.base import BaseEstimator
from .base_interpretation import BaseInterpretation
from sklearn.inspection import permutation_importance
from ..utils.evaluation_metrics import (
mze, mae, mse, adjacent_accuracy, weighted_kappa,
cem, spearman_correlation, kendall_tau,
ranked_probability_score, ordinal_weighted_ce,
evaluate_ordinal_model
)
[docs]
class PFI(BaseInterpretation):
"""
Permutation Feature Importance (PFI) interpretation method for feature importance.
This class implements the PFI method for analyzing feature importance in ordinal regression
models. It measures the impact of each feature by evaluating how model performance changes
when the feature values are randomly permuted.
Parameters
----------
model : object
The trained ordinal regression model. Must implement fit and predict methods.
X : pd.DataFrame
Dataset used for interpretation. Should contain the same features used
during model training.
y : np.ndarray, optional
Target labels. Required for evaluation but optional for initialization.
n_repeats : int, default=5
Number of times to permute each feature. Higher values provide more
robust importance estimates but increase computation time.
random_state : int, default=42
Random seed for reproducibility of permutations.
Attributes
----------
available_metrics : dict
Dictionary of available evaluation metrics and their corresponding functions.
n_repeats : int
Number of permutation repeats.
random_state : int
Random state for reproducibility.
original_predictions : array-like
Predictions from the original model.
original_results : dict
Evaluation results from the original model.
Raises
------
ValueError
If X is not a pandas DataFrame or if model predictions fail.
"""
[docs]
def __init__(self, model, X: pd.DataFrame, y: np.ndarray = None, n_repeats=5, random_state=42, show_intervals: bool = False):
"""
Initialize the Permutation Feature Importance interpretation method.
Parameters
----------
model : object
The trained ordinal regression model
X : pd.DataFrame
Dataset used for interpretation
y : np.ndarray, optional
Target labels
n_repeats : int, default=5
Number of times to permute each feature
random_state : int, default=42
Random seed for reproducibility
Raises
------
ValueError
If X is not a pandas DataFrame or if model predictions fail
"""
super().__init__(model, X, y)
if not isinstance(X, pd.DataFrame):
raise ValueError("X must be a pandas DataFrame")
self.available_metrics = {
'mze': mze,
'mae': mae,
'mse': mse,
'adjacent_accuracy': adjacent_accuracy,
'weighted_kappa_quadratic': lambda yt, yp: weighted_kappa(yt, yp, weights='quadratic'),
'weighted_kappa_linear': lambda yt, yp: weighted_kappa(yt, yp, weights='linear'),
'cem': cem,
'spearman_correlation': spearman_correlation,
'kendall_tau': kendall_tau,
'ranked_probability_score': ranked_probability_score,
'ordinal_weighted_ce_linear': lambda yt, yp: ordinal_weighted_ce(yt, yp, alpha=1),
'ordinal_weighted_ce_quadratic': lambda yt, yp: ordinal_weighted_ce(yt, yp, alpha=2),
}
self.n_repeats = n_repeats
self.random_state = random_state
# Whether to draw mean±std intervals on plots
self.show_intervals = show_intervals
if self.y is not None:
try:
self.original_predictions = self.model.predict(self.X)
self.original_results = evaluate_ordinal_model(self.y, self.original_predictions)
except Exception as e:
raise ValueError(f"Failed to get original model predictions: {str(e)}")
[docs]
def _create_scoring_func(self, metric_name, metric_func):
"""
Create a scoring function for a specific metric.
This method creates a scoring function that can be used with sklearn's
permutation_importance. The function handles both class predictions and
probability predictions, and ensures proper sign for different metric types.
Parameters
----------
metric_name : str
Name of the metric to create scoring function for
metric_func : callable
The metric function to use for scoring
Returns
-------
callable
A scoring function that takes (estimator, X, y) as arguments and
returns a score
Raises
------
ValueError
If score calculation fails for the metric
"""
def scoring_func(estimator, X, y):
try:
# Get class predictions
y_pred = estimator.predict(X)
# Try to get probability predictions if available
try:
y_pred_proba = estimator.predict_proba(X)
# Use probability predictions for metrics that support them
if metric_name in ['ranked_probability_score', 'ordinal_weighted_ce_linear', 'ordinal_weighted_ce_quadratic']:
score = metric_func(y, y_pred_proba)
else:
# Use class predictions for other metrics
score = metric_func(y, y_pred)
except (AttributeError, NotImplementedError):
# If predict_proba is not available, use class predictions
score = metric_func(y, y_pred)
# For metrics where lower is better, return negative score
if metric_name in ['mae', 'mse', 'mze', 'ranked_probability_score',
'ordinal_weighted_ce_linear', 'ordinal_weighted_ce_quadratic']:
return -score
return score
except Exception as e:
raise ValueError(f"Failed to calculate score for metric {metric_name}: {str(e)}")
return scoring_func
[docs]
def explain(self, observation_idx=None, feature_subset=None, plot=False, metrics=None, title = True):
"""
Generate Permutation Feature Importance explanations.
This method computes feature importance by measuring how model performance
changes when feature values are randomly permuted. The importance score
represents the average change in model performance across multiple
permutations.
Two modes are supported:
1. Global mode (default): Evaluates feature importance across the entire dataset.
2. Local mode (when observation_idx is provided): Evaluates feature importance
for a specific instance by permuting feature values for that instance only.
Parameters
----------
observation_idx : int, optional
Index of specific instance to analyze (local explanation). If None, computes global PFI.
feature_subset : list, optional
List of feature names or indices to consider (for permutation and output only)
plot : bool, default=False
Whether to create visualizations
metrics : list, optional
List of metrics to use for feature importance calculation
title : bool, default=True
Whether to add a suptitle to the visualization
Returns
-------
dict
Dictionary containing feature importance scores for each metric.
In local mode, returns per-feature importance for the specified instance.
"""
if not self.model.is_fitted_:
self.model.fit(self.X, self.y) # Ensure model is fitted
all_features = self.X.columns.tolist()
if feature_subset is None:
permute_features = all_features
feature_idxs = list(range(len(all_features)))
else:
if all(isinstance(f, int) for f in feature_subset):
permute_features = [all_features[i] for i in feature_subset]
feature_idxs = feature_subset
else:
permute_features = feature_subset
feature_idxs = [all_features.index(f) for f in permute_features]
# Set metrics to use
if metrics is None:
metrics_to_use = list(self.available_metrics.keys())
else:
invalid_metrics = [m for m in metrics if m not in self.available_metrics]
if invalid_metrics:
raise ValueError(f"Invalid metrics: {invalid_metrics}. Available metrics: {list(self.available_metrics.keys())}")
metrics_to_use = metrics
if observation_idx is not None:
# Local (instance-level) PFI
# Exclude metrics not suitable for single-instance explanation
metrics_to_exclude = ['spearman_correlation', 'kendall_tau', 'weighted_kappa_quadratic', 'weighted_kappa_linear']
metrics_to_use = [m for m in metrics_to_use if m not in metrics_to_exclude]
from ..utils.evaluation_metrics import _get_class_counts
X_instance = self.X.iloc[[observation_idx]]
y_instance = self.y.iloc[[observation_idx]] if hasattr(self.y, 'iloc') else np.array([self.y[observation_idx]])
rng = np.random.RandomState(self.random_state)
results = {}
# For local explanation, some metrics require class counts from the whole dataset
whole_dataset_class_counts = _get_class_counts(self.y)
original_pred = self.model.predict(X_instance)
try:
original_pred_proba = self.model.predict_proba(X_instance)
except (AttributeError, NotImplementedError):
original_pred_proba = None
original_results = evaluate_ordinal_model(y_instance, original_pred, original_pred_proba, metrics=metrics_to_use, class_counts=whole_dataset_class_counts, zero_indexed=True)
importances = {metric: [] for metric in metrics_to_use}
for feature in permute_features:
feature_importances = {metric: [] for metric in metrics_to_use}
for _ in range(self.n_repeats*20):
X_permuted = X_instance.copy()
permuted_value = rng.choice(self.X[feature].values)
X_permuted.loc[X_permuted.index[0], feature] = permuted_value
permuted_pred = self.model.predict(X_permuted)
try:
permuted_pred_proba = self.model.predict_proba(X_permuted)
except (AttributeError, NotImplementedError):
permuted_pred_proba = None
permuted_results = evaluate_ordinal_model(y_instance, permuted_pred, permuted_pred_proba, metrics=metrics_to_use, class_counts=whole_dataset_class_counts, zero_indexed=True)
for metric in metrics_to_use:
original_score = original_results[metric]
permuted_score = permuted_results[metric]
if metric in ['mae', 'mse', 'mze', 'ranked_probability_score', 'ordinal_weighted_ce_linear', 'ordinal_weighted_ce_quadratic']:
drop = permuted_score - original_score
else:
drop = original_score - permuted_score
feature_importances[metric].append(drop)
for metric in metrics_to_use:
importances[metric].append(feature_importances[metric])
results = {}
for metric in metrics_to_use:
metric_importances = np.array(importances[metric])
results[metric] = {
'features': permute_features,
'importances_mean': metric_importances.mean(axis=1),
'importances_std': metric_importances.std(axis=1),
'importances': metric_importances
}
if plot:
self._plot_feature_importance(results, metrics=metrics_to_use, title=title)
return results
else:
# Global PFI (existing logic)
results = {}
for metric_name in metrics_to_use:
metric_func = self.available_metrics[metric_name]
scoring_func = self._create_scoring_func(metric_name, metric_func)
result = permutation_importance(
self.model, self.X, self.y,
n_repeats=self.n_repeats,
n_jobs=-1, # Use all available CPU cores
random_state=self.random_state,
scoring=scoring_func
)
importances_mean = result.importances_mean[feature_idxs]
importances_std = result.importances_std[feature_idxs]
importances = result.importances[feature_idxs]
results[metric_name] = {
'features': permute_features,
'importances_mean': importances_mean,
'importances_std': importances_std,
'importances': importances
}
if plot:
self._plot_feature_importance(results, metrics=metrics_to_use, title=title)
return results
[docs]
def _plot_feature_importance(self, results, metrics=None, title=True):
"""
Plot feature importance scores for each metric.
This method creates bar plots showing feature importance scores for each metric.
The plots are arranged in a grid, with one subplot per metric. Features are
sorted by importance score, and error bars show the standard deviation across
permutations.
Parameters
----------
results : dict
Dictionary containing all results from the PFI analysis
metrics : list, optional
List of metrics to plot (defaults to all in results)
title : bool, default=True
Whether to add a suptitle to the visualization
- Error metrics are shown in red with "Increase" in the title
- Other metrics are shown in green with "Drop" in the title
- Feature names are rotated 90 degrees for better readability
- Score values are displayed in the middle of each bar
"""
import math
metric_abbr = {
'adjacent_accuracy': 'AA',
'weighted_kappa_linear': 'LWK',
'weighted_kappa_quadratic': 'QWK',
'spearman_correlation': 'Rho',
'kendall_tau': 'Tau',
'ranked_probability_score': 'RPS',
'ordinal_weighted_ce_linear': 'LW-OCE',
'ordinal_weighted_ce_quadratic': 'QW-OCE',
'mae': 'MAE',
'mse': 'MSE',
'mze': 'MZE',
'cem': 'CEM',
}
if metrics is None:
metrics = list(results.keys())
n_metrics = len(metrics)
n_cols = min(4, n_metrics) if n_metrics > 1 else 1
n_rows = math.ceil(n_metrics / n_cols)
fig, axes = plt.subplots(n_rows, n_cols, figsize=(7 * n_cols, 4.5 * n_rows))
if n_metrics == 1:
axes = [axes]
else:
axes = axes.flatten()
for i, metric in enumerate(metrics):
metric_result = results[metric]
# Sort features by importance
sorted_indices = np.argsort(metric_result['importances_mean'])[::-1]
features = np.array(metric_result['features'])[sorted_indices]
means = metric_result['importances_mean'][sorted_indices]
stds = metric_result['importances_std'][sorted_indices]
abbr = metric_abbr.get(metric, metric)
color = 'red' if metric in ['mae', 'mse', 'mze', 'ranked_probability_score',
'ordinal_weighted_ce_linear', 'ordinal_weighted_ce_quadratic'] else 'green'
title_suffix = "Increase" if color == 'red' else "Drop"
ax = axes[i]
bars = ax.bar(features, means, color=color, alpha=0.85)
# Add interval lines if requested
if self.show_intervals:
for bar, mean_val, std_val in zip(bars, means, stds):
x_left = bar.get_x()
x_right = x_left + bar.get_width()
# Upper interval
ax.hlines(y=mean_val + std_val, xmin=x_left, xmax=x_right,
colors='black', linestyles='dashed', linewidth=1)
# Lower interval
ax.hlines(y=mean_val - std_val, xmin=x_left, xmax=x_right,
colors='black', linestyles='dashed', linewidth=1)
ax.set_ylabel(f'{abbr} {title_suffix}', fontsize=8, labelpad=10)
ax.grid(axis='y', linestyle='--', alpha=0.5)
if len(features) > 12:
feature_fontsize = 5
else:
feature_fontsize = 8
plt.setp(ax.get_xticklabels(), rotation=90, ha='right', fontsize=feature_fontsize)
ylim = ax.get_ylim()
y_mid = (ylim[0] + ylim[1]) / 2
for bar, mean in zip(bars, means):
ax.text(bar.get_x() + bar.get_width()/2, y_mid,
f'{mean:.3f}', ha='center', va='center', fontsize=feature_fontsize, rotation=90, clip_on=True)
for i in range(n_metrics, len(axes)):
axes[i].set_visible(False)
if title:
fig.suptitle('Permutation Feature Importance Across Metrics', fontsize=18, y=0.995)
plt.tight_layout(h_pad=5,w_pad=5)
plt.subplots_adjust(top=0.95,left=0.05)
plt.show()