Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ChangeDetectionTask #2422

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft

Conversation

keves1
Copy link

@keves1 keves1 commented Nov 21, 2024

This PR is to add a change detection trainer as mentioned in #2382.

Key points/items to discuss:

  • I used the OSCD dataset to test with and modified this dataset to use a temporal dimension.
  • With the added temporal dimension, Kornia’s AugmentationSequential doesn’t work, but can be combined with VideoSequential to support the temporal dimension (see Kornia docs). I overrode self.aug in the OSCDDataModule to do this but not sure if this should be incorporated into the BaseDataModule instead.
  • VideoSequential adds a temporal dimension to the mask. Not sure if there is a way to avoid this, or if this is desirable, but I added an if statement to the AugmentationSequential wrapper to check for and remove this added dimension.
  • The OSCDDataModule applies _RandomNCrop augmentation, but this does not work for time series data. I'm not sure how to modify _RandomNCrop to fix this and would appreciate some help/guidance.
  • There are a few tests that I need to still make pass.

cc @robmarkcole

@github-actions github-actions bot added datasets Geospatial or benchmark datasets testing Continuous integration testing trainers PyTorch Lightning trainers transforms Data augmentation transforms datamodules PyTorch Lightning datamodules labels Nov 21, 2024
@robmarkcole
Copy link
Contributor

I wonder if we should limit the scope to change between two timesteps and binary change - then we can use binary metrics and provide a template for the plot methods. I say this because this is the most common change detection task by a mile. Might also simplify the augmentations approach? Treating as a video sequence seems overkill.
Also I understand there is support for multitemporal coming later.

@adamjstewart
Copy link
Collaborator

I wonder if we should limit the scope to change between two timesteps

I'm personally okay with this, although @hfangcat has a recent work using multiple pre-event images that would be nice to support someday (could be a subclass if necessary).

and binary change

Again, this would probably be fine as a starting point, although I would someday like to make all trainers support binary/multiclass/multilabel, e.g., #2219.

provide a template for the plot methods.

Could also do this in the datasets (at least for benchmark NonGeoDatasets). We're also trying to remove explicit plotting in the trainers: #2184

I say this because this is the most common change detection task by a mile.

Agreed.

Might also simplify the augmentations approach? Treating as a video sequence seems overkill.

I actually like the video augmentations, but let me loop in the Kornia folks to get their opinion: @edgarriba @johnnv1

Also I understand there is support for multitemporal coming later.

Correct, see #2382 for the big picture (I think I also sent you a recording of my presented plan).

@adamjstewart
Copy link
Collaborator

VideoSequential adds a temporal dimension to the mask. Not sure if there is a way to avoid this

Can you try keepdim=True?

I added an if statement to the AugmentationSequential wrapper to check for and remove this added dimension.

@ashnair1 would this work directly with K.AugmentationSequential now? We are trying to phase out our AugmentationSequential wrapper now that upstream supports (almost?) everything we need.

@keves1
Copy link
Author

keves1 commented Nov 22, 2024

I will go ahead and make changes for this to be for binary change and two timesteps, sounds like a good starting point.

Can you try keepdim=True?

I tried this and it didn't get rid of the other dimension. I also looked into extra_args but didn't see any options to help with this.

Could also do this in the datasets (at least for benchmark NonGeoDatasets). We're also trying to remove explicit plotting in the trainers

I was going to add plotting in the trainer, but would you rather not then? What would this look like in the dataset?

@robmarkcole
Copy link
Contributor

Perhaps there should even be a base class ChangeDetection and subclasses for BinaryChangeDetection etc?

@adamjstewart
Copy link
Collaborator

That's exactly what I'm trying to undo in #2219.

@adamjstewart
Copy link
Collaborator

I was going to add plotting in the trainer, but would you rather not then?

We can copy-n-paste the validation_step plotting stuff used by other trainers, but that's probably going to disappear soon (I think we're just waiting on testing in #2184.

What would this look like in the dataset?

See OSCD.plot()

@github-actions github-actions bot added the losses Geospatial loss functions label Dec 5, 2024
@keves1
Copy link
Author

keves1 commented Dec 5, 2024

I've updated this to now support only binary change with two timesteps.
To get test_weight_file in test_change.py to work with the two images stacked on the channel dimension for Unet, I modified the pytest fixture model() in conftest.py to use timm to create the model instead of torchvision, so that an in_channels parameter can be passed.

I still haven't been able to figure out how to make transforms.transforms._RandomNCrop work with the added temporal dimension. It seems to have something to do with _NCropGenerator not properly handling the temporal dimension but I really don't understand what is going on there.

Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you resolve the merge conflicts so we can run the tests?

load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
# multiply in_chans by 2 since images are concatenated
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How hard would it be to do late fusion, so pass each image through the encoder separately, then concatenate them, then pass them through the decoder? This would make it easier to use pre-trained models.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's definitely possible, although I think we would need a custom Unet implementation in torchgeo/models to do this. It would simplify using the pretrained weights but is late fusion a common enough approach that many people would find this useful?

monkeypatch.setattr(weights, 'url', str(path))
return weights

@pytest.mark.parametrize('model', [6], indirect=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remind me what [6] means here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Number of input channels (2 3-channel images stacked)

torchgeo/datamodules/oscd.py Outdated Show resolved Hide resolved
@@ -240,7 +242,7 @@ def _load_target(self, path: Path) -> Tensor:
array: np.typing.NDArray[np.int_] = np.array(img.convert('L'))
tensor = torch.from_numpy(array)
tensor = torch.clamp(tensor, min=0, max=1)
tensor = tensor.to(torch.long)
tensor = tensor.to(torch.float)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would the target be a float?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loss function BCEWithLogitsLoss expects the target to be a float.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to use BCEWithLogitsLoss? Can we use BCELoss instead?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both BCELoss and BCEWithLogitsLoss require float targets. Here's a brief explanation I found as to why: https://discuss.pytorch.org/t/inconsistency-between-loss-functions-input-types/138942. Is there any issue with the target being converted to a float here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. For our binary classification datasets, we convert the label to a float in MultiLabelClassificationTask instead of in the dataset. I would kind of like our datasets to be consistent (int for classification and float for regression). Let's change it in ChangeDetectionTask instead.

torchgeo/losses/__init__.py Outdated Show resolved Hide resolved
torchgeo/trainers/change.py Outdated Show resolved Hide resolved
torchgeo/trainers/change.py Show resolved Hide resolved
torchgeo/transforms/transforms.py Show resolved Hide resolved
@keves1
Copy link
Author

keves1 commented Dec 17, 2024

I'm going to need some help figuring out how to get transforms.transforms._RandomNCrop to work with the added temporal dimension (this is used by the OSCD dataset). I've delved into this a few times but with my lack of familiarity with Kornia I haven't been able to track down the source of the issue. You can see the issue by running tests/trainers/test_change.py::TestChangeDetectionTask::test_trainer.

Also, disregard my earlier comments about Kornia VideoSequential adding a dimension to the mask, this seems to have resolved with the latest Kornia version.

@adamjstewart adamjstewart added this to the 0.7.0 milestone Dec 19, 2024
@github-actions github-actions bot removed the losses Geospatial loss functions label Dec 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datamodules PyTorch Lightning datamodules datasets Geospatial or benchmark datasets testing Continuous integration testing trainers PyTorch Lightning trainers transforms Data augmentation transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants