diff --git a/PolarPointBEV/augment.py b/PolarPointBEV/augment.py new file mode 100644 index 0000000..260a1fe --- /dev/null +++ b/PolarPointBEV/augment.py @@ -0,0 +1,106 @@ +####################################################################################################### +# This file is borrowed from COiLTRAiNE https://github.com/felipecode/coiltraine by Felipe Codevilla # +# COiLTRAiNE itself is under MIT License # +####################################################################################################### + + +import imgaug as ia +from imgaug import augmenters as iaa + + +def hard(image_iteration): + + iteration = image_iteration/32 + frequency_factor = min(0.05 + float(iteration)/200000.0, 1.0) + color_factor = min(float(iteration)/1000000.0, 1.0) + dropout_factor = 0.198667 + (0.03856658 - 0.198667) / (1 + (iteration / 196416.6) ** 1.863486) + + blur_factor = min(0.5 + (0.5*iteration/100000.0), 1.0) + + add_factor = 10 + 10*iteration/100000.0 + + multiply_factor_pos = 1 + (2.5*iteration/200000.0) + multiply_factor_neg = 1 - (0.91 * iteration / 500000.0) + + contrast_factor_pos = 1 + (0.5*iteration/500000.0) + contrast_factor_neg = 1 - (0.5 * iteration / 500000.0) + + augmenter = iaa.Sequential([ + + iaa.Sometimes(frequency_factor, iaa.GaussianBlur((0, blur_factor))), + # blur images with a sigma between 0 and 1.5 + iaa.Sometimes(frequency_factor, iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, dropout_factor), + per_channel=color_factor)), + # add gaussian noise to images + iaa.Sometimes(frequency_factor, iaa.CoarseDropout((0.0, dropout_factor), size_percent=( + 0.08, 0.2), per_channel=color_factor)), + # randomly remove up to X% of the pixels + iaa.Sometimes(frequency_factor, iaa.Dropout((0.0, dropout_factor), per_channel=color_factor)), + # randomly remove up to X% of the pixels + iaa.Sometimes(frequency_factor, + iaa.Add((-add_factor, add_factor), per_channel=color_factor)), + # change brightness of images (by -X to Y of original value) + iaa.Sometimes(frequency_factor, + iaa.Multiply((multiply_factor_neg, multiply_factor_pos), per_channel=color_factor)), + # change brightness of images (X-Y% of original value) + # iaa.Sometimes(frequency_factor, iaa.ContrastNormalization((contrast_factor_neg, contrast_factor_pos), + # per_channel=color_factor)), + iaa.Sometimes(frequency_factor, iaa.contrast.LinearContrast((contrast_factor_neg, contrast_factor_pos), + per_channel=color_factor)), + # improve or worsen the contrast + iaa.Sometimes(frequency_factor, iaa.Grayscale((0.0, 1))), # put grayscale + + ], + random_order=True # do all of the above in random order + ) + + return augmenter + + +def hard_1(image_iteration): + + iteration = image_iteration/32 + frequency_factor = min(0.05 + float(iteration)/200000.0, 1.0) + color_factor = min(float(iteration)/1000000.0, 1.0) + dropout_factor = 0.198667 + (0.03856658 - 0.198667) / (1 + (iteration / 196416.6) ** 1.863486) + + blur_factor = min(0.5 + (0.5*iteration/100000.0), 1.0) + + add_factor = 10 + 10*iteration/100000.0 + + multiply_factor_pos = 1 + (2.5*iteration/200000.0) + multiply_factor_neg = 1 - (0.91 * iteration / 500000.0) + + contrast_factor_pos = 1 + (0.5*iteration/500000.0) + contrast_factor_neg = 1 - (0.5 * iteration / 500000.0) + + augmenter = iaa.Sequential([ + + iaa.Sometimes(frequency_factor, iaa.GaussianBlur((0, blur_factor))), + # blur images with a sigma between 0 and 1.5 + iaa.Sometimes(frequency_factor, iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, dropout_factor), + per_channel=color_factor)), + # add gaussian noise to images + iaa.Sometimes(frequency_factor, iaa.CoarseDropout((0.0, dropout_factor), size_percent=( + 0.02, 0.2), per_channel=color_factor)), + # randomly remove up to X% of the pixels + iaa.Sometimes(frequency_factor, iaa.Dropout((0.0, dropout_factor), per_channel=color_factor)), + # randomly remove up to X% of the pixels + iaa.Sometimes(frequency_factor, + iaa.Add((-add_factor, add_factor), per_channel=color_factor)), + # change brightness of images (by -X to Y of original value) + iaa.Sometimes(frequency_factor, + iaa.Multiply((multiply_factor_neg, multiply_factor_pos), per_channel=color_factor)), + # change brightness of images (X-Y% of original value) + # iaa.Sometimes(frequency_factor, iaa.ContrastNormalization((contrast_factor_neg, contrast_factor_pos), + # per_channel=color_factor)), + iaa.Sometimes(frequency_factor, iaa.contrast.LinearContrast((contrast_factor_neg, contrast_factor_pos), + per_channel=color_factor)), + # improve or worsen the contrast + iaa.Sometimes(frequency_factor, iaa.Grayscale((0.0, 1))), # put grayscale + + ], + random_order=True # do all of the above in random order + ) + + return augmenter diff --git a/PolarPointBEV/config.py b/PolarPointBEV/config.py new file mode 100644 index 0000000..de500f2 --- /dev/null +++ b/PolarPointBEV/config.py @@ -0,0 +1,68 @@ +import os + +class GlobalConfig: + """ base architecture configurations """ + # Data + seq_len = 1 # input timesteps + pred_len = 4 # future waypoints predicted + + # data root for pretrain + # root_dir_all = "tcp_carla_data" + # data root for training + root_dir_all = "/workspace/dataset/carla_data/" + + train_towns = ['train_data'] + val_towns = ['val_data'] + train_data, val_data = [], [] + for town in train_towns: + train_data.append(os.path.join(root_dir_all, town)) + # train_data.append(os.path.join(root_dir_all, town+'_addition')) + for town in val_towns: + val_data.append(os.path.join(root_dir_all, town)) + + ignore_sides = True # don't consider side cameras + ignore_rear = True # don't consider rear cameras + + input_resolution = 256 + + scale = 1 # image pre-processing + crop = 256 # image pre-processing + + lr = 1e-4 # learning rate + + # Controller + turn_KP = 0.75 + turn_KI = 0.75 + turn_KD = 0.3 + turn_n = 40 # buffer size + + speed_KP = 5.0 + speed_KI = 0.5 + speed_KD = 1.0 + speed_n = 40 # buffer size + + max_throttle = 0.75 # upper limit on throttle signal value in dataset + brake_speed = 0.4 # desired speed below which brake is triggered + brake_ratio = 1.1 # ratio of speed to desired speed at which brake is triggered + clip_delta = 0.25 # maximum change in speed input to logitudinal controller + + + aim_dist = 4.0 # distance to search around for aim point + angle_thresh = 0.3 # outlier control detection angle + dist_thresh = 10 # target point y-distance for outlier filtering + + + speed_weight = 0.05 + value_weight = 0.001 + features_weight = 0.05 + bev_weight = 1.0 + graph_weight = 1.0 + + rl_ckpt = "/workspace/published_project/polar_BEV/roach/log/ckpt_11833344.pth" + + img_aug = True + + + def __init__(self, **kwargs): + for k,v in kwargs.items(): + setattr(self, k, v) diff --git a/PolarPointBEV/data.py b/PolarPointBEV/data.py new file mode 100644 index 0000000..a1c8cdb --- /dev/null +++ b/PolarPointBEV/data.py @@ -0,0 +1,279 @@ +import json +import os +from PIL import Image +import numpy as np +import torch +from torch.utils.data import Dataset +from torchvision import transforms as T + +from PolarPointBEV.augment import hard as augmenter + +class PolarPoint_Data(Dataset): + + def __init__(self, root, data_folders, img_aug = False): + self.root = root + self.img_aug = img_aug + self._batch_read_number = 0 + + self.front_img = [] + self.graph = [] + self.x = [] + self.y = [] + self.command = [] + self.target_command = [] + self.target_gps = [] + self.theta = [] + self.speed = [] + + + self.value = [] + self.feature = [] + self.action = [] + self.action_mu = [] + self.action_sigma = [] + + self.future_x = [] + self.future_y = [] + self.future_theta = [] + + self.future_feature = [] + self.future_action = [] + self.future_action_mu = [] + self.future_action_sigma = [] + self.future_only_ap_brake = [] + + self.x_command = [] + self.y_command = [] + self.command = [] + self.only_ap_brake = [] + + for sub_root in data_folders: + data = np.load(os.path.join(sub_root, "packed_data_normal.npy"), allow_pickle=True).item() + + self.x_command += data['x_target'] + self.y_command += data['y_target'] + self.command += data['target_command'] + + self.front_img += data['front_img'] + self.graph += data['graph'] + self.x += data['input_x'] + self.y += data['input_y'] + self.theta += data['input_theta'] + self.speed += data['speed'] + + self.future_x += data['future_x'] + self.future_y += data['future_y'] + self.future_theta += data['future_theta'] + + self.future_feature += data['future_feature'] + self.future_action += data['future_action'] + self.future_action_mu += data['future_action_mu'] + self.future_action_sigma += data['future_action_sigma'] + self.future_only_ap_brake += data['future_only_ap_brake'] + + self.value += data['value'] + self.feature += data['feature'] + self.action += data['action'] + self.action_mu += data['action_mu'] + self.action_sigma += data['action_sigma'] + self.only_ap_brake += data['only_ap_brake'] + self._im_transform = T.Compose([T.ToTensor(), T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])]) + + def __len__(self): + """Returns the length of the dataset. """ + return len(self.front_img) + + def __getitem__(self, index): + """Returns the item at index idx. """ + data = dict() + data['front_img'] = self.front_img[index] + data['graph'] = self.graph[index] + + if self.img_aug: + data['front_img'] = self._im_transform(augmenter(self._batch_read_number).augment_image(np.array( + Image.open(self.root+self.front_img[index][0])))) + else: + data['front_img'] = self._im_transform(np.array( + Image.open(self.root+self.front_img[index][0]))) + + with open(self.root + self.graph[index][0]) as jfile: + data['graph'] = json.load(jfile) + data['graph'] = torch.tensor(data['graph'], dtype=torch.long) + + # fix for theta=nan in some measurements + if np.isnan(self.theta[index][0]): + self.theta[index][0] = 0. + + ego_x = self.x[index][0] + ego_y = self.y[index][0] + ego_theta = self.theta[index][0] + + waypoints = [] + for i in range(4): + R = np.array([ + [np.cos(np.pi/2+ego_theta), -np.sin(np.pi/2+ego_theta)], + [np.sin(np.pi/2+ego_theta), np.cos(np.pi/2+ego_theta)] + ]) + local_command_point = np.array([self.future_y[index][i]-ego_y, self.future_x[index][i]-ego_x] ) + local_command_point = R.T.dot(local_command_point) + waypoints.append(local_command_point) + + data['waypoints'] = np.array(waypoints) + + data['action'] = self.action[index] + data['action_mu'] = self.action_mu[index] + data['action_sigma'] = self.action_sigma[index] + + + future_only_ap_brake = self.future_only_ap_brake[index] + future_action_mu = self.future_action_mu[index] + future_action_sigma = self.future_action_sigma[index] + + # use the average value of roach braking action when the brake is only performed by the rule-based detector + for i in range(len(future_only_ap_brake)): + if future_only_ap_brake[i]: + future_action_mu[i][0] = 0.8 + future_action_sigma[i][0] = 5.5 + data['future_action_mu'] = future_action_mu + data['future_action_sigma'] = future_action_sigma + data['future_feature'] = self.future_feature[index] + + only_ap_brake = self.only_ap_brake[index] + if only_ap_brake: + data['action_mu'][0] = 0.8 + data['action_sigma'][0] = 5.5 + + R = np.array([ + [np.cos(np.pi/2+ego_theta), -np.sin(np.pi/2+ego_theta)], + [np.sin(np.pi/2+ego_theta), np.cos(np.pi/2+ego_theta)] + ]) + local_command_point = np.array([-1*(self.x_command[index]-ego_x), self.y_command[index]-ego_y] ) + local_command_point = R.T.dot(local_command_point) + data['target_point'] = local_command_point[:2] + + + local_command_point_aim = np.array([(self.y_command[index]-ego_y), self.x_command[index]-ego_x] ) + local_command_point_aim = R.T.dot(local_command_point_aim) + data['target_point_aim'] = local_command_point_aim[:2] + + data['target_point'] = local_command_point_aim[:2] + + data['speed'] = self.speed[index] + data['feature'] = self.feature[index] + data['value'] = self.value[index] + command = self.command[index] + + # VOID = -1 + # LEFT = 1 + # RIGHT = 2 + # STRAIGHT = 3 + # LANEFOLLOW = 4 + # CHANGELANELEFT = 5 + # CHANGELANERIGHT = 6 + if command < 0: + command = 4 + command -= 1 + assert command in [0, 1, 2, 3, 4, 5] + cmd_one_hot = [0] * 6 + cmd_one_hot[command] = 1 + data['target_command'] = torch.tensor(cmd_one_hot) + + self._batch_read_number += 1 + return data + + +def scale_and_crop_image(image, scale=1, crop_w=256, crop_h=256): + """ + Scale and crop a PIL image + """ + (width, height) = (int(image.width // scale), int(image.height // scale)) + im_resized = image.resize((width, height)) + start_x = height//2 - crop_h//2 + start_y = width//2 - crop_w//2 + cropped_image = im_resized.crop((start_y, start_x, start_y+crop_w, start_x+crop_h)) + + # cropped_image = image[start_x:start_x+crop, start_y:start_y+crop] + # cropped_image = np.transpose(cropped_image, (2,0,1)) + return cropped_image + + +def transform_2d_points(xyz, r1, t1_x, t1_y, r2, t2_x, t2_y): + """ + Build a rotation matrix and take the dot product. + """ + # z value to 1 for rotation + xy1 = xyz.copy() + xy1[:,2] = 1 + + c, s = np.cos(r1), np.sin(r1) + r1_to_world = np.matrix([[c, s, t1_x], [-s, c, t1_y], [0, 0, 1]]) + + # np.dot converts to a matrix, so we explicitly change it back to an array + world = np.asarray(r1_to_world @ xy1.T) + + c, s = np.cos(r2), np.sin(r2) + r2_to_world = np.matrix([[c, s, t2_x], [-s, c, t2_y], [0, 0, 1]]) + world_to_r2 = np.linalg.inv(r2_to_world) + + out = np.asarray(world_to_r2 @ world).T + + # reset z-coordinate + out[:,2] = xyz[:,2] + + return out + +def rot_to_mat(roll, pitch, yaw): + roll = np.deg2rad(roll) + pitch = np.deg2rad(pitch) + yaw = np.deg2rad(yaw) + + yaw_matrix = np.array([ + [np.cos(yaw), -np.sin(yaw), 0], + [np.sin(yaw), np.cos(yaw), 0], + [0, 0, 1] + ]) + pitch_matrix = np.array([ + [np.cos(pitch), 0, -np.sin(pitch)], + [0, 1, 0], + [np.sin(pitch), 0, np.cos(pitch)] + ]) + roll_matrix = np.array([ + [1, 0, 0], + [0, np.cos(roll), np.sin(roll)], + [0, -np.sin(roll), np.cos(roll)] + ]) + + rotation_matrix = yaw_matrix.dot(pitch_matrix).dot(roll_matrix) + return rotation_matrix + + +def vec_global_to_ref(target_vec_in_global, ref_rot_in_global): + R = rot_to_mat(ref_rot_in_global['roll'], ref_rot_in_global['pitch'], ref_rot_in_global['yaw']) + np_vec_in_global = np.array([[target_vec_in_global[0]], + [target_vec_in_global[1]], + [target_vec_in_global[2]]]) + np_vec_in_ref = R.T.dot(np_vec_in_global) + return np_vec_in_ref[:,0] + +def get_action_beta(alpha, beta): + x = torch.zeros_like(alpha) + x[:, 1] += 0.5 + mask1 = (alpha > 1) & (beta > 1) + x[mask1] = (alpha[mask1]-1)/(alpha[mask1]+beta[mask1]-2) + + mask2 = (alpha <= 1) & (beta > 1) + x[mask2] = 0.0 + + mask3 = (alpha > 1) & (beta <= 1) + x[mask3] = 1.0 + + # mean + mask4 = (alpha <= 1) & (beta <= 1) + x[mask4] = alpha[mask4]/(alpha[mask4]+beta[mask4]) + + x = x * 2 - 1 + + return x + + + \ No newline at end of file diff --git a/PolarPointBEV/data_pretrain.py b/PolarPointBEV/data_pretrain.py new file mode 100644 index 0000000..37efc5a --- /dev/null +++ b/PolarPointBEV/data_pretrain.py @@ -0,0 +1,269 @@ +import os +from PIL import Image +import numpy as np +import torch +from torch.utils.data import Dataset +from torchvision import transforms as T + +from PolarPointBEV.augment import hard as augmenter + + +class CARLA_Data(Dataset): + + def __init__(self, root, data_folders, img_aug=False): + self.root = root + self.img_aug = img_aug + self._batch_read_number = 0 + + self.front_img = [] + self.x = [] + self.y = [] + self.command = [] + self.target_command = [] + self.target_gps = [] + self.theta = [] + self.speed = [] + + self.value = [] + self.feature = [] + self.action = [] + self.action_mu = [] + self.action_sigma = [] + + self.future_x = [] + self.future_y = [] + self.future_theta = [] + + self.future_feature = [] + self.future_action = [] + self.future_action_mu = [] + self.future_action_sigma = [] + self.future_only_ap_brake = [] + + self.x_command = [] + self.y_command = [] + self.command = [] + self.only_ap_brake = [] + + for sub_root in data_folders: + data = np.load(os.path.join(sub_root, "packed_data.npy"), allow_pickle=True).item() + + self.x_command += data['x_target'] + self.y_command += data['y_target'] + self.command += data['target_command'] + + self.front_img += data['front_img'] + self.x += data['input_x'] + self.y += data['input_y'] + self.theta += data['input_theta'] + self.speed += data['speed'] + + self.future_x += data['future_x'] + self.future_y += data['future_y'] + self.future_theta += data['future_theta'] + + self.future_feature += data['future_feature'] + self.future_action += data['future_action'] + self.future_action_mu += data['future_action_mu'] + self.future_action_sigma += data['future_action_sigma'] + self.future_only_ap_brake += data['future_only_ap_brake'] + + self.value += data['value'] + self.feature += data['feature'] + self.action += data['action'] + self.action_mu += data['action_mu'] + self.action_sigma += data['action_sigma'] + self.only_ap_brake += data['only_ap_brake'] + self._im_transform = T.Compose( + [T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) + + def __len__(self): + """Returns the length of the dataset. """ + return len(self.front_img) + + def __getitem__(self, index): + """Returns the item at index idx. """ + data = dict() + data['front_img'] = self.front_img[index] + + if self.img_aug: + data['front_img'] = self._im_transform(augmenter(self._batch_read_number).augment_image(np.array( + Image.open(self.root + self.front_img[index][0])))) + else: + data['front_img'] = self._im_transform(np.array( + Image.open(self.root + self.front_img[index][0]))) + + # fix for theta=nan in some measurements + if np.isnan(self.theta[index][0]): + self.theta[index][0] = 0. + + ego_x = self.x[index][0] + ego_y = self.y[index][0] + ego_theta = self.theta[index][0] + + waypoints = [] + for i in range(4): + R = np.array([ + [np.cos(np.pi / 2 + ego_theta), -np.sin(np.pi / 2 + ego_theta)], + [np.sin(np.pi / 2 + ego_theta), np.cos(np.pi / 2 + ego_theta)] + ]) + local_command_point = np.array([self.future_y[index][i] - ego_y, self.future_x[index][i] - ego_x]) + local_command_point = R.T.dot(local_command_point) + waypoints.append(local_command_point) + + data['waypoints'] = np.array(waypoints) + + data['action'] = self.action[index] + data['action_mu'] = self.action_mu[index] + data['action_sigma'] = self.action_sigma[index] + + future_only_ap_brake = self.future_only_ap_brake[index] + future_action_mu = self.future_action_mu[index] + future_action_sigma = self.future_action_sigma[index] + + # use the average value of roach braking action when the brake is only performed by the rule-based detector + for i in range(len(future_only_ap_brake)): + if future_only_ap_brake[i]: + future_action_mu[i][0] = 0.8 + future_action_sigma[i][0] = 5.5 + data['future_action_mu'] = future_action_mu + data['future_action_sigma'] = future_action_sigma + data['future_feature'] = self.future_feature[index] + + only_ap_brake = self.only_ap_brake[index] + if only_ap_brake: + data['action_mu'][0] = 0.8 + data['action_sigma'][0] = 5.5 + + R = np.array([ + [np.cos(np.pi / 2 + ego_theta), -np.sin(np.pi / 2 + ego_theta)], + [np.sin(np.pi / 2 + ego_theta), np.cos(np.pi / 2 + ego_theta)] + ]) + local_command_point = np.array([-1 * (self.x_command[index] - ego_x), self.y_command[index] - ego_y]) + local_command_point = R.T.dot(local_command_point) + data['target_point'] = local_command_point[:2] + + local_command_point_aim = np.array([(self.y_command[index] - ego_y), self.x_command[index] - ego_x]) + local_command_point_aim = R.T.dot(local_command_point_aim) + data['target_point_aim'] = local_command_point_aim[:2] + + data['target_point'] = local_command_point_aim[:2] + + data['speed'] = self.speed[index] + data['feature'] = self.feature[index] + data['value'] = self.value[index] + command = self.command[index] + + # VOID = -1 + # LEFT = 1 + # RIGHT = 2 + # STRAIGHT = 3 + # LANEFOLLOW = 4 + # CHANGELANELEFT = 5 + # CHANGELANERIGHT = 6 + if command < 0: + command = 4 + command -= 1 + assert command in [0, 1, 2, 3, 4, 5] + cmd_one_hot = [0] * 6 + cmd_one_hot[command] = 1 + data['target_command'] = torch.tensor(cmd_one_hot) + + self._batch_read_number += 1 + return data + + +def scale_and_crop_image(image, scale=1, crop_w=256, crop_h=256): + """ + Scale and crop a PIL image + """ + (width, height) = (int(image.width // scale), int(image.height // scale)) + im_resized = image.resize((width, height)) + start_x = height // 2 - crop_h // 2 + start_y = width // 2 - crop_w // 2 + cropped_image = im_resized.crop((start_y, start_x, start_y + crop_w, start_x + crop_h)) + + # cropped_image = image[start_x:start_x+crop, start_y:start_y+crop] + # cropped_image = np.transpose(cropped_image, (2,0,1)) + return cropped_image + + +def transform_2d_points(xyz, r1, t1_x, t1_y, r2, t2_x, t2_y): + """ + Build a rotation matrix and take the dot product. + """ + # z value to 1 for rotation + xy1 = xyz.copy() + xy1[:, 2] = 1 + + c, s = np.cos(r1), np.sin(r1) + r1_to_world = np.matrix([[c, s, t1_x], [-s, c, t1_y], [0, 0, 1]]) + + # np.dot converts to a matrix, so we explicitly change it back to an array + world = np.asarray(r1_to_world @ xy1.T) + + c, s = np.cos(r2), np.sin(r2) + r2_to_world = np.matrix([[c, s, t2_x], [-s, c, t2_y], [0, 0, 1]]) + world_to_r2 = np.linalg.inv(r2_to_world) + + out = np.asarray(world_to_r2 @ world).T + + # reset z-coordinate + out[:, 2] = xyz[:, 2] + + return out + + +def rot_to_mat(roll, pitch, yaw): + roll = np.deg2rad(roll) + pitch = np.deg2rad(pitch) + yaw = np.deg2rad(yaw) + + yaw_matrix = np.array([ + [np.cos(yaw), -np.sin(yaw), 0], + [np.sin(yaw), np.cos(yaw), 0], + [0, 0, 1] + ]) + pitch_matrix = np.array([ + [np.cos(pitch), 0, -np.sin(pitch)], + [0, 1, 0], + [np.sin(pitch), 0, np.cos(pitch)] + ]) + roll_matrix = np.array([ + [1, 0, 0], + [0, np.cos(roll), np.sin(roll)], + [0, -np.sin(roll), np.cos(roll)] + ]) + + rotation_matrix = yaw_matrix.dot(pitch_matrix).dot(roll_matrix) + return rotation_matrix + + +def vec_global_to_ref(target_vec_in_global, ref_rot_in_global): + R = rot_to_mat(ref_rot_in_global['roll'], ref_rot_in_global['pitch'], ref_rot_in_global['yaw']) + np_vec_in_global = np.array([[target_vec_in_global[0]], + [target_vec_in_global[1]], + [target_vec_in_global[2]]]) + np_vec_in_ref = R.T.dot(np_vec_in_global) + return np_vec_in_ref[:, 0] + + +def get_action_beta(alpha, beta): + x = torch.zeros_like(alpha) + x[:, 1] += 0.5 + mask1 = (alpha > 1) & (beta > 1) + x[mask1] = (alpha[mask1] - 1) / (alpha[mask1] + beta[mask1] - 2) + + mask2 = (alpha <= 1) & (beta > 1) + x[mask2] = 0.0 + + mask3 = (alpha > 1) & (beta <= 1) + x[mask3] = 1.0 + + # mean + mask4 = (alpha <= 1) & (beta <= 1) + x[mask4] = alpha[mask4] / (alpha[mask4] + beta[mask4]) + + x = x * 2 - 1 + + return x diff --git a/PolarPointBEV/metric.py b/PolarPointBEV/metric.py new file mode 100644 index 0000000..272654e --- /dev/null +++ b/PolarPointBEV/metric.py @@ -0,0 +1,92 @@ +from sklearn.metrics import f1_score, classification_report +import torch +import numpy as np +import json + + +def List2List(List): + Arr1 = np.array(List[:-1]).reshape(-1, List[0].shape[1]) + Arr2 = np.array(List[-1]).reshape(-1, List[0].shape[1]) + Arr = np.vstack((Arr1, Arr2)) + + return [i for item in Arr for i in item] + +def Gragh_Metric(total_gt_graph, total_pred_graph): + graph_num = len(total_gt_graph) + gt_graph_list = [] + pred_graph_list = [] + + for idx in range(graph_num): + gt_graph = total_gt_graph[idx].cpu().numpy() + pred_graph = torch.argmax(total_pred_graph[idx], dim=1) + pred_graph = pred_graph.cpu().numpy() + gt_graph_list.append(gt_graph) + pred_graph_list.append(pred_graph) + + + gt_graph_list = List2List(gt_graph_list) + pred_graph_list = List2List(pred_graph_list) + + # Overall F1 for graph + f1_graph_overall = f1_score(gt_graph_list, pred_graph_list, average='micro') + # Mean F1 for graph + f1_graph_mean = f1_score(gt_graph_list, pred_graph_list, average='macro') + + return f1_graph_overall, classification_report(gt_graph_list, pred_graph_list), f1_graph_mean + + +def Cal_IoU(total_gt_bev, total_pred_bev): + confmat = ConfusionMatrix(3) + for idx in range(len(total_gt_bev)): + confmat.update(total_gt_bev[idx].flatten(), total_pred_bev[idx].argmax(1).flatten()) + confmat.reduce_from_all_processes() + + return confmat + +class ConfusionMatrix(object): + def __init__(self, num_classes): + self.num_classes = num_classes + self.mat = None + + def update(self, a, b): + n = self.num_classes + if self.mat is None: + # 创建混淆矩阵 + self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device) + with torch.no_grad(): + # 寻找GT中为目标的像素索引 + k = (a >= 0) & (a < n) + # 统计像素真实类别a[k]被预测成类别b[k]的个数(这里的做法很巧妙) + inds = n * a[k].to(torch.int64) + b[k] + self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) + + def reset(self): + if self.mat is not None: + self.mat.zero_() + + def compute(self): + h = self.mat.float() + acc_global = torch.diag(h).sum() / h.sum() + acc = torch.diag(h) / h.sum(1) + iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) + return acc_global, acc, iu + + def reduce_from_all_processes(self): + if not torch.distributed.is_available(): + return + if not torch.distributed.is_initialized(): + return + torch.distributed.barrier() + torch.distributed.all_reduce(self.mat) + + def __str__(self): + acc_global, acc, iu = self.compute() + return ( + 'global correct: {:.1f}\n' + 'average row correct: {}\n' + 'IoU: {}\n' + 'mean IoU: {:.1f}').format( + acc_global.item() * 100, + ['{:.1f}'.format(i) for i in (acc * 100).tolist()], + ['{:.1f}'.format(i) for i in (iu * 100).tolist()], + iu.mean().item() * 100) \ No newline at end of file diff --git a/PolarPointBEV/model.py b/PolarPointBEV/model.py new file mode 100644 index 0000000..1a89a0e --- /dev/null +++ b/PolarPointBEV/model.py @@ -0,0 +1,348 @@ +from collections import deque +import numpy as np +import torch +from torch import nn +from PolarPointBEV.resnet import * +from PolarPointBEV.modules import Graph_Pred + +class PIDController(object): + def __init__(self, K_P=1.0, K_I=0.0, K_D=0.0, n=20): + self._K_P = K_P + self._K_I = K_I + self._K_D = K_D + + self._window = deque([0 for _ in range(n)], maxlen=n) + self._max = 0.0 + self._min = 0.0 + + def step(self, error): + self._window.append(error) + self._max = max(self._max, abs(error)) + self._min = -abs(self._max) + + if len(self._window) >= 2: + integral = np.mean(self._window) + derivative = (self._window[-1] - self._window[-2]) + else: + integral = 0.0 + derivative = 0.0 + + return self._K_P * error + self._K_I * integral + self._K_D * derivative + +class XPlan(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + + self.turn_controller = PIDController(K_P=config.turn_KP, K_I=config.turn_KI, K_D=config.turn_KD, n=config.turn_n) + self.speed_controller = PIDController(K_P=config.speed_KP, K_I=config.speed_KI, K_D=config.speed_KD, n=config.speed_n) + + self.perception = resnet34(pretrained=True) + + self.graph_predictor = Graph_Pred(num_views=1, num_class=3, output_size=[16, 57], + map_extents=[-15., 0., 15., 30.], map_resolution=0.2) + + + self.measurements = nn.Sequential( + nn.Linear(1+2+6, 128), + nn.ReLU(inplace=True), + nn.Linear(128, 128), + nn.ReLU(inplace=True), + ) + + self.join_traj = nn.Sequential( + nn.Linear(128+1000, 512), + nn.ReLU(inplace=True), + nn.Linear(512, 512), + nn.ReLU(inplace=True), + nn.Linear(512, 256), + nn.ReLU(inplace=True), + ) + + self.join_ctrl = nn.Sequential( + nn.Linear(128+512, 512), + nn.ReLU(inplace=True), + nn.Linear(512, 512), + nn.ReLU(inplace=True), + nn.Linear(512, 256), + nn.ReLU(inplace=True), + ) + + self.speed_branch = nn.Sequential( + nn.Linear(1000, 256), + nn.ReLU(inplace=True), + nn.Linear(256, 256), + nn.Dropout2d(p=0.5), + nn.ReLU(inplace=True), + nn.Linear(256, 1), + ) + + self.value_branch_traj = nn.Sequential( + nn.Linear(256, 256), + nn.ReLU(inplace=True), + nn.Linear(256, 256), + nn.Dropout2d(p=0.5), + nn.ReLU(inplace=True), + nn.Linear(256, 1), + ) + self.value_branch_ctrl = nn.Sequential( + nn.Linear(256, 256), + nn.ReLU(inplace=True), + nn.Linear(256, 256), + nn.Dropout2d(p=0.5), + nn.ReLU(inplace=True), + nn.Linear(256, 1), + ) + # shared branches_neurons + dim_out = 2 + + self.policy_head = nn.Sequential( + nn.Linear(256, 256), + nn.ReLU(inplace=True), + nn.Linear(256, 256), + nn.Dropout2d(p=0.5), + nn.ReLU(inplace=True), + ) + self.decoder_ctrl = nn.GRUCell(input_size=256+4, hidden_size=256) + self.output_ctrl = nn.Sequential( + nn.Linear(256, 256), + nn.ReLU(inplace=True), + nn.Linear(256, 256), + nn.ReLU(inplace=True), + ) + self.dist_mu = nn.Sequential(nn.Linear(256, dim_out), nn.Softplus()) + self.dist_sigma = nn.Sequential(nn.Linear(256, dim_out), nn.Softplus()) + + + self.decoder_traj = nn.GRUCell(input_size=4, hidden_size=256) + self.output_traj = nn.Linear(256, 2) + + self.init_att = nn.Sequential( + nn.Linear(128, 256), + nn.ReLU(inplace=True), + nn.Linear(256, 29*8), + nn.Softmax(1) + ) + + self.wp_att = nn.Sequential( + nn.Linear(256+256, 256), + nn.ReLU(inplace=True), + nn.Linear(256, 29*8), + nn.Softmax(1) + ) + + self.merge = nn.Sequential( + nn.Linear(512+256, 512), + nn.ReLU(inplace=True), + nn.Linear(512, 256), + ) + + + def forward(self, img, state, target_point): + feature_emb, cnn_feature, bev_fearture = self.perception(img) + # feature_emb2, cnn_feature2, bev_fearture2 = self.perception(img) + + outputs = {} + outputs['graph'] = self.graph_predictor(bev_fearture) + + outputs['pred_speed'] = self.speed_branch(feature_emb) + measurement_feature = self.measurements(state) + + j_traj = self.join_traj(torch.cat([feature_emb, measurement_feature], 1)) + outputs['pred_value_traj'] = self.value_branch_traj(j_traj) + outputs['pred_features_traj'] = j_traj + z = j_traj + output_wp = list() + traj_hidden_state = list() + + # initial input variable to GRU + x = torch.zeros(size=(z.shape[0], 2), dtype=z.dtype).type_as(z) + + # autoregressive generation of output waypoints + for _ in range(self.config.pred_len): + x_in = torch.cat([x, target_point], dim=1) + z = self.decoder_traj(x_in, z) + traj_hidden_state.append(z) + dx = self.output_traj(z) + x = dx + x + output_wp.append(x) + + pred_wp = torch.stack(output_wp, dim=1) + outputs['pred_wp'] = pred_wp + + traj_hidden_state = torch.stack(traj_hidden_state, dim=1) + init_att = self.init_att(measurement_feature).view(-1, 1, 8, 29) + feature_emb = torch.sum(cnn_feature*init_att, dim=(2, 3)) + j_ctrl = self.join_ctrl(torch.cat([feature_emb, measurement_feature], 1)) + outputs['pred_value_ctrl'] = self.value_branch_ctrl(j_ctrl) + outputs['pred_features_ctrl'] = j_ctrl + policy = self.policy_head(j_ctrl) + outputs['mu_branches'] = self.dist_mu(policy) + outputs['sigma_branches'] = self.dist_sigma(policy) + + x = j_ctrl + mu = outputs['mu_branches'] + sigma = outputs['sigma_branches'] + future_feature, future_mu, future_sigma = [], [], [] + + # initial hidden variable to GRU + h = torch.zeros(size=(x.shape[0], 256), dtype=x.dtype).type_as(x) + + for _ in range(self.config.pred_len): + x_in = torch.cat([x, mu, sigma], dim=1) + h = self.decoder_ctrl(x_in, h) + wp_att = self.wp_att(torch.cat([h, traj_hidden_state[:, _]], 1)).view(-1, 1, 8, 29) + new_feature_emb = torch.sum(cnn_feature*wp_att, dim=(2, 3)) + merged_feature = self.merge(torch.cat([h, new_feature_emb], 1)) + dx = self.output_ctrl(merged_feature) + x = dx + x + + policy = self.policy_head(x) + mu = self.dist_mu(policy) + sigma = self.dist_sigma(policy) + future_feature.append(x) + future_mu.append(mu) + future_sigma.append(sigma) + + + outputs['future_feature'] = future_feature + outputs['future_mu'] = future_mu + outputs['future_sigma'] = future_sigma + return outputs + + def process_action(self, pred, command, speed, target_point): + action = self._get_action_beta(pred['mu_branches'].view(1,2), pred['sigma_branches'].view(1,2)) + acc, steer = action.cpu().numpy()[0].astype(np.float64) + if acc >= 0.0: + throttle = acc + brake = 0.0 + else: + throttle = 0.0 + brake = np.abs(acc) + + throttle = np.clip(throttle, 0, 1) + steer = np.clip(steer, -1, 1) + brake = np.clip(brake, 0, 1) + + metadata = { + 'speed': float(speed.cpu().numpy().astype(np.float64)), + 'steer': float(steer), + 'throttle': float(throttle), + 'brake': float(brake), + 'command': command, + 'target_point': tuple(target_point[0].data.cpu().numpy().astype(np.float64)), + } + return steer, throttle, brake, metadata + + def _get_action_beta(self, alpha, beta): + x = torch.zeros_like(alpha) + x[:, 1] += 0.5 + mask1 = (alpha > 1) & (beta > 1) + x[mask1] = (alpha[mask1]-1)/(alpha[mask1]+beta[mask1]-2) + + mask2 = (alpha <= 1) & (beta > 1) + x[mask2] = 0.0 + + mask3 = (alpha > 1) & (beta <= 1) + x[mask3] = 1.0 + + # mean + mask4 = (alpha <= 1) & (beta <= 1) + x[mask4] = alpha[mask4]/torch.clamp((alpha[mask4]+beta[mask4]), min=1e-5) + + x = x * 2 - 1 + + return x + + def control_pid(self, waypoints, velocity, target): + ''' Predicts vehicle control with a PID controller. + Args: + waypoints (tensor): output of self.plan() + velocity (tensor): speedometer input + ''' + assert(waypoints.size(0)==1) + waypoints = waypoints[0].data.cpu().numpy() + target = target.squeeze().data.cpu().numpy() + + # flip y (forward is negative in our waypoints) + waypoints[:,1] *= -1 + target[1] *= -1 + + # iterate over vectors between predicted waypoints + num_pairs = len(waypoints) - 1 + best_norm = 1e5 + desired_speed = 0 + aim = waypoints[0] + for i in range(num_pairs): + # magnitude of vectors, used for speed + desired_speed += np.linalg.norm( + waypoints[i+1] - waypoints[i]) * 2.0 / num_pairs + + # norm of vector midpoints, used for steering + norm = np.linalg.norm((waypoints[i+1] + waypoints[i]) / 2.0) + if abs(self.config.aim_dist-best_norm) > abs(self.config.aim_dist-norm): + aim = waypoints[i] + best_norm = norm + + aim_last = waypoints[-1] - waypoints[-2] + + angle = np.degrees(np.pi / 2 - np.arctan2(aim[1], aim[0])) / 90 + angle_last = np.degrees(np.pi / 2 - np.arctan2(aim_last[1], aim_last[0])) / 90 + angle_target = np.degrees(np.pi / 2 - np.arctan2(target[1], target[0])) / 90 + + use_target_to_aim = np.abs(angle_target) < np.abs(angle) + use_target_to_aim = use_target_to_aim or (np.abs(angle_target-angle_last) > self.config.angle_thresh and target[1] < self.config.dist_thresh) + if use_target_to_aim: + angle_final = angle_target + else: + angle_final = angle + + steer = self.turn_controller.step(angle_final) + steer = np.clip(steer, -1.0, 1.0) + + speed = velocity[0].data.cpu().numpy() + brake = desired_speed < self.config.brake_speed or (speed / desired_speed) > self.config.brake_ratio + + delta = np.clip(desired_speed - speed, 0.0, self.config.clip_delta) + throttle = self.speed_controller.step(delta) + throttle = np.clip(throttle, 0.0, self.config.max_throttle) + throttle = throttle if not brake else 0.0 + + metadata = { + 'speed': float(speed.astype(np.float64)), + 'steer': float(steer), + 'throttle': float(throttle), + 'brake': float(brake), + 'wp_4': tuple(waypoints[3].astype(np.float64)), + 'wp_3': tuple(waypoints[2].astype(np.float64)), + 'wp_2': tuple(waypoints[1].astype(np.float64)), + 'wp_1': tuple(waypoints[0].astype(np.float64)), + 'aim': tuple(aim.astype(np.float64)), + 'target': tuple(target.astype(np.float64)), + 'desired_speed': float(desired_speed.astype(np.float64)), + 'angle': float(angle.astype(np.float64)), + 'angle_last': float(angle_last.astype(np.float64)), + 'angle_target': float(angle_target.astype(np.float64)), + 'angle_final': float(angle_final.astype(np.float64)), + 'delta': float(delta.astype(np.float64)), + } + + return steer, throttle, brake, metadata + + + def get_action(self, mu, sigma): + action = self._get_action_beta(mu.view(1,2), sigma.view(1,2)) + acc, steer = action[:, 0], action[:, 1] + if acc >= 0.0: + throttle = acc + brake = torch.zeros_like(acc) + else: + throttle = torch.zeros_like(acc) + brake = torch.abs(acc) + + throttle = torch.clamp(throttle, 0, 1) + steer = torch.clamp(steer, -1, 1) + brake = torch.clamp(brake, 0, 1) + + return throttle, steer, brake \ No newline at end of file diff --git a/PolarPointBEV/model_pretrain.py b/PolarPointBEV/model_pretrain.py new file mode 100644 index 0000000..1f078b9 --- /dev/null +++ b/PolarPointBEV/model_pretrain.py @@ -0,0 +1,345 @@ +from collections import deque +import numpy as np +import torch +from torch import nn +from PolarPointBEV.resnet_original import * + + +class PIDController(object): + def __init__(self, K_P=1.0, K_I=0.0, K_D=0.0, n=20): + self._K_P = K_P + self._K_I = K_I + self._K_D = K_D + + self._window = deque([0 for _ in range(n)], maxlen=n) + self._max = 0.0 + self._min = 0.0 + + def step(self, error): + self._window.append(error) + self._max = max(self._max, abs(error)) + self._min = -abs(self._max) + + if len(self._window) >= 2: + integral = np.mean(self._window) + derivative = (self._window[-1] - self._window[-2]) + else: + integral = 0.0 + derivative = 0.0 + + return self._K_P * error + self._K_I * integral + self._K_D * derivative + + +class TCP(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + + self.turn_controller = PIDController(K_P=config.turn_KP, K_I=config.turn_KI, K_D=config.turn_KD, + n=config.turn_n) + self.speed_controller = PIDController(K_P=config.speed_KP, K_I=config.speed_KI, K_D=config.speed_KD, + n=config.speed_n) + + self.perception = resnet34(pretrained=True) + + self.measurements = nn.Sequential( + nn.Linear(1 + 2 + 6, 128), + nn.ReLU(inplace=True), + nn.Linear(128, 128), + nn.ReLU(inplace=True), + ) + + self.join_traj = nn.Sequential( + nn.Linear(128 + 1000, 512), + nn.ReLU(inplace=True), + nn.Linear(512, 512), + nn.ReLU(inplace=True), + nn.Linear(512, 256), + nn.ReLU(inplace=True), + ) + + self.join_ctrl = nn.Sequential( + nn.Linear(128 + 512, 512), + nn.ReLU(inplace=True), + nn.Linear(512, 512), + nn.ReLU(inplace=True), + nn.Linear(512, 256), + nn.ReLU(inplace=True), + ) + + self.speed_branch = nn.Sequential( + nn.Linear(1000, 256), + nn.ReLU(inplace=True), + nn.Linear(256, 256), + nn.Dropout2d(p=0.5), + nn.ReLU(inplace=True), + nn.Linear(256, 1), + ) + + self.value_branch_traj = nn.Sequential( + nn.Linear(256, 256), + nn.ReLU(inplace=True), + nn.Linear(256, 256), + nn.Dropout2d(p=0.5), + nn.ReLU(inplace=True), + nn.Linear(256, 1), + ) + self.value_branch_ctrl = nn.Sequential( + nn.Linear(256, 256), + nn.ReLU(inplace=True), + nn.Linear(256, 256), + nn.Dropout2d(p=0.5), + nn.ReLU(inplace=True), + nn.Linear(256, 1), + ) + # shared branches_neurons + dim_out = 2 + + self.policy_head = nn.Sequential( + nn.Linear(256, 256), + nn.ReLU(inplace=True), + nn.Linear(256, 256), + nn.Dropout2d(p=0.5), + nn.ReLU(inplace=True), + ) + self.decoder_ctrl = nn.GRUCell(input_size=256 + 4, hidden_size=256) + self.output_ctrl = nn.Sequential( + nn.Linear(256, 256), + nn.ReLU(inplace=True), + nn.Linear(256, 256), + nn.ReLU(inplace=True), + ) + self.dist_mu = nn.Sequential(nn.Linear(256, dim_out), nn.Softplus()) + self.dist_sigma = nn.Sequential(nn.Linear(256, dim_out), nn.Softplus()) + + self.decoder_traj = nn.GRUCell(input_size=4, hidden_size=256) + self.output_traj = nn.Linear(256, 2) + + self.init_att = nn.Sequential( + nn.Linear(128, 256), + nn.ReLU(inplace=True), + nn.Linear(256, 29 * 8), + nn.Softmax(1) + ) + + self.wp_att = nn.Sequential( + nn.Linear(256 + 256, 256), + nn.ReLU(inplace=True), + nn.Linear(256, 29 * 8), + nn.Softmax(1) + ) + + self.merge = nn.Sequential( + nn.Linear(512 + 256, 512), + nn.ReLU(inplace=True), + nn.Linear(512, 256), + ) + + def forward(self, img, state, target_point): + feature_emb, cnn_feature = self.perception(img) + outputs = {} + outputs['pred_speed'] = self.speed_branch(feature_emb) + measurement_feature = self.measurements(state) + + j_traj = self.join_traj(torch.cat([feature_emb, measurement_feature], 1)) + outputs['pred_value_traj'] = self.value_branch_traj(j_traj) + outputs['pred_features_traj'] = j_traj + z = j_traj + output_wp = list() + traj_hidden_state = list() + + # initial input variable to GRU + x = torch.zeros(size=(z.shape[0], 2), dtype=z.dtype).type_as(z) + + # autoregressive generation of output waypoints + for _ in range(self.config.pred_len): + x_in = torch.cat([x, target_point], dim=1) + z = self.decoder_traj(x_in, z) + traj_hidden_state.append(z) + dx = self.output_traj(z) + x = dx + x + output_wp.append(x) + + pred_wp = torch.stack(output_wp, dim=1) + outputs['pred_wp'] = pred_wp + + traj_hidden_state = torch.stack(traj_hidden_state, dim=1) + init_att = self.init_att(measurement_feature).view(-1, 1, 8, 29) + feature_emb = torch.sum(cnn_feature * init_att, dim=(2, 3)) + j_ctrl = self.join_ctrl(torch.cat([feature_emb, measurement_feature], 1)) + outputs['pred_value_ctrl'] = self.value_branch_ctrl(j_ctrl) + outputs['pred_features_ctrl'] = j_ctrl + policy = self.policy_head(j_ctrl) + outputs['mu_branches'] = self.dist_mu(policy) + outputs['sigma_branches'] = self.dist_sigma(policy) + + x = j_ctrl + mu = outputs['mu_branches'] + sigma = outputs['sigma_branches'] + future_feature, future_mu, future_sigma = [], [], [] + + # initial hidden variable to GRU + h = torch.zeros(size=(x.shape[0], 256), dtype=x.dtype).type_as(x) + + for _ in range(self.config.pred_len): + x_in = torch.cat([x, mu, sigma], dim=1) + h = self.decoder_ctrl(x_in, h) + wp_att = self.wp_att(torch.cat([h, traj_hidden_state[:, _]], 1)).view(-1, 1, 8, 29) + new_feature_emb = torch.sum(cnn_feature * wp_att, dim=(2, 3)) + merged_feature = self.merge(torch.cat([h, new_feature_emb], 1)) + dx = self.output_ctrl(merged_feature) + x = dx + x + + policy = self.policy_head(x) + mu = self.dist_mu(policy) + sigma = self.dist_sigma(policy) + future_feature.append(x) + future_mu.append(mu) + future_sigma.append(sigma) + + outputs['future_feature'] = future_feature + outputs['future_mu'] = future_mu + outputs['future_sigma'] = future_sigma + return outputs + + def process_action(self, pred, command, speed, target_point): + action = self._get_action_beta(pred['mu_branches'].view(1, 2), pred['sigma_branches'].view(1, 2)) + acc, steer = action.cpu().numpy()[0].astype(np.float64) + if acc >= 0.0: + throttle = acc + brake = 0.0 + else: + throttle = 0.0 + brake = np.abs(acc) + + throttle = np.clip(throttle, 0, 1) + steer = np.clip(steer, -1, 1) + brake = np.clip(brake, 0, 1) + + metadata = { + 'speed': float(speed.cpu().numpy().astype(np.float64)), + 'steer': float(steer), + 'throttle': float(throttle), + 'brake': float(brake), + 'command': command, + 'target_point': tuple(target_point[0].data.cpu().numpy().astype(np.float64)), + } + return steer, throttle, brake, metadata + + def _get_action_beta(self, alpha, beta): + x = torch.zeros_like(alpha) + x[:, 1] += 0.5 + mask1 = (alpha > 1) & (beta > 1) + x[mask1] = (alpha[mask1] - 1) / (alpha[mask1] + beta[mask1] - 2) + + mask2 = (alpha <= 1) & (beta > 1) + x[mask2] = 0.0 + + mask3 = (alpha > 1) & (beta <= 1) + x[mask3] = 1.0 + + # mean + mask4 = (alpha <= 1) & (beta <= 1) + x[mask4] = alpha[mask4] / torch.clamp((alpha[mask4] + beta[mask4]), min=1e-5) + + x = x * 2 - 1 + + return x + + def control_pid(self, waypoints, velocity, target): + ''' Predicts vehicle control with a PID controller. + Args: + waypoints (tensor): output of self.plan() + velocity (tensor): speedometer input + ''' + assert (waypoints.size(0) == 1) + waypoints = waypoints[0].data.cpu().numpy() + target = target.squeeze().data.cpu().numpy() + + # flip y (forward is negative in our waypoints) + waypoints[:, 1] *= -1 + target[1] *= -1 + + # iterate over vectors between predicted waypoints + num_pairs = len(waypoints) - 1 + best_norm = 1e5 + desired_speed = 0 + aim = waypoints[0] + for i in range(num_pairs): + # magnitude of vectors, used for speed + desired_speed += np.linalg.norm( + waypoints[i + 1] - waypoints[i]) * 2.0 / num_pairs + + # norm of vector midpoints, used for steering + norm = np.linalg.norm((waypoints[i + 1] + waypoints[i]) / 2.0) + if abs(self.config.aim_dist - best_norm) > abs(self.config.aim_dist - norm): + aim = waypoints[i] + best_norm = norm + + aim_last = waypoints[-1] - waypoints[-2] + + angle = np.degrees(np.pi / 2 - np.arctan2(aim[1], aim[0])) / 90 + angle_last = np.degrees(np.pi / 2 - np.arctan2(aim_last[1], aim_last[0])) / 90 + angle_target = np.degrees(np.pi / 2 - np.arctan2(target[1], target[0])) / 90 + + # choice of point to aim for steering, removing outlier predictions + # use target point if it has a smaller angle or if error is large + # predicted point otherwise + # (reduces noise in eg. straight roads, helps with sudden turn commands) + use_target_to_aim = np.abs(angle_target) < np.abs(angle) + use_target_to_aim = use_target_to_aim or ( + np.abs(angle_target - angle_last) > self.config.angle_thresh and target[ + 1] < self.config.dist_thresh) + if use_target_to_aim: + angle_final = angle_target + else: + angle_final = angle + + steer = self.turn_controller.step(angle_final) + steer = np.clip(steer, -1.0, 1.0) + + speed = velocity[0].data.cpu().numpy() + brake = desired_speed < self.config.brake_speed or (speed / desired_speed) > self.config.brake_ratio + + delta = np.clip(desired_speed - speed, 0.0, self.config.clip_delta) + throttle = self.speed_controller.step(delta) + throttle = np.clip(throttle, 0.0, self.config.max_throttle) + throttle = throttle if not brake else 0.0 + + metadata = { + 'speed': float(speed.astype(np.float64)), + 'steer': float(steer), + 'throttle': float(throttle), + 'brake': float(brake), + 'wp_4': tuple(waypoints[3].astype(np.float64)), + 'wp_3': tuple(waypoints[2].astype(np.float64)), + 'wp_2': tuple(waypoints[1].astype(np.float64)), + 'wp_1': tuple(waypoints[0].astype(np.float64)), + 'aim': tuple(aim.astype(np.float64)), + 'target': tuple(target.astype(np.float64)), + 'desired_speed': float(desired_speed.astype(np.float64)), + 'angle': float(angle.astype(np.float64)), + 'angle_last': float(angle_last.astype(np.float64)), + 'angle_target': float(angle_target.astype(np.float64)), + 'angle_final': float(angle_final.astype(np.float64)), + 'delta': float(delta.astype(np.float64)), + } + + return steer, throttle, brake, metadata + + def get_action(self, mu, sigma): + action = self._get_action_beta(mu.view(1, 2), sigma.view(1, 2)) + acc, steer = action[:, 0], action[:, 1] + if acc >= 0.0: + throttle = acc + brake = torch.zeros_like(acc) + else: + throttle = torch.zeros_like(acc) + brake = torch.abs(acc) + + throttle = torch.clamp(throttle, 0, 1) + steer = torch.clamp(steer, -1, 1) + brake = torch.clamp(brake, 0, 1) + + return throttle, steer, brake diff --git a/PolarPointBEV/modules.py b/PolarPointBEV/modules.py new file mode 100644 index 0000000..1c82271 --- /dev/null +++ b/PolarPointBEV/modules.py @@ -0,0 +1,70 @@ +from torch import nn + + +class TransformModule(nn.Module): + def __init__(self, dim=(37, 60), num_view=1): + super(TransformModule, self).__init__() + self.num_view = num_view + self.dim = dim + self.mat_list = nn.ModuleList() + + for i in range(self.num_view): + fc_transform = nn.Sequential( + nn.Linear(dim[0] * dim[1], dim[0] * dim[1]), + nn.ReLU(), + nn.Linear(dim[0] * dim[1], dim[0] * dim[1]), + nn.ReLU() + ) + self.mat_list += [fc_transform] + + def forward(self, x): + # shape x: B, V, C, H, W + x = x.view(list(x.size()[:3]) + [self.dim[0] * self.dim[1],]) + view_comb = self.mat_list[0](x[:, 0]) + for index in range(x.size(1))[1:]: + view_comb += self.mat_list[index](x[:, index]) + view_comb = view_comb.view(list(view_comb.size()[:2]) + list(self.dim)) + return view_comb + +class Graph_Pred(nn.Module): + def __init__(self, num_views, num_class, output_size, map_extents, map_resolution): + + super(Graph_Pred, self).__init__() + self.num_views = num_views + self.output_size = output_size + + self.seg_size = ( + int((map_extents[3] - map_extents[1]) / map_resolution), + int((map_extents[2] - map_extents[0]) / map_resolution), + ) + + self.transform_module = TransformModule(dim=self.output_size, num_view=self.num_views) + self.decoder = Graph(num_class) + + def forward(self, x, *args): + B, N, C, H, W = x.view([-1, self.num_views, int(x.size()[1] / self.num_views)] \ + + list(x.size()[2:])).size() + + x = x.view( B*N, C, H, W) + x = x.view([B, N] + list(x.size()[1:])) + x = self.transform_module(x) + x = self.decoder(x) + return x + +class Graph(nn.Module): + def __init__(self, num_class=3): + super(Graph, self).__init__() + self.pre_class = nn.Sequential( + nn.AdaptiveAvgPool2d((16, 27)), + # (16, 27) for normal; (16, 15) for sparse; (16, 21) for light; (16, 33) for thick; (16, 41) for dense + nn.Conv2d(256, 512, kernel_size=1, bias=False), + nn.BatchNorm2d(512), + nn.ReLU(inplace=True), + nn.Dropout2d(0.1), + nn.Conv2d(512, num_class, kernel_size=1) + ) + def forward(self, x): + x = self.pre_class(x) + return x + + diff --git a/PolarPointBEV/pre_train.py b/PolarPointBEV/pre_train.py new file mode 100644 index 0000000..8a8f19a --- /dev/null +++ b/PolarPointBEV/pre_train.py @@ -0,0 +1,211 @@ +import argparse +import os +from collections import OrderedDict + +import torch +import torch.optim as optim +from torch.utils.data import DataLoader +import torch.nn.functional as F +from torch.distributions import Beta + +import pytorch_lightning as pl +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.plugins import DDPPlugin + +from PolarPointBEV.model_pretrain import TCP +from PolarPointBEV.data_pretrain import CARLA_Data +from PolarPointBEV.config import GlobalConfig + + +class TCP_planner(pl.LightningModule): + def __init__(self, config, lr): + super().__init__() + self.lr = lr + self.config = config + self.model = TCP(config) + self._load_weight() + + counters = 0 + for child in self.model.children(): + counters += 1 + print('xxxxxxxxxx', counters) + + def _load_weight(self): + rl_state_dict = torch.load(self.config.rl_ckpt, map_location='cpu')['policy_state_dict'] + self._load_state_dict(self.model.value_branch_traj, rl_state_dict, 'value_head') + self._load_state_dict(self.model.value_branch_ctrl, rl_state_dict, 'value_head') + self._load_state_dict(self.model.dist_mu, rl_state_dict, 'dist_mu') + self._load_state_dict(self.model.dist_sigma, rl_state_dict, 'dist_sigma') + + def _load_state_dict(self, il_net, rl_state_dict, key_word): + rl_keys = [k for k in rl_state_dict.keys() if key_word in k] + il_keys = il_net.state_dict().keys() + assert len(rl_keys) == len(il_net.state_dict().keys()), f'mismatch number of layers loading {key_word}' + new_state_dict = OrderedDict() + for k_il, k_rl in zip(il_keys, rl_keys): + new_state_dict[k_il] = rl_state_dict[k_rl] + il_net.load_state_dict(new_state_dict) + + def forward(self, batch): + pass + + def training_step(self, batch, batch_idx): + front_img = batch['front_img'] + speed = batch['speed'].to(dtype=torch.float32).view(-1, 1) / 12. + target_point = batch['target_point'].to(dtype=torch.float32) + command = batch['target_command'] + + state = torch.cat([speed, target_point, command], 1) + value = batch['value'].view(-1, 1) + feature = batch['feature'] + + gt_waypoints = batch['waypoints'] + + pred = self.model(front_img, state, target_point) + + dist_sup = Beta(batch['action_mu'], batch['action_sigma']) + dist_pred = Beta(pred['mu_branches'], pred['sigma_branches']) + kl_div = torch.distributions.kl_divergence(dist_sup, dist_pred) + action_loss = torch.mean(kl_div[:, 0]) * 0.5 + torch.mean(kl_div[:, 1]) * 0.5 + speed_loss = F.l1_loss(pred['pred_speed'], speed) * self.config.speed_weight + value_loss = (F.mse_loss(pred['pred_value_traj'], value) + F.mse_loss(pred['pred_value_ctrl'], + value)) * self.config.value_weight + feature_loss = (F.mse_loss(pred['pred_features_traj'], feature) + F.mse_loss(pred['pred_features_ctrl'], + feature)) * self.config.features_weight + + future_feature_loss = 0 + future_action_loss = 0 + for i in range(self.config.pred_len): + dist_sup = Beta(batch['future_action_mu'][i], batch['future_action_sigma'][i]) + dist_pred = Beta(pred['future_mu'][i], pred['future_sigma'][i]) + kl_div = torch.distributions.kl_divergence(dist_sup, dist_pred) + future_action_loss += torch.mean(kl_div[:, 0]) * 0.5 + torch.mean(kl_div[:, 1]) * 0.5 + future_feature_loss += F.mse_loss(pred['future_feature'][i], + batch['future_feature'][i]) * self.config.features_weight + future_feature_loss /= self.config.pred_len + future_action_loss /= self.config.pred_len + wp_loss = F.l1_loss(pred['pred_wp'], gt_waypoints, reduction='none').mean() + loss = action_loss + speed_loss + value_loss + feature_loss + wp_loss + future_feature_loss + future_action_loss + self.log('train_action_loss', action_loss.item()) + self.log('train_wp_loss_loss', wp_loss.item()) + self.log('train_speed_loss', speed_loss.item()) + self.log('train_value_loss', value_loss.item()) + self.log('train_feature_loss', feature_loss.item()) + self.log('train_future_feature_loss', future_feature_loss.item()) + self.log('train_future_action_loss', future_action_loss.item()) + return loss + + def configure_optimizers(self): + optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-7) + lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 30, 0.5) + return [optimizer], [lr_scheduler] + + def validation_step(self, batch, batch_idx): + front_img = batch['front_img'] + speed = batch['speed'].to(dtype=torch.float32).view(-1, 1) / 12. + target_point = batch['target_point'].to(dtype=torch.float32) + command = batch['target_command'] + state = torch.cat([speed, target_point, command], 1) + value = batch['value'].view(-1, 1) + feature = batch['feature'] + gt_waypoints = batch['waypoints'] + + pred = self.model(front_img, state, target_point) + dist_sup = Beta(batch['action_mu'], batch['action_sigma']) + dist_pred = Beta(pred['mu_branches'], pred['sigma_branches']) + kl_div = torch.distributions.kl_divergence(dist_sup, dist_pred) + action_loss = torch.mean(kl_div[:, 0]) * 0.5 + torch.mean(kl_div[:, 1]) * 0.5 + speed_loss = F.l1_loss(pred['pred_speed'], speed) * self.config.speed_weight + value_loss = (F.mse_loss(pred['pred_value_traj'], value) + F.mse_loss(pred['pred_value_ctrl'], + value)) * self.config.value_weight + feature_loss = (F.mse_loss(pred['pred_features_traj'], feature) + F.mse_loss(pred['pred_features_ctrl'], + feature)) * self.config.features_weight + wp_loss = F.l1_loss(pred['pred_wp'], gt_waypoints, reduction='none').mean() + + B = batch['action_mu'].shape[0] + batch_steer_l1 = 0 + batch_brake_l1 = 0 + batch_throttle_l1 = 0 + for i in range(B): + throttle, steer, brake = self.model.get_action(pred['mu_branches'][i], pred['sigma_branches'][i]) + batch_throttle_l1 += torch.abs(throttle - batch['action'][i][0]) + batch_steer_l1 += torch.abs(steer - batch['action'][i][1]) + batch_brake_l1 += torch.abs(brake - batch['action'][i][2]) + + batch_throttle_l1 /= B + batch_steer_l1 /= B + batch_brake_l1 /= B + + future_feature_loss = 0 + future_action_loss = 0 + for i in range(self.config.pred_len - 1): + dist_sup = Beta(batch['future_action_mu'][i], batch['future_action_sigma'][i]) + dist_pred = Beta(pred['future_mu'][i], pred['future_sigma'][i]) + kl_div = torch.distributions.kl_divergence(dist_sup, dist_pred) + future_action_loss += torch.mean(kl_div[:, 0]) * 0.5 + torch.mean(kl_div[:, 1]) * 0.5 + future_feature_loss += F.mse_loss(pred['future_feature'][i], + batch['future_feature'][i]) * self.config.features_weight + future_feature_loss /= self.config.pred_len + future_action_loss /= self.config.pred_len + + val_loss = wp_loss + batch_throttle_l1 + 5 * batch_steer_l1 + batch_brake_l1 + + self.log("val_action_loss", action_loss.item(), sync_dist=True) + self.log('val_speed_loss', speed_loss.item(), sync_dist=True) + self.log('val_value_loss', value_loss.item(), sync_dist=True) + self.log('val_feature_loss', feature_loss.item(), sync_dist=True) + self.log('val_wp_loss_loss', wp_loss.item(), sync_dist=True) + self.log('val_future_feature_loss', future_feature_loss.item(), sync_dist=True) + self.log('val_future_action_loss', future_action_loss.item(), sync_dist=True) + self.log('val_loss', val_loss.item(), sync_dist=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument('--id', type=str, default='TCP_pre', help='Unique experiment identifier.') + parser.add_argument('--epochs', type=int, default=60, help='Number of train epochs.') + parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate.') + parser.add_argument('--val_every', type=int, default=3, help='Validation frequency (epochs).') + parser.add_argument('--batch_size', type=int, default=32, help='Batch size') + parser.add_argument('--logdir', type=str, default='log_pre', help='Directory to log data to.') + parser.add_argument('--gpus', type=int, default=1, help='number of gpus') + + args = parser.parse_args() + args.logdir = os.path.join(args.logdir, args.id) + + # Config + config = GlobalConfig() + + # Data + train_set = CARLA_Data(root=config.root_dir_all, data_folders=config.train_data, img_aug=config.img_aug) + print(len(train_set)) + val_set = CARLA_Data(root=config.root_dir_all, data_folders=config.val_data, ) + print(len(val_set)) + + dataloader_train = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=8) + dataloader_val = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=8) + + TCP_model = TCP_planner(config, args.lr) + + checkpoint_callback = ModelCheckpoint(save_weights_only=False, mode="min", monitor="val_loss", save_top_k=2, + save_last=True, + dirpath=args.logdir, filename="best_{epoch:02d}-{val_loss:.3f}") + checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last" + trainer = pl.Trainer.from_argparse_args(args, + default_root_dir=args.logdir, + gpus=args.gpus, + accelerator='ddp', + sync_batchnorm=True, + plugins=DDPPlugin(find_unused_parameters=False), + profiler='simple', + benchmark=True, + log_every_n_steps=1, + flush_logs_every_n_steps=5, + callbacks=[checkpoint_callback, + ], + check_val_every_n_epoch=args.val_every, + max_epochs=args.epochs + ) + + trainer.fit(TCP_model, dataloader_train, dataloader_val) diff --git a/PolarPointBEV/resnet.py b/PolarPointBEV/resnet.py new file mode 100644 index 0000000..f2f0df4 --- /dev/null +++ b/PolarPointBEV/resnet.py @@ -0,0 +1,392 @@ +import torch +from torch import Tensor +import torch.nn as nn +from torch.hub import load_state_dict_from_url +from typing import Type, Any, Callable, Union, List, Optional + + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, + stride: int = 1, dilate: bool = False) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + z = x + x_layer4 = self.layer4(x) + + x = self.avgpool(x_layer4) + x = torch.flatten(x, 1) + x = self.fc(x) + + # return x, x_layer4 + return x, x_layer4, z # for bev and graph + + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _resnet( + arch: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + pretrained: bool, + progress: bool, + **kwargs: Any +) -> ResNet: + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + +def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) \ No newline at end of file diff --git a/PolarPointBEV/resnet_original.py b/PolarPointBEV/resnet_original.py new file mode 100644 index 0000000..de48208 --- /dev/null +++ b/PolarPointBEV/resnet_original.py @@ -0,0 +1,392 @@ +import torch +from torch import Tensor +import torch.nn as nn +from torch.hub import load_state_dict_from_url +from typing import Type, Any, Callable, Union, List, Optional + + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, + stride: int = 1, dilate: bool = False) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + z = x + x_layer4 = self.layer4(x) + + x = self.avgpool(x_layer4) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x, x_layer4 + # return x, x_layer4, z # for bev and graph + + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _resnet( + arch: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + pretrained: bool, + progress: bool, + **kwargs: Any +) -> ResNet: + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + +def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) \ No newline at end of file diff --git a/PolarPointBEV/train.py b/PolarPointBEV/train.py new file mode 100644 index 0000000..a8b87f2 --- /dev/null +++ b/PolarPointBEV/train.py @@ -0,0 +1,287 @@ +import argparse +import os +from collections import OrderedDict + +import torch +import torch.optim as optim +from torch.utils.data import DataLoader +import torch.nn.functional as F +from torch.distributions import Beta + +import pytorch_lightning as pl +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.plugins import DDPPlugin + +from PolarPointBEV.model import XPlan +from PolarPointBEV.data import PolarPoint_Data +from PolarPointBEV.config import GlobalConfig +from PolarPointBEV.metric import Gragh_Metric, Cal_IoU + + +class XPlan_planner(pl.LightningModule): + def __init__(self, config, lr): + super().__init__() + self.lr = lr + self.config = config + self.model = XPlan(config) + self._load_weight() + self.load_ckpt() + self.val_counter = -1 + + def _load_weight(self): + rl_state_dict = torch.load(self.config.rl_ckpt, map_location='cpu')['policy_state_dict'] + self._load_state_dict(self.model.value_branch_traj, rl_state_dict, 'value_head') + self._load_state_dict(self.model.value_branch_ctrl, rl_state_dict, 'value_head') + self._load_state_dict(self.model.dist_mu, rl_state_dict, 'dist_mu') + self._load_state_dict(self.model.dist_sigma, rl_state_dict, 'dist_sigma') + + def _load_state_dict(self, il_net, rl_state_dict, key_word): + rl_keys = [k for k in rl_state_dict.keys() if key_word in k] + il_keys = il_net.state_dict().keys() + assert len(rl_keys) == len(il_net.state_dict().keys()), f'mismatch number of layers loading {key_word}' + new_state_dict = OrderedDict() + for k_il, k_rl in zip(il_keys, rl_keys): + new_state_dict[k_il] = rl_state_dict[k_rl] + il_net.load_state_dict(new_state_dict) + + def load_ckpt(self): + # load the pre-train weight + ckpt = torch.load('../pretrain_weight.ckpt') + ckpt = ckpt["state_dict"] + new_state_dict = OrderedDict() + for key, value in ckpt.items(): + new_key = key.replace("model.","") + new_state_dict[new_key] = value + self.model.load_state_dict(new_state_dict, strict=False) + + def forward(self, batch): + pass + + def training_step(self, batch, batch_idx): + front_img = batch['front_img'] + speed = batch['speed'].to(dtype=torch.float32).view(-1,1) / 12. + target_point = batch['target_point'].to(dtype=torch.float32) + command = batch['target_command'] + + state = torch.cat([speed, target_point, command], 1) + value = batch['value'].view(-1,1) + feature = batch['feature'] + + gt_waypoints = batch['waypoints'] + gt_graph = batch['graph'] + + pred = self.model(front_img, state, target_point) + + w_graph = [1, 10, 2] + w_graph = torch.FloatTensor(w_graph).cuda() + loss_fn = torch.nn.CrossEntropyLoss(weight=w_graph) + graph_loss = loss_fn(pred['graph'], gt_graph) * self.config.graph_weight + + dist_sup = Beta(batch['action_mu'], batch['action_sigma']) + dist_pred = Beta(pred['mu_branches'], pred['sigma_branches']) + kl_div = torch.distributions.kl_divergence(dist_sup, dist_pred) + action_loss = torch.mean(kl_div[:, 0]) * 0.5 + torch.mean(kl_div[:, 1]) * 0.5 + speed_loss = F.l1_loss(pred['pred_speed'], speed) * self.config.speed_weight + value_loss = (F.mse_loss(pred['pred_value_traj'], value) + F.mse_loss(pred['pred_value_ctrl'], value)) * self.config.value_weight + feature_loss = (F.mse_loss(pred['pred_features_traj'], feature) + F.mse_loss(pred['pred_features_ctrl'], feature))* self.config.features_weight + + future_feature_loss = 0 + future_action_loss = 0 + for i in range(self.config.pred_len): + dist_sup = Beta(batch['future_action_mu'][i], batch['future_action_sigma'][i]) + dist_pred = Beta(pred['future_mu'][i], pred['future_sigma'][i]) + kl_div = torch.distributions.kl_divergence(dist_sup, dist_pred) + future_action_loss += torch.mean(kl_div[:, 0]) *0.5 + torch.mean(kl_div[:, 1]) *0.5 + future_feature_loss += F.mse_loss(pred['future_feature'][i], batch['future_feature'][i]) * self.config.features_weight + future_feature_loss /= self.config.pred_len + future_action_loss /= self.config.pred_len + wp_loss = F.l1_loss(pred['pred_wp'], gt_waypoints, reduction='none').mean() + loss = action_loss + speed_loss + value_loss + feature_loss + wp_loss+ future_feature_loss + future_action_loss\ + + graph_loss + self.log('train_action_loss', action_loss.item()) + self.log('train_wp_loss_loss', wp_loss.item()) + self.log('train_speed_loss', speed_loss.item()) + self.log('train_value_loss', value_loss.item()) + self.log('train_feature_loss', feature_loss.item()) + self.log('train_future_feature_loss', future_feature_loss.item()) + self.log('train_future_action_loss', future_action_loss.item()) + return loss + + def configure_optimizers(self): + optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-7) + lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 30, 0.5) + return [optimizer], [lr_scheduler] + + def validation_step(self, batch, batch_idx): + front_img = batch['front_img'] + speed = batch['speed'].to(dtype=torch.float32).view(-1,1) / 12. + target_point = batch['target_point'].to(dtype=torch.float32) + command = batch['target_command'] + state = torch.cat([speed, target_point, command], 1) + value = batch['value'].view(-1,1) + feature = batch['feature'] + gt_waypoints = batch['waypoints'] + + gt_graph = batch['graph'] + + pred = self.model(front_img, state, target_point) + + w_graph = [1, 10, 2] + w_graph = torch.FloatTensor(w_graph).cuda() + loss_fn = torch.nn.CrossEntropyLoss(weight=w_graph) + graph_loss = loss_fn(pred['graph'], gt_graph) * self.config.graph_weight + + dist_sup = Beta(batch['action_mu'], batch['action_sigma']) + dist_pred = Beta(pred['mu_branches'], pred['sigma_branches']) + kl_div = torch.distributions.kl_divergence(dist_sup, dist_pred) + action_loss = torch.mean(kl_div[:, 0]) * 0.5 + torch.mean(kl_div[:, 1]) * 0.5 + speed_loss = F.l1_loss(pred['pred_speed'], speed) * self.config.speed_weight + value_loss = (F.mse_loss(pred['pred_value_traj'], value) + F.mse_loss(pred['pred_value_ctrl'], value)) * self.config.value_weight + feature_loss = (F.mse_loss(pred['pred_features_traj'], feature) +F.mse_loss(pred['pred_features_ctrl'], feature))* self.config.features_weight + wp_loss = F.l1_loss(pred['pred_wp'], gt_waypoints, reduction='none').mean() + + B = batch['action_mu'].shape[0] + batch_steer_l1 = 0 + batch_brake_l1 = 0 + batch_throttle_l1 = 0 + for i in range(B): + throttle, steer, brake = self.model.get_action(pred['mu_branches'][i], pred['sigma_branches'][i]) + batch_throttle_l1 += torch.abs(throttle-batch['action'][i][0]) + batch_steer_l1 += torch.abs(steer-batch['action'][i][1]) + batch_brake_l1 += torch.abs(brake-batch['action'][i][2]) + + batch_throttle_l1 /= B + batch_steer_l1 /= B + batch_brake_l1 /= B + + future_feature_loss = 0 + future_action_loss = 0 + for i in range(self.config.pred_len-1): + dist_sup = Beta(batch['future_action_mu'][i], batch['future_action_sigma'][i]) + dist_pred = Beta(pred['future_mu'][i], pred['future_sigma'][i]) + kl_div = torch.distributions.kl_divergence(dist_sup, dist_pred) + future_action_loss += torch.mean(kl_div[:, 0]) *0.5 + torch.mean(kl_div[:, 1]) *0.5 + future_feature_loss += F.mse_loss(pred['future_feature'][i], batch['future_feature'][i]) * self.config.features_weight + future_feature_loss /= self.config.pred_len + future_action_loss /= self.config.pred_len + + val_loss = wp_loss + batch_throttle_l1+5*batch_steer_l1+batch_brake_l1 + graph_loss + + self.log("val_action_loss", action_loss.item(), sync_dist=True) + self.log('val_speed_loss', speed_loss.item(), sync_dist=True) + self.log('val_value_loss', value_loss.item(), sync_dist=True) + self.log('val_feature_loss', feature_loss.item(), sync_dist=True) + self.log('val_wp_loss_loss', wp_loss.item(), sync_dist=True) + self.log('val_future_feature_loss', future_feature_loss.item(), sync_dist=True) + self.log('val_future_action_loss', future_action_loss.item(), sync_dist=True) + self.log('val_loss', val_loss.item(), sync_dist=True) + + return action_loss.item(),speed_loss.item(), value_loss.item(), feature_loss.item(),\ + wp_loss.item(), future_feature_loss.item(), future_action_loss.item(), val_loss.item(), \ + graph_loss.item(), gt_graph, pred['graph'] + + def validation_epoch_end(self, outputs): + total_action_loss = 0 + total_speed_loss = 0 + total_value_loss = 0 + total_feature_loss = 0 + total_wp_loss = 0 + total_future_feature_loss = 0 + total_future_action_loss = 0 + total_val_loss = 0 + + total_graph_loss = 0 + total_gt_graph = [] + total_pred_graph = [] + + for action_loss, speed_loss, value_loss, feature_loss, wp_loss, future_feature_loss,\ + future_action_loss, val_loss, graph_loss, gt_graph, pred_graph in outputs: + total_action_loss += action_loss + total_speed_loss += speed_loss + total_value_loss += value_loss + total_feature_loss += feature_loss + total_wp_loss += wp_loss + total_future_feature_loss += future_feature_loss + total_future_action_loss += future_action_loss + total_val_loss += val_loss + total_graph_loss += graph_loss + total_gt_graph.append(gt_graph) + total_pred_graph.append(pred_graph) + + f1_graph_overall, f1_graph_cate, f1_graph_mean = Gragh_Metric(total_gt_graph, total_pred_graph) + iou_graph = Cal_IoU(total_gt_graph, total_pred_graph) + val_info = """ + epoch {0} + ----------------------- + val_total_loss: {1} + val_graph_loss: {2} + graph_category: {3} + graph_overall: {4} + graph_mean: {5} + \ngraph_iou: {6} + + """.format(self.val_counter, total_val_loss, total_graph_loss, + f1_graph_cate, f1_graph_overall, f1_graph_mean, iou_graph) + print(val_info) + result_file = './log.txt' + with open(result_file, 'a') as f: + f.write(val_info) + self.val_counter += 2 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--id', type=str, default='XPlan', help='Unique experiment identifier.') + parser.add_argument('--epochs', type=int, default=60, help='Number of train epochs.') + parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate.') + parser.add_argument('--val_every', type=int, default=2, help='Validation frequency (epochs).') + parser.add_argument('--batch_size', type=int, default=2, help='Batch size') + parser.add_argument('--logdir', type=str, default='log', help='Directory to log data to.') + parser.add_argument('--gpus', type=int, default=1, help='number of gpus') + + args = parser.parse_args() + args.logdir = os.path.join(args.logdir, args.id) + + # Config + config = GlobalConfig() + + # Data + train_set = PolarPoint_Data(root=config.root_dir_all, data_folders=config.train_data, img_aug = config.img_aug) + print(len(train_set)) + val_set = PolarPoint_Data(root=config.root_dir_all, data_folders=config.val_data,) + print(len(val_set)) + + dataloader_train = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=8) + dataloader_val = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=8) + + TCP_model = XPlan_planner(config, args.lr) + + checkpoint_callback = ModelCheckpoint(save_weights_only=False, mode="min", monitor="val_loss", save_top_k=-1, save_last=True, + dirpath=args.logdir, filename="{epoch:02d}-{val_loss:.3f}") + checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last" + trainer = pl.Trainer.from_argparse_args(args, + default_root_dir=args.logdir, + gpus = args.gpus, + accelerator='ddp', + sync_batchnorm=True, + plugins=DDPPlugin(find_unused_parameters=False), + profiler='simple', + benchmark=True, + log_every_n_steps=1, + flush_logs_every_n_steps=5, + callbacks=[checkpoint_callback, + ], + check_val_every_n_epoch = args.val_every, + max_epochs = args.epochs + ) + + trainer.fit(TCP_model, dataloader_train, dataloader_val) + + + + + + + + +