Skip to content

Commit f32fbda

Browse files
committed
-fix bug when looping over generator in TF 2.3.0
1 parent 60354cd commit f32fbda

File tree

4 files changed

+25
-46
lines changed

4 files changed

+25
-46
lines changed

CNN_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ class CNN_Encoder(tf.keras.Model):
99
# Since you have already extracted the features and dumped it using pickle
1010
# This encoder passes those features through a Fully connected layer
1111
def __init__(self, model_path, model_name, pop_conv_layers, encoder_layers, tags_threshold, tags_embeddings=None,
12-
finetune_visual_model=False):
12+
finetune_visual_model=False, num_tags=105):
1313
super(CNN_Encoder, self).__init__()
1414
# shape after fc == (batch_size, 64, embedding_dim)
1515
if tags_embeddings is not None:
1616
self.tags_embeddings = tf.Variable(shape=tags_embeddings.shape, initial_value=tags_embeddings,
1717
trainable=False, dtype=tf.float32)
1818
else:
19-
self.tags_embeddings = tf.Variable(shape=(105, 400), initial_value=tf.ones((105, 400)), trainable=False,
19+
self.tags_embeddings = tf.Variable(shape=(num_tags, 400), initial_value=tf.ones((num_tags, 400)), trainable=False,
2020
dtype=tf.float32)
2121
self.encoder_layers = get_layers(encoder_layers, 'relu')
2222
visual_model = load_model(model_path, model_name)

requirements.txt

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,9 @@
1-
absl-py==0.7.1
2-
astor==0.8.0
3-
boto==2.49.0
4-
boto3==1.9.214
5-
botocore==1.12.214
6-
certifi==2019.6.16
7-
chardet==3.0.4
8-
Click==7.0
9-
cycler==0.10.0
10-
decorator==4.4.0
111
docutils==0.15.2
122
efficientnet==1.0.0
13-
gast==0.2.2
14-
gensim==3.8.0
15-
google-pasta==0.1.7
16-
grpcio==1.23.0
17-
h5py==2.9.0
18-
idna==2.8
193
imageio==2.5.0
204
imgaug==0.3.0
215
jmespath==0.9.4
226
joblib==0.14.1
23-
Keras-Applications==1.0.8
24-
Keras-Preprocessing==1.1.0
257
kiwisolver==1.1.0
268
lxml==4.4.1
279
Markdown==3.1.1
@@ -31,31 +13,30 @@ nltk==3.4.5
3113
numpy==1.17.0
3214
opencv-python==4.1.0.25
3315
opencv-python-headless==4.1.2.30
34-
opt-einsum==3.1.0
3516
pandas==0.25.1
3617
Pillow==6.1.0
37-
protobuf==3.9.1
3818
psutil==5.6.7
3919
pyparsing==2.4.2
4020
python-dateutil==2.8.0
4121
python-docx==0.8.10
42-
pytz==2019.2
4322
PyWavelets==1.0.3
4423
requests==2.22.0
4524
s3transfer==0.2.1
4625
scikit-image==0.15.0
4726
scikit-learn==0.22.1
48-
scipy==1.3.1
49-
Shapely==1.6.4.post2
50-
six==1.12.0
27+
Shapely==1.7.1
5128
smart-open==1.8.4
52-
tensorflow-gpu==2.1.0
29+
tensorflow==2.3.0
5330
termcolor==1.1.0
54-
Theano==1.0.4
5531
tqdm==4.41.1
5632
urllib3==1.25.3
57-
Werkzeug==0.15.5
58-
wrapt==1.11.2
59-
xdg==4.0.1
33+
boto3==1.10.50
34+
botocore==1.13.50
35+
pymc3==3.11.0
36+
theano==1.0.4
37+
theano-pymc==1.1.0
6038
transformers==2.5.0
39+
torch==1.10.0
40+
torchaudio==0.10.0
41+
torchvision==0.11.1
6142
git+https://github.com/Maluuba/nlg-eval.git@master

test.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def save_output_prediction(FLAGS, img_name, target_sentence, predicted_sentence)
101101
plt.close(fig)
102102

103103

104-
def evaluate_enqueuer(enqueuer, FLAGS, encoder, decoder, tokenizer_wrapper, name='Test set', verbose=True,
104+
def evaluate_enqueuer(enqueuer, steps, FLAGS, encoder, decoder, tokenizer_wrapper, name='Test set', verbose=True,
105105
write_json=True, write_images=False, test_mode=False):
106106
tf.keras.backend.set_learning_phase(0)
107107
hypothesis = []
@@ -111,12 +111,10 @@ def evaluate_enqueuer(enqueuer, FLAGS, encoder, decoder, tokenizer_wrapper, name
111111
start = time.time()
112112
csv_dict = {"image_path": [], "real": [], "prediction": []}
113113
generator = enqueuer.get()
114-
for batch in tqdm(list(range(generator.steps))):
114+
for batch in tqdm(list(range(steps))):
115115
images, target, img_path = next(generator)
116-
117116
predicted_sentence = evaluate_full(FLAGS, encoder, decoder, tokenizer_wrapper,
118117
images)
119-
120118
csv_dict["prediction"].append(predicted_sentence)
121119
csv_dict["image_path"].append(os.path.basename(img_path[0]))
122120
target_sentence = tokenizer_wrapper.GPT2_decode(target[0])
@@ -158,7 +156,7 @@ def evaluate_enqueuer(enqueuer, FLAGS, encoder, decoder, tokenizer_wrapper, name
158156
test_enqueuer.start(workers=1, max_queue_size=8)
159157

160158
encoder = CNN_Encoder('pretrained_visual_model', FLAGS.visual_model_name, FLAGS.visual_model_pop_layers,
161-
FLAGS.encoder_layers, FLAGS.tags_threshold)
159+
FLAGS.encoder_layers, FLAGS.tags_threshold, num_tags=len(FLAGS.tags))
162160

163161
decoder = TFGPT2LMHeadModel.from_pretrained('distilgpt2', from_pt=True, resume_download=True)
164162

@@ -174,4 +172,4 @@ def evaluate_enqueuer(enqueuer, FLAGS, encoder, decoder, tokenizer_wrapper, name
174172
start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])
175173
ckpt.restore(ckpt_manager.latest_checkpoint)
176174
print("Restored from checkpoint: {}".format(ckpt_manager.latest_checkpoint))
177-
evaluate_enqueuer(test_enqueuer, FLAGS, encoder, decoder, tokenizer_wrapper, write_images=True, test_mode=True)
175+
evaluate_enqueuer(test_enqueuer, test_steps, FLAGS, encoder, decoder, tokenizer_wrapper, write_images=True, test_mode=True)

train.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
FLAGS.max_sequence_length, FLAGS.tokenizer_vocab_size)
2626

2727
train_enqueuer, train_steps = get_enqueuer(FLAGS.train_csv, FLAGS.batch_size, FLAGS, tokenizer_wrapper)
28-
test_enqueuer, _ = get_enqueuer(FLAGS.test_csv, 1, FLAGS, tokenizer_wrapper)
29-
batch_test_enqueuer, _ = get_enqueuer(FLAGS.test_csv, FLAGS.batch_size, FLAGS, tokenizer_wrapper)
28+
test_enqueuer, test_steps = get_enqueuer(FLAGS.test_csv, 1, FLAGS, tokenizer_wrapper)
29+
batch_test_enqueuer, batch_test_steps = get_enqueuer(FLAGS.test_csv, FLAGS.batch_size, FLAGS, tokenizer_wrapper)
3030

3131
train_enqueuer.start(workers=FLAGS.generator_workers, max_queue_size=FLAGS.generator_queue_length)
3232

@@ -42,7 +42,7 @@
4242

4343
encoder = CNN_Encoder('pretrained_visual_model', FLAGS.visual_model_name, FLAGS.visual_model_pop_layers,
4444
FLAGS.encoder_layers,
45-
FLAGS.tags_threshold, tags_embeddings, FLAGS.finetune_visual_model)
45+
FLAGS.tags_threshold, tags_embeddings, FLAGS.finetune_visual_model, len(FLAGS.tags))
4646
decoder = TFGPT2LMHeadModel.from_pretrained('distilgpt2', from_pt=True, resume_download=True)
4747
optimizer = get_optimizer(FLAGS.optimizer_type, FLAGS.learning_rate)
4848

