Skip to content

Commit

Permalink
Merge pull request #5 from clane9/bold-diff
Browse files Browse the repository at this point in the history
Add volume difference plotting scripts
  • Loading branch information
gkiar authored Nov 11, 2022
2 parents 52210ec + e68a6d5 commit 1822fe4
Show file tree
Hide file tree
Showing 8 changed files with 405 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Short-and-sweet cookie cutter visualization scripts or utilities, intended to lo
| [`plot_nii_overlay.py`](./code/plot_nii_overlay.py) | Compare placement/alignment of two images | Structural Nii, Mean functional Nii, other 3D contrast | ![plot_nii_overlay example](./examples/plot_nii_overlay.png) |
| [`plot_nii_similarity.py`](./code/plot_nii_similarity.py) | Compare signal distributions of two images | Any two Nii images of the same shape | ![plot_nii_similarity example](./examples/plot_nii_similarity.png) |
| [`plot_gii_surface.py`](./code/plot_gii_surface.py) | Visualize data on the surface | Surface mesh file, volumetric nii data to be displayed on surface mesh | ![plot_gii_surface example](./examples/plot_gii_surface.png) |
| [`plot_nii_difference.py`](./code/plot_nii_difference.py) | Visualize the difference between two nii images | Any two Nii images of the same shape | ![plot_nii_difference example](./examples/plot_nii_difference.png) |


## Usage Instructions
Expand Down
225 changes: 225 additions & 0 deletions code/plot_nii_difference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
"""
Example usage:
python plot_nii_difference.py \
-o vol_diff.png \
--images /path/to/vol1.nii.gz /path/to/vol2.nii.gz \
--masks /path/to/mask1.nii.gz /path/to/mask2.nii.gz \
--labels vol1 vol2
"""

import argparse
import logging
import os
import pprint
from typing import List, Tuple, Optional

from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from matplotlib.colors import LinearSegmentedColormap
import nibabel as nib
from nilearn import plotting as nilplt
from nilearn import image as nilimg

from viz_utils import find_value_lims, apply_mask

plt.rcParams["figure.dpi"] = 150

logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)


def main(args: argparse.Namespace):
logging.info("Args:\n%s", pprint.pformat(args.__dict__))

logging.info("Loading images")
imgs = [nib.load(path) for path in args.images]
assert len(imgs) == 2, "only two images supported"
assert imgs[0].ndim == imgs[1].ndim, "both images should have the same ndim"
assert len(args.labels) == 2, "two labels are required"

if imgs[0].ndim == 4:
# NOTE: indexing a 4d '.nii.gz' is slow bc you have to gunzip the whole file.
# h5 with compressed chunks might be better..
logging.info("Indexing images at %d", args.index)
imgs[0] = nilimg.index_img(imgs[0], args.index)
imgs[1] = nilimg.index_img(imgs[1], args.index + args.index_offset)
index = args.index
else:
index = None

if args.masks is not None:
for ii, path in enumerate(args.masks):
if os.path.exists(path):
logging.info("Applying mask to image %d", ii)
mask = nib.load(path)
imgs[ii] = apply_mask(imgs[ii], mask)

logging.info("Resampling img2 -> img1")
img1, img2 = imgs
img1 = nilimg.reorder_img(img1, resample=False)
img2 = nilimg.reorder_img(img2, resample=False)
img2 = nilimg.resample_to_img(img2, img1)

cut_coords = tuple(float(val.strip()) for val in args.cut_coords.split(","))

f, axs = plt.subplots(3, 1, figsize=(9, 9))
logging.info("Plotting image")
plot_difference_triplet(
img1=img1,
img2=img2,
labels=args.labels,
fig=f,
axs=axs,
index=index,
cut_coords=cut_coords,
vmax=args.vmax,
colorbar=args.colorbar,
fname=args.out,
)

logging.info("Done")


