Skip to content

Commit

Permalink
add support for Claude AI
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeloffner committed Dec 1, 2024
1 parent 8fe012d commit 3bfc81d
Show file tree
Hide file tree
Showing 6 changed files with 467 additions and 2 deletions.
141 changes: 141 additions & 0 deletions core/src/main/java/lucee/runtime/ai/anthropic/ClaudeEngine.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package lucee.runtime.ai.anthropic;

import java.net.URL;
import java.util.ArrayList;
import java.util.List;

import lucee.commons.lang.StringUtil;
import lucee.commons.net.HTTPUtil;
import lucee.loader.util.Util;
import lucee.runtime.ai.AIEngine;
import lucee.runtime.ai.AIEngineFactory;
import lucee.runtime.ai.AIEngineSupport;
import lucee.runtime.ai.AIModel;
import lucee.runtime.ai.AISession;
import lucee.runtime.ai.AIUtil;
import lucee.runtime.exp.ApplicationException;
import lucee.runtime.exp.PageException;
import lucee.runtime.net.proxy.ProxyData;
import lucee.runtime.op.Caster;
import lucee.runtime.type.Struct;
import lucee.runtime.type.util.KeyConstants;

public class ClaudeEngine extends AIEngineSupport {
private static final String DEFAULT_URL = "https://api.anthropic.com/v1/";
private static final int DEFAULT_CONVERSATION_SIZE_LIMIT = 100;
private static final String DEFAULT_VERSION = "2023-06-01";
private static final String DEFAULT_CHARSET = "UTF-8";

private String label = "Claude";
private Struct properties;
private String apiKey;
private URL baseURL;
private String model;
private String systemMessage;
private String version;
private Double temperature;
private long timeout;
private long initTimeout;
private String charset;
ProxyData proxy = null;

@Override
public AIEngine init(AIEngineFactory factory, Struct properties) throws PageException {
super.init(factory);
this.properties = properties;

// API Key
apiKey = Caster.toString(properties.get(KeyConstants._apiKey, null), null);
if (Util.isEmpty(apiKey, true)) throw new ApplicationException("the property [apiKey] is required for Claude");

// Base URL
String urlStr = Caster.toString(properties.get(KeyConstants._URL, DEFAULT_URL), DEFAULT_URL);
try {
baseURL = HTTPUtil.toURL(urlStr.trim(), HTTPUtil.ENCODED_AUTO);
}
catch (Exception e) {
throw Caster.toPageException(e);
}

// timeout
timeout = Caster.toLongValue(properties.get(KeyConstants._timeout, null), DEFAULT_TIMEOUT);
initTimeout = Caster.toLongValue(properties.get("initTimeout", null), DEFAULT_TIMEOUT * 2);

// temperature
temperature = Caster.toDouble(properties.get(KeyConstants._temperature, null), null);
if (temperature != null && (temperature < 0D || temperature > 1D)) {
throw new ApplicationException("temperature has to be a number between 0 and 1, now it is [" + temperature + "]");
}

// Model
// TODO read available models and throw exception
model = Caster.toString(properties.get(KeyConstants._model, "claude-3-sonnet-20240229"), "claude-3-sonnet-20240229");

// System Message
systemMessage = Caster.toString(properties.get(KeyConstants._message, null), null);
// version
version = Caster.toString(properties.get(KeyConstants._version, DEFAULT_VERSION), DEFAULT_VERSION);
// charset
charset = Caster.toString(properties.get(KeyConstants._charset, null), DEFAULT_CHARSET);
if (Util.isEmpty(charset, true)) charset = DEFAULT_CHARSET;

return this;
}

@Override
public AISession createSession(String initialMessage, long timeout) {
return new ClaudeSession(this, StringUtil.isEmpty(initialMessage, true) ? systemMessage : initialMessage.trim(), timeout);
}

@Override
public String getLabel() {
return label;
}

@Override
public String getModel() {
return model;
}

@Override
public long getTimeout() {
return timeout;
}

public URL getBaseURL() {
return baseURL;
}

public String getApiKey() {
return apiKey;
}

public String getVersion() {
return version;
}

public String getCharset() {
return charset;
}

@Override
public List<AIModel> getModels() {
// not supported by Claude YET
return new ArrayList<>();

}

private void throwIfError(Struct raw) throws PageException {
Struct err = Caster.toStruct(raw.get(KeyConstants._error, null), null);
if (err != null) {
throw AIUtil.toException(this, Caster.toString(err.get(KeyConstants._message)), Caster.toString(err.get(KeyConstants._type, null), null),
Caster.toString(err.get(KeyConstants._code, null), null));
}
}

@Override
public int getConversationSizeLimit() {
// TODO Auto-generated method stub
return DEFAULT_CONVERSATION_SIZE_LIMIT;
}
}
83 changes: 83 additions & 0 deletions core/src/main/java/lucee/runtime/ai/anthropic/ClaudeResponse.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package lucee.runtime.ai.anthropic;

