Skip to content

Commit a4df15a

Browse files
committed
Use graphql query builder
1 parent a7f1429 commit a4df15a

File tree

3 files changed

+95
-32
lines changed

3 files changed

+95
-32
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ dependencies = [
4444
"tenacity>=8.0.1",
4545
"watchfiles>=0.19.0,<0.20",
4646
"truss-transfer>=0.0.37,<0.0.40",
47+
"gql-query-builder (>=0.1.7,<0.2.0)",
4748
]
4849

4950
[project.urls]

truss/remote/baseten/api.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import json
12
import logging
23
from enum import Enum
34
from typing import Any, Dict, List, Mapping, Optional
45

56
import requests
7+
from gql_query_builder import GqlQuery
68
from pydantic import BaseModel, Field
79

810
from truss.remote.baseten import custom_types as b10_types
@@ -299,37 +301,39 @@ def deploy_chain_atomic(
299301
chain_name: Optional[str] = None,
300302
environment: Optional[str] = None,
301303
is_draft: bool = False,
304+
original_source_artifact_s3_key: Optional[str] = None,
305+
allow_truss_download: Optional[bool] = True,
302306
):
303-
entrypoint_str = _chainlet_data_atomic_to_graphql_mutation(entrypoint)
304-
305-
dependencies_str = ", ".join(
306-
[
307+
if allow_truss_download is None:
308+
allow_truss_download = True
309+
310+
# Build the mutation parameters
311+
mutation_params = {
312+
"chain_id": chain_id,
313+
"chain_name": chain_name,
314+
"environment": environment,
315+
"original_source_artifact_s3_key": original_source_artifact_s3_key,
316+
"allow_truss_download": "false" if allow_truss_download is False else None,
317+
"is_draft": is_draft,
318+
"entrypoint": _chainlet_data_atomic_to_graphql_mutation(entrypoint),
319+
"dependencies": [
307320
_chainlet_data_atomic_to_graphql_mutation(dependency)
308321
for dependency in dependencies
309-
]
310-
)
322+
],
323+
"truss_user_env": "$trussUserEnv",
324+
}
325+
mutation_params = {
326+
str(k): v for k, v in mutation_params.items() if v is not None
327+
}
311328

312-
query_string = f"""
313-
mutation ($trussUserEnv: String) {{
314-
deploy_chain_atomic(
315-
{f'chain_id: "{chain_id}"' if chain_id else ""}
316-
{f'chain_name: "{chain_name}"' if chain_name else ""}
317-
{f'environment: "{environment}"' if environment else ""}
318-
is_draft: {str(is_draft).lower()}
319-
entrypoint: {entrypoint_str}
320-
dependencies: [{dependencies_str}]
321-
truss_user_env: $trussUserEnv
322-
) {{
323-
chain_deployment {{
324-
id
325-
chain {{
326-
id
327-
hostname
328-
}}
329-
}}
330-
}}
331-
}}
332-
"""
329+
gql = GqlQuery()
330+
gql.operation(
331+
"mutation",
332+
"deploy_chain_atomic",
333+
mutation_params,
334+
["chain_deployment { id chain { id hostname } }"],
335+
)
336+
query_string = gql.generate()
333337

334338
resp = self._post_graphql_query(
335339
query_string, variables={"trussUserEnv": truss_user_env.json()}

truss/tests/remote/baseten/test_api.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,8 @@ def test_deploy_chain_deployment(mock_post, baseten_api):
353353

354354
gql_mutation = mock_post.call_args[1]["json"]["query"]
355355

356-
assert 'environment: "production"' in gql_mutation
357-
assert 'chain_id: "chain_id"' in gql_mutation
356+
assert "environment: production" in gql_mutation
357+
assert "chain_id: chain_id" in gql_mutation
358358
assert "dependencies:" in gql_mutation
359359
assert "entrypoint:" in gql_mutation
360360

@@ -378,8 +378,8 @@ def test_deploy_chain_deployment_with_gitinfo(mock_post, baseten_api):
378378

379379
gql_mutation = mock_post.call_args[1]["json"]["query"]
380380

381-
assert 'environment: "production"' in gql_mutation
382-
assert 'chain_id: "chain_id"' in gql_mutation
381+
assert "environment: production" in gql_mutation
382+
assert "chain_id: chain_id" in gql_mutation
383383
assert "dependencies:" in gql_mutation
384384
assert "entrypoint:" in gql_mutation
385385

@@ -402,12 +402,70 @@ def test_deploy_chain_deployment_no_environment(mock_post, baseten_api):
402402

403403
gql_mutation = mock_post.call_args[1]["json"]["query"]
404404

405-
assert 'chain_id: "chain_id"' in gql_mutation
405+
assert "chain_id: chain_id" in gql_mutation
406406
assert "environment" not in gql_mutation
407407
assert "dependencies:" in gql_mutation
408408
assert "entrypoint:" in gql_mutation
409409

410410

411+
@mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
412+
def test_deploy_chain_deployment_with_dependencies(mock_post, baseten_api):
413+
dependencies = [
414+
ChainletDataAtomic(
415+
name="dependency-1",
416+
oracle=OracleData(
417+
model_name="dep-model-1",
418+
s3_key="dep-s3-key-1",
419+
encoded_config_str="dep-encoded-config-str-1",
420+
),
421+
),
422+
ChainletDataAtomic(
423+
name="dependency-2",
424+
oracle=OracleData(
425+
model_name="dep-model-2",
426+
s3_key="dep-s3-key-2",
427+
encoded_config_str="dep-encoded-config-str-2",
428+
),
429+
),
430+
]
431+
432+
baseten_api.deploy_chain_atomic(
433+
environment="production",
434+
chain_id="chain_id",
435+
dependencies=dependencies,
436+
entrypoint=ChainletDataAtomic(
437+
name="chainlet-1",
438+
oracle=OracleData(
439+
model_name="model-1",
440+
s3_key="s3-key-1",
441+
encoded_config_str="encoded-config-str-1",
442+
),
443+
),
444+
truss_user_env=b10_types.TrussUserEnv.collect(),
445+
)
446+
447+
gql_mutation = mock_post.call_args[1]["json"]["query"]
448+
449+
# Single regex to check all assertions
450+
import re
451+
452+
pattern = (
453+
r"(?=.*environment: production)"
454+
r"(?=.*chain_id: chain_id)"
455+
r"(?=.*dependencies:)"
456+
r"(?=.*entrypoint:)"
457+
r'(?=.*name: "dependency-1")'
458+
r'(?=.*name: "dependency-2")'
459+
r'(?=.*model_name: "dep-model-1")'
460+
r'(?=.*model_name: "dep-model-2")'
461+
r'(?=.*s3_key: "dep-s3-key-1")'
462+
r'(?=.*s3_key: "dep-s3-key-2")'
463+
)
464+
assert re.search(pattern, gql_mutation), (
465+
f"GraphQL mutation does not contain all expected elements: {gql_mutation}"
466+
)
467+
468+
411469
@mock.patch("requests.post", return_value=mock_upsert_training_project_response())
412470
def test_upsert_training_project(mock_post, baseten_api):
413471
baseten_api.upsert_training_project(

0 commit comments

Comments
 (0)