-
Notifications
You must be signed in to change notification settings - Fork 0
/
visualization_of_ocr_pred_and_gt.py
74 lines (60 loc) · 3.55 KB
/
visualization_of_ocr_pred_and_gt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import argparse
import numpy as np
from osgeo import gdal
import pickle
from utils import write_geolocated_image
import config_pop as cfg
from utils import read_input_raster_data
def visualization_of_ocr_pred_and_gt(dataset_name, rst_wp_regions_path, preproc_data_path, pred_map_path, output_dir):
# Read raster data
input_paths = cfg.input_paths[dataset_name]
inputs = read_input_raster_data(input_paths)
input_buildings = inputs["buildings"]
wp_rst_regions_gdal = gdal.Open(rst_wp_regions_path)
wp_rst_regions = wp_rst_regions_gdal.ReadAsArray().astype(np.uint32)
geo_transform = wp_rst_regions_gdal.GetGeoTransform()
projection = wp_rst_regions_gdal.GetProjection()
pred_map = gdal.Open(pred_map_path).ReadAsArray().astype(np.float32)
with open(preproc_data_path, 'rb') as handle:
pdata = pickle.load(handle)
valid_ids = pdata["valid_ids"]
valid_census = pdata["valid_census"]
# Group predictions into fine level map
valid_ocr_census = {}
valid_ocr_preds = {}
for id in valid_ids:
pred_pop_per_region = np.sum(pred_map[(wp_rst_regions == id) & (~np.isnan(pred_map)) ])
num_buildings_per_region = np.sum(input_buildings[(wp_rst_regions == id) & (~np.isnan(input_buildings)) & (input_buildings>=0) ])
valid_ocr_census[id] = valid_census[id] / num_buildings_per_region
valid_ocr_preds[id] = pred_pop_per_region / num_buildings_per_region
# Create map of fine level predictions and census data
fine_map_pred = np.zeros(wp_rst_regions.shape, dtype=np.float32)
fine_map_gt = np.zeros(wp_rst_regions.shape, dtype=np.float32)
fine_map_mape = np.zeros(wp_rst_regions.shape, dtype=np.float32)
fine_map_mpe = np.zeros(wp_rst_regions.shape, dtype=np.float32)
for id in valid_ids:
fine_map_pred[wp_rst_regions == id] = valid_ocr_preds[id]
fine_map_gt[wp_rst_regions == id] = valid_ocr_census[id]
if valid_ocr_census[id] > 0:
fine_map_mape[wp_rst_regions == id] = abs(valid_ocr_preds[id] - valid_ocr_census[id]) / valid_ocr_census[id]
fine_map_mpe[wp_rst_regions == id] = (valid_ocr_preds[id] - valid_ocr_census[id]) / valid_ocr_census[id]
fine_map_pred_path = os.path.join(output_dir, "ocr_fine_map_pred.tif")
fine_map_gt_path = os.path.join(output_dir, "ocr_fine_map_gt.tif")
fine_map_mape_path = os.path.join(output_dir, "ocr_fine_map_mape.tif")
fine_map_mpe_path = os.path.join(output_dir, "ocr_fine_map_mpe.tif")
write_geolocated_image(fine_map_pred, fine_map_pred_path, geo_transform, projection)
write_geolocated_image(fine_map_gt, fine_map_gt_path, geo_transform, projection)
write_geolocated_image(fine_map_mape, fine_map_mape_path, geo_transform, projection)
write_geolocated_image(fine_map_mpe, fine_map_mpe_path, geo_transform, projection)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("dataset_name", type=str, help="Dataset name (e.g., tza,uga)")
parser.add_argument("rst_wp_regions_path", type=str, help="Raster of WorldPop administrative boundaries information")
parser.add_argument("preproc_data_path", type=str, help="Preprocessed data of regions (pickle file)")
parser.add_argument("pred_map_path", type=str, help="Population prediction map")
parser.add_argument("output_dir", type=str, help="Output directory")
args = parser.parse_args()
visualization_of_ocr_pred_and_gt(args.dataset_name, args.rst_wp_regions_path, args.preproc_data_path, args.pred_map_path, args.output_dir)
if __name__ == "__main__":
main()