Skip to content

Commit

Permalink
Refactor hparams in tests (mosaicml#1498)
Browse files Browse the repository at this point in the history
This PR builds on mosaicml#1491 and completes the transition to removing hparams from unit-testing.
  • Loading branch information
hanlint authored Sep 3, 2022
1 parent 52f9428 commit e1654b4
Show file tree
Hide file tree
Showing 31 changed files with 506 additions and 547 deletions.
21 changes: 12 additions & 9 deletions tests/algorithms/test_gradient_clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,8 @@
import composer.algorithms.gradient_clipping.gradient_clipping as gc_module
from composer.algorithms.gradient_clipping import GradientClipping, apply_gradient_clipping
from composer.algorithms.gradient_clipping.gradient_clipping import _apply_agc, _get_clipped_gradient_coeff
from composer.core import Engine
from composer.core import Engine, State
from composer.core.event import Event
from tests.fixtures import dummy_fixtures

# To satisfy pyright.
dummy_state = dummy_fixtures.dummy_state


@pytest.fixture
Expand Down Expand Up @@ -81,7 +77,7 @@ def test_gradient_clipping_functional(monkeypatch):


@pytest.mark.parametrize('clipping_type', [('adaptive',), ('norm',), ('value',)])
def test_gradient_clipping_algorithm(monkeypatch, clipping_type, simple_model_with_grads, dummy_state):
def test_gradient_clipping_algorithm(monkeypatch, clipping_type, simple_model_with_grads, dummy_state: State):
model = simple_model_with_grads
apply_gc_fn = Mock()
monkeypatch.setattr(gc_module, 'apply_gradient_clipping', apply_gc_fn)
Expand All @@ -98,8 +94,11 @@ def test_gradient_clipping_algorithm(monkeypatch, clipping_type, simple_model_wi
apply_gc_fn.assert_called_once()


def test_gradient_clipping_algorithm_with_deepspeed_enabled(monkeypatch: pytest.MonkeyPatch, simple_model_with_grads,
dummy_state):
def test_gradient_clipping_algorithm_with_deepspeed_enabled(
monkeypatch: pytest.MonkeyPatch,
simple_model_with_grads,
dummy_state: State,
):
clipping_threshold = 0.1191
apply_gc_fn = Mock()
monkeypatch.setattr(gc_module, 'apply_gradient_clipping', apply_gc_fn)
Expand Down Expand Up @@ -128,7 +127,11 @@ def test_gradient_clipping_algorithm_with_deepspeed_enabled(monkeypatch: pytest.
apply_gc_fn.assert_not_called()


def test_algorithm_with_deepspeed_enabled_errors_out_for_non_norm(monkeypatch: pytest.MonkeyPatch, dummy_state):
def test_algorithm_with_deepspeed_enabled_errors_out_for_non_norm(
monkeypatch: pytest.MonkeyPatch,
dummy_state: State,
simple_model_with_grads,
):
clipping_threshold = 0.1191
apply_gc_fn = Mock()
monkeypatch.setattr(gc_module, 'apply_gradient_clipping', apply_gc_fn)
Expand Down
11 changes: 2 additions & 9 deletions tests/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@
from typing import List, Type

from tests.common.compare import deep_compare
from tests.common.datasets import (RandomClassificationDataset, RandomClassificationDatasetHparams, RandomImageDataset,
configure_dataset_hparams_for_synthetic)
from tests.common.datasets import RandomClassificationDataset, RandomImageDataset
from tests.common.events import EventCounterCallback
from tests.common.markers import device, world_size
from tests.common.models import (SimpleConvModel, SimpleConvModelHparams, SimpleModel, SimpleModelHparams,
configure_model_hparams_for_synthetic)
from tests.common.models import SimpleConvModel, SimpleModel
from tests.common.state import assert_state_equivalent


Expand All @@ -22,17 +20,12 @@ def get_module_subclasses(module: types.ModuleType, cls: Type) -> List[Type]:
__all__ = [
'assert_state_equivalent',
'RandomClassificationDataset',
'RandomClassificationDatasetHparams',
'RandomImageDataset',
'configure_dataset_hparams_for_synthetic',
'SimpleConvModel',
'SimpleModel',
'SimpleModelHparams',
'SimpleConvModelHparams',
'EventCounterCallback',
'deep_compare',
'device',
'world_size',
'configure_model_hparams_for_synthetic',
'get_module_subclasses',
]
56 changes: 1 addition & 55 deletions tests/common/datasets.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,14 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import dataclasses
from typing import List, Optional, Sequence
from typing import Sequence

import pytest
import torch
import torch.utils.data
import yahp as hp
from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets import VisionDataset

from composer.datasets.dataset_hparams import DataLoaderHparams, DatasetHparams
from composer.datasets.glue_hparams import GLUEHparams
from composer.datasets.lm_dataset_hparams import LMDatasetHparams
from composer.datasets.synthetic_hparams import SyntheticHparamsMixin
from composer.models import ModelHparams
from tests.common.models import model_hparams_to_tokenizer_family


class RandomClassificationDataset(Dataset):
"""Classification dataset drawn from a normal distribution.
Expand All @@ -41,32 +31,6 @@ def __getitem__(self, index: int):
return self.x[index], self.y[index]


@dataclasses.dataclass
class RandomClassificationDatasetHparams(DatasetHparams, SyntheticHparamsMixin):

data_shape: List[int] = hp.optional('data shape', default_factory=lambda: [1, 1, 1])
num_classes: int = hp.optional('num_classes', default=2)

def initialize_object(self, batch_size: int, dataloader_hparams: DataLoaderHparams):
assert self.data_shape is not None
assert self.num_classes is not None
dataset = RandomClassificationDataset(
size=self.synthetic_num_unique_samples,
shape=self.data_shape,
num_classes=self.num_classes,
)
if self.shuffle:
sampler = torch.utils.data.RandomSampler(dataset)
else:
sampler = torch.utils.data.SequentialSampler(dataset)
return dataloader_hparams.initialize_object(
dataset=dataset,
batch_size=batch_size,
sampler=sampler,
drop_last=self.drop_last,
)


class RandomImageDataset(VisionDataset):
""" Image Classification dataset with values drawn from a normal distribution
Args:
Expand Down Expand Up @@ -110,21 +74,3 @@ def __getitem__(self, index: int):
return self.transform(x), y
else:
return x, y


def configure_dataset_hparams_for_synthetic(
dataset_hparams: DatasetHparams,
model_hparams: Optional[ModelHparams] = None,
) -> None:
if not isinstance(dataset_hparams, SyntheticHparamsMixin):
pytest.xfail(f'{dataset_hparams.__class__.__name__} does not support synthetic data or num_total_batches')

assert isinstance(dataset_hparams, SyntheticHparamsMixin)

dataset_hparams.use_synthetic = True

if model_hparams and type(model_hparams) in model_hparams_to_tokenizer_family:
tokenizer_family = model_hparams_to_tokenizer_family[type(model_hparams)]
assert isinstance(dataset_hparams, (GLUEHparams, LMDatasetHparams))
dataset_hparams.tokenizer_name = tokenizer_family
dataset_hparams.max_seq_length = 128
33 changes: 0 additions & 33 deletions tests/common/hparams.py

This file was deleted.

137 changes: 0 additions & 137 deletions tests/common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,9 @@

"""Contains commonly used models that are shared across the test suite."""

import dataclasses
from typing import Any, Dict, Type

import torch
import yahp as hp

from composer.datasets.synthetic_lm import generate_synthetic_tokenizer
from composer.models import ComposerClassifier
from composer.models.bert.bert_hparams import BERTForClassificationHparams, BERTHparams
from composer.models.deeplabv3.deeplabv3_hparams import DeepLabV3Hparams
from composer.models.gpt2.gpt2_hparams import GPT2Hparams
from composer.models.model_hparams import ModelHparams

model_hparams_to_tokenizer_family: Dict[Type[ModelHparams], str] = {
GPT2Hparams: 'gpt2',
BERTForClassificationHparams: 'bert',
BERTHparams: 'bert'
}


class SimpleModel(ComposerClassifier):
Expand Down Expand Up @@ -58,18 +43,6 @@ def __init__(self, num_features: int = 1, num_classes: int = 2) -> None:
self.fc2 = fc2


@dataclasses.dataclass
class SimpleModelHparams(ModelHparams):
num_features: int = hp.optional('number of features', default=1)
num_classes: int = hp.optional('number of output classes', default=2)

def initialize_object(self) -> SimpleModel:
return SimpleModel(
num_features=self.num_features,
num_classes=self.num_classes,
)


class SimpleConvModel(ComposerClassifier):
"""Small convolutional classifer.
Expand Down Expand Up @@ -105,113 +78,3 @@ def __init__(self, num_channels: int = 3, num_classes: int = 2) -> None:
# surgery tests
self.conv1 = conv1
self.conv2 = conv2


@dataclasses.dataclass
class SimpleConvModelHparams(ModelHparams):
num_channels: int = hp.optional('number of channels', default=3)
num_classes: int = hp.optional('number of output classes', default=2)

def initialize_object(self) -> SimpleConvModel:
return SimpleConvModel(
num_channels=self.num_channels,
num_classes=self.num_classes,
)


def configure_model_hparams_for_synthetic(model_hparams: ModelHparams) -> None:
# configure Transformer-based models for synthetic testing
if type(model_hparams) in model_hparams_to_tokenizer_family.keys():
assert isinstance(model_hparams, (BERTHparams, GPT2Hparams, BERTForClassificationHparams))
tokenizer_family = model_hparams_to_tokenizer_family[type(model_hparams)]

# force a non-pretrained model
model_hparams.use_pretrained = False
model_hparams.pretrained_model_name = None

# generate tokenizers and synthetic models
tokenizer = generate_synthetic_tokenizer(tokenizer_family=tokenizer_family)
model_hparams.model_config = generate_dummy_model_config(type(model_hparams), tokenizer)

# configure DeepLabV3 models for synthetic testing
if isinstance(model_hparams, DeepLabV3Hparams):
model_hparams.backbone_weights = None # prevent downloading pretrained weights during test
model_hparams.sync_bn = False # sync_bn throws an error when run on CPU


def generate_dummy_model_config(cls: Type[hp.Hparams], tokenizer) -> Dict[str, Any]:
model_to_dummy_mapping: Dict[Type[hp.Hparams], Dict[str, Any]] = {
BERTHparams: {
'architectures': ['BertForMaskedLM'],
'attention_probs_dropout_prob': 0.1,
'gradient_checkpointing': False,
'hidden_act': 'gelu',
'hidden_dropout_prob': 0.1,
'hidden_size': 64,
'initializer_range': 0.02,
'intermediate_size': 256,
'layer_norm_eps': 1e-12,
'max_position_embeddings': 512,
'model_type': 'bert',
'num_attention_heads': 1,
'num_hidden_layers': 1,
'pad_token_id': tokenizer.pad_token_id,
'position_embedding_type': 'absolute',
'transformers_version': '4.6.0.dev0',
'type_vocab_size': 2,
'use_cache': True,
'vocab_size': tokenizer.vocab_size,
},
GPT2Hparams: {
'activation_function': 'gelu_new',
'architectures': ['GPT2LMHeadModel'],
'attn_pdrop': 0.1,
'bos_token_id': tokenizer.cls_token_id,
'embd_pdrop': 0.1,
'eos_token_id': tokenizer.cls_token_id,
'initializer_range': 0.02,
'layer_norm_epsilon': 0.00001,
'model_type': 'gpt2',
'n_ctx': 128,
'n_embd': 64,
'n_head': 1,
'n_layer': 1,
'n_positions': 128,
'resid_pdrop': 0.1,
'summary_activation': None,
'summary_first_dropout': 0.1,
'summary_proj_to_labels': True,
'summary_type': 'cls_index',
'summary_use_proj': True,
'task_specific_params': {
'text-generation': {
'do_sample': True,
'max_length': 50
}
},
'vocab_size': tokenizer.vocab_size
},
BERTForClassificationHparams: {
'architectures': ['BertForSequenceClassification'],
'attention_probs_dropout_prob': 0.1,
'classifier_dropout': None,
'gradient_checkpointing': False,
'hidden_act': 'gelu',
'hidden_dropout_prob': 0.1,
'hidden_size': 64,
'initializer_range': 0.02,
'intermediate_size': 256,
'layer_norm_eps': 1e-12,
'max_position_embeddings': 512,
'model_type': 'bert',
'num_attention_heads': 1,
'num_hidden_layers': 1,
'pad_token_id': tokenizer.pad_token_id,
'position_embedding_type': 'absolute',
'transformers_version': '4.16.2',
'type_vocab_size': 2,
'use_cache': True,
'vocab_size': tokenizer.vocab_size
}
}
return model_to_dummy_mapping[cls]
10 changes: 0 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,12 @@
# Enforce deterministic mode before any tests start.
reproducibility.configure_deterministic_mode()

# during the pytest refactor transition, this flag
# indicates whether to include the deprecated fixtures.
# used for internal development.
_include_deprecated_fixtures = True

# Add the path of any pytest fixture files you want to make global
pytest_plugins = [
'tests.fixtures.new_fixtures',
'tests.fixtures.synthetic_hf_state',
]

if _include_deprecated_fixtures:
pytest_plugins += [
'tests.fixtures.dummy_fixtures',
]


def _add_option(parser: pytest.Parser, name: str, help: str, choices: Optional[List[str]] = None):
parser.addoption(
Expand Down
Loading

0 comments on commit e1654b4

Please sign in to comment.