From 38a67b4c17f0149ab83806b728d975794e01ea99 Mon Sep 17 00:00:00 2001 From: Fehmi Havziu <106117501+Femi3211@users.noreply.github.com> Date: Thu, 11 Apr 2024 11:22:24 +0200 Subject: [PATCH] Feature/query command (#205) added: AI integration to query the data using natural language through query command --- README.md | 60 ++++++++ cli/build.gradle | 4 +- .../com/adaptivescale/rosetta/cli/Cli.java | 56 +++++++- .../rosetta/cli/model/Config.java | 12 ++ queryhelper/build.gradle | 26 ++++ .../queryhelper/pojo/GenericResponse.java | 38 +++++ .../OpenAIHttpExceptionErrorResponse.java | 63 ++++++++ .../queryhelper/pojo/QueryDataResponse.java | 50 +++++++ .../java/queryhelper/pojo/QueryRequest.java | 13 ++ .../java/queryhelper/service/AIService.java | 136 ++++++++++++++++++ .../java/queryhelper/utils/ErrorUtils.java | 56 ++++++++ .../java/queryhelper/utils/FileUtils.java | 44 ++++++ .../java/queryhelper/utils/PromptUtils.java | 20 +++ .../main/resources/static/output_format.json | 3 + .../java/com/adaptivescale/AIServiceTest.java | 22 +++ settings.gradle | 1 + 16 files changed, 596 insertions(+), 8 deletions(-) create mode 100644 queryhelper/build.gradle create mode 100644 queryhelper/src/main/java/queryhelper/pojo/GenericResponse.java create mode 100644 queryhelper/src/main/java/queryhelper/pojo/OpenAIHttpExceptionErrorResponse.java create mode 100644 queryhelper/src/main/java/queryhelper/pojo/QueryDataResponse.java create mode 100644 queryhelper/src/main/java/queryhelper/pojo/QueryRequest.java create mode 100644 queryhelper/src/main/java/queryhelper/service/AIService.java create mode 100644 queryhelper/src/main/java/queryhelper/utils/ErrorUtils.java create mode 100644 queryhelper/src/main/java/queryhelper/utils/FileUtils.java create mode 100644 queryhelper/src/main/java/queryhelper/utils/PromptUtils.java create mode 100644 queryhelper/src/main/resources/static/output_format.json create mode 100644 queryhelper/src/test/java/com/adaptivescale/AIServiceTest.java diff --git a/README.md b/README.md index b9ddaee8..26ccdae1 100644 --- a/README.md +++ b/README.md @@ -653,6 +653,66 @@ Parameter | Description --pyspark | Generates the Spark SQL file. --scala | Generates the Scala SQL file. +#### query +The query command allows you to use natural language commands to query your databases, transforming these commands into SQL SELECT statements. By leveraging the capabilities of AI and LLMs, specifically OpenAI models, it interprets user queries and generates the corresponding SQL queries. For effective use of this command, users need to provide their OpenAI API Key and specify the OpenAI model to be utilized. The output will be written to a CSV file. The max number of rows that will be returned is 200. You can overwrite this value, or remove completely the limit.The default openai model that is used is gpt-3.5-turbo. + + rosetta [-c, --config CONFIG_FILE] query [-h, --help] [-s, --source CONNECTION_NAME] [-q, --query "Natural language QUERY"] + +Parameter | Description +--- | --- +-h, --help | Show the help message and exit. +-c, --config CONFIG_FILE | YAML config file. If none is supplied it will use main.conf in the current directory if it exists. +-s, --source CONNECTION_NAME | The source connection is used to specify which models and connection to use. +-q --query "Natural language QUERY" | pecifies the natural language query to be transformed into an SQL SELECT statement. +-l --limit Response Row limit (Optional) | Limits the number of rows in the generated CSV file. If not specified, the default limit is set to 200 rows. +--no-limit (Optional) | Specifies that there should be no limit on the number of rows in the generated CSV file. + + +**Example** (Setting the key and model) : + +(Config file) +``` +openai_api_key: "sk-abcdefghijklmno1234567890" +openai_model: "gpt-4" +connections: + - name: mysql + databaseName: sakila + schemaName: + dbType: mysql + url: jdbc:mysql://root:sakila@localhost:3306/sakila + userName: root + password: sakila + - name: pg + databaseName: postgres + schemaName: public + dbType: postgres + url: jdbc:postgresql://localhost:5432/postgres?user=postgres&password=sakila + userName: postgres + password: sakila +``` + +***Example*** (Query) +``` + rosetta query -s mysql -q "Show me the top 10 customers by revenue." +``` +***CSV Output Example*** +```CSV +customer_name,total_revenue,location,email +John Doe,50000,New York,johndoe@example.com +Jane Smith,45000,Los Angeles,janesmith@example.com +David Johnson,40000,Chicago,davidjohnson@example.com +Emily Brown,35000,San Francisco,emilybrown@example.com +Michael Lee,30000,Miami,michaellee@example.com +Sarah Taylor,25000,Seattle,sarahtaylor@example.com +Robert Clark,20000,Boston,robertclark@example.com +Lisa Martinez,15000,Denver,lisamartinez@example.com +Christopher Anderson,10000,Austin,christopheranderson@example.com +Amanda Wilson,5000,Atlanta,amandawilson@example.com + +``` +**Note:** When giving a request that will not generate a SELECT statement the query will be generated but will not be executed rather be given to the user to execute on their own. + + ### Safety Operation In `model.yaml` you can find the attribute `safeMode` which is by default disabled (false). If you want to prevent any DROP operation during diff --git a/cli/build.gradle b/cli/build.gradle index 30a1c094..d65ccb72 100644 --- a/cli/build.gradle +++ b/cli/build.gradle @@ -13,6 +13,7 @@ dependencies { implementation project(':ddl') implementation project(':diff') implementation project(':test') + implementation project(':queryhelper') implementation group: 'info.picocli', name: 'picocli', version: '4.6.3' implementation group: 'org.slf4j', name: 'slf4j-simple', version: '2.0.5' @@ -24,7 +25,8 @@ dependencies { implementation group: 'org.reflections', name: 'reflections', version: '0.10.2' implementation group: 'org.thymeleaf', name: 'thymeleaf', version: '3.1.0.RELEASE' implementation group: 'com.h2database', name: 'h2', version: '2.1.214' - + implementation 'dev.langchain4j:langchain4j-open-ai:0.25.0' + implementation 'com.github.jsqlparser:jsqlparser:4.9' compileOnly 'org.projectlombok:lombok:1.18.12' annotationProcessor 'org.projectlombok:lombok:1.18.12' diff --git a/cli/src/main/java/com/adaptivescale/rosetta/cli/Cli.java b/cli/src/main/java/com/adaptivescale/rosetta/cli/Cli.java index 2a86ddb1..6888471a 100644 --- a/cli/src/main/java/com/adaptivescale/rosetta/cli/Cli.java +++ b/cli/src/main/java/com/adaptivescale/rosetta/cli/Cli.java @@ -36,6 +36,9 @@ import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; import picocli.CommandLine; +import queryhelper.pojo.GenericResponse; +import queryhelper.pojo.QueryRequest; +import queryhelper.service.AIService; import java.io.IOException; import java.io.InputStream; @@ -201,7 +204,7 @@ private void apply(@CommandLine.Option(names = {"-s", "--source"}, required = tr } if (changes.stream().filter(change -> change.getStatus().equals(Change.Status.DROP)).findFirst().isPresent() && - expectedDatabase.getSafeMode()) { + expectedDatabase.getSafeMode()) { log.info("Not going to perform the changes because there are DROP operations and the safe mode is enabled."); return; } @@ -252,8 +255,8 @@ private void test(@CommandLine.Option(names = {"-s", "--source"}) String sourceN } List collect = getDatabases(sourceWorkspace) - .map(AbstractMap.SimpleImmutableEntry::getValue) - .collect(Collectors.toList()); + .map(AbstractMap.SimpleImmutableEntry::getValue) + .collect(Collectors.toList()); for (Database database : collect) { AssertionSqlGenerator assertionSqlGenerator = AssertionSqlGeneratorFactory.generatorFor(source.get()); DefaultSqlExecution defaultSqlExecution = new DefaultSqlExecution(source.get(), new DriverManagerDriverProvider()); @@ -263,7 +266,7 @@ private void test(@CommandLine.Option(names = {"-s", "--source"}) String sourceN @CommandLine.Command(name = "init", description = "Creates a sample config (main.conf) and model directory.", mixinStandardHelpOptions = true) private void init(@CommandLine.Parameters(index = "0", description = "Project name.", defaultValue = "") - String projectName) throws IOException { + String projectName) throws IOException { Path fileName = Paths.get(projectName, CONFIG_NAME); InputStream resourceAsStream = getClass().getResourceAsStream("/" + TEMPLATE_CONFIG_NAME); Path projectDirectory = Path.of(projectName); @@ -297,8 +300,8 @@ private void dbt(@CommandLine.Option(names = {"-s", "--source"}, required = true @CommandLine.Command(name = "generate", description = "Generate code", mixinStandardHelpOptions = true) private void generate(@CommandLine.Option(names = {"-s", "--source"}, required = true) String sourceName, - @CommandLine.Option(names = {"-t", "--target"}, required = true) String targetName, - @CommandLine.Option(names = {"--pyspark"}) boolean generateSpark, + @CommandLine.Option(names = {"-t", "--target"}, required = true) String targetName, + @CommandLine.Option(names = {"--pyspark"}) boolean generateSpark, @CommandLine.Option(names = {"--scala"}) boolean generateScala ) throws Exception { requireConfig(config); @@ -375,7 +378,7 @@ private void extractDbtModels(Connection connection, Path sourceWorkspace) throw @CommandLine.Command(name = "diff", description = "Show difference between local model and database", mixinStandardHelpOptions = true) private void diff(@CommandLine.Option(names = {"-s", "--source"}) String sourceName, - @CommandLine.Option(names = {"-m", "--model"}, defaultValue=DEFAULT_MODEL_YAML) String model) throws Exception { + @CommandLine.Option(names = {"-m", "--model"}, defaultValue = DEFAULT_MODEL_YAML) String model) throws Exception { requireConfig(config); Connection sourceConnection = getSourceConnection(sourceName); @@ -517,4 +520,43 @@ public FileNameAndDatabasePair(String key, Database value) { super(key, value); } } + + + @CommandLine.Command(name = "query", description = "Query schema", mixinStandardHelpOptions = true) + private void query(@CommandLine.Option(names = {"-s", "--source"}, required = true) String sourceName, + @CommandLine.Option(names = {"-q", "--query"}, required = true) String userQueryRequest, + @CommandLine.Option(names = {"-l", "--limit"}, required = false, defaultValue = "200") Integer showRowLimit, + @CommandLine.Option(names = {"--no-limit"}, required = false, defaultValue = "false") Boolean noRowLimit + ) + throws Exception { + requireConfig(config); + + if (config.getOpenAIApiKey() == null) { + log.info("Open AI API key has to be provided in the config file"); + return; + } + + Connection source = getSourceConnection(sourceName); + + Path sourceWorkspace = Paths.get("./", sourceName); + + if (!Files.exists(sourceWorkspace)) { + Files.createDirectories(sourceWorkspace); + } + + Path dataDirectory = sourceWorkspace.resolve("data"); + if (!Files.exists(dataDirectory)) { + Files.createDirectories(dataDirectory); + } + + Database db = SourceGeneratorFactory.sourceGenerator(source).generate(source); + + DDL modelDDL = DDLFactory.ddlForDatabaseType(source.getDbType()); + String DDL = modelDDL.createDatabase(db, false); + + // If `noRowLimit` is true, set the row limit to 0 (no limit), otherwise use the value of `showRowLimit` + GenericResponse response = AIService.generateQuery(userQueryRequest, config.getOpenAIApiKey(), config.getOpenAIModel(), DDL, source, noRowLimit ? 0 : showRowLimit, dataDirectory); + log.info(response.getMessage()); + } + } diff --git a/cli/src/main/java/com/adaptivescale/rosetta/cli/model/Config.java b/cli/src/main/java/com/adaptivescale/rosetta/cli/model/Config.java index fd7048d1..a3633da3 100644 --- a/cli/src/main/java/com/adaptivescale/rosetta/cli/model/Config.java +++ b/cli/src/main/java/com/adaptivescale/rosetta/cli/model/Config.java @@ -1,6 +1,7 @@ package com.adaptivescale.rosetta.cli.model; import com.adaptivescale.rosetta.common.models.input.Connection; +import com.fasterxml.jackson.annotation.JsonProperty; import java.util.List; import java.util.Optional; @@ -8,6 +9,12 @@ public class Config { private List connections; + @JsonProperty("openai_api_key") + private String openAIApiKey; + + @JsonProperty("openai_model") + private String openAIModel; + public List getConnections() { return connections; } @@ -19,4 +26,9 @@ public void setConnections(List connection) { public Optional getConnection(String name) { return connections.stream().filter(target -> target.getName().equals(name)).findFirst(); } + + public String getOpenAIApiKey() { + return openAIApiKey; + } + public String getOpenAIModel() { return openAIModel; } } diff --git a/queryhelper/build.gradle b/queryhelper/build.gradle new file mode 100644 index 00000000..232236b8 --- /dev/null +++ b/queryhelper/build.gradle @@ -0,0 +1,26 @@ +plugins { + id 'java' +} + +group = 'com.adaptivescale' +version = '2.1.3' + +repositories { + mavenCentral() +} + +dependencies { + implementation project(':common') + implementation project(':source') + + testImplementation platform('org.junit:junit-bom:5.9.1') + testImplementation 'org.junit.jupiter:junit-jupiter' + + implementation 'dev.langchain4j:langchain4j-open-ai:0.25.0' + implementation 'com.google.code.gson:gson:2.10.1' + implementation 'com.github.jsqlparser:jsqlparser:4.9' +} + +test { + useJUnitPlatform() +} \ No newline at end of file diff --git a/queryhelper/src/main/java/queryhelper/pojo/GenericResponse.java b/queryhelper/src/main/java/queryhelper/pojo/GenericResponse.java new file mode 100644 index 00000000..57d86157 --- /dev/null +++ b/queryhelper/src/main/java/queryhelper/pojo/GenericResponse.java @@ -0,0 +1,38 @@ +package queryhelper.pojo; + +public class GenericResponse { + private Object data; + private String message; + private Integer statusCode; + + public GenericResponse() { + this.data = null; + this.message = null; + this.statusCode = null; + } + + public Object getData() { + return data; + } + + public void setData(Object data) { + this.data = data; + } + + public String getMessage() { + return message; + } + + public void setMessage(String message) { + this.message = message; + } + + public Integer getStatusCode() { + return statusCode; + } + + public void setStatusCode(Integer statusCode) { + this.statusCode = statusCode; + } + +} diff --git a/queryhelper/src/main/java/queryhelper/pojo/OpenAIHttpExceptionErrorResponse.java b/queryhelper/src/main/java/queryhelper/pojo/OpenAIHttpExceptionErrorResponse.java new file mode 100644 index 00000000..1620d591 --- /dev/null +++ b/queryhelper/src/main/java/queryhelper/pojo/OpenAIHttpExceptionErrorResponse.java @@ -0,0 +1,63 @@ +package queryhelper.pojo; + +public class OpenAIHttpExceptionErrorResponse { + private ErrorDetails error; + + public OpenAIHttpExceptionErrorResponse() { + this.error = null; + } + + public ErrorDetails getError() { + return error; + } + + public void setError(ErrorDetails error) { + this.error = error; + } + + public static class ErrorDetails { + private String message; + private String type; + private String param; + private String code; + + public ErrorDetails() { + this.message = null; + this.type = null; + this.param = null; + this.code = null; + } + + public String getMessage() { + return message; + } + + public void setMessage(String message) { + this.message = message; + } + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } + + public String getParam() { + return param; + } + + public void setParam(String param) { + this.param = param; + } + + public String getCode() { + return code; + } + + public void setCode(String code) { + this.code = code; + } + } +} diff --git a/queryhelper/src/main/java/queryhelper/pojo/QueryDataResponse.java b/queryhelper/src/main/java/queryhelper/pojo/QueryDataResponse.java new file mode 100644 index 00000000..64a3562b --- /dev/null +++ b/queryhelper/src/main/java/queryhelper/pojo/QueryDataResponse.java @@ -0,0 +1,50 @@ +package queryhelper.pojo; + +import java.util.List; +import java.util.Map; + +public class QueryDataResponse { + private String schema; + private Double responseTime; + private String query; + private List> records; + + public QueryDataResponse() { + this.schema = null; + this.responseTime = null; + this.query = null; + this.records = null; + } + + public String getSchema() { + return schema; + } + + public void setSchema(String schema) { + this.schema = schema; + } + + public Double getResponseTime() { + return responseTime; + } + + public void setResponseTime(Double responseTime) { + this.responseTime = responseTime; + } + + public String getQuery() { + return query; + } + + public void setQuery(String query) { + this.query = query; + } + + public List> getRecords() { + return records; + } + + public void setRecords(List> records) { + this.records = records; + } +} diff --git a/queryhelper/src/main/java/queryhelper/pojo/QueryRequest.java b/queryhelper/src/main/java/queryhelper/pojo/QueryRequest.java new file mode 100644 index 00000000..ca4b819a --- /dev/null +++ b/queryhelper/src/main/java/queryhelper/pojo/QueryRequest.java @@ -0,0 +1,13 @@ +package queryhelper.pojo; + +public class QueryRequest { + private String query; + + public String getQuery() { + return query; + } + + public void setQuery(String query) { + this.query = query; + } +} diff --git a/queryhelper/src/main/java/queryhelper/service/AIService.java b/queryhelper/src/main/java/queryhelper/service/AIService.java new file mode 100644 index 00000000..b42b9174 --- /dev/null +++ b/queryhelper/src/main/java/queryhelper/service/AIService.java @@ -0,0 +1,136 @@ +package queryhelper.service; + +import com.adaptivescale.rosetta.common.DriverManagerDriverProvider; +import com.adaptivescale.rosetta.common.JDBCUtils; +import com.adaptivescale.rosetta.common.models.input.Connection; +import com.adataptivescale.rosetta.source.common.QueryHelper; +import com.google.gson.Gson; +import com.google.gson.JsonSyntaxException; +import dev.langchain4j.model.openai.OpenAiChatModel; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import net.sf.jsqlparser.statement.select.Select; +import queryhelper.pojo.GenericResponse; +import queryhelper.pojo.QueryDataResponse; +import queryhelper.pojo.QueryRequest; +import queryhelper.utils.ErrorUtils; +import queryhelper.utils.FileUtils; +import queryhelper.utils.PromptUtils; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.sql.Driver; +import java.sql.SQLException; +import java.sql.Statement; +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.Properties; + + +public class AIService { + private final static String AI_MODEL = "gpt-3.5-turbo"; + + public static GenericResponse generateQuery(String userQueryRequest, String apiKey, String aiModel, String databaseDDL, Connection source, Integer showRowLimit, Path dataDirectory) { + + Gson gson = new Gson(); + GenericResponse response = new GenericResponse(); + QueryDataResponse data = new QueryDataResponse(); + QueryRequest queryRequest = new QueryRequest(); + queryRequest.setQuery(userQueryRequest); + + String query; + String aiOutputStr; + String prompt; + + OpenAiChatModel.OpenAiChatModelBuilder model = OpenAiChatModel + .builder() + .temperature(0.1) + .apiKey(apiKey) + .modelName(AI_MODEL); + + if (aiModel != null && !aiModel.isEmpty()) { + model.modelName(aiModel); + } + + prompt = PromptUtils.queryPrompt(queryRequest, databaseDDL, source); + + try { // Check if we have a properly set API key & that openAI services aren't down + aiOutputStr = model.build().generate(prompt); + QueryRequest aiOutputObj = gson.fromJson(aiOutputStr, QueryRequest.class); + query = aiOutputObj.getQuery(); + } catch (JsonSyntaxException e) { + return ErrorUtils.invalidResponseError(e); + } catch (Exception e) { + return ErrorUtils.openAIError(e); + } + + boolean selectStatement = isSelectStatement(query); + if (!selectStatement) { + GenericResponse errorResponse = new GenericResponse(); + errorResponse.setMessage("Generated query, execute on your own will: " + aiOutputStr); + errorResponse.setStatusCode(200); + } + + List> records = executeQueryAndGetRecords(query, source, showRowLimit); + data.setRecords(records); + + response.setData(data); + response.setStatusCode(200); + + + QueryDataResponse queryDataResponse = (QueryDataResponse) response.getData(); + String csvFile = createCSVFile(queryDataResponse, queryRequest.getQuery(), dataDirectory); + + response.setMessage( + aiOutputStr + "\n" + + "Total rows: " + data.getRecords().size() + "\n" + + "Your response is saved to a CSV file named '" + csvFile + "'!" + ); + + return response; + } + + private static List> executeQueryAndGetRecords(String query, Connection source, Integer showRowLimit) { + try { + DriverManagerDriverProvider driverManagerDriverProvider = new DriverManagerDriverProvider(); + Driver driver = driverManagerDriverProvider.getDriver(source); + Properties properties = JDBCUtils.setJDBCAuth(source); + java.sql.Connection jdbcConnection = driver.connect(source.getUrl(), properties); + Statement statement = jdbcConnection.createStatement(); + statement.setMaxRows(showRowLimit); + List> select = QueryHelper.select(statement, query); + return select; + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + public static boolean isSelectStatement(String query) { + boolean isSelectStatement = true; + try { + net.sf.jsqlparser.statement.Statement statement = CCJSqlParserUtil.parse(query); + if (!(statement instanceof Select)) { + return false; + } + } catch (Exception e) { + return false; + } + return isSelectStatement; + } + + private static String createCSVFile(QueryDataResponse queryDataResponse, String csvFileName, Path dataDirectory) { + try { + String timestamp = new SimpleDateFormat("yyyyMMdd_HHmmss").format(new Date()); + String fileName = csvFileName.replaceAll("\\s+", "_") + "_" + timestamp + ".csv"; + Path csvFilePath = dataDirectory.resolve(fileName); + + FileUtils.convertToCSV(csvFilePath.toString(), queryDataResponse.getRecords()); + + return csvFilePath.toString(); + } catch (Exception e) { + GenericResponse genericResponse = ErrorUtils.csvFileError(e); + throw new RuntimeException(genericResponse.getMessage()); + } + } +} \ No newline at end of file diff --git a/queryhelper/src/main/java/queryhelper/utils/ErrorUtils.java b/queryhelper/src/main/java/queryhelper/utils/ErrorUtils.java new file mode 100644 index 00000000..3a7921aa --- /dev/null +++ b/queryhelper/src/main/java/queryhelper/utils/ErrorUtils.java @@ -0,0 +1,56 @@ +package queryhelper.utils; + + +import com.google.gson.Gson; +import dev.ai4j.openai4j.OpenAiHttpException; +import queryhelper.pojo.GenericResponse; +import queryhelper.pojo.OpenAIHttpExceptionErrorResponse; + +public class ErrorUtils { + private static final GenericResponse errorResponse = new GenericResponse(); + + public static GenericResponse fileError(Exception e) { + return genericErrorResponse(e, "There was an error while reading from file."); + } + + public static GenericResponse csvFileError(Exception e) { + return genericErrorResponse(e, "There was an error while creating the csv file!"); + } + + public static GenericResponse openAIError(Exception e) { + if (e.getCause() instanceof OpenAiHttpException) { // When API key is wrong (currently the only supported exception by langchain) + Gson gson = new Gson(); + String errorMessage = e.getMessage(); + int startIndex = errorMessage.indexOf('{'); + int endIndex = errorMessage.lastIndexOf('}'); + + if (startIndex != -1 && endIndex != -1) { + String jsonError = errorMessage.substring(startIndex, endIndex + 1); + OpenAIHttpExceptionErrorResponse data = gson.fromJson(jsonError, OpenAIHttpExceptionErrorResponse.class); + + errorResponse.setData(data); + errorResponse.setMessage("Error occurred while communicating with OpenAI's API."); + errorResponse.setStatusCode(500); + return errorResponse; + } + } + + return genericErrorResponse(e, "Error occurred while communicating with OpenAI's API."); + } + + public static GenericResponse invalidResponseError(Exception e) { + return genericErrorResponse(e, "The generated response by chatGPT is invalid."); + } + + public static GenericResponse invalidSQLError(Exception e) { + return genericErrorResponse(e, "The SQL query generated by chatGPT failed during execution."); + } + + public static GenericResponse genericErrorResponse(Exception e, String message) { + errorResponse.setData(e.getMessage()); + errorResponse.setMessage(message); + errorResponse.setStatusCode(500); + return errorResponse; + } + +} diff --git a/queryhelper/src/main/java/queryhelper/utils/FileUtils.java b/queryhelper/src/main/java/queryhelper/utils/FileUtils.java new file mode 100644 index 00000000..565efd45 --- /dev/null +++ b/queryhelper/src/main/java/queryhelper/utils/FileUtils.java @@ -0,0 +1,44 @@ +package queryhelper.utils; + +import java.io.*; +import java.nio.charset.StandardCharsets; +import java.util.*; +import java.util.stream.Collectors; +import java.io.FileWriter; +import java.io.IOException; + +public class FileUtils { + + + public static String readJsonFile() { + try (InputStream inputStream = FileUtils.class.getClassLoader().getResourceAsStream("static/output_format.json"); + InputStreamReader reader = new InputStreamReader(inputStream, StandardCharsets.UTF_8)) { + return new BufferedReader(reader).lines().collect(Collectors.joining("\n")); + } catch (IOException e) { + e.printStackTrace(); + return null; + } + } + + public static void convertToCSV(String fileName, List> list) { + try (FileWriter csvWriter = new FileWriter(fileName)) { + if (!list.isEmpty()) { + Set headers = list.get(0).keySet(); + StringBuilder csvContent = new StringBuilder(); + csvContent.append(String.join(",", headers)).append("\n"); + + for (Map map : list) { + for (String header : headers) { + csvContent.append(map.getOrDefault(header, "")).append(","); + } + csvContent.setLength(csvContent.length() - 1); + csvContent.append("\n"); + } + csvWriter.write(csvContent.toString()); + } + } catch (IOException e) { + e.printStackTrace(); + } + } + +} diff --git a/queryhelper/src/main/java/queryhelper/utils/PromptUtils.java b/queryhelper/src/main/java/queryhelper/utils/PromptUtils.java new file mode 100644 index 00000000..e89821fa --- /dev/null +++ b/queryhelper/src/main/java/queryhelper/utils/PromptUtils.java @@ -0,0 +1,20 @@ +package queryhelper.utils; + +import com.adaptivescale.rosetta.common.models.input.Connection; +import queryhelper.pojo.QueryRequest; +public class PromptUtils { + public static String queryPrompt(QueryRequest queryRequest, String databaseDDL, Connection source) { + String query = queryRequest.getQuery(); + + String outputFormat = FileUtils.readJsonFile(); + + return "You are a system that generates and outputs " + source.getDbType() + " SQL queries.\n" + + "The following is the DDL of the database:\n\n" + + "START OF DDL\n" + databaseDDL + "\nEND OF DDL" + + "\n\nI want you to generate a SQL query based on the following description: " + + query + + " Respond only by giving me the SQL code with no other accompanying text in the following format:\n" + + outputFormat; + } + +} diff --git a/queryhelper/src/main/resources/static/output_format.json b/queryhelper/src/main/resources/static/output_format.json new file mode 100644 index 00000000..ddd22ada --- /dev/null +++ b/queryhelper/src/main/resources/static/output_format.json @@ -0,0 +1,3 @@ +{ + "query": "AI generated query" +} \ No newline at end of file diff --git a/queryhelper/src/test/java/com/adaptivescale/AIServiceTest.java b/queryhelper/src/test/java/com/adaptivescale/AIServiceTest.java new file mode 100644 index 00000000..ae2a9cd2 --- /dev/null +++ b/queryhelper/src/test/java/com/adaptivescale/AIServiceTest.java @@ -0,0 +1,22 @@ +package com.adaptivescale; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; +import queryhelper.service.AIService; + + +public class AIServiceTest { + + @Test + void testIsSelectStatementGoodCase() { + String goodQuery = "SELECT * FROM table_name;"; + assertTrue(AIService.isSelectStatement(goodQuery)); + } + + @Test + void testIsSelectStatementBadCase() { + String badQuery = "UPDATE table_name SET column_name = value WHERE condition;"; + assertFalse(AIService.isSelectStatement(badQuery)); + } + +} diff --git a/settings.gradle b/settings.gradle index 83e54a5c..f5fed1ed 100644 --- a/settings.gradle +++ b/settings.gradle @@ -10,4 +10,5 @@ if (!System.env.JITPACK) include 'binary' include 'diff' include 'test' +include 'queryhelper'