|
4 | 4 |
|
5 | 5 | import io |
6 | 6 | import logging |
| 7 | +import os |
7 | 8 | from unittest import mock |
8 | | -from unittest.mock import MagicMock |
| 9 | +from unittest.mock import MagicMock, patch |
9 | 10 |
|
10 | 11 | import pytest |
11 | 12 |
|
@@ -169,3 +170,71 @@ def test_run_query_when_ignore_results_true(mock_server_connection): |
169 | 170 | ) |
170 | 171 | mock_server_connection._to_data_or_iter.assert_called() |
171 | 172 | assert "sfqid" in result and result["sfqid"] == "ignore_results is False" |
| 173 | + |
| 174 | + |
| 175 | +def test_snowbook_detection_adds_notebook_application(mock_server_connection): |
| 176 | + """Test that 'notebook' is added when snowbook is available.""" |
| 177 | + with patch("importlib.util.find_spec") as mock_find_spec: |
| 178 | + mock_find_spec.side_effect = ( |
| 179 | + lambda name: mock.MagicMock() if name == "snowbook" else None |
| 180 | + ) |
| 181 | + |
| 182 | + mock_server_connection._lower_case_parameters = {} |
| 183 | + mock_server_connection._add_application_parameters() |
| 184 | + |
| 185 | + assert ( |
| 186 | + "notebook" in mock_server_connection._lower_case_parameters["application"] |
| 187 | + ) |
| 188 | + |
| 189 | + |
| 190 | +def test_snowbook_detection_without_snowbook(mock_server_connection): |
| 191 | + """Test that 'notebook' is not added when snowbook is not available.""" |
| 192 | + with patch("importlib.util.find_spec", return_value=None): |
| 193 | + mock_server_connection._lower_case_parameters = {} |
| 194 | + mock_server_connection._add_application_parameters() |
| 195 | + |
| 196 | + assert ( |
| 197 | + "notebook" |
| 198 | + not in mock_server_connection._lower_case_parameters["application"] |
| 199 | + ) |
| 200 | + |
| 201 | + |
| 202 | +def test_snowbook_detection_with_multiple_applications(mock_server_connection): |
| 203 | + """Test that snowbook works alongside other application detections.""" |
| 204 | + with patch("importlib.util.find_spec") as mock_find_spec: |
| 205 | + mock_find_spec.side_effect = ( |
| 206 | + lambda name: mock.MagicMock() |
| 207 | + if name in ["streamlit", "snowflake.ml", "snowbook"] |
| 208 | + else None |
| 209 | + ) |
| 210 | + |
| 211 | + mock_server_connection._lower_case_parameters = {} |
| 212 | + mock_server_connection._add_application_parameters() |
| 213 | + |
| 214 | + app_param = mock_server_connection._lower_case_parameters["application"] |
| 215 | + assert app_param == "streamlit:SnowparkML:notebook" |
| 216 | + |
| 217 | + |
| 218 | +def test_env_var_partner_takes_precedence(mock_server_connection): |
| 219 | + """Test that ENV_VAR_PARTNER takes precedence over module detection.""" |
| 220 | + with patch.dict(os.environ, {"SF_PARTNER": "custom_partner"}): |
| 221 | + with patch("importlib.util.find_spec", return_value=mock.MagicMock()): |
| 222 | + mock_server_connection._lower_case_parameters = {} |
| 223 | + mock_server_connection._add_application_parameters() |
| 224 | + |
| 225 | + assert ( |
| 226 | + mock_server_connection._lower_case_parameters["application"] |
| 227 | + == "custom_partner" |
| 228 | + ) |
| 229 | + |
| 230 | + |
| 231 | +def test_existing_application_param_not_overwritten(mock_server_connection): |
| 232 | + """Test that existing application parameter is preserved.""" |
| 233 | + with patch("importlib.util.find_spec", return_value=mock.MagicMock()): |
| 234 | + mock_server_connection._lower_case_parameters = {"application": "existing_app"} |
| 235 | + mock_server_connection._add_application_parameters() |
| 236 | + |
| 237 | + assert ( |
| 238 | + mock_server_connection._lower_case_parameters["application"] |
| 239 | + == "existing_app" |
| 240 | + ) |
0 commit comments