Skip to content

Commit fd1e5bb

Browse files
committed
Update params to work for number of devices
1 parent 563f9e6 commit fd1e5bb

File tree

3 files changed

+17
-27
lines changed

3 files changed

+17
-27
lines changed

axlearn/cloud/gcp/jobset_utils.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -452,14 +452,14 @@ def _build_container(self) -> Nested[Any]:
452452
env_vars["ENABLE_ICI_RESILIENCY"] = str(cfg.enable_tpu_ici_resiliency).lower()
453453

454454
resources = {"limits": {"google.com/tpu": system.chips_per_vm}}
455-
# Set request memory by host machine type.
456-
machine_memory_gi = GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS.get(
457-
system.gce_machine_type, None
458-
)
459-
if machine_memory_gi is not None:
460-
request_memory_gi = machine_memory_gi * _MEMORY_REQUEST_PERCENTAGE
461-
resources["limits"]["memory"] = f"{machine_memory_gi}Gi"
462-
resources["requests"] = {"memory": f"{math.floor(request_memory_gi)}Gi"}
455+
# # Set request memory by host machine type.
456+
# machine_memory_gi = GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS.get(
457+
# system.gce_machine_type, None
458+
# )
459+
# if machine_memory_gi is not None:
460+
# request_memory_gi = machine_memory_gi * _MEMORY_REQUEST_PERCENTAGE
461+
# resources["limits"]["memory"] = f"{machine_memory_gi}Gi"
462+
# resources["requests"] = {"memory": f"{math.floor(request_memory_gi)}Gi"}
463463

464464
k8s_env_vars = [dict(name=k, value=str(v)) for k, v in env_vars.items()]
465465
k8s_env_vars.append(
@@ -509,8 +509,8 @@ def _build_uploader_container(
509509
interval_s = 60
510510
sync_command = f"while true; do gsutil -m rsync -r {src} {dst}; sleep {interval_s}; done"
511511
resources = {
512-
"requests": {"cpu": "100m", "memory": "128Mi"},
513-
"limits": {"cpu": "500m", "memory": "256Mi"},
512+
# "requests": {"cpu": "100m", "memory": "128Mi"},
513+
# "limits": {"cpu": "500m", "memory": "256Mi"},
514514
}
515515
return dict(
516516
name="output-uploader",

axlearn/common/trainer.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -624,24 +624,10 @@ def run(
624624

625625
self._step = self._step + 1
626626
self.vlog(3, "Start step %s", self.step)
627-
<<<<<<< HEAD
628-
self._maybe_record_event(measurement.Event.START_STEP, self._step)
629-
output = self._run_step(
630-
utils.host_to_global_array(
631-
input_batch,
632-
partition=self._train_step_input_partition_specs(),
633-
),
634-
force_run_evals=(
635-
force_run_eval_sets_at_max_step
636-
if self.step >= cfg.max_step
637-
else None
638-
),
639-
=======
640627
step_events_manager = (
641628
self._recorder.record_event(measurement.Event.STEP, self.step)
642629
if self._recorder
643630
else contextlib.nullcontext()
644-
>>>>>>> 3755939 (Add workload hang monitoring & rolling window goodput support)
645631
)
646632
with step_events_manager:
647633
output = self._run_step(

axlearn/experiments/text/gpt/fuji.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
import enum
1414
import functools
1515
import itertools
16+
import jax
1617
from typing import Any, List, NamedTuple, Optional, Union
18+
from absl import logging
1719

1820
from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies
1921

@@ -813,6 +815,7 @@ def get_trainer_kwargs(
813815
),
814816
)
815817
elif model_size == "150B":
818+
logging.info("******* debugging number of devices: %s", len(jax.devices()))
816819
trainer_kwargs = dict(
817820
model_kwargs=dict(
818821
num_layers=80,
@@ -827,8 +830,9 @@ def get_trainer_kwargs(
827830
),
828831
learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1),
829832
max_sequence_length=max_sequence_length,
830-
train_batch_size=train_batch_size,
831-
max_step=max_step,
833+
train_batch_size=len(jax.devices()), # train_batch_size,
834+
max_step=10_000, # max_step,
835+
save_every_n_steps=100,
832836
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=64, model=4),
833837
mesh_rules=(
834838
(
@@ -839,7 +843,7 @@ def get_trainer_kwargs(
839843
ChainConfigModifier.default_config().set(
840844
config_modifiers=[
841845
MeshShapeModifier.default_config().set(
842-
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=64, model=4)
846+
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256)
843847
),
844848
RematSpecModifier.default_config().set(
845849
remat_policies={

0 commit comments

Comments
 (0)