Skip to content

Commit 1f0a4df

Browse files
committed
[QEff.finetuning] Adding config_manager and its test_cases.
Signed-off-by: Tanisha Chawada <[email protected]>
1 parent 28ec40b commit 1f0a4df

File tree

4 files changed

+196
-183
lines changed

4 files changed

+196
-183
lines changed

QEfficient/finetune/experimental/core/config_manager.py

Lines changed: 166 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111

1212
import json
1313
import os
14-
import sys
15-
from dataclasses import asdict, dataclass, field
14+
from dataclasses import asdict, dataclass, field, fields, is_dataclass
1615
from pathlib import Path
17-
from typing import Any, Dict, Optional, Union
16+
from typing import Any, Dict, List, Optional, Union
1817

1918
import yaml
2019
from transformers.hf_argparser import HfArgumentParser
@@ -257,7 +256,7 @@ class DdpConfig:
257256
metadata={"help": "The DDP backend to use (e.g., 'nccl', 'gloo', 'qccl')."},
258257
)
259258
ddp_find_unused_parameters: bool = field(
260-
default=True,
259+
default=False,
261260
metadata={"help": "Whether to find unused parameters in DDP."},
262261
)
263262
ddp_bucket_cap_mb: Optional[int] = field(
@@ -294,7 +293,10 @@ class TrainingConfig:
294293
default=42,
295294
metadata={"help": "Random seed for reproducibility."},
296295
)
297-
296+
device: str = field(
297+
default="qaic",
298+
metadata={"help": "The device to use for training ('cuda', 'cpu', etc.)."},
299+
)
298300
do_eval: bool = field(
299301
default=True,
300302
metadata={"help": "Whether to run evaluation during training."},
@@ -307,7 +309,6 @@ class TrainingConfig:
307309
default=100,
308310
metadata={"help": "Number of update steps between two evaluations."},
309311
)
310-
311312
per_device_train_batch_size: int = field(
312313
default=1,
313314
metadata={"help": "Batch size per device during training."},
@@ -381,10 +382,6 @@ class TrainingConfig:
381382
default=True,
382383
metadata={"help": "Whether to compile the model with `torch.compile`."},
383384
)
384-
include_tokens_per_second: bool = field(
385-
default=True,
386-
metadata={"help": "Whether to include tokens per second in logs."},
387-
)
388385
include_num_input_tokens_seen: bool = field(
389386
default=True,
390387
metadata={"help": "Whether to include the number of input tokens seen in logs."},
@@ -426,6 +423,14 @@ class TrainingConfig:
426423
default=None,
427424
metadata={"help": "Whether to restore callback states from checkpoint."},
428425
)
426+
report_to: Optional[List[str]] = field(
427+
default=None,
428+
metadata={"help": "The list of integrations to report the results and logs to."},
429+
)
430+
completion_only_loss: Optional[bool] = field(
431+
default=False,
432+
metadata={"help": "Whether to compute loss only on completion tokens."},
433+
)
429434

430435

431436
@dataclass
@@ -455,7 +460,7 @@ class MasterConfig:
455460
)
456461

457462

458-
def parse_arguments(config_path: Optional[str] = None) -> MasterConfig:
463+
def parse_arguments(config_path: Optional[str] = None, args: Optional[List[str]] = None) -> MasterConfig:
459464
"""Create argument parser for the new finetuning interface."""
460465
parser = HfArgumentParser(MasterConfig)
461466

@@ -472,12 +477,15 @@ def parse_arguments(config_path: Optional[str] = None) -> MasterConfig:
472477
except Exception as e:
473478
raise ValueError(f"Failed to parse YAML config '{config_path}': {e}")
474479

475-
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
476-
# If we pass only one argument to the script and it's the path to a json file,
477-
# let's parse it to get our arguments.
478-
master_config = parser.parse_yaml_file(yaml_file=os.path.abspath(sys.argv[1]))[0]
480+
args = [] if args is None else args
481+
# If a single positional YAML file was passed via args, parse it as YAML
482+
if len(args) == 1 and (args[0].endswith(".yaml") or args[0].endswith(".yml")):
483+
yaml_path = os.path.abspath(args[0])
484+
(master_config,) = parser.parse_yaml_file(yaml_file=yaml_path)
479485
else:
480-
master_config = parser.parse_args_into_dataclasses()
486+
(master_config,) = parser.parse_args_into_dataclasses(args=args)
487+
master_config = asdict(master_config)
488+
master_config = MasterConfig(**master_config)
481489

482490
return master_config
483491

@@ -512,34 +520,58 @@ def load_config(self, config_path: Union[str, Path]) -> None:
512520

