Skip to content

Commit

Permalink
Allow PyTorch Hub results to display in notebooks (ultralytics#9825)
Browse files Browse the repository at this point in the history
* Allow PyTorch Hub results to display in notebooks

* fix CI

* fix CI

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix CI

* fix CI

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix CI

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix CI

* fix CI

* fix CI

* fix CI

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix CI

* fix CI

* fix CI

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
glenn-jocher and pre-commit-ci[bot] authored Oct 17, 2022
1 parent e42c89d commit e3ff780
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 14 deletions.
2 changes: 1 addition & 1 deletion classify/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def run(
# Dataloader
bs = 1 # batch_size
if webcam:
view_img = check_imshow()
view_img = check_imshow(warn=True)
dataset = LoadStreams(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride)
bs = len(dataset)
elif screenshot:
Expand Down
2 changes: 1 addition & 1 deletion detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def run(
# Dataloader
bs = 1 # batch_size
if webcam:
view_img = check_imshow()
view_img = check_imshow(warn=True)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = len(dataset)
elif screenshot:
Expand Down
13 changes: 9 additions & 4 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@
import requests
import torch
import torch.nn as nn
from IPython.display import display
from PIL import Image
from torch.cuda import amp

from utils import TryExcept
from utils.dataloaders import exif_transpose, letterbox
from utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suffix, check_version, colorstr,
increment_path, make_divisible, non_max_suppression, scale_boxes, xywh2xyxy, xyxy2xywh,
yaml_load)
from utils.general import (LOGGER, ROOT, Profile, check_imshow, check_requirements, check_suffix, check_version,
colorstr, increment_path, make_divisible, non_max_suppression, scale_boxes, xywh2xyxy,
xyxy2xywh, yaml_load)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import copy_attr, smart_inference_mode

CHECK_IMSHOW = check_imshow()


def autopad(k, p=None, d=1): # kernel, padding, dilation
# Pad to 'same' shape outputs
Expand Down Expand Up @@ -756,7 +760,7 @@ def _run(self, pprint=False, show=False, save=False, crop=False, render=False, l

im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
if show:
im.show(self.files[i]) # show
im.show(self.files[i]) if CHECK_IMSHOW else display(im)
if save:
f = self.files[i]
im.save(save_dir / f) # save
Expand All @@ -772,6 +776,7 @@ def _run(self, pprint=False, show=False, save=False, crop=False, render=False, l
LOGGER.info(f'Saved results to {save_dir}\n')
return crops

@TryExcept('Showing images is not supported in this environment')
def show(self, labels=True):
self._run(show=True, labels=labels) # show results

Expand Down
2 changes: 1 addition & 1 deletion segment/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def run(
# Dataloader
bs = 1 # batch_size
if webcam:
view_img = check_imshow()
view_img = check_imshow(warn=True)
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = len(dataset)
elif screenshot:
Expand Down
2 changes: 1 addition & 1 deletion utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __enter__(self):

def __exit__(self, exc_type, value, traceback):
if value:
print(emojis(f'{self.msg}{value}'))
print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))
return True


Expand Down
2 changes: 1 addition & 1 deletion utils/autoanchor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def check_anchor_order(m):
m.anchors[:] = m.anchors.flip(0)


@TryExcept(f'{PREFIX}ERROR: ')
@TryExcept(f'{PREFIX}ERROR')
def check_anchors(dataset, model, thr=4.0, imgsz=640):
# Check anchor fit to data, recompute if necessary
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
Expand Down
17 changes: 13 additions & 4 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from zipfile import ZipFile

import cv2
import IPython
import numpy as np
import pandas as pd
import pkg_resources as pkg
Expand Down Expand Up @@ -73,6 +74,12 @@ def is_colab():
return 'COLAB_GPU' in os.environ


def is_notebook():
# Is environment a Jupyter notebook? Verified on Colab, Jupyterlab, Kaggle, Paperspace
ipython_type = str(type(IPython.get_ipython()))
return 'colab' in ipython_type or 'zmqshell' in ipython_type


def is_kaggle():
# Is environment a Kaggle Notebook?
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
Expand Down Expand Up @@ -383,18 +390,20 @@ def check_img_size(imgsz, s=32, floor=0):
return new_size


def check_imshow():
def check_imshow(warn=False):
# Check if environment supports image displays
try:
assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'
assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'
assert not is_notebook()
assert not is_docker()
assert 'NoneType' not in str(type(IPython.get_ipython())) # SSH terminals, GitHub CI
cv2.imshow('test', np.zeros((1, 1, 3)))
cv2.waitKey(1)
cv2.destroyAllWindows()
cv2.waitKey(1)
return True
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}')
if warn:
LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
return False


Expand Down
2 changes: 1 addition & 1 deletion utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def tp_fp(self):
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
return tp[:-1], fp[:-1] # remove background class

@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure: ')
@TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
def plot(self, normalize=True, save_dir='', names=()):
import seaborn as sn

Expand Down

0 comments on commit e3ff780

Please sign in to comment.