diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8f73c78..64e1d7e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -12,42 +12,57 @@ Codebase quality improvement are very welcome as I hadn't really used rust since src/ │ # args parsing logic ├── main.rs +│ # a (manageable) handful of utility functions used in serveral other places +├── utils.rs │ # logic to customize the template prompt with the args ├── prompt_customization.rs │ # logic to insert the input into the prompt -├── input_processing.rs -│ # smartcat-related config structs ├── config │ │ # function to check config │   ├── mod.rs -│ │ # config structs for API definition (url, key...) +│ │ # config structs for API config definition (url, key...) │   ├── api.rs │ │ # config structs for prompt defition (messages, model, temperature...) -│   └── prompt.rs -│ # third-party-related code (request, adapters) -└── third_party +│   ├── prompt.rs +│ │ # config structs for voice config (model, url, voice recording command...) +│   └── voice.rs +│ # voice api related code (request, adapters) +├── voice +│ │ # orchestrate the voice recording and request +│ ├── mod.rs +│ │ # start and stop the recording program +│ ├── recording.rs +│ │ # make the request to the api and read the result +│ ├── api_call.rs +│ │ # structs to parse and extract the message from third party answers +│ └── response_schemas.rs +└── text │ # make third party requests and read the result ├── mod.rs + │ # make the request to the api and read the result + ├── api_call.rs │ # logic to adapt smartcat prompts to third party ones - ├── prompt_adapters.rs + ├── request_schemas.rs │ # structs to parse and extract the message from third party answers - └── response_parsing.rs + └── response_schemas.rs ``` #### Logic flow The prompt object is passed through the entire program, enriched with the input (from stdin) and then the third party response. The third party response is then written stdout and the whole conversation (including the input and the response) is then saved as the last prompt for re-use. +**Regular** + ```python main # parse the args and get the template prompt / continue with last conversation as prompt -> prompt_customization::customize_prompt ╎# update the templated prompt with the information from the args <- --> input_processing::process_input_with_request +-> text::process_input_with_request ╎# insert the input in the prompt ╎# load the api config - -> third_party::make_api_request + -> text::api_call::post_prompt_and_get_answer ╎# translate the smartcat prompt to api-specific prompt ╎# make the request ╎# get the message from api-specific response @@ -59,15 +74,33 @@ main # exit ``` +**Voice** + +```python +main +-> prompt_customization::customize_prompt +-> voice::record_voice_and_get_transcript + -> voice::recording::start_recording + -> voice::recording::strop_recording + -> voice::api_call::post_audio_and_get_transcript +<- +-> text::process_input_with_request + -> text::api_call::post_prompt_and_get_answer +<- +``` + ### Testing Some tests rely on environement variables and don't behave well with multi-threading. They are marked with `#[serial]` from the [serial_test](https://docs.rs/serial_test/latest/serial_test/index.html) crate. +### DOING + +- Voice intergation + ### TODO - [ ] make it available on homebrew - [ ] handle streams - [ ] automagical context fetches (might be out of scope) - [ ] add RAG capabilities (might be out of scope) -- [ ] refactor to remove content logic from the `mod.rs` files diff --git a/Cargo.lock b/Cargo.lock index 610a23b..41ae4f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -114,25 +114,6 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" -[[package]] -name = "block-sys" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae85a0696e7ea3b835a453750bf002770776609115e6d25c6d2ff28a8200f7e7" -dependencies = [ - "objc-sys", -] - -[[package]] -name = "block2" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e58aa60e59d8dbfcc36138f5f18be5f24394d33b38b24f7fd0b1caa33095f22f" -dependencies = [ - "block-sys", - "objc2", -] - [[package]] name = "bumpalo" version = "3.16.0" @@ -219,30 +200,6 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" -[[package]] -name = "core-graphics" -version = "0.23.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c07782be35f9e1140080c6b96f0d44b739e2278479f64e02fdab4e32dfd8b081" -dependencies = [ - "bitflags 1.3.2", - "core-foundation", - "core-graphics-types", - "foreign-types 0.5.0", - "libc", -] - -[[package]] -name = "core-graphics-types" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45390e6114f68f718cc7a830514a96f903cccd70d02a8f6d9f643ac4ba45afaf" -dependencies = [ - "bitflags 1.3.2", - "core-foundation", - "libc", -] - [[package]] name = "crc32fast" version = "1.4.0" @@ -266,29 +223,27 @@ dependencies = [ ] [[package]] -name = "encoding_rs" -version = "0.8.34" +name = "device_query" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59" +checksum = "d56aeed4b70abf636f9211a155a937ee52cf2ec97cbca99ebeb65deed6217f14" dependencies = [ - "cfg-if", + "lazy_static", + "macos-accessibility-client", + "pkg-config", + "readkey", + "readmouse", + "windows", + "x11", ] [[package]] -name = "enigo" -version = "0.2.0" +name = "encoding_rs" +version = "0.8.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a4655cdcad61e57cf28922aaa3221b06ce29e644422bba506851a03b8817468" +checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59" dependencies = [ - "core-graphics", - "foreign-types-shared 0.3.1", - "icrate", - "libc", - "log", - "objc2", - "windows", - "xkbcommon", - "xkeysym", + "cfg-if", ] [[package]] @@ -358,28 +313,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" dependencies = [ - "foreign-types-shared 0.1.1", -] - -[[package]] -name = "foreign-types" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" -dependencies = [ - "foreign-types-macros", - "foreign-types-shared 0.3.1", -] - -[[package]] -name = "foreign-types-macros" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" -dependencies = [ - "proc-macro2", - "quote", - "syn", + "foreign-types-shared", ] [[package]] @@ -388,12 +322,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" -[[package]] -name = "foreign-types-shared" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa9a19cbb55df58761df49b23516a86d432839add4af60fc256da840f66ed35b" - [[package]] name = "form_urlencoded" version = "1.2.1" @@ -642,16 +570,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "icrate" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fb69199826926eb864697bddd27f73d9fddcffc004f5733131e15b465e30642" -dependencies = [ - "block2", - "objc2", -] - [[package]] name = "idna" version = "0.5.0" @@ -734,19 +652,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" [[package]] -name = "memchr" -version = "2.7.2" +name = "macos-accessibility-client" +version = "0.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" +checksum = "edf7710fbff50c24124331760978fb9086d6de6288dcdb38b25a97f8b1bdebbb" +dependencies = [ + "core-foundation", + "core-foundation-sys", +] [[package]] -name = "memmap2" -version = "0.8.0" +name = "memchr" +version = "2.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43a5a03cefb0d953ec0be133036f14e109412fa594edc2f77227249db66cc3ed" -dependencies = [ - "libc", -] +checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" [[package]] name = "mime" @@ -812,28 +731,6 @@ dependencies = [ "libc", ] -[[package]] -name = "objc-sys" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da284c198fb9b7b0603f8635185e85fbd5b64ee154b1ed406d489077de2d6d60" - -[[package]] -name = "objc2" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4b25e1034d0e636cd84707ccdaa9f81243d399196b8a773946dcffec0401659" -dependencies = [ - "objc-sys", - "objc2-encode", -] - -[[package]] -name = "objc2-encode" -version = "4.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88658da63e4cc2c8adb1262902cd6af51094df0488b760d6fd27194269c0950a" - [[package]] name = "object" version = "0.32.2" @@ -857,7 +754,7 @@ checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" dependencies = [ "bitflags 2.5.0", "cfg-if", - "foreign-types 0.3.2", + "foreign-types", "libc", "once_cell", "openssl-macros", @@ -978,6 +875,18 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "readkey" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7677f98ca49bc9bb26e04c8abf80ba579e2cb98e8a384a0ff8128ad70670d249" + +[[package]] +name = "readmouse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be105c72a1e6a5a1198acee3d5b506a15676b74a02ecd78060042a447f408d94" + [[package]] name = "redox_syscall" version = "0.5.1" @@ -1276,7 +1185,7 @@ name = "smartcat" version = "1.2.2" dependencies = [ "clap", - "enigo", + "device_query", "env_logger", "glob", "log", @@ -1695,55 +1604,11 @@ dependencies = [ [[package]] name = "windows" -version = "0.56.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1de69df01bdf1ead2f4ac895dc77c9351aefff65b2f3db429a343f9cbf05e132" -dependencies = [ - "windows-core", - "windows-targets 0.52.5", -] - -[[package]] -name = "windows-core" -version = "0.56.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4698e52ed2d08f8658ab0c39512a7c00ee5fe2688c65f8c0a4f06750d729f2a6" -dependencies = [ - "windows-implement", - "windows-interface", - "windows-result", - "windows-targets 0.52.5", -] - -[[package]] -name = "windows-implement" -version = "0.56.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6fc35f58ecd95a9b71c4f2329b911016e6bec66b3f2e6a4aad86bd2e99e2f9b" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "windows-interface" -version = "0.56.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08990546bf4edef8f431fa6326e032865f27138718c587dc21bc0265bbcb57cc" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "windows-result" -version = "0.1.1" +version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "749f0da9cc72d82e600d8d2e44cadd0b9eedb9038f71a1c58556ac1c5791813b" +checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.48.5", ] [[package]] @@ -1905,22 +1770,15 @@ dependencies = [ ] [[package]] -name = "xkbcommon" -version = "0.7.0" +name = "x11" +version = "2.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13867d259930edc7091a6c41b4ce6eee464328c6ff9659b7e4c668ca20d4c91e" +checksum = "502da5464ccd04011667b11c435cb992822c2c0dbde1770c988480d312a0db2e" dependencies = [ "libc", - "memmap2", - "xkeysym", + "pkg-config", ] -[[package]] -name = "xkeysym" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "054a8e68b76250b253f671d1268cb7f1ae089ec35e195b2efb2a4e9a836d0621" - [[package]] name = "zeroize" version = "1.7.0" diff --git a/README.md b/README.md index b4d184c..cb6cd1c 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,7 @@ Answers might be slow depending on your setup, you may want to try the third par - [Installation](#installation-) - [Usage](#usage) +- [Voice](#voice) - [A few examples to get started 🐈‍⬛](#a-few-examples-to-get-started-) - [Integrating with editors](#integrating-with-editors) - [Example workflows](#example-workflows) @@ -60,6 +61,16 @@ The minimum config requirement is a `default` prompt calling a setup API (either Now on how to get it. +### To use voice input + +**Only test on linux so far.** Any help is appreciated from the other platforms users. + +- On linux, make sure `arecord` is installed and `arecord --quiet audio.wav` record your audio until you ctrl-c it and produces no output to stdout. +- On Mac, make sure `sox` is installed and `sox -t waveaudio 0 audio.wav ` record your audio until you ctrl-c it and produces no output to stdout. +- On Windows, make sure `sox` is installed and `sox -t waveaudio 0 audio.wav` record your audio until you ctrl-c it and produces no output to stdout. + +If it doesn't work, please open an issue. + ### With Cargo With an **up to date** [rust and cargo](https://www.rust-lang.org/tools/install) setup (you might consider running `rustup update`): @@ -77,16 +88,17 @@ Chose the one compiled for your platform on the [release page](https://github.co ## Usage ```text -Usage: sc [OPTIONS] [INPUT_OR_CONFIG_REF] [INPUT_IF_CONFIG_REF] +Usage: sc [OPTIONS] [INPUT_OR_TEMPLATE_REF] [INPUT_IF_TEMPLATE_REF] Arguments: - [INPUT_OR_CONFIG_REF] ref to a prompt from config or straight input (will use `default` prompt template) - [INPUT_IF_CONFIG_REF] if the first arg matches a config ref, the second will be used as input + [INPUT_OR_TEMPLATE_REF] ref to a prompt template from config or straight input (will use `default` prompt template if input) + [INPUT_IF_TEMPLATE_REF] if the first arg matches a config template, the second will be used as input Options: -e, --extend-conversation whether to extend the previous conversation or start a new one -r, --repeat-input whether to repeat the input before the output, useful to extend instead of replacing - --api overrides which api to hit [possible values: openai, mistral, groq, anthropic, ollama] + -v, --voice whether to use voice for input + --api overrides which api to hit [possible values: another-api-for-tests, ollama, anthropic, groq, mistral, openai] -m, --model overrides which model (of the api) to use -t, --temperature temperature higher means answer further from the average -l, --char-limit max number of chars to include, ask for user approval if more, 0 = no limit @@ -100,10 +112,43 @@ You can use it to **accomplish tasks in the CLI** but **also in your editors** ( The key to make this work seamlessly is a good default prompt that tells the model to behave like a CLI tool an not write any unwanted text like markdown formatting or explanations. +# Voice + +⚠️ **Testing in progress** I only have a linux system and wasn't able to test the recording commands for other OS. The good news is you can make up your own that works and then plug it in the config. + +Use the `-v` flag to ask for voice input then press space to end it. It will replace the prompt customization arg. + +- uses openai whisper +- make sure your `recording_command` field works in your termimal command, it should create a wav file +- requires you to have an openai key in your `.api_keys.toml` +- you can still use any prompt template or text model to get your output + +``` +sc -v + +sc test -v + +sc test -v -c src/**/* +``` + +## How does it work? + +`smartcat` call an external program that handles the voice recording and instructs it to save the result in a wav file. It then listens to keyboard inputs and stops the recording when space is pressed. + +The recording is then sent to a speech to text model, the resulting transcript is finally added to the prompt and sent to the text model to get an answer. + +On linux: +On Mac: +On windows: + +To debug, you can check the `conversation.toml` file or listen to the `audio.wav` in the smart config home and see what the model heard and transcripted. + ## A few examples to get started 🐈‍⬛ ``` -sc "say hi" # just ask +sc "say hi" # just ask (uses default prompt template) + +sc -v # use your voice to ask (then press to stop the recording) sc test # use templated prompts sc test "and parametrize them" # extend them on the fly @@ -301,6 +346,14 @@ content ='''Write tests using pytest for the following code. Parametrize it if a ''' ``` +```toml +url = "https://api.openai.com/v1/audio/transcriptions" +# make sure this command fit you OS and works on its own +recording_command = "arecord -f S16_LE --quiet " +model = "whisper-1" +api = "openai" +``` + see [the config setup file](./src/config/mod.rs) for more details. ## Ollama setup diff --git a/src/config/api.rs b/src/config/api.rs index b782ee0..51fdce7 100644 --- a/src/config/api.rs +++ b/src/config/api.rs @@ -7,7 +7,7 @@ use std::io::Write; use std::path::PathBuf; use std::str::FromStr; -use crate::config::{prompt::Prompt, resolve_config_path}; +use super::{prompt::Prompt, resolve_config_path}; const API_KEYS_FILE: &str = ".api_configs.toml"; diff --git a/src/config/mod.rs b/src/config/mod.rs index 798752f..4d16a53 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,13 +1,15 @@ pub mod api; pub mod prompt; +pub mod voice; + +use std::{path::PathBuf, process::Command}; use self::{ api::{api_keys_path, generate_api_keys_file, get_api_config}, prompt::{generate_prompts_file, get_prompts, prompts_path}, + voice::{generate_voice_file, voice_config_path}, }; -use crate::input_processing::is_interactive; - -use std::{path::PathBuf, process::Command}; +use crate::utils::is_interactive; pub const PLACEHOLDER_TOKEN: &str = "#[]"; @@ -39,6 +41,14 @@ pub fn ensure_config_files() -> std::io::Result<()> { generate_prompts_file()? } + if !voice_config_path().exists() { + println!( + "Voice config file not found at {:?}, generating one.\n...", + () + ); + generate_voice_file().expect("Unable to generate config files"); + }; + if !api_keys_path().exists() { println!( "API config file not found at {:?}, generating one.\n...", @@ -100,9 +110,11 @@ mod tests { api::{api_keys_path, Api, ApiConfig}, ensure_config_files, prompt::{prompts_path, Prompt}, - resolve_config_path, CUSTOM_CONFIG_ENV_VAR, DEFAULT_CONFIG_PATH, + resolve_config_path, + voice::{voice_config_path, VoiceConfig}, + CUSTOM_CONFIG_ENV_VAR, DEFAULT_CONFIG_PATH, }, - input_processing::IS_NONINTERACTIVE_ENV_VAR, + utils::IS_NONINTERACTIVE_ENV_VAR, }; use serial_test::serial; use std::collections::HashMap; @@ -158,9 +170,11 @@ mod tests { let api_keys_path = api_keys_path(); let prompts_path = prompts_path(); + let voice_path = voice_config_path(); assert!(!api_keys_path.exists()); assert!(!prompts_path.exists()); + assert!(!voice_path.exists()); let result = ensure_config_files(); @@ -173,6 +187,7 @@ mod tests { assert!(api_keys_path.exists()); assert!(prompts_path.exists()); + assert!(voice_path.exists()); Ok(()) } @@ -188,6 +203,7 @@ mod tests { let api_keys_path = api_keys_path(); let prompts_path = prompts_path(); + let voice_path = voice_config_path(); // Precreate files with some content let mut api_keys_file = fs::File::create(&api_keys_path)?; @@ -196,6 +212,9 @@ mod tests { let mut prompts_file = fs::File::create(&prompts_path)?; prompts_file.write_all(b"Some prompts data")?; + let mut voice_file = fs::File::create(&voice_path)?; + voice_file.write_all(b"Some voice data")?; + let result = ensure_config_files(); // Restoring the original environment variable @@ -209,6 +228,7 @@ mod tests { // Check if files still exist assert!(api_keys_path.exists()); assert!(prompts_path.exists()); + assert!(voice_path.exists()); // Check if the contents remain unchanged let mut api_keys_content = String::new(); @@ -219,6 +239,10 @@ mod tests { fs::File::open(&prompts_path)?.read_to_string(&mut prompts_content)?; assert_eq!(prompts_content, "Some prompts data".to_string()); + let mut voice_content = String::new(); + fs::File::open(&voice_path)?.read_to_string(&mut voice_content)?; + assert_eq!(voice_content, "Some voice data".to_string()); + Ok(()) } @@ -233,9 +257,11 @@ mod tests { let api_keys_path = api_keys_path(); let prompts_path = prompts_path(); + let voice_path = voice_config_path(); assert!(!api_keys_path.exists()); assert!(!prompts_path.exists()); + assert!(!voice_path.exists()); let result = ensure_config_files(); @@ -249,6 +275,7 @@ mod tests { // Read back the files and deserialize let api_config_contents = fs::read_to_string(&api_keys_path)?; let prompts_config_contents = fs::read_to_string(&prompts_path)?; + let voice_file_content = fs::read_to_string(&voice_path)?; // Deserialize contents to expected data structures // TODO: would be better to use `get_config` and `get_prompts` but @@ -260,7 +287,12 @@ mod tests { let prompt_config: HashMap = toml::from_str(&prompts_config_contents).expect("Failed to deserialize prompts config"); + let voice_config: VoiceConfig = + toml::from_str(&voice_file_content).expect("Failed to deserialize voice config"); + // Check if the content matches the default values + + // API assert_eq!( api_config.get(&Prompt::default().api.to_string()), Some(&ApiConfig::default()) @@ -280,12 +312,16 @@ mod tests { Some(&ApiConfig::anthropic()) ); + // Prompts let default_prompt = Prompt::default(); assert_eq!(prompt_config.get("default"), Some(&default_prompt)); let empty_prompt = Prompt::empty(); assert_eq!(prompt_config.get("empty"), Some(&empty_prompt)); + // Voice + assert_eq!(voice_config, VoiceConfig::default()); + Ok(()) } } diff --git a/src/config/voice.rs b/src/config/voice.rs new file mode 100644 index 0000000..d994ebf --- /dev/null +++ b/src/config/voice.rs @@ -0,0 +1,61 @@ +use serde::{Deserialize, Serialize}; +use std::fs; +use std::io::Write; +use std::path::PathBuf; + +const VOICE_CONFIG_FILE: &str = "voice.toml"; +pub const AUDIO_FILE_PATH_PLACEHOLDER: &str = ""; + +use super::{api::Api, resolve_config_path}; + +#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)] +pub struct VoiceConfig { + pub url: String, + pub recording_command: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + pub api: Api, +} + +impl Default for VoiceConfig { + fn default() -> Self { + let recording_command: String = match std::env::consts::OS { + "windows" => format!("sox -t waveaudio 0 -d {}", AUDIO_FILE_PATH_PLACEHOLDER), + "macos" => format!("sox -t waveaudio 0 -d {}", AUDIO_FILE_PATH_PLACEHOLDER), + "linux" => format!("arecord -f S16_LE --quiet {}", AUDIO_FILE_PATH_PLACEHOLDER), + os => panic!("Unexpected os: {}", os), + }; + + VoiceConfig { + url: "https://api.openai.com/v1/audio/transcriptions".to_string(), + recording_command, + model: Some("whisper-1".to_string()), + api: Api::Openai, + } + } +} + +pub(super) fn voice_config_path() -> PathBuf { + resolve_config_path().join(VOICE_CONFIG_FILE) +} + +pub(super) fn generate_voice_file() -> std::io::Result<()> { + let voice_config = VoiceConfig::default(); + + std::fs::create_dir_all(voice_config_path().parent().unwrap())?; + + let mut voice_config_file = fs::File::create(voice_config_path())?; + + let voice_config_str = toml::to_string_pretty(&voice_config) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + voice_config_file.write_all(voice_config_str.as_bytes())?; + Ok(()) +} + +pub fn get_voice_config() -> VoiceConfig { + let content = fs::read_to_string(voice_config_path()).unwrap_or_else(|error| { + panic!("Could not read file {:?}, {:?}", voice_config_path(), error) + }); + + toml::from_str(&content).expect("Unble to parse voice file content into config struct") +} diff --git a/src/main.rs b/src/main.rs index 0355a01..097b269 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ mod config; -mod input_processing; mod prompt_customization; -mod third_party; +mod text; +mod utils; mod voice; use crate::config::{ @@ -12,11 +12,12 @@ use crate::config::{ use prompt_customization::customize_prompt; use clap::{Args, Parser}; -use input_processing::process_input_with_request; use log::debug; use std::fs; use std::io::{self, IsTerminal, Read, Write}; +use text::process_input_with_request; + const DEFAULT_PROMPT_NAME: &str = "default"; #[derive(Debug, Parser)] @@ -28,6 +29,10 @@ const DEFAULT_PROMPT_NAME: &str = "default"; long_about = None, after_help = "Examples: ========= + +- sc +- sc + - sc \"say hi\" # just ask - sc test # use templated prompts @@ -42,17 +47,17 @@ const DEFAULT_PROMPT_NAME: &str = "default"; " )] struct Cli { - /// ref to a prompt from config or straight input (will use `default` prompt template) - input_or_config_ref: Option, - /// if the first arg matches a config ref, the second will be used as input - input_if_config_ref: Option, + /// ref to a prompt template from config or straight input (will use `default` prompt template if input) + input_or_template_ref: Option, + /// if the first arg matches a config template, the second will be used as input + input_if_template_ref: Option, /// whether to extend the previous conversation or start a new one #[arg(short, long)] extend_conversation: bool, /// whether to repeat the input before the output, useful to extend instead of replacing #[arg(short, long)] repeat_input: bool, - /// whether to use voice for input, requires may require admin permissions + /// whether to use voice for input #[arg(short, long)] voice: bool, #[command(flatten)] @@ -118,17 +123,17 @@ fn main() { get_default_and_or_custom_prompt(&args, &mut prompt_customizaton_text) } else { if args.voice { - if args.input_or_config_ref.is_some() { + if args.input_or_template_ref.is_some() { panic!( - "Invalid parameters, when using voice, either provide a valid ref to a config prompt or nothing at all.\n\ - Use `sc -v ` or `sc -v`" + "Invalid parameters, when extending conversation and using voice you can't provide additional input args.\n\ + Use `sc -e -v`" ); } - prompt_customizaton_text = voice::get_voice_transcript(); + prompt_customizaton_text = voice::record_voice_and_get_transcript(); } else { - prompt_customizaton_text = args.input_or_config_ref; + prompt_customizaton_text = args.input_or_template_ref; } - if args.input_if_config_ref.is_some() { + if args.input_if_template_ref.is_some() { panic!( "Invalid parameters, cannot provide a config ref when extending a conversation.\n\ Use `sc -e \".\"`" @@ -189,12 +194,12 @@ fn get_default_and_or_custom_prompt( ); let input_or_config_ref = args - .input_or_config_ref + .input_or_template_ref .clone() .unwrap_or_else(|| String::from("default")); if let Some(prompt) = prompts.remove(&input_or_config_ref) { - if args.input_if_config_ref.is_some() { + if args.input_if_template_ref.is_some() { // first arg matching a prompt and second one is customization if args.voice { panic!( @@ -202,15 +207,15 @@ fn get_default_and_or_custom_prompt( Use `sc -v ` or `sc -v`" ); } - *prompt_customization_text = args.input_if_config_ref.clone() + *prompt_customization_text = args.input_if_template_ref.clone() } if args.voice { - *prompt_customization_text = voice::get_voice_transcript(); + *prompt_customization_text = voice::record_voice_and_get_transcript(); } prompt } else { *prompt_customization_text = Some(input_or_config_ref); - if args.input_if_config_ref.is_some() { + if args.input_if_template_ref.is_some() { // first arg isn't a prompt and a second one was provided panic!( "Invalid parameters, either provide a valid ref to a config prompt then an input, or only an input.\n\ diff --git a/src/third_party/mod.rs b/src/text/api_call.rs similarity index 52% rename from src/third_party/mod.rs rename to src/text/api_call.rs index 653f5e0..c2746da 100644 --- a/src/third_party/mod.rs +++ b/src/text/api_call.rs @@ -1,16 +1,11 @@ -mod prompt_adapters; -mod response_parsing; +use super::request_schemas::{AnthropicPrompt, OpenAiPrompt}; +use super::response_schemas::{AnthropicResponse, OllamaResponse, OpenAiResponse}; -use self::prompt_adapters::{AnthropicPrompt, OpenAiPrompt}; -use self::response_parsing::{AnthropicResponse, OllamaResponse, OpenAiResponse}; -use crate::input_processing::is_interactive; -use crate::{ - config::{ - api::{Api, ApiConfig}, - prompt::{Message, Prompt}, - }, - input_processing::read_user_input, +use crate::config::{ + api::{Api, ApiConfig}, + prompt::{Message, Prompt}, }; +use crate::utils::handle_api_response; use log::debug; @@ -23,13 +18,15 @@ enum PromptFormat { Anthropic(AnthropicPrompt), } -pub fn make_api_request(api_config: ApiConfig, prompt: &Prompt) -> reqwest::Result { +pub fn post_prompt_and_get_answer( + api_config: ApiConfig, + prompt: &Prompt, +) -> reqwest::Result { debug!( "Trying to reach {:?} with key {:?}", api_config.url, api_config.api_key ); debug!("Prompt: {:?}", prompt); - validate_prompt_size(prompt); let mut prompt = prompt.clone(); @@ -82,47 +79,3 @@ pub fn make_api_request(api_config: ApiConfig, prompt: &Prompt) -> reqwest::Resu }; Ok(Message::assistant(&response_text)) } - -/// clean error management -pub fn handle_api_response>( - response: reqwest::blocking::Response, -) -> String { - let status = response.status(); - if response.status().is_success() { - response.json::().unwrap().into() - } else { - let error_text = response.text().unwrap(); - panic!("API request failed with status {}: {}", status, error_text); - } -} - -fn validate_prompt_size(prompt: &Prompt) { - let char_limit = prompt.char_limit.unwrap_or_default(); - let number_of_chars: u32 = prompt - .messages - .iter() - .map(|message| message.content.len() as u32) - .sum(); - - debug!("Number of chars is prompt: {}", number_of_chars); - - if char_limit > 0 && number_of_chars > char_limit { - if is_interactive() { - println!( - "The number of chars in the input {} is greater than the set limit {}\n\ - Do you want to continue? High costs may ensue.\n[Y/n]", - number_of_chars, char_limit, - ); - let input = read_user_input(); - if input.trim() != "Y" { - println!("exiting..."); - std::process::exit(0); - } - } else { - panic!( - "Input {} larger than limit {} in non-interactive mode. Exiting.", - number_of_chars, char_limit - ); - } - } -} diff --git a/src/input_processing.rs b/src/text/mod.rs similarity index 63% rename from src/input_processing.rs rename to src/text/mod.rs index 69c64b3..83d05cc 100644 --- a/src/input_processing.rs +++ b/src/text/mod.rs @@ -1,23 +1,15 @@ +mod api_call; +mod request_schemas; +mod response_schemas; + use log::debug; use std::io::{Result, Write}; +use self::api_call::post_prompt_and_get_answer; use crate::config::{api::get_api_config, prompt::Prompt, PLACEHOLDER_TOKEN}; -use crate::third_party::make_api_request; - -pub const IS_NONINTERACTIVE_ENV_VAR: &str = "SMARTCAT_NONINTERACTIVE"; - -pub fn is_interactive() -> bool { - std::env::var(IS_NONINTERACTIVE_ENV_VAR).unwrap_or_default() != "1" -} - -pub fn read_user_input() -> String { - let mut user_input = String::new(); - std::io::stdin() - .read_line(&mut user_input) - .expect("Failed to read line"); - user_input.trim().to_string() -} +use crate::utils::{is_interactive, read_user_input}; +/// insert the input in the prompt, validate the length and make the request pub fn process_input_with_request( mut prompt: Prompt, mut input: String, @@ -31,7 +23,8 @@ pub fn process_input_with_request( // fetch the api config tied to the prompt let api_config = get_api_config(&prompt.api.to_string()); - let response_message = match make_api_request(api_config, &prompt) { + validate_prompt_size(&prompt); + let response_message = match post_prompt_and_get_answer(api_config, &prompt) { Ok(message) => message, Err(e) => { eprintln!("Failed to make API request: {:?}", e); @@ -52,6 +45,37 @@ pub fn process_input_with_request( Ok(prompt) } +fn validate_prompt_size(prompt: &Prompt) { + let char_limit = prompt.char_limit.unwrap_or_default(); + let number_of_chars: u32 = prompt + .messages + .iter() + .map(|message| message.content.len() as u32) + .sum(); + + debug!("Number of chars is prompt: {}", number_of_chars); + + if char_limit > 0 && number_of_chars > char_limit { + if is_interactive() { + println!( + "The number of chars in the input {} is greater than the set limit {}\n\ + Do you want to continue? High costs may ensue.\n[Y/n]", + number_of_chars, char_limit, + ); + let input = read_user_input(); + if input.trim() != "Y" { + println!("exiting..."); + std::process::exit(0); + } + } else { + panic!( + "Input {} larger than limit {} in non-interactive mode. Exiting.", + number_of_chars, char_limit + ); + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/third_party/prompt_adapters.rs b/src/text/request_schemas.rs similarity index 100% rename from src/third_party/prompt_adapters.rs rename to src/text/request_schemas.rs diff --git a/src/third_party/response_parsing.rs b/src/text/response_schemas.rs similarity index 100% rename from src/third_party/response_parsing.rs rename to src/text/response_schemas.rs diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..1ce813a --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,26 @@ +pub const IS_NONINTERACTIVE_ENV_VAR: &str = "SMARTCAT_NONINTERACTIVE"; + +/// clean error logging +pub fn handle_api_response>( + response: reqwest::blocking::Response, +) -> String { + let status = response.status(); + if response.status().is_success() { + response.json::().unwrap().into() + } else { + let error_text = response.text().unwrap(); + panic!("API request failed with status {}: {}", status, error_text); + } +} + +pub fn is_interactive() -> bool { + std::env::var(IS_NONINTERACTIVE_ENV_VAR).unwrap_or_default() != "1" +} + +pub fn read_user_input() -> String { + let mut user_input = String::new(); + std::io::stdin() + .read_line(&mut user_input) + .expect("Failed to read line"); + user_input.trim().to_string() +} diff --git a/src/voice/api_call.rs b/src/voice/api_call.rs new file mode 100644 index 0000000..10645f1 --- /dev/null +++ b/src/voice/api_call.rs @@ -0,0 +1,20 @@ +use crate::config::prompt::audio_file_path; +use crate::utils::handle_api_response; + +use super::response_schemas::OpenAiVoiceResponse; + +pub(super) fn post_audio_and_get_transcript(api_key: &str) -> reqwest::Result { + let client = reqwest::blocking::Client::new(); + let form = reqwest::blocking::multipart::Form::new() + .text("model", "whisper-1") + .file("file", audio_file_path()) + .expect("Failed to read audio file."); + + let response = client + .post("https://api.openai.com/v1/audio/transcriptions") + .bearer_auth(api_key) + .multipart(form) + .send()?; + + Ok(handle_api_response::(response)) +} diff --git a/src/voice/mod.rs b/src/voice/mod.rs index b084311..633aa32 100644 --- a/src/voice/mod.rs +++ b/src/voice/mod.rs @@ -1,49 +1,33 @@ -pub mod schemas; +mod api_call; +mod recording; +mod response_schemas; use device_query::{DeviceQuery, DeviceState, Keycode}; -use std::process::{Child, Command}; +use log::debug; +use std::time::Instant; -use self::schemas::{OpenAiVoiceResponse, VoiceConfig}; -use super::config::{api::get_api_config, prompt::audio_file_path}; -use super::third_party::handle_api_response; +use crate::config::api::get_api_config; +use crate::config::prompt::audio_file_path; +use crate::config::voice::{get_voice_config, AUDIO_FILE_PATH_PLACEHOLDER}; -fn start_recording() -> Option { - let os_string = audio_file_path().into_os_string(); - let audio_file_path = os_string; +use self::api_call::post_audio_and_get_transcript; +use self::recording::{start_recording, stop_recording}; - match std::env::consts::OS { - "windows" => Command::new("cmd") - .args(["/C", "start", "rec"]) - .arg(audio_file_path) - .spawn() - .ok(), - "macos" => Command::new("sox") - .arg("-d") - .arg(audio_file_path) - .spawn() - .ok(), - "linux" => Command::new("arecord") - .arg("-f") - .arg("S16_LE") - .arg("--quiet") - .arg(audio_file_path) - .spawn() - .ok(), - os => panic!("Unexpected os: {}", os), - } -} +pub fn record_voice_and_get_transcript() -> Option { + let voice_config = get_voice_config(); -fn stop_recording(process: &mut Child) { - process.kill().expect("Failed to stop recording."); -} + let recording_command = voice_config.recording_command.replace( + AUDIO_FILE_PATH_PLACEHOLDER, + audio_file_path() + .to_str() + .expect("Unable to parse audio file path to str."), + ); -pub fn get_voice_transcript() -> Option { - use std::time::Instant; + let mut process = start_recording(recording_command)?; - let mut process = start_recording()?; let device_state = DeviceState::new(); - let start_time = Instant::now(); + let start_time = Instant::now(); loop { let keys: Vec = device_state.get_keys(); if keys.contains(&Keycode::Space) { @@ -53,31 +37,16 @@ pub fn get_voice_transcript() -> Option { } stop_recording(&mut process); - let duration = start_time.elapsed(); - println!("Recording duration: {:?}", duration); - std::thread::sleep(std::time::Duration::from_secs(1)); - let voice_config = VoiceConfig::default(); + debug!("Recording duration: {:?}", start_time.elapsed()); - let api_config = get_api_config(&voice_config.api.to_string()); + std::thread::sleep(std::time::Duration::from_millis(250)); - let transcript = - post_audio(&api_config.get_api_key()).expect("Failed to send audio file to API"); - - Some(transcript) -} + let api_config = get_api_config(&voice_config.api.to_string()); -fn post_audio(api_key: &str) -> reqwest::Result { - let client = reqwest::blocking::Client::new(); - let form = reqwest::blocking::multipart::Form::new() - .text("model", "whisper-1") - .file("file", audio_file_path()) - .expect("Failed to read audio file."); + let transcript = post_audio_and_get_transcript(&api_config.get_api_key()) + .expect("Failed to send audio file to API"); - let response = client - .post("https://api.openai.com/v1/audio/transcriptions") - .bearer_auth(api_key) - .multipart(form) - .send()?; + debug!("Audio transcript: {}", transcript); - Ok(handle_api_response::(response)) + Some(transcript) } diff --git a/src/voice/recording.rs b/src/voice/recording.rs new file mode 100644 index 0000000..0d3cfc3 --- /dev/null +++ b/src/voice/recording.rs @@ -0,0 +1,13 @@ +use std::process::{Child, Command}; + +pub(super) fn start_recording(recording_command: String) -> Option { + // default commands for each os are defined in src/config/voice.rs + Command::new(recording_command.split_whitespace().next().unwrap()) + .args(recording_command.split_whitespace().skip(1)) + .spawn() + .ok() +} + +pub(super) fn stop_recording(process: &mut Child) { + process.kill().expect("Failed to stop recording."); +} diff --git a/src/voice/response_schemas.rs b/src/voice/response_schemas.rs new file mode 100644 index 0000000..4508f71 --- /dev/null +++ b/src/voice/response_schemas.rs @@ -0,0 +1,12 @@ +use serde::Deserialize; + +#[derive(Debug, Deserialize)] +pub struct OpenAiVoiceResponse { + text: String, +} + +impl From for String { + fn from(response: OpenAiVoiceResponse) -> Self { + response.text + } +} diff --git a/src/voice/schemas.rs b/src/voice/schemas.rs deleted file mode 100644 index 019219b..0000000 --- a/src/voice/schemas.rs +++ /dev/null @@ -1,35 +0,0 @@ -use serde::{Deserialize, Serialize}; - -use crate::config::api::Api; - -#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)] -pub(super) struct VoiceConfig { - pub url: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub recording_command: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub model: Option, - pub api: Api, -} - -impl Default for VoiceConfig { - fn default() -> Self { - VoiceConfig { - url: "https://api.openai.com/v1/audio/transcriptions".to_string(), - recording_command: None, - model: Some("whisper-1".to_string()), - api: Api::Openai, - } - } -} - -#[derive(Debug, Deserialize)] -pub struct OpenAiVoiceResponse { - text: String, -} - -impl From for String { - fn from(response: OpenAiVoiceResponse) -> Self { - response.text - } -}