Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# syntax=docker/dockerfile:1

ARG TARGET=base
ARG BASE_IMAGE=ubuntu:22.04
ARG BASE_IMAGE=ubuntu:24.04

FROM ${BASE_IMAGE} AS base

Expand All @@ -18,7 +18,7 @@ RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && \
apt-get update -y -qq && \
apt-get install -y -qq apt-transport-https ca-certificates gcc g++ \
git screen ca-certificates google-perftools google-cloud-cli python3.10-venv && \
git screen ca-certificates google-perftools google-cloud-cli python3.12-venv && \
apt clean -y -qq

# Setup.
Expand All @@ -30,7 +30,7 @@ COPY pyproject.toml pyproject.toml
RUN mkdir axlearn && touch axlearn/__init__.py
# Setup venv to suppress pip warnings.
ENV VIRTUAL_ENV=/opt/venv
RUN python3 -m venv $VIRTUAL_ENV
RUN python3.12 -m venv $VIRTUAL_ENV
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# Install dependencies.
RUN pip install -qq --upgrade pip && \
Expand Down Expand Up @@ -81,7 +81,7 @@ COPY . .

# Dataflow workers can't start properly if the entrypoint is not set
# See: https://cloud.google.com/dataflow/docs/guides/build-container-image#use_a_custom_base_image
COPY --from=apache/beam_python3.10_sdk:2.52.0 /opt/apache/beam /opt/apache/beam
COPY --from=apache/beam_python3.12_sdk:2.68.0 /opt/apache/beam /opt/apache/beam
ENTRYPOINT ["/opt/apache/beam/boot"]

################################################################################
Expand All @@ -96,15 +96,16 @@ ARG EXTRAS=
# Needed until Jax is upgraded to 0.8.0 or newer.
ARG INSTALL_PATHWAYS_JAXLIB=false

