4848# There is no guarantee that this image will work with newer Jax releases.
4949# This image version extends GRPC timeout for long context models, based on jax-0.5.3-patch060625
5050# This image extends GRPC timeout for long context models.
51- _PATHWAYS_IMAGE_TAG = "disable_settings_20250701"
51+ # _PATHWAYS_IMAGE_TAG = "disable_settings_20250701"
52+ _PATHWAYS_IMAGE_TAG = "uds"
5253# The docker image used by pathways proxy container.
5354_PATHWAYS_PROXY_IMAGE = (
54- f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:{ _PATHWAYS_IMAGE_TAG } "
55+ # f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:{_PATHWAYS_IMAGE_TAG}"
56+ "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/shauryag/"
57+ f"unsanitized_proxy_server:{ _PATHWAYS_IMAGE_TAG } "
5558)
5659# The docker image used by pathways resource manager container and worker container.
5760_PATHWAYS_SERVER_IMAGE = (
58- f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:{ _PATHWAYS_IMAGE_TAG } "
61+ # f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:{_PATHWAYS_IMAGE_TAG}"
62+ "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/shauryag/"
63+ f"unsanitized_server:{ _PATHWAYS_IMAGE_TAG } "
5964)
6065# The container name of pathways resourcemanager.
6166_PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME = "pathways-rm"
@@ -269,10 +274,16 @@ def _build_pathways_head_container(self) -> dict:
269274 head_container = copy .deepcopy (container )
270275
271276 env_list = head_container .get ("env" , [])
277+ # self._update_env_list(
278+ # env_list,
279+ # "JAX_BACKEND_TARGET",
280+ # f"grpc://localhost:{_PATHWAYS_PROXY_PORT}",
281+ # )
282+ # Unix domain socket
272283 self ._update_env_list (
273284 env_list ,
274285 "JAX_BACKEND_TARGET" ,
275- f "grpc://localhost: { _PATHWAYS_PROXY_PORT } " ,
286+ "grpc:///tmp/ifrt_proxy.sock " ,
276287 )
277288 self ._update_env_list (env_list , "XCLOUD_ENVIRONMENT" , "GCP" )
278289 self ._update_env_list (env_list , "JAX_PLATFORMS" , "proxy" )
@@ -327,6 +338,10 @@ def _build_pathways_head_container(self) -> dict:
327338 }
328339 head_container ["resources" ] = resources
329340
341+ volume_mounts = head_container .get ("volumeMounts" , [])
342+ volume_mounts .append (dict (name = "shared-memory" , mountPath = "/tmp/" ))
343+ head_container ["volumeMounts" ] = volume_mounts
344+
330345 return head_container
331346
332347 def _build_pathways_head_sidecar_containers (self ) -> list [Nested [Any ]]:
@@ -350,6 +365,7 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
350365
351366 cmd_args = [
352367 f"--resource_manager_address=localhost:{ _PATHWAYS_RESOURCE_MANAGER_PORT } " ,
368+ # using unix socket but port needs to be set anyway
353369 f"--server_port={ _PATHWAYS_PROXY_PORT } " ,
354370 f"--gcs_scratch_location={ staging_location } " ,
355371 ]
@@ -374,7 +390,10 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
374390 {"name" : "XLA_FLAGS" , "value" : f"--xla_dump_to=/output/{ cfg .name } /xla" },
375391 ],
376392 ports = [dict (containerPort = _PATHWAYS_PROXY_PORT )],
377- volumeMounts = [dict (name = "shared-output" , mountPath = "/output" )],
393+ volumeMounts = [
394+ dict (name = "shared-output" , mountPath = "/output" ),
395+ dict (name = "shared-memory" , mountPath = "/tmp/" ),
396+ ],
378397 ),
379398 dict (
380399 name = _PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME ,
@@ -412,6 +431,7 @@ def _build_pathways_head_pod(self) -> Nested[Any]:
412431 labels .update ({BASTION_JOB_VERSION_LABEL : os .environ .get (BASTION_JOB_VERSION_ENV_VAR )})
413432
414433 volumes .append (dict (name = "shared-output" , emptyDir = {}))
434+ volumes .append (dict (name = "shared-memory" , emptyDir = dict (medium = "Memory" )))
415435
416436 if cfg .gcsfuse_mount :
417437 annotations .update (
0 commit comments