diff --git a/Dockerfile b/Dockerfile index b186587a0..a8eddc74e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -90,7 +90,7 @@ ENTRYPOINT ["/opt/apache/beam/boot"] FROM base AS tpu -ARG EXTRAS= +ARG EXTRAS=orbax ENV UV_FIND_LINKS=https://storage.googleapis.com/jax-releases/libtpu_releases.html # Ensure we install the TPU version, even if building locally. diff --git a/axlearn/cloud/gcp/jobset_utils.py b/axlearn/cloud/gcp/jobset_utils.py index 0cdd34700..542ded227 100644 --- a/axlearn/cloud/gcp/jobset_utils.py +++ b/axlearn/cloud/gcp/jobset_utils.py @@ -243,6 +243,7 @@ class Config(BaseReplicatedJob.Config): env_vars: Optional env vars to set. gcsfuse_mount: Optional configs for the GCS FUSE sidecar and volume mount. See `GCSFuseMount` for details. + mtc_mount: Optionally mount multi-tier checkpointing volume. host_mounts: List of volumes from host to mount into the container. See `HostMount` for details. service_account: Optional service account to execute the job as. @@ -256,6 +257,7 @@ class Config(BaseReplicatedJob.Config): accelerator: AcceleratorConfig = AcceleratorConfig() env_vars: dict[str, str] = {} gcsfuse_mount: Optional[GCSFuseMount] = None + mtc_mount: Optional[bool] = None host_mounts: Optional[Sequence[HostMount]] = None service_account: Optional[str] = None # This config is made Optional for backwards compatibility. Default is False. @@ -377,6 +379,7 @@ class Config(SingleReplicatedJob.Config): to attach to the node pool. This is needed to support multiple NIC. Refer to GKE TPU provisioner for more context: https://github.com/GoogleCloudPlatform/ai-on-gke/blob/5f256eed7075a5cb8e73cd72328aea46237b8ce6/tpu-provisioner/internal/cloud/common.go#L29-L31 + mtc_mount: Whether or not to mount volume for multi-tier checkpointing """ reservation: Optional[str] = None @@ -386,6 +389,7 @@ class Config(SingleReplicatedJob.Config): enable_tpu_smart_repair: bool = False priority_class: Optional[str] = None additional_node_networks: Optional[str] = None + mtc_mount: Optional[bool] = None @classmethod def define_flags(cls, fv: flags.FlagValues): @@ -410,6 +414,12 @@ def define_flags(cls, fv: flags.FlagValues): "The GKE PriorityClass for the job.", **common_kwargs, ) + flags.DEFINE_boolean( + "mtc_mount", + None, + "Whether to mount checkpoint volume", + **common_kwargs, + ) @classmethod def from_flags(cls, fv: flags.FlagValues, **kwargs) -> Config: @@ -433,6 +443,7 @@ def from_flags(cls, fv: flags.FlagValues, **kwargs) -> Config: cfg.additional_node_networks = gcp_settings( "additional_node_networks", required=False, fv=fv ) + cfg.mtc_mount = fv.mtc_mount return cfg def __init__(self, cfg: Config, *, bundler: Bundler): @@ -476,6 +487,11 @@ def _build_container(self) -> Nested[Any]: volume_mounts, spec=VolumeMount(name="shared-memory", mount_path="/dev/shm") ) + if cfg.mtc_mount: + self._maybe_add_volume_mount( + volume_mounts, spec=VolumeMount(name="checkpoint", mount_path="/checkpoint") + ) + if cfg.host_mounts: for mount in cfg.host_mounts: self._maybe_add_volume_mount(volume_mounts, spec=mount) @@ -577,6 +593,17 @@ def _build_pod(self) -> Nested[Any]: annotations, labels, selector, volumes, tolerations = {}, {}, {}, [], [] volumes.append(dict(name="shared-output", emptyDir={})) + + if cfg.mtc_mount: + volumes.append( + dict( + name="checkpoint", + csi=dict( + driver="multitier-checkpoint.csi.storage.gke.io", + ), + ) + ) + if cfg.gcsfuse_mount: # Increases the shared memory volumes when enabled gcsfuse. This is useful when grain # prefetch is enabled. diff --git a/axlearn/common/checkpointer_orbax_emergency_replicator.py b/axlearn/common/checkpointer_orbax_emergency_replicator.py new file mode 100644 index 000000000..3aa4714d9 --- /dev/null +++ b/axlearn/common/checkpointer_orbax_emergency_replicator.py @@ -0,0 +1,340 @@ +# Copyright © 2024 Apple Inc. + +"""Implements Orbax replicator checkpointing and provide utilities for correct store. + +See the docstring of `OrbaxEmergencyReplicatorCheckpointer` for more details. +""" + +import copy +import os +import time +from contextlib import contextmanager +from typing import Any, Dict, Optional, Tuple, Union + +import jax +import jax.lib +import orbax.checkpoint as ocp +import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as oercp +from absl import flags, logging +from etils import epath +from jax._src.distributed import global_state +from jax._src.mesh import thread_resources + +from axlearn.common.checkpointer import BaseCheckpointer +from axlearn.common.config import REQUIRED, Required, config_class +from axlearn.common.module import Module +from axlearn.common.utils import Nested, Tensor, TensorSpec + +FLAGS = flags.FLAGS + +flags.DEFINE_integer( + "assume_data_parallelism", + None, + ( + "Number of identical pipelines in job, " + "should be equal to ICI data parallelism * DCN data parallelism." + ), +) + + +@contextmanager +def setup(spec: str): + """Setups FLAGS.process_id and FLAGS.distributed_coordinator as required by Orbax. + + Args: + spec: Key=Value pairs separated by comma. + """ + parsed_args = {} + allowed_fields = ["local_ckpt_dir", "assume_data_parallelism"] + for field in spec.split(","): + k, v = field.split("=") + if k not in allowed_fields: + raise ValueError(f"Expected key in {allowed_fields}, got key={k}.") + parsed_args[k] = v + if "local_ckpt_dir" not in parsed_args: + raise ValueError("local_ckpt_dir must be specified.") + if "assume_data_parallelism" not in parsed_args: + raise ValueError("assume_data_parallelism must be specified.") + # Get process ID and IP of jax coordinator + process_id, coordinator_address = _retrieve_jax_init_info(parsed_args["local_ckpt_dir"]) + FLAGS.assume_data_parallelism = int(parsed_args["assume_data_parallelism"]) + FLAGS.process_id = int(process_id) + FLAGS.distributed_coordinator = coordinator_address + FLAGS.experimental_orbax_use_distributed_process_id = True + + yield + + +def _wait_for_file_to_disappear(f, timeout=300): + for _ in range(timeout): + if not f.exists(): + return True + time.sleep(1) + logging.error("File %s did not dissappear in time.", f) + return False + + +def _extract_step(f): + # The base file name is formatted as {job_name}-s{step}-n{node_rank}-w{worker_rank} + return f.rsplit("-", 3)[1][1:] + + +def _block_and_process_restore_dir(directory, timeout=300): + """Blocks until the directory symlink ending with `.restore` appears, then extracts + the step number and renames the directory using the step number. + """ + directory_path = epath.Path(directory) + for _ in range(timeout): + for f in directory_path.glob("*.restore"): + step = _extract_step(f.name) + if step != "0": + f.rename(directory_path / step) + logging.info( + "Renamed restore directory at step %s to %s.", + step, + directory_path / step, + ) + else: + logging.info("Found a restore directory at step 0, skipping renaming.") + return + time.sleep(1) + logging.error("%s seconds have passed but no .restore file was found.", timeout) + raise TimeoutError(f"{timeout} seconds have passed but no .restore file was found.") + + +def _retrieve_jax_init_info(local_ckpt_dir): + """Retrieve JAX init info from a local file.""" + jax_init_info_file = "jax-init-info.txt" + local_jax_init_info_file = epath.Path(local_ckpt_dir) / jax_init_info_file + # Allow time for the JAX init info file to be populated by GKE. + # File only populated when the worker with process id of 0 is determined. + for i in range(900): + if local_jax_init_info_file.exists(): + return local_jax_init_info_file.read_text().split("\n")[:2] + logging.info( + "Unable to locate %s after %d seconds, sleeping for 1 second before retrying...", + jax_init_info_file, + i, + ) + time.sleep(1) + logging.error("Unable to locate %s after 900 seconds.", jax_init_info_file) + raise TimeoutError( + f"Unable to locate {jax_init_info_file} after 900 seconds, " + "returning empty process id and coordinator address." + ) + + +class OrbaxEmergencyReplicatorCheckpointer(BaseCheckpointer): + """Checkpointer implementation that uses Orbax emergency replicator checkpointer. + + EXPERIMENTAL. Do not use for actual training runs since the checkpoint layout will likely + change in the future.""" + + @config_class + class Config(BaseCheckpointer.Config): + """Configures OrbaxEmergencyReplicatorCheckpointer. + + Attributes: + backup_interval_minutes: How often GKE multi-tier checkpointer should back up local + checkpoints to GCS. + save_interval_steps: Number of steps between each checkpoint. + local_dir: Ckpt base path for local storage. The content in this path must persist + across pod restarts unless the restart is caused by node failure. `local_dir` must + be the same for all processes or processes may hang. + trainer_dir: A string that's unique for the current run. Typically, this is set to + trainer_dir. Local checkpoint will be stored in local_dir/sha256(trainer_dir). + During init, all other folders in local_dir will be removed to prevent unexpected + memory usage. + async_timeout_secs: Timeout for async barrier in seconds when saving tensors. + """ + + backup_interval_minutes: int = 30 + save_interval_steps: int = 100 + local_dir: str = "/checkpoint" + trainer_dir: Required[str] = REQUIRED + async_timeout_secs: int = 3600 + + def __init__(self, cfg: Config, *, parent: Optional[Module]): + super().__init__(cfg, parent=parent) + cfg: OrbaxEmergencyReplicatorCheckpointer.Config = self.config + self._name_format = ocp.step.standard_name_format( + step_prefix=None, + step_format_fixed_length=None, + ) + self._local_dir = cfg.local_dir + self._save_interval_steps = cfg.save_interval_steps + # Orbax replicator ckpt requires this function to be called prior to checkpointer + # operations. This function also serves as a barrier. + ocp.multihost.initialize_runtime_to_distributed_ids() + ocp.multihost.initialize_distributed_to_device_ids() + + num_slices = int(os.environ["MEGASCALE_NUM_SLICES"]) + + replicator_file = "replicator.yaml" + temp_file = replicator_file + ".tmp" + replicator_file_path = epath.Path(self._local_dir) / replicator_file + if not _wait_for_file_to_disappear(replicator_file_path): + logging.error("Existing replicator.yaml did not disappear in time.") + raise TimeoutError("Existing replicator.yaml did not disappear in time.") + else: + logging.info("replicator.yaml no longer exists, creating new replicator.yaml.") + temp_file = epath.Path(self._local_dir) / temp_file + num_nodes = jax.process_count() + nodes_per_slice = num_nodes // num_slices + + node_rank = global_state.process_id + my_process_index = jax.process_index() + proc_index_to_node_rank = ocp.multihost.runtime_to_distributed_ids() + + my_in_pipeline_index = my_process_index % nodes_per_slice + peer_ranks = [] + for i in range(num_slices): + peer_process_index = i * nodes_per_slice + my_in_pipeline_index + if peer_process_index != my_process_index: + peer_process_rank = proc_index_to_node_rank[peer_process_index] + peer_ranks.append(peer_process_rank) + + logging.info("Peers for NodeRank %s: %s", node_rank, peer_ranks) + + run_name = os.environ.get("HOSTNAME").split("job")[0].rstrip("-") + + if run_name is None or run_name == "": + logging.error("HOSTNAME is not set or value is invalid.") + raise ValueError("HOSTNAME is not set or value is invalid.") + + replicator_yaml = f"""job-name: {run_name} + framework: orbax + assume-data-parallelism: {FLAGS.assume_data_parallelism} + node-rank: {node_rank} + nodes: {num_nodes} + peer-ranks: {peer_ranks} + backup-interval-minutes: {cfg.backup_interval_minutes}""" + + temp_file.write_text("\n".join([l.strip() for l in replicator_yaml.split("\n")])) + os.rename(temp_file, replicator_file_path) + if not _wait_for_file_to_disappear(replicator_file_path): + logging.error("The newly created replicator.yaml was not deleted in time.") + raise TimeoutError("The newly created replicator.yaml was not deleted in time.") + else: + logging.info("The newly created replicator.yaml was deleted, moving forward.") + _block_and_process_restore_dir(self._local_dir) + + self._tensor_manager: Optional[oercp.ReplicatorCheckpointManager] = None + # See comments of _eval_summaries in `OrbaxCheckpointer`. + self._eval_summaries = None + self._reached_preemption = False + + def _get_abstract_state( + self, state_with_tensors: Nested[Tensor] + ) -> Nested[jax.ShapeDtypeStruct]: + """Generate the abstract states required by the Orbax replicator checkpointer.""" + return jax.tree.map( + lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=x.sharding), + state_with_tensors, + ) + + def _get_tensor_manager( + self, + ) -> oercp.ReplicatorCheckpointManager: + """Creates the replicator checkpoint manager if not exists. + + We defer the creation of this checkpoint manager because it requires the state dict, + which is not present during __init__. + """ + if self._tensor_manager is not None: + return self._tensor_manager + + # For meaning of these options, refer to + # https://github.com/google/orbax/blob/de0b6d0bca643d12840ae73a1f7cfee80af73dcd/checkpoint/orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager.py#L87 + self._tensor_manager = oercp.ReplicatorCheckpointManager( + self._local_dir, + options=oercp.ReplicatorCheckpointManagerOptions( + save_interval_steps=self._save_interval_steps, + step_name_format=self._name_format, + ), + global_mesh=thread_resources.env.physical_mesh, + ) + return self._tensor_manager + + def save( + self, *, step: int, state: Nested[Tensor], evaler_summaries: Optional[Dict[str, Any]] = None + ): + """See `BaseCheckpointer.save` for details.""" + assert self._eval_summaries is None, self._eval_summaries + self._eval_summaries = copy.deepcopy(evaler_summaries or {}) + self._reached_preemption = self._tensor_manager.reached_preemption(step) + + state_with_tensors = jax.tree.map( + lambda x: x if isinstance(x, (Tensor, TensorSpec)) else None, state + ) + + start_t = time.perf_counter() + self._get_tensor_manager().save( + step=step, args=ocp.args.Composite(state=ocp.args.PyTreeSave(item=state_with_tensors)) + ) + time_diff = time.perf_counter() - start_t + logging.info("Save time is %fs.", time_diff) + self._eval_summaries = None + if self._reached_preemption: + self.wait_until_finished() + raise SystemExit(f"Exiting after saving checkpoint at {step=} due to pre-emption.") + + def restore( + self, + *, + step: Optional[int] = None, + state: Union[Nested[Tensor], Nested[TensorSpec]], + ) -> Tuple[Optional[int], Nested[Tensor]]: + """Restores state from either local or persistent checkpoint.""" + start_t = time.perf_counter() + cfg: OrbaxEmergencyReplicatorCheckpointer.Config = self.config + state_with_tensors = jax.tree.map( + lambda x: x if isinstance(x, (Tensor, TensorSpec)) else None, state + ) + tensor_manager = self._get_tensor_manager() + if step is None: + common_steps = sorted(set(self._tensor_manager.all_steps())) + + if not common_steps: + logging.warning("Could not find any completed checkpoints under %s.", cfg.dir) + return None, state + + step = max(common_steps) + + def _restore_args(x: Any) -> ocp.RestoreArgs: + return ocp.checkpoint_utils.construct_restore_args( + jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding) + ) + + restore_args = jax.tree.map(_restore_args, state) + + restored_state_with_tensors = tensor_manager.restore( + step=step, + args=ocp.args.Composite( + state=ocp.args.PyTreeRestore( + item=self._get_abstract_state(state_with_tensors), + restore_args=restore_args, + ) + ), + ) + + restored_state_with_tensors = restored_state_with_tensors["state"] + + restored_state = jax.tree.map( + lambda non_tensor, tensor: non_tensor if tensor is None else tensor, + state, + restored_state_with_tensors, + ) + + time_diff = time.perf_counter() - start_t + logging.info("Took %ss to restore replicator checkpoint from %s.", time_diff, cfg.dir) + return step, restored_state + + def wait_until_finished(self): + """See `BaseCheckpointer.wait_until_finished` docstring for details.""" + self._tensor_manager.wait_until_finished() + + def stop(self, *, has_exception: bool = False): + """See `BaseCheckpointer.stop` for details.""" + if self._tensor_manager: + self._tensor_manager.close() diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index da86a87e9..f93e14f21 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -690,6 +690,7 @@ def get_trainer_config_fn( keep_every_n_steps: int = 50_000, save_every_n_steps: Optional[int] = None, init_state_builder: Optional[state_builder.Builder.Config] = None, + checkpointer: str = "", ) -> TrainerConfigFn: """Builds a TrainerConfigFn according to the model and input specs. @@ -764,6 +765,63 @@ def config_fn() -> InstantiableConfig: ) cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps) cfg.checkpointer.keep_last_n = 3 + calculated_save_every_n_steps = save_every_n_steps or min(eval_every_n_steps, 100) + + if not checkpointer: + cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set( + n=calculated_save_every_n_steps, + max_step=max_step, + ) + cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps) + cfg.checkpointer.keep_last_n = 3 + elif checkpointer == "OrbaxEmergencyCheckpointer": + # Prevent global dependency on Orbax. + # pylint: disable-next=import-outside-toplevel + from axlearn.common.checkpointer_orbax_emergency import OrbaxEmergencyCheckpointer + + ckpt_config: OrbaxEmergencyCheckpointer.Config = ( + OrbaxEmergencyCheckpointer.default_config() + ) + ckpt_config.save_policy = config_for_function(every_n_steps_and_last_policy).set( + n=calculated_save_every_n_steps, + max_step=max_step, + ) + ckpt_config.local_save_policy = config_for_function(every_n_steps_and_last_policy).set( + n=calculated_save_every_n_steps, + max_step=max_step, + ) + ckpt_config.local_dir = "/checkpoint" + ckpt_config.keep_every_n_steps = min(max_step, keep_every_n_steps) + ckpt_config.keep_last_n = 3 + ckpt_config.replica_axis_index = 1 + cfg.checkpointer = ckpt_config + elif checkpointer == "OrbaxEmergencyReplicatorCheckpointer": + # Prevent global dependency on Orbax. + # pylint: disable-next=import-outside-toplevel + from axlearn.common.checkpointer_orbax_emergency_replicator import ( + OrbaxEmergencyReplicatorCheckpointer, + ) + + ckpt_config: OrbaxEmergencyReplicatorCheckpointer.Config = ( + OrbaxEmergencyReplicatorCheckpointer.default_config() + ) + ckpt_config.local_dir = "/checkpoint" + cfg.checkpointer = ckpt_config + elif checkpointer == "OrbaxRegularCheckpointer": + # Prevent global dependency on Orbax. + # pylint: disable-next=import-outside-toplevel + from axlearn.common.checkpointer_orbax import OrbaxCheckpointer + + ckpt_config: OrbaxCheckpointer.Config = OrbaxCheckpointer.default_config() + ckpt_config.save_policy = config_for_function(every_n_steps_and_last_policy).set( + n=calculated_save_every_n_steps, + max_step=max_step, + ) + ckpt_config.keep_every_n_steps = min(max_step, keep_every_n_steps) + ckpt_config.keep_last_n = 3 + ckpt_config.enable_single_replica_ckpt_restoring = True + cfg.checkpointer = ckpt_config + cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 100) cfg.summary_writer.max_queue = 1000 if len(mesh_axis_names) != len(mesh_shape): diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 9ec469dbb..9101f05a5 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -944,17 +944,37 @@ def trainer_configs( """ arch = "fuji" config_map = {} - for version, model_size, flash_attention in itertools.product( - Version, MODEL_SIZES, [True, False] + for version, model_size, flash_attention, checkpointer in itertools.product( + Version, + MODEL_SIZES, + [True, False], + [ + "", + "OrbaxEmergencyCheckpointer", + "OrbaxEmergencyReplicatorCheckpointer", + "OrbaxRegularCheckpointer", + ], ): if model_size not in TOTAL_TOKENS[version]: # This combination does not exist. continue vocab_size = VOCAB_SIZE[version] + + current_suffix_parts = [] + if flash_attention: + current_suffix_parts.append("-flash") + if checkpointer == "OrbaxEmergencyCheckpointer": + current_suffix_parts.append("-orbaxem") + if checkpointer == "OrbaxEmergencyReplicatorCheckpointer": + current_suffix_parts.append("-orbaxem-replicator") + elif checkpointer == "OrbaxRegularCheckpointer": + current_suffix_parts.append("-orbax") + current_suffix = "".join(current_suffix_parts) + config_name = make_config_name( arch=arch, model_size=model_size, version=f"v{version.value}", - suffix="-flash" if flash_attention else "", + suffix=current_suffix, ) kwargs = get_trainer_kwargs( model_size, vocab_size=vocab_size, version=version, flash_attention=flash_attention @@ -969,6 +989,7 @@ def trainer_configs( evalers=evaler_config_dict( eval_input_sources(vocab_size=vocab_size, max_sequence_length=max_sequence_length), ), + checkpointer=checkpointer, **kwargs, )