Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dstack._internal.server.background.pipeline_tasks.base import Pipeline
from dstack._internal.server.background.pipeline_tasks.compute_groups import ComputeGroupPipeline
from dstack._internal.server.background.pipeline_tasks.fleets import FleetPipeline
from dstack._internal.server.background.pipeline_tasks.gateways import GatewayPipeline
from dstack._internal.server.background.pipeline_tasks.placement_groups import (
PlacementGroupPipeline,
Expand All @@ -16,6 +17,7 @@ class PipelineManager:
def __init__(self) -> None:
self._pipelines: list[Pipeline] = [
ComputeGroupPipeline(),
FleetPipeline(),
GatewayPipeline(),
PlacementGroupPipeline(),
VolumePipeline(),
Expand Down
86 changes: 76 additions & 10 deletions src/dstack/_internal/server/background/pipeline_tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,20 @@
import random
import uuid
from abc import ABC, abstractmethod
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any, ClassVar, Generic, Optional, Protocol, Sequence, TypeVar
from typing import (
Any,
ClassVar,
Final,
Generic,
Optional,
Protocol,
TypedDict,
TypeVar,
Union,
)

from sqlalchemy import and_, or_, update
from sqlalchemy.orm import Mapped
Expand Down Expand Up @@ -337,16 +348,71 @@ async def process(self, item: ItemT):
pass


UpdateMap = dict[str, Any]
class _NowPlaceholder:
pass


NOW_PLACEHOLDER: Final = _NowPlaceholder()
"""
Use `NOW_PLACEHOLDER` together with `resolve_now_placeholders()` in pipeline update maps
instead of `get_current_time()` to have the same current time for all updates in the transaction.
"""


UpdateMapDateTime = Union[datetime, _NowPlaceholder]


class _UnlockUpdateMap(TypedDict, total=False):
lock_expires_at: Optional[datetime]
lock_token: Optional[uuid.UUID]
lock_owner: Optional[str]


class _ProcessedUpdateMap(TypedDict, total=False):
last_processed_at: UpdateMapDateTime


class ItemUpdateMap(_UnlockUpdateMap, _ProcessedUpdateMap, total=False):
lock_expires_at: Optional[datetime]
lock_token: Optional[uuid.UUID]
lock_owner: Optional[str]
last_processed_at: UpdateMapDateTime

def get_unlock_update_map() -> UpdateMap:
return {
"lock_expires_at": None,
"lock_token": None,
"lock_owner": None,
}

def set_unlock_update_map_fields(update_map: _UnlockUpdateMap):
update_map["lock_expires_at"] = None
update_map["lock_token"] = None
update_map["lock_owner"] = None

def get_processed_update_map() -> UpdateMap:
return {"last_processed_at": get_current_datetime()}

def set_processed_update_map_fields(
update_map: _ProcessedUpdateMap,
now: UpdateMapDateTime = NOW_PLACEHOLDER,
):
update_map["last_processed_at"] = now


class _ResolveNowUpdateMap(Protocol):
def items(self) -> Iterable[tuple[str, object]]: ...


_ResolveNowInput = Union[_ResolveNowUpdateMap, Sequence[_ResolveNowUpdateMap]]


def resolve_now_placeholders(update_values: _ResolveNowInput, now: datetime):
"""
Replaces `NOW_PLACEHOLDER` with `now` in an update map or a sequence of update rows.
"""
if isinstance(update_values, Sequence):
for update_row in update_values:
resolve_now_placeholders(update_row, now)
return
# Runtime dict narrowing is required here: pyright doesn't model TypedDicts as
# supporting generic dynamic-key mutation via protocol methods.
if not isinstance(update_values, dict):
raise TypeError(
"resolve_now_placeholders() expects update maps or sequences of update maps"
)
for key, value in update_values.items():
if value is NOW_PLACEHOLDER:
update_values[key] = now
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Sequence
from typing import Sequence, TypedDict

from sqlalchemy import or_, select, update
from sqlalchemy.orm import joinedload, load_only
Expand All @@ -12,14 +12,17 @@
from dstack._internal.core.models.compute_groups import ComputeGroupStatus
from dstack._internal.core.models.instances import InstanceStatus
from dstack._internal.server.background.pipeline_tasks.base import (
NOW_PLACEHOLDER,
Fetcher,
Heartbeater,
ItemUpdateMap,
Pipeline,
PipelineItem,
UpdateMap,
UpdateMapDateTime,
Worker,
get_processed_update_map,
get_unlock_update_map,
resolve_now_placeholders,
set_processed_update_map_fields,
set_unlock_update_map_fields,
)
from dstack._internal.server.db import get_db, get_session_ctx
from dstack._internal.server.models import ComputeGroupModel, InstanceModel, ProjectModel
Expand Down Expand Up @@ -199,25 +202,28 @@ async def process(self, item: PipelineItem):
)
return

terminate_result = _TerminateResult()
result = _TerminateResult()
# TODO: Fetch only compute groups with all instances terminating.
if all(i.status == InstanceStatus.TERMINATING for i in compute_group_model.instances):
terminate_result = await _terminate_compute_group(compute_group_model)
if terminate_result.compute_group_update_map:
result = await _terminate_compute_group(compute_group_model)
set_processed_update_map_fields(result.compute_group_update_map)
if result.instances_update_map:
set_processed_update_map_fields(result.instances_update_map)
set_unlock_update_map_fields(result.compute_group_update_map)
if result.compute_group_update_map.get("deleted", False):
logger.info("Terminated compute group %s", compute_group_model.id)
else:
terminate_result.compute_group_update_map = get_processed_update_map()

terminate_result.compute_group_update_map |= get_unlock_update_map()

async with get_session_ctx() as session:
now = get_current_datetime()
resolve_now_placeholders(result.compute_group_update_map, now=now)
resolve_now_placeholders(result.instances_update_map, now=now)
res = await session.execute(
update(ComputeGroupModel)
.where(
ComputeGroupModel.id == compute_group_model.id,
ComputeGroupModel.lock_token == compute_group_model.lock_token,
)
.values(**terminate_result.compute_group_update_map)
.values(**result.compute_group_update_map)
.returning(ComputeGroupModel.id)
)
updated_ids = list(res.scalars().all())
Expand All @@ -229,13 +235,13 @@ async def process(self, item: PipelineItem):
item.id,
)
return
if not terminate_result.instances_update_map:
if not result.instances_update_map:
return
instances_ids = [i.id for i in compute_group_model.instances]
res = await session.execute(
update(InstanceModel)
.where(InstanceModel.id.in_(instances_ids))
.values(**terminate_result.instances_update_map)
.values(**result.instances_update_map)
)
for instance_model in compute_group_model.instances:
emit_instance_status_change_event(
Expand All @@ -246,10 +252,28 @@ async def process(self, item: PipelineItem):
)


