Skip to content

Commit

Permalink
Stream frames from resampler to non-real-time-vad (take 2) (ricky0123…
Browse files Browse the repository at this point in the history
…#108)

* Add Resampler#stream for streaming resampling in NonRealTimeVad

* Allow passing in a wav file to node example

* Apply prettier
  • Loading branch information
lox authored Jun 1, 2024
1 parent d09859e commit e33df77
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 38 deletions.
4 changes: 3 additions & 1 deletion examples/node/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ const vad = require("@ricky0123/vad-node")
const wav = require("wav-decoder")
const fs = require("fs")

const audioSamplePath = `${__dirname}/test.wav`
const audioSamplePath = process.argv[2] || `${__dirname}/test.wav`

console.log(`Processing ${audioSamplePath}`)

function loadAudio(audioPath) {
let buffer = fs.readFileSync(audioPath)
Expand Down
27 changes: 15 additions & 12 deletions packages/_common/src/non-real-time-vad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ interface NonRealTimeVADSpeechData {
end: number
}

export interface NonRealTimeVADOptions extends FrameProcessorOptions, OrtOptions {}
export interface NonRealTimeVADOptions
extends FrameProcessorOptions,
OrtOptions {}

export const defaultNonRealTimeVADOptions: NonRealTimeVADOptions = {
...defaultFrameProcessorOptions,
ortConfig: undefined
ortConfig: undefined,
}

export class PlatformAgnosticNonRealTimeVAD {
Expand Down Expand Up @@ -77,33 +79,34 @@ export class PlatformAgnosticNonRealTimeVAD {
targetFrameSize: this.options.frameSamples,
}
const resampler = new Resampler(resamplerOptions)
const frames = resampler.process(inputAudio)
let start: number, end: number
for (const i of [...Array(frames.length)].keys()) {
const f = frames[i]
const { msg, audio } = await this.frameProcessor.process(f)
let start = 0
let end = 0
let frameIndex = 0

for await (const frame of resampler.stream(inputAudio)) {
const { msg, audio } = await this.frameProcessor.process(frame)
switch (msg) {
case Message.SpeechStart:
start = (i * this.options.frameSamples) / 16
start = (frameIndex * this.options.frameSamples) / 16
break

case Message.SpeechEnd:
end = ((i + 1) * this.options.frameSamples) / 16
// @ts-ignore
end = ((frameIndex + 1) * this.options.frameSamples) / 16
yield { audio, start, end }
break

default:
break
}
frameIndex++
}

const { msg, audio } = this.frameProcessor.endSegment()
if (msg == Message.SpeechEnd) {
yield {
audio,
// @ts-ignore
start,
end: (frames.length * this.options.frameSamples) / 16,
end: (frameIndex * this.options.frameSamples) / 16,
}
}
}
Expand Down
75 changes: 51 additions & 24 deletions packages/_common/src/resampler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,40 +20,67 @@ export class Resampler {

process = (audioFrame: Float32Array): Float32Array[] => {
const outputFrames: Array<Float32Array> = []
this.fillInputBuffer(audioFrame)

while (this.hasEnoughDataForFrame()) {
const outputFrame = this.generateOutputFrame()
outputFrames.push(outputFrame)
}

return outputFrames
}

stream = async function* (audioFrame: Float32Array) {
this.fillInputBuffer(audioFrame)

while (this.hasEnoughDataForFrame()) {
const outputFrame = this.generateOutputFrame()
yield outputFrame
}
}

private fillInputBuffer(audioFrame: Float32Array) {
for (const sample of audioFrame) {
this.inputBuffer.push(sample)
}
}

while (
private hasEnoughDataForFrame(): boolean {
return (
(this.inputBuffer.length * this.options.targetSampleRate) /
this.options.nativeSampleRate >
this.options.nativeSampleRate >=
this.options.targetFrameSize
) {
const outputFrame = new Float32Array(this.options.targetFrameSize)
let outputIndex = 0
let inputIndex = 0
while (outputIndex < this.options.targetFrameSize) {
let sum = 0
let num = 0
while (
inputIndex <
Math.min(
this.inputBuffer.length,
((outputIndex + 1) * this.options.nativeSampleRate) /
this.options.targetSampleRate
)
) {
sum += this.inputBuffer[inputIndex] as number
)
}

private generateOutputFrame(): Float32Array {
const outputFrame = new Float32Array(this.options.targetFrameSize)
let outputIndex = 0
let inputIndex = 0

while (outputIndex < this.options.targetFrameSize) {
let sum = 0
let num = 0
while (
inputIndex <
Math.min(
this.inputBuffer.length,
((outputIndex + 1) * this.options.nativeSampleRate) /
this.options.targetSampleRate
)
) {
const value = this.inputBuffer[inputIndex]
if (value !== undefined) {
sum += value
num++
inputIndex++
}
outputFrame[outputIndex] = sum / num
outputIndex++
inputIndex++
}
this.inputBuffer = this.inputBuffer.slice(inputIndex)
outputFrames.push(outputFrame)
outputFrame[outputIndex] = sum / num
outputIndex++
}
return outputFrames

this.inputBuffer = this.inputBuffer.slice(inputIndex)
return outputFrame
}
}
3 changes: 2 additions & 1 deletion packages/node/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
FrameProcessorOptions,
Message,
NonRealTimeVADOptions,
Resampler,
} from "./_common"
import * as fs from "fs/promises"

Expand All @@ -24,5 +25,5 @@ class NonRealTimeVAD extends PlatformAgnosticNonRealTimeVAD {
}
}

export { utils, FrameProcessor, Message, NonRealTimeVAD }
export { utils, Resampler, FrameProcessor, Message, NonRealTimeVAD }
export type { FrameProcessorOptions, NonRealTimeVADOptions }
115 changes: 115 additions & 0 deletions packages/node/test/resampler.spec.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
const vad = require("@ricky0123/vad-node")
const { assert } = require("chai")
const { audioSamplePath } = require("./utils")
const fs = require("fs")
const wav = require("wav-decoder")

function loadAudio(audioPath) {
let buffer = fs.readFileSync(audioPath)
let result = wav.decode.sync(buffer)
let audioData = new Float32Array(result.channelData[0].length)
for (let i = 0; i < audioData.length; i++) {
audioData[i] = result.channelData[0][i] // Assuming mono channel for simplicity
}
return [audioData, result.sampleRate]
}

describe("Resampler", function () {
const testCases = [
{ targetSampleRate: 8000, targetFrameSize: 160 },
{ targetSampleRate: 16000, targetFrameSize: 320 },
{ targetSampleRate: 22050, targetFrameSize: 441 },
{ targetSampleRate: 44100, targetFrameSize: 882 },
]

describe("process", function () {
const testCases = [
{ targetSampleRate: 8000, targetFrameSize: 160 },
{ targetSampleRate: 16000, targetFrameSize: 320 },
{ targetSampleRate: 22050, targetFrameSize: 441 },
{ targetSampleRate: 44100, targetFrameSize: 882 },
]

testCases.forEach(({ targetSampleRate, targetFrameSize }) => {
it(`should correctly resample audio to ${targetSampleRate} Hz with frame size ${targetFrameSize}`, async function () {
const [audioData, nativeSampleRate] = loadAudio(audioSamplePath)

const resampler = new vad.Resampler({
nativeSampleRate: nativeSampleRate,
targetSampleRate: targetSampleRate,
targetFrameSize: targetFrameSize,
})

const outputFrames = resampler.process(audioData)

// Calculate expected number of frames, discarding partial frame at the end
const duration = audioData.length / nativeSampleRate
const expectedNumberOfFrames = Math.floor(
(duration * targetSampleRate) / targetFrameSize
)

assert.equal(
outputFrames.length,
expectedNumberOfFrames,
"Number of output frames does not match expected"
)

// Check if the frame size is correct
outputFrames.forEach((frame) => {
assert.equal(
frame.length,
targetFrameSize,
"Output frame size is incorrect"
)
})
})
})
})

describe("stream", function () {
const testCases = [
{ targetSampleRate: 8000, targetFrameSize: 160 },
{ targetSampleRate: 16000, targetFrameSize: 320 },
{ targetSampleRate: 22050, targetFrameSize: 441 },
{ targetSampleRate: 44100, targetFrameSize: 882 },
]

testCases.forEach(({ targetSampleRate, targetFrameSize }) => {
it(`should stream resampled audio frames correctly at ${targetSampleRate} Hz with frame size ${targetFrameSize}`, async function () {
const [audioData, nativeSampleRate] = loadAudio(audioSamplePath)

const resampler = new vad.Resampler({
nativeSampleRate: nativeSampleRate,
targetSampleRate: targetSampleRate,
targetFrameSize: targetFrameSize,
})

const frameStream = resampler.stream(audioData)
let frameCount = 0
let allFramesCorrectSize = true

for await (const frame of frameStream) {
frameCount++
if (frame.length !== targetFrameSize) {
allFramesCorrectSize = false
break
}
}

const expectedNumberOfFrames = Math.floor(
((audioData.length / nativeSampleRate) * targetSampleRate) /
targetFrameSize
)
assert.equal(
frameCount,
expectedNumberOfFrames,
"Number of streamed frames does not match expected"
)
assert.isTrue(
allFramesCorrectSize,
"Not all frames are of the correct size"
)
})
})
})
})

0 comments on commit e33df77

Please sign in to comment.