diff --git a/docs/code/data.rst b/docs/code/data.rst index a2d978025..9b9271fd1 100644 --- a/docs/code/data.rst +++ b/docs/code/data.rst @@ -132,9 +132,6 @@ Data Loaders .. autoclass:: texar.torch.data.DatasetBase :members: - .. automethod:: process - .. automethod:: collate - :hidden:`MonoTextData` ~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: texar.torch.data.MonoTextData diff --git a/tests/run/executor_test.py b/tests/run/executor_test.py index 2af1b3057..fd0b10106 100644 --- a/tests/run/executor_test.py +++ b/tests/run/executor_test.py @@ -99,6 +99,7 @@ def tearDown(self) -> None: shutil.rmtree(self.tbx_logging_dir) def test_train_loop(self): + optimizer = torch.optim.Adam(self.model.parameters()) executor = Executor( model=self.model, train_data=self.datasets["train"], @@ -110,8 +111,9 @@ def test_train_loop(self): save_every=[cond.time(seconds=10), cond.validation(better=True)], train_metrics=[("loss", metric.RunningAverage(20)), metric.F1(pred_name="preds", mode="macro"), - metric.Accuracy(pred_name="preds")], - optimizer={"type": torch.optim.Adam, "kwargs": {}}, + metric.Accuracy(pred_name="preds"), + metric.LR(optimizer)], + optimizer=optimizer, stop_training_on=cond.epoch(10), valid_metrics=[metric.F1(pred_name="preds", mode="micro"), ("loss", metric.Average())], @@ -129,6 +131,9 @@ def test_train_loop(self): executor.train() executor.test() + executor.save() + executor.load() + def test_tbx_logging(self): executor = Executor( model=self.model, diff --git a/texar/torch/run/executor.py b/texar/torch/run/executor.py index 2b8823f53..596f87730 100644 --- a/texar/torch/run/executor.py +++ b/texar/torch/run/executor.py @@ -1280,8 +1280,12 @@ def remove_action(self) -> None: def train(self): r"""Start the training loop. """ - # open the log files - self._open_files() + # Check whether files have been opened, to avoid re-opening and closing. + # This could happen when, e.g., `test` is called in a registered hook + # during training. + should_open_file = (len(self._opened_files) == 0) + if should_open_file: + self._open_files() if self._directory_exists: self.write_log( @@ -1351,8 +1355,9 @@ def _try_get_data_size(executor: 'Executor'): self._fire_event(Event.Training, True) - # close the log files - self._close_files() + # Close the log files if we opened them here. + if should_open_file: + self._close_files() def test(self, dataset: OptionalDict[DatasetBase] = None): r"""Start the test loop. @@ -1369,8 +1374,12 @@ def test(self, dataset: OptionalDict[DatasetBase] = None): If `None`, :attr:`test_data` from the constructor arguments is used. Defaults to `None`. """ - # open the log files - self._open_files() + # Check whether files have been opened, to avoid re-opening and closing. + # This could happen when, e.g., `test` is called in a registered hook + # during training. + should_open_file = (len(self._opened_files) == 0) + if should_open_file: + self._open_files() if dataset is None and self.test_data is None: raise ValueError("No testing dataset is specified") @@ -1417,8 +1426,9 @@ def test(self, dataset: OptionalDict[DatasetBase] = None): self.model.train(model_mode) - # close the log files - self._close_files() + # Close the log files if we opened them here. + if should_open_file: + self._close_files() def _register_logging_actions(self, show_live_progress: List[str]): # Register logging actions. @@ -1728,6 +1738,7 @@ def _open_files(self): def _close_files(self): for file in self._opened_files: file.close() + self._opened_files = [] if hasattr(self, 'summary_writer'): self.summary_writer.close() @@ -1890,7 +1901,7 @@ def _validate_loop(self, iterator: DataIterator) -> None: self._fire_event(Event.ValidationIteration, False) return_dict = self._validate_step(batch) - # Update metrics. + self._valid_tracker.add(len(batch)) utils.update_metrics(return_dict, batch, self.valid_metrics) self._fire_event(Event.ValidationIteration, True) @@ -1906,8 +1917,7 @@ def _test_loop(self, iterator: DataIterator) -> None: return_dict = self._test_step(batch) self._test_tracker.add(len(batch)) - utils.update_metrics( - return_dict, batch, self.test_metrics) + utils.update_metrics(return_dict, batch, self.test_metrics) self._fire_event(Event.TestingIteration, True) diff --git a/texar/torch/run/metric/summary.py b/texar/torch/run/metric/summary.py index 0774616ea..104f0d752 100644 --- a/texar/torch/run/metric/summary.py +++ b/texar/torch/run/metric/summary.py @@ -17,6 +17,7 @@ from collections import deque from typing import Any, Deque, Optional, Sequence +import weakref import numpy as np from torch.optim.optimizer import Optimizer @@ -152,15 +153,25 @@ class LR(StreamingMetric[Any, float]): def __init__(self, optimizer: Optimizer, param_group: int = 0): super().__init__(pred_name=None) - self.optimizer = optimizer + self.optimizer = weakref.ref(optimizer) self.group = param_group def add(self, _, __): pass def value(self) -> float: - return self.optimizer.param_groups[self.group]['lr'] # type: ignore + return self.optimizer().param_groups[self.group]['lr'] # type: ignore def better(self, cur: float, prev: float) -> Optional[bool]: # Always return `None` to indicate values are uncomparable. return None + + def __getstate__(self): + # There's no point in pickling an `LR` metric; just ignore it. + return None + + def __getnewargs__(self): + # But when unpickling, we need to make sure we can construct something. + # This requires passing a dummy `optimizer` to which a weakref can be + # constructed. In this case, we use an arbitrary built-in class. + return (int,)