Skip to content

Commit

Permalink
fixed Flake issues
Browse files Browse the repository at this point in the history
  • Loading branch information
MinuraPunchihewa committed May 30, 2024
1 parent 63adf72 commit fd8ac84
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 48 deletions.
1 change: 1 addition & 0 deletions mindsdb/integrations/handlers/openai_handler/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class PendingFT(openai.OpenAIError):
Custom exception to handle pending fine-tuning status.
"""
message: str

def __init__(self, message) -> None:
super().__init__()
self.message = message
Expand Down
65 changes: 33 additions & 32 deletions mindsdb/integrations/handlers/openai_handler/openai_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, *args, **kwargs):
self.max_batch_size = 20
self.default_max_tokens = 100
self.chat_completion_models = CHAT_MODELS
self.supported_ft_models = FINETUNING_MODELS # base models compatible with finetuning
self.supported_ft_models = FINETUNING_MODELS # base models compatible with finetuning
# For now this are only used for handlers that inherits OpenAIHandler and don't need to override base methods
self.api_key_name = getattr(self, 'api_key_name', self.name)
self.api_base = getattr(self, 'api_base', OPENAI_API_BASE)
Expand Down Expand Up @@ -265,11 +265,11 @@ def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame
args = self.model_storage.json_get('args')
connection_args = self.engine_storage.get_connection_args()

args['api_base'] = (pred_args.get('api_base') or
self.api_base or
connection_args.get('api_base') or
args.get('api_base') or
os.environ.get('OPENAI_API_BASE', OPENAI_API_BASE))
args['api_base'] = (pred_args.get('api_base')
or self.api_base
or connection_args.get('api_base')
or args.get('api_base')
or os.environ.get('OPENAI_API_BASE', OPENAI_API_BASE))
if pred_args.get('api_organization'):
args['api_organization'] = pred_args['api_organization']
df = df.reset_index(drop=True)
Expand Down Expand Up @@ -525,7 +525,7 @@ def _submit_completion(model_name: Text, prompts: List[Text], api_args: Dict, ar
df (pd.DataFrame): Input data to run completion on.
Returns:
List[Text]: List of completions.
List[Text]: List of completions.
"""
kwargs = {
'model': model_name,
Expand Down Expand Up @@ -554,7 +554,7 @@ def _log_api_call(params: Dict, response: Any) -> None:
response (Any): Response from the API.
Returns:
None
None
"""
after_openai_query(params, response)

Expand Down Expand Up @@ -587,11 +587,11 @@ def _tidy(comp: openai.types.completion.Completion) -> List[Text]:
comp (openai.types.completion.Completion): Completion object.
Returns:
List[Text]: List of completions as text.
List[Text]: List of completions as text.
"""
tidy_comps = []
for c in comp.choices:
if hasattr(c,'text'):
if hasattr(c, 'text'):
tidy_comps.append(c.text.strip('\n').strip(''))
return tidy_comps

Expand Down Expand Up @@ -625,13 +625,13 @@ def _tidy(comp: openai.types.create_embedding_response.CreateEmbeddingResponse)
Args:
comp (openai.types.create_embedding_response.CreateEmbeddingResponse): Embedding object.
Returns:
List[float]: List of embeddings as numbers.
"""
tidy_comps = []
for c in comp.data:
if hasattr(c,'embedding'):
if hasattr(c, 'embedding'):
tidy_comps.append([c.embedding])
return tidy_comps

Expand Down Expand Up @@ -673,7 +673,7 @@ def _tidy(comp: openai.types.chat.chat_completion.ChatCompletion) -> List[Text]:
"""
tidy_comps = []
for c in comp.choices:
if hasattr(c,'message'):
if hasattr(c, 'message'):
tidy_comps.append(c.message.content.strip('\n').strip(''))
return tidy_comps

Expand Down Expand Up @@ -775,7 +775,7 @@ def _tidy(comp: List[openai.types.image.Image]) -> List[Text]:
List[Text]: List of image completions as URLs or base64 encoded images.
"""
return [
c.url if hasattr(c,'url') else c.b64_json
c.url if hasattr(c, 'url') else c.b64_json
for c in comp
]

Expand All @@ -784,13 +784,13 @@ def _tidy(comp: List[openai.types.image.Image]) -> List[Text]:
for p in prompts
]
return _tidy(completions)

client = self._get_client(
api_key=api_key,
base_url=args.get('api_base'),
org=args.pop('api_organization') if 'api_organization' in args else None,
)
)

try:
# check if simple completion works
completion = _submit_completion(
Expand All @@ -801,7 +801,7 @@ def _tidy(comp: List[openai.types.image.Image]) -> List[Text]:
# else, we get the max batch size
if 'you can currently request up to at most a total of' in str(e):
pattern = 'a total of'
max_batch_size = int(e[e.find(pattern) + len(pattern) :].split(').')[0])
max_batch_size = int(e[e.find(pattern) + len(pattern):].split(').')[0])
else:
max_batch_size = (
self.max_batch_size
Expand All @@ -812,7 +812,7 @@ def _tidy(comp: List[openai.types.image.Image]) -> List[Text]:
for i in range(math.ceil(len(prompts) / max_batch_size)):
partial = _submit_completion(
model_name,
prompts[i * max_batch_size : (i + 1) * max_batch_size],
prompts[i * max_batch_size: (i + 1) * max_batch_size],
api_args,
args,
df,
Expand All @@ -833,7 +833,7 @@ def _tidy(comp: List[openai.types.image.Image]) -> List[Text]:
future = executor.submit(
_submit_completion,
model_name,
prompts[i * max_batch_size : (i + 1) * max_batch_size],
prompts[i * max_batch_size: (i + 1) * max_batch_size],
api_args,
args,
df,
Expand All @@ -856,7 +856,7 @@ def describe(self, attribute: Optional[Text] = None) -> pd.DataFrame:
attribute (Optional[Text]): Attribute to describe. Can be 'args' or 'metadata'.
Returns:
pd.DataFrame: Model metadata or model arguments.
pd.DataFrame: Model metadata or model arguments.
"""
# TODO: Update to use update() artifacts

Expand All @@ -867,7 +867,7 @@ def describe(self, attribute: Optional[Text] = None) -> pd.DataFrame:
elif attribute == 'metadata':
model_name = args.get('model_name', self.default_model)
try:
client= self._get_client(
client = self._get_client(
api_key=api_key,
base_url=args.get('api_base'),
org=args.get('api_organization')
Expand Down Expand Up @@ -911,12 +911,11 @@ def finetune(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = Non
api_key = get_api_key(self.api_key_name, args, self.engine_storage)

using_args = args.pop('using') if 'using' in args else {}

api_base = using_args.get('api_base', os.environ.get('OPENAI_API_BASE', OPENAI_API_BASE))
org = using_args.get('api_organization')
client = self._get_client(api_key=api_key, base_url=api_base, org=org)


args = {**using_args, **args}
prev_model_name = self.base_model_storage.json_get('args').get('model_name', '')

Expand All @@ -943,8 +942,10 @@ def finetune(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = Non
jsons = {k: None for k in file_names.keys()}
for split, file_name in file_names.items():
if os.path.isfile(os.path.join(temp_storage_path, file_name)):
jsons[split] = client.files.create(file=open(f"{temp_storage_path}/{file_name}", "rb"),
purpose='fine-tune')
jsons[split] = client.files.create(
file=open(f"{temp_storage_path}/{file_name}", "rb"),
purpose='fine-tune'
)

if type(jsons['train']) is openai.types.FileObject:
train_file_id = jsons['train'].id
Expand Down Expand Up @@ -1015,7 +1016,7 @@ def _prepare_ft_jsonl(df: pd.DataFrame, _, temp_filename: Text, temp_model_path:
temp_model_path (Text): Temporary model path.
Returns:
Dict: File names for the fine-tuning process.
Dict: File names for the fine-tuning process.
"""
df.to_json(temp_model_path, orient='records', lines=True)

Expand Down Expand Up @@ -1050,7 +1051,7 @@ def _get_ft_model_type(self, model_name: Text) -> Text:
model_name (Text): Model name.
Returns:
Text: Model to use for fine-tuning.
Text: Model to use for fine-tuning.
"""
for model_type in self.supported_ft_models:
if model_type in model_name.lower():
Expand All @@ -1067,7 +1068,7 @@ def _add_extra_ft_params(ft_params: Dict, using_args: Dict) -> Dict:
using_args (Dict): Parameters passed when calling the fine-tuning process via a model.
Returns:
Dict: Fine-tuning parameters with extra parameters.
Dict: Fine-tuning parameters with extra parameters.
"""
extra_params = {
'n_epochs': using_args.get('n_epochs', None),
Expand Down Expand Up @@ -1127,7 +1128,7 @@ def _check_ft_status(job_id: Text) -> openai.types.fine_tuning.FineTuningJob:
PendingFT: If the fine-tuning process is still pending.
Returns:
openai.types.fine_tuning.FineTuningJob: Fine-tuning stats.
openai.types.fine_tuning.FineTuningJob: Fine-tuning stats.
"""
ft_retrieved = client.fine_tuning.jobs.retrieve(fine_tuning_job_id=job_id)
if ft_retrieved.status in ('succeeded', 'failed', 'cancelled'):
Expand All @@ -1149,7 +1150,7 @@ def _check_ft_status(job_id: Text) -> openai.types.fine_tuning.FineTuningJob:
result_file_id = result_file_id.id # legacy endpoint

return ft_stats, result_file_id

@staticmethod
def _get_client(api_key: Text, base_url: Text, org: Optional[Text] = None) -> OpenAI:
"""
Expand All @@ -1161,6 +1162,6 @@ def _get_client(api_key: Text, base_url: Text, org: Optional[Text] = None) -> Op
org (Optional[Text]): OpenAI organization.
Returns:
openai.OpenAI: OpenAI client.
openai.OpenAI: OpenAI client.
"""
return OpenAI(api_key=api_key, base_url=base_url, organization=org)
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_create_model_with_unsupported_model_raises_exception(self):
Test if CREATE MODEL raises an exception with an unsupported model.
"""
self.run_sql(
f"""
"""
CREATE MODEL proj.test_openaai_unsupported_model_model
PREDICT answer
USING
Expand All @@ -51,7 +51,7 @@ def test_full_flow_in_default_mode_with_question_column_for_single_prediction_ru
Test the full flow in default mode with a question column for a single prediction.
"""
self.run_sql(
f"""
"""
CREATE MODEL proj.test_openai_single_full_flow_default_mode_question_column
PREDICT answer
USING
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_full_flow_in_default_mode_with_question_column_for_bulk_predictions_run
self.set_handler(mock_handler, name="pg", tables={"df": df})

self.run_sql(
f"""
"""
CREATE MODEL proj.test_openai_bulk_full_flow_default_mode_question_column
PREDICT answer
USING
Expand All @@ -109,7 +109,7 @@ def test_full_flow_in_default_mode_with_prompt_template_for_single_prediction_ru
Test the full flow in default mode with a prompt template for a single prediction.
"""
self.run_sql(
f"""
"""
CREATE MODEL proj.test_openai_single_full_flow_default_mode_prompt_template
PREDICT answer
USING
Expand Down Expand Up @@ -142,7 +142,7 @@ def test_full_flow_in_default_mode_with_prompt_template_for_bulk_predictions_run
self.set_handler(mock_handler, name="pg", tables={"df": df})

self.run_sql(
f"""
"""
CREATE MODEL proj.test_openai_bulk_full_flow_default_mode_prompt_template
PREDICT answer
USING
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_full_flow_in_embedding_mode_for_single_prediction_runs_no_errors(self):
Test the full flow in embedding mode for a single prediction.
"""
self.run_sql(
f"""
"""
CREATE MODEL proj.test_openai_single_full_flow_embedding_mode
PREDICT answer
USING
Expand Down Expand Up @@ -205,7 +205,7 @@ def test_full_flow_in_embedding_mode_for_bulk_predictions_runs_no_errors(self, m
self.set_handler(mock_handler, name="pg", tables={"df": df})

self.run_sql(
f"""
"""
CREATE MODEL proj.test_openai_bulk_full_flow_embedding_mode
PREDICT answer
USING
Expand Down Expand Up @@ -235,7 +235,7 @@ def test_full_flow_in_image_mode_for_single_prediction_runs_no_errors(self):
Test the full flow in image mode for a single prediction.
"""
self.run_sql(
f"""
"""
CREATE MODEL proj.test_openai_single_full_flow_image_mode
PREDICT answer
USING
Expand Down Expand Up @@ -268,7 +268,7 @@ def test_full_flow_in_image_mode_for_bulk_predictions_runs_no_errors(self, mock_
self.set_handler(mock_handler, name="pg", tables={"df": df})

self.run_sql(
f"""
"""
CREATE MODEL proj.test_openai_bulk_full_flow_image_mode
PREDICT answer
USING
Expand All @@ -295,7 +295,7 @@ def test_full_flow_in_conversational_for_single_prediction_mode_runs_no_errors(s
Test the full flow in conversational mode for a single prediction.
"""
self.run_sql(
f"""
"""
CREATE MODEL proj.test_openai_single_full_flow_conversational_mode
PREDICT answer
USING
Expand Down Expand Up @@ -330,7 +330,7 @@ def test_full_flow_in_conversational_mode_for_bulk_predictions_runs_no_errors(se
self.set_handler(mock_handler, name="pg", tables={"df": df})

self.run_sql(
f"""
"""
CREATE MODEL proj.test_openai_bulk_full_flow_conversational_mode
PREDICT answer
USING
Expand Down Expand Up @@ -359,7 +359,7 @@ def test_full_flow_in_conversational_full_mode_for_single_prediction_runs_no_err
Test the full flow in conversational-full mode for a single prediction.
"""
self.run_sql(
f"""
"""
CREATE MODEL proj.test_openai_single_full_flow_conversational_full_mode
PREDICT answer
USING
Expand Down Expand Up @@ -394,7 +394,7 @@ def test_full_flow_in_conversational_full_mode_for_bulk_predictions_runs_no_erro
self.set_handler(mock_handler, name="pg", tables={"df": df})

self.run_sql(
f"""
"""
CREATE MODEL proj.test_openai_bulk_full_flow_conversational_full_mode
PREDICT answer
USING
Expand Down Expand Up @@ -454,7 +454,7 @@ def test_full_flow_in_conversational_full_mode_for_bulk_predictions_runs_no_erro
# };
# """,
# """
# CREATE MODEL
# CREATE MODEL
# mindsdb.home_rentals_model
# FROM example_db
# (SELECT * FROM demo_data.home_rentals)
Expand All @@ -466,7 +466,7 @@ def test_full_flow_in_conversational_full_mode_for_bulk_predictions_runs_no_erro
# JOIN project_name.model_name [AS] p;
# """
# ]

# }
# )
# self.set_handler(mock_handler, name="pg", tables={"df": df})
Expand Down Expand Up @@ -506,4 +506,4 @@ def test_full_flow_in_conversational_full_mode_for_bulk_predictions_runs_no_erro


if __name__ == "__main__":
pytest.main([__file__])
pytest.main([__file__])

0 comments on commit fd8ac84

Please sign in to comment.