def plot_difference_triplet(
*,
img1: nib.Nifti1Image,
img2: nib.Nifti1Image,
labels: List[str],
fig: Optional[Figure] = None,
axs: Optional[List[Axes]] = None,
index: Optional[int] = None,
cut_coords: Tuple[float, float, float] = (1.0, 0.0, 0.0),
vmin: Optional[float] = None,
vmax: Optional[float] = None,
colorbar: bool = False,
fname: Optional[str] = None,
):
"""
Plot `img1`, `img2`, and the difference `img1 - img2`.
Args:
img1: First volume.
img2: Second volume.
labels: List of two labels, one per volume.
fig: Optional figure to plot into.
axs: Optional list of three axes to plot into.
index: Optional volume index (just for labeling).
cut_coords: Ortho viewer cut coordinates. See
`nilearn.plotting.plot_epi` for details.
vmin: Optional vmin.
vmax: Optional vmax.
colorbar: Show the colorbar.
fname: Optional image filename.
"""
assert len(labels) == 2, "two labels expected"
if fig is None:
fig, axs = plt.subplots(3, 1, figsize=(9, 9))
else:
assert axs is not None, "axs is required with fig is provided"
assert len(axs) == 3, "three Axes required"
fig.clear()

if vmin is None:
vmin = 0.0
if vmax is None:
_, vmax1 = find_value_lims(img1.get_fdata())
_, vmax2 = find_value_lims(img2.get_fdata())
vmax = max(vmax1, vmax2)

title = labels[0] if index is None else f"{labels[0]} ({index:04d})"
nilplt.plot_epi(
img1,
figure=fig,
axes=axs[0],
colorbar=colorbar,
cut_coords=cut_coords,
draw_cross=True,
vmin=vmin,
vmax=vmax,
cmap="gray",
title=title,
)

nilplt.plot_epi(
img2,
figure=fig,
axes=axs[1],
colorbar=colorbar,
cut_coords=cut_coords,
draw_cross=True,
vmin=vmin,
vmax=vmax,
cmap="gray",
title=labels[1],
)

diff = nib.Nifti1Image(img1.dataobj - img2.dataobj, img1.affine)
nilplt.plot_epi(
diff,
figure=fig,
axes=axs[2],
colorbar=colorbar,
cut_coords=cut_coords,
draw_cross=True,
vmin=-vmax,
vmax=vmax,
cmap=LinearSegmentedColormap.from_list(
"cold_hot",
["cyan", "blue", "black", "red", "yellow"],
),
title="difference",
)

if fname is not None:
fig.savefig(fname, bbox_inches="tight", facecolor="black")


if __name__ == "__main__":
parser = argparse.ArgumentParser("plot_nii_difference")
parser.add_argument(
"--out", "-o", metavar="PATH", required=True, type=str,
help="path to output image"
)
parser.add_argument(
"--images", "-i", metavar="PATH", required=True, type=str, nargs=2,
help="paths to two images"
)
parser.add_argument(
"--masks", metavar="PATH", type=str, nargs=2,
help="paths to two corresponding mask images"
)
parser.add_argument(
"--labels", metavar="LABEL", type=str, nargs=2,
help="labels for the two series"
)
parser.add_argument(
"--index", metavar="IND", type=int, default=0,
help="volume index for 4d data"
)
parser.add_argument(
"--index-offset", metavar="IND", type=int, default=0,
help=(
"Offset between the two image series. "
"`index1 = index; index2 = index + offset"
)
)
parser.add_argument(
"--cut-coords", metavar="X,Y,Z", type=str, default="1.0, 0.0, 0.0",
help='ortho cut coordinates (default: "1.0, 0.0, 0.0")'
)
parser.add_argument(
"--vmax", metavar="VAL", type=float, default=None,
help="plotting vmax"
)
parser.add_argument(
"--colorbar", action="store_true",
help="show colorbar"
)

