Skip to content

Commit

Permalink
Add arguments key per stage
Browse files Browse the repository at this point in the history
  • Loading branch information
jelmervdl committed Dec 23, 2023
1 parent f142670 commit 9957bd4
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 15 deletions.
16 changes: 11 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ trainer: /path/to/trainer/run.py
```
### Number of fields
If `num_fields` is provided, at read time, the trainer will strip any extra TSV fields that the dataset contains (such as optinal alignment field that you are not going to use). Furthermore, any line that doesn't have enough fields gets filtered (eg lines missing alignment info when you do actually care about alignment).
If `num_fields` is provided, at read time, the trainer will strip any extra TSV fields that the dataset contains (such as optional alignment field that you are not going to use). Furthermore, any line that doesn't have enough fields gets filtered (eg lines missing alignment info when you do actually care about alignment).

### Extended stage configuration
If you want to change which modifiers are used for a specific stage, you can the extended stage configuration format. If a `modifiers` is mentioned here, it will override the curriculum-wide defined `modifiers` for just this stage.
If you want to change which modifiers are used for a specific stage, you can the extended stage configuration format.

In the extended format, the list of datasets is defined in the `mix` key. You can optionally add a `modifiers` key. For example:
In the extended format, the list of datasets is defined in the `mix` key. You can optionally add a `modifiers` and `arguments` key. For example:

```yaml
start:
Expand All @@ -110,10 +110,16 @@ start:
- dirty 0
- until clean 2 # Until two epochs of clean
modifiers:
- UpperCase: 0.05
- TitleCase: 0.05
- UpperCase: 0.05
- TitleCase: 0.05
arguments:
- "--stop-early"
```

If a `modifiers` is mentioned here, it will override the curriculum-wide defined `modifiers` for just this stage.

If the optional `arguments` key is added, it will be appended to the end of the arguments list of the trainer argument.

Note that you can use YAML references if you wish to extensively combine global and local modifiers.

### Modifiers
Expand Down
10 changes: 8 additions & 2 deletions src/opustrainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class Stage:
until_dataset: str
until_epoch: Optional[int]
modifiers: Optional[List[Modifier]]
arguments: List[str]


@dataclass(frozen=True)
Expand Down Expand Up @@ -448,6 +449,9 @@ def _load_stages(self, ymldata:dict, basepath:str, stages_order:List[str], datas
- until dataset3 epochs
modifiers:
- Modifier: freq
arguments:
- arg1
- arg2
```
"""
return {
Expand Down Expand Up @@ -491,7 +495,8 @@ def _load_stage(self, ymldata:dict, basepath:str, stage_name:str, available_data
datasets=datasets,
until_dataset=until_dataset_name,
until_epoch=until_epoch,
modifiers=self._load_modifiers(ymldata[stage_name], basepath) if isinstance(ymldata[stage_name], dict) and 'modifiers' in ymldata[stage_name] else None
modifiers=self._load_modifiers(ymldata[stage_name], basepath) if isinstance(ymldata[stage_name], dict) and 'modifiers' in ymldata[stage_name] else None,
arguments=[str(arg) for arg in ymldata[stage_name]['arguments']] if isinstance(ymldata[stage_name], dict) and 'arguments' in ymldata[stage_name] else [],
)
except Exception as exc:
raise CurriculumLoaderError(f"could not complete the parse of stage '{stage_name}': {exc!s}") from exc
Expand Down Expand Up @@ -835,8 +840,9 @@ def main(args:argparse.Namespace) -> None:
signal.signal(signal.SIGUSR1, lambda signum, handler: print_state(trainer.state()))

while trainer.stage is not None:
logger.log(' '.join(trainer.stage.arguments))
model_trainer = subprocess.Popen(
args.trainer or shlex.split(config['trainer']),
(args.trainer or shlex.split(config['trainer'])) + trainer.stage.arguments,
stdin=subprocess.PIPE,
encoding="utf-8",
preexec_fn=ignore_sigint) # ignore_sigint makes marian ignore Ctrl-C. We'll stop it from here.
Expand Down
24 changes: 16 additions & 8 deletions tests/test_trainer_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ def test_marian_log_args(self):

def test_early_stopping(self):
"""Test letting the trainer move to the next stage using early-stopping"""
head_lines = 10000

basepath = Path('contrib').absolute()

config = {
Expand All @@ -51,14 +49,20 @@ def test_early_stopping(self):
'start',
'mid',
],
'start': [
'start': {
'mix': [
'clean 1.0',
'until clean inf'
],
'mid': [
],
'arguments': ['5000']
},
'mid': {
'mix': [
'medium 1.0',
'until medium inf',
],
],
'arguments': ['10000']
},
'seed': 1111
}

Expand All @@ -72,7 +76,7 @@ def test_early_stopping(self):
'--do-not-resume',
'--no-shuffle',
'--config', str(Path(tmp) / 'config.yml'),
'head', '-n', str(head_lines)
'head', '-n', # plus value for n, per stage
], stdout=fout, stderr=ferr)

retval = child.wait(30)
Expand All @@ -84,4 +88,8 @@ def test_early_stopping(self):

# Assert we got the number of lines we'd expect
line_count = sum(1 for _ in fout)
self.assertEqual(line_count, len(config['stages']) * head_lines)
expected_line_count = sum(
int(config[stage]['arguments'][0])
for stage in config['stages']
)
self.assertEqual(line_count, expected_line_count)

0 comments on commit 9957bd4

Please sign in to comment.