From 4efc20af1d83c9f0867f1b241ac5111d200c3137 Mon Sep 17 00:00:00 2001 From: efugier Date: Wed, 15 May 2024 11:37:50 +0200 Subject: [PATCH] feat(interface): implement voice --- Cargo.toml | 4 ++ src/config/prompt.rs | 5 ++ src/input_processing.rs | 10 ++- src/main.rs | 69 ++++++++++++++---- src/third_party/mod.rs | 108 +++++++++++++++------------- src/third_party/response_parsing.rs | 39 +++++----- src/voice/mod.rs | 83 +++++++++++++++++++++ src/voice/schemas.rs | 35 +++++++++ 8 files changed, 269 insertions(+), 84 deletions(-) create mode 100644 src/voice/mod.rs create mode 100644 src/voice/schemas.rs diff --git a/Cargo.toml b/Cargo.toml index f442966..d3b71da 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,9 +21,13 @@ clap = { version = "4", features = ["derive"] } glob = "0" log = "0" serde = { version = "1", features = ["derive"] } +serde_json = "1" toml = "0" ureq = { version="2", features = ["json"] } env_logger = "0" +# device_query = { version = "2", optional = true } +reqwest = { version = "0", features = ["json", "blocking", "multipart"] } +device_query = "2" [dev-dependencies] tempfile = "3" diff --git a/src/config/prompt.rs b/src/config/prompt.rs index 699a87b..3d9a9e6 100644 --- a/src/config/prompt.rs +++ b/src/config/prompt.rs @@ -10,6 +10,7 @@ use crate::config::{api::Api, resolve_config_path}; const PROMPT_FILE: &str = "prompts.toml"; const CONVERSATION_FILE: &str = "conversation.toml"; +const AUDIO_FILE: &str = "audio.wav"; #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] pub struct Prompt { @@ -100,6 +101,10 @@ pub fn conversation_file_path() -> PathBuf { resolve_config_path().join(CONVERSATION_FILE) } +pub fn audio_file_path() -> PathBuf { + resolve_config_path().join(AUDIO_FILE) +} + pub fn get_last_conversation_as_prompt() -> Prompt { let content = fs::read_to_string(conversation_file_path()).unwrap_or_else(|error| { panic!( diff --git a/src/input_processing.rs b/src/input_processing.rs index 8481e54..69c64b3 100644 --- a/src/input_processing.rs +++ b/src/input_processing.rs @@ -31,9 +31,13 @@ pub fn process_input_with_request( // fetch the api config tied to the prompt let api_config = get_api_config(&prompt.api.to_string()); - // make the request - let response_message = make_api_request(api_config, &prompt)?; - + let response_message = match make_api_request(api_config, &prompt) { + Ok(message) => message, + Err(e) => { + eprintln!("Failed to make API request: {:?}", e); + std::process::exit(1); + } + }; debug!("{}", &response_message.content); prompt.messages.push(response_message.clone()); diff --git a/src/main.rs b/src/main.rs index f01e1f0..0355a01 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,8 @@ +mod config; mod input_processing; mod prompt_customization; mod third_party; - -mod config; +mod voice; use crate::config::{ api::Api, @@ -52,6 +52,9 @@ struct Cli { /// whether to repeat the input before the output, useful to extend instead of replacing #[arg(short, long)] repeat_input: bool, + /// whether to use voice for input, requires may require admin permissions + #[arg(short, long)] + voice: bool, #[command(flatten)] prompt_params: PromptParams, } @@ -107,10 +110,24 @@ fn main() { .expect("Unable to verify that the config files exist or to generate new ones."); let is_piped = !stdin.is_terminal(); - let mut custom_prompt: Option = None; + let mut prompt_customizaton_text: Option = None; - let prompt: Prompt = if args.extend_conversation { - custom_prompt = args.input_or_config_ref; + let prompt: Prompt = if !args.extend_conversation { + // try to get prompt matching the first arg and use second arg as customization text + // if it doesn't use default prompt and treat that first arg as customization text + get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text) + } else { + if args.voice { + if args.input_or_config_ref.is_some() { + panic!( + "Invalid parameters, when using voice, either provide a valid ref to a config prompt or nothing at all.\n\ + Use `sc -v ` or `sc -v`" + ); + } + prompt_customizaton_text = voice::get_voice_transcript(); + } else { + prompt_customizaton_text = args.input_or_config_ref; + } if args.input_if_config_ref.is_some() { panic!( "Invalid parameters, cannot provide a config ref when extending a conversation.\n\ @@ -118,8 +135,6 @@ fn main() { ); } get_last_conversation_as_prompt() - } else { - get_default_and_or_custom_prompt(&args, &mut custom_prompt) }; // if no text was piped, use the custom prompt as input @@ -128,14 +143,14 @@ fn main() { } if input.is_empty() { - input.push_str(&custom_prompt.unwrap_or_default()); - custom_prompt = None; + input.push_str(&prompt_customizaton_text.unwrap_or_default()); + prompt_customizaton_text = None; } debug!("input: {}", input); - debug!("custom_prompt: {:?}", custom_prompt); + debug!("promt_customization_text: {:?}", prompt_customizaton_text); - let prompt = customize_prompt(prompt, &args.prompt_params, custom_prompt); + let prompt = customize_prompt(prompt, &args.prompt_params, prompt_customizaton_text); debug!("{:?}", prompt); @@ -156,7 +171,16 @@ fn main() { } } -fn get_default_and_or_custom_prompt(args: &Cli, custom_prompt: &mut Option) -> Prompt { +/// Fills prompt_customization_text with the correct part of the args +/// first arg -> input_or_config_ref +/// second arg -> input_if_config_ref +/// if first arg is a prompt name, get that prompt and use second arg as input +/// if not, use default prompt, use first arg as input and forbid second arg +/// when using voice, only a prompt name can be provided +fn get_default_and_or_custom_prompt( + args: &Cli, + prompt_customization_text: &mut Option, +) -> Prompt { let mut prompts = get_prompts(); let available_prompts: Vec<&String> = prompts.keys().collect(); let prompt_not_found_error = format!( @@ -171,17 +195,34 @@ fn get_default_and_or_custom_prompt(args: &Cli, custom_prompt: &mut Option` or `sc -v`" + ); + } + *prompt_customization_text = args.input_if_config_ref.clone() + } + if args.voice { + *prompt_customization_text = voice::get_voice_transcript(); } prompt } else { - *custom_prompt = Some(input_or_config_ref); + *prompt_customization_text = Some(input_or_config_ref); if args.input_if_config_ref.is_some() { + // first arg isn't a prompt and a second one was provided panic!( "Invalid parameters, either provide a valid ref to a config prompt then an input, or only an input.\n\ Use `sc \"\"`" ); + } else if args.voice { + panic!( + "Invalid parameters, when using voice, either provide a valid ref to a config prompt or nothing at all.\n\ + Use `sc -v ` or `sc -v`" + ); } + prompts .remove(DEFAULT_PROMPT_NAME) .expect(&prompt_not_found_error) diff --git a/src/third_party/mod.rs b/src/third_party/mod.rs index 6e1342a..653f5e0 100644 --- a/src/third_party/mod.rs +++ b/src/third_party/mod.rs @@ -2,9 +2,8 @@ mod prompt_adapters; mod response_parsing; use self::prompt_adapters::{AnthropicPrompt, OpenAiPrompt}; -use self::response_parsing::{AnthropicResponse, OpenAiResponse}; +use self::response_parsing::{AnthropicResponse, OllamaResponse, OpenAiResponse}; use crate::input_processing::is_interactive; -use crate::third_party::response_parsing::OllamaResponse; use crate::{ config::{ api::{Api, ApiConfig}, @@ -14,15 +13,22 @@ use crate::{ }; use log::debug; -use std::io; -pub fn make_api_request(api_config: ApiConfig, prompt: &Prompt) -> io::Result { +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +#[serde(untagged)] +enum PromptFormat { + OpenAi(OpenAiPrompt), + Anthropic(AnthropicPrompt), +} + +pub fn make_api_request(api_config: ApiConfig, prompt: &Prompt) -> reqwest::Result { debug!( "Trying to reach {:?} with key {:?}", api_config.url, api_config.api_key ); - debug!("request content: {:?}", prompt); - + debug!("Prompt: {:?}", prompt); validate_prompt_size(prompt); let mut prompt = prompt.clone(); @@ -34,56 +40,60 @@ pub fn make_api_request(api_config: ApiConfig, prompt: &Prompt) -> io::Result { - let request = request.set("Content-Type", "application/json"); - let response: OllamaResponse = - read_response(request.send_json(OpenAiPrompt::from(prompt)))?.into_json()?; - response.into() - } - Api::Openai | Api::Mistral | Api::Groq => { - let request = request.set("Content-Type", "application/json").set( - "Authorization", - &format!("Bearer {}", &api_config.get_api_key()), - ); - let response: OpenAiResponse = - read_response(request.send_json(OpenAiPrompt::from(prompt)))?.into_json()?; - response.into() - } - Api::Anthropic => { - let request = request - .set("Content-Type", "application/json") - .set("x-api-key", &api_config.get_api_key()) - .set( - "anthropic-version", - &api_config.version.expect( - "version required for Anthropic, please add version key to your api config", - ), - ); - let response: AnthropicResponse = - read_response(request.send_json(AnthropicPrompt::from(prompt)))?.into_json()?; - response.into() + let client = reqwest::blocking::Client::new(); + + let prompt_format = match prompt.api { + Api::Ollama | Api::Openai | Api::Mistral | Api::Groq => { + PromptFormat::OpenAi(OpenAiPrompt::from(prompt.clone())) } + Api::Anthropic => PromptFormat::Anthropic(AnthropicPrompt::from(prompt.clone())), Api::AnotherApiForTests => panic!("This api is not made for actual use."), }; + let request = client + .post(&api_config.url) + .header("Content-Type", "application/json") + .json(&prompt_format); + + // Add auth if necessary + let request = match prompt.api { + Api::Openai | Api::Mistral | Api::Groq => request.header( + "Authorization", + &format!("Bearer {}", &api_config.get_api_key()), + ), + Api::Anthropic => request + .header("x-api-key", &api_config.get_api_key()) + .header( + "anthropic-version", + &api_config.version.expect( + "version required for Anthropic, please add version key to your api config", + ), + ), + _ => request, + }; + + let response_text: String = match prompt.api { + Api::Ollama => handle_api_response::(request.send()?), + Api::Openai | Api::Mistral | Api::Groq => { + handle_api_response::(request.send()?) + } + Api::Anthropic => handle_api_response::(request.send()?), + Api::AnotherApiForTests => unreachable!(), + }; Ok(Message::assistant(&response_text)) } -fn read_response(response: Result) -> io::Result { - response.map_err(|e| match e { - ureq::Error::Status(status, response) => { - let content = match response.into_string() { - Ok(content) => content, - Err(_) => "(non-UTF-8 response)".to_owned(), - }; - io::Error::other(format!( - "API call failed with status code {status} and body: {content}" - )) - } - ureq::Error::Transport(transport) => io::Error::other(transport), - }) +/// clean error management +pub fn handle_api_response>( + response: reqwest::blocking::Response, +) -> String { + let status = response.status(); + if response.status().is_success() { + response.json::().unwrap().into() + } else { + let error_text = response.text().unwrap(); + panic!("API request failed with status {}: {}", status, error_text); + } } fn validate_prompt_size(prompt: &Prompt) { diff --git a/src/third_party/response_parsing.rs b/src/third_party/response_parsing.rs index ec599d3..0d32d38 100644 --- a/src/third_party/response_parsing.rs +++ b/src/third_party/response_parsing.rs @@ -2,11 +2,10 @@ use crate::config::prompt::Message; use serde::Deserialize; use std::fmt::Debug; +// OpenAi #[derive(Debug, Deserialize)] -pub(super) struct AnthropicMessage { - pub text: String, - #[serde(rename(serialize = "type", deserialize = "type"))] - pub _type: String, +pub(super) struct OpenAiResponse { + pub choices: Vec, } #[derive(Debug, Deserialize)] @@ -14,9 +13,24 @@ pub(super) struct MessageWrapper { pub message: Message, } +impl From for String { + fn from(value: OpenAiResponse) -> Self { + value.choices.first().unwrap().message.content.to_owned() + } +} + +// Anthropic #[derive(Debug, Deserialize)] -pub(super) struct OpenAiResponse { - pub choices: Vec, +pub(super) struct AnthropicMessage { + pub text: String, + #[serde(rename(serialize = "type", deserialize = "type"))] + pub _type: String, +} + +impl From for String { + fn from(value: AnthropicResponse) -> Self { + value.content.first().unwrap().text.to_owned() + } } #[derive(Debug, Deserialize)] @@ -24,6 +38,7 @@ pub(super) struct AnthropicResponse { pub content: Vec, } +// Ollama #[derive(Debug, Deserialize)] pub(super) struct OllamaResponse { pub message: Message, @@ -34,15 +49,3 @@ impl From for String { value.message.content } } - -impl From for String { - fn from(value: AnthropicResponse) -> Self { - value.content.first().unwrap().text.to_owned() - } -} - -impl From for String { - fn from(value: OpenAiResponse) -> Self { - value.choices.first().unwrap().message.content.to_owned() - } -} diff --git a/src/voice/mod.rs b/src/voice/mod.rs new file mode 100644 index 0000000..b084311 --- /dev/null +++ b/src/voice/mod.rs @@ -0,0 +1,83 @@ +pub mod schemas; + +use device_query::{DeviceQuery, DeviceState, Keycode}; +use std::process::{Child, Command}; + +use self::schemas::{OpenAiVoiceResponse, VoiceConfig}; +use super::config::{api::get_api_config, prompt::audio_file_path}; +use super::third_party::handle_api_response; + +fn start_recording() -> Option { + let os_string = audio_file_path().into_os_string(); + let audio_file_path = os_string; + + match std::env::consts::OS { + "windows" => Command::new("cmd") + .args(["/C", "start", "rec"]) + .arg(audio_file_path) + .spawn() + .ok(), + "macos" => Command::new("sox") + .arg("-d") + .arg(audio_file_path) + .spawn() + .ok(), + "linux" => Command::new("arecord") + .arg("-f") + .arg("S16_LE") + .arg("--quiet") + .arg(audio_file_path) + .spawn() + .ok(), + os => panic!("Unexpected os: {}", os), + } +} + +fn stop_recording(process: &mut Child) { + process.kill().expect("Failed to stop recording."); +} + +pub fn get_voice_transcript() -> Option { + use std::time::Instant; + + let mut process = start_recording()?; + let device_state = DeviceState::new(); + let start_time = Instant::now(); + + loop { + let keys: Vec = device_state.get_keys(); + if keys.contains(&Keycode::Space) { + break; + } + std::thread::sleep(std::time::Duration::from_millis(100)); + } + + stop_recording(&mut process); + let duration = start_time.elapsed(); + println!("Recording duration: {:?}", duration); + std::thread::sleep(std::time::Duration::from_secs(1)); + let voice_config = VoiceConfig::default(); + + let api_config = get_api_config(&voice_config.api.to_string()); + + let transcript = + post_audio(&api_config.get_api_key()).expect("Failed to send audio file to API"); + + Some(transcript) +} + +fn post_audio(api_key: &str) -> reqwest::Result { + let client = reqwest::blocking::Client::new(); + let form = reqwest::blocking::multipart::Form::new() + .text("model", "whisper-1") + .file("file", audio_file_path()) + .expect("Failed to read audio file."); + + let response = client + .post("https://api.openai.com/v1/audio/transcriptions") + .bearer_auth(api_key) + .multipart(form) + .send()?; + + Ok(handle_api_response::(response)) +} diff --git a/src/voice/schemas.rs b/src/voice/schemas.rs new file mode 100644 index 0000000..019219b --- /dev/null +++ b/src/voice/schemas.rs @@ -0,0 +1,35 @@ +use serde::{Deserialize, Serialize}; + +use crate::config::api::Api; + +#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)] +pub(super) struct VoiceConfig { + pub url: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub recording_command: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + pub api: Api, +} + +impl Default for VoiceConfig { + fn default() -> Self { + VoiceConfig { + url: "https://api.openai.com/v1/audio/transcriptions".to_string(), + recording_command: None, + model: Some("whisper-1".to_string()), + api: Api::Openai, + } + } +} + +#[derive(Debug, Deserialize)] +pub struct OpenAiVoiceResponse { + text: String, +} + +impl From for String { + fn from(response: OpenAiVoiceResponse) -> Self { + response.text + } +}