diff --git a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py index d9f67680c..6b3762419 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py @@ -2,6 +2,7 @@ from dstack._internal.server.background.pipeline_tasks.base import Pipeline from dstack._internal.server.background.pipeline_tasks.compute_groups import ComputeGroupPipeline +from dstack._internal.server.background.pipeline_tasks.fleets import FleetPipeline from dstack._internal.server.background.pipeline_tasks.gateways import GatewayPipeline from dstack._internal.server.background.pipeline_tasks.placement_groups import ( PlacementGroupPipeline, @@ -16,6 +17,7 @@ class PipelineManager: def __init__(self) -> None: self._pipelines: list[Pipeline] = [ ComputeGroupPipeline(), + FleetPipeline(), GatewayPipeline(), PlacementGroupPipeline(), VolumePipeline(), diff --git a/src/dstack/_internal/server/background/pipeline_tasks/base.py b/src/dstack/_internal/server/background/pipeline_tasks/base.py index 9d016934c..aa5af9a4a 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/base.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/base.py @@ -3,9 +3,20 @@ import random import uuid from abc import ABC, abstractmethod +from collections.abc import Iterable, Sequence from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, ClassVar, Generic, Optional, Protocol, Sequence, TypeVar +from typing import ( + Any, + ClassVar, + Final, + Generic, + Optional, + Protocol, + TypedDict, + TypeVar, + Union, +) from sqlalchemy import and_, or_, update from sqlalchemy.orm import Mapped @@ -337,16 +348,71 @@ async def process(self, item: ItemT): pass -UpdateMap = dict[str, Any] +class _NowPlaceholder: + pass + + +NOW_PLACEHOLDER: Final = _NowPlaceholder() +""" +Use `NOW_PLACEHOLDER` together with `resolve_now_placeholders()` in pipeline update maps +instead of `get_current_time()` to have the same current time for all updates in the transaction. +""" + + +UpdateMapDateTime = Union[datetime, _NowPlaceholder] + + +class _UnlockUpdateMap(TypedDict, total=False): + lock_expires_at: Optional[datetime] + lock_token: Optional[uuid.UUID] + lock_owner: Optional[str] + + +class _ProcessedUpdateMap(TypedDict, total=False): + last_processed_at: UpdateMapDateTime + +class ItemUpdateMap(_UnlockUpdateMap, _ProcessedUpdateMap, total=False): + lock_expires_at: Optional[datetime] + lock_token: Optional[uuid.UUID] + lock_owner: Optional[str] + last_processed_at: UpdateMapDateTime -def get_unlock_update_map() -> UpdateMap: - return { - "lock_expires_at": None, - "lock_token": None, - "lock_owner": None, - } +def set_unlock_update_map_fields(update_map: _UnlockUpdateMap): + update_map["lock_expires_at"] = None + update_map["lock_token"] = None + update_map["lock_owner"] = None -def get_processed_update_map() -> UpdateMap: - return {"last_processed_at": get_current_datetime()} + +def set_processed_update_map_fields( + update_map: _ProcessedUpdateMap, + now: UpdateMapDateTime = NOW_PLACEHOLDER, +): + update_map["last_processed_at"] = now + + +class _ResolveNowUpdateMap(Protocol): + def items(self) -> Iterable[tuple[str, object]]: ... + + +_ResolveNowInput = Union[_ResolveNowUpdateMap, Sequence[_ResolveNowUpdateMap]] + + +def resolve_now_placeholders(update_values: _ResolveNowInput, now: datetime): + """ + Replaces `NOW_PLACEHOLDER` with `now` in an update map or a sequence of update rows. + """ + if isinstance(update_values, Sequence): + for update_row in update_values: + resolve_now_placeholders(update_row, now) + return + # Runtime dict narrowing is required here: pyright doesn't model TypedDicts as + # supporting generic dynamic-key mutation via protocol methods. + if not isinstance(update_values, dict): + raise TypeError( + "resolve_now_placeholders() expects update maps or sequences of update maps" + ) + for key, value in update_values.items(): + if value is NOW_PLACEHOLDER: + update_values[key] = now diff --git a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py index 33e839b8b..0ee2975eb 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/compute_groups.py @@ -2,7 +2,7 @@ import uuid from dataclasses import dataclass, field from datetime import datetime, timedelta -from typing import Sequence +from typing import Sequence, TypedDict from sqlalchemy import or_, select, update from sqlalchemy.orm import joinedload, load_only @@ -12,14 +12,17 @@ from dstack._internal.core.models.compute_groups import ComputeGroupStatus from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.server.background.pipeline_tasks.base import ( + NOW_PLACEHOLDER, Fetcher, Heartbeater, + ItemUpdateMap, Pipeline, PipelineItem, - UpdateMap, + UpdateMapDateTime, Worker, - get_processed_update_map, - get_unlock_update_map, + resolve_now_placeholders, + set_processed_update_map_fields, + set_unlock_update_map_fields, ) from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ComputeGroupModel, InstanceModel, ProjectModel @@ -199,25 +202,28 @@ async def process(self, item: PipelineItem): ) return - terminate_result = _TerminateResult() + result = _TerminateResult() # TODO: Fetch only compute groups with all instances terminating. if all(i.status == InstanceStatus.TERMINATING for i in compute_group_model.instances): - terminate_result = await _terminate_compute_group(compute_group_model) - if terminate_result.compute_group_update_map: + result = await _terminate_compute_group(compute_group_model) + set_processed_update_map_fields(result.compute_group_update_map) + if result.instances_update_map: + set_processed_update_map_fields(result.instances_update_map) + set_unlock_update_map_fields(result.compute_group_update_map) + if result.compute_group_update_map.get("deleted", False): logger.info("Terminated compute group %s", compute_group_model.id) - else: - terminate_result.compute_group_update_map = get_processed_update_map() - - terminate_result.compute_group_update_map |= get_unlock_update_map() async with get_session_ctx() as session: + now = get_current_datetime() + resolve_now_placeholders(result.compute_group_update_map, now=now) + resolve_now_placeholders(result.instances_update_map, now=now) res = await session.execute( update(ComputeGroupModel) .where( ComputeGroupModel.id == compute_group_model.id, ComputeGroupModel.lock_token == compute_group_model.lock_token, ) - .values(**terminate_result.compute_group_update_map) + .values(**result.compute_group_update_map) .returning(ComputeGroupModel.id) ) updated_ids = list(res.scalars().all()) @@ -229,13 +235,13 @@ async def process(self, item: PipelineItem): item.id, ) return - if not terminate_result.instances_update_map: + if not result.instances_update_map: return instances_ids = [i.id for i in compute_group_model.instances] res = await session.execute( update(InstanceModel) .where(InstanceModel.id.in_(instances_ids)) - .values(**terminate_result.instances_update_map) + .values(**result.instances_update_map) ) for instance_model in compute_group_model.instances: emit_instance_status_change_event( @@ -246,10 +252,28 @@ async def process(self, item: PipelineItem): ) +class _ComputeGroupUpdateMap(ItemUpdateMap, total=False): + status: ComputeGroupStatus + deleted: bool + deleted_at: UpdateMapDateTime + first_termination_retry_at: UpdateMapDateTime + last_termination_retry_at: UpdateMapDateTime + + +class _InstanceBulkUpdateMap(TypedDict, total=False): + last_processed_at: UpdateMapDateTime + deleted: bool + deleted_at: UpdateMapDateTime + finished_at: UpdateMapDateTime + status: InstanceStatus + + @dataclass class _TerminateResult: - compute_group_update_map: UpdateMap = field(default_factory=dict) - instances_update_map: UpdateMap = field(default_factory=dict) + compute_group_update_map: _ComputeGroupUpdateMap = field( + default_factory=_ComputeGroupUpdateMap + ) + instances_update_map: _InstanceBulkUpdateMap = field(default_factory=_InstanceBulkUpdateMap) async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _TerminateResult: @@ -283,15 +307,15 @@ async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _T compute_group, ) except Exception as e: + retry_at = get_current_datetime() + first_termination_retry_at = compute_group_model.first_termination_retry_at if compute_group_model.first_termination_retry_at is None: - result.compute_group_update_map["first_termination_retry_at"] = get_current_datetime() - result.compute_group_update_map["last_termination_retry_at"] = get_current_datetime() - if _next_termination_retry_at( - result.compute_group_update_map["last_termination_retry_at"] - ) < _get_termination_deadline( - result.compute_group_update_map.get( - "first_termination_retry_at", compute_group_model.first_termination_retry_at - ) + result.compute_group_update_map["first_termination_retry_at"] = NOW_PLACEHOLDER + first_termination_retry_at = retry_at + assert first_termination_retry_at is not None + result.compute_group_update_map["last_termination_retry_at"] = NOW_PLACEHOLDER + if _next_termination_retry_at(retry_at) < _get_termination_deadline( + first_termination_retry_at ): logger.warning( "Failed to terminate compute group %s. Will retry. Error: %r", @@ -309,11 +333,9 @@ async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _T exc_info=not isinstance(e, BackendError), ) terminated_result = _get_terminated_result() - return _TerminateResult( - compute_group_update_map=result.compute_group_update_map - | terminated_result.compute_group_update_map, - instances_update_map=result.instances_update_map | terminated_result.instances_update_map, - ) + terminated_result.compute_group_update_map.update(result.compute_group_update_map) + terminated_result.instances_update_map.update(result.instances_update_map) + return terminated_result def _next_termination_retry_at(last_termination_retry_at: datetime) -> datetime: @@ -325,19 +347,16 @@ def _get_termination_deadline(first_termination_retry_at: datetime) -> datetime: def _get_terminated_result() -> _TerminateResult: - now = get_current_datetime() return _TerminateResult( compute_group_update_map={ - "last_processed_at": now, "deleted": True, - "deleted_at": now, + "deleted_at": NOW_PLACEHOLDER, "status": ComputeGroupStatus.TERMINATED, }, instances_update_map={ - "last_processed_at": now, "deleted": True, - "deleted_at": now, - "finished_at": now, + "deleted_at": NOW_PLACEHOLDER, + "finished_at": NOW_PLACEHOLDER, "status": InstanceStatus.TERMINATED, }, ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py new file mode 100644 index 000000000..55ffcd7f9 --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -0,0 +1,558 @@ +import asyncio +import uuid +from dataclasses import dataclass, field +from datetime import timedelta +from typing import Sequence, TypedDict + +from sqlalchemy import or_, select, update +from sqlalchemy.ext.asyncio.session import AsyncSession +from sqlalchemy.orm import joinedload, load_only, selectinload + +from dstack._internal.core.models.fleets import FleetSpec, FleetStatus +from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason +from dstack._internal.core.models.runs import RunStatus +from dstack._internal.server.background.pipeline_tasks.base import ( + NOW_PLACEHOLDER, + Fetcher, + Heartbeater, + ItemUpdateMap, + Pipeline, + PipelineItem, + UpdateMapDateTime, + Worker, + resolve_now_placeholders, + set_processed_update_map_fields, + set_unlock_update_map_fields, +) +from dstack._internal.server.db import get_db, get_session_ctx +from dstack._internal.server.models import ( + FleetModel, + InstanceModel, + JobModel, + PlacementGroupModel, + RunModel, +) +from dstack._internal.server.services import events +from dstack._internal.server.services.fleets import ( + create_fleet_instance_model, + emit_fleet_status_change_event, + get_fleet_spec, + get_next_instance_num, + is_fleet_empty, + is_fleet_in_use, +) +from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.utils import sentry_utils +from dstack._internal.utils.common import get_current_datetime +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +class FleetPipeline(Pipeline[PipelineItem]): + def __init__( + self, + workers_num: int = 10, + queue_lower_limit_factor: float = 0.5, + queue_upper_limit_factor: float = 2.0, + min_processing_interval: timedelta = timedelta(seconds=60), + lock_timeout: timedelta = timedelta(seconds=20), + heartbeat_trigger: timedelta = timedelta(seconds=10), + ) -> None: + super().__init__( + workers_num=workers_num, + queue_lower_limit_factor=queue_lower_limit_factor, + queue_upper_limit_factor=queue_upper_limit_factor, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeat_trigger=heartbeat_trigger, + ) + self.__heartbeater = Heartbeater[PipelineItem]( + model_type=FleetModel, + lock_timeout=self._lock_timeout, + heartbeat_trigger=self._heartbeat_trigger, + ) + self.__fetcher = FleetFetcher( + queue=self._queue, + queue_desired_minsize=self._queue_desired_minsize, + min_processing_interval=self._min_processing_interval, + lock_timeout=self._lock_timeout, + heartbeater=self._heartbeater, + ) + self.__workers = [ + FleetWorker(queue=self._queue, heartbeater=self._heartbeater) + for _ in range(self._workers_num) + ] + + @property + def hint_fetch_model_name(self) -> str: + return FleetModel.__name__ + + @property + def _heartbeater(self) -> Heartbeater[PipelineItem]: + return self.__heartbeater + + @property + def _fetcher(self) -> Fetcher[PipelineItem]: + return self.__fetcher + + @property + def _workers(self) -> Sequence["FleetWorker"]: + return self.__workers + + +class FleetFetcher(Fetcher[PipelineItem]): + def __init__( + self, + queue: asyncio.Queue[PipelineItem], + queue_desired_minsize: int, + min_processing_interval: timedelta, + lock_timeout: timedelta, + heartbeater: Heartbeater[PipelineItem], + queue_check_delay: float = 1.0, + ) -> None: + super().__init__( + queue=queue, + queue_desired_minsize=queue_desired_minsize, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeater=heartbeater, + queue_check_delay=queue_check_delay, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.FleetFetcher.fetch") + async def fetch(self, limit: int) -> list[PipelineItem]: + fleet_lock, _ = get_locker(get_db().dialect_name).get_lockset(FleetModel.__tablename__) + async with fleet_lock: + async with get_session_ctx() as session: + now = get_current_datetime() + res = await session.execute( + select(FleetModel) + .where( + FleetModel.deleted == False, + or_( + FleetModel.last_processed_at <= now - self._min_processing_interval, + FleetModel.last_processed_at == FleetModel.created_at, + ), + or_( + FleetModel.lock_expires_at.is_(None), + FleetModel.lock_expires_at < now, + ), + or_( + FleetModel.lock_owner.is_(None), + FleetModel.lock_owner == FleetPipeline.__name__, + ), + ) + .order_by(FleetModel.last_processed_at.asc()) + .limit(limit) + .with_for_update(skip_locked=True, key_share=True) + .options( + load_only( + FleetModel.id, + FleetModel.lock_token, + FleetModel.lock_expires_at, + ) + ) + ) + fleet_models = list(res.scalars().all()) + lock_expires_at = get_current_datetime() + self._lock_timeout + lock_token = uuid.uuid4() + items = [] + for fleet_model in fleet_models: + prev_lock_expired = fleet_model.lock_expires_at is not None + fleet_model.lock_expires_at = lock_expires_at + fleet_model.lock_token = lock_token + fleet_model.lock_owner = FleetPipeline.__name__ + items.append( + PipelineItem( + __tablename__=FleetModel.__tablename__, + id=fleet_model.id, + lock_expires_at=lock_expires_at, + lock_token=lock_token, + prev_lock_expired=prev_lock_expired, + ) + ) + await session.commit() + return items + + +class FleetWorker(Worker[PipelineItem]): + def __init__( + self, + queue: asyncio.Queue[PipelineItem], + heartbeater: Heartbeater[PipelineItem], + ) -> None: + super().__init__( + queue=queue, + heartbeater=heartbeater, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.FleetWorker.process") + async def process(self, item: PipelineItem): + async with get_session_ctx() as session: + res = await session.execute( + select(FleetModel) + .where( + FleetModel.id == item.id, + FleetModel.lock_token == item.lock_token, + ) + .options(joinedload(FleetModel.project)) + .options( + selectinload(FleetModel.instances.and_(InstanceModel.deleted == False)) + .joinedload(InstanceModel.jobs) + .load_only(JobModel.id), + ) + .options( + selectinload( + FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses())) + ).load_only(RunModel.status) + ) + ) + fleet_model = res.unique().scalar_one_or_none() + if fleet_model is None: + logger.warning( + "Failed to process %s item %s: lock_token mismatch." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + return + + instance_lock, _ = get_locker(get_db().dialect_name).get_lockset( + InstanceModel.__tablename__ + ) + async with instance_lock: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.fleet_id == item.id, + InstanceModel.deleted == False, + # TODO: Lock instance models in the DB + # or_( + # InstanceModel.lock_expires_at.is_(None), + # InstanceModel.lock_expires_at < get_current_datetime(), + # ), + # or_( + # InstanceModel.lock_owner.is_(None), + # InstanceModel.lock_owner == FleetPipeline.__name__, + # ), + ) + .with_for_update(skip_locked=True, key_share=True) + ) + locked_instance_models = res.scalars().all() + if len(fleet_model.instances) != len(locked_instance_models): + logger.debug( + "Failed to lock fleet %s instances. The fleet will be processed later.", + item.id, + ) + now = get_current_datetime() + # Keep `lock_owner` so that `InstancePipeline` sees that the fleet is being locked + # but unset `lock_expires_at` to process the item again ASAP (after `min_processing_interval`). + # Unset `lock_token` so that heartbeater can no longer update the item. + res = await session.execute( + update(FleetModel) + .where( + FleetModel.id == item.id, + FleetModel.lock_token == item.lock_token, + ) + .values( + lock_expires_at=None, + lock_token=None, + last_processed_at=now, + ) + ) + if res.rowcount == 0: # pyright: ignore[reportAttributeAccessIssue] + logger.warning( + "Failed to reset lock: lock_token changed." + " The item is expected to be processed and updated on another fetch iteration." + ) + return + + # TODO: Lock instance models in the DB + # for instance_model in locked_instance_models: + # instance_model.lock_expires_at = item.lock_expires_at + # instance_model.lock_token = item.lock_token + # instance_model.lock_owner = FleetPipeline.__name__ + # await session.commit() + + result = await _process_fleet(fleet_model) + fleet_update_map = _FleetUpdateMap() + fleet_update_map.update(result.fleet_update_map) + set_processed_update_map_fields(fleet_update_map) + set_unlock_update_map_fields(fleet_update_map) + instance_update_rows = _build_instance_update_rows(result.instance_id_to_update_map) + + async with get_session_ctx() as session: + now = get_current_datetime() + resolve_now_placeholders(fleet_update_map, now=now) + resolve_now_placeholders(instance_update_rows, now=now) + res = await session.execute( + update(FleetModel) + .where( + FleetModel.id == fleet_model.id, + FleetModel.lock_token == fleet_model.lock_token, + ) + .values(**fleet_update_map) + .returning(FleetModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + logger.warning( + "Failed to update %s item %s after processing: lock_token changed." + " The item is expected to be processed and updated on another fetch iteration.", + item.__tablename__, + item.id, + ) + # TODO: Clean up fleet. + return + + if fleet_update_map.get("deleted"): + await session.execute( + update(PlacementGroupModel) + .where(PlacementGroupModel.fleet_id == item.id) + .values(fleet_deleted=True) + ) + if instance_update_rows: + await session.execute( + update(InstanceModel).execution_options(synchronize_session=False), + instance_update_rows, + ) + if result.new_instances_count > 0: + await _create_missing_fleet_instances( + session=session, + fleet_model=fleet_model, + new_instances_count=result.new_instances_count, + ) + emit_fleet_status_change_event( + session=session, + fleet_model=fleet_model, + old_status=fleet_model.status, + new_status=fleet_update_map.get("status", fleet_model.status), + status_message=fleet_update_map.get("status_message", fleet_model.status_message), + ) + + +class _FleetUpdateMap(ItemUpdateMap, total=False): + status: FleetStatus + status_message: str + deleted: bool + deleted_at: UpdateMapDateTime + consolidation_attempt: int + last_consolidated_at: UpdateMapDateTime + + +class _InstanceUpdateMap(TypedDict, total=False): + status: InstanceStatus + termination_reason: InstanceTerminationReason + termination_reason_message: str + deleted: bool + deleted_at: UpdateMapDateTime + last_processed_at: UpdateMapDateTime + id: uuid.UUID + + +@dataclass +class _ProcessResult: + fleet_update_map: _FleetUpdateMap = field(default_factory=_FleetUpdateMap) + instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict) + new_instances_count: int = 0 + + +@dataclass +class _MaintainNodesResult: + instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap] = field(default_factory=dict) + new_instances_count: int = 0 + changes_required: bool = False + + @property + def has_changes(self) -> bool: + return len(self.instance_id_to_update_map) > 0 or self.new_instances_count > 0 + + +async def _process_fleet(fleet_model: FleetModel) -> _ProcessResult: + result = _consolidate_fleet_state_with_spec(fleet_model) + if result.new_instances_count > 0: + # Avoid deleting fleets that are about to provision new instances. + return result + delete = _should_delete_fleet(fleet_model) + if delete: + result.fleet_update_map["status"] = FleetStatus.TERMINATED + result.fleet_update_map["deleted"] = True + result.fleet_update_map["deleted_at"] = NOW_PLACEHOLDER + return result + + +def _consolidate_fleet_state_with_spec(fleet_model: FleetModel) -> _ProcessResult: + result = _ProcessResult() + if fleet_model.status == FleetStatus.TERMINATING: + return result + fleet_spec = get_fleet_spec(fleet_model) + if fleet_spec.configuration.nodes is None or fleet_spec.autocreated: + # Only explicitly created cloud fleets are consolidated. + return result + if not _is_fleet_ready_for_consolidation(fleet_model): + return result + maintain_nodes_result = _maintain_fleet_nodes_in_min_max_range(fleet_model, fleet_spec) + if maintain_nodes_result.has_changes: + result.instance_id_to_update_map = maintain_nodes_result.instance_id_to_update_map + result.new_instances_count = maintain_nodes_result.new_instances_count + if maintain_nodes_result.changes_required: + result.fleet_update_map["consolidation_attempt"] = fleet_model.consolidation_attempt + 1 + else: + # The fleet is consolidated with respect to nodes min/max. + result.fleet_update_map["consolidation_attempt"] = 0 + result.fleet_update_map["last_consolidated_at"] = NOW_PLACEHOLDER + return result + + +def _is_fleet_ready_for_consolidation(fleet_model: FleetModel) -> bool: + consolidation_retry_delay = _get_consolidation_retry_delay(fleet_model.consolidation_attempt) + last_consolidated_at = fleet_model.last_consolidated_at or fleet_model.last_processed_at + duration_since_last_consolidation = get_current_datetime() - last_consolidated_at + return duration_since_last_consolidation >= consolidation_retry_delay + + +# We use exponentially increasing consolidation retry delays so that +# consolidation does not happen too often. In particular, this prevents +# retrying instance provisioning constantly in case of no offers. +_CONSOLIDATION_RETRY_DELAYS = [ + timedelta(minutes=1), + timedelta(minutes=2), + timedelta(minutes=5), + timedelta(minutes=10), + timedelta(minutes=30), +] + + +def _get_consolidation_retry_delay(consolidation_attempt: int) -> timedelta: + if consolidation_attempt < len(_CONSOLIDATION_RETRY_DELAYS): + return _CONSOLIDATION_RETRY_DELAYS[consolidation_attempt] + return _CONSOLIDATION_RETRY_DELAYS[-1] + + +def _maintain_fleet_nodes_in_min_max_range( + fleet_model: FleetModel, + fleet_spec: FleetSpec, +) -> _MaintainNodesResult: + """ + Ensures the fleet has at least `nodes.min` and at most `nodes.max` instances. + """ + assert fleet_spec.configuration.nodes is not None + result = _MaintainNodesResult() + for instance in fleet_model.instances: + # Delete terminated but not deleted instances since + # they are going to be replaced with new pending instances. + if instance.status == InstanceStatus.TERMINATED and not instance.deleted: + result.changes_required = True + result.instance_id_to_update_map[instance.id] = { + "deleted": True, + "deleted_at": NOW_PLACEHOLDER, + } + active_instances = [ + i for i in fleet_model.instances if i.status != InstanceStatus.TERMINATED and not i.deleted + ] + active_instances_num = len(active_instances) + if active_instances_num < fleet_spec.configuration.nodes.min: + result.changes_required = True + nodes_missing = fleet_spec.configuration.nodes.min - active_instances_num + result.new_instances_count = nodes_missing + return result + if ( + fleet_spec.configuration.nodes.max is None + or active_instances_num <= fleet_spec.configuration.nodes.max + ): + return result + # Fleet has more instances than allowed by nodes.max. + # This is possible due to race conditions (e.g. provisioning jobs in a fleet concurrently) + # or if nodes.max is updated. + result.changes_required = True + nodes_redundant = active_instances_num - fleet_spec.configuration.nodes.max + for instance in fleet_model.instances: + if nodes_redundant == 0: + break + if instance.status == InstanceStatus.IDLE: + result.instance_id_to_update_map[instance.id] = { + "termination_reason": InstanceTerminationReason.MAX_INSTANCES_LIMIT, + "termination_reason_message": "Fleet has too many instances", + "status": InstanceStatus.TERMINATING, + } + nodes_redundant -= 1 + return result + + +def _should_delete_fleet(fleet_model: FleetModel) -> bool: + if fleet_model.project.deleted: + # It used to be possible to delete project with active resources: + # https://github.com/dstackai/dstack/issues/3077 + logger.info("Fleet %s deleted due to deleted project", fleet_model.name) + return True + + if is_fleet_in_use(fleet_model) or not is_fleet_empty(fleet_model): + return False + + # TODO: Drop non-terminating fleets auto-deletion after dropping fleets auto-creation. + fleet_spec = get_fleet_spec(fleet_model) + if ( + fleet_model.status != FleetStatus.TERMINATING + and fleet_spec.configuration.nodes is not None + and fleet_spec.configuration.nodes.min == 0 + ): + # Empty fleets that allow 0 nodes should not be auto-deleted + return False + + logger.info("Automatic cleanup of an empty fleet %s", fleet_model.name) + return True + + +def _build_instance_update_rows( + instance_id_to_update_map: dict[uuid.UUID, _InstanceUpdateMap], +) -> list[_InstanceUpdateMap]: + instance_update_rows = [] + for instance_id, instance_update_map in instance_id_to_update_map.items(): + update_row = _InstanceUpdateMap() + update_row.update(instance_update_map) + update_row["id"] = instance_id + set_processed_update_map_fields(update_row) + instance_update_rows.append(update_row) + return instance_update_rows + + +async def _create_missing_fleet_instances( + session: AsyncSession, + fleet_model: FleetModel, + new_instances_count: int, +): + fleet_spec = get_fleet_spec(fleet_model) + res = await session.execute( + select(InstanceModel.instance_num).where( + InstanceModel.fleet_id == fleet_model.id, + InstanceModel.deleted == False, + ) + ) + taken_instance_nums = set(res.scalars().all()) + for _ in range(new_instances_count): + instance_num = get_next_instance_num(taken_instance_nums) + instance_model = create_fleet_instance_model( + session=session, + project=fleet_model.project, + # TODO: Store fleet.user and pass it instead of the project owner. + username=fleet_model.project.owner.name, + spec=fleet_spec, + instance_num=instance_num, + ) + instance_model.fleet_id = fleet_model.id + taken_instance_nums.add(instance_num) + events.emit( + session=session, + message=( + "Instance created to meet target fleet node count." + f" Status: {instance_model.status.upper()}" + ), + actor=events.SystemActor(), + targets=[events.Target.from_model(instance_model)], + ) + logger.info( + "Added %d instances to fleet %s", + new_instances_count, + fleet_model.name, + ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py index cdd0904e1..2d5f0a947 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py @@ -2,7 +2,7 @@ import uuid from dataclasses import dataclass, field from datetime import timedelta -from typing import Optional, Sequence +from typing import Optional, Sequence, TypedDict from sqlalchemy import delete, or_, select, update from sqlalchemy.orm import joinedload, load_only @@ -14,12 +14,13 @@ from dstack._internal.server.background.pipeline_tasks.base import ( Fetcher, Heartbeater, + ItemUpdateMap, Pipeline, PipelineItem, - UpdateMap, Worker, - get_processed_update_map, - get_unlock_update_map, + resolve_now_placeholders, + set_processed_update_map_fields, + set_unlock_update_map_fields, ) from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( @@ -227,13 +228,18 @@ async def _process_submitted_item(item: GatewayPipelineItem): return result = await _process_submitted_gateway(gateway_model) - update_map = result.update_map | get_processed_update_map() | get_unlock_update_map() + update_map = _GatewayUpdateMap() + update_map.update(result.update_map) + set_processed_update_map_fields(update_map) + set_unlock_update_map_fields(update_map) async with get_session_ctx() as session: gateway_compute_model = result.gateway_compute_model if gateway_compute_model is not None: session.add(gateway_compute_model) await session.flush() update_map["gateway_compute_id"] = gateway_compute_model.id + now = get_current_datetime() + resolve_now_placeholders(update_map, now=now) res = await session.execute( update(GatewayModel) .where( @@ -262,9 +268,20 @@ async def _process_submitted_item(item: GatewayPipelineItem): ) +class _GatewayUpdateMap(ItemUpdateMap, total=False): + status: GatewayStatus + status_message: str + gateway_compute_id: uuid.UUID + + +class _GatewayComputeUpdateMap(TypedDict, total=False): + active: bool + deleted: bool + + @dataclass class _SubmittedResult: - update_map: UpdateMap = field(default_factory=dict) + update_map: _GatewayUpdateMap = field(default_factory=_GatewayUpdateMap) gateway_compute_model: Optional[GatewayComputeModel] = None @@ -337,15 +354,20 @@ async def _process_provisioning_item(item: GatewayPipelineItem): return result = await _process_provisioning_gateway(gateway_model) - update_map = result.gateway_update_map | get_processed_update_map() | get_unlock_update_map() + gateway_update_map = result.gateway_update_map + set_processed_update_map_fields(gateway_update_map) + set_unlock_update_map_fields(gateway_update_map) + async with get_session_ctx() as session: + now = get_current_datetime() + resolve_now_placeholders(gateway_update_map, now=now) res = await session.execute( update(GatewayModel) .where( GatewayModel.id == gateway_model.id, GatewayModel.lock_token == gateway_model.lock_token, ) - .values(**update_map) + .values(**gateway_update_map) .returning(GatewayModel.id) ) updated_ids = list(res.scalars().all()) @@ -361,8 +383,8 @@ async def _process_provisioning_item(item: GatewayPipelineItem): session=session, gateway_model=gateway_model, old_status=gateway_model.status, - new_status=update_map.get("status", gateway_model.status), - status_message=update_map.get("status_message", gateway_model.status_message), + new_status=gateway_update_map.get("status", gateway_model.status), + status_message=gateway_update_map.get("status_message", gateway_model.status_message), ) if result.gateway_compute_update_map: res = await session.execute( @@ -383,8 +405,10 @@ async def _process_provisioning_item(item: GatewayPipelineItem): @dataclass class _ProvisioningResult: - gateway_update_map: UpdateMap = field(default_factory=dict) - gateway_compute_update_map: UpdateMap = field(default_factory=dict) + gateway_update_map: _GatewayUpdateMap = field(default_factory=_GatewayUpdateMap) + gateway_compute_update_map: _GatewayComputeUpdateMap = field( + default_factory=_GatewayComputeUpdateMap + ) async def _process_provisioning_gateway(gateway_model: GatewayModel) -> _ProvisioningResult: @@ -475,13 +499,17 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem): targets=[events.Target.from_model(gateway_model)], ) else: + update_map = _GatewayUpdateMap() + set_processed_update_map_fields(update_map) + set_unlock_update_map_fields(update_map) + resolve_now_placeholders(update_map, now=get_current_datetime()) res = await session.execute( update(GatewayModel) .where( GatewayModel.id == gateway_model.id, GatewayModel.lock_token == gateway_model.lock_token, ) - .values(**get_processed_update_map()) + .values(**update_map) .returning(GatewayModel.id) ) updated_ids = list(res.scalars().all()) @@ -513,12 +541,14 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem): @dataclass -class _DeletedResult: +class _ProcessToBeDeletedResult: delete_gateway: bool - gateway_compute_update_map: UpdateMap = field(default_factory=dict) + gateway_compute_update_map: _GatewayComputeUpdateMap = field( + default_factory=_GatewayComputeUpdateMap + ) -async def _process_to_be_deleted_gateway(gateway_model: GatewayModel) -> _DeletedResult: +async def _process_to_be_deleted_gateway(gateway_model: GatewayModel) -> _ProcessToBeDeletedResult: assert gateway_model.backend.type != BackendType.DSTACK backend = await backends_services.get_project_backend_by_type_or_error( project=gateway_model.project, backend_type=gateway_model.backend.type @@ -542,9 +572,9 @@ async def _process_to_be_deleted_gateway(gateway_model: GatewayModel) -> _Delete "Error when deleting gateway compute for %s", gateway_model.name, ) - return _DeletedResult(delete_gateway=False) + return _ProcessToBeDeletedResult(delete_gateway=False) logger.info("Deleted gateway compute for %s", gateway_model.name) - result = _DeletedResult(delete_gateway=True) + result = _ProcessToBeDeletedResult(delete_gateway=True) if gateway_model.gateway_compute is not None: await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address) result.gateway_compute_update_map = {"active": False, "deleted": True} diff --git a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py index 193358ec0..703cfe154 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py @@ -1,5 +1,6 @@ import asyncio import uuid +from dataclasses import dataclass, field from datetime import timedelta from typing import Sequence @@ -9,14 +10,17 @@ from dstack._internal.core.backends.base.compute import ComputeWithPlacementGroupSupport from dstack._internal.core.errors import PlacementGroupInUseError from dstack._internal.server.background.pipeline_tasks.base import ( + NOW_PLACEHOLDER, Fetcher, Heartbeater, + ItemUpdateMap, Pipeline, PipelineItem, - UpdateMap, + UpdateMapDateTime, Worker, - get_processed_update_map, - get_unlock_update_map, + resolve_now_placeholders, + set_processed_update_map_fields, + set_unlock_update_map_fields, ) from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( @@ -193,15 +197,15 @@ async def process(self, item: PipelineItem): ) return - update_map = await _delete_placement_group(placement_group_model) - if update_map: + result = await _delete_placement_group(placement_group_model) + update_map = result.update_map + set_processed_update_map_fields(update_map) + set_unlock_update_map_fields(update_map) + if update_map.get("deleted", False): logger.info("Deleted placement group %s", placement_group_model.name) - else: - update_map = get_processed_update_map() - - update_map |= get_unlock_update_map() async with get_session_ctx() as session: + resolve_now_placeholders(update_map, now=get_current_datetime()) res = await session.execute( update(PlacementGroupModel) .where( @@ -221,13 +225,25 @@ async def process(self, item: PipelineItem): ) -async def _delete_placement_group(placement_group_model: PlacementGroupModel) -> UpdateMap: +class _PlacementGroupUpdateMap(ItemUpdateMap, total=False): + deleted: bool + deleted_at: UpdateMapDateTime + + +@dataclass +class _DeleteResult: + update_map: _PlacementGroupUpdateMap = field(default_factory=_PlacementGroupUpdateMap) + + +async def _delete_placement_group( + placement_group_model: PlacementGroupModel, +) -> _DeleteResult: placement_group = placement_group_model_to_placement_group(placement_group_model) if placement_group.provisioning_data is None: logger.error( "Failed to delete placement group %s. provisioning_data is None.", placement_group.name ) - return _get_deleted_update_map() + return _get_deleted_result() backend = await backends_services.get_project_backend_by_type( project=placement_group_model.project, backend_type=placement_group.provisioning_data.backend, @@ -238,7 +254,7 @@ async def _delete_placement_group(placement_group_model: PlacementGroupModel) -> "Failed to delete placement group %s. Backend not available. Please delete it manually.", placement_group.name, ) - return _get_deleted_update_map() + return _get_deleted_result() compute = backend.compute() assert isinstance(compute, ComputeWithPlacementGroupSupport) try: @@ -247,22 +263,18 @@ async def _delete_placement_group(placement_group_model: PlacementGroupModel) -> logger.info( "Placement group %s is still in use. Skipping deletion for now.", placement_group.name ) - return {} + return _DeleteResult() except Exception: # TODO: Retry deletion logger.exception( "Got exception when deleting placement group %s. Please delete it manually.", placement_group.name, ) - return _get_deleted_update_map() - - return _get_deleted_update_map() + return _get_deleted_result() -def _get_deleted_update_map() -> UpdateMap: - now = get_current_datetime() - return { - "last_processed_at": now, - "deleted": True, - "deleted_at": now, - } +def _get_deleted_result() -> _DeleteResult: + update_map = _PlacementGroupUpdateMap() + update_map["deleted"] = True + update_map["deleted_at"] = NOW_PLACEHOLDER + return _DeleteResult(update_map=update_map) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py index 578fe8423..c7a8f5761 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py @@ -11,14 +11,17 @@ from dstack._internal.core.errors import BackendError, BackendNotAvailable from dstack._internal.core.models.volumes import VolumeStatus from dstack._internal.server.background.pipeline_tasks.base import ( + NOW_PLACEHOLDER, Fetcher, Heartbeater, + ItemUpdateMap, Pipeline, PipelineItem, - UpdateMap, + UpdateMapDateTime, Worker, - get_processed_update_map, - get_unlock_update_map, + resolve_now_placeholders, + set_processed_update_map_fields, + set_unlock_update_map_fields, ) from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( @@ -233,8 +236,12 @@ async def _process_submitted_item(item: VolumePipelineItem): return result = await _process_submitted_volume(volume_model) - update_map = result.update_map | get_processed_update_map() | get_unlock_update_map() + update_map = result.update_map + set_processed_update_map_fields(update_map) + set_unlock_update_map_fields(update_map) + async with get_session_ctx() as session: + resolve_now_placeholders(update_map, now=get_current_datetime()) res = await session.execute( update(VolumeModel) .where( @@ -263,9 +270,17 @@ async def _process_submitted_item(item: VolumePipelineItem): ) +class _VolumeUpdateMap(ItemUpdateMap, total=False): + status: VolumeStatus + status_message: str + volume_provisioning_data: str + deleted: bool + deleted_at: UpdateMapDateTime + + @dataclass class _SubmittedResult: - update_map: UpdateMap = field(default_factory=dict) + update_map: _VolumeUpdateMap = field(default_factory=_VolumeUpdateMap) async def _process_submitted_volume(volume_model: VolumeModel) -> _SubmittedResult: @@ -363,8 +378,13 @@ async def _process_to_be_deleted_item(item: VolumePipelineItem): return result = await _process_to_be_deleted_volume(volume_model) - update_map = result.update_map | get_unlock_update_map() + update_map = _VolumeUpdateMap() + update_map.update(result.update_map) + set_processed_update_map_fields(update_map) + set_unlock_update_map_fields(update_map) async with get_session_ctx() as session: + now = get_current_datetime() + resolve_now_placeholders(update_map, now=now) res = await session.execute( update(VolumeModel) .where( @@ -392,11 +412,11 @@ async def _process_to_be_deleted_item(item: VolumePipelineItem): @dataclass -class _DeletedResult: - update_map: UpdateMap = field(default_factory=dict) +class _ProcessToBeDeletedResult: + update_map: _VolumeUpdateMap = field(default_factory=_VolumeUpdateMap) -async def _process_to_be_deleted_volume(volume_model: VolumeModel) -> _DeletedResult: +async def _process_to_be_deleted_volume(volume_model: VolumeModel) -> _ProcessToBeDeletedResult: volume = volume_model_to_volume(volume_model) if volume.external: return _get_deleted_result() @@ -437,12 +457,10 @@ async def _process_to_be_deleted_volume(volume_model: VolumeModel) -> _DeletedRe return _get_deleted_result() -def _get_deleted_result() -> _DeletedResult: - now = get_current_datetime() - return _DeletedResult( +def _get_deleted_result() -> _ProcessToBeDeletedResult: + return _ProcessToBeDeletedResult( update_map={ - "last_processed_at": now, "deleted": True, - "deleted_at": now, + "deleted_at": NOW_PLACEHOLDER, } ) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py index 45ae8ec7f..9c7cd6ac1 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py @@ -102,13 +102,13 @@ def start_scheduled_tasks() -> AsyncIOScheduler: _scheduler.add_job( process_idle_volumes, IntervalTrigger(seconds=60, jitter=10), max_instances=1 ) - _scheduler.add_job( - process_fleets, - IntervalTrigger(seconds=10, jitter=2), - max_instances=1, - ) _scheduler.add_job(delete_instance_health_checks, IntervalTrigger(minutes=5), max_instances=1) if not FeatureFlags.PIPELINE_PROCESSING_ENABLED: + _scheduler.add_job( + process_fleets, + IntervalTrigger(seconds=10, jitter=2), + max_instances=1, + ) _scheduler.add_job( process_gateways, IntervalTrigger(seconds=10, jitter=2), max_instances=5 ) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/compute_groups.py b/src/dstack/_internal/server/background/scheduled_tasks/compute_groups.py index feb1cc507..58d6b2c8b 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/compute_groups.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/compute_groups.py @@ -32,6 +32,8 @@ TERMINATION_RETRY_MAX_DURATION = timedelta(minutes=15) +# NOTE: This scheduled task is going to be deprecated in favor of `ComputeGroupPipeline`. +# If this logic changes before removal, keep `pipeline_tasks/compute_groups.py` in sync. async def process_compute_groups(batch_size: int = 1): tasks = [] for _ in range(batch_size): diff --git a/src/dstack/_internal/server/background/scheduled_tasks/fleets.py b/src/dstack/_internal/server/background/scheduled_tasks/fleets.py index a758f86ad..6b1ba7667 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/fleets.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/fleets.py @@ -5,10 +5,11 @@ from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload, load_only, selectinload, with_loader_criteria +from sqlalchemy.orm import joinedload, load_only, selectinload from dstack._internal.core.models.fleets import FleetSpec, FleetStatus from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason +from dstack._internal.core.models.runs import RunStatus from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( FleetModel, @@ -39,6 +40,8 @@ MIN_PROCESSING_INTERVAL = timedelta(seconds=30) +# NOTE: This scheduled task is going to be deprecated in favor of `FleetPipeline`. +# If this logic changes before removal, keep `pipeline_tasks/fleets.py` in sync. @sentry_utils.instrument_scheduled_task async def process_fleets(): fleet_lock, fleet_lockset = get_locker(get_db().dialect_name).get_lockset( @@ -59,10 +62,9 @@ async def process_fleets(): ) .options( load_only(FleetModel.id, FleetModel.name), - selectinload(FleetModel.instances).load_only(InstanceModel.id), - with_loader_criteria( - InstanceModel, InstanceModel.deleted == False, include_aliases=True - ), + selectinload( + FleetModel.instances.and_(InstanceModel.deleted == False) + ).load_only(InstanceModel.id), ) .order_by(FleetModel.last_processed_at.asc()) .limit(BATCH_SIZE) @@ -115,14 +117,17 @@ async def _process_fleets(session: AsyncSession, fleet_models: List[FleetModel]) res = await session.execute( select(FleetModel) .where(FleetModel.id.in_(fleet_ids)) + .options(joinedload(FleetModel.project)) .options( - joinedload(FleetModel.instances).joinedload(InstanceModel.jobs).load_only(JobModel.id), - with_loader_criteria( - InstanceModel, InstanceModel.deleted == False, include_aliases=True - ), + selectinload(FleetModel.instances.and_(InstanceModel.deleted == False)) + .joinedload(InstanceModel.jobs) + .load_only(JobModel.id), + ) + .options( + selectinload( + FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses())) + ).load_only(RunModel.status) ) - .options(joinedload(FleetModel.project)) - .options(joinedload(FleetModel.runs).load_only(RunModel.status)) .execution_options(populate_existing=True) ) fleet_models = list(res.unique().scalars().all()) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/gateways.py b/src/dstack/_internal/server/background/scheduled_tasks/gateways.py index fc12e8e3b..262f45a18 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/gateways.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/gateways.py @@ -35,6 +35,8 @@ async def process_gateways_connections(): await _process_active_connections() +# NOTE: This scheduled task is going to be deprecated in favor of `GatewayPipeline`. +# If this logic changes before removal, keep `pipeline_tasks/gateways.py` in sync. @sentry_utils.instrument_scheduled_task async def process_gateways(): lock, lockset = get_locker(get_db().dialect_name).get_lockset(GatewayModel.__tablename__) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/placement_groups.py b/src/dstack/_internal/server/background/scheduled_tasks/placement_groups.py index 71ab51b07..1106ce491 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/placement_groups.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/placement_groups.py @@ -19,6 +19,8 @@ logger = get_logger(__name__) +# NOTE: This scheduled task is going to be deprecated in favor of `PlacementGroupPipeline`. +# If this logic changes before removal, keep `pipeline_tasks/placement_groups.py` in sync. @sentry_utils.instrument_scheduled_task async def process_placement_groups(): lock, lockset = get_locker(get_db().dialect_name).get_lockset( diff --git a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py index 5d1b2e1a7..151f07dee 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/submitted_jobs.py @@ -13,7 +13,6 @@ load_only, noload, selectinload, - with_loader_criteria, ) from dstack._internal.core.backends.base.backend import Backend @@ -223,9 +222,8 @@ async def _process_submitted_job( .where(JobModel.id == job_model.id) .options(joinedload(JobModel.instance)) .options( - joinedload(JobModel.fleet).joinedload(FleetModel.instances), - with_loader_criteria( - InstanceModel, InstanceModel.deleted == False, include_aliases=True + joinedload(JobModel.fleet).selectinload( + FleetModel.instances.and_(InstanceModel.deleted == False) ), ) ) @@ -236,9 +234,8 @@ async def _process_submitted_job( .options(joinedload(RunModel.project).joinedload(ProjectModel.backends)) .options(joinedload(RunModel.user).load_only(UserModel.name)) .options( - joinedload(RunModel.fleet).joinedload(FleetModel.instances), - with_loader_criteria( - InstanceModel, InstanceModel.deleted == False, include_aliases=True + joinedload(RunModel.fleet).selectinload( + FleetModel.instances.and_(InstanceModel.deleted == False) ), ) ) @@ -584,6 +581,8 @@ async def _fetch_fleet_with_master_instance_provisioning_data( # To avoid violating fleet placement cluster during master provisioning, # we must lock empty fleets and respect existing instances in non-empty fleets. # On SQLite always take the lock during master provisioning for simplicity. + # It's fine to lock fleets currently locked by pipelines (with lock_* fields set) + # since we won't update fleets – we only need to ensure there is no parallel provisioning. await exit_stack.enter_async_context( get_locker(get_db().dialect_name).lock_ctx( FleetModel.__tablename__, [fleet_model.id] diff --git a/src/dstack/_internal/server/background/scheduled_tasks/volumes.py b/src/dstack/_internal/server/background/scheduled_tasks/volumes.py index a61f79694..11e6f3c59 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/volumes.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/volumes.py @@ -24,6 +24,8 @@ logger = get_logger(__name__) +# NOTE: This scheduled task is going to be deprecated in favor of `VolumePipeline`. +# If this logic changes before removal, keep `pipeline_tasks/volumes.py` in sync. @sentry_utils.instrument_scheduled_task async def process_submitted_volumes(): lock, lockset = get_locker(get_db().dialect_name).get_lockset(VolumeModel.__tablename__) diff --git a/src/dstack/_internal/server/migrations/versions/2026/02_27_1218_d21d3e61de27_add_fleetmodel_pipeline_columns.py b/src/dstack/_internal/server/migrations/versions/2026/02_27_1218_d21d3e61de27_add_fleetmodel_pipeline_columns.py new file mode 100644 index 000000000..fad3da790 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/02_27_1218_d21d3e61de27_add_fleetmodel_pipeline_columns.py @@ -0,0 +1,47 @@ +"""Add FleetModel pipeline columns + +Revision ID: d21d3e61de27 +Revises: 9a363c3cbe04 +Create Date: 2026-02-27 12:18:01.768776+00:00 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "d21d3e61de27" +down_revision = "9a363c3cbe04" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("fleets", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "lock_expires_at", dstack._internal.server.models.NaiveDateTime(), nullable=True + ) + ) + batch_op.add_column( + sa.Column( + "lock_token", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True + ) + ) + batch_op.add_column(sa.Column("lock_owner", sa.String(length=100), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("fleets", schema=None) as batch_op: + batch_op.drop_column("lock_owner") + batch_op.drop_column("lock_token") + batch_op.drop_column("lock_expires_at") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_02_0530_46150101edec_add_ix_fleets_pipeline_fetch_q_index.py b/src/dstack/_internal/server/migrations/versions/2026/03_02_0530_46150101edec_add_ix_fleets_pipeline_fetch_q_index.py new file mode 100644 index 000000000..365aac41c --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/03_02_0530_46150101edec_add_ix_fleets_pipeline_fetch_q_index.py @@ -0,0 +1,49 @@ +"""Add ix_fleets_pipeline_fetch_q index + +Revision ID: 46150101edec +Revises: d21d3e61de27 +Create Date: 2026-03-02 05:30:07.196407+00:00 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "46150101edec" +down_revision = "d21d3e61de27" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.get_context().autocommit_block(): + op.drop_index( + "ix_fleets_pipeline_fetch_q", + table_name="fleets", + if_exists=True, + postgresql_concurrently=True, + ) + op.create_index( + "ix_fleets_pipeline_fetch_q", + "fleets", + [sa.literal_column("last_processed_at ASC")], + unique=False, + sqlite_where=sa.text("deleted = 0"), + postgresql_where=sa.text("deleted IS FALSE"), + postgresql_concurrently=True, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.get_context().autocommit_block(): + op.drop_index( + "ix_fleets_pipeline_fetch_q", + table_name="fleets", + if_exists=True, + postgresql_concurrently=True, + ) + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index a7a8ec0bd..15c5488da 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -576,7 +576,7 @@ class PoolModel(BaseModel): instances: Mapped[List["InstanceModel"]] = relationship(back_populates="pool", lazy="selectin") -class FleetModel(BaseModel): +class FleetModel(PipelineModelMixin, BaseModel): __tablename__ = "fleets" id: Mapped[uuid.UUID] = mapped_column( @@ -604,9 +604,20 @@ class FleetModel(BaseModel): jobs: Mapped[List["JobModel"]] = relationship(back_populates="fleet") instances: Mapped[List["InstanceModel"]] = relationship(back_populates="fleet") + # `consolidation_attempt` counts how many times in a row fleet needed consolidation. + # Allows increasing delays between attempts. consolidation_attempt: Mapped[int] = mapped_column(Integer, server_default="0") last_consolidated_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) + __table_args__ = ( + Index( + "ix_fleets_pipeline_fetch_q", + last_processed_at.asc(), + postgresql_where=deleted == false(), + sqlite_where=deleted == false(), + ), + ) + class InstanceModel(BaseModel): __tablename__ = "instances" diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index c0483ebb6..c0ec21aea 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -102,11 +102,45 @@ def switch_fleet_status( return fleet_model.status = new_status + emit_fleet_status_change_event( + session=session, + fleet_model=fleet_model, + old_status=old_status, + new_status=new_status, + status_message=fleet_model.status_message, + actor=actor, + ) - msg = f"Fleet status changed {old_status.upper()} -> {new_status.upper()}" + +def emit_fleet_status_change_event( + session: AsyncSession, + fleet_model: FleetModel, + old_status: FleetStatus, + new_status: FleetStatus, + status_message: Optional[str], + actor: events.AnyActor = events.SystemActor(), +) -> None: + if old_status == new_status: + return + msg = get_fleet_status_change_message( + old_status=old_status, + new_status=new_status, + status_message=status_message, + ) events.emit(session, msg, actor=actor, targets=[events.Target.from_model(fleet_model)]) +def get_fleet_status_change_message( + old_status: FleetStatus, + new_status: FleetStatus, + status_message: Optional[str], +) -> str: + msg = f"Fleet status changed {old_status.upper()} -> {new_status.upper()}" + if status_message is not None: + msg += f" ({status_message})" + return msg + + async def list_projects_with_no_active_fleets( session: AsyncSession, user: UserModel, @@ -225,7 +259,7 @@ async def list_projects_fleet_models( .where(*filters) .order_by(*order_by) .limit(limit) - .options(joinedload(FleetModel.instances.and_(InstanceModel.deleted == False))) + .options(selectinload(FleetModel.instances.and_(InstanceModel.deleted == False))) ) fleet_models = list(res.unique().scalars().all()) return fleet_models @@ -256,7 +290,7 @@ async def list_project_fleet_models( res = await session.execute( select(FleetModel) .where(*filters) - .options(joinedload(FleetModel.instances.and_(InstanceModel.deleted == False))) + .options(selectinload(FleetModel.instances.and_(InstanceModel.deleted == False))) ) return list(res.unique().scalars().all()) @@ -485,13 +519,24 @@ async def apply_plan( .joinedload(InstanceModel.jobs) .load_only(JobModel.id) ) - .options(selectinload(FleetModel.runs)) + # `is_fleet_in_use()` only needs active run presence/status. + .options( + selectinload( + FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses())) + ).load_only(RunModel.id, RunModel.status) + ) .execution_options(populate_existing=True) .order_by(FleetModel.id) # take locks in order .with_for_update(key_share=True) ) fleet_model = res.scalars().unique().one_or_none() if fleet_model is not None: + if fleet_model.lock_expires_at is not None: + # TODO: Make the endpoint fully async so we don't need to lock and error: + # put the request in queue and process in the background. + raise ServerClientError( + "Failed to update fleet: fleet is being processed currently. Try again later." + ) return await _update_fleet( session=session, user=user, @@ -629,8 +674,7 @@ async def delete_fleets( FleetModel.name.in_(names), FleetModel.deleted == False, ) - .order_by(FleetModel.id) # take locks in order - .with_for_update(key_share=True) + .order_by(FleetModel.id) ) fleets_ids = list(res.scalars().unique().all()) res = await session.execute( @@ -639,8 +683,7 @@ async def delete_fleets( InstanceModel.fleet_id.in_(fleets_ids), InstanceModel.deleted == False, ) - .order_by(InstanceModel.id) # take locks in order - .with_for_update(key_share=True) + .order_by(InstanceModel.id) ) instances_ids = list(res.scalars().unique().all()) if is_db_sqlite(): @@ -654,22 +697,56 @@ async def delete_fleets( # TODO: Do not lock fleet when deleting only instances. res = await session.execute( select(FleetModel) - .where(FleetModel.id.in_(fleets_ids)) + .where( + FleetModel.project_id == project.id, + FleetModel.id.in_(fleets_ids), + FleetModel.deleted == False, + FleetModel.lock_expires_at.is_(None), + ) .options( - joinedload(FleetModel.instances.and_(InstanceModel.id.in_(instances_ids))) - .joinedload(InstanceModel.jobs) + selectinload(FleetModel.instances.and_(InstanceModel.id.in_(instances_ids))) + .selectinload(InstanceModel.jobs) .load_only(JobModel.id) ) .options( - joinedload( + selectinload( FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses())) - ) + ).load_only(RunModel.status) ) .execution_options(populate_existing=True) + .order_by(FleetModel.id) # take locks in order + .with_for_update(key_share=True, of=FleetModel) ) fleet_models = res.scalars().unique().all() - fleets = [fleet_model_to_fleet(m) for m in fleet_models] - for fleet in fleets: + if len(fleet_models) != len(fleets_ids): + # TODO: Make the endpoint fully async so we don't need to lock and error: + # put the request in queue and process in the background. + msg = ( + "Failed to delete fleets: fleets are being processed currently. Try again later." + if instance_nums is None + else "Failed to delete fleet instances: fleets are being processed currently. Try again later." + ) + raise ServerClientError(msg) + res = await session.execute( + select(InstanceModel.id) + .where( + InstanceModel.id.in_(instances_ids), + InstanceModel.deleted == False, + ) + .order_by(InstanceModel.id) # take locks in order + .with_for_update(key_share=True, of=InstanceModel) + .execution_options(populate_existing=True) + ) + instance_models_ids = list(res.scalars().unique().all()) + if len(instance_models_ids) != len(instances_ids): + msg = ( + "Failed to delete fleets: fleet instances are being processed currently. Try again later." + if instance_nums is None + else "Failed to delete fleet instances: fleet instances are being processed currently. Try again later." + ) + raise ServerClientError(msg) + for fleet_model in fleet_models: + fleet = fleet_model_to_fleet(fleet_model) if fleet.spec.configuration.ssh_config is not None: _check_can_manage_ssh_fleets(user=user, project=project) if instance_nums is None: diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 762af8bef..ddc3d64c4 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -356,7 +356,7 @@ async def _delete_gateways_pipeline( ) gateway_models = res.scalars().all() if len(gateway_models) != len(gateways_ids): - # TODO: Make the delete endpoint fully async so we don't need to lock and error: + # TODO: Make the endpoint fully async so we don't need to lock and error: # put the request in queue and process in the background. raise ServerClientError( "Failed to delete gateways: gateways are being processed currently. Try again later." diff --git a/src/dstack/_internal/server/services/volumes.py b/src/dstack/_internal/server/services/volumes.py index f0d2fc703..1c846c724 100644 --- a/src/dstack/_internal/server/services/volumes.py +++ b/src/dstack/_internal/server/services/volumes.py @@ -369,7 +369,7 @@ async def _delete_volumes_pipeline( ) volume_models = res.scalars().unique().all() if len(volume_models) != len(volumes_ids): - # TODO: Make the delete endpoint fully async so we don't need to lock and error: + # TODO: Make the endpoint fully async so we don't need to lock and error: # put the request in queue and process in the background. raise ServerClientError( "Failed to delete volumes: volumes are being processed currently. Try again later." diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py new file mode 100644 index 000000000..746ddf2ea --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py @@ -0,0 +1,398 @@ +import uuid +from datetime import datetime, timezone +from unittest.mock import Mock + +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.fleets import FleetNodesSpec, FleetStatus +from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.core.models.runs import RunStatus +from dstack._internal.core.models.users import GlobalRole, ProjectRole +from dstack._internal.server.background.pipeline_tasks.base import PipelineItem +from dstack._internal.server.background.pipeline_tasks.fleets import ( + FleetWorker, +) +from dstack._internal.server.models import FleetModel, InstanceModel +from dstack._internal.server.services.projects import add_project_member +from dstack._internal.server.testing.common import ( + create_fleet, + create_instance, + create_placement_group, + create_project, + create_repo, + create_run, + create_user, + get_fleet_spec, +) + + +@pytest.fixture +def worker() -> FleetWorker: + return FleetWorker(queue=Mock(), heartbeater=Mock()) + + +def _fleet_to_pipeline_item(fleet: FleetModel) -> PipelineItem: + assert fleet.lock_token is not None + assert fleet.lock_expires_at is not None + return PipelineItem( + __tablename__=fleet.__tablename__, + id=fleet.id, + lock_token=fleet.lock_token, + lock_expires_at=fleet.lock_expires_at, + prev_lock_expired=False, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestFleetWorker: + async def test_deletes_empty_autocreated_fleet( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.autocreated = True + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + assert fleet.deleted + + async def test_deletes_terminating_user_fleet( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.autocreated = False + fleet = await create_fleet( + session=session, + project=project, + status=FleetStatus.TERMINATING, + ) + + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + assert fleet.deleted + + async def test_does_not_delete_fleet_with_active_run( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + ) + user = await create_user(session=session, global_role=GlobalRole.USER) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + repo = await create_repo( + session=session, + project_id=project.id, + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.RUNNING, + ) + fleet.runs.append(run) + + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + assert not fleet.deleted + + async def test_does_not_delete_fleet_with_instance( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + ) + user = await create_user(session=session, global_role=GlobalRole.USER) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + ) + fleet.instances.append(instance) + + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + assert not fleet.deleted + + async def test_consolidation_creates_missing_instances( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=2, target=2, max=2) + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + instance_num=1, + ) + + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + instances = (await session.execute(select(InstanceModel))).scalars().all() + assert len(instances) == 2 + assert {i.instance_num for i in instances} == {0, 1} + assert fleet.consolidation_attempt == 1 + + async def test_consolidation_terminates_redundant_instances( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=1) + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + instance1 = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.BUSY, + instance_num=0, + ) + instance2 = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + instance_num=1, + ) + instance3 = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.TERMINATED, + instance_num=2, + ) + + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(instance1) + await session.refresh(instance2) + await session.refresh(instance3) + assert instance1.status == InstanceStatus.BUSY + assert instance2.status == InstanceStatus.TERMINATING + assert instance3.deleted + assert fleet.consolidation_attempt == 1 + + async def test_consolidation_attempt_increments_when_over_max_and_no_idle_instances( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=1) + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + instance1 = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.BUSY, + instance_num=0, + ) + instance2 = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.BUSY, + instance_num=1, + ) + + fleet.consolidation_attempt = 2 + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(instance1) + await session.refresh(instance2) + assert instance1.status == InstanceStatus.BUSY + assert instance2.status == InstanceStatus.BUSY + assert fleet.consolidation_attempt == 3 + + async def test_marks_placement_groups_fleet_deleted_on_fleet_delete( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + status=FleetStatus.TERMINATING, + ) + placement_group1 = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="test-pg-1", + ) + placement_group2 = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="test-pg-2", + ) + + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + await session.refresh(placement_group1) + await session.refresh(placement_group2) + assert fleet.deleted + assert placement_group1.fleet_deleted + assert placement_group2.fleet_deleted + + async def test_consolidation_respects_retry_delay( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=2, target=2, max=2) + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + instance_num=0, + ) + fleet.consolidation_attempt = 1 + fleet.last_consolidated_at = datetime.now(timezone.utc) + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + instances = ( + ( + await session.execute( + select(InstanceModel).where( + InstanceModel.fleet_id == fleet.id, + InstanceModel.deleted == False, + ) + ) + ) + .scalars() + .all() + ) + assert len(instances) == 1 + assert fleet.consolidation_attempt == 1 + assert not fleet.deleted + + async def test_consolidation_attempt_resets_when_no_changes( + self, test_db, session: AsyncSession, worker: FleetWorker + ): + project = await create_project(session) + spec = get_fleet_spec() + spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=1) + fleet = await create_fleet( + session=session, + project=project, + spec=spec, + ) + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + instance_num=0, + ) + fleet.consolidation_attempt = 3 + previous_last_consolidated_at = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc) + fleet.last_consolidated_at = previous_last_consolidated_at + fleet.lock_token = uuid.uuid4() + fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + await worker.process(_fleet_to_pipeline_item(fleet)) + + await session.refresh(fleet) + instances = ( + ( + await session.execute( + select(InstanceModel).where( + InstanceModel.fleet_id == fleet.id, + InstanceModel.deleted == False, + ) + ) + ) + .scalars() + .all() + ) + assert len(instances) == 1 + assert fleet.consolidation_attempt == 0 + assert ( + fleet.last_consolidated_at is not None + and fleet.last_consolidated_at > previous_last_consolidated_at + ) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py b/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py index 9628451bd..59cbd370e 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py @@ -257,6 +257,7 @@ async def test_keeps_gateway_if_terminate_fails( ) gateway.lock_token = uuid.uuid4() gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + gateway.lock_owner = "GatewayPipeline" gateway.to_be_deleted = True original_last_processed_at = gateway.last_processed_at await session.commit() @@ -286,6 +287,9 @@ async def test_keeps_gateway_if_terminate_fails( await session.refresh(gateway_compute) assert gateway.to_be_deleted is True assert gateway.last_processed_at > original_last_processed_at + assert gateway.lock_token is None + assert gateway.lock_expires_at is None + assert gateway.lock_owner is None assert gateway_compute.active is True assert gateway_compute.deleted is False events = await list_events(session) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py b/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py index 7baed58b6..c23d5e604 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_placement_groups.py @@ -5,6 +5,7 @@ import pytest from sqlalchemy.ext.asyncio import AsyncSession +from dstack._internal.core.errors import PlacementGroupInUseError from dstack._internal.server.background.pipeline_tasks.base import PipelineItem from dstack._internal.server.background.pipeline_tasks.placement_groups import PlacementGroupWorker from dstack._internal.server.models import PlacementGroupModel @@ -62,3 +63,41 @@ async def test_deletes_placement_group( aws_mock.compute.return_value.delete_placement_group.assert_called_once() await session.refresh(placement_group) assert placement_group.deleted + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_retries_placement_group_deletion_if_still_in_use( + self, test_db, session: AsyncSession, worker: PlacementGroupWorker + ): + project = await create_project(session) + fleet = await create_fleet( + session=session, + project=project, + ) + placement_group = await create_placement_group( + session=session, + project=project, + fleet=fleet, + name="test2-pg", + fleet_deleted=True, + ) + placement_group.lock_token = uuid.uuid4() + placement_group.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + placement_group.lock_owner = "PlacementGroupPipeline" + original_last_processed_at = placement_group.last_processed_at + await session.commit() + with patch("dstack._internal.server.services.backends.get_project_backend_by_type") as m: + aws_mock = Mock() + m.return_value = aws_mock + aws_mock.compute.return_value = Mock(spec=ComputeMockSpec) + aws_mock.compute.return_value.delete_placement_group.side_effect = ( + PlacementGroupInUseError() + ) + await worker.process(_placement_group_to_pipeline_item(placement_group)) + aws_mock.compute.return_value.delete_placement_group.assert_called_once() + await session.refresh(placement_group) + assert not placement_group.deleted + assert placement_group.last_processed_at > original_last_processed_at + assert placement_group.lock_token is None + assert placement_group.lock_expires_at is None + assert placement_group.lock_owner is None diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_fleets.py b/src/tests/_internal/server/background/scheduled_tasks/test_fleets.py index 2ef1b27ab..2136a2c96 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_fleets.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_fleets.py @@ -154,8 +154,17 @@ async def test_consolidation_terminates_redundant_instances( status=InstanceStatus.IDLE, instance_num=1, ) + instance3 = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.TERMINATED, + instance_num=2, + ) await process_fleets() await session.refresh(instance1) await session.refresh(instance2) + await session.refresh(instance3) assert instance1.status == InstanceStatus.BUSY assert instance2.status == InstanceStatus.TERMINATING + assert instance3.deleted diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index 1a250612b..02a4430b7 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -931,6 +931,37 @@ async def test_returns_400_when_fleets_in_use( assert not fleet.deleted assert instance.status == InstanceStatus.BUSY + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_400_when_fleet_locked( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + fleet = await create_fleet(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + ) + fleet.instances.append(instance) + fleet.lock_expires_at = datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc) + await session.commit() + + response = await client.post( + f"/api/project/{project.name}/fleets/delete", + headers=get_auth_headers(user.token), + json={"names": [fleet.name]}, + ) + assert response.status_code == 400 + + await session.refresh(fleet) + await session.refresh(instance) + assert fleet.status != FleetStatus.TERMINATING + assert instance.status != InstanceStatus.TERMINATING + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_forbids_if_no_permission_to_manage_ssh_fleets( @@ -1057,6 +1088,38 @@ async def test_returns_400_when_deleting_busy_instances( assert instance.status != InstanceStatus.TERMINATING assert fleet.status != FleetStatus.TERMINATING + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_400_when_fleet_locked( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + fleet = await create_fleet(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + instance_num=1, + ) + fleet.instances.append(instance) + fleet.lock_expires_at = datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc) + await session.commit() + + response = await client.post( + f"/api/project/{project.name}/fleets/delete_instances", + headers=get_auth_headers(user.token), + json={"name": fleet.name, "instance_nums": [1]}, + ) + assert response.status_code == 400 + + await session.refresh(fleet) + await session.refresh(instance) + assert fleet.status != FleetStatus.TERMINATING + assert instance.status != InstanceStatus.TERMINATING + class TestGetPlan: @pytest.mark.asyncio