From 1aea74cddbc78e7f79dac07090cb157dfc24dbcc Mon Sep 17 00:00:00 2001 From: VELC Date: Sun, 4 Sep 2022 17:15:53 +0200 Subject: [PATCH] Add new `--vid-stride` inference parameter for videos (#9256) * fps feature/skip frame added * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * predict.py updates * Update dataloaders.py Signed-off-by: Glenn Jocher * Update dataloaders.py Signed-off-by: Glenn Jocher * remove unused attribute Signed-off-by: Glenn Jocher * Cleanup Signed-off-by: Glenn Jocher * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update predict.py Signed-off-by: Glenn Jocher * Update detect.py Signed-off-by: Glenn Jocher * Update dataloaders.py Signed-off-by: Glenn Jocher * Rename skip_frame to vid_stride * cleanup * cleanup2 Signed-off-by: Glenn Jocher Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher --- classify/predict.py | 6 ++++-- detect.py | 6 ++++-- utils/dataloaders.py | 15 +++++++++------ 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/classify/predict.py b/classify/predict.py index 76115c75029f..701b5b1ac92d 100644 --- a/classify/predict.py +++ b/classify/predict.py @@ -66,6 +66,7 @@ def run( exist_ok=False, # existing project/name ok, do not increment half=False, # use FP16 half-precision inference dnn=False, # use OpenCV DNN for ONNX inference + vid_stride=1, # video frame-rate stride ): source = str(source) save_img = not nosave and not source.endswith('.txt') # save inference images @@ -88,10 +89,10 @@ def run( # Dataloader if webcam: view_img = check_imshow() - dataset = LoadStreams(source, img_size=imgsz, transforms=classify_transforms(imgsz[0])) + dataset = LoadStreams(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride) bs = len(dataset) # batch_size else: - dataset = LoadImages(source, img_size=imgsz, transforms=classify_transforms(imgsz[0])) + dataset = LoadImages(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride) bs = 1 # batch_size vid_path, vid_writer = [None] * bs, [None] * bs @@ -196,6 +197,7 @@ def parse_opt(): parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference') + parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride') opt = parser.parse_args() opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand print_args(vars(opt)) diff --git a/detect.py b/detect.py index cf75d0f11c92..69a1bf13aac6 100644 --- a/detect.py +++ b/detect.py @@ -74,6 +74,7 @@ def run( hide_conf=False, # hide confidences half=False, # use FP16 half-precision inference dnn=False, # use OpenCV DNN for ONNX inference + vid_stride=1, # video frame-rate stride ): source = str(source) save_img = not nosave and not source.endswith('.txt') # save inference images @@ -96,10 +97,10 @@ def run( # Dataloader if webcam: view_img = check_imshow() - dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt) + dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) bs = len(dataset) # batch_size else: - dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt) + dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) bs = 1 # batch_size vid_path, vid_writer = [None] * bs, [None] * bs @@ -236,6 +237,7 @@ def parse_opt(): parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences') parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference') + parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride') opt = parser.parse_args() opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand print_args(vars(opt)) diff --git a/utils/dataloaders.py b/utils/dataloaders.py index 38ae3399ce26..c1ad1f1a4b83 100755 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -187,7 +187,7 @@ def __iter__(self): class LoadImages: # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4` - def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None): + def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1): files = [] for p in sorted(path) if isinstance(path, (list, tuple)) else [path]: p = str(Path(p).resolve()) @@ -212,6 +212,7 @@ def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None): self.mode = 'image' self.auto = auto self.transforms = transforms # optional + self.vid_stride = vid_stride # video frame-rate stride if any(videos): self._new_video(videos[0]) # new video else: @@ -232,6 +233,7 @@ def __next__(self): # Read video self.mode = 'video' ret_val, im0 = self.cap.read() + self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.vid_stride * (self.frame + 1)) # read at vid_stride while not ret_val: self.count += 1 self.cap.release() @@ -242,7 +244,7 @@ def __next__(self): ret_val, im0 = self.cap.read() self.frame += 1 - # im0 = self._cv2_rotate(im0) # for use if cv2 auto rotation is False + # im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: ' else: @@ -265,7 +267,7 @@ def _new_video(self, path): # Create a new video capture object self.frame = 0 self.cap = cv2.VideoCapture(path) - self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) + self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride) self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees # self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0) # disable https://github.com/ultralytics/yolov5/issues/8493 @@ -285,11 +287,12 @@ def __len__(self): class LoadStreams: # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams` - def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True, transforms=None): + def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True, transforms=None, vid_stride=1): torch.backends.cudnn.benchmark = True # faster for fixed-size inference self.mode = 'stream' self.img_size = img_size self.stride = stride + self.vid_stride = vid_stride # video frame-rate stride sources = Path(sources).read_text().rsplit() if Path(sources).is_file() else [sources] n = len(sources) self.sources = [clean_str(x) for x in sources] # clean source names for later @@ -329,11 +332,11 @@ def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True, tr def update(self, i, cap, stream): # Read stream `i` frames in daemon thread - n, f, read = 0, self.frames[i], 1 # frame number, frame array, inference every 'read' frame + n, f = 0, self.frames[i] # frame number, frame array while cap.isOpened() and n < f: n += 1 cap.grab() # .read() = .grab() followed by .retrieve() - if n % read == 0: + if n % self.vid_stride == 0: success, im = cap.retrieve() if success: self.imgs[i] = im