From e33df77362d4fffa1d5d80caca9af2fc9f2f64d5 Mon Sep 17 00:00:00 2001 From: Lachlan Donald Date: Sun, 2 Jun 2024 09:22:03 +1000 Subject: [PATCH] Stream frames from resampler to non-real-time-vad (take 2) (#108) * Add Resampler#stream for streaming resampling in NonRealTimeVad * Allow passing in a wav file to node example * Apply prettier --- examples/node/index.js | 4 +- packages/_common/src/non-real-time-vad.ts | 27 ++--- packages/_common/src/resampler.ts | 75 +++++++++----- packages/node/src/index.ts | 3 +- packages/node/test/resampler.spec.js | 115 ++++++++++++++++++++++ 5 files changed, 186 insertions(+), 38 deletions(-) create mode 100644 packages/node/test/resampler.spec.js diff --git a/examples/node/index.js b/examples/node/index.js index 38eeba5..16cec91 100644 --- a/examples/node/index.js +++ b/examples/node/index.js @@ -2,7 +2,9 @@ const vad = require("@ricky0123/vad-node") const wav = require("wav-decoder") const fs = require("fs") -const audioSamplePath = `${__dirname}/test.wav` +const audioSamplePath = process.argv[2] || `${__dirname}/test.wav` + +console.log(`Processing ${audioSamplePath}`) function loadAudio(audioPath) { let buffer = fs.readFileSync(audioPath) diff --git a/packages/_common/src/non-real-time-vad.ts b/packages/_common/src/non-real-time-vad.ts index 6c3caf7..8541ab3 100644 --- a/packages/_common/src/non-real-time-vad.ts +++ b/packages/_common/src/non-real-time-vad.ts @@ -15,11 +15,13 @@ interface NonRealTimeVADSpeechData { end: number } -export interface NonRealTimeVADOptions extends FrameProcessorOptions, OrtOptions {} +export interface NonRealTimeVADOptions + extends FrameProcessorOptions, + OrtOptions {} export const defaultNonRealTimeVADOptions: NonRealTimeVADOptions = { ...defaultFrameProcessorOptions, - ortConfig: undefined + ortConfig: undefined, } export class PlatformAgnosticNonRealTimeVAD { @@ -77,33 +79,34 @@ export class PlatformAgnosticNonRealTimeVAD { targetFrameSize: this.options.frameSamples, } const resampler = new Resampler(resamplerOptions) - const frames = resampler.process(inputAudio) - let start: number, end: number - for (const i of [...Array(frames.length)].keys()) { - const f = frames[i] - const { msg, audio } = await this.frameProcessor.process(f) + let start = 0 + let end = 0 + let frameIndex = 0 + + for await (const frame of resampler.stream(inputAudio)) { + const { msg, audio } = await this.frameProcessor.process(frame) switch (msg) { case Message.SpeechStart: - start = (i * this.options.frameSamples) / 16 + start = (frameIndex * this.options.frameSamples) / 16 break case Message.SpeechEnd: - end = ((i + 1) * this.options.frameSamples) / 16 - // @ts-ignore + end = ((frameIndex + 1) * this.options.frameSamples) / 16 yield { audio, start, end } break default: break } + frameIndex++ } + const { msg, audio } = this.frameProcessor.endSegment() if (msg == Message.SpeechEnd) { yield { audio, - // @ts-ignore start, - end: (frames.length * this.options.frameSamples) / 16, + end: (frameIndex * this.options.frameSamples) / 16, } } } diff --git a/packages/_common/src/resampler.ts b/packages/_common/src/resampler.ts index 4972c6e..71259f6 100644 --- a/packages/_common/src/resampler.ts +++ b/packages/_common/src/resampler.ts @@ -20,40 +20,67 @@ export class Resampler { process = (audioFrame: Float32Array): Float32Array[] => { const outputFrames: Array = [] + this.fillInputBuffer(audioFrame) + while (this.hasEnoughDataForFrame()) { + const outputFrame = this.generateOutputFrame() + outputFrames.push(outputFrame) + } + + return outputFrames + } + + stream = async function* (audioFrame: Float32Array) { + this.fillInputBuffer(audioFrame) + + while (this.hasEnoughDataForFrame()) { + const outputFrame = this.generateOutputFrame() + yield outputFrame + } + } + + private fillInputBuffer(audioFrame: Float32Array) { for (const sample of audioFrame) { this.inputBuffer.push(sample) } + } - while ( + private hasEnoughDataForFrame(): boolean { + return ( (this.inputBuffer.length * this.options.targetSampleRate) / - this.options.nativeSampleRate > + this.options.nativeSampleRate >= this.options.targetFrameSize - ) { - const outputFrame = new Float32Array(this.options.targetFrameSize) - let outputIndex = 0 - let inputIndex = 0 - while (outputIndex < this.options.targetFrameSize) { - let sum = 0 - let num = 0 - while ( - inputIndex < - Math.min( - this.inputBuffer.length, - ((outputIndex + 1) * this.options.nativeSampleRate) / - this.options.targetSampleRate - ) - ) { - sum += this.inputBuffer[inputIndex] as number + ) + } + + private generateOutputFrame(): Float32Array { + const outputFrame = new Float32Array(this.options.targetFrameSize) + let outputIndex = 0 + let inputIndex = 0 + + while (outputIndex < this.options.targetFrameSize) { + let sum = 0 + let num = 0 + while ( + inputIndex < + Math.min( + this.inputBuffer.length, + ((outputIndex + 1) * this.options.nativeSampleRate) / + this.options.targetSampleRate + ) + ) { + const value = this.inputBuffer[inputIndex] + if (value !== undefined) { + sum += value num++ - inputIndex++ } - outputFrame[outputIndex] = sum / num - outputIndex++ + inputIndex++ } - this.inputBuffer = this.inputBuffer.slice(inputIndex) - outputFrames.push(outputFrame) + outputFrame[outputIndex] = sum / num + outputIndex++ } - return outputFrames + + this.inputBuffer = this.inputBuffer.slice(inputIndex) + return outputFrame } } diff --git a/packages/node/src/index.ts b/packages/node/src/index.ts index 6375199..58dc370 100644 --- a/packages/node/src/index.ts +++ b/packages/node/src/index.ts @@ -6,6 +6,7 @@ import { FrameProcessorOptions, Message, NonRealTimeVADOptions, + Resampler, } from "./_common" import * as fs from "fs/promises" @@ -24,5 +25,5 @@ class NonRealTimeVAD extends PlatformAgnosticNonRealTimeVAD { } } -export { utils, FrameProcessor, Message, NonRealTimeVAD } +export { utils, Resampler, FrameProcessor, Message, NonRealTimeVAD } export type { FrameProcessorOptions, NonRealTimeVADOptions } diff --git a/packages/node/test/resampler.spec.js b/packages/node/test/resampler.spec.js new file mode 100644 index 0000000..80a3666 --- /dev/null +++ b/packages/node/test/resampler.spec.js @@ -0,0 +1,115 @@ +const vad = require("@ricky0123/vad-node") +const { assert } = require("chai") +const { audioSamplePath } = require("./utils") +const fs = require("fs") +const wav = require("wav-decoder") + +function loadAudio(audioPath) { + let buffer = fs.readFileSync(audioPath) + let result = wav.decode.sync(buffer) + let audioData = new Float32Array(result.channelData[0].length) + for (let i = 0; i < audioData.length; i++) { + audioData[i] = result.channelData[0][i] // Assuming mono channel for simplicity + } + return [audioData, result.sampleRate] +} + +describe("Resampler", function () { + const testCases = [ + { targetSampleRate: 8000, targetFrameSize: 160 }, + { targetSampleRate: 16000, targetFrameSize: 320 }, + { targetSampleRate: 22050, targetFrameSize: 441 }, + { targetSampleRate: 44100, targetFrameSize: 882 }, + ] + + describe("process", function () { + const testCases = [ + { targetSampleRate: 8000, targetFrameSize: 160 }, + { targetSampleRate: 16000, targetFrameSize: 320 }, + { targetSampleRate: 22050, targetFrameSize: 441 }, + { targetSampleRate: 44100, targetFrameSize: 882 }, + ] + + testCases.forEach(({ targetSampleRate, targetFrameSize }) => { + it(`should correctly resample audio to ${targetSampleRate} Hz with frame size ${targetFrameSize}`, async function () { + const [audioData, nativeSampleRate] = loadAudio(audioSamplePath) + + const resampler = new vad.Resampler({ + nativeSampleRate: nativeSampleRate, + targetSampleRate: targetSampleRate, + targetFrameSize: targetFrameSize, + }) + + const outputFrames = resampler.process(audioData) + + // Calculate expected number of frames, discarding partial frame at the end + const duration = audioData.length / nativeSampleRate + const expectedNumberOfFrames = Math.floor( + (duration * targetSampleRate) / targetFrameSize + ) + + assert.equal( + outputFrames.length, + expectedNumberOfFrames, + "Number of output frames does not match expected" + ) + + // Check if the frame size is correct + outputFrames.forEach((frame) => { + assert.equal( + frame.length, + targetFrameSize, + "Output frame size is incorrect" + ) + }) + }) + }) + }) + + describe("stream", function () { + const testCases = [ + { targetSampleRate: 8000, targetFrameSize: 160 }, + { targetSampleRate: 16000, targetFrameSize: 320 }, + { targetSampleRate: 22050, targetFrameSize: 441 }, + { targetSampleRate: 44100, targetFrameSize: 882 }, + ] + + testCases.forEach(({ targetSampleRate, targetFrameSize }) => { + it(`should stream resampled audio frames correctly at ${targetSampleRate} Hz with frame size ${targetFrameSize}`, async function () { + const [audioData, nativeSampleRate] = loadAudio(audioSamplePath) + + const resampler = new vad.Resampler({ + nativeSampleRate: nativeSampleRate, + targetSampleRate: targetSampleRate, + targetFrameSize: targetFrameSize, + }) + + const frameStream = resampler.stream(audioData) + let frameCount = 0 + let allFramesCorrectSize = true + + for await (const frame of frameStream) { + frameCount++ + if (frame.length !== targetFrameSize) { + allFramesCorrectSize = false + break + } + } + + const expectedNumberOfFrames = Math.floor( + ((audioData.length / nativeSampleRate) * targetSampleRate) / + targetFrameSize + ) + assert.equal( + frameCount, + expectedNumberOfFrames, + "Number of streamed frames does not match expected" + ) + assert.isTrue( + allFramesCorrectSize, + "Not all frames are of the correct size" + ) + }) + }) + }) +})