Skip to content

Commit

Permalink
fix to small tracks
Browse files Browse the repository at this point in the history
  • Loading branch information
gferraro committed Nov 11, 2024
1 parent 25ae03b commit 7cd7efc
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
15 changes: 10 additions & 5 deletions src/ml_tools/datasetstructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,11 +1095,16 @@ def get_segments(

for i in range(segment_count):
if random_mask:
mask_start = i * 25
frame_indices = list(whole_indices[0:mask_start].copy())
frame_indices.extend(whole_indices[mask_start + 25 :].copy())
frame_indices = [f for f in frame_indices if f not in used_indices]
frame_indices = np.uint32(frame_indices)
if len(whole_indices) < 40:
frame_indices = whole_indices.copy()
else:
mask_start = i * 25
frame_indices = list(whole_indices[0:mask_start].copy())
frame_indices.extend(whole_indices[mask_start + 25 :].copy())
frame_indices = [
f for f in frame_indices if f not in used_indices
]
frame_indices = np.uint32(frame_indices)
np.random.shuffle(frame_indices)

# always get atleast one segment, not doing annymore
Expand Down
4 changes: 3 additions & 1 deletion src/ml_tools/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def preprocess(self, clip, track, **args):
predict_from_last,
segment_frames=segment_frames,
dont_filter=args.get("dont_filter", False),
min_segments=args.get("min_segments"),
)
else:
frames, preprocessed, masses = self.preprocess_frames(
Expand Down Expand Up @@ -290,6 +291,7 @@ def preprocess_segments(
predict_from_last=None,
segment_frames=None,
dont_filter=False,
min_segments=None,
):
from ml_tools.preprocess import preprocess_frame, preprocess_movement

Expand All @@ -304,7 +306,7 @@ def preprocess_segments(
max_segments=max_segments,
dont_filter=dont_filter,
filter_by_fp=False,
min_segments=1,
min_segments=min_segments,
)
frame_indices = set()
for segment in segments:
Expand Down
3 changes: 1 addition & 2 deletions src/modelevaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,8 @@ def load_clip_data(cptv_file):
for track in clip.tracks:
try:
frames, preprocessed, masses = worker_model.preprocess(
clip_db, track, frames_per_classify=25, dont_filter=True
clip_db, track, frames_per_classify=25, dont_filter=True, min_segments=1
)

data.append(
(
f"{track.clip_id}-{track.get_id()}",
Expand Down

0 comments on commit 7cd7efc

Please sign in to comment.