Skip to content

Commit 117de89

Browse files
authored
[ENH] Add keepalive and max conns to python client (#5822)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - Adds support for `chroma_http_keepalive_secs `, `chroma_http_max_connections `, `chroma_http_max_keepalive_connections ` in settings, propagating to fastapi and async fastapi httpx Clients. - Fixes a bug where when ssl was set, all http limits were removed when a new httpx client was created with the ssl variables. - New functionality - ... ## Test plan _How are these changes tested?_ Added tests to ensure settings are persisted, and that the httpx client correctly defines the limits - [ x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent 91da387 commit 117de89

File tree

6 files changed

+114
-16
lines changed

6 files changed

+114
-16
lines changed

chromadb/api/async_fastapi.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,11 @@ def _get_client(self) -> httpx.AsyncClient:
130130
+ " (https://github.com/chroma-core/chroma)"
131131
)
132132

133-
limits = httpx.Limits(keepalive_expiry=self.keepalive_secs)
134133
self._clients[loop_hash] = httpx.AsyncClient(
135134
timeout=None,
136135
headers=headers,
137136
verify=self._settings.chroma_server_ssl_verify or False,
138-
limits=limits,
137+
limits=self.http_limits,
139138
)
140139

141140
return self._clients[loop_hash]
@@ -527,7 +526,7 @@ async def _get(
527526
return GetResult(
528527
ids=resp_json["ids"],
529528
embeddings=resp_json.get("embeddings", None),
530-
metadatas=metadatas, # type: ignore
529+
metadatas=metadatas,
531530
documents=resp_json.get("documents", None),
532531
data=None,
533532
uris=resp_json.get("uris", None),
@@ -723,7 +722,7 @@ async def _query(
723722
ids=resp_json["ids"],
724723
distances=resp_json.get("distances", None),
725724
embeddings=resp_json.get("embeddings", None),
726-
metadatas=metadata_batches, # type: ignore
725+
metadatas=metadata_batches,
727726
documents=resp_json.get("documents", None),
728727
uris=resp_json.get("uris", None),
729728
data=None,

chromadb/api/base_http_client.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,47 @@
55
import httpx
66

77
import chromadb.errors as errors
8-
from chromadb.config import Settings
8+
from chromadb.config import Component, Settings, System
99

1010
logger = logging.getLogger(__name__)
1111

1212

13-
class BaseHTTPClient:
13+
# inherits from Component so that it can create an init function to use system
14+
# this way it can build limits from the settings in System
15+
class BaseHTTPClient(Component):
1416
_settings: Settings
1517
pre_flight_checks: Any = None
16-
keepalive_secs: int = 40
18+
DEFAULT_KEEPALIVE_SECS: float = 40.0
19+
20+
def __init__(self, system: System):
21+
super().__init__(system)
22+
self._settings = system.settings
23+
keepalive_setting = self._settings.chroma_http_keepalive_secs
24+
self.keepalive_secs: Optional[float] = (
25+
keepalive_setting
26+
if keepalive_setting is not None
27+
else BaseHTTPClient.DEFAULT_KEEPALIVE_SECS
28+
)
29+
self._http_limits = self._build_limits()
30+
31+
def _build_limits(self) -> httpx.Limits:
32+
limit_kwargs: Dict[str, Any] = {}
33+
if self.keepalive_secs is not None:
34+
limit_kwargs["keepalive_expiry"] = self.keepalive_secs
35+
36+
max_connections = self._settings.chroma_http_max_connections
37+
if max_connections is not None:
38+
limit_kwargs["max_connections"] = max_connections
39+
40+
max_keepalive_connections = self._settings.chroma_http_max_keepalive_connections
41+
if max_keepalive_connections is not None:
42+
limit_kwargs["max_keepalive_connections"] = max_keepalive_connections
43+
44+
return httpx.Limits(**limit_kwargs)
45+
46+
@property
47+
def http_limits(self) -> httpx.Limits:
48+
return self._http_limits
1749

1850
@staticmethod
1951
def _validate_host(host: str) -> None:

chromadb/api/fastapi.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,14 @@ def __init__(self, system: System):
7979
default_api_path=system.settings.chroma_server_api_default_path,
8080
)
8181

82-
limits = httpx.Limits(keepalive_expiry=self.keepalive_secs)
83-
self._session = httpx.Client(timeout=None, limits=limits)
82+
if self._settings.chroma_server_ssl_verify is not None:
83+
self._session = httpx.Client(
84+
timeout=None,
85+
limits=self.http_limits,
86+
verify=self._settings.chroma_server_ssl_verify,
87+
)
88+
else:
89+
self._session = httpx.Client(timeout=None, limits=self.http_limits)
8490

8591
self._header = system.settings.chroma_server_headers or {}
8692
self._header["Content-Type"] = "application/json"
@@ -90,8 +96,6 @@ def __init__(self, system: System):
9096
+ " (https://github.com/chroma-core/chroma)"
9197
)
9298

93-
if self._settings.chroma_server_ssl_verify is not None:
94-
self._session = httpx.Client(verify=self._settings.chroma_server_ssl_verify)
9599
if self._header is not None:
96100
self._session.headers.update(self._header)
97101

@@ -492,7 +496,7 @@ def _get(
492496
return GetResult(
493497
ids=resp_json["ids"],
494498
embeddings=resp_json.get("embeddings", None),
495-
metadatas=metadatas, # type: ignore
499+
metadatas=metadatas,
496500
documents=resp_json.get("documents", None),
497501
data=None,
498502
uris=resp_json.get("uris", None),
@@ -700,7 +704,7 @@ def _query(
700704
ids=resp_json["ids"],
701705
distances=resp_json.get("distances", None),
702706
embeddings=resp_json.get("embeddings", None),
703-
metadatas=metadata_batches, # type: ignore
707+
metadatas=metadata_batches,
704708
documents=resp_json.get("documents", None),
705709
uris=resp_json.get("uris", None),
706710
data=None,

chromadb/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ def empty_str_to_none(cls, v: str) -> Optional[str]:
154154
# eg ["http://localhost:8000"]
155155
chroma_server_cors_allow_origins: List[str] = []
156156

157+
chroma_http_keepalive_secs: Optional[float] = 40.0
158+
chroma_http_max_connections: Optional[int] = None
159+
chroma_http_max_keepalive_connections: Optional[int] = None
160+
157161
# ==================
158162
# Server config
159163
# ==================

chromadb/test/test_client.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import asyncio
2-
from typing import Any, Callable, Generator, cast
3-
from unittest.mock import patch
2+
from typing import Any, Callable, Generator, cast, Dict, Tuple
3+
from unittest.mock import MagicMock, patch
44
import chromadb
5-
from chromadb.config import Settings
5+
from chromadb.config import Settings, System
66
from chromadb.api import ClientAPI
77
import chromadb.server.fastapi
8+
from chromadb.api.fastapi import FastAPI
89
import pytest
910
import tempfile
1011
import os
@@ -110,3 +111,43 @@ def test_http_client_with_inconsistent_port_settings(
110111
str(e)
111112
== "Chroma server http port provided in settings[8001] is different to the one provided in HttpClient: [8002]"
112113
)
114+
115+
116+
def make_sync_client_factory() -> Tuple[Callable[..., Any], Dict[str, Any]]:
117+
captured: Dict[str, Any] = {}
118+
119+
# takes any positional args to match httpx.Client
120+
def factory(*_: Any, **kwargs: Any) -> Any:
121+
captured.update(kwargs)
122+
session = MagicMock()
123+
session.headers = {}
124+
return session
125+
126+
return factory, captured
127+
128+
129+
def test_fastapi_uses_http_limits_from_settings() -> None:
130+
settings = Settings(
131+
chroma_api_impl="chromadb.api.fastapi.FastAPI",
132+
chroma_server_host="localhost",
133+
chroma_server_http_port=9000,
134+
chroma_server_ssl_verify=True,
135+
chroma_http_keepalive_secs=12.5,
136+
chroma_http_max_connections=64,
137+
chroma_http_max_keepalive_connections=16,
138+
)
139+
system = System(settings)
140+
141+
factory, captured = make_sync_client_factory()
142+
143+
with patch.object(FastAPI, "require", side_effect=[MagicMock(), MagicMock()]):
144+
with patch("chromadb.api.fastapi.httpx.Client", side_effect=factory):
145+
api = FastAPI(system)
146+
147+
api.stop()
148+
limits = captured["limits"]
149+
assert limits.keepalive_expiry == 12.5
150+
assert limits.max_connections == 64
151+
assert limits.max_keepalive_connections == 16
152+
assert captured["timeout"] is None
153+
assert captured["verify"] is True

chromadb/test/test_config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,21 @@ def test_runtime_dependencies() -> None:
189189
assert data.starts == ["D", "C"]
190190
system.stop()
191191
assert data.stops == ["C", "D"]
192+
193+
194+
def test_http_client_setting_defaults() -> None:
195+
settings = Settings()
196+
assert settings.chroma_http_keepalive_secs == 40.0
197+
assert settings.chroma_http_max_connections is None
198+
assert settings.chroma_http_max_keepalive_connections is None
199+
200+
201+
def test_http_client_setting_overrides() -> None:
202+
settings = Settings(
203+
chroma_http_keepalive_secs=5.5,
204+
chroma_http_max_connections=123,
205+
chroma_http_max_keepalive_connections=17,
206+
)
207+
assert settings.chroma_http_keepalive_secs == 5.5
208+
assert settings.chroma_http_max_connections == 123
209+
assert settings.chroma_http_max_keepalive_connections == 17

0 commit comments

Comments
 (0)