From 1fc8aea3bcd5b735fd0ef5999f0adf6434989f6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CAvinash=E2=80=9D?= Date: Wed, 4 Sep 2019 14:01:16 -0400 Subject: [PATCH 1/7] Add Hyperopt example for BERT classifier --- examples/bert/README.md | 22 ++ examples/bert/bert_tpe_config_classifier.py | 11 + examples/bert/bert_with_tpe.py | 267 ++++++++++++++++++++ examples/bert/config_data.py | 2 +- examples/bert/requirements.txt | 3 +- texar/torch/run/executor.py | 59 +++-- 6 files changed, 339 insertions(+), 25 deletions(-) create mode 100644 examples/bert/bert_tpe_config_classifier.py create mode 100644 examples/bert/bert_with_tpe.py diff --git a/examples/bert/README.md b/examples/bert/README.md index b9368e49f..681c378a4 100644 --- a/examples/bert/README.md +++ b/examples/bert/README.md @@ -15,6 +15,7 @@ To summarize, this example showcases: * Building and fine-tuning on downstream tasks * Use of Texar `RecordData` module for data loading and processing * Use of Texar `Executor` module for simplified training loops and TensorBoard visualization +* Use of Hyperopt library to tune hyperparameters with `Executor` module Future work: @@ -178,3 +179,24 @@ tensorboard --logdir runs/ ``` ![Visualizing loss/accuarcy on Tensorboard](tbx.png) + +## Hyperparameter tuning with Executor + +To run this example, please install `hyperopt` by issuing the following command + +```commandline +pip install hyperopt +``` + +`bert_with_tpe.py` shows an example of how to tune hyperparameters with Executor using `hyperopt`. +To run this example, run the following command + +```commandline +python bert_with_tpe.py +``` + +In this simple example, the hyperparameters to be tuned are provided as a `dict` in +`bert_tpe_config_classifier.py` which are fed into `objective_func()` . We use `TPE` algorithm for +tuning the hyperparams (provided in `hyperopt` library). The example runs for 3 trials to find the +best hyperparam settings. The final model is saved in `/model/{exp_number}` folder. More +information about the libary can be found at [Hyperopt](https://github.com/hyperopt/hyperopt) diff --git a/examples/bert/bert_tpe_config_classifier.py b/examples/bert/bert_tpe_config_classifier.py new file mode 100644 index 000000000..b04250943 --- /dev/null +++ b/examples/bert/bert_tpe_config_classifier.py @@ -0,0 +1,11 @@ +name = "bert_classifier" +hidden_size = 768 +clas_strategy = "cls_time" +dropout = 0.1 +num_classes = 2 + +# hyperparams +hyperparams = { + "optimizer.warmup_steps": {"start": 10000, "end": 20000, "dtype": int}, + "optimizer.static_lr": {"start": 1e-3, "end": 1e-2, "dtype": float} +} diff --git a/examples/bert/bert_with_tpe.py b/examples/bert/bert_with_tpe.py new file mode 100644 index 000000000..f6821f795 --- /dev/null +++ b/examples/bert/bert_with_tpe.py @@ -0,0 +1,267 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import functools +import importlib +import logging +import shutil +from typing import Dict + +import torch +from torch import nn +import torch.nn.functional as F + +import hyperopt as hpo + +import texar.torch as tx +from texar.torch.run import * +from texar.torch.modules import BERTClassifier + +from utils import model_utils + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--config-downstream", default="bert_tpe_config_classifier", + help="Configuration of the downstream part of the model") +parser.add_argument( + '--pretrained-model-name', type=str, default='bert-base-uncased', + choices=tx.modules.BERTEncoder.available_checkpoints(), + help="Name of the pre-trained checkpoint to load.") +parser.add_argument( + "--config-data", default="config_data", help="The dataset config.") +parser.add_argument( + "--output-dir", default="output/", + help="The output directory where the model checkpoints will be written.") +parser.add_argument( + "--checkpoint", type=str, default=None, + help="Path to a model checkpoint (including bert modules) to restore from.") +parser.add_argument( + "--do-train", action="store_true", help="Whether to run training.") +parser.add_argument( + "--do-eval", action="store_true", + help="Whether to run eval on the dev set.") +parser.add_argument( + "--do-test", action="store_true", + help="Whether to run test on the test set.") +args = parser.parse_args() + +config_data = importlib.import_module(args.config_data) +config_downstream = importlib.import_module(args.config_downstream) +config_downstream = { + k: v for k, v in config_downstream.__dict__.items() + if not k.startswith('__')} + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +logging.root.setLevel(logging.INFO) + + +class ModelWrapper(nn.Module): + def __init__(self, model: BERTClassifier): + super().__init__() + self.model = model + + def _compute_loss(self, logits, labels): + r"""Compute loss. + """ + if self.model.is_binary: + loss = F.binary_cross_entropy( + logits.view(-1), labels.view(-1), reduction='mean') + else: + loss = F.cross_entropy( + logits.view(-1, self.model.num_classes), + labels.view(-1), reduction='mean') + return loss + + def forward(self, # type: ignore + batch: tx.data.Batch) -> Dict[str, torch.Tensor]: + input_ids = batch["input_ids"] + segment_ids = batch["segment_ids"] + labels = batch["label_ids"] + + input_length = (1 - (input_ids == 0).int()).sum(dim=1) + + logits, preds = self.model(input_ids, input_length, segment_ids) + + loss = self._compute_loss(logits, labels) + + return {"loss": loss, "preds": preds} + + def predict(self, batch: tx.data.Batch) -> Dict[str, torch.Tensor]: + input_ids = batch["input_ids"] + segment_ids = batch["segment_ids"] + + input_length = (1 - (input_ids == 0).int()).sum(dim=1) + + _, preds = self.model(input_ids, input_length, segment_ids) + + return {"preds": preds} + + +class TPE: + def __init__(self, model_config=None): + tx.utils.maybe_create_dir(args.output_dir) + + self.model_config = model_config + + # create datasets + self.train_dataset = tx.data.RecordData( + hparams=config_data.train_hparam, device=device) + self.eval_dataset = tx.data.RecordData( + hparams=config_data.eval_hparam, device=device) + + # Builds BERT + model = tx.modules.BERTClassifier( + pretrained_model_name=args.pretrained_model_name, + hparams=self.model_config) + self.model = ModelWrapper(model=model) + self.model.to(device) + + # batching + self.batching_strategy = tx.data.TokenCountBatchingStrategy( + max_tokens=config_data.max_batch_tokens) + + # logging formats + self.log_format = "{time} : Epoch {epoch:2d} @ {iteration:6d}it " \ + "({progress}%, {speed}), " \ + "lr = {lr:.9e}, loss = {loss:.3f}" + self.valid_log_format = "{time} : Epoch {epoch}, " \ + "{split} accuracy = {Accuracy:.3f}, " \ + "loss = {loss:.3f}" + self.valid_progress_log_format = "{time} : Evaluating on " \ + "{split} ({progress}%, {speed})" + + # exp number + self.exp_number = 1 + + self.optim = tx.core.BertAdam + + def objective_func(self, params: Dict): + + print(f"Using {params} for trial {self.exp_number}") + + # Loads data + num_train_data = config_data.num_train_data + num_train_steps = int(num_train_data / config_data.train_batch_size * + config_data.max_train_epoch) + + # hyperparams + num_warmup_steps = params["optimizer.warmup_steps"] + static_lr = params["optimizer.static_lr"] + + vars_with_decay = [] + vars_without_decay = [] + for name, param in self.model.named_parameters(): + if 'layer_norm' in name or name.endswith('bias'): + vars_without_decay.append(param) + else: + vars_with_decay.append(param) + + opt_params = [{ + 'params': vars_with_decay, + 'weight_decay': 0.01, + }, { + 'params': vars_without_decay, + 'weight_decay': 0.0, + }] + + optim = self.optim(opt_params, betas=(0.9, 0.999), eps=1e-6, + lr=static_lr) + + scheduler = torch.optim.lr_scheduler.LambdaLR( + optim, functools.partial(model_utils.get_lr_multiplier, + total_steps=num_train_steps, + warmup_steps=num_warmup_steps)) + + valid_metric = metric.Accuracy(pred_name="preds", + label_name="label_ids") + checkpoint_dir = f"./models/exp{self.exp_number}" + + executor = Executor( + # supply executor with the model + model=self.model, + # define datasets + train_data=self.train_dataset, + valid_data=self.eval_dataset, + batching_strategy=self.batching_strategy, + device=device, + # training and stopping details + optimizer=optim, + lr_scheduler=scheduler, + stop_training_on=cond.epoch(config_data.max_train_epoch), + # logging details + log_every=[cond.epoch(1)], + # logging format + log_format=self.log_format, + # define metrics + train_metrics=[ + ("loss", metric.RunningAverage(1)), + ("lr", metric.LR(optim))], + valid_metrics=[valid_metric, ("loss", metric.Average())], + validate_every=cond.epoch(1), + save_every=cond.epoch(config_data.max_train_epoch), + checkpoint_dir=checkpoint_dir, + max_to_keep=1, + show_live_progress=True, + print_model_arch=False + ) + + executor.train() + + print(f"Loss on the valid dataset " + f"{executor.valid_metrics['loss'].value()}") + self.exp_number += 1 + + return { + "loss": executor.valid_metrics["loss"].value(), + "status": hpo.STATUS_OK, + "model": checkpoint_dir + } + + def run(self, hyperparams: Dict): + space = {} + for k, v in hyperparams.items(): + if isinstance(v, dict): + if v["dtype"] == int: + space[k] = hpo.hp.choice( + k, range(v["start"], v["end"])) + else: + space[k] = hpo.hp.uniform(k, v["start"], v["end"]) + trials = hpo.Trials() + hpo.fmin(fn=self.objective_func, + space=space, + algo=hpo.tpe.suggest, + max_evals=3, + trials=trials) + _, best_trial = min((trial["result"]["loss"], trial) + for trial in trials.trials) + + # delete all the other models + for trial in trials.trials: + if trial is not best_trial: + shutil.rmtree(trial["result"]["model"]) + + +def main(): + model_config = {k: v for k, v in config_downstream.items() if + k != "hyperparams"} + tpe = TPE(model_config=model_config) + hyperparams = config_downstream["hyperparams"] + tpe.run(hyperparams) + + +if __name__ == '__main__': + main() diff --git a/examples/bert/config_data.py b/examples/bert/config_data.py index 256e1ef42..ee8b44636 100644 --- a/examples/bert/config_data.py +++ b/examples/bert/config_data.py @@ -7,7 +7,7 @@ max_batch_tokens = 128 train_batch_size = 32 -max_train_epoch = 5 +max_train_epoch = 3 display_steps = 50 # Print training loss every display_steps; -1 to disable # tbx config diff --git a/examples/bert/requirements.txt b/examples/bert/requirements.txt index ee8755313..5759aecef 100644 --- a/examples/bert/requirements.txt +++ b/examples/bert/requirements.txt @@ -1,2 +1,3 @@ tensorflow -tensorboardX>=1.8 \ No newline at end of file +tensorboardX>=1.8 +hyperopt \ No newline at end of file diff --git a/texar/torch/run/executor.py b/texar/torch/run/executor.py index 69e204f0d..f3d48168b 100644 --- a/texar/torch/run/executor.py +++ b/texar/torch/run/executor.py @@ -15,7 +15,6 @@ The Executor module. """ -import atexit import pickle import random import re @@ -825,28 +824,8 @@ def __init__(self, model: nn.Module, self._log_destination: List[IO[str]] = [] self._log_destination_is_tty: List[bool] = [] self._opened_files: List[IO[str]] = [] - log_destination = log_destination or self._defaults["log_destination"] - for dest in utils.to_list(log_destination): # type: ignore - if isinstance(dest, (str, Path)): - # Append to the logs to prevent accidentally overwriting - # previous logs. - file = open(dest, "a") - self._opened_files.append(file) - self._log_destination_is_tty.append(False) - else: - if not hasattr(dest, "write"): - raise ValueError(f"Log destination {dest} is not a " - f"file-like object") - try: - isatty = dest.isatty() # type: ignore - except AttributeError: - isatty = False - file = dest # type: ignore - self._log_destination_is_tty.append(isatty) - self._log_destination.append(file) - - # Close files when program exits. - atexit.register(self._close_files) + self.log_destination = log_destination or \ + self._defaults["log_destination"] # Training loop self.train_metrics = utils.to_metric_dict(train_metrics) @@ -1300,6 +1279,9 @@ def remove_action(self) -> None: def train(self): r"""Start the training loop. """ + # open the log files + self._open_files() + if self._directory_exists: self.write_log( f"Specified checkpoint directory '{self.checkpoint_dir}' " @@ -1365,6 +1347,10 @@ def _try_get_data_size(executor: 'Executor'): self.write_log("Training terminated", mode='info') finally: self._train_tracker.stop() + + # close the log files + self._close_files() + self._fire_event(Event.Training, True) def test(self, dataset: OptionalDict[DataBase] = None): @@ -1382,6 +1368,9 @@ def test(self, dataset: OptionalDict[DataBase] = None): If `None`, :attr:`test_data` from the constructor arguments is used. Defaults to `None`. """ + # open the log files + self._open_files() + if dataset is None and self.test_data is None: raise ValueError("No testing dataset is specified") if len(self.test_metrics) == 0: @@ -1424,6 +1413,10 @@ def test(self, dataset: OptionalDict[DataBase] = None): self._test_tracker.stop() self._fire_event(Event.Testing, True) + + # close the log files + self._close_files() + self.model.train(model_mode) def _register_logging_actions(self, show_live_progress: List[str]): @@ -1707,6 +1700,26 @@ def _register_hook(self, event_point: EventPoint, action: ActionFn, raise ValueError( f"Specified hook point {event_point} is invalid") from None + def _open_files(self): + for dest in utils.to_list(self.log_destination): + if isinstance(dest, (str, Path)): + # Append to the logs to prevent accidentally overwriting + # previous logs. + file = open(dest, "a") + self._opened_files.append(file) + self._log_destination_is_tty.append(False) + else: + if not hasattr(dest, "write"): + raise ValueError(f"Log destination {dest} is not a " + f"file-like object") + try: + isatty = dest.isatty() + except AttributeError: + isatty = False + file = dest + self._log_destination_is_tty.append(isatty) + self._log_destination.append(file) + def _close_files(self): for file in self._opened_files: file.close() From e175fa6fc909afa524194cc8d87618e98d9c6fbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CAvinash=E2=80=9D?= Date: Wed, 4 Sep 2019 18:36:32 -0400 Subject: [PATCH 2/7] Address review comments --- examples/bert/README.md | 16 ++-- ... => bert_hypertuning_config_classifier.py} | 0 ...h_tpe.py => bert_with_hypertuning_main.py} | 86 +++++++++++++++++-- examples/bert/config_data.py | 2 +- examples/bert/requirements.txt | 6 +- texar/torch/run/executor.py | 12 ++- 6 files changed, 100 insertions(+), 22 deletions(-) rename examples/bert/{bert_tpe_config_classifier.py => bert_hypertuning_config_classifier.py} (100%) rename examples/bert/{bert_with_tpe.py => bert_with_hypertuning_main.py} (74%) diff --git a/examples/bert/README.md b/examples/bert/README.md index 681c378a4..07a4ae8b3 100644 --- a/examples/bert/README.md +++ b/examples/bert/README.md @@ -15,7 +15,8 @@ To summarize, this example showcases: * Building and fine-tuning on downstream tasks * Use of Texar `RecordData` module for data loading and processing * Use of Texar `Executor` module for simplified training loops and TensorBoard visualization -* Use of Hyperopt library to tune hyperparameters with `Executor` module +* Use of [Hyperopt]((https://github.com/hyperopt/hyperopt)) library to tune hyperparameters with +`Executor` module Future work: @@ -188,15 +189,16 @@ To run this example, please install `hyperopt` by issuing the following command pip install hyperopt ``` -`bert_with_tpe.py` shows an example of how to tune hyperparameters with Executor using `hyperopt`. +`bert_with_hypertuning_main.py` shows an example of how to tune hyperparameters with Executor using `hyperopt`. To run this example, run the following command ```commandline -python bert_with_tpe.py +python bert_with_hypertuning_main.py ``` In this simple example, the hyperparameters to be tuned are provided as a `dict` in -`bert_tpe_config_classifier.py` which are fed into `objective_func()` . We use `TPE` algorithm for -tuning the hyperparams (provided in `hyperopt` library). The example runs for 3 trials to find the -best hyperparam settings. The final model is saved in `/model/{exp_number}` folder. More -information about the libary can be found at [Hyperopt](https://github.com/hyperopt/hyperopt) +`bert_hypertuning_config_classifier.py` which are fed into `objective_func()` . We use `TPE` +(Tree-structured Parzen Estimator) algorithm for tuning the hyperparams (provided in `hyperopt` +library). The example runs for 3 trials to find the best hyperparam settings. The final model is +saved in `output_dir` provided by the user. More information about the libary can be +found at [Hyperopt](https://github.com/hyperopt/hyperopt) diff --git a/examples/bert/bert_tpe_config_classifier.py b/examples/bert/bert_hypertuning_config_classifier.py similarity index 100% rename from examples/bert/bert_tpe_config_classifier.py rename to examples/bert/bert_hypertuning_config_classifier.py diff --git a/examples/bert/bert_with_tpe.py b/examples/bert/bert_with_hypertuning_main.py similarity index 74% rename from examples/bert/bert_with_tpe.py rename to examples/bert/bert_with_hypertuning_main.py index f6821f795..ba09ee34c 100644 --- a/examples/bert/bert_with_tpe.py +++ b/examples/bert/bert_with_hypertuning_main.py @@ -34,7 +34,7 @@ parser = argparse.ArgumentParser() parser.add_argument( - "--config-downstream", default="bert_tpe_config_classifier", + "--config-downstream", default="bert_hypertuning_config_classifier", help="Configuration of the downstream part of the model") parser.add_argument( '--pretrained-model-name', type=str, default='bert-base-uncased', @@ -70,13 +70,21 @@ class ModelWrapper(nn.Module): + r"""This class wraps a model (in this case a BERT classifier) and implements + :meth:`forward` and :meth:`predict` to conform to the requirements of + :class:`Executor` class. Particularly, :meth:`forward` returns a dict with + keys "loss" and "preds" and :meth:`predict` returns a dict with key "preds". + + Args: + `model`: BERTClassifier + A BERTClassifier model + """ + def __init__(self, model: BERTClassifier): super().__init__() self.model = model def _compute_loss(self, logits, labels): - r"""Compute loss. - """ if self.model.is_binary: loss = F.binary_cross_entropy( logits.view(-1), labels.view(-1), reduction='mean') @@ -88,6 +96,18 @@ def _compute_loss(self, logits, labels): def forward(self, # type: ignore batch: tx.data.Batch) -> Dict[str, torch.Tensor]: + r"""Run forward through the network and return a dict to be consumed + by the :class:`Executor`. This method will be called by + :class:``Executor` during training. + + Args: + `batch`: tx.data.Batch + A batch of inputs to be passed through the network + + Returns: + A dict with keys "loss" and "preds" containing the loss and + predictions on :attr:`batch` respectively. + """ input_ids = batch["input_ids"] segment_ids = batch["segment_ids"] labels = batch["label_ids"] @@ -101,6 +121,15 @@ def forward(self, # type: ignore return {"loss": loss, "preds": preds} def predict(self, batch: tx.data.Batch) -> Dict[str, torch.Tensor]: + r"""Predict the labels for the :attr:`batch` of examples. This method + will be called instead of :meth:`forward` during validation or testing, + if :class:`Executor`'s :attr:`validate_mode` or :attr:`test_mode` is set + to ``"predict"`` instead of ``"eval"``. + + Args: + `batch`: tx.data.Batch + A batch of inputs to run prediction on + """ input_ids = batch["input_ids"] segment_ids = batch["segment_ids"] @@ -112,11 +141,23 @@ def predict(self, batch: tx.data.Batch) -> Dict[str, torch.Tensor]: class TPE: - def __init__(self, model_config=None): + r""":class:`TPE` uses Tree-structured Parzen Estimator algorithm from + `hyperopt` for hyperparameter tuning. + + Args: + model_config: Dict + A conf dict which is passed to BERT classifier + output_dir: str + A path to store the models + + """ + def __init__(self, model_config: Dict, output_dir: str = "output/"): tx.utils.maybe_create_dir(args.output_dir) self.model_config = model_config + self.output_dir = output_dir + # create datasets self.train_dataset = tx.data.RecordData( hparams=config_data.train_hparam, device=device) @@ -150,7 +191,31 @@ def __init__(self, model_config=None): self.optim = tx.core.BertAdam def objective_func(self, params: Dict): - + r"""Compute a "loss" for a given hyperparameter values. This function is + passed to hyperopt's ``"fmin"`` (see the :meth:`run` method) function + and gets repeatedly called to find the best set of hyperparam values. + Below is an example of how to use this method + + .. code-block:: python + + import hyperopt as hpo + + trials = hpo.Trials() + hpo.fmin(fn=self.objective_func, + space=space, + algo=hpo.tpe.suggest, + max_evals=3, + trials=trials) + + Args: + params: Dict + A `(key, value)` dict representing the ``"value"`` to try for + the hyperparam ``"key"`` + + Returns: + A dict with keys "loss", "status" and "model" indicating the loss + for this trial, the status, and the path to the saved model. + """ print(f"Using {params} for trial {self.exp_number}") # Loads data @@ -188,7 +253,7 @@ def objective_func(self, params: Dict): valid_metric = metric.Accuracy(pred_name="preds", label_name="label_ids") - checkpoint_dir = f"./models/exp{self.exp_number}" + checkpoint_dir = f"./{self.output_dir}/exp{self.exp_number}" executor = Executor( # supply executor with the model @@ -232,6 +297,13 @@ def objective_func(self, params: Dict): } def run(self, hyperparams: Dict): + r"""Run the TPE algorithm with hyperparameters :attr:`hyperparams` + + Args: + hyperparams: Dict + The `(key, value)` pairs of hyperparameters along their range of + values. + """ space = {} for k, v in hyperparams.items(): if isinstance(v, dict): @@ -258,7 +330,7 @@ def run(self, hyperparams: Dict): def main(): model_config = {k: v for k, v in config_downstream.items() if k != "hyperparams"} - tpe = TPE(model_config=model_config) + tpe = TPE(model_config=model_config, output_dir=args.output_dir) hyperparams = config_downstream["hyperparams"] tpe.run(hyperparams) diff --git a/examples/bert/config_data.py b/examples/bert/config_data.py index ee8b44636..256e1ef42 100644 --- a/examples/bert/config_data.py +++ b/examples/bert/config_data.py @@ -7,7 +7,7 @@ max_batch_tokens = 128 train_batch_size = 32 -max_train_epoch = 3 +max_train_epoch = 5 display_steps = 50 # Print training loss every display_steps; -1 to disable # tbx config diff --git a/examples/bert/requirements.txt b/examples/bert/requirements.txt index 5759aecef..ce3dc99fd 100644 --- a/examples/bert/requirements.txt +++ b/examples/bert/requirements.txt @@ -1,3 +1,3 @@ -tensorflow -tensorboardX>=1.8 -hyperopt \ No newline at end of file +tensorflow # used for loading BERT official model checkpoint +tensorboardX>=1.8 # used only in bert_classifier_using_executor_main.py +hyperopt # used only in bert_with_hypertuning_main.py \ No newline at end of file diff --git a/texar/torch/run/executor.py b/texar/torch/run/executor.py index f3d48168b..790111390 100644 --- a/texar/torch/run/executor.py +++ b/texar/torch/run/executor.py @@ -1348,11 +1348,11 @@ def _try_get_data_size(executor: 'Executor'): finally: self._train_tracker.stop() + self._fire_event(Event.Training, True) + # close the log files self._close_files() - self._fire_event(Event.Training, True) - def test(self, dataset: OptionalDict[DataBase] = None): r"""Start the test loop. @@ -1414,11 +1414,11 @@ def test(self, dataset: OptionalDict[DataBase] = None): self._fire_event(Event.Testing, True) + self.model.train(model_mode) + # close the log files self._close_files() - self.model.train(model_mode) - def _register_logging_actions(self, show_live_progress: List[str]): # Register logging actions. Points = Sequence[Union[Condition, Event]] @@ -1701,6 +1701,10 @@ def _register_hook(self, event_point: EventPoint, action: ActionFn, f"Specified hook point {event_point} is invalid") from None def _open_files(self): + self._opened_files = [] + self._log_destination = [] + self._log_destination_is_tty = [] + for dest in utils.to_list(self.log_destination): if isinstance(dest, (str, Path)): # Append to the logs to prevent accidentally overwriting From 61dc09492a7343c3d340fd4013809f5f2c12cb2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CAvinash=E2=80=9D?= Date: Thu, 5 Sep 2019 16:31:33 -0400 Subject: [PATCH 3/7] Address review comments --- .../bert_hypertuning_config_classifier.py | 11 ------ examples/bert/bert_with_hypertuning_main.py | 37 ++++++++++--------- examples/bert/config_classifier.py | 6 +++ 3 files changed, 26 insertions(+), 28 deletions(-) delete mode 100644 examples/bert/bert_hypertuning_config_classifier.py diff --git a/examples/bert/bert_hypertuning_config_classifier.py b/examples/bert/bert_hypertuning_config_classifier.py deleted file mode 100644 index b04250943..000000000 --- a/examples/bert/bert_hypertuning_config_classifier.py +++ /dev/null @@ -1,11 +0,0 @@ -name = "bert_classifier" -hidden_size = 768 -clas_strategy = "cls_time" -dropout = 0.1 -num_classes = 2 - -# hyperparams -hyperparams = { - "optimizer.warmup_steps": {"start": 10000, "end": 20000, "dtype": int}, - "optimizer.static_lr": {"start": 1e-3, "end": 1e-2, "dtype": float} -} diff --git a/examples/bert/bert_with_hypertuning_main.py b/examples/bert/bert_with_hypertuning_main.py index ba09ee34c..25fe8bcb6 100644 --- a/examples/bert/bert_with_hypertuning_main.py +++ b/examples/bert/bert_with_hypertuning_main.py @@ -34,7 +34,7 @@ parser = argparse.ArgumentParser() parser.add_argument( - "--config-downstream", default="bert_hypertuning_config_classifier", + "--config-downstream", default="config_classifier", help="Configuration of the downstream part of the model") parser.add_argument( '--pretrained-model-name', type=str, default='bert-base-uncased', @@ -72,8 +72,9 @@ class ModelWrapper(nn.Module): r"""This class wraps a model (in this case a BERT classifier) and implements :meth:`forward` and :meth:`predict` to conform to the requirements of - :class:`Executor` class. Particularly, :meth:`forward` returns a dict with - keys "loss" and "preds" and :meth:`predict` returns a dict with key "preds". + :class:`texar.torch.run.Executor` class. Particularly, :meth:`forward` + returns a dict with keys "loss" and "preds" and :meth:`predict` returns a + dict with key "preds". Args: `model`: BERTClassifier @@ -96,13 +97,15 @@ def _compute_loss(self, logits, labels): def forward(self, # type: ignore batch: tx.data.Batch) -> Dict[str, torch.Tensor]: - r"""Run forward through the network and return a dict to be consumed - by the :class:`Executor`. This method will be called by - :class:``Executor` during training. + r"""Run forward through the model and return a dict to be consumed + by the :class:`texar.torch.run.Executor`. This method will be called by + :class:`texar.torch.run.Executor` during training. See + https://texar-pytorch.readthedocs.io/en/latest/code/run.html#executor-general-args + for more details. Args: - `batch`: tx.data.Batch - A batch of inputs to be passed through the network + `batch`: :class:`texar.data.Batch` + A batch of inputs to be passed through the model Returns: A dict with keys "loss" and "preds" containing the loss and @@ -123,8 +126,8 @@ def forward(self, # type: ignore def predict(self, batch: tx.data.Batch) -> Dict[str, torch.Tensor]: r"""Predict the labels for the :attr:`batch` of examples. This method will be called instead of :meth:`forward` during validation or testing, - if :class:`Executor`'s :attr:`validate_mode` or :attr:`test_mode` is set - to ``"predict"`` instead of ``"eval"``. + if :class:`texar.torch.run.Executor`'s :attr:`validate_mode` or + :attr:`test_mode` is set to ``"predict"`` instead of ``"eval"``. Args: `batch`: tx.data.Batch @@ -152,7 +155,7 @@ class TPE: """ def __init__(self, model_config: Dict, output_dir: str = "output/"): - tx.utils.maybe_create_dir(args.output_dir) + tx.utils.maybe_create_dir(output_dir) self.model_config = model_config @@ -190,8 +193,8 @@ def __init__(self, model_config: Dict, output_dir: str = "output/"): self.optim = tx.core.BertAdam - def objective_func(self, params: Dict): - r"""Compute a "loss" for a given hyperparameter values. This function is + def objective_func(self, hyperparams: Dict): + r"""Compute "loss" for a given hyperparameter values. This function is passed to hyperopt's ``"fmin"`` (see the :meth:`run` method) function and gets repeatedly called to find the best set of hyperparam values. Below is an example of how to use this method @@ -208,7 +211,7 @@ def objective_func(self, params: Dict): trials=trials) Args: - params: Dict + hyperparams: Dict A `(key, value)` dict representing the ``"value"`` to try for the hyperparam ``"key"`` @@ -216,7 +219,7 @@ def objective_func(self, params: Dict): A dict with keys "loss", "status" and "model" indicating the loss for this trial, the status, and the path to the saved model. """ - print(f"Using {params} for trial {self.exp_number}") + print(f"Using {hyperparams} for trial {self.exp_number}") # Loads data num_train_data = config_data.num_train_data @@ -224,8 +227,8 @@ def objective_func(self, params: Dict): config_data.max_train_epoch) # hyperparams - num_warmup_steps = params["optimizer.warmup_steps"] - static_lr = params["optimizer.static_lr"] + num_warmup_steps = hyperparams["optimizer.warmup_steps"] + static_lr = hyperparams["optimizer.static_lr"] vars_with_decay = [] vars_without_decay = [] diff --git a/examples/bert/config_classifier.py b/examples/bert/config_classifier.py index c49420589..3000603ec 100644 --- a/examples/bert/config_classifier.py +++ b/examples/bert/config_classifier.py @@ -3,3 +3,9 @@ clas_strategy = "cls_time" dropout = 0.1 num_classes = 2 + +# This hyperparams is used in bert_with_hypertuning_main.py example +hyperparams = { + "optimizer.warmup_steps": {"start": 10000, "end": 20000, "dtype": int}, + "optimizer.static_lr": {"start": 1e-3, "end": 1e-2, "dtype": float} +} From f7e7a2ad2427ab87c06cc7282f1ae3641e75d917 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CAvinash=E2=80=9D?= Date: Thu, 5 Sep 2019 19:43:30 -0400 Subject: [PATCH 4/7] Remove hyperparams from config_downstream --- examples/bert/bert_classifier_main.py | 2 +- examples/bert/bert_classifier_using_executor_main.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/bert/bert_classifier_main.py b/examples/bert/bert_classifier_main.py index 2ba0d3f3b..b04530a2b 100644 --- a/examples/bert/bert_classifier_main.py +++ b/examples/bert/bert_classifier_main.py @@ -56,7 +56,7 @@ config_downstream = importlib.import_module(args.config_downstream) config_downstream = { k: v for k, v in config_downstream.__dict__.items() - if not k.startswith('__')} + if not k.startswith('__') and k != "hyperparams"} device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/examples/bert/bert_classifier_using_executor_main.py b/examples/bert/bert_classifier_using_executor_main.py index 9b31e4249..3adccc302 100644 --- a/examples/bert/bert_classifier_using_executor_main.py +++ b/examples/bert/bert_classifier_using_executor_main.py @@ -65,7 +65,7 @@ config_downstream = importlib.import_module(args.config_downstream) config_downstream = { k: v for k, v in config_downstream.__dict__.items() - if not k.startswith('__')} + if not k.startswith('__') and k != "hyperparams"} device = torch.device("cuda" if torch.cuda.is_available() else "cpu") From 4757b6a0252f15833ec9f2b34d2253a8a1be3810 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CAvinash=E2=80=9D?= Date: Thu, 5 Sep 2019 19:44:12 -0400 Subject: [PATCH 5/7] Add URL for Batch object. Removed unused args --- examples/bert/bert_with_hypertuning_main.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/examples/bert/bert_with_hypertuning_main.py b/examples/bert/bert_with_hypertuning_main.py index 25fe8bcb6..bd0700064 100644 --- a/examples/bert/bert_with_hypertuning_main.py +++ b/examples/bert/bert_with_hypertuning_main.py @@ -48,14 +48,6 @@ parser.add_argument( "--checkpoint", type=str, default=None, help="Path to a model checkpoint (including bert modules) to restore from.") -parser.add_argument( - "--do-train", action="store_true", help="Whether to run training.") -parser.add_argument( - "--do-eval", action="store_true", - help="Whether to run eval on the dev set.") -parser.add_argument( - "--do-test", action="store_true", - help="Whether to run test on the test set.") args = parser.parse_args() config_data = importlib.import_module(args.config_data) @@ -104,7 +96,9 @@ def forward(self, # type: ignore for more details. Args: - `batch`: :class:`texar.data.Batch` + `batch`: :class:`texar.data.Batch`. (See + https://texar-pytorch.readthedocs.io/en/latest/code/data.html#texar.torch.data.Batch + for more details) A batch of inputs to be passed through the model Returns: @@ -287,6 +281,9 @@ def objective_func(self, hyperparams: Dict): print_model_arch=False ) + if args.checkpoint is not None: + executor.load(args.checkpoint) + executor.train() print(f"Loss on the valid dataset " From 72df18c721db098f1eb60851ecc4c46a7c15c9ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CAvinash=E2=80=9D?= Date: Thu, 5 Sep 2019 19:44:49 -0400 Subject: [PATCH 6/7] Add docs for Batch and FieldBatch --- docs/code/data.rst | 10 +++++++++ texar/torch/data/data/dataset_utils.py | 30 +++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/docs/code/data.rst b/docs/code/data.rst index 9f1c18663..032d32c50 100644 --- a/docs/code/data.rst +++ b/docs/code/data.rst @@ -130,6 +130,16 @@ Data Loaders Data Iterators =============== +:hidden:`Batch` +~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: texar.torch.data.Batch + :members: + +:hidden:`FieldBatch` +~~~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: texar.torch.data.FieldBatch + :members: + :hidden:`DataIterator` ~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: texar.torch.data.DataIterator diff --git a/texar/torch/data/data/dataset_utils.py b/texar/torch/data/data/dataset_utils.py index edee4d2d5..d5297120d 100644 --- a/texar/torch/data/data/dataset_utils.py +++ b/texar/torch/data/data/dataset_utils.py @@ -57,9 +57,33 @@ def connect_name(lhs_name, rhs_name): class Batch: - r"""Wrapper over Python dictionaries representing a batch. This provides a - common interface with :class:`~texar.torch.data.data.dataset_utils.Batch` - that allows accessing via attributes. + r"""Wrapper over Python dictionaries representing a batch. It provides a + dictionary-like interface to access its fields. This class can be used in + the followed way + + .. code-block:: python + + hparams = { + 'dataset': { 'files': 'data.txt', 'vocab_file': 'vocab.txt' }, + 'batch_size': 1 + } + + data = MonoTextData(hparams) + iterator = DataIterator(data) + model = BERTEncoder(pretrained_model_name="bert-base-uncased") + + for batch in iterator: + # batch is Batch object and contains the following fields + # batch == { + # 'text': [['', 'example', 'sequence', '']], + # 'text_ids': [[1, 5, 10, 2]], + # 'length': [4] + # } + + input_ids = torch.tensor(batch['text_ids']) + input_length = (1 - (input_ids == 0).int()).sum(dim=1) + + bert_embeddings, _ = model(input_ids, input_length) """ def __init__(self, batch_size: int, batch: Optional[Dict[str, Any]] = None, From aa750234daf75955ca21dab51ed4ddc73ab57b05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CAvinash=E2=80=9D?= Date: Thu, 5 Sep 2019 22:58:43 -0400 Subject: [PATCH 7/7] Address review comments --- docs/code/data.rst | 5 -- examples/bert/bert_with_hypertuning_main.py | 7 ++- texar/torch/data/data/dataset_utils.py | 55 ++++++--------------- 3 files changed, 21 insertions(+), 46 deletions(-) diff --git a/docs/code/data.rst b/docs/code/data.rst index 032d32c50..0c9e0d1ea 100644 --- a/docs/code/data.rst +++ b/docs/code/data.rst @@ -135,11 +135,6 @@ Data Iterators .. autoclass:: texar.torch.data.Batch :members: -:hidden:`FieldBatch` -~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: texar.torch.data.FieldBatch - :members: - :hidden:`DataIterator` ~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: texar.torch.data.DataIterator diff --git a/examples/bert/bert_with_hypertuning_main.py b/examples/bert/bert_with_hypertuning_main.py index bd0700064..a468be9ba 100644 --- a/examples/bert/bert_with_hypertuning_main.py +++ b/examples/bert/bert_with_hypertuning_main.py @@ -15,6 +15,7 @@ import argparse import functools import importlib +import sys import logging import shutil from typing import Dict @@ -97,8 +98,8 @@ def forward(self, # type: ignore Args: `batch`: :class:`texar.data.Batch`. (See - https://texar-pytorch.readthedocs.io/en/latest/code/data.html#texar.torch.data.Batch - for more details) + https://texar-pytorch.readthedocs.io/en/latest/code/data.html#texar.torch.data.Batch + for more details) A batch of inputs to be passed through the model Returns: @@ -251,6 +252,7 @@ def objective_func(self, hyperparams: Dict): valid_metric = metric.Accuracy(pred_name="preds", label_name="label_ids") checkpoint_dir = f"./{self.output_dir}/exp{self.exp_number}" + log_file = f"./{self.output_dir}/log.txt" executor = Executor( # supply executor with the model @@ -266,6 +268,7 @@ def objective_func(self, hyperparams: Dict): stop_training_on=cond.epoch(config_data.max_train_epoch), # logging details log_every=[cond.epoch(1)], + log_destination=[sys.stdout, log_file], # logging format log_format=self.log_format, # define metrics diff --git a/texar/torch/data/data/dataset_utils.py b/texar/torch/data/data/dataset_utils.py index d5297120d..bfe15f0ed 100644 --- a/texar/torch/data/data/dataset_utils.py +++ b/texar/torch/data/data/dataset_utils.py @@ -12,7 +12,6 @@ 'padded_batch', 'connect_name', 'Batch', - 'FieldBatch', '_LazyStrategy', '_CacheStrategy', ] @@ -63,27 +62,26 @@ class Batch: .. code-block:: python - hparams = { - 'dataset': { 'files': 'data.txt', 'vocab_file': 'vocab.txt' }, - 'batch_size': 1 - } + hparams = { + 'dataset': { 'files': 'data.txt', 'vocab_file': 'vocab.txt' }, + 'batch_size': 1 + } - data = MonoTextData(hparams) - iterator = DataIterator(data) - model = BERTEncoder(pretrained_model_name="bert-base-uncased") + data = MonoTextData(hparams) + iterator = DataIterator(data) - for batch in iterator: - # batch is Batch object and contains the following fields - # batch == { - # 'text': [['', 'example', 'sequence', '']], - # 'text_ids': [[1, 5, 10, 2]], - # 'length': [4] - # } + for batch in iterator: + # batch is Batch object and contains the following fields + # batch == { + # 'text': [['', 'example', 'sequence', '']], + # 'text_ids': [[1, 5, 10, 2]], + # 'length': [4] + # } - input_ids = torch.tensor(batch['text_ids']) - input_length = (1 - (input_ids == 0).int()).sum(dim=1) + input_ids = torch.tensor(batch['text_ids']) - bert_embeddings, _ = model(input_ids, input_length) + # we can also access the elements using dot notation + input_text = batch.text """ def __init__(self, batch_size: int, batch: Optional[Dict[str, Any]] = None, @@ -114,27 +112,6 @@ def items(self) -> ItemsView[str, Any]: return self._batch.items() -class FieldBatch(Batch): - r"""Defines a batch of examples with support for multiple fields. This is - a simplified version of `torchtext.data.Batch`, with all the useless stuff - removed. - """ - - def __init__(self, data=None, dataset=None, device=None): - r"""Create a Batch from a list of examples. - """ - if data is not None: - batch_size = len(data) - _batch_dict = {} - for (name, field) in dataset.fields.items(): - if field is not None: - batch = [getattr(x, name) for x in data] - _batch_dict[name] = field.process(batch, device=device) - super().__init__(batch_size, _batch_dict) - else: - super().__init__(0) - - class _LazyStrategy(Enum): NONE = "none" PROCESS = "process"