|
1 | 1 | import asyncio |
2 | | -from typing import Any, Callable, Generator, cast |
3 | | -from unittest.mock import patch |
| 2 | +from typing import Any, Callable, Generator, cast, Dict, Tuple |
| 3 | +from unittest.mock import MagicMock, patch |
4 | 4 | import chromadb |
5 | | -from chromadb.config import Settings |
| 5 | +from chromadb.config import Settings, System |
6 | 6 | from chromadb.api import ClientAPI |
7 | 7 | import chromadb.server.fastapi |
| 8 | +from chromadb.api.fastapi import FastAPI |
8 | 9 | import pytest |
9 | 10 | import tempfile |
10 | 11 | import os |
@@ -110,3 +111,43 @@ def test_http_client_with_inconsistent_port_settings( |
110 | 111 | str(e) |
111 | 112 | == "Chroma server http port provided in settings[8001] is different to the one provided in HttpClient: [8002]" |
112 | 113 | ) |
| 114 | + |
| 115 | + |
| 116 | +def make_sync_client_factory() -> Tuple[Callable[..., Any], Dict[str, Any]]: |
| 117 | + captured: Dict[str, Any] = {} |
| 118 | + |
| 119 | + # takes any positional args to match httpx.Client |
| 120 | + def factory(*_: Any, **kwargs: Any) -> Any: |
| 121 | + captured.update(kwargs) |
| 122 | + session = MagicMock() |
| 123 | + session.headers = {} |
| 124 | + return session |
| 125 | + |
| 126 | + return factory, captured |
| 127 | + |
| 128 | + |
| 129 | +def test_fastapi_uses_http_limits_from_settings() -> None: |
| 130 | + settings = Settings( |
| 131 | + chroma_api_impl="chromadb.api.fastapi.FastAPI", |
| 132 | + chroma_server_host="localhost", |
| 133 | + chroma_server_http_port=9000, |
| 134 | + chroma_server_ssl_verify=True, |
| 135 | + chroma_http_keepalive_secs=12.5, |
| 136 | + chroma_http_max_connections=64, |
| 137 | + chroma_http_max_keepalive_connections=16, |
| 138 | + ) |
| 139 | + system = System(settings) |
| 140 | + |
| 141 | + factory, captured = make_sync_client_factory() |
| 142 | + |
| 143 | + with patch.object(FastAPI, "require", side_effect=[MagicMock(), MagicMock()]): |
| 144 | + with patch("chromadb.api.fastapi.httpx.Client", side_effect=factory): |
| 145 | + api = FastAPI(system) |
| 146 | + |
| 147 | + api.stop() |
| 148 | + limits = captured["limits"] |
| 149 | + assert limits.keepalive_expiry == 12.5 |
| 150 | + assert limits.max_connections == 64 |
| 151 | + assert limits.max_keepalive_connections == 16 |
| 152 | + assert captured["timeout"] is None |
| 153 | + assert captured["verify"] is True |
0 commit comments