Skip to content

Commit

Permalink
Merge branch 'main' into change-detection-task
Browse files Browse the repository at this point in the history
  • Loading branch information
keves1 authored Jan 7, 2025
2 parents 7425575 + 9819625 commit 75e0e5e
Show file tree
Hide file tree
Showing 27 changed files with 402 additions and 75 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/deploy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Build project
run: python3 -m build
- name: Upload artifacts
uses: actions/upload-artifact@v4.4.3
uses: actions/upload-artifact@v4.5.0
with:
name: pypi-dist
path: dist/
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
pytest --cov --cov-report=xml
python3 -m torchgeo --help
- name: Report coverage
uses: codecov/[email protected].1
uses: codecov/[email protected].2
with:
token: ${{ secrets.CODECOV_TOKEN }}
minimum:
Expand Down Expand Up @@ -82,7 +82,7 @@ jobs:
pytest --cov --cov-report=xml
python3 -m torchgeo --help
- name: Report coverage
uses: codecov/[email protected].1
uses: codecov/[email protected].2
with:
token: ${{ secrets.CODECOV_TOKEN }}
datasets:
Expand Down Expand Up @@ -114,7 +114,7 @@ jobs:
pytest --cov --cov-report=xml
python3 -m torchgeo --help
- name: Report coverage
uses: codecov/[email protected].1
uses: codecov/[email protected].2
with:
token: ${{ secrets.CODECOV_TOKEN }}
concurrency:
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ Testing:

The recommended way to install TorchGeo is with [pip](https://pip.pypa.io/):

```console
$ pip install torchgeo
```sh
pip install torchgeo
```

For [conda](https://docs.conda.io/) and [spack](https://spack.io/) installation instructions, see the [documentation](https://torchgeo.readthedocs.io/en/stable/user/installation.html).
Expand Down Expand Up @@ -192,7 +192,7 @@ trainer.fit(model=task, datamodule=datamodule)

TorchGeo also supports command-line interface training using [LightningCLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html). It can be invoked in two ways:

```console
```sh
# If torchgeo has been installed
torchgeo
# If torchgeo has been installed, or if it has been cloned to the current directory
Expand All @@ -201,7 +201,7 @@ python3 -m torchgeo

It supports command-line configuration or YAML/JSON config files. Valid options can be found from the help messages:

```console
```sh
# See valid stages
torchgeo --help
# See valid trainer options
Expand Down Expand Up @@ -233,7 +233,7 @@ data:
we can see the script in action:
```console
```sh
# Train and validate a model
torchgeo fit --config config.yaml
# Validate-only
Expand Down
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
('py:class', 'fiona.model.Feature'),
('py:class', 'kornia.augmentation._2d.intensity.base.IntensityAugmentationBase2D'),
('py:class', 'kornia.augmentation.base._AugmentationBase'),
('py:class', 'lightning.pytorch.utilities.types.LRSchedulerConfig'),
('py:class', 'lightning.pytorch.utilities.types.OptimizerConfig'),
('py:class', 'lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig'),
('py:class', 'segmentation_models_pytorch.base.model.SegmentationModel'),
('py:class', 'timm.models.resnet.ResNet'),
Expand Down
4 changes: 3 additions & 1 deletion docs/tutorials/basic_usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ Basic Usage
The following tutorials introduce the basic concepts and components of TorchGeo:

* `Transforms <transforms.ipynb>`_: Preprocessing and data augmentation transforms for geospatial data
* `Indices <indices.ipynb>`_: Spectral indices
* `Spectral Indices <indices.ipynb>`_: Visualizing and appending spectral indices
* `Pretrained Weights <pretrained_weights.ipynb>`_: Models and pretrained weights
* `Lightning Trainers <trainers.ipynb>`_: PyTorch Lightning data modules and trainers
* `Command-Line Interface <cli.ipynb>`_: TorchGeo's command-line interface

.. toctree::
:hidden:
Expand All @@ -16,3 +17,4 @@ The following tutorials introduce the basic concepts and components of TorchGeo:
indices
pretrained_weights
trainers
cli
292 changes: 292 additions & 0 deletions docs/tutorials/cli.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "16421d50-8d7a-4972-b06f-160fd890cc86",
"metadata": {},
"outputs": [],
"source": [
"# Copyright (c) Microsoft Corporation. All rights reserved.\n",
"# Licensed under the MIT License."
]
},
{
"cell_type": "markdown",
"id": "e563313d",
"metadata": {},
"source": [
"# Command-Line Interface\n",
"\n",
"_Written by: Adam J. Stewart_\n",
"\n",
"TorchGeo provides a command-line interface based on [LightningCLI](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.cli.LightningCLI.html) that allows users to combine our data modules and trainers from the comfort of the command line. This no-code solution can be attractive for both beginners and experts, as it offers flexibility and reproducibility. In this tutorial, we demonstrate some of the features of this interface."
]
},
{
"cell_type": "markdown",
"id": "8c1f4156",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"First, we install TorchGeo. In addition to the Python library, this also installs a `torchgeo` executable."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3f0d31a8",
"metadata": {},
"outputs": [],
"source": [
"%pip install torchgeo"
]
},
{
"cell_type": "markdown",
"id": "7801ab8b-0ee3-40ac-88c2-4bdc29bb4e1b",
"metadata": {},
"source": [
"## Subcommands\n",
"\n",
"The `torchgeo` command has a number of *subcommands* that can be run. The `--help` flag can be used to list them."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a6ccac4e-7f20-4aa8-b851-27234ffd259f",
"metadata": {},
"outputs": [],
"source": [
"!torchgeo --help"
]
},
{
"cell_type": "markdown",
"id": "19ee017d-0d8f-41c6-8e7c-68495c7e62b6",
"metadata": {},
"source": [
"## Trainer\n",
"\n",
"Below, we run `--help` on the `fit` subcommand to see what options are available to us. `fit` is used to train and validate a model, and we can customize many aspects of the training process."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "afe1dc9d-4cee-43b0-ae30-200c64d3401a",
"metadata": {},
"outputs": [],
"source": [
"!torchgeo fit --help"
]
},
{
"cell_type": "markdown",
"id": "b437860c-b406-4150-b30b-8aa895eebfcd",
"metadata": {},
"source": [
"## Model\n",
"\n",
"We must first select an `nn.Module` model architecture to train and a `lightning.pytorch.LightningModule` trainer to train it. We will experiment with the `ClassificationTask` trainer and see what options we can customize. Any of TorchGeo's builtin trainers, or trainers written by the user, can be used in this way."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7cd9bbd0-17c9-4e87-b10d-ea846c39bc24",
"metadata": {},
"outputs": [],
"source": [
"!torchgeo fit --model.help ClassificationTask"
]
},
{
"cell_type": "markdown",
"id": "3daacd8d-64f4-4357-bdf3-759295a14224",
"metadata": {},
"source": [
"## Data\n",
"\n",
"We must also select a `Dataset` we would like to train on and a `lightning.pytorch.LightningDataModule` we can use to access the train/val/test split and any augmentations to apply to the data. Similarly, we use the `--help` flag to see what options are available for the `EuroSAT100` dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "136eb59f-6662-44af-82e9-c55bdb3f17ac",
"metadata": {},
"outputs": [],
"source": [
"!torchgeo fit --data.help EuroSAT100DataModule"
]
},
{
"cell_type": "markdown",
"id": "8039cb67-ee18-4b41-8bf5-0e939493f5bb",
"metadata": {},
"source": [
"## Config\n",
"\n",
"Now that we have seen all important configuration options, we can put them together in a YAML file. LightingCLI supports YAML, JSON, and command-line configuration. While we will write this file using Python in this tutorial, normally this file would be written in your favorite text editor."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e25c8efb-ed8c-4795-862c-bfb84cc84e1f",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import tempfile\n",
"\n",
"root = os.path.join(tempfile.gettempdir(), 'eurosat100')\n",
"config = f\"\"\"\n",
"trainer:\n",
" max_epochs: 1\n",
" default_root_dir: '{root}'\n",
"model:\n",
" class_path: ClassificationTask\n",
" init_args:\n",
" model: 'resnet18'\n",
" in_channels: 13\n",
" num_classes: 10\n",
"data:\n",
" class_path: EuroSAT100DataModule\n",
" init_args:\n",
" batch_size: 8\n",
" dict_kwargs:\n",
" root: '{root}'\n",
" download: true\n",
"\"\"\"\n",
"os.makedirs(root, exist_ok=True)\n",
"with open(os.path.join(root, 'config.yaml'), 'w') as f:\n",
" f.write(config)"
]
},
{
"cell_type": "markdown",
"id": "a661b8d7-2dc9-4a30-8842-bd52d130e080",
"metadata": {},
"source": [
"This YAML file has three sections:\n",
"\n",
"* trainer: Arguments to pass to the [Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html)\n",
"* model: Arguments to pass to the task\n",
"* data: Arguments to pass to the data module\n",
"\n",
"The `class_path` gives the class to instantiate, `init_args` lists standard arguments, and `dict_kwargs` lists keyword arguments."
]
},
{
"cell_type": "markdown",
"id": "e132f933-4edf-42bb-b585-e0d8ceb65eab",
"metadata": {},
"source": [
"## Training\n",
"\n",
"We can now train our model like so."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f84b0739-c9e7-4057-8864-98ab69a11f64",
"metadata": {},
"outputs": [],
"source": [
"!torchgeo fit --config {root}/config.yaml"
]
},
{
"cell_type": "markdown",
"id": "cb1557f1-6cc0-46da-909c-836911acb248",
"metadata": {},
"source": [
"## Validation\n",
"\n",
"Now that we have a trained model, we can evaluate performance on the validation set. Note that we need to explicitly pass in the location of the checkpoint from the previous run."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b9cbb4f4-1879-4ae7-bae4-2c24d49a4a61",
"metadata": {},
"outputs": [],
"source": [
"import glob\n",
"\n",
"checkpoint = glob.glob(\n",
" os.path.join(root, 'lightning_logs', 'version_0', 'checkpoints', '*.ckpt')\n",
")[0]\n",
"\n",
"!torchgeo validate --config {root}/config.yaml --ckpt_path {checkpoint}"
]
},
{
"cell_type": "markdown",
"id": "ba816fc3-5cac-4cbc-a6ef-effc6c9faa61",
"metadata": {},
"source": [
"## Testing\n",
"\n",
"After finishing our hyperparameter tuning, we can calculate and report the final test performance."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1faa997-9f81-4847-94fc-5a8bb7687369",
"metadata": {},
"outputs": [],
"source": [
"!torchgeo test --config {root}/config.yaml --ckpt_path {checkpoint}"
]
},
{
"cell_type": "markdown",
"id": "f5383d30-8f76-44a2-8366-e6fcbd1e6042",
"metadata": {},
"source": [
"## Additional Reading\n",
"\n",
"Lightning CLI has many more features that are worth learning. You can learn more by reading the following set of tutorials:\n",
"\n",
"* [Configure hyperparameters from the CLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": []
},
"execution": {
"timeout": 1200
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit 75e0e5e

Please sign in to comment.