Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/dstack/_internal/cli/services/configurators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
get_repo_creds_and_default_branch,
)
from dstack._internal.core.services.ssh.ports import PortUsedError
from dstack._internal.settings import FeatureFlags
from dstack._internal.utils.common import local_time
from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator
from dstack._internal.utils.logging import get_logger
Expand Down Expand Up @@ -215,6 +216,7 @@ def apply_configuration(
current_job_submission = run._run.latest_job_submission
if run.status in (RunStatus.RUNNING, RunStatus.DONE):
_print_service_urls(run)
_print_dev_environment_connection_info(run)
bind_address: Optional[str] = getattr(
configurator_args, _BIND_ADDRESS_ARG, None
)
Expand Down Expand Up @@ -806,6 +808,30 @@ def _print_service_urls(run: Run) -> None:
console.print()


def _print_dev_environment_connection_info(run: Run) -> None:
if not FeatureFlags.CLI_PRINT_JOB_CONNECTION_INFO:
return
if run._run.run_spec.configuration.type != RunConfigurationType.DEV_ENVIRONMENT.value:
return
jci = run._run.jobs[0].job_connection_info
if jci is None:
return
if jci.ide_name:
urls = [u for u in (jci.attached_ide_url, jci.proxied_ide_url) if u]
if urls:
console.print(
f"To open in {jci.ide_name}, use link{'s' if len(urls) > 1 else ''} below:\n"
)
for link in urls:
console.print(f" [link={link}]{link}[/]\n")
ssh_commands = [" ".join(c) for c in (jci.attached_ssh_command, jci.proxied_ssh_command) if c]
if ssh_commands:
console.print(
f"To connect via SSH, use: {' or '.join(f'[code]{c}[/]' for c in ssh_commands)}\n"
)
console.print()


def print_finished_message(run: Run):
status_message = (
run._run.latest_job_submission.status_message
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/core/compatibility/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def get_apply_plan_excludes(plan: ApplyRunPlanInput) -> Optional[IncludeExcludeD
]
),
},
# Contains only informational computed fields, safe to exclude unconditionally
"job_connection_info": True,
}
}
if current_resource.latest_job_submission is not None:
Expand Down
44 changes: 44 additions & 0 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,53 @@ def duration(self) -> timedelta:
return end_time - self.submitted_at


class JobConnectionInfo(CoreModel):
ide_name: Annotated[
Optional[str], Field(description="Dev environment IDE name for UI, human-readable.")
]
attached_ide_url: Annotated[
Optional[str],
Field(
description=(
"Dev environment IDE URL."
" Not set if the job has not started yet."
" Only works if the user is attached to the run via CLI or Python API."
)
),
]
proxied_ide_url: Annotated[
Optional[str],
Field(
description=(
"Dev environment IDE URL."
" Not set if the job has hot started yet or sshproxy is not configured."
)
),
]
attached_ssh_command: Annotated[
Optional[list[str]],
Field(
description=(
"SSH command to connect to the job, list of command line arguments."
" Only works if the user is attached to the run via CLI or Python API."
)
),
]
proxied_ssh_command: Annotated[
Optional[list[str]],
Field(
description=(
"SSH command to connect to the job, list of command line arguments."
" Not set if sshproxy is not configured."
)
),
]


class Job(CoreModel):
job_spec: JobSpec
job_submissions: List[JobSubmission]
job_connection_info: Optional[JobConnectionInfo] = None


class RunSpecConfig(CoreConfig):
Expand Down
8 changes: 4 additions & 4 deletions src/dstack/_internal/server/routers/sshproxy.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import os
from typing import Annotated

from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession

from dstack._internal.core.errors import ResourceNotExistsError
from dstack._internal.server import settings
from dstack._internal.server.db import get_session
from dstack._internal.server.schemas.sshproxy import GetUpstreamRequest, GetUpstreamResponse
from dstack._internal.server.security.permissions import AlwaysForbidden, ServiceAccount
from dstack._internal.server.services.sshproxy import get_upstream_response
from dstack._internal.server.services.sshproxy.handlers import get_upstream_response
from dstack._internal.server.utils.routers import (
CustomORJSONResponse,
get_base_api_additional_responses,
)

if _token := os.getenv("DSTACK_SSHPROXY_API_TOKEN"):
_auth = ServiceAccount(_token)
if settings.SSHPROXY_API_TOKEN is not None:
_auth = ServiceAccount(settings.SSHPROXY_API_TOKEN)
else:
_auth = AlwaysForbidden()

