Skip to content

Commit bfd08f4

Browse files
committed
tf/estimator.py: only write checkpoint in rank0
1 parent 0a73b38 commit bfd08f4

File tree

1 file changed

+26
-7
lines changed

1 file changed

+26
-7
lines changed

python/raydp/tf/estimator.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,15 @@ def train_func(config):
188188
if config["evaluate"]:
189189
test_history = multi_worker_model.evaluate(eval_tf_dataset, callbacks=callbacks)
190190
results.append(test_history)
191-
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
192-
multi_worker_model.save(temp_checkpoint_dir, save_format="tf")
193-
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
191+
192+
# Only save checkpoint from the chief worker to avoid race conditions
193+
checkpoint = None
194+
if session.get_world_rank() == 0:
195+
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
196+
multi_worker_model.save(temp_checkpoint_dir, save_format="tf")
197+
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
194198

195-
session.report({}, checkpoint=checkpoint)
199+
session.report({}, checkpoint=checkpoint)
196200

197201
def fit(self,
198202
train_ds: Dataset,
@@ -290,6 +294,21 @@ def fit_on_spark(self,
290294

291295
def get_model(self) -> Any:
292296
assert self._trainer, "Trainer has not been created"
293-
return TensorflowCheckpoint.from_saved_model(
294-
self._results.checkpoint.to_directory()
295-
).get_model()
297+
checkpoint_dir = self._results.checkpoint.to_directory()
298+
299+
try:
300+
# Try standard loading
301+
return TensorflowCheckpoint.from_saved_model(checkpoint_dir).get_model()
302+
except RuntimeError as e:
303+
if "Fingerprint" in str(e) or "fingerprint.pb" in str(e):
304+
# Fallback: Load directly with tf.keras.models.load_model
305+
# This bypasses Ray Train's checkpoint wrapper and fingerprint validation
306+
import warnings
307+
warnings.warn(
308+
f"Encountered fingerprint error when loading checkpoint: {e}. "
309+
"Falling back to direct tf.keras.models.load_model. "
310+
"This may indicate a TensorFlow version incompatibility."
311+
)
312+
return tf.keras.models.load_model(checkpoint_dir)
313+
else:
314+
raise

0 commit comments

Comments
 (0)