Skip to content

Commit

Permalink
Move reflections to rest API calls
Browse files Browse the repository at this point in the history
  • Loading branch information
bcmeireles committed Dec 19, 2024
1 parent 790a6a4 commit 82a35cf
Show file tree
Hide file tree
Showing 10 changed files with 218 additions and 166 deletions.
1 change: 1 addition & 0 deletions dbt/adapters/dremio/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# __init__.py
from .rest.endpoints import (
create_reflection,
delete_catalog,
sql_endpoint,
job_status,
Expand Down
47 changes: 46 additions & 1 deletion dbt/adapters/dremio/api/rest/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,26 @@ def _post(
return _check_error(response, details)


def _put(
url,
request_headers=None,
json=None,
details="",
ssl_verify=True,
timeout=None,
):
if isinstance(json, str):
json = jsonlib.loads(json)
response = session.put(
url,
headers=request_headers,
timeout=timeout,
verify=ssl_verify,
json=json,
)
return _check_error(response, details)


def _delete(url, request_headers, details="", ssl_verify=True):
response = session.delete(url, headers=request_headers, verify=ssl_verify)
return _check_error(response, details)
Expand Down Expand Up @@ -149,7 +169,6 @@ def _check_error(response, details=""):


def login(api_parameters: Parameters, timeout=10):

if isinstance(api_parameters.authentication, DremioPatAuthentication):
return api_parameters

Expand Down Expand Up @@ -251,3 +270,29 @@ def delete_catalog(api_parameters, cid):
api_parameters.authentication.get_headers(),
ssl_verify=api_parameters.authentication.verify_ssl,
)

def get_reflection(api_parameters, dataset_id):
url = UrlBuilder.get_reflection_url(api_parameters, dataset_id)
return _get(
url,
api_parameters.authentication.get_headers(),
ssl_verify=api_parameters.authentication.verify_ssl,
)

def create_reflection(api_parameters: Parameters, name: str, type: str, payload):
url = UrlBuilder.create_reflection_url(api_parameters)
return _post(
url,
api_parameters.authentication.get_headers(),
json=payload,
ssl_verify=api_parameters.authentication.verify_ssl,
)

def update_reflection(api_parameters: Parameters, dataset_id: str, payload):
url = UrlBuilder.update_reflection_url(api_parameters, dataset_id)
return _put(
url,
api_parameters.authentication.get_headers(),
json=payload,
ssl_verify=api_parameters.authentication.verify_ssl,
)
62 changes: 53 additions & 9 deletions dbt/adapters/dremio/api/rest/url_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ class UrlBuilder:
SOFTWARE_CATALOG_ENDPOINT = "/api/v3/catalog"
CLOUD_CATALOG_ENDPOINT = CLOUD_PROJECT_ENDPOINT + "/{}/catalog"

SOFTWARE_REFLECTIONS_ENDPOINT = "/api/v3/reflection"
CLOUD_REFLECTIONS_ENDPOINT = CLOUD_PROJECT_ENDPOINT + "/{}/reflection"

SOFTWARE_DATASET_ENDPOIT = "/api/v3/dataset"
CLOUD_DATASET_ENDPOINT = CLOUD_PROJECT_ENDPOINT + "/{}/dataset"

# https://docs.dremio.com/software/rest-api/jobs/get-job/
OFFSET_DEFAULT = 0
LIMIT_DEFAULT = 100
Expand All @@ -56,10 +62,10 @@ def sql_url(cls, parameters: Parameters):
def job_status_url(cls, parameters: Parameters, job_id):
if type(parameters) is CloudParameters:
return (
parameters.base_url
+ UrlBuilder.CLOUD_JOB_ENDPOINT.format(parameters.cloud_project_id)
+ "/"
+ job_id
parameters.base_url
+ UrlBuilder.CLOUD_JOB_ENDPOINT.format(parameters.cloud_project_id)
+ "/"
+ job_id
)
return parameters.base_url + UrlBuilder.SOFTWARE_JOB_ENDPOINT + "/" + job_id

Expand All @@ -75,11 +81,11 @@ def job_cancel_url(cls, parameters: Parameters, job_id):

@classmethod
def job_results_url(
cls,
parameters: Parameters,
job_id,
offset=OFFSET_DEFAULT,
limit=LIMIT_DEFAULT,
cls,
parameters: Parameters,
job_id,
offset=OFFSET_DEFAULT,
limit=LIMIT_DEFAULT,
):
url_path = parameters.base_url
if type(parameters) is CloudParameters:
Expand Down Expand Up @@ -139,3 +145,41 @@ def catalog_item_by_path_url(cls, parameters: Parameters, path_list):
joined_path_str = "/".join(quoted_path_list).replace('"', "")
endpoint = f"/by-path/{joined_path_str}"
return url_path + endpoint

@classmethod
def create_reflection_url(cls, parameters: Parameters):
url_path = parameters.base_url
if type(parameters) is CloudParameters:
url_path += UrlBuilder.CLOUD_REFLECTIONS_ENDPOINT.format(
parameters.cloud_project_id
)
else:
url_path += UrlBuilder.SOFTWARE_REFLECTIONS_ENDPOINT

return url_path

@classmethod
def update_reflection_url(cls, parameters: Parameters, dataset_id):
url_path = parameters.base_url
if type(parameters) is CloudParameters:
url_path += UrlBuilder.CLOUD_REFLECTIONS_ENDPOINT.format(
parameters.cloud_project_id
)
else:
url_path += UrlBuilder.SOFTWARE_REFLECTIONS_ENDPOINT

endpoint = "/{}".format(dataset_id)
return url_path + endpoint

@classmethod
def get_reflection_url(cls, parameters: Parameters, dataset_id):
url_path = parameters.base_url
if type(parameters) is CloudParameters:
url_path += UrlBuilder.CLOUD_DATASET_ENDPOINT.format(
parameters.cloud_project_id
)
else:
url_path += UrlBuilder.SOFTWARE_DATASET_ENDPOIT

endpoint = "/{}/reflection".format(dataset_id)
return url_path + endpoint
90 changes: 82 additions & 8 deletions dbt/adapters/dremio/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from dbt.adapters.contracts.connection import AdapterResponse

from dbt.adapters.dremio.api.rest.endpoints import (
create_reflection,
update_reflection,
get_reflection,
delete_catalog,
create_catalog_api,
get_catalog_item,
Expand Down Expand Up @@ -133,7 +136,7 @@ def add_commit_query(self):

# Auto_begin may not be relevant with the rest_api
def add_query(
self, sql, auto_begin=True, bindings=None, abridge_sql_log=False, fetch=False
self, sql, auto_begin=True, bindings=None, abridge_sql_log=False, fetch=False
):
connection = self.get_thread_connection()
if auto_begin and connection.transaction_open is False:
Expand Down Expand Up @@ -174,11 +177,11 @@ def get_response(cls, cursor: DremioCursor) -> AdapterResponse:
return AdapterResponse(_message=message, rows_affected=rows)

def execute(
self,
sql: str,
auto_begin: bool = False,
fetch: bool = False,
limit: Optional[int] = None,
self,
sql: str,
auto_begin: bool = False,
fetch: bool = False,
limit: Optional[int] = None,
) -> Tuple[AdapterResponse, agate.Table]:
sql = self._add_query_comment(sql)
_, cursor = self.add_query(sql, auto_begin, fetch=fetch)
Expand Down Expand Up @@ -231,6 +234,76 @@ def create_catalog(self, relation):
self._create_folders(database, schema, api_parameters)
return

def dbt_reflection_integration(self, name: str, type: str, anchor, display, dimensions, date_dimensions, measures,
computations, partition_by, partition_method, localsort_by):
thread_connection = self.get_thread_connection()
connection = self.open(thread_connection)
api_parameters = connection.handle.get_parameters()

database = anchor.database
schema = anchor.schema
path = self._create_path_list(database, schema)
identifier = anchor.identifier

path.append(identifier)

catalog_info = get_catalog_item(
api_parameters,
catalog_id=None,
catalog_path=path,
)

dataset_id = catalog_info.get("id")

payload = {
"type": type,
"name": name,
"datasetId": dataset_id,
"enabled": True,
"arrowCachingEnabled": False,
"partitionDistributionStrategy": partition_method.upper(),
"entityType": "reflection"
}

if display:
payload["displayFields"] = [{"name": field} for field in display]

if dimensions:
if not date_dimensions:
date_dimensions = []

payload["dimensionFields"] = [
{"name": dimension} if dimension not in date_dimensions else {"name": dimension, "granularity": "DATE"}
for dimension in dimensions]

if measures and computations:
payload["measureFields"] = [{"name": measure, "measureTypeList": computation.split(',')} for
measure, computation in zip(measures, computations)]

if partition_by:
payload["partitionFields"] = [{"name": partition} for partition in partition_by]

if localsort_by:
payload["sortFields"] = [{"name": sort} for sort in localsort_by]

dataset_info = get_reflection(api_parameters, dataset_id)
reflections_info = dataset_info.get("data")

updated = False
for reflection in reflections_info:
if reflection.get("name") == name:
logger.debug(f"Reflection {name} already exists. Updating it")
payload["tag"] = reflection.get("tag")
logger.info(
update_reflection(api_parameters, reflection.get("id"), payload))
updated = True
break

if not updated:
logger.debug(f"Reflection {name} does not exist. Creating it")
logger.info(
create_reflection(api_parameters, name, type, payload))

def _make_new_space_json(self, name) -> json:
python_dict = {"entityType": "space", "name": name}
return json.dumps(python_dict)
Expand Down Expand Up @@ -263,6 +336,7 @@ def _create_folders(self, database, schema, api_parameters):

def _create_path_list(self, database, schema):
path = [database]
folders = schema.split(".")
path.extend(folders)
if schema != 'no_schema':
folders = schema.split(".")
path.extend(folders)
return path
5 changes: 5 additions & 0 deletions dbt/adapters/dremio/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from typing import List
from typing import Optional
from dbt.adapters.base.meta import available
from dbt.adapters.base.relation import BaseRelation

from dbt.adapters.capability import (
Expand Down Expand Up @@ -177,6 +178,10 @@ def run_sql_for_tests(self, sql, fetch, conn):
finally:
conn.transaction_open = False

@available
def dbt_reflection_integration(self, name: str, type: str, anchor, display, dimensions, date_dimensions, measures, computations, partition_by, partition_method, localsort_by) -> None:
self.connections.dbt_reflection_integration(name, type, anchor, display, dimensions, date_dimensions, measures, computations, partition_by, partition_method, localsort_by)


COLUMNS_EQUAL_SQL = """
with diff_count as (
Expand Down
4 changes: 0 additions & 4 deletions dbt/include/dremio/dbt_project.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,3 @@ quoting:
identifier: true

macro-paths: ["macros"]

vars:
"dremio:reflections_enabled": false
"dremio:exact_search_enabled": false
48 changes: 0 additions & 48 deletions dbt/include/dremio/macros/adapters/metadata.sql
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,11 @@ limitations under the License.*/
{%- endmacro -%}

{% macro dremio__get_catalog_relations_result_sql(relations) %}
{%- if var('dremio:reflections_enabled', default=false) %}
{{get_catalog_reflections(relations)}}
{% else %}

select *
from t
join columns on (t.table_schema = columns.table_schema
and t.table_name = columns.table_name)
order by "column_index"
{% endif %}
{%- endmacro -%}

{% macro get_catalog_reflections(relations) %}
Expand Down Expand Up @@ -241,49 +236,6 @@ limitations under the License.*/
{%- set schema_name = database
+ (('.' + schema) if schema != 'no_schema' else '') -%}
{% call statement('list_relations_without_caching', fetch_result=True) -%}

{%- if var('dremio:reflections_enabled', default=false) -%}

with cte1 as (
select
dataset_name
,reflection_name
,type
,case when substr(dataset_name, 1, 1) = '"'
then strpos(dataset_name, '".') + 1
else strpos(dataset_name, '.')
end as first_dot
,length(dataset_name) -
case when substr(dataset_name, length(dataset_name)) = '"'
then strpos(reverse(dataset_name), '".')
else strpos(reverse(dataset_name), '.') - 1
end as last_dot
,length(dataset_name) as length
{%- if target.cloud_host and not target.software_host %}
from sys.project.reflections
{%- elif target.software_host and not target.cloud_host %}
from sys.reflections
{%- endif %}
)
, cte2 as (
select
replace(substr(dataset_name, 1, first_dot - 1), '"', '') as table_catalog
,reflection_name as table_name
,replace(case when first_dot < last_dot
then substr(dataset_name, first_dot + 1, last_dot - first_dot - 1)
else 'no_schema' end, '"', '') as table_schema
,'materialized_view' as table_type
from cte1
)
select table_catalog, table_name, table_schema, table_type
from cte2
where ilike(table_catalog, '{{ database }}')
and ilike(table_schema, '{{ schema }}')

union all

{%- endif %}

select (case when position('.' in table_schema) > 0
then substring(table_schema, 1, position('.' in table_schema) - 1)
else table_schema
Expand Down
Loading

0 comments on commit 82a35cf

Please sign in to comment.