diff --git a/trainers/trainer.py b/trainers/trainer.py index 9968092..d71d90b 100644 --- a/trainers/trainer.py +++ b/trainers/trainer.py @@ -320,6 +320,7 @@ def train(self,args=None): torch.save(self.algorithm.classifier.state_dict(), self.cpath) if losses_val['Total_loss']<= self.best_val_loss: + self.best_val_loss = losses_val['Total_loss'] # Update to current best loss value torch.save(self.algorithm.feature_extractor.state_dict(), f"{self.fpath}_best_val") torch.save(self.algorithm.classifier.state_dict(), f"{self.cpath}_best_val") self.f1_run_score.append(f1)