-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_generator_vgg_face_to_onnx_truncated.py
119 lines (95 loc) · 3.81 KB
/
model_generator_vgg_face_to_onnx_truncated.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
**2060.io**
Open-source tools that enable Institutions building Verifiable Credentials based services
"""
"""Load the dependencies"""
import tensorflow as tf
tf_version = int(tf.__version__.split(".", maxsplit=1)[0])
if tf_version == 1:
from keras.models import Model, Sequential
from keras.layers import (
Convolution2D,
ZeroPadding2D,
MaxPooling2D,
Flatten,
Dropout,
Activation,
)
else:
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import (
Convolution2D,
ZeroPadding2D,
MaxPooling2D,
Flatten,
Dropout,
Activation,
)
"""Define the architecture of the model (.h5 file does not include the architecture)
**Note**: The VGG-Face has a last layer with a SoftMax, so we have to remove it.
"""
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import ZeroPadding2D, Convolution2D, MaxPooling2D, Dropout, Flatten, Activation
def base_model() -> Sequential:
"""
Base model of VGG-Face being used for classification - not to find embeddings
Returns:
model (Sequential): model was trained to classify 2622 identities
"""
model = Sequential()
model.add(ZeroPadding2D((1, 1), input_shape=(224, 224, 3)))
model.add(Convolution2D(64, (3, 3), activation="relu"))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(64, (3, 3), activation="relu"))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(128, (3, 3), activation="relu"))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(128, (3, 3), activation="relu"))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(256, (3, 3), activation="relu"))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(256, (3, 3), activation="relu"))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(256, (3, 3), activation="relu"))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, (3, 3), activation="relu"))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, (3, 3), activation="relu"))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, (3, 3), activation="relu"))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, (3, 3), activation="relu"))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, (3, 3), activation="relu"))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, (3, 3), activation="relu"))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(Convolution2D(4096, (7, 7), activation="relu"))
model.add(Dropout(0.5))
model.add(Convolution2D(4096, (1, 1), activation="relu"))
model.add(Dropout(0.5))
model.add(Convolution2D(2622, (1, 1)))
model.add(Flatten())
model.add(Activation("softmax"))
return model
# Create the model and load the weights
model = base_model()
model.load_weights('vgg_face_weights.h5')
vgg_face_descriptor = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output)
# Save the model in TensorFlow SavedModel format
#model.save('saved_vgg_face_model')
tf.saved_model.save(model, 'saved_vgg_face_model')
tf.saved_model.save(vgg_face_descriptor, 'saved_vgg_face_model_truncated')
"""
Then:
- Install Tensorflow to ONNX package
pip install -U tf2onnx
- Convert the model to ONNX
python -m tf2onnx.convert --saved-model /content/saved_vgg_face_model_truncated --output /content/saved_vgg_face_model_truncated.onnx --opset 13
Now the model (saved_vgg_face_model_truncated.onnx) shoud be in the files.
"""