class _ComputeGroupUpdateMap(ItemUpdateMap, total=False):
status: ComputeGroupStatus
deleted: bool
deleted_at: UpdateMapDateTime
first_termination_retry_at: UpdateMapDateTime
last_termination_retry_at: UpdateMapDateTime


class _InstanceBulkUpdateMap(TypedDict, total=False):
last_processed_at: UpdateMapDateTime
deleted: bool
deleted_at: UpdateMapDateTime
finished_at: UpdateMapDateTime
status: InstanceStatus


@dataclass
class _TerminateResult:
compute_group_update_map: UpdateMap = field(default_factory=dict)
instances_update_map: UpdateMap = field(default_factory=dict)
compute_group_update_map: _ComputeGroupUpdateMap = field(
default_factory=_ComputeGroupUpdateMap
)
instances_update_map: _InstanceBulkUpdateMap = field(default_factory=_InstanceBulkUpdateMap)


async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _TerminateResult:
Expand Down Expand Up @@ -283,15 +307,15 @@ async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _T
compute_group,
)
except Exception as e:
retry_at = get_current_datetime()
first_termination_retry_at = compute_group_model.first_termination_retry_at
if compute_group_model.first_termination_retry_at is None:
result.compute_group_update_map["first_termination_retry_at"] = get_current_datetime()
result.compute_group_update_map["last_termination_retry_at"] = get_current_datetime()
if _next_termination_retry_at(
result.compute_group_update_map["last_termination_retry_at"]
) < _get_termination_deadline(
result.compute_group_update_map.get(
"first_termination_retry_at", compute_group_model.first_termination_retry_at
)
result.compute_group_update_map["first_termination_retry_at"] = NOW_PLACEHOLDER
first_termination_retry_at = retry_at
assert first_termination_retry_at is not None
result.compute_group_update_map["last_termination_retry_at"] = NOW_PLACEHOLDER
if _next_termination_retry_at(retry_at) < _get_termination_deadline(
first_termination_retry_at
):
logger.warning(
"Failed to terminate compute group %s. Will retry. Error: %r",
Expand All @@ -309,11 +333,9 @@ async def _terminate_compute_group(compute_group_model: ComputeGroupModel) -> _T
exc_info=not isinstance(e, BackendError),
)
terminated_result = _get_terminated_result()
return _TerminateResult(
compute_group_update_map=result.compute_group_update_map
| terminated_result.compute_group_update_map,
instances_update_map=result.instances_update_map | terminated_result.instances_update_map,
)
terminated_result.compute_group_update_map.update(result.compute_group_update_map)
terminated_result.instances_update_map.update(result.instances_update_map)
return terminated_result


def _next_termination_retry_at(last_termination_retry_at: datetime) -> datetime:
Expand All @@ -325,19 +347,16 @@ def _get_termination_deadline(first_termination_retry_at: datetime) -> datetime:


def _get_terminated_result() -> _TerminateResult:
now = get_current_datetime()
return _TerminateResult(
compute_group_update_map={
"last_processed_at": now,
"deleted": True,
"deleted_at": now,
"deleted_at": NOW_PLACEHOLDER,
"status": ComputeGroupStatus.TERMINATED,
},
instances_update_map={
"last_processed_at": now,
"deleted": True,
"deleted_at": now,
"finished_at": now,
"deleted_at": NOW_PLACEHOLDER,
"finished_at": NOW_PLACEHOLDER,
"status": InstanceStatus.TERMINATED,
},
)
Loading