Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions axlearn/cloud/gcp/jobset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ class Config(SingleReplicatedJob.Config):
to attach to the node pool. This is needed to support multiple NIC.
Refer to GKE TPU provisioner for more context:
https://github.com/GoogleCloudPlatform/ai-on-gke/blob/5f256eed7075a5cb8e73cd72328aea46237b8ce6/tpu-provisioner/internal/cloud/common.go#L29-L31
scheduler: Optional; The GKE Scheduler for the job.
"""

reservation: Optional[str] = None
Expand All @@ -386,6 +387,7 @@ class Config(SingleReplicatedJob.Config):
enable_tpu_smart_repair: bool = False
priority_class: Optional[str] = None
additional_node_networks: Optional[str] = None
scheduler: Optional[str] = None

@classmethod
def define_flags(cls, fv: flags.FlagValues):
Expand All @@ -411,6 +413,13 @@ def define_flags(cls, fv: flags.FlagValues):
**common_kwargs,
)

flags.DEFINE_string(
"scheduler",
None,
"The GKE Scheduler for the job.",
**common_kwargs,
)

@classmethod
def from_flags(cls, fv: flags.FlagValues, **kwargs) -> Config:
cfg: TPUJobBuilder.Config = super().from_flags(fv, **kwargs)
Expand Down Expand Up @@ -759,6 +768,9 @@ def _build_pod(self) -> Nested[Any]:
spec["hostNetwork"] = True
spec["dnsPolicy"] = "ClusterFirstWithHostNet"

if cfg.scheduler:
spec["schedulerName"] = cfg.scheduler

return dict(
metadata=dict(annotations=annotations, labels=labels),
spec=spec,
Expand Down
Loading