@@ -2899,7 +2899,7 @@ def _get_all_ckpt_names(self, checkpoints_path, tag):
28992899
29002900 ckpt_files = glob .glob (ckpt_file_pattern )
29012901 ckpt_files .sort ()
2902- return ckpt_files
2902+ return ckpt_files , ckpt_file_pattern
29032903
29042904 def load_checkpoint (self ,
29052905 load_dir ,
@@ -2923,7 +2923,7 @@ def load_checkpoint(self,
29232923
29242924 Returns:
29252925 A tuple of ``load_path`` and ``client_state``.
2926- *``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed.
2926+ *``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed or loading a HF based UCP
29272927 *``client_state``: State dictionary used for loading required training states in the client code.
29282928
29292929 Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right
@@ -2962,6 +2962,12 @@ def load_checkpoint(self,
29622962 custom_load_fn = custom_load_fn )
29632963
29642964 load_zero_checkpoint = load_path is not None and (self .zero_optimization () or self .bfloat16_enabled ())
2965+ # import pdb; pdb.set_trace()
2966+ if self .load_universal_checkpoint ():
2967+ ucp_ckpt_folder = os .path .join (load_dir , tag )
2968+ # UCP load can ignore '*mp' files or '*model_states.pt' but ucp_ckpt_folder must exist
2969+ load_zero_checkpoint = os .path .isdir (ucp_ckpt_folder )
2970+
29652971 if load_zero_checkpoint :
29662972 if (load_optimizer_states and not load_module_only ) or self .load_universal_checkpoint ():
29672973 success = self ._load_zero_checkpoint (load_dir , tag , load_optimizer_states = load_optimizer_states )
@@ -3002,7 +3008,11 @@ def _load_checkpoint(self,
30023008
30033009 from deepspeed .runtime .state_dict_factory import SDLoaderFactory
30043010
3005- ckpt_list = self ._get_all_ckpt_names (load_dir , tag )
3011+ ckpt_list , ckpt_file_pattern = self ._get_all_ckpt_names (load_dir , tag )
3012+ if self .load_universal_checkpoint () and len (ckpt_list ) == 0 :
3013+ logger .warning (f"Unable to find { ckpt_file_pattern } files in UCP folder { load_dir } " )
3014+ return None , {}
3015+
30063016 sd_loader = SDLoaderFactory .get_sd_loader (ckpt_list , checkpoint_engine = self .checkpoint_engine )
30073017
30083018 is_pipe_parallel = isinstance (self .module , PipelineModule )
0 commit comments