Skip to content

Commit

Permalink
Move cudnn.benchmarks(True) to LoadStreams (ultralytics#9258)
Browse files Browse the repository at this point in the history
* Move cudnn.benchmarks(True) to LoadStreams

* Update dataloaders.py

Signed-off-by: Glenn Jocher <[email protected]>

* Move cudnn.benchmarks(True) to LoadStreams

Signed-off-by: Glenn Jocher <[email protected]>
  • Loading branch information
glenn-jocher authored Sep 2, 2022
1 parent 9da6d0f commit ffdb58b
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 54 deletions.
2 changes: 0 additions & 2 deletions classify/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from pathlib import Path

import torch
import torch.backends.cudnn as cudnn
import torch.nn.functional as F

FILE = Path(__file__).resolve()
Expand Down Expand Up @@ -89,7 +88,6 @@ def run(
# Dataloader
if webcam:
view_img = check_imshow()
cudnn.benchmark = True # set True to speed up constant image size inference
dataset = LoadStreams(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]))
bs = len(dataset) # batch_size
else:
Expand Down
2 changes: 0 additions & 2 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from pathlib import Path

import torch
import torch.backends.cudnn as cudnn

FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
Expand Down Expand Up @@ -97,7 +96,6 @@ def run(
# Dataloader
if webcam:
view_img = check_imshow()
cudnn.benchmark = True # set True to speed up constant image size inference
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
bs = len(dataset) # batch_size
else:
Expand Down
54 changes: 4 additions & 50 deletions utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,62 +283,17 @@ def __len__(self):
return self.nf # number of files


class LoadWebcam: # for inference
# YOLOv5 local webcam dataloader, i.e. `python detect.py --source 0`
def __init__(self, pipe='0', img_size=640, stride=32):
self.img_size = img_size
self.stride = stride
self.pipe = eval(pipe) if pipe.isnumeric() else pipe
self.cap = cv2.VideoCapture(self.pipe) # video capture object
self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size

def __iter__(self):
self.count = -1
return self

def __next__(self):
self.count += 1
if cv2.waitKey(1) == ord('q'): # q to quit
self.cap.release()
cv2.destroyAllWindows()
raise StopIteration

# Read frame
ret_val, im0 = self.cap.read()
im0 = cv2.flip(im0, 1) # flip left-right

# Print
assert ret_val, f'Camera Error {self.pipe}'
img_path = 'webcam.jpg'
s = f'webcam {self.count}: '

# Process
im = letterbox(im0, self.img_size, stride=self.stride)[0] # resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous

return img_path, im, im0, None, s

def __len__(self):
return 0


class LoadStreams:
# YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True, transforms=None):
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
self.mode = 'stream'
self.img_size = img_size
self.stride = stride

if os.path.isfile(sources):
with open(sources) as f:
sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
else:
sources = [sources]

sources = Path(sources).read_text().rsplit() if Path(sources).is_file() else [sources]
n = len(sources)
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
self.sources = [clean_str(x) for x in sources] # clean source names for later
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream
st = f'{i + 1}/{n}: {s}... '
Expand Down Expand Up @@ -377,8 +332,7 @@ def update(self, i, cap, stream):
n, f, read = 0, self.frames[i], 1 # frame number, frame array, inference every 'read' frame
while cap.isOpened() and n < f:
n += 1
# _, self.imgs[index] = cap.read()
cap.grab()
cap.grab() # .read() = .grab() followed by .retrieve()
if n % read == 0:
success, im = cap.retrieve()
if success:
Expand Down

0 comments on commit ffdb58b

Please sign in to comment.