513521
self.update_config(config_dict)
514522

523+
def _ensure_extra_params(self, obj) -> Dict[str, Any]:
524+
"""Ensure obj.extra_params exists and is a dict; return it."""
525+
ep = getattr(obj, "extra_params", None)
526+
if ep is None:
527+
setattr(obj, "extra_params", {})
528+
ep = obj.extra_params
529+
if not isinstance(ep, dict):
530+
raise TypeError("extra_params must be a dict.")
531+
return ep
532+
533+
def _stash_top_level_extra(self, section: str, nested_key: str, value: Any) -> None:
534+
"""Store unknown nested values under MasterConfig.extra_params['section.nested_key']."""
535+
ep = self._ensure_extra_params(self.config)
536+
ep[f"{section}.{nested_key}"] = value
537+
515538
def update_config(self, config_dict: Dict[str, Any]) -> None:
516539
"""Update configuration with dictionary values."""
540+
541+
SPECIAL_KEYS = {"callbacks"}
542+
517543
for key, value in config_dict.items():
518544
if hasattr(self.config, key):
519-
if isinstance(value, dict) and hasattr(getattr(self.config, key), "__dataclass_fields__"):
520-
# Special handling for callbacks
521-
if key in ["callbacks", "optimizers", "loss_functions"]:
522-
nested_config = getattr(self.config, key)
523-
for component_name, component_dict in value.items():
524-
if isinstance(component_dict, dict):
525-
getattr(nested_config, key)[component_name] = component_dict
526-
else:
527-
getattr(nested_config, "extra_params")[component_name] = nested_config.extra_params[
528-
component_name
529-
] = component_dict
545+
target = getattr(self.config, key)
546+
547+
# Special handling for callbacks (dict inside CallbackConfig)
548+
if key in SPECIAL_KEYS and isinstance(value, dict):
549+
if is_dataclass(target) and hasattr(target, "callbacks") and isinstance(target.callbacks, dict):
550+
for component_name, component_cfg in value.items():
551+
target.callbacks[component_name] = component_cfg
552+
elif isinstance(target, dict):
553+
target.update(value)
530554
else:
531-
# Update nested dataclass
532-
nested_config = getattr(self.config, key)
533-
for nested_key, nested_value in value.items():
534-
if hasattr(nested_config, nested_key):
535-
setattr(getattr(self.config, key), nested_key, nested_value)
536-
elif hasattr(nested_config, "extra_params"):
537-
getattr(getattr(self.config, key), "extra_params")[nested_key] = nested_value
538-
else:
539-
setattr(self.config, key, value)
555+
self._stash_top_level_extra(key, "__all__", value)
556+
continue
557+
558+
if isinstance(value, dict) and is_dataclass(target):
559+
known = {f.name for f in fields(target)}
560+
for nested_key, nested_value in value.items():
561+
if nested_key in known:
562+
setattr(target, nested_key, nested_value)
563+
else:
564+
self._stash_top_level_extra(key, nested_key, nested_value)
565+
continue
566+
567+
if isinstance(value, dict) and isinstance(target, dict):
568+
target.update(value)
569+
continue
570+
setattr(self.config, key, value)
571+
540572
else:
541-
# Store unknown parameters in extra_params
542-
self.config.extra_params[key] = value
573+
ep = self._ensure_extra_params(self.config)
574+
ep[key] = value
543575

544576
def save_config(self, output_path: Union[str, Path]) -> None:
545577
"""Save current configuration to file."""
@@ -557,38 +589,105 @@ def save_config(self, output_path: Union[str, Path]) -> None:
557589
else:
558590
raise ValueError(f"Unsupported output file format: {output_path.suffix}")
559591

560-
def validate_config(self) -> None:
561-
"""Validate configuration parameters."""
562-
errors = []
563-
564-
# Validate model configuration
565-
if not self.config.model.model_name:
566-
errors.append("Model name is required")
567-
568-
# Validate dataset configuration
569-
if not self.config.dataset.dataset_name:
570-
errors.append("Dataset name is required")
571-
572-
# Validate training parameters
573-
if self.config.dataset.train_batch_size <= 0:
574-
errors.append("Train batch size must be positive")
575-
576-
if self.config.dataset.eval_batch_size <= 0:
577-
errors.append("Validation batch size must be positive")
592+
def _push(self, errs: List[str], cond: bool, msg: str) -> None:
593+
"""Append msg to errs if cond is True."""
594+
if cond:
595+
errs.append(msg)
578596

