Skip to content

Commit

Permalink
Customizable AI templates (#11884)
Browse files Browse the repository at this point in the history
* Start to work

* Fix runtime problems and checkers

* Merge with main

* Fix from code review

* Update from merge

* Fix compiler errors

* Fix from code review
  • Loading branch information
InAnYan authored Oct 30, 2024
1 parent aaff6f5 commit 709386a
Show file tree
Hide file tree
Showing 19 changed files with 386 additions and 141 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ dependencies {
exclude group: 'org.jetbrains.kotlin'
}


implementation 'org.apache.velocity:velocity-engine-core:2.3'
implementation platform('ai.djl:bom:0.30.0')
implementation 'ai.djl:api'
implementation 'ai.djl.huggingface:tokenizers'
Expand Down
1 change: 1 addition & 0 deletions src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@
uses ai.djl.repository.RepositoryFactory;
uses ai.djl.repository.zoo.ZooProvider;
uses dev.langchain4j.spi.prompt.PromptTemplateFactory;
requires velocity.engine.core;
// endregion

// region: Lucene
Expand Down
39 changes: 35 additions & 4 deletions src/main/java/org/jabref/gui/preferences/ai/AiTab.fxml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
<?import com.dlsc.unitfx.IntegerInputField?>
<?import org.controlsfx.control.SearchableComboBox?>
<?import org.controlsfx.control.textfield.CustomPasswordField?>
<?import javafx.scene.control.TabPane?>
<?import javafx.scene.control.Tab?>
<?import javafx.scene.control.TextArea?>
<fx:root
spacing="10.0"
type="VBox"
Expand Down Expand Up @@ -162,10 +165,6 @@
</children>
</HBox>

<ResizableTextArea
fx:id="instructionTextArea"
wrapText="true"/>

<GridPane hgap="10" vgap="10">
<columnConstraints>
<ColumnConstraints hgrow="ALWAYS" percentWidth="50" />
Expand Down Expand Up @@ -235,5 +234,37 @@
glyph="REFRESH"/>
</graphic>
</Button>

<HBox alignment="BASELINE_CENTER">
<Label styleClass="sectionHeader"
text="%Templates"
maxWidth="Infinity"
HBox.hgrow="ALWAYS"/>
<Button fx:id="templatesHelp"
prefWidth="20.0"/>
</HBox>

<TabPane>
<Tab text="%System message for chatting" closable="false">
<TextArea fx:id="systemMessageTextArea"/>
</Tab>
<Tab text="User message for chatting" closable="false">
<TextArea fx:id="userMessageTextArea"/>
</Tab>
<Tab text="Completion text for summarization of a chunk" closable="false">
<TextArea fx:id="summarizationChunkTextArea"/>
</Tab>
<Tab text="Completion text for summarization of several chunks" closable="false">
<TextArea fx:id="summarizationCombineTextArea"/>
</Tab>
</TabPane>

<Button onAction="#onResetTemplatesButtonClick"
text="%Reset templates to default">
<graphic>
<JabRefIconView
glyph="REFRESH"/>
</graphic>
</Button>
</children>
</fx:root>
29 changes: 21 additions & 8 deletions src/main/java/org/jabref/gui/preferences/ai/AiTab.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import javafx.scene.control.Button;
import javafx.scene.control.CheckBox;
import javafx.scene.control.ComboBox;
import javafx.scene.control.TextArea;
import javafx.scene.control.TextField;

import org.jabref.gui.actions.ActionFactory;
Expand All @@ -15,13 +16,13 @@
import org.jabref.gui.preferences.AbstractPreferenceTabView;
import org.jabref.gui.preferences.PreferencesTab;
import org.jabref.gui.util.ViewModelListCellFactory;
import org.jabref.logic.ai.templates.AiTemplate;
import org.jabref.logic.help.HelpFile;
import org.jabref.logic.l10n.Localization;
import org.jabref.model.ai.AiProvider;
import org.jabref.model.ai.EmbeddingModel;

import com.airhacks.afterburner.views.ViewLoader;
import com.dlsc.gemsfx.ResizableTextArea;
import com.dlsc.unitfx.IntegerInputField;
import de.saxsys.mvvmfx.utils.validation.visualization.ControlsFxVisualizer;
import org.controlsfx.control.SearchableComboBox;
Expand All @@ -43,16 +44,21 @@ public class AiTab extends AbstractPreferenceTabView<AiTabViewModel> implements

@FXML private TextField apiBaseUrlTextField;
@FXML private SearchableComboBox<EmbeddingModel> embeddingModelComboBox;
@FXML private ResizableTextArea instructionTextArea;
@FXML private TextField temperatureTextField;
@FXML private IntegerInputField contextWindowSizeTextField;
@FXML private IntegerInputField documentSplitterChunkSizeTextField;
@FXML private IntegerInputField documentSplitterOverlapSizeTextField;
@FXML private IntegerInputField ragMaxResultsCountTextField;
@FXML private TextField ragMinScoreTextField;

@FXML private TextArea systemMessageTextArea;
@FXML private TextArea userMessageTextArea;
@FXML private TextArea summarizationChunkTextArea;
@FXML private TextArea summarizationCombineTextArea;

@FXML private Button generalSettingsHelp;
@FXML private Button expertSettingsHelp;
@FXML private Button templatesHelp;

private final ControlsFxVisualizer visualizer = new ControlsFxVisualizer();

Expand All @@ -74,14 +80,14 @@ public void initialize() {
new ViewModelListCellFactory<AiProvider>()
.withText(AiProvider::toString)
.install(aiProviderComboBox);
aiProviderComboBox.setItems(viewModel.aiProvidersProperty());
aiProviderComboBox.itemsProperty().bind(viewModel.aiProvidersProperty());
aiProviderComboBox.valueProperty().bindBidirectional(viewModel.selectedAiProviderProperty());
aiProviderComboBox.disableProperty().bind(viewModel.disableBasicSettingsProperty());

new ViewModelListCellFactory<String>()
.withText(text -> text)
.install(chatModelComboBox);
chatModelComboBox.setItems(viewModel.chatModelsProperty());
chatModelComboBox.itemsProperty().bind(viewModel.chatModelsProperty());
chatModelComboBox.valueProperty().bindBidirectional(viewModel.selectedChatModelProperty());
chatModelComboBox.disableProperty().bind(viewModel.disableBasicSettingsProperty());

Expand Down Expand Up @@ -123,9 +129,6 @@ public void initialize() {
apiBaseUrlTextField.setDisable(newValue || viewModel.disableExpertSettingsProperty().get())
);

instructionTextArea.textProperty().bindBidirectional(viewModel.instructionProperty());
instructionTextArea.disableProperty().bind(viewModel.disableExpertSettingsProperty());

// bindBidirectional doesn't work well with number input fields ({@link IntegerInputField}, {@link DoubleInputField}),
// so they are expanded into `addListener` calls.

Expand Down Expand Up @@ -180,7 +183,6 @@ public void initialize() {
visualizer.initVisualization(viewModel.getChatModelValidationStatus(), chatModelComboBox);
visualizer.initVisualization(viewModel.getApiBaseUrlValidationStatus(), apiBaseUrlTextField);
visualizer.initVisualization(viewModel.getEmbeddingModelValidationStatus(), embeddingModelComboBox);
visualizer.initVisualization(viewModel.getSystemMessageValidationStatus(), instructionTextArea);
visualizer.initVisualization(viewModel.getTemperatureTypeValidationStatus(), temperatureTextField);
visualizer.initVisualization(viewModel.getTemperatureRangeValidationStatus(), temperatureTextField);
visualizer.initVisualization(viewModel.getMessageWindowSizeValidationStatus(), contextWindowSizeTextField);
Expand All @@ -191,9 +193,15 @@ public void initialize() {
visualizer.initVisualization(viewModel.getRagMinScoreRangeValidationStatus(), ragMinScoreTextField);
});

systemMessageTextArea.textProperty().bindBidirectional(viewModel.getTemplateSources().get(AiTemplate.CHATTING_SYSTEM_MESSAGE));
userMessageTextArea.textProperty().bindBidirectional(viewModel.getTemplateSources().get(AiTemplate.CHATTING_USER_MESSAGE));
summarizationChunkTextArea.textProperty().bindBidirectional(viewModel.getTemplateSources().get(AiTemplate.SUMMARIZATION_CHUNK));
summarizationCombineTextArea.textProperty().bindBidirectional(viewModel.getTemplateSources().get(AiTemplate.SUMMARIZATION_COMBINE));

ActionFactory actionFactory = new ActionFactory();
actionFactory.configureIconButton(StandardActions.HELP, new HelpAction(HelpFile.AI_GENERAL_SETTINGS, dialogService, preferences.getExternalApplicationsPreferences()), generalSettingsHelp);
actionFactory.configureIconButton(StandardActions.HELP, new HelpAction(HelpFile.AI_EXPERT_SETTINGS, dialogService, preferences.getExternalApplicationsPreferences()), expertSettingsHelp);
actionFactory.configureIconButton(StandardActions.HELP, new HelpAction(HelpFile.AI_TEMPLATES, dialogService, preferences.getExternalApplicationsPreferences()), templatesHelp);
}

@Override
Expand All @@ -206,6 +214,11 @@ private void onResetExpertSettingsButtonClick() {
viewModel.resetExpertSettings();
}

@FXML
private void onResetTemplatesButtonClick() {
viewModel.resetTemplates();
}

public ReadOnlyBooleanProperty aiEnabledProperty() {
return enableAi.selectedProperty();
}
Expand Down
42 changes: 24 additions & 18 deletions src/main/java/org/jabref/gui/preferences/ai/AiTabViewModel.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package org.jabref.gui.preferences.ai;

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;

import javafx.beans.property.BooleanProperty;
Expand All @@ -20,6 +22,7 @@
import org.jabref.gui.preferences.PreferenceTabViewModel;
import org.jabref.logic.ai.AiDefaultPreferences;
import org.jabref.logic.ai.AiPreferences;
import org.jabref.logic.ai.templates.AiTemplate;
import org.jabref.logic.l10n.Localization;
import org.jabref.logic.preferences.CliPreferences;
import org.jabref.logic.util.LocalizedNumbers;
Expand Down Expand Up @@ -79,7 +82,13 @@ public class AiTabViewModel implements PreferenceTabViewModel {
private final StringProperty huggingFaceApiBaseUrl = new SimpleStringProperty();
private final StringProperty gpt4AllApiBaseUrl = new SimpleStringProperty();

private final StringProperty instruction = new SimpleStringProperty();
private final Map<AiTemplate, StringProperty> templateSources = Map.of(
AiTemplate.CHATTING_SYSTEM_MESSAGE, new SimpleStringProperty(),
AiTemplate.CHATTING_USER_MESSAGE, new SimpleStringProperty(),
AiTemplate.SUMMARIZATION_CHUNK, new SimpleStringProperty(),
AiTemplate.SUMMARIZATION_COMBINE, new SimpleStringProperty()
);

private final StringProperty temperature = new SimpleStringProperty();
private final IntegerProperty contextWindowSize = new SimpleIntegerProperty();
private final IntegerProperty documentSplitterChunkSize = new SimpleIntegerProperty();
Expand All @@ -96,7 +105,6 @@ public class AiTabViewModel implements PreferenceTabViewModel {
private final Validator chatModelValidator;
private final Validator apiBaseUrlValidator;
private final Validator embeddingModelValidator;
private final Validator instructionValidator;
private final Validator temperatureTypeValidator;
private final Validator temperatureRangeValidator;
private final Validator contextWindowSizeValidator;
Expand Down Expand Up @@ -242,11 +250,6 @@ public AiTabViewModel(CliPreferences preferences) {
Objects::nonNull,
ValidationMessage.error(Localization.lang("Embedding model has to be provided")));

this.instructionValidator = new FunctionBasedValidator<>(
instruction,
message -> !StringUtil.isBlank(message),
ValidationMessage.error(Localization.lang("The instruction has to be provided")));

this.temperatureTypeValidator = new FunctionBasedValidator<>(
temperature,
temp -> LocalizedNumbers.stringToDouble(temp).isPresent(),
Expand Down Expand Up @@ -318,7 +321,10 @@ public void setValues() {
customizeExpertSettings.setValue(aiPreferences.getCustomizeExpertSettings());

selectedEmbeddingModel.setValue(aiPreferences.getEmbeddingModel());
instruction.setValue(aiPreferences.getInstruction());

Arrays.stream(AiTemplate.values()).forEach(template ->
templateSources.get(template).set(aiPreferences.getTemplate(template)));

temperature.setValue(LocalizedNumbers.doubleToString(aiPreferences.getTemperature()));
contextWindowSize.setValue(aiPreferences.getContextWindowSize());
documentSplitterChunkSize.setValue(aiPreferences.getDocumentSplitterChunkSize());
Expand Down Expand Up @@ -359,7 +365,9 @@ public void storeSettings() {
aiPreferences.setHuggingFaceApiBaseUrl(huggingFaceApiBaseUrl.get() == null ? "" : huggingFaceApiBaseUrl.get());
aiPreferences.setGpt4AllApiBaseUrl(gpt4AllApiBaseUrl.get() == null ? "" : gpt4AllApiBaseUrl.get());

aiPreferences.setInstruction(instruction.get());
Arrays.stream(AiTemplate.values()).forEach(template ->
aiPreferences.setTemplate(template, templateSources.get(template).get()));

// We already check the correctness of temperature and RAG minimum score in validators, so we don't need to check it here.
aiPreferences.setTemperature(LocalizedNumbers.stringToDouble(oldLocale, temperature.get()).get());
aiPreferences.setContextWindowSize(contextWindowSize.get());
Expand All @@ -373,8 +381,6 @@ public void resetExpertSettings() {
String resetApiBaseUrl = selectedAiProvider.get().getApiUrl();
currentApiBaseUrl.set(resetApiBaseUrl);

instruction.set(AiDefaultPreferences.SYSTEM_MESSAGE);

contextWindowSize.set(AiDefaultPreferences.getContextWindowSize(selectedAiProvider.get(), currentChatModel.get()));

temperature.set(LocalizedNumbers.doubleToString(AiDefaultPreferences.TEMPERATURE));
Expand All @@ -384,6 +390,11 @@ public void resetExpertSettings() {
ragMinScore.set(LocalizedNumbers.doubleToString(AiDefaultPreferences.RAG_MIN_SCORE));
}

public void resetTemplates() {
Arrays.stream(AiTemplate.values()).forEach(template ->
templateSources.get(template).set(AiDefaultPreferences.TEMPLATES.get(template)));
}

@Override
public boolean validateSettings() {
if (enableAi.get()) {
Expand All @@ -410,7 +421,6 @@ public boolean validateExpertSettings() {
List<Validator> validators = List.of(
apiBaseUrlValidator,
embeddingModelValidator,
instructionValidator,
temperatureTypeValidator,
temperatureRangeValidator,
contextWindowSizeValidator,
Expand Down Expand Up @@ -484,8 +494,8 @@ public BooleanProperty disableApiBaseUrlProperty() {
return disableApiBaseUrl;
}

public StringProperty instructionProperty() {
return instruction;
public Map<AiTemplate, StringProperty> getTemplateSources() {
return templateSources;
}

public StringProperty temperatureProperty() {
Expand Down Expand Up @@ -536,10 +546,6 @@ public ValidationStatus getEmbeddingModelValidationStatus() {
return embeddingModelValidator.getValidationStatus();
}

public ValidationStatus getSystemMessageValidationStatus() {
return instructionValidator.getValidationStatus();
}

public ValidationStatus getTemperatureTypeValidationStatus() {
return temperatureTypeValidator.getValidationStatus();
}
Expand Down
41 changes: 41 additions & 0 deletions src/main/java/org/jabref/logic/ai/AiDefaultPreferences.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.util.List;
import java.util.Map;

import org.jabref.logic.ai.templates.AiTemplate;
import org.jabref.model.ai.AiProvider;
import org.jabref.model.ai.EmbeddingModel;

Expand Down Expand Up @@ -80,6 +81,46 @@ public String toString() {

public static final int FALLBACK_CONTEXT_WINDOW_SIZE = 8196;

public static final Map<AiTemplate, String> TEMPLATES = Map.of(
AiTemplate.CHATTING_SYSTEM_MESSAGE, """
You are an AI assistant that analyses research papers. You answer questions about papers.
You will be supplied with the necessary information. The supplied information will contain mentions of papers in form '@citationKey'.
Whenever you refer to a paper, use its citation key in the same form with @ symbol. Whenever you find relevant information, always use the citation key.
Here are the papers you are analyzing:
#foreach( $entry in $entries )
${CanonicalBibEntry.getCanonicalRepresentation($entry)}
#end""",

AiTemplate.CHATTING_USER_MESSAGE, """
$message
Here is some relevant information for you:
#foreach( $excerpt in $excerpts )
${excerpt.citationKey()}:
${excerpt.text()}
#end""",

AiTemplate.SUMMARIZATION_CHUNK, """
Please provide an overview of the following text. It is a part of a scientific paper.
The summary should include the main objectives, methodologies used, key findings, and conclusions.
Mention any significant experiments, data, or discussions presented in the paper.
DOCUMENT:
$document
OVERVIEW:""",

AiTemplate.SUMMARIZATION_COMBINE, """
You have written an overview of a scientific paper. You have been collecting notes from various parts
of the paper. Now your task is to combine all of the notes in one structured message.
SUMMARIES:
$summaries
FINAL OVERVIEW:"""
);

public static List<String> getAvailableModels(AiProvider aiProvider) {
return Arrays.stream(AiDefaultPreferences.PredefinedChatModel.values())
.filter(model -> model.getAiProvider() == aiProvider)
Expand Down
Loading

0 comments on commit 709386a

Please sign in to comment.