Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions docs/getting_started/train.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ You can visit the `config.py` file under each subfolder to see what parameters a
- **freeze_modules**: List of submodules (e.g., `visual`) to freeze during training.
- **use_liger_kernel/use_rmpad**: Performance optimizations. Keep enabled if supported on your stack.
- **fsdp2/fsdp_config**: Enable FSDP2 sharding and wrap transformer layer classes via `transformer_layer_cls_to_wrap`. Tune `reshard_after_forward` for memory/perf trade-offs.
- **enable_cuda_event_profiler**: Enable a low-overhead CUDA event profiler for FSDP2 training. It writes per-rank JSONL files under `output_dir/cuda_event_profiler` for phases such as `host_to_device`, `training_step`, and `training_metrics`.
- **EMA (Exponential Moving Average)**: Enable EMA with `ema_enabled: true`. Configure `ema_decay` (default 0.9999), `ema_update_every`, `ema_start_step`, and optionally filter parameters via `ema_param_filter`. EMA checkpoints are saved alongside regular checkpoints and can be merged using `merge_fsdp.py` with `--state_dict_dirname pytorch_ema_model_fsdp_0`.

## Run
Expand Down Expand Up @@ -174,6 +175,27 @@ Here are frequently used parameters you can override:
- `trainer_args.ema_requires_grad_only`: Only apply EMA to trainable parameters (default: `true`)
- `trainer_args.ema_param_filter`: Filter parameters by name (supports `mode`, `include`, `exclude`)
- `trainer_args.ema_resume_from_ema`: Resume training from EMA weights (default: `false`)
- `trainer_args.enable_cuda_event_profiler`: Enable lightweight CUDA event timing (default: `false`)
- `trainer_args.cuda_event_profiler_config`: Optional profiler window, rank filter, and sampling config, e.g. `{start_step: 100, end_step: 1000, record_every_n_steps: 10, flush_every_n_steps: 50, ranks: [0, 1, 7]}`

### Lightweight CUDA Event Profiling

For long-running distributed jobs, `torch.profiler` traces can be too heavy to keep enabled. The CUDA event profiler records only named phase durations and writes one JSON object per completed event:

```yaml
trainer_args:
enable_cuda_event_profiler: true
cuda_event_profiler_config:
start_step: 100
end_step: 1000
record_every_n_steps: 10
flush_every_n_steps: 50
ranks: [0, 1, 7]
```

Selected ranks write to `output_dir/cuda_event_profiler/cuda_events_rank_<rank>.jsonl`. These files can be aggregated into rank heatmaps or timeline views to diagnose stragglers without the synchronization overhead of full profiler traces.

The profiler is intended for diagnosis and remains disabled by default. For large jobs, prefer bounded windows, sampled steps, and rank filters instead of recording every rank on every step. If `record_every_n_steps` is omitted, it defaults to 10.

### Advanced Example

Expand Down Expand Up @@ -204,5 +226,3 @@ This loads all settings from `qwen2_5_vl_dp.yaml` in the specified directory and
- Boolean values: `packing=true` or `packing=false`
- For complex values (lists/arrays), use Hydra's syntax: `trainer_args.fsdp_config.transformer_layer_cls_to_wrap=["Qwen2_5_VLDecoderLayer"]`
- Add new parameters with `+`: `+dataset_config.extra_kwargs.image_max_pixels=4194304`


2 changes: 2 additions & 0 deletions src/lmms_engine/launch/config/default_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ trainer_args:
print_batch_input_steps: -1
enable_profiler: false
profiler_config: null
enable_cuda_event_profiler: false
cuda_event_profiler_config: null
ep_degree: 1
sp_ulysses_degree: 1
tp_degree: 1
Expand Down
2 changes: 2 additions & 0 deletions src/lmms_engine/train/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class TrainingArguments(transformers.TrainingArguments):
# and auto-dumps a .pickle on CUDA OOM. View at https://pytorch.org/memory_viz
enable_memory_snapshot: Optional[bool] = False
memory_snapshot_config: Optional[Dict[str, Any]] = None
enable_cuda_event_profiler: Optional[bool] = False
cuda_event_profiler_config: Optional[Dict[str, Any]] = None