import java.util.Iterator;

import lucee.commons.io.CharsetUtil;
import lucee.commons.lang.StringUtil;
import lucee.runtime.ai.Response;
import lucee.runtime.converter.ConverterException;
import lucee.runtime.converter.JSONConverter;
import lucee.runtime.converter.JSONDateFormat;
import lucee.runtime.listener.SerializationSettings;
import lucee.runtime.op.Caster;
import lucee.runtime.type.Array;
import lucee.runtime.type.Struct;
import lucee.runtime.type.util.KeyConstants;

public class ClaudeResponse implements Response {
private Struct raw;
private String charset;
private long tokens = -1L;

public ClaudeResponse(Struct raw, String charset) {
this.raw = raw;
this.charset = charset;
}

@Override
public String toString() {
try {
JSONConverter json = new JSONConverter(false, CharsetUtil.toCharset(charset), JSONDateFormat.PATTERN_CF, false);
return json.serialize(null, raw, SerializationSettings.SERIALIZE_AS_UNDEFINED, true);
}
catch (ConverterException e) {
return raw.toString();
}
}

@Override
public String getAnswer() {
// Claude's response structure is different from OpenAI
Array arr = Caster.toArray(raw.get(KeyConstants._content, null), null);
if (arr == null || arr.size() == 0) return null;
Iterator<Object> it = arr.valueIterator();
Struct sct;
String type, text;
StringBuilder sb = new StringBuilder();
while (it.hasNext()) {
sct = Caster.toStruct(it.next(), null);
type = Caster.toString(sct.get(KeyConstants._type, null), null);
if ("text".equals(type) || "code".equals(type)) {
text = Caster.toString(sct.get(KeyConstants._text, null), null);
if (!StringUtil.isEmpty(text, true)) {
if (sb.length() > 0) sb.append('\n');
sb.append(text);
}
}
}
if (sb.length() > 0) return sb.toString();
return null;
// TODO support image?

}

public Struct getData() {
return raw;
}

@Override
public long getTotalTokenUsed() {
if (tokens == -1L) {
// Claude's usage structure is different from OpenAI
Struct usage = Caster.toStruct(raw.get("usage", null), null);
if (usage == null) return tokens = 0L;

// Claude reports input_tokens and output_tokens separately
long inputTokens = Caster.toLongValue(usage.get("input_tokens", null), 0L);
long outputTokens = Caster.toLongValue(usage.get("output_tokens", null), 0L);

return tokens = inputTokens + outputTokens;
}
return tokens;
}
}
169 changes: 169 additions & 0 deletions core/src/main/java/lucee/runtime/ai/anthropic/ClaudeSession.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
// Session implementation
package lucee.runtime.ai.anthropic;

import java.io.BufferedReader;
import java.io.InputStreamReader;

import org.apache.http.Header;
import org.apache.http.HttpEntity;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;

import lucee.commons.io.CharsetUtil;
import lucee.commons.lang.StringUtil;
import lucee.commons.lang.mimetype.MimeType;
import lucee.loader.util.Util;
import lucee.runtime.ai.AIResponseListener;
import lucee.runtime.ai.AISessionSupport;
import lucee.runtime.ai.AIUtil;
import lucee.runtime.ai.Conversation;
import lucee.runtime.ai.ConversationImpl;
import lucee.runtime.ai.RequestSupport;
import lucee.runtime.ai.Response;
import lucee.runtime.converter.JSONConverter;
import lucee.runtime.converter.JSONDateFormat;
import lucee.runtime.exp.ApplicationException;
import lucee.runtime.exp.PageException;
import lucee.runtime.interpreter.JSONExpressionInterpreter;
import lucee.runtime.listener.SerializationSettings;
import lucee.runtime.op.Caster;
import lucee.runtime.type.Array;
import lucee.runtime.type.ArrayImpl;
import lucee.runtime.type.Struct;
import lucee.runtime.type.StructImpl;
import lucee.runtime.type.util.KeyConstants;

