Skip to content

Commit

Permalink
Refactor, Delete unnecessary params
Browse files Browse the repository at this point in the history
  • Loading branch information
Halim, Calvin Janitra | Calvin | CTMO authored and magnusbarata committed Sep 29, 2021
1 parent 8f10193 commit 4bfa518
Showing 1 changed file with 12 additions and 20 deletions.
32 changes: 12 additions & 20 deletions models/efficientnet.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,27 @@
from tensorflow import keras
import efficientnet_3D.tfkeras as efn

def efficientnet(input_shape, n_class=2, n_filters=64, **kwargs):
def efficientnet(input_shape, n_class=2, variant='B0', **kwargs):
"""EfficientNet model from the paper https://arxiv.org/abs/1905.11946"""
default_effnet_params = {
'include_top': True,
'weights': None,
'input_shape': input_shape,
'pooling': None,
'classes': n_class,
}
if len(input_shape) == 4:
effnet_layer = efn.EfficientNetB0(
include_top=True,
weights=None,
input_shape=input_shape,
pooling=None,
classes=2,
)
effnet_layer = getattr(efn, f'EfficientNet{variant}')(**default_effnet_params)
modal = '3D'
elif len(input_shape) == 3:
effnet_layer = keras.applications.EfficientNetB0(
include_top=True,
weights=None,
input_shape=input_shape,
pooling=None,
classes=2,
classifier_activation='softmax'
effnet_layer = getattr(keras.applications, f'EfficientNet{variant}')(
classifier_activation='softmax', **default_effnet_params
)
modal = '2D'
else:
raise ValueError('input_shape is expected as an array ranked 3 or 4')

inputs = keras.Input(input_shape)

outputs = effnet_layer(inputs)

if modal == '3D':
outputs = keras.layers.Softmax()(outputs)

return keras.Model(inputs, outputs, name=f'efficientnet_{modal}')
return keras.Model(inputs, outputs, name=f'efficientnet_{modal}_{variant}')

0 comments on commit 4bfa518

Please sign in to comment.