Skip to content

Commit

Permalink
Merge branch 'main' into mixed_attention_modules
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML authored Jun 24, 2024
2 parents cc1f2f3 + fd7b187 commit 9f7b346
Showing 1 changed file with 44 additions and 12 deletions.
56 changes: 44 additions & 12 deletions scripts/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pyspark.sql.connect.proto as pb2
import pyspark.sql.connect.proto.cloud_pb2 as cloud_pb2
import requests
from composer.utils import retry
from databricks import sql
from databricks.connect import DatabricksSession
from databricks.sdk import WorkspaceClient
Expand Down Expand Up @@ -347,6 +348,44 @@ def fetch_data(
)


@retry(Exception, num_attempts=5, initial_backoff=1.0, max_jitter=0.5)
def get_total_rows(
tablename: str,
method: str,
cursor: Optional[Cursor],
sparkSession: Optional[SparkSession],
):
ans = run_query(
f'SELECT COUNT(*) FROM {tablename}',
method,
cursor,
sparkSession,
)
nrows = [row.asDict() for row in ans][0].popitem()[1] # pyright: ignore
log.info(f'total_rows = {nrows}')
return nrows


@retry(Exception, num_attempts=5, initial_backoff=1.0, max_jitter=0.5)
def get_columns_info(
tablename: str,
method: str,
cursor: Optional[Cursor],
sparkSession: Optional[SparkSession],
):
ans = run_query(
f'SHOW COLUMNS IN {tablename}',
method,
cursor,
sparkSession,
)
columns = [row.asDict().popitem()[1] for row in ans] # pyright: ignore
order_by = columns[0]
columns_str = ','.join(columns)
log.info(f'order by column {order_by}')
return columns, order_by, columns_str


def fetch(
method: str,
tablename: str,
Expand All @@ -368,32 +407,25 @@ def fetch(
dbsql (databricks.sql.connect): dbsql session
"""
cursor = dbsql.cursor() if dbsql is not None else None

try:
ans = run_query(
f'SELECT COUNT(*) FROM {tablename}',
nrows = get_total_rows(
tablename,
method,
cursor,
sparkSession,
)
nrows = [row.asDict() for row in ans][0].popitem()[1] # pyright: ignore
log.info(f'total_rows = {nrows}')
except Exception as e:
raise RuntimeError(
f'Error in get total rows from {tablename}. Restart sparkSession and try again',
f'Error in get rows from {tablename}. Restart sparkSession and try again',
) from e

try:
ans = run_query(
f'SHOW COLUMNS IN {tablename}',
columns, order_by, columns_str = get_columns_info(
tablename,
method,
cursor,
sparkSession,
)
columns = [row.asDict().popitem()[1] for row in ans] # pyright: ignore
order_by = columns[0]
columns_str = ','.join(columns)
log.info(f'order by column {order_by}')
except Exception as e:
raise RuntimeError(
f'Error in get columns from {tablename}. Restart sparkSession and try again',
Expand Down

0 comments on commit 9f7b346

Please sign in to comment.