public class ClaudeSession extends AISessionSupport {
private ClaudeEngine engine;
private String systemMessage;

public ClaudeSession(ClaudeEngine engine, String systemMessage, long timeout) {
super(engine, timeout);
this.engine = engine;
this.systemMessage = systemMessage;
}

@Override
public Response inquiry(String message, AIResponseListener listener) throws PageException {
try {

Struct requestBody = new StructImpl();
requestBody.set(KeyConstants._model, engine.getModel());
requestBody.set("max_tokens", 4096);
requestBody.set("stream", listener != null);

// Set system message at top level if exists
if (!StringUtil.isEmpty(systemMessage)) {
requestBody.set(KeyConstants._system, systemMessage);
}

// Build messages array with system and conversation history
Array messages = new ArrayImpl();

// Add conversation history
for (Conversation c: getHistoryAsList()) {
messages.append(createMessage("user", c.getRequest().getQuestion()));
messages.append(createMessage("assistant", c.getResponse().getAnswer()));
}

// Add new message
messages.append(createMessage("user", message));
requestBody.set("messages", messages);

// Make API request
HttpPost post = new HttpPost(engine.getBaseURL().toExternalForm() + "messages");
post.setHeader("Content-Type", "application/json");
post.setHeader("x-api-key", engine.getApiKey());
post.setHeader("anthropic-version", engine.getVersion());

// Convert request body to JSON
JSONConverter json = new JSONConverter(true, CharsetUtil.UTF8, JSONDateFormat.PATTERN_CF, false);
String str = json.serialize(null, requestBody, SerializationSettings.SERIALIZE_AS_COLUMN, null);

// Create entity and set it to the post request
StringEntity entity = new StringEntity(str, engine.getCharset());
post.setEntity(entity);

// Set timeout
int timeout = Caster.toIntValue(getTimeout());
RequestConfig config = RequestConfig.custom().setConnectTimeout(timeout).setSocketTimeout(timeout).build();
post.setConfig(config);

// Execute request
try (CloseableHttpClient httpClient = HttpClients.createDefault(); CloseableHttpResponse response = httpClient.execute(post)) {

HttpEntity responseEntity = response.getEntity();
Header ct = responseEntity.getContentType();
MimeType mt = MimeType.getInstance(ct.getValue());

String t = mt.getType() + "/" + mt.getSubtype();
String cs = mt.getCharset() != null ? mt.getCharset().toString() : engine.getCharset();
// Handle JSON response
if ("application/json".equals(t)) {
if (Util.isEmpty(cs, true)) cs = engine.getCharset();
String rawStr = EntityUtils.toString(responseEntity, engine.getCharset());

Struct raw = Caster.toStruct(new JSONExpressionInterpreter().interpret(null, rawStr));

// Check for errors
Struct err = Caster.toStruct(raw.get(KeyConstants._error, null), null);
if (err != null) {
throw AIUtil.toException(this.getEngine(), Caster.toString(err.get(KeyConstants._message)), Caster.toString(err.get(KeyConstants._type, null), null),
Caster.toString(err.get(KeyConstants._code, null), null));
}

// Create response object
Response r = new ClaudeResponse(raw, cs);
AIUtil.addConversation(engine, getHistoryAsList(), new ConversationImpl(new RequestSupport(message), r));

return r;
}
// Handle streaming response if needed
else if ("text/event-stream".equals(t)) {
if (Util.isEmpty(cs, true)) cs = engine.getCharset();
JSONExpressionInterpreter interpreter = new JSONExpressionInterpreter();
Response r = new ClaudeStreamResponse(cs, listener);

try (BufferedReader reader = new BufferedReader(
cs == null ? new InputStreamReader(responseEntity.getContent()) : new InputStreamReader(responseEntity.getContent(), cs))) {
String line;
while ((line = reader.readLine()) != null) {
if (!line.startsWith("data: ")) continue;
line = line.substring(6);
if ("[DONE]".equals(line)) break;
((ClaudeStreamResponse) r).addPart(Caster.toStruct(interpreter.interpret(null, line)));
}
}

AIUtil.addConversation(engine, getHistoryAsList(), new ConversationImpl(new RequestSupport(message), r));
return r;
}
else {
throw new ApplicationException("The AI did answer with the mime type [" + t + "] that is not supported, only [application/json] is supported");
}
}

}
catch (Exception e) {
throw Caster.toPageException(e);
}
}

private Struct createMessage(String role, String content) {
Struct message = new StructImpl();
message.setEL(KeyConstants._role, role);
message.setEL(KeyConstants._content, content);
return message;
}

@Override
public void release() throws PageException {
// TODO Auto-generated method stub

}
}
Loading

0 comments on commit 3bfc81d

Please sign in to comment.