diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index c1f9e1d385..9cbbc4db26 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -222,6 +222,8 @@ def _add_application_parameters(self) -> None: applications.append("streamlit") if importlib.util.find_spec("snowflake.ml"): applications.append("SnowparkML") + if importlib.util.find_spec("snowbook"): + applications.append("notebook") self._lower_case_parameters[PARAM_APPLICATION] = ( ":".join(applications) or get_application_name() ) diff --git a/tests/unit/test_server_connection.py b/tests/unit/test_server_connection.py index aa06054ed3..92ef08a87d 100644 --- a/tests/unit/test_server_connection.py +++ b/tests/unit/test_server_connection.py @@ -4,8 +4,9 @@ import io import logging +import os from unittest import mock -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest @@ -169,3 +170,71 @@ def test_run_query_when_ignore_results_true(mock_server_connection): ) mock_server_connection._to_data_or_iter.assert_called() assert "sfqid" in result and result["sfqid"] == "ignore_results is False" + + +def test_snowbook_detection_adds_notebook_application(mock_server_connection): + """Test that 'notebook' is added when snowbook is available.""" + with patch("importlib.util.find_spec") as mock_find_spec: + mock_find_spec.side_effect = ( + lambda name: mock.MagicMock() if name == "snowbook" else None + ) + + mock_server_connection._lower_case_parameters = {} + mock_server_connection._add_application_parameters() + + assert ( + "notebook" in mock_server_connection._lower_case_parameters["application"] + ) + + +def test_snowbook_detection_without_snowbook(mock_server_connection): + """Test that 'notebook' is not added when snowbook is not available.""" + with patch("importlib.util.find_spec", return_value=None): + mock_server_connection._lower_case_parameters = {} + mock_server_connection._add_application_parameters() + + assert ( + "notebook" + not in mock_server_connection._lower_case_parameters["application"] + ) + + +def test_snowbook_detection_with_multiple_applications(mock_server_connection): + """Test that snowbook works alongside other application detections.""" + with patch("importlib.util.find_spec") as mock_find_spec: + mock_find_spec.side_effect = ( + lambda name: mock.MagicMock() + if name in ["streamlit", "snowflake.ml", "snowbook"] + else None + ) + + mock_server_connection._lower_case_parameters = {} + mock_server_connection._add_application_parameters() + + app_param = mock_server_connection._lower_case_parameters["application"] + assert app_param == "streamlit:SnowparkML:notebook" + + +def test_env_var_partner_takes_precedence(mock_server_connection): + """Test that ENV_VAR_PARTNER takes precedence over module detection.""" + with patch.dict(os.environ, {"SF_PARTNER": "custom_partner"}): + with patch("importlib.util.find_spec", return_value=mock.MagicMock()): + mock_server_connection._lower_case_parameters = {} + mock_server_connection._add_application_parameters() + + assert ( + mock_server_connection._lower_case_parameters["application"] + == "custom_partner" + ) + + +def test_existing_application_param_not_overwritten(mock_server_connection): + """Test that existing application parameter is preserved.""" + with patch("importlib.util.find_spec", return_value=mock.MagicMock()): + mock_server_connection._lower_case_parameters = {"application": "existing_app"} + mock_server_connection._add_application_parameters() + + assert ( + mock_server_connection._lower_case_parameters["application"] + == "existing_app" + )