generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 150
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added tests for mobilnetv3 timm backbone
- Loading branch information
1 parent
ee3c8e1
commit 20f7fd5
Showing
1 changed file
with
55 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import pytest | ||
from icevision.all import * | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"ds, model_type", | ||
[ | ||
( | ||
"fridge_ds", | ||
models.mmdet.retinanet, | ||
), | ||
], | ||
) | ||
class TestTimmBackbones: | ||
def dls_model(self, ds, model_type, samples_source, request): | ||
train_ds, valid_ds = request.getfixturevalue(ds) | ||
train_dl = model_type.train_dl(train_ds, batch_size=2) | ||
valid_dl = model_type.valid_dl(valid_ds, batch_size=2) | ||
|
||
# backbone = model_type.backbones.mmdet.resnet50_fpn_1x() | ||
backbone = model_type.backbones.timm.mobilenet.mobilenetv3_large_100 | ||
backbone.config_path = samples_source / backbone.config_path | ||
|
||
model = model_type.model(backbone=backbone, num_classes=5) | ||
|
||
return train_dl, valid_dl, model | ||
|
||
def test_mmdet_bbox_models_fastai(self, ds, model_type, samples_source, request): | ||
train_dl, valid_dl, model = self.dls_model( | ||
ds, model_type, samples_source, request | ||
) | ||
|
||
learn = model_type.fastai.learner( | ||
dls=[train_dl, valid_dl], model=model, splitter=fastai.trainable_params | ||
) | ||
learn.fine_tune(1, 3e-4) | ||
|
||
def test_mmdet_bbox_models_light(self, ds, model_type, samples_source, request): | ||
train_dl, valid_dl, model = self.dls_model( | ||
ds, model_type, samples_source, request | ||
) | ||
|
||
class LitModel(model_type.lightning.ModelAdapter): | ||
def configure_optimizers(self): | ||
return Adam(self.parameters(), lr=1e-4) | ||
|
||
light_model = LitModel(model) | ||
trainer = pl.Trainer( | ||
max_epochs=1, | ||
weights_summary=None, | ||
num_sanity_val_steps=0, | ||
logger=False, | ||
checkpoint_callback=False, | ||
) | ||
trainer.fit(light_model, train_dl, valid_dl) |