From 4a991bd91180342df3bc926ef0d890f219951c14 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 23 Oct 2023 16:31:46 +0200 Subject: [PATCH] Add support for `text-to-speech` (w/ Speecht5) (#345) * Add vocoder to export * Add tokenizer.json export for speecht5 models * Update speecht5 supported models * Create `SpeechT5Tokenizer` * Add `ones` and `ones_like` tensor functions * Add support for speecht5 text-to-speech * Disambiguate `SpeechSeq2Seq` and `Seq2SeqLM` * Create `TextToAudioPipeline` * Add listed support for `text-to-audio` / `text-to-speech` * Use unquantized vocoder by default * Skip speecht5 unit tests for now Due to bug in transformers: https://github.com/huggingface/transformers/issues/26547 * Update example pipeline output * Create simple in-browser TTS demo * Add template README * Delete package-lock.json * Update required transformers.js version * Add link to Transformers.js * Double -> Single quotes * Add link to text-to-speech demo * Update sample speaker embeddings --- README.md | 3 +- docs/snippets/3_examples.snippet | 1 + docs/snippets/5_supported-tasks.snippet | 2 +- examples/text-to-speech-client/.eslintrc.cjs | 20 ++ examples/text-to-speech-client/.gitignore | 24 +++ examples/text-to-speech-client/README.md | 8 + examples/text-to-speech-client/index.html | 12 ++ examples/text-to-speech-client/package.json | 30 +++ .../text-to-speech-client/postcss.config.js | 6 + examples/text-to-speech-client/src/App.jsx | 162 ++++++++++++++ .../src/components/AudioPlayer.jsx | 26 +++ .../src/components/Progress.jsx | 12 ++ .../text-to-speech-client/src/constants.js | 11 + examples/text-to-speech-client/src/index.css | 21 ++ examples/text-to-speech-client/src/main.jsx | 10 + examples/text-to-speech-client/src/utils.js | 47 +++++ examples/text-to-speech-client/src/worker.js | 97 +++++++++ .../text-to-speech-client/tailwind.config.js | 12 ++ examples/text-to-speech-client/vite.config.js | 7 + scripts/convert.py | 11 + scripts/extra/speecht5.py | 116 ++++++++++ scripts/supported_models.py | 4 + src/models.js | 198 +++++++++++++++++- src/pipelines.js | 120 ++++++++++- src/processors.js | 16 ++ src/tokenizers.js | 3 + src/utils/tensor.js | 22 ++ tests/generate_tests.py | 3 + 28 files changed, 988 insertions(+), 16 deletions(-) create mode 100644 examples/text-to-speech-client/.eslintrc.cjs create mode 100644 examples/text-to-speech-client/.gitignore create mode 100644 examples/text-to-speech-client/README.md create mode 100644 examples/text-to-speech-client/index.html create mode 100644 examples/text-to-speech-client/package.json create mode 100644 examples/text-to-speech-client/postcss.config.js create mode 100644 examples/text-to-speech-client/src/App.jsx create mode 100644 examples/text-to-speech-client/src/components/AudioPlayer.jsx create mode 100644 examples/text-to-speech-client/src/components/Progress.jsx create mode 100644 examples/text-to-speech-client/src/constants.js create mode 100644 examples/text-to-speech-client/src/index.css create mode 100644 examples/text-to-speech-client/src/main.jsx create mode 100644 examples/text-to-speech-client/src/utils.js create mode 100644 examples/text-to-speech-client/src/worker.js create mode 100644 examples/text-to-speech-client/tailwind.config.js create mode 100644 examples/text-to-speech-client/vite.config.js create mode 100644 scripts/extra/speecht5.py diff --git a/README.md b/README.md index 03c17c3a0..10671a60f 100644 --- a/README.md +++ b/README.md @@ -116,6 +116,7 @@ Want to jump straight in? Get started with one of our sample applications/templa | Semantic Image Search (server-side) | Search for images with text (Supabase) | [code](./examples/semantic-image-search/), [demo](https://huggingface.co/spaces/Xenova/semantic-image-search) | | Vanilla JavaScript | In-browser object detection | [video](https://scrimba.com/scrim/cKm9bDAg), [code](./examples/vanilla-js/), [demo](https://huggingface.co/spaces/Scrimba/vanilla-js-object-detector) | | React | Multilingual translation website | [code](./examples/react-translator/), [demo](https://huggingface.co/spaces/Xenova/react-translator) | +| Text to speech (client-side) | In-browser speech synthesis | [code](./examples/text-to-speech-client/), [demo](https://huggingface.co/spaces/Xenova/text-to-speech-client) | | Browser extension | Text classification extension | [code](./examples/extension/) | | Electron | Text classification application | [code](./examples/electron/) | | Next.js (client-side) | Sentiment analysis (in-browser inference) | [code](./examples/next-client/), [demo](https://huggingface.co/spaces/Xenova/next-example-app) | @@ -222,7 +223,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te | [Audio Classification](https://huggingface.co/tasks/audio-classification) | `audio-classification` | Assigning a label or class to a given audio. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.AudioClassificationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=audio-classification&library=transformers.js) | | [Audio-to-Audio](https://huggingface.co/tasks/audio-to-audio) | n/a | Generating audio from an input audio source. | ❌ | | [Automatic Speech Recognition](https://huggingface.co/tasks/automatic-speech-recognition) | `automatic-speech-recognition` | Transcribing a given audio into text. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.AutomaticSpeechRecognitionPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition&library=transformers.js) | -| [Text-to-Speech](https://huggingface.co/tasks/text-to-speech) | n/a | Generating natural-sounding speech given text input. | ❌ | +| [Text-to-Speech](https://huggingface.co/tasks/text-to-speech) | `text-to-speech` or `text-to-audio` | | Generating natural-sounding speech given text input. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.TextToAudioPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=text-to-audio&library=transformers.js) | #### Tabular diff --git a/docs/snippets/3_examples.snippet b/docs/snippets/3_examples.snippet index 009494342..6af3da7ec 100644 --- a/docs/snippets/3_examples.snippet +++ b/docs/snippets/3_examples.snippet @@ -9,6 +9,7 @@ Want to jump straight in? Get started with one of our sample applications/templa | Semantic Image Search (server-side) | Search for images with text (Supabase) | [code](./examples/semantic-image-search/), [demo](https://huggingface.co/spaces/Xenova/semantic-image-search) | | Vanilla JavaScript | In-browser object detection | [video](https://scrimba.com/scrim/cKm9bDAg), [code](./examples/vanilla-js/), [demo](https://huggingface.co/spaces/Scrimba/vanilla-js-object-detector) | | React | Multilingual translation website | [code](./examples/react-translator/), [demo](https://huggingface.co/spaces/Xenova/react-translator) | +| Text to speech (client-side) | In-browser speech synthesis | [code](./examples/text-to-speech-client/), [demo](https://huggingface.co/spaces/Xenova/text-to-speech-client) | | Browser extension | Text classification extension | [code](./examples/extension/) | | Electron | Text classification application | [code](./examples/electron/) | | Next.js (client-side) | Sentiment analysis (in-browser inference) | [code](./examples/next-client/), [demo](https://huggingface.co/spaces/Xenova/next-example-app) | diff --git a/docs/snippets/5_supported-tasks.snippet b/docs/snippets/5_supported-tasks.snippet index e2da1636a..002634869 100644 --- a/docs/snippets/5_supported-tasks.snippet +++ b/docs/snippets/5_supported-tasks.snippet @@ -38,7 +38,7 @@ | [Audio Classification](https://huggingface.co/tasks/audio-classification) | `audio-classification` | Assigning a label or class to a given audio. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.AudioClassificationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=audio-classification&library=transformers.js) | | [Audio-to-Audio](https://huggingface.co/tasks/audio-to-audio) | n/a | Generating audio from an input audio source. | ❌ | | [Automatic Speech Recognition](https://huggingface.co/tasks/automatic-speech-recognition) | `automatic-speech-recognition` | Transcribing a given audio into text. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.AutomaticSpeechRecognitionPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition&library=transformers.js) | -| [Text-to-Speech](https://huggingface.co/tasks/text-to-speech) | n/a | Generating natural-sounding speech given text input. | ❌ | +| [Text-to-Speech](https://huggingface.co/tasks/text-to-speech) | `text-to-speech` or `text-to-audio` | | Generating natural-sounding speech given text input. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.TextToAudioPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=text-to-audio&library=transformers.js) | #### Tabular diff --git a/examples/text-to-speech-client/.eslintrc.cjs b/examples/text-to-speech-client/.eslintrc.cjs new file mode 100644 index 000000000..4dcb43901 --- /dev/null +++ b/examples/text-to-speech-client/.eslintrc.cjs @@ -0,0 +1,20 @@ +module.exports = { + root: true, + env: { browser: true, es2020: true }, + extends: [ + 'eslint:recommended', + 'plugin:react/recommended', + 'plugin:react/jsx-runtime', + 'plugin:react-hooks/recommended', + ], + ignorePatterns: ['dist', '.eslintrc.cjs'], + parserOptions: { ecmaVersion: 'latest', sourceType: 'module' }, + settings: { react: { version: '18.2' } }, + plugins: ['react-refresh'], + rules: { + 'react-refresh/only-export-components': [ + 'warn', + { allowConstantExport: true }, + ], + }, +} diff --git a/examples/text-to-speech-client/.gitignore b/examples/text-to-speech-client/.gitignore new file mode 100644 index 000000000..a547bf36d --- /dev/null +++ b/examples/text-to-speech-client/.gitignore @@ -0,0 +1,24 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +pnpm-debug.log* +lerna-debug.log* + +node_modules +dist +dist-ssr +*.local + +# Editor directories and files +.vscode/* +!.vscode/extensions.json +.idea +.DS_Store +*.suo +*.ntvs* +*.njsproj +*.sln +*.sw? diff --git a/examples/text-to-speech-client/README.md b/examples/text-to-speech-client/README.md new file mode 100644 index 000000000..f768e33fc --- /dev/null +++ b/examples/text-to-speech-client/README.md @@ -0,0 +1,8 @@ +# React + Vite + +This template provides a minimal setup to get React working in Vite with HMR and some ESLint rules. + +Currently, two official plugins are available: + +- [@vitejs/plugin-react](https://github.com/vitejs/vite-plugin-react/blob/main/packages/plugin-react/README.md) uses [Babel](https://babeljs.io/) for Fast Refresh +- [@vitejs/plugin-react-swc](https://github.com/vitejs/vite-plugin-react-swc) uses [SWC](https://swc.rs/) for Fast Refresh diff --git a/examples/text-to-speech-client/index.html b/examples/text-to-speech-client/index.html new file mode 100644 index 000000000..26748c68f --- /dev/null +++ b/examples/text-to-speech-client/index.html @@ -0,0 +1,12 @@ + + + + + + Transformers.js - Text-to-speech demo + + +
+ + + diff --git a/examples/text-to-speech-client/package.json b/examples/text-to-speech-client/package.json new file mode 100644 index 000000000..49949efa2 --- /dev/null +++ b/examples/text-to-speech-client/package.json @@ -0,0 +1,30 @@ +{ + "name": "text-to-speech-client", + "private": true, + "version": "0.0.0", + "type": "module", + "scripts": { + "dev": "vite", + "build": "vite build", + "lint": "eslint . --ext js,jsx --report-unused-disable-directives --max-warnings 0", + "preview": "vite preview" + }, + "dependencies": { + "@xenova/transformers": "^2.7.0", + "react": "^18.2.0", + "react-dom": "^18.2.0" + }, + "devDependencies": { + "@types/react": "^18.2.15", + "@types/react-dom": "^18.2.7", + "@vitejs/plugin-react": "^4.0.3", + "autoprefixer": "^10.4.16", + "eslint": "^8.45.0", + "eslint-plugin-react": "^7.32.2", + "eslint-plugin-react-hooks": "^4.6.0", + "eslint-plugin-react-refresh": "^0.4.3", + "postcss": "^8.4.31", + "tailwindcss": "^3.3.3", + "vite": "^4.4.5" + } +} diff --git a/examples/text-to-speech-client/postcss.config.js b/examples/text-to-speech-client/postcss.config.js new file mode 100644 index 000000000..2e7af2b7f --- /dev/null +++ b/examples/text-to-speech-client/postcss.config.js @@ -0,0 +1,6 @@ +export default { + plugins: { + tailwindcss: {}, + autoprefixer: {}, + }, +} diff --git a/examples/text-to-speech-client/src/App.jsx b/examples/text-to-speech-client/src/App.jsx new file mode 100644 index 000000000..cfa53fb9a --- /dev/null +++ b/examples/text-to-speech-client/src/App.jsx @@ -0,0 +1,162 @@ +import React, { useState, useEffect, useRef } from 'react'; + +import AudioPlayer from './components/AudioPlayer'; +import Progress from './components/Progress'; +import { SPEAKERS, DEFAULT_SPEAKER } from './constants'; + +const App = () => { + + // Model loading + const [ready, setReady] = useState(null); + const [disabled, setDisabled] = useState(false); + const [progressItems, setProgressItems] = useState([]); + + // Inputs and outputs + const [text, setText] = useState('I love Hugging Face!'); + const [selectedSpeaker, setSelectedSpeaker] = useState(DEFAULT_SPEAKER); + const [output, setOutput] = useState(null); + + // Create a reference to the worker object. + const worker = useRef(null); + + // We use the `useEffect` hook to setup the worker as soon as the `App` component is mounted. + useEffect(() => { + if (!worker.current) { + // Create the worker if it does not yet exist. + worker.current = new Worker(new URL('./worker.js', import.meta.url), { + type: 'module' + }); + } + + // Create a callback function for messages from the worker thread. + const onMessageReceived = (e) => { + switch (e.data.status) { + case 'initiate': + // Model file start load: add a new progress item to the list. + setReady(false); + setProgressItems(prev => [...prev, e.data]); + break; + + case 'progress': + // Model file progress: update one of the progress items. + setProgressItems( + prev => prev.map(item => { + if (item.file === e.data.file) { + return { ...item, progress: e.data.progress } + } + return item; + }) + ); + break; + + case 'done': + // Model file loaded: remove the progress item from the list. + setProgressItems( + prev => prev.filter(item => item.file !== e.data.file) + ); + break; + + case 'ready': + // Pipeline ready: the worker is ready to accept messages. + setReady(true); + break; + + case 'complete': + // Generation complete: re-enable the "Translate" button + setDisabled(false); + + const blobUrl = URL.createObjectURL(e.data.output); + setOutput(blobUrl); + break; + } + }; + + // Attach the callback function as an event listener. + worker.current.addEventListener('message', onMessageReceived); + + // Define a cleanup function for when the component is unmounted. + return () => worker.current.removeEventListener('message', onMessageReceived); + }); + + + const handleGenerateSpeech = () => { + setDisabled(true); + worker.current.postMessage({ + text, + speaker_id: selectedSpeaker, + }); + }; + + const isLoading = ready === false; + return ( +
+
+ {isLoading && ( + + )} + {progressItems.map(data => ( +
+ +
+ ))} +
+
+

In-browser Text to Speech

+

Made with 🤗 Transformers.js

+
+ + +
+
+ + +
+
+ +
+ {output && } +
+
+ ); +}; + +export default App; diff --git a/examples/text-to-speech-client/src/components/AudioPlayer.jsx b/examples/text-to-speech-client/src/components/AudioPlayer.jsx new file mode 100644 index 000000000..a6d2daf1f --- /dev/null +++ b/examples/text-to-speech-client/src/components/AudioPlayer.jsx @@ -0,0 +1,26 @@ +import { useEffect, useRef } from "react"; + +export default function AudioPlayer({ audioUrl, mimeType }) { + const audioPlayer = useRef(null); + const audioSource = useRef(null); + + // Updates src when url changes + useEffect(() => { + if (audioPlayer.current && audioSource.current) { + audioSource.current.src = audioUrl; + audioPlayer.current.load(); + } + }, [audioUrl]); + + return ( +
+ +
+ ); +} \ No newline at end of file diff --git a/examples/text-to-speech-client/src/components/Progress.jsx b/examples/text-to-speech-client/src/components/Progress.jsx new file mode 100644 index 000000000..efaaf0a9a --- /dev/null +++ b/examples/text-to-speech-client/src/components/Progress.jsx @@ -0,0 +1,12 @@ + +export default function Progress({ text, percentage }) { + percentage ??= 0; + return ( +
+
+ {text} ({`${percentage.toFixed(2)}%`}) +
+
+ ); +} + diff --git a/examples/text-to-speech-client/src/constants.js b/examples/text-to-speech-client/src/constants.js new file mode 100644 index 000000000..ef6d848af --- /dev/null +++ b/examples/text-to-speech-client/src/constants.js @@ -0,0 +1,11 @@ +export const SPEAKERS = { + "US female 1": "cmu_us_slt_arctic-wav-arctic_a0001", + "US female 2": "cmu_us_clb_arctic-wav-arctic_a0001", + "US male 1": "cmu_us_bdl_arctic-wav-arctic_a0003", + "US male 2": "cmu_us_rms_arctic-wav-arctic_a0003", + "Canadian male": "cmu_us_jmk_arctic-wav-arctic_a0002", + "Scottish male": "cmu_us_awb_arctic-wav-arctic_b0002", + "Indian male": "cmu_us_ksp_arctic-wav-arctic_a0007", +} + +export const DEFAULT_SPEAKER = "cmu_us_slt_arctic-wav-arctic_a0001"; diff --git a/examples/text-to-speech-client/src/index.css b/examples/text-to-speech-client/src/index.css new file mode 100644 index 000000000..2ea01764b --- /dev/null +++ b/examples/text-to-speech-client/src/index.css @@ -0,0 +1,21 @@ +@tailwind base; +@tailwind components; +@tailwind utilities; + +:root { + font-family: Inter, system-ui, Avenir, Helvetica, Arial, sans-serif; + line-height: 1.5; + font-weight: 400; + color: #213547; + background-color: #ffffff; + + font-synthesis: none; + text-rendering: optimizeLegibility; + -webkit-font-smoothing: antialiased; + -moz-osx-font-smoothing: grayscale; + -webkit-text-size-adjust: 100%; +} + +audio::-webkit-media-controls-panel { + background-color: white; +} \ No newline at end of file diff --git a/examples/text-to-speech-client/src/main.jsx b/examples/text-to-speech-client/src/main.jsx new file mode 100644 index 000000000..54b39dd1d --- /dev/null +++ b/examples/text-to-speech-client/src/main.jsx @@ -0,0 +1,10 @@ +import React from 'react' +import ReactDOM from 'react-dom/client' +import App from './App.jsx' +import './index.css' + +ReactDOM.createRoot(document.getElementById('root')).render( + + + , +) diff --git a/examples/text-to-speech-client/src/utils.js b/examples/text-to-speech-client/src/utils.js new file mode 100644 index 000000000..b23f88bba --- /dev/null +++ b/examples/text-to-speech-client/src/utils.js @@ -0,0 +1,47 @@ +// Adapted from https://www.npmjs.com/package/audiobuffer-to-wav + +export function encodeWAV(samples) { + let offset = 44; + const buffer = new ArrayBuffer(offset + samples.length * 4); + const view = new DataView(buffer); + const sampleRate = 16000; + + /* RIFF identifier */ + writeString(view, 0, 'RIFF') + /* RIFF chunk length */ + view.setUint32(4, 36 + samples.length * 4, true) + /* RIFF type */ + writeString(view, 8, 'WAVE') + /* format chunk identifier */ + writeString(view, 12, 'fmt ') + /* format chunk length */ + view.setUint32(16, 16, true) + /* sample format (raw) */ + view.setUint16(20, 3, true) + /* channel count */ + view.setUint16(22, 1, true) + /* sample rate */ + view.setUint32(24, sampleRate, true) + /* byte rate (sample rate * block align) */ + view.setUint32(28, sampleRate * 4, true) + /* block align (channel count * bytes per sample) */ + view.setUint16(32, 4, true) + /* bits per sample */ + view.setUint16(34, 32, true) + /* data chunk identifier */ + writeString(view, 36, 'data') + /* data chunk length */ + view.setUint32(40, samples.length * 4, true) + + for (let i = 0; i < samples.length; ++i, offset += 4) { + view.setFloat32(offset, samples[i], true) + } + + return buffer +} + +function writeString(view, offset, string) { + for (let i = 0; i < string.length; ++i) { + view.setUint8(offset + i, string.charCodeAt(i)) + } +} diff --git a/examples/text-to-speech-client/src/worker.js b/examples/text-to-speech-client/src/worker.js new file mode 100644 index 000000000..76b8f76ef --- /dev/null +++ b/examples/text-to-speech-client/src/worker.js @@ -0,0 +1,97 @@ + +import { env, Tensor, AutoTokenizer, SpeechT5ForTextToSpeech, SpeechT5HifiGan } from '@xenova/transformers'; +import { encodeWAV } from './utils'; + +// Disable local model checks +env.allowLocalModels = false; + + +// Use the Singleton pattern to enable lazy construction of the pipeline. +class MyTextToSpeechPipeline { + + static BASE_URL = 'https://huggingface.co/datasets/Xenova/cmu-arctic-xvectors-extracted/resolve/main/'; + + static model_id = 'Xenova/speecht5_tts'; + static vocoder_id = 'Xenova/speecht5_hifigan'; + + static tokenizer_instance = null; + static model_instance = null; + static vocoder_instance = null; + + static async getInstance(progress_callback = null) { + if (this.tokenizer_instance === null) { + this.tokenizer = AutoTokenizer.from_pretrained(this.model_id, { progress_callback }); + } + + if (this.model_instance === null) { + this.model_instance = SpeechT5ForTextToSpeech.from_pretrained(this.model_id, { + quantized: false, + progress_callback, + }); + } + + if (this.vocoder_instance === null) { + this.vocoder_instance = SpeechT5HifiGan.from_pretrained(this.vocoder_id, { + quantized: false, + progress_callback, + }); + } + + return new Promise(async (resolve, reject) => { + const result = await Promise.all([ + this.tokenizer, + this.model_instance, + this.vocoder_instance, + ]); + self.postMessage({ + status: 'ready', + }); + resolve(result); + }); + } + + static async getSpeakerEmbeddings(speaker_id) { + // e.g., `cmu_us_awb_arctic-wav-arctic_a0001` + const speaker_embeddings_url = `${this.BASE_URL}${speaker_id}.bin`; + const speaker_embeddings = new Tensor( + 'float32', + new Float32Array(await (await fetch(speaker_embeddings_url)).arrayBuffer()), + [1, 512] + ) + return speaker_embeddings; + } +} + +// Mapping of cached speaker embeddings +const speaker_embeddings_cache = new Map(); + +// Listen for messages from the main thread +self.addEventListener('message', async (event) => { + // Load the pipeline + const [tokenizer, model, vocoder] = await MyTextToSpeechPipeline.getInstance(x => { + // We also add a progress callback so that we can track model loading. + self.postMessage(x); + }); + + // Tokenize the input + const { input_ids } = tokenizer(event.data.text); + + // Load the speaker embeddings + let speaker_embeddings = speaker_embeddings_cache.get(event.data.speaker_id); + if (speaker_embeddings === undefined) { + speaker_embeddings = await MyTextToSpeechPipeline.getSpeakerEmbeddings(event.data.speaker_id); + speaker_embeddings_cache.set(event.data.speaker_id, speaker_embeddings); + } + + // Generate the waveform + const { waveform } = await model.generate_speech(input_ids, speaker_embeddings, { vocoder }); + + // Encode the waveform as a WAV file + const wav = encodeWAV(waveform.data); + + // Send the output back to the main thread + self.postMessage({ + status: 'complete', + output: new Blob([wav], { type: 'audio/wav' }), + }); +}); diff --git a/examples/text-to-speech-client/tailwind.config.js b/examples/text-to-speech-client/tailwind.config.js new file mode 100644 index 000000000..d37737fc0 --- /dev/null +++ b/examples/text-to-speech-client/tailwind.config.js @@ -0,0 +1,12 @@ +/** @type {import('tailwindcss').Config} */ +export default { + content: [ + "./index.html", + "./src/**/*.{js,ts,jsx,tsx}", + ], + theme: { + extend: {}, + }, + plugins: [], +} + diff --git a/examples/text-to-speech-client/vite.config.js b/examples/text-to-speech-client/vite.config.js new file mode 100644 index 000000000..5a33944a9 --- /dev/null +++ b/examples/text-to-speech-client/vite.config.js @@ -0,0 +1,7 @@ +import { defineConfig } from 'vite' +import react from '@vitejs/plugin-react' + +// https://vitejs.dev/config/ +export default defineConfig({ + plugins: [react()], +}) diff --git a/scripts/convert.py b/scripts/convert.py index b86abb9b6..0a68cd5ca 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -307,6 +307,17 @@ def main(): with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp: json.dump(tokenizer_json, fp, indent=4) + elif config.model_type == 'speecht5': + # TODO allow user to specify vocoder path + export_kwargs["model_kwargs"] = {"vocoder": "microsoft/speecht5_hifigan"} + + if tokenizer is not None: + from .extra.speecht5 import generate_tokenizer_json + tokenizer_json = generate_tokenizer_json(tokenizer) + + with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp: + json.dump(tokenizer_json, fp, indent=4) + else: pass # TODO diff --git a/scripts/extra/speecht5.py b/scripts/extra/speecht5.py new file mode 100644 index 000000000..e01992c5c --- /dev/null +++ b/scripts/extra/speecht5.py @@ -0,0 +1,116 @@ +import json + + +def generate_tokenizer_json(tokenizer): + vocab = tokenizer.get_vocab() + + tokenizer_json = { + "version": "1.0", + "truncation": None, + "padding": None, + "added_tokens": [ + { + "id": vocab[token], + "content": token, + "single_word": False, + "lstrip": False, + "rstrip": False, + "normalized": False, + "special": True + } + for token in vocab + if token.startswith('<') and token.endswith('>') + ], + + "normalizer": { + "type": "Precompiled", + "precompiled_charsmap": None + }, + "pre_tokenizer": { + "type": "Sequence", + "pretokenizers": [ + { + "type": "WhitespaceSplit" + }, + { + "type": "Metaspace", + "replacement": "▁", + "add_prefix_space": True + }, + { + "type": "Split", + "pattern": { + "Regex": "" + }, + "behavior": "Isolated", + "invert": False + } + ] + }, + "post_processor": { + "type": "TemplateProcessing", + "single": [ + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + } + ], + "pair": [ + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "B", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + } + ], + "special_tokens": { + "": { + "id": "", + "ids": [ + 2 + ], + "tokens": [ + "" + ] + } + } + }, + "decoder": { + "type": "Metaspace", + "replacement": "▁", + "add_prefix_space": True + }, + 'model': { + # 'type': 'Char', + 'unk_id': 2, + "vocab": vocab + } + } + + return tokenizer_json diff --git a/scripts/supported_models.py b/scripts/supported_models.py index 56531ea8d..03711b806 100644 --- a/scripts/supported_models.py +++ b/scripts/supported_models.py @@ -375,6 +375,10 @@ # 'facebook/sam-vit-large', # 'facebook/sam-vit-huge', # ], + 'speecht5': [ + # Text-to-speech + 'microsoft/speecht5_tts', + ], 'squeezebert': [ # Feature extraction 'squeezebert/squeezebert-uncased', diff --git a/src/models.js b/src/models.js index 8e5085f62..7c3b55964 100644 --- a/src/models.js +++ b/src/models.js @@ -74,6 +74,7 @@ import { cat, dynamicTimeWarping, mean, + ones_like, stack, std_mean, Tensor, @@ -278,11 +279,7 @@ function prepareAttentionMask(self, tokens) { ) return new Tensor('int64', data, tokens.dims) } else { - return new Tensor( - 'int64', - new BigInt64Array(tokens.data.length).fill(1n), - tokens.dims - ) + return ones_like(tokens); } } @@ -928,7 +925,9 @@ export class PreTrainedModel extends Callable { const modelType = this.config.model_type; const possibleInfo = MODEL_WITH_LM_HEAD_MAPPING_NAMES.get(modelType) - ?? MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES.get(modelType) + ?? MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES.get(modelType) + ?? MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.get(modelType) + // ?? MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES.get(modelType) // TODO ?? MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES.get(modelType); if (possibleInfo) { @@ -3563,6 +3562,151 @@ export class WavLMForSequenceClassification extends WavLMPreTrainedModel { } } +////////////////////////////////////////////////// +// SpeechT5 models +/** + * An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. + */ +export class SpeechT5PreTrainedModel extends PreTrainedModel { }; + +/** + * The bare SpeechT5 Encoder-Decoder Model outputting raw hidden-states without any specific pre- or post-nets. + */ +export class SpeechT5Model extends SpeechT5PreTrainedModel { }; + +/** + * SpeechT5 Model with a speech encoder and a text decoder. + */ +export class SpeechT5ForSpeechToText extends SpeechT5PreTrainedModel { } + +/** + * SpeechT5 Model with a text encoder and a speech decoder. + */ +export class SpeechT5ForTextToSpeech extends SpeechT5PreTrainedModel { + + /** + * Creates a new instance of the `SpeechT5ForTextToSpeech` class. + * @param {Object} config The model configuration. + * @param {any} session session for the model. + * @param {any} decoder_merged_session session for the decoder. + * @param {GenerationConfig} generation_config The generation configuration. + */ + constructor(config, session, decoder_merged_session, generation_config) { + super(config, session); + this.decoder_merged_session = decoder_merged_session; + this.generation_config = generation_config; + + this.num_decoder_layers = this.config.decoder_layers; + this.num_decoder_heads = this.config.decoder_attention_heads; + this.decoder_dim_kv = this.config.hidden_size / this.num_decoder_heads; + + this.num_encoder_layers = this.config.encoder_layers; + this.num_encoder_heads = this.config.encoder_attention_heads; + this.encoder_dim_kv = this.config.hidden_size / this.num_encoder_heads; + } + + /** + * @typedef {Object} SpeechOutput + * @property {Tensor} [spectrogram] The predicted log-mel spectrogram of shape + * `(output_sequence_length, config.num_mel_bins)`. Returned when no `vocoder` is provided + * @property {Tensor} [waveform] The predicted waveform of shape `(num_frames,)`. Returned when a `vocoder` is provided. + * @property {Tensor} [cross_attentions] The outputs of the decoder's cross-attention layers of shape + * `(config.decoder_layers, config.decoder_attention_heads, output_sequence_length, input_sequence_length)`. returned when `output_cross_attentions` is `true`. + */ + + /** + * Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a speech waveform using a vocoder. + * @param {Tensor} input_values Indices of input sequence tokens in the vocabulary. + * @param {Tensor} speaker_embeddings Tensor containing the speaker embeddings. + * @param {Object} options Optional parameters for generating speech. + * @param {number} [options.threshold=0.5] The generated sequence ends when the predicted stop token probability exceeds this value. + * @param {number} [options.minlenratio=0.0] Used to calculate the minimum required length for the output sequence. + * @param {number} [options.maxlenratio=20.0] Used to calculate the maximum allowed length for the output sequence. + * @param {Object} [options.vocoder=null] The vocoder that converts the mel spectrogram into a speech waveform. If `null`, the output is the mel spectrogram. + * @param {boolean} [options.output_cross_attentions=false] Whether or not to return the attentions tensors of the decoder's cross-attention layers. + * @returns {Promise} A promise which resolves to an object containing the spectrogram, waveform, and cross-attention tensors. + */ + async generate_speech(input_values, speaker_embeddings, { + threshold = 0.5, + minlenratio = 0.0, + maxlenratio = 20.0, + vocoder = null, + // output_cross_attentions = false, // TODO add + } = {}) { + + const model_inputs = { + input_ids: input_values + } + + const { encoder_outputs, encoder_attention_mask } = await encoderForward(this, model_inputs); + + const r = encoder_outputs.dims[1] / this.config.reduction_factor; + const maxlen = Math.floor(r * maxlenratio); + const minlen = Math.floor(r * minlenratio); + + const num_mel_bins = this.config.num_mel_bins; + + let spectrogramParts = []; + let past_key_values = null; + let decoder_outputs = null; + let idx = 0; + + while (true) { + ++idx; + + const use_cache_branch = boolTensor(!!decoder_outputs); + let output_sequence; + if (decoder_outputs) { + output_sequence = decoder_outputs.output_sequence_out; + } else { + output_sequence = new Tensor( + 'float32', + new Float32Array(num_mel_bins), + [1, 1, num_mel_bins], + ) + } + let decoderFeeds = { + use_cache_branch, + output_sequence, + encoder_attention_mask: encoder_attention_mask, + speaker_embeddings: speaker_embeddings, + encoder_hidden_states: encoder_outputs, + }; + + this.addPastKeyValues(decoderFeeds, past_key_values); + decoder_outputs = await sessionRun(this.decoder_merged_session, decoderFeeds); + past_key_values = this.getPastKeyValues(decoder_outputs, past_key_values); + + const { prob, spectrum } = decoder_outputs; + spectrogramParts.push(spectrum); + + if (idx >= minlen && ( + // Finished when stop token or maximum length is reached. + Array.from(prob.data).filter(p => p >= threshold).length > 0 || idx >= maxlen + )) { + break; + } + } + + const spectrogram = cat(spectrogramParts); + const { waveform } = await sessionRun(vocoder.session, { spectrogram }); + + return { + spectrogram, + waveform, + // cross_attentions: null, // TODO add + } + } +} + +/** + * HiFi-GAN vocoder. + */ +export class SpeechT5HifiGan extends PreTrainedModel { + main_input_name = 'spectrogram'; +} +////////////////////////////////////////////////// + ////////////////////////////////////////////////// // AutoModels, used to simplify construction of PreTrainedModels // (uses config to instantiate correct class) @@ -3659,6 +3803,8 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([ ['donut-swin', ['DonutSwinModel', DonutSwinModel]], ['yolos', ['YolosModel', YolosModel]], + ['hifigan', ['SpeechT5HifiGan', SpeechT5HifiGan]], + ['sam', ['SamModel', SamModel]], // TODO change to encoder-decoder when model is split correctly ]); @@ -3689,6 +3835,15 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([ ['opt', ['OPTModel', OPTModel]], ]); +const MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = new Map([ + ['speecht5', ['SpeechT5ForSpeechToText', SpeechT5ForSpeechToText]], + ['whisper', ['WhisperForConditionalGeneration', WhisperForConditionalGeneration]], +]) + +const MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = new Map([ + ['speecht5', ['SpeechT5ForTextToSpeech', SpeechT5ForTextToSpeech]], +]) + const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([ ['bert', ['BertForSequenceClassification', BertForSequenceClassification]], ['camembert', ['CamembertForSequenceClassification', CamembertForSequenceClassification]], @@ -3718,13 +3873,12 @@ const MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = new Map([ ['xlm-roberta', ['XLMRobertaForTokenClassification', XLMRobertaForTokenClassification]], ]); -const MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES = new Map([ +const MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = new Map([ ['t5', ['T5ForConditionalGeneration', T5ForConditionalGeneration]], ['longt5', ['LongT5ForConditionalGeneration', LongT5ForConditionalGeneration]], ['mt5', ['MT5ForConditionalGeneration', MT5ForConditionalGeneration]], ['bart', ['BartForConditionalGeneration', BartForConditionalGeneration]], ['mbart', ['MBartForConditionalGeneration', MBartForConditionalGeneration]], - ['whisper', ['WhisperForConditionalGeneration', WhisperForConditionalGeneration]], ['marian', ['MarianMTModel', MarianMTModel]], ['m2m_100', ['M2M100ForConditionalGeneration', M2M100ForConditionalGeneration]], ['blenderbot', ['BlenderbotForConditionalGeneration', BlenderbotForConditionalGeneration]], @@ -3822,7 +3976,8 @@ const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_MAPPING_NAMES_DECODER_ONLY, MODEL_TYPES.DecoderOnly], [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], - [MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Seq2Seq], + [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq], + [MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Seq2Seq], [MODEL_WITH_LM_HEAD_MAPPING_NAMES, MODEL_TYPES.DecoderOnly], [MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], @@ -3833,6 +3988,7 @@ const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq], ]; for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) { @@ -3897,7 +4053,29 @@ export class AutoModelForTokenClassification extends PretrainedMixin { * let model = await AutoModelForSeq2SeqLM.from_pretrained('t5-small'); */ export class AutoModelForSeq2SeqLM extends PretrainedMixin { - static MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES]; + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES]; +} + +/** + * Helper class which is used to instantiate pretrained sequence-to-sequence speech-to-text models with the `from_pretrained` function. + * The chosen model class is determined by the type specified in the model config. + * + * @example + * let model = await AutoModelForSpeechSeq2Seq.from_pretrained('openai/whisper-tiny.en'); + */ +export class AutoModelForSpeechSeq2Seq extends PretrainedMixin { + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES]; +} + +/** + * Helper class which is used to instantiate pretrained sequence-to-sequence text-to-spectrogram models with the `from_pretrained` function. + * The chosen model class is determined by the type specified in the model config. + * + * @example + * let model = await AutoModelForTextToSpectrogram.from_pretrained('microsoft/speecht5_tts'); + */ +export class AutoModelForTextToSpectrogram extends PretrainedMixin { + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES]; } /** diff --git a/src/pipelines.js b/src/pipelines.js index 0f6ddccee..4873f35a9 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -25,6 +25,8 @@ import { AutoModelForQuestionAnswering, AutoModelForMaskedLM, AutoModelForSeq2SeqLM, + AutoModelForSpeechSeq2Seq, + AutoModelForTextToSpectrogram, AutoModelForCTC, AutoModelForCausalLM, AutoModelForVision2Seq, @@ -32,6 +34,7 @@ import { AutoModelForImageSegmentation, AutoModelForObjectDetection, AutoModelForDocumentQuestionAnswering, + // AutoModelForTextToWaveform, PreTrainedModel, } from './models.js'; import { @@ -57,6 +60,7 @@ import { read_audio } from './utils/audio.js'; import { + Tensor, mean_pooling, } from './utils/tensor.js'; import { RawImage } from './utils/image.js'; @@ -1127,8 +1131,7 @@ export class AutomaticSpeechRecognitionPipeline extends Pipeline { } /** - * @typedef {import('./utils/tensor.js').Tensor} Tensor - * @typedef {{stride: number[], input_features: Tensor, is_last: boolean, tokens?: number[], token_timestamps?: number[]}} Chunk + * @typedef {{stride: number[], input_features: import('./utils/tensor.js').Tensor, is_last: boolean, tokens?: number[], token_timestamps?: number[]}} Chunk * * @callback ChunkCallback * @param {Chunk} chunk The chunk to process. @@ -1835,6 +1838,103 @@ export class DocumentQuestionAnsweringPipeline extends Pipeline { } } +/** + * Text-to-audio generation pipeline using any `AutoModelForTextToWaveform` or `AutoModelForTextToSpectrogram`. + * This pipeline generates an audio file from an input text and optional other conditional inputs. + * + * **Example:** Generate audio from text with `Xenova/speecht5_tts`. + * ```js + * let speaker_embeddings = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/speaker_embeddings.bin'; + * let synthesizer = await pipeline('text-to-speech', 'Xenova/speecht5_tts', { quantized: false }); + * let out = await synthesizer('Hello, my dog is cute', { speaker_embeddings }); + * // { + * // audio: Float32Array(26112) [-0.00005657337896991521, 0.00020583874720614403, ...], + * // sampling_rate: 16000 + * // } + * ``` + * + * You can then save the audio to a .wav file with the `wavefile` package: + * ```js + * import wavefile from 'wavefile'; + * import fs from 'fs'; + * + * let wav = new wavefile.WaveFile(); + * wav.fromScratch(1, out.sampling_rate, '32f', out.audio); + * fs.writeFileSync('out.wav', wav.toBuffer()); + * ``` + */ +export class TextToAudioPipeline extends Pipeline { + DEFAULT_VOCODER_ID = "Xenova/speecht5_hifigan" + + /** + * Create a new TextToAudioPipeline. + * @param {Object} options An object containing the following properties: + * @param {string} [options.task] The task of the pipeline. Useful for specifying subtasks. + * @param {PreTrainedModel} [options.model] The model to use. + * @param {PreTrainedTokenizer} [options.tokenizer] The tokenizer to use. + * @param {Processor} [options.processor] The processor to use. + * @param {PreTrainedModel} [options.vocoder] The vocoder to use. + */ + constructor(options) { + super(options); + + // TODO: Find a better way for `pipeline` to set the default vocoder + this.vocoder = options.vocoder ?? null; + } + + /** + * Generates speech/audio from the inputs. + * @param {string|string[]} text_inputs The text(s) to generate. + * @param {Object} options Parameters passed to the model generation/forward method. + * @param {PreTrainedModel} [options.vocoder=null] The vocoder to use (if the model uses one). If not provided, use the default HifiGan vocoder. + * @param {Tensor|Float32Array|string|URL} [options.speaker_embeddings=null] + * @returns {Promise} An object containing the generated audio and sampling rate. + */ + async _call(text_inputs, { + speaker_embeddings = null, + } = {}) { + // Load vocoder, if not provided + if (!this.vocoder) { + console.log('No vocoder specified, using default HifiGan vocoder.'); + this.vocoder = await AutoModel.from_pretrained(this.DEFAULT_VOCODER_ID, { quantized: false }); + } + + // Load speaker embeddings as Float32Array from path/URL + if (typeof speaker_embeddings === 'string' || speaker_embeddings instanceof URL) { + // Load from URL with fetch + speaker_embeddings = new Float32Array( + await (await fetch(speaker_embeddings)).arrayBuffer() + ); + } + + if (speaker_embeddings instanceof Float32Array) { + speaker_embeddings = new Tensor( + 'float32', + speaker_embeddings, + [1, speaker_embeddings.length] + ) + } else if (!(speaker_embeddings instanceof Tensor)) { + throw new Error("Speaker embeddings must be a `Tensor`, `Float32Array`, `string`, or `URL`.") + } + + // Run tokenization + const { input_ids } = this.tokenizer(text_inputs, { + padding: true, + truncation: true + }); + + // NOTE: At this point, we are guaranteed that `speaker_embeddings` is a `Tensor` + // @ts-ignore + const { waveform } = await this.model.generate_speech(input_ids, speaker_embeddings, { vocoder: this.vocoder }); + + const sampling_rate = this.processor.feature_extractor.config.sampling_rate; + return { + audio: waveform.data, + sampling_rate, + } + } +} + const SUPPORTED_TASKS = { "text-classification": { "tokenizer": AutoTokenizer, @@ -1950,7 +2050,7 @@ const SUPPORTED_TASKS = { "automatic-speech-recognition": { "tokenizer": AutoTokenizer, "pipeline": AutomaticSpeechRecognitionPipeline, - "model": [AutoModelForSeq2SeqLM, AutoModelForCTC], + "model": [AutoModelForSpeechSeq2Seq, AutoModelForCTC], "processor": AutoProcessor, "default": { // TODO: replace with original @@ -1959,7 +2059,18 @@ const SUPPORTED_TASKS = { }, "type": "multimodal", }, - + "text-to-audio": { + "tokenizer": AutoTokenizer, + "pipeline": TextToAudioPipeline, + "model": [ /* TODO: AutoModelForTextToWaveform, */ AutoModelForTextToSpectrogram], + "processor": AutoProcessor, + "default": { + // TODO: replace with original + // "model": "microsoft/speecht5_tts", + "model": "Xenova/speecht5_tts", + }, + "type": "text", + }, "image-to-text": { "tokenizer": AutoTokenizer, "pipeline": ImageToTextPipeline, @@ -2058,6 +2169,7 @@ const TASK_ALIASES = { "ner": "token-classification", "vqa": "visual-question-answering", "asr": "automatic-speech-recognition", + "text-to-speech": "text-to-audio", // Add for backwards compatibility "embeddings": "feature-extraction", diff --git a/src/processors.js b/src/processors.js index 8db7004f9..0770261c1 100644 --- a/src/processors.js +++ b/src/processors.js @@ -1309,6 +1309,8 @@ export class Wav2Vec2FeatureExtractor extends FeatureExtractor { } } +export class SpeechT5FeatureExtractor extends FeatureExtractor { } + /** * Represents a Processor that extracts features from an input. * @extends Callable @@ -1381,6 +1383,18 @@ export class Wav2Vec2ProcessorWithLM extends Processor { } } +export class SpeechT5Processor extends Processor { + /** + * Calls the feature_extractor function with the given input. + * @param {any} input The input to extract features from. + * @returns {Promise} A Promise that resolves with the extracted features. + */ + async _call(input) { + return await this.feature_extractor(input) + } +} + + ////////////////////////////////////////////////// /** * Helper class which is used to instantiate pretrained processors with the `from_pretrained` function. @@ -1426,12 +1440,14 @@ export class AutoProcessor { SamImageProcessor, Wav2Vec2FeatureExtractor, + SpeechT5FeatureExtractor, } static PROCESSOR_CLASS_MAPPING = { WhisperProcessor, Wav2Vec2ProcessorWithLM, SamProcessor, + SpeechT5Processor, } /** diff --git a/src/tokenizers.js b/src/tokenizers.js index 74d179a6d..1f7bd9fcd 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -3747,6 +3747,8 @@ export class Wav2Vec2CTCTokenizer extends PreTrainedTokenizer { } export class BlenderbotTokenizer extends PreTrainedTokenizer { } export class BlenderbotSmallTokenizer extends PreTrainedTokenizer { } +export class SpeechT5Tokenizer extends PreTrainedTokenizer { } + /** * Helper class which is used to instantiate pretrained tokenizers with the `from_pretrained` function. * The chosen tokenizer class is determined by the type specified in the tokenizer config. @@ -3788,6 +3790,7 @@ export class AutoTokenizer { Wav2Vec2CTCTokenizer, BlenderbotTokenizer, BlenderbotSmallTokenizer, + SpeechT5Tokenizer, // Base case: PreTrainedTokenizer, diff --git a/src/utils/tensor.js b/src/utils/tensor.js index 902fe9f87..f5c6dff83 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -988,3 +988,25 @@ function dimsToStride(dims) { } return stride; } + +/** + * Returns a tensor filled with the scalar value 1, with the shape defined by the variable argument size. + * @param {number[]} size A sequence of integers defining the shape of the output tensor. + */ +export function ones(size) { + const numElements = size.reduce((a, b) => a * b, 1); + return new Tensor( + 'int64', + new BigInt64Array(numElements).fill(1n), + size + ) +} + +/** + * Returns a tensor filled with the scalar value 1, with the same size as input. + * @param {Tensor} tensor The size of input will determine size of the output tensor. + * @returns The ones tensor. + */ +export function ones_like(tensor) { + return ones(tensor.dims); +} diff --git a/tests/generate_tests.py b/tests/generate_tests.py index adf35b399..2be31d749 100644 --- a/tests/generate_tests.py +++ b/tests/generate_tests.py @@ -28,6 +28,9 @@ # TODO: remove when https://github.com/huggingface/transformers/issues/26018 is fixed 'marian', + + # TODO: remove when https://github.com/huggingface/transformers/issues/26547 is fixed + 'speecht5', ] TOKENIZERS_TO_IGNORE = [