2525 FLAGS .max_sequence_length , FLAGS .tokenizer_vocab_size )
2626
2727train_enqueuer , train_steps = get_enqueuer (FLAGS .train_csv , FLAGS .batch_size , FLAGS , tokenizer_wrapper )
28- test_enqueuer , _ = get_enqueuer (FLAGS .test_csv , 1 , FLAGS , tokenizer_wrapper )
29- batch_test_enqueuer , _ = get_enqueuer (FLAGS .test_csv , FLAGS .batch_size , FLAGS , tokenizer_wrapper )
28+ test_enqueuer , test_steps = get_enqueuer (FLAGS .test_csv , 1 , FLAGS , tokenizer_wrapper )
29+ batch_test_enqueuer , batch_test_steps = get_enqueuer (FLAGS .test_csv , FLAGS .batch_size , FLAGS , tokenizer_wrapper )
3030
3131train_enqueuer .start (workers = FLAGS .generator_workers , max_queue_size = FLAGS .generator_queue_length )
3232
4242
4343encoder = CNN_Encoder ('pretrained_visual_model' , FLAGS .visual_model_name , FLAGS .visual_model_pop_layers ,
4444 FLAGS .encoder_layers ,
45- FLAGS .tags_threshold , tags_embeddings , FLAGS .finetune_visual_model )
45+ FLAGS .tags_threshold , tags_embeddings , FLAGS .finetune_visual_model , len ( FLAGS . tags ) )
4646decoder = TFGPT2LMHeadModel .from_pretrained ('distilgpt2' , from_pt = True , resume_download = True )
4747optimizer = get_optimizer (FLAGS .optimizer_type , FLAGS .learning_rate )
4848
@@ -126,7 +126,7 @@ def get_avg_score(scores_dict):
126126time_csv = {"epoch" : [], 'time_taken' : [], "scores" : []}
127127
128128
129- def get_overall_loss (enqueuer , batch_losses_csv ):
129+ def get_overall_loss (enqueuer , steps , batch_losses_csv ):
130130 tf .keras .backend .set_learning_phase (0 )
131131
132132 if not enqueuer .is_running ():
@@ -136,7 +136,7 @@ def get_overall_loss(enqueuer, batch_losses_csv):
136136 batch_losses = []
137137 total_loss = 0
138138 step = 0
139- for batch in range (generator . steps ):
139+ for batch in range (steps ):
140140 img , target , _ = next (generator )
141141 batch_loss = train_step (img , target , True )
142142 batch_losses_csv ['step' ].append (step )
@@ -186,8 +186,8 @@ def get_overall_loss(enqueuer, batch_losses_csv):
186186 print ('Time taken for 1 epoch {} sec\n ' .format (time .time () - start ))
187187 print ('Batches that took long: {}' .format (times_to_get_batch ))
188188 if FLAGS .calculate_loss_after_epoch :
189- test_epoch_loss , _ = get_overall_loss (batch_test_enqueuer , test_batch_losses_csv )
190- train_epoch_loss , _ = get_overall_loss (train_enqueuer , train_after_batch_losses_csv )
189+ test_epoch_loss , _ = get_overall_loss (batch_test_enqueuer , batch_test_steps , test_batch_losses_csv )
190+ train_epoch_loss , _ = get_overall_loss (train_enqueuer , train_steps , train_after_batch_losses_csv )
191191 losses_csv ['train_after_loss' ].append (train_epoch_loss .numpy ())
192192 losses_csv ['test_loss' ].append (test_epoch_loss .numpy ())
193193 else :
@@ -209,11 +209,11 @@ def get_overall_loss(enqueuer, batch_losses_csv):
209209 plt .title ('Loss Plot' )
210210 plt .savefig (FLAGS .ckpt_path + "/loss.png" )
211211
212- if epoch % FLAGS .epochs_to_evaluate == 0 :
212+ if epoch % FLAGS .epochs_to_evaluate == 0 and epoch > 0 :
213213 current_avg_score = 0
214214 print ("Evaluating on test set.." )
215215 train_enqueuer .stop ()
216- current_scores = evaluate_enqueuer (test_enqueuer , FLAGS , encoder , decoder , tokenizer_wrapper )
216+ current_scores = evaluate_enqueuer (test_enqueuer , test_steps , FLAGS , encoder , decoder , tokenizer_wrapper )
217217 time_csv ['epoch' ].append (epoch + 1 )
218218 time_csv ['time_taken' ].append (pure_training_time )
219219 time_csv ['scores' ].append (current_scores )
0 commit comments