Skip to content

Commit

Permalink
feat!: clean up naming of methods to setup clients
Browse files Browse the repository at this point in the history
  • Loading branch information
moldhouse committed Dec 10, 2024
1 parent b9a2fd2 commit 5f18f38
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 37 deletions.
2 changes: 1 addition & 1 deletion src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ pub struct HttpClient {
impl HttpClient {
/// In production you typically would want set this to <https://inference-api.pharia.your-company.com>.
/// Yet you may want to use a different instance for testing.
pub fn with_base_url(host: String, api_token: Option<String>) -> Result<Self, Error> {
pub fn new(host: String, api_token: Option<String>) -> Result<Self, Error> {
let http = ClientBuilder::new().build()?;

Ok(Self {
Expand Down
16 changes: 5 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,33 +71,27 @@ pub struct Client {

impl Client {
/// A new instance of an Aleph Alpha client helping you interact with the Aleph Alpha API.
/// For "normal" client applications you may likely rather use [`Self::with_base_url`].
///
/// Setting the token to None allows specifying it on a per request basis.
/// You may want to only use request based authentication and skip default authentication. This
/// is useful if writing an application which invokes the client on behalf of many different
/// users. Having neither request, nor default authentication is considered a bug and will cause
/// a panic.
pub fn new(host: impl Into<String>, api_token: Option<String>) -> Result<Self, Error> {
let http_client = HttpClient::with_base_url(host.into(), api_token)?;
let http_client = HttpClient::new(host.into(), api_token)?;
Ok(Self { http_client })
}

/// Use your on-premise inference with your API token for all requests.
///
/// In production you typically would want set this to <https://inference-api.pharia.your-company.com>.
/// Yet you may want to use a different instance for testing.
pub fn with_base_url(
host: impl Into<String>,
api_token: impl Into<String>,
) -> Result<Self, Error> {
/// A client instance that always uses the same token for all requests.
pub fn with_auth(host: impl Into<String>, api_token: impl Into<String>) -> Result<Self, Error> {
Self::new(host, Some(api_token.into()))
}

pub fn from_env() -> Result<Self, Error> {
let _ = dotenv();
let api_token = env::var("PHARIA_AI_TOKEN").unwrap();
let inference_url = env::var("INFERENCE_URL").unwrap();
Self::with_base_url(inference_url, api_token)
Self::with_auth(inference_url, api_token)
}

/// Execute a task with the aleph alpha API and fetch its result.
Expand Down
38 changes: 19 additions & 19 deletions tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async fn chat_with_pharia_1_7b_base() {
let task = TaskChat::with_message(message);

let model = "pharia-1-llm-7b-control";
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
let response = client.chat(&task, model, &How::default()).await.unwrap();

// Then
Expand All @@ -49,7 +49,7 @@ async fn completion_with_luminous_base() {
let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1);

let model = "luminous-base";
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -67,7 +67,7 @@ async fn request_authentication_has_priority() {
let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1);

let model = "luminous-base";
let client = Client::with_base_url(inference_url(), bad_pharia_ai_token).unwrap();
let client = Client::with_auth(inference_url(), bad_pharia_ai_token).unwrap();
let response = client
.output_of(
&task.with_model(model),
Expand Down Expand Up @@ -140,7 +140,7 @@ async fn semanitc_search_with_luminous_base() {
temperature, traditionally in a wood-fired oven.",
);
let query = Prompt::from_text("What is Pizza?");
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();

// When
let robot_embedding_task = TaskSemanticEmbedding {
Expand Down Expand Up @@ -203,7 +203,7 @@ async fn complete_structured_prompt() {
sampling: Sampling::MOST_LIKELY,
};
let model = "luminous-base";
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand Down Expand Up @@ -232,7 +232,7 @@ async fn maximum_tokens_none_request() {
sampling: Sampling::MOST_LIKELY,
};
let model = "luminous-base";
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -253,7 +253,7 @@ async fn explain_request() {
target: " How is it going?",
granularity: Granularity::default().with_prompt_granularity(PromptGranularity::Sentence),
};
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();

// When
let response = client
Expand Down Expand Up @@ -283,7 +283,7 @@ async fn explain_request_with_auto_granularity() {
target: " How is it going?",
granularity: Granularity::default(),
};
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();

// When
let response = client
Expand Down Expand Up @@ -315,7 +315,7 @@ async fn explain_request_with_image_modality() {
target: " a cat.",
granularity: Granularity::default().with_prompt_granularity(PromptGranularity::Paragraph),
};
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();

// When
let response = client
Expand Down Expand Up @@ -365,7 +365,7 @@ async fn describe_image_starting_from_a_path() {
sampling: Sampling::MOST_LIKELY,
};
let model = "luminous-base";
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand Down Expand Up @@ -394,7 +394,7 @@ async fn describe_image_starting_from_a_dyn_image() {
sampling: Sampling::MOST_LIKELY,
};
let model = "luminous-base";
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -420,7 +420,7 @@ async fn only_answer_with_specific_animal() {
},
};
let model = "luminous-base";
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -447,7 +447,7 @@ async fn answer_should_continue() {
},
};
let model = "luminous-base";
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -474,7 +474,7 @@ async fn batch_semanitc_embed_with_luminous_base() {
temperature, traditionally in a wood-fired oven.",
);

let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();

// When
let embedding_task = TaskBatchSemanticEmbedding {
Expand All @@ -499,7 +499,7 @@ async fn tokenization_with_luminous_base() {
// Given
let input = "Hello, World!";

let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();

// When
let task1 = TaskTokenization::new(input, false, true);
Expand Down Expand Up @@ -536,7 +536,7 @@ async fn detokenization_with_luminous_base() {
// Given
let input = vec![49222, 15, 5390, 4];

let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();

// When
let task = TaskDetokenization { token_ids: &input };
Expand All @@ -553,7 +553,7 @@ async fn detokenization_with_luminous_base() {
#[tokio::test]
async fn fetch_tokenizer_for_pharia_1_llm_7b() {
// Given
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();

// When
let tokenizer = client
Expand All @@ -568,7 +568,7 @@ async fn fetch_tokenizer_for_pharia_1_llm_7b() {
#[tokio::test]
async fn stream_completion() {
// Given a streaming completion task
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
let task = TaskCompletion::from_text("").with_maximum_tokens(7);

// When the events are streamed and collected
Expand Down Expand Up @@ -601,7 +601,7 @@ async fn stream_completion() {
#[tokio::test]
async fn stream_chat_with_pharia_1_llm_7b() {
// Given a streaming completion task
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
let message = Message::user("Hello,");
let task = TaskChat::with_messages(vec![message]).with_maximum_tokens(7);

Expand Down
12 changes: 6 additions & 6 deletions tests/unit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async fn completion_with_luminous_base() {
// When
let task = TaskCompletion::from_text("Hello,").with_maximum_tokens(1);
let model = "luminous-base";
let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap();
let client = Client::with_auth(mock_server.uri(), "dummy-token").unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand Down Expand Up @@ -74,7 +74,7 @@ async fn detect_rate_limiting() {
// When
let task = TaskCompletion::from_text("Hello,").with_maximum_tokens(1);
let model = "luminous-base";
let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap();
let client = Client::with_auth(mock_server.uri(), "dummy-token").unwrap();
let error = client
.output_of(&task.with_model(model), &How::default())
.await
Expand Down Expand Up @@ -118,7 +118,7 @@ async fn detect_queue_full() {
// When
let task = TaskCompletion::from_text("Hello,").with_maximum_tokens(1);
let model = "luminous-base";
let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap();
let client = Client::with_auth(mock_server.uri(), "dummy-token").unwrap();
let error = client
.output_of(&task.with_model(model), &How::default())
.await
Expand Down Expand Up @@ -155,7 +155,7 @@ async fn detect_service_unavailable() {
// When
let task = TaskCompletion::from_text("Hello,").with_maximum_tokens(1);
let model = "luminous-base";
let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap();
let client = Client::with_auth(mock_server.uri(), "dummy-token").unwrap();
let error = client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -177,7 +177,7 @@ async fn be_nice() {
// When
let task = TaskCompletion::from_text("Hello,").with_maximum_tokens(1);
let model = "luminous-base";
let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap();
let client = Client::with_auth(mock_server.uri(), "dummy-token").unwrap();
// Drop result, answer is meaningless anyway
let _ = client
.output_of(
Expand Down Expand Up @@ -206,7 +206,7 @@ async fn client_timeout() {
.respond_with(ResponseTemplate::new(StatusCode::OK).set_delay(response_time))
.mount(&mock_server)
.await;
let client = Client::with_base_url(mock_server.uri(), "dummy-token").unwrap();
let client = Client::with_auth(mock_server.uri(), "dummy-token").unwrap();

// When
let result = client
Expand Down

0 comments on commit 5f18f38

Please sign in to comment.