Skip to content

Commit

Permalink
Delta to JSONL conversion script cleanup and bug fix (#868)
Browse files Browse the repository at this point in the history
* Small test change

* small cleanups

* lint and precommit

* lint and precommit

* comments

* another one

* pr suggestion and use input param not args
  • Loading branch information
nancyhung authored Jan 13, 2024
1 parent d05c099 commit b69318e
Showing 1 changed file with 70 additions and 40 deletions.
110 changes: 70 additions & 40 deletions scripts/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
from pyspark.sql.dataframe import DataFrame as SparkDataFrame
from pyspark.sql.types import Row

MINIMUM_DB_CONNECT_DBR_VERSION = '14.1.0'
MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2.0'
MINIMUM_DB_CONNECT_DBR_VERSION = '14.1'
MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2'

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -377,64 +377,61 @@ def fetch(
cursor.close()


def fetch_DT(args: Namespace) -> None:
"""Fetch UC Delta Table to local as jsonl."""
log.info(f'Start .... Convert delta to json')

obj = urllib.parse.urlparse(args.json_output_folder)
if obj.scheme != '':
raise ValueError(
f'Check the json_output_folder and verify it is a local path!')

if os.path.exists(args.json_output_folder):
if not os.path.isdir(args.json_output_folder) or os.listdir(
args.json_output_folder):
raise RuntimeError(
f'A file or a folder {args.json_output_folder} already exists and is not empty. Remove it and retry!'
)

os.makedirs(args.json_output_folder, exist_ok=True)

if not args.json_output_filename.endswith('.jsonl'):
raise ValueError('json_output_filename needs to be a jsonl file')

log.info(f'Directory {args.json_output_folder} created.')
def validate_and_get_cluster_info(cluster_id: str,
databricks_host: str,
databricks_token: str,
http_path: Optional[str],
use_serverless: bool = False) -> tuple:
"""Validate and get cluster info for running the Delta to JSONL conversion.
Args:
cluster_id (str): cluster id to validate and fetch additional info for
databricks_host (str): databricks host name
databricks_token (str): databricks auth token
http_path (Optional[str]): http path to use for sql connect
use_serverless (bool): whether to use serverless or not
"""
method = 'dbsql'
dbsql = None
sparkSession = None

if args.use_serverless:
if use_serverless:
method = 'dbconnect'
else:
w = WorkspaceClient()
res = w.clusters.get(cluster_id=args.cluster_id)
runtime_version = res.spark_version.split('-scala')[0].replace(
'x-snapshot', '0').replace('x', '0')
res = w.clusters.get(cluster_id=cluster_id)
if res is None:
raise ValueError(
f'Cluster id {cluster_id} does not exist. Check cluster id and try again!'
)
stripped_runtime = re.sub(
r'[a-zA-Z]', '',
res.spark_version.split('-scala')[0].replace('x-snapshot', ''))
runtime_version = re.sub(r'.-+$', '', stripped_runtime)
if version.parse(runtime_version) < version.parse(
MINIMUM_SQ_CONNECT_DBR_VERSION):
raise ValueError(
f'The minium DBR version required is {MINIMUM_SQ_CONNECT_DBR_VERSION} but got {version.parse(runtime_version)}'
)

if args.http_path is None and version.parse(
if http_path is None and version.parse(
runtime_version) >= version.parse(
MINIMUM_DB_CONNECT_DBR_VERSION):
method = 'dbconnect'

if method == 'dbconnect':
try:
if args.use_serverless:
if use_serverless:
session_id = str(uuid4())
sparkSession = DatabricksSession.builder.host(
args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN).header(
databricks_host).token(databricks_token).header(
'x-databricks-session-id', session_id).getOrCreate()

else:
sparkSession = DatabricksSession.builder.remote(
host=args.DATABRICKS_HOST,
token=args.DATABRICKS_TOKEN,
cluster_id=args.cluster_id).getOrCreate()
host=databricks_host,
token=databricks_token,
cluster_id=cluster_id).getOrCreate()

except Exception as e:
raise RuntimeError(
Expand All @@ -444,15 +441,47 @@ def fetch_DT(args: Namespace) -> None:
try:
dbsql = sql.connect(
server_hostname=re.compile(r'^https?://').sub(
'', args.DATABRICKS_HOST).strip(
'', databricks_host).strip(
), # sqlconnect hangs if hostname starts with https
http_path=args.http_path,
access_token=args.DATABRICKS_TOKEN,
http_path=http_path,
access_token=databricks_token,
)
except Exception as e:
raise RuntimeError(
'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!'
) from e
return method, dbsql, sparkSession


def fetch_DT(args: Namespace) -> None:
"""Fetch UC Delta Table to local as jsonl."""
log.info(f'Start .... Convert delta to json')

obj = urllib.parse.urlparse(args.json_output_folder)
if obj.scheme != '':
raise ValueError(
f'Check the json_output_folder and verify it is a local path!')

if os.path.exists(args.json_output_folder):
if not os.path.isdir(args.json_output_folder) or os.listdir(
args.json_output_folder):
raise RuntimeError(
f'A file or a folder {args.json_output_folder} already exists and is not empty. Remove it and retry!'
)

os.makedirs(args.json_output_folder, exist_ok=True)

if not args.json_output_filename.endswith('.jsonl'):
raise ValueError('json_output_filename needs to be a jsonl file')

log.info(f'Directory {args.json_output_folder} created.')

method, dbsql, sparkSession = validate_and_get_cluster_info(
cluster_id=args.cluster_id,
databricks_host=args.DATABRICKS_HOST,
databricks_token=args.DATABRICKS_TOKEN,
http_path=args.http_path,
use_serverless=args.use_serverless)

fetch(method, args.delta_table_name, args.json_output_folder,
args.batch_size, args.processes, sparkSession, dbsql)
Expand Down Expand Up @@ -494,9 +523,8 @@ def fetch_DT(args: Namespace) -> None:
help='number of processes allowed to use')
parser.add_argument(
'--cluster_id',
required=True,
required=False,
type=str,
default=None,
help=
'cluster id has runtime newer than 14.1.0 and access mode of either assigned or shared can use databricks-connect.'
)
Expand All @@ -513,7 +541,9 @@ def fetch_DT(args: Namespace) -> None:
required=False,
type=str,
default='train-00000-of-00001.jsonl',
help='The combined final jsonl that combines all partitioned jsonl')
help=
'The name of the combined final jsonl that combines all partitioned jsonl'
)
args = parser.parse_args()

from databricks.sdk import WorkspaceClient
Expand Down

0 comments on commit b69318e

Please sign in to comment.