Skip to content

Commit

Permalink
try random section
Browse files Browse the repository at this point in the history
  • Loading branch information
gferraro committed Nov 11, 2024
1 parent 46b431b commit e9b51a9
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 10 deletions.
2 changes: 2 additions & 0 deletions src/classify/clipclassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def classify_clip(self, clip, model, meta_data, reuse_frames=None):
predictions.model_load_time = time.time() - start

for i, track in enumerate(clip.tracks):
logging.info("Track id is %s", track.get_id())

segment_frames = None
if reuse_frames:
tracks = meta_data.get("tracks")
Expand Down
37 changes: 27 additions & 10 deletions src/ml_tools/datasetstructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,7 @@ def get_segments(
frame_min_mass=None,
fp_frames=None,
repeat_frame_indices=True,
min_segments=None,
):
if min_frames is None:
min_frames = segment_width / 4.0
Expand Down Expand Up @@ -1059,7 +1060,9 @@ def get_segments(
segments.extend(new_segments)
filtered_stats.merge(filtered)
continue
if len(frame_indices) < min_frames:
if len(frame_indices) < min_frames and (
min_segments == 0 or min_segments is None
):
filtered_stats["too short"] += 1
continue

Expand All @@ -1069,7 +1072,6 @@ def get_segments(
# probably only counts for all random
if max_segments is not None and segment_type not in [SegmentType.ALL_SECTIONS]:
segment_count = min(max_segments, segment_count)

# take any segment_width frames, this could be done each epoch
whole_indices = frame_indices
random_frames = segment_type in [
Expand All @@ -1079,17 +1081,31 @@ def get_segments(
SegmentType.TOP_RANDOM,
None,
]
random_mask = True
for _ in range(repeats):
frame_indices = whole_indices.copy()
if random_frames:
# random_frames and not random_sections:
np.random.shuffle(frame_indices)
used_indices = []
if not random_mask:
frame_indices = whole_indices.copy()

if random_frames:
# random_frames and not random_sections:
np.random.shuffle(frame_indices)

for i in range(segment_count):
if random_mask:
mask_start = i * 25
frame_indices = list(whole_indices[0:mask_start].copy())
frame_indices.extend(whole_indices[mask_start + 25 :].copy())
frame_indices = [f for f in frame_indices if f not in used_indices]
frame_indices = np.uint32(frame_indices)
np.random.shuffle(frame_indices)

# always get atleast one segment, not doing annymore
if (
len(frame_indices) < segment_width / 2.0 and len(segments) > 1
) or len(frame_indices) < segment_width / 4:
break
if len(frame_indices) == 0 or len(segments) >= min_segments:
if (
len(frame_indices) < segment_width / 2.0 and len(segments) > 1
) or len(frame_indices) < segment_width / 4:
break

if segment_type == SegmentType.ALL_SECTIONS:
# random frames from section 2.2 * segment_width
Expand All @@ -1106,6 +1122,7 @@ def get_segments(
elif random_frames:
# frame indices already randomized so just need to grab some
frames = frame_indices[:segment_width]
used_indices.extend(frames)
frame_indices = frame_indices[segment_width:]
else:
segment_start = i * segment_frame_spacing
Expand Down
1 change: 1 addition & 0 deletions src/ml_tools/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def preprocess_segments(
max_segments=max_segments,
dont_filter=dont_filter,
filter_by_fp=False,
min_segments=1,
)
frame_indices = set()
for segment in segments:
Expand Down
2 changes: 2 additions & 0 deletions src/track/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ def get_segments(
ffc_frames=None,
dont_filter=False,
filter_by_fp=False,
min_segments=1,
):
if from_last is not None:
if from_last == 0:
Expand Down Expand Up @@ -490,6 +491,7 @@ def get_segments(
segment_types=segment_types,
max_segments=max_segments,
dont_filter=dont_filter,
min_segments=min_segments,
)

return segments
Expand Down

0 comments on commit e9b51a9

Please sign in to comment.