From e9b51a977237cd848df3a58cf6048f6292240857 Mon Sep 17 00:00:00 2001 From: gferraro Date: Mon, 11 Nov 2024 18:19:42 +0100 Subject: [PATCH] try random section --- src/classify/clipclassifier.py | 2 ++ src/ml_tools/datasetstructures.py | 37 ++++++++++++++++++++++--------- src/ml_tools/interpreter.py | 1 + src/track/track.py | 2 ++ 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/src/classify/clipclassifier.py b/src/classify/clipclassifier.py index 9e7dd279..4de7ce0c 100644 --- a/src/classify/clipclassifier.py +++ b/src/classify/clipclassifier.py @@ -182,6 +182,8 @@ def classify_clip(self, clip, model, meta_data, reuse_frames=None): predictions.model_load_time = time.time() - start for i, track in enumerate(clip.tracks): + logging.info("Track id is %s", track.get_id()) + segment_frames = None if reuse_frames: tracks = meta_data.get("tracks") diff --git a/src/ml_tools/datasetstructures.py b/src/ml_tools/datasetstructures.py index daa41741..554fc7d5 100644 --- a/src/ml_tools/datasetstructures.py +++ b/src/ml_tools/datasetstructures.py @@ -987,6 +987,7 @@ def get_segments( frame_min_mass=None, fp_frames=None, repeat_frame_indices=True, + min_segments=None, ): if min_frames is None: min_frames = segment_width / 4.0 @@ -1059,7 +1060,9 @@ def get_segments( segments.extend(new_segments) filtered_stats.merge(filtered) continue - if len(frame_indices) < min_frames: + if len(frame_indices) < min_frames and ( + min_segments == 0 or min_segments is None + ): filtered_stats["too short"] += 1 continue @@ -1069,7 +1072,6 @@ def get_segments( # probably only counts for all random if max_segments is not None and segment_type not in [SegmentType.ALL_SECTIONS]: segment_count = min(max_segments, segment_count) - # take any segment_width frames, this could be done each epoch whole_indices = frame_indices random_frames = segment_type in [ @@ -1079,17 +1081,31 @@ def get_segments( SegmentType.TOP_RANDOM, None, ] + random_mask = True for _ in range(repeats): - frame_indices = whole_indices.copy() - if random_frames: - # random_frames and not random_sections: - np.random.shuffle(frame_indices) + used_indices = [] + if not random_mask: + frame_indices = whole_indices.copy() + + if random_frames: + # random_frames and not random_sections: + np.random.shuffle(frame_indices) + for i in range(segment_count): + if random_mask: + mask_start = i * 25 + frame_indices = list(whole_indices[0:mask_start].copy()) + frame_indices.extend(whole_indices[mask_start + 25 :].copy()) + frame_indices = [f for f in frame_indices if f not in used_indices] + frame_indices = np.uint32(frame_indices) + np.random.shuffle(frame_indices) + # always get atleast one segment, not doing annymore - if ( - len(frame_indices) < segment_width / 2.0 and len(segments) > 1 - ) or len(frame_indices) < segment_width / 4: - break + if len(frame_indices) == 0 or len(segments) >= min_segments: + if ( + len(frame_indices) < segment_width / 2.0 and len(segments) > 1 + ) or len(frame_indices) < segment_width / 4: + break if segment_type == SegmentType.ALL_SECTIONS: # random frames from section 2.2 * segment_width @@ -1106,6 +1122,7 @@ def get_segments( elif random_frames: # frame indices already randomized so just need to grab some frames = frame_indices[:segment_width] + used_indices.extend(frames) frame_indices = frame_indices[segment_width:] else: segment_start = i * segment_frame_spacing diff --git a/src/ml_tools/interpreter.py b/src/ml_tools/interpreter.py index bdac4f53..b2dc166d 100644 --- a/src/ml_tools/interpreter.py +++ b/src/ml_tools/interpreter.py @@ -304,6 +304,7 @@ def preprocess_segments( max_segments=max_segments, dont_filter=dont_filter, filter_by_fp=False, + min_segments=1, ) frame_indices = set() for segment in segments: diff --git a/src/track/track.py b/src/track/track.py index b8264c35..f265b014 100644 --- a/src/track/track.py +++ b/src/track/track.py @@ -445,6 +445,7 @@ def get_segments( ffc_frames=None, dont_filter=False, filter_by_fp=False, + min_segments=1, ): if from_last is not None: if from_last == 0: @@ -490,6 +491,7 @@ def get_segments( segment_types=segment_types, max_segments=max_segments, dont_filter=dont_filter, + min_segments=min_segments, ) return segments