Skip to content
This repository has been archived by the owner on Dec 18, 2024. It is now read-only.

Commit

Permalink
added finetuned monodepth models
Browse files Browse the repository at this point in the history
  • Loading branch information
ranftlr committed Apr 12, 2021
1 parent 822b5f7 commit 7646355
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 9 deletions.
16 changes: 13 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ This repository contains code and models for our [paper](https://arxiv.org/abs/2


Monodepth:
- [dpt_hybrid-midas-501f0c75.pt](https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt), [Mirror](https://drive.google.com/file/d/1dgcJEYYw1F8qirXhZxgNK8dWWz_8gZBD/view?usp=sharing)
- [dpt_hybrid-midas-501f0c75.pt](https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt) [Mirror](https://drive.google.com/file/d/1dgcJEYYw1F8qirXhZxgNK8dWWz_8gZBD/view?usp=sharing)
- [dpt_large-midas-2f21e586.pt](https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt), [Mirror](https://drive.google.com/file/d/1vnuhoMc6caF-buQQ4hK0CeiMk9SjwB-G/view?usp=sharing)

Segmentation:
Expand All @@ -31,7 +31,6 @@ Segmentation:

The code was tested with Python 3.7, PyTorch 1.8.0, OpenCV 4.5.1, and timm 0.4.5


### Usage

1) Place one or more input images in the folder `input`.
Expand All @@ -48,10 +47,21 @@ Segmentation:
python run_segmentation.py
```

3) The results are written to the folder `output_monodepth` and `output_segmentation`, respectively.
3) The results are written to the folder `output_monodepth` and `output_semseg`, respectively.

Use the flag `-t` to switch between different models. Possible options are `dpt_hybrid` (default) and `dpt_large`.

**Additional models:**

- Finetuned on KITTI: [dpt_hybrid_kitti-cb926ef4.pt](https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid_kitti-cb926ef4.pt)
- Finetuned on NYUv2: [dpt_hybrid_nyu-2ce69ec7.pt](https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid_nyu-2ce69ec7.pt)

Run with

```shell
python run_monodepth -t [dpt_hybrid_kitti|dpt_hybrid_nyu]
```


### Citation

Expand Down
18 changes: 16 additions & 2 deletions dpt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,15 @@ def forward(self, x):


class DPTDepthModel(DPT):
def __init__(self, path=None, non_negative=True, **kwargs):
def __init__(
self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs
):
features = kwargs["features"] if "features" in kwargs else 256

self.scale = scale
self.shift = shift
self.invert = invert

head = nn.Sequential(
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode="bilinear"),
Expand All @@ -106,7 +112,15 @@ def __init__(self, path=None, non_negative=True, **kwargs):
self.load(path)

def forward(self, x):
return super().forward(x).squeeze(dim=1)
inv_depth = super().forward(x).squeeze(dim=1)

if self.invert:
depth = self.scale * inv_depth + self.shift
depth[depth < 1e-8] = 1e-8
depth = 1.0 / depth
return depth
else:
return inv_depth


class DPTSegmentationModel(DPT):
Expand Down
60 changes: 56 additions & 4 deletions run_monodepth.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ def run(input_path, output_path, model_path, model_type="dpt_hybrid", optimize=T
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device: %s" % device)

net_w = net_h = 384

# load network
if model_type == "dpt_large": # DPT-Large
net_w = net_h = 384
model = DPTDepthModel(
path=model_path,
backbone="vitl16_384",
Expand All @@ -43,22 +42,55 @@ def run(input_path, output_path, model_path, model_type="dpt_hybrid", optimize=T
)
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "dpt_hybrid": # DPT-Hybrid
net_w = net_h = 384
model = DPTDepthModel(
path=model_path,
backbone="vitb_rn50_384",
non_negative=True,
enable_attention_hooks=args.vis,
)
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "dpt_hybrid_kitti":
net_w = 1216
net_h = 352

model = DPTDepthModel(
path=model_path,
scale=0.00006016,
shift=0.00579,
invert=True,
backbone="vitb_rn50_384",
non_negative=True,
enable_attention_hooks=args.vis,
)

normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "dpt_hybrid_nyu":
net_w = 640
net_h = 480

model = DPTDepthModel(
path=model_path,
scale=0.000305,
shift=0.1378,
invert=True,
backbone="vitb_rn50_384",
non_negative=True,
enable_attention_hooks=args.vis,
)

normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
elif model_type == "midas_v21": # Convolutional model
net_w = net_h = 384

model = MidasNet_large(model_path, non_negative=True)
normalization = NormalizeImage(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
else:
assert (
False
), f"model_type '{model_type}' not implemented, use: --model_type [dpt_large|dpt_hybrid|midas_v21]"
), f"model_type '{model_type}' not implemented, use: --model_type [dpt_large|dpt_hybrid|dpt_hybrid_kitti|dpt_hybrid_nyu|midas_v21]"

transform = Compose(
[
Expand Down Expand Up @@ -93,11 +125,20 @@ def run(input_path, output_path, model_path, model_type="dpt_hybrid", optimize=T

print("start processing")
for ind, img_name in enumerate(img_names):
if os.path.isdir(img_name):
continue

print(" processing {} ({}/{})".format(img_name, ind + 1, num_images))
# input

img = util.io.read_image(img_name)

if args.kitti_crop is True:
height, width, _ = img.shape
top = height - 352
left = (width - 1216) // 2
img = img[top : top + 352, left : left + 1216, :]

img_input = transform({"image": img})["image"]

# compute
Expand All @@ -121,9 +162,14 @@ def run(input_path, output_path, model_path, model_type="dpt_hybrid", optimize=T
.numpy()
)

if model_type == "dpt_hybrid_kitti":
prediction *= 256

if model_type == "dpt_hybrid_nyu":
prediction *= 1000.0

if args.vis:
visualize_attention(sample, model, prediction, args.model_type)
# exit()

filename = os.path.join(
output_path, os.path.splitext(os.path.basename(img_name))[0]
Expand Down Expand Up @@ -160,16 +206,22 @@ def run(input_path, output_path, model_path, model_type="dpt_hybrid", optimize=T
help="model type [dpt_large|dpt_hybrid|midas_v21]",
)

parser.add_argument("--kitti_crop", dest="kitti_crop", action="store_true")

parser.add_argument("--optimize", dest="optimize", action="store_true")
parser.add_argument("--no-optimize", dest="optimize", action="store_false")

parser.set_defaults(optimize=True)
parser.set_defaults(kitti_crop=False)

args = parser.parse_args()

default_models = {
"midas_v21": "weights/midas_v21-f6b98070.pt",
"dpt_large": "weights/dpt_large-midas-2f21e586.pt",
"dpt_hybrid": "weights/dpt_hybrid-midas-501f0c75.pt",
"dpt_hybrid_kitti": "weights/dpt_hybrid_kitti-cb926ef4.pt",
"dpt_hybrid_nyu": "weights/dpt_hybrid_nyu-2ce69ec7.pt",
}

if args.model_weights is None:
Expand Down

0 comments on commit 7646355

Please sign in to comment.