# Parallelism
ep_degree: Optional[int] = 1
Expand Down
67 changes: 41 additions & 26 deletions src/lmms_engine/train/fsdp2/fsdp2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
get_cosine_schedule_with_warmup,
get_wsd_schedule_with_warmup,
)
from lmms_engine.utils.profiler import MemorySnapshotProfiler, StepProfiler
from lmms_engine.utils.profiler import (
CudaEventProfiler,
MemorySnapshotProfiler,
StepProfiler,
)
from lmms_engine.utils.tracking import Tracking

DatasetType = Union[Dataset, IterableDataset]
Expand Down Expand Up @@ -80,6 +84,12 @@ def __init__(
rank=dist.get_rank(),
memory_snapshot_config=getattr(self.args, "memory_snapshot_config", None),
)
self.cuda_event_profiler = CudaEventProfiler(
enable=getattr(self.args, "enable_cuda_event_profiler", False),
directory=os.path.join(self.args.output_dir, "cuda_event_profiler"),
profiler_config=getattr(self.args, "cuda_event_profiler_config", None),
rank=dist.get_rank(),
)
self.accumulated_grad_steps = 0

# Optional EMA (fully opt-in)
Expand Down Expand Up @@ -364,11 +374,13 @@ def train(self, resume_from_checkpoint: bool = False):
if self.should_stop():
break
# send batch to device
batch = send_to_device(batch, self.fsdp2_model.device)
with self.cuda_event_profiler.record("host_to_device", self.global_step):
batch = send_to_device(batch, self.fsdp2_model.device)
self.memory_snapshot_profiler.step(self.global_step)
start_time = time.perf_counter()
try:
train_metrics = self.training_step(batch)
with self.cuda_event_profiler.record("training_step", self.global_step):
train_metrics = self.training_step(batch)
except torch.OutOfMemoryError:
self.memory_snapshot_profiler.dump_on_exception(f"oom_step{self.global_step}")
raise
Expand All @@ -383,29 +395,31 @@ def train(self, resume_from_checkpoint: bool = False):
end_time = time.perf_counter()
delta_time = end_time - start_time

# Calculate flops per rank
seq_len = batch.get("attention_mask", torch.zeros((1, 1))).sum(dim=1).detach().cpu().tolist()
flops, promised_flops, raw_flops = model_utils.flops_counter.estimate_flops(
seq_len, delta_time=delta_time
)
self.compute_tracker.accumulate_flops(raw_flops)
device = self.fsdp2_model.device
flops_tensor = torch.tensor(flops, device=device)
sp_size = pgm.process_group_manager.cp_world_size
tp_size = pgm.process_group_manager.tp_world_size
parallel_size = sp_size * tp_size

# Calculate training metrics (MFU, token stats, throughput)
perf_metrics, self.total_tokens = self.calculate_training_metrics(
flops_tensor=flops_tensor,
parallel_size=parallel_size,
promised_flops=promised_flops,
device=device,
seq_len=seq_len,
total_tokens=self.total_tokens,
delta_time=delta_time,
world_size=world_size,
)
with self.cuda_event_profiler.record("training_metrics", self.global_step):
# Calculate flops per rank
seq_len = batch.get("attention_mask", torch.zeros((1, 1))).sum(dim=1).detach().cpu().tolist()
flops, promised_flops, raw_flops = model_utils.flops_counter.estimate_flops(
seq_len, delta_time=delta_time
)
self.compute_tracker.accumulate_flops(raw_flops)
device = self.fsdp2_model.device
flops_tensor = torch.tensor(flops, device=device)
sp_size = pgm.process_group_manager.cp_world_size
tp_size = pgm.process_group_manager.tp_world_size
parallel_size = sp_size * tp_size

# Calculate training metrics (MFU, token stats, throughput)
perf_metrics, self.total_tokens = self.calculate_training_metrics(
flops_tensor=flops_tensor,
parallel_size=parallel_size,
promised_flops=promised_flops,
device=device,
seq_len=seq_len,
total_tokens=self.total_tokens,
delta_time=delta_time,
world_size=world_size,
)
self.cuda_event_profiler.maybe_flush(self.global_step)
Comment on lines +398 to +422
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe calculating training metrics should not be recorded as a cuda event?

