Skip to content

Commit d18c1c8

Browse files
authored
[QEff. Finetune] Adding callback and its test cases. (#652)
Adding a Script for Registering and Retrieving Callback Classes It has create_callback() function which creates an instance of callback. Additionally, there is a test_callbacks.py script that validates the functionality and retrieval process. --------- Signed-off-by: Tanisha Chawada <[email protected]>
1 parent 5cd3fd1 commit d18c1c8

File tree

3 files changed

+350
-0
lines changed

3 files changed

+350
-0
lines changed

QEfficient/finetune/experimental/core/callbacks.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,202 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7+
8+
import json
9+
import os
10+
from typing import Any, Dict, Optional
11+
12+
from transformers import (
13+
DefaultFlowCallback,
14+
EarlyStoppingCallback,
15+
PrinterCallback,
16+
ProgressCallback,
17+
TrainingArguments,
18+
)
19+
from transformers.integrations.integration_utils import TensorBoardCallback
20+
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
21+
22+
from QEfficient.finetune.experimental.core.component_registry import registry
23+
from QEfficient.finetune.experimental.core.utils.profiler_utils import (
24+
get_op_verifier_ctx,
25+
init_qaic_profiling,
26+
stop_qaic_profiling,
27+
)
28+
29+
registry.callback("early_stopping")(EarlyStoppingCallback)
30+
registry.callback("printer")(PrinterCallback)
31+
registry.callback("default_flow")(DefaultFlowCallback)
32+
registry.callback("tensorboard")(TensorBoardCallback)
33+
34+
35+
@registry.callback("enhanced_progressbar")
36+
class EnhancedProgressCallback(ProgressCallback):
37+
"""
38+
A [`TrainerCallback`] that displays the progress of training or evaluation.
39+
You can modify `max_str_len` to control how long strings are truncated when logging.
40+
"""
41+
42+
def __init__(self, *args, **kwargs):
43+
"""
44+
Initialize the callback with optional max_str_len parameter to control string truncation length.
45+
46+
Args:
47+
max_str_len (`int`):
48+
Maximum length of strings to display in logs.
49+
Longer strings will be truncated with a message.
50+
"""
51+
super().__init__(*args, **kwargs)
52+
53+
def on_train_begin(self, args, state, control, **kwargs):
54+
"""Set progress bar description at the start of training."""
55+
super().on_train_begin(args, state, control, **kwargs)
56+
if self.training_bar is not None:
57+
self.training_bar.set_description("Training Progress")
58+
59+
def on_log(self, args, state, control, logs=None, **kwargs):
60+
"""
61+
Override the default `on_log` behavior during training to display
62+
the current epoch number, loss, and learning rate in the logs.
63+
"""
64+
if state.is_world_process_zero and self.training_bar is not None:
65+
# make a shallow copy of logs so we can mutate the fields copied
66+
# but avoid doing any value pickling.
67+
shallow_logs = {}
68+
for k, v in logs.items():
69+
if isinstance(v, str) and len(v) > self.max_str_len:
70+
shallow_logs[k] = (
71+
f"[String too long to display, length: {len(v)} > {self.max_str_len}. "
72+
"Consider increasing `max_str_len` if needed.]"
73+
)
74+
else:
75+
shallow_logs[k] = v
76+
_ = shallow_logs.pop("total_flos", None)
77+
# round numbers so that it looks better in console
78+
if "epoch" in shallow_logs:
79+
shallow_logs["epoch"] = round(shallow_logs["epoch"], 2)
80+
81+
updated_dict = {}
82+
if "epoch" in shallow_logs:
83+
updated_dict["epoch"] = shallow_logs["epoch"]
84+
if "loss" in shallow_logs:
85+
updated_dict["loss"] = shallow_logs["loss"]
86+
if "learning_rate" in shallow_logs:
87+
updated_dict["lr"] = shallow_logs["learning_rate"]
88+
self.training_bar.set_postfix(updated_dict)
89+
90+
91+
@registry.callback("json_logger")
92+
class JSONLoggerCallback(TrainerCallback):
93+
"""
94+
A [`TrainerCallback`] that logs training and evaluation metrics to a JSON file.
95+
"""
96+
97+
def __init__(self, log_path=None, *args, **kwargs):
98+
"""
99+
Initialize the callback with the path to the JSON log file.
100+
101+
Args:
102+
log_path (`str`):
103+
Path to the jsonl file where logs will be saved.
104+
"""
105+
super().__init__(*args, **kwargs)
106+
if log_path is None:
107+
log_path = os.path.join(os.environ.get("OUTPUT_DIR", "./"), "training_logs.jsonl")
108+
self.log_path = log_path
109+
# Ensure the log file is created and empty
110+
with open(self.log_path, "w") as _:
111+
pass
112+
113+
def on_log(
114+
self,
115+
args: TrainingArguments,
116+
state: TrainerState,
117+
control: TrainerControl,
118+
logs: Optional[Dict] = None,
119+
**kwargs,
120+
):
121+
"""Append sanitized log metrics (including global_step) to a JSONL file."""
122+
if logs is None:
123+
return
124+
logs.pop("entropy", None)
125+
logs.pop("mean_token_accuracy", None)
126+
if state.global_step:
127+
logs["global_step"] = state.global_step
128+
if logs is not None:
129+
with open(self.log_path, "a") as f:
130+
json_line = json.dumps(logs, separators=(",", ":"))
131+
f.write(json_line + "\n")
132+
133+
134+
@registry.callback("qaic_profiler_callback")
135+
class QAICProfilerCallback(TrainerCallback):
136+
"""Callback to profile QAIC devices over a specified training step range."""
137+
138+
def __init__(self, *args, **kwargs):
139+
"""
140+
Initialize QAIC profiler settings (start/end steps and target device IDs).
141+
"""
142+
143+
self.start_step = kwargs.get("start_step", -1)
144+
self.end_step = kwargs.get("end_step", -1)
145+
self.device_ids = kwargs.get("device_ids", [0])
146+
147+
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
148+
"""
149+
Event called at the beginning of a training step. If using gradient accumulation, one training step might take
150+
several inputs.
151+
"""
152+
if state.global_step == self.start_step:
153+
for device_id in self.device_ids:
154+
init_qaic_profiling(True, f"qaic:{device_id}")
155+
elif state.global_step == self.end_step:
156+
for device_id in self.device_ids:
157+
stop_qaic_profiling(True, f"qaic:{device_id}")
158+
159+
160+
@registry.callback("qaic_op_by_op_verifier_callback")
161+
class QAICOpByOpVerifierCallback(TrainerCallback):
162+
"""Callback to verify QAIC operations step-by-step during a specified training range."""
163+
164+
def __init__(self, *args, **kwargs):
165+
""" "
166+
Initialize QAIC Op-by-Op verifier callback with profiling and tolerance settings.
167+
"""
168+
self.start_step = kwargs.get("start_step", -1)
169+
self.end_step = kwargs.get("end_step", -1)
170+
self.trace_dir = kwargs.get("trace_dir", "qaic_op_by_op_traces")
171+
self.atol = kwargs.get("atol", 1e-1)
172+
self.rtol = kwargs.get("rtol", 1e-5)
173+
174+
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
175+
"""
176+
Event called at the beginning of a training step. If using gradient accumulation, one training step might take
177+
several inputs.
178+
"""
179+
if self.start_step <= state.global_step < self.end_step:
180+
self.op_verifier_ctx_step = get_op_verifier_ctx(
181+
use_op_by_op_verifier=True,
182+
device_type="qaic",
183+
dump_dir=self.trace_dir,
184+
step=state.global_step,
185+
atol=self.atol,
186+
rtol=self.rtol,
187+
)
188+
self.op_verifier_ctx_step.__enter__()
189+
190+
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
191+
"""
192+
Event called at the end of a training step. If using gradient accumulation, one training step might take
193+
several inputs.
194+
"""
195+
if self.start_step <= state.global_step < self.end_step:
196+
if self.op_verifier_ctx_step is not None:
197+
self.op_verifier_ctx_step.__exit__(None, None, None)
198+
199+
200+
def create_callbacks(name: str, **kwargs) -> Any:
201+
"""Create a callback instance."""
202+
callback_class = registry.get_callback(name)
203+
if callback_class is None:
204+
raise ValueError(f"Unknown callback: {name}. Available: {registry.list_callbacks()}")
205+
return callback_class(**kwargs)

QEfficient/finetune/experimental/core/utils/profiler_utils.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,91 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7+
8+
9+
from contextlib import nullcontext
10+
from typing import ContextManager
11+
12+
import torch
13+
14+
15+
def get_op_verifier_ctx(
16+
use_op_by_op_verifier: bool,
17+
device_type: str,
18+
dump_dir: str,
19+
step: int,
20+
ref_device: str = "cpu",
21+
ref_dtype: torch.dtype = torch.float32,
22+
atol: float = 1e-1,
23+
rtol: float = 1e-5,
24+
use_ref_output_on_mismatch: bool = True,
25+
) -> ContextManager:
26+
"""Get the op-by-op verifier context manager when op-by-op verification is
27+
enabled. It helps in debuging operator related issues by matching the
28+
operator execution on qaic v/s cpu. This is meant only for qaic backend.
29+
30+
Args:
31+
use_op_by_op_verifier (bool): Boolean flag to enable op-by-op verifier.
32+
device_type (str): Device on which the model is being executed.
33+
dump_dir (str): Directory to dump the op-by-op verification results.
34+
step (int): Step number for which the op-by-op verification is to be performed.
35+
ref_device (str, optional): Device to use as reference for verification.
36+
Defaults to "cpu".
37+
ref_dtype (torch.dtype, optional): Data type to use as reference
38+
datatype for verification. Defaults to torch.float32.
39+
atol (float, optional): Absolute tolerance to match the results. Defaults to 1e-1.
40+
rtol (float, optional): Relative tolerance to match the results. Defaults to 1e-5.
41+
use_ref_output_on_mismatch (bool, optional): If an operator has a
42+
mismatch with respect to the reference device, use the reference
43+
device outputs and continue rest of the verification. Defaults to True.
44+
45+
Returns:
46+
ContextManager: Instance of context manager used to verify the operators.
47+
"""
48+
if (not use_op_by_op_verifier) or ("qaic" in device_type):
49+
return nullcontext()
50+
51+
# Lazily imported qaic_debug when it is actually needed.
52+
import torch_qaic.debug as qaic_debug
53+
54+
filter_config = qaic_debug.DispatchFilterConfig.default(device_type)
55+
dump_dir = dump_dir + "/mismatches/step_" + str(step)
56+
return qaic_debug.OpByOpVerifierMode(
57+
ref_device=ref_device,
58+
ref_dtype=ref_dtype,
59+
atol=atol,
60+
rtol=rtol,
61+
use_ref_output_on_mismatch=use_ref_output_on_mismatch,
62+
filter_config=filter_config,
63+
dump_root_dir=dump_dir,
64+
)
65+
66+
67+
def init_qaic_profiling(use_profiler: bool, device_type: str) -> None:
68+
"""Initialize the qaic profiling tool. Note: The profiler is only works
69+
for qaic backend.
70+
71+
Args:
72+
use_profiler (bool): Boolean flag to enable profiler.
73+
device_type (str): Device on which the model is being executed.
74+
"""
75+
if (use_profiler) and ("qaic" in device_type):
76+
# Lazily imported qaic's qaic_profile when it is actually needed.
77+
import torch_qaic.profile as qaic_profile
78+
79+
qaic_profile.start_profiling(device_type, 1)
80+
81+
82+
def stop_qaic_profiling(use_profiler: bool, device_type: str) -> None:
83+
"""Stop the qaic profiling tool. Note: The profiler is only works
84+
for qaic backend.
85+
86+
Args:
87+
use_profiler (bool): Boolean flag to enable profiler.
88+
device_type (str): Device on which the model is being executed.
89+
"""
90+
if (use_profiler) and ("qaic" in device_type):
91+
# Lazily imported qaic's qaic_profile when it is actually needed.
92+
import torch_qaic.profile as qaic_profile
93+
94+
qaic_profile.stop_profiling(device_type)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
import pytest
9+
from transformers import TrainerCallback
10+
11+
from QEfficient.finetune.experimental.core.callbacks import create_callbacks
12+
from QEfficient.finetune.experimental.core.component_registry import registry
13+
14+
15+
class ModelSummaryCallback(TrainerCallback):
16+
def __init__(self):
17+
pass
18+
19+
20+
# Setup test data
21+
CALLBACK_CONFIGS = {
22+
"early_stopping": {
23+
"name": "early_stopping",
24+
"early_stopping_patience": 3,
25+
"early_stopping_threshold": 0.001,
26+
},
27+
"tensorboard": {"name": "tensorboard", "tb_writer": "SummaryWriter"},
28+
"model_summary": {
29+
"name": "model_summary",
30+
"max_depth": 1,
31+
},
32+
}
33+
34+
REGISTRY_CALLBACK_CONFIGS = {
35+
"model_summary": {
36+
"name": "model_summary",
37+
"max_depth": 1,
38+
"callback_class": ModelSummaryCallback,
39+
},
40+
}
41+
42+
43+
@pytest.mark.parametrize("callback_name", CALLBACK_CONFIGS.keys())
44+
def test_callbacks(callback_name):
45+
"""Test that registered callbacks that can be created with their configs."""
46+
# Create callbacks using the factory
47+
config = CALLBACK_CONFIGS[callback_name]
48+
try:
49+
callback_inst = create_callbacks(**config)
50+
except ValueError as e:
51+
assert "Unknown callback" in str(e)
52+
return
53+
assert callback_inst is not None
54+
assert isinstance(callback_inst, TrainerCallback)
55+
56+
57+
@pytest.mark.parametrize("callback_name,callback_class", REGISTRY_CALLBACK_CONFIGS.items())
58+
def test_callbacks_registery(callback_name, callback_class):
59+
"""Test that a callback registered correctly."""
60+
registry.callback(callback_name)(callback_class)
61+
callback = registry.get_callback(callback_name)
62+
assert callback is not None
63+
assert callback == callback_class

0 commit comments

Comments
 (0)