diff --git a/gramophone/gp/alignment.py b/gramophone/gp/alignment.py index 91fb08b..e144960 100644 --- a/gramophone/gp/alignment.py +++ b/gramophone/gp/alignment.py @@ -73,6 +73,9 @@ def align(self,g,p): ''' Aligns a grapheme-phoneme sequence pair. ''' + alignment_fst = self.__align_fst(g,p) + if alignment_fst.start() == -1: + return [] return(self.__extract_alignments(self.__align_fst(g,p))) def scan(self,g): @@ -93,7 +96,7 @@ def __align_fst(self,g,p): t4 = self.expand(p) t4.project(project_output=True) - if t4.num_arcs(t4.start()) == 0: + if t4.start() == -1 or t4.num_arcs(t4.start()) == 0: return fst.Fst() t5 = fst.compose(t3,self.E) diff --git a/gramophone/gp/transcription.py b/gramophone/gp/transcription.py index 8635046..09b92b0 100644 --- a/gramophone/gp/transcription.py +++ b/gramophone/gp/transcription.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import wapiti +from tqdm import tqdm patterns = ''' # Unigram @@ -75,7 +76,7 @@ def train(self, training_data): self.clear() # parse training data - for alignment in training_data: + for alignment in tqdm(training_data): seq = u"\n".join(u"%s %s" % (alignment[0][i],alignment[1][i]) for i in range(len(alignment[0]))) #seq = u"\n".join(u"%s %s" % (x,y) for x,y in zip(alignment[0],alignment[1])) self.model.add_training_sequence(seq) diff --git a/gramophone/scripts/gramophone.py b/gramophone/scripts/gramophone.py index a04d9e8..01d82e4 100644 --- a/gramophone/scripts/gramophone.py +++ b/gramophone/scripts/gramophone.py @@ -92,21 +92,23 @@ def train_gp(mapping,model,data): with open(str(data),"r") as f: training_data = f.read() - with click.progressbar(training_data.split("\n")) as bar: - for line in bar: + for line in tqdm(training_data.split("\n")): - # skip comments - if line.startswith("#"): - continue + # skip comments + if line.startswith("#"): + continue - # assume tab-separated values - fields = line.split("\t") - if len(fields) < 2: - continue + # assume tab-separated values + fields = line.split("\t") + if len(fields) < 2: + continue - # align - alignment = aligner.align(fields[0],fields[1]) + # align + alignment = aligner.align(fields[0],fields[1]) + if alignment: aligned_training_data.append(alignment) + else: + click.echo("%s and %s could not be aligned." % (fields[0], fields[1])) # # stage 2: crf training @@ -352,7 +354,6 @@ def apply_st(crf,strings): # convert for string in in_strings: encodement = coder.encode(string,mode="scan") - click.echo(encodement) labellings = labeller.label(encodement) combination = [] for labelling in labellings: