Skip to content
This repository has been archived by the owner on Aug 28, 2024. It is now read-only.

Commit

Permalink
script and README update
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffxtang committed Jul 6, 2022
1 parent c4d5463 commit 4ca2314
Show file tree
Hide file tree
Showing 3 changed files with 4,199 additions and 9 deletions.
16 changes: 7 additions & 9 deletions StreamingASR/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,28 @@ git clone https://github.com/pytorch/android-demo-app
cd android-demo-app/StreamingASR
```

If you don't have PyTorch 1.12 and torchaudio 0.12 installed or want to have a quick try of the demo app, you can download the optimized scripted model file [streaming_asrv2.ptl](TOREPLACE), then drag and drop it to the `StreamingASR/app/src/main/assets` folder inside `android-demo-app/StreamingASR`, and continue to Step 3.

If you don't have PyTorch 1.12 and torchaudio 0.12 installed or want to have a quick try of the demo app, you can download the optimized scripted model file [streaming_asrv2.ptl](https://drive.google.com/file/d/1XRCAFpMqOSz5e7VP0mhiACMGCCcYfpk-/view?usp=sharing), then drag and drop it to the `StreamingASR/app/src/main/assets` folder inside `android-demo-app/StreamingASR`, and continue to Step 3.

### 2. Test and Prepare the Model

To install PyTorch 1.12, torchaudio 0.12, and other required numpy and pyaudio library, do something like this:
To install PyTorch 1.12, torchaudio 0.12, and other required packages (numpy, pyaudio, and fairseq), do something like this:

```
conda create -n pt1.12 python=3.8.5
conda activate pt1.12
pip install torch torchaudio numpy pyaudio
pip install torch torchaudio numpy pyaudio fairseq
```

Now download the streaming ASR model file
[scripted_wrapper_tuple.pt](TOREPLACE) to the `android-demo-app/StreamingASR` directory.
First, create the model file `scripted_wrapper_tuple.pt` by running `python generate_ts.py`.

To test the model, run `python run_sasr.py`. After you see:
Then, to test the model, run `python run_sasr.py`. After you see:
```
Initializing model...
Initialization complete.
```
say something like "good afternoon happy new year", and you'll likely see the streaming recognition results `good afternoon happy new year` while you speak. Hit Ctrl-C to end.

To optimize and convert the model to the format that can run on Android, run the following commands:
Finally, to optimize and convert the model to the format that can run on Android, run the following commands:
```
mkdir -p StreamingASR/app/src/main/assets
python save_model_for_mobile.py
Expand All @@ -64,4 +62,4 @@ Start Android Studio, open the project located in `android-demo-app/StreamingASR

The first version of this demo uses a [C++ port](https://github.com/ewan-xu/LibrosaCpp/) of [Librosa](https://librosa.org), a popular audio processing library in Python, to perform the MelSpectrogram transform, because torchaudio before version 0.11 doesn't support fft on Android (see [here](https://github.com/pytorch/audio/issues/408)). Using the Librosa C++ port and [JNI](https://developer.android.com/training/articles/perf-jni) (Java Native Interface) on Android makes the MelSpectrogram possible on Android. Furthermore, the Librosa C++ port requires [Eigen](https://eigen.tuxfamily.org/), a C++ template library for linear algebra, so both the port and the Eigen library are included in the first version of the demo app and built as JNI.

See [here](https://github.com/jeffxtang/android-demo-app/tree/librosa_jni/StreamingASR) for the first version of the demo if interested in an example of using native C++ to expand operations not yet supported in PyTorch or one of its domain libraries.
See [here](https://github.com/jeffxtang/android-demo-app/tree/librosa_jni/StreamingASR) for the first version of the demo if interested in an example of using native C++ to expand operations not yet supported in PyTorch or one of its domain libraries.
100 changes: 100 additions & 0 deletions StreamingASR/generate_ts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import Dict, List, Optional, Tuple
import json
import math

from fairseq.data import Dictionary
import torch
import torchaudio
from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
from torchaudio.models import Hypothesis


def get_hypo_tokens(hypo: Hypothesis) -> List[int]:
return hypo[0]


def get_hypo_score(hypo: Hypothesis) -> float:
return hypo[3]


def to_string(input: List[int], tgt_dict: List[str], bos_idx: int = 0, eos_idx: int = 2, separator: str = "",) -> str:
# torchscript dislikes sets
extra_symbols_to_ignore: Dict[int, int] = {}
extra_symbols_to_ignore[eos_idx] = 1
extra_symbols_to_ignore[bos_idx] = 1

# it also dislikes comprehensions with conditionals
filtered_idx: List[int] = []
for idx in input:
if idx not in extra_symbols_to_ignore:
filtered_idx.append(idx)

return separator.join([tgt_dict[idx] for idx in filtered_idx]).replace("\u2581", " ")


def post_process_hypos(
hypos: List[Hypothesis], tgt_dict: List[str],
) -> List[Tuple[str, List[float], List[int]]]:
post_process_remove_list = [
3, # unk
2, # eos
1, # pad
]
hypos_str: List[str] = []
for h in hypos:
filtered_tokens: List[int] = []
for token_index in get_hypo_tokens(h)[1:]:
if token_index not in post_process_remove_list:
filtered_tokens.append(token_index)
string = to_string(filtered_tokens, tgt_dict)
hypos_str.append(string)

hypos_ids = [get_hypo_tokens(h)[1:] for h in hypos]
hypos_score = [[math.exp(get_hypo_score(h))] for h in hypos]

nbest_batch = list(zip(hypos_str, hypos_score, hypos_ids))

return nbest_batch


def _piecewise_linear_log(x):
x[x > math.e] = torch.log(x[x > math.e])
x[x <= math.e] = x[x <= math.e] / math.e
return x


class ModelWrapper(torch.nn.Module):
def __init__(self, tgt_dict: List[str]):
super().__init__()
self.transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160)

self.decoder = EMFORMER_RNNT_BASE_LIBRISPEECH.get_decoder()

self.tgt_dict = tgt_dict

with open("global_stats.json") as f:
blob = json.loads(f.read())

self.mean = torch.tensor(blob["mean"])
self.invstddev = torch.tensor(blob["invstddev"])

self.decibel = 2 * 20 * math.log10(32767)
self.gain = pow(10, 0.05 * self.decibel)

def forward(
self, input: torch.Tensor, prev_hypo: Optional[Hypothesis], prev_state: Optional[List[List[torch.Tensor]]]
) -> Tuple[str, Hypothesis, Optional[List[List[torch.Tensor]]]]:
spectrogram = self.transform(input).transpose(1, 0)
features = _piecewise_linear_log(spectrogram * self.gain).unsqueeze(0)[:, :-1]
features = (features - self.mean) * self.invstddev
length = torch.tensor([features.shape[1]])

hypotheses, state = self.decoder.infer(features, length, 10, state=prev_state, hypothesis=prev_hypo)
transcript = post_process_hypos(hypotheses[:1], self.tgt_dict)[0][0]
return transcript, hypotheses[0], state


tgt_dict = Dictionary.load("spm_bpe_4096_fairseq.dict")
wrapper = ModelWrapper(tgt_dict.symbols)
wrapper = torch.jit.script(wrapper)
wrapper.save("scripted_wrapper_tuple.pt")
Loading

0 comments on commit 4ca2314

Please sign in to comment.