-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtrain.py
More file actions
112 lines (90 loc) · 4.6 KB
/
train.py
File metadata and controls
112 lines (90 loc) · 4.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Here, we write the code to train the model
import argparse
import json
import logging
import cv2
import keras
import numpy as np
from data.database import Database
from lib.decorator import GeneratorLoop
from lib.img_sim import compute_ssim
from lib.utils import chunks, CSVLogger
from src.CRNN import CRNN
from src.c3d import C3DModel
from src.unet import UNETModel
from src.vae import VAE
from src.vgg3d import VGG3DModel
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--db_path", dest="db_path", default="../dataset", type=str, help="dataset path")
parser.add_argument("--weight_file", dest="weight_file", type=str, help="model weight to be loaded, blank if new model")
parser.add_argument("--sequence_size", dest="sequence_size", default=10, type=int, help="batch size")
parser.add_argument("--batch_size", dest="batch_size", default=1, type=int, help="batch size")
parser.add_argument("--custom_lenght", dest="custom_lenght", default=None, type=int, help="max video to look at")
parser.add_argument("--n_epochs", dest="n_epochs", default=10, type=int, help="nb epochs")
parser.add_argument("--method", dest="method", default="c3d", type=str, help="[c3d,vgg,crnn,vae,unet]")
batch_size = 1
options = parser.parse_args()
print(vars(options))
logging.basicConfig(filename='logging.log', level=logging.DEBUG,
format='%(asctime)s -- %(name)s -- %(levelname)s -- %(message)s')
logging.info(vars(options))
methods = ["c3d", "crnn", "vae", "unet", "vgg"]
assert options.method in methods, "Not a valid method"
if options.method == "c3d":
model = C3DModel(options.sequence_size, batch_size=options.batch_size, weight_file=options.weight_file)
elif options.method == "crnn":
model = CRNN(options.sequence_size, batch_size=options.batch_size, weight_file=options.weight_file)
elif options.method == "vae":
model = VAE(options.sequence_size, batch_size=options.batch_size, weight_file=options.weight_file)
elif options.method == "vgg":
model = VGG3DModel(options.sequence_size, batch_size=options.batch_size, weight_file=options.weight_file)
elif options.method == "unet":
model = UNETModel(options.sequence_size, batch_size=options.batch_size, weight_file=options.weight_file)
else:
print("{} is not available at this moment".format(options.method))
exit(0)
db = Database(options.db_path, options.sequence_size, batch_size=options.batch_size, size=model.img_size,
output_size=model.output_size, custom_lenght=options.custom_lenght)
n_epoch = 0
max_epoch = options.n_epochs
@GeneratorLoop
def get_generator():
for (imgs, gt) in db.get_datas():
yield model.preprocess(np.asarray([db.load_imgs(imgs)]), db.get_groundtruth(gt, 255.0))
@GeneratorLoop
def get_generator_batched():
for batch in chunks(db.get_datas(), options.batch_size):
imgs, gts = zip(*batch)
yield model.preprocess(np.asarray([db.load_imgs(img) for img in imgs]),
np.asarray([db.get_groundtruth(gt, 255.0) for gt in gts]))
@GeneratorLoop
def get_validation_generator_batched():
for batch in chunks(db.get_tests(), options.batch_size):
imgs, gts = zip(*batch)
yield model.preprocess(np.asarray([db.load_imgs(img) for img in imgs]),
np.asarray([db.get_groundtruth(gt, 255.0) for gt in gts]))
def save_one():
imgs, gts = next(get_generator())
img = model.get_model().predict(imgs)[0]
output = 255.0 * np.reshape(img, [model.output_size, model.output_size])
gt = 255. * np.reshape(gts, [model.output_size, model.output_size])
cv2.imwrite("output.png", output)
cv2.imwrite("gt.png", gt)
return abs(compute_ssim(output, gt, 255))
try:
model.get_model().fit_generator(generator=get_generator_batched(), samples_per_epoch=db.get_total_count(),
nb_epoch=max_epoch,
callbacks=[keras.callbacks.ModelCheckpoint(
"mod_{}_{}{}.model".format(options.method, options.sequence_size,
options.custom_lenght)),
CSVLogger("log.csv", append=True)])
except Exception:
logging.warning("Model stopped training!")
logging.info("Starting Testing")
history = model.get_model().evaluate_generator(get_validation_generator_batched(), db.get_total_test_count())
logging.info(history)
with open("history.log", "w") as f:
json.dump(history, f)
ssms = save_one()
logging.info("SSMS {}".format(ssms))
model.get_model().save_weights("{}_w.h5".format(model.name))