diff --git a/seaborn/axisgrid.py b/seaborn/axisgrid.py index 17d333bc89..6c7fe4de77 100644 --- a/seaborn/axisgrid.py +++ b/seaborn/axisgrid.py @@ -1248,12 +1248,43 @@ def __init__( data = handle_data_source(data) # Sort out the variables that define the grid - numeric_cols = self._find_numeric_cols(data) - if hue in numeric_cols: - numeric_cols.remove(hue) - if vars is not None: + + # `_find_numeric_cols` will crash if there are duplicated columns in the data, + # even if these columns are not in the `vars` variable. + # If `vars` is provided, we don't need `_find_numeric_cols`. + # If `vars` is not provided, we need `_find_numeric_cols`, + # but it crashes with ambigious ValueError: The truth value of a DataFrame ... + # My fix is to skip `_find_numeric_cols` when `vars` is provided. + # And raise early error if data_to_plot.columns is duplicated. + # [I suppose duplicated columns are not expected in PairGrid] + + if vars is not None: # user provide vars x_vars = list(vars) y_vars = list(vars) + if len(set(vars)) < len(x_vars): + # Does not crash, only causes unexpected figures. + # Do not take efforts to specify duplicants. + warnings.warn(f"Duplicated items in vars: {x_vars}") + condensed_vars = list(set(x_vars)) + else: + condensed_vars = x_vars + # Use condensed_vars to avoid duplicated items in vars + # causing duplicates in data.loc[:, vars].columns. + selected_columns = data.loc[:, condensed_vars].columns + if not selected_columns.is_unique: + # Crash if duplicated columns are selected in vars. + # Specify duplicants since we raise an Error. + dupe_cols = selected_columns[selected_columns.duplicated()] + raise ValueError( + f"Columns: {dupe_cols} are duplicated.") + else: + if not data.columns.is_unique: + dupe_cols = data.columns[data.columns.duplicated()] + raise ValueError( + f"Columns: {dupe_cols} are duplicated.") + numeric_cols = self._find_numeric_cols(data) + if hue in numeric_cols: + numeric_cols.remove(hue) if x_vars is None: x_vars = numeric_cols if y_vars is None: diff --git a/tests/test_axisgrid.py b/tests/test_axisgrid.py index 6470edfa4f..d53321afd2 100644 --- a/tests/test_axisgrid.py +++ b/tests/test_axisgrid.py @@ -785,6 +785,32 @@ def test_remove_hue_from_default(self): assert hue in g.x_vars assert hue in g.y_vars + def test_duplicates_in_df_columns_without_vars(self): + # should fail with clear msg + df_with_dupe = self.df.loc[:, ["x", "y", "z"]].copy() + df_with_dupe.columns = ["x", "y", "y"] + with pytest.raises(ValueError, match=r"Columns: .* are duplicated\."): + ag.PairGrid(df_with_dupe) + + def test_duplicates_in_df_columns_with_related_vars(self): + # should fail with clear msg + df_with_dupe = self.df.loc[:, ["x", "y", "z"]].copy() + df_with_dupe.columns = ["x", "y", "y"] + with pytest.raises(ValueError, match=r"Columns: .* are duplicated\."): + ag.PairGrid(df_with_dupe, vars=['x', 'y']) + + def test_duplicated_vars(self): + # should only warn + with pytest.warns(UserWarning, match=r"Duplicated items in vars: .*"): + ag.PairGrid(self.df, vars=['x', 'y', 'y']) + + def test_duplicates_in_df_columns_with_not_related_vars(self): + # should pass + df_with_dupe = pd.concat( + [self.df["x"], self.df["y"], self.df["z"], self.df["x"]], axis=1) + df_with_dupe.columns = ["x", "y", "z", "z"] + ag.PairGrid(df_with_dupe, vars=['x', 'y']) + @pytest.mark.parametrize( "x_vars, y_vars", [