-
Notifications
You must be signed in to change notification settings - Fork 91
Validate config.yaml fields depending on transport kind, fix model wrapper test #1788
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -542,10 +542,10 @@ def _validate_path(cls, v: str) -> str: | |
|
|
||
| class DockerServer(custom_types.ConfigModel): | ||
| start_command: str | ||
| server_port: int | ||
| predict_endpoint: str | ||
| readiness_endpoint: str | ||
| liveness_endpoint: str | ||
| server_port: Optional[int] = None | ||
| predict_endpoint: Optional[str] = None | ||
| readiness_endpoint: Optional[str] = None | ||
| liveness_endpoint: Optional[str] = None | ||
|
|
||
|
|
||
| class Checkpoint(custom_types.ConfigModel): | ||
|
|
@@ -616,6 +616,13 @@ class TrussConfig(custom_types.ConfigModel): | |
| apply_library_patches: bool = True | ||
| spec_version: str = "2.0" | ||
|
|
||
| DOCKER_SERVER_OPTIONAL_FIELDS: ClassVar[list[str]] = [ | ||
| "server_port", | ||
| "predict_endpoint", | ||
| "readiness_endpoint", | ||
| "liveness_endpoint", | ||
| ] | ||
|
|
||
| class Config: | ||
| protected_namespaces = () # Silence warnings about fields starting with `model_`. | ||
|
|
||
|
|
@@ -720,6 +727,34 @@ def _serialize_trt_llm( | |
| exclude_unset = bool(info.context and "verbose" in info.context) | ||
| return trt_llm.model_dump(exclude_unset=exclude_unset) | ||
|
|
||
| @pydantic.model_validator(mode="after") | ||
| def _validate_docker_server(self) -> "TrussConfig": | ||
| is_grpc = self.runtime.transport.kind == TransportKind.GRPC | ||
| has_docker_server = self.docker_server is not None | ||
|
|
||
| if is_grpc: | ||
| if not has_docker_server: | ||
| raise ValueError( | ||
| "docker_server is required when transport kind is gRPC" | ||
| ) | ||
| if any( | ||
| getattr(self.docker_server, field) is not None | ||
|
||
| for field in TrussConfig.DOCKER_SERVER_OPTIONAL_FIELDS | ||
| ): | ||
| raise ValueError( | ||
| "When transport kind is gRPC, docker_server should only have start_command defined" | ||
| ) | ||
| elif has_docker_server: | ||
| if any( | ||
| getattr(self.docker_server, field) is None | ||
| for field in TrussConfig.DOCKER_SERVER_OPTIONAL_FIELDS | ||
| ): | ||
| raise ValueError( | ||
saptarshi-baseten marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| "Please define server_port, predict_endpoint, readiness_endpoint, and liveness_endpoint for docker_server" | ||
| ) | ||
|
|
||
| return self | ||
|
|
||
|
|
||
| def _map_to_supported_python_version(python_version: str) -> str: | ||
| """Map python version to truss supported python version. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -977,3 +977,90 @@ def test_supported_versions_are_sorted(): | |
| assert semvers == semvers_sorted, ( | ||
| f"{constants.SUPPORTED_PYTHON_VERSIONS} must be sorted ascendingly" | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "transport_config", | ||
| [ | ||
| pytest.param( | ||
| { | ||
| "runtime": {"transport": {"kind": "grpc"}}, | ||
| "docker_server": {"start_command": "python main.py"}, | ||
| }, | ||
| id="valid-grpc-minimal", | ||
| ), | ||
| pytest.param( | ||
| { | ||
| "runtime": {"transport": {"kind": "http"}}, | ||
| "docker_server": { | ||
| "start_command": "./start.sh", | ||
| "server_port": 8080, | ||
| "predict_endpoint": "/predict", | ||
| "readiness_endpoint": "/ready", | ||
| "liveness_endpoint": "/health", | ||
| }, | ||
| }, | ||
| id="valid-http-full", | ||
| ), | ||
| pytest.param( | ||
| { | ||
| "runtime": {"transport": {"kind": "websocket"}}, | ||
| "docker_server": { | ||
| "start_command": "./start.sh", | ||
| "server_port": 8080, | ||
| "predict_endpoint": "/predict", | ||
| "readiness_endpoint": "/ready", | ||
| "liveness_endpoint": "/health", | ||
| }, | ||
| }, | ||
| id="valid-websocket-full", | ||
| ), | ||
| pytest.param( | ||
| {"runtime": {"transport": {"kind": "http"}}}, id="valid-http-no-docker" | ||
| ), | ||
| ], | ||
| ) | ||
| def test_valid_transport_configurations(transport_config, tmp_path): | ||
| config_path = tmp_path / "config.yaml" | ||
| config_path.write_text(yaml.dump(transport_config)) | ||
| config = TrussConfig.from_yaml(config_path) | ||
| assert config.runtime.transport.kind == TransportKind( | ||
| transport_config["runtime"]["transport"]["kind"] | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "invalid_config,expected_error", | ||
| [ | ||
| pytest.param( | ||
| { | ||
| "runtime": {"transport": {"kind": "grpc"}}, | ||
| "docker_server": { | ||
| "start_command": "./start.sh", | ||
| "server_port": 8080, | ||
| "predict_endpoint": "/predict", | ||
| }, | ||
| }, | ||
| "When transport kind is gRPC, docker_server should only have start_command defined", | ||
| id="invalid-grpc-extra-fields", | ||
| ), | ||
| pytest.param( | ||
| { | ||
| "runtime": {"transport": {"kind": "http"}}, | ||
| "docker_server": {"start_command": "./start.sh", "server_port": 8080}, | ||
| }, | ||
| "Please define server_port, predict_endpoint, readiness_endpoint, and liveness_endpoint for docker_server", | ||
| id="invalid-http-missing-fields", | ||
| ), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be good to add a case for websockets |
||
| pytest.param( | ||
| {"runtime": {"transport": {"kind": "grpc"}}}, | ||
| "docker_server is required when transport kind is gRPC", | ||
| id="invalid-grpc-missing-docker", | ||
| ), | ||
| ], | ||
| ) | ||
| def test_invalid_transport_configurations(invalid_config, expected_error, tmp_path): | ||
| config_path = tmp_path / "config.yaml" | ||
| config_path.write_text(yaml.dump(invalid_config)) | ||
| with pytest.raises(ValueError, match=expected_error): | ||
| TrussConfig.from_yaml(config_path) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are only optional in the case of grpc right? can we rename this to be clearer