Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- adds `InterventionalTreeExplainer` in `shapiq.tree.interventional`
- adds `KNNExplainer`, `WeightedKNNExplainer` and `ThresholdNNExplainer` for nearest neighbor models
- changes the default for all user-facing `Explainer` classes to `index="SV"`, `max_order=1` (Shapley values) — see Breaking Changes below
- adds `shapiq.scatter_plot` for SHAP-style scatter (dependence) plots of interaction values, supporting both first-order and higher-order interactions [#516](https://github.com/mmschlk/shapiq/pull/516)


### Introducing ProxySHAP [#501](https://github.com/mmschlk/shapiq/pull/501), [Preprint](https://arxiv.org/abs/2605.22738)
Expand Down
138 changes: 138 additions & 0 deletions examples/visualization/plot_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""
Scatter Plot
============

This example demonstrates :func:`~shapiq.scatter_plot`, which plots the
per-sample value of an interaction against the value of one feature. For
first-order interactions this matches SHAP's ``shap.plots.scatter``; for
higher-order interactions the x-axis is restricted to a single feature in
the interaction tuple.
"""

from __future__ import annotations

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from xgboost import XGBRegressor

import shapiq

# %%
# Train a Model
# -------------

x_data, y_data = shapiq.datasets.load_california_housing(to_numpy=False)
feature_names = list(x_data.columns)
x_data, y_data = x_data.values, y_data.values
x_train, x_test, y_train, y_test = train_test_split(
x_data,
y_data,
test_size=0.2,
random_state=42,
)
model = XGBRegressor(random_state=42, max_depth=4, n_estimators=50)
model.fit(x_train, y_train)

# %%
# Compute Explanations for Multiple Instances
# ---------------------------------------------
# We explain 200 test instances so the scatter plots show a meaningful
# distribution while keeping the example fast.

x_explain = x_test[:200]
explainer = shapiq.TabularExplainer(
model,
data=x_test,
index="FSII",
max_order=2,
random_state=42,
)
explanations = explainer.explain_X(x_explain, budget=200)

# %%
# Default Scatter Plot
# ---------------------
# Without an explicit ``interaction``, the most important interaction is
# selected automatically (by mean absolute aggregated value).

shapiq.scatter_plot(explanations, x_explain, feature_names=feature_names)

# %%
# Main Effect of a Single Feature
# --------------------------------
# Pass a feature name (or index) to plot its first-order Shapley value
# against its feature values.

shapiq.scatter_plot(
explanations,
x_explain,
interaction="MedInc",
feature_names=feature_names,
)

# %%
# Pairwise Interaction
# ---------------------
# Plot a higher-order interaction value. By default the x-axis is the first
# feature in the interaction tuple.

shapiq.scatter_plot(
explanations,
x_explain,
interaction=("MedInc", "Latitude"),
feature_names=feature_names,
)

# %%
# Pairwise Interaction with Chosen X-axis
# -----------------------------------------
# Use ``x_feature`` to switch which feature in the interaction is on the x-axis.

shapiq.scatter_plot(
explanations,
x_explain,
interaction=("MedInc", "Latitude"),
x_feature="Latitude",
feature_names=feature_names,
)

# %%
# Color by Another Feature
# -------------------------
# Set ``color`` to render points using a red-blue colormap based on another
# feature's value, and add a colorbar.

shapiq.scatter_plot(
explanations,
x_explain,
interaction="MedInc",
color="HouseAge",
feature_names=feature_names,
)

# %%
# Disable the X-axis Histogram Strip
# -----------------------------------
# By default a faint histogram of the x-axis feature is drawn along the bottom
# (SHAP-style). Pass ``hist=False`` to hide it.

shapiq.scatter_plot(
explanations,
x_explain,
interaction="MedInc",
feature_names=feature_names,
hist=False,
)

# %%
# Custom Axis
# -----------

fig, ax = plt.subplots(figsize=(6, 5))
shapiq.scatter_plot(
explanations,
x_explain,
interaction="MedInc",
feature_names=feature_names,
ax=ax,
)
2 changes: 2 additions & 0 deletions src/shapiq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
beeswarm_plot,
force_plot,
network_plot,
scatter_plot,
sentence_plot,
si_graph_plot,
stacked_bar_plot,
Expand Down Expand Up @@ -136,6 +137,7 @@
"sentence_plot",
"upset_plot",
"beeswarm_plot",
"scatter_plot",
# public utils
"powerset",
"get_explicit_subsets",
Expand Down
2 changes: 2 additions & 0 deletions src/shapiq/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .beeswarm import beeswarm_plot
from .force import force_plot
from .network import network_plot
from .scatter import scatter_plot
from .sentence import sentence_plot
from .si_graph import si_graph_plot
from .stacked_bar import stacked_bar_plot
Expand All @@ -25,6 +26,7 @@
"sentence_plot",
"upset_plot",
"beeswarm_plot",
"scatter_plot",
# utils
"abbreviate_feature_names",
]
Loading
Loading