Expand Down
21 changes: 21 additions & 0 deletions src/dstack/_internal/server/services/ides/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Literal, Optional

from dstack._internal.server.services.ides.base import IDE
from dstack._internal.server.services.ides.cursor import CursorDesktop
from dstack._internal.server.services.ides.vscode import VSCodeDesktop
from dstack._internal.server.services.ides.windsurf import WindsurfDesktop

_IDELiteral = Literal["vscode", "cursor", "windsurf"]

_ide_literal_to_ide_class_map: dict[_IDELiteral, type[IDE]] = {
"vscode": VSCodeDesktop,
"cursor": CursorDesktop,
"windsurf": WindsurfDesktop,
}


def get_ide(ide_literal: _IDELiteral) -> Optional[IDE]:
ide_class = _ide_literal_to_ide_class_map.get(ide_literal)
if ide_class is None:
return None
return ide_class()
25 changes: 25 additions & 0 deletions src/dstack/_internal/server/services/ides/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from abc import ABC, abstractmethod
from typing import ClassVar, Optional


class IDE(ABC):
name: ClassVar[str]
url_scheme: ClassVar[str]

@abstractmethod
def get_install_commands(
self, version: Optional[str] = None, extensions: Optional[list[str]] = None
) -> list[str]:
pass

def get_url(self, authority: str, working_dir: str) -> str:
return f"{self.url_scheme}://vscode-remote/ssh-remote+{authority}{working_dir}"

def get_print_readme_commands(self, authority: str) -> list[str]:
url = self.get_url(authority, working_dir="$DSTACK_WORKING_DIR")
return [
f"echo 'To open in {self.name}, use link below:'",
"echo",
f'echo " {url}"',
"echo",
]
31 changes: 31 additions & 0 deletions src/dstack/_internal/server/services/ides/cursor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Optional

from dstack._internal.server.services.ides.base import IDE


class CursorDesktop(IDE):
name = "Cursor"
url_scheme = "cursor"

def get_install_commands(
self, version: Optional[str] = None, extensions: Optional[list[str]] = None
) -> list[str]:
commands = []
if version is not None:
url = f"https://cursor.blob.core.windows.net/remote-releases/{version}/vscode-reh-linux-$arch.tar.gz"
archive = "vscode-reh-linux-$arch.tar.gz"
target = f'~/.cursor-server/cli/servers/"Stable-{version}"/server'
commands.extend(
[
'if [ $(uname -m) = "aarch64" ]; then arch="arm64"; else arch="x64"; fi',
"mkdir -p /tmp",
f'wget -q --show-progress "{url}" -O "/tmp/{archive}"',
f"mkdir -vp {target}",
f'tar --no-same-owner -xz --strip-components=1 -C {target} -f "/tmp/{archive}"',
f'rm "/tmp/{archive}"',
]
)
if extensions:
_extensions = " ".join(f'--install-extension "{name}"' for name in extensions)
commands.append(f'PATH="$PATH":{target}/bin cursor-server {_extensions}')
return commands
33 changes: 33 additions & 0 deletions src/dstack/_internal/server/services/ides/vscode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Optional

from dstack._internal.server.services.ides.base import IDE


class VSCodeDesktop(IDE):
name = "VS Code"
url_scheme = "vscode"

def get_install_commands(
self, version: Optional[str] = None, extensions: Optional[list[str]] = None
) -> list[str]:
commands = []
if version is not None:
url = (
f"https://update.code.visualstudio.com/commit:{version}/server-linux-$arch/stable"
)
archive = "vscode-server-linux-$arch.tar.gz"
target = f'~/.vscode-server/bin/"{version}"'
commands.extend(
[
'if [ $(uname -m) = "aarch64" ]; then arch="arm64"; else arch="x64"; fi',
"mkdir -p /tmp",
f'wget -q --show-progress "{url}" -O "/tmp/{archive}"',
f"mkdir -vp {target}",
f'tar --no-same-owner -xz --strip-components=1 -C {target} -f "/tmp/{archive}"',
f'rm "/tmp/{archive}"',
]
)
if extensions:
_extensions = " ".join(f'--install-extension "{name}"' for name in extensions)
commands.append(f'PATH="$PATH":{target}/bin code-server {_extensions}')
return commands
32 changes: 32 additions & 0 deletions src/dstack/_internal/server/services/ides/windsurf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Optional

from dstack._internal.server.services.ides.base import IDE


