diff --git a/configs/train_css/local/conformer_v0.51_mc.yaml b/configs/train_css/local/conformer_v0.51_mc.yaml new file mode 100644 index 0000000..0d07d29 --- /dev/null +++ b/configs/train_css/local/conformer_v0.51_mc.yaml @@ -0,0 +1,41 @@ +# Note there's newer and better data. Do not download v1.2 +train_dir: ./v1.2/100hrs/train +val_dir: ./v1.2/100hrs/val +out_dir: ./ + +train_set_cfg: + sample_frac: 1.0 + max_urls: null # null means no limit +val_set_cfg: + sample_frac: 1.0 + max_urls: null # null means no limit + +clip_gt_to_mixture: True + +log_params_mlflow: True +log_metrics_mlflow: True + +scheduler_step_every: [1, iterations] +scheduler_name: step_lr +scheduler_linear_warmup_decay_cfg: + warmup: 10000 + decay: 260000 +scheduler_step_lr_cfg: + step_size: 1 + gamma: 1.0 # no decay + +stop_after: [260000, iterations] +eval_every: [1000, iterations] +save_every: [1000, iterations] + +global_batch_size: 256 +learning_rate: 1e-5 +weight_decay: 1e-2 # according to the paper set to 1e-2 + +# Large model per CSS with Conformer definition +conformer_css_cfg: + nnet_conf: + conformer_conf: + attention_dim: 512 # default 256 + attention_heads: 8 # default 4 + num_blocks: 18 # default 16 diff --git a/configs/train_css/local/conformer_v0.51_sc.yaml b/configs/train_css/local/conformer_v0.51_sc.yaml new file mode 100644 index 0000000..3a7f61f --- /dev/null +++ b/configs/train_css/local/conformer_v0.51_sc.yaml @@ -0,0 +1,46 @@ +# Note there's newer and better data. Do not download v1.2 +train_dir: ./v1.2/100hrs/train +val_dir: ./v1.2/100hrs/val +out_dir: ./ + +single_channel: True + +train_set_cfg: + sample_frac: 1.0 + max_urls: null # null means no limit +val_set_cfg: + sample_frac: 1.0 + max_urls: null # null means no limit + +clip_gt_to_mixture: True + +log_params_mlflow: True +log_metrics_mlflow: True + +scheduler_step_every: [1, iterations] +scheduler_name: step_lr +scheduler_linear_warmup_decay_cfg: + warmup: 10000 + decay: 260000 +scheduler_step_lr_cfg: + step_size: 1 + gamma: 1.0 # no decay + +stop_after: [260000, iterations] +eval_every: [1000, iterations] +save_every: [1000, iterations] + +global_batch_size: 256 +learning_rate: 1e-5 +weight_decay: 1e-2 # according to the paper set to 1e-2 + +# Large model per CSS with Conformer definition +conformer_css_cfg: + extractor_conf: + ipd_index: '' # For MC '1,0;2,0;3,0;4,0;5,0;6,0'. For SC ''. + nnet_conf: + conformer_conf: + attention_dim: 512 # default 256 + attention_heads: 8 # default 4 + num_blocks: 18 # default 16 + in_features: 257 # For MC 1799. For SC 257. diff --git a/configs/train_css/local/conformer_v0.5_mc.yaml b/configs/train_css/local/conformer_v0.5_mc.yaml index 7d08a2c..35d33b2 100644 --- a/configs/train_css/local/conformer_v0.5_mc.yaml +++ b/configs/train_css/local/conformer_v0.5_mc.yaml @@ -10,7 +10,8 @@ val_set_cfg: sample_frac: 1.0 max_urls: null # null means no limit -# This model was actually trained with clip_gt_to_mixture: False, but we recommend to set it to True. +# This model was trained with clip_gt_to_mixture=False, but we recommend to set it to True. +clip_gt_to_mixture: False log_params_mlflow: True log_metrics_mlflow: True diff --git a/configs/train_css/local/conformer_v0.5_sc.yaml b/configs/train_css/local/conformer_v0.5_sc.yaml index a70b31f..23332bb 100644 --- a/configs/train_css/local/conformer_v0.5_sc.yaml +++ b/configs/train_css/local/conformer_v0.5_sc.yaml @@ -12,7 +12,8 @@ val_set_cfg: sample_frac: 1.0 max_urls: null # null means no limit -# This model was actually trained with clip_gt_to_mixture: False, but we recommend to set it to True. +# This model was trained with clip_gt_to_mixture=False, but we recommend to set it to True. +clip_gt_to_mixture: False log_params_mlflow: True log_metrics_mlflow: True diff --git a/configs/train_css/local/debug_mc.yaml b/configs/train_css/local/debug_mc.yaml index 8b18f73..fd6e3a3 100644 --- a/configs/train_css/local/debug_mc.yaml +++ b/configs/train_css/local/debug_mc.yaml @@ -10,6 +10,8 @@ val_set_cfg: sample_frac: 1.0 max_urls: 2 # null means no limit +clip_gt_to_mixture: True + log_params_mlflow: False log_metrics_mlflow: False diff --git a/configs/train_css/local/debug_sc.yaml b/configs/train_css/local/debug_sc.yaml index 2df8bee..94dbbac 100644 --- a/configs/train_css/local/debug_sc.yaml +++ b/configs/train_css/local/debug_sc.yaml @@ -17,6 +17,8 @@ val_set_cfg: sample_frac: 1.0 max_urls: 2 # null means no limit +clip_gt_to_mixture: True + log_params_mlflow: False log_metrics_mlflow: False diff --git a/css/training/train.py b/css/training/train.py index 8e40bfb..c79c304 100644 --- a/css/training/train.py +++ b/css/training/train.py @@ -64,7 +64,8 @@ class TrainCfg: learning_rate: float = 1e-3 global_batch_size: int = 32 # global means across all GPUs, local means per GPU clip_grad_norm: float = 0.01 - clip_gt_to_mixture: bool = True # clips the ground truth to the mixture to avoid trying to drive the mask above 1 + # clips the ground truth to the mixture to avoid trying to drive the mask above 1. "True" is recommended. + clip_gt_to_mixture: bool = False weight_decay: float = 1e-4 is_debug: bool = False # no data workers, no DataParallel, etc. log_params_mlflow: bool = True diff --git a/utils/text_norm_whisper_like/__init__.py b/utils/text_norm_whisper_like/__init__.py index b06b125..b130d05 100644 --- a/utils/text_norm_whisper_like/__init__.py +++ b/utils/text_norm_whisper_like/__init__.py @@ -6,13 +6,15 @@ from .basic import BasicTextNormalizer as BasicTextNormalizer from .english import EnglishTextNormalizer as EnglishTextNormalizer +from whisper.normalizers import EnglishTextNormalizer as OriginalEnglishTextNormalizer def get_txt_norm(txt_norm): - assert txt_norm in ["chime8", None] if txt_norm is None: return None elif txt_norm == "chime8": return EnglishTextNormalizer() + elif txt_norm == "whisper": + return OriginalEnglishTextNormalizer() else: - raise NotImplementedError + raise NotImplementedError() diff --git a/utils/text_norm_whisper_like/english.py b/utils/text_norm_whisper_like/english.py index edd2930..b9bf2e1 100644 --- a/utils/text_norm_whisper_like/english.py +++ b/utils/text_norm_whisper_like/english.py @@ -448,6 +448,82 @@ def __call__(self, s: str): return s +class EnglishReverseNumberNormalizer(EnglishNumberNormalizer): + """ + This is an approximate inverse of EnglishNumberNormalizer that converts arabic numerals + into spelled-out numbers. + + Motivation: Whisper's original EnglishNumberNormalizer produces numberals that match Whisper's rich + token set, which many ASRs cannot output. + This class takes an alternative normalization approach, converting Whisper's numberals back to + spelled-out numbers. This ensures compatibility with the token sets of other ASR systems while + avoiding penalizing Whisper for outputting numerals. + + Examples of cases handled: + - "365" -> "three hundred sixty five" + - "$20" -> "twenty dollars" + - "50%" -> "fifty percent" + - "12th" -> "twelfth", "12s" -> "twelves" + - "90th" -> "ninetieth", "90s" -> "nineties" + - The special cases of "70 000" -> "seventy thousand" but not larger numbers. + + Caveats: this class takes care of the majority of cases, but it is not perfect. + - Only numerals within the 0-1000 range are handled. + - Minus/plus signs are not handled. + - There is inherent ambiguity e.g. "100" -> "one hundred" or "a hundred". + """ + + def __init__(self): + super().__init__() + # Reverse dictionaries + self.int_to_ones = {v: k for k, v in self.ones.items()} + self.int_to_tens = {v: k for k, v in self.tens.items()} + + # 11th -> eleventh etc. + self.str_to_ones_suffixed = {str(n)+s: k for k, (n,s) in self.ones_suffixed.items()} + # 20s -> twenties etc. + self.str_to_tens_suffixed = {str(n)+s: k for k, (n,s) in self.tens_suffixed.items()} + + def __call__(self, s: str): + # "$x[.y]" -> "x[.y] dollars" + s = re.sub(r'\$(\d+(\.\d+)?)', r'\1 dollars', s) + # "x[.y]"% -> "x[.y] percent" + s = re.sub(r'(\d+(\.\d+)?)%', r'\1 percent', s) + # note this doesn't handle cases such as -x or +x. + + def number_to_words(w: str): + if w.isdigit(): + num = int(w) + if w == '000': + return 'thousand' # will work in case of "70 000" -> "seventy thousand" + if num == 0: + return "zero" + elif num == 100: + return "hundred" + elif 0 < num < 1000: + hundreds, remainder = divmod(num, 100) + tens, ones = divmod(remainder, 10) + h = [f"{self.int_to_ones[hundreds]} hundred"] if hundreds > 0 else [] + if 0 < remainder <= 19: + t = [self.int_to_ones[remainder]] + o = [] + else: + t = [self.int_to_tens[tens*10]] if tens > 0 else [] + o = [self.int_to_ones[ones]] if ones > 0 else [] + return " ".join(h + t + o) + elif num == 1000: + return "thousand" + else: + return w # case not handled + else: + # suffixed numbers + w = self.str_to_ones_suffixed.get(w, w) + w = self.str_to_tens_suffixed.get(w, w) + return w + + return " ".join(number_to_words(w) for w in s.split()) + + class EnglishSpellingNormalizer: """ Applies British-American spelling mappings as listed in [1]. @@ -464,7 +540,7 @@ def __call__(self, s: str): class EnglishTextNormalizer: - def __init__(self, standardize_numbers=False): + def __init__(self, standardize_numbers=False, standardize_numbers_rev=True, remove_fillers=True): self.replacers = { # common non verbal sounds are mapped to the similar ones r"\b(hm+)\b|\b(mhm)\b|\b(mm+)\b|\b(m+h)\b|\b(hm+)\b|\b(um+)\b|\b(uhm+)\b": ( # noqa e501 @@ -492,6 +568,9 @@ def __init__(self, standardize_numbers=False): r"\bcoulda\b": "could have", r"\bshoulda\b": "should have", r"\bma'am\b": "madam", + r"\bokay\b": "ok", + r"\bsetup\b": "set up", + r"\beveryday\b": "every day", # contractions in titles/prefixes r"\bmr\b": "mister ", r"\bmrs\b": "missus ", @@ -532,11 +611,23 @@ def __init__(self, standardize_numbers=False): } if standardize_numbers: self.standardize_numbers = EnglishNumberNormalizer() + assert not standardize_numbers_rev else: self.standardize_numbers = None + + if standardize_numbers_rev: + self.standardize_numbers_rev = EnglishReverseNumberNormalizer() + else: + self.standardize_numbers_rev = None + self.standardize_spellings = EnglishSpellingNormalizer() self.pre_standardize_spellings = EnglishSpellingNormalizer("pre_english.json") + if remove_fillers: + self.fillers = ['hmm', 'uh', 'ah', 'eh'] # assumes replacers have been applied + else: + self.fillers = None + def __call__(self, s: str): s = s.lower() @@ -561,13 +652,25 @@ def __call__(self, s: str): if self.standardize_numbers is not None: s = self.standardize_numbers(s) + if self.standardize_numbers_rev is not None: + s = self.standardize_numbers_rev(s) + s = self.standardize_spellings(s) # now remove prefix/suffix symbols # that are not preceded/followed by numbers s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s) s = re.sub(r"([^0-9])%", r"\1 ", s) + # remove filler words + # motivation: these words are very common, yet hold little information in the majority of cases. + # some ASR systems may ignore them by convention and will be penalized unfairly. + if self.fillers: + s = re.sub(r'\b(' + '|'.join(self.fillers) + r')\b', "", s) + s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space + s = re.sub(r"^\s+|\s+$", "", s) + # remove leading and trailing whitespaces + return s