diff --git a/src/main/java/org/jabref/gui/ai/components/errorstate/ErrorStateComponent.java b/src/main/java/org/jabref/gui/ai/components/util/errorstate/ErrorStateComponent.java
similarity index 97%
rename from src/main/java/org/jabref/gui/ai/components/errorstate/ErrorStateComponent.java
rename to src/main/java/org/jabref/gui/ai/components/util/errorstate/ErrorStateComponent.java
index ca58b1f6943..da640261ef7 100644
--- a/src/main/java/org/jabref/gui/ai/components/errorstate/ErrorStateComponent.java
+++ b/src/main/java/org/jabref/gui/ai/components/util/errorstate/ErrorStateComponent.java
@@ -1,4 +1,4 @@
-package org.jabref.gui.ai.components.errorstate;
+package org.jabref.gui.ai.components.util.errorstate;
import javafx.fxml.FXML;
import javafx.scene.control.Button;
diff --git a/src/main/java/org/jabref/gui/ai/components/util/notifications/Notification.java b/src/main/java/org/jabref/gui/ai/components/util/notifications/Notification.java
new file mode 100644
index 00000000000..c8123974a16
--- /dev/null
+++ b/src/main/java/org/jabref/gui/ai/components/util/notifications/Notification.java
@@ -0,0 +1,12 @@
+package org.jabref.gui.ai.components.util.notifications;
+
+/**
+ * Record that is used to display errors and warnings in the AI chat. If you need global notifications,
+ * see {@link org.jabref.gui.DialogService#notify(String)}.
+ *
+ * This type is used to represent errors for: no files in {@link org.jabref.model.entry.BibEntry}, files are processing,
+ * etc. This is made via notifications to support chat with groups: on one hand we need to be able to notify users
+ * about possible problems with entries (because that will affect LLM output), but on the other hand the user would
+ * like to chat with all available entries in the group, even if some of them are not valid.
+ */
+public record Notification(String title, String message) { }
diff --git a/src/main/java/org/jabref/gui/ai/components/util/notifications/NotificationComponent.java b/src/main/java/org/jabref/gui/ai/components/util/notifications/NotificationComponent.java
new file mode 100644
index 00000000000..a711641f949
--- /dev/null
+++ b/src/main/java/org/jabref/gui/ai/components/util/notifications/NotificationComponent.java
@@ -0,0 +1,33 @@
+package org.jabref.gui.ai.components.util.notifications;
+
+import javafx.geometry.Insets;
+import javafx.scene.control.Label;
+import javafx.scene.layout.VBox;
+import javafx.scene.text.Font;
+
+/**
+ * Component used to display {@link Notification} in AI chat. See the documentation of {@link Notification} for more
+ * details.
+ */
+public class NotificationComponent extends VBox {
+ private final Label title = new Label("Title");
+ private final Label message = new Label("Message");
+
+ public NotificationComponent() {
+ setSpacing(10);
+ setPadding(new Insets(10));
+
+ title.setFont(new Font("System Bold", title.getFont().getSize()));
+ this.getChildren().addAll(title, message);
+ }
+
+ public NotificationComponent(Notification notification) {
+ this();
+ setNotification(notification);
+ }
+
+ public void setNotification(Notification notification) {
+ title.setText(notification.title());
+ message.setText(notification.message());
+ }
+}
diff --git a/src/main/java/org/jabref/gui/ai/components/util/notifications/NotificationsComponent.java b/src/main/java/org/jabref/gui/ai/components/util/notifications/NotificationsComponent.java
new file mode 100644
index 00000000000..72ec5f271ce
--- /dev/null
+++ b/src/main/java/org/jabref/gui/ai/components/util/notifications/NotificationsComponent.java
@@ -0,0 +1,31 @@
+package org.jabref.gui.ai.components.util.notifications;
+
+import java.util.List;
+
+import javafx.collections.ListChangeListener;
+import javafx.collections.ObservableList;
+import javafx.scene.control.ScrollPane;
+import javafx.scene.layout.VBox;
+
+/**
+ * A {@link ScrollPane} for displaying AI chat {@link Notification}s. See the documentation of {@link Notification} for
+ * more details.
+ */
+public class NotificationsComponent extends ScrollPane {
+ private static final double SCROLL_PANE_MAX_HEIGHT = 300;
+
+ private final VBox vBox = new VBox(10);
+
+ public NotificationsComponent(ObservableList notifications) {
+ setContent(vBox);
+ setMaxHeight(SCROLL_PANE_MAX_HEIGHT);
+
+ fill(notifications);
+ notifications.addListener((ListChangeListener super Notification>) change -> fill(notifications));
+ }
+
+ private void fill(List notifications) {
+ vBox.getChildren().clear();
+ notifications.stream().map(NotificationComponent::new).forEach(vBox.getChildren()::add);
+ }
+}
diff --git a/src/main/java/org/jabref/gui/entryeditor/AiChatTab.java b/src/main/java/org/jabref/gui/entryeditor/AiChatTab.java
index 4eb66e7ea17..eead1789d95 100644
--- a/src/main/java/org/jabref/gui/entryeditor/AiChatTab.java
+++ b/src/main/java/org/jabref/gui/entryeditor/AiChatTab.java
@@ -5,24 +5,18 @@
import java.util.List;
import java.util.Optional;
-import javafx.scene.Node;
+import javafx.beans.property.SimpleStringProperty;
+import javafx.beans.property.StringProperty;
+import javafx.collections.FXCollections;
import javafx.scene.control.Tooltip;
import org.jabref.gui.DialogService;
-import org.jabref.gui.LibraryTabContainer;
-import org.jabref.gui.ai.components.aichat.AiChatComponent;
-import org.jabref.gui.ai.components.apikeymissing.ApiKeyMissingComponent;
-import org.jabref.gui.ai.components.errorstate.ErrorStateComponent;
+import org.jabref.gui.ai.components.aichat.AiChatGuardedComponent;
import org.jabref.gui.ai.components.privacynotice.PrivacyNoticeComponent;
+import org.jabref.gui.ai.components.util.errorstate.ErrorStateComponent;
import org.jabref.gui.util.TaskExecutor;
-import org.jabref.gui.util.UiTaskExecutor;
-import org.jabref.logic.ai.AiChatLogic;
import org.jabref.logic.ai.AiService;
-import org.jabref.logic.ai.GenerateEmbeddingsTask;
-import org.jabref.logic.ai.chathistory.AiChatHistory;
-import org.jabref.logic.ai.chathistory.InMemoryAiChatHistory;
-import org.jabref.logic.ai.embeddings.FullyIngestedDocumentsTracker;
-import org.jabref.logic.ai.models.JabRefEmbeddingModel;
+import org.jabref.logic.ai.util.CitationKeyCheck;
import org.jabref.logic.citationkeypattern.CitationKeyGenerator;
import org.jabref.logic.l10n.Localization;
import org.jabref.logic.util.io.FileUtil;
@@ -31,51 +25,41 @@
import org.jabref.model.entry.LinkedFile;
import org.jabref.preferences.FilePreferences;
import org.jabref.preferences.PreferencesService;
-import org.jabref.preferences.ai.AiApiKeyProvider;
-
-import com.google.common.eventbus.Subscribe;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+import org.jabref.preferences.ai.AiPreferences;
public class AiChatTab extends EntryEditorTab {
- private static final Logger LOGGER = LoggerFactory.getLogger(AiChatTab.class);
-
- private final LibraryTabContainer libraryTabContainer;
+ private final BibDatabaseContext bibDatabaseContext;
+ private final AiService aiService;
private final DialogService dialogService;
+ private final AiPreferences aiPreferences;
private final FilePreferences filePreferences;
private final EntryEditorPreferences entryEditorPreferences;
- private final BibDatabaseContext bibDatabaseContext;
- private final TaskExecutor taskExecutor;
private final CitationKeyGenerator citationKeyGenerator;
- private final AiService aiService;
- private final AiApiKeyProvider aiApiKeyProvider;
+ private final TaskExecutor taskExecutor;
- private final List entriesUnderIngestion = new ArrayList<>();
+ private Optional previousBibEntry = Optional.empty();
- public AiChatTab(LibraryTabContainer libraryTabContainer,
+ public AiChatTab(BibDatabaseContext bibDatabaseContext,
+ AiService aiService,
DialogService dialogService,
PreferencesService preferencesService,
- AiApiKeyProvider aiApiKeyProvider,
- AiService aiService,
- BibDatabaseContext bibDatabaseContext,
- TaskExecutor taskExecutor) {
- this.libraryTabContainer = libraryTabContainer;
+ TaskExecutor taskExecutor
+ ) {
+ this.bibDatabaseContext = bibDatabaseContext;
+
+ this.aiService = aiService;
this.dialogService = dialogService;
+ this.aiPreferences = preferencesService.getAiPreferences();
this.filePreferences = preferencesService.getFilePreferences();
this.entryEditorPreferences = preferencesService.getEntryEditorPreferences();
- this.aiApiKeyProvider = aiApiKeyProvider;
- this.aiService = aiService;
- this.bibDatabaseContext = bibDatabaseContext;
- this.taskExecutor = taskExecutor;
this.citationKeyGenerator = new CitationKeyGenerator(bibDatabaseContext, preferencesService.getCitationKeyPatternPreferences());
+ this.taskExecutor = taskExecutor;
+
setText(Localization.lang("AI chat"));
setTooltip(new Tooltip(Localization.lang("Chat with AI about content of attached file(s)")));
-
- aiService.getEmbeddingsManager().registerListener(this);
- aiService.getEmbeddingModel().registerListener(this);
}
@Override
@@ -83,59 +67,31 @@ public boolean shouldShow(BibEntry entry) {
return entryEditorPreferences.shouldShowAiChatTab();
}
- @Override
- protected void handleFocus() {
- if (currentEntry != null) {
- bindToEntry(currentEntry);
- }
- }
-
/**
* @implNote Method similar to {@link AiSummaryTab#bindToEntry(BibEntry)}
*/
@Override
protected void bindToEntry(BibEntry entry) {
- if (!aiService.getPreferences().getEnableAi()) {
+ previousBibEntry.ifPresent(previousBibEntry ->
+ aiService.getChatHistoryService().closeChatHistoryForEntry(previousBibEntry));
+
+ previousBibEntry = Optional.of(entry);
+
+ if (!aiPreferences.getEnableAi()) {
showPrivacyNotice(entry);
- } else if (aiApiKeyProvider.getApiKeyForAiProvider(aiService.getPreferences().getAiProvider()).isEmpty()) {
- showApiKeyMissing();
} else if (entry.getFiles().isEmpty()) {
showErrorNoFiles();
} else if (entry.getFiles().stream().map(LinkedFile::getLink).map(Path::of).noneMatch(FileUtil::isPDFFile)) {
showErrorNotPdfs();
- } else if (!citationKeyIsValid(bibDatabaseContext, entry)) {
+ } else if (!CitationKeyCheck.citationKeyIsPresentAndUnique(bibDatabaseContext, entry)) {
tryToGenerateCitationKeyThenBind(entry);
- } else if (!aiService.getEmbeddingModel().isPresent()) {
- if (aiService.getEmbeddingModel().hadErrorWhileBuildingModel()) {
- showErrorWhileBuildingEmbeddingModel();
- } else {
- showBuildingEmbeddingModel();
- }
- } else if (!aiService.getEmbeddingsManager().hasIngestedLinkedFiles(entry.getFiles())) {
- startIngesting(entry);
} else {
- entriesUnderIngestion.remove(entry);
bindToCorrectEntry(entry);
}
}
private void showPrivacyNotice(BibEntry entry) {
- setContent(new PrivacyNoticeComponent(dialogService, aiService.getPreferences(), filePreferences, () -> {
- bindToEntry(entry);
- }));
- }
-
- private void showApiKeyMissing() {
- setContent(new ApiKeyMissingComponent(libraryTabContainer, dialogService));
- }
-
- private void showErrorNotIngested() {
- setContent(
- ErrorStateComponent.withSpinner(
- Localization.lang("Processing..."),
- Localization.lang("The embeddings of the file(s) are currently being generated. Please wait, and at the end you will be able to chat.")
- )
- );
+ setContent(new PrivacyNoticeComponent(aiPreferences, () -> bindToEntry(entry), filePreferences, dialogService));
}
private void showErrorNotPdfs() {
@@ -169,103 +125,22 @@ private void tryToGenerateCitationKeyThenBind(BibEntry entry) {
}
}
- private void showErrorWhileBuildingEmbeddingModel() {
- setContent(
- ErrorStateComponent.withTextAreaAndButton(
- Localization.lang("Unable to chat"),
- Localization.lang("An error occurred while building the embedding model"),
- aiService.getEmbeddingModel().getErrorWhileBuildingModel(),
- Localization.lang("Rebuild"),
- () -> aiService.getEmbeddingModel().startRebuildingTask()
- )
- );
- }
-
- public void showBuildingEmbeddingModel() {
- setContent(
- ErrorStateComponent.withSpinner(
- Localization.lang("Downloading..."),
- Localization.lang("Downloading embedding model... Afterward, you will be able to chat with your files.")
- )
- );
- }
-
- private static boolean citationKeyIsValid(BibDatabaseContext bibDatabaseContext, BibEntry bibEntry) {
- return !hasEmptyCitationKey(bibEntry) && bibEntry.getCitationKey().map(key -> citationKeyIsUnique(bibDatabaseContext, key)).orElse(false);
- }
-
- private static boolean hasEmptyCitationKey(BibEntry bibEntry) {
- return bibEntry.getCitationKey().map(String::isEmpty).orElse(true);
- }
-
- private static boolean citationKeyIsUnique(BibDatabaseContext bibDatabaseContext, String citationKey) {
- return bibDatabaseContext.getDatabase().getNumberOfCitationKeyOccurrences(citationKey) == 1;
- }
-
- private void startIngesting(BibEntry entry) {
- // This method should be called if entry is fully prepared for chatting.
- assert entry.getCitationKey().isPresent();
-
- showErrorNotIngested();
-
- if (!entriesUnderIngestion.contains(entry)) {
- entriesUnderIngestion.add(entry);
-
- new GenerateEmbeddingsTask(entry.getCitationKey().get(), entry.getFiles(), aiService.getEmbeddingsManager(), bibDatabaseContext, filePreferences)
- .onSuccess(res -> handleFocus())
- .onFailure(this::showErrorWhileIngesting)
- .executeWith(taskExecutor);
- }
- }
-
- private void showErrorWhileIngesting(Exception e) {
- LOGGER.error("Got an error while generating embeddings for entry {}", currentEntry.getCitationKey(), e);
-
- setContent(ErrorStateComponent.withTextArea(Localization.lang("Unable to chat"), Localization.lang("Got error while processing the file:"), e.getMessage()));
-
- entriesUnderIngestion.remove(currentEntry);
-
- currentEntry.getFiles().stream().map(LinkedFile::getLink).forEach(link -> aiService.getEmbeddingsManager().removeDocument(link));
- }
-
private void bindToCorrectEntry(BibEntry entry) {
- assert entry.getCitationKey().isPresent();
-
- AiChatHistory aiChatHistory = getAiChatHistory(aiService, entry, bibDatabaseContext);
-
- AiChatLogic aiChatLogic = AiChatLogic.forBibEntry(aiService, aiChatHistory, entry);
-
- Node content = new AiChatComponent(aiService.getPreferences(), aiChatLogic, entry.getCitationKey().get(), dialogService, taskExecutor);
-
- setContent(content);
- }
-
- private static AiChatHistory getAiChatHistory(AiService aiService, BibEntry entry, BibDatabaseContext bibDatabaseContext) {
- Optional databasePath = bibDatabaseContext.getDatabasePath();
-
- if (databasePath.isEmpty() || entry.getCitationKey().isEmpty()) {
- LOGGER.warn("AI chat is constructed, but the database path is empty. Cannot store chat history");
- return new InMemoryAiChatHistory();
- } else if (entry.getCitationKey().isEmpty()) {
- LOGGER.warn("AI chat is constructed, but the entry citation key is empty. Cannot store chat history");
- return new InMemoryAiChatHistory();
- } else {
- return aiService.getChatHistoryManager().getChatHistory(databasePath.get(), entry.getCitationKey().get());
- }
- }
-
- @Subscribe
- public void listen(FullyIngestedDocumentsTracker.DocumentIngestedEvent event) {
- UiTaskExecutor.runInJavaFXThread(AiChatTab.this::handleFocus);
- }
-
- @Subscribe
- public void listen(JabRefEmbeddingModel.EmbeddingModelBuiltEvent event) {
- UiTaskExecutor.runInJavaFXThread(AiChatTab.this::handleFocus);
- }
-
- @Subscribe
- public void listen(JabRefEmbeddingModel.EmbeddingModelBuildingErrorEvent event) {
- UiTaskExecutor.runInJavaFXThread(AiChatTab.this::handleFocus);
+ // We omit the localization here, because it is only a chat with one entry in the {@link EntryEditor}.
+ // See documentation for {@link AiChatGuardedComponent#name}.
+ StringProperty chatName = new SimpleStringProperty("entry " + entry.getCitationKey().orElse(""));
+ entry.getCiteKeyBinding().addListener((observable, oldValue, newValue) -> chatName.setValue("entry " + newValue));
+
+ setContent(new AiChatGuardedComponent(
+ chatName,
+ aiService.getChatHistoryService().getChatHistoryForEntry(entry),
+ bibDatabaseContext,
+ FXCollections.observableArrayList(new ArrayList<>(List.of(entry))),
+ aiService,
+ dialogService,
+ aiPreferences,
+ filePreferences,
+ taskExecutor
+ ));
}
}
diff --git a/src/main/java/org/jabref/gui/entryeditor/AiSummaryTab.java b/src/main/java/org/jabref/gui/entryeditor/AiSummaryTab.java
index f787d91d5bf..f61f69eb1bc 100644
--- a/src/main/java/org/jabref/gui/entryeditor/AiSummaryTab.java
+++ b/src/main/java/org/jabref/gui/entryeditor/AiSummaryTab.java
@@ -1,73 +1,44 @@
package org.jabref.gui.entryeditor;
-import java.nio.file.Path;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Optional;
-
import javafx.scene.control.Tooltip;
import org.jabref.gui.DialogService;
-import org.jabref.gui.LibraryTabContainer;
-import org.jabref.gui.ai.components.apikeymissing.ApiKeyMissingComponent;
-import org.jabref.gui.ai.components.errorstate.ErrorStateComponent;
-import org.jabref.gui.ai.components.privacynotice.PrivacyNoticeComponent;
import org.jabref.gui.ai.components.summary.SummaryComponent;
-import org.jabref.gui.util.TaskExecutor;
-import org.jabref.gui.util.UiTaskExecutor;
import org.jabref.logic.ai.AiService;
-import org.jabref.logic.ai.summarization.GenerateSummaryTask;
-import org.jabref.logic.ai.summarization.SummariesStorage;
-import org.jabref.logic.citationkeypattern.CitationKeyGenerator;
+import org.jabref.logic.citationkeypattern.CitationKeyPatternPreferences;
import org.jabref.logic.l10n.Localization;
-import org.jabref.logic.util.io.FileUtil;
import org.jabref.model.database.BibDatabaseContext;
import org.jabref.model.entry.BibEntry;
-import org.jabref.model.entry.LinkedFile;
import org.jabref.preferences.FilePreferences;
import org.jabref.preferences.PreferencesService;
-import org.jabref.preferences.ai.AiApiKeyProvider;
-
-import com.google.common.eventbus.Subscribe;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+import org.jabref.preferences.ai.AiPreferences;
public class AiSummaryTab extends EntryEditorTab {
- private static final Logger LOGGER = LoggerFactory.getLogger(AiSummaryTab.class);
-
- private final LibraryTabContainer libraryTabContainer;
+ private final BibDatabaseContext bibDatabaseContext;
+ private final AiService aiService;
private final DialogService dialogService;
+ private final AiPreferences aiPreferences;
private final FilePreferences filePreferences;
+ private final CitationKeyPatternPreferences citationKeyPatternPreferences;
private final EntryEditorPreferences entryEditorPreferences;
- private final BibDatabaseContext bibDatabaseContext;
- private final TaskExecutor taskExecutor;
- private final CitationKeyGenerator citationKeyGenerator;
- private final AiApiKeyProvider aiApiKeyProvider;
- private final AiService aiService;
-
- private final List entriesUnderSummarization = new ArrayList<>();
- public AiSummaryTab(LibraryTabContainer libraryTabContainer,
- DialogService dialogService,
- PreferencesService preferencesService,
- AiApiKeyProvider aiApiKeyProvider,
+ public AiSummaryTab(BibDatabaseContext bibDatabaseContext,
AiService aiService,
- BibDatabaseContext bibDatabaseContext,
- TaskExecutor taskExecutor) {
- this.libraryTabContainer = libraryTabContainer;
+ DialogService dialogService,
+ PreferencesService preferencesService
+ ) {
+ this.bibDatabaseContext = bibDatabaseContext;
+
+ this.aiService = aiService;
this.dialogService = dialogService;
+
+ this.aiPreferences = preferencesService.getAiPreferences();
this.filePreferences = preferencesService.getFilePreferences();
+ this.citationKeyPatternPreferences = preferencesService.getCitationKeyPatternPreferences();
this.entryEditorPreferences = preferencesService.getEntryEditorPreferences();
- this.aiApiKeyProvider = aiApiKeyProvider;
- this.aiService = aiService;
- this.bibDatabaseContext = bibDatabaseContext;
- this.taskExecutor = taskExecutor;
- this.citationKeyGenerator = new CitationKeyGenerator(bibDatabaseContext, preferencesService.getCitationKeyPatternPreferences());
setText(Localization.lang("AI summary"));
setTooltip(new Tooltip(Localization.lang("AI-generated summary of attached file(s)")));
-
- aiService.getSummariesStorage().registerListener(new SummarySetListener());
}
@Override
@@ -75,169 +46,19 @@ public boolean shouldShow(BibEntry entry) {
return entryEditorPreferences.shouldShowAiSummaryTab();
}
- @Override
- protected void handleFocus() {
- if (currentEntry != null) {
- bindToEntry(currentEntry);
- }
- }
-
/**
* @implNote Method similar to {@link AiChatTab#bindToEntry(BibEntry)}
*/
@Override
protected void bindToEntry(BibEntry entry) {
- if (!aiService.getPreferences().getEnableAi()) {
- showPrivacyNotice(entry);
- } else if (aiApiKeyProvider.getApiKeyForAiProvider(aiService.getPreferences().getAiProvider()).isEmpty()) {
- showApiKeyMissing();
- } else if (bibDatabaseContext.getDatabasePath().isEmpty()) {
- showErrorNoDatabasePath();
- } else if (entry.getFiles().isEmpty()) {
- showErrorNoFiles();
- } else if (entry.getFiles().stream().map(LinkedFile::getLink).map(Path::of).noneMatch(FileUtil::isPDFFile)) {
- showErrorNotPdfs();
- } else if (entry.getCitationKey().isEmpty() || !citationKeyIsValid(bibDatabaseContext, entry)) {
- // There is no need for additional check `entry.getCitationKey().isEmpty()` because method `citationKeyIsValid`,
- // will check this. But with this call the linter is happy for the next expression in else if.
- tryToGenerateCitationKeyThenBind(entry);
- } else {
- Optional summary = aiService.getSummariesStorage().get(bibDatabaseContext.getDatabasePath().get(), entry.getCitationKey().get());
- if (summary.isEmpty()) {
- startGeneratingSummary(entry);
- } else {
- bindToCorrectEntry(summary.get());
- }
- }
- }
-
- private void showPrivacyNotice(BibEntry entry) {
- setContent(new PrivacyNoticeComponent(dialogService, aiService.getPreferences(), filePreferences, () -> {
- bindToEntry(entry);
- }));
- }
-
- private void showApiKeyMissing() {
- setContent(new ApiKeyMissingComponent(libraryTabContainer, dialogService));
- }
-
- private void showErrorNoDatabasePath() {
- setContent(
- new ErrorStateComponent(
- Localization.lang("Unable to chat"),
- Localization.lang("The path of the current library is not set, but it is required for summarization")
- )
- );
- }
-
- private void showErrorNotPdfs() {
- setContent(
- new ErrorStateComponent(
- Localization.lang("Unable to chat"),
- Localization.lang("Only PDF files are supported.")
- )
- );
- }
-
- private void showErrorNoFiles() {
- setContent(
- new ErrorStateComponent(
- Localization.lang("Unable to chat"),
- Localization.lang("Please attach at least one PDF file to enable chatting with PDF file(s).")
- )
- );
- }
-
- private void tryToGenerateCitationKeyThenBind(BibEntry entry) {
- if (citationKeyGenerator.generateAndSetKey(entry).isEmpty()) {
- setContent(
- new ErrorStateComponent(
- Localization.lang("Unable to chat"),
- Localization.lang("Please provide a non-empty and unique citation key for this entry.")
- )
- );
- } else {
- bindToEntry(entry);
- }
- }
-
- private static boolean citationKeyIsValid(BibDatabaseContext bibDatabaseContext, BibEntry bibEntry) {
- return !hasEmptyCitationKey(bibEntry) && bibEntry.getCitationKey().map(key -> citationKeyIsUnique(bibDatabaseContext, key)).orElse(false);
- }
-
- private static boolean hasEmptyCitationKey(BibEntry bibEntry) {
- return bibEntry.getCitationKey().map(String::isEmpty).orElse(true);
- }
-
- private static boolean citationKeyIsUnique(BibDatabaseContext bibDatabaseContext, String citationKey) {
- return bibDatabaseContext.getDatabase().getNumberOfCitationKeyOccurrences(citationKey) == 1;
- }
-
- private void startGeneratingSummary(BibEntry entry) {
- assert entry.getCitationKey().isPresent();
-
- showErrorNotSummarized();
-
- if (!entriesUnderSummarization.contains(entry)) {
- entriesUnderSummarization.add(entry);
-
- new GenerateSummaryTask(bibDatabaseContext, entry.getCitationKey().get(), entry.getFiles(), aiService, filePreferences)
- .onSuccess(res -> handleFocus())
- .onFailure(this::showErrorWhileSummarizing)
- .executeWith(taskExecutor);
- }
- }
-
- private void showErrorWhileSummarizing(Exception e) {
- LOGGER.error("Got an error while generating a summary for entry {}", currentEntry.getCitationKey(), e);
-
- setContent(
- ErrorStateComponent.withTextAreaAndButton(
- Localization.lang("Unable to chat"),
- Localization.lang("Got error while processing the file:"),
- e.getMessage(),
- Localization.lang("Regenerate"),
- () -> bindToEntry(currentEntry)
- )
- );
-
- entriesUnderSummarization.remove(currentEntry);
- }
-
- private void showErrorNotSummarized() {
- setContent(
- ErrorStateComponent.withSpinner(
- Localization.lang("Processing..."),
- Localization.lang("The attached file(s) are currently being processed by %0. Once completed, you will be able to see the summary.", aiService.getPreferences().getAiProvider().getLabel())
- )
- );
- }
-
- private void bindToCorrectEntry(SummariesStorage.SummarizationRecord summary) {
- entriesUnderSummarization.remove(currentEntry);
-
- SummaryComponent summaryComponent = new SummaryComponent(summary, () -> {
- if (bibDatabaseContext.getDatabasePath().isEmpty()) {
- LOGGER.error("Bib database path is not set, but it was expected to be present. Unable to regenerate summary");
- return;
- }
-
- if (currentEntry.getCitationKey().isEmpty()) {
- LOGGER.error("Citation key is not set, but it was expected to be present. Unable to regenerate summary");
- return;
- }
-
- aiService.getSummariesStorage().clear(bibDatabaseContext.getDatabasePath().get(), currentEntry.getCitationKey().get());
- bindToEntry(currentEntry);
- });
-
- setContent(summaryComponent);
- }
-
- private class SummarySetListener {
- @Subscribe
- public void listen(SummariesStorage.SummarySetEvent event) {
- UiTaskExecutor.runInJavaFXThread(AiSummaryTab.this::handleFocus);
- }
+ setContent(new SummaryComponent(
+ bibDatabaseContext,
+ entry,
+ aiService,
+ aiPreferences,
+ filePreferences,
+ citationKeyPatternPreferences,
+ dialogService
+ ));
}
}
diff --git a/src/main/java/org/jabref/gui/entryeditor/EntryEditor.java b/src/main/java/org/jabref/gui/entryeditor/EntryEditor.java
index 08eb0c11780..67be8a5f908 100644
--- a/src/main/java/org/jabref/gui/entryeditor/EntryEditor.java
+++ b/src/main/java/org/jabref/gui/entryeditor/EntryEditor.java
@@ -314,8 +314,8 @@ private List createTabs() {
tabs.add(sourceTab);
tabs.add(new LatexCitationsTab(databaseContext, preferencesService, dialogService, directoryMonitorManager));
tabs.add(new FulltextSearchResultsTab(stateManager, preferencesService, dialogService, databaseContext, taskExecutor, libraryTab.searchQueryProperty()));
- tabs.add(new AiSummaryTab(libraryTab.getLibraryTabContainer(), dialogService, preferencesService, aiApiKeyProvider, aiService, libraryTab.getBibDatabaseContext(), taskExecutor));
- tabs.add(new AiChatTab(libraryTab.getLibraryTabContainer(), dialogService, preferencesService, aiApiKeyProvider, aiService, libraryTab.getBibDatabaseContext(), taskExecutor));
+ tabs.add(new AiSummaryTab(libraryTab.getBibDatabaseContext(), aiService, dialogService, preferencesService));
+ tabs.add(new AiChatTab(libraryTab.getBibDatabaseContext(), aiService, dialogService, preferencesService, taskExecutor));
return tabs;
}
diff --git a/src/main/java/org/jabref/gui/groups/GroupTreeView.java b/src/main/java/org/jabref/gui/groups/GroupTreeView.java
index 0b50b7d6d20..97eb5d686c9 100644
--- a/src/main/java/org/jabref/gui/groups/GroupTreeView.java
+++ b/src/main/java/org/jabref/gui/groups/GroupTreeView.java
@@ -58,6 +58,7 @@
import org.jabref.gui.util.TaskExecutor;
import org.jabref.gui.util.ViewModelTreeTableCellFactory;
import org.jabref.gui.util.ViewModelTreeTableRowFactory;
+import org.jabref.logic.ai.AiService;
import org.jabref.logic.l10n.Localization;
import org.jabref.model.entry.BibEntry;
import org.jabref.preferences.PreferencesService;
@@ -81,6 +82,7 @@ public class GroupTreeView extends BorderPane {
private final StateManager stateManager;
private final DialogService dialogService;
+ private final AiService aiService;
private final TaskExecutor taskExecutor;
private final PreferencesService preferencesService;
@@ -105,11 +107,14 @@ public class GroupTreeView extends BorderPane {
public GroupTreeView(TaskExecutor taskExecutor,
StateManager stateManager,
PreferencesService preferencesService,
- DialogService dialogService) {
+ DialogService dialogService,
+ AiService aiService
+ ) {
this.taskExecutor = taskExecutor;
this.stateManager = stateManager;
this.preferencesService = preferencesService;
this.dialogService = dialogService;
+ this.aiService = aiService;
createNodes();
this.getStylesheets().add(Objects.requireNonNull(GroupTreeView.class.getResource("GroupTree.css")).toExternalForm());
@@ -159,7 +164,7 @@ private void createNodes() {
private void initialize() {
this.localDragboard = stateManager.getLocalDragboard();
- viewModel = new GroupTreeViewModel(stateManager, dialogService, preferencesService, taskExecutor, localDragboard);
+ viewModel = new GroupTreeViewModel(stateManager, dialogService, aiService, preferencesService, taskExecutor, localDragboard);
// Set-up groups tree
groupTree.getSelectionModel().setSelectionMode(SelectionMode.MULTIPLE);
@@ -544,6 +549,10 @@ private ContextMenu createContextMenuForGroup(GroupNodeViewModel group) {
removeGroup = factory.createMenuItem(StandardActions.GROUP_REMOVE, new GroupTreeView.ContextAction(StandardActions.GROUP_REMOVE, group));
}
+ if (preferencesService.getAiPreferences().getEnableAi()) {
+ contextMenu.getItems().add(factory.createMenuItem(StandardActions.GROUP_CHAT, new ContextAction(StandardActions.GROUP_CHAT, group)));
+ }
+
contextMenu.getItems().addAll(
factory.createMenuItem(StandardActions.GROUP_EDIT, new ContextAction(StandardActions.GROUP_EDIT, group)),
removeGroup,
@@ -659,6 +668,8 @@ public void execute() {
viewModel.editGroup(group);
groupTree.refresh();
}
+ case GROUP_CHAT ->
+ viewModel.chatWithGroup(group);
case GROUP_SUBGROUP_ADD ->
viewModel.addNewSubgroup(group, GroupDialogHeader.SUBGROUP);
case GROUP_SUBGROUP_REMOVE ->
diff --git a/src/main/java/org/jabref/gui/groups/GroupTreeViewModel.java b/src/main/java/org/jabref/gui/groups/GroupTreeViewModel.java
index 5593b7052e5..b47ea65a01d 100644
--- a/src/main/java/org/jabref/gui/groups/GroupTreeViewModel.java
+++ b/src/main/java/org/jabref/gui/groups/GroupTreeViewModel.java
@@ -26,6 +26,7 @@
import org.jabref.gui.StateManager;
import org.jabref.gui.util.CustomLocalDragboard;
import org.jabref.gui.util.TaskExecutor;
+import org.jabref.logic.ai.AiService;
import org.jabref.logic.l10n.Localization;
import org.jabref.model.database.BibDatabaseContext;
import org.jabref.model.entry.BibEntry;
@@ -42,6 +43,7 @@
import org.jabref.preferences.PreferencesService;
import com.tobiasdiez.easybind.EasyBind;
+import dev.langchain4j.data.message.ChatMessage;
public class GroupTreeViewModel extends AbstractViewModel {
@@ -49,6 +51,7 @@ public class GroupTreeViewModel extends AbstractViewModel {
private final ListProperty selectedGroups = new SimpleListProperty<>(FXCollections.observableArrayList());
private final StateManager stateManager;
private final DialogService dialogService;
+ private final AiService aiService;
private final PreferencesService preferences;
private final TaskExecutor taskExecutor;
private final CustomLocalDragboard localDragboard;
@@ -72,9 +75,10 @@ public class GroupTreeViewModel extends AbstractViewModel {
};
private Optional currentDatabase = Optional.empty();
- public GroupTreeViewModel(StateManager stateManager, DialogService dialogService, PreferencesService preferencesService, TaskExecutor taskExecutor, CustomLocalDragboard localDragboard) {
+ public GroupTreeViewModel(StateManager stateManager, DialogService dialogService, AiService aiService, PreferencesService preferencesService, TaskExecutor taskExecutor, CustomLocalDragboard localDragboard) {
this.stateManager = Objects.requireNonNull(stateManager);
this.dialogService = Objects.requireNonNull(dialogService);
+ this.aiService = Objects.requireNonNull(aiService);
this.preferences = Objects.requireNonNull(preferencesService);
this.taskExecutor = Objects.requireNonNull(taskExecutor);
this.localDragboard = Objects.requireNonNull(localDragboard);
@@ -258,8 +262,8 @@ public void editGroup(GroupNodeViewModel oldGroup) {
oldGroup.getGroupNode().getParent().orElse(null),
oldGroup.getGroupNode().getGroup(),
GroupDialogHeader.SUBGROUP));
- newGroup.ifPresent(group -> {
+ newGroup.ifPresent(group -> {
AbstractGroup oldGroupDef = oldGroup.getGroupNode().getGroup();
String oldGroupName = oldGroupDef.getName();
@@ -286,7 +290,7 @@ public void editGroup(GroupNodeViewModel oldGroup) {
dialogService.notify(Localization.lang("Modified group \"%0\".", group.getName()));
writeGroupChangesToMetaData();
- // This is ugly but we have no proper update mechanism in place to propagate the changes, so redraw everything
+ // This is ugly, but we have no proper update mechanism in place to propagate the changes, so redraw everything
refresh();
return;
}
@@ -372,12 +376,32 @@ public void editGroup(GroupNodeViewModel oldGroup) {
dialogService.notify(Localization.lang("Modified group \"%0\".", group.getName()));
writeGroupChangesToMetaData();
- // This is ugly but we have no proper update mechanism in place to propagate the changes, so redraw everything
+ // This is ugly, but we have no proper update mechanism in place to propagate the changes, so redraw everything
refresh();
});
});
}
+ public void chatWithGroup(GroupNodeViewModel group) {
+ // This should probably be done some other way. Please don't blame, it's just a thing to make it quick and fast.
+ if (currentDatabase.isEmpty()) {
+ dialogService.showErrorDialogAndWait(Localization.lang("Unable to chat with group"), Localization.lang("No database is set."));
+ return;
+ }
+
+ StringProperty groupNameProperty = group.getGroupNode().getGroup().nameProperty();
+
+ // We localize the name here, because it is used as the title of the window.
+ // See documentation for {@link AiChatGuardedComponent#name}.
+ StringProperty nameProperty = new SimpleStringProperty(Localization.lang("Group %0", groupNameProperty.get()));
+ groupNameProperty.addListener((obs, oldValue, newValue) -> nameProperty.setValue(Localization.lang("Group %0", groupNameProperty.get())));
+
+ ObservableList chatHistory = aiService.getChatHistoryService().getChatHistoryForGroup(group.getGroupNode());
+ ObservableList bibEntries = FXCollections.observableArrayList(group.getGroupNode().findMatches(currentDatabase.get().getDatabase()));
+
+ aiService.openAiChat(nameProperty, chatHistory, currentDatabase.get(), bibEntries);
+ }
+
public void removeSubgroups(GroupNodeViewModel group) {
boolean confirmation = dialogService.showConfirmationDialogAndWait(
Localization.lang("Remove subgroups"),
diff --git a/src/main/java/org/jabref/gui/importer/actions/ListenForCitationKeyChangeForAiAction.java b/src/main/java/org/jabref/gui/importer/actions/ListenForCitationKeyChangeForAiAction.java
deleted file mode 100644
index 095e2840cce..00000000000
--- a/src/main/java/org/jabref/gui/importer/actions/ListenForCitationKeyChangeForAiAction.java
+++ /dev/null
@@ -1,23 +0,0 @@
-package org.jabref.gui.importer.actions;
-
-import org.jabref.gui.DialogService;
-import org.jabref.logic.ai.AiService;
-import org.jabref.logic.importer.ParserResult;
-import org.jabref.preferences.PreferencesService;
-
-import com.airhacks.afterburner.injection.Injector;
-
-public class ListenForCitationKeyChangeForAiAction implements GUIPostOpenAction {
- private final AiService aiService = Injector.instantiateModelOrService(AiService.class);
-
- @Override
- public boolean isActionNecessary(ParserResult pr, PreferencesService preferencesService) {
- return true;
- }
-
- @Override
- public void performAction(ParserResult pr, DialogService dialogService, PreferencesService preferencesService) {
- pr.getDatabase().registerListener(aiService.getChatHistoryManager());
- pr.getDatabase().registerListener(aiService.getSummariesStorage());
- }
-}
diff --git a/src/main/java/org/jabref/gui/importer/actions/OpenDatabaseAction.java b/src/main/java/org/jabref/gui/importer/actions/OpenDatabaseAction.java
index c01e63c1570..97ccf23758d 100644
--- a/src/main/java/org/jabref/gui/importer/actions/OpenDatabaseAction.java
+++ b/src/main/java/org/jabref/gui/importer/actions/OpenDatabaseAction.java
@@ -56,9 +56,7 @@ public class OpenDatabaseAction extends SimpleCommand {
// Check for new custom entry types loaded from the BIB file:
new CheckForNewEntryTypesAction(),
// Migrate search groups from Search.g4 to Lucene syntax
- new SearchGroupsMigrationAction(),
- // AI chat history links BibEntry with citation key. When citation key is changed, chat history should be transferred from old citation key to new citation key
- new ListenForCitationKeyChangeForAiAction());
+ new SearchGroupsMigrationAction());
private final LibraryTabContainer tabContainer;
private final PreferencesService preferencesService;
diff --git a/src/main/java/org/jabref/gui/sidepane/SidePaneContentFactory.java b/src/main/java/org/jabref/gui/sidepane/SidePaneContentFactory.java
index 2249dbae452..c6da1040ac6 100644
--- a/src/main/java/org/jabref/gui/sidepane/SidePaneContentFactory.java
+++ b/src/main/java/org/jabref/gui/sidepane/SidePaneContentFactory.java
@@ -61,7 +61,8 @@ public Node create(SidePaneType sidePaneType) {
taskExecutor,
stateManager,
preferences,
- dialogService);
+ dialogService,
+ aiService);
case OPEN_OFFICE -> new OpenOfficePanel(
tabContainer,
preferences,
diff --git a/src/main/java/org/jabref/gui/util/BaseWindow.java b/src/main/java/org/jabref/gui/util/BaseWindow.java
new file mode 100644
index 00000000000..5669ab8b5a0
--- /dev/null
+++ b/src/main/java/org/jabref/gui/util/BaseWindow.java
@@ -0,0 +1,51 @@
+package org.jabref.gui.util;
+
+import javafx.collections.FXCollections;
+import javafx.collections.ListChangeListener;
+import javafx.collections.ObservableList;
+import javafx.scene.Scene;
+import javafx.scene.layout.Pane;
+import javafx.stage.Modality;
+import javafx.stage.Stage;
+
+import org.jabref.gui.icon.IconTheme;
+import org.jabref.gui.keyboard.KeyBinding;
+import org.jabref.gui.keyboard.KeyBindingRepository;
+
+import com.airhacks.afterburner.injection.Injector;
+
+/**
+ * A base class for non-modal windows of JabRef.
+ *
+ * You can create a new instance of this class and set the title in the constructor. After that you can call
+ * {@link org.jabref.gui.DialogService#showCustomWindow(BaseWindow)} in order to show the window. All the JabRef styles
+ * will be applied.
+ *
+ * See {@link org.jabref.gui.ai.components.aichat.AiChatWindow} for example.
+ */
+public class BaseWindow extends Stage {
+ private final ObservableList stylesheets = FXCollections.observableArrayList();
+
+ public BaseWindow() {
+ this.initModality(Modality.NONE);
+ this.getIcons().add(IconTheme.getJabRefImage());
+
+ setScene(new Scene(new Pane()));
+
+ stylesheets.addListener((ListChangeListener) c -> getScene().getStylesheets().setAll(stylesheets));
+ sceneProperty().addListener((obs, oldValue, newValue) -> {
+ newValue.getStylesheets().setAll(stylesheets);
+ newValue.setOnKeyPressed(event -> {
+ KeyBindingRepository keyBindingRepository = Injector.instantiateModelOrService(KeyBindingRepository.class);
+ if (keyBindingRepository.checkKeyCombinationEquality(KeyBinding.CLOSE, event)) {
+ close();
+ onCloseRequestProperty().get().handle(null);
+ }
+ });
+ });
+ }
+
+ public void applyStylesheets(ObservableList stylesheets) {
+ this.stylesheets.setAll(stylesheets);
+ }
+}
diff --git a/src/main/java/org/jabref/gui/util/DynamicallyChangeableNode.java b/src/main/java/org/jabref/gui/util/DynamicallyChangeableNode.java
new file mode 100644
index 00000000000..af4150644b9
--- /dev/null
+++ b/src/main/java/org/jabref/gui/util/DynamicallyChangeableNode.java
@@ -0,0 +1,24 @@
+package org.jabref.gui.util;
+
+import javafx.scene.Node;
+import javafx.scene.control.Tab;
+import javafx.scene.layout.Priority;
+import javafx.scene.layout.VBox;
+
+import org.jabref.gui.ai.components.privacynotice.AiPrivacyNoticeGuardedComponent;
+
+/**
+ * A node that can change its content using a setContent(Node) method, similar to {@link Tab}.
+ *
+ * It is used in places where the content is changed dynamically, but you have to provide a one {@link Node} and set it
+ * only once.
+ *
+ * See {@link AiPrivacyNoticeGuardedComponent#rebuildUi()} for example.
+ */
+public class DynamicallyChangeableNode extends VBox {
+ protected void setContent(Node node) {
+ getChildren().clear();
+ VBox.setVgrow(node, Priority.ALWAYS);
+ getChildren().add(node);
+ }
+}
diff --git a/src/main/java/org/jabref/logic/ai/AiChatLogic.java b/src/main/java/org/jabref/logic/ai/AiChatLogic.java
deleted file mode 100644
index b14b2eebc48..00000000000
--- a/src/main/java/org/jabref/logic/ai/AiChatLogic.java
+++ /dev/null
@@ -1,147 +0,0 @@
-package org.jabref.logic.ai;
-
-import java.util.List;
-
-import org.jabref.logic.ai.chathistory.AiChatHistory;
-import org.jabref.logic.ai.misc.ErrorMessage;
-import org.jabref.model.entry.BibEntry;
-import org.jabref.model.entry.CanonicalBibEntry;
-import org.jabref.model.entry.LinkedFile;
-import org.jabref.preferences.ai.AiPreferences;
-
-import dev.langchain4j.chain.Chain;
-import dev.langchain4j.chain.ConversationalRetrievalChain;
-import dev.langchain4j.data.message.AiMessage;
-import dev.langchain4j.data.message.ChatMessage;
-import dev.langchain4j.data.message.SystemMessage;
-import dev.langchain4j.data.message.UserMessage;
-import dev.langchain4j.memory.ChatMemory;
-import dev.langchain4j.memory.chat.TokenWindowChatMemory;
-import dev.langchain4j.model.openai.OpenAiTokenizer;
-import dev.langchain4j.rag.DefaultRetrievalAugmentor;
-import dev.langchain4j.rag.RetrievalAugmentor;
-import dev.langchain4j.rag.content.retriever.ContentRetriever;
-import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
-import dev.langchain4j.store.embedding.filter.Filter;
-import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-public class AiChatLogic {
- private static final Logger LOGGER = LoggerFactory.getLogger(AiChatLogic.class);
-
- private final AiService aiService;
- private final AiChatHistory aiChatHistory;
- private final Filter embeddingsFilter;
- private final BibEntry entry;
-
- private ChatMemory chatMemory;
- private Chain chain;
-
- public AiChatLogic(AiService aiService, AiChatHistory aiChatHistory, Filter embeddingsFilter, BibEntry entry) {
- this.aiService = aiService;
- this.aiChatHistory = aiChatHistory;
- this.embeddingsFilter = embeddingsFilter;
- this.entry = entry;
-
- setupListeningToPreferencesChanges();
- rebuildFull(aiChatHistory.getMessages());
- }
-
- public static AiChatLogic forBibEntry(AiService aiService, AiChatHistory aiChatHistory, BibEntry entry) {
- Filter filter = MetadataFilterBuilder
- .metadataKey(FileEmbeddingsManager.LINK_METADATA_KEY)
- .isIn(entry
- .getFiles()
- .stream()
- .map(LinkedFile::getLink)
- .toList()
- );
-
- if (entry.getCitationKey().isEmpty()) {
- LOGGER.error("AiChatLogic should not be derived from BibEntry with no citation key");
- }
-
- return new AiChatLogic(aiService, aiChatHistory, filter, entry);
- }
-
- private void setupListeningToPreferencesChanges() {
- AiPreferences aiPreferences = aiService.getPreferences();
-
- aiPreferences.instructionProperty().addListener(obs -> setSystemMessage(aiPreferences.getInstruction()));
- aiPreferences.contextWindowSizeProperty().addListener(obs -> rebuildFull(chatMemory.messages()));
- }
-
- private void rebuildFull(List chatMessages) {
- rebuildChatMemory(chatMessages);
- rebuildChain();
- }
-
- private void rebuildChatMemory(List chatMessages) {
- AiPreferences aiPreferences = aiService.getPreferences();
-
- this.chatMemory = TokenWindowChatMemory
- .builder()
- .maxTokens(aiPreferences.getContextWindowSize(), new OpenAiTokenizer())
- .build();
-
- chatMessages.stream().filter(chatMessage -> !(chatMessage instanceof ErrorMessage)).forEach(chatMemory::add);
-
- setSystemMessage(aiPreferences.getInstruction());
- }
-
- private void rebuildChain() {
- AiPreferences aiPreferences = aiService.getPreferences();
-
- ContentRetriever contentRetriever = EmbeddingStoreContentRetriever
- .builder()
- .embeddingStore(aiService.getEmbeddingsManager().getEmbeddingsStore())
- .filter(embeddingsFilter)
- .embeddingModel(aiService.getEmbeddingModel())
- .maxResults(aiPreferences.getRagMaxResultsCount())
- .minScore(aiPreferences.getRagMinScore())
- .build();
-
- RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor
- .builder()
- .contentRetriever(contentRetriever)
- .executor(aiService.getCachedThreadPool())
- .build();
-
- this.chain = ConversationalRetrievalChain
- .builder()
- .chatLanguageModel(aiService.getChatLanguageModel())
- .retrievalAugmentor(retrievalAugmentor)
- .chatMemory(chatMemory)
- .build();
- }
-
- private void setSystemMessage(String systemMessage) {
- chatMemory.add(new SystemMessage(augmentSystemMessage(systemMessage)));
- }
-
- private String augmentSystemMessage(String systemMessage) {
- return systemMessage + "\n" + CanonicalBibEntry.getCanonicalRepresentation(entry);
- }
-
- public AiMessage execute(UserMessage message) {
- // Message will be automatically added to ChatMemory through ConversationalRetrievalChain.
-
- LOGGER.info("Sending message to AI provider ({}) for answering in entry {}: {}",
- AiDefaultPreferences.PROVIDERS_API_URLS.get(aiService.getPreferences().getAiProvider()),
- entry.getCitationKey().orElse(""),
- message.singleText());
-
- aiChatHistory.add(message);
- AiMessage result = new AiMessage(chain.execute(message.singleText()));
- aiChatHistory.add(result);
-
- LOGGER.debug("Message was answered by the AI provider for entry {}: {}", entry.getCitationKey().orElse(""), result.text());
-
- return result;
- }
-
- public AiChatHistory getChatHistory() {
- return aiChatHistory;
- }
-}
diff --git a/src/main/java/org/jabref/logic/ai/AiDefaultPreferences.java b/src/main/java/org/jabref/logic/ai/AiDefaultPreferences.java
index 13a12ef0c47..206fa0ae57d 100644
--- a/src/main/java/org/jabref/logic/ai/AiDefaultPreferences.java
+++ b/src/main/java/org/jabref/logic/ai/AiDefaultPreferences.java
@@ -50,7 +50,7 @@ public class AiDefaultPreferences {
public static final boolean CUSTOMIZE_SETTINGS = false;
public static final EmbeddingModel EMBEDDING_MODEL = EmbeddingModel.SENTENCE_TRANSFORMERS_ALL_MINILM_L12_V2;
- public static final String SYSTEM_MESSAGE = "You are an AI assistant that analyses research papers.";
+ public static final String SYSTEM_MESSAGE = "You are an AI assistant that analyses research papers. You answer questions about papers. You will be supplied with the necessary information. The supplied information will contain mentions of papers in form '@citationKey'. Whenever you refer to a paper, use its citation key in the same form with @ symbol. Whenever you find relevant information, always use the citation key. Here are the papers you are analyzing:\n";
public static final double TEMPERATURE = 0.7;
public static final int DOCUMENT_SPLITTER_CHUNK_SIZE = 300;
public static final int DOCUMENT_SPLITTER_OVERLAP = 100;
diff --git a/src/main/java/org/jabref/logic/ai/AiService.java b/src/main/java/org/jabref/logic/ai/AiService.java
index 7ba31fe269f..0503307df27 100644
--- a/src/main/java/org/jabref/logic/ai/AiService.java
+++ b/src/main/java/org/jabref/logic/ai/AiService.java
@@ -1,29 +1,39 @@
package org.jabref.logic.ai;
-import java.nio.file.Files;
-import java.nio.file.Path;
+import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import javafx.beans.property.BooleanProperty;
-import javafx.beans.property.ReadOnlyBooleanProperty;
import javafx.beans.property.SimpleBooleanProperty;
+import javafx.beans.property.StringProperty;
+import javafx.collections.ObservableList;
import org.jabref.gui.DialogService;
+import org.jabref.gui.StateManager;
+import org.jabref.gui.ai.components.aichat.AiChatWindow;
import org.jabref.gui.desktop.JabRefDesktop;
import org.jabref.gui.util.TaskExecutor;
-import org.jabref.logic.ai.chathistory.BibDatabaseChatHistoryManager;
-import org.jabref.logic.ai.models.JabRefChatLanguageModel;
-import org.jabref.logic.ai.models.JabRefEmbeddingModel;
-import org.jabref.logic.ai.summarization.SummariesStorage;
-import org.jabref.logic.l10n.Localization;
+import org.jabref.logic.ai.chatting.AiChatService;
+import org.jabref.logic.ai.chatting.chathistory.ChatHistoryService;
+import org.jabref.logic.ai.chatting.chathistory.storages.MVStoreChatHistoryStorage;
+import org.jabref.logic.ai.chatting.model.JabRefChatLanguageModel;
+import org.jabref.logic.ai.ingestion.IngestionService;
+import org.jabref.logic.ai.ingestion.MVStoreEmbeddingStore;
+import org.jabref.logic.ai.ingestion.model.JabRefEmbeddingModel;
+import org.jabref.logic.ai.ingestion.storages.MVStoreFullyIngestedDocumentsTracker;
+import org.jabref.logic.ai.summarization.SummariesService;
+import org.jabref.logic.ai.summarization.storages.MVStoreSummariesStorage;
+import org.jabref.logic.citationkeypattern.CitationKeyPatternPreferences;
+import org.jabref.model.database.BibDatabaseContext;
+import org.jabref.model.entry.BibEntry;
+import org.jabref.preferences.FilePreferences;
import org.jabref.preferences.ai.AiApiKeyProvider;
import org.jabref.preferences.ai.AiPreferences;
+import com.airhacks.afterburner.injection.Injector;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
-import org.h2.mvstore.MVStore;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+import dev.langchain4j.data.message.ChatMessage;
/**
* The main class for the AI functionality.
@@ -33,10 +43,17 @@
public class AiService implements AutoCloseable {
public static final String VERSION = "1";
- private static final Logger LOGGER = LoggerFactory.getLogger(AiService.class);
- private static final String AI_SERVICE_MVSTORE_FILE_NAME = "ai.mv";
+ private static final String CHAT_HISTORY_FILE_NAME = "chat-histories.mv";
+ private static final String EMBEDDINGS_FILE_NAME = "embeddings.mv";
+ private static final String FULLY_INGESTED_FILE_NAME = "fully-ingested.mv";
+ private static final String SUMMARIES_FILE_NAME = "summaries.mv";
+
+ private final StateManager stateManager = Injector.instantiateModelOrService(StateManager.class);
private final AiPreferences aiPreferences;
+ private final FilePreferences filePreferences;
+ private final DialogService dialogService;
+ private final TaskExecutor taskExecutor;
// This field is used to shut down AI-related background tasks.
// If a background task processes a big document and has a loop, then the task should check the status
@@ -47,80 +64,113 @@ public class AiService implements AutoCloseable {
new ThreadFactoryBuilder().setNameFormat("ai-retrieval-pool-%d").build()
);
- private final MVStore mvStore;
+ private final MVStoreEmbeddingStore mvStoreEmbeddingStore;
+ private final MVStoreFullyIngestedDocumentsTracker mvStoreFullyIngestedDocumentsTracker;
+ private final MVStoreChatHistoryStorage mvStoreChatHistoryStorage;
+ private final MVStoreSummariesStorage mvStoreSummariesStorage;
private final JabRefChatLanguageModel jabRefChatLanguageModel;
- private final BibDatabaseChatHistoryManager bibDatabaseChatHistoryManager;
-
+ private final ChatHistoryService chatHistoryService;
private final JabRefEmbeddingModel jabRefEmbeddingModel;
- private final FileEmbeddingsManager fileEmbeddingsManager;
-
- private final SummariesStorage summariesStorage;
-
- public AiService(AiPreferences aiPreferences, AiApiKeyProvider aiApiKeyProvider, DialogService dialogService, TaskExecutor taskExecutor) {
+ private final AiChatService aiChatService;
+ private final IngestionService ingestionService;
+ private final SummariesService summariesService;
+
+ public AiService(AiPreferences aiPreferences,
+ FilePreferences filePreferences,
+ CitationKeyPatternPreferences citationKeyPatternPreferences,
+ AiApiKeyProvider aiApiKeyProvider,
+ DialogService dialogService,
+ TaskExecutor taskExecutor
+ ) {
this.aiPreferences = aiPreferences;
+ this.filePreferences = filePreferences;
+ this.dialogService = dialogService;
+ this.taskExecutor = taskExecutor;
- MVStore mvStore;
- try {
- Files.createDirectories(JabRefDesktop.getAiFilesDirectory());
-
- Path mvStorePath = JabRefDesktop.getAiFilesDirectory().resolve(AI_SERVICE_MVSTORE_FILE_NAME);
-
- mvStore = MVStore.open(mvStorePath.toString());
- } catch (Exception e) {
- LOGGER.error("An error occurred while creating directories for AI cache and chat history. Chat history will not be remembered in next session.", e);
- dialogService.notify(Localization.lang("An error occurred while creating directories for AI cache and chat history. Chat history will not be remembered in next session."));
- mvStore = MVStore.open(null);
- }
+ this.jabRefChatLanguageModel = new JabRefChatLanguageModel(aiPreferences, aiApiKeyProvider);
- this.mvStore = mvStore;
+ this.mvStoreEmbeddingStore = new MVStoreEmbeddingStore(JabRefDesktop.getAiFilesDirectory().resolve(EMBEDDINGS_FILE_NAME), dialogService);
+ this.mvStoreFullyIngestedDocumentsTracker = new MVStoreFullyIngestedDocumentsTracker(JabRefDesktop.getAiFilesDirectory().resolve(FULLY_INGESTED_FILE_NAME), dialogService);
+ this.mvStoreSummariesStorage = new MVStoreSummariesStorage(JabRefDesktop.getAiFilesDirectory().resolve(SUMMARIES_FILE_NAME), dialogService);
+ this.mvStoreChatHistoryStorage = new MVStoreChatHistoryStorage(JabRefDesktop.getAiFilesDirectory().resolve(CHAT_HISTORY_FILE_NAME), dialogService);
- this.jabRefChatLanguageModel = new JabRefChatLanguageModel(aiPreferences, aiApiKeyProvider);
- this.bibDatabaseChatHistoryManager = new BibDatabaseChatHistoryManager(mvStore);
+ this.chatHistoryService = new ChatHistoryService(citationKeyPatternPreferences, mvStoreChatHistoryStorage);
this.jabRefEmbeddingModel = new JabRefEmbeddingModel(aiPreferences, dialogService, taskExecutor);
- this.fileEmbeddingsManager = new FileEmbeddingsManager(aiPreferences, shutdownSignal, jabRefEmbeddingModel, mvStore);
- this.summariesStorage = new SummariesStorage(aiPreferences, mvStore);
+ this.aiChatService = new AiChatService(aiPreferences, jabRefChatLanguageModel, jabRefEmbeddingModel, mvStoreEmbeddingStore, cachedThreadPool);
+ this.ingestionService = new IngestionService(
+ aiPreferences,
+ shutdownSignal,
+ jabRefEmbeddingModel,
+ mvStoreEmbeddingStore,
+ mvStoreFullyIngestedDocumentsTracker,
+ filePreferences,
+ taskExecutor
+ );
+ this.summariesService = new SummariesService(aiPreferences, mvStoreSummariesStorage, jabRefChatLanguageModel, shutdownSignal, filePreferences, taskExecutor);
}
- @Override
- public void close() {
- shutdownSignal.set(true);
-
- this.cachedThreadPool.shutdownNow();
- this.jabRefChatLanguageModel.close();
- this.jabRefEmbeddingModel.close();
- this.mvStore.close();
+ public JabRefChatLanguageModel getChatLanguageModel() {
+ return jabRefChatLanguageModel;
}
- public AiPreferences getPreferences() {
- return aiPreferences;
+ public JabRefEmbeddingModel getEmbeddingModel() {
+ return jabRefEmbeddingModel;
}
- public ExecutorService getCachedThreadPool() {
- return cachedThreadPool;
+ public AiChatService getAiChatService() {
+ return aiChatService;
}
- public JabRefChatLanguageModel getChatLanguageModel() {
- return jabRefChatLanguageModel;
+ public ChatHistoryService getChatHistoryService() {
+ return chatHistoryService;
}
- public JabRefEmbeddingModel getEmbeddingModel() {
- return jabRefEmbeddingModel;
+ public IngestionService getIngestionService() {
+ return ingestionService;
}
- public BibDatabaseChatHistoryManager getChatHistoryManager() {
- return bibDatabaseChatHistoryManager;
+ public SummariesService getSummariesService() {
+ return summariesService;
}
- public FileEmbeddingsManager getEmbeddingsManager() {
- return fileEmbeddingsManager;
+ public void openAiChat(StringProperty name, ObservableList chatHistory, BibDatabaseContext bibDatabaseContext, ObservableList entries) {
+ Optional existingWindow = stateManager.getAiChatWindows().stream().filter(window -> window.getChatName().equals(name.get())).findFirst();
+
+ if (existingWindow.isPresent()) {
+ existingWindow.get().requestFocus();
+ } else {
+ AiChatWindow aiChatWindow = new AiChatWindow(
+ this,
+ dialogService,
+ aiPreferences,
+ filePreferences,
+ taskExecutor
+ );
+
+ aiChatWindow.setOnCloseRequest(event ->
+ stateManager.getAiChatWindows().remove(aiChatWindow)
+ );
+
+ stateManager.getAiChatWindows().add(aiChatWindow);
+ dialogService.showCustomWindow(aiChatWindow);
+ aiChatWindow.setChat(name, chatHistory, bibDatabaseContext, entries);
+ aiChatWindow.requestFocus();
+ }
}
- public SummariesStorage getSummariesStorage() {
- return summariesStorage;
- }
+ @Override
+ public void close() {
+ shutdownSignal.set(true);
+
+ cachedThreadPool.shutdownNow();
+ jabRefChatLanguageModel.close();
+ jabRefEmbeddingModel.close();
+ chatHistoryService.close();
- public ReadOnlyBooleanProperty getShutdownSignal() {
- return shutdownSignal;
+ mvStoreFullyIngestedDocumentsTracker.close();
+ mvStoreEmbeddingStore.close();
+ mvStoreChatHistoryStorage.close();
+ mvStoreSummariesStorage.close();
}
}
diff --git a/src/main/java/org/jabref/logic/ai/GenerateEmbeddingsTask.java b/src/main/java/org/jabref/logic/ai/GenerateEmbeddingsTask.java
deleted file mode 100644
index 0ec89939626..00000000000
--- a/src/main/java/org/jabref/logic/ai/GenerateEmbeddingsTask.java
+++ /dev/null
@@ -1,125 +0,0 @@
-package org.jabref.logic.ai;
-
-import java.io.IOException;
-import java.nio.file.Files;
-import java.nio.file.Path;
-import java.nio.file.attribute.BasicFileAttributes;
-import java.util.List;
-import java.util.Optional;
-import java.util.concurrent.TimeUnit;
-
-import org.jabref.gui.util.BackgroundTask;
-import org.jabref.logic.ai.embeddings.FileToDocument;
-import org.jabref.logic.l10n.Localization;
-import org.jabref.logic.util.ProgressCounter;
-import org.jabref.model.database.BibDatabaseContext;
-import org.jabref.model.entry.LinkedFile;
-import org.jabref.preferences.FilePreferences;
-
-import dev.langchain4j.data.document.Document;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-public class GenerateEmbeddingsTask extends BackgroundTask {
- private static final Logger LOGGER = LoggerFactory.getLogger(GenerateEmbeddingsTask.class);
-
- private final String citationKey;
- private final List linkedFiles;
- private final FileEmbeddingsManager fileEmbeddingsManager;
- private final BibDatabaseContext bibDatabaseContext;
- private final FilePreferences filePreferences;
-
- private final ProgressCounter progressCounter = new ProgressCounter();
-
- public GenerateEmbeddingsTask(String citationKey,
- List linkedFiles,
- FileEmbeddingsManager fileEmbeddingsManager,
- BibDatabaseContext bibDatabaseContext,
- FilePreferences filePreferences) {
- this.citationKey = citationKey;
- this.linkedFiles = linkedFiles;
- this.fileEmbeddingsManager = fileEmbeddingsManager;
- this.bibDatabaseContext = bibDatabaseContext;
- this.filePreferences = filePreferences;
-
- titleProperty().set(Localization.lang("Generating embeddings for %0", citationKey));
- showToUser(true);
-
- progressCounter.listenToAllProperties(this::updateProgress);
- }
-
- @Override
- protected Void call() throws Exception {
- LOGGER.info("Starting embeddings generation task for entry {}", citationKey);
-
- try {
- // forEach() method would look better here, but we need to catch the {@link InterruptedException}.
- for (LinkedFile linkedFile : linkedFiles) {
- ingestLinkedFile(linkedFile);
- }
- } catch (InterruptedException e) {
- LOGGER.info("There is a embeddings generation task for {}. It will be cancelled, because user quits JabRef.", citationKey);
- }
-
- showToUser(false);
-
- LOGGER.info("Finished embeddings generation task for entry {}", citationKey);
-
- progressCounter.stop();
-
- return null;
- }
-
- private void ingestLinkedFile(LinkedFile linkedFile) throws InterruptedException {
- // Rationale for RuntimeException here:
- // See org.jabref.logic.ai.summarization.GenerateSummaryTask.summarizeAll
-
- LOGGER.info("Generating embeddings for file \"{}\" of entry {}", linkedFile.getLink(), citationKey);
-
- Optional path = linkedFile.findIn(bibDatabaseContext, filePreferences);
-
- if (path.isEmpty()) {
- LOGGER.error("Could not find path for a linked file \"{}\", while generating embeddings for entry {}", linkedFile.getLink(), citationKey);
- LOGGER.info("Unable to generate embeddings for file \"{}\", because it was not found while generating embeddings for entry {}", linkedFile.getLink(), citationKey);
- throw new RuntimeException(Localization.lang("Could not find path for a linked file %0 while generating embeddings for entry %1", linkedFile.getLink(), citationKey));
- }
-
- try {
- BasicFileAttributes attributes = Files.readAttributes(path.get(), BasicFileAttributes.class);
-
- long currentModificationTimeInSeconds = attributes.lastModifiedTime().to(TimeUnit.SECONDS);
-
- Optional ingestedModificationTimeInSeconds = fileEmbeddingsManager.getIngestedDocumentModificationTimeInSeconds(linkedFile.getLink());
-
- if (ingestedModificationTimeInSeconds.isPresent() && currentModificationTimeInSeconds <= ingestedModificationTimeInSeconds.get()) {
- LOGGER.info("No need to generate embeddings for entry {} for file \"{}\", because it was already generated", citationKey, linkedFile.getLink());
- return;
- }
-
- Optional document = FileToDocument.fromFile(path.get());
- if (document.isPresent()) {
- fileEmbeddingsManager.addDocument(linkedFile.getLink(), document.get(), currentModificationTimeInSeconds, progressCounter.workDoneProperty(), progressCounter.workMaxProperty());
- LOGGER.info("Embeddings for file \"{}\" were generated successfully, while processing entry {}", linkedFile.getLink(), citationKey);
- } else {
- LOGGER.error("Unable to generate embeddings for file \"{}\", because JabRef was unable to extract text from the file, while processing entry {}", linkedFile.getLink(), citationKey);
- throw new RuntimeException(Localization.lang("Unable to generate embeddings for file %0, because JabRef was unable to extract text from the file, while processing entry %1", linkedFile.getLink(), citationKey));
- }
- } catch (IOException e) {
- LOGGER.error("Couldn't retrieve attributes of a linked file \"{}\", while generating embeddings for entry {}", linkedFile.getLink(), citationKey, e);
- LOGGER.warn("Regenerating embeddings for linked file \"{}\", while processing entry {}", linkedFile.getLink(), citationKey);
-
- Optional document = FileToDocument.fromFile(path.get());
- if (document.isPresent()) {
- fileEmbeddingsManager.addDocument(linkedFile.getLink(), document.get(), 0, progressCounter.workDoneProperty(), progressCounter.workMaxProperty());
- LOGGER.info("Embeddings for file \"{}\" were generated successfully while processing entry {}, but the JabRef couldn't check if the file was changed", linkedFile.getLink(), citationKey);
- } else {
- LOGGER.info("Unable to generate embeddings for file \"{}\" while processing entry {}, because JabRef was unable to extract text from the file", linkedFile.getLink(), citationKey);
- }
- }
- }
-
- private void updateProgress() {
- updateProgress(progressCounter.getWorkDone(), progressCounter.getWorkMax());
- updateMessage(progressCounter.getMessage());
- }
-}
diff --git a/src/main/java/org/jabref/logic/ai/chathistory/AiChatHistory.java b/src/main/java/org/jabref/logic/ai/chathistory/AiChatHistory.java
deleted file mode 100644
index 9cf3ef797ac..00000000000
--- a/src/main/java/org/jabref/logic/ai/chathistory/AiChatHistory.java
+++ /dev/null
@@ -1,15 +0,0 @@
-package org.jabref.logic.ai.chathistory;
-
-import java.util.List;
-
-import dev.langchain4j.data.message.ChatMessage;
-
-public interface AiChatHistory {
- List getMessages();
-
- void add(ChatMessage chatMessage);
-
- void remove(int index);
-
- void clear();
-}
diff --git a/src/main/java/org/jabref/logic/ai/chathistory/BibDatabaseChatHistoryManager.java b/src/main/java/org/jabref/logic/ai/chathistory/BibDatabaseChatHistoryManager.java
deleted file mode 100644
index f03aaba511a..00000000000
--- a/src/main/java/org/jabref/logic/ai/chathistory/BibDatabaseChatHistoryManager.java
+++ /dev/null
@@ -1,165 +0,0 @@
-package org.jabref.logic.ai.chathistory;
-
-import java.io.Serializable;
-import java.nio.file.Path;
-import java.util.Comparator;
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-
-import org.jabref.gui.StateManager;
-import org.jabref.logic.ai.misc.ErrorMessage;
-import org.jabref.model.database.BibDatabaseContext;
-import org.jabref.model.entry.event.FieldChangedEvent;
-import org.jabref.model.entry.field.InternalField;
-
-import com.airhacks.afterburner.injection.Injector;
-import com.google.common.eventbus.Subscribe;
-import dev.langchain4j.data.message.AiMessage;
-import dev.langchain4j.data.message.ChatMessage;
-import dev.langchain4j.data.message.UserMessage;
-import org.h2.mvstore.MVStore;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-/**
- * This class stores the chat history with AI for .bib files. The chat history file is stored in a user local folder.
- *
- * It uses MVStore for serializing the messages. In case any error occurs while opening the MVStore,
- * the class will notify the user of this error and continue with in-memory store (meaning all messages will
- * be thrown away on exit).
- *
- * @implNote If something is changed in the data model, increase {@link org.jabref.logic.ai.AiService#VERSION}
- */
-public class BibDatabaseChatHistoryManager {
- private static final Logger LOGGER = LoggerFactory.getLogger(BibDatabaseChatHistoryManager.class);
-
- private final StateManager stateManager = Injector.instantiateModelOrService(StateManager.class);
-
- private record ChatHistoryRecord(String className, String content) implements Serializable {
- public ChatMessage toLangchainMessage() {
- if (className.equals(AiMessage.class.getName())) {
- return new AiMessage(content);
- } else if (className.equals(UserMessage.class.getName())) {
- return new UserMessage(content);
- } else if (className.equals(ErrorMessage.class.getName())) {
- return new ErrorMessage(content);
- } else {
- LOGGER.warn("BibDatabaseChatHistoryManager supports only AI and user messages, but retrieved message has other type: {}. Will treat as an AI message.", className);
- return new AiMessage(content);
- }
- }
- }
-
- private final MVStore mvStore;
-
- public BibDatabaseChatHistoryManager(MVStore mvStore) {
- this.mvStore = mvStore;
- }
-
- private Map getMap(Path bibDatabasePath, String citationKey) {
- return mvStore.openMap("chathistory-" + bibDatabasePath + "-" + citationKey);
- }
-
- public AiChatHistory getChatHistory(Path bibDatabasePath, String citationKey) {
- return new AiChatHistory() {
- @Override
- public List getMessages() {
- Map messages = getMap(bibDatabasePath, citationKey);
-
- return messages
- .entrySet()
- // we need to check all keys, because upon deletion, there can be "holes" in the integer
- .stream()
- .sorted(Comparator.comparingInt(Map.Entry::getKey))
- .map(entry -> entry.getValue().toLangchainMessage())
- .toList();
- }
-
- @Override
- public void add(ChatMessage chatMessage) {
- Map map = getMap(bibDatabasePath, citationKey);
-
- // We count 0-based, thus "size()" is the next number.
- // 0 entries -> 0 is the first new id.
- // 1 entry -> 0 is assigned, 1 is the next number, which is also the size.
- // But if an entry is removed, keys are not updated, so we have to find the maximum key.
- int id = map.keySet().stream().max(Integer::compareTo).orElse(0) + 1;
-
- String content = getContentFromLangchainMessage(chatMessage);
-
- map.put(id, new ChatHistoryRecord(chatMessage.getClass().getName(), content));
- }
-
- @Override
- public void remove(int index) {
- Map map = getMap(bibDatabasePath, citationKey);
-
- Optional id = map
- .entrySet()
- .stream()
- .sorted(Comparator.comparingInt(Map.Entry::getKey))
- .skip(index)
- .map(Map.Entry::getKey)
- .findFirst();
-
- if (id.isPresent()) {
- map.remove(id.get());
- } else {
- LOGGER.error("Attempted to delete a message that does not exist in the chat history at index {}", index);
- }
- }
-
- @Override
- public void clear() {
- getMap(bibDatabasePath, citationKey).clear();
- }
- };
- }
-
- private static String getContentFromLangchainMessage(ChatMessage chatMessage) {
- String content;
-
- switch (chatMessage) {
- case AiMessage aiMessage ->
- content = aiMessage.text();
- case UserMessage userMessage ->
- content = userMessage.singleText();
- case ErrorMessage errorMessage ->
- content = errorMessage.getText();
- default -> {
- LOGGER.warn("BibDatabaseChatHistoryManager supports only AI, user. and error messages, but added message has other type: {}", chatMessage.type().name());
- return "";
- }
- }
-
- return content;
- }
-
- @Subscribe
- private void fieldChangedEventListener(FieldChangedEvent event) {
- // TODO: This methods doesn't take into account if the new citation key is valid.
-
- if (event.getField() != InternalField.KEY_FIELD) {
- return;
- }
-
- Optional bibDatabaseContext = stateManager.getOpenDatabases().stream().filter(dbContext -> dbContext.getDatabase().getEntries().contains(event.getBibEntry())).findFirst();
-
- if (bibDatabaseContext.isEmpty()) {
- LOGGER.error("Could not listen to field change event because no database context was found. BibEntry: {}", event.getBibEntry());
- return;
- }
-
- Optional bibDatabasePath = bibDatabaseContext.get().getDatabasePath();
-
- if (bibDatabasePath.isEmpty()) {
- LOGGER.error("Could not listen to field change event because no database path was found. BibEntry: {}", event.getBibEntry());
- return;
- }
-
- Map oldMap = getMap(bibDatabasePath.get(), event.getOldValue());
- getMap(bibDatabasePath.get(), event.getNewValue()).putAll(oldMap);
- oldMap.clear();
- }
-}
diff --git a/src/main/java/org/jabref/logic/ai/chathistory/InMemoryAiChatHistory.java b/src/main/java/org/jabref/logic/ai/chathistory/InMemoryAiChatHistory.java
deleted file mode 100644
index 980b939b79b..00000000000
--- a/src/main/java/org/jabref/logic/ai/chathistory/InMemoryAiChatHistory.java
+++ /dev/null
@@ -1,30 +0,0 @@
-package org.jabref.logic.ai.chathistory;
-
-import java.util.ArrayList;
-import java.util.List;
-
-import dev.langchain4j.data.message.ChatMessage;
-
-public class InMemoryAiChatHistory implements AiChatHistory {
- private final List chatMessages = new ArrayList<>();
-
- @Override
- public List getMessages() {
- return chatMessages;
- }
-
- @Override
- public void add(ChatMessage chatMessage) {
- chatMessages.add(chatMessage);
- }
-
- @Override
- public void remove(int index) {
- chatMessages.remove(index);
- }
-
- @Override
- public void clear() {
- chatMessages.clear();
- }
-}
diff --git a/src/main/java/org/jabref/logic/ai/chatting/AiChatLogic.java b/src/main/java/org/jabref/logic/ai/chatting/AiChatLogic.java
new file mode 100644
index 00000000000..2aacd133db7
--- /dev/null
+++ b/src/main/java/org/jabref/logic/ai/chatting/AiChatLogic.java
@@ -0,0 +1,179 @@
+package org.jabref.logic.ai.chatting;
+
+import java.util.List;
+import java.util.concurrent.Executor;
+import java.util.stream.Collectors;
+
+import javafx.beans.property.StringProperty;
+import javafx.collections.ListChangeListener;
+import javafx.collections.ObservableList;
+
+import org.jabref.logic.ai.AiDefaultPreferences;
+import org.jabref.logic.ai.ingestion.FileEmbeddingsManager;
+import org.jabref.logic.ai.util.ErrorMessage;
+import org.jabref.model.database.BibDatabaseContext;
+import org.jabref.model.entry.BibEntry;
+import org.jabref.model.entry.CanonicalBibEntry;
+import org.jabref.model.entry.LinkedFile;
+import org.jabref.model.util.ListUtil;
+import org.jabref.preferences.ai.AiPreferences;
+
+import dev.langchain4j.chain.Chain;
+import dev.langchain4j.chain.ConversationalRetrievalChain;
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.data.message.SystemMessage;
+import dev.langchain4j.data.message.UserMessage;
+import dev.langchain4j.data.segment.TextSegment;
+import dev.langchain4j.memory.ChatMemory;
+import dev.langchain4j.memory.chat.TokenWindowChatMemory;
+import dev.langchain4j.model.chat.ChatLanguageModel;
+import dev.langchain4j.model.embedding.EmbeddingModel;
+import dev.langchain4j.model.openai.OpenAiTokenizer;
+import dev.langchain4j.rag.DefaultRetrievalAugmentor;
+import dev.langchain4j.rag.RetrievalAugmentor;
+import dev.langchain4j.rag.content.retriever.ContentRetriever;
+import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
+import dev.langchain4j.store.embedding.EmbeddingStore;
+import dev.langchain4j.store.embedding.filter.Filter;
+import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
+import jakarta.annotation.Nullable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class AiChatLogic {
+ private static final Logger LOGGER = LoggerFactory.getLogger(AiChatLogic.class);
+
+ private final AiPreferences aiPreferences;
+ private final ChatLanguageModel chatLanguageModel;
+ private final EmbeddingModel embeddingModel;
+ private final EmbeddingStore embeddingStore;
+ private final Executor cachedThreadPool;
+
+ private final ObservableList chatHistory;
+ private final ObservableList entries;
+ private final StringProperty name;
+ private final BibDatabaseContext bibDatabaseContext;
+
+ private ChatMemory chatMemory;
+ private Chain chain;
+
+ public AiChatLogic(AiPreferences aiPreferences,
+ ChatLanguageModel chatLanguageModel,
+ EmbeddingModel embeddingModel,
+ EmbeddingStore embeddingStore,
+ Executor cachedThreadPool,
+ StringProperty name,
+ ObservableList chatHistory,
+ ObservableList entries,
+ BibDatabaseContext bibDatabaseContext
+ ) {
+ this.aiPreferences = aiPreferences;
+ this.chatLanguageModel = chatLanguageModel;
+ this.embeddingModel = embeddingModel;
+ this.embeddingStore = embeddingStore;
+ this.cachedThreadPool = cachedThreadPool;
+ this.chatHistory = chatHistory;
+ this.entries = entries;
+ this.name = name;
+ this.bibDatabaseContext = bibDatabaseContext;
+
+ this.entries.addListener((ListChangeListener) change -> rebuildChain());
+
+ setupListeningToPreferencesChanges();
+ rebuildFull(chatHistory);
+ }
+
+ private void setupListeningToPreferencesChanges() {
+ aiPreferences.instructionProperty().addListener(obs -> setSystemMessage(aiPreferences.getInstruction()));
+ aiPreferences.contextWindowSizeProperty().addListener(obs -> rebuildFull(chatMemory.messages()));
+ }
+
+ private void rebuildFull(List chatMessages) {
+ rebuildChatMemory(chatMessages);
+ rebuildChain();
+ }
+
+ private void rebuildChatMemory(List chatMessages) {
+ this.chatMemory = TokenWindowChatMemory
+ .builder()
+ .maxTokens(aiPreferences.getContextWindowSize(), new OpenAiTokenizer())
+ .build();
+
+ chatMessages.stream().filter(chatMessage -> !(chatMessage instanceof ErrorMessage)).forEach(chatMemory::add);
+
+ setSystemMessage(aiPreferences.getInstruction());
+ }
+
+ private void rebuildChain() {
+ List linkedFiles = ListUtil.getLinkedFiles(entries).toList();
+ @Nullable Filter filter;
+
+ if (linkedFiles.isEmpty()) {
+ // You must not pass an empty list to langchain4j {@link IsIn} filter.
+ filter = null;
+ } else {
+ filter = MetadataFilterBuilder
+ .metadataKey(FileEmbeddingsManager.LINK_METADATA_KEY)
+ .isIn(linkedFiles
+ .stream()
+ .map(LinkedFile::getLink)
+ .toList()
+ );
+ }
+
+ ContentRetriever contentRetriever = EmbeddingStoreContentRetriever
+ .builder()
+ .embeddingStore(embeddingStore)
+ .filter(filter)
+ .embeddingModel(embeddingModel)
+ .maxResults(aiPreferences.getRagMaxResultsCount())
+ .minScore(aiPreferences.getRagMinScore())
+ .build();
+
+ RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor
+ .builder()
+ .contentRetriever(contentRetriever)
+ .contentInjector(new JabRefContentInjector(bibDatabaseContext))
+ .executor(cachedThreadPool)
+ .build();
+
+ this.chain = ConversationalRetrievalChain
+ .builder()
+ .chatLanguageModel(chatLanguageModel)
+ .retrievalAugmentor(retrievalAugmentor)
+ .chatMemory(chatMemory)
+ .build();
+ }
+
+ private void setSystemMessage(String systemMessage) {
+ chatMemory.add(new SystemMessage(augmentSystemMessage(systemMessage)));
+ }
+
+ private String augmentSystemMessage(String systemMessage) {
+ String entriesInfo = entries.stream().map(CanonicalBibEntry::getCanonicalRepresentation).collect(Collectors.joining("\n"));
+
+ return systemMessage + "\n" + entriesInfo;
+ }
+
+ public AiMessage execute(UserMessage message) {
+ // Message will be automatically added to ChatMemory through ConversationalRetrievalChain.
+
+ LOGGER.info("Sending message to AI provider ({}) for answering in {}: {}",
+ AiDefaultPreferences.PROVIDERS_API_URLS.get(aiPreferences.getAiProvider()),
+ name.get(),
+ message.singleText());
+
+ chatHistory.add(message);
+ AiMessage result = new AiMessage(chain.execute(message.singleText()));
+ chatHistory.add(result);
+
+ LOGGER.debug("Message was answered by the AI provider for {}: {}", name.get(), result.text());
+
+ return result;
+ }
+
+ public ObservableList getChatHistory() {
+ return chatHistory;
+ }
+}
diff --git a/src/main/java/org/jabref/logic/ai/chatting/AiChatService.java b/src/main/java/org/jabref/logic/ai/chatting/AiChatService.java
new file mode 100644
index 00000000000..7309a422add
--- /dev/null
+++ b/src/main/java/org/jabref/logic/ai/chatting/AiChatService.java
@@ -0,0 +1,56 @@
+package org.jabref.logic.ai.chatting;
+
+import java.util.concurrent.Executor;
+
+import javafx.beans.property.StringProperty;
+import javafx.collections.ObservableList;
+
+import org.jabref.model.database.BibDatabaseContext;
+import org.jabref.model.entry.BibEntry;
+import org.jabref.preferences.ai.AiPreferences;
+
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.data.segment.TextSegment;
+import dev.langchain4j.model.chat.ChatLanguageModel;
+import dev.langchain4j.model.embedding.EmbeddingModel;
+import dev.langchain4j.store.embedding.EmbeddingStore;
+
+public class AiChatService {
+ private final AiPreferences aiPreferences;
+ private final ChatLanguageModel chatLanguageModel;
+ private final EmbeddingModel embeddingModel;
+ private final EmbeddingStore embeddingStore;
+ private final Executor cachedThreadPool;
+
+ public AiChatService(AiPreferences aiPreferences,
+ ChatLanguageModel chatLanguageModel,
+ EmbeddingModel embeddingModel,
+ EmbeddingStore embeddingStore,
+ Executor cachedThreadPool
+ ) {
+ this.aiPreferences = aiPreferences;
+ this.chatLanguageModel = chatLanguageModel;
+ this.embeddingModel = embeddingModel;
+ this.embeddingStore = embeddingStore;
+ this.cachedThreadPool = cachedThreadPool;
+ }
+
+ public AiChatLogic makeChat(
+ StringProperty name,
+ ObservableList chatHistory,
+ ObservableList entries,
+ BibDatabaseContext bibDatabaseContext
+ ) {
+ return new AiChatLogic(
+ aiPreferences,
+ chatLanguageModel,
+ embeddingModel,
+ embeddingStore,
+ cachedThreadPool,
+ name,
+ chatHistory,
+ entries,
+ bibDatabaseContext
+ );
+ }
+}
diff --git a/src/main/java/org/jabref/logic/ai/chatting/JabRefContentInjector.java b/src/main/java/org/jabref/logic/ai/chatting/JabRefContentInjector.java
new file mode 100644
index 00000000000..d9762f12e7c
--- /dev/null
+++ b/src/main/java/org/jabref/logic/ai/chatting/JabRefContentInjector.java
@@ -0,0 +1,68 @@
+package org.jabref.logic.ai.chatting;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import org.jabref.model.database.BibDatabaseContext;
+import org.jabref.model.entry.BibEntry;
+
+import dev.langchain4j.data.message.UserMessage;
+import dev.langchain4j.model.input.PromptTemplate;
+import dev.langchain4j.rag.content.Content;
+import dev.langchain4j.rag.content.injector.ContentInjector;
+
+import static org.jabref.logic.ai.ingestion.FileEmbeddingsManager.LINK_METADATA_KEY;
+
+public class JabRefContentInjector implements ContentInjector {
+ public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from("{{userMessage}}\n\nAnswer using the following information:\n{{contents}}");
+
+ private final BibDatabaseContext bibDatabaseContext;
+
+ public JabRefContentInjector(BibDatabaseContext bibDatabaseContext) {
+ this.bibDatabaseContext = bibDatabaseContext;
+ }
+
+ @Override
+ public UserMessage inject(List list, UserMessage userMessage) {
+ String contentText = list.stream().map(this::contentToString).collect(Collectors.joining("\n\n"));
+
+ String res = applyPrompt(userMessage.singleText(), contentText);
+ return new UserMessage(res);
+ }
+
+ private String contentToString(Content content) {
+ String text = content.textSegment().text();
+
+ String link = content.textSegment().metadata().getString(LINK_METADATA_KEY);
+ if (link == null) {
+ return text;
+ }
+
+ String keys = findEntriesByLink(link)
+ .filter(entry -> entry.getCitationKey().isPresent())
+ .map(entry -> "@" + entry.getCitationKey().get())
+ .collect(Collectors.joining(", "));
+
+ if (keys.isEmpty()) {
+ return text;
+ } else {
+ return keys + ":\n" + text;
+ }
+ }
+
+ private Stream findEntriesByLink(String link) {
+ return bibDatabaseContext.getEntries().stream().filter(entry -> entry.getFiles().stream().anyMatch(file -> file.getLink().equals(link)));
+ }
+
+ private String applyPrompt(String userMessage, String contents) {
+ Map variables = new HashMap<>();
+
+ variables.put("userMessage", userMessage);
+ variables.put("contents", contents);
+
+ return DEFAULT_PROMPT_TEMPLATE.apply(variables).text();
+ }
+}
diff --git a/src/main/java/org/jabref/logic/ai/chatting/chathistory/ChatHistoryService.java b/src/main/java/org/jabref/logic/ai/chatting/chathistory/ChatHistoryService.java
new file mode 100644
index 00000000000..acbde14c3a3
--- /dev/null
+++ b/src/main/java/org/jabref/logic/ai/chatting/chathistory/ChatHistoryService.java
@@ -0,0 +1,298 @@
+package org.jabref.logic.ai.chatting.chathistory;
+
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Optional;
+import java.util.TreeMap;
+
+import javafx.collections.FXCollections;
+import javafx.collections.ListChangeListener;
+import javafx.collections.ObservableList;
+
+import org.jabref.gui.StateManager;
+import org.jabref.logic.ai.util.CitationKeyCheck;
+import org.jabref.logic.citationkeypattern.CitationKeyGenerator;
+import org.jabref.logic.citationkeypattern.CitationKeyPatternPreferences;
+import org.jabref.model.database.BibDatabaseContext;
+import org.jabref.model.entry.BibEntry;
+import org.jabref.model.entry.event.FieldChangedEvent;
+import org.jabref.model.entry.field.InternalField;
+import org.jabref.model.groups.AbstractGroup;
+import org.jabref.model.groups.GroupTreeNode;
+
+import com.airhacks.afterburner.injection.Injector;
+import com.google.common.eventbus.Subscribe;
+import dev.langchain4j.data.message.ChatMessage;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Main class for getting and storing chat history for entries and groups.
+ * Use this class in the logic and UI.
+ *
+ * The returned chat history is a {@link ObservableList}. So chat history exists for every possible
+ * {@link BibEntry} and {@link AbstractGroup}. The chat history is stored in runtime.
+ *
+ * To save and load chat history, {@link BibEntry} and {@link AbstractGroup} must satisfy several constraints.
+ * Serialization and deserialization is handled in {@link ChatHistoryStorage}.
+ *
+ * Constraints for serialization and deserialization of a chat history of a {@link BibEntry}:
+ * 1. There should exist an associated {@link BibDatabaseContext} for the {@link BibEntry}.
+ * 2. The database path of the associated {@link BibDatabaseContext} must be set.
+ * 3. The citation key of the {@link BibEntry} must be set and unique.
+ *
+ * Constraints for serialization and deserialization of a chat history of an {@link GroupTreeNode}:
+ * 1. There should exist an associated {@link BibDatabaseContext} for the {@link GroupTreeNode}.
+ * 2. The database path of the associated {@link BibDatabaseContext} must be set.
+ * 3. The name of an {@link GroupTreeNode} must be set and unique (this requirement is possibly already satisfied in
+ * JabRef, but for {@link BibEntry} it is definitely not).
+ */
+public class ChatHistoryService implements AutoCloseable {
+ private static final Logger LOGGER = LoggerFactory.getLogger(ChatHistoryService.class);
+
+ private final StateManager stateManager = Injector.instantiateModelOrService(StateManager.class);
+
+ private final CitationKeyPatternPreferences citationKeyPatternPreferences;
+
+ private final ChatHistoryStorage implementation;
+
+ private record ChatHistoryManagementRecord(Optional bibDatabaseContext, ObservableList chatHistory) { }
+
+ // We use a {@link TreeMap} here to store {@link BibEntry} chat histories by their id.
+ // When you compare {@link BibEntry} instances, they are compared by value, not by reference.
+ // And when you store {@link BibEntry} instances in a {@link HashMap}, an old hash may be stored when the {@link BibEntry} is changed.
+ // See also ADR-38.
+ private final TreeMap bibEntriesChatHistory = new TreeMap<>(Comparator.comparing(BibEntry::getId));
+
+ // We use {@link TreeMap} for group chat history for the same reason as for {@link BibEntry}ies.
+ private final TreeMap groupsChatHistory = new TreeMap<>((o1, o2) -> {
+ // The most important thing is to catch equality/non-equality.
+ // For "less" or "bigger" comparison, we will fall back to group names.
+ return o1 == o2 ? 0 : o1.getGroup().getName().compareTo(o2.getGroup().getName());
+ });
+
+ public ChatHistoryService(CitationKeyPatternPreferences citationKeyPatternPreferences,
+ ChatHistoryStorage implementation) {
+ this.citationKeyPatternPreferences = citationKeyPatternPreferences;
+ this.implementation = implementation;
+
+ configureHistoryTransfer();
+ }
+
+ private void configureHistoryTransfer() {
+ stateManager.getOpenDatabases().addListener((ListChangeListener) change -> {
+ while (change.next()) {
+ if (change.wasAdded()) {
+ change.getAddedSubList().forEach(this::configureHistoryTransfer);
+ }
+ }
+ });
+ }
+
+ private void configureHistoryTransfer(BibDatabaseContext bibDatabaseContext) {
+ bibDatabaseContext.getMetaData().getGroups().ifPresent(rootGroupTreeNode -> {
+ rootGroupTreeNode.iterateOverTree().forEach(groupNode -> {
+ groupNode.getGroup().nameProperty().addListener((observable, oldValue, newValue) -> {
+ if (newValue != null && oldValue != null) {
+ transferGroupHistory(bibDatabaseContext, groupNode, oldValue, newValue);
+ }
+ });
+
+ groupNode.getGroupProperty().addListener((obs, oldValue, newValue) -> {
+ if (oldValue != null && newValue != null) {
+ transferGroupHistory(bibDatabaseContext, groupNode, oldValue.getName(), newValue.getName());
+ }
+ });
+ });
+ });
+
+ bibDatabaseContext.getDatabase().getEntries().forEach(entry -> {
+ entry.registerListener(new CitationKeyChangeListener(bibDatabaseContext));
+ });
+ }
+
+ public ObservableList getChatHistoryForEntry(BibEntry entry) {
+ return bibEntriesChatHistory.computeIfAbsent(entry, entryArg -> {
+ Optional bibDatabaseContext = findBibDatabaseForEntry(entry);
+
+ ObservableList chatHistory;
+
+ if (bibDatabaseContext.isEmpty() || entry.getCitationKey().isEmpty() || !correctCitationKey(bibDatabaseContext.get(), entry) || bibDatabaseContext.get().getDatabasePath().isEmpty()) {
+ chatHistory = FXCollections.observableArrayList();
+ } else {
+ List chatMessagesList = implementation.loadMessagesForEntry(bibDatabaseContext.get().getDatabasePath().get(), entry.getCitationKey().get());
+ chatHistory = FXCollections.observableArrayList(chatMessagesList);
+ }
+
+ return new ChatHistoryManagementRecord(bibDatabaseContext, chatHistory);
+ }).chatHistory;
+ }
+
+ /**
+ * Removes the chat history for the given {@link BibEntry} from the internal RAM map.
+ * If the {@link BibEntry} satisfies requirements for serialization and deserialization of chat history (see
+ * the docstring for the {@link ChatHistoryService}), then the chat history will be stored via the
+ * {@link ChatHistoryStorage}.
+ *
+ * It is not necessary to call this method (everything will be stored in {@link ChatHistoryService#close()},
+ * but it's best to call it when the chat history {@link BibEntry} is no longer needed.
+ */
+ public void closeChatHistoryForEntry(BibEntry entry) {
+ ChatHistoryManagementRecord chatHistoryManagementRecord = bibEntriesChatHistory.get(entry);
+ if (chatHistoryManagementRecord == null) {
+ return;
+ }
+
+ Optional bibDatabaseContext = chatHistoryManagementRecord.bibDatabaseContext();
+
+ if (bibDatabaseContext.isPresent() && entry.getCitationKey().isPresent() && correctCitationKey(bibDatabaseContext.get(), entry) && bibDatabaseContext.get().getDatabasePath().isPresent()) {
+ // Method `correctCitationKey` will already check `entry.getCitationKey().isPresent()`, but it is still
+ // there, to suppress warning from IntelliJ IDEA on `entry.getCitationKey().get()`.
+ implementation.storeMessagesForEntry(
+ bibDatabaseContext.get().getDatabasePath().get(),
+ entry.getCitationKey().get(),
+ chatHistoryManagementRecord.chatHistory()
+ );
+ }
+
+ // TODO: What if there is two AI chats for the same entry? And one is closed and one is not?
+ bibEntriesChatHistory.remove(entry);
+ }
+
+ public ObservableList getChatHistoryForGroup(GroupTreeNode group) {
+ return groupsChatHistory.computeIfAbsent(group, groupArg -> {
+ Optional bibDatabaseContext = findBibDatabaseForGroup(group);
+
+ ObservableList chatHistory;
+
+ if (bibDatabaseContext.isEmpty() || bibDatabaseContext.get().getDatabasePath().isEmpty()) {
+ chatHistory = FXCollections.observableArrayList();
+ } else {
+ List chatMessagesList = implementation.loadMessagesForGroup(
+ bibDatabaseContext.get().getDatabasePath().get(),
+ group.getGroup().getName()
+ );
+
+ chatHistory = FXCollections.observableArrayList(chatMessagesList);
+ }
+
+ return new ChatHistoryManagementRecord(bibDatabaseContext, chatHistory);
+ }).chatHistory;
+ }
+
+ /**
+ * Removes the chat history for the given {@link GroupTreeNode} from the internal RAM map.
+ * If the {@link GroupTreeNode} satisfies requirements for serialization and deserialization of chat history (see
+ * the docstring for the {@link ChatHistoryService}), then the chat history will be stored via the
+ * {@link ChatHistoryStorage}.
+ *
+ * It is not necessary to call this method (everything will be stored in {@link ChatHistoryService#close()},
+ * but it's best to call it when the chat history {@link GroupTreeNode} is no longer needed.
+ */
+ public void closeChatHistoryForGroup(GroupTreeNode group) {
+ ChatHistoryManagementRecord chatHistoryManagementRecord = groupsChatHistory.get(group);
+ if (chatHistoryManagementRecord == null) {
+ return;
+ }
+
+ Optional bibDatabaseContext = chatHistoryManagementRecord.bibDatabaseContext();
+
+ if (bibDatabaseContext.isPresent() && bibDatabaseContext.get().getDatabasePath().isPresent()) {
+ implementation.storeMessagesForGroup(
+ bibDatabaseContext.get().getDatabasePath().get(),
+ group.getGroup().getName(),
+ chatHistoryManagementRecord.chatHistory()
+ );
+ }
+
+ // TODO: What if there is two AI chats for the same entry? And one is closed and one is not?
+ groupsChatHistory.remove(group);
+ }
+
+ private boolean correctCitationKey(BibDatabaseContext bibDatabaseContext, BibEntry bibEntry) {
+ if (!CitationKeyCheck.citationKeyIsPresentAndUnique(bibDatabaseContext, bibEntry)) {
+ tryToGenerateCitationKey(bibDatabaseContext, bibEntry);
+ }
+
+ return CitationKeyCheck.citationKeyIsPresentAndUnique(bibDatabaseContext, bibEntry);
+ }
+
+ private void tryToGenerateCitationKey(BibDatabaseContext bibDatabaseContext, BibEntry bibEntry) {
+ new CitationKeyGenerator(bibDatabaseContext, citationKeyPatternPreferences).generateAndSetKey(bibEntry);
+ }
+
+ private Optional findBibDatabaseForEntry(BibEntry entry) {
+ return stateManager
+ .getOpenDatabases()
+ .stream()
+ .filter(dbContext -> dbContext.getDatabase().getEntries().contains(entry))
+ .findFirst();
+ }
+
+ private Optional findBibDatabaseForGroup(GroupTreeNode group) {
+ return stateManager
+ .getOpenDatabases()
+ .stream()
+ .filter(dbContext ->
+ dbContext.getMetaData().groupsBinding().get().map(groupTreeNode ->
+ groupTreeNode.containsGroup(group.getGroup())
+ ).orElse(false)
+ )
+ .findFirst();
+ }
+
+ @Override
+ public void close() {
+ // We need to clone `bibEntriesChatHistory.keySet()` because closeChatHistoryForEntry() modifies the `bibEntriesChatHistory` map.
+ new HashSet<>(bibEntriesChatHistory.keySet()).forEach(this::closeChatHistoryForEntry);
+
+ // Clone is for the same reason, as written above.
+ new HashSet<>(groupsChatHistory.keySet()).forEach(this::closeChatHistoryForGroup);
+
+ implementation.commit();
+ }
+
+ private void transferGroupHistory(BibDatabaseContext bibDatabaseContext, GroupTreeNode groupTreeNode, String oldName, String newName) {
+ if (bibDatabaseContext.getDatabasePath().isEmpty()) {
+ LOGGER.warn("Could not transfer chat history of group {} (old name: {}): database path is empty.", newName, oldName);
+ return;
+ }
+
+ List chatMessages = groupsChatHistory.computeIfAbsent(groupTreeNode,
+ e -> new ChatHistoryManagementRecord(Optional.of(bibDatabaseContext), FXCollections.observableArrayList())).chatHistory;
+ implementation.storeMessagesForGroup(bibDatabaseContext.getDatabasePath().get(), oldName, List.of());
+ implementation.storeMessagesForGroup(bibDatabaseContext.getDatabasePath().get(), newName, chatMessages);
+ }
+
+ private void transferEntryHistory(BibDatabaseContext bibDatabaseContext, BibEntry entry, String oldCitationKey, String newCitationKey) {
+ // TODO: This method does not check if the citation key is valid.
+
+ if (bibDatabaseContext.getDatabasePath().isEmpty()) {
+ LOGGER.warn("Could not transfer chat history of entry {} (old key: {}): database path is empty.", newCitationKey, oldCitationKey);
+ return;
+ }
+
+ List chatMessages = bibEntriesChatHistory.computeIfAbsent(entry,
+ e -> new ChatHistoryManagementRecord(Optional.of(bibDatabaseContext), FXCollections.observableArrayList())).chatHistory;
+ implementation.storeMessagesForGroup(bibDatabaseContext.getDatabasePath().get(), oldCitationKey, List.of());
+ implementation.storeMessagesForEntry(bibDatabaseContext.getDatabasePath().get(), newCitationKey, chatMessages);
+ }
+
+ private class CitationKeyChangeListener {
+ private final BibDatabaseContext bibDatabaseContext;
+
+ public CitationKeyChangeListener(BibDatabaseContext bibDatabaseContext) {
+ this.bibDatabaseContext = bibDatabaseContext;
+ }
+
+ @Subscribe
+ void listen(FieldChangedEvent e) {
+ if (e.getField() != InternalField.KEY_FIELD) {
+ return;
+ }
+
+ transferEntryHistory(bibDatabaseContext, e.getBibEntry(), e.getOldValue(), e.getNewValue());
+ }
+ }
+}
diff --git a/src/main/java/org/jabref/logic/ai/chatting/chathistory/ChatHistoryStorage.java b/src/main/java/org/jabref/logic/ai/chatting/chathistory/ChatHistoryStorage.java
new file mode 100644
index 00000000000..1a30a4332e7
--- /dev/null
+++ b/src/main/java/org/jabref/logic/ai/chatting/chathistory/ChatHistoryStorage.java
@@ -0,0 +1,20 @@
+package org.jabref.logic.ai.chatting.chathistory;
+
+import java.nio.file.Path;
+import java.util.List;
+
+import dev.langchain4j.data.message.ChatMessage;
+
+public interface ChatHistoryStorage {
+ List loadMessagesForEntry(Path bibDatabasePath, String citationKey);
+
+ void storeMessagesForEntry(Path bibDatabasePath, String citationKey, List messages);
+
+ List loadMessagesForGroup(Path bibDatabasePath, String name);
+
+ void storeMessagesForGroup(Path bibDatabasePath, String name, List messages);
+
+ void commit();
+
+ void close();
+}
diff --git a/src/main/java/org/jabref/logic/ai/chatting/chathistory/storages/MVStoreChatHistoryStorage.java b/src/main/java/org/jabref/logic/ai/chatting/chathistory/storages/MVStoreChatHistoryStorage.java
new file mode 100644
index 00000000000..af92ea69a2e
--- /dev/null
+++ b/src/main/java/org/jabref/logic/ai/chatting/chathistory/storages/MVStoreChatHistoryStorage.java
@@ -0,0 +1,135 @@
+package org.jabref.logic.ai.chatting.chathistory.storages;
+
+import java.io.Serializable;
+import java.nio.file.Path;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Map;
+
+import org.jabref.gui.DialogService;
+import org.jabref.logic.ai.chatting.chathistory.ChatHistoryStorage;
+import org.jabref.logic.ai.util.ErrorMessage;
+import org.jabref.logic.ai.util.MVStoreBase;
+import org.jabref.logic.l10n.Localization;
+
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.data.message.ChatMessage;
+import dev.langchain4j.data.message.UserMessage;
+import kotlin.ranges.IntRange;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class MVStoreChatHistoryStorage extends MVStoreBase implements ChatHistoryStorage {
+ private static final String ENTRY_CHAT_HISTORY_PREFIX = "entry";
+ private static final String GROUP_CHAT_HISTORY_PREFIX = "group";
+
+ private record ChatHistoryRecord(String className, String content) implements Serializable {
+ private static final Logger LOGGER = LoggerFactory.getLogger(ChatHistoryRecord.class);
+
+ public static ChatHistoryRecord fromLangchainMessage(ChatMessage chatMessage) {
+ String className = chatMessage.getClass().getName();
+ String content = getContentFromLangchainMessage(chatMessage);
+ return new ChatHistoryRecord(className, content);
+ }
+
+ private static String getContentFromLangchainMessage(ChatMessage chatMessage) {
+ String content;
+
+ switch (chatMessage) {
+ case AiMessage aiMessage ->
+ content = aiMessage.text();
+ case UserMessage userMessage ->
+ content = userMessage.singleText();
+ case ErrorMessage errorMessage ->
+ content = errorMessage.getText();
+ default -> {
+ LOGGER.warn("ChatHistoryRecord supports only AI, user. and error messages, but added message has other type: {}", chatMessage.type().name());
+ return "";
+ }
+ }
+
+ return content;
+ }
+
+ public ChatMessage toLangchainMessage() {
+ if (className.equals(AiMessage.class.getName())) {
+ return new AiMessage(content);
+ } else if (className.equals(UserMessage.class.getName())) {
+ return new UserMessage(content);
+ } else if (className.equals(ErrorMessage.class.getName())) {
+ return new ErrorMessage(content);
+ } else {
+ LOGGER.warn("ChatHistoryRecord supports only AI and user messages, but retrieved message has other type: {}. Will treat as an AI message.", className);
+ return new AiMessage(content);
+ }
+ }
+ }
+
+ public MVStoreChatHistoryStorage(Path path, DialogService dialogService) {
+ super(path, dialogService);
+ }
+
+ @Override
+ public List loadMessagesForEntry(Path bibDatabasePath, String citationKey) {
+ return loadMessagesFromMap(getMapForEntry(bibDatabasePath, citationKey));
+ }
+
+ @Override
+ public void storeMessagesForEntry(Path bibDatabasePath, String citationKey, List messages) {
+ storeMessagesForMap(getMapForEntry(bibDatabasePath, citationKey), messages);
+ }
+
+ @Override
+ public List loadMessagesForGroup(Path bibDatabasePath, String name) {
+ return loadMessagesFromMap(getMapForGroup(bibDatabasePath, name));
+ }
+
+ @Override
+ public void storeMessagesForGroup(Path bibDatabasePath, String name, List messages) {
+ storeMessagesForMap(getMapForGroup(bibDatabasePath, name), messages);
+ }
+
+ private List loadMessagesFromMap(Map map) {
+ return map
+ .entrySet()
+ // We need to check all keys, because upon deletion, there can be "holes" in the integer.
+ .stream()
+ .sorted(Comparator.comparingInt(Map.Entry::getKey))
+ .map(entry -> entry.getValue().toLangchainMessage())
+ .toList();
+ }
+
+ private void storeMessagesForMap(Map map, List messages) {
+ map.clear();
+
+ new IntRange(0, messages.size() - 1).forEach(i ->
+ map.put(i, ChatHistoryRecord.fromLangchainMessage(messages.get(i)))
+ );
+ }
+
+ private Map getMapForEntry(Path bibDatabasePath, String citationKey) {
+ return getMap(bibDatabasePath, ENTRY_CHAT_HISTORY_PREFIX, citationKey);
+ }
+
+ private Map getMapForGroup(Path bibDatabasePath, String name) {
+ return getMap(bibDatabasePath, GROUP_CHAT_HISTORY_PREFIX, name);
+ }
+
+ private Map getMap(Path bibDatabasePath, String type, String name) {
+ return mvStore.openMap(bibDatabasePath + "-" + type + "-" + name);
+ }
+
+ public void commit() {
+ mvStore.commit();
+ }
+
+ @Override
+ protected String errorMessageForOpening() {
+ return "An error occurred while opening chat history storage. Chat history of entries and groups will not be stored in the next session.";
+ }
+
+ @Override
+ protected String errorMessageForOpeningLocalized() {
+ return Localization.lang("An error occurred while opening chat history storage. Chat history of entries and groups will not be stored in the next session.");
+ }
+}
diff --git a/src/main/java/org/jabref/logic/ai/models/JabRefChatLanguageModel.java b/src/main/java/org/jabref/logic/ai/chatting/model/JabRefChatLanguageModel.java
similarity index 96%
rename from src/main/java/org/jabref/logic/ai/models/JabRefChatLanguageModel.java
rename to src/main/java/org/jabref/logic/ai/chatting/model/JabRefChatLanguageModel.java
index 9739b7a6d64..8bd71112cb8 100644
--- a/src/main/java/org/jabref/logic/ai/models/JabRefChatLanguageModel.java
+++ b/src/main/java/org/jabref/logic/ai/chatting/model/JabRefChatLanguageModel.java
@@ -1,4 +1,4 @@
-package org.jabref.logic.ai.models;
+package org.jabref.logic.ai.chatting.model;
import java.net.http.HttpClient;
import java.time.Duration;
@@ -7,7 +7,7 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
-import org.jabref.logic.ai.AiChatLogic;
+import org.jabref.logic.ai.chatting.AiChatLogic;
import org.jabref.logic.l10n.Localization;
import org.jabref.preferences.ai.AiApiKeyProvider;
import org.jabref.preferences.ai.AiPreferences;
@@ -20,7 +20,6 @@
import dev.langchain4j.model.huggingface.HuggingFaceChatModel;
import dev.langchain4j.model.mistralai.MistralAiChatModel;
import dev.langchain4j.model.output.Response;
-import org.h2.mvstore.MVStore;
/**
* Wrapper around langchain4j chat language model.
@@ -52,7 +51,7 @@ public JabRefChatLanguageModel(AiPreferences aiPreferences, AiApiKeyProvider api
* Update the underlying {@link dev.langchain4j.model.chat.ChatLanguageModel} by current {@link AiPreferences} parameters.
* When the model is updated, the chat messages are not lost.
* See {@link AiChatLogic}, where messages are stored in {@link ChatMemory},
- * and using {@link org.jabref.logic.ai.chathistory.BibDatabaseChatHistoryManager}, where messages are stored in {@link MVStore}.
+ * and see {@link org.jabref.logic.ai.chatting.chathistory.ChatHistoryStorage}.
*/
private void rebuild() {
String apiKey = apiKeyProvider.getApiKeyForAiProvider(aiPreferences.getAiProvider());
diff --git a/src/main/java/org/jabref/logic/ai/models/JvmOpenAiChatLanguageModel.java b/src/main/java/org/jabref/logic/ai/chatting/model/JvmOpenAiChatLanguageModel.java
similarity index 98%
rename from src/main/java/org/jabref/logic/ai/models/JvmOpenAiChatLanguageModel.java
rename to src/main/java/org/jabref/logic/ai/chatting/model/JvmOpenAiChatLanguageModel.java
index 19e1dfd1f93..f38435aaf8b 100644
--- a/src/main/java/org/jabref/logic/ai/models/JvmOpenAiChatLanguageModel.java
+++ b/src/main/java/org/jabref/logic/ai/chatting/model/JvmOpenAiChatLanguageModel.java
@@ -1,4 +1,4 @@
-package org.jabref.logic.ai.models;
+package org.jabref.logic.ai.chatting.model;
import java.net.http.HttpClient;
import java.util.List;
diff --git a/src/main/java/org/jabref/logic/ai/FileEmbeddingsManager.java b/src/main/java/org/jabref/logic/ai/ingestion/FileEmbeddingsManager.java
similarity index 62%
rename from src/main/java/org/jabref/logic/ai/FileEmbeddingsManager.java
rename to src/main/java/org/jabref/logic/ai/ingestion/FileEmbeddingsManager.java
index 48e07d61a6b..c366874885f 100644
--- a/src/main/java/org/jabref/logic/ai/FileEmbeddingsManager.java
+++ b/src/main/java/org/jabref/logic/ai/ingestion/FileEmbeddingsManager.java
@@ -1,24 +1,19 @@
-package org.jabref.logic.ai;
+package org.jabref.logic.ai.ingestion;
import java.util.List;
import java.util.Optional;
-import java.util.Set;
import javafx.beans.property.IntegerProperty;
import javafx.beans.property.ReadOnlyBooleanProperty;
-import org.jabref.logic.ai.embeddings.FullyIngestedDocumentsTracker;
-import org.jabref.logic.ai.embeddings.LowLevelIngestor;
-import org.jabref.logic.ai.embeddings.MVStoreEmbeddingStore;
-import org.jabref.logic.ai.models.JabRefEmbeddingModel;
import org.jabref.model.entry.LinkedFile;
import org.jabref.preferences.ai.AiPreferences;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.segment.TextSegment;
+import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder;
-import org.h2.mvstore.MVStore;
/**
* This class is responsible for managing the embeddings cache. The cache is saved in a local user directory.
@@ -36,19 +31,23 @@ public class FileEmbeddingsManager {
public static final String LINK_METADATA_KEY = "link";
private final AiPreferences aiPreferences;
- private final ReadOnlyBooleanProperty shutdownProperty;
+ private final ReadOnlyBooleanProperty shutdownSignal;
- private final MVStoreEmbeddingStore embeddingStore;
+ private final EmbeddingStore embeddingStore;
private final FullyIngestedDocumentsTracker fullyIngestedDocumentsTracker;
private final LowLevelIngestor lowLevelIngestor;
- public FileEmbeddingsManager(AiPreferences aiPreferences, ReadOnlyBooleanProperty shutdownProperty, JabRefEmbeddingModel jabRefEmbeddingModel, MVStore mvStore) {
+ public FileEmbeddingsManager(AiPreferences aiPreferences,
+ ReadOnlyBooleanProperty shutdownSignal,
+ EmbeddingModel embeddingModel,
+ EmbeddingStore embeddingStore,
+ FullyIngestedDocumentsTracker fullyIngestedDocumentsTracker
+ ) {
this.aiPreferences = aiPreferences;
- this.shutdownProperty = shutdownProperty;
-
- this.embeddingStore = new MVStoreEmbeddingStore(mvStore);
- this.fullyIngestedDocumentsTracker = new FullyIngestedDocumentsTracker(mvStore);
- this.lowLevelIngestor = new LowLevelIngestor(aiPreferences, embeddingStore, jabRefEmbeddingModel);
+ this.shutdownSignal = shutdownSignal;
+ this.embeddingStore = embeddingStore;
+ this.fullyIngestedDocumentsTracker = fullyIngestedDocumentsTracker;
+ this.lowLevelIngestor = new LowLevelIngestor(aiPreferences, embeddingStore, embeddingModel);
setupListeningToPreferencesChanges();
}
@@ -59,9 +58,9 @@ private void setupListeningToPreferencesChanges() {
public void addDocument(String link, Document document, long modificationTimeInSeconds, IntegerProperty workDone, IntegerProperty workMax) throws InterruptedException {
document.metadata().put(LINK_METADATA_KEY, link);
- lowLevelIngestor.ingestDocument(document, shutdownProperty, workDone, workMax);
+ lowLevelIngestor.ingestDocument(document, shutdownSignal, workDone, workMax);
- if (!shutdownProperty.get()) {
+ if (!shutdownSignal.get()) {
fullyIngestedDocumentsTracker.markDocumentAsFullyIngested(link, modificationTimeInSeconds);
}
}
@@ -75,30 +74,10 @@ public EmbeddingStore getEmbeddingsStore() {
return embeddingStore;
}
- public Set getIngestedDocuments() {
- return fullyIngestedDocumentsTracker.getFullyIngestedDocuments();
- }
-
public Optional getIngestedDocumentModificationTimeInSeconds(String link) {
return fullyIngestedDocumentsTracker.getIngestedDocumentModificationTimeInSeconds(link);
}
- public void registerListener(Object object) {
- fullyIngestedDocumentsTracker.registerListener(object);
- }
-
- public boolean hasIngestedDocument(String link) {
- return fullyIngestedDocumentsTracker.hasIngestedDocument(link);
- }
-
- public boolean hasIngestedDocuments(List links) {
- return links.stream().allMatch(this::hasIngestedDocument);
- }
-
- public boolean hasIngestedLinkedFiles(List linkedFiles) {
- return hasIngestedDocuments(linkedFiles.stream().map(LinkedFile::getLink).toList());
- }
-
public void clearEmbeddingsFor(List linkedFiles) {
linkedFiles.stream().map(LinkedFile::getLink).forEach(this::removeDocument);
}
diff --git a/src/main/java/org/jabref/logic/ai/embeddings/FileToDocument.java b/src/main/java/org/jabref/logic/ai/ingestion/FileToDocument.java
similarity index 63%
rename from src/main/java/org/jabref/logic/ai/embeddings/FileToDocument.java
rename to src/main/java/org/jabref/logic/ai/ingestion/FileToDocument.java
index 3277021dc09..0d8cb119906 100644
--- a/src/main/java/org/jabref/logic/ai/embeddings/FileToDocument.java
+++ b/src/main/java/org/jabref/logic/ai/ingestion/FileToDocument.java
@@ -1,50 +1,62 @@
-package org.jabref.logic.ai.embeddings;
+package org.jabref.logic.ai.ingestion;
import java.io.StringWriter;
import java.nio.file.Path;
import java.util.Optional;
+import javafx.beans.property.ReadOnlyBooleanProperty;
+
+import org.jabref.logic.pdf.InterruptablePDFTextStripper;
import org.jabref.logic.util.io.FileUtil;
import org.jabref.logic.xmp.XmpUtilReader;
import dev.langchain4j.data.document.Document;
import org.apache.pdfbox.pdmodel.PDDocument;
-import org.apache.pdfbox.text.PDFTextStripper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class FileToDocument {
private static final Logger LOGGER = LoggerFactory.getLogger(FileToDocument.class);
- public static Optional fromFile(Path path) {
+ private final ReadOnlyBooleanProperty shutdownSignal;
+
+ public FileToDocument(ReadOnlyBooleanProperty shutdownSignal) {
+ this.shutdownSignal = shutdownSignal;
+ }
+
+ public Optional fromFile(Path path) {
if (FileUtil.isPDFFile(path)) {
- return FileToDocument.fromPdfFile(path);
+ return fromPdfFile(path);
} else {
LOGGER.info("Unsupported file type of file: {}. Currently, only PDF files are supported", path);
return Optional.empty();
}
}
- private static Optional fromPdfFile(Path path) {
+ private Optional fromPdfFile(Path path) {
// This method is private to ensure that the path is really pointing to PDF file (determined by extension).
try (PDDocument document = new XmpUtilReader().loadWithAutomaticDecryption(path)) {
int lastPage = document.getNumberOfPages();
StringWriter writer = new StringWriter();
- PDFTextStripper stripper = new PDFTextStripper();
+ InterruptablePDFTextStripper stripper = new InterruptablePDFTextStripper(shutdownSignal);
stripper.setStartPage(1);
stripper.setEndPage(lastPage);
stripper.writeText(document, writer);
- return FileToDocument.fromString(writer.toString());
+ if (shutdownSignal.get()) {
+ return Optional.empty();
+ }
+
+ return fromString(writer.toString());
} catch (Exception e) {
LOGGER.error("An error occurred while reading the PDF file: {}", path, e);
return Optional.empty();
}
}
- public static Optional fromString(String content) {
+ public Optional fromString(String content) {
return Optional.of(new Document(content));
}
}
diff --git a/src/main/java/org/jabref/logic/ai/ingestion/FullyIngestedDocumentsTracker.java b/src/main/java/org/jabref/logic/ai/ingestion/FullyIngestedDocumentsTracker.java
new file mode 100644
index 00000000000..9ba066e7900
--- /dev/null
+++ b/src/main/java/org/jabref/logic/ai/ingestion/FullyIngestedDocumentsTracker.java
@@ -0,0 +1,20 @@
+package org.jabref.logic.ai.ingestion;
+
+import java.util.Optional;
+
+/**
+ * This class is responsible for recording the information about which documents (or documents) have been fully ingested.
+ *
+ * The class also records the document modification time.
+ */
+public interface FullyIngestedDocumentsTracker {
+ void markDocumentAsFullyIngested(String link, long modificationTimeInSeconds);
+
+ Optional getIngestedDocumentModificationTimeInSeconds(String link);
+
+ void unmarkDocumentAsFullyIngested(String link);
+
+ void commit();
+
+ void close();
+}
diff --git a/src/main/java/org/jabref/logic/ai/ingestion/GenerateEmbeddingsForSeveralTask.java b/src/main/java/org/jabref/logic/ai/ingestion/GenerateEmbeddingsForSeveralTask.java
new file mode 100644
index 00000000000..43a0c813e7d
--- /dev/null
+++ b/src/main/java/org/jabref/logic/ai/ingestion/GenerateEmbeddingsForSeveralTask.java
@@ -0,0 +1,113 @@
+package org.jabref.logic.ai.ingestion;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.Future;
+
+import javafx.beans.property.ReadOnlyBooleanProperty;
+import javafx.beans.property.StringProperty;
+import javafx.util.Pair;
+
+import org.jabref.gui.util.BackgroundTask;
+import org.jabref.gui.util.TaskExecutor;
+import org.jabref.logic.ai.processingstatus.ProcessingInfo;
+import org.jabref.logic.ai.processingstatus.ProcessingState;
+import org.jabref.logic.l10n.Localization;
+import org.jabref.logic.util.ProgressCounter;
+import org.jabref.model.database.BibDatabaseContext;
+import org.jabref.model.entry.LinkedFile;
+import org.jabref.preferences.FilePreferences;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This task generates embeddings for several {@link LinkedFile} (typically used for groups).
+ * It will check if embeddings were already generated.
+ * And it also will store the embeddings.
+ */
+public class GenerateEmbeddingsForSeveralTask extends BackgroundTask {
+ private static final Logger LOGGER = LoggerFactory.getLogger(GenerateEmbeddingsTask.class);
+
+ private final StringProperty name;
+ private final List> linkedFiles;
+ private final FileEmbeddingsManager fileEmbeddingsManager;
+ private final BibDatabaseContext bibDatabaseContext;
+ private final FilePreferences filePreferences;
+ private final TaskExecutor taskExecutor;
+ private final ReadOnlyBooleanProperty shutdownSignal;
+
+ private final ProgressCounter progressCounter = new ProgressCounter();
+
+ private String currentFile = "";
+
+ public GenerateEmbeddingsForSeveralTask(
+ StringProperty name,
+ List> linkedFiles,
+ FileEmbeddingsManager fileEmbeddingsManager,
+ BibDatabaseContext bibDatabaseContext,
+ FilePreferences filePreferences,
+ TaskExecutor taskExecutor,
+ ReadOnlyBooleanProperty shutdownSignal
+ ) {
+ this.name = name;
+ this.linkedFiles = linkedFiles;
+ this.fileEmbeddingsManager = fileEmbeddingsManager;
+ this.bibDatabaseContext = bibDatabaseContext;
+ this.filePreferences = filePreferences;
+ this.taskExecutor = taskExecutor;
+ this.shutdownSignal = shutdownSignal;
+
+ configure(name);
+ }
+
+ private void configure(StringProperty name) {
+ showToUser(true);
+ titleProperty().set(Localization.lang("Generating embeddings for %0", name.get()));
+ name.addListener((o, oldValue, newValue) -> titleProperty().set(Localization.lang("Generating embeddings for %0", newValue)));
+
+ progressCounter.increaseWorkMax(linkedFiles.size());
+ progressCounter.listenToAllProperties(this::updateProgress);
+ updateProgress();
+ }
+
+ @Override
+ protected Void call() throws Exception {
+ LOGGER.debug("Starting embeddings generation of several files for {}", name.get());
+
+ List, String>> futures = new ArrayList<>();
+ linkedFiles
+ .stream()
+ .map(processingInfo -> {
+ processingInfo.setState(ProcessingState.PROCESSING);
+ return new Pair<>(
+ new GenerateEmbeddingsTask(
+ processingInfo.getObject(),
+ fileEmbeddingsManager,
+ bibDatabaseContext,
+ filePreferences,
+ shutdownSignal
+ )
+ .onSuccess(v -> processingInfo.setState(ProcessingState.SUCCESS))
+ .onFailure(processingInfo::setException)
+ .onFinished(() -> progressCounter.increaseWorkDone(1))
+ .executeWith(taskExecutor),
+ processingInfo.getObject().getLink());
+ })
+ .forEach(futures::add);
+
+ for (Pair extends Future>, String> pair : futures) {
+ currentFile = pair.getValue();
+ pair.getKey().get();
+ }
+
+ LOGGER.debug("Finished embeddings generation task of several files for {}", name.get());
+ progressCounter.stop();
+ return null;
+ }
+
+ private void updateProgress() {
+ updateProgress(progressCounter.getWorkDone(), progressCounter.getWorkMax());
+ updateMessage(progressCounter.getMessage() + " - " + currentFile + ", ...");
+ }
+}
diff --git a/src/main/java/org/jabref/logic/ai/ingestion/GenerateEmbeddingsTask.java b/src/main/java/org/jabref/logic/ai/ingestion/GenerateEmbeddingsTask.java
new file mode 100644
index 00000000000..e0e1561f267
--- /dev/null
+++ b/src/main/java/org/jabref/logic/ai/ingestion/GenerateEmbeddingsTask.java
@@ -0,0 +1,132 @@
+package org.jabref.logic.ai.ingestion;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.attribute.BasicFileAttributes;
+import java.util.Optional;
+import java.util.concurrent.TimeUnit;
+
+import javafx.beans.property.ReadOnlyBooleanProperty;
+
+import org.jabref.gui.util.BackgroundTask;
+import org.jabref.logic.l10n.Localization;
+import org.jabref.logic.util.ProgressCounter;
+import org.jabref.model.database.BibDatabaseContext;
+import org.jabref.model.entry.LinkedFile;
+import org.jabref.preferences.FilePreferences;
+
+import dev.langchain4j.data.document.Document;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This task generates embeddings for a {@link LinkedFile}.
+ * It will check if embeddings were already generated.
+ * And it also will store the embeddings.
+ */
+public class GenerateEmbeddingsTask extends BackgroundTask {
+ private static final Logger LOGGER = LoggerFactory.getLogger(GenerateEmbeddingsTask.class);
+
+ private final LinkedFile linkedFile;
+ private final FileEmbeddingsManager fileEmbeddingsManager;
+ private final BibDatabaseContext bibDatabaseContext;
+ private final FilePreferences filePreferences;
+ private final ReadOnlyBooleanProperty shutdownSignal;
+
+ private final ProgressCounter progressCounter = new ProgressCounter();
+
+ public GenerateEmbeddingsTask(LinkedFile linkedFile,
+ FileEmbeddingsManager fileEmbeddingsManager,
+ BibDatabaseContext bibDatabaseContext,
+ FilePreferences filePreferences,
+ ReadOnlyBooleanProperty shutdownSignal
+ ) {
+ this.linkedFile = linkedFile;
+ this.fileEmbeddingsManager = fileEmbeddingsManager;
+ this.bibDatabaseContext = bibDatabaseContext;
+ this.filePreferences = filePreferences;
+ this.shutdownSignal = shutdownSignal;
+
+ configure(linkedFile);
+ }
+
+ private void configure(LinkedFile linkedFile) {
+ titleProperty().set(Localization.lang("Generating embeddings for file '%0'", linkedFile.getLink()));
+
+ progressCounter.listenToAllProperties(this::updateProgress);
+ }
+
+ @Override
+ protected Void call() throws Exception {
+ LOGGER.debug("Starting embeddings generation task for file \"{}\"", linkedFile.getLink());
+
+ try {
+ ingestLinkedFile(linkedFile);
+ } catch (InterruptedException e) {
+ LOGGER.debug("There is a embeddings generation task for file \"{}\". It will be cancelled, because user quits JabRef.", linkedFile.getLink());
+ }
+
+ LOGGER.debug("Finished embeddings generation task for file \"{}\"", linkedFile.getLink());
+ progressCounter.stop();
+ return null;
+ }
+
+ private void ingestLinkedFile(LinkedFile linkedFile) throws InterruptedException {
+ // Rationale for RuntimeException here:
+ // See org.jabref.logic.ai.summarization.GenerateSummaryTask.summarizeAll
+
+ LOGGER.debug("Generating embeddings for file \"{}\"", linkedFile.getLink());
+
+ Optional path = linkedFile.findIn(bibDatabaseContext, filePreferences);
+
+ if (path.isEmpty()) {
+ LOGGER.error("Could not find path for a linked file \"{}\", while generating embeddings", linkedFile.getLink());
+ LOGGER.debug("Unable to generate embeddings for file \"{}\", because it was not found while generating embeddings", linkedFile.getLink());
+ throw new RuntimeException(Localization.lang("Could not find path for a linked file '%0' while generating embeddings.", linkedFile.getLink()));
+ }
+
+ Optional modTime = Optional.empty();
+ boolean shouldIngest = true;
+
+ try {
+ BasicFileAttributes attributes = Files.readAttributes(path.get(), BasicFileAttributes.class);
+
+ long currentModificationTimeInSeconds = attributes.lastModifiedTime().to(TimeUnit.SECONDS);
+
+ Optional ingestedModificationTimeInSeconds = fileEmbeddingsManager.getIngestedDocumentModificationTimeInSeconds(linkedFile.getLink());
+
+ if (ingestedModificationTimeInSeconds.isEmpty()) {
+ modTime = Optional.of(currentModificationTimeInSeconds);
+ } else {
+ if (currentModificationTimeInSeconds > ingestedModificationTimeInSeconds.get()) {
+ modTime = Optional.of(currentModificationTimeInSeconds);
+ } else {
+ LOGGER.debug("No need to generate embeddings for file \"{}\", because it was already generated", linkedFile.getLink());
+ shouldIngest = false;
+ }
+ }
+ } catch (IOException e) {
+ LOGGER.error("Could not retrieve attributes of a linked file \"{}\"", linkedFile.getLink(), e);
+ LOGGER.warn("Possibly regenerating embeddings for linked file \"{}\"", linkedFile.getLink());
+ }
+
+ if (!shouldIngest) {
+ return;
+ }
+
+ Optional document = new FileToDocument(shutdownSignal).fromFile(path.get());
+ if (document.isPresent()) {
+ fileEmbeddingsManager.addDocument(linkedFile.getLink(), document.get(), modTime.orElse(0L), progressCounter.workDoneProperty(), progressCounter.workMaxProperty());
+ LOGGER.debug("Embeddings for file \"{}\" were generated successfully", linkedFile.getLink());
+ } else {
+ LOGGER.error("Unable to generate embeddings for file \"{}\", because JabRef was unable to extract text from the file", linkedFile.getLink());
+ throw new RuntimeException(Localization.lang("Unable to generate embeddings for file '%0', because JabRef was unable to extract text from the file", linkedFile.getLink()));
+ }
+ }
+
+ private void updateProgress() {
+ updateProgress(progressCounter.getWorkDone(), progressCounter.getWorkMax());
+ updateMessage(progressCounter.getMessage());
+ }
+}
diff --git a/src/main/java/org/jabref/logic/ai/ingestion/IngestionService.java b/src/main/java/org/jabref/logic/ai/ingestion/IngestionService.java
new file mode 100644
index 00000000000..9e692a04b36
--- /dev/null
+++ b/src/main/java/org/jabref/logic/ai/ingestion/IngestionService.java
@@ -0,0 +1,118 @@
+package org.jabref.logic.ai.ingestion;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import javafx.beans.property.ReadOnlyBooleanProperty;
+import javafx.beans.property.StringProperty;
+
+import org.jabref.gui.util.TaskExecutor;
+import org.jabref.logic.ai.processingstatus.ProcessingInfo;
+import org.jabref.logic.ai.processingstatus.ProcessingState;
+import org.jabref.model.database.BibDatabaseContext;
+import org.jabref.model.entry.LinkedFile;
+import org.jabref.preferences.FilePreferences;
+import org.jabref.preferences.ai.AiPreferences;
+
+import dev.langchain4j.data.segment.TextSegment;
+import dev.langchain4j.model.embedding.EmbeddingModel;
+import dev.langchain4j.store.embedding.EmbeddingStore;
+
+/**
+ * Main class for generating embedding for files.
+ * Use this class in the logic and UI.
+ */
+public class IngestionService {
+ private final Map> ingestionStatusMap = new HashMap<>();
+
+ private final List> listsUnderIngestion = new ArrayList<>();
+
+ private final FilePreferences filePreferences;
+ private final TaskExecutor taskExecutor;
+
+ private final FileEmbeddingsManager fileEmbeddingsManager;
+
+ private final ReadOnlyBooleanProperty shutdownSignal;
+
+ public IngestionService(AiPreferences aiPreferences,
+ ReadOnlyBooleanProperty shutdownSignal,
+ EmbeddingModel embeddingModel,
+ EmbeddingStore embeddingStore,
+ FullyIngestedDocumentsTracker fullyIngestedDocumentsTracker,
+ FilePreferences filePreferences,
+ TaskExecutor taskExecutor
+ ) {
+ this.filePreferences = filePreferences;
+ this.taskExecutor = taskExecutor;
+
+ this.fileEmbeddingsManager = new FileEmbeddingsManager(
+ aiPreferences,
+ shutdownSignal,
+ embeddingModel,
+ embeddingStore,
+ fullyIngestedDocumentsTracker
+ );
+
+ this.shutdownSignal = shutdownSignal;
+ }
+
+ /**
+ * Start ingesting of a {@link LinkedFile}, if it was not ingested.
+ * This method returns a {@link ProcessingInfo} that can be used for tracking state of the ingestion.
+ * Returned {@link ProcessingInfo} is related to the passed {@link LinkedFile}, so if you call this method twice
+ * on the same {@link LinkedFile}, the method will return the same {@link ProcessingInfo}.
+ */
+ public ProcessingInfo ingest(LinkedFile linkedFile, BibDatabaseContext bibDatabaseContext) {
+ ProcessingInfo processingInfo = getProcessingInfo(linkedFile);
+
+ if (processingInfo.getState() == ProcessingState.STOPPED) {
+ startEmbeddingsGenerationTask(linkedFile, bibDatabaseContext, processingInfo);
+ }
+
+ return processingInfo;
+ }
+
+ /**
+ * Get {@link ProcessingInfo} of a {@link LinkedFile}. Initially, it is in state {@link ProcessingState#STOPPED}.
+ * This method will not start ingesting. If you need to start it, use {@link IngestionService#ingest(LinkedFile, BibDatabaseContext)}.
+ */
+ public ProcessingInfo getProcessingInfo(LinkedFile linkedFile) {
+ return ingestionStatusMap.computeIfAbsent(linkedFile, file -> new ProcessingInfo<>(linkedFile, ProcessingState.STOPPED));
+ }
+
+ public List> getProcessingInfo(List linkedFiles) {
+ return linkedFiles.stream().map(this::getProcessingInfo).toList();
+ }
+
+ public List> ingest(StringProperty name, List linkedFiles, BibDatabaseContext bibDatabaseContext) {
+ List> result = getProcessingInfo(linkedFiles);
+
+ if (listsUnderIngestion.contains(linkedFiles)) {
+ return result;
+ }
+
+ List> needToProcess = result.stream().filter(processingInfo -> processingInfo.getState() == ProcessingState.STOPPED).toList();
+ startEmbeddingsGenerationTask(name, needToProcess, bibDatabaseContext);
+
+ return result;
+ }
+
+ private void startEmbeddingsGenerationTask(LinkedFile linkedFile, BibDatabaseContext bibDatabaseContext, ProcessingInfo processingInfo) {
+ new GenerateEmbeddingsTask(linkedFile, fileEmbeddingsManager, bibDatabaseContext, filePreferences, shutdownSignal)
+ .onSuccess(v -> processingInfo.setState(ProcessingState.SUCCESS))
+ .onFailure(processingInfo::setException)
+ .executeWith(taskExecutor);
+ }
+
+ private void startEmbeddingsGenerationTask(StringProperty name, List> linkedFiles, BibDatabaseContext bibDatabaseContext) {
+ new GenerateEmbeddingsForSeveralTask(name, linkedFiles, fileEmbeddingsManager, bibDatabaseContext, filePreferences, taskExecutor, shutdownSignal)
+ .executeWith(taskExecutor);
+ }
+
+ public void clearEmbeddingsFor(List linkedFiles) {
+ fileEmbeddingsManager.clearEmbeddingsFor(linkedFiles);
+ ingestionStatusMap.values().forEach(processingInfo -> processingInfo.setState(ProcessingState.STOPPED));
+ }
+}
diff --git a/src/main/java/org/jabref/logic/ai/embeddings/LowLevelIngestor.java b/src/main/java/org/jabref/logic/ai/ingestion/LowLevelIngestor.java
similarity index 98%
rename from src/main/java/org/jabref/logic/ai/embeddings/LowLevelIngestor.java
rename to src/main/java/org/jabref/logic/ai/ingestion/LowLevelIngestor.java
index abd023ea063..28589cb053c 100644
--- a/src/main/java/org/jabref/logic/ai/embeddings/LowLevelIngestor.java
+++ b/src/main/java/org/jabref/logic/ai/ingestion/LowLevelIngestor.java
@@ -1,4 +1,4 @@
-package org.jabref.logic.ai.embeddings;
+package org.jabref.logic.ai.ingestion;
import java.util.List;
diff --git a/src/main/java/org/jabref/logic/ai/embeddings/MVStoreEmbeddingStore.java b/src/main/java/org/jabref/logic/ai/ingestion/MVStoreEmbeddingStore.java
similarity index 75%
rename from src/main/java/org/jabref/logic/ai/embeddings/MVStoreEmbeddingStore.java
rename to src/main/java/org/jabref/logic/ai/ingestion/MVStoreEmbeddingStore.java
index 16c334eef37..8a1e4e76c70 100644
--- a/src/main/java/org/jabref/logic/ai/embeddings/MVStoreEmbeddingStore.java
+++ b/src/main/java/org/jabref/logic/ai/ingestion/MVStoreEmbeddingStore.java
@@ -1,6 +1,7 @@
-package org.jabref.logic.ai.embeddings;
+package org.jabref.logic.ai.ingestion;
import java.io.Serializable;
+import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
@@ -13,7 +14,9 @@
import java.util.stream.IntStream;
import java.util.stream.Stream;
-import org.jabref.logic.ai.FileEmbeddingsManager;
+import org.jabref.gui.DialogService;
+import org.jabref.logic.ai.util.MVStoreBase;
+import org.jabref.logic.l10n.Localization;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
@@ -31,6 +34,7 @@
import org.h2.mvstore.MVStore;
import static java.util.Comparator.comparingDouble;
+import static org.jabref.logic.ai.ingestion.FileEmbeddingsManager.LINK_METADATA_KEY;
/**
* A custom implementation of langchain4j's {@link EmbeddingStore} that uses a {@link MVStore} as an embedded database.
@@ -39,14 +43,18 @@
* string (the content).
*
*/
-public class MVStoreEmbeddingStore implements EmbeddingStore {
+public class MVStoreEmbeddingStore extends MVStoreBase implements EmbeddingStore {
// `file` field is nullable, because {@link Optional} can't be serialized.
private record EmbeddingRecord(@Nullable String file, String content, float[] embeddingVector) implements Serializable { }
+ private static final String EMBEDDINGS_MAP_NAME = "embeddings";
+
private final Map embeddingsMap;
- public MVStoreEmbeddingStore(MVStore mvStore) {
- this.embeddingsMap = mvStore.openMap("embeddingsMap");
+ public MVStoreEmbeddingStore(Path path, DialogService dialogService) {
+ super(path, dialogService);
+
+ this.embeddingsMap = this.mvStore.openMap(EMBEDDINGS_MAP_NAME);
}
@Override
@@ -74,7 +82,7 @@ public void add(String id, Embedding embedding) {
@Override
public String add(Embedding embedding, TextSegment textSegment) {
String id = String.valueOf(UUID.randomUUID());
- String linkedFile = textSegment.metadata().getString(FileEmbeddingsManager.LINK_METADATA_KEY);
+ String linkedFile = textSegment.metadata().getString(LINK_METADATA_KEY);
embeddingsMap.put(id, new EmbeddingRecord(linkedFile, textSegment.text(), embedding.vector()));
return id;
}
@@ -103,8 +111,8 @@ public void removeAll() {
/**
* The main function of finding most relevant text segments.
* Note: the only filters supported are:
- * - {@link IsIn} with key {@link FileEmbeddingsManager.LINK_METADATA_KEY}
- * - {@link IsEqualTo} with key {@link FileEmbeddingsManager.LINK_METADATA_KEY}
+ * - {@link IsIn} with key {@link LINK_METADATA_KEY}
+ * - {@link IsEqualTo} with key {@link LINK_METADATA_KEY}
*
* @param request embedding search request
*
@@ -124,7 +132,15 @@ public EmbeddingSearchResult search(EmbeddingSearchRequest request)
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
if (score >= request.minScore()) {
- matches.add(new EmbeddingMatch<>(score, id, Embedding.from(eRecord.embeddingVector), new TextSegment(eRecord.content, new Metadata())));
+ matches.add(
+ new EmbeddingMatch<>(
+ score,
+ id,
+ Embedding.from(eRecord.embeddingVector),
+ new TextSegment(
+ eRecord.content,
+ new Metadata(
+ eRecord.file == null ? Map.of() : Map.of(LINK_METADATA_KEY, eRecord.file)))));
if (matches.size() > request.maxResults()) {
matches.poll();
@@ -147,10 +163,10 @@ private Stream applyFilter(@Nullable Filter filter) {
return switch (filter) {
case null -> embeddingsMap.keySet().stream();
- case IsIn isInFilter when Objects.equals(isInFilter.key(), FileEmbeddingsManager.LINK_METADATA_KEY) ->
+ case IsIn isInFilter when Objects.equals(isInFilter.key(), LINK_METADATA_KEY) ->
filterEntries(entry -> isInFilter.comparisonValues().contains(entry.getValue().file));
- case IsEqualTo isEqualToFilter when Objects.equals(isEqualToFilter.key(), FileEmbeddingsManager.LINK_METADATA_KEY) ->
+ case IsEqualTo isEqualToFilter when Objects.equals(isEqualToFilter.key(), LINK_METADATA_KEY) ->
filterEntries(entry -> isEqualToFilter.comparisonValue().equals(entry.getValue().file));
default -> throw new IllegalArgumentException("Wrong filter passed to MVStoreEmbeddingStore");
@@ -160,4 +176,14 @@ private Stream applyFilter(@Nullable Filter filter) {
private Stream filterEntries(Predicate> predicate) {
return embeddingsMap.entrySet().stream().filter(predicate).map(Map.Entry::getKey);
}
+
+ @Override
+ protected String errorMessageForOpening() {
+ return "An error occurred while opening the embeddings cache file. Embeddings will not be stored in the next session.";
+ }
+
+ @Override
+ protected String errorMessageForOpeningLocalized() {
+ return Localization.lang("An error occurred while opening the embeddings cache file. Embeddings will not be stored in the next session.");
+ }
}
diff --git a/src/main/java/org/jabref/logic/ai/models/DeepJavaEmbeddingModel.java b/src/main/java/org/jabref/logic/ai/ingestion/model/DeepJavaEmbeddingModel.java
similarity index 97%
rename from src/main/java/org/jabref/logic/ai/models/DeepJavaEmbeddingModel.java
rename to src/main/java/org/jabref/logic/ai/ingestion/model/DeepJavaEmbeddingModel.java
index c8eba6c6131..560a953a6e3 100644
--- a/src/main/java/org/jabref/logic/ai/models/DeepJavaEmbeddingModel.java
+++ b/src/main/java/org/jabref/logic/ai/ingestion/model/DeepJavaEmbeddingModel.java
@@ -1,4 +1,4 @@
-package org.jabref.logic.ai.models;
+package org.jabref.logic.ai.ingestion.model;
import java.io.IOException;
import java.util.ArrayList;
diff --git a/src/main/java/org/jabref/logic/ai/models/JabRefEmbeddingModel.java b/src/main/java/org/jabref/logic/ai/ingestion/model/JabRefEmbeddingModel.java
similarity index 99%
rename from src/main/java/org/jabref/logic/ai/models/JabRefEmbeddingModel.java
rename to src/main/java/org/jabref/logic/ai/ingestion/model/JabRefEmbeddingModel.java
index e9600567309..83870a691aa 100644
--- a/src/main/java/org/jabref/logic/ai/models/JabRefEmbeddingModel.java
+++ b/src/main/java/org/jabref/logic/ai/ingestion/model/JabRefEmbeddingModel.java
@@ -1,4 +1,4 @@
-package org.jabref.logic.ai.models;
+package org.jabref.logic.ai.ingestion.model;
import java.util.List;
import java.util.Optional;
diff --git a/src/main/java/org/jabref/logic/ai/models/UpdateEmbeddingModelTask.java b/src/main/java/org/jabref/logic/ai/ingestion/model/UpdateEmbeddingModelTask.java
similarity index 98%
rename from src/main/java/org/jabref/logic/ai/models/UpdateEmbeddingModelTask.java
rename to src/main/java/org/jabref/logic/ai/ingestion/model/UpdateEmbeddingModelTask.java
index 0a487aeccb6..d972e66a2ef 100644
--- a/src/main/java/org/jabref/logic/ai/models/UpdateEmbeddingModelTask.java
+++ b/src/main/java/org/jabref/logic/ai/ingestion/model/UpdateEmbeddingModelTask.java
@@ -1,4 +1,4 @@
-package org.jabref.logic.ai.models;
+package org.jabref.logic.ai.ingestion.model;
import java.io.IOException;
import java.util.Optional;
diff --git a/src/main/java/org/jabref/logic/ai/embeddings/FullyIngestedDocumentsTracker.java b/src/main/java/org/jabref/logic/ai/ingestion/storages/MVStoreFullyIngestedDocumentsTracker.java
similarity index 53%
rename from src/main/java/org/jabref/logic/ai/embeddings/FullyIngestedDocumentsTracker.java
rename to src/main/java/org/jabref/logic/ai/ingestion/storages/MVStoreFullyIngestedDocumentsTracker.java
index 90c6e17813c..a9d3ff55cbe 100644
--- a/src/main/java/org/jabref/logic/ai/embeddings/FullyIngestedDocumentsTracker.java
+++ b/src/main/java/org/jabref/logic/ai/ingestion/storages/MVStoreFullyIngestedDocumentsTracker.java
@@ -1,21 +1,22 @@
-package org.jabref.logic.ai.embeddings;
+package org.jabref.logic.ai.ingestion.storages;
-import java.util.HashSet;
+import java.nio.file.Path;
import java.util.Map;
import java.util.Optional;
-import java.util.Set;
-import com.google.common.eventbus.EventBus;
-import org.h2.mvstore.MVStore;
+import org.jabref.gui.DialogService;
+import org.jabref.logic.ai.ingestion.FullyIngestedDocumentsTracker;
+import org.jabref.logic.ai.util.MVStoreBase;
+import org.jabref.logic.l10n.Localization;
/**
* This class is responsible for recording the information about which documents (or documents) have been fully ingested.
*
- * It will also post an {@link DocumentIngestedEvent} to its event bus when a document is fully ingested.
- *
* The class also records the document modification time.
*/
-public class FullyIngestedDocumentsTracker {
+public class MVStoreFullyIngestedDocumentsTracker extends MVStoreBase implements FullyIngestedDocumentsTracker {
+ private static final String INGESTED_MAP_NAME = "ingested";
+
// This map stores the ingested documents. The key is LinkedDocument.getLink(), and the value is the modification time in seconds.
// If an entry is present, then it means the document was ingested. Otherwise, document was not ingested.
// The reason why we need to track ingested documents is because we cannot use AiEmbeddingsManager and see if there are
@@ -23,37 +24,31 @@ public class FullyIngestedDocumentsTracker {
// it doesn't mean the document is fully ingested.
private final Map ingestedMap;
- // Used to update the tab content after the data is available
- private final EventBus eventBus = new EventBus();
+ public MVStoreFullyIngestedDocumentsTracker(Path path, DialogService dialogService) {
+ super(path, dialogService);
- public FullyIngestedDocumentsTracker(MVStore mvStore) {
- this.ingestedMap = mvStore.openMap("ingestedMap");
+ this.ingestedMap = this.mvStore.openMap(INGESTED_MAP_NAME);
}
- public boolean hasIngestedDocument(String link) {
- return ingestedMap.containsKey(link);
- }
-
- public static class DocumentIngestedEvent { }
-
public void markDocumentAsFullyIngested(String link, long modificationTimeInSeconds) {
ingestedMap.put(link, modificationTimeInSeconds);
- eventBus.post(new DocumentIngestedEvent());
}
public Optional getIngestedDocumentModificationTimeInSeconds(String link) {
return Optional.ofNullable(ingestedMap.get(link));
}
- public void registerListener(Object listener) {
- eventBus.register(listener);
- }
-
public void unmarkDocumentAsFullyIngested(String link) {
ingestedMap.remove(link);
}
- public Set getFullyIngestedDocuments() {
- return new HashSet<>(ingestedMap.keySet());
+ @Override
+ protected String errorMessageForOpening() {
+ return "An error occurred while opening the fully ingested documents cache file. Fully ingested documents will not be stored in the next session.";
+ }
+
+ @Override
+ protected String errorMessageForOpeningLocalized() {
+ return Localization.lang("An error occurred while opening the fully ingested documents cache file. Fully ingested documents will not be stored in the next session.");
}
}
diff --git a/src/main/java/org/jabref/logic/ai/processingstatus/ProcessingInfo.java b/src/main/java/org/jabref/logic/ai/processingstatus/ProcessingInfo.java
new file mode 100644
index 00000000000..080d7dddef8
--- /dev/null
+++ b/src/main/java/org/jabref/logic/ai/processingstatus/ProcessingInfo.java
@@ -0,0 +1,57 @@
+package org.jabref.logic.ai.processingstatus;
+
+import java.util.Optional;
+
+import javafx.beans.property.ObjectProperty;
+import javafx.beans.property.ReadOnlyObjectProperty;
+import javafx.beans.property.SimpleObjectProperty;
+
+import jakarta.annotation.Nullable;
+
+public class ProcessingInfo {
+ private final O object;
+ private final ObjectProperty state;
+ private Optional exception = Optional.empty();
+ private Optional data = Optional.empty();
+
+ public ProcessingInfo(O object, ProcessingState state) {
+ this.object = object;
+ this.state = new SimpleObjectProperty<>(state);
+ }
+
+ public void setSuccess(@Nullable D data) {
+ // Listeners will probably handle only state property, so be careful to set the data BEFORE setting the state.
+ this.data = Optional.ofNullable(data);
+ this.state.set(ProcessingState.SUCCESS);
+ }
+
+ public void setException(Exception exception) {
+ // Listeners will probably handle only state property, so be careful to set the error message BEFORE setting the state.
+ this.exception = Optional.of(exception);
+ this.state.set(ProcessingState.ERROR);
+ }
+
+ public O getObject() {
+ return object;
+ }
+
+ public ProcessingState getState() {
+ return state.get();
+ }
+
+ public void setState(ProcessingState state) {
+ this.state.set(state);
+ }
+
+ public ReadOnlyObjectProperty stateProperty() {
+ return state;
+ }
+
+ public Optional getException() {
+ return exception;
+ }
+
+ public Optional getData() {
+ return data;
+ }
+}
diff --git a/src/main/java/org/jabref/logic/ai/processingstatus/ProcessingState.java b/src/main/java/org/jabref/logic/ai/processingstatus/ProcessingState.java
new file mode 100644
index 00000000000..466685bcc78
--- /dev/null
+++ b/src/main/java/org/jabref/logic/ai/processingstatus/ProcessingState.java
@@ -0,0 +1,8 @@
+package org.jabref.logic.ai.processingstatus;
+
+public enum ProcessingState {
+ PROCESSING,
+ SUCCESS,
+ ERROR,
+ STOPPED
+}
diff --git a/src/main/java/org/jabref/logic/ai/summarization/GenerateSummaryTask.java b/src/main/java/org/jabref/logic/ai/summarization/GenerateSummaryTask.java
index c99076a3653..1dfb447ecc5 100644
--- a/src/main/java/org/jabref/logic/ai/summarization/GenerateSummaryTask.java
+++ b/src/main/java/org/jabref/logic/ai/summarization/GenerateSummaryTask.java
@@ -1,7 +1,6 @@
package org.jabref.logic.ai.summarization;
import java.nio.file.Path;
-import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@@ -9,25 +8,35 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;
+import javafx.beans.property.BooleanProperty;
+
import org.jabref.gui.util.BackgroundTask;
-import org.jabref.logic.ai.AiService;
-import org.jabref.logic.ai.embeddings.FileToDocument;
+import org.jabref.logic.ai.ingestion.FileToDocument;
import org.jabref.logic.l10n.Localization;
import org.jabref.logic.util.ProgressCounter;
import org.jabref.model.database.BibDatabaseContext;
import org.jabref.model.entry.LinkedFile;
import org.jabref.preferences.FilePreferences;
+import org.jabref.preferences.ai.AiPreferences;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.DocumentSplitter;
import dev.langchain4j.data.document.splitter.DocumentSplitters;
import dev.langchain4j.data.segment.TextSegment;
+import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-public class GenerateSummaryTask extends BackgroundTask {
+/**
+ * This task generates a new summary for an entry.
+ * It will not check if summary was already generated.
+ * And it also does not store the summary.
+ *
+ * This task is created in the {@link SummariesService}, and stored then in a {@link SummariesStorage}.
+ */
+public class GenerateSummaryTask extends BackgroundTask {
private static final Logger LOGGER = LoggerFactory.getLogger(GenerateSummaryTask.class);
// Be careful when constructing prompt.
@@ -62,7 +71,9 @@ public class GenerateSummaryTask extends BackgroundTask {
private final BibDatabaseContext bibDatabaseContext;
private final String citationKey;
private final List linkedFiles;
- private final AiService aiService;
+ private final ChatLanguageModel chatLanguageModel;
+ private final BooleanProperty shutdownSignal;
+ private final AiPreferences aiPreferences;
private final FilePreferences filePreferences;
private final ProgressCounter progressCounter = new ProgressCounter();
@@ -70,14 +81,23 @@ public class GenerateSummaryTask extends BackgroundTask {
public GenerateSummaryTask(BibDatabaseContext bibDatabaseContext,
String citationKey,
List linkedFiles,
- AiService aiService,
- FilePreferences filePreferences) {
+ ChatLanguageModel chatLanguageModel,
+ BooleanProperty shutdownSignal,
+ AiPreferences aiPreferences,
+ FilePreferences filePreferences
+ ) {
this.bibDatabaseContext = bibDatabaseContext;
this.citationKey = citationKey;
this.linkedFiles = linkedFiles;
- this.aiService = aiService;
+ this.chatLanguageModel = chatLanguageModel;
+ this.shutdownSignal = shutdownSignal;
+ this.aiPreferences = aiPreferences;
this.filePreferences = filePreferences;
+ configure(citationKey);
+ }
+
+ private void configure(String citationKey) {
titleProperty().set(Localization.lang("Waiting summary for %0...", citationKey));
showToUser(true);
@@ -85,33 +105,28 @@ public GenerateSummaryTask(BibDatabaseContext bibDatabaseContext,
}
@Override
- protected Void call() throws Exception {
- LOGGER.info("Starting summarization task for entry {}", citationKey);
+ protected String call() throws Exception {
+ LOGGER.debug("Starting summarization task for entry {}", citationKey);
+
+ String result = null;
try {
- summarizeAll();
+ result = summarizeAll();
} catch (InterruptedException e) {
- LOGGER.info("There was a summarization task for {}. It will be canceled, because user quits JabRef.", citationKey);
+ LOGGER.debug("There was a summarization task for {}. It will be canceled, because user quits JabRef.", citationKey);
}
showToUser(false);
-
- LOGGER.info("Finished summarization task for entry {}", citationKey);
-
+ LOGGER.debug("Finished summarization task for entry {}", citationKey);
progressCounter.stop();
-
- return null;
+ return result;
}
- private void summarizeAll() throws InterruptedException {
+ private String summarizeAll() throws InterruptedException {
// Rationale for RuntimeException here:
// It follows the same idiom as in langchain4j. See {@link JabRefChatLanguageModel.generate}, this method
// is used internally in the summarization, and it also throws RuntimeExceptions.
- if (bibDatabaseContext.getDatabasePath().isEmpty()) {
- throw new RuntimeException(Localization.lang("No summary can be generated for entry '%0' as the database does not have path", citationKey));
- }
-
// Stream API would look better here, but we need to catch InterruptedException.
List linkedFilesSummary = new ArrayList<>();
for (LinkedFile linkedFile : linkedFiles) {
@@ -127,7 +142,7 @@ private void summarizeAll() throws InterruptedException {
throw new RuntimeException(Localization.lang("No summary can be generated for entry '%0'. Could not find attached linked files.", citationKey));
}
- LOGGER.info("All summaries for attached files of entry {} are generated. Generating final summary.", citationKey);
+ LOGGER.debug("All summaries for attached files of entry {} are generated. Generating final summary.", citationKey);
String finalSummary;
@@ -141,93 +156,86 @@ private void summarizeAll() throws InterruptedException {
doneOneWork();
- SummariesStorage.SummarizationRecord summaryRecord = new SummariesStorage.SummarizationRecord(
- LocalDateTime.now(),
- aiService.getPreferences().getAiProvider(),
- aiService.getPreferences().getSelectedChatModel(),
- finalSummary
- );
-
- aiService.getSummariesStorage().set(bibDatabaseContext.getDatabasePath().get(), citationKey, summaryRecord);
+ return finalSummary;
}
private Optional generateSummary(LinkedFile linkedFile) throws InterruptedException {
- LOGGER.info("Generating summary for file \"{}\" of entry {}", linkedFile.getLink(), citationKey);
+ LOGGER.debug("Generating summary for file \"{}\" of entry {}", linkedFile.getLink(), citationKey);
Optional path = linkedFile.findIn(bibDatabaseContext, filePreferences);
if (path.isEmpty()) {
LOGGER.error("Could not find path for a linked file \"{}\" of entry {}", linkedFile.getLink(), citationKey);
- LOGGER.info("Unable to generate summary for file \"{}\" of entry {}, because it was not found", linkedFile.getLink(), citationKey);
+ LOGGER.debug("Unable to generate summary for file \"{}\" of entry {}, because it was not found", linkedFile.getLink(), citationKey);
return Optional.empty();
}
- Optional document = FileToDocument.fromFile(path.get());
+ Optional document = new FileToDocument(shutdownSignal).fromFile(path.get());
if (document.isEmpty()) {
LOGGER.warn("Could not extract text from a linked file \"{}\" of entry {}. It will be skipped when generating a summary.", linkedFile.getLink(), citationKey);
- LOGGER.info("Unable to generate summary for file \"{}\" of entry {}, because it was not found", linkedFile.getLink(), citationKey);
+ LOGGER.debug("Unable to generate summary for file \"{}\" of entry {}, because it was not found", linkedFile.getLink(), citationKey);
return Optional.empty();
}
String linkedFileSummary = summarizeOneDocument(path.get().toString(), document.get().text());
- LOGGER.info("Summary for file \"{}\" of entry {} was generated successfully", linkedFile.getLink(), citationKey);
+ LOGGER.debug("Summary for file \"{}\" of entry {} was generated successfully", linkedFile.getLink(), citationKey);
return Optional.of(linkedFileSummary);
}
public String summarizeOneDocument(String filePath, String document) throws InterruptedException {
addMoreWork(1); // For the combination of summary chunks.
- DocumentSplitter documentSplitter = DocumentSplitters.recursive(aiService.getPreferences().getContextWindowSize() - MAX_OVERLAP_SIZE_IN_CHARS * 2 - estimateTokenCount(CHUNK_PROMPT_TEMPLATE), MAX_OVERLAP_SIZE_IN_CHARS);
+ DocumentSplitter documentSplitter = DocumentSplitters.recursive(aiPreferences.getContextWindowSize() - MAX_OVERLAP_SIZE_IN_CHARS * 2 - estimateTokenCount(CHUNK_PROMPT_TEMPLATE), MAX_OVERLAP_SIZE_IN_CHARS);
List chunkSummaries = documentSplitter.split(new Document(document)).stream().map(TextSegment::text).toList();
- LOGGER.info("The file \"{}\" of entry {} was split into {} chunk(s)", filePath, citationKey, chunkSummaries.size());
+ LOGGER.debug("The file \"{}\" of entry {} was split into {} chunk(s)", filePath, citationKey, chunkSummaries.size());
int passes = 0;
do {
passes++;
- LOGGER.info("Summarizing chunk(s) for file \"{}\" of entry {} ({} pass)", filePath, citationKey, passes);
+ LOGGER.debug("Summarizing chunk(s) for file \"{}\" of entry {} ({} pass)", filePath, citationKey, passes);
addMoreWork(chunkSummaries.size());
List list = new ArrayList<>();
for (String chunkSummary : chunkSummaries) {
- if (aiService.getShutdownSignal().get()) {
+ if (shutdownSignal.get()) {
throw new InterruptedException();
}
Prompt prompt = CHUNK_PROMPT_TEMPLATE.apply(Collections.singletonMap("document", chunkSummary));
- LOGGER.info("Sending request to AI provider to summarize a chunk from file \"{}\" of entry {}", filePath, citationKey);
- String chunk = aiService.getChatLanguageModel().generate(prompt.toString());
- LOGGER.info("Chunk summary for file \"{}\" of entry {} was generated successfully", filePath, citationKey);
+ LOGGER.debug("Sending request to AI provider to summarize a chunk from file \"{}\" of entry {}", filePath, citationKey);
+ String chunk = chatLanguageModel.generate(prompt.toString());
+ LOGGER.debug("Chunk summary for file \"{}\" of entry {} was generated successfully", filePath, citationKey);
list.add(chunk);
doneOneWork();
}
chunkSummaries = list;
- } while (estimateTokenCount(chunkSummaries) > aiService.getPreferences().getContextWindowSize() - estimateTokenCount(COMBINE_PROMPT_TEMPLATE));
+ } while (estimateTokenCount(chunkSummaries) > aiPreferences.getContextWindowSize() - estimateTokenCount(COMBINE_PROMPT_TEMPLATE));
if (chunkSummaries.size() == 1) {
doneOneWork(); // No need to call LLM for combination of summary chunks.
- LOGGER.info("Summary of the file \"{}\" of entry {} was generated successfully", filePath, citationKey);
+ LOGGER.debug("Summary of the file \"{}\" of entry {} was generated successfully", filePath, citationKey);
return chunkSummaries.getFirst();
}
Prompt prompt = COMBINE_PROMPT_TEMPLATE.apply(Collections.singletonMap("summaries", String.join("\n\n", chunkSummaries)));
- if (aiService.getShutdownSignal().get()) {
+ if (shutdownSignal.get()) {
throw new InterruptedException();
}
- LOGGER.info("Sending request to AI provider to combine summary chunk(s) for file \"{}\" of entry {}", filePath, citationKey);
- String result = aiService.getChatLanguageModel().generate(prompt.toString());
- LOGGER.info("Summary of the file \"{}\" of entry {} was generated successfully", filePath, citationKey);
+ LOGGER.debug("Sending request to AI provider to combine summary chunk(s) for file \"{}\" of entry {}", filePath, citationKey);
+ String result = chatLanguageModel.generate(prompt.toString());
+ LOGGER.debug("Summary of the file \"{}\" of entry {} was generated successfully", filePath, citationKey);
doneOneWork();
return result;
diff --git a/src/main/java/org/jabref/logic/ai/summarization/SummariesService.java b/src/main/java/org/jabref/logic/ai/summarization/SummariesService.java
new file mode 100644
index 00000000000..5734aadb212
--- /dev/null
+++ b/src/main/java/org/jabref/logic/ai/summarization/SummariesService.java
@@ -0,0 +1,137 @@
+package org.jabref.logic.ai.summarization;
+
+import java.time.LocalDateTime;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+
+import javafx.beans.property.BooleanProperty;
+
+import org.jabref.gui.util.TaskExecutor;
+import org.jabref.logic.ai.processingstatus.ProcessingInfo;
+import org.jabref.logic.ai.processingstatus.ProcessingState;
+import org.jabref.logic.ai.util.CitationKeyCheck;
+import org.jabref.model.database.BibDatabaseContext;
+import org.jabref.model.entry.BibEntry;
+import org.jabref.preferences.FilePreferences;
+import org.jabref.preferences.ai.AiPreferences;
+
+import dev.langchain4j.model.chat.ChatLanguageModel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Main class for generating summaries of {@link BibEntry}ies.
+ * Use this class in the logic and UI.
+ *
+ * In order for summary to be stored and loaded, the {@link BibEntry} must satisfy the following requirements:
+ * 1. There should exist an associated {@link BibDatabaseContext} for the {@link BibEntry}.
+ * 2. The database path of the associated {@link BibDatabaseContext} must be set.
+ * 3. The citation key of the {@link BibEntry} must be set and unique.
+ */
+public class SummariesService {
+ private static final Logger LOGGER = LoggerFactory.getLogger(SummariesService.class);
+
+ private final Map> summariesStatusMap = new HashMap<>();
+
+ private final AiPreferences aiPreferences;
+ private final SummariesStorage summariesStorage;
+ private final ChatLanguageModel chatLanguageModel;
+ private final BooleanProperty shutdownSignal;
+ private final FilePreferences filePreferences;
+ private final TaskExecutor taskExecutor;
+
+ public SummariesService(AiPreferences aiPreferences,
+ SummariesStorage summariesStorage,
+ ChatLanguageModel chatLanguageModel,
+ BooleanProperty shutdownSignal,
+ FilePreferences filePreferences,
+ TaskExecutor taskExecutor
+ ) {
+ this.aiPreferences = aiPreferences;
+ this.summariesStorage = summariesStorage;
+ this.chatLanguageModel = chatLanguageModel;
+ this.shutdownSignal = shutdownSignal;
+ this.filePreferences = filePreferences;
+ this.taskExecutor = taskExecutor;
+ }
+
+ /**
+ * Start generating summary of a {@link BibEntry}, if it was already generated.
+ * This method returns a {@link ProcessingInfo} that can be used for tracking state of the summarization.
+ * Returned {@link ProcessingInfo} is related to the passed {@link BibEntry}, so if you call this method twice
+ * on the same {@link BibEntry}, the method will return the same {@link ProcessingInfo}.
+ */
+ public ProcessingInfo summarize(BibEntry bibEntry, BibDatabaseContext bibDatabaseContext) {
+ return summariesStatusMap.computeIfAbsent(bibEntry, file -> {
+ ProcessingInfo processingInfo = new ProcessingInfo<>(bibEntry, ProcessingState.PROCESSING);
+ generateSummary(bibEntry, bibDatabaseContext, processingInfo);
+ return processingInfo;
+ });
+ }
+
+ private void generateSummary(BibEntry bibEntry, BibDatabaseContext bibDatabaseContext, ProcessingInfo processingInfo) {
+ if (bibDatabaseContext.getDatabasePath().isEmpty()) {
+ runGenerateSummaryTask(processingInfo, bibEntry, bibDatabaseContext);
+ } else if (bibEntry.getCitationKey().isEmpty() || CitationKeyCheck.citationKeyIsPresentAndUnique(bibDatabaseContext, bibEntry)) {
+ runGenerateSummaryTask(processingInfo, bibEntry, bibDatabaseContext);
+ } else {
+ Optional summary = summariesStorage.get(bibDatabaseContext.getDatabasePath().get(), bibEntry.getCitationKey().get());
+
+ if (summary.isEmpty()) {
+ runGenerateSummaryTask(processingInfo, bibEntry, bibDatabaseContext);
+ } else {
+ processingInfo.setSuccess(summary.get());
+ }
+ }
+ }
+
+ /**
+ * Method, similar to {@link #summarize(BibEntry, BibDatabaseContext)}, but it allows you to regenerate summary.
+ */
+ public void regenerateSummary(BibEntry bibEntry, BibDatabaseContext bibDatabaseContext) {
+ ProcessingInfo processingInfo = summarize(bibEntry, bibDatabaseContext);
+ processingInfo.setState(ProcessingState.PROCESSING);
+
+ if (bibDatabaseContext.getDatabasePath().isEmpty()) {
+ LOGGER.info("No database path is present. Could not clear stored summary for regeneration");
+ } else if (bibEntry.getCitationKey().isEmpty() || CitationKeyCheck.citationKeyIsPresentAndUnique(bibDatabaseContext, bibEntry)) {
+ LOGGER.info("No valid citation key is present. Could not clear stored summary for regeneration");
+ } else {
+ summariesStorage.clear(bibDatabaseContext.getDatabasePath().get(), bibEntry.getCitationKey().get());
+ }
+
+ generateSummary(bibEntry, bibDatabaseContext, processingInfo);
+ }
+
+ private void runGenerateSummaryTask(ProcessingInfo processingInfo, BibEntry bibEntry, BibDatabaseContext bibDatabaseContext) {
+ new GenerateSummaryTask(
+ bibDatabaseContext,
+ bibEntry.getCitationKey().orElse(""),
+ bibEntry.getFiles(),
+ chatLanguageModel,
+ shutdownSignal,
+ aiPreferences,
+ filePreferences)
+ .onSuccess(summary -> {
+ Summary Summary = new Summary(
+ LocalDateTime.now(),
+ aiPreferences.getAiProvider(),
+ aiPreferences.getSelectedChatModel(),
+ summary
+ );
+
+ processingInfo.setSuccess(Summary);
+
+ if (bibDatabaseContext.getDatabasePath().isEmpty()) {
+ LOGGER.info("No database path is present. Summary will not be stored in the next sessions");
+ } else if (CitationKeyCheck.citationKeyIsPresentAndUnique(bibDatabaseContext, bibEntry)) {
+ LOGGER.info("No valid citation key is present. Summary will not be stored in the next sessions");
+ } else {
+ summariesStorage.set(bibDatabaseContext.getDatabasePath().get(), bibEntry.getCitationKey().get(), Summary);
+ }
+ })
+ .onFailure(processingInfo::setException)
+ .executeWith(taskExecutor);
+ }
+}
diff --git a/src/main/java/org/jabref/logic/ai/summarization/SummariesStorage.java b/src/main/java/org/jabref/logic/ai/summarization/SummariesStorage.java
index 6b1ff964475..70e8d959813 100644
--- a/src/main/java/org/jabref/logic/ai/summarization/SummariesStorage.java
+++ b/src/main/java/org/jabref/logic/ai/summarization/SummariesStorage.java
@@ -1,94 +1,12 @@
package org.jabref.logic.ai.summarization;
-import java.io.Serializable;
import java.nio.file.Path;
-import java.time.LocalDateTime;
-import java.util.Map;
import java.util.Optional;
-import org.jabref.gui.StateManager;
-import org.jabref.model.database.BibDatabaseContext;
-import org.jabref.model.entry.event.FieldChangedEvent;
-import org.jabref.model.entry.field.InternalField;
-import org.jabref.preferences.ai.AiPreferences;
-import org.jabref.preferences.ai.AiProvider;
+public interface SummariesStorage {
+ void set(Path bibDatabasePath, String citationKey, Summary summary);
-import com.airhacks.afterburner.injection.Injector;
-import com.google.common.eventbus.EventBus;
-import com.google.common.eventbus.Subscribe;
-import jakarta.inject.Inject;
-import org.h2.mvstore.MVStore;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+ Optional get(Path bibDatabasePath, String citationKey);
-public class SummariesStorage {
- private final static Logger LOGGER = LoggerFactory.getLogger(SummariesStorage.class);
-
- private final MVStore mvStore;
-
- private final EventBus eventBus = new EventBus();
-
- @Inject private StateManager stateManager = Injector.instantiateModelOrService(StateManager.class);
-
- public record SummarizationRecord(LocalDateTime timestamp, AiProvider aiProvider, String model, String content) implements Serializable { }
-
- public SummariesStorage(AiPreferences aiPreferences, MVStore mvStore) {
- this.mvStore = mvStore;
- }
-
- public void registerListener(Object object) {
- eventBus.register(object);
- }
-
- public static class SummarySetEvent { }
-
- private Map getMap(Path bibDatabasePath) {
- return mvStore.openMap("summarizationRecords-" + bibDatabasePath.toString());
- }
-
- public void set(Path bibDatabasePath, String citationKey, SummarizationRecord summary) {
- getMap(bibDatabasePath).put(citationKey, summary);
- eventBus.post(new SummarySetEvent());
- }
-
- public Optional get(Path bibDatabasePath, String citationKey) {
- return Optional.ofNullable(getMap(bibDatabasePath).get(citationKey));
- }
-
- public void clear(Path bibDatabasePath, String citationKey) {
- getMap(bibDatabasePath).remove(citationKey);
- }
-
- @Subscribe
- private void fieldChangedEventListener(FieldChangedEvent event) {
- // TODO: This methods doesn't take into account if the new citation key is valid.
-
- if (event.getField() != InternalField.KEY_FIELD) {
- return;
- }
-
- Optional bibDatabaseContext = stateManager.getOpenDatabases().stream().filter(dbContext -> dbContext.getDatabase().getEntries().contains(event.getBibEntry())).findFirst();
-
- if (bibDatabaseContext.isEmpty()) {
- LOGGER.error("Could not listen to field change event because no database context was found. BibEntry: {}", event.getBibEntry());
- return;
- }
-
- Optional bibDatabasePath = bibDatabaseContext.get().getDatabasePath();
-
- if (bibDatabasePath.isEmpty()) {
- LOGGER.error("Could not listen to field change event because no database path was found. BibEntry: {}", event.getBibEntry());
- return;
- }
-
- Optional oldSummary = get(bibDatabasePath.get(), event.getOldValue());
-
- if (oldSummary.isEmpty()) {
- LOGGER.info("Old summary not found for {}", event.getNewValue());
- return;
- }
-
- set(bibDatabasePath.get(), event.getNewValue(), oldSummary.get());
- clear(bibDatabasePath.get(), event.getOldValue());
- }
+ void clear(Path bibDatabasePath, String citationKey);
}
diff --git a/src/main/java/org/jabref/logic/ai/summarization/Summary.java b/src/main/java/org/jabref/logic/ai/summarization/Summary.java
new file mode 100644
index 00000000000..eb2ee4d314d
--- /dev/null
+++ b/src/main/java/org/jabref/logic/ai/summarization/Summary.java
@@ -0,0 +1,8 @@
+package org.jabref.logic.ai.summarization;
+
+import java.io.Serializable;
+import java.time.LocalDateTime;
+
+import org.jabref.preferences.ai.AiProvider;
+
+public record Summary(LocalDateTime timestamp, AiProvider aiProvider, String model, String content) implements Serializable { }
diff --git a/src/main/java/org/jabref/logic/ai/summarization/storages/MVStoreSummariesStorage.java b/src/main/java/org/jabref/logic/ai/summarization/storages/MVStoreSummariesStorage.java
new file mode 100644
index 00000000000..47a1e1df995
--- /dev/null
+++ b/src/main/java/org/jabref/logic/ai/summarization/storages/MVStoreSummariesStorage.java
@@ -0,0 +1,45 @@
+package org.jabref.logic.ai.summarization.storages;
+
+import java.nio.file.Path;
+import java.util.Map;
+import java.util.Optional;
+
+import org.jabref.gui.DialogService;
+import org.jabref.logic.ai.summarization.SummariesStorage;
+import org.jabref.logic.ai.summarization.Summary;
+import org.jabref.logic.ai.util.MVStoreBase;
+import org.jabref.logic.l10n.Localization;
+
+public class MVStoreSummariesStorage extends MVStoreBase implements SummariesStorage {
+ private static final String SUMMARIES_MAP_PREFIX = "summaries";
+
+ public MVStoreSummariesStorage(Path path, DialogService dialogService) {
+ super(path, dialogService);
+ }
+
+ public void set(Path bibDatabasePath, String citationKey, Summary summary) {
+ getMap(bibDatabasePath).put(citationKey, summary);
+ }
+
+ public Optional get(Path bibDatabasePath, String citationKey) {
+ return Optional.ofNullable(getMap(bibDatabasePath).get(citationKey));
+ }
+
+ public void clear(Path bibDatabasePath, String citationKey) {
+ getMap(bibDatabasePath).remove(citationKey);
+ }
+
+ private Map getMap(Path bibDatabasePath) {
+ return mvStore.openMap(SUMMARIES_MAP_PREFIX + "-" + bibDatabasePath.toString());
+ }
+
+ @Override
+ protected String errorMessageForOpening() {
+ return "An error occurred while opening summary storage. Summaries of entries will not be stored in the next session.";
+ }
+
+ @Override
+ protected String errorMessageForOpeningLocalized() {
+ return Localization.lang("An error occurred while opening summary storage. Summaries of entries will not be stored in the next session.");
+ }
+}
diff --git a/src/main/java/org/jabref/logic/ai/util/CitationKeyCheck.java b/src/main/java/org/jabref/logic/ai/util/CitationKeyCheck.java
new file mode 100644
index 00000000000..0fc7678fcb3
--- /dev/null
+++ b/src/main/java/org/jabref/logic/ai/util/CitationKeyCheck.java
@@ -0,0 +1,18 @@
+package org.jabref.logic.ai.util;
+
+import org.jabref.model.database.BibDatabaseContext;
+import org.jabref.model.entry.BibEntry;
+
+public class CitationKeyCheck {
+ public static boolean citationKeyIsPresentAndUnique(BibDatabaseContext bibDatabaseContext, BibEntry bibEntry) {
+ return !hasEmptyCitationKey(bibEntry) && bibEntry.getCitationKey().map(key -> citationKeyIsUnique(bibDatabaseContext, key)).orElse(false);
+ }
+
+ private static boolean hasEmptyCitationKey(BibEntry bibEntry) {
+ return bibEntry.getCitationKey().map(String::isEmpty).orElse(true);
+ }
+
+ private static boolean citationKeyIsUnique(BibDatabaseContext bibDatabaseContext, String citationKey) {
+ return bibDatabaseContext.getDatabase().getNumberOfCitationKeyOccurrences(citationKey) == 1;
+ }
+}
diff --git a/src/main/java/org/jabref/logic/ai/misc/ErrorMessage.java b/src/main/java/org/jabref/logic/ai/util/ErrorMessage.java
similarity index 96%
rename from src/main/java/org/jabref/logic/ai/misc/ErrorMessage.java
rename to src/main/java/org/jabref/logic/ai/util/ErrorMessage.java
index 5a1fba81ffd..c15bdf7332c 100644
--- a/src/main/java/org/jabref/logic/ai/misc/ErrorMessage.java
+++ b/src/main/java/org/jabref/logic/ai/util/ErrorMessage.java
@@ -1,4 +1,4 @@
-package org.jabref.logic.ai.misc;
+package org.jabref.logic.ai.util;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
diff --git a/src/main/java/org/jabref/logic/ai/util/MVStoreBase.java b/src/main/java/org/jabref/logic/ai/util/MVStoreBase.java
new file mode 100644
index 00000000000..f7e14e9c72f
--- /dev/null
+++ b/src/main/java/org/jabref/logic/ai/util/MVStoreBase.java
@@ -0,0 +1,46 @@
+package org.jabref.logic.ai.util;
+
+import java.nio.file.Files;
+import java.nio.file.Path;
+
+import org.jabref.gui.DialogService;
+
+import jakarta.annotation.Nullable;
+import org.h2.mvstore.MVStore;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public abstract class MVStoreBase implements AutoCloseable {
+ private static final Logger LOGGER = LoggerFactory.getLogger(MVStoreBase.class);
+
+ protected final MVStore mvStore;
+
+ public MVStoreBase(Path path, DialogService dialogService) {
+ @Nullable Path mvStorePath = path;
+
+ try {
+ Files.createDirectories(path.getParent());
+ } catch (Exception e) {
+ LOGGER.error(errorMessageForOpening(), e);
+ dialogService.notify(errorMessageForOpeningLocalized());
+ mvStorePath = null;
+ }
+
+ this.mvStore = new MVStore.Builder()
+ .autoCommitDisabled()
+ .fileName(mvStorePath == null ? null : mvStorePath.toString())
+ .open();
+ }
+
+ public void commit() {
+ mvStore.commit();
+ }
+
+ public void close() {
+ mvStore.close();
+ }
+
+ protected abstract String errorMessageForOpening();
+
+ protected abstract String errorMessageForOpeningLocalized();
+}
diff --git a/src/main/java/org/jabref/logic/pdf/InterruptablePDFTextStripper.java b/src/main/java/org/jabref/logic/pdf/InterruptablePDFTextStripper.java
new file mode 100644
index 00000000000..ec136da8959
--- /dev/null
+++ b/src/main/java/org/jabref/logic/pdf/InterruptablePDFTextStripper.java
@@ -0,0 +1,26 @@
+package org.jabref.logic.pdf;
+
+import java.io.IOException;
+
+import javafx.beans.property.ReadOnlyBooleanProperty;
+
+import org.apache.pdfbox.pdmodel.PDPage;
+import org.apache.pdfbox.text.PDFTextStripper;
+
+public class InterruptablePDFTextStripper extends PDFTextStripper {
+ private final ReadOnlyBooleanProperty shutdownSignal;
+
+ public InterruptablePDFTextStripper(ReadOnlyBooleanProperty shutdownSignal) {
+ super();
+ this.shutdownSignal = shutdownSignal;
+ }
+
+ @Override
+ public void processPage(PDPage page) throws IOException {
+ if (shutdownSignal.get()) {
+ return;
+ }
+
+ super.processPage(page);
+ }
+}
diff --git a/src/main/java/org/jabref/logic/util/ProgressCounter.java b/src/main/java/org/jabref/logic/util/ProgressCounter.java
index 06093c09044..abd9b21cb1b 100644
--- a/src/main/java/org/jabref/logic/util/ProgressCounter.java
+++ b/src/main/java/org/jabref/logic/util/ProgressCounter.java
@@ -16,6 +16,12 @@
import ai.djl.util.Progress;
+/**
+ * Convenient class for managing ETA for background tasks.
+ *
+ * Always call {@link ProgressCounter#stop()} when your task is done, because there is a background timer that
+ * periodically updates the ETA.
+ */
public class ProgressCounter implements Progress {
private record ProgressMessage(int maxTime, String message) { }
@@ -96,10 +102,6 @@ private void update() {
Duration eta = oneWorkTime.multipliedBy(workMax.get() - workDone.get() <= 0 ? 1 : workMax.get() - workDone.get());
updateMessage(eta);
-
- if (workDone.get() != 0 && workMax.get() != 0 && workDone.get() == workMax.get()) {
- stop();
- }
}
@Override
diff --git a/src/main/java/org/jabref/model/TreeNode.java b/src/main/java/org/jabref/model/TreeNode.java
index fcf0b8ab26e..b34e2011a0f 100644
--- a/src/main/java/org/jabref/model/TreeNode.java
+++ b/src/main/java/org/jabref/model/TreeNode.java
@@ -7,6 +7,7 @@
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Predicate;
+import java.util.stream.Stream;
import javafx.collections.FXCollections;
import javafx.collections.ObservableList;
@@ -626,4 +627,8 @@ public List findChildrenSatisfying(Predicate matcher) {
return hits;
}
+
+ public Stream iterateOverTree() {
+ return Stream.concat(Stream.of((T) this), getChildren().stream().flatMap(TreeNode::iterateOverTree));
+ }
}
diff --git a/src/main/java/org/jabref/model/entry/BibEntry.java b/src/main/java/org/jabref/model/entry/BibEntry.java
index e23c0268a43..f0d2d43689e 100644
--- a/src/main/java/org/jabref/model/entry/BibEntry.java
+++ b/src/main/java/org/jabref/model/entry/BibEntry.java
@@ -730,7 +730,15 @@ public String toString() {
return CanonicalBibEntry.getCanonicalRepresentation(this);
}
+ public String getAuthorTitleYear() {
+ return getAuthorTitleYear(0);
+ }
+
/**
+ * Creates a short textual description of the entry in the format: Author1, Author2: Title (Year)
+ *
+ * If 0
is passed as maxCharacters
, the description is not truncated.
+ *
* @param maxCharacters The maximum number of characters (additional
* characters are replaced with "..."). Set to 0 to disable truncation.
* @return A short textual description of the entry in the format:
diff --git a/src/main/java/org/jabref/model/groups/GroupTreeNode.java b/src/main/java/org/jabref/model/groups/GroupTreeNode.java
index d5954181d9d..6e5b3284a68 100644
--- a/src/main/java/org/jabref/model/groups/GroupTreeNode.java
+++ b/src/main/java/org/jabref/model/groups/GroupTreeNode.java
@@ -8,6 +8,9 @@
import java.util.Optional;
import java.util.stream.Collectors;
+import javafx.beans.property.ObjectProperty;
+import javafx.beans.property.SimpleObjectProperty;
+
import org.jabref.model.FieldChange;
import org.jabref.model.TreeNode;
import org.jabref.model.database.BibDatabase;
@@ -22,7 +25,7 @@
public class GroupTreeNode extends TreeNode {
private static final String PATH_DELIMITER = " > ";
- private AbstractGroup group;
+ private ObjectProperty groupProperty = new SimpleObjectProperty<>();
/**
* Creates this node and associates the specified group with it.
@@ -44,7 +47,11 @@ public static GroupTreeNode fromGroup(AbstractGroup group) {
* @return the group associated with this node
*/
public AbstractGroup getGroup() {
- return group;
+ return groupProperty.get();
+ }
+
+ public ObjectProperty getGroupProperty() {
+ return groupProperty;
}
/**
@@ -55,7 +62,7 @@ public AbstractGroup getGroup() {
*/
@Deprecated
public void setGroup(AbstractGroup newGroup) {
- this.group = Objects.requireNonNull(newGroup);
+ this.groupProperty.set(Objects.requireNonNull(newGroup));
}
/**
@@ -69,7 +76,7 @@ public void setGroup(AbstractGroup newGroup) {
public List setGroup(AbstractGroup newGroup, boolean shouldKeepPreviousAssignments,
boolean shouldRemovePreviousAssignments, List entriesInDatabase) {
AbstractGroup oldGroup = getGroup();
- group = Objects.requireNonNull(newGroup);
+ groupProperty.set(Objects.requireNonNull(newGroup));
List