1+ # Copyright (c) Microsoft Corporation.
2+ # SPDX-License-Identifier: Apache-2.0
3+
4+ # DeepSpeed Team
5+
16import torch
27import os
38import shutil
49import logging
5- from concurrent .futures import ProcessPoolExecutor , as_completed
10+ from concurrent .futures import ProcessPoolExecutor
11+ from deepspeed .accelerator import get_accelerator
612from tqdm import tqdm
713from typing import List
814
1420 'word_embeddings' ,
1521 'embed_tokens' ,
1622 'embedding' ,
17- 'wte' , # GPT style embeddings
18- 'lm_head' # Language model head, often tied with embeddings
23+ 'wte' , # GPT style embeddings
24+ 'lm_head' # Language model head, often tied with embeddings
1925]
2026
2127
@@ -24,20 +30,27 @@ def get_parameter_type(name: str) -> dict:
2430 param_info = {
2531 'cat_dim' : 0 # Default concatenation dimension
2632 }
27-
33+
2834 # Check for vocabulary tensors (embeddings, etc.)
2935 if any (pattern in name .lower () for pattern in VOCAB_PARAMETER_PATTERNS ):
3036 param_info ['vocab_tensor' ] = True
31-
37+
3238 # TODO: figure out if we need to check for row-parallel parameters
3339 return param_info
3440
41+
3542if __name__ == '__main__' :
3643 import argparse
37-
44+
3845 parser = argparse .ArgumentParser (description = 'Convert HuggingFace checkpoint to Universal Checkpoint format' )
39- parser .add_argument ('--hf_checkpoint_dir' , type = str , required = True , help = 'Path to the HuggingFace checkpoint directory' )
40- parser .add_argument ('--safe_serialization' , action = 'store_true' , default = False , help = 'Use safetensors for serialization' )
46+ parser .add_argument ('--hf_checkpoint_dir' ,
47+ type = str ,
48+ required = True ,
49+ help = 'Path to the HuggingFace checkpoint directory' )
50+ parser .add_argument ('--safe_serialization' ,
51+ action = 'store_true' ,
52+ default = False ,
53+ help = 'Use safetensors for serialization' )
4154 parser .add_argument ('--num_workers' , type = int , default = 4 , help = 'Number of workers to use for saving checkpoints' )
4255 parser .add_argument ('--save_dir' , type = str , required = True , help = 'Directory to save checkpoints' )
4356 args = parser .parse_args ()
@@ -50,19 +63,19 @@ def save_parameter(name: str, param: torch.Tensor, save_dir: str):
5063 # Create parameter directory under zero/
5164 param_dir = os .path .join (save_dir , name )
5265 os .makedirs (param_dir , exist_ok = True )
53-
66+
5467 # Get parameter type and required fields
5568 param_info = get_parameter_type (name )
56-
69+
5770 # Save parameter in fp32 with proper dictionary structure
5871 param_path = os .path .join (param_dir , "fp32.pt" )
5972 param_dict = {
6073 'param' : param .to (torch .float32 ), # Main tensor goes in 'param' field
6174 ** param_info # Include all determined parameter info
6275 }
6376 torch .save (param_dict , param_path )
64-
65- # Since HuggingFace checkpoints do not have optimizer states,
77+
78+ # Since HuggingFace checkpoints do not have optimizer states,
6679 # we initialize them with zeros
6780 for state in ("exp_avg" , "exp_avg_sq" ):
6881 state_path = os .path .join (param_dir , f"{ state } .pt" )
@@ -77,30 +90,30 @@ def process_shard(shard_file, checkpoint_dir, save_dir, safe_serialization):
7790 try :
7891 shard_path = os .path .join (checkpoint_dir , shard_file )
7992 logger .info (f"Loading shard from: { shard_path } " )
80-
93+
8194 if safe_serialization :
8295 from safetensors .torch import load_file
8396 shard_dict = load_file (shard_path )
8497 else :
8598 shard_dict = torch .load (shard_path , map_location = 'cpu' )
86-
99+
87100 # Create progress bar for parameters within this shard
88- pbar = tqdm (total = len (shard_dict ),
89- desc = f"Processing { os .path .basename (shard_file )} " ,
90- position = 1 ,
91- leave = False )
92-
101+ pbar = tqdm (total = len (shard_dict ),
102+ desc = f"Processing { os .path .basename (shard_file )} " ,
103+ position = 1 ,
104+ leave = False )
105+
93106 for key , param in shard_dict .items ():
94107 save_parameter (key , param , save_dir )
95108 del param
96109 pbar .update (1 )
97110 pbar .set_postfix ({'key' : key [:20 ] + '...' if len (key ) > 20 else key })
98-
111+
99112 pbar .close ()
100113 del shard_dict
101- torch . cuda .empty_cache ()
114+ get_accelerator () .empty_cache ()
102115 logger .info (f"Completed processing shard: { shard_file } " )
103-
116+
104117 except Exception as e :
105118 logger .error (f"Error processing shard { shard_file } : { str (e )} " )
106119 raise
@@ -111,7 +124,7 @@ def get_shard_list(checkpoint_dir):
111124 index_file = os .path .join (checkpoint_dir , "model.safetensors.index.json" )
112125 else :
113126 index_file = os .path .join (checkpoint_dir , "pytorch_model.bin.index.json" )
114-
127+
115128 if os .path .exists (index_file ):
116129 import json
117130 with open (index_file , 'r' ) as f :
@@ -131,18 +144,11 @@ def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: s
131144 with ProcessPoolExecutor (max_workers = args .num_workers ) as executor :
132145 futures = []
133146 for shard_file in shard_files :
134- future = executor .submit (process_shard ,
135- shard_file ,
136- checkpoint_dir ,
137- save_dir ,
138- safe_serialization )
147+ future = executor .submit (process_shard , shard_file , checkpoint_dir , save_dir , safe_serialization )
139148 futures .append ((shard_file , future ))
140-
149+
141150 # Create progress bar for this batch
142- batch_pbar = tqdm (total = len (futures ),
143- desc = f"Processing shard batch" ,
144- position = 0 ,
145- leave = True )
151+ batch_pbar = tqdm (total = len (futures ), desc = f"Processing shard batch" , position = 0 , leave = True )
146152
147153 # Wait for all futures to complete
148154 for shard_file , future in futures :
@@ -153,7 +159,7 @@ def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: s
153159 except Exception as e :
154160 logger .error (f"Failed processing shard { shard_file } : { str (e )} " )
155161 raise
156-
162+
157163 batch_pbar .close ()
158164
159165 try :
@@ -162,42 +168,45 @@ def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: s
162168 if os .path .exists (temp_zero_dir ):
163169 logger .info (f"Removing existing temp directory: { temp_zero_dir } " )
164170 shutil .rmtree (temp_zero_dir )
165-
171+
166172 shard_files = get_shard_list (args .hf_checkpoint_dir )
167173 total_shards = len (shard_files )
168174 logger .info (f"Found { total_shards } shards to process" )
169175 # Process shards in batches equal to the number of workers
170176 batch_size = args .num_workers
171177 for i in range (0 , total_shards , batch_size ):
172178 batch_shards = shard_files [i :i + batch_size ]
173- logger .info (f"Processing batch of { len (batch_shards )} shards ({ i + 1 } -{ min (i + batch_size , total_shards )} of { total_shards } )" )
174- process_shard_batch (batch_shards ,
175- args .hf_checkpoint_dir ,
176- temp_zero_dir , # Changed from temp_save_dir to temp_zero_dir
177- args .safe_serialization )
178-
179+ logger .info (
180+ f"Processing batch of { len (batch_shards )} shards ({ i + 1 } -{ min (i + batch_size , total_shards )} of { total_shards } )"
181+ )
182+ process_shard_batch (
183+ batch_shards ,
184+ args .hf_checkpoint_dir ,
185+ temp_zero_dir , # Changed from temp_save_dir to temp_zero_dir
186+ args .safe_serialization )
187+
179188 # Clear CUDA cache after each batch to free up memory
180- torch . cuda .empty_cache ()
181-
189+ get_accelerator () .empty_cache ()
190+
182191 logger .info ("All shard batches processed successfully" )
183-
192+
184193 final_save_dir = os .path .join (args .save_dir , 'zero' )
185194 if os .path .exists (final_save_dir ):
186195 shutil .rmtree (final_save_dir )
187-
196+
188197 # Create the parent directory if it doesn't exist
189198 os .makedirs (os .path .dirname (final_save_dir ), exist_ok = True )
190199 # Move the zero directory to its final location
191200 os .rename (temp_zero_dir , final_save_dir )
192-
201+
193202 # Clean up the temporary directory
194203 if os .path .exists (temp_save_dir ):
195204 shutil .rmtree (temp_save_dir )
196-
205+
197206 # Write identifier file
198207 with open (os .path .join (args .save_dir , 'source.txt' ), 'w' ) as f :
199208 f .write ("Huggingface checkpoint" )
200-
209+
201210 logger .info (f"Successfully saved checkpoint to { final_save_dir } " )
202211
203212 # Update latest file
@@ -206,7 +215,7 @@ def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: s
206215 latest_file = os .path .join (checkpoint_root_folder , 'latest_universal' )
207216 with open (latest_file , 'w' ) as f :
208217 f .write (step_folder )
209-
218+
210219 logger .info (f"Checkpoint conversion completed successfully. Latest file updated at { latest_file } " )
211220
212221 except Exception as e :
0 commit comments