Skip to content

Commit cf24047

Browse files
committed
wip: unix socket pathways proxy
1 parent 08718bd commit cf24047

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

axlearn/cloud/gcp/pathways_utils.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,19 @@
4848
# There is no guarantee that this image will work with newer Jax releases.
4949
# This image version extends GRPC timeout for long context models, based on jax-0.5.3-patch060625
5050
# This image extends GRPC timeout for long context models.
51-
_PATHWAYS_IMAGE_TAG = "disable_settings_20250701"
51+
# _PATHWAYS_IMAGE_TAG = "disable_settings_20250701"
52+
_PATHWAYS_IMAGE_TAG = "uds"
5253
# The docker image used by pathways proxy container.
5354
_PATHWAYS_PROXY_IMAGE = (
54-
f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:{_PATHWAYS_IMAGE_TAG}"
55+
# f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:{_PATHWAYS_IMAGE_TAG}"
56+
"us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/shauryag/"
57+
f"unsanitized_proxy_server:{_PATHWAYS_IMAGE_TAG}"
5558
)
5659
# The docker image used by pathways resource manager container and worker container.
5760
_PATHWAYS_SERVER_IMAGE = (
58-
f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:{_PATHWAYS_IMAGE_TAG}"
61+
# f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:{_PATHWAYS_IMAGE_TAG}"
62+
"us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/shauryag/"
63+
f"unsanitized_server:{_PATHWAYS_IMAGE_TAG}"
5964
)
6065
# The container name of pathways resourcemanager.
6166
_PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME = "pathways-rm"
@@ -269,10 +274,16 @@ def _build_pathways_head_container(self) -> dict:
269274
head_container = copy.deepcopy(container)
270275

271276
env_list = head_container.get("env", [])
277+
# self._update_env_list(
278+
# env_list,
279+
# "JAX_BACKEND_TARGET",
280+
# f"grpc://localhost:{_PATHWAYS_PROXY_PORT}",
281+
# )
282+
# Unix domain socket
272283
self._update_env_list(
273284
env_list,
274285
"JAX_BACKEND_TARGET",
275-
f"grpc://localhost:{_PATHWAYS_PROXY_PORT}",
286+
"grpc:///tmp/ifrt_proxy.sock",
276287
)
277288
self._update_env_list(env_list, "XCLOUD_ENVIRONMENT", "GCP")
278289
self._update_env_list(env_list, "JAX_PLATFORMS", "proxy")
@@ -327,6 +338,10 @@ def _build_pathways_head_container(self) -> dict:
327338
}
328339
head_container["resources"] = resources
329340

341+
volume_mounts = head_container.get("volumeMounts", [])
342+
volume_mounts.append(dict(name="shared-memory", mountPath="/tmp/"))
343+
head_container["volumeMounts"] = volume_mounts
344+
330345
return head_container
331346

332347
def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
@@ -350,6 +365,7 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
350365

351366
cmd_args = [
352367
f"--resource_manager_address=localhost:{_PATHWAYS_RESOURCE_MANAGER_PORT}",
368+
# using unix socket but port needs to be set anyway
353369
f"--server_port={_PATHWAYS_PROXY_PORT}",
354370
f"--gcs_scratch_location={staging_location}",
355371
]
@@ -374,7 +390,10 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
374390
{"name": "XLA_FLAGS", "value": f"--xla_dump_to=/output/{cfg.name}/xla"},
375391
],
376392
ports=[dict(containerPort=_PATHWAYS_PROXY_PORT)],
377-
volumeMounts=[dict(name="shared-output", mountPath="/output")],
393+
volumeMounts=[
394+
dict(name="shared-output", mountPath="/output"),
395+
dict(name="shared-memory", mountPath="/tmp/"),
396+
],
378397
),
379398
dict(
380399
name=_PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME,
@@ -412,6 +431,7 @@ def _build_pathways_head_pod(self) -> Nested[Any]:
412431
labels.update({BASTION_JOB_VERSION_LABEL: os.environ.get(BASTION_JOB_VERSION_ENV_VAR)})
413432

414433
volumes.append(dict(name="shared-output", emptyDir={}))
434+
volumes.append(dict(name="shared-memory", emptyDir=dict(medium="Memory")))
415435

416436
if cfg.gcsfuse_mount:
417437
annotations.update(

0 commit comments

Comments
 (0)