Skip to content

Commit

Permalink
feat(interface): implement voice
Browse files Browse the repository at this point in the history
  • Loading branch information
efugier committed May 16, 2024
1 parent 0368d37 commit 44ecae8
Show file tree
Hide file tree
Showing 8 changed files with 269 additions and 84 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ clap = { version = "4", features = ["derive"] }
glob = "0"
log = "0"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
toml = "0"
ureq = { version="2", features = ["json"] }
env_logger = "0"
# device_query = { version = "2", optional = true }
reqwest = { version = "0", features = ["json", "blocking", "multipart"] }
device_query = "2"

[dev-dependencies]
tempfile = "3"
Expand Down
5 changes: 5 additions & 0 deletions src/config/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::config::{api::Api, resolve_config_path};

const PROMPT_FILE: &str = "prompts.toml";
const CONVERSATION_FILE: &str = "conversation.toml";
const AUDIO_FILE: &str = "audio.wav";

#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
pub struct Prompt {
Expand Down Expand Up @@ -100,6 +101,10 @@ pub fn conversation_file_path() -> PathBuf {
resolve_config_path().join(CONVERSATION_FILE)
}

pub fn audio_file_path() -> PathBuf {
resolve_config_path().join(AUDIO_FILE)
}

pub fn get_last_conversation_as_prompt() -> Prompt {
let content = fs::read_to_string(conversation_file_path()).unwrap_or_else(|error| {
panic!(
Expand Down
10 changes: 7 additions & 3 deletions src/input_processing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@ pub fn process_input_with_request<W: Write>(
// fetch the api config tied to the prompt
let api_config = get_api_config(&prompt.api.to_string());

// make the request
let response_message = make_api_request(api_config, &prompt)?;

let response_message = match make_api_request(api_config, &prompt) {
Ok(message) => message,
Err(e) => {
eprintln!("Failed to make API request: {:?}", e);
std::process::exit(1);
}
};
debug!("{}", &response_message.content);

prompt.messages.push(response_message.clone());
Expand Down
69 changes: 55 additions & 14 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
mod config;
mod input_processing;
mod prompt_customization;
mod third_party;

mod config;
mod voice;

use crate::config::{
api::Api,
Expand Down Expand Up @@ -52,6 +52,9 @@ struct Cli {
/// whether to repeat the input before the output, useful to extend instead of replacing
#[arg(short, long)]
repeat_input: bool,
/// whether to use voice for input, requires may require admin permissions
#[arg(short, long)]
voice: bool,
#[command(flatten)]
prompt_params: PromptParams,
}
Expand Down Expand Up @@ -107,19 +110,31 @@ fn main() {
.expect("Unable to verify that the config files exist or to generate new ones.");

let is_piped = !stdin.is_terminal();
let mut custom_prompt: Option<String> = None;
let mut prompt_customizaton_text: Option<String> = None;

let prompt: Prompt = if args.extend_conversation {
custom_prompt = args.input_or_config_ref;
let prompt: Prompt = if !args.extend_conversation {
// try to get prompt matching the first arg and use second arg as customization text
// if it doesn't use default prompt and treat that first arg as customization text
get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text)
} else {
if args.voice {
if args.input_or_config_ref.is_some() {
panic!(
"Invalid parameters, when using voice, either provide a valid ref to a config prompt or nothing at all.\n\
Use `sc -v <config_ref>` or `sc -v`"
);
}
prompt_customizaton_text = voice::get_voice_transcript();
} else {
prompt_customizaton_text = args.input_or_config_ref;
}
if args.input_if_config_ref.is_some() {
panic!(
"Invalid parameters, cannot provide a config ref when extending a conversation.\n\
Use `sc -e \"<your_prompt>.\"`"
);
}
get_last_conversation_as_prompt()
} else {
get_default_and_or_custom_prompt(&args, &mut custom_prompt)
};

// if no text was piped, use the custom prompt as input
Expand All @@ -128,14 +143,14 @@ fn main() {
}

if input.is_empty() {
input.push_str(&custom_prompt.unwrap_or_default());
custom_prompt = None;
input.push_str(&prompt_customizaton_text.unwrap_or_default());
prompt_customizaton_text = None;
}

debug!("input: {}", input);
debug!("custom_prompt: {:?}", custom_prompt);
debug!("promt_customization_text: {:?}", prompt_customizaton_text);

let prompt = customize_prompt(prompt, &args.prompt_params, custom_prompt);
let prompt = customize_prompt(prompt, &args.prompt_params, prompt_customizaton_text);

debug!("{:?}", prompt);

Expand All @@ -156,7 +171,16 @@ fn main() {
}
}

fn get_default_and_or_custom_prompt(args: &Cli, custom_prompt: &mut Option<String>) -> Prompt {
/// Fills prompt_customization_text with the correct part of the args
/// first arg -> input_or_config_ref
/// second arg -> input_if_config_ref
/// if first arg is a prompt name, get that prompt and use second arg as input
/// if not, use default prompt, use first arg as input and forbid second arg
/// when using voice, only a prompt name can be provided
fn get_default_and_or_custom_prompt(
args: &Cli,
prompt_customization_text: &mut Option<String>,
) -> Prompt {
let mut prompts = get_prompts();
let available_prompts: Vec<&String> = prompts.keys().collect();
let prompt_not_found_error = format!(
Expand All @@ -171,17 +195,34 @@ fn get_default_and_or_custom_prompt(args: &Cli, custom_prompt: &mut Option<Strin

if let Some(prompt) = prompts.remove(&input_or_config_ref) {
if args.input_if_config_ref.is_some() {
*custom_prompt = args.input_if_config_ref.clone()
// first arg matching a prompt and second one is customization
if args.voice {
panic!(
"Invalid parameters, when using voice, either provide a valid ref to a config prompt or nothing at all.\n\
Use `sc -v <config_ref>` or `sc -v`"
);
}
*prompt_customization_text = args.input_if_config_ref.clone()
}
if args.voice {
*prompt_customization_text = voice::get_voice_transcript();
}
prompt
} else {
*custom_prompt = Some(input_or_config_ref);
*prompt_customization_text = Some(input_or_config_ref);
if args.input_if_config_ref.is_some() {
// first arg isn't a prompt and a second one was provided
panic!(
"Invalid parameters, either provide a valid ref to a config prompt then an input, or only an input.\n\
Use `sc <config_ref> \"<your_prompt\"` or `sc \"<your_prompt>\"`"
);
} else if args.voice {
panic!(
"Invalid parameters, when using voice, either provide a valid ref to a config prompt or nothing at all.\n\
Use `sc -v <config_ref>` or `sc -v`"
);
}

prompts
.remove(DEFAULT_PROMPT_NAME)
.expect(&prompt_not_found_error)
Expand Down
108 changes: 59 additions & 49 deletions src/third_party/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ mod prompt_adapters;
mod response_parsing;

use self::prompt_adapters::{AnthropicPrompt, OpenAiPrompt};
use self::response_parsing::{AnthropicResponse, OpenAiResponse};
use self::response_parsing::{AnthropicResponse, OllamaResponse, OpenAiResponse};
use crate::input_processing::is_interactive;
use crate::third_party::response_parsing::OllamaResponse;
use crate::{
config::{
api::{Api, ApiConfig},
Expand All @@ -14,15 +13,22 @@ use crate::{
};

use log::debug;
use std::io;

pub fn make_api_request(api_config: ApiConfig, prompt: &Prompt) -> io::Result<Message> {
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize)]
#[serde(untagged)]
enum PromptFormat {
OpenAi(OpenAiPrompt),
Anthropic(AnthropicPrompt),
}

pub fn make_api_request(api_config: ApiConfig, prompt: &Prompt) -> reqwest::Result<Message> {
debug!(
"Trying to reach {:?} with key {:?}",
api_config.url, api_config.api_key
);
debug!("request content: {:?}", prompt);

debug!("Prompt: {:?}", prompt);
validate_prompt_size(prompt);

let mut prompt = prompt.clone();
Expand All @@ -34,56 +40,60 @@ pub fn make_api_request(api_config: ApiConfig, prompt: &Prompt) -> io::Result<Me
// currently not compatible with streams
prompt.stream = Some(false);

let request = ureq::post(&api_config.url);
let response_text: String = match prompt.api {
Api::Ollama => {
let request = request.set("Content-Type", "application/json");
let response: OllamaResponse =
read_response(request.send_json(OpenAiPrompt::from(prompt)))?.into_json()?;
response.into()
}
Api::Openai | Api::Mistral | Api::Groq => {
let request = request.set("Content-Type", "application/json").set(
"Authorization",
&format!("Bearer {}", &api_config.get_api_key()),
);
let response: OpenAiResponse =
read_response(request.send_json(OpenAiPrompt::from(prompt)))?.into_json()?;
response.into()
}
Api::Anthropic => {
let request = request
.set("Content-Type", "application/json")
.set("x-api-key", &api_config.get_api_key())
.set(
"anthropic-version",
&api_config.version.expect(
"version required for Anthropic, please add version key to your api config",
),
);
let response: AnthropicResponse =
read_response(request.send_json(AnthropicPrompt::from(prompt)))?.into_json()?;
response.into()
let client = reqwest::blocking::Client::new();

let prompt_format = match prompt.api {
Api::Ollama | Api::Openai | Api::Mistral | Api::Groq => {
PromptFormat::OpenAi(OpenAiPrompt::from(prompt.clone()))
}
Api::Anthropic => PromptFormat::Anthropic(AnthropicPrompt::from(prompt.clone())),
Api::AnotherApiForTests => panic!("This api is not made for actual use."),
};

let request = client
.post(&api_config.url)
.header("Content-Type", "application/json")
.json(&prompt_format);

// Add auth if necessary
let request = match prompt.api {
Api::Openai | Api::Mistral | Api::Groq => request.header(
"Authorization",
&format!("Bearer {}", &api_config.get_api_key()),
),
Api::Anthropic => request
.header("x-api-key", &api_config.get_api_key())
.header(
"anthropic-version",
&api_config.version.expect(
"version required for Anthropic, please add version key to your api config",
),
),
_ => request,
};

let response_text: String = match prompt.api {
Api::Ollama => handle_api_response::<OllamaResponse>(request.send()?),
Api::Openai | Api::Mistral | Api::Groq => {
handle_api_response::<OpenAiResponse>(request.send()?)
}
Api::Anthropic => handle_api_response::<AnthropicResponse>(request.send()?),
Api::AnotherApiForTests => unreachable!(),
};
Ok(Message::assistant(&response_text))
}

fn read_response(response: Result<ureq::Response, ureq::Error>) -> io::Result<ureq::Response> {
response.map_err(|e| match e {
ureq::Error::Status(status, response) => {
let content = match response.into_string() {
Ok(content) => content,
Err(_) => "(non-UTF-8 response)".to_owned(),
};
io::Error::other(format!(
"API call failed with status code {status} and body: {content}"
))
}
ureq::Error::Transport(transport) => io::Error::other(transport),
})
/// clean error management
pub fn handle_api_response<T: serde::de::DeserializeOwned + Into<String>>(
response: reqwest::blocking::Response,
) -> String {
let status = response.status();
if response.status().is_success() {
response.json::<T>().unwrap().into()
} else {
let error_text = response.text().unwrap();
panic!("API request failed with status {}: {}", status, error_text);
}
}

fn validate_prompt_size(prompt: &Prompt) {
Expand Down
39 changes: 21 additions & 18 deletions src/third_party/response_parsing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,43 @@ use crate::config::prompt::Message;
use serde::Deserialize;
use std::fmt::Debug;

// OpenAi
#[derive(Debug, Deserialize)]
pub(super) struct AnthropicMessage {
pub text: String,
#[serde(rename(serialize = "type", deserialize = "type"))]
pub _type: String,
pub(super) struct OpenAiResponse {
pub choices: Vec<MessageWrapper>,
}

#[derive(Debug, Deserialize)]
pub(super) struct MessageWrapper {
pub message: Message,
}

impl From<OpenAiResponse> for String {
fn from(value: OpenAiResponse) -> Self {
value.choices.first().unwrap().message.content.to_owned()
}
}

// Anthropic
#[derive(Debug, Deserialize)]
pub(super) struct OpenAiResponse {
pub choices: Vec<MessageWrapper>,
pub(super) struct AnthropicMessage {
pub text: String,
#[serde(rename(serialize = "type", deserialize = "type"))]
pub _type: String,
}

impl From<AnthropicResponse> for String {
fn from(value: AnthropicResponse) -> Self {
value.content.first().unwrap().text.to_owned()
}
}

#[derive(Debug, Deserialize)]
pub(super) struct AnthropicResponse {
pub content: Vec<AnthropicMessage>,
}

// Ollama
#[derive(Debug, Deserialize)]
pub(super) struct OllamaResponse {
pub message: Message,
Expand All @@ -34,15 +49,3 @@ impl From<OllamaResponse> for String {
value.message.content
}
}

impl From<AnthropicResponse> for String {
fn from(value: AnthropicResponse) -> Self {
value.content.first().unwrap().text.to_owned()
}
}

impl From<OpenAiResponse> for String {
fn from(value: OpenAiResponse) -> Self {
value.choices.first().unwrap().message.content.to_owned()
}
}
Loading

0 comments on commit 44ecae8

Please sign in to comment.