Skip to content

Commit

Permalink
Add new --vid-stride inference parameter for videos (ultralytics#9256)
Browse files Browse the repository at this point in the history
* fps feature/skip frame added

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* predict.py updates

* Update dataloaders.py

Signed-off-by: Glenn Jocher <[email protected]>

* Update dataloaders.py

Signed-off-by: Glenn Jocher <[email protected]>

* remove unused attribute

Signed-off-by: Glenn Jocher <[email protected]>

* Cleanup

Signed-off-by: Glenn Jocher <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update predict.py

Signed-off-by: Glenn Jocher <[email protected]>

* Update detect.py

Signed-off-by: Glenn Jocher <[email protected]>

* Update dataloaders.py

Signed-off-by: Glenn Jocher <[email protected]>

* Rename skip_frame to vid_stride

* cleanup

* cleanup2

Signed-off-by: Glenn Jocher <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <[email protected]>
  • Loading branch information
3 people authored Sep 4, 2022
1 parent e45d335 commit 1aea74c
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
6 changes: 4 additions & 2 deletions classify/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def run(
exist_ok=False, # existing project/name ok, do not increment
half=False, # use FP16 half-precision inference
dnn=False, # use OpenCV DNN for ONNX inference
vid_stride=1, # video frame-rate stride
):
source = str(source)
save_img = not nosave and not source.endswith('.txt') # save inference images
Expand All @@ -88,10 +89,10 @@ def run(
# Dataloader
if webcam:
view_img = check_imshow()
dataset = LoadStreams(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]))
dataset = LoadStreams(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride)
bs = len(dataset) # batch_size
else:
dataset = LoadImages(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]))
dataset = LoadImages(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride)
bs = 1 # batch_size
vid_path, vid_writer = [None] * bs, [None] * bs

Expand Down Expand Up @@ -196,6 +197,7 @@ def parse_opt():
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
opt = parser.parse_args()
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
print_args(vars(opt))
Expand Down
6 changes: 4 additions & 2 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def run(
hide_conf=False, # hide confidences
half=False, # use FP16 half-precision inference
dnn=False, # use OpenCV DNN for ONNX inference
vid_stride=1, # video frame-rate stride
):
source = str(source)
save_img = not nosave and not source.endswith('.txt') # save inference images
Expand All @@ -96,10 +97,10 @@ def run(
# Dataloader
if webcam:
view_img = check_imshow()
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = len(dataset) # batch_size
else:
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = 1 # batch_size
vid_path, vid_writer = [None] * bs, [None] * bs

Expand Down Expand Up @@ -236,6 +237,7 @@ def parse_opt():
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
opt = parser.parse_args()
opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
print_args(vars(opt))
Expand Down
15 changes: 9 additions & 6 deletions utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __iter__(self):

class LoadImages:
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None):
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
files = []
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
p = str(Path(p).resolve())
Expand All @@ -212,6 +212,7 @@ def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None):
self.mode = 'image'
self.auto = auto
self.transforms = transforms # optional
self.vid_stride = vid_stride # video frame-rate stride
if any(videos):
self._new_video(videos[0]) # new video
else:
Expand All @@ -232,6 +233,7 @@ def __next__(self):
# Read video
self.mode = 'video'
ret_val, im0 = self.cap.read()
self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.vid_stride * (self.frame + 1)) # read at vid_stride
while not ret_val:
self.count += 1
self.cap.release()
Expand All @@ -242,7 +244,7 @@ def __next__(self):
ret_val, im0 = self.cap.read()

self.frame += 1
# im0 = self._cv2_rotate(im0) # for use if cv2 auto rotation is False
# im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '

else:
Expand All @@ -265,7 +267,7 @@ def _new_video(self, path):
# Create a new video capture object
self.frame = 0
self.cap = cv2.VideoCapture(path)
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
# self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0) # disable https://github.com/ultralytics/yolov5/issues/8493

Expand All @@ -285,11 +287,12 @@ def __len__(self):

class LoadStreams:
# YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True, transforms=None):
def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
self.mode = 'stream'
self.img_size = img_size
self.stride = stride
self.vid_stride = vid_stride # video frame-rate stride
sources = Path(sources).read_text().rsplit() if Path(sources).is_file() else [sources]
n = len(sources)
self.sources = [clean_str(x) for x in sources] # clean source names for later
Expand Down Expand Up @@ -329,11 +332,11 @@ def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True, tr

def update(self, i, cap, stream):
# Read stream `i` frames in daemon thread
n, f, read = 0, self.frames[i], 1 # frame number, frame array, inference every 'read' frame
n, f = 0, self.frames[i] # frame number, frame array
while cap.isOpened() and n < f:
n += 1
cap.grab() # .read() = .grab() followed by .retrieve()
if n % read == 0:
if n % self.vid_stride == 0:
success, im = cap.retrieve()
if success:
self.imgs[i] = im
Expand Down

0 comments on commit 1aea74c

Please sign in to comment.