Skip to content

Commit

Permalink
Merge pull request #209 from AdaptiveScale/release-2.2.1
Browse files Browse the repository at this point in the history
Release 2.2.1
  • Loading branch information
nbesimi authored Apr 16, 2024
2 parents 3c001e3 + 48226b5 commit f35fac1
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 31 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ Parameter | Description
--scala | Generates the Scala SQL file.

#### query
The query command allows you to use natural language commands to query your databases, transforming these commands into SQL SELECT statements. By leveraging the capabilities of AI and LLMs, specifically OpenAI models, it interprets user queries and generates the corresponding SQL queries. For effective use of this command, users need to provide their OpenAI API Key and specify the OpenAI model to be utilized. The output will be written to a CSV file. The max number of rows that will be returned is 200. You can overwrite this value, or remove completely the limit.The default openai model that is used is gpt-3.5-turbo.
The query command allows you to use natural language commands to query your databases, transforming these commands into SQL SELECT statements. By leveraging the capabilities of AI and LLMs, specifically OpenAI models, it interprets user queries and generates the corresponding SQL queries. For effective use of this command, users need to provide their OpenAI API Key and specify the OpenAI model to be utilized. The output will be written to a CSV file. The max number of rows that will be returned is 200. You can overwrite this value, or remove completely the limit. The default openai model that is used is gpt-3.5-turbo.

rosetta [-c, --config CONFIG_FILE] query [-h, --help] [-s, --source CONNECTION_NAME] [-q, --query "Natural language QUERY"]

Expand Down
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ repositories {

allprojects {
group = 'com.adaptivescale'
version = '2.2.0'
version = '2.2.1'
sourceCompatibility = 11
targetCompatibility = 11
}
Expand Down
2 changes: 1 addition & 1 deletion cli/src/main/java/com/adaptivescale/rosetta/cli/Cli.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
@Slf4j
@CommandLine.Command(name = "cli",
mixinStandardHelpOptions = true,
version = "2.2.0",
version = "2.2.1",
description = "Declarative Database Management - DDL Transpiler"
)
class Cli implements Callable<Void> {
Expand Down
64 changes: 36 additions & 28 deletions queryhelper/src/main/java/queryhelper/service/AIService.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import queryhelper.utils.FileUtils;
import queryhelper.utils.PromptUtils;

import java.nio.file.Files;
import java.nio.file.Path;
import java.sql.Driver;
import java.sql.SQLException;
Expand All @@ -33,42 +32,18 @@ public class AIService {

public static GenericResponse generateQuery(String userQueryRequest, String apiKey, String aiModel, String databaseDDL, Connection source, Integer showRowLimit, Path dataDirectory) {

Gson gson = new Gson();
GenericResponse response = new GenericResponse();
QueryDataResponse data = new QueryDataResponse();
QueryRequest queryRequest = new QueryRequest();
queryRequest.setQuery(userQueryRequest);

String query;
String aiOutputStr;
String prompt;

OpenAiChatModel.OpenAiChatModelBuilder model = OpenAiChatModel
.builder()
.temperature(0.1)
.apiKey(apiKey)
.modelName(AI_MODEL);

if (aiModel != null && !aiModel.isEmpty()) {
model.modelName(aiModel);
}

prompt = PromptUtils.queryPrompt(queryRequest, databaseDDL, source);

try { // Check if we have a properly set API key & that openAI services aren't down
aiOutputStr = model.build().generate(prompt);
QueryRequest aiOutputObj = gson.fromJson(aiOutputStr, QueryRequest.class);
query = aiOutputObj.getQuery();
} catch (JsonSyntaxException e) {
return ErrorUtils.invalidResponseError(e);
} catch (Exception e) {
return ErrorUtils.openAIError(e);
}
query = generateAIOutput(apiKey, aiModel, queryRequest, source, databaseDDL);

boolean selectStatement = isSelectStatement(query);
if (!selectStatement) {
GenericResponse errorResponse = new GenericResponse();
errorResponse.setMessage("Generated query, execute on your own will: " + aiOutputStr);
errorResponse.setMessage("Generated query, execute on your own will: " + query);
errorResponse.setStatusCode(200);
}

Expand All @@ -83,7 +58,7 @@ public static GenericResponse generateQuery(String userQueryRequest, String apiK
String csvFile = createCSVFile(queryDataResponse, queryRequest.getQuery(), dataDirectory);

response.setMessage(
aiOutputStr + "\n" +
query + "\n" +
"Total rows: " + data.getRecords().size() + "\n" +
"Your response is saved to a CSV file named '" + csvFile + "'!"
);
Expand Down Expand Up @@ -133,4 +108,37 @@ private static String createCSVFile(QueryDataResponse queryDataResponse, String
throw new RuntimeException(genericResponse.getMessage());
}
}

public static String generateAIOutput(String apiKey, String aiModel, QueryRequest queryRequest, Connection source, String databaseDDL) {
Gson gson = new Gson();
String aiOutputStr;
String query;

OpenAiChatModel.OpenAiChatModelBuilder model = OpenAiChatModel
.builder()
.temperature(0.1)
.apiKey(apiKey)
.modelName(AI_MODEL);

if (aiModel != null && !aiModel.isEmpty()) {
model.modelName(aiModel);
}

String prompt = PromptUtils.queryPrompt(queryRequest, databaseDDL, source);

try {
aiOutputStr = model.build().generate(prompt);
QueryRequest aiOutputObj = gson.fromJson(aiOutputStr, QueryRequest.class);
query = aiOutputObj.getQuery();
} catch (JsonSyntaxException e) {
GenericResponse genericResponse = ErrorUtils.invalidResponseError(e);
throw new RuntimeException(genericResponse.getMessage());
} catch (Exception e) {
GenericResponse genericResponse = ErrorUtils.openAIError(e);
throw new RuntimeException(genericResponse.getMessage());
}

return query;
}

}

0 comments on commit f35fac1

Please sign in to comment.