From fdb586e9ba9d31b219493f0677a8de38b38e1a85 Mon Sep 17 00:00:00 2001 From: Archit Mishra Date: Wed, 12 Nov 2025 11:45:38 -0800 Subject: [PATCH] feat(truss): Add teams param to training job creation --- truss-train/truss_train/definitions.py | 1 + truss-train/truss_train/deployment.py | 9 +- truss/cli/remote_cli.py | 19 +++ truss/cli/train_commands.py | 23 +++- truss/remote/baseten/api.py | 20 ++- .../cli/train/test_train_team_parameter.py | 118 ++++++++++++++++++ 6 files changed, 186 insertions(+), 4 deletions(-) create mode 100644 truss/tests/cli/train/test_train_team_parameter.py diff --git a/truss-train/truss_train/definitions.py b/truss-train/truss_train/definitions.py index a5ea6bdd8..33f43edb4 100644 --- a/truss-train/truss_train/definitions.py +++ b/truss-train/truss_train/definitions.py @@ -179,6 +179,7 @@ def model_dump(self, *args, **kwargs): class TrainingProject(custom_types.SafeModelNoExtra): name: str + team_name: Optional[str] = None # TrainingProject is the wrapper around project config and job config. However, we exclude job # in serialization so just TrainingProject metadata is included in API requests. job: TrainingJob = pydantic.Field(exclude=True) diff --git a/truss-train/truss_train/deployment.py b/truss-train/truss_train/deployment.py index 909872d46..cf73e3ff5 100644 --- a/truss-train/truss_train/deployment.py +++ b/truss-train/truss_train/deployment.py @@ -53,8 +53,13 @@ def prepare_push(api: BasetenApi, config: pathlib.Path, training_job: TrainingJo def create_training_job( - remote_provider: BasetenRemote, training_project: TrainingProject, config: Path + remote_provider: BasetenRemote, + training_project: TrainingProject, + config: Path, + team_name: Optional[str] = None, ) -> dict: + if team_name: + training_project.team_name = team_name project_resp = remote_provider.api.upsert_training_project( training_project=training_project ) @@ -70,6 +75,7 @@ def create_training_job_from_file( remote_provider: BasetenRemote, config: Path, job_name_from_cli: Optional[str] = None, + team_name: Optional[str] = None, ) -> dict: with loader.import_training_project(config) as training_project: if job_name_from_cli: @@ -82,6 +88,7 @@ def create_training_job_from_file( remote_provider=remote_provider, training_project=training_project, config=config, + team_name=team_name, ) job_resp["job_object"] = training_project.job return job_resp diff --git a/truss/cli/remote_cli.py b/truss/cli/remote_cli.py index 2d535d095..a99f61f10 100644 --- a/truss/cli/remote_cli.py +++ b/truss/cli/remote_cli.py @@ -1,7 +1,10 @@ +from typing import Optional + from InquirerPy import inquirer from InquirerPy.validator import ValidationError, Validator from truss.cli.utils.output import console +from truss.remote.baseten.remote import BasetenRemote from truss.remote.remote_factory import USER_TRUSSRC_PATH, RemoteFactory from truss.remote.truss_remote import RemoteConfig @@ -56,3 +59,19 @@ def inquire_remote_name() -> str: def inquire_model_name() -> str: return inquirer.text("📦 Name this model:", qmark="").execute() + + +def inquire_team(remote_provider: BasetenRemote) -> Optional[str]: + """ + Inquire for team selection if multiple teams are available. + Returns team name if selected, None otherwise. + """ + teams = remote_provider.api.get_teams() + if len(teams) > 1: + team_names = [team["name"] for team in teams] + selected_team_name = inquirer.select( + "👥 Which team do you want to use?", qmark="", choices=team_names + ).execute() + return selected_team_name + # If 0 or 1 teams, return None (don't propagate team param) + return None diff --git a/truss/cli/train_commands.py b/truss/cli/train_commands.py index c958253d3..ab0649435 100644 --- a/truss/cli/train_commands.py +++ b/truss/cli/train_commands.py @@ -116,9 +116,16 @@ def _prepare_click_context(f: click.Command, params: dict) -> click.Context: @click.option("--remote", type=str, required=False, help="Remote to use") @click.option("--tail", is_flag=True, help="Tail for status + logs after push.") @click.option("--job-name", type=str, required=False, help="Name of the training job.") +@click.option( + "--team", type=str, required=False, help="Team name for the training project" +) @common.common_options() def push_training_job( - config: Path, remote: Optional[str], tail: bool, job_name: Optional[str] + config: Path, + remote: Optional[str], + tail: bool, + job_name: Optional[str], + team: Optional[str], ): """Run a training job""" from truss_train import deployment @@ -130,8 +137,20 @@ def push_training_job( remote_provider: BasetenRemote = cast( BasetenRemote, RemoteFactory.create(remote=remote) ) + # If team not provided, inquire for team selection + if team is None: + team = remote_cli.inquire_team(remote_provider) + # Validate team exists if provided + elif team is not None: + teams = remote_provider.api.get_teams() + team_names = [t["name"] for t in teams] + if team not in team_names: + available_teams_str = ", ".join(team_names) if team_names else "none" + raise click.ClickException( + f"Team '{team}' does not exist. Available teams: {available_teams_str}" + ) job_resp = deployment.create_training_job_from_file( - remote_provider, config, job_name + remote_provider, config, job_name, team_name=team ) # Note: This post create logic needs to happen outside the context diff --git a/truss/remote/baseten/api.py b/truss/remote/baseten/api.py index 9f8b95c4c..bb3b0a9af 100644 --- a/truss/remote/baseten/api.py +++ b/truss/remote/baseten/api.py @@ -659,7 +659,7 @@ def create_api_key(self, api_key_type: APIKeyCategory, name: str) -> Any: def upsert_training_project(self, training_project): resp_json = self._rest_api_client.post( "v1/training_projects", - body={"training_project": training_project.model_dump()}, + body={"training_project": training_project.model_dump(exclude_none=True)}, ) return resp_json["training_project"] @@ -912,3 +912,21 @@ def get_instance_types(self) -> List[InstanceTypeV1]: return [ InstanceTypeV1(**instance_type) for instance_type in instance_types_data ] + + def get_teams(self) -> List[Dict[str, str]]: + """ + Get all available teams via GraphQL API. + Returns a list of dictionaries with 'id' and 'name' keys. + """ + query_string = """ + query Teams { + teams { + id + name + } + } + """ + + resp = self._post_graphql_query(query_string) + teams_data = resp["data"]["teams"] + return teams_data diff --git a/truss/tests/cli/train/test_train_team_parameter.py b/truss/tests/cli/train/test_train_team_parameter.py new file mode 100644 index 000000000..76a576811 --- /dev/null +++ b/truss/tests/cli/train/test_train_team_parameter.py @@ -0,0 +1,118 @@ +"""Tests for team parameter in training project creation.""" + +from pathlib import Path +from unittest.mock import Mock, patch + +from click.testing import CliRunner +from requests import Response + +from truss.cli.cli import truss_cli +from truss.remote.baseten.remote import BasetenRemote + + +def mock_upsert_training_project_response(): + """Create a mock response for upsert_training_project.""" + response = Response() + response.status_code = 200 + response.json = Mock( + return_value={"training_project": {"id": "12345", "name": "training-project"}} + ) + return response + + +class TestTeamParameter: + """Test team parameter in training project creation.""" + + @patch("truss_train.deployment.create_training_job_from_file") + @patch("truss.cli.train_commands.RemoteFactory.create") + @patch("truss.cli.train_commands.console.status") + @patch("truss.cli.train_commands._handle_post_create_logic") + def test_team_provided_propagated_to_backend( + self, mock_post_create, mock_status, mock_remote_factory, mock_create_job + ): + """Test that --team parameter is propagated to backend request.""" + mock_status.return_value.__enter__ = Mock(return_value=None) + mock_status.return_value.__exit__ = Mock(return_value=None) + + mock_remote = Mock(spec=BasetenRemote) + mock_api = Mock() + mock_remote.api = mock_api + mock_api.get_teams.return_value = [{"id": "team1", "name": "Team Alpha"}] + mock_remote_factory.return_value = mock_remote + + mock_create_job.return_value = { + "id": "job123", + "training_project": {"id": "12345", "name": "test-project"}, + } + + runner = CliRunner() + config_path = Path("/tmp/test_config.py") + # Create a dummy config file for the test + config_path.parent.mkdir(parents=True, exist_ok=True) + config_path.write_text("# dummy config") + + result = runner.invoke( + truss_cli, + [ + "train", + "push", + str(config_path), + "--remote", + "test_remote", + "--team", + "Team Alpha", + ], + ) + + assert result.exit_code == 0 + mock_create_job.assert_called_once() + call_args = mock_create_job.call_args + assert call_args[1]["team_name"] == "Team Alpha" + + @patch("truss_train.deployment.create_training_job_from_file") + @patch("truss.cli.train_commands.RemoteFactory.create") + @patch("truss.cli.remote_cli.inquire_team") + @patch("truss.cli.train_commands.console.status") + @patch("truss.cli.train_commands._handle_post_create_logic") + def test_team_not_provided_inquire_team_called( + self, + mock_post_create, + mock_status, + mock_inquire_team, + mock_remote_factory, + mock_create_job, + ): + """Test that inquire_team is called when --team is not provided.""" + mock_status.return_value.__enter__ = Mock(return_value=None) + mock_status.return_value.__exit__ = Mock(return_value=None) + + mock_remote = Mock(spec=BasetenRemote) + mock_api = Mock() + mock_remote.api = mock_api + mock_api.get_teams.return_value = [ + {"id": "team1", "name": "Team Alpha"}, + {"id": "team2", "name": "Team Beta"}, + ] + mock_remote_factory.return_value = mock_remote + + mock_inquire_team.return_value = "Team Beta" + mock_create_job.return_value = { + "id": "job123", + "training_project": {"id": "12345", "name": "test-project"}, + } + + runner = CliRunner() + config_path = Path("/tmp/test_config.py") + # Create a dummy config file for the test + config_path.parent.mkdir(parents=True, exist_ok=True) + config_path.write_text("# dummy config") + + result = runner.invoke( + truss_cli, ["train", "push", str(config_path), "--remote", "test_remote"] + ) + + assert result.exit_code == 0 + mock_inquire_team.assert_called_once_with(mock_remote) + mock_create_job.assert_called_once() + call_args = mock_create_job.call_args + assert call_args[1]["team_name"] == "Team Beta"