Skip to content

Commit

Permalink
Merge --use-ontology-for-* into --use-ontology
Browse files Browse the repository at this point in the history
Summary:
There are three options for the ontology:
* `--use-ontology-for-training`
* `--use-ontology-for-validation`
* `--use-ontology-for-balancing`

The first two must always be set together.

In the past, I observed that it's best not to use ontology for data balancing even if we use ontology for training and validation. But now I no longer observe this.

Therefore, I'm merging all these three options into one (`--use-ontology`).

In addition, I'm also moving the logic of avoiding loading teacher models out of `checkpoint_utils.py`. If you want to load a student model without loading its teachers (e.g. for prediction only), specify `arg_overrides={"ignore_teachers": True}` when calling `load_model_ensemble`.

Reviewed By: xiaoxiao26

Differential Revision: D32518830

fbshipit-source-id: 103c6458f7927ec5ca7470109c8f956c00f514a2
  • Loading branch information
MaigoAkisame authored and facebook-github-bot committed Nov 19, 2021
1 parent 7105d7f commit bf61974
Showing 1 changed file with 7 additions and 17 deletions.
24 changes: 7 additions & 17 deletions fairseq/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,14 +578,6 @@ def _upgrade_state_dict(state):
# keep track of number of updates
if "num_updates" not in state["optimizer_history"][-1]:
state["optimizer_history"][-1]["num_updates"] = 0
# old model checkpoints may not have separate source/target positions
if (
"args" in state
and hasattr(state["args"], "max_positions")
and not hasattr(state["args"], "max_source_positions")
):
state["args"].max_source_positions = state["args"].max_positions
state["args"].max_target_positions = state["args"].max_positions
# use stateful training data iterator
if "train_iterator" not in state["extra_state"]:
state["extra_state"]["train_iterator"] = {
Expand All @@ -595,6 +587,13 @@ def _upgrade_state_dict(state):

# backward compatibility, cfg updates
if "args" in state and state["args"] is not None:
# old model checkpoints may not have separate source/target positions
if (
hasattr(state["args"], "max_positions")
and not hasattr(state["args"], "max_source_positions")
):
state["args"].max_source_positions = state["args"].max_positions
state["args"].max_target_positions = state["args"].max_positions
# default to translation task
if not hasattr(state["args"], "task"):
state["args"].task = "translation"
Expand Down Expand Up @@ -646,15 +645,6 @@ def _upgrade_state_dict(state):
and len(state["args"].data) > 0
):
state["args"].data = state["args"].data[0]
# remove keys in state["args"] related to teacher-student learning
for key in [
"static_teachers",
"static_teacher_weights",
"dynamic_teachers",
"dynamic_teacher_weights",
]:
if key in state["args"]:
delattr(state["args"], key)

state["cfg"] = convert_namespace_to_omegaconf(state["args"])

Expand Down

0 comments on commit bf61974

Please sign in to comment.