Skip to content

Commit

Permalink
Fix ai issue (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
RealYusufIsmail authored Sep 11, 2024
1 parent cbfafe6 commit 68627be
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.github.yusufsdiscordbot.mystiguardian.utils.PermChecker;
import java.util.EnumSet;
import java.util.List;
import java.util.Optional;
import lombok.val;
import net.dv8tion.jda.api.Permission;
import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent;
Expand All @@ -39,13 +40,20 @@ public void onSlashCommandInteractionEvent(
@NotNull SlashCommandInteractionEvent event,
@NotNull MystiGuardianUtils.ReplyUtils replyUtils,
PermChecker permChecker) {

val question = event.getOption("question", OptionMapping::getAsString);
var newChat = Optional.ofNullable(event.getOption("new-chat", OptionMapping::getAsBoolean));
val model = Optional.ofNullable(event.getOption("model", OptionMapping::getAsString));

val githubAIModel =
MystiGuardianUtils.getGithubAIModel(
event.getGuild().getIdLong(), event.getMember().getIdLong(), model);

val githubAIModel = MystiGuardianUtils.getGithubAIModel(event.getGuild().getIdLong());
event.deferReply().queue();

githubAIModel
.askQuestion(question)
.thenAccept(replyUtils::sendSuccess)
.askQuestion(question, event.getMember().getIdLong(), newChat.orElse(Boolean.FALSE))
.thenAccept((answer) -> event.getHook().editOriginal(answer).queue())
.exceptionally(
throwable -> {
replyUtils.sendError("An error occurred while asking the question");
Expand All @@ -69,6 +77,7 @@ public String getDescription() {
public List<OptionData> getOptions() {
return List.of(
new OptionData(OptionType.STRING, "question", "The question to ask the AI model", true),
new OptionData(OptionType.BOOLEAN, "new-chat", "Start a new chat session", false),
new OptionData(OptionType.STRING, "model", "The model to use", false));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,51 @@
import io.github.yusufsdiscordbot.mystiguardian.utils.MystiGuardianUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import okhttp3.*;
import org.jetbrains.annotations.NotNull;

public class GithubAIModel {
private final String model;
private String model;
private final String token;
private final List<Message> context = new ArrayList<>();
private final Map<Long, List<Message>> context = new HashMap<>();
private final List<Message> initialMessages = new ArrayList<>();

private final OkHttpClient client;
private final ObjectMapper mapper;

public GithubAIModel(String model, String initialPrompt) {
public GithubAIModel(String model, String initialPrompt, Long memberId) {
this.model = model;
this.token = MystiGuardianUtils.getMainConfig().githubToken();
this.client = new OkHttpClient();
this.mapper = new ObjectMapper();
this.context.add(new Message("system", initialPrompt));
initialMessages.add(new Message("system", initialPrompt));
context.put(memberId, new ArrayList<>(initialMessages));
}

public CompletableFuture<String> askQuestion(String question) {
context.add(new Message("user", question));
return sendRequest();
public CompletableFuture<String> askQuestion(String question, Long memberId, boolean newChat) {
if (!context.containsKey(memberId)) {
context.put(memberId, new ArrayList<>(initialMessages));
}

if (newChat) {
context.put(memberId, new ArrayList<>(initialMessages));
}

context.get(memberId).add(new Message("user", question));
return sendRequest(memberId);
}

private CompletableFuture<String> sendRequest() {
@NotNull
private CompletableFuture<String> sendRequest(long memberId) {
CompletableFuture<String> future = new CompletableFuture<>();
try {
String url = "https://models.inference.ai.azure.com/chat/completions";

RequestBody requestBody = getRequestBody();
RequestBody requestBody = getRequestBody(memberId);
Request request =
new Request.Builder()
.url(url)
Expand All @@ -70,7 +84,6 @@ private CompletableFuture<String> sendRequest() {
new Callback() {
@Override
public void onFailure(@NotNull Call call, @NotNull IOException e) {
MystiGuardianUtils.logger.error("Error while sending request to AI model", e);
future.completeExceptionally(e);
}

Expand All @@ -79,57 +92,50 @@ public void onResponse(@NotNull Call call, @NotNull Response response)
throws IOException {
if (response.isSuccessful()) {
String responseBody = response.body().string();
MystiGuardianUtils.logger.debug("Request sent to AI model successfully");
future.complete(parseResponse(responseBody));
future.complete(parseResponse(responseBody, memberId));
} else {
MystiGuardianUtils.logger.error(
"Error while sending request to AI model. Response code: {}, Response body: {}",
response.code(),
response.body().string());
future.completeExceptionally(
new RuntimeException(
"Request failed with status code: " + response.code()));
}
}
});
} catch (JsonProcessingException e) {
MystiGuardianUtils.logger.error("Error while processing JSON for AI model request", e);
future.completeExceptionally(e);
}
return future;
}

private @NotNull RequestBody getRequestBody() throws JsonProcessingException {
private @NotNull RequestBody getRequestBody(Long userId) throws JsonProcessingException {
ObjectNode payload = mapper.createObjectNode();
ArrayNode messages = payload.putArray("messages");

for (Message message : context) {
for (Message message : context.get(userId)) {
ObjectNode messageNode = messages.addObject();
messageNode.put("role", message.role());
messageNode.put("content", message.content());
}

payload.put("model", model);

String jsonInputString = mapper.writeValueAsString(payload);

MystiGuardianUtils.logger.debug("Request to AI model: {}", jsonInputString);
return RequestBody.create(jsonInputString, MediaType.get("application/json"));
return RequestBody.create(
mapper.writeValueAsString(payload), MediaType.get("application/json"));
}

private String parseResponse(String responseBody) {
MystiGuardianUtils.logger.debug("Response from AI model: {}", responseBody);
private String parseResponse(String responseBody, long memberId) {
try {
ObjectNode responseJson = (ObjectNode) mapper.readTree(responseBody);
String assistantResponse =
responseJson.get("choices").get(0).get("message").get("content").asText();
context.add(new Message("assistant", assistantResponse));
context.get(memberId).add(new Message("assistant", assistantResponse));
return assistantResponse;
} catch (JsonProcessingException e) {
MystiGuardianUtils.logger.error("Error while parsing AI model response", e);
return null;
}
}

public void setNewModel(String model) {
this.model = model;
}

private record Message(String role, String content) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ public class MystiGuardianUtils {
private static final SystemInfo systemInfo = new SystemInfo();
private static final CentralProcessor processor = systemInfo.getHardware().getProcessor();
private static final Map<Long, GithubAIModel> githubAIModel = new HashMap<>();
private static final String AI_PROMPT =
"You are MystiGuardian, your server’s mystical protector and entertainment extraordinaire, created by RealYusufIsmail. As an experienced Java developer active on Discord, your mission is to unite moderation with fun, ensuring a secure and delightful experience for all. You provide helpful, accurate, and timely assistance to users, solving their programming challenges while offering valuable insights to improve their skills. Beyond your technical expertise, you strive to foster a positive and supportive environment, making every interaction productive and uplifting. With your unique combination of wisdom and charm, you guide the server with balance, ensuring both order and entertainment for everyone.";

@Getter
private static final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(3);
Expand Down Expand Up @@ -242,14 +244,17 @@ public static String getJavaVendor() {
}

@NotNull
public static GithubAIModel getGithubAIModel(long id) {
if (!githubAIModel.containsKey(id)) {
return new GithubAIModel(
"meta-llama-3-8b-instruct",
"You are a java developer, existing on discord. You aim to help others with their problems and make their day better.");
public static GithubAIModel getGithubAIModel(long guildId, long userId, Optional<String> model) {
if (!githubAIModel.containsKey(guildId)) {
githubAIModel.put(
guildId, new GithubAIModel(model.orElse("meta-llama-3-8b-instruct"), AI_PROMPT, userId));
}

return getGithubAIModel(id);
val githubMode = githubAIModel.get(guildId);

model.ifPresent(githubMode::setNewModel);

return githubMode;
}

public static synchronized void clearGithubAIModel() {
Expand Down

0 comments on commit 68627be

Please sign in to comment.