Skip to content

Commit 5a5ccde

Browse files
Lukas Valatkaastronautas
authored andcommitted
fix: Thread safe Clickhouse offline store (#5710)
* add pull_all_from_table_or_query for clickhouse, to align with new materialization logic (calling it) Signed-off-by: lukas.valatka <[email protected]> * make sure get client is thread-local, since client is thread-unsafe Signed-off-by: lukas.valatka <[email protected]> * cleanup Signed-off-by: lukas.valatka <[email protected]> --------- Signed-off-by: lukas.valatka <[email protected]> Co-authored-by: lukas.valatka <[email protected]>
1 parent 289849b commit 5a5ccde

File tree

2 files changed

+92
-10
lines changed

2 files changed

+92
-10
lines changed
Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
1-
from functools import cache
1+
import threading
22

33
import clickhouse_connect
44
from clickhouse_connect.driver import Client
55

66
from feast.infra.utils.clickhouse.clickhouse_config import ClickhouseConfig
77

8+
thread_local = threading.local()
9+
810

9-
@cache
1011
def get_client(config: ClickhouseConfig) -> Client:
11-
client = clickhouse_connect.get_client(
12-
host=config.host,
13-
port=config.port,
14-
user=config.user,
15-
password=config.password,
16-
database=config.database,
17-
)
18-
return client
12+
# Clickhouse client is not thread-safe, so we need to create a separate instance for each thread.
13+
if not hasattr(thread_local, "clickhouse_client"):
14+
thread_local.clickhouse_client = clickhouse_connect.get_client(
15+
host=config.host,
16+
port=config.port,
17+
user=config.user,
18+
password=config.password,
19+
database=config.database,
20+
)
21+
22+
return thread_local.clickhouse_client
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import threading
2+
from unittest.mock import MagicMock, patch
3+
4+
import pytest
5+
6+
from feast.infra.utils.clickhouse.clickhouse_config import ClickhouseConfig
7+
from feast.infra.utils.clickhouse.connection_utils import get_client, thread_local
8+
9+
10+
@pytest.fixture
11+
def clickhouse_config():
12+
"""Create a test ClickHouse configuration."""
13+
return ClickhouseConfig(
14+
host="localhost",
15+
port=9000,
16+
user="default",
17+
password="password",
18+
database="test_db",
19+
)
20+
21+
22+
@pytest.fixture(autouse=True)
23+
def cleanup_thread_local():
24+
"""Clean up thread_local storage after each test."""
25+
yield
26+
if hasattr(thread_local, "clickhouse_client"):
27+
delattr(thread_local, "clickhouse_client")
28+
29+
30+
@patch("feast.infra.utils.clickhouse.connection_utils.clickhouse_connect.get_client")
31+
def test_get_client_returns_different_objects_for_separate_threads(
32+
mock_get_client, clickhouse_config
33+
):
34+
"""
35+
Clickhouse client is thread-unsafe and crashes if shared between threads.
36+
This test ensures that get_client returns different client instances for different threads, while
37+
reusing the same instance within the same thread.
38+
"""
39+
40+
def create_mock_client(*args, **kwargs):
41+
"""Create a unique mock client for each call."""
42+
return MagicMock()
43+
44+
mock_get_client.side_effect = create_mock_client
45+
46+
results = {}
47+
48+
def thread_1_work():
49+
"""Thread 1 makes 2 calls to get_client."""
50+
client_1a = get_client(clickhouse_config)
51+
client_1b = get_client(clickhouse_config)
52+
results["thread_1"] = (client_1a, client_1b)
53+
54+
def thread_2_work():
55+
"""Thread 2 makes 1 call to get_client."""
56+
client_2 = get_client(clickhouse_config)
57+
results["thread_2"] = client_2
58+
59+
thread_1 = threading.Thread(target=thread_1_work)
60+
thread_2 = threading.Thread(target=thread_2_work)
61+
62+
thread_1.start()
63+
thread_2.start()
64+
65+
thread_1.join()
66+
thread_2.join()
67+
68+
# Thread 1's two calls should return the same client (thread-local reuse)
69+
client_1a, client_1b = results["thread_1"]
70+
assert client_1a is client_1b, (
71+
"Same thread should get same client instance (cached)"
72+
)
73+
74+
# Thread 2's client should be different from thread 1's client
75+
client_2 = results["thread_2"]
76+
assert client_1a is not client_2, (
77+
"Different threads should get different client instances (not cached)"
78+
)

0 commit comments

Comments
 (0)