Skip to content

Commit

Permalink
Update text2sql.py according to linter
Browse files Browse the repository at this point in the history
  • Loading branch information
SichengStevenLi authored Aug 6, 2024
1 parent bb47ef5 commit 327d3c4
Showing 1 changed file with 24 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, insert
import argparse


def create_database_schema():
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
Expand All @@ -23,6 +24,7 @@ def create_database_schema():
metadata_obj.create_all(engine)
return engine, city_stats_table


def define_sql_database(engine, city_stats_table):
sql_database = SQLDatabase(engine, include_tables=["city_stats"])

Expand All @@ -43,21 +45,19 @@ def define_sql_database(engine, city_stats_table):

return sql_database


def main(args):
engine, city_stats_table = create_database_schema()

sql_database = define_sql_database(engine, city_stats_table)

model_id=args.embedding_model_path
model_id = args.embedding_model_path
device_map = args.device


embed_model = IpexLLMEmbedding(
model_id,
device=device_map
)
embed_model = IpexLLMEmbedding(model_id, device=device_map)

llm = IpexLLM.from_model_id(
llm = IpexLLM.from_model_id(
model_name=args.model_path,
tokenizer_name=args.model_path,
context_window=512,
Expand All @@ -70,11 +70,11 @@ def main(args):

# default retrieval (return_raw=True)
nl_sql_retriever = NLSQLRetriever(
sql_database,
tables=["city_stats"],
llm=llm,
embed_model=embed_model,
return_raw=True
sql_database,
tables=["city_stats"],
llm=llm,
embed_model=embed_model,
return_raw=True
)

query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever, llm=llm)
Expand All @@ -84,13 +84,13 @@ def main(args):


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='LlamaIndex IpexLLM Example')
parser = argparse.ArgumentParser(description="LlamaIndex IpexLLM Example")
parser.add_argument(
'-m',
'--model-path',
"-m",
"--model-path",
type=str,
required=True,
help='the path to transformers model'
help="the path to transformers model"
)
parser.add_argument(
"--device",
Expand All @@ -101,24 +101,24 @@ def main(args):
help="The device (Intel CPU or Intel GPU) the LLM model runs on",
)
parser.add_argument(
'-q',
'--question',
"-q",
"--question",
type=str,
default='Which city has the highest population?',
help='question you want to ask.'
default="Which city has the highest population?",
help="question you want to ask."
)
parser.add_argument(
'-e',
'--embedding-model-path',
"-e",
"--embedding-model-path",
default="BAAI/bge-small-en",
help="the path to embedding model path"
)
parser.add_argument(
'-n',
'--n-predict',
"-n",
"--n-predict",
type=int,
default=32,
help='max number of predict tokens'
help="max number of predict tokens"
)
args = parser.parse_args()

Expand Down

0 comments on commit 327d3c4

Please sign in to comment.