train_metrics.update(perf_metrics)
self.print_batch_input(batch)

Expand Down Expand Up @@ -463,6 +477,7 @@ def train(self, resume_from_checkpoint: bool = False):
"compute/co2_kg": summary.co2_kg,
}
)
self.cuda_event_profiler.close()

def evaluate(self):
raise NotImplementedError("Evaluation is not implemented")
Expand Down
123 changes: 123 additions & 0 deletions src/lmms_engine/utils/profiler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import time
from contextlib import contextmanager, nullcontext
Expand Down Expand Up @@ -184,3 +185,125 @@ def stop_and_save(self, reason: str = "manual"):
pass
self.stopped = True
logger.info(f"[MemSnapshot] recording stopped (reason={reason}, rank {self.rank})")


class CudaEventProfiler:
"""Low-overhead CUDA event profiler for long-running distributed jobs.

Unlike torch.profiler, this profiler only records named CUDA event pairs and
writes completed durations to JSONL. Non-blocking flushes avoid introducing
synchronization into the training step.
"""

def __init__(
self,
enable: bool,
directory: str,
rank: int = 0,
profiler_config: Optional[Dict[str, Any]] = None,
):
self.enable = enable and torch.cuda.is_available()
self.rank = rank
self.directory = directory
self.profiler_config = profiler_config or {}
self.start_step = self.profiler_config.get("start_step", 0)
self.end_step = self.profiler_config.get("end_step")
self.record_every_n_steps = max(int(self.profiler_config.get("record_every_n_steps", 10)), 1)
self.flush_every_n_steps = max(int(self.profiler_config.get("flush_every_n_steps", 10)), 1)
self.ranks = self.profiler_config.get("ranks")
if self.ranks is not None:
self.ranks = {int(rank) for rank in self.ranks}
self.pending_events = []
self._last_flush_step = -1
self._file = None

if not self.enable:
if enable:
logger.warning("[CudaEventProfiler] CUDA is unavailable; profiler is disabled")
return
if self.ranks is not None and self.rank not in self.ranks:
self.enable = False
return

os.makedirs(self.directory, exist_ok=True)
self.path = os.path.join(self.directory, f"cuda_events_rank_{self.rank}.jsonl")
self._file = open(self.path, "a", buffering=1)
logger.info(f"[CudaEventProfiler] Writing CUDA event timings to {self.path}")

def should_record(self, step: int) -> bool:
if not self.enable:
return False
if self.ranks is not None and self.rank not in self.ranks:
return False
if step < self.start_step:
return False
if self.end_step is not None and step > self.end_step:
return False
return (step - self.start_step) % self.record_every_n_steps == 0

@contextmanager
def record(self, name: str, step: int, **metadata):
if not self.should_record(step):
yield
return

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
try:
yield
finally:
end_event.record()
self.pending_events.append(
{
"name": name,
"step": step,
"rank": self.rank,
"start_event": start_event,
"end_event": end_event,
"metadata": metadata,
}
)

def maybe_flush(self, step: int):
if not self.enable:
return
if step == self._last_flush_step:
return
if step % self.flush_every_n_steps != 0:
return
self.flush(blocking=False)
self._last_flush_step = step

def flush(self, blocking: bool = False):
if not self.enable or self._file is None:
return

remaining_events = []
for event in self.pending_events:
end_event = event["end_event"]
if blocking:
end_event.synchronize()
elif not end_event.query():
remaining_events.append(event)
continue

record = {
"name": event["name"],
"step": event["step"],
"rank": event["rank"],
"duration_ms": event["start_event"].elapsed_time(end_event),
}
if event["metadata"]:
record.update(event["metadata"])
self._file.write(json.dumps(record, sort_keys=True) + "\n")

self.pending_events = remaining_events

def close(self):
if not self.enable:
return
self.flush(blocking=True)
if self._file is not None:
self._file.close()
self._file = None
Loading
Loading