1111
1212import json
1313import os
14- import sys
15- from dataclasses import asdict , dataclass , field
14+ from dataclasses import asdict , dataclass , field , fields , is_dataclass
1615from pathlib import Path
17- from typing import Any , Dict , Optional , Union
16+ from typing import Any , Dict , List , Optional , Union
1817
1918import yaml
2019from transformers .hf_argparser import HfArgumentParser
@@ -257,7 +256,7 @@ class DdpConfig:
257256 metadata = {"help" : "The DDP backend to use (e.g., 'nccl', 'gloo', 'qccl')." },
258257 )
259258 ddp_find_unused_parameters : bool = field (
260- default = True ,
259+ default = False ,
261260 metadata = {"help" : "Whether to find unused parameters in DDP." },
262261 )
263262 ddp_bucket_cap_mb : Optional [int ] = field (
@@ -294,7 +293,10 @@ class TrainingConfig:
294293 default = 42 ,
295294 metadata = {"help" : "Random seed for reproducibility." },
296295 )
297-
296+ device : str = field (
297+ default = "qaic" ,
298+ metadata = {"help" : "The device to use for training ('cuda', 'cpu', etc.)." },
299+ )
298300 do_eval : bool = field (
299301 default = True ,
300302 metadata = {"help" : "Whether to run evaluation during training." },
@@ -307,7 +309,6 @@ class TrainingConfig:
307309 default = 100 ,
308310 metadata = {"help" : "Number of update steps between two evaluations." },
309311 )
310-
311312 per_device_train_batch_size : int = field (
312313 default = 1 ,
313314 metadata = {"help" : "Batch size per device during training." },
@@ -381,10 +382,6 @@ class TrainingConfig:
381382 default = True ,
382383 metadata = {"help" : "Whether to compile the model with `torch.compile`." },
383384 )
384- include_tokens_per_second : bool = field (
385- default = True ,
386- metadata = {"help" : "Whether to include tokens per second in logs." },
387- )
388385 include_num_input_tokens_seen : bool = field (
389386 default = True ,
390387 metadata = {"help" : "Whether to include the number of input tokens seen in logs." },
@@ -426,6 +423,14 @@ class TrainingConfig:
426423 default = None ,
427424 metadata = {"help" : "Whether to restore callback states from checkpoint." },
428425 )
426+ report_to : Optional [List [str ]] = field (
427+ default = None ,
428+ metadata = {"help" : "The list of integrations to report the results and logs to." },
429+ )
430+ completion_only_loss : Optional [bool ] = field (
431+ default = False ,
432+ metadata = {"help" : "Whether to compute loss only on completion tokens." },
433+ )
429434
430435
431436@dataclass
@@ -455,7 +460,7 @@ class MasterConfig:
455460 )
456461
457462
458- def parse_arguments (config_path : Optional [str ] = None ) -> MasterConfig :
463+ def parse_arguments (config_path : Optional [str ] = None , args : Optional [ List [ str ]] = None ) -> MasterConfig :
459464 """Create argument parser for the new finetuning interface."""
460465 parser = HfArgumentParser (MasterConfig )
461466
@@ -472,12 +477,15 @@ def parse_arguments(config_path: Optional[str] = None) -> MasterConfig:
472477 except Exception as e :
473478 raise ValueError (f"Failed to parse YAML config '{ config_path } ': { e } " )
474479
475- if len (sys .argv ) == 2 and sys .argv [1 ].endswith (".yaml" ):
476- # If we pass only one argument to the script and it's the path to a json file,
477- # let's parse it to get our arguments.
478- master_config = parser .parse_yaml_file (yaml_file = os .path .abspath (sys .argv [1 ]))[0 ]
480+ args = [] if args is None else args
481+ # If a single positional YAML file was passed via args, parse it as YAML
482+ if len (args ) == 1 and (args [0 ].endswith (".yaml" ) or args [0 ].endswith (".yml" )):
483+ yaml_path = os .path .abspath (args [0 ])
484+ (master_config ,) = parser .parse_yaml_file (yaml_file = yaml_path )
479485 else :
480- master_config = parser .parse_args_into_dataclasses ()
486+ (master_config ,) = parser .parse_args_into_dataclasses (args = args )
487+ master_config = asdict (master_config )
488+ master_config = MasterConfig (** master_config )
481489
482490 return master_config
483491
@@ -512,34 +520,58 @@ def load_config(self, config_path: Union[str, Path]) -> None:
512520
513521 self .update_config (config_dict )
514522
523+ def _ensure_extra_params (self , obj ) -> Dict [str , Any ]:
524+ """Ensure obj.extra_params exists and is a dict; return it."""
525+ ep = getattr (obj , "extra_params" , None )
526+ if ep is None :
527+ setattr (obj , "extra_params" , {})
528+ ep = obj .extra_params
529+ if not isinstance (ep , dict ):
530+ raise TypeError ("extra_params must be a dict." )
531+ return ep
532+
533+ def _stash_top_level_extra (self , section : str , nested_key : str , value : Any ) -> None :
534+ """Store unknown nested values under MasterConfig.extra_params['section.nested_key']."""
535+ ep = self ._ensure_extra_params (self .config )
536+ ep [f"{ section } .{ nested_key } " ] = value
537+
515538 def update_config (self , config_dict : Dict [str , Any ]) -> None :
516539 """Update configuration with dictionary values."""
540+
541+ SPECIAL_KEYS = {"callbacks" }
542+
517543 for key , value in config_dict .items ():
518544 if hasattr (self .config , key ):
519- if isinstance (value , dict ) and hasattr (getattr (self .config , key ), "__dataclass_fields__" ):
520- # Special handling for callbacks
521- if key in ["callbacks" , "optimizers" , "loss_functions" ]:
522- nested_config = getattr (self .config , key )
523- for component_name , component_dict in value .items ():
524- if isinstance (component_dict , dict ):
525- getattr (nested_config , key )[component_name ] = component_dict
526- else :
527- getattr (nested_config , "extra_params" )[component_name ] = nested_config .extra_params [
528- component_name
529- ] = component_dict
545+ target = getattr (self .config , key )
546+
547+ # Special handling for callbacks (dict inside CallbackConfig)
548+ if key in SPECIAL_KEYS and isinstance (value , dict ):
549+ if is_dataclass (target ) and hasattr (target , "callbacks" ) and isinstance (target .callbacks , dict ):
550+ for component_name , component_cfg in value .items ():
551+ target .callbacks [component_name ] = component_cfg
552+ elif isinstance (target , dict ):
553+ target .update (value )
530554 else :
531- # Update nested dataclass
532- nested_config = getattr (self .config , key )
533- for nested_key , nested_value in value .items ():
534- if hasattr (nested_config , nested_key ):
535- setattr (getattr (self .config , key ), nested_key , nested_value )
536- elif hasattr (nested_config , "extra_params" ):
537- getattr (getattr (self .config , key ), "extra_params" )[nested_key ] = nested_value
538- else :
539- setattr (self .config , key , value )
555+ self ._stash_top_level_extra (key , "__all__" , value )
556+ continue
557+
558+ if isinstance (value , dict ) and is_dataclass (target ):
559+ known = {f .name for f in fields (target )}
560+ for nested_key , nested_value in value .items ():
561+ if nested_key in known :
562+ setattr (target , nested_key , nested_value )
563+ else :
564+ self ._stash_top_level_extra (key , nested_key , nested_value )
565+ continue
566+
567+ if isinstance (value , dict ) and isinstance (target , dict ):
568+ target .update (value )
569+ continue
570+ setattr (self .config , key , value )
571+
540572 else :
541- # Store unknown parameters in extra_params
542- self . config . extra_params [key ] = value
573+ ep = self . _ensure_extra_params ( self . config )
574+ ep [key ] = value
543575
544576 def save_config (self , output_path : Union [str , Path ]) -> None :
545577 """Save current configuration to file."""
@@ -557,38 +589,105 @@ def save_config(self, output_path: Union[str, Path]) -> None:
557589 else :
558590 raise ValueError (f"Unsupported output file format: { output_path .suffix } " )
559591
560- def validate_config (self ) -> None :
561- """Validate configuration parameters."""
562- errors = []
563-
564- # Validate model configuration
565- if not self .config .model .model_name :
566- errors .append ("Model name is required" )
567-
568- # Validate dataset configuration
569- if not self .config .dataset .dataset_name :
570- errors .append ("Dataset name is required" )
571-
572- # Validate training parameters
573- if self .config .dataset .train_batch_size <= 0 :
574- errors .append ("Train batch size must be positive" )
575-
576- if self .config .dataset .eval_batch_size <= 0 :
577- errors .append ("Validation batch size must be positive" )
592+ def _push (self , errs : List [str ], cond : bool , msg : str ) -> None :
593+ """Append msg to errs if cond is True."""
594+ if cond :
595+ errs .append (msg )
578596
579- if self .config .training .num_train_epochs <= 0 :
580- errors .append ("Number of epochs must be positive" )
581-
582- if self .config .training .gradient_accumulation_steps <= 0 :
583- errors .append ("Gradient accumulation steps must be positive" )
584-
585- # Validate device configuration
597+ def validate_config (self ) -> None :
598+ """
599+ Validate configuration parameters for MasterConfig.
600+ """
601+ errors : List [str ] = []
602+
603+ cfg = self .config
604+ model = getattr (cfg , "model" , {})
605+ dataset = getattr (cfg , "dataset" , {})
606+ training = getattr (cfg , "training" , {})
607+
608+ # ---------- Model ----------
609+ self ._push (errors , not model .get ("model_name" ), "model.model_name is required." )
610+
611+ # PEFT validation
612+ if model .get ("use_peft" ):
613+ pc = model .get ("peft_config" , {})
614+ self ._push (errors , not isinstance (pc , dict ), "model.peft_config must be a dict when use_peft=True." )
615+ if isinstance (pc , dict ):
616+ self ._push (
617+ errors ,
618+ not isinstance (pc .get ("lora_r" , 0 ), int ) or pc .get ("lora_r" , 0 ) <= 0 ,
619+ "model.peft_config.lora_r must be a positive integer." ,
620+ )
621+ self ._push (
622+ errors ,
623+ not isinstance (pc .get ("lora_alpha" , 0 ), int ) or pc .get ("lora_alpha" , 0 ) <= 0 ,
624+ "model.peft_config.lora_alpha must be a positive integer." ,
625+ )
626+ self ._push (
627+ errors ,
628+ not (0.0 <= float (pc .get ("lora_dropout" , 0.0 )) < 1.0 ),
629+ "model.peft_config.lora_dropout must be in [0,1)." ,
630+ )
631+
632+ # ---------- Dataset ----------
633+ self ._push (errors , not dataset .get ("dataset_name" ), "dataset.dataset_name is required." )
634+ self ._push (errors , not dataset .get ("tokenizer_name" ), "dataset.tokenizer_name is required." )
635+ self ._push (errors , dataset .get ("max_seq_length" , 0 ) <= 0 , "dataset.max_seq_length must be positive." )
636+
637+ # ---------- Training ----------
638+ # Batch sizes
639+ self ._push (
640+ errors ,
641+ training .get ("per_device_train_batch_size" , 0 ) <= 0 ,
642+ "training.per_device_train_batch_size must be positive." ,
643+ )
644+ self ._push (
645+ errors ,
646+ training .get ("per_device_eval_batch_size" , 0 ) <= 0 ,
647+ "training.per_device_eval_batch_size must be positive." ,
648+ )
649+
650+ # Epochs / steps
651+ n_epochs = training .get ("num_train_epochs" , 0 )
652+ max_steps = training .get ("max_steps" , - 1 )
653+ self ._push (
654+ errors ,
655+ n_epochs <= 0 and max_steps <= 0 ,
656+ "Either training.num_train_epochs > 0 or training.max_steps > 0 must be set." ,
657+ )
658+
659+ # Gradient accumulation
660+ self ._push (
661+ errors ,
662+ training .get ("gradient_accumulation_steps" , 0 ) <= 0 ,
663+ "training.gradient_accumulation_steps must be positive." ,
664+ )
665+
666+ # Logging / saving configs
667+ self ._push (errors , training .get ("logging_steps" , 0 ) < 0 , "training.logging_steps must be >= 0." )
668+ self ._push (errors , training .get ("save_total_limit" , 0 ) < 0 , "training.save_total_limit must be >= 0." )
669+
670+ # Device
586671 valid_devices = ["cpu" , "cuda" , "qaic" ]
587- if self .config .training .device not in valid_devices :
588- errors .append (f"Device must be one of { valid_devices } " )
589-
672+ training_device = training .get ("device" , None )
673+ if training_device not in valid_devices :
674+ self ._push (errors , training_device not in valid_devices , f"training.device must be one of { valid_devices } ." )
675+
676+ # DDP config
677+ ddp = training .get ("ddp_config" , {})
678+ if isinstance (ddp , dict ):
679+ backend = ddp .get ("ddp_backend" )
680+ # Accept qccl for Qualcomm, nccl for CUDA, gloo for CPU
681+ self ._push (
682+ errors ,
683+ backend not in {"qccl" , "nccl" , "gloo" , None },
684+ "training.ddp_config.ddp_backend must be one of {'qccl','nccl','gloo'} or omitted." ,
685+ )
686+
687+ # ---------- Final ----------
590688 if errors :
591- raise ValueError ("Configuration validation failed:\n " + "\n " .join (f"- { error } " for error in errors ))
689+ # Join messages with bullet points for readability
690+ raise ValueError ("Configuration validation failed:\n - " + "\n - " .join (errors ))
592691
593692 def get_callback_config (self ) -> Dict [str , Any ]:
594693 """Get callback configuration as dictionary."""
0 commit comments