Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/clustered/clustered_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,52 @@ func TestGetTaskPhase_Running(t *testing.T) {
assert.Equal(t, pluginsCore.PhaseRunning, phase.Phase())
}

// --- fast-fail / maintenance tests ---

func TestGetTaskPhase_FastFail_NoJobsFailed(t *testing.T) {
// When no jobs have failed in ReplicatedJobsStatus, the fast-fail path is not taken.
js := makeJobSet("", "", false)
// Explicitly set workers status with Failed=0.
js.Status.ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus{
{Name: "workers", Failed: 0, Active: 2},
}
// Add an active condition so the switch falls through to running.
js.Status.Conditions = []metav1.Condition{
{
Type: "SomeActiveCondition",
Status: metav1.ConditionTrue,
LastTransitionTime: metav1.NewTime(time.Now()),
},
}

spec := &clusteredpb.ClusteredTaskSpec{Replicas: 2, NprocPerNode: 1}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
assert.NoError(t, err)
// No pod inspection happens — returns Running.
assert.Equal(t, pluginsCore.PhaseRunning, phase.Phase())
}

func TestGetTaskPhase_MaintenanceRetry_FlagFalse(t *testing.T) {
// With RestartOnHostMaintenance=false (default), JobSetFailed always becomes RetryableFailure.
js := makeJobSet(jobsetv1alpha2.JobSetFailed, metav1.ConditionTrue, false)

spec := &clusteredpb.ClusteredTaskSpec{
Replicas: 2,
NprocPerNode: 1,
FailurePolicy: &clusteredpb.ClusterFailurePolicy{RestartOnHostMaintenance: false},
}
pCtx := dummyPluginCtx(buildTaskTemplate(spec))

handler := clusteredResourceHandler{}
phase, err := handler.GetTaskPhase(context.Background(), pCtx, js)
assert.NoError(t, err)
// Flag is false → no pod lookup → normal retryable failure.
assert.Equal(t, pluginsCore.PhaseRetryableFailure, phase.Phase())
}

// --- IsTerminal / GetCompletionTime ---

func TestIsTerminal(t *testing.T) {
Expand Down
60 changes: 59 additions & 1 deletion flyteplugins/go/tasks/plugins/k8s/clustered/phase.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ import (
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"

pluginsCore "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
clusteredpb "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/plugins"
)

func (clusteredResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
Expand All @@ -21,6 +24,12 @@ func (clusteredResourceHandler) GetTaskPhase(ctx context.Context, pluginContext
return pluginsCore.PhaseInfoUndefined, fmt.Errorf("unexpected resource type %T", resource)
}

// Read spec for failure-policy flags (restart_on_host_maintenance).
var spec clusteredpb.ClusteredTaskSpec
if taskTemplate, err := pluginContext.TaskReader().Read(ctx); err == nil && taskTemplate != nil {
_ = utils.UnmarshalStruct(taskTemplate.GetCustom(), &spec) //nolint:staticcheck
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we raise an error if it fails to unmarshal?

}

taskLogs, err := getTaskLogs(ctx, pluginContext, jobSet)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
Expand All @@ -36,7 +45,7 @@ func (clusteredResourceHandler) GetTaskPhase(ctx context.Context, pluginContext

condition := extractCurrentCondition(jobSet.Status.Conditions)
if condition == nil {
// JobSet exists, not suspended, no terminal condition yet.
// JobSet exists, no terminal condition yet.
return pluginsCore.PhaseInfoInitializing(occurredAt, pluginsCore.DefaultPhaseVersion, "pods scheduling / DNS resolving", &taskInfo), nil
}

Expand All @@ -45,9 +54,19 @@ func (clusteredResourceHandler) GetTaskPhase(ctx context.Context, pluginContext
return pluginsCore.PhaseInfoSuccess(&taskInfo), nil

case jobsetv1alpha2.JobSetFailed:
if spec.GetFailurePolicy().GetRestartOnHostMaintenance() {
if phase, ok := maybeSystemRetryOnMaintenance(ctx, pluginContext, jobSet, &taskInfo); ok {
return phase, nil
}
}
return pluginsCore.PhaseInfoRetryableFailure(condition.Reason, condition.Message, &taskInfo), nil
}

// Running: check for fast-fail before reporting PhaseRunning.
if phase, ok := maybeFastFailWorker0(ctx, pluginContext, jobSet, &taskInfo); ok {
return phase, nil
}

phaseInfo := pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &taskInfo)

if err := k8s.MaybeUpdatePhaseVersionFromPluginContext(&phaseInfo, &pluginContext); err != nil {
Expand All @@ -56,6 +75,45 @@ func (clusteredResourceHandler) GetTaskPhase(ctx context.Context, pluginContext
return phaseInfo, nil
}

// maybeFastFailWorker0 inspects the rank-0 pod when at least one Job under the "workers"
// ReplicatedJob has failed. Returns (phaseInfo, true) if the pod is in a terminal failed state.
// This surfaces the failure before the JobSet controller sets JobSetFailed, reducing tail latency.
func maybeFastFailWorker0(ctx context.Context, pluginContext k8s.PluginContext, jobSet *jobsetv1alpha2.JobSet, taskInfo *pluginsCore.TaskInfo) (pluginsCore.PhaseInfo, bool) {
for _, s := range jobSet.Status.ReplicatedJobsStatus {
if s.Name == "workers" && s.Failed > 0 {
podName := fmt.Sprintf("%s-workers-0-0", jobSet.Name)
containerName := pluginContext.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
phase, err := flytek8s.DemystifyFailedOrPendingPod(ctx, pluginContext, *taskInfo, jobSet.Namespace, podName, containerName)
if err != nil {
return pluginsCore.PhaseInfoUndefined, false
}
if phase.Phase().IsFailure() {
return phase, true
}
}
}
return pluginsCore.PhaseInfoUndefined, false
}

// maybeSystemRetryOnMaintenance inspects the rank-0 pod after a JobSetFailed condition.
// If the pod was evicted due to host maintenance (system-retryable), returns
// PhaseInfoSystemRetryableFailureWithCleanup so Flyte retries without charging user's max_restarts.
// Best-effort: if the pod is already cleaned up, returns (_, false) and the caller falls through.
func maybeSystemRetryOnMaintenance(ctx context.Context, pluginContext k8s.PluginContext, jobSet *jobsetv1alpha2.JobSet, taskInfo *pluginsCore.TaskInfo) (pluginsCore.PhaseInfo, bool) {
podName := fmt.Sprintf("%s-workers-0-0", jobSet.Name)
containerName := pluginContext.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
phase, err := flytek8s.DemystifyFailedOrPendingPod(ctx, pluginContext, *taskInfo, jobSet.Namespace, podName, containerName)
if err != nil {
return pluginsCore.PhaseInfoUndefined, false
}
if phase.Phase() == pluginsCore.PhaseRetryableFailure && phase.Err() != nil && phase.Err().GetKind() == core.ExecutionError_SYSTEM {
return pluginsCore.PhaseInfoSystemRetryableFailureWithCleanup(
"HostMaintenance", "pod evicted due to host maintenance; retrying without charging max_restarts", taskInfo,
), true
}
return pluginsCore.PhaseInfoUndefined, false
}

// extractCurrentCondition returns the most recently transitioned condition with Status=True, or nil.
// Ported from kfoperators/common/common_operator.go — not imported to avoid the dependency.
func extractCurrentCondition(conditions []metav1.Condition) *metav1.Condition {
Expand Down
Loading