Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions axlearn/add_one_colocated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import pathwaysutils

print("initializing pathwaysutils")
pathwaysutils.initialize()
print("pathwaysutils initialized")

import numpy as np
import jax
from jax.experimental import colocated_python
import os
import shutil
import orbax.checkpoint as ocp

print("jax version on cpu host:", jax.__version__)

print("getting tpu devices")
tpu_devices = jax.devices()
print("tpu devices: ", tpu_devices)
print("getting cpu devices")
cpu_devices = colocated_python.colocated_cpu_devices(tpu_devices)
print("cpu devices: ", cpu_devices)

import cloudpickle

print("JAX_PLATFORMS is 'proxy'. Setting up pathways colocated python checkpointing.")
print(f" Using jax version {jax.__version__} and cloudpickle version {cloudpickle.__version__}")


print("def add_one")


@colocated_python.colocated_python
def add_one(x):
import sys

sys.stderr.write("In colocated python function \n")
sys.stderr.write(f"[Colocated] jax version: {jax.__version__} \n")
sys.stderr.write("[Colocated] add_one")
sys.stderr.write(f"[Colocated] x: {x} on device: {x.device } \n")
return x+1


print("creating input 1")
x = np.array(1)
print("putting on device")
x = jax.device_put(x, cpu_devices[0])

print("adding one to input 1")
out = add_one(x)
print("getting out")
out = jax.device_get(out)
print("out 1: ", out)

print("creating input 2")
x = np.array(5)
print("putting on device")
x = jax.device_put(x, cpu_devices[0])

assert out == 2, f"out: {out}"

print("adding one to input 2")
out = add_one(x)
print("getting out")
out = jax.device_get(out)
print("out 2: ", out)

assert out == 6, f"out: {out}"
2 changes: 1 addition & 1 deletion axlearn/cloud/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import copy
import dataclasses
import functools
import importlib
import importlib.metadata
import logging as pylogging
import os
import shlex
Expand Down
40 changes: 34 additions & 6 deletions axlearn/cloud/gcp/pathways_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,23 @@
# The port used by pathways worker server.
# The specific value is not important, as long as clients and servers use the same port.
_PATHWAYS_WORKER_PORT = 29001
_COLOCATED_CONTAINER_PORT = 50051
# Pin to specific pathways image version for stable release.
# There is no guarantee that this image will work with newer Jax releases.
# This image version extends GRPC timeout for long context models, based on jax-0.5.3-patch060625
# This image extends GRPC timeout for long context models.
_PATHWAYS_IMAGE_TAG = "disable_settings_20250701"

# The docker image used by pathways proxy container.
_PATHWAYS_PROXY_IMAGE = (
f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:{_PATHWAYS_IMAGE_TAG}"
"us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/ksadi/unsanitized_proxy_server@sha256:73516e07b3ccd9af487100c55cff35b7089025b4909847fd234f0f768d99ebea"
)
# The docker image used by pathways resource manager container and worker container.
_PATHWAYS_SERVER_IMAGE = (
f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:{_PATHWAYS_IMAGE_TAG}"
"us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/ksadi/unsanitized_server@sha256:fde763e2bae514d0fa758840e501b71a9ea48781dddafa5d8ed3a0fa316fd1ae"
)
_COLOCATED_PYTHON_IMAGE = (
"us-docker.pkg.dev/cloud-tpu-multipod-dev/colocated-images/lk-colocated-image:latest"
)
# The container name of pathways resourcemanager.
_PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME = "pathways-rm"
Expand All @@ -63,6 +68,10 @@
# The k8s replicatedJob name for pathways-worker pods.
_PATHWAYS_WORKER_REPLICATED_JOB_NAME = "pathways-worker"

_COLOCATED_PYTHON_SIDECAR_NAME = "colocated-python-sidecar"



# Add node-selector for cpu workload to avoid sharing nodes with system services.
_PATHWAYS_HEAD_NODE_POOL_SELECTOR_KEY = "axlearn/nodepool_type"
_PATHWAYS_HEAD_NODE_POOL_SELECTOR_VALUE = "workload"
Expand Down Expand Up @@ -183,8 +192,8 @@ def define_flags(cls, fv):
@classmethod
def set_defaults(cls, fv):
super().set_defaults(fv)
fv.set_default("pathways_head_cpu", fv.pathways_head_cpu or "1")
fv.set_default("pathways_head_mem", fv.pathways_head_mem or "16")
fv.set_default("pathways_head_cpu", fv.pathways_head_cpu or "8")
fv.set_default("pathways_head_mem", fv.pathways_head_mem or "80")

@classmethod
def default_config(cls):
Expand Down Expand Up @@ -314,7 +323,7 @@ def _build_pathways_head_container(self) -> dict:
mem_req = f"{self.config.pathways_head_mem}Gi"
resources = {
"requests": {"cpu": cpu_req, "memory": mem_req},
"limits": {"cpu": cpu_req, "memory": mem_req},
#"limits": {"cpu": cpu_req, "memory": mem_req},
}
head_container["resources"] = resources

