Skip to content

Commit

Permalink
Merge pull request #106 from JoHof/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
JoHof authored Apr 6, 2024
2 parents b0d4ab8 + 5e69260 commit 0653c70
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 74 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ __pycache__/
**/.ipynb_checkpoints/
*.pyc
**/.coverage*
build/
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,5 @@ The regular U-net(R231) model works very well for COVID-19 CT scans. However, co
![alt text](figures/example_covid.jpg "COVID examples")

## jpg, png and non HU images
**This feature is only available in versions between 0.2.5 and 0.2.14**
As of version 0.2.5 these images are supported. Use the ```--noHU``` tag if you process images that are not encoded in HU. Keep in mind that the models were trained on proper CT scans encoded in HU. The results on cropped, annotated, very high and very low intensity shifted images may not be very reliable. When using the ```--noHU``` tag only single slices can be processed.
62 changes: 35 additions & 27 deletions lungmask/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ def main():
parser.add_argument(
"output", metavar="output", type=str, help="Filepath for output lungmask"
)
parser.add_argument(
"--modeltype", help="Default: unet", type=str, choices=["unet"], default="unet"
)
parser.add_argument(
"--modelname",
help="spcifies the trained model, Default: R231",
Expand All @@ -45,10 +42,6 @@ def main():
parser.add_argument(
"--modelpath", help="spcifies the path to the trained model", default=None
)
parser.add_argument(
"--classes",
help="spcifies the number of output classes of the model",
)
parser.add_argument(
"--cpu",
help="Force using the CPU even when a GPU is available, will override batchsize to 1",
Expand All @@ -59,11 +52,6 @@ def main():
help="Deactivates postprocessing (removal of unconnected components and hole filling)",
action="store_true",
)
parser.add_argument(
"--noHU",
help="For processing of images that are not encoded in hounsfield units (HU). E.g. png or jpg images from the web. Be aware, results may be substantially worse on these images",
action="store_true",
)
parser.add_argument(
"--batchsize",
type=int,
Expand All @@ -81,22 +69,28 @@ def main():
action="version",
version=version,
)
parser.add_argument(
"--removemetadata",
action="store_true",
help="Do not keep study/patient related metadata of the input, if any. Only affects output file formats that can store such information (e.g. DICOM).",
)

argsin = sys.argv[1:]
args = parser.parse_args(argsin)

if args.classes is not None:
logger.warn(
"!!! Warning: The `classes` parameter is deprecated and will be removed in the next version !!!"
)

batchsize = args.batchsize
if args.cpu:
batchsize = 1

# keeping any Patient / Study info is the default, deactivate in case of arg specified or non-HU data
keepmetadata = not args.removemetadata

logger.info("Load model")

input_image = utils.load_input_image(args.input, disable_tqdm=args.noprogress)
input_image = utils.load_input_image(
args.input, disable_tqdm=args.noprogress, read_metadata=keepmetadata
)

logger.info("Infer lungmask")
if args.modelname == "LTRCLobes_R231":
assert (
Expand All @@ -108,7 +102,6 @@ def main():
fillmodel="R231",
batch_size=batchsize,
volume_postprocessing=not (args.nopostprocess),
noHU=args.noHU,
tqdm_disable=args.noprogress,
)
result = inferer.apply(input_image)
Expand All @@ -119,21 +112,36 @@ def main():
force_cpu=args.cpu,
batch_size=batchsize,
volume_postprocessing=not (args.nopostprocess),
noHU=args.noHU,
tqdm_disable=args.noprogress,
)
result = inferer.apply(input_image)

if args.noHU:
file_ending = args.output.split(".")[-1]
if file_ending in ["jpg", "jpeg", "png"]:
result = (result / (result.max()) * 255).astype(np.uint8)
result = result[0]

result_out = sitk.GetImageFromArray(result)
result_out.CopyInformation(input_image)

writer = sitk.ImageFileWriter()
writer.SetFileName(args.output)

if keepmetadata:
# keep the Study Instance UID
writer.SetKeepOriginalImageUID(True)

DICOM_tags_to_keep = utils.get_DICOM_tags_to_keep()

# copy the DICOM tags we want to keep
for key in input_image.GetMetaDataKeys():
if key in DICOM_tags_to_keep:
result_out.SetMetaData(key, input_image.GetMetaData(key))

# set the Series Description tag
result_out.SetMetaData("0008|103e", "Created with lungmask")

# set WL/WW
result_out.SetMetaData("0028|1050", "1") # Window Center
result_out.SetMetaData("0028|1051", "2") # Window Width

logger.info(f"Save result to: {args.output}")
sitk.WriteImage(result_out, args.output)
writer.Execute(result_out)


