Skip to content

Commit

Permalink
Fix bugs in Executor (#306)
Browse files Browse the repository at this point in the history
* Fix: saved `meta_dict` contains excessive data

- `meta_dict` contains training and evaluation metrics objects, but it's
  unnecessary to save objects -- only their values are needed.
- A more serious problem is with the `LR` metric, which stores a
  reference to the optimizer. This will result in the entire optimizer
  being saved in the meta-info.
- Solution is two parts: 1) save only the metric values; 2) store the
  optimizer as a weakref.

* Fix: don't close files when calling test in train

- `_open_files` and `_close_files` are called at the beginning and end
  of `train` and `test`, to prevent holding on to an open file object
  for an unnecessarily long amount of time.
- However, it's possible that we call `test` within `train`. For
  instance, calling `test` in a action triggered by the validation
  event. In this case, the file will be closed before training ends.
- Solution is to check whether we need to open files, and if we don't,
  then don't open nor close them.

* Fix: missing call to tracker in `_validate_loop`

This is so stupid: for some reason I forgot to call `_valid_tracker.add`
in `_validate_loop`, so the status is never updated during validation.

* Revert c89e0e4: fix `meta_dict` issue

- It turns out we must store the metric objects -- otherwise we can't
  even compare two metric values.
- So I just changed the pickle behavior for `LR` so that it doesn't save
  the optimizer. Seems like a hack, but let's just leave it at this.

* Fix doc building issues
  • Loading branch information
huzecong authored Apr 20, 2020
1 parent 27fe398 commit 95f3d64
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 18 deletions.
3 changes: 0 additions & 3 deletions docs/code/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,6 @@ Data Loaders
.. autoclass:: texar.torch.data.DatasetBase
:members:

.. automethod:: process
.. automethod:: collate

:hidden:`MonoTextData`
~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: texar.torch.data.MonoTextData
Expand Down
9 changes: 7 additions & 2 deletions tests/run/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def tearDown(self) -> None:
shutil.rmtree(self.tbx_logging_dir)

def test_train_loop(self):
optimizer = torch.optim.Adam(self.model.parameters())
executor = Executor(
model=self.model,
train_data=self.datasets["train"],
Expand All @@ -110,8 +111,9 @@ def test_train_loop(self):
save_every=[cond.time(seconds=10), cond.validation(better=True)],
train_metrics=[("loss", metric.RunningAverage(20)),
metric.F1(pred_name="preds", mode="macro"),
metric.Accuracy(pred_name="preds")],
optimizer={"type": torch.optim.Adam, "kwargs": {}},
metric.Accuracy(pred_name="preds"),
metric.LR(optimizer)],
optimizer=optimizer,
stop_training_on=cond.epoch(10),
valid_metrics=[metric.F1(pred_name="preds", mode="micro"),
("loss", metric.Average())],
Expand All @@ -129,6 +131,9 @@ def test_train_loop(self):
executor.train()
executor.test()

executor.save()
executor.load()

def test_tbx_logging(self):
executor = Executor(
model=self.model,
Expand Down
32 changes: 21 additions & 11 deletions texar/torch/run/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,8 +1280,12 @@ def remove_action(self) -> None:
def train(self):
r"""Start the training loop.
"""
# open the log files
self._open_files()
# Check whether files have been opened, to avoid re-opening and closing.
# This could happen when, e.g., `test` is called in a registered hook
# during training.
should_open_file = (len(self._opened_files) == 0)
if should_open_file:
self._open_files()

if self._directory_exists:
self.write_log(
Expand Down Expand Up @@ -1351,8 +1355,9 @@ def _try_get_data_size(executor: 'Executor'):

self._fire_event(Event.Training, True)

# close the log files
self._close_files()
# Close the log files if we opened them here.
if should_open_file:
self._close_files()

def test(self, dataset: OptionalDict[DatasetBase] = None):
r"""Start the test loop.
Expand All @@ -1369,8 +1374,12 @@ def test(self, dataset: OptionalDict[DatasetBase] = None):
If `None`, :attr:`test_data` from the constructor arguments is
used. Defaults to `None`.
"""
# open the log files
self._open_files()
# Check whether files have been opened, to avoid re-opening and closing.
# This could happen when, e.g., `test` is called in a registered hook
# during training.
should_open_file = (len(self._opened_files) == 0)
if should_open_file:
self._open_files()

if dataset is None and self.test_data is None:
raise ValueError("No testing dataset is specified")
Expand Down Expand Up @@ -1417,8 +1426,9 @@ def test(self, dataset: OptionalDict[DatasetBase] = None):

self.model.train(model_mode)

# close the log files
self._close_files()
# Close the log files if we opened them here.
if should_open_file:
self._close_files()

def _register_logging_actions(self, show_live_progress: List[str]):
# Register logging actions.
Expand Down Expand Up @@ -1728,6 +1738,7 @@ def _open_files(self):
def _close_files(self):
for file in self._opened_files:
file.close()
self._opened_files = []

if hasattr(self, 'summary_writer'):
self.summary_writer.close()
Expand Down Expand Up @@ -1890,7 +1901,7 @@ def _validate_loop(self, iterator: DataIterator) -> None:
self._fire_event(Event.ValidationIteration, False)
return_dict = self._validate_step(batch)

# Update metrics.
self._valid_tracker.add(len(batch))
utils.update_metrics(return_dict, batch, self.valid_metrics)

self._fire_event(Event.ValidationIteration, True)
Expand All @@ -1906,8 +1917,7 @@ def _test_loop(self, iterator: DataIterator) -> None:
return_dict = self._test_step(batch)

self._test_tracker.add(len(batch))
utils.update_metrics(
return_dict, batch, self.test_metrics)
utils.update_metrics(return_dict, batch, self.test_metrics)

self._fire_event(Event.TestingIteration, True)

Expand Down
15 changes: 13 additions & 2 deletions texar/torch/run/metric/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from collections import deque
from typing import Any, Deque, Optional, Sequence
import weakref

import numpy as np
from torch.optim.optimizer import Optimizer
Expand Down Expand Up @@ -152,15 +153,25 @@ class LR(StreamingMetric[Any, float]):

def __init__(self, optimizer: Optimizer, param_group: int = 0):
super().__init__(pred_name=None)
self.optimizer = optimizer
self.optimizer = weakref.ref(optimizer)
self.group = param_group

def add(self, _, __):
pass

def value(self) -> float:
return self.optimizer.param_groups[self.group]['lr'] # type: ignore
return self.optimizer().param_groups[self.group]['lr'] # type: ignore

def better(self, cur: float, prev: float) -> Optional[bool]:
# Always return `None` to indicate values are uncomparable.
return None

def __getstate__(self):
# There's no point in pickling an `LR` metric; just ignore it.
return None

def __getnewargs__(self):
# But when unpickling, we need to make sure we can construct something.
# This requires passing a dummy `optimizer` to which a weakref can be
# constructed. In this case, we use an arbitrary built-in class.
return (int,)

0 comments on commit 95f3d64

Please sign in to comment.