diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index fa47ef64b..098edbd1d 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -48,14 +48,19 @@ # There is no guarantee that this image will work with newer Jax releases. # This image version extends GRPC timeout for long context models, based on jax-0.5.3-patch060625 # This image extends GRPC timeout for long context models. -_PATHWAYS_IMAGE_TAG = "disable_settings_20250701" +# _PATHWAYS_IMAGE_TAG = "disable_settings_20250701" +_PATHWAYS_IMAGE_TAG = "uds" # The docker image used by pathways proxy container. _PATHWAYS_PROXY_IMAGE = ( - f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:{_PATHWAYS_IMAGE_TAG}" + # f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:{_PATHWAYS_IMAGE_TAG}" + "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/shauryag/" + f"unsanitized_proxy_server:{_PATHWAYS_IMAGE_TAG}" ) # The docker image used by pathways resource manager container and worker container. _PATHWAYS_SERVER_IMAGE = ( - f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:{_PATHWAYS_IMAGE_TAG}" + # f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:{_PATHWAYS_IMAGE_TAG}" + "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/shauryag/" + f"unsanitized_server:{_PATHWAYS_IMAGE_TAG}" ) # The container name of pathways resourcemanager. _PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME = "pathways-rm" @@ -107,7 +112,7 @@ def get_pathways_tpu_version(gke_machine_type: str) -> str: def get_megascale_options( - xla_options: dict[str, Union[str, bool, int]] + xla_options: dict[str, Union[str, bool, int]], ) -> dict[str, Union[str, bool, int]]: """Filters XLA options for those pertaining to Megascale. @@ -122,7 +127,7 @@ def get_megascale_options( def get_xla_options( - xla_options: dict[str, Union[str, bool, int]] + xla_options: dict[str, Union[str, bool, int]], ) -> dict[str, Union[str, bool, int]]: """Filters XLA options for those starting with 'xla_'. @@ -146,12 +151,14 @@ class Config(BaseReplicatedJob.Config): inner: The wrapped TPUReplicatedJob configuration. pathways_head_cpu: CPU request for pathways-head container. pathways_head_mem: Memory request for pathways-head container. + pathways_head_on_tpu: Whether to run pathways head on TPU VM. """ inner: Required[TPUReplicatedJob.Config] = REQUIRED pathways_xla_flags: list[str] = [] pathways_head_cpu: Optional[str] = None pathways_head_mem: Optional[str] = None + pathways_head_on_tpu: bool = False @classmethod def define_flags(cls, fv): @@ -180,6 +187,12 @@ def define_flags(cls, fv): "Memory request for pathways-head container in GiB. Default is 16GiB", **common_kwargs, ) + flags.DEFINE_boolean( + "pathways_head_on_tpu", + False, + "If True, run pathways head on TPU VM.", + **common_kwargs, + ) @classmethod def set_defaults(cls, fv): @@ -261,10 +274,16 @@ def _build_pathways_head_container(self) -> dict: head_container = copy.deepcopy(container) env_list = head_container.get("env", []) + # self._update_env_list( + # env_list, + # "JAX_BACKEND_TARGET", + # f"grpc://localhost:{_PATHWAYS_PROXY_PORT}", + # ) + # Unix domain socket self._update_env_list( env_list, "JAX_BACKEND_TARGET", - f"grpc://localhost:{_PATHWAYS_PROXY_PORT}", + "grpc:///tmp/ifrt_proxy.sock", ) self._update_env_list(env_list, "XCLOUD_ENVIRONMENT", "GCP") self._update_env_list(env_list, "JAX_PLATFORMS", "proxy") @@ -315,10 +334,14 @@ def _build_pathways_head_container(self) -> dict: mem_req = f"{self.config.pathways_head_mem}Gi" resources = { "requests": {"cpu": cpu_req, "memory": mem_req}, - "limits": {"cpu": cpu_req, "memory": mem_req}, + # "limits": {"cpu": cpu_req, "memory": mem_req}, } head_container["resources"] = resources + volume_mounts = head_container.get("volumeMounts", []) + volume_mounts.append(dict(name="shared-memory", mountPath="/tmp/")) + head_container["volumeMounts"] = volume_mounts + return head_container def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]: @@ -342,6 +365,7 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]: cmd_args = [ f"--resource_manager_address=localhost:{_PATHWAYS_RESOURCE_MANAGER_PORT}", + # using unix socket but port needs to be set anyway f"--server_port={_PATHWAYS_PROXY_PORT}", f"--gcs_scratch_location={staging_location}", ] @@ -354,6 +378,7 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]: dict( name=_PATHWAYS_PROXY_CONTAINER_NAME, image=_PATHWAYS_PROXY_IMAGE, + securityContext={"privileged": True}, # https://kubernetes.io/docs/concepts/workloads/pods/sidecar-containers/#pod-sidecar-containers # SideCar container is an init container with restartPolicy as "Always". restartPolicy="Always", @@ -365,7 +390,10 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]: {"name": "XLA_FLAGS", "value": f"--xla_dump_to=/output/{cfg.name}/xla"}, ], ports=[dict(containerPort=_PATHWAYS_PROXY_PORT)], - volumeMounts=[dict(name="shared-output", mountPath="/output")], + volumeMounts=[ + dict(name="shared-output", mountPath="/output"), + dict(name="shared-memory", mountPath="/tmp/"), + ], ), dict( name=_PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME, @@ -403,6 +431,7 @@ def _build_pathways_head_pod(self) -> Nested[Any]: labels.update({BASTION_JOB_VERSION_LABEL: os.environ.get(BASTION_JOB_VERSION_ENV_VAR)}) volumes.append(dict(name="shared-output", emptyDir={})) + volumes.append(dict(name="shared-memory", emptyDir=dict(medium="Memory"))) if cfg.gcsfuse_mount: annotations.update( @@ -414,9 +443,15 @@ def _build_pathways_head_pod(self) -> Nested[Any]: } ) - node_selector = { - _PATHWAYS_HEAD_NODE_POOL_SELECTOR_KEY: _PATHWAYS_HEAD_NODE_POOL_SELECTOR_VALUE, - } + if self.config.pathways_head_on_tpu: + # pylint: disable-next=protected-access + pod = self._inner._build_pod() + node_selector = {} + tolerations = pod["spec"]["tolerations"] + else: + node_selector = { + _PATHWAYS_HEAD_NODE_POOL_SELECTOR_KEY: _PATHWAYS_HEAD_NODE_POOL_SELECTOR_VALUE, + } head_container = self._build_pathways_head_container() init_containers = [ @@ -444,6 +479,32 @@ def _build_pathways_head_pod(self) -> Nested[Any]: "hostNetwork": True, "dnsPolicy": "ClusterFirstWithHostNet", } + if self.config.pathways_head_on_tpu: + head_pod_spec["affinity"] = { + "podAffinity": { + "requiredDuringSchedulingIgnoredDuringExecution": [ + { + "labelSelector": { + "matchExpressions": [ + { + "key": "batch.kubernetes.io/job-name", + "operator": "In", + "values": [ + f"{cfg.name}-{_PATHWAYS_WORKER_REPLICATED_JOB_NAME}-0" + ], + } + ] + }, + "topologyKey": "kubernetes.io/hostname", + } + ] + } + } + + # Remove host ports to avoid scheduling conflicts on the same node. + # The pod runs on host network anyway, so the ports are still accessible. + if "ports" in head_pod_spec["containers"][0]: + del head_pod_spec["containers"][0]["ports"] if cfg.priority_class: head_pod_spec["priorityClassName"] = cfg.priority_class @@ -537,6 +598,17 @@ def _build_pathways_worker_container( f"--resource_manager_address={pathways_head_address}:" + f"{_PATHWAYS_RESOURCE_MANAGER_PORT}", f"--gcs_scratch_location={cfg.output_dir}/pathways-staging", + # Set premap buffer to 17GB, needed for faster jax.device_put h2d + # "--pathways_tpu_premapped_buffer_size=17179869184" doesn't work in cloud + # Below flags did not help on 7b restore time + # Recycle vs on-demand seems to give a slight perf boost + "--tpu_pinned_host_allocation_recycle=true", + # pylint: disable=line-too-long + "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_premapped_buffer_size=68719476736", + # "--temporary_flags_for_debugging=temporary_flag_for_debuggings_max_num_threads_for_xla_compilation=1000" + # "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_max_inflight_async_computations=1000", + # "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_tpu_allow_async_allocations=true", + # "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_num_premapped_partitions=65536", ] mega_scale_args = xla_flags_from_options(self._mxla_options).split() worker_container["args"].extend(mega_scale_args) @@ -634,18 +706,23 @@ def _build_pathways_worker_job( def __call__(self) -> Sequence[Nested[Any]]: cfg: TPUReplicatedJob.Config = self._inner.config - replicated_jobs = [ - dict( - name=_PATHWAYS_HEAD_REPLICATED_JOB_NAME, - replicas=1, - template=self._build_pathways_head_job(), - ), - dict( - name=_PATHWAYS_WORKER_REPLICATED_JOB_NAME, - replicas=cfg.accelerator.num_replicas, - template=self._build_pathways_worker_job(), - ), - ] + worker_job = dict( + name=_PATHWAYS_WORKER_REPLICATED_JOB_NAME, + replicas=cfg.accelerator.num_replicas, + template=self._build_pathways_worker_job(), + ) + head_job = dict( + name=_PATHWAYS_HEAD_REPLICATED_JOB_NAME, + replicas=1, + template=self._build_pathways_head_job(), + ) + if self.config.pathways_head_on_tpu: + head_job["dependsOn"] = [ + dict(name=_PATHWAYS_WORKER_REPLICATED_JOB_NAME, status="Ready") + ] + replicated_jobs = [worker_job, head_job] + else: + replicated_jobs = [head_job, worker_job] return replicated_jobs @@ -865,6 +942,7 @@ def _build_pathways_proxy_container(self) -> dict: return dict( name=_PATHWAYS_PROXY_CONTAINER_NAME, image=_PATHWAYS_PROXY_IMAGE, + securityContext={"privileged": True}, args=[ f"--resource_manager_address=localhost:{_PATHWAYS_RESOURCE_MANAGER_PORT}", f"--server_port={_PATHWAYS_PROXY_PORT}", @@ -900,6 +978,14 @@ def _build_pathways_rm_container(self) -> dict: "--instance_count=1", f"--instance_type={pathways_tpu_version}:{system.topology}", f"--gcs_scratch_location={staging_location}", + # Troubleshooting perf + "--tpu_pinned_host_allocation_recycle=true", + # pylint: disable=line-too-long + "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_premapped_buffer_size=68719476736", + # "--temporary_flags_for_debugging=temporary_flag_for_debuggings_max_num_threads_for_xla_compilation=1000" + # "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_max_inflight_async_computations=1000", + # "--temporary_flags_for_debugging=temporary_flag_for_debugging_xla_tpu_allow_async_allocations=true", + # "--temporary_flags_for_debugging=temporary_flag_for_debugging_tpu_num_premapped_partitions=65536", ], ports=[dict(containerPort=_PATHWAYS_RESOURCE_MANAGER_PORT)], ) @@ -910,7 +996,7 @@ def _build_head_container(self) -> dict: mem_req = f"{self.config.pathways_head_mem}Gi" resources = { "requests": {"cpu": cpu_req, "memory": mem_req}, - "limits": {"cpu": cpu_req, "memory": mem_req}, + # "limits": {"cpu": cpu_req, "memory": mem_req}, } return dict( name=cfg.name, @@ -936,9 +1022,9 @@ def _build_head_container(self) -> dict: ], imagePullPolicy="Always", resources=resources, - ports=[dict(containerPort=self.config.target_port)] - if self.config.enable_service - else [], + ports=( + [dict(containerPort=self.config.target_port)] if self.config.enable_service else [] + ), ) def build_leader_pod(self) -> Nested[Any]: diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 9ba0bbf81..3993de4a6 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -20,8 +20,10 @@ import functools import math import os +import sys import threading import time +import uuid from collections import defaultdict from concurrent import futures from concurrent.futures import ThreadPoolExecutor @@ -205,8 +207,7 @@ def _fix_metadata(tspec: dict[str, Any], shard_infos: list[_ShardInfo]): class TensorstoreSpecModifier: - def __call__(self, spec: dict[str, Any], *, shard_infos: list[_ShardInfo]): - ... + def __call__(self, spec: dict[str, Any], *, shard_infos: list[_ShardInfo]): ... async def _async_serialize( @@ -350,6 +351,8 @@ async def _run_serializer( def _blocking_device_put(out: Tensor, layout: Layout) -> Tensor: + # Make it non blocking + # return jax.device_put(out, layout) return jax.block_until_ready(jax.device_put(out, layout)) @@ -404,11 +407,26 @@ async def _async_deserialize( f" an instance of `jax.sharding.Sharding`. Got {in_sharding}" ) dll = user_in_sharding.device_local_layout if isinstance(user_in_sharding, Layout) else None + + # gcs_grpc improves performance. + if tensorstore_spec.get("kvstore", {}).get("driver", "") == "gcs": + tensorstore_spec["kvstore"]["driver"] = "gcs_grpc" + + logging.info("tensorstore_spec: %s", tensorstore_spec) + t = await ts.open( tensorstore_spec, open=True, assume_metadata=False, - context=serialization.TS_CONTEXT, + # context=serialization.TS_CONTEXT, + # Improve GCS performance + context=ts.Context( + { + "cache_pool": {"total_bytes_limit": 0}, + "data_copy_concurrency": {"limit": "shared"}, + "gcs_request_concurrency": {"limit": 480}, + } + ), ) shape = tuple(t.shape if global_shape is None else global_shape) new_shard_shape = in_sharding.shard_shape(shape) @@ -434,9 +452,11 @@ async def cb(index: array.Index, device: jax.Device): # the extra values will be filled with 0s. out = np.zeros(new_shard_shape, read_ts.dtype.numpy_dtype) + write_start_time = time.time() await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][restricted_domain].write( read_ts ) + logging.info("ts.array.write took %.4f seconds.", time.time() - write_start_time) # Convert to jnp array so that layouts are initialized properly for # sub-byte dtypes. @@ -450,13 +470,25 @@ async def cb(index: array.Index, device: jax.Device): mb_256 = 256 * 1024 * 1024 out_size = math.ceil(out_size / mb_256) * mb_256 + logging.info("in_sharding: %s", in_sharding) layout = Layout( dll, jax.sharding.SingleDeviceSharding(device, memory_kind=in_sharding.memory_kind) ) try: + log_id = id(out) + logging.info( + "Sending jax.device_put of size %s MiB. Shape: %s. ID: %s", + out_size / (1024 * 1024), + out.shape, + log_id, + ) + start_time = time.time() await h2d_limiter.wait_for_bytes(out_size) result = await loop.run_in_executor(None, _blocking_device_put, out, layout) await h2d_limiter.release_bytes(out_size) + logging.info("Device put took %.4f seconds. ID: %s", time.time() - start_time, log_id) + # We delete afterwards from HBM since we're testing on v5e with limited HBM + # result.delete(), this didn't work it causes instance device_puts except ValueError as e: if "Requested more bytes than we reserved" not in str(e): raise e # Raise if it's not the type of error we expect. @@ -588,7 +620,13 @@ def deserialize( dtypes: Optional[Sequence[typing.DTypeLike]] = None, concurrent_gb: int = 32, ): + # force to 64 + # concurrent_gb = max(64, concurrent_gb) + logging.info("concurrent_gb=%s GB.", concurrent_gb) self.wait_until_finished() + start_time = time.time() + uid = uuid.uuid4() + jax.profiler.start_trace(f"gs://cloud-tpu-multipod-dev-uss1/stoelinga-{uid}/") concurrent_bytes = concurrent_gb * 10**9 @@ -613,7 +651,14 @@ async def _run_deserializer(): return await asyncio.gather(*future_arrays) fut = asyncio.run_coroutine_threadsafe(_run_deserializer(), self._loop) - return fut.result() + result = fut.result() + # Only needed when we use non blocking device put + # jax.block_until_ready(result) + logging.info("deserialize took %.4f seconds.", time.time() - start_time) + jax.profiler.stop_trace() + sys.exit(0) + # pylint: disable=unreachable + return result class BoundedDataShardedAsyncCheckpointManager(GlobalAsyncCheckpointManager): diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 9ec469dbb..3b892f701 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -15,6 +15,7 @@ import itertools from typing import Any, List, NamedTuple, Optional, Union +import jax from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies from axlearn.common import causal_lm, config @@ -409,8 +410,9 @@ def get_trainer_kwargs( learner_kwargs=dict(peak_lr=3e-4, weight_decay=0.1), max_sequence_length=max_sequence_length, train_batch_size=train_batch_size, + save_every_n_steps=100, max_step=max_step, - mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8), + mesh_shape=mesh_shape_from_axes(fsdp=-1), mesh_rules=( # Step time: # v1 on tpu-v4-1024 (512 chips): 3.03s @@ -423,21 +425,28 @@ def get_trainer_kwargs( ("tpu-v4-(1024|2048)", mesh_shape_from_axes(data=-1, fsdp=16)), # tpu-v5e. ( - "tpu-v5litepod-256", + "tpu-v5litepod-32-1", ChainConfigModifier.default_config().set( config_modifiers=[ MeshShapeModifier.default_config().set( - mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256) + mesh_shape=mesh_shape_from_axes(fsdp=32) ), RematSpecModifier.default_config().set( remat_policies={ "model.decoder.transformer.layer": RematSpec( prevent_cse=False, - policy=offload_dots_saveable_policy, + policy=config_for_function( + save_and_offload_only_these_names_regex + ).set( + names_which_can_be_saved=None, + names_which_can_be_offloaded=None, + offload_src="device", + offload_dst="pinned_host", + ), ), } ), - GradientAccumulationModifier.default_config().set(grad_acc_steps=4), + # GradientAccumulationModifier.default_config().set(grad_acc_steps=4), ], ), ), @@ -704,6 +713,31 @@ def get_trainer_kwargs( ], ), ), + ( + "tpu-v5e-.*", + ChainConfigModifier.default_config().set( + config_modifiers=[ + MeshShapeModifier.default_config().set( + mesh_shape=mesh_shape_from_axes(fsdp=-1) + ), + RematSpecModifier.default_config().set( + remat_policies={ + "model.decoder.transformer.layer": RematSpec( + prevent_cse=False, + policy=config_for_function( + save_and_offload_only_these_names_regex + ).set( + names_which_can_be_saved=None, + names_which_can_be_offloaded=None, + offload_src="device", + offload_dst="pinned_host", + ), + ), + } + ), + ], + ), + ), ( "tpu-v5p-.*", ChainConfigModifier.default_config().set( @@ -841,6 +875,8 @@ def get_trainer_kwargs( ) else: raise NotImplementedError(f"Unknown model size {model_size}.") + total_chips = len(jax.devices()) + trainer_kwargs["train_batch_size"] = total_chips model_kwargs = trainer_kwargs.pop("model_kwargs") model_kwargs.setdefault("vocab_size", vocab_size) if version == Version.V3_TIKTOKEN: # tiktoken tokenizer diff --git a/benchmark_deserialize.py b/benchmark_deserialize.py new file mode 100644 index 000000000..490a277e9 --- /dev/null +++ b/benchmark_deserialize.py @@ -0,0 +1,254 @@ +""" +A script to benchmark the GlobalAsyncCheckpointManager.deserialize function. + +This script contains a local patch for the deserialization logic to work around +a bug in the installed axlearn library, avoiding any modification to the library itself. +""" + +import asyncio +import functools +import math +import os +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Optional, Sequence, Union + +import jax +import jax.numpy as jnp +import numpy as np +import tensorstore as ts +from absl import app, flags +from jax._src import array, typing +from jax._src.layout import Layout +from jax.experimental.array_serialization import serialization +from jax.experimental.array_serialization.serialization import get_tensorstore_spec +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from axlearn.common.array_serialization import ( + GlobalAsyncCheckpointManager, + _get_premapped_buffer_size, +) +from axlearn.common.checkpointer import read_state_spec +from axlearn.common.utils import flatten_items + +# JAX platforms might be initialized by another process. +# We follow the logic in axlearn.common.launch to initialize JAX. +if os.environ.get("JAX_PLATFORMS", "") == "proxy": + import pathwaysutils # type: ignore + + pathwaysutils.initialize() +else: + jax.distributed.initialize() + + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + "checkpoint_dir", + "gs://cloud-tpu-multipod-dev-axlearn/stoelinga-v7-70b-17/checkpoints/step_00000100/", + "The GCS path to the checkpoint step directory.", +) +flags.DEFINE_integer("num_iterations", 5, "The number of benchmark iterations.") +flags.DEFINE_integer("warmup_iterations", 1, "The number of warmup iterations.") + + +# --- Local Patch for Deserialization --- +# The following functions are copied from axlearn.common.array_serialization +# and patched locally to fix a TypeError without modifying the library. + + +def _blocking_device_put(out: np.ndarray, layout: Layout) -> jax.Array: + return jax.block_until_ready(jax.device_put(out, layout)) + + +async def _patched_async_deserialize( + user_in_sharding: jax.sharding.Sharding | Layout, + tensorstore_spec: dict[str, Any], + global_shape: Optional[Sequence[int]], + dtype: Optional[typing.DTypeLike], + *, + h2d_limiter: serialization._LimitInFlightBytes, + byte_limiter: serialization._LimitInFlightBytes, + single_thread_pool: ThreadPoolExecutor, +): + """Patched version of _async_deserialize.""" + in_sharding = ( + user_in_sharding.sharding if isinstance(user_in_sharding, Layout) else user_in_sharding + ) + if not isinstance(in_sharding, jax.sharding.Sharding): + raise ValueError( + "sharding passed to deserialization should be specified, concrete and" + f" an instance of `jax.sharding.Sharding`. Got {in_sharding}" + ) + dll = user_in_sharding.device_local_layout if isinstance(user_in_sharding, Layout) else None + t = await ts.open( + tensorstore_spec, + open=True, + assume_metadata=False, + context=serialization.TS_CONTEXT, + ) + shape = tuple(t.shape if global_shape is None else global_shape) + new_shard_shape = in_sharding.shard_shape(shape) + loop = asyncio.get_running_loop() + + async def cb(index: array.Index, device: jax.Device): + requested_domain = ts.IndexTransform(input_shape=shape)[index].domain + restricted_domain = t.domain.intersect(requested_domain) + requested_bytes = serialization.estimate_read_memory_footprint(t, restricted_domain) + await byte_limiter.wait_for_bytes(requested_bytes) + read_ts = t[restricted_domain] + if dtype is not None: + read_ts = ts.cast(read_ts, dtype) + if tuple(t.shape) == shape: + out = np.empty(new_shard_shape, read_ts.dtype.numpy_dtype) + else: + out = np.zeros(new_shard_shape, read_ts.dtype.numpy_dtype) + + await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][restricted_domain].write( + read_ts + ) + + if out.dtype == jnp.int4: + out = jnp.asarray(out) + + out_size = out.size * out.dtype.itemsize + mb_256 = 256 * 1024 * 1024 + out_size = math.ceil(out_size / mb_256) * mb_256 + + sharding_for_put = jax.sharding.SingleDeviceSharding( + device, memory_kind=in_sharding.memory_kind + ) + if dll is not None: + sharding_for_put = Layout(dll, sharding_for_put) + + try: + await h2d_limiter.wait_for_bytes(out_size) + result = await loop.run_in_executor(None, _blocking_device_put, out, sharding_for_put) + await h2d_limiter.release_bytes(out_size) + except ValueError as e: + if "Requested more bytes than we reserved" not in str(e): + raise e + result = await loop.run_in_executor( + single_thread_pool, _blocking_device_put, out, sharding_for_put + ) + + await byte_limiter.release_bytes(requested_bytes) + return result + + # This is the patched line. + # pylint: disable-next=protected-access + return await serialization.create_async_array_from_callback(shape, in_sharding, cb) + + +class PatchedGlobalAsyncCheckpointManager(GlobalAsyncCheckpointManager): + """An override of the manager to use our patched deserialize logic.""" + + def deserialize( + self, + shardings: Sequence[Union[jax.sharding.Sharding, Layout]], + tensorstore_specs: Sequence[dict[str, Any]], + global_shapes: Optional[Sequence[array.Shape]] = None, + dtypes: Optional[Sequence[typing.DTypeLike]] = None, + concurrent_gb: int = 32, + ): + self.wait_until_finished() + concurrent_bytes = concurrent_gb * 10**9 + + max_shard_bytes = 0 + if global_shapes and dtypes: + for sharding, shape, dtype in zip(shardings, global_shapes, dtypes): + if isinstance(sharding, Layout): + sharding = sharding.sharding + shard_shape = sharding.shard_shape(shape) + shard_bytes = np.prod(shard_shape) * np.dtype(dtype).itemsize + if shard_bytes > max_shard_bytes: + max_shard_bytes = shard_bytes + + if max_shard_bytes > concurrent_bytes: + concurrent_bytes = int(max_shard_bytes) + 1 + + async def _run_deserializer(): + # pylint: disable=protected-access + byte_limiter = serialization._LimitInFlightBytes(concurrent_bytes) + h2d_limiter = serialization._LimitInFlightBytes(_get_premapped_buffer_size()) + future_arrays = jax.tree.map( + functools.partial( + _patched_async_deserialize, # Use our patched function. + byte_limiter=byte_limiter, + h2d_limiter=h2d_limiter, + single_thread_pool=self._single_thread_pool, + ), + shardings, + tensorstore_specs, + [None] * len(tensorstore_specs) if global_shapes is None else global_shapes, + [None] * len(tensorstore_specs) if dtypes is None else dtypes, + ) + return await asyncio.gather(*future_arrays) + + fut = asyncio.run_coroutine_threadsafe(_run_deserializer(), self._loop) + return fut.result() + + +def main(argv: Sequence[str]) -> None: + """Benchmarks the deserialize function.""" + del argv + + devices = jax.devices() + mesh = Mesh(devices, axis_names=("data",)) + + state_spec = read_state_spec(FLAGS.checkpoint_dir) + flat_state_spec = flatten_items(state_spec, separator="/") + + ts_specs, shardings_list, global_shapes, dtypes = [], [], [], [] + + for path, spec in flat_state_spec: + gda_path = os.path.join(FLAGS.checkpoint_dir, "gda", path) + ts_specs.append(get_tensorstore_spec(gda_path)) + + partition_spec = PartitionSpec() + if len(spec.shape) > 0 and spec.shape[0] % len(devices) == 0: + partition_spec = PartitionSpec("data", *(None,) * (len(spec.shape) - 1)) + + shardings_list.append(NamedSharding(mesh, partition_spec)) + global_shapes.append(spec.shape) + dtypes.append(spec.dtype) + + manager = PatchedGlobalAsyncCheckpointManager() + + def run_deserialize(): + """Runs deserialization across all tensors.""" + start_time = time.time() + restored_arrays = manager.deserialize( + shardings=shardings_list, + tensorstore_specs=ts_specs, + global_shapes=global_shapes, + dtypes=dtypes, + ) + for arr in restored_arrays: + arr.block_until_ready() + return time.time() - start_time + + print(f"Running {FLAGS.warmup_iterations} warmup iterations...") + for _ in range(FLAGS.warmup_iterations): + run_deserialize() + + print(f"Running {FLAGS.num_iterations} benchmark iterations...") + durations = [] + for i in range(FLAGS.num_iterations): + duration = run_deserialize() + print(f"Iteration {i+1} took {duration:.4f} seconds.") + durations.append(duration) + + print("\n--- Benchmark Results ---") + print(f"Number of devices: {len(devices)}") + print(f"Iterations: {FLAGS.num_iterations}") + print(f"Average time: {sum(durations) / len(durations):.4f} seconds") + print(f"Min time: {min(durations):.4f} seconds") + print(f"Max time: {max(durations):.4f} seconds") + print("-------------------------\n") + + manager.stop() + + +if __name__ == "__main__": + app.run(main) diff --git a/proxy_bench.py b/proxy_bench.py new file mode 100644 index 000000000..1f89a84c4 --- /dev/null +++ b/proxy_bench.py @@ -0,0 +1,90 @@ +"""A script to benchmark JAX device_put throughput.""" + +import asyncio +import os +import time +import uuid + +import jax +import numpy as np +import pathwaysutils +from jax.sharding import SingleDeviceSharding + + +async def benchmark_host_to_device_throughput( + device_put_buffer_mb: int = 512, num_transfers: int = 10 +): + """Benchmarks JAX device_put throughput from CPU host to a v5e-32 TPU slice.""" + print(f"JAX version: {jax.__version__}") + devices = jax.devices() if os.environ.get("JAX_PLATFORMS") else jax.local_devices() + num_devices = len(devices) + print(f"Available devices: {num_devices}") + + data_bytes_per_device = int(device_put_buffer_mb * 1024 * 1024) + dtype = np.float32 + num_elements = data_bytes_per_device // np.dtype(dtype).itemsize + data_gb_per_device = num_elements * np.dtype(dtype).itemsize / (1024**3) + + print( + f"Creating {num_devices} NumPy arrays of shape ({num_elements},) type {dtype}, size" + f" {data_gb_per_device:.2f} GiB each" + ) + host_arrays = [np.arange(num_elements, dtype=dtype) for _ in range(num_devices)] + shardings = [SingleDeviceSharding(device) for device in devices] + + loop = asyncio.get_running_loop() + transfer_times = [] + + print(f"Starting benchmark ({num_transfers} transfers)...") + for i in range(num_transfers): + uid = uuid.uuid4() + if i == 0: + trace_dir = f"gs://cloud-tpu-multipod-dev-axlearn/{uid}" + jax.profiler.start_trace(f"{trace_dir}/{device_put_buffer_mb}mb") + + start_time = time.perf_counter() + + # Issue device_put calls in parallel. + device_put_futures = [ + loop.run_in_executor(None, jax.device_put, host_arrays[j], shardings[j]) + for j in range(num_devices) + ] + device_arrays = await asyncio.gather(*device_put_futures) + + # Block until all transfers are complete. + for device_array in device_arrays: + device_array.block_until_ready() + + end_time = time.perf_counter() + + duration = end_time - start_time + transfer_times.append(duration) + print(f"Transfer {i+1}/{num_transfers}: {duration:.4f} seconds") + + if i == 0: + jax.profiler.stop_trace() + + # Optional: hint for early deletion. + del device_arrays + + avg_time = np.mean(transfer_times) + print(f"\nAverage time per parallel device_put batch: {avg_time:.4f} seconds") + + total_data_moved_gb = data_gb_per_device * num_devices + throughput_gb_s = total_data_moved_gb / avg_time + + print(f"Data per device: {data_gb_per_device:.2f} GiB") + print(f"Total data transferred from host per operation: {total_data_moved_gb:.2f} GiB") + print(f"Aggregated Host -> Devices Throughput: {throughput_gb_s:.2f} GiB/s") + print(f"Aggregated Host -> Devices Throughput: {throughput_gb_s * 8:.2f} Gbps/s") + + +if __name__ == "__main__": + if os.environ.get("JAX_PLATFORMS") == "proxy": + pathwaysutils.initialize() + else: + jax.distributed.initialize() + scenarios_mb = [1, 128, 1024, 2048] + for scenario in scenarios_mb: + print(f"Running scenario {scenario}MB") + asyncio.run(benchmark_host_to_device_throughput(scenario)) diff --git a/restore_fuji_70b.sh b/restore_fuji_70b.sh new file mode 100644 index 000000000..08bd5b352 --- /dev/null +++ b/restore_fuji_70b.sh @@ -0,0 +1,8 @@ +python3 -m axlearn.common.launch_trainer_main \ + --module=text.gpt.c4_trainer \ + --config=fuji-70B-v3-flash \ + --trainer_dir=gs://cloud-tpu-multipod-dev-uss1/axlearn-fuji-v3-70b/ \ + --data_dir=gs://axlearn-public/tensorflow_datasets \ + --jax_backend=proxy \ + --mesh_selector=tpu-v5litepod-32-1 \ + --trace_at_steps=11 diff --git a/train_fuji_7b.sh b/train_fuji_7b.sh new file mode 100644 index 000000000..06706944f --- /dev/null +++ b/train_fuji_7b.sh @@ -0,0 +1,10 @@ + +# --trainer_dir=gs://cloud-tpu-multipod-dev-uss1/axlearn-fuji-v2-7b/ \ + +python3 -m axlearn.common.launch_trainer_main \ + --module=text.gpt.c4_trainer \ + --config=fuji-7B-v2-flash \ + --trainer_dir=gs://cloud-tpu-multipod-dev-euw4/axlearn-fuji-v2-7b/ \ + --data_dir=gs://axlearn-public/tensorflow_datasets \ + --jax_backend=proxy +# --mesh_selector=tpu-v5p