Skip to content

Commit

Permalink
test central erroring
Browse files Browse the repository at this point in the history
  • Loading branch information
v-chen_data committed Sep 27, 2024
1 parent ae1810a commit d8af015
Showing 1 changed file with 73 additions and 91 deletions.
164 changes: 73 additions & 91 deletions llmfoundry/command_utils/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def run_query(
collect: bool = True,
) -> Optional[Union[list['Row'], 'DataFrame', 'SparkDataFrame']]:
"""Run SQL query via databricks-connect or databricks-sql.
Args:
query (str): sql query
method (str): select from dbsql and dbconnect
Expand All @@ -227,43 +228,36 @@ def run_query(
if method == 'dbsql':
if cursor is None:
raise ValueError(f'cursor cannot be None if using method dbsql')
try:
cursor.execute(query)
if collect:
return cursor.fetchall()
except Exception as e:
from databricks.sql.exc import ServerOperationError
if isinstance(e, ServerOperationError):
if 'INSUFFICIENT_PERMISSIONS' in str(e):
match = re.search(r"'([^']+)'", str(e))
if match:
table_name = match.group(1)
action = f'accessing table {table_name}'
else:
action = 'accessing table'
raise InsufficientPermissionsError(action=action) from e
raise
cursor.execute(query)
if collect:
return cursor.fetchall()
elif method == 'dbconnect':
if spark == None:
raise ValueError(f'sparkSession is required for dbconnect')

try:
df = spark.sql(query)
if collect:
return df.collect()
return df
except Exception as e:
from pyspark.errors import AnalysisException
if isinstance(e, AnalysisException):
if 'INSUFFICIENT_PERMISSIONS' in str(e):
match = re.search(r"Table '([^']+)'", str(e))
if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore
match = re.search(
r"Schema\s+'([^']+)'",
e.message, # pyright: ignore
)
if match:
table_name = match.group(1)
action = f'accessing table {table_name}'
schema_name = match.group(1)
action = f'using the schema {schema_name}'
else:
action = 'accessing table'
raise InsufficientPermissionsError(action=action) from e
raise
action = 'using the schema'
raise InsufficientPermissionsError(action=action,) from e
raise RuntimeError(
f'Error in querying into schema. Restart sparkSession and try again',
) from e

if collect:
return df.collect()
return df
else:
raise ValueError(f'Unrecognized method: {method}')

Expand Down Expand Up @@ -475,79 +469,67 @@ def fetch(
"""
cursor = dbsql.cursor() if dbsql is not None else None
try:
nrows = get_total_rows(
tablename,
method,
cursor,
sparkSession,
)
# Get total rows
nrows = get_total_rows(tablename, method, cursor, sparkSession)

# Get columns info
columns, order_by, columns_str = get_columns_info(tablename, method, cursor, sparkSession)

if method == 'dbconnect' and sparkSession is not None:
log.info(f'{processes=}')
df = sparkSession.table(tablename)

# Running the query and collecting the data as arrow or json.
signed, _, _ = df.collect_cf('arrow') # pyright: ignore
log.info(f'len(signed) = {len(signed)}')

args = get_args(signed, json_output_folder, columns)

# Stopping the SparkSession to avoid spilling connection state into the subprocesses.
sparkSession.stop()

with ProcessPoolExecutor(max_workers=processes) as executor:
list(executor.map(download_starargs, args))

elif method == 'dbsql' and cursor is not None:
for start in range(0, nrows, batch_size):
log.warning(f'batch {start}')
end = min(start + batch_size, nrows)
fetch_data(
method,
cursor,
sparkSession,
start,
end,
order_by,
tablename,
columns_str,
json_output_folder,
)

except Exception as e:
from databricks.sql.exc import ServerOperationError
from pyspark.errors import AnalysisException
from databricks.sql.exc import ServerOperationError

if isinstance(e, (AnalysisException, ServerOperationError)):
if 'INSUFFICIENT_PERMISSIONS' in str(e):
if isinstance(
e,
AnalysisException,
) or isinstance(e, ServerOperationError):
raise InsufficientPermissionsError(
action=f'reading from {tablename}',
) from e

match = re.search(r"(?:Table|Schema)\s+'([^']+)'", str(e))
if match:
object_name = match.group(1)
action = f'accessing {object_name}'
else:
action = f'accessing {tablename}'
raise InsufficientPermissionsError(action=action) from e
if isinstance(e, InsufficientPermissionsError):
raise

raise RuntimeError(
f'Error in get rows from {tablename}. Restart sparkSession and try again',
) from e

try:
columns, order_by, columns_str = get_columns_info(
tablename,
method,
cursor,
sparkSession,
)
except Exception as e:
raise RuntimeError(
f'Error in get columns from {tablename}. Restart sparkSession and try again',
) from e

if method == 'dbconnect' and sparkSession is not None:
log.info(f'{processes=}')
df = sparkSession.table(tablename)

# Running the query and collecting the data as arrow or json.
signed, _, _ = df.collect_cf('arrow') # pyright: ignore
log.info(f'len(signed) = {len(signed)}')

args = get_args(signed, json_output_folder, columns)

# Stopping the SparkSession to avoid spilling connection state into the subprocesses.
sparkSession.stop()

with ProcessPoolExecutor(max_workers=processes) as executor:
list(executor.map(download_starargs, args))

elif method == 'dbsql' and cursor is not None:
for start in range(0, nrows, batch_size):
log.warning(f'batch {start}')
end = min(start + batch_size, nrows)
fetch_data(
method,
cursor,
sparkSession,
start,
end,
order_by,
tablename,
columns_str,
json_output_folder,
)
# For any other exception, raise a general error
raise RuntimeError(f"Error processing {tablename}: {str(e)}") from e

if cursor is not None:
cursor.close()
finally:
if cursor is not None:
cursor.close()


def validate_and_get_cluster_info(
Expand Down Expand Up @@ -814,4 +796,4 @@ def convert_delta_to_json_from_args(
DATABRICKS_HOST=DATABRICKS_HOST,
DATABRICKS_TOKEN=DATABRICKS_TOKEN,
)
log.info(f'Elapsed time {time.time() - tik}')
log.info(f'Elapsed time {time.time() - tik}')

0 comments on commit d8af015

Please sign in to comment.