Skip to content

Commit

Permalink
Feature/query command (#205)
Browse files Browse the repository at this point in the history
added: AI integration to query the data using natural language through query command
  • Loading branch information
Femi3211 authored Apr 11, 2024
1 parent 14b2711 commit 38a67b4
Show file tree
Hide file tree
Showing 16 changed files with 596 additions and 8 deletions.
60 changes: 60 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,66 @@ Parameter | Description
--pyspark | Generates the Spark SQL file.
--scala | Generates the Scala SQL file.

#### query
The query command allows you to use natural language commands to query your databases, transforming these commands into SQL SELECT statements. By leveraging the capabilities of AI and LLMs, specifically OpenAI models, it interprets user queries and generates the corresponding SQL queries. For effective use of this command, users need to provide their OpenAI API Key and specify the OpenAI model to be utilized. The output will be written to a CSV file. The max number of rows that will be returned is 200. You can overwrite this value, or remove completely the limit.The default openai model that is used is gpt-3.5-turbo.

rosetta [-c, --config CONFIG_FILE] query [-h, --help] [-s, --source CONNECTION_NAME] [-q, --query "Natural language QUERY"]

Parameter | Description
--- | ---
-h, --help | Show the help message and exit.
-c, --config CONFIG_FILE | YAML config file. If none is supplied it will use main.conf in the current directory if it exists.
-s, --source CONNECTION_NAME | The source connection is used to specify which models and connection to use.
-q --query "Natural language QUERY" | pecifies the natural language query to be transformed into an SQL SELECT statement.
-l --limit Response Row limit (Optional) | Limits the number of rows in the generated CSV file. If not specified, the default limit is set to 200 rows.
--no-limit (Optional) | Specifies that there should be no limit on the number of rows in the generated CSV file.


**Example** (Setting the key and model) :

(Config file)
```
openai_api_key: "sk-abcdefghijklmno1234567890"
openai_model: "gpt-4"
connections:
- name: mysql
databaseName: sakila
schemaName:
dbType: mysql
url: jdbc:mysql://root:sakila@localhost:3306/sakila
userName: root
password: sakila
- name: pg
databaseName: postgres
schemaName: public
dbType: postgres
url: jdbc:postgresql://localhost:5432/postgres?user=postgres&password=sakila
userName: postgres
password: sakila
```
***Example*** (Query)
```
rosetta query -s mysql -q "Show me the top 10 customers by revenue."
```
***CSV Output Example***
```CSV
customer_name,total_revenue,location,email
John Doe,50000,New York,[email protected]
Jane Smith,45000,Los Angeles,[email protected]
David Johnson,40000,Chicago,[email protected]
Emily Brown,35000,San Francisco,[email protected]
Michael Lee,30000,Miami,[email protected]
Sarah Taylor,25000,Seattle,[email protected]
Robert Clark,20000,Boston,[email protected]
Lisa Martinez,15000,Denver,[email protected]
Christopher Anderson,10000,Austin,[email protected]
Amanda Wilson,5000,Atlanta,[email protected]
```
**Note:** When giving a request that will not generate a SELECT statement the query will be generated but will not be executed rather be given to the user to execute on their own.



### Safety Operation
In `model.yaml` you can find the attribute `safeMode` which is by default disabled (false). If you want to prevent any DROP operation during
Expand Down
4 changes: 3 additions & 1 deletion cli/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies {
implementation project(':ddl')
implementation project(':diff')
implementation project(':test')
implementation project(':queryhelper')

implementation group: 'info.picocli', name: 'picocli', version: '4.6.3'
implementation group: 'org.slf4j', name: 'slf4j-simple', version: '2.0.5'
Expand All @@ -24,7 +25,8 @@ dependencies {
implementation group: 'org.reflections', name: 'reflections', version: '0.10.2'
implementation group: 'org.thymeleaf', name: 'thymeleaf', version: '3.1.0.RELEASE'
implementation group: 'com.h2database', name: 'h2', version: '2.1.214'

implementation 'dev.langchain4j:langchain4j-open-ai:0.25.0'
implementation 'com.github.jsqlparser:jsqlparser:4.9'
compileOnly 'org.projectlombok:lombok:1.18.12'
annotationProcessor 'org.projectlombok:lombok:1.18.12'

Expand Down
56 changes: 49 additions & 7 deletions cli/src/main/java/com/adaptivescale/rosetta/cli/Cli.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import picocli.CommandLine;
import queryhelper.pojo.GenericResponse;
import queryhelper.pojo.QueryRequest;
import queryhelper.service.AIService;

import java.io.IOException;
import java.io.InputStream;
Expand Down Expand Up @@ -201,7 +204,7 @@ private void apply(@CommandLine.Option(names = {"-s", "--source"}, required = tr
}

if (changes.stream().filter(change -> change.getStatus().equals(Change.Status.DROP)).findFirst().isPresent() &&
expectedDatabase.getSafeMode()) {
expectedDatabase.getSafeMode()) {
log.info("Not going to perform the changes because there are DROP operations and the safe mode is enabled.");
return;
}
Expand Down Expand Up @@ -252,8 +255,8 @@ private void test(@CommandLine.Option(names = {"-s", "--source"}) String sourceN
}

List<Database> collect = getDatabases(sourceWorkspace)
.map(AbstractMap.SimpleImmutableEntry::getValue)
.collect(Collectors.toList());
.map(AbstractMap.SimpleImmutableEntry::getValue)
.collect(Collectors.toList());
for (Database database : collect) {
AssertionSqlGenerator assertionSqlGenerator = AssertionSqlGeneratorFactory.generatorFor(source.get());
DefaultSqlExecution defaultSqlExecution = new DefaultSqlExecution(source.get(), new DriverManagerDriverProvider());
Expand All @@ -263,7 +266,7 @@ private void test(@CommandLine.Option(names = {"-s", "--source"}) String sourceN

@CommandLine.Command(name = "init", description = "Creates a sample config (main.conf) and model directory.", mixinStandardHelpOptions = true)
private void init(@CommandLine.Parameters(index = "0", description = "Project name.", defaultValue = "")
String projectName) throws IOException {
String projectName) throws IOException {
Path fileName = Paths.get(projectName, CONFIG_NAME);
InputStream resourceAsStream = getClass().getResourceAsStream("/" + TEMPLATE_CONFIG_NAME);
Path projectDirectory = Path.of(projectName);
Expand Down Expand Up @@ -297,8 +300,8 @@ private void dbt(@CommandLine.Option(names = {"-s", "--source"}, required = true

@CommandLine.Command(name = "generate", description = "Generate code", mixinStandardHelpOptions = true)
private void generate(@CommandLine.Option(names = {"-s", "--source"}, required = true) String sourceName,
@CommandLine.Option(names = {"-t", "--target"}, required = true) String targetName,
@CommandLine.Option(names = {"--pyspark"}) boolean generateSpark,
@CommandLine.Option(names = {"-t", "--target"}, required = true) String targetName,
@CommandLine.Option(names = {"--pyspark"}) boolean generateSpark,
@CommandLine.Option(names = {"--scala"}) boolean generateScala
) throws Exception {
requireConfig(config);
Expand Down Expand Up @@ -375,7 +378,7 @@ private void extractDbtModels(Connection connection, Path sourceWorkspace) throw

@CommandLine.Command(name = "diff", description = "Show difference between local model and database", mixinStandardHelpOptions = true)
private void diff(@CommandLine.Option(names = {"-s", "--source"}) String sourceName,
@CommandLine.Option(names = {"-m", "--model"}, defaultValue=DEFAULT_MODEL_YAML) String model) throws Exception {
@CommandLine.Option(names = {"-m", "--model"}, defaultValue = DEFAULT_MODEL_YAML) String model) throws Exception {
requireConfig(config);
Connection sourceConnection = getSourceConnection(sourceName);

Expand Down Expand Up @@ -517,4 +520,43 @@ public FileNameAndDatabasePair(String key, Database value) {
super(key, value);
}
}


@CommandLine.Command(name = "query", description = "Query schema", mixinStandardHelpOptions = true)
private void query(@CommandLine.Option(names = {"-s", "--source"}, required = true) String sourceName,
@CommandLine.Option(names = {"-q", "--query"}, required = true) String userQueryRequest,
@CommandLine.Option(names = {"-l", "--limit"}, required = false, defaultValue = "200") Integer showRowLimit,
@CommandLine.Option(names = {"--no-limit"}, required = false, defaultValue = "false") Boolean noRowLimit
)
throws Exception {
requireConfig(config);

if (config.getOpenAIApiKey() == null) {
log.info("Open AI API key has to be provided in the config file");
return;
}

Connection source = getSourceConnection(sourceName);

Path sourceWorkspace = Paths.get("./", sourceName);

if (!Files.exists(sourceWorkspace)) {
Files.createDirectories(sourceWorkspace);
}

Path dataDirectory = sourceWorkspace.resolve("data");
if (!Files.exists(dataDirectory)) {
Files.createDirectories(dataDirectory);
}

Database db = SourceGeneratorFactory.sourceGenerator(source).generate(source);

DDL modelDDL = DDLFactory.ddlForDatabaseType(source.getDbType());
String DDL = modelDDL.createDatabase(db, false);

// If `noRowLimit` is true, set the row limit to 0 (no limit), otherwise use the value of `showRowLimit`
GenericResponse response = AIService.generateQuery(userQueryRequest, config.getOpenAIApiKey(), config.getOpenAIModel(), DDL, source, noRowLimit ? 0 : showRowLimit, dataDirectory);
log.info(response.getMessage());
}

}
12 changes: 12 additions & 0 deletions cli/src/main/java/com/adaptivescale/rosetta/cli/model/Config.java
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
package com.adaptivescale.rosetta.cli.model;

import com.adaptivescale.rosetta.common.models.input.Connection;
import com.fasterxml.jackson.annotation.JsonProperty;

import java.util.List;
import java.util.Optional;

public class Config {
private List<Connection> connections;

@JsonProperty("openai_api_key")
private String openAIApiKey;

@JsonProperty("openai_model")
private String openAIModel;

public List<Connection> getConnections() {
return connections;
}
Expand All @@ -19,4 +26,9 @@ public void setConnections(List<Connection> connection) {
public Optional<Connection> getConnection(String name) {
return connections.stream().filter(target -> target.getName().equals(name)).findFirst();
}

public String getOpenAIApiKey() {
return openAIApiKey;
}
public String getOpenAIModel() { return openAIModel; }
}
26 changes: 26 additions & 0 deletions queryhelper/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
plugins {
id 'java'
}

group = 'com.adaptivescale'
version = '2.1.3'

repositories {
mavenCentral()
}

dependencies {
implementation project(':common')
implementation project(':source')

testImplementation platform('org.junit:junit-bom:5.9.1')
testImplementation 'org.junit.jupiter:junit-jupiter'

implementation 'dev.langchain4j:langchain4j-open-ai:0.25.0'
implementation 'com.google.code.gson:gson:2.10.1'
implementation 'com.github.jsqlparser:jsqlparser:4.9'
}

test {
useJUnitPlatform()
}
38 changes: 38 additions & 0 deletions queryhelper/src/main/java/queryhelper/pojo/GenericResponse.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package queryhelper.pojo;

public class GenericResponse {
private Object data;
private String message;
private Integer statusCode;

public GenericResponse() {
this.data = null;
this.message = null;
this.statusCode = null;
}

public Object getData() {
return data;
}

public void setData(Object data) {
this.data = data;
}

public String getMessage() {
return message;
}

public void setMessage(String message) {
this.message = message;
}

public Integer getStatusCode() {
return statusCode;
}

public void setStatusCode(Integer statusCode) {
this.statusCode = statusCode;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package queryhelper.pojo;

public class OpenAIHttpExceptionErrorResponse {
private ErrorDetails error;

public OpenAIHttpExceptionErrorResponse() {
this.error = null;
}

public ErrorDetails getError() {
return error;
}

public void setError(ErrorDetails error) {
this.error = error;
}

public static class ErrorDetails {
private String message;
private String type;
private String param;
private String code;

public ErrorDetails() {
this.message = null;
this.type = null;
this.param = null;
this.code = null;
}

public String getMessage() {
return message;
}

public void setMessage(String message) {
this.message = message;
}

public String getType() {
return type;
}

public void setType(String type) {
this.type = type;
}

public String getParam() {
return param;
}

public void setParam(String param) {
this.param = param;
}

public String getCode() {
return code;
}

public void setCode(String code) {
this.code = code;
}
}
}
Loading

0 comments on commit 38a67b4

Please sign in to comment.