Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions truss-train/truss_train/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def model_dump(self, *args, **kwargs):

class TrainingProject(custom_types.SafeModelNoExtra):
name: str
team_name: Optional[str] = None
# TrainingProject is the wrapper around project config and job config. However, we exclude job
# in serialization so just TrainingProject metadata is included in API requests.
job: TrainingJob = pydantic.Field(exclude=True)
Expand Down
9 changes: 8 additions & 1 deletion truss-train/truss_train/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,13 @@ def prepare_push(api: BasetenApi, config: pathlib.Path, training_job: TrainingJo


def create_training_job(
remote_provider: BasetenRemote, training_project: TrainingProject, config: Path
remote_provider: BasetenRemote,
training_project: TrainingProject,
config: Path,
team_name: Optional[str] = None,
) -> dict:
if team_name:
training_project.team_name = team_name
project_resp = remote_provider.api.upsert_training_project(
training_project=training_project
)
Expand All @@ -70,6 +75,7 @@ def create_training_job_from_file(
remote_provider: BasetenRemote,
config: Path,
job_name_from_cli: Optional[str] = None,
team_name: Optional[str] = None,
) -> dict:
with loader.import_training_project(config) as training_project:
if job_name_from_cli:
Expand All @@ -82,6 +88,7 @@ def create_training_job_from_file(
remote_provider=remote_provider,
training_project=training_project,
config=config,
team_name=team_name,
)
job_resp["job_object"] = training_project.job
return job_resp
19 changes: 19 additions & 0 deletions truss/cli/remote_cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Optional

from InquirerPy import inquirer
from InquirerPy.validator import ValidationError, Validator

from truss.cli.utils.output import console
from truss.remote.baseten.remote import BasetenRemote
from truss.remote.remote_factory import USER_TRUSSRC_PATH, RemoteFactory
from truss.remote.truss_remote import RemoteConfig

Expand Down Expand Up @@ -56,3 +59,19 @@ def inquire_remote_name() -> str:

def inquire_model_name() -> str:
return inquirer.text("📦 Name this model:", qmark="").execute()


def inquire_team(remote_provider: BasetenRemote) -> Optional[str]:
"""
Inquire for team selection if multiple teams are available.
Returns team name if selected, None otherwise.
"""
teams = remote_provider.api.get_teams()
if len(teams) > 1:
team_names = [team["name"] for team in teams]
selected_team_name = inquirer.select(
"👥 Which team do you want to use?", qmark="", choices=team_names
).execute()
return selected_team_name
# If 0 or 1 teams, return None (don't propagate team param)
return None
23 changes: 21 additions & 2 deletions truss/cli/train_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,16 @@ def _prepare_click_context(f: click.Command, params: dict) -> click.Context:
@click.option("--remote", type=str, required=False, help="Remote to use")
@click.option("--tail", is_flag=True, help="Tail for status + logs after push.")
@click.option("--job-name", type=str, required=False, help="Name of the training job.")
@click.option(
"--team", type=str, required=False, help="Team name for the training project"
)
@common.common_options()
def push_training_job(
config: Path, remote: Optional[str], tail: bool, job_name: Optional[str]
config: Path,
remote: Optional[str],
tail: bool,
job_name: Optional[str],
team: Optional[str],
):
"""Run a training job"""
from truss_train import deployment
Expand All @@ -130,8 +137,20 @@ def push_training_job(
remote_provider: BasetenRemote = cast(
BasetenRemote, RemoteFactory.create(remote=remote)
)
# If team not provided, inquire for team selection
if team is None:
team = remote_cli.inquire_team(remote_provider)
# Validate team exists if provided
elif team is not None:
teams = remote_provider.api.get_teams()
team_names = [t["name"] for t in teams]
if team not in team_names:
available_teams_str = ", ".join(team_names) if team_names else "none"
raise click.ClickException(
f"Team '{team}' does not exist. Available teams: {available_teams_str}"
)
job_resp = deployment.create_training_job_from_file(
remote_provider, config, job_name
remote_provider, config, job_name, team_name=team
)

# Note: This post create logic needs to happen outside the context
Expand Down
20 changes: 19 additions & 1 deletion truss/remote/baseten/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ def create_api_key(self, api_key_type: APIKeyCategory, name: str) -> Any:
def upsert_training_project(self, training_project):
resp_json = self._rest_api_client.post(
"v1/training_projects",
body={"training_project": training_project.model_dump()},
body={"training_project": training_project.model_dump(exclude_none=True)},
)
return resp_json["training_project"]

Expand Down Expand Up @@ -912,3 +912,21 @@ def get_instance_types(self) -> List[InstanceTypeV1]:
return [
InstanceTypeV1(**instance_type) for instance_type in instance_types_data
]

def get_teams(self) -> List[Dict[str, str]]:
"""
Get all available teams via GraphQL API.
Returns a list of dictionaries with 'id' and 'name' keys.
"""
query_string = """
query Teams {
teams {
id
name
}
}
"""

resp = self._post_graphql_query(query_string)
teams_data = resp["data"]["teams"]
return teams_data
118 changes: 118 additions & 0 deletions truss/tests/cli/train/test_train_team_parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Tests for team parameter in training project creation."""

from pathlib import Path
from unittest.mock import Mock, patch

from click.testing import CliRunner
from requests import Response

from truss.cli.cli import truss_cli
from truss.remote.baseten.remote import BasetenRemote


def mock_upsert_training_project_response():
"""Create a mock response for upsert_training_project."""
response = Response()
response.status_code = 200
response.json = Mock(
return_value={"training_project": {"id": "12345", "name": "training-project"}}
)
return response


class TestTeamParameter:
"""Test team parameter in training project creation."""

@patch("truss_train.deployment.create_training_job_from_file")
@patch("truss.cli.train_commands.RemoteFactory.create")
@patch("truss.cli.train_commands.console.status")
@patch("truss.cli.train_commands._handle_post_create_logic")
def test_team_provided_propagated_to_backend(
self, mock_post_create, mock_status, mock_remote_factory, mock_create_job
):
"""Test that --team parameter is propagated to backend request."""
mock_status.return_value.__enter__ = Mock(return_value=None)
mock_status.return_value.__exit__ = Mock(return_value=None)

mock_remote = Mock(spec=BasetenRemote)
mock_api = Mock()
mock_remote.api = mock_api
mock_api.get_teams.return_value = [{"id": "team1", "name": "Team Alpha"}]
mock_remote_factory.return_value = mock_remote

mock_create_job.return_value = {
"id": "job123",
"training_project": {"id": "12345", "name": "test-project"},
}

runner = CliRunner()
config_path = Path("/tmp/test_config.py")
# Create a dummy config file for the test
config_path.parent.mkdir(parents=True, exist_ok=True)
config_path.write_text("# dummy config")

result = runner.invoke(
truss_cli,
[
"train",
"push",
str(config_path),
"--remote",
"test_remote",
"--team",
"Team Alpha",
],
)

assert result.exit_code == 0
mock_create_job.assert_called_once()
call_args = mock_create_job.call_args
assert call_args[1]["team_name"] == "Team Alpha"

@patch("truss_train.deployment.create_training_job_from_file")
@patch("truss.cli.train_commands.RemoteFactory.create")
@patch("truss.cli.remote_cli.inquire_team")
@patch("truss.cli.train_commands.console.status")
@patch("truss.cli.train_commands._handle_post_create_logic")
def test_team_not_provided_inquire_team_called(
self,
mock_post_create,
mock_status,
mock_inquire_team,
mock_remote_factory,
mock_create_job,
):
"""Test that inquire_team is called when --team is not provided."""
mock_status.return_value.__enter__ = Mock(return_value=None)
mock_status.return_value.__exit__ = Mock(return_value=None)

mock_remote = Mock(spec=BasetenRemote)
mock_api = Mock()
mock_remote.api = mock_api
mock_api.get_teams.return_value = [
{"id": "team1", "name": "Team Alpha"},
{"id": "team2", "name": "Team Beta"},
]
mock_remote_factory.return_value = mock_remote

mock_inquire_team.return_value = "Team Beta"
mock_create_job.return_value = {
"id": "job123",
"training_project": {"id": "12345", "name": "test-project"},
}

runner = CliRunner()
config_path = Path("/tmp/test_config.py")
# Create a dummy config file for the test
config_path.parent.mkdir(parents=True, exist_ok=True)
config_path.write_text("# dummy config")

result = runner.invoke(
truss_cli, ["train", "push", str(config_path), "--remote", "test_remote"]
)

assert result.exit_code == 0
mock_inquire_team.assert_called_once_with(mock_remote)
mock_create_job.assert_called_once()
call_args = mock_create_job.call_args
assert call_args[1]["team_name"] == "Team Beta"
Loading