Skip to content

Commit

Permalink
add a action
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiboRao committed Oct 20, 2024
1 parent 77ac339 commit d034128
Show file tree
Hide file tree
Showing 30 changed files with 1,525 additions and 599 deletions.
2 changes: 1 addition & 1 deletion Datasets/debug_dataset.csv
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
img
/home4/datasets/jack/Documents_home2/DFC2019_track2_trainval/Track_Train/OMA389_028_025_LEFT_RGB.tif
/home4/datasets/jack/Documents_home2/DFC2019_track2_trainval/Track_Train/OMA342_038_036_LEFT_RGB.tif
488 changes: 244 additions & 244 deletions Datasets/whu_reconstruction_val_list.csv

Large diffs are not rendered by default.

320 changes: 59 additions & 261 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,263 +1,25 @@
# Template-jf
[![Use the JackFramework Demo](https://github.com/Archaic-Atom/FrameworkTemplate/actions/workflows/build_env.yml/badge.svg?event=push)](https://github.com/Archaic-Atom/FrameworkTemplate/actions/workflows/build_env.yml)
![Python 3.8](https://img.shields.io/badge/python-3.8-green.svg?style=plastic)
![Pytorch 1.7](https://img.shields.io/badge/PyTorch%20-%23EE4C2C.svg?style=plastic)
![cuDnn 7.3.6](https://img.shields.io/badge/cudnn-7.3.6-green.svg?style=plastic)
![License MIT](https://img.shields.io/badge/license-MIT-green.svg?style=plastic)

>This is template project for JackFramework (https://github.com/Archaic-Atom/JackFramework). **It is used to rapidly build the model, without caring about the training process (such as DDP or DP, Tensorboard, et al.)**

Document:https://www.wolai.com/archaic-atom/rqKJVi7M1x44mPT8CdM1TL
# Cascaded Recurrent Networks with Masked Representation Learning for Stereo Matching of High-Resolution Satellite Images

Demo Project: https://github.com/Archaic-Atom/Demo-jf
## Project Overview
This project presents Masked Cascaded Recurrent Networks (MaskCRNet), a method for stereo matching of high-resolution satellite images. It employs masked representation learning to enhance feature extraction and uses cascaded recurrent modules to improve robustness against imperfect rectification, achieving accurate stereo matching for high-resolution satellite images.

---
### Software Environment
1. OS Environment
```
os >= linux 16.04
cudaToolKit == 10.1
cudnn == 7.3.6
```

2. Python Environment (We provide the whole env in )
```
python >= 3.8.5
pythorch >= 1.15.0
numpy >= 1.14.5
opencv >= 3.4.0
PIL >= 5.1.0
```
---
### Hardware Environment
The framework only can be used in GPUs.

### Train the model by running:
0. Install the JackFramework lib from Github (https://github.com/Archaic-Atom/JackFramework)
```
$ cd JackFramework/
$ ./install.sh
```

1. Get the Training list or Testing list (You need rewrite the code by your path, and my related demo code can be found in Source/Tools/genrate_**_traning_path.py)
```
$ ./GenPath.sh
```
Please check the path. The source code in Source/Tools.

2. Implement the model's interface and dataloader's interface of JackFramework in Source/UserModelImplementation/Models/your_model/inference.py and Source/UserModelImplementation/Dataloaders/your_dataloader.py.

The template of model is shown in follows:
```python
# -*- coding: utf-8 -*-
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import torch.optim as optim

import JackFramework as jf
# import UserModelImplementation.user_define as user_def


class YourModelInterface(jf.UserTemplate.ModelHandlerTemplate):
"""docstring for DeepLabV3Plus"""

def __init__(self, args: object) -> object:
super().__init__(args)
self.__args = args

def get_model(self) -> list:
# args = self.__args
# return model
return []

def optimizer(self, model: list, lr: float) -> list:
# args = self.__args
# return opt and sch
return [], []

def lr_scheduler(self, sch: object, ave_loss: list, sch_id: int) -> None:
# how to do schenduler
pass

def inference(self, model: list, input_data: list, model_id: int) -> list:
# args = self.__args
# return output
return []

def accuary(self, output_data: list, label_data: list, model_id: int) -> list:
# return acc
# args = self.__args
return []

def loss(self, output_data: list, label_data: list, model_id: int) -> list:
# return loss
# args = self.__args
return []

# Optional
def pretreatment(self, epoch: int, rank: object) -> None:
# do something before training epoch
pass

# Optional
def postprocess(self, epoch: int, rank: object,
ave_tower_loss: list, ave_tower_acc: list) -> None:
# do something after training epoch
pass

# Optional
def load_model(self, model: object, checkpoint: dict, model_id: int) -> bool:
# return False
return False

# Optional
def load_opt(self, opt: object, checkpoint: dict, model_id: int) -> bool:
# return False
return False

# Optional
def save_model(self, epoch: int, model_list: list, opt_list: list) -> dict:
# return None
return None

```

The template of Dataloader is shown in follows:
```python
# -*- coding: utf-8 -*-
import time
import JackFramework as jf
# import UserModelImplementation.user_define as user_def


class YourDataloader(jf.UserTemplate.DataHandlerTemplate):
"""docstring for DataHandlerTemplate"""

def __init__(self, args: object) -> object:
super().__init__(args)
self.__args = args
self.__result_str = jf.ResultStr()
self.__train_dataset = None
self.__val_dataset = None
self.__imgs_num = 0
self.__start_time = 0

def get_train_dataset(self, path: str, is_training: bool = True) -> object:
# args = self.__args
# return dataset
return None

def get_val_dataset(self, path: str) -> object:
# return dataset
# args = self.__args
# return dataset
return None

def split_data(self, batch_data: tuple, is_training: bool) -> list:
self.__start_time = time.time()
if is_training:
# return input_data_list, label_data_list
return [], []
# return input_data, supplement
return [], []

def show_train_result(self, epoch: int, loss:
list, acc: list,
duration: float) -> None:
assert len(loss) == len(acc) # same model number
info_str = self.__result_str.training_result_str(epoch, loss[0], acc[0], duration, True)
jf.log.info(info_str)

def show_val_result(self, epoch: int, loss:
list, acc: list,
duration: float) -> None:
assert len(loss) == len(acc) # same model number
info_str = self.__result_str.training_result_str(epoch, loss[0], acc[0], duration, False)
jf.log.info(info_str)

def save_result(self, output_data: list, supplement: list,
img_id: int, model_id: int) -> None:
assert self.__train_dataset is not None
# args = self.__args
# save method
pass

def show_intermediate_result(self, epoch: int,
loss: list, acc: list) -> str:
assert len(loss) == len(acc) # same model number
return self.__result_str.training_intermediate_result(epoch, loss[0], acc[0])
## Key Contributions

- **Masked Representation Learning Pre-training Strategy**: Addresses challenges in remote sensing stereo datasets by improving data utilization and feature representation on small datasets.
- **Improved Correlation Computation**: Based on self-attention, cross-attention, and deformable convolutions, it handles imperfect rectification to enhance performance.
- **State-of-the-Art Performance**: Achieves state-of-the-art results on the US3D and WHU-Stereo datasets.

```

you must implement the related class for using JackFramework, the demo can be find in Source/UserModelImplementation/Models/Your_Model/inference.py or Source/UserModelImplementation/Dataloaders/your_dataloader.py. Or you can find the other demo in Demo project.

Next, you need implement the interface file Source/user_interface.py (you can add some parameters in user\_parser function of this file ), as shown in follows:
```python
# -*- coding: utf-8 -*-
import argparse
import JackFramework as jf
# import UserModelImplementation.user_define as user_def

# model and dataloader
from UserModelImplementation import Models
from UserModelImplementation import Dataloaders


class UserInterface(jf.UserTemplate.NetWorkInferenceTemplate):
"""docstring for UserInterface"""

def __init__(self) -> object:
super().__init__()

def inference(self, args: object) -> object:
dataloader = Dataloaders.dataloaders_zoo(args, args.dataset)
model = Models.model_zoo(args, args.modelName)
return model, dataloader

def user_parser(self, parser: object) -> object:
# parser.add_argument('--startDisp', type=int, default=user_def.START_DISP,
# help='start disparity')
# return parser
return None

@staticmethod
def __str2bool(arg: str) -> bool:
if arg.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif arg.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
```
## Code Structure

Finally, you need pass this object to JackFramework, as shown in follows:
```python
# -*coding: utf-8 -*-
import JackFramework as jf
from UserModelImplementation.user_interface import UserInterface


def main()->None:
app = jf.Application(UserInterface(), "Stereo Matching Models")
app.start()


# execute the main function
if __name__ == "__main__":
main()

```

3. Run the program, like:
```
$ ./Scripts/start_debug_stereo_net.sh
```
---
### File Structure
```
Template-jf
MaskCRNet
├── Datasets # Get it by ./generate_path.sh, you need build folder
│ ├── dataset_example_training_list.csv
│ └── ...
Expand All @@ -278,19 +40,55 @@ Template-jf
├── LICENSE
└── README.md
```
---
### Update log
#### 2021-05-29
1. Add the depth for transformer;
2. Fork the JackFramework to a new project;
3. Remove the JackFramework from this project.

#### 2021-04-08
1. Add the stereo;
2. Add transformer.
## Dataset Preparation
1. US3D Dataset: Download from the US3D official website and organize according to the dataset's README.
2. WHU-Stereo Dataset: Download from the WHU-Stereo GitHub page and organize according to the dataset's README.

## Environment Dependencies

Ensure you have the following Python libraries installed:

- torch
- torchvision
- numpy
- JackFramework
- DatasetHandler

## Training the Model
1. Get the Training list or Testing list (You need rewrite the code by your path, and my related demo code can be found in Source/Tools/genrate_**_traning_path.py)
```
$ ./Scripts/GenPath.sh
```


2. Run the program, like:
```
$ ./Scripts/start_debug_stereo_net.sh
```

## Testing the Model

1. Run the program, like:
```
$ ./Scripts/start_test_stereo_net.sh
```

## Citation
If you use this code or method, please cite the following paper:
```
@article{rao2024cascaded,
title={Cascaded Recurrent Networks with Masked Representation Learning for Stereo Matching of High-Resolution Satellite Images},
author={Rao, Zhibo and Li, Xing and Xiong, Bangshu and Dai, Yuchao and Shen, Zhelun and Li, Hangbiao and Lou, Yue},
journal={ISPRS Journal of Photogrammetry and Remote Sensing},
year={2024},
url={https://github.com/Archaic-Atom/MaskCRNet}
}
```

## Contact Us
For any questions or suggestions, please contact us at:

- Email: [email protected]

#### 2021-01-13
1. Fork a new prject (based on pythorch);
2. Use a new code style;
3. Build the frameworks for pythorch;
4. Write ReadMe
Thank you for using our code!
6 changes: 3 additions & 3 deletions Scripts/start_test_whu_dataset.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/bin/bash
test_gpus_id=0,1,2,3,4
eva_gpus_id=7
# test_list_path='./Datasets/whu_stereo_testing_list.csv'
test_list_path='./Datasets/whu_stereo_val_list.csv'
test_list_path='./Datasets/whu_stereo_testing_list.csv'
# test_list_path='./Datasets/whu_stereo_val_list.csv'
evalution_format='training'

CUDA_VISIBLE_DEVICES=${test_gpus_id} python Source/main.py \
Expand All @@ -25,7 +25,7 @@ CUDA_VISIBLE_DEVICES=${test_gpus_id} python Source/main.py \
--pre_train_opt false \
--modelName SwinStereo \
--outputDir ./TestResult/ \
--modelDir ./Checkpoint/ \
--modelDir ./Checkpoint_old/ \
--dataset whu

CUDA_VISIBLE_DEVICES=${eva_gpus_id} python ./Source/Tools/evalution_stereo_net.py --gt_list_path ${test_list_path} --invaild_value -999 --img_path_format ./ResultImg/%06d_10.tiff
8 changes: 4 additions & 4 deletions Scripts/start_train_pre_us3d_dataset.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
tensorboard_port=6234
dist_port=8809
tensorboard_folder='./log/'
# train_list_path='./Datasets/debug_dataset.csv'
train_list_path='./Datasets/us3d_reconstruction_training_list.csv'
train_list_path='./Datasets/debug_dataset.csv'
# train_list_path='./Datasets/us3d_reconstruction_training_list.csv'
echo "The tensorboard_port:" ${tensorboard_port}
echo "The dist_port:" ${dist_port}

Expand All @@ -17,12 +17,12 @@ fi
echo "Begin to train the model!"
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 nohup python -u Source/main.py \
--batchSize 1 \
--gpu 8 \
--gpu 1 \
--trainListPath ${train_list_path} \
--imgWidth 448 \
--imgHeight 448 \
--dataloaderNum 8 \
--maxEpochs 1000 \
--maxEpochs 1 \
--imgNum 2440 \
--sampleNum 1 \
--log ${tensorboard_folder} \
Expand Down
Loading

0 comments on commit d034128

Please sign in to comment.