From d92cd5ae037ae85ab9730499d99e5c1bd475eed2 Mon Sep 17 00:00:00 2001 From: zhifu gao Date: Tue, 6 Feb 2024 21:22:21 +0800 Subject: [PATCH] Funasr1.0 (#1362) * funasr1.0.5 * funasr1.0.5 audio samples input * batch_type token * batch_type token * huggingface model zoo * dataloader * dataloader * fbank input * vad is_final=True bugfix --- funasr/auto/auto_model.py | 3 +- funasr/datasets/audio_datasets/index_ds.py | 54 +++++- funasr/datasets/audio_datasets/samplers.py | 193 +++++++++++++++++++++ funasr/models/fsmn_vad_streaming/model.py | 3 +- funasr/models/paraformer/cif_predictor.py | 2 +- funasr/models/paraformer/model.py | 2 + funasr/train_utils/trainer.py | 18 ++ 7 files changed, 270 insertions(+), 5 deletions(-) diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 8e00703ca..13451570f 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -171,7 +171,7 @@ def build_model(self, **kwargs): # build model model_class = tables.model_classes.get(kwargs["model"]) model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size) - model.eval() + model.to(device) # init_param @@ -206,6 +206,7 @@ def inference(self, input, input_len=None, model=None, kwargs=None, key=None, ** kwargs = self.kwargs if kwargs is None else kwargs kwargs.update(cfg) model = self.model if model is None else model + model.eval() batch_size = kwargs.get("batch_size", 1) # if kwargs.get("device", "cpu") == "cpu": diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py index c94d20961..008b08ff1 100644 --- a/funasr/datasets/audio_datasets/index_ds.py +++ b/funasr/datasets/audio_datasets/index_ds.py @@ -6,8 +6,8 @@ from funasr.register import tables -@tables.register("index_ds_classes", "IndexDSJsonl") -class IndexDSJsonl(torch.utils.data.Dataset): +@tables.register("index_ds_classes", "IndexDSJsonlRankSplit") +class IndexDSJsonlRankSplit(torch.utils.data.Dataset): def __init__(self, path): super().__init__() @@ -66,3 +66,53 @@ def get_source_len(self, data_dict): def get_target_len(self, data_dict): return data_dict["target_len"] if "target_len" in data_dict else 0 + +@tables.register("index_ds_classes", "IndexDSJsonl") +@tables.register("index_ds_classes", "IndexDSJsonlRankFull") +class IndexDSJsonlRankFull(torch.utils.data.Dataset): + + def __init__(self, path): + super().__init__() + + contents = [] + with open(path, encoding='utf-8') as fin: + for line in fin: + data = json.loads(line.strip()) + if "text" in data: # for sft + self.contents.append(data['text']) + if "source" in data: # for speech lab pretrain + prompt = data.get("prompt", "") + source = data["source"] + target = data["target"] + source_len = data.get("source_len", 1) + target_len = data.get("target_len", 0) + + contents.append({"source": source, + "prompt": prompt, + "target": target, + "source_len": source_len, + "target_len": target_len, + } + ) + + self.contents = contents + + logging.info( + "total_num of samplers across ranks: {}".format(len(self.contents))) + + def __len__(self): + return len(self.contents) + + def __getitem__(self, index): + try: + data = self.contents[index] + except: + print(index) + return data + + def get_source_len(self, data_dict): + return data_dict.get("source_len", 1) + + def get_target_len(self, data_dict): + + return data_dict.get("target_len", 0) diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py index 535df5d05..914e77692 100644 --- a/funasr/datasets/audio_datasets/samplers.py +++ b/funasr/datasets/audio_datasets/samplers.py @@ -1,5 +1,7 @@ import torch import numpy as np +import logging +import torch.distributed as dist from funasr.register import tables @@ -82,3 +84,194 @@ def __iter__(self): max_token = sample_len_cur_raw num_sample = 1 + +@tables.register("batch_sampler_classes", "BatchSampler") +@tables.register("batch_sampler_classes", "RankFullLocalShuffleBatchSampler") +class RankFullLocalShuffleBatchSampler(torch.utils.data.BatchSampler): + + def __init__(self, dataset, + batch_type: str = "example", + batch_size: int = 100, + buffer_size: int = 30, + drop_last: bool = True, + shuffle: bool = True, + is_training: bool = True, + **kwargs): + + self.drop_last = drop_last + self.pre_idx = -1 + self.dataset = dataset + self.total_samples = len(dataset) + self.batch_type = batch_type + self.batch_size = int(batch_size) + self.buffer_size = buffer_size + self.max_token_length = kwargs.get("max_token_length", 1500) + self.shuffle_idx = np.arange(self.total_samples) + self.shuffle = shuffle and is_training + self.length_scale_source = kwargs.get("length_scale_source", 1.0) + + try: + rank = dist.get_rank() + world_size = dist.get_world_size() + except: + rank = 0 + world_size = 1 + self.rank = rank + self.world_size = world_size + + def __len__(self): + return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1 + + def set_epoch(self, epoch): + np.random.seed(epoch) + + def __iter__(self): + + batch_size_total = self.batch_size * self.world_size + + if self.shuffle: + np.random.shuffle(self.shuffle_idx) + + batch = [] + max_token = 0 + num_sample = 0 + + iter_num = (self.total_samples - 1) // self.buffer_size + 1 + # print("iter_num: ", iter_num) + for iter in range(self.pre_idx + 1, iter_num): + # if iter == iter_num -1 and self.drop_last: + # continue + datalen_with_index = [] + for i in range(self.buffer_size): + idx = iter * self.buffer_size + i + if idx >= self.total_samples: + continue + + idx_map = self.shuffle_idx[idx] + # prompt = self.dataset.indexed_dataset[idx_map]["prompt"] + + source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source + target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0 + sample_len_cur = source_len + target_len + + datalen_with_index.append([idx, sample_len_cur]) + + datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1]) + for item in datalen_with_index_sort: + idx, sample_len_cur_raw = item + if sample_len_cur_raw > self.max_token_length: + continue + + max_token_cur = max(max_token, sample_len_cur_raw) + max_token_padding = 1 + num_sample + # if self.batch_type != 'example': + # max_token_padding *= max_token_cur + if max_token_padding <= batch_size_total: + batch.append(idx) + max_token = max_token_cur + num_sample += 1 + else: + batch_rank = batch[self.rank*self.batch_size: (self.rank+1)*self.batch_size] + yield batch_rank + batch = [idx] + max_token = sample_len_cur_raw + num_sample = 1 + + +@tables.register("batch_sampler_classes", "RankFullLocalShuffleDynamicBatchSampler") +class RankFullLocalShuffleDynamicBatchSampler(torch.utils.data.BatchSampler): + + def __init__(self, dataset, + batch_type: str = "example", + batch_size: int = 100, + buffer_size: int = 30, + drop_last: bool = True, + shuffle: bool = True, + is_training: bool = True, + **kwargs): + + self.drop_last = drop_last + self.pre_idx = -1 + self.dataset = dataset + self.total_samples = len(dataset) + self.batch_type = batch_type + self.batch_size = int(batch_size) + self.buffer_size = buffer_size + self.max_token_length = kwargs.get("max_token_length", 1500) + self.shuffle_idx = np.arange(self.total_samples) + self.shuffle = shuffle and is_training + self.length_scale_source = kwargs.get("length_scale_source", 1.0) + + try: + rank = dist.get_rank() + world_size = dist.get_world_size() + except: + rank = 0 + world_size = 1 + self.rank = rank + self.world_size = world_size + + def __len__(self): + return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1 + + def set_epoch(self, epoch): + np.random.seed(epoch) + + def __iter__(self): + + batch_size_total = self.batch_size * self.world_size + if self.shuffle: + np.random.shuffle(self.shuffle_idx) + + batch_list_all_rank = [] + batch_list_cur = [] + max_token = 0 + num_sample = 0 + + iter_num = (self.total_samples - 1) // self.buffer_size + 1 + # print("iter_num: ", iter_num) + for iter in range(self.pre_idx + 1, iter_num): + # if iter == iter_num - 1 and self.drop_last: + # continue + datalen_with_index = [] + for i in range(self.buffer_size): + idx = iter * self.buffer_size + i + if idx >= self.total_samples: + continue + + idx_map = self.shuffle_idx[idx] + # prompt = self.dataset.indexed_dataset[idx_map]["prompt"] + + source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source + target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0 + sample_len_cur = source_len + target_len + + datalen_with_index.append([idx, sample_len_cur]) + + datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1]) + for ii, item in enumerate(datalen_with_index_sort): + is_last_batch = iter == iter_num - 1 and ii == len(datalen_with_index_sort) + idx, sample_len_cur_raw = item + if sample_len_cur_raw > self.max_token_length: + continue + + max_token_cur = max(max_token, sample_len_cur_raw) + max_token_padding = 1 + num_sample + + if self.batch_type != 'example': + max_token_padding *= max_token_cur + if len(batch_list_all_rank) < self.world_size: + + if max_token_padding <= self.batch_size: + batch_list_cur.append(idx) + max_token = max_token_cur + num_sample += 1 + else: + batch_list_all_rank.append(batch_list_cur) + batch_list_cur = [] + else: + batch_rank = batch_list_all_rank[self.rank] + yield batch_rank + batch_list_all_rank = [idx] + max_token = sample_len_cur_raw + num_sample = 1 diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py index 5fc6aae2f..4fd18c85f 100644 --- a/funasr/models/fsmn_vad_streaming/model.py +++ b/funasr/models/fsmn_vad_streaming/model.py @@ -575,7 +575,8 @@ def inference(self, time1 = time.perf_counter() is_streaming_input = kwargs.get("is_streaming_input", False) if chunk_size >= 15000 else kwargs.get("is_streaming_input", True) - cfg = {"is_final": kwargs.get("is_final", False), "is_streaming_input": is_streaming_input} + is_final = kwargs.get("is_final", False) if is_streaming_input else kwargs.get("is_final", True) + cfg = {"is_final": is_final, "is_streaming_input": is_streaming_input} audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py index a5086c3c2..60ddc24e0 100644 --- a/funasr/models/paraformer/cif_predictor.py +++ b/funasr/models/paraformer/cif_predictor.py @@ -186,7 +186,7 @@ def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk alphas = alphas.squeeze(-1) mask = mask.squeeze(-1) if target_label_length is not None: - target_length = target_label_length + target_length = target_label_length.squeeze(-1) elif target_label is not None: target_length = (target_label != ignore_id).float().sum(-1) else: diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py index 6e422ad75..77471466b 100644 --- a/funasr/models/paraformer/model.py +++ b/funasr/models/paraformer/model.py @@ -491,6 +491,8 @@ def inference(self, b, n, d = decoder_out.size() if isinstance(key[0], (list, tuple)): key = key[0] + if len(key) < b: + key = key*b for i in range(b): x = encoder_out[i, :encoder_out_lens[i], :] am_scores = decoder_out[i, :pre_token_length[i], :] diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py index 414c0d7ca..d144019aa 100644 --- a/funasr/train_utils/trainer.py +++ b/funasr/train_utils/trainer.py @@ -204,7 +204,25 @@ def _train_epoch(self, epoch): my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext with my_context(): time2 = time.perf_counter() + print("before, GPU, memory: {:.1} MB, " + "{:.1} MB, " + "{:.1} MB, " + "{:.1} MB".format(torch.cuda.memory_allocated()/1024/1024/1024, + torch.cuda.max_memory_allocated()/1024/1024/1024, + torch.cuda.memory_reserved()/1024/1024/1024, + torch.cuda.max_memory_reserved()/1024/1024/1024, + )) + retval = self.model(**batch) + torch.cuda.empty_cache() + print("after, GPU, memory: {:.1} MB, " + "{:.1} MB, " + "{:.1} MB, " + "{:.1} MB".format(torch.cuda.memory_allocated()/1024/1024/1024, + torch.cuda.max_memory_allocated()/1024/1024/1024, + torch.cuda.memory_reserved()/1024/1024/1024, + torch.cuda.max_memory_reserved()/1024/1024/1024, + )) time3 = time.perf_counter() speed_stats["forward_time"] = f"{time3 - time2:0.3f}" loss, stats, weight = retval