Skip to content

Commit

Permalink
tidy up
Browse files Browse the repository at this point in the history
  • Loading branch information
gferraro committed Nov 12, 2024
1 parent 0ca2c93 commit d39902e
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 18 deletions.
13 changes: 0 additions & 13 deletions src/classify/clipclassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,6 @@ def classify_clip(self, clip, model, meta_data, reuse_frames=None):
predictions.model_load_time = time.time() - start

for i, track in enumerate(clip.tracks):
logging.info("Track id is %s", track.get_id())

segment_frames = None
if reuse_frames:
tracks = meta_data.get("tracks")
Expand Down Expand Up @@ -247,17 +245,6 @@ def save_metadata(
prediction = predictions.prediction_for(track.get_id())
if prediction is None:
continue
# DEBUGGING STUFF REMOVE ME
# logging.info("Track predictions %s", track)
# for p in prediction.predictions:
# logging.info(
# "Have %s sum %s smoothed %s mass %s",
# p,
# np.sum(p.prediction),
# np.round(p.smoothed_prediction),
# p.mass,
# )
# logging.info("smoothed %s", np.round(100 * prediction.class_best_score))
prediction_meta = prediction.get_metadata()
prediction_meta["model_id"] = model_id
prediction_info.append(prediction_meta)
Expand Down
1 change: 0 additions & 1 deletion src/ml_tools/hyperparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def segment_width(self):

@property
def segment_types(self):

segment_types = self.get("segment_type", [SegmentType.ALL_RANDOM])
# convert string to enum type
if isinstance(segment_types[0], str):
Expand Down
14 changes: 11 additions & 3 deletions src/ml_tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,17 @@ def saveclassify_image(data, filename):
# saves image channels side by side, expected data to be values in the range of 0->1
Path(filename).parent.mkdir(parents=True, exist_ok=True)
r = Image.fromarray(np.uint8(data[:, :, 0]))
g = Image.fromarray(np.uint8(data[:, :, 1]))
b = g
# b = Image.fromarray(np.uint8(data[:, :, 2]))
_, _, channels = data.shape

if channels == 1:
g = r
else:
g = Image.fromarray(np.uint8(data[:, :, 1]))

if channels == 2:
b = r
else:
b = Image.fromarray(np.uint8(data[:, :, 2]))
concat = np.concatenate((r, g, b), axis=1) # horizontally
img = Image.fromarray(np.uint8(concat))
img.save(filename + ".png")
Expand Down
2 changes: 1 addition & 1 deletion src/rebuildDate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dateutil.parser import parse as parse_date

parser = argparse.ArgumentParser()
parser.add_argument("data_dir", help="Directory of hdf5 files")
parser.add_argument("data_dir", help="Directory of cptv files")
args = parser.parse_args()
args.data_dir = Path(args.data_dir)
latest_date = None
Expand Down

0 comments on commit d39902e

Please sign in to comment.