Skip to content

Commit

Permalink
Merge pull request #27 from wrznr/fix_gp
Browse files Browse the repository at this point in the history
Fix the gp module
  • Loading branch information
wrznr authored Aug 5, 2019
2 parents b624828 + b3adf43 commit 924a4c0
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 14 deletions.
5 changes: 4 additions & 1 deletion gramophone/gp/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def align(self,g,p):
'''
Aligns a grapheme-phoneme sequence pair.
'''
alignment_fst = self.__align_fst(g,p)
if alignment_fst.start() == -1:
return []
return(self.__extract_alignments(self.__align_fst(g,p)))

def scan(self,g):
Expand All @@ -93,7 +96,7 @@ def __align_fst(self,g,p):
t4 = self.expand(p)
t4.project(project_output=True)

if t4.num_arcs(t4.start()) == 0:
if t4.start() == -1 or t4.num_arcs(t4.start()) == 0:
return fst.Fst()

t5 = fst.compose(t3,self.E)
Expand Down
3 changes: 2 additions & 1 deletion gramophone/gp/transcription.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-

import wapiti
from tqdm import tqdm

patterns = '''
# Unigram
Expand Down Expand Up @@ -75,7 +76,7 @@ def train(self, training_data):
self.clear()

# parse training data
for alignment in training_data:
for alignment in tqdm(training_data):
seq = u"\n".join(u"%s %s" % (alignment[0][i],alignment[1][i]) for i in range(len(alignment[0])))
#seq = u"\n".join(u"%s %s" % (x,y) for x,y in zip(alignment[0],alignment[1]))
self.model.add_training_sequence(seq)
Expand Down
25 changes: 13 additions & 12 deletions gramophone/scripts/gramophone.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,21 +92,23 @@ def train_gp(mapping,model,data):
with open(str(data),"r") as f:
training_data = f.read()

with click.progressbar(training_data.split("\n")) as bar:
for line in bar:
for line in tqdm(training_data.split("\n")):

# skip comments
if line.startswith("#"):
continue
# skip comments
if line.startswith("#"):
continue

# assume tab-separated values
fields = line.split("\t")
if len(fields) < 2:
continue
# assume tab-separated values
fields = line.split("\t")
if len(fields) < 2:
continue

# align
alignment = aligner.align(fields[0],fields[1])
# align
alignment = aligner.align(fields[0],fields[1])
if alignment:
aligned_training_data.append(alignment)
else:
click.echo("%s and %s could not be aligned." % (fields[0], fields[1]))

#
# stage 2: crf training
Expand Down Expand Up @@ -352,7 +354,6 @@ def apply_st(crf,strings):
# convert
for string in in_strings:
encodement = coder.encode(string,mode="scan")
click.echo(encodement)
labellings = labeller.label(encodement)
combination = []
for labelling in labellings:
Expand Down

0 comments on commit 924a4c0

Please sign in to comment.