-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature request: diarize with whishper #4
Comments
Hi there! I use Whisper along with pyannote-rs in my project, Vibe. If anyone is interested in creating an example for it, contributions are more than welcome! It shouldn't be too difficult. The dependencies for Whisper.cpp should match those in the project I mentioned (a custom setup). |
use eyre::{Result, eyre};
use pyannote_rs::{EmbeddingExtractor, EmbeddingManager};
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
use std::time::Instant;
fn create_whisper_context(model_path: &str) -> Result<WhisperContext> {
let mut ctx_params = WhisperContextParameters::default();
ctx_params.use_gpu = true;
WhisperContext::new_with_params(model_path, ctx_params)
.map_err(|e| eyre!("Failed to create Whisper context: {:?}", e))
}
fn setup_whisper_params() -> FullParams<'static, 'static> {
let mut params = FullParams::new(SamplingStrategy::default());
params.set_print_progress(false);
params.set_print_timestamps(false);
params.set_print_special(false);
params.set_suppress_blank(true);
params.set_single_segment(true);
params
}
fn transcribe_segment(state: &mut whisper_rs::WhisperState, params: &FullParams, samples: &[f32]) -> Result<String> {
state.full(params.clone(), samples)?;
let num_segments = state.full_n_segments()?;
let mut transcription = String::new();
for i in 0..num_segments {
transcription.push_str(&state.full_get_segment_text(i)?);
transcription.push(' ');
}
Ok(transcription.trim().to_string())
}
fn main() -> Result<()> {
let audio_path = "samples/output_16000.wav";
let max_speakers = 2;
let search_threshold = 0.5;
let embedding_model_path = "wespeaker_en_voxceleb_CAM++.onnx";
let segmentation_model_path = "segmentation-3.0.onnx";
let whisper_model_path: &str = "ggml-base.bin";
// Initialize Whisper
let ctx = create_whisper_context(whisper_model_path)?;
let mut state = ctx.create_state()?;
let params = setup_whisper_params();
// Set up pyannote-rs
let (samples, sample_rate) = pyannote_rs::read_wav(audio_path)?;
let mut embedding_extractor = EmbeddingExtractor::new(embedding_model_path)?;
let mut embedding_manager = EmbeddingManager::new(max_speakers);
let segments = pyannote_rs::segment(&samples, sample_rate, segmentation_model_path)?;
// Convert samples to f32 for Whisper
let mut whisper_samples = vec![0.0f32; samples.len()];
whisper_rs::convert_integer_to_float_audio(&samples, &mut whisper_samples)?;
let start_time = Instant::now();
// Process each segment
for segment in segments {
// Compute the embedding result for speaker identification
let embedding_result: Vec<f32> = match embedding_extractor.compute(&segment.samples) {
Ok(result) => result.collect(),
Err(error) => {
eprintln!(
"Error in {:.2}s - {:.2}s: {:?}",
segment.start, segment.end, error
);
continue; // Skip to the next segment
}
};
let speaker = if embedding_manager.get_all_speakers().len() == max_speakers {
embedding_manager
.get_best_speaker_match(embedding_result.clone())
.map(|r| r.to_string())
.unwrap_or_else(|_| "Unknown".to_string())
} else {
embedding_manager
.search_speaker(embedding_result, search_threshold)
.map(|r| r.to_string())
.unwrap_or_else(|| "Unknown".to_string())
};
// Extract the segment's samples for Whisper
let start_sample = (segment.start * sample_rate as f64) as usize;
let end_sample = (segment.end * sample_rate as f64) as usize;
let segment_samples = &whisper_samples[start_sample..end_sample];
// Transcribe the segment using Whisper
let transcription = transcribe_segment(&mut state, ¶ms, segment_samples)?;
println!(
"start = {:.2}, end = {:.2}, speaker = {}, text = {}",
segment.start, segment.end, speaker, transcription
);
}
let processing_time = start_time.elapsed().as_secs();
println!("Processing time: {} seconds", processing_time);
Ok(())
} you can simply try this. it's not perfect but it should give you an idea I believe :-) But if you are in Windows, you should look at it here too: #3 |
https://github.com/thewh1teagle/vad-rs/tree/main/examples/whisper Maybe we can adapt this one too |
isn't that a real-time version? |
hello thank you!
is it possible to add a simple example to show how this work with whishper?
I have also tried it myself, but I do not have any familiarity on this area and bad with rust, I failed eventually.
regards.
The text was updated successfully, but these errors were encountered: