Skip to content

Commit

Permalink
name refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
liopeer committed Nov 19, 2024
1 parent f51e834 commit 8286ee3
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 10 deletions.
2 changes: 1 addition & 1 deletion lightly/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
if torchvision_transforms_v2_available():
from lightly.transforms.add_grid_transform import AddGridTransform
from lightly.transforms.detcon_transform import (
DetConSimCLRViewTransform,
DetConSTransform,
DetConSViewTransform,
)
from lightly.transforms.multi_view_transform_v2 import MultiViewTransformV2
8 changes: 4 additions & 4 deletions lightly/transforms/detcon_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def __init__(
) -> None:
self.grid_size = grid_size

tr1: List[Union[AddGridTransform, DetConSimCLRViewTransform]] = []
tr2: List[Union[AddGridTransform, DetConSimCLRViewTransform]] = []
tr1: List[Union[AddGridTransform, DetConSViewTransform]] = []
tr2: List[Union[AddGridTransform, DetConSViewTransform]] = []

if self.grid_size is not None:
grid_tr1 = AddGridTransform(
Expand All @@ -120,7 +120,7 @@ def __init__(
tr2 += [grid_tr2]

tr1 += [
DetConSimCLRViewTransform(
DetConSViewTransform(
gaussian_blur=gaussian_blur_t1,
input_size=input_size,
cj_prob=cj_prob,
Expand All @@ -141,7 +141,7 @@ def __init__(
)
]
tr2 += [
DetConSimCLRViewTransform(
DetConSViewTransform(
gaussian_blur=gaussian_blur_t2,
input_size=input_size,
cj_prob=cj_prob,
Expand Down
7 changes: 2 additions & 5 deletions tests/transforms/test_detcon_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
pytest.skip("torchvision.transforms.v2 not available", allow_module_level=True)
from torchvision.tv_tensors import Image, Mask

from lightly.transforms.detcon_transform import (
DetConSimCLRViewTransform,
DetConSTransform,
)
from lightly.transforms.detcon_transform import DetConSTransform, DetConSViewTransform


# ignore typing due to Any type used in torchvison.transforms.v2.Transform
Expand All @@ -27,7 +24,7 @@ def mask() -> Mask: # type: ignore[misc]

class TestDetConSimCLRViewTransform:
def test_given_masks(self, img: Image, mask: Mask) -> None:
tr = DetConSimCLRViewTransform(input_size=(224, 224))
tr = DetConSViewTransform(input_size=(224, 224))

img_tr, mask_tr = tr(img, mask)
assert img_tr.shape == (3, 224, 224)
Expand Down

0 comments on commit 8286ee3

Please sign in to comment.