args = parser.parse_args()
main(args)
138 changes: 138 additions & 0 deletions code/plot_nii_difference_movie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""
Example usage:
python plot_nii_difference_movie.py \
-o bold_diff_movie \
--images /path/to/bold1.nii.gz /path/to/bold2.nii.gz \
--masks /path/to/mask1.nii.gz /path/to/mask2.nii.gz \
--labels bold1 bold2
"""

import argparse
import logging
import os
import pprint
import subprocess
from pathlib import Path

from matplotlib import pyplot as plt
import nibabel as nib
from nilearn import image as nilimg

from plot_nii_difference import plot_difference_triplet
from viz_utils import find_value_lims, apply_mask

plt.rcParams["figure.dpi"] = 150

logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)


def main(args: argparse.Namespace):
logging.info("Args:\n%s", pprint.pformat(args.__dict__))
outdir = Path(args.outdir)
(outdir / "frames").mkdir(parents=True, exist_ok=True)

logging.info("Loading images")
imgs = [nib.load(path) for path in args.images]
assert len(imgs) == 2, "only two images supported"
assert imgs[0].ndim == imgs[1].ndim == 4, "both images should be 4d"
assert len(args.labels) == 2, "two labels are required"
if args.masks is not None:
for ii, path in enumerate(args.masks):
if os.path.exists(path):
logging.info("Applying mask to image %d", ii)
mask = nib.load(path)
imgs[ii] = apply_mask(imgs[ii], mask)

logging.info("Resampling img2 -> img1")
img1, img2 = imgs
img1 = nilimg.reorder_img(img1, resample=False)
img2 = nilimg.reorder_img(img2, resample=False)
img2 = nilimg.resample_to_img(img2, img1)

# truncate to last volumes
if img1.shape[3] != img2.shape[3]:
ntpts1, ntpts2 = img1.shape[3], img2.shape[3]
ntpts = min(ntpts1, ntpts2)
logging.info("Truncating to %d time points", ntpts)
img1 = nib.Nifti1Image(img1.dataobj[..., -ntpts:], img1.affine)
img2 = nib.Nifti1Image(img2.dataobj[..., -ntpts:], img2.affine)
else:
ntpts = img1.shape[3]

logging.info("Finding the vmax")
_, vmax1 = find_value_lims(img1.get_fdata())
_, vmax2 = find_value_lims(img2.get_fdata())
vmax = max(vmax1, vmax2)

cut_coords = tuple(float(val.strip()) for val in args.cut_coords.split(","))

f, axs = plt.subplots(3, 1, figsize=(9, 9))

for tpt in range(ntpts):
logging.info("Plotting frame %d", tpt)
fname = outdir / "frames" / f"{tpt:04d}.png"
plot_difference_triplet(
img1=nilimg.index_img(img1, tpt),
img2=nilimg.index_img(img2, tpt),
labels=args.labels,
fig=f,
axs=axs,
index=tpt,
cut_coords=cut_coords,
vmax=vmax,
colorbar=args.colorbar,
fname=fname,
)

cmd = (
"ffmpeg -y -framerate 2 -pattern_type glob -i '{frames}' "
"-vf 'pad=ceil(iw/2)*2:ceil(ih/2)*2' "
"-c:v libx264 -r 30 -pix_fmt yuv420p {out}"
).format(
frames=str(outdir / "frames" / "*.png"),
out=str(outdir / args.fname),
)
logging.info("Combining frames with ffmpeg\n\t%s", cmd)
subprocess.call(cmd, shell=True)

logging.info("Done")


if __name__ == "__main__":
parser = argparse.ArgumentParser("plot_nii_difference_movie")
parser.add_argument(
"--outdir", "-o", metavar="PATH", required=True, type=str,
help="path to output directory"
)
parser.add_argument(
"--images", "-i", metavar="PATH", required=True, type=str, nargs=2,
help="paths to two 4d image series"
)
parser.add_argument(
"--masks", metavar="PATH", type=str, nargs=2,
help="paths to two corresponding mask images"
)
parser.add_argument(
"--labels", metavar="LABEL", type=str, nargs=2,
help="labels for the two series"
)
parser.add_argument(
"--cut-coords", metavar="X,Y,Z", type=str, default="1.0, 0.0, 0.0",
help='ortho cut coordinates (default: "1.0, 0.0, 0.0")'
)
parser.add_argument(
"--colorbar", action="store_true",
help="show colorbar"
)
parser.add_argument(
"--fname", metavar="NAME", type=str, default="out.mp4",
help='output video filename (default: "out.mp4")'
)

args = parser.parse_args()
main(args)
Loading

0 comments on commit 1822fe4

Please sign in to comment.