Expand Down Expand Up @@ -357,7 +366,7 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
# https://kubernetes.io/docs/concepts/workloads/pods/sidecar-containers/#pod-sidecar-containers
# SideCar container is an init container with restartPolicy as "Always".
restartPolicy="Always",
args=cmd_args,
args=cmd_args + ["--sidecar_name=external"],
env=proxy_env,
ports=[dict(containerPort=_PATHWAYS_PROXY_PORT)],
),
Expand All @@ -382,6 +391,23 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
],
),
]

def _colocated_python_container(self):

return dict(
name=_COLOCATED_PYTHON_SIDECAR_NAME,
image=_COLOCATED_PYTHON_IMAGE,
restartPolicy="Always",
env=[
{
"name": "GRPC_SERVER_ADDRESS",
"value": f"0.0.0.0:{_COLOCATED_CONTAINER_PORT}",
},
],
imagePullPolicy="Always",
ports=[dict(containerPort=_COLOCATED_CONTAINER_PORT)],

)

def _build_pathways_head_pod(self) -> Nested[Any]:
"""Builds a pathways head pod. The pod includes a head container,
Expand Down Expand Up @@ -563,6 +589,8 @@ def _build_pathways_worker_pod(
pod_spec["containers"] = [
self._build_pathways_worker_container(pathways_worker_replicated_job_index)
]
pod_spec["initContainers"]=[self._colocated_python_container()]

worker_pod["spec"] = pod_spec

# Service account for nodes.
Expand Down
2 changes: 1 addition & 1 deletion axlearn/experiments/text/gpt/envy.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,4 +537,4 @@ def make_single_host_config(base_config_name: str) -> SpmdTrainer.Config:
make_single_host_config_func = functools.partial(make_single_host_config, config_name)
config_map[f"{config_name}-single-host"] = make_single_host_config_func

return config_map
return config_map
32 changes: 32 additions & 0 deletions colocated/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Use the JAX image with the custom-built sidecar as the base.
# FROM gcr.io/cloud-tpu-multipod-dev/sujinesh_sidecar_debug@sha256:19abcd94addb6ff2749c299d6b0cc4748f27a4ab8759a18b466d0bdd3e5b71e8
FROM us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/sidecar:python_3.10-jax_0.6.2
# Defines a build argument for the requirements file. This contains the user's custom
# dependencies formatted as a requirements file.
ARG REQUIREMENTS_FILE

# Set the working directory (this is already inherited)
WORKDIR /app

# Copy the user's requirements file into the image.
COPY ${REQUIREMENTS_FILE} .

# Install the additional user-provided dependencies, strictly enforcing the rules
# from the base image's constraints file.
RUN \
# Safeguard: Explicitly fail the build if the user tries to reinstall JAX or JAXlib.
# if grep -i -E '^jax(lib)?' ${REQUIREMENTS_FILE}; then \
# echo "ERROR: Your requirements file attempts to re-install 'jax' or 'jaxlib'." >&2; \
# echo "Please remove these lines. The base image provides a specific, custom-patched version of JAX." >&2; \
# exit 1; \
# fi && \
\
# If the safeguard passes, proceed with the constrained installation.
uv pip install --prerelease=allow -r ${REQUIREMENTS_FILE} -c /opt/venv/server_constraints.txt && \
\
# Clean the cache to keep the image slim.
uv cache clean

# Note: The ENTRYPOINT and CMD are inherited from the base image, so they do not
# need to be redefined here. I.e. the sidecar will be launched automatically.

26 changes: 26 additions & 0 deletions colocated/colocated_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
absl-py==2.1.0
#chex==0.1.88
importlab==0.8.1
ml-dtypes==0.5.1
msgpack==1.1.0
nltk==3.7
#optax==0.1.7
portpicker
pyarrow<21.0.0
protobuf>=3.20.3
tensorboard-plugin-profile
tensorflow #==2.14.1
tensorflow-datasets #>=4.9.2
tensorflow-io #>=0.37.1
tensorflow_text # >=2.19.0; platform_machine == 'x86_64'
tensorstore>=0.1.63
toml
typing-extensions==4.12.2
scipy==1.15.0
seqio>=0.0.15
#flax==0.10.2
prefixed==0.9.0
grain==0.2.7
#axlearn[gcp]==0.0.1.dev20240211233521
#pathwaysutils==0.1.1
jaxlib @ file:///app/patched_jax/dist/jaxlib-0.6.2.dev20250930-cp310-cp310-manylinux2014_x86_64.whl#sha256=1141cab71a7950b5724b110594d7f661e14b7ca123bb6f37ea2976ac415b2a24
Loading