diff --git a/QEfficient/finetune/experimental/core/config_manager.py b/QEfficient/finetune/experimental/core/config_manager.py index d647b73a6..244967f39 100644 --- a/QEfficient/finetune/experimental/core/config_manager.py +++ b/QEfficient/finetune/experimental/core/config_manager.py @@ -4,3 +4,752 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- +""" +Configuration manager for handling all training configurations. +Provides centralized configuration loading, validation, and management. +""" + +import json +import os +from dataclasses import asdict, dataclass, field, fields, is_dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import yaml +from transformers.hf_argparser import HfArgumentParser + +from QEfficient.finetune.experimental.core.component_registry import registry + + +@dataclass +class OptimizerConfig: + """Configuration for optimizers.""" + + optimizer_name: str = field( + default="adamw", + metadata={"help": "The name of the optimizer to use."}, + ) + lr: float = field( + default=5e-5, + metadata={"help": "The initial learning rate for the optimizer."}, + ) + weight_decay: float = field( + default=0.01, + metadata={"help": "The weight decay to apply (if any)."}, + ) + + +@dataclass +class SchedulerConfig: + """Configuration for learning rate schedulers.""" + + scheduler_name: str = field( + default="cosine", + metadata={"help": "The name of the scheduler to use (e.g., 'linear', 'cosine')."}, + ) + warmup_steps: int = field( + default=100, + metadata={ + "help": "Number of steps for the warmup phase. If provided " + "value is within [0-1) range then it will be interpreted as " + "ratio of total training steps for the warmup phase." + }, + ) + + +@dataclass +class DatasetConfig: + """Configuration for datasets.""" + + tokenizer_name: str = field( + default="HuggingFaceTB/SmolLM-135M", + metadata={"help": "The name or path of the tokenizer to use."}, + ) + dataset_type: str = field( + default="seq_completion", + metadata={"help": "The type of dataset (e.g., 'seq_completion')."}, + ) + dataset_name: str = field( + default="knkarthick/samsum", + metadata={"help": "The name or path of the dataset."}, + ) + dataset_subset: str = field( + default="default", + metadata={"help": "The subset of the dataset to use, if applicable."}, + ) + train_split: str = field( + default="train", + metadata={"help": "The name of the training split."}, + ) + test_split: str = field( + default="test", + metadata={"help": "The name of the test/validation split."}, + ) + max_seq_length: int = field( + default=512, + metadata={"help": "The maximum sequence length for tokenization."}, + ) + split_ratio: float = field( + default=0.8, + metadata={"help": "Ratio for train/test split, used when only train_split is provided."}, + ) + input_columns: list[str] = field( + default_factory=lambda: ["text"], + metadata={"help": "List of column names containing input text."}, + ) + target_column: Optional[str] = field( + default=None, + metadata={"help": "Name of the column containing target labels (if applicable)."}, + ) + train_batch_size: int = field( + default=1, + metadata={"help": "Batch size per device during training."}, + ) + eval_batch_size: int = field( + default=1, + metadata={"help": "Batch size per device during evaluation."}, + ) + num_workers: int = field( + default=4, + metadata={"help": "Number of workers for dataset processing."}, + ) + collate_fn: str = field( + default="dynamic_padding", + metadata={"help": "The collation function to use (e.g., 'dynamic_padding')."}, + ) + group_by_length: bool = field( + default=True, + metadata={"help": "Whether to group samples by length to minimize padding."}, + ) + length_column_name: str = field( + default="input_ids", + metadata={"help": "The column name containing the length of the input sequences."}, + ) + dataloader_pin_memory: bool = field( + default=True, + metadata={"help": "Whether to pin GPU memory for dataloaders."}, + ) + dataloader_persistent_workers: bool = field( + default=True, + metadata={"help": "Whether to keep dataloader workers alive across epochs."}, + ) + dataloader_prefetch_factor: int = field( + default=1, + metadata={"help": "Number of samples loaded in advance by each worker."}, + ) + dataloader_drop_last: bool = field( + default=False, + metadata={"help": "Whether to drop the last incomplete batch."}, + ) + dataloader_num_workers: int = field( + default=1, + metadata={"help": "Number of workers for the DataLoader."}, + ) + + +@dataclass +class PeftConfig: + """Configuration for PEFT (Parameter-Efficient Fine-Tuning) methods.""" + + lora_r: int = field( + default=8, + metadata={"help": "Lora attention dimension."}, + ) + lora_alpha: int = field( + default=16, + metadata={"help": "Lora alpha."}, + ) + lora_dropout: float = field( + default=0.1, + metadata={"help": "The dropout probability for Lora layers."}, + ) + target_modules: list[str] = field( + default_factory=lambda: ["q_proj", "v_proj"], + metadata={"help": "The modules to apply Lora to."}, + ) + bias: str = field( + default="none", + metadata={"help": "Bias type for Lora ('none', 'all', 'lora_only')."}, + ) + task_type: str = field( + default="CAUSAL_LM", + metadata={"help": "The task type for PEFT (e.g., 'CAUSAL_LM', 'SEQ_2_SEQ_LM')."}, + ) + peft_type: str = field( + default="LORA", + metadata={"help": "The PEFT method to use (e.g., 'LORA', 'IA3')."}, + ) + + +@dataclass +class ModelConfig: + """Configuration for models.""" + + model_name: str = field( + default="HuggingFaceTB/SmolLM-135M", + metadata={"help": "The name or path of the pretrained model."}, + ) + model_type: str = field( + default="hf", + metadata={"help": "The type of model ('hf' for Hugging Face, 'custom' for custom models)."}, + ) + auto_class_name: str = field( + default="AutoModelForCausalLM", + metadata={"help": "The AutoClass name to load the model (e.g., 'AutoModelForCausalLM')."}, + ) + load_in_4bit: bool = field( + default=False, + metadata={"help": "Whether to load the model in 4-bit quantization."}, + ) + use_peft: bool = field( + default=True, + metadata={"help": "Whether to use PEFT (Parameter-Efficient Fine-Tuning)."}, + ) + peft_config: Optional[PeftConfig] = field( + default_factory=PeftConfig, + metadata={"help": "Configuration for PEFT."}, + ) + use_cache: bool = field( + default=False, + metadata={"help": "Whether to use the past key/values in the model for faster decoding."}, + ) + attn_implementation: str = field( + default="sdpa", + metadata={"help": "The attention implementation to use (e.g., 'sdpa', 'eager')."}, + ) + device_map: Optional[str] = field( + default=None, + metadata={"help": "The device map to use for model distribution (e.g., 'auto')."}, + ) + + +@dataclass +class CallbackConfig: + """Configuration for callbacks.""" + + callbacks: Dict[str, Dict[str, Any]] = field( + default_factory=dict, + metadata={"help": "Dictionary of callback configurations, keyed by callback name."}, + ) + + +@dataclass +class GradientCheckpointingKwargs: + """Arguments for gradient checkpointing.""" + + preserve_rng_state: bool = field( + default=True, + metadata={"help": "Whether to preserve the RNG state when checkpointing."}, + ) + use_reenrant: bool = field( + default=False, + metadata={"help": "Whether to use reentrant gradient checkpointing."}, + ) + + +@dataclass +class DdpConfig: + """Arguments for Distributed Data Parallel (DDP) training.""" + + ddp_backend: str = field( + default="qccl", + metadata={"help": "The DDP backend to use (e.g., 'nccl', 'gloo', 'qccl')."}, + ) + ddp_find_unused_parameters: bool = field( + default=False, + metadata={"help": "Whether to find unused parameters in DDP."}, + ) + ddp_bucket_cap_mb: Optional[int] = field( + default=25, + metadata={"help": "The bucket size in MB for DDP communication."}, + ) + ddp_broadcast_buffers: bool = field( + default=True, + metadata={"help": "Whether to broadcast buffers in DDP."}, + ) + ddp_timeout: int = field( + default=1800, + metadata={"help": "Timeout for DDP operations in seconds."}, + ) + + +@dataclass +class TrainingConfig: + """Configuration for training.""" + + type: str = field( + default="sft", + metadata={"help": "The type of training (e.g., 'sft' for Supervised Fine-Tuning)."}, + ) + output_dir: str = field( + default="./training_results", + metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, + ) + overwrite_output_dir: bool = field( + default=False, + metadata={"help": "Whether to overwrite the output directory."}, + ) + seed: int = field( + default=42, + metadata={"help": "Random seed for reproducibility."}, + ) + device: str = field( + default="qaic", + metadata={"help": "The device to use for training ('cuda', 'cpu', etc.)."}, + ) + do_eval: bool = field( + default=True, + metadata={"help": "Whether to run evaluation during training."}, + ) + eval_strategy: str = field( + default="epoch", + metadata={"help": "The evaluation strategy to use ('no', 'steps', 'epoch')."}, + ) + eval_steps: int = field( + default=100, + metadata={"help": "Number of update steps between two evaluations."}, + ) + per_device_train_batch_size: int = field( + default=1, + metadata={"help": "Batch size per device during training."}, + ) + per_device_eval_batch_size: int = field( + default=1, + metadata={"help": "Batch size per device during evaluation."}, + ) + gradient_accumulation_steps: int = field( + default=1, + metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, + ) + num_train_epochs: int = field( + default=1, + metadata={"help": "Total number of training epochs to perform."}, + ) + max_steps: int = field( + default=-1, + metadata={"help": "If > 0: set total number of training steps to perform."}, + ) + + log_level: str = field( + default="info", + metadata={"help": "Set the verbosity level of the logs ('debug', 'info', 'warning', 'error')."}, + ) + log_on_each_node: bool = field( + default=True, + metadata={"help": "Whether to log on each node in a distributed setup."}, + ) + logging_strategy: str = field( + default="steps", + metadata={"help": "The logging strategy to use ('no', 'steps', 'epoch')."}, + ) + logging_steps: int = field( + default=10, + metadata={"help": "Number of update steps between two loggings."}, + ) + + save_strategy: str = field( + default="epoch", + metadata={"help": "The checkpoint save strategy to use ('no', 'steps', 'epoch')."}, + ) + save_steps: int = field( + default=100, + metadata={"help": "Number of update steps between two checkpoints (if save_strategy is 'steps')."}, + ) + save_total_limit: int = field( + default=5, + metadata={"help": "Limit the total amount of checkpoints. Deletes older checkpoints to stay within limit."}, + ) + metric_for_best_model: str = field( + default="eval_loss", + metadata={"help": "The metric to use to compare two models ('eval_loss', etc.)."}, + ) + + dtype: str = field( + default="fp16", + metadata={"help": "The data type to use for training (e.g., 'fp16', 'bf16')."}, + ) + + gradient_checkpointing: bool = field( + default=False, + metadata={"help": "Whether to use gradient checkpointing."}, + ) + gradient_checkpointing_kwargs: Optional[GradientCheckpointingKwargs] = field( + default_factory=GradientCheckpointingKwargs, + metadata={"help": "Arguments for gradient checkpointing."}, + ) + + torch_compile: bool = field( + default=True, + metadata={"help": "Whether to compile the model with `torch.compile`."}, + ) + include_num_input_tokens_seen: bool = field( + default=True, + metadata={"help": "Whether to include the number of input tokens seen in logs."}, + ) + average_tokens_across_devices: bool = field( + default=True, + metadata={"help": "Whether to average tokens across devices in distributed training."}, + ) + + disable_tqdm: Optional[bool] = field( + default=None, + metadata={"help": "Whether to disable the tqdm progress bar."}, + ) + fsdp_config: Optional[Dict[str, Any]] = field( + default=None, + metadata={"help": "FSDP configuration dictionary."}, + ) + deepspeed_config: Optional[Dict[str, Any]] = field( + default=None, + metadata={"help": "DeepSpeed configuration dictionary."}, + ) + accelerator_config: Optional[Dict[str, Any]] = field( + default=None, + metadata={"help": "Accelerate configuration dictionary."}, + ) + ddp_config: Optional[DdpConfig] = field( + default_factory=DdpConfig, + metadata={"help": "DDP configuration dictionary."}, + ) + use_cpu: Optional[bool] = field( + default=None, + metadata={"help": "Whether to explicitly run training on CPU."}, + ) + resume_from_checkpoint: Optional[str] = field( + default=None, + metadata={"help": "Path to a checkpoint to resume training from."}, + ) + restore_callback_states_from_checkpoint: Optional[bool] = field( + default=None, + metadata={"help": "Whether to restore callback states from checkpoint."}, + ) + report_to: Optional[List[str]] = field( + default=None, + metadata={"help": "The list of integrations to report the results and logs to."}, + ) + completion_only_loss: Optional[bool] = field( + default=False, + metadata={"help": "Whether to compute loss only on completion tokens."}, + ) + + +@dataclass +class MasterConfig: + """Main training configuration.""" + + model: ModelConfig = field(default_factory=ModelConfig, metadata={"help": "Configuration for the model."}) + + dataset: DatasetConfig = field(default_factory=DatasetConfig, metadata={"help": "Configuration for the dataset."}) + + optimizers: OptimizerConfig = field( + default_factory=OptimizerConfig, metadata={"help": "Configuration for optimizers."} + ) + + scheduler: SchedulerConfig = field( + default_factory=SchedulerConfig, metadata={"help": "Configuration for the learning rate scheduler."} + ) + + callbacks: CallbackConfig = field(default_factory=CallbackConfig, metadata={"help": "Configuration for callbacks."}) + + training: TrainingConfig = field( + default_factory=TrainingConfig, metadata={"help": "Configuration for training parameters."} + ) + + extra_params: Dict[str, Any] = field( + default_factory=dict, metadata={"help": "Additional top-level parameters not explicitly defined."} + ) + + +def parse_arguments(config_path: Optional[str] = None, args: Optional[List[str]] = None) -> MasterConfig: + """Create argument parser for the new finetuning interface.""" + parser = HfArgumentParser(MasterConfig) + + if config_path: + config_path = os.path.abspath(config_path) + if not os.path.exists(config_path): + raise FileNotFoundError(f"Config file not found: {config_path}") + if not (config_path.endswith(".yaml") or config_path.endswith(".yml")): + raise ValueError(f"Expected a .yaml/.yml file, got: {config_path}") + + try: + (master_config,) = parser.parse_yaml_file(yaml_file=config_path) + return master_config + except Exception as e: + raise ValueError(f"Failed to parse YAML config '{config_path}': {e}") + + args = [] if args is None else args + # If a single positional YAML file was passed via args, parse it as YAML + if len(args) == 1 and (args[0].endswith(".yaml") or args[0].endswith(".yml")): + yaml_path = os.path.abspath(args[0]) + (master_config,) = parser.parse_yaml_file(yaml_file=yaml_path) + else: + (master_config,) = parser.parse_args_into_dataclasses(args=args) + master_config = asdict(master_config) + master_config = MasterConfig(**master_config) + + return master_config + + +class ConfigManager: + """Manages configuration loading, validation, and updates.""" + + def __init__(self, config: MasterConfig): + """ + Initialize ConfigManager with either: + - Path to config file (str or Path) + - Configuration dictionary + - None (creates empty config) + """ + self.config = config + + def load_config(self, config_path: Union[str, Path]) -> None: + """Load configuration from file.""" + config_path = Path(config_path) + + if not config_path.exists(): + raise FileNotFoundError(f"Configuration file not found: {config_path}") + + if config_path.suffix.lower() in [".yaml", ".yml"]: + with open(config_path, "r") as f: + config_dict = yaml.safe_load(f) + elif config_path.suffix.lower() == ".json": + with open(config_path, "r") as f: + config_dict = json.load(f) + else: + raise ValueError(f"Unsupported configuration file format: {config_path.suffix}") + + self.update_config(config_dict) + + def _ensure_extra_params(self, obj) -> Dict[str, Any]: + """Ensure obj.extra_params exists and is a dict; return it.""" + ep = getattr(obj, "extra_params", None) + if ep is None: + setattr(obj, "extra_params", {}) + ep = obj.extra_params + if not isinstance(ep, dict): + raise TypeError("extra_params must be a dict.") + return ep + + def _stash_top_level_extra(self, section: str, nested_key: str, value: Any) -> None: + """Store unknown nested values under MasterConfig.extra_params['section.nested_key'].""" + ep = self._ensure_extra_params(self.config) + ep[f"{section}.{nested_key}"] = value + + def update_config(self, config_dict: Dict[str, Any]) -> None: + """Update configuration with dictionary values.""" + + SPECIAL_KEYS = {"callbacks"} + + for key, value in config_dict.items(): + if hasattr(self.config, key): + target = getattr(self.config, key) + + # Special handling for callbacks (dict inside CallbackConfig) + if key in SPECIAL_KEYS and isinstance(value, dict): + if is_dataclass(target) and hasattr(target, "callbacks") and isinstance(target.callbacks, dict): + for component_name, component_cfg in value.items(): + target.callbacks[component_name] = component_cfg + elif isinstance(target, dict): + target.update(value) + else: + self._stash_top_level_extra(key, "__all__", value) + continue + + if isinstance(value, dict) and is_dataclass(target): + known = {f.name for f in fields(target)} + for nested_key, nested_value in value.items(): + if nested_key in known: + setattr(target, nested_key, nested_value) + else: + self._stash_top_level_extra(key, nested_key, nested_value) + continue + + if isinstance(value, dict) and isinstance(target, dict): + target.update(value) + continue + setattr(self.config, key, value) + + else: + ep = self._ensure_extra_params(self.config) + ep[key] = value + + def save_config(self, output_path: Union[str, Path]) -> None: + """Save current configuration to file.""" + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + config_dict = self.config + + if output_path.suffix.lower() in [".yaml", ".yml"]: + with open(output_path, "w") as f: + yaml.dump(config_dict, f, default_flow_style=False, indent=2) + elif output_path.suffix.lower() == ".json": + with open(output_path, "w") as f: + json.dump(config_dict, f, indent=2) + else: + raise ValueError(f"Unsupported output file format: {output_path.suffix}") + + def _push(self, errs: List[str], cond: bool, msg: str) -> None: + """Append msg to errs if cond is True.""" + if cond: + errs.append(msg) + + def validate_config(self) -> None: + """ + Validate configuration parameters for MasterConfig. + """ + errors: List[str] = [] + + cfg = self.config + model = getattr(cfg, "model", {}) + optimizers = getattr(cfg, "optimizers", {}) + dataset = getattr(cfg, "dataset", {}) + training = getattr(cfg, "training", {}) + + # ---------- Model ---------- + self._push(errors, not model.get("model_name"), "model.model_name is required.") + + # PEFT validation + if model.get("use_peft"): + pc = model.get("peft_config", {}) + self._push(errors, not isinstance(pc, dict), "model.peft_config must be a dict when use_peft=True.") + if isinstance(pc, dict): + self._push( + errors, + not isinstance(pc.get("lora_r", 0), int) or pc.get("lora_r", 0) <= 0, + "model.peft_config.lora_r must be a positive integer.", + ) + self._push( + errors, + not isinstance(pc.get("lora_alpha", 0), int) or pc.get("lora_alpha", 0) <= 0, + "model.peft_config.lora_alpha must be a positive integer.", + ) + self._push( + errors, + not (0.0 <= float(pc.get("lora_dropout", 0.0)) < 1.0), + "model.peft_config.lora_dropout must be in [0,1).", + ) + + # ---------- Dataset ---------- + self._push(errors, not dataset.get("dataset_name"), "dataset.dataset_name is required.") + self._push(errors, not dataset.get("tokenizer_name"), "dataset.tokenizer_name is required.") + self._push(errors, dataset.get("max_seq_length", 0) <= 0, "dataset.max_seq_length must be positive.") + + # ---------- Training ---------- + # Batch sizes + self._push( + errors, + training.get("per_device_train_batch_size", 0) <= 0, + "training.per_device_train_batch_size must be positive.", + ) + self._push( + errors, + training.get("per_device_eval_batch_size", 0) <= 0, + "training.per_device_eval_batch_size must be positive.", + ) + + # Epochs / steps + n_epochs = training.get("num_train_epochs", 0) + max_steps = training.get("max_steps", -1) + self._push( + errors, + n_epochs <= 0 and max_steps <= 0, + "Either training.num_train_epochs > 0 or training.max_steps > 0 must be set.", + ) + + # Gradient accumulation + self._push( + errors, + training.get("gradient_accumulation_steps", 0) <= 0, + "training.gradient_accumulation_steps must be positive.", + ) + + # Logging / saving configs + self._push(errors, training.get("logging_steps", 0) < 0, "training.logging_steps must be >= 0.") + self._push(errors, training.get("save_total_limit", 0) < 0, "training.save_total_limit must be >= 0.") + + # Device + valid_devices = ["cpu", "cuda", "qaic"] + training_device = training.get("device", None) + if training_device not in valid_devices: + self._push(errors, training_device not in valid_devices, f"training.device must be one of {valid_devices}.") + + # DDP config + ddp = training.get("ddp_config", {}) + if isinstance(ddp, dict): + backend = ddp.get("ddp_backend") + # Accept qccl for Qualcomm, nccl for CUDA, gloo for CPU + self._push( + errors, + backend not in {"qccl", "nccl", "gloo", None}, + "training.ddp_config.ddp_backend must be one of {'qccl','nccl','gloo'} or omitted.", + ) + # -----------Optimizers---------- + self._push(errors, float(optimizers.get("lr", 0)) <= 0, "optimizer.lr must be positive.") + # ---------- Final ---------- + if errors: + # Join messages with bullet points for readability + raise ValueError("Configuration validation failed:\n- " + "\n- ".join(errors)) + + def get_callback_config(self) -> Dict[str, Any]: + """Get callback configuration as dictionary.""" + return self.config.callbacks + + def get_optimizer_config(self) -> Dict[str, Any]: + """Get optimizer configuration as dictionary.""" + return self.config.optimizers + + def get_training_config(self) -> Dict[str, Any]: + """Get training configuration as dictionary.""" + return self.config.training + + def get_scheduler_config(self) -> Dict[str, Any]: + """Get scheduler configuration as dictionary.""" + return self.config.scheduler + + def get_dataset_config(self) -> Dict[str, Any]: + """Get dataset configuration as dictionary.""" + return self.config.dataset + + def get_model_config(self) -> Dict[str, Any]: + """Get model configuration as dictionary.""" + return self.config.model + + def to_dict(self) -> Dict[str, Any]: + """Convert configuration to dictionary.""" + return asdict(self.config) + + def __getattr__(self, name: str) -> Any: + """Allow direct access to config attributes.""" + if hasattr(self.config, name): + return getattr(self.config, name) + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + +def create_trainer_config(name: str, **dependencies) -> tuple: + """ + Create trainer configuration based on registered trainer modules. + + Args: + name: Name of the trainer type + **dependencies: Any dependencies needed to configure the trainer + + Returns: + tuple: (trainer_class, args_class, additional_kwargs) + """ + config = registry.get_trainer_module(name) + + # Process required kwargs based on available dependencies + additional_kwargs = {} + for kwarg, default in config["required_kwargs"].items(): + if kwarg in dependencies: + additional_kwargs[kwarg] = dependencies[kwarg] + elif default != "REQUIRED": + additional_kwargs[kwarg] = default + + # Check for missing required arguments + for kwarg, default in config["required_kwargs"].items(): + if kwarg not in additional_kwargs and default == "REQUIRED": + raise ValueError(f"Required argument '{kwarg}' not provided for trainer '{name}'") + + return config["trainer_cls"], config["args_cls"], additional_kwargs diff --git a/QEfficient/finetune/experimental/tests/test_config.yaml b/QEfficient/finetune/experimental/tests/test_config.yaml new file mode 100644 index 000000000..e97e99d58 --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_config.yaml @@ -0,0 +1,104 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +# model configuration +model: + model_type: "hf" + auto_class_name: "AutoModelForCausalLM" + model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name + load_in_4bit: false + use_peft: true + peft_config: + lora_r: 8 + lora_alpha: 16 + lora_dropout: 0.1 + target_modules: ["q_proj", "v_proj"] + bias: "none" + task_type: "CAUSAL_LM" + peft_type: "LORA" + +# Dataset configuration +dataset: + tokenizer_name: "HuggingFaceTB/SmolLM-135M" + dataset_type: "seq_completion" + # dataset_name: "Arthur-LAGACHERIE/very-smollm-corpus-0.5M" + dataset_name: "knkarthick/samsum" + train_split: "train" + max_seq_length: 512 + split_ratio: 0.8 # Ratio for train/test split, used when only train_split is provided + test_split: "test" + group_by_length: True + num_workers: 4 + dataloader_pin_memory: True + dataloader_persistent_workers: True + dataloader_prefetch_factor: 1 + dataloader_drop_last: False + +# Training configuration +training: + type: "sft" + output_dir: "./training_results" + overwrite_output_dir: False + seed: 42 + device: "qaic" + do_eval: True + eval_strategy: "epoch" + eval_steps: 100 + + per_device_train_batch_size: 1 + per_device_eval_batch_size: 1 + gradient_accumulation_steps: 1 + num_train_epochs: 1 + max_steps: -1 + + log_level: "info" + log_on_each_node: True + logging_strategy: "steps" + logging_steps: 10 + + save_strategy: "epoch" + save_total_limit: 5 + metric_for_best_model: "eval_loss" + + dtype: "fp16" + completion_only_loss: True + report_to: "trackio" + + ddp_config: + ddp_backend: "qccl" + ddp_find_unused_parameters: False + ddp_bucket_cap_mb: 25 + ddp_broadcast_buffers: null + ddp_timeout: 1800 + + use_cpu: False + + gradient_checkpointing: False + gradient_checkpointing_kwargs: + preserve_rng_state : True + use_reenrant: False + + torch_compile: True + include_num_input_tokens_seen: True + average_tokens_across_devices: True + +# Optimizer configuration +optimizers: + optimizer_name: "adamw" + lr: 5e-5 + weight_decay: 0.01 + +scheduler: + scheduler_name: "cosine" + warmup_steps: 100 # warmup_steps or warmup_ratio + +callbacks: + early_stopping: + early_stopping_patience: 3 + early_stopping_threshold: 0.001 + tensorboard: + diff --git a/QEfficient/finetune/experimental/tests/test_config_manager.py b/QEfficient/finetune/experimental/tests/test_config_manager.py new file mode 100644 index 000000000..fd2abfd48 --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_config_manager.py @@ -0,0 +1,62 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + + +from pathlib import Path + +import pytest + +from QEfficient.finetune.experimental.core.config_manager import ConfigManager, parse_arguments + + +@pytest.fixture +def config_path() -> Path: + here = Path(__file__).resolve().parent + return (here / "test_config.yaml").resolve() + + +def test_config(config_path): + master_config = parse_arguments(args=[]) + config_manager = ConfigManager(master_config) + assert isinstance(config_manager, ConfigManager) + config_manager.load_config(config_path) + try: + config_manager.validate_config() + except Exception as e: + pytest.fail(f"Config validation failed with error: {e}") + + # Test that all required fields are present + missing = [ + a + for a in ("model", "dataset", "optimizers", "scheduler", "callbacks", "training") + if not hasattr(config_manager, a) + ] + assert not missing, f"Missing attributes: {missing}" + trainer_config = config_manager.get_training_config() + assert trainer_config is not None + assert isinstance(trainer_config, dict) + assert (hasattr(trainer_config, attr) for attr in ("output_dir", "train_batch_size", "num_epochs", "ddp_config")) + dataset_config = config_manager.get_dataset_config() + assert dataset_config is not None + assert isinstance(dataset_config, dict) + assert (hasattr(dataset_config, attr) for attr in ("dataset_type", "dataset_name", "tokenizer_name")) + model_config = config_manager.get_model_config() + assert model_config is not None + assert isinstance(model_config, dict) + assert (hasattr(model_config, attr) for attr in ("model_type", "model_name", "use_peft", "peft_config")) + scheduler_config = config_manager.get_scheduler_config() + assert scheduler_config is not None + assert isinstance(scheduler_config, dict) + assert (hasattr(scheduler_config, attr) for attr in ("scheduler_name")) + callback_config = config_manager.get_callback_config() + assert callback_config is not None + assert isinstance(callback_config, dict) + assert (hasattr(callback_config, attr) for attr in ("earlystopping")) + optimizer_config = config_manager.get_optimizer_config() + assert optimizer_config is not None + assert isinstance(optimizer_config, dict) + assert (hasattr(optimizer_config, attr) for attr in ("optimizer_name", "lr"))