diff --git a/foolbox/plot.py b/foolbox/plot.py index 7d4141fa..4ce32df0 100644 --- a/foolbox/plot.py +++ b/foolbox/plot.py @@ -13,8 +13,10 @@ def images( nrows: Optional[int] = None, figsize: Optional[Tuple[float, float]] = None, scale: float = 1, + labels: Any = None, + return_fig: bool = False, **kwargs: Any, -) -> None: +) -> Optional[Tuple[Any, Any]]: import matplotlib.pyplot as plt x: ep.Tensor = ep.astensor(images) @@ -57,7 +59,7 @@ def images( nrows=nrows, figsize=figsize, squeeze=False, - constrained_layout=True, + constrained_layout=False, **kwargs, ) @@ -68,5 +70,11 @@ def images( ax.set_yticks([]) ax.axis("off") i = row * ncols + col + if labels is not None: + ax.set_title(labels[i]) if i < len(x): ax.imshow(x[i]) + + if return_fig: + return fig, axes + return None