Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions axlearn/cloud/gcp/jobset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ def __init__(self, cfg: Config, *, bundler: Bundler):
job_name=cfg.job_name,
)
self._output_volume_mount = dict(name="shared-output", mountPath="/output")
self._output_volume_mount = dict(name="checkpoint", mountPath="/checkpoint")
if cfg.additional_node_networks and not cfg.service_account:
raise ValueError("service_account must be set if additional_node_networks is set.")
self._load_balancer = _LoadBalancer(jobset_name=cfg.name, replicated_job_name=cfg.job_name)
Expand Down Expand Up @@ -543,6 +544,16 @@ def _build_pod(self) -> Nested[Any]:
annotations, labels, selector, volumes, tolerations = {}, {}, {}, [], []

volumes.append(dict(name="shared-output", emptyDir={}))

volumes.append(
dict(
name="checkpoint",
csi=dict(
driver="multitier-checkpoint.csi.storage.gke.io",
),
)
)

if cfg.gcsfuse_mount:
# Increases the shared memory volumes when enabled gcsfuse. This is useful when grain
# prefetch is enabled.
Expand Down
84 changes: 81 additions & 3 deletions axlearn/common/checkpointer_orbax.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import jax
import numpy as np
import orbax.checkpoint as ocp
import tensorflow as tf
from absl import logging
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
from orbax.checkpoint._src.serialization.type_handlers import ArrayHandler

from axlearn.common import utils
from axlearn.common.checkpointer import (
Expand Down Expand Up @@ -187,15 +190,20 @@ class Config(BaseCheckpointer.Config):

Attributes:
keep_last_n: Keep this many past ckpts.
keep_every_n_steps: If set, keep a checkpoint every n steps.
validation_type: Checkpoint validation during restore.
async_timeout_secs: Timeout for async barrier in seconds.
use_replica_parallel: Whether to use replica parallel checkpointing.
"""

keep_last_n: int = 1
keep_every_n_steps: Optional[int] = None
validation_type: CheckpointValidationType = CheckpointValidationType.EXACT
async_timeout_secs: int = 300
max_concurrent_save_gb: Optional[int] = None
max_concurrent_restore_gb: Optional[int] = None
enable_single_replica_ckpt_restoring: bool = True
use_replica_parallel: bool = False

@classmethod
def checkpoint_paths(cls, base_dir: str) -> List[str]:
Expand All @@ -210,6 +218,7 @@ def checkpoint_steps(cls, base_dir) -> list[int]:
def __init__(self, cfg: Config, *, parent: Optional[Module]):
super().__init__(cfg, parent=parent)

logging.set_verbosity(logging.DEBUG)
cfg: OrbaxCheckpointer.Config = self.config
save_policy = cfg.save_policy.instantiate()

Expand Down Expand Up @@ -237,6 +246,7 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool:
options=ocp.CheckpointManagerOptions(
create=True,
max_to_keep=cfg.keep_last_n,
keep_period=cfg.keep_every_n_steps,
enable_async_checkpointing=True,
step_name_format=self._name_format,
should_save_fn=save_fn_with_summaries,
Expand Down Expand Up @@ -285,10 +295,23 @@ def save(

Checkpoint saving is handled by `orbax` checkpoint manager.
"""
cfg: OrbaxCheckpointer.Config = self.config
spec = self._get_spec(step=step, state=state)
assert self._eval_summaries is None, self._eval_summaries
self._eval_summaries = copy.deepcopy(evaler_summaries or {})

# Store the original handler to restore it later
original_handler = None
if not cfg.use_replica_parallel:
# Get the current handler for jax.Array
original_handler = ocp.type_handlers.get_type_handler(jax.Array)
# Register a new ArrayHandler with use_replica_parallel=False
custom_handler = ArrayHandler(
use_replica_parallel=False,
array_metadata_store=array_metadata_store_lib.Store(),
)
ocp.type_handlers.register_type_handler(jax.Array, custom_handler, override=True)

try:
# Note that save() waits for prior serialization to finish.
self._manager.save(
Expand All @@ -309,6 +332,9 @@ def save(
self._manager.wait_until_finished()
raise SystemExit(f"Exiting after saving checkpoint at {step=} due to pre-emption.")
finally:
# Restore the original handler if we modified it
if not cfg.use_replica_parallel and original_handler is not None:
ocp.type_handlers.register_type_handler(jax.Array, original_handler, override=True)
self._eval_summaries = None

def restore(
Expand All @@ -321,11 +347,33 @@ def restore(

cfg: OrbaxCheckpointer.Config = self.config

if cfg.enable_single_replica_ckpt_restoring:
array_handler = ocp.type_handlers.SingleReplicaArrayHandler(
replica_axis_index=0,
broadcast_memory_limit_bytes=1024 * 1024 * 1000, # 1000 MB limit
)
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)

def _restore_args(x: Any) -> ocp.RestoreArgs:
if isinstance(x, (Tensor, TensorSpec)):
return ocp.checkpoint_utils.construct_restore_args(
jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)
)
if cfg.enable_single_replica_ckpt_restoring:
pspec = x.sharding.spec
mesh = x.sharding.mesh
replica_axis_index = 0
replica_devices = _replica_devices(mesh.devices, replica_axis_index)
replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names)
single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec)

return ocp.type_handlers.SingleReplicaArrayRestoreArgs(
sharding=jax.sharding.NamedSharding(mesh, pspec),
single_replica_sharding=single_replica_sharding,
global_shape=x.shape,
dtype=x.dtype,
)
else:
return ocp.checkpoint_utils.construct_restore_args(
jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)
)
elif isinstance(x, tf.data.Iterator):
return _TfIteratorHandler.RestoreArgs(item=x)
elif _GRAIN_INSTALLED and isinstance(x, _GrainIterator):
Expand All @@ -349,6 +397,13 @@ def _restore_args(x: Any) -> ocp.RestoreArgs:
raise ValueError(f"Failed to restore at step {step}.") from e
logging.info("Could not find any completed checkpoints under %s: %s", cfg.dir, e)
return None, state # Return the input state.
finally:
if cfg.enable_single_replica_ckpt_restoring:
ocp.type_handlers.register_type_handler(
jax.Array,
ArrayHandler(array_metadata_store=array_metadata_store_lib.Store()),
override=True,
)

restored_index = composite_state["index"]
restored_state = composite_state["state"]
Expand All @@ -375,3 +430,26 @@ def wait_until_finished(self):
def stop(self, *, has_exception: bool = False):
"""See `BaseCheckpointer.stop` for details."""
self._manager.close()


def _find_idx(array: np.ndarray, replica_axis_idx: int):
"""Returns the index along given dimension that the current host belongs to."""
idx = None
for idx, val in np.ndenumerate(array):
if val.process_index == jax.process_index():
break
return idx[replica_axis_idx]


def _replica_devices(device_array: np.ndarray, replica_axis_idx: int):
"""Returns the devices from the replica that current host belongs to.
Replicas are assumed to be restricted to the first axis.
Args:
device_array: devices of the mesh that can be obtained by mesh.devices()
replica_axis_idx: axis dimension along which replica is taken
Returns:
devices inside the replica that current host is in
"""
idx = _find_idx(device_array, replica_axis_idx)
replica_result = np.take(device_array, idx, axis=replica_axis_idx)
return np.expand_dims(replica_result, axis=replica_axis_idx)
Loading
Loading