Skip to content

Commit

Permalink
added mask segment type as default
Browse files Browse the repository at this point in the history
  • Loading branch information
gferraro committed Nov 12, 2024
1 parent 7cd7efc commit 0ca2c93
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/ml_tools/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
self.excluded_tags = config.build.excluded_tags
self.min_frame_mass = config.build.min_frame_mass
self.filter_by_lq = config.build.filter_by_lq
self.segment_types = [SegmentType.ALL_RANDOM]
self.segment_types = [SegmentType.ALL_RANDOM_MASKED]
self.max_segments = config.build.max_segments
self.country = config.build.country
self.max_frames = config.build.max_frames
Expand Down
34 changes: 23 additions & 11 deletions src/ml_tools/datasetstructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class SegmentType(Enum):
ALL_SECTIONS = 5
TOP_RANDOM = 6
ALL_RANDOM_NOMIN = 7
ALL_RANDOM_MASKED = 8


class BaseSample(ABC):
Expand Down Expand Up @@ -1071,44 +1072,55 @@ def get_segments(
frame_indices = np.array(frame_indices)
segment_count = max(1, len(frame_indices) // segment_frame_spacing)
segment_count = int(segment_count)
mask_length = 25

# probably only counts for all random
if max_segments is not None and segment_type not in [SegmentType.ALL_SECTIONS]:
segment_count = min(max_segments, segment_count)
# adjust size of mask if we take less segments
mask_length = max(mask_length, len(frame_indices) // segment_count)
# take any segment_width frames, this could be done each epoch
whole_indices = frame_indices
random_frames = segment_type in [
SegmentType.IMPORTANT_RANDOM,
SegmentType.ALL_RANDOM,
SegmentType.ALL_RANDOM_NOMIN,
SegmentType.TOP_RANDOM,
SegmentType.ALL_RANDOM_MASKED,
None,
]
random_mask = True

for _ in range(repeats):
used_indices = []
if not random_mask:
if segment_type != SegmentType.ALL_RANDOM_MASKED or len(whole_indices) < 40:
frame_indices = whole_indices.copy()

if random_frames:
# random_frames and not random_sections:
np.random.shuffle(frame_indices)

for i in range(segment_count):
if random_mask:
if len(whole_indices) < 40:
frame_indices = whole_indices.copy()
else:
mask_start = i * 25
frame_indices = list(whole_indices[0:mask_start].copy())
frame_indices.extend(whole_indices[mask_start + 25 :].copy())
if segment_type == SegmentType.ALL_RANDOM_MASKED:
if len(whole_indices) > 40:
mask_start = i * mask_length
frame_indices = whole_indices[0:mask_start]
frame_indices = np.concatenate(
[frame_indices, whole_indices[mask_start + mask_length :]],
axis=0,
)
# maybe some faster way of doing this...
frame_indices = [
f for f in frame_indices if f not in used_indices
]
frame_indices = np.uint32(frame_indices)
np.random.shuffle(frame_indices)
np.random.shuffle(frame_indices)

# always get atleast one segment, not doing annymore
if len(frame_indices) == 0 or len(segments) >= min_segments:
if (
len(frame_indices) == 0
or min_segments is None
or len(segments) >= min_segments
):
if (
len(frame_indices) < segment_width / 2.0 and len(segments) > 1
) or len(frame_indices) < segment_width / 4:
Expand Down

0 comments on commit 0ca2c93

Please sign in to comment.