Skip to content

Commit

Permalink
feat: Added Mistral/Codestral custom providers.
Browse files Browse the repository at this point in the history
  • Loading branch information
HavenDV committed Oct 2, 2024
1 parent 9dbf6ab commit ab3dacd
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 2 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ using var api = CustomProviders.OpenRouter("API_KEY");
using var api = CustomProviders.Together("API_KEY");
using var api = CustomProviders.Perplexity("API_KEY");
using var api = CustomProviders.SambaNova("API_KEY");
using var api = CustomProviders.Mistral("API_KEY");
using var api = CustomProviders.Codestral("API_KEY");
using var api = CustomProviders.Ollama();
using var api = CustomProviders.LmStudio();
```
Expand Down
28 changes: 28 additions & 0 deletions src/libs/OpenAI/CustomProviders.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ public static class CustomProviders
/// </summary>
public const string SambaNovaBaseUrl = "https://api.sambanova.ai/v1";

/// <summary>
///
/// </summary>
public const string MistralBaseUrl = "https://api.mistral.ai/v1";

/// <summary>
///
/// </summary>
public const string CodestralBaseUrl = "https://codestral.mistral.ai/v1";

/// <summary>
///
/// </summary>
Expand Down Expand Up @@ -151,6 +161,24 @@ public static OpenAiApi SambaNova(string apiKey)
return new OpenAiApi(apiKey, baseUri: new Uri(SambaNovaBaseUrl));
}

/// <summary>
/// Create an API to use for Mistral.
/// </summary>
/// <returns></returns>
public static OpenAiApi Mistral(string apiKey)
{
return new OpenAiApi(apiKey, baseUri: new Uri(MistralBaseUrl));
}

/// <summary>
/// Create an API to use for Codestral.
/// </summary>
/// <returns></returns>
public static OpenAiApi Codestral(string apiKey)
{
return new OpenAiApi(apiKey, baseUri: new Uri(CodestralBaseUrl));
}

/// <summary>
/// Create an API to use for Ollama.
/// </summary>
Expand Down
2 changes: 2 additions & 0 deletions src/tests/OpenAI.IntegrationTests/CustomProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ public enum CustomProvider
Ollama,
LmStudio,
Groq,
Mistral,
Codestral,
}
18 changes: 16 additions & 2 deletions src/tests/OpenAI.IntegrationTests/Tests.Chat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ public partial class Tests
[DataRow(CustomProvider.Ollama)]
[DataRow(CustomProvider.LmStudio)]
[DataRow(CustomProvider.Groq)]
[DataRow(CustomProvider.Mistral)]
[DataRow(CustomProvider.Codestral)]
public async Task GenerateFiveRandomWords(CustomProvider customProvider)
{
var pair = GetAuthorizedChatApi(customProvider);
Expand All @@ -23,12 +25,17 @@ public async Task GenerateFiveRandomWords(CustomProvider customProvider)
string response = await api.Chat.CreateChatCompletionAsync(
messages: ["Generate five random words."],
model: pair.Model,
user: "tryAGI.OpenAI.IntegrationTests.Tests.CreateChatCompletion",
user: customProvider switch
{
CustomProvider.Mistral or CustomProvider.Codestral => null,
_ => "tryAGI.OpenAI.Tests.GenerateFiveRandomWords",
},
frequencyPenalty: customProvider switch
{
CustomProvider.Perplexity => 0.5,
_ => null,
},
presencePenalty: null,
logprobs: null);
response.Should().NotBeEmpty();

Expand All @@ -47,6 +54,8 @@ public async Task GenerateFiveRandomWords(CustomProvider customProvider)
[DataRow(CustomProvider.Ollama)]
[DataRow(CustomProvider.LmStudio)]
[DataRow(CustomProvider.Groq)]
[DataRow(CustomProvider.Mistral)]
//[DataRow(CustomProvider.Codestral)]
public async Task GenerateFiveRandomWordsAsStream(CustomProvider customProvider)
{
var pair = GetAuthorizedChatApi(customProvider);
Expand All @@ -55,7 +64,12 @@ public async Task GenerateFiveRandomWordsAsStream(CustomProvider customProvider)
var enumerable = api.Chat.CreateChatCompletionAsStreamAsync(
messages: ["Generate five random words."],
model: pair.Model,
user: "tryAGI.OpenAI.IntegrationTests.Tests.CreateChatCompletion");
presencePenalty: null,
user: customProvider switch
{
CustomProvider.Mistral or CustomProvider.Codestral => null,
_ => "tryAGI.OpenAI.Tests.GenerateFiveRandomWordsAsStream",
});

await foreach (string response in enumerable)
{
Expand Down
14 changes: 14 additions & 0 deletions src/tests/OpenAI.IntegrationTests/Tests.Helpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,20 @@ internal static (OpenAiApi Api, string Model) GetAuthorizedChatApi(CustomProvide
throw new AssertInconclusiveException("SAMBANOVA_API_KEY environment variable is not found.")),
model ?? "Meta-Llama-3.1-8B-Instruct");
}
if (customProvider == CustomProvider.Mistral)
{
return (CustomProviders.Mistral(apiKey:
Environment.GetEnvironmentVariable("MISTRAL_API_KEY") ??
throw new AssertInconclusiveException("MISTRAL_API_KEY environment variable is not found.")),
model ?? "mistral-large-latest");
}
if (customProvider == CustomProvider.Codestral)
{
return (CustomProviders.Codestral(apiKey:
Environment.GetEnvironmentVariable("CODESTRAL_API_KEY") ??
throw new AssertInconclusiveException("CODESTRAL_API_KEY environment variable is not found.")),
model ?? "codestral-latest");
}

if (customProvider == CustomProvider.Ollama)
{
Expand Down

0 comments on commit ab3dacd

Please sign in to comment.