Skip to content

Commit 29e3668

Browse files
committed
-handle batch size issues
1 parent 1457e96 commit 29e3668

File tree

3 files changed

+14
-28
lines changed

3 files changed

+14
-28
lines changed

generator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ def __getitem__(self, idx):
5151
batch_x = np.asarray([self.load_image(x_path) for x_path in batch_x_path])
5252
batch_x = self.transform_batch_images(batch_x)
5353
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
54+
if batch_x.shape[0] < self.batch_size:
55+
remaining = self.batch_size - batch_x.shape[0]
56+
batch_x = np.concatenate((batch_x,batch_x[0:remaining]), axis = 0)
57+
batch_x_path = np.concatenate((batch_x_path,batch_x_path[0:remaining]), axis = 0)
58+
batch_y = np.concatenate((batch_y,batch_y[0:remaining]), axis = 0)
5459
return batch_x, batch_y, batch_x_path
5560

5661
def load_image(self, image_file):
@@ -83,7 +88,7 @@ def get_y_true(self):
8388

8489
def prepare_dataset(self):
8590
df = self.dataset_df.sample(frac=1., random_state=self.random_state)
86-
self.x_path, self.y = df["Image Index"].as_matrix(), self.tokenizer_wrapper.tokenize_sentences(df[self.class_names].as_matrix())
91+
self.x_path, self.y = df["Image Index"].values, self.tokenizer_wrapper.tokenize_sentences(df[self.class_names].values)
8792

8893
def on_epoch_end(self):
8994
if self.shuffle:

requirements.txt

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,9 @@
1-
absl-py==0.7.1
2-
astor==0.8.0
3-
boto==2.49.0
4-
boto3==1.9.214
5-
botocore==1.12.214
6-
certifi==2019.6.16
7-
chardet==3.0.4
8-
Click==7.0
9-
cycler==0.10.0
10-
decorator==4.4.0
111
docutils==0.15.2
122
efficientnet==1.0.0
13-
gast==0.2.2
14-
google-pasta==0.1.7
15-
h5py==2.9.0
16-
idna==2.8
173
imageio==2.5.0
184
imgaug==0.3.0
195
jmespath==0.9.4
206
joblib==0.14.1
21-
Keras-Applications==1.0.8
22-
Keras-Preprocessing==1.1.0
237
kiwisolver==1.1.0
248
lxml==4.4.1
259
Markdown==3.1.1
@@ -29,29 +13,26 @@ nltk==3.4.5
2913
numpy==1.17.0
3014
opencv-python==4.1.0.25
3115
opencv-python-headless==4.1.2.30
32-
opt-einsum==3.1.0
3316
pandas==0.25.1
3417
Pillow==6.1.0
35-
protobuf==3.9.1
3618
psutil==5.6.7
3719
pyparsing==2.4.2
3820
python-dateutil==2.8.0
3921
python-docx==0.8.10
40-
pytz==2019.2
4122
PyWavelets==1.0.3
4223
requests==2.22.0
4324
s3transfer==0.2.1
4425
scikit-image==0.15.0
4526
scikit-learn==0.22.1
46-
Shapely==1.6.4.post2
47-
six==1.12.0
27+
Shapely==1.7.1
4828
smart-open==1.8.4
49-
tensorflow==2.1.0
29+
tensorflow==2.3.0
5030
termcolor==1.1.0
51-
Theano==1.0.4
5231
tqdm==4.41.1
5332
urllib3==1.25.3
54-
Werkzeug==0.15.5
55-
wrapt==1.11.2
56-
xdg==4.0.1
33+
boto3==1.10.50
34+
botocore==1.13.50
35+
pymc3==3.11.0
36+
theano==1.0.4
37+
theano-pymc==1.1.0
5738
git+https://github.com/Maluuba/nlg-eval.git@master

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def train_step(tag_predictions, visual_features, target):
145145
print('Batches that took long: {}'.format(times_to_get_batch))
146146

147147
ckpt_manager.save()
148-
if epoch % FLAGS.epochs_to_evaluate == 0:
148+
if epoch % FLAGS.epochs_to_evaluate == 0 and epoch > 0:
149149
print("Evaluating on test set..")
150150
train_enqueuer.stop()
151151
current_scores = evaluate_enqueuer(test_enqueuer, test_steps, FLAGS, encoder, decoder, tokenizer_wrapper,

0 commit comments

Comments
 (0)