579-
if self.config.training.num_train_epochs <= 0:
580-
errors.append("Number of epochs must be positive")
581-
582-
if self.config.training.gradient_accumulation_steps <= 0:
583-
errors.append("Gradient accumulation steps must be positive")
584-
585-
# Validate device configuration
597+
def validate_config(self) -> None:
598+
"""
599+
Validate configuration parameters for MasterConfig.
600+
"""
601+
errors: List[str] = []
602+
603+
cfg = self.config
604+
model = getattr(cfg, "model", {})
605+
dataset = getattr(cfg, "dataset", {})
606+
training = getattr(cfg, "training", {})
607+
608+
# ---------- Model ----------
609+
self._push(errors, not model.get("model_name"), "model.model_name is required.")
610+
611+
# PEFT validation
612+
if model.get("use_peft"):
613+
pc = model.get("peft_config", {})
614+
self._push(errors, not isinstance(pc, dict), "model.peft_config must be a dict when use_peft=True.")
615+
if isinstance(pc, dict):
616+
self._push(
617+
errors,
618+
not isinstance(pc.get("lora_r", 0), int) or pc.get("lora_r", 0) <= 0,
619+
"model.peft_config.lora_r must be a positive integer.",
620+
)
621+
self._push(
622+
errors,
623+
not isinstance(pc.get("lora_alpha", 0), int) or pc.get("lora_alpha", 0) <= 0,
624+
"model.peft_config.lora_alpha must be a positive integer.",
625+
)
626+
self._push(
627+
errors,
628+
not (0.0 <= float(pc.get("lora_dropout", 0.0)) < 1.0),
629+
"model.peft_config.lora_dropout must be in [0,1).",
630+
)
631+
632+
# ---------- Dataset ----------
633+
self._push(errors, not dataset.get("dataset_name"), "dataset.dataset_name is required.")
634+
self._push(errors, not dataset.get("tokenizer_name"), "dataset.tokenizer_name is required.")
635+
self._push(errors, dataset.get("max_seq_length", 0) <= 0, "dataset.max_seq_length must be positive.")
636+
637+
# ---------- Training ----------
638+
# Batch sizes
639+
self._push(
640+
errors,
641+
training.get("per_device_train_batch_size", 0) <= 0,
642+
"training.per_device_train_batch_size must be positive.",
643+
)
644+
self._push(
645+
errors,
646+
training.get("per_device_eval_batch_size", 0) <= 0,
647+
"training.per_device_eval_batch_size must be positive.",
648+
)
649+
650+
# Epochs / steps
651+
n_epochs = training.get("num_train_epochs", 0)
652+
max_steps = training.get("max_steps", -1)
653+
self._push(
654+
errors,
655+
n_epochs <= 0 and max_steps <= 0,
656+
"Either training.num_train_epochs > 0 or training.max_steps > 0 must be set.",
657+
)
658+
659+
# Gradient accumulation
660+
self._push(
661+
errors,
662+
training.get("gradient_accumulation_steps", 0) <= 0,
663+
"training.gradient_accumulation_steps must be positive.",
664+
)
665+
666+
# Logging / saving configs
667+
self._push(errors, training.get("logging_steps", 0) < 0, "training.logging_steps must be >= 0.")
668+
self._push(errors, training.get("save_total_limit", 0) < 0, "training.save_total_limit must be >= 0.")
669+
670+
# Device
586671
valid_devices = ["cpu", "cuda", "qaic"]
587-
if self.config.training.device not in valid_devices:
588-
errors.append(f"Device must be one of {valid_devices}")
589-
672+
training_device = training.get("device", None)
673+
if training_device not in valid_devices:
674+
self._push(errors, training_device not in valid_devices, f"training.device must be one of {valid_devices}.")
675+
676+
# DDP config
677+
ddp = training.get("ddp_config", {})
678+
if isinstance(ddp, dict):
679+
backend = ddp.get("ddp_backend")
680+
# Accept qccl for Qualcomm, nccl for CUDA, gloo for CPU
681+
self._push(
682+
errors,
683+
backend not in {"qccl", "nccl", "gloo", None},
684+
"training.ddp_config.ddp_backend must be one of {'qccl','nccl','gloo'} or omitted.",
685+
)
686+
687+
# ---------- Final ----------
590688
if errors:
591-
raise ValueError("Configuration validation failed:\n" + "\n".join(f"- {error}" for error in errors))
689+
# Join messages with bullet points for readability
690+
raise ValueError("Configuration validation failed:\n- " + "\n- ".join(errors))
592691

593692
def get_callback_config(self) -> Dict[str, Any]:
594693
"""Get callback configuration as dictionary."""

0 commit comments

Comments
 (0)