Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def _add_application_parameters(self) -> None:
applications.append("streamlit")
if importlib.util.find_spec("snowflake.ml"):
applications.append("SnowparkML")
if importlib.util.find_spec("snowbook"):
applications.append("notebook")
self._lower_case_parameters[PARAM_APPLICATION] = (
":".join(applications) or get_application_name()
)
Expand Down
71 changes: 70 additions & 1 deletion tests/unit/test_server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

import io
import logging
import os
from unittest import mock
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest

Expand Down Expand Up @@ -169,3 +170,71 @@ def test_run_query_when_ignore_results_true(mock_server_connection):
)
mock_server_connection._to_data_or_iter.assert_called()
assert "sfqid" in result and result["sfqid"] == "ignore_results is False"


def test_snowbook_detection_adds_notebook_application(mock_server_connection):
"""Test that 'notebook' is added when snowbook is available."""
with patch("importlib.util.find_spec") as mock_find_spec:
mock_find_spec.side_effect = (
lambda name: mock.MagicMock() if name == "snowbook" else None
)

mock_server_connection._lower_case_parameters = {}
mock_server_connection._add_application_parameters()

assert (
"notebook" in mock_server_connection._lower_case_parameters["application"]
)


def test_snowbook_detection_without_snowbook(mock_server_connection):
"""Test that 'notebook' is not added when snowbook is not available."""
with patch("importlib.util.find_spec", return_value=None):
mock_server_connection._lower_case_parameters = {}
mock_server_connection._add_application_parameters()

assert (
"notebook"
not in mock_server_connection._lower_case_parameters["application"]
)


def test_snowbook_detection_with_multiple_applications(mock_server_connection):
"""Test that snowbook works alongside other application detections."""
with patch("importlib.util.find_spec") as mock_find_spec:
mock_find_spec.side_effect = (
lambda name: mock.MagicMock()
if name in ["streamlit", "snowflake.ml", "snowbook"]
else None
)

mock_server_connection._lower_case_parameters = {}
mock_server_connection._add_application_parameters()

app_param = mock_server_connection._lower_case_parameters["application"]
assert app_param == "streamlit:SnowparkML:notebook"


def test_env_var_partner_takes_precedence(mock_server_connection):
"""Test that ENV_VAR_PARTNER takes precedence over module detection."""
with patch.dict(os.environ, {"SF_PARTNER": "custom_partner"}):
with patch("importlib.util.find_spec", return_value=mock.MagicMock()):
mock_server_connection._lower_case_parameters = {}
mock_server_connection._add_application_parameters()

assert (
mock_server_connection._lower_case_parameters["application"]
== "custom_partner"
)


def test_existing_application_param_not_overwritten(mock_server_connection):
"""Test that existing application parameter is preserved."""
with patch("importlib.util.find_spec", return_value=mock.MagicMock()):
mock_server_connection._lower_case_parameters = {"application": "existing_app"}
mock_server_connection._add_application_parameters()

assert (
mock_server_connection._lower_case_parameters["application"]
== "existing_app"
)
Loading