From fb70bc8b4be89ed8c5bfbd4b886b4181a6046909 Mon Sep 17 00:00:00 2001 From: Ayush Shri Date: Thu, 24 Oct 2024 23:26:00 +0530 Subject: [PATCH] required_changes --- lightly/models/modules/ijepa.py | 8 ++++---- lightly/models/modules/ijepa_timm.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/lightly/models/modules/ijepa.py b/lightly/models/modules/ijepa.py index d891d1f33..aa7ed3f17 100644 --- a/lightly/models/modules/ijepa.py +++ b/lightly/models/modules/ijepa.py @@ -117,14 +117,14 @@ def forward(self, x, masks_x, masks): Args: x: - Input tensor. + Input tensor of shape (batch_size, sequence_length, feature_dim). masks_x: - Mask indices for the input tensor. + Mask indices for the input tensor of shape (batch_size, num_patches). masks: - Mask indices for the predicted tokens. + Mask indices for the predicted tokens of shape (batch_size, num_patches). Returns: - The predicted output tensor. + The predicted output tensor of shape (batch_size, num_patches, output_dim). """ assert (masks is not None) and ( masks_x is not None diff --git a/lightly/models/modules/ijepa_timm.py b/lightly/models/modules/ijepa_timm.py index 89e3dca76..dc5c5dea4 100644 --- a/lightly/models/modules/ijepa_timm.py +++ b/lightly/models/modules/ijepa_timm.py @@ -103,14 +103,14 @@ def forward( Args: x: - Input tensor. + Input tensor of shape (batch_size, sequence_length, feature_dim). masks_x: - Mask indices for the input tensor. + Mask indices for the input tensor of shape (batch_size, num_patches). masks: - Mask indices for the predicted tokens. + Mask indices for the predicted tokens of shape (batch_size, num_patches). Returns: - The predicted output tensor. + The predicted output tensor of shape (batch_size, num_patches, output_dim). """ assert (masks is not None) and (