@@ -353,8 +353,8 @@ def test_deploy_chain_deployment(mock_post, baseten_api):
353353
354354 gql_mutation = mock_post .call_args [1 ]["json" ]["query" ]
355355
356- assert ' environment: " production"' in gql_mutation
357- assert ' chain_id: " chain_id"' in gql_mutation
356+ assert " environment: production" in gql_mutation
357+ assert " chain_id: chain_id" in gql_mutation
358358 assert "dependencies:" in gql_mutation
359359 assert "entrypoint:" in gql_mutation
360360
@@ -378,8 +378,8 @@ def test_deploy_chain_deployment_with_gitinfo(mock_post, baseten_api):
378378
379379 gql_mutation = mock_post .call_args [1 ]["json" ]["query" ]
380380
381- assert ' environment: " production"' in gql_mutation
382- assert ' chain_id: " chain_id"' in gql_mutation
381+ assert " environment: production" in gql_mutation
382+ assert " chain_id: chain_id" in gql_mutation
383383 assert "dependencies:" in gql_mutation
384384 assert "entrypoint:" in gql_mutation
385385
@@ -402,12 +402,70 @@ def test_deploy_chain_deployment_no_environment(mock_post, baseten_api):
402402
403403 gql_mutation = mock_post .call_args [1 ]["json" ]["query" ]
404404
405- assert ' chain_id: " chain_id"' in gql_mutation
405+ assert " chain_id: chain_id" in gql_mutation
406406 assert "environment" not in gql_mutation
407407 assert "dependencies:" in gql_mutation
408408 assert "entrypoint:" in gql_mutation
409409
410410
411+ @mock .patch ("requests.post" , return_value = mock_deploy_chain_deployment_response ())
412+ def test_deploy_chain_deployment_with_dependencies (mock_post , baseten_api ):
413+ dependencies = [
414+ ChainletDataAtomic (
415+ name = "dependency-1" ,
416+ oracle = OracleData (
417+ model_name = "dep-model-1" ,
418+ s3_key = "dep-s3-key-1" ,
419+ encoded_config_str = "dep-encoded-config-str-1" ,
420+ ),
421+ ),
422+ ChainletDataAtomic (
423+ name = "dependency-2" ,
424+ oracle = OracleData (
425+ model_name = "dep-model-2" ,
426+ s3_key = "dep-s3-key-2" ,
427+ encoded_config_str = "dep-encoded-config-str-2" ,
428+ ),
429+ ),
430+ ]
431+
432+ baseten_api .deploy_chain_atomic (
433+ environment = "production" ,
434+ chain_id = "chain_id" ,
435+ dependencies = dependencies ,
436+ entrypoint = ChainletDataAtomic (
437+ name = "chainlet-1" ,
438+ oracle = OracleData (
439+ model_name = "model-1" ,
440+ s3_key = "s3-key-1" ,
441+ encoded_config_str = "encoded-config-str-1" ,
442+ ),
443+ ),
444+ truss_user_env = b10_types .TrussUserEnv .collect (),
445+ )
446+
447+ gql_mutation = mock_post .call_args [1 ]["json" ]["query" ]
448+
449+ # Single regex to check all assertions
450+ import re
451+
452+ pattern = (
453+ r"(?=.*environment: production)"
454+ r"(?=.*chain_id: chain_id)"
455+ r"(?=.*dependencies:)"
456+ r"(?=.*entrypoint:)"
457+ r'(?=.*name: "dependency-1")'
458+ r'(?=.*name: "dependency-2")'
459+ r'(?=.*model_name: "dep-model-1")'
460+ r'(?=.*model_name: "dep-model-2")'
461+ r'(?=.*s3_key: "dep-s3-key-1")'
462+ r'(?=.*s3_key: "dep-s3-key-2")'
463+ )
464+ assert re .search (pattern , gql_mutation ), (
465+ f"GraphQL mutation does not contain all expected elements: { gql_mutation } "
466+ )
467+
468+
411469@mock .patch ("requests.post" , return_value = mock_upsert_training_project_response ())
412470def test_upsert_training_project (mock_post , baseten_api ):
413471 baseten_api .upsert_training_project (
0 commit comments