diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index a3ada735e6..8b7c125ea6 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -57,6 +57,7 @@ get_repo_creds_and_default_branch, ) from dstack._internal.core.services.ssh.ports import PortUsedError +from dstack._internal.settings import FeatureFlags from dstack._internal.utils.common import local_time from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator from dstack._internal.utils.logging import get_logger @@ -215,6 +216,7 @@ def apply_configuration( current_job_submission = run._run.latest_job_submission if run.status in (RunStatus.RUNNING, RunStatus.DONE): _print_service_urls(run) + _print_dev_environment_connection_info(run) bind_address: Optional[str] = getattr( configurator_args, _BIND_ADDRESS_ARG, None ) @@ -806,6 +808,30 @@ def _print_service_urls(run: Run) -> None: console.print() +def _print_dev_environment_connection_info(run: Run) -> None: + if not FeatureFlags.CLI_PRINT_JOB_CONNECTION_INFO: + return + if run._run.run_spec.configuration.type != RunConfigurationType.DEV_ENVIRONMENT.value: + return + jci = run._run.jobs[0].job_connection_info + if jci is None: + return + if jci.ide_name: + urls = [u for u in (jci.attached_ide_url, jci.proxied_ide_url) if u] + if urls: + console.print( + f"To open in {jci.ide_name}, use link{'s' if len(urls) > 1 else ''} below:\n" + ) + for link in urls: + console.print(f" [link={link}]{link}[/]\n") + ssh_commands = [" ".join(c) for c in (jci.attached_ssh_command, jci.proxied_ssh_command) if c] + if ssh_commands: + console.print( + f"To connect via SSH, use: {' or '.join(f'[code]{c}[/]' for c in ssh_commands)}\n" + ) + console.print() + + def print_finished_message(run: Run): status_message = ( run._run.latest_job_submission.status_message diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index 2b0e3a6b41..1d6b734f54 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -46,6 +46,8 @@ def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[IncludeExcludeD ] ), }, + # Contains only informational computed fields, safe to exclude unconditionally + "job_connection_info": True, } } if current_resource.latest_job_submission is not None: diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index fdb7b58cd2..0c48664a17 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -431,9 +431,53 @@ def duration(self) -> timedelta: return end_time - self.submitted_at +class JobConnectionInfo(CoreModel): + ide_name: Annotated[ + Optional[str], Field(description="Dev environment IDE name for UI, human-readable.") + ] + attached_ide_url: Annotated[ + Optional[str], + Field( + description=( + "Dev environment IDE URL." + " Not set if the job has not started yet." + " Only works if the user is attached to the run via CLI or Python API." + ) + ), + ] + proxied_ide_url: Annotated[ + Optional[str], + Field( + description=( + "Dev environment IDE URL." + " Not set if the job has hot started yet or sshproxy is not configured." + ) + ), + ] + attached_ssh_command: Annotated[ + Optional[list[str]], + Field( + description=( + "SSH command to connect to the job, list of command line arguments." + " Only works if the user is attached to the run via CLI or Python API." + ) + ), + ] + proxied_ssh_command: Annotated[ + Optional[list[str]], + Field( + description=( + "SSH command to connect to the job, list of command line arguments." + " Not set if sshproxy is not configured." + ) + ), + ] + + class Job(CoreModel): job_spec: JobSpec job_submissions: List[JobSubmission] + job_connection_info: Optional[JobConnectionInfo] = None class RunSpecConfig(CoreConfig): diff --git a/src/dstack/_internal/server/routers/sshproxy.py b/src/dstack/_internal/server/routers/sshproxy.py index 3edc927e96..33c5d4cdcc 100644 --- a/src/dstack/_internal/server/routers/sshproxy.py +++ b/src/dstack/_internal/server/routers/sshproxy.py @@ -1,21 +1,21 @@ -import os from typing import Annotated from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession from dstack._internal.core.errors import ResourceNotExistsError +from dstack._internal.server import settings from dstack._internal.server.db import get_session from dstack._internal.server.schemas.sshproxy import GetUpstreamRequest, GetUpstreamResponse from dstack._internal.server.security.permissions import AlwaysForbidden, ServiceAccount -from dstack._internal.server.services.sshproxy import get_upstream_response +from dstack._internal.server.services.sshproxy.handlers import get_upstream_response from dstack._internal.server.utils.routers import ( CustomORJSONResponse, get_base_api_additional_responses, ) -if _token := os.getenv("DSTACK_SSHPROXY_API_TOKEN"): - _auth = ServiceAccount(_token) +if settings.SSHPROXY_API_TOKEN is not None: + _auth = ServiceAccount(settings.SSHPROXY_API_TOKEN) else: _auth = AlwaysForbidden() diff --git a/src/dstack/_internal/server/services/ides/__init__.py b/src/dstack/_internal/server/services/ides/__init__.py new file mode 100644 index 0000000000..377fc139c9 --- /dev/null +++ b/src/dstack/_internal/server/services/ides/__init__.py @@ -0,0 +1,21 @@ +from typing import Literal, Optional + +from dstack._internal.server.services.ides.base import IDE +from dstack._internal.server.services.ides.cursor import CursorDesktop +from dstack._internal.server.services.ides.vscode import VSCodeDesktop +from dstack._internal.server.services.ides.windsurf import WindsurfDesktop + +_IDELiteral = Literal["vscode", "cursor", "windsurf"] + +_ide_literal_to_ide_class_map: dict[_IDELiteral, type[IDE]] = { + "vscode": VSCodeDesktop, + "cursor": CursorDesktop, + "windsurf": WindsurfDesktop, +} + + +def get_ide(ide_literal: _IDELiteral) -> Optional[IDE]: + ide_class = _ide_literal_to_ide_class_map.get(ide_literal) + if ide_class is None: + return None + return ide_class() diff --git a/src/dstack/_internal/server/services/ides/base.py b/src/dstack/_internal/server/services/ides/base.py new file mode 100644 index 0000000000..f97aad6d91 --- /dev/null +++ b/src/dstack/_internal/server/services/ides/base.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod +from typing import ClassVar, Optional + + +class IDE(ABC): + name: ClassVar[str] + url_scheme: ClassVar[str] + + @abstractmethod + def get_install_commands( + self, version: Optional[str] = None, extensions: Optional[list[str]] = None + ) -> list[str]: + pass + + def get_url(self, authority: str, working_dir: str) -> str: + return f"{self.url_scheme}://vscode-remote/ssh-remote+{authority}{working_dir}" + + def get_print_readme_commands(self, authority: str) -> list[str]: + url = self.get_url(authority, working_dir="$DSTACK_WORKING_DIR") + return [ + f"echo 'To open in {self.name}, use link below:'", + "echo", + f'echo " {url}"', + "echo", + ] diff --git a/src/dstack/_internal/server/services/ides/cursor.py b/src/dstack/_internal/server/services/ides/cursor.py new file mode 100644 index 0000000000..95512f355e --- /dev/null +++ b/src/dstack/_internal/server/services/ides/cursor.py @@ -0,0 +1,31 @@ +from typing import Optional + +from dstack._internal.server.services.ides.base import IDE + + +class CursorDesktop(IDE): + name = "Cursor" + url_scheme = "cursor" + + def get_install_commands( + self, version: Optional[str] = None, extensions: Optional[list[str]] = None + ) -> list[str]: + commands = [] + if version is not None: + url = f"https://cursor.blob.core.windows.net/remote-releases/{version}/vscode-reh-linux-$arch.tar.gz" + archive = "vscode-reh-linux-$arch.tar.gz" + target = f'~/.cursor-server/cli/servers/"Stable-{version}"/server' + commands.extend( + [ + 'if [ $(uname -m) = "aarch64" ]; then arch="arm64"; else arch="x64"; fi', + "mkdir -p /tmp", + f'wget -q --show-progress "{url}" -O "/tmp/{archive}"', + f"mkdir -vp {target}", + f'tar --no-same-owner -xz --strip-components=1 -C {target} -f "/tmp/{archive}"', + f'rm "/tmp/{archive}"', + ] + ) + if extensions: + _extensions = " ".join(f'--install-extension "{name}"' for name in extensions) + commands.append(f'PATH="$PATH":{target}/bin cursor-server {_extensions}') + return commands diff --git a/src/dstack/_internal/server/services/ides/vscode.py b/src/dstack/_internal/server/services/ides/vscode.py new file mode 100644 index 0000000000..3ab0b8ab95 --- /dev/null +++ b/src/dstack/_internal/server/services/ides/vscode.py @@ -0,0 +1,33 @@ +from typing import Optional + +from dstack._internal.server.services.ides.base import IDE + + +class VSCodeDesktop(IDE): + name = "VS Code" + url_scheme = "vscode" + + def get_install_commands( + self, version: Optional[str] = None, extensions: Optional[list[str]] = None + ) -> list[str]: + commands = [] + if version is not None: + url = ( + f"https://update.code.visualstudio.com/commit:{version}/server-linux-$arch/stable" + ) + archive = "vscode-server-linux-$arch.tar.gz" + target = f'~/.vscode-server/bin/"{version}"' + commands.extend( + [ + 'if [ $(uname -m) = "aarch64" ]; then arch="arm64"; else arch="x64"; fi', + "mkdir -p /tmp", + f'wget -q --show-progress "{url}" -O "/tmp/{archive}"', + f"mkdir -vp {target}", + f'tar --no-same-owner -xz --strip-components=1 -C {target} -f "/tmp/{archive}"', + f'rm "/tmp/{archive}"', + ] + ) + if extensions: + _extensions = " ".join(f'--install-extension "{name}"' for name in extensions) + commands.append(f'PATH="$PATH":{target}/bin code-server {_extensions}') + return commands diff --git a/src/dstack/_internal/server/services/ides/windsurf.py b/src/dstack/_internal/server/services/ides/windsurf.py new file mode 100644 index 0000000000..3b5042bb6d --- /dev/null +++ b/src/dstack/_internal/server/services/ides/windsurf.py @@ -0,0 +1,32 @@ +from typing import Optional + +from dstack._internal.server.services.ides.base import IDE + + +class WindsurfDesktop(IDE): + name = "Windsurf" + url_scheme = "windsurf" + + def get_install_commands( + self, version: Optional[str] = None, extensions: Optional[list[str]] = None + ) -> list[str]: + commands = [] + if version is not None: + version, commit = version.split("@") + url = f"https://windsurf-stable.codeiumdata.com/linux-reh-$arch/stable/{commit}/windsurf-reh-linux-$arch-{version}.tar.gz" + archive = "windsurf-reh-linux-$arch.tar.gz" + target = f'~/.windsurf-server/bin/"{commit}"' + commands.extend( + [ + 'if [ $(uname -m) = "aarch64" ]; then arch="arm64"; else arch="x64"; fi', + "mkdir -p /tmp", + f'wget -q --show-progress "{url}" -O "/tmp/{archive}"', + f"mkdir -vp {target}", + f'tar --no-same-owner -xz --strip-components=1 -C {target} -f "/tmp/{archive}"', + f'rm "/tmp/{archive}"', + ] + ) + if extensions: + _extensions = " ".join(f'--install-extension "{name}"' for name in extensions) + commands.append(f'PATH="$PATH":{target}/bin windsurf-server {_extensions}') + return commands diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 62254d7659..75d48d7cab 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -19,6 +19,7 @@ from dstack._internal.core.models.configurations import RunConfigurationType from dstack._internal.core.models.runs import ( Job, + JobConnectionInfo, JobProvisioningData, JobRuntimeData, JobSpec, @@ -37,6 +38,7 @@ ) from dstack._internal.server.services import events from dstack._internal.server.services import volumes as volumes_services +from dstack._internal.server.services.ides import get_ide from dstack._internal.server.services.instances import ( get_instance_ssh_private_keys, ) @@ -51,9 +53,14 @@ from dstack._internal.server.services.probes import probe_model_to_probe from dstack._internal.server.services.runner import client from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel +from dstack._internal.server.services.sshproxy import ( + build_proxied_job_ssh_command, + build_proxied_job_ssh_url_authority, +) from dstack._internal.utils import common from dstack._internal.utils.common import run_async from dstack._internal.utils.logging import get_logger +from dstack._internal.utils.ssh import build_ssh_command, build_ssh_url_authority logger = get_logger(__name__) @@ -490,6 +497,47 @@ def remove_job_spec_sensitive_info(spec: JobSpec): spec.ssh_key = None +def get_job_connection_info(job_model: JobModel, run_spec: RunSpec) -> JobConnectionInfo: + # Run.attach() Python API method, used internally by CLI, uses the following as the Hostname + # in the SSH config: + # * for the (job=0 replica=0) job - run name, e.g., `my-task` + # * for other jobs - job name, e.g., `my-task-0-1` + attached_hostname = run_spec.run_name + if job_model.job_num != 0 or job_model.replica_num != 0: + attached_hostname = job_model.job_name + assert attached_hostname is not None + + # ide_* fields are for dev-environment only + ide_name: Optional[str] = None + # IDE URLs are not set until the job status is switched to RUNNING, + # as JobRuntimeData.working_dir, which is required to build URLs, is returned + # by dstack-runner's `/api/run` method + attached_ide_url: Optional[str] = None + proxied_ide_url: Optional[str] = None + if ( + run_spec.configuration.type == RunConfigurationType.DEV_ENVIRONMENT.value + and run_spec.configuration.ide is not None + ): + ide = get_ide(run_spec.configuration.ide) + if ide is not None: + ide_name = ide.name + jrd = get_job_runtime_data(job_model) + if jrd is not None and jrd.working_dir is not None: + attached_url_authority = build_ssh_url_authority(hostname=attached_hostname) + attached_ide_url = ide.get_url(attached_url_authority, jrd.working_dir) + proxied_url_authority = build_proxied_job_ssh_url_authority(job_model) + if proxied_url_authority is not None: + proxied_ide_url = ide.get_url(proxied_url_authority, jrd.working_dir) + + return JobConnectionInfo( + ide_name=ide_name, + attached_ide_url=attached_ide_url, + proxied_ide_url=proxied_ide_url, + attached_ssh_command=build_ssh_command(hostname=attached_hostname), + proxied_ssh_command=build_proxied_job_ssh_command(job_model), + ) + + def _get_job_mount_point_attached_volume( volumes: List[Volume], job_provisioning_data: JobProvisioningData, diff --git a/src/dstack/_internal/server/services/jobs/configurators/dev.py b/src/dstack/_internal/server/services/jobs/configurators/dev.py index ae059b9e09..8bc82288d5 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/dev.py +++ b/src/dstack/_internal/server/services/jobs/configurators/dev.py @@ -4,10 +4,8 @@ from dstack._internal.core.models.configurations import PortMapping, RunConfigurationType from dstack._internal.core.models.profiles import SpotPolicy from dstack._internal.core.models.runs import RunSpec +from dstack._internal.server.services.ides import get_ide from dstack._internal.server.services.jobs.configurators.base import JobConfigurator -from dstack._internal.server.services.jobs.configurators.extensions.cursor import CursorDesktop -from dstack._internal.server.services.jobs.configurators.extensions.vscode import VSCodeDesktop -from dstack._internal.server.services.jobs.configurators.extensions.windsurf import WindsurfDesktop INSTALL_IPYKERNEL = ( "(echo 'pip install ipykernel...' && pip install -q --no-cache-dir ipykernel 2> /dev/null) || " @@ -18,6 +16,8 @@ class DevEnvironmentJobConfigurator(JobConfigurator): TYPE: RunConfigurationType = RunConfigurationType.DEV_ENVIRONMENT + ide_extensions = ["ms-python.python", "ms-toolsai.jupyter"] + def __init__( self, run_spec: RunSpec, secrets: Dict[str, str], replica_group_name: Optional[str] = None ): @@ -26,19 +26,10 @@ def __init__( if run_spec.configuration.ide is None: self.ide = None else: - if run_spec.configuration.ide == "vscode": - __class = VSCodeDesktop - elif run_spec.configuration.ide == "cursor": - __class = CursorDesktop - elif run_spec.configuration.ide == "windsurf": - __class = WindsurfDesktop - else: + ide = get_ide(run_spec.configuration.ide) + if ide is None: raise ServerClientError(f"Unsupported IDE: {run_spec.configuration.ide}") - self.ide = __class( - run_name=run_spec.run_name, - version=run_spec.configuration.version, - extensions=["ms-python.python", "ms-toolsai.jupyter"], - ) + self.ide = ide super().__init__(run_spec=run_spec, secrets=secrets, replica_group_name=replica_group_name) def _shell_commands(self) -> List[str]: @@ -46,13 +37,16 @@ def _shell_commands(self) -> List[str]: commands = [] if self.ide is not None: - commands += self.ide.get_install_commands() + commands += self.ide.get_install_commands( + version=self.run_spec.configuration.version, extensions=self.ide_extensions + ) commands.append(INSTALL_IPYKERNEL) commands += self.run_spec.configuration.setup commands.append("echo") commands += self.run_spec.configuration.init if self.ide is not None: - commands += self.ide.get_print_readme_commands() + assert self.run_spec.run_name is not None + commands += self.ide.get_print_readme_commands(self.run_spec.run_name) commands += [ f"echo 'To connect via SSH, use: `ssh {self.run_spec.run_name}`'", "echo", diff --git a/src/dstack/_internal/server/services/jobs/configurators/extensions/__init__.py b/src/dstack/_internal/server/services/jobs/configurators/extensions/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/dstack/_internal/server/services/jobs/configurators/extensions/base.py b/src/dstack/_internal/server/services/jobs/configurators/extensions/base.py deleted file mode 100644 index 73f30036f7..0000000000 --- a/src/dstack/_internal/server/services/jobs/configurators/extensions/base.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Callable, List - -CommandsExtension = Callable[[], List[str]] - - -def get_required_commands(executables: List[str]) -> CommandsExtension: - def wrapper() -> List[str]: - commands = [] - for exe in executables: - commands.append( - f'((command -v {exe} > /dev/null) || (echo "{exe} is required" && exit 1))' - ) - return commands - - return wrapper diff --git a/src/dstack/_internal/server/services/jobs/configurators/extensions/cursor.py b/src/dstack/_internal/server/services/jobs/configurators/extensions/cursor.py deleted file mode 100644 index 5ecaa02e9f..0000000000 --- a/src/dstack/_internal/server/services/jobs/configurators/extensions/cursor.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import List, Optional - - -class CursorDesktop: - def __init__( - self, - run_name: Optional[str], - version: Optional[str], - extensions: List[str], - ): - self.run_name = run_name - self.version = version - self.extensions = extensions - - def get_install_commands(self) -> List[str]: - commands = [] - if self.version is not None: - url = f"https://cursor.blob.core.windows.net/remote-releases/{self.version}/vscode-reh-linux-$arch.tar.gz" - archive = "vscode-reh-linux-$arch.tar.gz" - target = f'~/.cursor-server/cli/servers/"Stable-{self.version}"/server' - commands.extend( - [ - 'if [ $(uname -m) = "aarch64" ]; then arch="arm64"; else arch="x64"; fi', - "mkdir -p /tmp", - f'wget -q --show-progress "{url}" -O "/tmp/{archive}"', - f"mkdir -vp {target}", - f'tar --no-same-owner -xz --strip-components=1 -C {target} -f "/tmp/{archive}"', - f'rm "/tmp/{archive}"', - ] - ) - if self.extensions: - extensions = " ".join(f'--install-extension "{name}"' for name in self.extensions) - commands.append(f'PATH="$PATH":{target}/bin cursor-server {extensions}') - return commands - - def get_print_readme_commands(self) -> List[str]: - return [ - "echo To open in Cursor, use link below:", - "echo", - f'echo " cursor://vscode-remote/ssh-remote+{self.run_name}$DSTACK_WORKING_DIR"', - "echo", - ] diff --git a/src/dstack/_internal/server/services/jobs/configurators/extensions/vscode.py b/src/dstack/_internal/server/services/jobs/configurators/extensions/vscode.py deleted file mode 100644 index 87e79fd987..0000000000 --- a/src/dstack/_internal/server/services/jobs/configurators/extensions/vscode.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import List, Optional - - -class VSCodeDesktop: - def __init__( - self, - run_name: Optional[str], - version: Optional[str], - extensions: List[str], - ): - self.run_name = run_name - self.version = version - self.extensions = extensions - - def get_install_commands(self) -> List[str]: - commands = [] - if self.version is not None: - url = f"https://update.code.visualstudio.com/commit:{self.version}/server-linux-$arch/stable" - archive = "vscode-server-linux-$arch.tar.gz" - target = f'~/.vscode-server/bin/"{self.version}"' - commands.extend( - [ - 'if [ $(uname -m) = "aarch64" ]; then arch="arm64"; else arch="x64"; fi', - "mkdir -p /tmp", - f'wget -q --show-progress "{url}" -O "/tmp/{archive}"', - f"mkdir -vp {target}", - f'tar --no-same-owner -xz --strip-components=1 -C {target} -f "/tmp/{archive}"', - f'rm "/tmp/{archive}"', - ] - ) - if self.extensions: - extensions = " ".join(f'--install-extension "{name}"' for name in self.extensions) - commands.append(f'PATH="$PATH":{target}/bin code-server {extensions}') - return commands - - def get_print_readme_commands(self) -> List[str]: - return [ - "echo 'To open in VS Code Desktop, use link below:'", - "echo", - f'echo " vscode://vscode-remote/ssh-remote+{self.run_name}$DSTACK_WORKING_DIR"', - "echo", - ] diff --git a/src/dstack/_internal/server/services/jobs/configurators/extensions/windsurf.py b/src/dstack/_internal/server/services/jobs/configurators/extensions/windsurf.py deleted file mode 100644 index 63fee839f0..0000000000 --- a/src/dstack/_internal/server/services/jobs/configurators/extensions/windsurf.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import List, Optional - - -class WindsurfDesktop: - def __init__( - self, - run_name: Optional[str], - version: Optional[str], - extensions: List[str], - ): - self.run_name = run_name - self.version = version - self.extensions = extensions - - def get_install_commands(self) -> List[str]: - commands = [] - if self.version is not None: - version, commit = self.version.split("@") - url = f"https://windsurf-stable.codeiumdata.com/linux-reh-$arch/stable/{commit}/windsurf-reh-linux-$arch-{version}.tar.gz" - archive = "windsurf-reh-linux-$arch.tar.gz" - target = f'~/.windsurf-server/bin/"{commit}"' - commands.extend( - [ - 'if [ $(uname -m) = "aarch64" ]; then arch="arm64"; else arch="x64"; fi', - "mkdir -p /tmp", - f'wget -q --show-progress "{url}" -O "/tmp/{archive}"', - f"mkdir -vp {target}", - f'tar --no-same-owner -xz --strip-components=1 -C {target} -f "/tmp/{archive}"', - f'rm "/tmp/{archive}"', - ] - ) - if self.extensions: - extensions = " ".join(f'--install-extension "{name}"' for name in self.extensions) - commands.append(f'PATH="$PATH":{target}/bin windsurf-server {extensions}') - return commands - - def get_print_readme_commands(self) -> List[str]: - return [ - "echo To open in Windsurf, use link below:", - "echo", - f'echo " windsurf://vscode-remote/ssh-remote+{self.run_name}$DSTACK_WORKING_DIR"', - "echo", - ] diff --git a/src/dstack/_internal/server/services/runs/__init__.py b/src/dstack/_internal/server/services/runs/__init__.py index 3fb9468d3a..86dfb0bac0 100644 --- a/src/dstack/_internal/server/services/runs/__init__.py +++ b/src/dstack/_internal/server/services/runs/__init__.py @@ -24,6 +24,7 @@ from dstack._internal.core.models.runs import ( ApplyRunPlanInput, Job, + JobConnectionInfo, JobStatus, JobSubmission, JobTerminationReason, @@ -53,6 +54,7 @@ check_can_attach_job_volumes, delay_job_instance_termination, get_job_configured_volumes, + get_job_connection_info, get_job_spec, get_jobs_from_run_spec, job_model_to_job_submission, @@ -294,7 +296,7 @@ async def get_run_by_name( run_model = await get_run_model_by_name(session=session, project=project, run_name=run_name) if run_model is None: return None - return run_model_to_run(run_model, return_in_api=True) + return run_model_to_run(run_model, return_in_api=True, include_job_connection_info=True) async def get_run_by_id( @@ -315,7 +317,7 @@ async def get_run_by_id( run_model = res.scalar() if run_model is None: return None - return run_model_to_run(run_model, return_in_api=True) + return run_model_to_run(run_model, return_in_api=True, include_job_connection_info=True) async def get_plan( @@ -744,18 +746,21 @@ def run_model_to_run( job_submissions_limit: Optional[int] = None, return_in_api: bool = False, include_sensitive: bool = False, + include_job_connection_info: bool = False, ) -> Run: + run_spec = get_run_spec(run_model) + jobs: List[Job] = [] if include_jobs: jobs = _get_run_jobs_with_submissions( run_model=run_model, + run_spec=run_spec, job_submissions_limit=job_submissions_limit, return_in_api=return_in_api, include_sensitive=include_sensitive, + include_job_connection_info=include_job_connection_info, ) - run_spec = get_run_spec(run_model) - latest_job_submission = None if len(jobs) > 0 and len(jobs[0].job_submissions) > 0: # TODO(egor-s): does it make sense with replicas and multi-node? @@ -808,9 +813,11 @@ def _set_run_resources_defaults(run_spec: RunSpec) -> None: def _get_run_jobs_with_submissions( run_model: RunModel, + run_spec: RunSpec, job_submissions_limit: Optional[int], return_in_api: bool = False, include_sensitive: bool = False, + include_job_connection_info: bool = False, ) -> List[Job]: jobs: List[Job] = [] run_jobs = sorted(run_model.jobs, key=lambda j: (j.replica_num, j.job_num, j.submission_num)) @@ -845,7 +852,16 @@ def _get_run_jobs_with_submissions( job_spec = get_job_spec(job_model) if not include_sensitive: remove_job_spec_sensitive_info(job_spec) - jobs.append(Job(job_spec=job_spec, job_submissions=submissions)) + job_connection_info: Optional[JobConnectionInfo] = None + if include_job_connection_info and job_model.status == JobStatus.RUNNING: + job_connection_info = get_job_connection_info(job_model, run_spec) + jobs.append( + Job( + job_spec=job_spec, + job_submissions=submissions, + job_connection_info=job_connection_info, + ) + ) return jobs diff --git a/src/dstack/_internal/server/services/sshproxy/__init__.py b/src/dstack/_internal/server/services/sshproxy/__init__.py new file mode 100644 index 0000000000..9966ba91f7 --- /dev/null +++ b/src/dstack/_internal/server/services/sshproxy/__init__.py @@ -0,0 +1,32 @@ +from typing import Optional + +from dstack._internal.server import settings +from dstack._internal.server.models import JobModel +from dstack._internal.utils.ssh import build_ssh_command, build_ssh_url_authority + + +def build_proxied_job_ssh_url_authority(job: JobModel) -> Optional[str]: + if not settings.SSHPROXY_ENABLED: + return None + assert settings.SSHPROXY_HOSTNAME is not None + return build_ssh_url_authority( + username=_build_proxied_job_username(job), + hostname=settings.SSHPROXY_HOSTNAME, + port=settings.SSHPROXY_PORT, + ) + + +def build_proxied_job_ssh_command(job: JobModel) -> Optional[list[str]]: + if not settings.SSHPROXY_ENABLED: + return None + assert settings.SSHPROXY_HOSTNAME is not None + return build_ssh_command( + username=_build_proxied_job_username(job), + hostname=settings.SSHPROXY_HOSTNAME, + port=settings.SSHPROXY_PORT, + ) + + +def _build_proxied_job_username(job: JobModel) -> str: + # Job's UUID in lowercase, without dashes + return job.id.hex diff --git a/src/dstack/_internal/server/services/sshproxy.py b/src/dstack/_internal/server/services/sshproxy/handlers.py similarity index 96% rename from src/dstack/_internal/server/services/sshproxy.py rename to src/dstack/_internal/server/services/sshproxy/handlers.py index 0877436d0e..9f03971148 100644 --- a/src/dstack/_internal/server/services/sshproxy.py +++ b/src/dstack/_internal/server/services/sshproxy/handlers.py @@ -37,7 +37,7 @@ async def get_upstream_response( JobModel.status == JobStatus.RUNNING, ) .options( - (joinedload(JobModel.project, innerjoin=True).load_only(ProjectModel.ssh_private_key)), + joinedload(JobModel.project, innerjoin=True).load_only(ProjectModel.ssh_private_key), ( joinedload(JobModel.instance, innerjoin=True) .load_only(InstanceModel.remote_connection_info) diff --git a/src/dstack/_internal/server/settings.py b/src/dstack/_internal/server/settings.py index 01216cff31..e005a6cb56 100644 --- a/src/dstack/_internal/server/settings.py +++ b/src/dstack/_internal/server/settings.py @@ -6,6 +6,7 @@ from enum import Enum from pathlib import Path +from dstack._internal.server.utils.settings import parse_hostname_port from dstack._internal.utils.env import environ from dstack._internal.utils.logging import get_logger @@ -105,6 +106,12 @@ os.getenv("DSTACK_SERVER_EVENTS_TTL_SECONDS", 30 * 24 * 3600) ) +SSHPROXY_API_TOKEN = environ.get("DSTACK_SSHPROXY_API_TOKEN") or None +SSHPROXY_HOSTNAME, SSHPROXY_PORT = environ.get_callback( + "DSTACK_SERVER_SSHPROXY_ADDRESS", parse_hostname_port, default=(None, None) +) +SSHPROXY_ENABLED = SSHPROXY_API_TOKEN is not None and SSHPROXY_HOSTNAME is not None + SERVER_KEEP_SHIM_TASKS = os.getenv("DSTACK_SERVER_KEEP_SHIM_TASKS") is not None DEFAULT_PROJECT_NAME = "main" diff --git a/src/dstack/_internal/server/utils/settings.py b/src/dstack/_internal/server/utils/settings.py new file mode 100644 index 0000000000..a80d63ed9e --- /dev/null +++ b/src/dstack/_internal/server/utils/settings.py @@ -0,0 +1,19 @@ +from typing import Optional +from urllib.parse import urlsplit + + +def parse_hostname_port(address: str) -> tuple[str, Optional[int]]: + err_msg = "must be valid HOSTNAME[:PORT]" + if "//" in address: + raise ValueError(err_msg) + res = urlsplit(f"//{address}") + if any((res.path, res.query, res.fragment, res.username, res.password)): + raise ValueError(err_msg) + hostname = res.hostname + if not hostname: + raise ValueError(err_msg) + try: + port = res.port + except ValueError as e: + raise ValueError(err_msg) from e + return hostname, port diff --git a/src/dstack/_internal/settings.py b/src/dstack/_internal/settings.py index d94bb56547..be0d6fbe3d 100644 --- a/src/dstack/_internal/settings.py +++ b/src/dstack/_internal/settings.py @@ -50,3 +50,8 @@ class FeatureFlags: # DSTACK_FF_PIPELINE_PROCESSING_ENABLED enables new pipeline-based processing tasks (background/pipeline_tasks/) # instead of scheduler-based processing tasks (background/scheduled_tasks/) for tasks that implement pipelines. PIPELINE_PROCESSING_ENABLED = os.getenv("DSTACK_FF_PIPELINE_PROCESSING_ENABLED") is not None + # If DSTACK_FF_CLI_PRINT_JOB_CONNECTION_INFO enabled, `dstack apply` command prints server-provided + # IDE URL(s) and SSH command(s) before job logs (for dev-environments only). + CLI_PRINT_JOB_CONNECTION_INFO = ( + os.getenv("DSTACK_FF_CLI_PRINT_JOB_CONNECTION_INFO") is not None + ) diff --git a/src/dstack/_internal/utils/env.py b/src/dstack/_internal/utils/env.py index 53eb4c61de..9ce97b6c40 100644 --- a/src/dstack/_internal/utils/env.py +++ b/src/dstack/_internal/utils/env.py @@ -1,16 +1,27 @@ import os from collections.abc import Mapping from enum import Enum -from typing import Optional, TypeVar, Union, overload +from typing import Callable, Optional, TypeVar, Union, overload -_Value = Union[str, int] -_T = TypeVar("_T", bound=Enum) +_EVT = Union[str, int] +_ET = TypeVar("_ET", bound=Enum) + +_CVT = TypeVar("_CVT") class Environ: def __init__(self, environ: Mapping[str, str]): self._environ = environ + @overload + def get(self, name: str, *, default: None = None) -> Optional[str]: ... + + @overload + def get(self, name: str, *, default: str) -> str: ... + + def get(self, name: str, *, default: Optional[str] = None) -> Optional[str]: + return self._environ.get(name, default) + @overload def get_bool(self, name: str, *, default: None = None) -> Optional[bool]: ... @@ -49,30 +60,30 @@ def get_int(self, name: str, *, default: Optional[int] = None) -> Optional[int]: def get_enum( self, name: str, - enum_cls: type[_T], + enum_cls: type[_ET], *, - value_type: Optional[type[_Value]] = None, + value_type: Optional[type[_EVT]] = None, default: None = None, - ) -> Optional[_T]: ... + ) -> Optional[_ET]: ... @overload def get_enum( self, name: str, - enum_cls: type[_T], + enum_cls: type[_ET], *, - value_type: Optional[type[_Value]] = None, - default: _T, - ) -> _T: ... + value_type: Optional[type[_EVT]] = None, + default: _ET, + ) -> _ET: ... def get_enum( self, name: str, - enum_cls: type[_T], + enum_cls: type[_ET], *, - value_type: Optional[type[_Value]] = None, - default: Optional[_T] = None, - ) -> Optional[_T]: + value_type: Optional[type[_EVT]] = None, + default: Optional[_ET] = None, + ) -> Optional[_ET]: try: raw_value = self._environ[name] except KeyError: @@ -84,5 +95,27 @@ def get_enum( except (ValueError, TypeError) as e: raise ValueError(f"Invalid {enum_cls.__name__} value: {e}: {name}={raw_value}") from e + @overload + def get_callback( + self, name: str, callback: Callable[[str], _CVT], *, default: None = None + ) -> Optional[_CVT]: ... + + @overload + def get_callback( + self, name: str, callback: Callable[[str], _CVT], *, default: _CVT + ) -> _CVT: ... + + def get_callback( + self, name: str, callback: Callable[[str], _CVT], *, default: Optional[_CVT] = None + ) -> Optional[_CVT]: + try: + raw_value = self._environ[name] + except KeyError: + return default + try: + return callback(raw_value) + except ValueError as e: + raise ValueError(f"Invalid value: {e}: {name}={raw_value}") from e + environ = Environ(os.environ) diff --git a/src/dstack/_internal/utils/ssh.py b/src/dstack/_internal/utils/ssh.py index e2b5fef436..e1828bc33b 100644 --- a/src/dstack/_internal/utils/ssh.py +++ b/src/dstack/_internal/utils/ssh.py @@ -344,3 +344,64 @@ def find_ssh_util(name: str) -> Optional[Path]: if path.exists(): return path return None + + +def build_ssh_command( + *, + username: Optional[str] = None, + hostname: str, + port: Optional[int] = None, + ssh_executable: Optional[str] = None, +) -> list[str]: + """ + Builds an SSH client command line to connect. + + The resulting command is: + + ssh [username@]hostname [-p port] + + The port argument -p is only included if the port is not the default SSH port (22). + + :param username: an optional user login name. + :param hostname: a hostname, required. + :param port: an optional SSH port, defaults to 22. + :param ssh_executable: an optional file name or path of the SSH client, defaults to `ssh`. + :return: a list of command line arguments including the executable. + """ + if ssh_executable is None: + ssh_executable = "ssh" + command: list[str] = [ssh_executable] + if username is not None: + command.append(f"{username}@{hostname}") + else: + command.append(hostname) + if port is not None and port != 22: + command.extend(("-p", str(port))) + return command + + +def build_ssh_url_authority( + *, username: Optional[str] = None, hostname: str, port: Optional[int] = None +) -> str: + """ + Builds an authority URL component for use with ssh:// and ssh-based URLs (e.g., vscode://). + + The authority component consists of subcomponents: + + authority = [userinfo "@"] host [":" port] + + The port subcomponent is only included if the port is not the default SSH port (22). + + :param username: an optional user login name, used as the userinfo if provided. + :param hostname: a hostname, required. + :param port: an optional SSH port, defaults to 22. + :return: the authority URL component as a string. + """ + if ":" in hostname and not hostname.startswith("["): + hostname = f"[{hostname}]" + authority = hostname + if username is not None: + authority = f"{username}@{authority}" + if port is not None and port != 22: + authority = f"{authority}:{port}" + return authority diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 3d14263a37..879ad48231 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -6,7 +6,6 @@ from uuid import UUID import pytest -from fastapi.testclient import TestClient from freezegun import freeze_time from httpx import AsyncClient from sqlalchemy import select @@ -45,7 +44,6 @@ ) from dstack._internal.core.models.users import GlobalRole, ProjectRole from dstack._internal.core.models.volumes import InstanceMountPoint, MountPoint -from dstack._internal.server.main import app from dstack._internal.server.models import JobModel, RunModel from dstack._internal.server.schemas.runs import ApplyRunPlanRequest from dstack._internal.server.services.projects import add_project_member @@ -70,15 +68,19 @@ get_auth_headers, get_fleet_spec, get_job_provisioning_data, + get_job_runtime_data, get_run_spec, get_ssh_fleet_configuration, list_events, ) from dstack._internal.server.testing.matchers import SomeUUID4Str -pytestmark = pytest.mark.usefixtures("image_config_mock") +pytestmark = pytest.mark.usefixtures("image_config_mock", "disable_sshproxy") -client = TestClient(app) + +@pytest.fixture +def disable_sshproxy(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("dstack._internal.server.settings.SSHPROXY_ENABLED", False) def get_dev_env_run_plan_dict( @@ -107,7 +109,7 @@ def get_dev_env_run_plan_dict( " && pip install -q --no-cache-dir ipykernel 2> /dev/null)" " || echo 'no pip, ipykernel was not installed'" " && echo" - " && echo 'To open in VS Code Desktop, use link below:'" + " && echo 'To open in VS Code, use link below:'" " && echo" ' && echo " vscode://vscode-remote/ssh-remote+dry-run$DSTACK_WORKING_DIR"' " && echo" @@ -134,7 +136,7 @@ def get_dev_env_run_plan_dict( " && pip install -q --no-cache-dir ipykernel 2> /dev/null)" " || echo 'no pip, ipykernel was not installed'" " && echo" - " && echo 'To open in VS Code Desktop, use link below:'" + " && echo 'To open in VS Code, use link below:'" " && echo" ' && echo " vscode://vscode-remote/ssh-remote+dry-run$DSTACK_WORKING_DIR"' " && echo" @@ -340,7 +342,7 @@ def get_dev_env_run_dict( " && pip install -q --no-cache-dir ipykernel 2> /dev/null)" " || echo 'no pip, ipykernel was not installed'" " && echo" - " && echo 'To open in VS Code Desktop, use link below:'" + " && echo 'To open in VS Code, use link below:'" " && echo" ' && echo " vscode://vscode-remote/ssh-remote+test-run$DSTACK_WORKING_DIR"' " && echo" @@ -367,7 +369,7 @@ def get_dev_env_run_dict( " && pip install -q --no-cache-dir ipykernel 2> /dev/null)" " || echo 'no pip, ipykernel was not installed'" " && echo" - " && echo 'To open in VS Code Desktop, use link below:'" + " && echo 'To open in VS Code, use link below:'" " && echo" ' && echo " vscode://vscode-remote/ssh-remote+test-run$DSTACK_WORKING_DIR"' " && echo" @@ -561,6 +563,7 @@ def get_dev_env_run_dict( "probes": [], } ], + "job_connection_info": None, } ], "latest_job_submission": { @@ -729,6 +732,7 @@ async def test_lists_runs(self, test_db, session: AsyncSession, client: AsyncCli "probes": [], } ], + "job_connection_info": None, } ], "latest_job_submission": { @@ -919,6 +923,7 @@ async def test_limits_job_submissions( "probes": [], } ], + "job_connection_info": None, } ], "latest_job_submission": { @@ -1122,6 +1127,136 @@ async def test_patches_service_configuration_probes_for_old_clients( assert response.status_code == 200 assert response.json()["run_spec"]["configuration"]["probes"] == expected_probes + @pytest.mark.asyncio + @pytest.mark.parametrize( + "sshproxy", + [ + pytest.param(False, id="without-sshproxy"), + pytest.param(True, id="with-sshproxy"), + ], + ) + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_run_with_job_connection_info_dev_environment( + self, + monkeypatch: pytest.MonkeyPatch, + test_db, + session: AsyncSession, + client: AsyncClient, + sshproxy: bool, + ): + monkeypatch.setattr("dstack._internal.server.settings.SSHPROXY_ENABLED", sshproxy) + monkeypatch.setattr("dstack._internal.server.settings.SSHPROXY_HOSTNAME", "example.com") + monkeypatch.setattr("dstack._internal.server.settings.SSHPROXY_PORT", 2222) + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + repo = await create_repo( + session=session, + project_id=project.id, + ) + run_spec = get_run_spec( + repo_id=repo.name, + run_name="dev-env", + configuration=DevEnvironmentConfiguration(ide="cursor"), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=run_spec, + run_name=run_spec.run_name, + ) + job_runtime_data = get_job_runtime_data(working_dir="/test") + job = await create_job( + session=session, run=run, status=JobStatus.RUNNING, job_runtime_data=job_runtime_data + ) + response = await client.post( + f"/api/project/{project.name}/runs/get", + headers=get_auth_headers(user.token), + json={"run_name": run.run_name}, + ) + assert response.status_code == 200, response.json() + assert response.json()["jobs"][0]["job_connection_info"] == { + "ide_name": "Cursor", + "attached_ide_url": "cursor://vscode-remote/ssh-remote+dev-env/test", + "proxied_ide_url": f"cursor://vscode-remote/ssh-remote+{job.id.hex}@example.com:2222/test" + if sshproxy + else None, + "attached_ssh_command": ["ssh", "dev-env"], + "proxied_ssh_command": ["ssh", f"{job.id.hex}@example.com", "-p", "2222"] + if sshproxy + else None, + } + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_run_with_job_connection_info_task( + self, monkeypatch: pytest.MonkeyPatch, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + repo = await create_repo( + session=session, + project_id=project.id, + ) + run_spec = get_run_spec( + repo_id=repo.name, + run_name="test-task", + configuration=TaskConfiguration(commands=["sleep inf"]), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=run_spec, + run_name=run_spec.run_name, + ) + job_runtime_data = get_job_runtime_data(working_dir="/test") + for replica_num in range(2): + for job_num in range(2): + await create_job( + session=session, + run=run, + # test-task-1-1 is still PULLING, other jobs are RUNNING + status=JobStatus.PULLING if replica_num == job_num == 1 else JobStatus.RUNNING, + job_runtime_data=job_runtime_data, + replica_num=replica_num, + job_num=job_num, + ) + response = await client.post( + f"/api/project/{project.name}/runs/get", + headers=get_auth_headers(user.token), + json={"run_name": run.run_name}, + ) + assert response.status_code == 200, response.json() + jobs = response.json()["jobs"] + common_fields = { + "ide_name": None, + "attached_ide_url": None, + "proxied_ide_url": None, + "proxied_ssh_command": None, + } + assert jobs[0]["job_connection_info"] == { + "attached_ssh_command": ["ssh", "test-task"], + **common_fields, + } + assert jobs[1]["job_connection_info"] == { + "attached_ssh_command": ["ssh", "test-task-1-0"], + **common_fields, + } + assert jobs[2]["job_connection_info"] == { + "attached_ssh_command": ["ssh", "test-task-0-1"], + **common_fields, + } + assert jobs[3]["job_connection_info"] is None + class TestGetRunPlan: @pytest.mark.asyncio diff --git a/src/tests/_internal/server/utils/test_settings.py b/src/tests/_internal/server/utils/test_settings.py new file mode 100644 index 0000000000..d2d40385db --- /dev/null +++ b/src/tests/_internal/server/utils/test_settings.py @@ -0,0 +1,46 @@ +from typing import Optional + +import pytest + +from dstack._internal.server.utils.settings import parse_hostname_port + + +class TestParseHostnamePort: + @pytest.mark.parametrize( + ["value", "expected_hostname", "expected_port"], + [ + pytest.param("example.com", "example.com", None, id="domain"), + pytest.param("example.com:22", "example.com", 22, id="domain-port"), + pytest.param("10.0.0.1", "10.0.0.1", None, id="ipv4"), + pytest.param( + "[fd69:b03c:7b2:b68a:6eda:b557:9526:757]", + "fd69:b03c:7b2:b68a:6eda:b557:9526:757", + None, + id="ipv6", + ), + pytest.param( + "[fd69:b03c:7b2:b68a:6eda:b557:9526:757]:22", + "fd69:b03c:7b2:b68a:6eda:b557:9526:757", + 22, + id="ipv6-port", + ), + ], + ) + def test_valid(self, value: str, expected_hostname: str, expected_port: Optional[int]): + hostname, port = parse_hostname_port(value) + assert hostname == expected_hostname + assert port == expected_port + + @pytest.mark.parametrize( + "value", + [ + pytest.param("", id="empty-string"), + pytest.param(":22", id="no-hostname"), + pytest.param("fd69:b03c:7b2:b68a:6eda:b557:9526:757", id="ipv6-without-brackets"), + pytest.param("example.com:port", id="non-integer-port"), + pytest.param("example.com:1000000", id="port-out-of-range"), + ], + ) + def test_invalid(self, value: str): + with pytest.raises(ValueError, match=r"must be valid HOSTNAME\[:PORT\]"): + parse_hostname_port(value) diff --git a/src/tests/_internal/utils/test_env.py b/src/tests/_internal/utils/test_env.py index 4e9e2f6b79..d9242a13e9 100644 --- a/src/tests/_internal/utils/test_env.py +++ b/src/tests/_internal/utils/test_env.py @@ -3,7 +3,7 @@ import pytest -from dstack._internal.utils.env import Environ, _Value +from dstack._internal.utils.env import Environ class _TestEnviron: @@ -87,7 +87,9 @@ class TestEnvironGetEnum(_TestEnviron): pytest.param(_IntEnum, int, "100", id="int"), ], ) - def test_is_set(self, enum_cls: type[_Enum], value_type: type[_Value], value: str): + def test_is_set( + self, enum_cls: type[_Enum], value_type: Union[type[str], type[int]], value: str + ): environ = self.get_environ(VAR=value) assert environ.get_enum("VAR", enum_cls, value_type=value_type) is enum_cls.FOO @@ -107,7 +109,31 @@ def test_not_set_default_is_set(self): pytest.param(_IntEnum, int, "10a", id="invalid-int"), ], ) - def test_error_bad_value(self, enum_cls: type[_Enum], value_type: type[_Value], value: str): + def test_error_bad_value( + self, enum_cls: type[_Enum], value_type: Union[type[str], type[int]], value: str + ): environ = self.get_environ(VAR=value) with pytest.raises(ValueError, match=f"VAR={value}"): environ.get_enum("VAR", enum_cls, value_type=value_type) + + +class TestEnvironGetCallback(_TestEnviron): + def test_is_set(self): + environ = self.get_environ(VAR="foo bar") + assert environ.get_callback("VAR", str.split) == ["foo", "bar"] + + def test_not_set_default_not_set(self): + environ = self.get_environ() + assert environ.get_callback("VAR", str.split) is None + + def test_not_set_default_is_set(self): + environ = self.get_environ() + assert environ.get_callback("VAR", str.split, default=["default"]) == ["default"] + + def test_error_bad_value(self): + def callback(value: str) -> list[str]: + raise ValueError("bad value") + + environ = self.get_environ(VAR="value") + with pytest.raises(ValueError, match="bad value: VAR=value"): + environ.get_callback("VAR", callback=callback)