"""
Local Interpretable Model-agnostic Explanations (LIME) for ordinal regression models.
This module implements LIME for ordinal regression models, providing local explanations
by fitting interpretable surrogate models to explain individual predictions. The implementation
extends standard LIME to handle ordinal data by comparing predictions with adjacent or
following classes.
Key Features:
- Local explanations for individual predictions
- Support for both logistic regression and decision tree surrogate models
- Multiple sampling strategies (grid, uniform, permutation)
- Customizable kernel functions for sample weighting
- Visualization of feature importance through coefficients or decision trees
- Support for both numerical and categorical features
The implementation is particularly useful for:
- Understanding individual predictions in ordinal regression models
- Identifying key features that influence specific predictions
- Comparing feature importance across different classes
- Providing interpretable explanations for model decisions
"""
from typing import Optional, List, Dict, Union, Callable, Tuple
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier, plot_tree
from .base_interpretation import BaseInterpretation
import matplotlib.pyplot as plt
import gower
from sklearn.model_selection import ParameterGrid, train_test_split
import logging
from sklearn.metrics import accuracy_score, log_loss
import re
logger = logging.getLogger(__name__)
[docs]
class LIME(BaseInterpretation):
"""
Local Interpretable Model-agnostic Explanations for ordinal regression models.
This class implements LIME for ordinal regression models, providing local explanations
by fitting interpretable surrogate models to explain individual predictions. The implementation
extends standard LIME to handle ordinal data by comparing predictions with adjacent or
following classes.
Parameters
----------
model : object
The trained ordinal regression model. Must implement predict and transform methods.
X : pd.DataFrame
Dataset used for interpretation. Should contain the same features used
during model training.
y : np.ndarray, optional
Target labels. Not required for interpretation but useful for reference.
comparison_method : str, default='one_vs_following'
Method for comparing classes:
- 'one_vs_next': Compare with adjacent classes only
- 'one_vs_following': Compare with all higher/lower classes
model_type : str, default='logistic'
Type of surrogate model to use:
- 'logistic': Logistic regression model
- 'decision_tree': Decision tree model
kernel_width : float, default=0.75
Width of the exponential kernel for sample weighting. Controls how quickly
the weight of samples decreases with distance.
custom_kernel : callable, optional
Custom kernel function for sample weighting. Should take distances as input
and return weights.
sampling : str, default='permute'
Sampling strategy for generating perturbed samples:
- 'grid': Generate samples on a grid (for small feature spaces)
- 'uniform': Sample uniformly from feature ranges
- 'permute': Permute feature values from the dataset
max_samples : int, default=10000
Maximum number of samples to generate for surrogate model training.
Attributes
----------
model : object
The trained ordinal regression model
X : pd.DataFrame
Training data
y : np.ndarray
Target labels
comparison_method : str
Method for comparing classes
model_type : str
Type of surrogate model
kernel_width : float
Width of the exponential kernel
custom_kernel : callable
Custom kernel function
sampling : str
Sampling strategy
max_samples : int
Maximum number of samples
Raises
------
ValueError
If comparison_method is invalid, kernel_width is non-positive,
or sampling strategy is invalid
"""
[docs]
def __init__(self,
model,
X: pd.DataFrame,
y: Optional[np.ndarray] = None,
comparison_method: str = 'one_vs_following',
model_type: str = "logistic",
kernel_width: float = 0.75,
custom_kernel: Optional[Callable] = None,
sampling: str = "permute",
max_samples: int = 10000,
random_state: int = 42,
**surrogate_kwargs) -> None:
"""
Initialize LIME interpretation.
Parameters
----------
model : object
The trained ordinal regression model
X : pd.DataFrame
Training data
y : np.ndarray, optional
Target labels
comparison_method : "one_vs_following" | "one_vs_next", default='one_vs_following'
Method for comparing classes
model_type : "logistic" | "decision_tree", default='logistic'
Type of surrogate model to use
kernel_width : float, default=0.75
Width of the exponential kernel
custom_kernel : callable, optional
Custom kernel function
sampling : "grid" | "uniform" | "permute", default='permute'
Sampling strategy
max_samples : int, default=10000
Maximum number of samples
Raises
------
ValueError
If comparison_method is invalid, kernel_width is non-positive,
or sampling strategy is invalid
"""
super().__init__(model, X, y)
if comparison_method not in ["one_vs_next", "one_vs_following"]:
raise ValueError("comparison_method must be either 'one_vs_next' or 'one_vs_following'")
if kernel_width <= 0:
raise ValueError("kernel_width must be positive")
if sampling not in ["grid", "uniform", "permute"]:
raise ValueError("sampling must be one of: 'grid', 'uniform', 'permute'")
self.model = model
self.comparison_method = comparison_method
self.model_type = model_type
self.kernel_width = kernel_width
self.custom_kernel = custom_kernel
self.sampling = sampling
self.max_samples = max_samples
self.random_state = random_state
self.surrogate_kwargs = surrogate_kwargs
logger.info(f"Initialized LIME with {comparison_method} comparison method and {sampling} sampling")
[docs]
def _compute_weights(self,
samples: pd.DataFrame,
observation: pd.Series,
kernel: Optional[Callable] = None) -> np.ndarray:
"""
Compute sample weights using exponential kernel with Gower's distance.
This method calculates weights for perturbed samples based on their distance
from the observation being explained. It uses Gower's distance to handle
both numerical and categorical features.
Parameters
----------
samples : pd.DataFrame
DataFrame containing perturbed samples
observation : pd.Series
Series containing the observation to explain
kernel : callable, optional
Custom kernel function for weight calculation
Returns
-------
np.ndarray
Array of weights for each sample
Notes
-----
- Uses Gower's distance to handle mixed data types
- Applies exponential kernel by default
- Weights decrease exponentially with distance
"""
# Convert observation to DataFrame
observation_df = pd.DataFrame([observation])
# Calculate Gower's distances from observation to all samples
distances = gower.gower_matrix(samples, observation_df).flatten()
if kernel is None:
# Apply exponential kernel
weights = np.exp(-(distances ** 2) / (self.kernel_width ** 2))
else:
weights = kernel(distances)
return weights
[docs]
def _get_comparison_labels(self,
pred_class: int,
samples_preds: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Get binary labels for comparison based on method.
This method creates binary labels for the surrogate model based on the
comparison method chosen. It identifies which samples are in higher or
lower classes than the prediction.
Parameters
----------
pred_class : int
Predicted class of the observation
samples_preds : np.ndarray
Array of predictions for all samples
Returns
-------
Tuple[np.ndarray, np.ndarray]
Tuple of (higher_mask, lower_mask) arrays indicating which samples
are in higher/lower classes than the prediction
Notes
-----
- For 'one_vs_next', only compares with adjacent classes
- For 'one_vs_following', compares with all higher/lower classes
"""
if self.comparison_method == 'one_vs_next':
# Compare with next class
higher_mask = samples_preds == pred_class + 1
lower_mask = samples_preds == pred_class - 1
else: # one_vs_following
# Compare with all higher/lower classes
higher_mask = samples_preds > pred_class
lower_mask = samples_preds < pred_class
return higher_mask, lower_mask
[docs]
def _plot_coefficients(self,
higher_coef: Optional[np.ndarray],
lower_coef: Optional[np.ndarray],
feature_names: List[str],
observation_idx: int,
observation: pd.Series,
pred_class: int,
higher_fidelity_in: Optional[float] = None,
higher_fidelity_out: Optional[float] = None,
lower_fidelity_in: Optional[float] = None,
lower_fidelity_out: Optional[float] = None,
is_effect: bool = False) -> None:
"""
Plot horizontal bars for either surrogate coefficients or predictor
effects (coefficient × value) for higher / lower class comparisons.
This method creates horizontal bar plots showing the coefficients of the
logistic regression surrogate model for both higher and lower class
comparisons.
Parameters
----------
higher_coef : np.ndarray, optional
Coefficients for higher class comparison
lower_coef : np.ndarray, optional
Coefficients for lower class comparison
feature_names : List[str]
List of feature names
observation_idx : int
Index of the observation being explained
observation : pd.Series
The observation being explained
pred_class : int
Predicted class of the observation
Notes
-----
- Creates separate plots for higher and lower class comparisons
- Shows feature values in the title
- Uses blue for higher class and red for lower class
- Includes observation details in the title
"""
# Determine availability of coefficient sets
show_higher = higher_coef is not None
show_lower = lower_coef is not None
if not (show_higher or show_lower):
logger.warning("No coefficients to plot.")
return
# Compose observation info and fidelity lines
obs_header = f"Observation {observation_idx} | Predicted class: {pred_class}"
obs_values = ", ".join([f"{name}: {value}" for name, value in observation.items()])
fidelity_line = ""
if any(v is not None for v in [higher_fidelity_in, higher_fidelity_out, lower_fidelity_in, lower_fidelity_out]):
parts = []
if higher_fidelity_in is not None:
parts.append(f"Higher(in-sample): {higher_fidelity_in[0]:.3f}/{higher_fidelity_in[1]:.3f}")
if higher_fidelity_out is not None:
parts.append(f"Higher(out-of-sample): {higher_fidelity_out[0]:.3f}/{higher_fidelity_out[1]:.3f}")
if lower_fidelity_in is not None:
parts.append(f"Lower(in-sample): {lower_fidelity_in[0]:.3f}/{lower_fidelity_in[1]:.3f}")
if lower_fidelity_out is not None:
parts.append(f"Lower(out-of-sample): {lower_fidelity_out[0]:.3f}/{lower_fidelity_out[1]:.3f}")
fidelity_line = " | ".join(parts)
obs_info = f"{obs_header} | {obs_values}"
if fidelity_line:
obs_info = f"{obs_info}\n Fidelity (BCE/MZE): {fidelity_line}"
# CASE 1 – both coefficient vectors are present → create one stacked bar chart
if show_higher and show_lower:
# Interleave bars per feature: higher first, then lower, maintaining horizontal orientation
n_features = len(feature_names)
y_high = np.arange(0, n_features * 2, 2)
y_low = y_high + 1
fig_height = max(6, n_features) # scale figure height with number of rows
fig, ax = plt.subplots(figsize=(10, fig_height))
fig.suptitle(obs_info, fontsize=8)
# Plot bars
ax.barh(y_high, higher_coef, color="#4682b4", label="Higher rank")
ax.barh(y_low, lower_coef, color="#b44646", label="Lower rank")
# X-axis label
xlabel = "Predictor effect" if is_effect else "Coefficient"
ax.set_xlabel(xlabel)
# Y-tick labels duplicated for high/low rows
yticks = np.concatenate([y_high, y_low])
ylabels = []
for name in feature_names:
ylabels.extend([f"{name} ↑"])
for name in feature_names:
ylabels.extend([f"{name} ↓"])
ax.set_yticks(yticks)
ax.set_yticklabels(ylabels, fontsize=9)
xlabel = "Predictor effect" if is_effect else "Coefficient"
ax.set_xlabel(xlabel)
if self.comparison_method == "one_vs_next":
title = (
f"Surrogate Model {'Predictor Effects' if is_effect else 'Coefficients'} – Class {pred_class - 1} & {pred_class + 1}"
)
else:
title = f"Surrogate Model {'Predictor Effects' if is_effect else 'Coefficients'} – Classes < {pred_class} & > {pred_class}"
ax.set_title(title, fontsize=13, pad=10)
ax.tick_params(axis="x", labelsize=10)
ax.tick_params(axis="y", labelsize=10)
ax.legend()
plt.tight_layout()
plt.subplots_adjust(top=0.9)
plt.show()
return
# CASE 2 – only one of the coefficient/effect sets is present → single-bar chart
coef = higher_coef if show_higher else lower_coef
color = "#4682b4" if show_higher else "#b44646"
label = "Higher rank" if show_higher else "Lower rank"
fig, ax = plt.subplots(figsize=(10, 6))
fig.suptitle(obs_info, fontsize=8)
y_pos = np.arange(len(feature_names))
ax.barh(y_pos, coef, color=color, label=label)
ax.set_yticks(y_pos)
ax.set_yticklabels(feature_names, fontsize=10)
title = (
f"Surrogate Model {'Predictor Effects' if is_effect else 'Coefficients'} for Classes > {pred_class}"
if show_higher
else f"Surrogate Model {'Predictor Effects' if is_effect else 'Coefficients'} for Classes < {pred_class}"
)
if self.comparison_method == "one_vs_next":
if show_higher:
title = f"Surrogate Model {'Predictor Effects' if is_effect else 'Coefficients'} for Class {pred_class + 1}"
else:
title = f"Surrogate Model {'Predictor Effects' if is_effect else 'Coefficients'} for Class {pred_class - 1}"
ax.set_title(title, fontsize=13, pad=10)
xlabel = "Predictor effect" if is_effect else "Coefficient"
ax.set_xlabel(xlabel)
ax.tick_params(axis="x", labelsize=10)
ax.tick_params(axis="y", labelsize=10)
ax.legend()
plt.tight_layout()
plt.subplots_adjust(top=0.9)
plt.show()
[docs]
def _generate_grid_samples(self, X: pd.DataFrame) -> pd.DataFrame:
"""
Generate samples using grid sampling strategy.
This method creates a grid of samples by:
1. Identifying categorical and numerical columns
2. Using all unique values for categorical features
3. Creating evenly spaced points for numerical features
4. Combining all possible combinations
Parameters
----------
X : pd.DataFrame
Original dataset
Returns
-------
pd.DataFrame
DataFrame containing grid samples
Raises
------
ValueError
If grid size would exceed max_samples
Notes
-----
- For numerical features, uses linspace between min and max values
- For categorical features, uses all unique values
- Caps number of points per numerical feature at 100
- Samples randomly if grid size exceeds max_samples
"""
# Identify categorical and numerical columns
cat_cols = X.select_dtypes(include=["object", "category"]).columns.tolist()
num_cols = X.select_dtypes(include=[np.number]).columns.tolist()
n_features = len(X.columns)
# Throw error if 2^n_features > max_samples (for binary features, grid explosion)
if 2 ** n_features > self.max_samples:
raise ValueError(
f"Grid would have {2 ** n_features} rows, which exceeds max_samples={self.max_samples}. "
"Reduce the number of features or increase max_samples."
)
# For each categorical, use all unique values
cat_values = [X[col].unique() for col in cat_cols]
# For each numerical, use linspace across min/max
num_ranges = [(X[col].min(), X[col].max()) for col in num_cols]
n_cats = [len(vals) for vals in cat_values]
n_num = len(num_cols)
grid_dict = {}
if n_num > 0 and not n_cats: # Only numerical features
n_per_num = max(10, int(np.round(self.max_samples ** (1 / n_num))))
for col, (vmin, vmax) in zip(num_cols, num_ranges):
grid_dict[col] = np.linspace(vmin, vmax, n_per_num)
else:
n_grid_num = max(10, int(np.floor(self.max_samples / np.prod(n_cats)**(1/n_num))) if n_cats else self.max_samples)
num_grids = [min(n_grid_num, 100) for _ in num_cols] # cap at 100 per feature for safety
for col, (vmin, vmax), n in zip(num_cols, num_ranges, num_grids):
grid_dict[col] = np.linspace(vmin, vmax, n)
for col, vals in zip(cat_cols, cat_values):
grid_dict[col] = vals
# Use sklearn's ParameterGrid for grid creation
grid = list(ParameterGrid(grid_dict))
X_grid = pd.DataFrame(grid, columns=X.columns)
# If grid is too large, sample max_samples rows
if len(X_grid) > self.max_samples:
X_grid = X_grid.sample(n=self.max_samples, random_state=42).reset_index(drop=True)
return X_grid
[docs]
def _generate_permute_samples(self, X: pd.DataFrame) -> pd.DataFrame:
"""
Generate samples using permutation sampling strategy.
This method creates samples by randomly permuting the values of each
feature from the original dataset. This preserves the marginal
distribution of each feature.
Parameters
----------
X : pd.DataFrame
Original dataset
Returns
-------
pd.DataFrame
DataFrame containing permuted samples
Notes
-----
- Preserves marginal distributions of features
- Breaks correlations between features
- Number of samples controlled by max_samples parameter
"""
n_samples = self.max_samples
samples = pd.DataFrame(index=range(n_samples), columns=X.columns)
#set seed for reproducibility
np.random.seed(self.random_state)
for col in X.columns:
samples[col] = np.random.choice(X[col], n_samples, replace=True)
return samples
[docs]
def _plot_decision_tree(self,
higher_model: Optional[DecisionTreeClassifier],
lower_model: Optional[DecisionTreeClassifier],
feature_names: List[str],
observation_idx: int,
observation: pd.Series,
pred_class: int,
higher_fidelity_in: Optional[float] = None,
higher_fidelity_out: Optional[float] = None,
lower_fidelity_in: Optional[float] = None,
lower_fidelity_out: Optional[float] = None,
is_effect: bool = False) -> None:
"""
Plot decision tree surrogate models.
This method creates visualizations of the decision tree surrogate models
for both higher and lower class comparisons. The trees show the decision
rules learned by the surrogate model.
Parameters
----------
higher_model : DecisionTreeClassifier, optional
Fitted decision tree model for higher class comparison
lower_model : DecisionTreeClassifier, optional
Fitted decision tree model for lower class comparison
feature_names : List[str]
List of feature names
observation_idx : int
Index of the observation being explained
observation : pd.Series
The observation being explained
pred_class : int
Predicted class of the observation
fidelity_in : float, optional
Fidelity of the model on the training data
fidelity_out : float, optional
Fidelity of the model on the test data
Notes
-----
- Creates separate plots for higher and lower class comparisons
- Shows feature values in the title
- Uses simplified node text for better readability
- Includes observation details in the title
"""
# Determine which plots to show
show_higher = higher_model is not None
show_lower = lower_model is not None
n_plots = show_higher + show_lower
if n_plots == 0:
logger.warning("No trees to plot.")
return
fig, axes = plt.subplots(1, n_plots, figsize=(20, 10 * n_plots))
if n_plots == 1:
axes = [axes]
# Compose observation info (multi-line, compact)
obs_header = f"Observation {observation_idx} | Predicted class: {pred_class}"
obs_values = ", ".join([f"{name}: {value}" for name, value in observation.items()])
fidelity_line = ""
if any(v is not None for v in [higher_fidelity_in, higher_fidelity_out, lower_fidelity_in, lower_fidelity_out]):
parts = []
if higher_fidelity_in is not None:
parts.append(f"Higher(in-sample): {higher_fidelity_in[0]:.3f}/{higher_fidelity_in[1]:.3f}")
if higher_fidelity_out is not None:
parts.append(f"Higher(out-of-sample): {higher_fidelity_out[0]:.3f}/{higher_fidelity_out[1]:.3f}")
if lower_fidelity_in is not None:
parts.append(f"Lower(in-sample): {lower_fidelity_in[0]:.3f}/{lower_fidelity_in[1]:.3f}")
if lower_fidelity_out is not None:
parts.append(f"Lower(out-of-sample): {lower_fidelity_out[0]:.3f}/{lower_fidelity_out[1]:.3f}")
fidelity_line = " | ".join(parts)
suptitle = f"{obs_header} | {obs_values}"
if fidelity_line:
suptitle = f"{suptitle}\n Fidelity (BCE/MZE): {fidelity_line}"
fig.suptitle(suptitle, fontsize=8, y=0.97)
plot_idx = 0
if show_higher:
ax = axes[plot_idx]
tree= plot_tree(higher_model,
feature_names=feature_names,
class_names=["Same", "Higher"],
filled=True,
rounded=True,
impurity=False,
proportion=False,
label="all",
fontsize=10,
ax=ax)
#modify node texts
for i, t in enumerate(ax.texts):
text = t.get_text()
if not "True" in text and not "False" in text:
text = text.split('\n')[:-3]+[text.split('\n')[-1]]
t.set_text('\n'.join(text))
if self.comparison_method == 'one_vs_next':
title = f'Decision Tree for Class {pred_class + 1}'
else:
title = f'Decision Tree for Classes > {pred_class}'
ax.set_title(title, fontsize=13, y=0.94)
plot_idx += 1
if show_lower:
ax = axes[plot_idx]
tree = plot_tree(lower_model,
feature_names=feature_names,
class_names=["Same", "Lower"],
filled=True,
rounded=True,
impurity=False,
proportion=False,
fontsize=10,
label="all",
ax=ax)
#modify node texts
for i, t in enumerate(ax.texts):
text = t.get_text()
if not "True" in text and not "False" in text:
text = text.split('\n')[:-3]+[text.split('\n')[-1]]
t.set_text('\n'.join(text))
if self.comparison_method == 'one_vs_next':
title = f'Decision Tree for Class {pred_class - 1}'
else:
title = f'Decision Tree for Classes < {pred_class}'
ax.set_title(title, fontsize=13, y=0.94)
plt.tight_layout(h_pad=4)
plt.subplots_adjust(top=0.95)
plt.show()
[docs]
def explain(self,
observation_idx: Optional[int] = None,
feature_subset: Optional[List[Union[int, str]]] = None,
plot: bool = False,
show_coefficients: bool = False,
**kwargs) -> Dict[str, Union[List[str], np.ndarray, DecisionTreeClassifier]]:
"""
Generate LIME explanations for a specific observation.
Depending on the `show_coefficients` flag the horizontal bar plot will
display either (a) raw surrogate coefficients or (b) *predictor effects*
defined as `coefficient × local feature value` for numerical features.
Categorical one-hot features always use the raw coefficient value.
This method creates local explanations by:
1. Generating perturbed samples around the observation
2. Computing sample weights based on distance
3. Fitting surrogate models to explain the prediction
4. Visualizing the results if requested
Parameters
----------
observation_idx : int, optional
Index of the observation to explain
feature_subset : List[Union[int, str]], optional
List of feature indices or names to include
plot : bool, default=False
If True, visualise the explanation.
show_coefficients : bool, default=False
If False (default) plot predictor effects; if True plot raw
coefficients.
Whether to create visualizations
**kwargs : dict
Additional keyword arguments
Returns
-------
Dict[str, Union[List[str], np.ndarray, DecisionTreeClassifier]]
Dictionary containing:
- features: List of feature names
- higher_model: Decision tree model for higher class comparison (if model_type="decision_tree")
- lower_model: Decision tree model for lower class comparison (if model_type="decision_tree")
- higher_coef: Coefficients for higher class comparison (if model_type="logistic")
- lower_coef: Coefficients for lower class comparison (if model_type="logistic")
Raises
------
ValueError
If observation_idx is not specified or model_type is invalid
Notes
-----
- Requires observation_idx to be specified
- Supports both logistic regression and decision tree surrogate models
- Can focus on specific features using feature_subset
- Provides visualizations of coefficients or decision trees
"""
if observation_idx is None:
raise ValueError("observation_idx must be specified for LIME")
if self.model_type not in ["logistic", "decision_tree"]:
raise ValueError("model_type must be either 'logistic' or 'decision_tree'")
# Get the observation
observation = self.X.iloc[observation_idx]
if feature_subset is not None:
observation = observation[feature_subset]
# Choose samples for surrogate model fitting
if self.sampling == "grid":
samples = self._generate_grid_samples(self.X)
elif self.sampling == "uniform":
samples = self._generate_uniform_samples(self.X)
elif self.sampling == "permute":
samples = self._generate_permute_samples(self.X)
else:
samples = self.X
# Get predictions for all data points
samples_preds = self.model.predict(samples)
# Get original prediction
pred_class = self.model.predict(self.X)[observation_idx]
logger.info(f"Original prediction class: {pred_class}")
n_classes = len(np.unique(samples_preds))
# Compute sample weights
weights = self._compute_weights(samples, observation, self.custom_kernel)
# Get comparison labels
higher_mask, lower_mask = self._get_comparison_labels(pred_class, samples_preds)
# Transform features
X_transformed = self.model.transform(samples, fit=False, no_scaling=True)
# Transformed observation for effect calculation (ensure correct dtypes)
obs_transformed = self.model.transform(self.X.iloc[[observation_idx]], fit=False, no_scaling=True).iloc[0].values
feature_names = X_transformed.columns.tolist()
if feature_subset is not None:
if all(isinstance(f, int) for f in feature_subset):
idxs = feature_subset
else:
idxs = [feature_names.index(f) for f in feature_subset]
feature_names = [feature_names[i] for i in idxs]
X_transformed = X_transformed.iloc[:, idxs]
result = {'features': feature_names}
if self.model_type == "logistic":
higher_coef = None
lower_coef = None
if pred_class < n_classes - 1:
# Check if we have any positive samples for higher class
if np.sum(higher_mask) > 1:
try:
# Train/test split for fidelity estimation
idx_train, idx_test = train_test_split(np.arange(len(higher_mask)), test_size=0.2, random_state=42, stratify=higher_mask)
# Ensure at least 2 samples in each split
if np.sum(higher_mask[idx_train]) < 2:
idx_train = np.append(idx_train, np.where(higher_mask)[0])
if np.sum(higher_mask[idx_test]) < 2:
idx_test = np.append(idx_test, np.where(higher_mask)[0])
X_train, X_test = X_transformed.iloc[idx_train], X_transformed.iloc[idx_test]
y_train, y_test = higher_mask[idx_train], higher_mask[idx_test]
w_train, w_test = weights[idx_train], weights[idx_test]
higher_model = LogisticRegression(random_state=42, class_weight="balanced", max_iter=100000, **self.surrogate_kwargs)
higher_model.fit(X_train, y_train, sample_weight=w_train)
higher_coef = higher_model.coef_[0]
if not show_coefficients:
num_mask = ~np.isin(obs_transformed, [0, 1])
higher_effect = higher_coef * (obs_transformed * num_mask + (~num_mask) * 1) # multiply only numeric
result['higher_effect'] = higher_effect
# Fidelity (binary cross-entropy and zero-one loss) in and out of sample
prob_train = higher_model.predict_proba(X_train)[:, 1]
prob_test = higher_model.predict_proba(X_test)[:, 1]
in_BCE = log_loss(y_train, prob_train, sample_weight=w_train, labels=[0, 1])
out_BCE = log_loss(y_test, prob_test, sample_weight=w_test, labels=[0, 1])
zero_one_in = 1.0 - (higher_model.predict(X_train) == y_train).mean()
zero_one_out = 1.0 - (higher_model.predict(X_test) == y_test).mean()
result['higher_fidelity_in'] = (in_BCE, zero_one_in)
result['higher_fidelity_out'] = (out_BCE, zero_one_out)
# Refit surrogate on the FULL data for final explanation coefficients
higher_model_full = LogisticRegression(random_state=42, class_weight="balanced", max_iter=100000, **self.surrogate_kwargs)
higher_model_full.fit(X_transformed, higher_mask, sample_weight=weights)
higher_coef = higher_model_full.coef_[0]
if not show_coefficients:
num_mask = ~np.isin(obs_transformed, [0, 1])
higher_effect = higher_coef * (obs_transformed * num_mask + (~num_mask) * 1)
result['higher_effect'] = higher_effect
result['higher_coef'] = higher_coef
result['higher_intercept'] = higher_model_full.intercept_
except Exception as e:
logger.warning(f"Failed to fit higher class model: {str(e)}")
result['higher_coef'] = np.zeros(X_transformed.shape[1])
result['higher_intercept'] = 0
else:
logger.warning("No positive samples for higher class comparison")
result['higher_coef'] = np.zeros(X_transformed.shape[1])
result['higher_intercept'] = 0
if pred_class > 0:
# Check if we have any positive samples for lower class
if np.sum(lower_mask) > 1:
try:
idx_train_l, idx_test_l = train_test_split(np.arange(len(lower_mask)), test_size=0.2, random_state=42, stratify=lower_mask)
# Ensure at least 2 samples in each split
if np.sum(lower_mask[idx_train_l]) < 2:
idx_train_l = np.append(idx_train_l, np.where(lower_mask)[0])
if np.sum(lower_mask[idx_test_l]) < 2:
idx_test_l = np.append(idx_test_l, np.where(lower_mask)[0])
X_train_l, X_test_l = X_transformed.iloc[idx_train_l], X_transformed.iloc[idx_test_l]
y_train_l, y_test_l = lower_mask[idx_train_l], lower_mask[idx_test_l]
w_train_l, w_test_l = weights[idx_train_l], weights[idx_test_l]
lower_model = LogisticRegression(random_state=42, class_weight="balanced", max_iter=100000, **self.surrogate_kwargs)
lower_model.fit(X_train_l, y_train_l, sample_weight=w_train_l)
lower_coef = lower_model.coef_[0]
if not show_coefficients:
num_mask = ~np.isin(obs_transformed, [0, 1])
lower_effect = lower_coef * (obs_transformed * num_mask + (~num_mask) * 1)
result['lower_effect'] = lower_effect
result['lower_coef'] = lower_coef
prob_train_l = lower_model.predict_proba(X_train_l)[:, 1]
prob_test_l = lower_model.predict_proba(X_test_l)[:, 1]
in_BCE_l = log_loss(y_train_l, prob_train_l, sample_weight=w_train_l, labels=[0, 1])
out_BCE_l = log_loss(y_test_l, prob_test_l, sample_weight=w_test_l, labels=[0, 1])
zero_one_in_l = 1.0 - (lower_model.predict(X_train_l) == y_train_l).mean()
zero_one_out_l = 1.0 - (lower_model.predict(X_test_l) == y_test_l).mean()
result['lower_fidelity_in'] = (in_BCE_l, zero_one_in_l)
result['lower_fidelity_out'] = (out_BCE_l, zero_one_out_l)
# Refit surrogate on the FULL data for final explanation coefficients
lower_model_full = LogisticRegression(random_state=42, class_weight="balanced", max_iter=100000, **self.surrogate_kwargs)
lower_model_full.fit(X_transformed, lower_mask, sample_weight=weights)
lower_coef = lower_model_full.coef_[0]
if not show_coefficients:
num_mask = ~np.isin(obs_transformed, [0, 1])
lower_effect = lower_coef * (obs_transformed * num_mask + (~num_mask) * 1)
result['lower_effect'] = lower_effect
result['lower_coef'] = lower_coef
result['lower_intercept'] = lower_model_full.intercept_
except Exception as e:
logger.warning(f"Failed to fit lower class model: {str(e)}")
result['lower_coef'] = np.zeros(X_transformed.shape[1])
result['lower_intercept'] = 0
else:
logger.warning("No positive samples for lower class comparison")
result['lower_coef'] = np.zeros(X_transformed.shape[1])
result['lower_intercept'] = 0
# Aggregate in- and out-of-sample BCE losses
losses_in = []
losses_out = []
if 'higher_fidelity_in' in result:
losses_in.append(result['higher_fidelity_in'])
if 'lower_fidelity_in' in result:
losses_in.append(result['lower_fidelity_in'])
if 'higher_fidelity_out' in result:
losses_out.append(result['higher_fidelity_out'])
if 'lower_fidelity_out' in result:
losses_out.append(result['lower_fidelity_out'])
if plot:
# Determine data to plot: coefficients or effects
plot_high = higher_coef if show_coefficients else result.get('higher_effect', higher_coef)
plot_low = lower_coef if show_coefficients else result.get('lower_effect', lower_coef)
self._plot_coefficients(
plot_high,
plot_low,
feature_names=result.get('features'),
observation_idx=observation_idx,
observation=observation,
pred_class=pred_class,
higher_fidelity_in=result.get('higher_fidelity_in'),
higher_fidelity_out=result.get('higher_fidelity_out'),
lower_fidelity_in=result.get('lower_fidelity_in'),
lower_fidelity_out=result.get('lower_fidelity_out'),
is_effect=not show_coefficients,
)
print(result)
elif self.model_type == "decision_tree":
higher_model = None
lower_model = None
if pred_class < n_classes - 1:
# Check if we have any positive samples for higher class
if np.sum(higher_mask) > 1:
try:
idx_train_h, idx_test_h = train_test_split(np.arange(len(higher_mask)), test_size=0.2, random_state=42, stratify=higher_mask)
# Ensure at least 2 samples in each split
if np.sum(higher_mask[idx_train_h]) < 2:
idx_train_h = np.append(idx_train_h, np.where(higher_mask)[0])
if np.sum(higher_mask[idx_test_h]) < 2:
idx_test_h = np.append(idx_test_h, np.where(higher_mask)[0])
X_train_h, X_test_h = X_transformed.iloc[idx_train_h], X_transformed.iloc[idx_test_h]
y_train_h, y_test_h = higher_mask[idx_train_h], higher_mask[idx_test_h]
w_train_h, w_test_h = weights[idx_train_h], weights[idx_test_h]
tree_kwargs = {'random_state': 42, 'class_weight': 'balanced', **self.surrogate_kwargs}
tree_kwargs.setdefault('max_depth', 3)
higher_model = DecisionTreeClassifier(**tree_kwargs)
higher_model.fit(X_train_h, y_train_h, sample_weight=w_train_h)
result['higher_model'] = higher_model
prob_train_h = higher_model.predict_proba(X_train_h)[:, 1]
prob_test_h = higher_model.predict_proba(X_test_h)[:, 1]
BCE_in_h = log_loss(y_train_h, prob_train_h, sample_weight=w_train_h, labels=[0, 1])
BCE_out_h = log_loss(y_test_h, prob_test_h, sample_weight=w_test_h, labels=[0, 1])
zero_one_in_h = 1.0 - (higher_model.predict(X_train_h) == y_train_h).mean()
zero_one_out_h = 1.0 - (higher_model.predict(X_test_h) == y_test_h).mean()
result['higher_fidelity_in'] = (BCE_in_h, zero_one_in_h)
result['higher_fidelity_out'] = (BCE_out_h, zero_one_out_h)
# Refit decision tree on the FULL data for final explanation
tree_kwargs = {'random_state': 42, 'class_weight': 'balanced', **self.surrogate_kwargs}
tree_kwargs.setdefault('max_depth', 3)
higher_model = DecisionTreeClassifier(**tree_kwargs)
higher_model.fit(X_transformed, higher_mask, sample_weight=weights)
result['higher_model'] = higher_model
except Exception as e:
logger.warning(f"Failed to fit higher class model: {str(e)}")
else:
logger.warning("No positive samples for higher class comparison")
if pred_class > 0:
# Check if we have any positive samples for lower class
if np.sum(lower_mask) > 1:
try:
idx_train_l, idx_test_l = train_test_split(np.arange(len(lower_mask)), test_size=0.2, random_state=42, stratify=lower_mask)
# Ensure at least 2 samples in each split
if np.sum(lower_mask[idx_train_l]) < 2:
idx_train_l = np.append(idx_train_l, np.where(lower_mask)[0])
if np.sum(lower_mask[idx_test_l]) < 2:
idx_test_l = np.append(idx_test_l, np.where(lower_mask)[0])
X_train_l, X_test_l = X_transformed.iloc[idx_train_l], X_transformed.iloc[idx_test_l]
y_train_l, y_test_l = lower_mask[idx_train_l], lower_mask[idx_test_l]
w_train_l, w_test_l = weights[idx_train_l], weights[idx_test_l]
tree_kwargs = {'random_state': 42, 'class_weight': 'balanced', **self.surrogate_kwargs}
tree_kwargs.setdefault('max_depth', 3)
lower_model = DecisionTreeClassifier(**tree_kwargs)
lower_model.fit(X_train_l, y_train_l, sample_weight=w_train_l)
result['lower_model'] = lower_model
prob_train_l = lower_model.predict_proba(X_train_l)[:, 1]
prob_test_l = lower_model.predict_proba(X_test_l)[:, 1]
BCE_in_l = log_loss(y_train_l, prob_train_l, sample_weight=w_train_l, labels=[0, 1])
BCE_out_l = log_loss(y_test_l, prob_test_l, sample_weight=w_test_l, labels=[0, 1])
zero_one_in_l = 1.0 - (lower_model.predict(X_train_l) == y_train_l).mean()
zero_one_out_l = 1.0 - (lower_model.predict(X_test_l) == y_test_l).mean()
result['lower_fidelity_in'] = (BCE_in_l, zero_one_in_l)
result['lower_fidelity_out'] = (BCE_out_l, zero_one_out_l)
# Refit decision tree on the FULL data for final explanation
tree_kwargs = {'random_state': 42, 'class_weight': 'balanced', **self.surrogate_kwargs}
tree_kwargs.setdefault('max_depth', 3)
lower_model = DecisionTreeClassifier(**tree_kwargs)
lower_model.fit(X_transformed, lower_mask, sample_weight=weights)
result['lower_model'] = lower_model
except Exception as e:
logger.warning(f"Failed to fit lower class model: {str(e)}")
else:
logger.warning("No positive samples for lower class comparison")
# Aggregate losses
losses_in = []
losses_out = []
if 'higher_fidelity_in' in result:
losses_in.append(result['higher_fidelity_in'])
if 'lower_fidelity_in' in result:
losses_in.append(result['lower_fidelity_in'])
if 'higher_fidelity_out' in result:
losses_out.append(result['higher_fidelity_out'])
if 'lower_fidelity_out' in result:
losses_out.append(result['lower_fidelity_out'])
if plot:
# pass fidelity for trees if needed (uses 'fidelity')
self._plot_decision_tree(
higher_model,
lower_model,
feature_names,
observation_idx,
observation,
pred_class,
higher_fidelity_in=result.get('higher_fidelity_in'),
higher_fidelity_out=result.get('higher_fidelity_out'),
lower_fidelity_in=result.get('lower_fidelity_in'),
lower_fidelity_out=result.get('lower_fidelity_out'),
is_effect=not show_coefficients,
)
return result