Skip to content

Commit 60cfcb1

Browse files
authored
Merge pull request #465 from tensorlayer/models-mobilenetv1
tl.models.MobileNetV1
2 parents c2811ce + 1d05d45 commit 60cfcb1

File tree

9 files changed

+245
-16
lines changed

9 files changed

+245
-16
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ Examples can be found [in this folder](https://github.com/zsdonghao/tensorlayer/
8585
- VGG 19 (ImageNet). Classification task, see [tutorial_vgg19.py](https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_vgg19.py).
8686
- InceptionV3 (ImageNet). Classification task, see [tutorial\_inceptionV3_tfslim.py](https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_inceptionV3_tfslim.py).
8787
- SqueezeNet (ImageNet). Model compression, see [tl.models.SqueezeNetV1](https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_models_squeezenetv1.py) or [tutorial_squeezenet.py](https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_squeezenet.py)
88-
- MobileNet (ImageNet). Model compression, see [tutorial_mobilenet.py](https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mobilenet.py).
88+
- MobileNet (ImageNet). Model compression, see [tl.models.MobileNetV1](https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_models_mobilenetv1.py) or [tutorial_mobilenet.py](https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mobilenet.py).
8989
- BinaryNet. Model compression, see [mnist](https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_binarynet_mnist_cnn.py) [cifar10](https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_binarynet_cifar10_tfrecord.py).
9090
- Ternary Weight Network. Model compression, see [mnist](https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_ternaryweight_mnist_cnn.py) [cifar10](https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_ternaryweight_cifar10_tfrecord.py).
9191
- DoReFa-Net. Model compression, see [mnist](https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_dorefanet_mnist_cnn.py) [cifar10](https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_dorefanet_cifar10_tfrecord.py).

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ method, this part of the documentation is for you.
5252
modules/files
5353
modules/visualize
5454
modules/activation
55+
modules/models
5556
modules/distributed
5657

5758

docs/modules/models.rst

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
API - Models
2+
================================
3+
4+
TensorLayer provides many pretrained models, you can easily use the whole or a part of the pretrained models via these APIs.
5+
6+
.. automodule:: tensorlayer.models
7+
8+
.. autosummary::
9+
10+
VGG16
11+
SqueezeNetV1
12+
MobileNetV1
13+
14+
VGG16
15+
----------------------
16+
17+
.. autoclass:: VGG16
18+
19+
SqueezeNetV1
20+
----------------
21+
.. autoclass:: SqueezeNetV1
22+
23+
MobileNetV1
24+
----------------
25+
26+
.. autoclass:: MobileNetV1

docs/user/example.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Computer Vision
2222
- VGG 19 (ImageNet). Classification task, see `tutorial_vgg19.py <https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_vgg19.py>`_.
2323
- InceptionV3 (ImageNet). Classification task, see `tutorial_inceptionV3_tfslim.py <https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_inceptionV3_tfslim.py>`_.
2424
- SqueezeNet (ImageNet). Model compression, see `tl.models.SqueezeNetV1 <https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_models_squeezenetv1.py>`__ or `tutorial_squeezenet.py <https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_squeezenet.py>`_.
25-
- MobileNet (ImageNet). Model compression, see `tutorial_mobilenet.py <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mobilenet.py>`__.
25+
- MobileNet (ImageNet). Model compression, see `tl.models.MobileNetV1 <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_models_mobilenetv1.py>`__ or `tutorial_mobilenet.py <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mobilenet.py>`__.
2626
- BinaryNet. Model compression, see `mnist <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_binarynet_mnist_cnn.py>`__ `cifar10 <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_binarynet_cifar10_tfrecord.py>`__.
2727
- Ternary Weight Network. Model compression, see `mnist <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_ternaryweight_mnist_cnn.py>`__ `cifar10 <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_ternaryweight_cifar10_tfrecord.py>`__.
2828
- DoReFa-Net. Model compression, see `mnist <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_dorefanet_mnist_cnn.py>`__ `cifar10 <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_dorefanet_cifar10_tfrecord.py>`__.
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
"""
4+
MobileNetV1 for ImageNet using TL models
5+
6+
- mobilenetv2 : https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet
7+
- tf.slim : https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models
8+
"""
9+
10+
import time
11+
import numpy as np
12+
import tensorflow as tf
13+
import tensorlayer as tl
14+
from tensorlayer.models.imagenet_classes import class_names
15+
16+
x = tf.placeholder(tf.float32, [None, 224, 224, 3])
17+
18+
# get the whole model
19+
mobilenetv1 = tl.models.MobileNetV1(x)
20+
21+
# restore pre-trained parameters
22+
sess = tf.InteractiveSession()
23+
24+
mobilenetv1.restore_params(sess)
25+
26+
probs = tf.nn.softmax(mobilenetv1.outputs)
27+
28+
mobilenetv1.print_params(False)
29+
30+
mobilenetv1.print_layers()
31+
32+
img1 = tl.vis.read_image('data/tiger.jpeg')
33+
img1 = tl.prepro.imresize(img1, (224, 224)) / 255
34+
35+
_ = sess.run(probs, feed_dict={x: [img1]})[0] # 1st time takes time to compile
36+
start_time = time.time()
37+
prob = sess.run(probs, feed_dict={x: [img1]})[0]
38+
print(" End time : %.5ss" % (time.time() - start_time))
39+
preds = (np.argsort(prob)[::-1])[0:5]
40+
for p in preds:
41+
print(class_names[p], prob[p])

tensorlayer/layers/shape.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,13 @@ class FlattenLayer(Layer):
3737
def __init__(
3838
self,
3939
prev_layer,
40-
name='flatten_layer',
40+
name='flatten',
4141
):
4242
Layer.__init__(self, prev_layer=prev_layer, name=name)
4343
self.inputs = prev_layer.outputs
4444
self.outputs = flatten_reshape(self.inputs, name=name)
4545
self.n_units = int(self.outputs.get_shape()[-1])
4646
logging.info("FlattenLayer %s: %d" % (self.name, self.n_units))
47-
# self.all_layers = list(layer.all_layers)
48-
# self.all_params = list(layer.all_params)
49-
# self.all_drop = dict(layer.all_drop)
5047
self.all_layers.append(self.outputs)
5148

5249

@@ -76,15 +73,12 @@ def __init__(
7673
self,
7774
prev_layer,
7875
shape,
79-
name='reshape_layer',
76+
name='reshape',
8077
):
8178
Layer.__init__(self, prev_layer=prev_layer, name=name)
8279
self.inputs = prev_layer.outputs
8380
self.outputs = tf.reshape(self.inputs, shape=shape, name=name)
8481
logging.info("ReshapeLayer %s: %s" % (self.name, self.outputs.get_shape()))
85-
# self.all_layers = list(layer.all_layers)
86-
# self.all_params = list(layer.all_params)
87-
# self.all_drop = dict(layer.all_drop)
8882
self.all_layers.append(self.outputs)
8983

9084

@@ -122,10 +116,5 @@ def __init__(
122116
assert perm is not None
123117

124118
logging.info("TransposeLayer %s: perm:%s" % (self.name, perm))
125-
# with tf.variable_scope(name) as vs:
126119
self.outputs = tf.transpose(self.inputs, perm=perm, name=name)
127-
# self.all_layers = list(layer.all_layers)
128-
# self.all_params = list(layer.all_params)
129-
# self.all_drop = dict(layer.all_drop)
130120
self.all_layers.append(self.outputs)
131-
# self.all_params.extend( variables )

tensorlayer/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22

33
from .vgg16 import VGG16
44
from .squeezenetv1 import SqueezeNetV1
5+
from .mobilenetv1 import MobileNetV1

tensorlayer/models/mobilenetv1.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
MobileNet for ImageNet.
4+
"""
5+
6+
import os
7+
# import numpy as np
8+
import tensorflow as tf
9+
from .. import _logging as logging
10+
from ..layers import (Layer, BatchNormLayer, Conv2d, DepthwiseConv2d, FlattenLayer, GlobalMeanPool2d, InputLayer, ReshapeLayer)
11+
from ..files import maybe_download_and_extract, assign_params, load_npz
12+
13+
__all__ = [
14+
'MobileNetV1',
15+
]
16+
17+
18+
class MobileNetV1(Layer):
19+
"""Pre-trained MobileNetV1 model.
20+
21+
Parameters
22+
------------
23+
x : placeholder
24+
shape [None, 224, 224, 3], value range [0, 1].
25+
end_with : str
26+
The end point of the model [conv, depth1, depth2 ... depth13, globalmeanpool, out]. Default ``out`` i.e. the whole model.
27+
is_train : boolean
28+
Whether the model is used for training i.e. enable dropout.
29+
reuse : boolean
30+
Whether to reuse the model.
31+
32+
Examples
33+
---------
34+
Classify ImageNet classes, see `tutorial_models_mobilenetv1.py <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_models_mobilenetv1.py>__`
35+
>>> x = tf.placeholder(tf.float32, [None, 224, 224, 3])
36+
>>> # get the whole model
37+
>>> net = tl.models.MobileNetV1(x)
38+
>>> # restore pre-trained parameters
39+
>>> sess = tf.InteractiveSession()
40+
>>> net.restore_params(sess)
41+
>>> # use for inferencing
42+
>>> probs = tf.nn.softmax(net.outputs)
43+
44+
Extract features and Train a classifier with 100 classes
45+
>>> x = tf.placeholder(tf.float32, [None, 224, 224, 3])
46+
>>> # get model without the last layer
47+
>>> cnn = tl.models.MobileNetV1(x, end_with='reshape')
48+
>>> # add one more layer
49+
>>> net = Conv2d(cnn, 100, (1, 1), (1, 1), name='out')
50+
>>> net = FlattenLayer(net, name='flatten')
51+
>>> # initialize all parameters
52+
>>> sess = tf.InteractiveSession()
53+
>>> tl.layers.initialize_global_variables(sess)
54+
>>> # restore pre-trained parameters
55+
>>> cnn.restore_params(sess)
56+
>>> # train your own classifier (only update the last layer)
57+
>>> train_params = tl.layers.get_variables_with_name('output')
58+
59+
Reuse model
60+
>>> x1 = tf.placeholder(tf.float32, [None, 224, 224, 3])
61+
>>> x2 = tf.placeholder(tf.float32, [None, 224, 224, 3])
62+
>>> # get VGG without the last layer
63+
>>> net1 = tl.models.MobileNetV1(x1, end_with='reshape')
64+
>>> # reuse the parameters with different input
65+
>>> net2 = tl.models.MobileNetV1(x2, end_with='reshape', reuse=True)
66+
>>> # restore pre-trained parameters (as they share parameters, we don’t need to restore net2)
67+
>>> sess = tf.InteractiveSession()
68+
>>> net1.restore_params(sess)
69+
70+
"""
71+
72+
def __init__(self, x, end_with='out', is_train=False, reuse=None):
73+
74+
self.net = self.mobilenetv1(x, end_with, is_train, reuse)
75+
self.outputs = self.net.outputs
76+
self.all_params = self.net.all_params
77+
self.all_layers = self.net.all_layers
78+
self.all_drop = self.net.all_drop
79+
self.print_layers = self.net.print_layers
80+
self.print_params = self.net.print_params
81+
82+
# @classmethod
83+
def mobilenetv1(self, x, end_with='out', is_train=False, reuse=None):
84+
with tf.variable_scope("mobilenetv1", reuse=reuse):
85+
n = InputLayer(x)
86+
n = self.conv_block(n, 32, strides=(2, 2), is_train=is_train, name="conv")
87+
if end_with in n.outputs.name:
88+
return n
89+
n = self.depthwise_conv_block(n, 64, is_train=is_train, name="depth1")
90+
if end_with in n.outputs.name:
91+
return n
92+
93+
n = self.depthwise_conv_block(n, 128, strides=(2, 2), is_train=is_train, name="depth2")
94+
if end_with in n.outputs.name:
95+
return n
96+
n = self.depthwise_conv_block(n, 128, is_train=is_train, name="depth3")
97+
if end_with in n.outputs.name:
98+
return n
99+
100+
n = self.depthwise_conv_block(n, 256, strides=(2, 2), is_train=is_train, name="depth4")
101+
if end_with in n.outputs.name:
102+
return n
103+
n = self.depthwise_conv_block(n, 256, is_train=is_train, name="depth5")
104+
if end_with in n.outputs.name:
105+
return n
106+
107+
n = self.depthwise_conv_block(n, 512, strides=(2, 2), is_train=is_train, name="depth6")
108+
if end_with in n.outputs.name:
109+
return n
110+
n = self.depthwise_conv_block(n, 512, is_train=is_train, name="depth7")
111+
if end_with in n.outputs.name:
112+
return n
113+
n = self.depthwise_conv_block(n, 512, is_train=is_train, name="depth8")
114+
if end_with in n.outputs.name:
115+
return n
116+
n = self.depthwise_conv_block(n, 512, is_train=is_train, name="depth9")
117+
if end_with in n.outputs.name:
118+
return n
119+
n = self.depthwise_conv_block(n, 512, is_train=is_train, name="depth10")
120+
if end_with in n.outputs.name:
121+
return n
122+
n = self.depthwise_conv_block(n, 512, is_train=is_train, name="depth11")
123+
if end_with in n.outputs.name:
124+
return n
125+
126+
n = self.depthwise_conv_block(n, 1024, strides=(2, 2), is_train=is_train, name="depth12")
127+
if end_with in n.outputs.name:
128+
return n
129+
n = self.depthwise_conv_block(n, 1024, is_train=is_train, name="depth13")
130+
if end_with in n.outputs.name:
131+
return n
132+
133+
n = GlobalMeanPool2d(n, name='globalmeanpool')
134+
if end_with in n.outputs.name:
135+
return n
136+
# n = DropoutLayer(n, 1-1e-3, True, is_train, name='drop')
137+
# n = DenseLayer(n, 1000, act=tf.identity, name='output') # equal
138+
n = ReshapeLayer(n, [-1, 1, 1, 1024], name='reshape')
139+
if end_with in n.outputs.name:
140+
return n
141+
n = Conv2d(n, 1000, (1, 1), (1, 1), name='out')
142+
n = FlattenLayer(n, name='flatten')
143+
if end_with == 'out':
144+
return n
145+
146+
raise Exception("end_with : conv, depth1, depth2 ... depth13, globalmeanpool, out")
147+
148+
@classmethod
149+
def conv_block(cls, n, n_filter, filter_size=(3, 3), strides=(1, 1), is_train=False, name='conv_block'):
150+
# ref: https://github.com/keras-team/keras/blob/master/keras/applications/mobilenet.py
151+
with tf.variable_scope(name):
152+
n = Conv2d(n, n_filter, filter_size, strides, b_init=None, name='conv')
153+
n = BatchNormLayer(n, act=tf.nn.relu6, is_train=is_train, name='batchnorm')
154+
return n
155+
156+
@classmethod
157+
def depthwise_conv_block(cls, n, n_filter, strides=(1, 1), is_train=False, name="depth_block"):
158+
with tf.variable_scope(name):
159+
n = DepthwiseConv2d(n, (3, 3), strides, b_init=None, name='depthwise')
160+
n = BatchNormLayer(n, act=tf.nn.relu6, is_train=is_train, name='batchnorm1')
161+
n = Conv2d(n, n_filter, (1, 1), (1, 1), b_init=None, name='conv')
162+
n = BatchNormLayer(n, act=tf.nn.relu6, is_train=is_train, name='batchnorm2')
163+
return n
164+
165+
def restore_params(self, sess, path='models'):
166+
logging.info("Restore pre-trained parameters")
167+
maybe_download_and_extract(
168+
'mobilenet.npz', path, 'https://github.com/tensorlayer/pretrained-models/raw/master/models/', expected_bytes=25600116) # ls -al
169+
params = load_npz(name=os.path.join(path, 'mobilenet.npz'))
170+
assign_params(sess, params[:len(self.net.all_params)], self.net)
171+
del params

tensorlayer/models/squeezenetv1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class SqueezeNetV1(Layer):
3131
3232
Examples
3333
---------
34-
Classify ImageNet classes, see `tutorial_models_vgg16.py <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_models_vgg16.py>__`
34+
Classify ImageNet classes, see `tutorial_models_squeezenetv1.py <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_models_squeezenetv1.py>__`
3535
>>> x = tf.placeholder(tf.float32, [None, 224, 224, 3])
3636
>>> # get the whole model
3737
>>> net = tl.models.SqueezeNetV1(x)

0 commit comments

Comments
 (0)