Skip to content

Commit

Permalink
feat: Multi texts translation per request
Browse files Browse the repository at this point in the history
  • Loading branch information
brenoepics committed Feb 1, 2024
1 parent 23e64da commit bf5f4e1
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 77 deletions.
3 changes: 2 additions & 1 deletion src/main/java/io/github/brenoepics/at4j/AzureApi.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io.github.brenoepics.at4j.data.response.TranslationResponse;

import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;

Expand Down Expand Up @@ -72,7 +73,7 @@ public interface AzureApi {
* @param params The {@link TranslateParams} to translate.
* @return The {@link TranslationResponse} containing the translation.
*/
CompletableFuture<Optional<TranslationResponse>> translate(TranslateParams params);
CompletableFuture<Optional<List<TranslationResponse>>> translate(TranslateParams params);

/**
* Gets the available languages for translation.
Expand Down
35 changes: 7 additions & 28 deletions src/main/java/io/github/brenoepics/at4j/core/AzureApiImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
import io.github.brenoepics.at4j.util.rest.RestEndpoint;
import io.github.brenoepics.at4j.util.rest.RestMethod;
import io.github.brenoepics.at4j.util.rest.RestRequest;
import io.github.brenoepics.at4j.util.rest.RestRequestResult;

import java.net.http.HttpClient;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;

Expand Down Expand Up @@ -89,42 +91,19 @@ public ThreadPool getThreadPool() {
}

@Override
public CompletableFuture<Optional<TranslationResponse>> translate(TranslateParams params) {
if (params.getText() == null || params.getText().isEmpty()) {
public CompletableFuture<Optional<List<TranslationResponse>>> translate(TranslateParams params) {
if (params.getTexts() == null || params.getTexts().isEmpty()) {
return CompletableFuture.completedFuture(Optional.empty());
}

RestRequest<Optional<TranslationResponse>> request =
new RestRequest<Optional<TranslationResponse>>(
RestRequest<Optional<List<TranslationResponse>>> request =
new RestRequest<Optional<List<TranslationResponse>>>(
this, RestMethod.POST, RestEndpoint.TRANSLATE)
.setBody(params.getBody());
params.getQueryParameters().forEach(request::addQueryParameter);
params.getTargetLanguages().forEach(lang -> request.addQueryParameter("to", lang));

return request.execute(
response -> {
if (response.getJsonBody().isNull()
|| !response.getJsonBody().has(0)
|| !response.getJsonBody().get(0).has("translations")) return Optional.empty();

JsonNode jsonNode = response.getJsonBody().get(0);
Collection<Translation> translations = new ArrayList<>();
jsonNode
.get("translations")
.forEach(node -> translations.add(Translation.ofJSON((ObjectNode) node)));

TranslationResponse translationResponse;
if (jsonNode.has("detectedLanguage")) {
JsonNode detectedLanguage = jsonNode.get("detectedLanguage");
translationResponse =
new TranslationResponse(
DetectedLanguage.ofJSON((ObjectNode) detectedLanguage), translations);
} else {
translationResponse = new TranslationResponse(translations);
}

return Optional.of(translationResponse);
});
return request.execute(params::handleTranslations);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import io.github.brenoepics.at4j.azure.lang.Language;
import io.github.brenoepics.at4j.data.DetectedLanguage;
import io.github.brenoepics.at4j.data.Translation;
import io.github.brenoepics.at4j.data.request.optional.ProfanityAction;
import io.github.brenoepics.at4j.data.request.optional.ProfanityMarker;
import io.github.brenoepics.at4j.data.request.optional.TextType;
import io.github.brenoepics.at4j.data.response.TranslationResponse;
import io.github.brenoepics.at4j.util.rest.RestRequestResult;

import java.util.*;
import java.util.stream.Collectors;

Expand All @@ -18,7 +23,7 @@
*/
public class TranslateParams {
// The text to be translated
private String text;
private LinkedHashMap<Integer, String> toTranslate;
// The type of the text to be translated (plain or HTML)
private TextType textType;
// The action to be taken on profanities in the text
Expand All @@ -43,18 +48,32 @@ public class TranslateParams {
* @param targetLanguages The target languages for the translation.
*/
public TranslateParams(String text, Collection<String> targetLanguages) {
this.text = text;
this.toTranslate = new LinkedHashMap<>();
this.toTranslate.put(1, text);
this.targetLanguages = targetLanguages;
}

/**
* Constructor that initializes the text to be translated.
*
* @param texts The text list to be translated.
* @param targetLanguages The target languages for the translation.
*/
public TranslateParams(Collection<String> texts, Collection<String> targetLanguages) {
this.toTranslate = new LinkedHashMap<>();
texts.forEach(t -> this.toTranslate.put(this.toTranslate.size() + 1, t));
this.targetLanguages = targetLanguages;
}

/**
* Sets the text to be translated.
*
* @param text The text to be translated.
* @param texts The texts to be translated.
* @return This instance.
*/
public TranslateParams setText(String text) {
this.text = text;
public TranslateParams setTexts(Collection<String> texts) {
this.toTranslate = new LinkedHashMap<>();
texts.forEach(t -> this.toTranslate.put(this.toTranslate.size() + 1, t));
return this;
}

Expand Down Expand Up @@ -167,10 +186,7 @@ public TranslateParams setSourceLanguage(String sourceLanguage) {
*/
public TranslateParams setTargetLanguages(Collection<Language> targetLanguages) {
this.targetLanguages =
Collections.unmodifiableCollection(
targetLanguages.stream()
.map(Language::getCode)
.collect(Collectors.toCollection(ArrayList::new)));
targetLanguages.stream().map(Language::getCode).collect(Collectors.toUnmodifiableList());
return this;
}

Expand All @@ -185,8 +201,8 @@ public TranslateParams setTargetLanguages(String... targetLanguages) {
return this;
}

public String getText() {
return text;
public Map<Integer, String> getTexts() {
return toTranslate;
}

public Boolean getIncludeAlignment() {
Expand Down Expand Up @@ -264,11 +280,45 @@ public Map<String, String> getQueryParameters() {
*/
public JsonNode getBody() {
ArrayNode body = JsonNodeFactory.instance.arrayNode();
if (getText() != null && !getText().isEmpty()) {

for (String text : getTexts().values()) {
ObjectNode textNode = JsonNodeFactory.instance.objectNode();
textNode.put("Text", getText());
textNode.put("Text", text);
body.add(textNode);
}
return body;
}

public Optional<List<TranslationResponse>> handleTranslations(
RestRequestResult<Optional<List<TranslationResponse>>> response) {
if (response.getJsonBody().isNull() || response.getJsonBody().isEmpty())
return Optional.empty();

List<TranslationResponse> responses = new ArrayList<>();
getTexts()
.forEach(
(index, baseText) -> {
JsonNode jsonNode = response.getJsonBody().get(index - 1);
if (!jsonNode.has("translations")) return;

Collection<Translation> translations = new ArrayList<>();
jsonNode
.get("translations")
.forEach(node -> translations.add(Translation.ofJSON((ObjectNode) node)));

if (jsonNode.has("detectedLanguage")) {
JsonNode detectedLanguage = jsonNode.get("detectedLanguage");
responses.add(
new TranslationResponse(
baseText,
DetectedLanguage.ofJSON((ObjectNode) detectedLanguage),
translations));
return;
}

responses.add(new TranslationResponse(baseText, translations));
});

return Optional.of(responses);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ public class TranslationResponse {
*/
private DetectedLanguage detectedLanguage = null;

private final String baseText;

// A collection of translations for the input text.
private final Collection<Translation> translations;

Expand All @@ -28,7 +30,8 @@ public class TranslationResponse {
* @param translations A collection of translations for the input text.
*/
public TranslationResponse(
DetectedLanguage detectedLanguage, Collection<Translation> translations) {
String baseText, DetectedLanguage detectedLanguage, Collection<Translation> translations) {
this.baseText = baseText;
this.detectedLanguage = detectedLanguage;
this.translations = translations;
}
Expand All @@ -39,7 +42,8 @@ public TranslationResponse(
*
* @param translations A collection of translations for the input text.
*/
public TranslationResponse(Collection<Translation> translations) {
public TranslationResponse(String baseText, Collection<Translation> translations) {
this.baseText = baseText;
this.translations = translations;
}

Expand All @@ -60,4 +64,13 @@ public DetectedLanguage getDetectedLanguage() {
public Collection<Translation> getTranslations() {
return translations;
}

/**
* Returns the base texts that were translated
*
* @return the base text
*/
public String getBaseText() {
return baseText;
}
}
42 changes: 20 additions & 22 deletions src/test/java/io/github/brenoepics/at4j/AzureApiTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,35 +70,31 @@ void translateEmptyKey() {
AzureApi api = new AzureApiBuilder().baseURL(BaseURL.GLOBAL).setKey("").region("test").build();

TranslateParams params = new TranslateParams("test", List.of("pt")).setSourceLanguage("en");
CompletableFuture<Optional<TranslationResponse>> translation = api.translate(params);
CompletableFuture<Optional<List<TranslationResponse>>> translation = api.translate(params);
assertThrows(CompletionException.class, translation::join);
api.disconnect();
}

@Test
void translateEmptyText() {
AzureApi api = new AzureApiBuilder().baseURL(BaseURL.GLOBAL).setKey("test").build();

TranslateParams params = new TranslateParams("", List.of("pt")).setSourceLanguage("en");
CompletableFuture<Optional<TranslationResponse>> translation = api.translate(params);
Optional<TranslationResponse> tr = translation.join();
tr.ifPresent(translations -> assertEquals(0, translations.getTranslations().size()));
api.disconnect();
}
void translateHelloWorld() {
String azureKey = System.getenv("AZURE_KEY");
String region = System.getenv("AZURE_REGION");
Assumptions.assumeTrue(
azureKey != null && region != null, "Azure Credentials are null, skipping the test");
Assumptions.assumeTrue(
!azureKey.isEmpty() && !region.isEmpty(), "Azure Credentials are empty, skipping the test");

@Test
void translateEmptySourceLanguage() {
AzureApi api = new AzureApiBuilder().baseURL(BaseURL.GLOBAL).setKey("test").build();
AzureApiBuilder builder = new AzureApiBuilder().setKey(azureKey).region(region);
AzureApi api = builder.build();

TranslateParams params = new TranslateParams("", List.of("pt"));
CompletableFuture<Optional<TranslationResponse>> translation = api.translate(params);
Optional<TranslationResponse> tr = translation.join();
tr.ifPresent(translations -> assertEquals(0, translations.getTranslations().size()));
api.disconnect();
TranslateParams params = new TranslateParams("Hello World!", List.of("pt", "es"));
Optional<List<TranslationResponse>> translate = api.translate(params).join();
assertTrue(translate.isPresent());
assertEquals(2, translate.get().get(0).getTranslations().size());
}

@Test
void translateHelloWorld() {
void translateMultiText() {
String azureKey = System.getenv("AZURE_KEY");
String region = System.getenv("AZURE_REGION");
Assumptions.assumeTrue(
Expand All @@ -109,10 +105,12 @@ void translateHelloWorld() {
AzureApiBuilder builder = new AzureApiBuilder().setKey(azureKey).region(region);
AzureApi api = builder.build();

TranslateParams params = new TranslateParams("Hello World!", List.of("pt", "es"));
Optional<TranslationResponse> translate = api.translate(params).join();
TranslateParams params =
new TranslateParams(List.of("Hello World!", "How are you?"), List.of("pt", "es"));
Optional<List<TranslationResponse>> translate = api.translate(params).join();
assertTrue(translate.isPresent());
assertEquals(2, translate.get().getTranslations().size());

assertEquals(2, translate.get().size());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import org.junit.jupiter.api.Test;
import org.mockito.Mock;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;

Expand All @@ -28,7 +30,7 @@ public void setup() {

@Test
void returnsEmptyOnInvalidInput() {
translateParams.setText(null);
translateParams.setTexts(Collections.emptyList());
azureApi
.translate(translateParams)
.whenComplete(
Expand All @@ -39,7 +41,7 @@ void returnsEmptyOnInvalidInput() {
}
});

CompletableFuture<Optional<TranslationResponse>> response = azureApi.translate(translateParams);
CompletableFuture<Optional<List<TranslationResponse>>> response = azureApi.translate(translateParams);

assertFalse(response.join().isPresent());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;

import static org.junit.jupiter.api.Assertions.*;

Expand All @@ -18,8 +15,8 @@ class TranslateParamsTest {
@Test
void shouldSetAndGetText() {
TranslateParams params = new TranslateParams("Hello", List.of("fr"));
params.setText("Bonjour");
assertEquals("Bonjour", params.getText());
params.setTexts(Collections.singleton("Bonjour"));
assertEquals("Bonjour", params.getTexts().get(1));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class TranslationResponseTest {
void createsTranslationResponseWithDetectedLanguageAndTranslations() {
DetectedLanguage detectedLanguage = new DetectedLanguage("en", 1.0f);
Translation translation = new Translation("pt", "Olá, mundo!");
TranslationResponse response = new TranslationResponse(detectedLanguage, List.of(translation));
TranslationResponse response = new TranslationResponse(translation.getText(), detectedLanguage, List.of(translation));

assertEquals(detectedLanguage, response.getDetectedLanguage());
assertEquals(1, response.getTranslations().size());
Expand All @@ -26,7 +26,7 @@ void createsTranslationResponseWithDetectedLanguageAndTranslations() {
@Test
void createsTranslationResponseWithTranslationsOnly() {
Translation translation = new Translation("pt", "Olá, mundo!");
TranslationResponse response = new TranslationResponse(List.of(translation));
TranslationResponse response = new TranslationResponse(translation.getText(), List.of(translation));

assertNull(response.getDetectedLanguage());
assertEquals(1, response.getTranslations().size());
Expand All @@ -35,7 +35,7 @@ void createsTranslationResponseWithTranslationsOnly() {

@Test
void returnsEmptyTranslationsWhenNoneProvided() {
TranslationResponse response = new TranslationResponse(Collections.emptyList());
TranslationResponse response = new TranslationResponse("", Collections.emptyList());
assertEquals(0, response.getTranslations().size());
}
}

0 comments on commit bf5f4e1

Please sign in to comment.