if __name__ == "__main__":
Expand Down
47 changes: 10 additions & 37 deletions lungmask/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def __init__(
force_cpu: bool = False,
batch_size: int = 20,
volume_postprocessing: bool = True,
noHU: bool = False,
tqdm_disable: bool = False,
):
"""LungMaskInference
Expand All @@ -91,7 +90,6 @@ def __init__(
force_cpu (bool, optional): Will not use GPU is `True`. Defaults to False.
batch_size (int, optional): Batch size. Defaults to 20.
volume_postprocessing (bool, optional): If `Fales` will not perform postprocessing (connected component analysis). Defaults to True.
noHU (bool, optional): If `True` no HU intensities are expected. Not recommended. Defaults to False.
tqdm_disable (bool, optional): If `True`, will disable progress bar. Defaults to False.
"""
assert (
Expand All @@ -113,7 +111,6 @@ def __init__(
self.force_cpu = force_cpu
self.batch_size = batch_size
self.volume_postprocessing = volume_postprocessing
self.noHU = noHU
self.tqdm_disable = tqdm_disable

self.model = get_model(self.modelname, modelpath)
Expand Down Expand Up @@ -166,20 +163,9 @@ def _inference(
image = sitk.DICOMOrient(image, "LPS")
inimg_raw = sitk.GetArrayFromImage(image)

if self.noHU:
# support for non HU images. This is just a hack. The models were not trained with this in mind
tvolslices = skimage.color.rgb2gray(inimg_raw)
tvolslices = skimage.transform.resize(tvolslices, [256, 256])
tvolslices = np.asarray([tvolslices * x for x in np.linspace(0.3, 2, 20)])
tvolslices[tvolslices > 1] = 1
sanity = [
(tvolslices[x] > 0.6).sum() > 25000 for x in range(len(tvolslices))
]
tvolslices = tvolslices[sanity]
else:
tvolslices, xnew_box = utils.preprocess(inimg_raw, resolution=[256, 256])
tvolslices[tvolslices > 600] = 600
tvolslices = np.divide((tvolslices + 1024), 1624)
tvolslices, xnew_box = utils.preprocess(inimg_raw, resolution=[256, 256])
tvolslices[tvolslices > 600] = 600
tvolslices = np.divide((tvolslices + 1024), 1624)

timage_res = np.empty((np.append(0, tvolslices[0].shape)), dtype=np.uint8)

Expand Down Expand Up @@ -207,22 +193,13 @@ def _inference(
else:
outmask = timage_res

if self.noHU:
outmask = skimage.transform.resize(
outmask[np.argmax((outmask == 1).sum(axis=(1, 2)))],
inimg_raw.shape[:2],
order=0,
anti_aliasing=False,
preserve_range=True,
)[None, :, :]
else:
outmask = np.asarray(
[
utils.reshape_mask(outmask[i], xnew_box[i], inimg_raw.shape[1:])
for i in range(outmask.shape[0])
],
dtype=np.uint8,
)
outmask = np.asarray(
[
utils.reshape_mask(outmask[i], xnew_box[i], inimg_raw.shape[1:])
for i in range(outmask.shape[0])
],
dtype=np.uint8,
)

if not numpy_mode:
if curr_orient != "LPS":
Expand Down Expand Up @@ -261,7 +238,6 @@ def apply(
force_cpu=False,
batch_size=20,
volume_postprocessing=True,
noHU=False,
tqdm_disable=False,
):
warnings.warn(
Expand All @@ -272,7 +248,6 @@ def apply(
force_cpu=force_cpu,
batch_size=batch_size,
volume_postprocessing=volume_postprocessing,
noHU=noHU,
tqdm_disable=tqdm_disable,
)
if model is not None:
Expand All @@ -287,7 +262,6 @@ def apply_fused(
force_cpu=False,
batch_size=20,
volume_postprocessing=True,
noHU=False,
tqdm_disable=False,
):
warnings.warn(
Expand All @@ -300,7 +274,6 @@ def apply_fused(
fillmodel=fillmodel,
batch_size=batch_size,
volume_postprocessing=volume_postprocessing,
noHU=noHU,
tqdm_disable=tqdm_disable,
)
return inferer.apply(image)
53 changes: 49 additions & 4 deletions lungmask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@

from lungmask.logger import logger

DICOM_METADATA_TO_KEEP = (
'0008|0020', # StudyDate
'0008|0030', # StudyTime
'0008|0050', # AccessionNumber
'0008|0090', # ReferringPhysicianName
'0008|1030', # StudyDescription
'0010|0010', # PatientName
'0010|0020', # PatientID
'0010|0030', # PatientBirthDate
'0010|0040', # PatientSex
'0018|5100', # Patient Position
'0020|000d', # StudyInstanceUID
'0020|0010' # StudyID
)

def preprocess(
img: np.ndarray, resolution: list = [192, 192]
Expand Down Expand Up @@ -115,8 +129,9 @@ def reshape_mask(mask: np.ndarray, tbox: np.ndarray, origsize: tuple) -> np.ndar
return res


def read_dicoms(path, primary=True, original=True, disable_tqdm=False):
def read_dicoms(path, primary=True, original=True, disable_tqdm=False, read_metadata=False):
allfnames = []

for dir, _, fnames in os.walk(path):
[allfnames.append(os.path.join(dir, fname)) for fname in fnames]

Expand Down Expand Up @@ -199,29 +214,48 @@ def read_dicoms(path, primary=True, original=True, disable_tqdm=False):
relevant_series.append(vol_files)
reader = sitk.ImageSeriesReader()
reader.SetFileNames(vol_files)

if read_metadata:
reader.SetMetaDataDictionaryArrayUpdate(True)
reader.LoadPrivateTagsOn()

vol = reader.Execute()

if read_metadata:
for key in reader.GetMetaDataKeys(0):
vol.SetMetaData(key, reader.GetMetaData(0, key))

relevant_volumes.append(vol)

return relevant_volumes


def load_input_image(path: str, disable_tqdm=False) -> sitk.Image:
def load_input_image(path: str, disable_tqdm=False, read_metadata=False) -> sitk.Image:
"""Loads image, if path points to a file, file will be loaded. If path points ot a folder, a DICOM series will be loaded. If multiple series are present, the largest series (higher number of slices) will be loaded.
Args:
path (str): File or folderpath to be loaded. If folder, DICOM series is expected
disable_tqdm (bool, optional): Disable tqdm progress bar. Defaults to False.
read_metadata (bool, optional): Read the metadata - including DICOM tags - from the input and store in the loaded image
Returns:
sitk.Image: Loaded image
"""
if os.path.isfile(path):
logger.info(f"Read input: {path}")
input_image = sitk.ReadImage(path)

reader = sitk.ImageFileReader()
reader.SetFileName(path)
input_image = reader.Execute()

if read_metadata:
for key in reader.GetMetaDataKeys():
input_image.SetMetaData(key, reader.GetMetaData(key))

else:
logger.info(f"Looking for dicoms in {path}")
dicom_vols = read_dicoms(
path, original=False, primary=False, disable_tqdm=disable_tqdm
path, original=False, primary=False, disable_tqdm=disable_tqdm, read_metadata=read_metadata
)
if len(dicom_vols) < 1:
sys.exit("No dicoms found!")
Expand Down Expand Up @@ -368,3 +402,14 @@ def keep_largest_connected_component(mask: np.ndarray) -> np.ndarray:
max_region = np.argsort(resizes)[-1] + 1
mask = mask == max_region
return mask

def get_DICOM_tags_to_keep():
"""Returns the DICOM metadata to keep
Args:
none
Returns:
Tuple with the DICOM tags to keep
"""
return DICOM_METADATA_TO_KEEP
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

setuptools.setup(
name="lungmask",
version="0.2.19",
version="0.2.20",
author="Johannes Hofmanninger",
author_email="johannes[email protected]",
author_email="j[email protected]",
description="Package for automated lung segmentation in CT",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
13 changes: 9 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import SimpleITK as sitk
import pytest

from lungmask.utils import (
bbox_3D,
Expand Down Expand Up @@ -62,9 +63,11 @@ def test_bbox_3D():
assert tuple(bb) == (0, 10, 1, 9, 2, 8)


def test_read_dicoms():
d = read_dicoms(os.path.join(os.path.dirname(__file__), "testdata"))
@pytest.mark.parametrize("read_metadata,exp_len_metadata", [(True, 22),(False, 0)])
def test_read_dicoms(read_metadata, exp_len_metadata):
d = read_dicoms(os.path.join(os.path.dirname(__file__), "testdata"), read_metadata=read_metadata)
assert d[0].GetSize() == (512, 512, 2)
assert len(d[0].GetMetaDataKeys()) == exp_len_metadata


def test_simple_bodymask():
Expand Down Expand Up @@ -104,10 +107,12 @@ def test_reshape_mask():
assert np.sum(cropped_mask) == 400


def test_load_input_image(tmp_path):
@pytest.mark.parametrize("read_metadata,exp_len_metadata", [(True, 22),(False, 0)])
def test_load_input_image(tmp_path, read_metadata, exp_len_metadata):
# test dicom
d = load_input_image(os.path.join(os.path.dirname(__file__), "testdata"))
d = load_input_image(os.path.join(os.path.dirname(__file__), "testdata"), read_metadata=read_metadata)
assert d.GetSize() == (512, 512, 2)
assert len(d.GetMetaDataKeys()) == exp_len_metadata

# test nifti
fp_testnii = str(tmp_path / "test.nii.gz")
Expand Down

0 comments on commit 0653c70

Please sign in to comment.