@@ -126,7 +126,7 @@ def get_avg_score(scores_dict):
126126
time_csv = {"epoch": [], 'time_taken': [], "scores": []}
127127

128128

129-
def get_overall_loss(enqueuer, batch_losses_csv):
129+
def get_overall_loss(enqueuer, steps, batch_losses_csv):
130130
tf.keras.backend.set_learning_phase(0)
131131

132132
if not enqueuer.is_running():
@@ -136,7 +136,7 @@ def get_overall_loss(enqueuer, batch_losses_csv):
136136
batch_losses = []
137137
total_loss = 0
138138
step = 0
139-
for batch in range(generator.steps):
139+
for batch in range(steps):
140140
img, target, _ = next(generator)
141141
batch_loss = train_step(img, target, True)
142142
batch_losses_csv['step'].append(step)
@@ -186,8 +186,8 @@ def get_overall_loss(enqueuer, batch_losses_csv):
186186
print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
187187
print('Batches that took long: {}'.format(times_to_get_batch))
188188
if FLAGS.calculate_loss_after_epoch:
189-
test_epoch_loss, _ = get_overall_loss(batch_test_enqueuer, test_batch_losses_csv)
190-
train_epoch_loss, _ = get_overall_loss(train_enqueuer, train_after_batch_losses_csv)
189+
test_epoch_loss, _ = get_overall_loss(batch_test_enqueuer, batch_test_steps, test_batch_losses_csv)
190+
train_epoch_loss, _ = get_overall_loss(train_enqueuer, train_steps, train_after_batch_losses_csv)
191191
losses_csv['train_after_loss'].append(train_epoch_loss.numpy())
192192
losses_csv['test_loss'].append(test_epoch_loss.numpy())
193193
else:
@@ -209,11 +209,11 @@ def get_overall_loss(enqueuer, batch_losses_csv):
209209
plt.title('Loss Plot')
210210
plt.savefig(FLAGS.ckpt_path + "/loss.png")
211211

212-
if epoch % FLAGS.epochs_to_evaluate == 0:
212+
if epoch % FLAGS.epochs_to_evaluate == 0 and epoch > 0:
213213
current_avg_score = 0
214214
print("Evaluating on test set..")
215215
train_enqueuer.stop()
216-
current_scores = evaluate_enqueuer(test_enqueuer, FLAGS, encoder, decoder, tokenizer_wrapper)
216+
current_scores = evaluate_enqueuer(test_enqueuer, test_steps, FLAGS, encoder, decoder, tokenizer_wrapper)
217217
time_csv['epoch'].append(epoch + 1)
218218
time_csv['time_taken'].append(pure_training_time)
219219
time_csv['scores'].append(current_scores)

0 commit comments

Comments
 (0)