diff --git a/mist/analyze_data/analyze.py b/mist/analyze_data/analyze.py index 8d1c5fb..f3afaf3 100755 --- a/mist/analyze_data/analyze.py +++ b/mist/analyze_data/analyze.py @@ -112,9 +112,13 @@ def compute_class_weights(self): """Compute class weights on original data.""" # Either compute class weights or use user provided weights. - # Check that number of class weights matches the number of labels. + # Check that number of class weights matches the number of labels if + # provided. n_labels = len(self.dataset_information["labels"]) - if len(self.mist_arguments.class_weights) != n_labels: + if ( + self.mist_arguments.class_weights and + len(self.mist_arguments.class_weights) != n_labels + ): raise ValueError( "Number of class weights must match number of labels." ) @@ -185,7 +189,7 @@ def get_target_spacing(self): # to make sure that all of the axes in the spacings match up. # We load the masks because they are smaller and faster to load. mask = ants.image_read(patient["mask"]) - mask = mask.reorient_image2(mask, "RAI") + mask = ants.reorient_image2(mask, "RAI") mask.set_direction( analyzer_constants.AnalyzeConstants.RAI_ANTS_DIRECTION )