ENV UV_FIND_LINKS=https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Ensure we install the TPU version, even if building locally.
# Jax will fallback to CPU when run on a machine without TPU.
RUN uv pip install -qq --prerelease=allow .[core,tpu] && uv cache clean
RUN uv pip install --prerelease=allow .[core,tpu] && uv cache clean
RUN if [ -n "$EXTRAS" ]; then uv pip install -qq .[$EXTRAS] && uv cache clean; fi
RUN if [ "$INSTALL_PATHWAYS_JAXLIB" = "true" ]; then \
uv pip install --prerelease=allow "jaxlib==0.5.3.dev20250918" \
--find-links https://storage.googleapis.com/axlearn-wheels/wheels.html; \
fi
COPY --from=libtpu-target:latest /wheels /wheels
RUN uv pip install --no-deps /wheels/*.whl && uv cache clean
COPY . .

################################################################################
Expand All @@ -114,13 +115,12 @@ COPY . .
FROM base AS gpu

# TODO(markblee): Support extras.
ENV UV_FIND_LINKS=https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Enable the CUDA repository and install the required libraries (libnvrtc.so)
RUN curl -o cuda-keyring_1.1-1_all.deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb && \
RUN curl -o cuda-keyring_1.1-1_all.deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb && \
dpkg -i cuda-keyring_1.1-1_all.deb && \
apt-get update && apt-get install -y cuda-libraries-dev-12-8 ibverbs-utils && \
apt-get update && apt-get install -y cuda-libraries-dev-12-9 ibverbs-utils && \
apt clean -y
RUN uv pip install -qq .[core,gpu] && uv cache clean
RUN uv pip install --prerelease=allow .[core,gpu] && uv cache clean
COPY . .

################################################################################
Expand Down
6 changes: 3 additions & 3 deletions axlearn/cli/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def test_subprocess_argv(self):
absl_main(mock_args)
self.assertEqual(1, len(mock_popen.call_args_list))
self.assertEqual((expected,), mock_popen.call_args[0])
self.assertDictContainsSubset({"text": True}, mock_popen.call_args[1])
self.assertTrue({"text": True}.items() <= mock_popen.call_args[1].items())
self.assertEqual(self.root_module, mock_popen.call_args[1]["env"]["AXLEARN_CLI_NAME"])

shell_cases = [
Expand Down Expand Up @@ -397,8 +397,8 @@ def test_subprocess_argv(self):
absl_main(mock_args)
self.assertEqual(1, len(mock_popen.call_args_list))
self.assertEqual((expected,), mock_popen.call_args[0])
self.assertDictContainsSubset(
{"text": True, "shell": True}, mock_popen.call_args[1]
self.assertTrue(
{"text": True, "shell": True}.items() <= mock_popen.call_args[1].items()
)
self.assertEqual(
self.root_module, mock_popen.call_args[1]["env"]["AXLEARN_CLI_NAME"]
Expand Down
3 changes: 3 additions & 0 deletions axlearn/cloud/gcp/jobset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,9 @@ def _build_pod(self) -> Nested[Any]:
# Tier "0" corresponds to reserved; otherwise we use preemptible.
tier = os.environ.get("BASTION_TIER", None)

# TODO(samos123) support using reservation when using local launch
# the local launch command automatically sets tier=disabled.
logging.info("Found tier=%s in env. Using reservation=%s", tier, cfg.reservation)
if tier == "0" and cfg.reservation is not None:
logging.info("Found tier=%s in env. Using reservation=%s", tier, cfg.reservation)
selector.update({"cloud.google.com/reservation-name": cfg.reservation})
Expand Down
2 changes: 1 addition & 1 deletion axlearn/cloud/gcp/runners/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def should_recreate_job(
# We only consider recreation for TPU jobs.
return False

if str(tier) != "0" and reservation is not None:
if str(tier) not in ["0", "disabled"] and reservation is not None:
logging.info(
"Bastion tier is %s but reservation is %s. Job resources will be recreated.",
tier,
Expand Down
53 changes: 21 additions & 32 deletions axlearn/common/array_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import tensorstore as ts
from absl import logging
from jax._src import array, typing
from jax._src.layout import Layout
from jax._src.layout import Format
from jax.experimental.array_serialization import serialization

from axlearn.common.utils import Tensor
Expand Down Expand Up @@ -291,26 +291,23 @@ async def _async_serialize(
# Await for limiter before D2H.
if limiter is not None:
# pylint: disable-next=protected-access
if jax.__version__ == "0.6.2" and nbytes > limiter._max_bytes:
if nbytes > limiter._max_bytes:
raise ValueError(
"Attempting to read more bytes than we allocated space for in the limiter"
# pylint: disable-next=protected-access
f"{nbytes} > {limiter._max_bytes}"
)
await limiter.wait_for_bytes(nbytes)
else:
await limiter.wait_for_bytes(nbytes)

# Fully addressable arrays lead to races between multiple writing hosts.
assert not (
isinstance(arr_inp, array.ArrayImpl)
and jax.process_count() > 1
and arr_inp.is_fully_addressable
)
# pylint: disable=protected-access
spec_has_metadata = {
"0.6.2": lambda: serialization.ts_impl._spec_has_metadata,
"0.5.3": lambda: serialization._spec_has_metadata,
}[jax.__version__]()
if not spec_has_metadata(tensorstore_spec):
# pylint: disable-next=protected-access
if not serialization.ts_impl._spec_has_metadata(tensorstore_spec):
# pylint: disable-next=protected-access
tensorstore_spec["metadata"] = serialization._get_metadata(arr_inp)
if "dtype" not in tensorstore_spec:
Expand Down Expand Up @@ -407,12 +404,12 @@ async def _run_serializer(
raise e


def _blocking_device_put(out: Tensor, layout: Layout) -> Tensor:
def _blocking_device_put(out: Tensor, layout: Format) -> Tensor:
return jax.block_until_ready(jax.device_put(out, layout))


async def _async_deserialize(
user_in_sharding: jax.sharding.Sharding | Layout,
user_in_sharding: jax.sharding.Sharding | Format,
tensorstore_spec: dict[str, Any],
global_shape: Optional[Sequence[int]],
dtype: Optional[typing.DTypeLike],
Expand Down Expand Up @@ -459,14 +456,14 @@ async def _async_deserialize(
huge pages (THP) can help, but it's only for jax 0.5.1+.
"""
in_sharding = (
user_in_sharding.sharding if isinstance(user_in_sharding, Layout) else user_in_sharding
user_in_sharding.sharding if isinstance(user_in_sharding, Format) else user_in_sharding
)
if not isinstance(in_sharding, jax.sharding.Sharding):
raise ValueError(
"sharding passed to deserialization should be specified, concrete and"
f" an instance of `jax.sharding.Sharding`. Got {in_sharding}"
)
dll = user_in_sharding.device_local_layout if isinstance(user_in_sharding, Layout) else None
dll = user_in_sharding.device_local_layout if isinstance(user_in_sharding, Format) else None

# gcs_grpc is 2x to 4x faster than gcs on read performance. And this is recommended by Google
# GCS team.
Expand All @@ -486,11 +483,7 @@ async def _async_deserialize(
async def cb(index: array.Index, device: jax.Device):
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
restricted_domain = t.domain.intersect(requested_domain)
estimate_read_memory_footprint = {
"0.6.2": lambda: serialization.ts_impl.estimate_read_memory_footprint,
"0.5.3": lambda: serialization.estimate_read_memory_footprint,
}[jax.__version__]()
requested_bytes = estimate_read_memory_footprint(t, restricted_domain)
requested_bytes = serialization.ts_impl.estimate_read_memory_footprint(t, restricted_domain)
# Limit the bytes read for every shard.
await byte_limiter.wait_for_bytes(requested_bytes)
read_ts = t[restricted_domain]
Expand Down Expand Up @@ -528,9 +521,10 @@ async def cb(index: array.Index, device: jax.Device):
mb_256 = 256 * 1024 * 1024
out_size = math.ceil(out_size / mb_256) * mb_256

layout = Layout(
layout = Format(
dll, jax.sharding.SingleDeviceSharding(device, memory_kind=in_sharding.memory_kind)
)

# Jax >= 0.6.2 changes the behavior of _LimitInFlightBytes, where wait_for_bytes no longer
# throws an exception if requested_bytes > max_bytes
# pylint: disable-next=protected-access
Expand Down Expand Up @@ -566,12 +560,10 @@ async def cb(index: array.Index, device: jax.Device):
await byte_limiter.release_bytes(requested_bytes)
return result

# pylint: disable=protected-access
create_async_array_from_callback = {
"0.6.2": lambda: serialization.ts_impl._create_async_array_from_callback,
"0.5.3": lambda: serialization.create_async_array_from_callback,
}[jax.__version__]()
return await create_async_array_from_callback(shape, in_sharding, cb)
# pylint: disable-next=protected-access
return await serialization.ts_impl._create_async_array_from_callback(
shape, dtype, in_sharding, cb
)


# Reference:
Expand Down Expand Up @@ -652,14 +644,11 @@ def serialize(

commit_futures = [[] for _ in range(len(tensorstore_specs))]

async_serialize = {
"0.6.2": lambda: serialization.ts_impl.async_serialize,
"0.5.3": lambda: serialization.async_serialize,
}[jax.__version__]()

# pylint: disable-next=redefined-outer-name
async def _run_serializer():
future_writer = jax.tree.map(async_serialize, arrays, tensorstore_specs, commit_futures)
future_writer = jax.tree.map(
serialization.ts_impl.async_serialize, arrays, tensorstore_specs, commit_futures
)
return await asyncio.gather(*future_writer)

# Note: We need to run the coroutine in another event loop driven by a separate thread.
Expand All @@ -680,7 +669,7 @@ async def _run_serializer():
# https://github.com/jax-ml/jax/blob/66037d10e7742c4fcadd07f0459a00813ec7ed5f/jax/experimental/array_serialization/serialization.py#L413-L429
def deserialize(
self,
shardings: Sequence[Union[jax.sharding.Sharding, Layout]],
shardings: Sequence[Union[jax.sharding.Sharding, Format]],
tensorstore_specs: Sequence[dict[str, Any]],
global_shapes: Optional[Sequence[array.Shape]] = None,
dtypes: Optional[Sequence[typing.DTypeLike]] = None,
Expand Down
44 changes: 19 additions & 25 deletions axlearn/common/array_serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import tensorstore as ts
from absl.testing import absltest, parameterized
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding

from axlearn.common import array_serialization
from axlearn.common.array_serialization import (
Expand All @@ -37,10 +36,6 @@
serialization,
)

# TODO(wyi): This dictionary is introduced for the temporary peroiod of upgrading JAX from 0.5.3 to
# 0.6.2. Once the upgrading is complete, we should remove it ASAP.
_ts_open = {"0.6.2": "ts.open", "0.5.3": "serialization.ts.open"}[jax.__version__]


@contextmanager
def get_tensorstore_spec(arr: jax.Array):
Expand Down Expand Up @@ -91,8 +86,9 @@ def _create_partially_replicated_array(self, sharded: bool):
if jax.device_count() != 8 or jax.process_count() != 1:
self.skipTest("Incorrect device count for mesh.")
devices = mesh_utils.create_device_mesh((8,))
sharding = PositionalSharding(devices)
arr = jax.device_put(single_device_arr, sharding.reshape(4, 2).replicate(0))
mesh = jax.sharding.Mesh(devices.reshape((4, 2)), ("x", "y"))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None, "y"))
arr = jax.device_put(single_device_arr, sharding)
return arr
return single_device_arr

Expand All @@ -101,10 +97,7 @@ def test_async_serialize_d2h_sync(self, sharded):
arr = self._create_partially_replicated_array(sharded)

ts_open_handle: Any = None
old_open = {
"0.6.2": lambda: array_serialization.ts.open,
"0.5.3": lambda: array_serialization.serialization.ts.open,
}[jax.__version__]()
old_open = array_serialization.ts.open

async def ts_open_patch(*args, **kwargs):
nonlocal ts_open_handle
Expand All @@ -126,7 +119,7 @@ def transfer_to_host_patch(*args, **kwargs):

d2h_future = array_serialization.futures.Future()
with mock.patch(
f"{array_serialization.__name__}.{_ts_open}",
f"{array_serialization.__name__}.ts.open",
ts_open_patch,
), get_tensorstore_spec(arr) as spec, mock.patch(
f"{array_serialization.__name__}._transfer_to_host", transfer_to_host_patch
Expand All @@ -152,7 +145,7 @@ def transfer_to_host_patch(*args, **kwargs):
arr_host = jax.device_get(arr)
d2h_future = array_serialization.futures.Future()
with mock.patch(
f"{array_serialization.__name__}.{_ts_open}",
f"{array_serialization.__name__}.ts.open",
ts_open_patch,
), get_tensorstore_spec(arr) as spec, mock.patch(
f"{array_serialization.__name__}._transfer_to_host", transfer_to_host_patch
Expand Down Expand Up @@ -186,7 +179,7 @@ async def ts_open_patch(*_, **__):

d2h_future = array_serialization.futures.Future()
with mock.patch(
f"{array_serialization.__name__}.{_ts_open}",
f"{array_serialization.__name__}.ts.open",
ts_open_patch,
), get_tensorstore_spec(arr) as spec:
f = _CommitFuture(
Expand Down Expand Up @@ -284,7 +277,7 @@ async def _copy_to_host_patch(shard_infos: list[_ShardInfo]):
mock.patch(
f"{array_serialization.__name__}.serialization._get_metadata", lambda *_: {}
),
mock.patch(f"{array_serialization.__name__}.{_ts_open}", open_patch),
mock.patch(f"{array_serialization.__name__}.ts.open", open_patch),
mock.patch(f"{array_serialization.__name__}.ts.Spec", mock.MagicMock()),
):
manager.serialize(arrays, tensorstore_specs, on_commit_callback=lambda: None)
Expand Down Expand Up @@ -312,7 +305,8 @@ def test_deserialize(
)

devices = mesh_utils.create_device_mesh((8,))
sharding = PositionalSharding(devices)
mesh = jax.sharding.Mesh(devices, "x")
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(("x",)))
data = [jax.device_put(arr, sharding) for arr in arrays]

# create a temporary tensorstore spec for the arrays
Expand Down Expand Up @@ -470,9 +464,9 @@ def test_shard_info_partially_replicated(

single_device_arr = jnp.arange(0, 1024 * 1024).reshape(1024, 1024)
devices = mesh_utils.create_device_mesh((8,))
sharding = PositionalSharding(devices)

arr = jax.device_put(single_device_arr, sharding.reshape(4, 2).replicate(0))
mesh = jax.sharding.Mesh(devices.reshape((4, 2)), ("x", "y"))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None, "y"))
arr = jax.device_put(single_device_arr, sharding)

replica_count = _num_replicas_per_shard(arr)
self.assertEqual(replica_count[((None, None, None), (0, 512, None))], 4)
Expand All @@ -490,9 +484,9 @@ def test_shard_info_fully_sharded(self, max_data_shard_degree: int, shard_thresh
self.skipTest("Incorrect device count for mesh.")
single_device_arr = jnp.arange(0, 1024 * 1024).reshape(1024, 1024)
devices = mesh_utils.create_device_mesh((8,))
sharding = PositionalSharding(devices)

arr = jax.device_put(single_device_arr, sharding.reshape(4, 2))
mesh = jax.sharding.Mesh(devices.reshape((4, 2)), ("x", "y"))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x", "y"))
arr = jax.device_put(single_device_arr, sharding)

replica_count = _num_replicas_per_shard(arr)
self.assertEqual(replica_count[((0, 256, None), (0, 512, None))], 1)
Expand All @@ -513,9 +507,9 @@ def test_shard_info_fully_replicated(
self.skipTest("Incorrect device count for mesh.")
single_device_arr = jnp.arange(0, sz)
devices = mesh_utils.create_device_mesh((8,))
sharding = PositionalSharding(devices)

arr = jax.device_put(single_device_arr, sharding.replicate(0))
mesh = jax.sharding.Mesh(devices, "x")
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None))
arr = jax.device_put(single_device_arr, sharding)

replica_count = _num_replicas_per_shard(arr)
# Fully replicated on 8 devices.
Expand Down
1 change: 1 addition & 0 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4556,6 +4556,7 @@ class RematRegexSavePatterns(enum.Enum):
FLASH_CONTEXT = f".*{FLASH_ATTN_RESIDUAL_NAME}"
FLASH_ATTENTION = "|".join([FLASH_CONTEXT, QKV_PROJ, O_PROJ])
FEED_FORWARD = "|".join([LINEAR1_X, LINEAR2_X])
INPUT = r".*input"


def build_remat_spec(
Expand Down
Loading
Loading