diff --git a/D2Go/create_d2go.py b/D2Go/create_d2go.py index 3dc326d6..bd5f7201 100644 --- a/D2Go/create_d2go.py +++ b/D2Go/create_d2go.py @@ -13,40 +13,11 @@ from d2go.model_zoo import model_zoo from mobile_cv.common.misc.file_utils import make_temp_directory -from d2go.tests.data_loader_helper import LocalImageGenerator, register_toy_dataset patch_d2_meta_arch() -@contextlib.contextmanager -def create_fake_detection_data_loader(height, width, is_train): - with make_temp_directory("detectron2go_tmp_dataset") as dataset_dir: - runner = create_runner("d2go.runner.GeneralizedRCNNRunner") - cfg = runner.get_default_cfg() - cfg.DATASETS.TRAIN = ["default_dataset_train"] - cfg.DATASETS.TEST = ["default_dataset_test"] - - with make_temp_directory("detectron2go_tmp_dataset") as dataset_dir: - image_dir = os.path.join(dataset_dir, "images") - os.makedirs(image_dir) - image_generator = LocalImageGenerator(image_dir, width=width, height=height) - - if is_train: - with register_toy_dataset( - "default_dataset_train", image_generator, num_images=3 - ): - train_loader = runner.build_detection_train_loader(cfg) - yield train_loader - else: - with register_toy_dataset( - "default_dataset_test", image_generator, num_images=3 - ): - test_loader = runner.build_detection_test_loader( - cfg, dataset_name="default_dataset_test" - ) - yield test_loader - def test_export_torchvision_format(): cfg_name = 'faster_rcnn_fbnetv3a_dsmask_C4.yaml' pytorch_model = model_zoo.get(cfg_name, trained=True) @@ -76,21 +47,27 @@ def forward(self, inputs: List[torch.Tensor]): size_divisibility = max(pytorch_model.backbone.size_divisibility, 10) h, w = size_divisibility, size_divisibility * 2 - with create_fake_detection_data_loader(h, w, is_train=False) as data_loader: - predictor_path = convert_and_export_predictor( - model_zoo.get_config(cfg_name), - copy.deepcopy(pytorch_model), - "torchscript_int8@tracing", - './', - data_loader, - ) - orig_model = torch.jit.load(os.path.join(predictor_path, "model.jit")) - wrapped_model = Wrapper(orig_model) - # optionally do a forward - wrapped_model([torch.rand(3, 600, 600)]) - scripted_model = torch.jit.script(wrapped_model) - scripted_model.save("ObjectDetection/app/src/main/assets/d2go.pt") + runner = create_runner("d2go.runner.GeneralizedRCNNRunner") + cfg = model_zoo.get_config(cfg_name) + datasets = list(cfg.DATASETS.TRAIN) + + data_loader = runner.build_detection_test_loader(cfg, datasets) + + predictor_path = convert_and_export_predictor( + cfg, + copy.deepcopy(pytorch_model), + "torchscript_int8@tracing", + './', + data_loader, + ) + + orig_model = torch.jit.load(os.path.join(predictor_path, "model.jit")) + wrapped_model = Wrapper(orig_model) + # optionally do a forward + wrapped_model([torch.rand(3, 600, 600)]) + scripted_model = torch.jit.script(wrapped_model) + scripted_model.save("ObjectDetection/app/src/main/assets/d2go.pt") if __name__ == '__main__': test_export_torchvision_format()