Skip to content

Commit

Permalink
Small fixes - testing works
Browse files Browse the repository at this point in the history
  • Loading branch information
johnml1135 committed Aug 16, 2024
1 parent 58eb4da commit 88abab2
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
13 changes: 13 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@
"build1"
]
},
{
"name": "build_word_alignment_model",
"type": "debugpy",
"request": "launch",
"module": "machine.jobs.build_word_alignment_model",
"justMyCode": false,
"args": [
"--model-type",
"thot",
"--build-id",
"build1"
]
},
{
"name": "Python: Debug Tests",
"type": "debugpy",
Expand Down
2 changes: 2 additions & 0 deletions machine/jobs/shared_file_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class PretranslationInfo(TypedDict):

class WordAlignmentInfo(TypedDict):
refs: List[str]
column_count: int
row_count: int
alignmnent: str


Expand Down
13 changes: 10 additions & 3 deletions machine/jobs/word_alignment_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def train_model(
trainer.train(progress=phase_progress, check_canceled=check_canceled)
trainer.save()
self._train_corpus_size = trainer.stats.train_corpus_size
self._confidence = trainer.stats.metrics["bleu"] * 100
self._confidence = -1

if check_canceled is not None:
check_canceled()
Expand All @@ -68,14 +68,21 @@ def batch_inference(
current_inference_step = 0
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))
batch_size = self._config["inference_batch_size"]
segment_batch = list(self._parallel_corpus.lowercase().take(batch_size))
segment_batch = list(self._parallel_corpus.lowercase().tokenize(self._tokenizer).take(batch_size))
if check_canceled is not None:
check_canceled()
alignments = alignment_model.align_batch(segment_batch)
if check_canceled is not None:
check_canceled()
for row, alignment in zip(self._parallel_corpus.get_rows(), alignments):
writer.write(WordAlignmentInfo(refs=row.source_refs, alignment=str(alignment))) # type: ignore
writer.write(
WordAlignmentInfo(
refs=[str(ref) for ref in row.source_refs],
column_count=alignment.column_count,
row_count=alignment.row_count,
alignment=str(alignment),
) # type: ignore
)

def save_model(self) -> None:
logger.info("Saving model")
Expand Down

0 comments on commit 88abab2

Please sign in to comment.