class WindsurfDesktop(IDE):
name = "Windsurf"
url_scheme = "windsurf"

def get_install_commands(
self, version: Optional[str] = None, extensions: Optional[list[str]] = None
) -> list[str]:
commands = []
if version is not None:
version, commit = version.split("@")
url = f"https://windsurf-stable.codeiumdata.com/linux-reh-$arch/stable/{commit}/windsurf-reh-linux-$arch-{version}.tar.gz"
archive = "windsurf-reh-linux-$arch.tar.gz"
target = f'~/.windsurf-server/bin/"{commit}"'
commands.extend(
[
'if [ $(uname -m) = "aarch64" ]; then arch="arm64"; else arch="x64"; fi',
"mkdir -p /tmp",
f'wget -q --show-progress "{url}" -O "/tmp/{archive}"',
f"mkdir -vp {target}",
f'tar --no-same-owner -xz --strip-components=1 -C {target} -f "/tmp/{archive}"',
f'rm "/tmp/{archive}"',
]
)
if extensions:
_extensions = " ".join(f'--install-extension "{name}"' for name in extensions)
commands.append(f'PATH="$PATH":{target}/bin windsurf-server {_extensions}')
return commands
48 changes: 48 additions & 0 deletions src/dstack/_internal/server/services/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dstack._internal.core.models.configurations import RunConfigurationType
from dstack._internal.core.models.runs import (
Job,
JobConnectionInfo,
JobProvisioningData,
JobRuntimeData,
JobSpec,
Expand All @@ -37,6 +38,7 @@
)
from dstack._internal.server.services import events
from dstack._internal.server.services import volumes as volumes_services
from dstack._internal.server.services.ides import get_ide
from dstack._internal.server.services.instances import (
get_instance_ssh_private_keys,
)
Expand All @@ -51,9 +53,14 @@
from dstack._internal.server.services.probes import probe_model_to_probe
from dstack._internal.server.services.runner import client
from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
from dstack._internal.server.services.sshproxy import (
build_proxied_job_ssh_command,
build_proxied_job_ssh_url_authority,
)
from dstack._internal.utils import common
from dstack._internal.utils.common import run_async
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.ssh import build_ssh_command, build_ssh_url_authority

logger = get_logger(__name__)

Expand Down Expand Up @@ -490,6 +497,47 @@ def remove_job_spec_sensitive_info(spec: JobSpec):
spec.ssh_key = None


def get_job_connection_info(job_model: JobModel, run_spec: RunSpec) -> JobConnectionInfo:
# Run.attach() Python API method, used internally by CLI, uses the following as the Hostname
# in the SSH config:
# * for the (job=0 replica=0) job - run name, e.g., `my-task`
# * for other jobs - job name, e.g., `my-task-0-1`
attached_hostname = run_spec.run_name
if job_model.job_num != 0 or job_model.replica_num != 0:
attached_hostname = job_model.job_name
assert attached_hostname is not None

# ide_* fields are for dev-environment only
ide_name: Optional[str] = None
# IDE URLs are not set until the job status is switched to RUNNING,
# as JobRuntimeData.working_dir, which is required to build URLs, is returned
# by dstack-runner's `/api/run` method
attached_ide_url: Optional[str] = None
proxied_ide_url: Optional[str] = None
if (
run_spec.configuration.type == RunConfigurationType.DEV_ENVIRONMENT.value
and run_spec.configuration.ide is not None
):
ide = get_ide(run_spec.configuration.ide)
if ide is not None:
ide_name = ide.name
jrd = get_job_runtime_data(job_model)
if jrd is not None and jrd.working_dir is not None:
attached_url_authority = build_ssh_url_authority(hostname=attached_hostname)
attached_ide_url = ide.get_url(attached_url_authority, jrd.working_dir)
proxied_url_authority = build_proxied_job_ssh_url_authority(job_model)
if proxied_url_authority is not None:
proxied_ide_url = ide.get_url(proxied_url_authority, jrd.working_dir)

return JobConnectionInfo(
ide_name=ide_name,
attached_ide_url=attached_ide_url,
proxied_ide_url=proxied_ide_url,
attached_ssh_command=build_ssh_command(hostname=attached_hostname),
proxied_ssh_command=build_proxied_job_ssh_command(job_model),
)


def _get_job_mount_point_attached_volume(
volumes: List[Volume],
job_provisioning_data: JobProvisioningData,
Expand Down
Loading
Loading