Skip to content

Commit

Permalink
Restrict build options to only update model hyperparameters (#88)
Browse files Browse the repository at this point in the history
- fixes #86
  • Loading branch information
ddaspit authored Dec 14, 2023
1 parent 5203d87 commit a105036
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 41 deletions.
54 changes: 37 additions & 17 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,46 @@
"request": "launch",
"module": "machine.jobs.build_nmt_engine",
"justMyCode": false,
"args": [
"--model-type",
"huggingface",
"--build-id",
"build1",
"--engine-id",
"engine1",
"--src-lang",
"spa_Latn",
"--trg-lang",
"eng_Latn",
"--clearml",
"--build-options",
"{\"max_steps\": 10}"
]
"windows": {
"args": [
"--model-type",
"huggingface",
"--build-id",
"build1",
"--engine-id",
"engine1",
"--src-lang",
"spa_Latn",
"--trg-lang",
"eng_Latn",
"--build-options",
"{\\\"train_params\\\": {\\\"max_steps\\\": 10}}"
]
},
"linux": {
"args": [
"--model-type",
"huggingface",
"--build-id",
"build1",
"--engine-id",
"engine1",
"--src-lang",
"spa_Latn",
"--trg-lang",
"eng_Latn",
"--build-options",
"{\"train_params\": {\"max_steps\": 10}}"
]
}
},
{
"name": "Debug Unit Test",
"name": "Python: Debug Tests",
"type": "python",
"request": "test",
"request": "launch",
"program": "${file}",
"purpose": ["debug-test"],
"console": "integratedTerminal",
"justMyCode": false
}
]
Expand Down
7 changes: 3 additions & 4 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
{
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": true
"source.organizeImports": "explicit"
},
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true
},
"python.formatting.provider": "none"
}
}
}
20 changes: 10 additions & 10 deletions machine/jobs/build_nmt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,20 @@ def clearml_progress(status: ProgressStatus) -> None:
logger.info("NMT Engine Build Job started")

SETTINGS.update(args)
try:
SETTINGS.build_options = json.loads(args["build_options"])
except ValueError as e:
raise ValueError("Build options could not be parsed: Invalid JSON") from e
except TypeError as e:
raise TypeError(f"Build options could not be parsed: {e}") from e
if SETTINGS.build_options:
SETTINGS.update(SETTINGS.build_options)
model_type = cast(str, SETTINGS.model_type).lower()
if "build_options" in SETTINGS:
try:
build_options = json.loads(cast(str, SETTINGS.build_options))
except ValueError as e:
raise ValueError("Build options could not be parsed: Invalid JSON") from e
except TypeError as e:
raise TypeError(f"Build options could not be parsed: {e}") from e
SETTINGS.update({model_type: build_options})
SETTINGS.data_dir = os.path.expanduser(cast(str, SETTINGS.data_dir))

logger.info(f"Config: {SETTINGS.as_dict()}")

shared_file_service = ClearMLSharedFileService(SETTINGS)
model_type = cast(str, SETTINGS.model_type).lower()
nmt_model_factory: NmtModelFactory
if model_type == "huggingface":
from .huggingface.hugging_face_nmt_model_factory import HuggingFaceNmtModelFactory
Expand Down Expand Up @@ -87,7 +87,7 @@ def main() -> None:
parser.add_argument("--src-lang", required=True, type=str, help="Source language tag")
parser.add_argument("--trg-lang", required=True, type=str, help="Target language tag")
parser.add_argument("--clearml", default=False, action="store_true", help="Initializes a ClearML task")
parser.add_argument("--build-options", default="{}", type=str, help="Build configurations")
parser.add_argument("--build-options", default=None, type=str, help="Build configurations")
args = parser.parse_args()

run({k: v for k, v in vars(args).items() if v is not None})
Expand Down
9 changes: 5 additions & 4 deletions machine/jobs/huggingface/hugging_face_nmt_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ def __init__(self, config: Any, shared_file_service: SharedFileService) -> None:
args = config.huggingface.train_params.to_dict()
args["output_dir"] = str(self._model_dir)
args["overwrite_output_dir"] = True
if "max_steps" in self._config:
if self._config.max_steps > 50000:
raise ValueError("max_steps must be less than or equal to 50000")
args["max_steps"] = self._config.max_steps
# Use "max_steps" from root for backward compatibility
if "max_steps" in self._config.huggingface:
args["max_steps"] = self._config.huggingface.max_steps
parser = HfArgumentParser(cast(Any, Seq2SeqTrainingArguments))
self._training_args = cast(Seq2SeqTrainingArguments, parser.parse_dict(args)[0])
if self._training_args.max_steps > 50000:
raise ValueError("max_steps must be less than or equal to 50000")
if (
not config.clearml
and self._training_args.report_to is not None
Expand Down
2 changes: 1 addition & 1 deletion machine/jobs/nmt_engine_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def run(
writer = stack.enter_context(self._shared_file_service.open_target_pretranslation_writer())
current_inference_step = 0
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))
batch_size = self._config["batch_size"]
batch_size = self._config["pretranslation_batch_size"]
for pi_batch in batch(src_pretranslations, batch_size):
if check_canceled is not None:
check_canceled()
Expand Down
9 changes: 5 additions & 4 deletions machine/jobs/settings.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
default:
model_type: huggingface
max_steps: 20000
data_dir: ~/machine
batch_size: 1024
pretranslation_batch_size: 1024
huggingface:
parent_model_name: facebook/nllb-200-distilled-1.3B
train_params:
Expand All @@ -16,6 +15,7 @@ default:
gradient_checkpointing: true
fp16: true
save_strategy: no
max_steps: 20000
generate_params:
device: 0
num_beams: 2
Expand All @@ -31,8 +31,9 @@ development:
generate_params:
num_beams: 1
staging:
max_steps: 10
huggingface:
parent_model_name: facebook/nllb-200-distilled-600M
train_params:
max_steps: 10
generate_params:
num_beams: 1
num_beams: 1
2 changes: 1 addition & 1 deletion tests/jobs/test_nmt_engine_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_cancel(decoy: Decoy) -> None:

class _TestEnvironment:
def __init__(self, decoy: Decoy) -> None:
config = {"src_lang": "es", "trg_lang": "en", "batch_size": 100}
config = {"src_lang": "es", "trg_lang": "en", "pretranslation_batch_size": 100}
self.source_tokenizer_trainer = decoy.mock(cls=Trainer)
self.target_tokenizer_trainer = decoy.mock(cls=Trainer)

Expand Down

0 comments on commit a105036

Please sign in to comment.