From 2787ad701fbb308cfb494ae8fb68b0fcea0e4077 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 25 Sep 2022 14:52:49 +0200 Subject: [PATCH] Add segment line predictions (#9571) * Add segment line predictions Signed-off-by: Glenn Jocher * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update Signed-off-by: Glenn Jocher Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- segment/predict.py | 20 ++++++++++++-------- utils/segment/general.py | 14 ++++++++++++++ 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/segment/predict.py b/segment/predict.py index 2241204715b5..607a8697d731 100644 --- a/segment/predict.py +++ b/segment/predict.py @@ -42,9 +42,10 @@ from models.common import DetectMultiBackend from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2, - increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh) + increment_path, non_max_suppression, print_args, scale_boxes, scale_segments, + strip_optimizer, xyxy2xywh) from utils.plots import Annotator, colors, save_one_box -from utils.segment.general import process_mask +from utils.segment.general import masks2segments, process_mask from utils.torch_utils import select_device, smart_inference_mode @@ -145,14 +146,16 @@ def run( save_path = str(save_dir / p.name) # im.jpg txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt s += '%gx%g ' % im.shape[2:] # print string - gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh imc = im0.copy() if save_crop else im0 # for save_crop annotator = Annotator(im0, line_width=line_thickness, example=str(names)) if len(det): masks = process_mask(proto[i], det[:, 6:], det[:, :4], im.shape[2:], upsample=True) # HWC + det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() # rescale boxes to im0 size - # Rescale boxes from img_size to im0 size - det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() + # Segments + if save_txt: + segments = reversed(masks2segments(masks)) + segments = [scale_segments(im.shape[2:], x, im0.shape).round() for x in segments] # Print results for c in det[:, 5].unique(): @@ -165,10 +168,10 @@ def run( im_gpu=None if retina_masks else im[i]) # Write results - for *xyxy, conf, cls in reversed(det[:, :6]): + for j, (*xyxy, conf, cls) in enumerate(reversed(det[:, :6])): if save_txt: # Write to file - xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh - line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format + segj = segments[j].reshape(-1) # (n,2) to (n*2) + line = (cls, *segj, conf) if save_conf else (cls, *segj) # label format with open(f'{txt_path}.txt', 'a') as f: f.write(('%g ' * len(line)).rstrip() % line + '\n') @@ -176,6 +179,7 @@ def run( c = int(cls) # integer class label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}') annotator.box_label(xyxy, label, color=colors(c, True)) + annotator.draw.polygon(segments[j], outline=colors(c, True), width=3) if save_crop: save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True) diff --git a/utils/segment/general.py b/utils/segment/general.py index 36547ed0889c..655123bdcfeb 100644 --- a/utils/segment/general.py +++ b/utils/segment/general.py @@ -1,4 +1,5 @@ import cv2 +import numpy as np import torch import torch.nn.functional as F @@ -118,3 +119,16 @@ def masks_iou(mask1, mask2, eps=1e-7): intersection = (mask1 * mask2).sum(1).clamp(0) # (N, ) union = (mask1.sum(1) + mask2.sum(1))[None] - intersection # (area1 + area2) - intersection return intersection / (union + eps) + + +def masks2segments(masks, strategy='largest'): + # Convert masks(n,160,160) into segments(n,xy) + segments = [] + for x in masks.int().numpy().astype('uint8'): + c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] + if strategy == 'concat': # concatenate all segments + c = np.concatenate([x.reshape(-1, 2) for x in c]) + elif strategy == 'largest': # select largest segment + c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2) + segments.append(c.astype('float32')) + return segments