Skip to content

Commit

Permalink
Merge branch 'main' into cl_callback
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu authored Jun 24, 2024
2 parents a5fa8a5 + 2267bc7 commit d745a12
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 13 deletions.
56 changes: 44 additions & 12 deletions scripts/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pyspark.sql.connect.proto as pb2
import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2
import requests
from composer.utils import retry
from databricks import sql
from databricks.connect import DatabricksSession
from databricks.sdk import WorkspaceClient
Expand Down Expand Up @@ -347,6 +348,44 @@ def fetch_data(
)


@retry(Exception, num_attempts=5, initial_backoff=1.0, max_jitter=0.5)
def get_total_rows(
tablename: str,
method: str,
cursor: Optional[Cursor],
sparkSession: Optional[SparkSession],
):
ans = run_query(
f'SELECT COUNT(*) FROM {tablename}',
method,
cursor,
sparkSession,
)
nrows = [row.asDict() for row in ans][0].popitem()[1] # pyright: ignore
log.info(f'total_rows = {nrows}')
return nrows


@retry(Exception, num_attempts=5, initial_backoff=1.0, max_jitter=0.5)
def get_columns_info(
tablename: str,
method: str,
cursor: Optional[Cursor],
sparkSession: Optional[SparkSession],
):
ans = run_query(
f'SHOW COLUMNS IN {tablename}',
method,
cursor,
sparkSession,
)
columns = [row.asDict().popitem()[1] for row in ans] # pyright: ignore
order_by = columns[0]
columns_str = ','.join(columns)
log.info(f'order by column {order_by}')
return columns, order_by, columns_str


def fetch(
method: str,
tablename: str,
Expand All @@ -368,32 +407,25 @@ def fetch(
dbsql (databricks.sql.connect): dbsql session
"""
cursor = dbsql.cursor() if dbsql is not None else None

try:
ans = run_query(
f'SELECT COUNT(*) FROM {tablename}',
nrows = get_total_rows(
tablename,
method,
cursor,
sparkSession,
)
nrows = [row.asDict() for row in ans][0].popitem()[1] # pyright: ignore
log.info(f'total_rows = {nrows}')
except Exception as e:
raise RuntimeError(
f'Error in get total rows from {tablename}. Restart sparkSession and try again',
f'Error in get rows from {tablename}. Restart sparkSession and try again',
) from e

try:
ans = run_query(
f'SHOW COLUMNS IN {tablename}',
columns, order_by, columns_str = get_columns_info(
tablename,
method,
cursor,
sparkSession,
)
columns = [row.asDict().popitem()[1] for row in ans] # pyright: ignore
order_by = columns[0]
columns_str = ','.join(columns)
log.info(f'order by column {order_by}')
except Exception as e:
raise RuntimeError(
f'Error in get columns from {tablename}. Restart sparkSession and try again',
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@

install_requires = [
'mosaicml[libcloud,wandb,oci,gcs,mlflow]>=0.23.4,<0.24',
'mlflow>=2.13.2,<2.14',
'mlflow>=2.14.1,<2.15',
'accelerate>=0.25,<0.26', # for HF inference `device_map`
'transformers>=4.40,<4.41',
'mosaicml-streaming>=0.7.6,<0.8',
Expand Down

0 comments on commit d745a12

Please sign in to comment.