@@ -188,11 +188,15 @@ def train_func(config):
188188 if config ["evaluate" ]:
189189 test_history = multi_worker_model .evaluate (eval_tf_dataset , callbacks = callbacks )
190190 results .append (test_history )
191- with tempfile .TemporaryDirectory () as temp_checkpoint_dir :
192- multi_worker_model .save (temp_checkpoint_dir , save_format = "tf" )
193- checkpoint = Checkpoint .from_directory (temp_checkpoint_dir )
191+
192+ # Only save checkpoint from the chief worker to avoid race conditions
193+ checkpoint = None
194+ if session .get_world_rank () == 0 :
195+ with tempfile .TemporaryDirectory () as temp_checkpoint_dir :
196+ multi_worker_model .save (temp_checkpoint_dir , save_format = "tf" )
197+ checkpoint = Checkpoint .from_directory (temp_checkpoint_dir )
194198
195- session .report ({}, checkpoint = checkpoint )
199+ session .report ({}, checkpoint = checkpoint )
196200
197201 def fit (self ,
198202 train_ds : Dataset ,
@@ -290,6 +294,21 @@ def fit_on_spark(self,
290294
291295 def get_model (self ) -> Any :
292296 assert self ._trainer , "Trainer has not been created"
293- return TensorflowCheckpoint .from_saved_model (
294- self ._results .checkpoint .to_directory ()
295- ).get_model ()
297+ checkpoint_dir = self ._results .checkpoint .to_directory ()
298+
299+ try :
300+ # Try standard loading
301+ return TensorflowCheckpoint .from_saved_model (checkpoint_dir ).get_model ()
302+ except RuntimeError as e :
303+ if "Fingerprint" in str (e ) or "fingerprint.pb" in str (e ):
304+ # Fallback: Load directly with tf.keras.models.load_model
305+ # This bypasses Ray Train's checkpoint wrapper and fingerprint validation
306+ import warnings
307+ warnings .warn (
308+ f"Encountered fingerprint error when loading checkpoint: { e } . "
309+ "Falling back to direct tf.keras.models.load_model. "
310+ "This may indicate a TensorFlow version incompatibility."
311+ )
312+ return tf .keras .models .load_model (checkpoint_dir )
313+ else :
314+ raise
0 commit comments