diff --git a/axlearn/cloud/gcp/jobset_utils.py b/axlearn/cloud/gcp/jobset_utils.py index 0cdd34700..90a8b81f2 100644 --- a/axlearn/cloud/gcp/jobset_utils.py +++ b/axlearn/cloud/gcp/jobset_utils.py @@ -377,6 +377,7 @@ class Config(SingleReplicatedJob.Config): to attach to the node pool. This is needed to support multiple NIC. Refer to GKE TPU provisioner for more context: https://github.com/GoogleCloudPlatform/ai-on-gke/blob/5f256eed7075a5cb8e73cd72328aea46237b8ce6/tpu-provisioner/internal/cloud/common.go#L29-L31 + scheduler: Optional; The GKE Scheduler for the job. """ reservation: Optional[str] = None @@ -386,6 +387,7 @@ class Config(SingleReplicatedJob.Config): enable_tpu_smart_repair: bool = False priority_class: Optional[str] = None additional_node_networks: Optional[str] = None + scheduler: Optional[str] = None @classmethod def define_flags(cls, fv: flags.FlagValues): @@ -411,6 +413,13 @@ def define_flags(cls, fv: flags.FlagValues): **common_kwargs, ) + flags.DEFINE_string( + "scheduler", + None, + "The GKE Scheduler for the job.", + **common_kwargs, + ) + @classmethod def from_flags(cls, fv: flags.FlagValues, **kwargs) -> Config: cfg: TPUJobBuilder.Config = super().from_flags(fv, **kwargs) @@ -759,6 +768,9 @@ def _build_pod(self) -> Nested[Any]: spec["hostNetwork"] = True spec["dnsPolicy"] = "ClusterFirstWithHostNet" + if cfg.scheduler: + spec["schedulerName"] = cfg.scheduler + return dict( metadata=dict(annotations=annotations, labels=labels), spec=spec,