Skip to content

Commit

Permalink
Assign version to pytorch-lightning (baal-org#226)
Browse files Browse the repository at this point in the history
* Assign version to pytorch-lightning

* Fix mypy

* Bump count

* Fix test PL
  • Loading branch information
Dref360 authored Jul 10, 2022
1 parent e0db2c8 commit cdcb3ae
Show file tree
Hide file tree
Showing 15 changed files with 1,606 additions and 1,587 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ mypy:

.PHONY: check-mypy-error-count
check-mypy-error-count: MYPY_INFO = $(shell expr `poetry run mypy baal | grep ": error" | wc -l`)
check-mypy-error-count: MYPY_ERROR_COUNT = 9
check-mypy-error-count: MYPY_ERROR_COUNT = 6

check-mypy-error-count:
@if [ ${MYPY_INFO} -gt ${MYPY_ERROR_COUNT} ]; then \
Expand All @@ -38,4 +38,4 @@ REPORT_FOLDER ?= ./reports
bandit: ./reports/security/bandit/ ## SECURITY - Run bandit
poetry run bandit ${SRC_FOLDER}/* -r -x "*.pyi,*/_generated/*,*__pycache__*" -v -ll -f json > ${REPORT_FOLDER}/security/bandit/index.json

.PHONY: bandit
.PHONY: bandit
15 changes: 13 additions & 2 deletions baal/active/dataset/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
import warnings
from typing import Union, List, Optional, Any
from typing import Union, List, Optional, Any, TYPE_CHECKING, Protocol

import numpy as np
from sklearn.utils import check_random_state
from torch.utils import data as torchdata


class SplittedDataset(torchdata.Dataset):
class SizeableDataset(torchdata.Dataset):
def __len__(self):
pass


if TYPE_CHECKING:
Dataset = SizeableDataset
else:
Dataset = torchdata.Dataset


class SplittedDataset(Dataset):
"""Abstract class for Dataset that can be splitted.
Args:
Expand Down
2 changes: 1 addition & 1 deletion baal/active/dataset/nlp_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import numpy as np
import torch
from baal.active.dataset.base import Dataset
from datasets import Dataset as HFDataset
from torch.utils.data import Dataset

from baal.active import ActiveLearningDataset

Expand Down
29 changes: 14 additions & 15 deletions baal/active/dataset/pytorch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

import numpy as np
import torch.utils.data as torchdata
from sklearn.utils import check_random_state

from baal.active.dataset.base import SplittedDataset
from baal.active.dataset.base import SplittedDataset, Dataset


def _identity(x):
Expand All @@ -33,7 +32,7 @@ class ActiveLearningDataset(SplittedDataset):

def __init__(
self,
dataset: torchdata.Dataset,
dataset: Dataset,
labelled: Optional[np.ndarray] = None,
make_unlabelled: Callable = _identity,
random_state=None,
Expand Down Expand Up @@ -70,9 +69,9 @@ def check_dataset_can_label(self):
with definition: `label(self, idx, value)` where `value`
is the label for indice `idx`.
"""
has_label_attr = hasattr(self._dataset, "label")
has_label_attr = getattr(self._dataset, "label", None)
if has_label_attr:
if callable(self._dataset.label):
if callable(has_label_attr):
return True
else:
warnings.warn(
Expand Down Expand Up @@ -109,22 +108,22 @@ def __iter__(self):
return self.ActiveIter(self)

@property
def pool(self) -> torchdata.Dataset:
def pool(self) -> "ActiveLearningPool":
"""Returns a new Dataset made from unlabelled samples.
Raises:
ValueError if a pool specific attribute cannot be set.
"""
pool_dataset = deepcopy(self._dataset)
current_dataset = deepcopy(self._dataset)

for attr, new_val in self.pool_specifics.items():
if hasattr(pool_dataset, attr):
setattr(pool_dataset, attr, new_val)
if hasattr(current_dataset, attr):
setattr(current_dataset, attr, new_val)
else:
raise ValueError(f"{pool_dataset} doesn't have {attr}")
raise ValueError(f"{current_dataset} doesn't have {attr}")

pool_dataset = torchdata.Subset(
pool_dataset, (~self.labelled).nonzero()[0].reshape([-1]).tolist()
pool_dataset: torchdata.Subset = torchdata.Subset(
current_dataset, (~self.labelled).nonzero()[0].reshape([-1]).tolist()
)
ald = ActiveLearningPool(pool_dataset, make_unlabelled=self.make_unlabelled)
return ald
Expand Down Expand Up @@ -163,7 +162,7 @@ def label(self, index: Union[list, int], value: Optional[Any] = None) -> None:
active_step = self.current_al_step + 1
for idx, val in zip_longest(indexes, value_lst, fillvalue=None):
if self.can_label and val is not None:
self._dataset.label(idx, val)
self._dataset.label(idx, val) # type: ignore
self.labelled_map[idx] = active_step
elif self.can_label and val is None:
raise ValueError(
Expand Down Expand Up @@ -211,8 +210,8 @@ class ActiveLearningPool(torchdata.Dataset):
"""

def __init__(self, dataset: torchdata.Dataset, make_unlabelled: Callable = _identity) -> None:
self._dataset: torchdata.Dataset = dataset
def __init__(self, dataset: torchdata.Subset, make_unlabelled: Callable = _identity) -> None:
self._dataset: torchdata.Subset = dataset
self.make_unlabelled = make_unlabelled

def __getitem__(self, index: int) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion baal/active/heuristics/heuristics_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from baal.active.dataset.base import Dataset

from baal import ModelWrapper

Expand Down
2 changes: 1 addition & 1 deletion baal/calibration/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import structlog
import torch
from baal.active.dataset.base import Dataset
from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset

from baal import ModelWrapper
from baal.utils.metrics import ECE, ECE_PerCLs
Expand Down
3 changes: 2 additions & 1 deletion baal/modelwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import structlog
import torch
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from tqdm import tqdm

from baal.utils.array_utils import stack_in_memory
from baal.active.dataset.base import Dataset
from baal.utils.cuda_utils import to_cuda
from baal.utils.iterutils import map_on_tensor
from baal.utils.metrics import Loss
Expand Down
2 changes: 1 addition & 1 deletion baal/transformers_trainer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# These packages are optional and not needed for BaaL main package.
try:
from transformers import Trainer
from transformers.trainer import Trainer
except ImportError:
raise ImportError(
"`transformers` library is required to use this module."
Expand Down
6 changes: 4 additions & 2 deletions baal/utils/pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def predict_step(self, batch, batch_idx, dataloader_idx: Optional[int] = None):
# Get the input only.
x, _ = batch
# Perform Monte-Carlo Inference fro I iterations.
out = mc_inference(self, x, self.hparams.iterations, self.hparams.replicate_in_memory)
out = mc_inference(self, x,
self.hparams.iterations, # type: ignore
self.hparams.replicate_in_memory) # type: ignore
return out


Expand Down Expand Up @@ -151,7 +153,7 @@ def predict_on_dataset_generator(
model = model or self.lightning_module
model.eval()
if isinstance(self.accelerator, GPUAccelerator):
model.cuda(self.accelerator.root_device)
model.cuda(self.strategy.root_device.index)
dataloader = dataloader or model.pool_dataloader()
if len(dataloader) == 0:
return None
Expand Down
2 changes: 2 additions & 0 deletions baal/utils/ssl_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import typing
from argparse import Namespace
from typing import Dict

Expand Down Expand Up @@ -35,6 +36,7 @@ def training_step(self, batch, *args):
else:
return self.unsupervised_training_step(SemiSupervisedIterator.get_batch(batch), *args)

@typing.no_type_check
def train_dataloader(self) -> SemiSupervisedIterator:
"""SemiSupervisedIterator for train set.
Expand Down
2 changes: 1 addition & 1 deletion experiments/segmentation/unet_mcdropout_pascal.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def mean_regions(n, grid_size=16):
n = torch.from_numpy(n[:, None, ...])
# [Batch_size, 1, grid, grid]
out = F.adaptive_avg_pool2d(n, grid_size)
return np.mean(out.view([-1, grid_size ** 2]).numpy(), -1)
return np.mean(out.view([-1, grid_size**2]).numpy(), -1)


def parse_args():
Expand Down
2 changes: 1 addition & 1 deletion experiments/ssl_experiments/pimodel_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def rampup_value(self):
def rampdown_value(self):
if self.current_epoch >= self.epoch - self.hparams.rampup_stop - 1:
T = (1 / (self.epoch - self.hparams.rampup_stop - 1)) * self.current_epoch
return np.exp(-12.5 * T ** 2)
return np.exp(-12.5 * T**2)
else:
return 0

Expand Down
Loading

0 comments on commit cdcb3ae

Please sign in to comment.