Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

result predicted on kitti sample data error #81

Open
lyfadvance opened this issue Feb 12, 2023 · 2 comments
Open

result predicted on kitti sample data error #81

lyfadvance opened this issue Feb 12, 2023 · 2 comments

Comments

@lyfadvance
Copy link

disp

I downloaded the fine-tuned model of kitti and ran it. The results are as follows. I don't know which link went wrong. My version of python is 3.9

this is my code:

#!/usr/bin/env python
# coding: utf-8

# In[1]:


from PIL import Image
import torch
import argparse
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append('../') # add relative path

from module.sttr import STTR
from dataset.preprocess import normalization, compute_left_occ_region
from utilities.misc import NestedTensor


# ### Define STTR model

# In[2]:


# Default parameters
args = type('', (), {})() # create empty args
args.channel_dim = 128
args.position_encoding='sine1d_rel'
args.num_attn_layers=6
args.nheads=8
args.regression_head='ot'
args.context_adjustment_layer='cal'
args.cal_num_blocks=8
args.cal_feat_dim=16
args.cal_expansion_ratio=4


# In[3]:


model = STTR(args).cuda().eval()


# In[4]:


# Load the pretrained model
model_file_name = "../kitti_finetuned_model.pth.tar"
#model_file_name = "../run/kitti/kitti_ft/original/model.pth.tar"
checkpoint = torch.load(model_file_name)
pretrained_dict = checkpoint['state_dict']
model.load_state_dict(pretrained_dict, strict=False) # prevent BN parameters from breaking the model loading
print("Pre-trained model successfully loaded.")


# ### Read image

# In[5]:


left = np.array(Image.open('../sample_data/KITTI_2015/training/image_2/000046_10.png'))
right = np.array(Image.open('../sample_data/KITTI_2015/training/image_3/000046_10.png'))
disp = np.array(Image.open('../sample_data/KITTI_2015/training/disp_occ_0/000046_10.png')).astype(np.float) / 256.


# In[6]:


# Visualize image
plt.figure(1)
plt.imshow(left)
plt.figure(2)
plt.imshow(right)
plt.figure(3)
plt.imshow(disp)


# Preprocess data for STTR

# In[7]:


# normalize
input_data = {'left': left, 'right':right, 'disp':disp}
input_data = normalization(**input_data)


# In[8]:


# donwsample attention by stride of 3
h, w, _ = left.shape
bs = 1

downsample = 3
col_offset = int(downsample / 2)
row_offset = int(downsample / 2)
sampled_cols = torch.arange(col_offset, w, downsample)[None,].expand(bs, -1).cuda()
sampled_rows = torch.arange(row_offset, h, downsample)[None,].expand(bs, -1).cuda()


# In[9]:


# build NestedTensor
input_data = NestedTensor(input_data['left'].cuda()[None,],input_data['right'].cuda()[None,], sampled_cols=sampled_cols, sampled_rows=sampled_rows)


# ### Inference

# In[10]:


output = model(input_data)


# In[11]:
# set disparity of occ area to 0
disp_pred = output['disp_pred'].data.cpu().numpy()[0]
occ_pred = output['occ_pred'].data.cpu().numpy()[0] > 0.5
disp_pred[occ_pred] = 0.0
disp_pred[disp_pred<0] = 0
disp_pred_flat = disp_pred.reshape(-1)

disp_image = Image.fromarray(disp_pred.astype(np.uint8))
disp_image.save("disp.jpg")
@mli0603
Copy link
Owner

mli0603 commented Feb 15, 2023

Hi @lyfadvance

Thank you for your interest in the project. The result indeed looks odd. Have you checked the logged accuracy? My side returns
Index 46, l1_raw 0.1737, rr 0.5591, l1 0.0545, occ_be 0.0006, epe 0.2251, iou 0.9996, px error 0.0017

I am using the original PyTorch 1.5.1 to test the pre-trained weights to eliminate torch version inconsistency. I don't think python version makes a difference though I am not entirely sure.

@lyfadvance
Copy link
Author

I checked the kitti data set. The reason for the poor effect may be that the place with disparity=0 is not trained

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants