-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reproducing test performance #12
Comments
Hi Mansoor, We do not follow the original split for the testing data. We use all ground truth images (not counting the reconstructed images in SCARED) for evaluation, our models are not trained on any data from SCARED or SERV-CT. To quote the paper..
We have not shared our training and evaluation code yet. I have refactored the code significantly and am just running some last tests to ensure the results are still the same. This should be released this week. Sorry for the delay, we have been having issues with our HPC cluster. In the mean time, I can share the loss/score function (ssimae), It is also worth noting that SCARED depth maps need to be inverted to disparity space (1/depth) as our models output relative disparity maps... import torch
from typing import Tuple, Sequence
def __apply_mask(depths, masks) -> Sequence[torch.Tensor]:
if masks == None:
return [depth for depth in depths]
else:
return [depth[mask] for depth, mask in zip(depths, masks)]
def normalise_depths(depths:torch.Tensor, masks:torch.Tensor) -> torch.Tensor:
masked_depths = __apply_mask(depths, masks)
means = torch.stack([md.mean() for md in masked_depths])
stds = torch.stack([md.std() for md in masked_depths])
return (depths - means[:, None, None, None]) / stds[:, None, None, None]
def fit_shifts_and_scales(source_depths:torch.Tensor, target_depths:torch.Tensor, masks:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
masked_source_depths = __apply_mask(source_depths, masks)
masked_target_depths = __apply_mask(target_depths, masks)
shifts, scales = [], []
for masked_source_depth, masked_target_depth in zip(masked_source_depths, masked_target_depths):
A = torch.vander(masked_source_depth, 2)
B = masked_target_depth.unsqueeze(-1)
X = torch.pinverse(A) @ B
scale, shift = X[0], X[1]
shifts.append(shift)
scales.append(scale)
return torch.cat(shifts), torch.cat(scales)
def fit_depths(source_depths:torch.Tensor, target_depths:torch.Tensor, masks:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad():
shifts, scales = fit_shifts_and_scales(source_depths, target_depths, masks)
fitted_depths = (shifts[:, None, None, None] + scales[:, None, None, None] * source_depths)
return fitted_depths
def ssimae(predicted_depths:torch.Tensor, target_depths:torch.Tensor, masks:torch.Tensor, normalise:bool=True) -> torch.Tensor:
if normalise:
target_depths = normalise_depths(target_depths, masks)
with torch.no_grad():
shifts, scales = fit_shifts_and_scales(predicted_depths, target_depths, masks)
error_maps = (shifts[:, None, None, None] + scales[:, None, None, None] * predicted_depths) - target_depths
ssimaes = torch.stack([masked_error_map.abs().mean() for masked_error_map in __apply_mask(error_maps, masks)])
return ssimaes |
Hi,
Thanks for sharing nice work on depth estimation. I am trying to reproduce your test results with provided checkpoints, can you share your code for that and the dataset spilts, correct me if I am wrong, you are using SERV-CT and SCARRED for the test.
The text was updated successfully, but these errors were encountered: