1414 'word_embeddings' ,
1515 'embed_tokens' ,
1616 'embedding' ,
17- 'wte' , # GPT style embeddings
18- 'lm_head' # Often tied with embeddings
17+ 'wte' , # GPT style embeddings
18+ 'lm_head' # Language model head, often tied with embeddings
1919]
2020
2121
@@ -35,8 +35,8 @@ def get_parameter_type(name: str) -> dict:
3535if __name__ == '__main__' :
3636 import argparse
3737
38- parser = argparse .ArgumentParser (description = 'Load a HuggingFace model ' )
39- parser .add_argument ('--hf_checkpoint_dir' , type = str , help = 'Path to the HuggingFace checkpoint directory' )
38+ parser = argparse .ArgumentParser (description = 'Convert HuggingFace checkpoint to Universal Checkpoint format ' )
39+ parser .add_argument ('--hf_checkpoint_dir' , type = str , required = True , help = 'Path to the HuggingFace checkpoint directory' )
4040 parser .add_argument ('--safe_serialization' , action = 'store_true' , default = False , help = 'Use safetensors for serialization' )
4141 parser .add_argument ('--num_workers' , type = int , default = 4 , help = 'Number of workers to use for saving checkpoints' )
4242 parser .add_argument ('--save_dir' , type = str , required = True , help = 'Directory to save checkpoints' )
@@ -119,10 +119,12 @@ def get_shard_list(checkpoint_dir):
119119 return list (set (index ['weight_map' ].values ()))
120120 else :
121121 # Handle single file case
122- if args .safe_serialization :
122+ if args .safe_serialization and os . path . exists ( os . path . join ( checkpoint_dir , "model.safetensors" )) :
123123 return ["model.safetensors" ]
124- else :
124+ elif os . path . exists ( os . path . join ( checkpoint_dir , "pytorch_model.bin" )) :
125125 return ["pytorch_model.bin" ]
126+ else :
127+ raise FileNotFoundError (f"No checkpoint files found in { checkpoint_dir } " )
126128
127129 def process_shard_batch (shard_files : List [str ], checkpoint_dir : str , save_dir : str , safe_serialization : bool ):
128130 """Process a batch of shards in parallel."""
0 commit comments