diff --git a/CHANGELOG.md b/CHANGELOG.md index 7dceb1d93..27cf77be3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ - adds `InterventionalTreeExplainer` in `shapiq.tree.interventional` - adds `KNNExplainer`, `WeightedKNNExplainer` and `ThresholdNNExplainer` for nearest neighbor models - changes the default for all user-facing `Explainer` classes to `index="SV"`, `max_order=1` (Shapley values) — see Breaking Changes below +- adds `shapiq.scatter_plot` for SHAP-style scatter (dependence) plots of interaction values, supporting both first-order and higher-order interactions [#516](https://github.com/mmschlk/shapiq/pull/516) ### Introducing ProxySHAP [#501](https://github.com/mmschlk/shapiq/pull/501), [Preprint](https://arxiv.org/abs/2605.22738) diff --git a/examples/visualization/plot_scatter.py b/examples/visualization/plot_scatter.py new file mode 100644 index 000000000..f1a68bdd3 --- /dev/null +++ b/examples/visualization/plot_scatter.py @@ -0,0 +1,138 @@ +""" +Scatter Plot +============ + +This example demonstrates :func:`~shapiq.scatter_plot`, which plots the +per-sample value of an interaction against the value of one feature. For +first-order interactions this matches SHAP's ``shap.plots.scatter``; for +higher-order interactions the x-axis is restricted to a single feature in +the interaction tuple. +""" + +from __future__ import annotations + +import matplotlib.pyplot as plt +from sklearn.model_selection import train_test_split +from xgboost import XGBRegressor + +import shapiq + +# %% +# Train a Model +# ------------- + +x_data, y_data = shapiq.datasets.load_california_housing(to_numpy=False) +feature_names = list(x_data.columns) +x_data, y_data = x_data.values, y_data.values +x_train, x_test, y_train, y_test = train_test_split( + x_data, + y_data, + test_size=0.2, + random_state=42, +) +model = XGBRegressor(random_state=42, max_depth=4, n_estimators=50) +model.fit(x_train, y_train) + +# %% +# Compute Explanations for Multiple Instances +# --------------------------------------------- +# We explain 200 test instances so the scatter plots show a meaningful +# distribution while keeping the example fast. + +x_explain = x_test[:200] +explainer = shapiq.TabularExplainer( + model, + data=x_test, + index="FSII", + max_order=2, + random_state=42, +) +explanations = explainer.explain_X(x_explain, budget=200) + +# %% +# Default Scatter Plot +# --------------------- +# Without an explicit ``interaction``, the most important interaction is +# selected automatically (by mean absolute aggregated value). + +shapiq.scatter_plot(explanations, x_explain, feature_names=feature_names) + +# %% +# Main Effect of a Single Feature +# -------------------------------- +# Pass a feature name (or index) to plot its first-order Shapley value +# against its feature values. + +shapiq.scatter_plot( + explanations, + x_explain, + interaction="MedInc", + feature_names=feature_names, +) + +# %% +# Pairwise Interaction +# --------------------- +# Plot a higher-order interaction value. By default the x-axis is the first +# feature in the interaction tuple. + +shapiq.scatter_plot( + explanations, + x_explain, + interaction=("MedInc", "Latitude"), + feature_names=feature_names, +) + +# %% +# Pairwise Interaction with Chosen X-axis +# ----------------------------------------- +# Use ``x_feature`` to switch which feature in the interaction is on the x-axis. + +shapiq.scatter_plot( + explanations, + x_explain, + interaction=("MedInc", "Latitude"), + x_feature="Latitude", + feature_names=feature_names, +) + +# %% +# Color by Another Feature +# ------------------------- +# Set ``color`` to render points using a red-blue colormap based on another +# feature's value, and add a colorbar. + +shapiq.scatter_plot( + explanations, + x_explain, + interaction="MedInc", + color="HouseAge", + feature_names=feature_names, +) + +# %% +# Disable the X-axis Histogram Strip +# ----------------------------------- +# By default a faint histogram of the x-axis feature is drawn along the bottom +# (SHAP-style). Pass ``hist=False`` to hide it. + +shapiq.scatter_plot( + explanations, + x_explain, + interaction="MedInc", + feature_names=feature_names, + hist=False, +) + +# %% +# Custom Axis +# ----------- + +fig, ax = plt.subplots(figsize=(6, 5)) +shapiq.scatter_plot( + explanations, + x_explain, + interaction="MedInc", + feature_names=feature_names, + ax=ax, +) diff --git a/src/shapiq/__init__.py b/src/shapiq/__init__.py index 3e6b88a9c..140804b74 100644 --- a/src/shapiq/__init__.py +++ b/src/shapiq/__init__.py @@ -70,6 +70,7 @@ beeswarm_plot, force_plot, network_plot, + scatter_plot, sentence_plot, si_graph_plot, stacked_bar_plot, @@ -136,6 +137,7 @@ "sentence_plot", "upset_plot", "beeswarm_plot", + "scatter_plot", # public utils "powerset", "get_explicit_subsets", diff --git a/src/shapiq/plot/__init__.py b/src/shapiq/plot/__init__.py index 3814b37d4..afbb631ad 100644 --- a/src/shapiq/plot/__init__.py +++ b/src/shapiq/plot/__init__.py @@ -8,6 +8,7 @@ from .beeswarm import beeswarm_plot from .force import force_plot from .network import network_plot +from .scatter import scatter_plot from .sentence import sentence_plot from .si_graph import si_graph_plot from .stacked_bar import stacked_bar_plot @@ -25,6 +26,7 @@ "sentence_plot", "upset_plot", "beeswarm_plot", + "scatter_plot", # utils "abbreviate_feature_names", ] diff --git a/src/shapiq/plot/scatter.py b/src/shapiq/plot/scatter.py new file mode 100644 index 000000000..9a2222c8b --- /dev/null +++ b/src/shapiq/plot/scatter.py @@ -0,0 +1,316 @@ +"""Scatter (a.k.a. dependence) plot for :class:`~shapiq.InteractionValues`. + +Plots the per-sample interaction value of a chosen interaction tuple against +the value of one feature. For first-order interactions this matches +``shap.plots.scatter``; for higher-order interactions the x-axis is restricted +to a single feature (selected from the interaction tuple). +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from shapiq.interaction_values import InteractionValues, aggregate_interaction_values + +from .beeswarm import _get_red_blue_cmap +from .utils import abbreviate_feature_names + +if TYPE_CHECKING: + from matplotlib.axes import Axes + from matplotlib.figure import Figure + + +__all__ = ["scatter_plot"] + + +def _resolve_feature( + feature: int | str | np.integer, + name_to_idx: dict[str, int], + n_players: int, +) -> int: + """Resolves a feature identifier (index or name) to an integer index.""" + if isinstance(feature, int | np.integer) and not isinstance(feature, bool): + idx = int(feature) + if not 0 <= idx < n_players: + error_message = f"Feature index {idx} out of range [0, {n_players})." + raise ValueError(error_message) + return idx + if isinstance(feature, str): + if feature not in name_to_idx: + error_message = f"Unknown feature name: {feature!r}." + raise ValueError(error_message) + return name_to_idx[feature] + error_message = f"Feature identifier must be int or str, got {type(feature).__name__}." + raise TypeError(error_message) + + +def _resolve_interaction( + interaction: tuple[int, ...] | tuple[str, ...] | int | str | None, + interaction_values_list: list[InteractionValues], + name_to_idx: dict[str, int], + n_players: int, +) -> tuple[int, ...]: + """Resolves an ``interaction`` argument to a sorted tuple of feature indices.""" + if interaction is None: + agg = aggregate_interaction_values( + [abs(iv) for iv in interaction_values_list], aggregation="mean" + ) + candidates = [(k, v) for k, v in agg.interactions.items() if len(k) >= 1] + if not candidates: + error_message = "No non-empty interactions available to plot." + raise ValueError(error_message) + candidates.sort(key=lambda kv: kv[1], reverse=True) + return candidates[0][0] + + if isinstance(interaction, int | np.integer | str): + return (_resolve_feature(interaction, name_to_idx, n_players),) + + if isinstance(interaction, tuple): + resolved = tuple(sorted(_resolve_feature(f, name_to_idx, n_players) for f in interaction)) + if len(resolved) == 0: + error_message = "interaction tuple must contain at least one feature." + raise ValueError(error_message) + return resolved + + error_message = ( + f"interaction must be a tuple, int, str, or None. Got {type(interaction).__name__}." + ) + raise TypeError(error_message) + + +def scatter_plot( + interaction_values_list: list[InteractionValues], + data: pd.DataFrame | np.ndarray, + interaction: tuple[int, ...] | tuple[str, ...] | int | str | None = None, + *, + x_feature: int | str | None = None, + color: int | str | None = None, + feature_names: list[str] | None = None, + abbreviate: bool = True, + alpha: float = 0.8, + dot_size: float = 16, + jitter: float = 0.0, + hist: bool = True, + ax: Axes | None = None, + show: bool = True, +) -> Axes | None: + """Plots a scatter (dependence) plot of an interaction's per-sample value against one feature. + + Inspired by `SHAP's `_ + ``shap.plots.scatter``. For a first-order interaction ``(i,)`` the x-axis is feature ``i``'s + value across samples and the y-axis is its Shapley value. For higher-order interactions like + ``(i, j)`` the x-axis is the value of a single feature in the interaction (selected via + ``x_feature``, defaulting to the first feature in the sorted tuple) and the y-axis is the + higher-order interaction value. + + Args: + interaction_values_list: A non-empty list of :class:`~shapiq.InteractionValues` objects, + one per sample row of ``data``. + data: The feature values for the samples, as a ``pandas.DataFrame`` or 2D ``numpy`` array. + Must have the same number of rows as ``interaction_values_list``. + interaction: Identifies the interaction to plot. Accepts an ``int`` or ``str`` (treated + as a main effect single-element tuple), a tuple of feature indices like ``(0, 2)``, + or a tuple of feature names like ``("MedInc", "Latitude")``. If ``None``, the + globally most important interaction (by mean absolute aggregated value) is selected. + Defaults to ``None``. + x_feature: For higher-order interactions, which feature in ``interaction`` to place on + the x-axis. Must be a member of ``interaction``. Ignored for first-order + interactions. Defaults to the first feature in the sorted interaction tuple. + color: Feature index or name used to color the points (with a red-blue colormap and a + colorbar). If ``None`` (default), all points are drawn in a neutral color and no + colorbar is shown. ``NaN`` color values render gray. + feature_names: Names of the features. Defaults to ``["F0", "F1", ...]``. + abbreviate: Whether to abbreviate feature names for axis labels. Defaults to ``True``. + alpha: Transparency of the points, in ``(0, 1]``. Defaults to ``0.8``. + dot_size: Size of the scatter points. Defaults to ``16``. + jitter: If positive, adds Gaussian jitter to the plotted x-values, scaled to + ``jitter * std(x_vals)``. Useful for categorical or integer-valued features. + Defaults to ``0.0`` (disabled). + hist: Whether to draw a faint histogram of the x-axis feature's distribution along + the bottom of the plot (SHAP-style). The bars share the main x-axis: no separate + axes is created. Defaults to ``True``. + ax: ``matplotlib`` ``Axes`` object to plot on. If ``None``, a new figure and axes are + created. + show: Whether to call ``plt.show()`` at the end. If ``False``, returns the axes instead. + Defaults to ``True``. + + Returns: + The ``Axes`` object if ``show=False``, otherwise ``None``. + + Raises: + ValueError: If inputs are inconsistent (empty list, length mismatch, unknown feature + names or indices, an interaction tuple absent from every sample's lookup, an + out-of-tuple ``x_feature``, or invalid numeric parameters). + TypeError: If ``data`` is not a DataFrame or ndarray, or if a feature identifier has an + unsupported type. + + """ + if not isinstance(interaction_values_list, list) or len(interaction_values_list) == 0: + error_message = "interaction_values_list must be a non-empty list." + raise ValueError(error_message) + if not isinstance(data, pd.DataFrame) and not isinstance(data, np.ndarray): + error_message = f"data must be a pandas DataFrame or a numpy array. Got: {type(data)}." + raise TypeError(error_message) + if len(interaction_values_list) != len(data): + error_message = "Length of interaction_values_list must match number of rows in data." + raise ValueError(error_message) + if alpha <= 0 or alpha > 1: + error_message = "alpha must be between 0 and 1." + raise ValueError(error_message) + if dot_size <= 0: + error_message = "dot_size must be a positive value." + raise ValueError(error_message) + if jitter < 0: + error_message = "jitter must be non-negative." + raise ValueError(error_message) + + n_players = interaction_values_list[0].n_players + + if feature_names is None: + feature_names_full = [f"F{i}" for i in range(n_players)] + else: + if len(feature_names) != n_players: + error_message = "Length of feature_names must match n_players." + raise ValueError(error_message) + feature_names_full = list(feature_names) + + feature_names_display = ( + abbreviate_feature_names(feature_names_full) if abbreviate else list(feature_names_full) + ) + name_to_idx = {n: i for i, n in enumerate(feature_names_full)} + display_mapping = dict(enumerate(feature_names_display)) + + interaction_tuple = _resolve_interaction( + interaction, interaction_values_list, name_to_idx, n_players + ) + if not any(interaction_tuple in iv.interaction_lookup for iv in interaction_values_list): + error_message = f"Interaction {interaction_tuple} not found in InteractionValues lookup." + raise ValueError(error_message) + + if len(interaction_tuple) == 1 or x_feature is None: + x_idx = interaction_tuple[0] + else: + x_idx = _resolve_feature(x_feature, name_to_idx, n_players) + if x_idx not in interaction_tuple: + error_message = ( + f"x_feature {x_feature!r} must be a member of interaction {interaction_tuple}." + ) + raise ValueError(error_message) + + color_idx: int | None = None + if color is not None: + color_idx = _resolve_feature(color, name_to_idx, n_players) + + x_numpy = data.to_numpy(dtype=float) if isinstance(data, pd.DataFrame) else data.astype(float) + x_vals = x_numpy[:, x_idx] + y_vals = np.array([iv[interaction_tuple] for iv in interaction_values_list], dtype=float) + + if ax is None: + _fig, ax = plt.subplots(figsize=(7, 5)) + fig: Figure = ax.get_figure() # type: ignore[assignment] + + x_plot = x_vals + if jitter > 0: + std = float(np.nanstd(x_vals)) + if std > 0: + rng = np.random.default_rng(0) + x_plot = x_vals + rng.normal(0.0, jitter * std, size=x_vals.shape) + + n_samples = len(x_vals) + sc = None + if color_idx is None: + ax.scatter( + x_plot, + y_vals, + color="#1f77b4", + s=dot_size, + alpha=alpha, + linewidth=0, + rasterized=n_samples > 500, + ) + else: + c_vals = x_numpy[:, color_idx] + nan_mask = np.isnan(c_vals) + valid_mask = ~nan_mask + + if nan_mask.any(): + ax.scatter( + x_plot[nan_mask], + y_vals[nan_mask], + color="#777777", + s=dot_size, + alpha=alpha * 0.5, + linewidth=0, + rasterized=n_samples > 500, + ) + + if valid_mask.any(): + valid_color_vals = c_vals[valid_mask] + vmin = float(np.min(valid_color_vals)) + vmax = float(np.max(valid_color_vals)) + if vmin == vmax: + vmin -= 1e-9 + vmax += 1e-9 + sc = ax.scatter( + x_plot[valid_mask], + y_vals[valid_mask], + c=valid_color_vals, + cmap=_get_red_blue_cmap(), + vmin=vmin, + vmax=vmax, + s=dot_size, + alpha=alpha, + linewidth=0, + rasterized=n_samples > 500, + ) + + if sc is not None and color_idx is not None: + cb = fig.colorbar(sc, ax=ax, aspect=80) + cb.set_label(display_mapping[color_idx], size=11, labelpad=0) + cb.ax.tick_params(labelsize=10, length=0) + cb.outline.set_visible(False) # type: ignore[union-attr] + + valid_x_for_hist = x_vals[~np.isnan(x_vals)] + draw_hist = ( + hist + and len(valid_x_for_hist) >= 2 + and float(np.min(valid_x_for_hist)) < float(np.max(valid_x_for_hist)) + ) + if draw_hist: + n_bins = min(50, max(10, len(valid_x_for_hist) // 2)) + counts, bin_edges = np.histogram(valid_x_for_hist, bins=n_bins) + if counts.max() > 0: + hist_band = 0.10 # bottom 10% of the existing plot area + rel_heights = (counts / counts.max()) * hist_band + widths = np.diff(bin_edges) + ax.bar( + bin_edges[:-1], + rel_heights, + width=widths, + bottom=0.0, + align="edge", + color="#aaaaaa", + alpha=0.4, + edgecolor="none", + zorder=-2, + transform=ax.get_xaxis_transform(), + ) + + ax.axhline(0, color="#999999", linestyle="-", linewidth=1, zorder=1) + ax.set_xlabel(display_mapping[x_idx], fontsize=12) + index_name = interaction_values_list[0].index + feature_label = ", ".join(display_mapping[f] for f in interaction_tuple) + ax.set_ylabel(f"{index_name}({feature_label})", fontsize=12) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + plt.tight_layout() + + if not show: + return ax + plt.show() + return None diff --git a/tests/shapiq/tests_unit/tests_plots/test_scatter.py b/tests/shapiq/tests_unit/tests_plots/test_scatter.py new file mode 100644 index 000000000..cb8851c49 --- /dev/null +++ b/tests/shapiq/tests_unit/tests_plots/test_scatter.py @@ -0,0 +1,507 @@ +"""This module contains all tests for the scatter plot.""" + +from __future__ import annotations + +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pytest +from matplotlib import collections + +from shapiq.interaction_values import InteractionValues +from shapiq.plot import scatter_plot + +N_SAMPLES = 10 +N_PLAYERS = 5 + + +@pytest.fixture +def mock_interaction_data() -> tuple[list[InteractionValues], np.ndarray, pd.DataFrame, list[str]]: + """Creates mock data for scatter plot tests.""" + lookup = { + (): 1, + (0,): 0, + (1,): 1, + (2,): 2, + (3,): 3, + (4,): 4, + (0, 1): 5, + (0, 2): 6, + (1, 3): 7, + (2, 4): 8, + (0, 1, 2): 9, + } + n_interactions = len(lookup) + + interaction_values_list = [] + rng = np.random.default_rng(42) + for _ in range(N_SAMPLES): + values = rng.random(n_interactions) * 2 - 1 + iv = InteractionValues( + values=values, + interaction_lookup=lookup, + index="k-SII", + min_order=1, + max_order=3, + n_players=N_PLAYERS, + baseline_value=rng.random(), + ) + interaction_values_list.append(iv) + + feature_data_np = rng.random((N_SAMPLES, N_PLAYERS)) + feature_names = [f"feature_{i}" for i in range(N_PLAYERS)] + feature_data_pd = pd.DataFrame(feature_data_np, columns=feature_names) + + return interaction_values_list, feature_data_np, feature_data_pd, feature_names + + +def test_scatter_plot_basic(mock_interaction_data): + """Tests basic scatter plot calls with numpy/pandas inputs and shorthand interaction args.""" + interaction_values_list, feature_data_np, feature_data_pd, feature_names = mock_interaction_data + + ax = scatter_plot( + interaction_values_list, + feature_data_np, + interaction=(0,), + feature_names=feature_names, + show=False, + ) + assert isinstance(ax, plt.Axes) + assert "feature" in ax.get_xlabel().lower() or "f0" in ax.get_xlabel().lower() + assert ax.get_ylabel() == "k-SII(F0)" + plt.close("all") + + ax = scatter_plot( + interaction_values_list, + feature_data_pd, + interaction=(0,), + feature_names=feature_names, + show=False, + ) + assert isinstance(ax, plt.Axes) + plt.close("all") + + ax = scatter_plot( + interaction_values_list, + feature_data_np, + interaction=(0,), + show=False, + ) + assert ax.get_xlabel().startswith("F") + plt.close("all") + + # Auto-pick interaction + ax = scatter_plot( + interaction_values_list, + feature_data_pd, + feature_names=feature_names, + show=False, + ) + assert isinstance(ax, plt.Axes) + assert ax.get_ylabel() != "" + plt.close("all") + + # Equivalence of int / str / tuple shorthand for main effects + labels = [] + for arg in (0, "feature_0", (0,), ("feature_0",)): + ax = scatter_plot( + interaction_values_list, + feature_data_pd, + interaction=arg, + feature_names=feature_names, + show=False, + ) + labels.append(ax.get_xlabel()) + plt.close("all") + assert len(set(labels)) == 1 + + # Higher-order: default x-axis is first feature in tuple + ax = scatter_plot( + interaction_values_list, + feature_data_pd, + interaction=(0, 1), + feature_names=feature_names, + abbreviate=False, + show=False, + ) + assert ax.get_xlabel() == "feature_0" + assert "feature_0" in ax.get_ylabel() and "feature_1" in ax.get_ylabel() + plt.close("all") + + # Higher-order with explicit x_feature by index + ax = scatter_plot( + interaction_values_list, + feature_data_pd, + interaction=(0, 1), + x_feature=1, + feature_names=feature_names, + abbreviate=False, + show=False, + ) + assert ax.get_xlabel() == "feature_1" + plt.close("all") + + # Higher-order via names + x_feature by name + ax = scatter_plot( + interaction_values_list, + feature_data_pd, + interaction=("feature_0", "feature_1"), + x_feature="feature_1", + feature_names=feature_names, + abbreviate=False, + show=False, + ) + assert ax.get_xlabel() == "feature_1" + plt.close("all") + + +def test_scatter_plot_options(mock_interaction_data): + """Tests scatter_plot color, jitter, ax, abbreviate options.""" + interaction_values_list, _, feature_data_pd, feature_names = mock_interaction_data + + # color = explicit feature -> colorbar added + fig, ax = plt.subplots() + n_axes_before = len(fig.axes) + scatter_plot( + interaction_values_list, + feature_data_pd, + interaction=(0,), + color="feature_2", + feature_names=feature_names, + abbreviate=False, + ax=ax, + show=False, + ) + assert len(fig.axes) > n_axes_before + cbar_labels = [a.get_ylabel() for a in fig.axes if a is not ax] + assert "feature_2" in cbar_labels + plt.close("all") + + # color = None -> no extra colorbar axes added (histogram lives on the main ax) + fig, ax = plt.subplots() + n_axes_before = len(fig.axes) + scatter_plot( + interaction_values_list, + feature_data_pd, + interaction=(0,), + color=None, + feature_names=feature_names, + ax=ax, + show=False, + ) + assert len(fig.axes) == n_axes_before + plt.close("all") + + # NaN in color feature -> gray points at half alpha + data_with_nan = feature_data_pd.copy() + data_with_nan.iloc[0, 2] = np.nan + test_alpha = 0.8 + ax = scatter_plot( + interaction_values_list, + data_with_nan, + interaction=(0,), + color="feature_2", + alpha=test_alpha, + feature_names=feature_names, + show=False, + ) + expected_nan_color = list(mcolors.to_rgba("#777777")) + expected_nan_alpha = test_alpha * 0.5 + expected_nan_color[3] = expected_nan_alpha + nan_points_found = False + for collection in ax.collections: + colors = collection.get_facecolors() + if ( + isinstance(collection, collections.PathCollection) + and collection.get_alpha() == expected_nan_alpha + and len(colors) > 0 + and np.allclose(colors[0], expected_nan_color) + ): + nan_points_found = True + break + assert nan_points_found + plt.close("all") + + # All-same color values -> no error (vmin/vmax epsilon) + data_const_color = feature_data_pd.copy() + data_const_color.iloc[:, 2] = 0.5 + ax = scatter_plot( + interaction_values_list, + data_const_color, + interaction=(0,), + color="feature_2", + feature_names=feature_names, + show=False, + ) + assert isinstance(ax, plt.Axes) + plt.close("all") + + # ax= passed in returns the same axes + _, ax_existing = plt.subplots() + ax_returned = scatter_plot( + interaction_values_list, + feature_data_pd, + interaction=(0,), + feature_names=feature_names, + ax=ax_existing, + show=False, + ) + assert ax_returned is ax_existing + plt.close("all") + + # abbreviate=False keeps long names + long_names = [f"a_very_long_feature_name_{i}" for i in range(N_PLAYERS)] + ax = scatter_plot( + interaction_values_list, + feature_data_pd, + interaction=(0,), + feature_names=long_names, + abbreviate=False, + show=False, + ) + assert ax.get_xlabel() == long_names[0] + plt.close("all") + + # jitter changes plotted x values + ax = scatter_plot( + interaction_values_list, + feature_data_pd, + interaction=(0,), + feature_names=feature_names, + jitter=0.5, + hist=False, + show=False, + ) + raw = feature_data_pd.iloc[:, 0].to_numpy() + plotted = ax.collections[0].get_offsets()[:, 0] + assert np.max(np.abs(plotted - raw)) > 0 + plt.close("all") + + # hist = True -> draws histogram bars on the main ax (no extra axes, no y-extension) + fig, ax = plt.subplots() + n_axes_before = len(fig.axes) + scatter_plot( + interaction_values_list, + feature_data_pd, + interaction=(0,), + feature_names=feature_names, + abbreviate=False, + hist=True, + ax=ax, + show=False, + ) + assert len(fig.axes) == n_axes_before + assert ax.get_xlabel() == "feature_0" + rect_patches = [p for p in ax.patches if type(p).__name__ == "Rectangle"] + assert len(rect_patches) > 0 + # histogram bars sit behind the scatter (lower zorder) + bar_zorder = min(p.get_zorder() for p in rect_patches) + scatter_zorder = min(c.get_zorder() for c in ax.collections) + assert bar_zorder < scatter_zorder + plt.close("all") + + # hist = True with constant x-values -> no histogram drawn + data_const_x = feature_data_pd.copy() + data_const_x.iloc[:, 0] = 0.5 + ax_const = scatter_plot( + interaction_values_list, + data_const_x, + interaction=(0,), + feature_names=feature_names, + abbreviate=False, + hist=True, + show=False, + ) + assert ax_const.get_xlabel() == "feature_0" + rect_patches = [p for p in ax_const.patches if type(p).__name__ == "Rectangle"] + assert len(rect_patches) == 0 + plt.close("all") + + # hist = True does not extend the y-axis (limits identical to hist = False) + ax_with = scatter_plot( + interaction_values_list, + feature_data_pd, + interaction=(0,), + feature_names=feature_names, + hist=True, + show=False, + ) + ylim_with = ax_with.get_ylim() + plt.close("all") + ax_without = scatter_plot( + interaction_values_list, + feature_data_pd, + interaction=(0,), + feature_names=feature_names, + hist=False, + show=False, + ) + ylim_without = ax_without.get_ylim() + plt.close("all") + assert np.isclose(ylim_with[0], ylim_without[0]) + assert np.isclose(ylim_with[1], ylim_without[1]) + + +def test_scatter_plot_errors(mock_interaction_data): + """Tests that scatter_plot raises informative errors for bad inputs.""" + interaction_values_list, feature_data_np, feature_data_pd, feature_names = mock_interaction_data + + with pytest.raises(ValueError, match="non-empty list"): + scatter_plot([], feature_data_np, interaction=(0,), show=False) + + with pytest.raises(ValueError, match="must match number of rows"): + scatter_plot( + interaction_values_list, + feature_data_np[1:], + interaction=(0,), + show=False, + ) + + with pytest.raises(TypeError, match="must be a pandas DataFrame or a numpy array"): + scatter_plot(interaction_values_list, "not_a_valid_data_type", interaction=(0,), show=False) + + with pytest.raises(ValueError, match="Unknown feature name"): + scatter_plot( + interaction_values_list, + feature_data_pd, + interaction=("nonexistent",), + feature_names=feature_names, + show=False, + ) + + with pytest.raises(ValueError, match="out of range"): + scatter_plot( + interaction_values_list, + feature_data_pd, + interaction=(99,), + feature_names=feature_names, + show=False, + ) + + with pytest.raises(ValueError, match="not found in InteractionValues"): + scatter_plot( + interaction_values_list, + feature_data_pd, + interaction=(3, 4), + feature_names=feature_names, + show=False, + ) + + with pytest.raises(ValueError, match="must be a member of interaction"): + scatter_plot( + interaction_values_list, + feature_data_pd, + interaction=(0, 1), + x_feature=2, + feature_names=feature_names, + show=False, + ) + + with pytest.raises(ValueError, match="Length of feature_names must match"): + scatter_plot( + interaction_values_list, + feature_data_np, + interaction=(0,), + feature_names=feature_names[:-1], + show=False, + ) + + with pytest.raises(ValueError, match="alpha must be between 0 and 1"): + scatter_plot( + interaction_values_list, + feature_data_np, + interaction=(0,), + alpha=-0.1, + show=False, + ) + + with pytest.raises(ValueError, match="alpha must be between 0 and 1"): + scatter_plot( + interaction_values_list, + feature_data_np, + interaction=(0,), + alpha=2.0, + show=False, + ) + + with pytest.raises(ValueError, match="dot_size must be a positive value"): + scatter_plot( + interaction_values_list, + feature_data_np, + interaction=(0,), + dot_size=-1, + show=False, + ) + + with pytest.raises(ValueError, match="jitter must be non-negative"): + scatter_plot( + interaction_values_list, + feature_data_np, + interaction=(0,), + jitter=-0.5, + show=False, + ) + + # interaction with unsupported type (e.g., float) -> TypeError from _resolve_interaction + with pytest.raises(TypeError, match="interaction must be a tuple"): + scatter_plot( + interaction_values_list, + feature_data_pd, + interaction=1.5, + feature_names=feature_names, + show=False, + ) + + # interaction with a list -> not a tuple -> TypeError + with pytest.raises(TypeError, match="interaction must be a tuple"): + scatter_plot( + interaction_values_list, + feature_data_pd, + interaction=[0, 1], + feature_names=feature_names, + show=False, + ) + + # empty interaction tuple -> ValueError + with pytest.raises(ValueError, match="must contain at least one feature"): + scatter_plot( + interaction_values_list, + feature_data_pd, + interaction=(), + feature_names=feature_names, + show=False, + ) + + # tuple of unsupported feature type -> TypeError from _resolve_feature + with pytest.raises(TypeError, match="Feature identifier must be int or str"): + scatter_plot( + interaction_values_list, + feature_data_pd, + interaction=(1.5,), + feature_names=feature_names, + show=False, + ) + + +def test_scatter_plot_no_non_empty_interactions(): + """interaction=None with an InteractionValues containing only the empty key raises.""" + n_samples = 4 + n_players = 3 + lookup = {(): 0} + rng = np.random.default_rng(0) + ivs = [ + InteractionValues( + values=np.array([rng.random()]), + interaction_lookup=lookup, + index="k-SII", + min_order=0, + max_order=0, + n_players=n_players, + baseline_value=0.0, + ) + for _ in range(n_samples) + ] + data = rng.random((n_samples, n_players)) + with pytest.raises(ValueError, match="No non-empty interactions"): + scatter_plot(ivs, data, interaction=None, show=False)