Skip to content

Commit

Permalink
refactor(defaults): better defaults (#5)
Browse files Browse the repository at this point in the history
This commit moves some constants over from `env` to their own classes,
and makes the clients have a `PROTOCOL` instance in them. This will make
it easier to extend further later.
  • Loading branch information
serramatutu authored Jun 17, 2024
1 parent 609fca6 commit 2d2a48e
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 10 deletions.
4 changes: 1 addition & 3 deletions dbtsl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,4 @@ def err_factory(*args, **kwargs) -> None: # noqa: D103

SemanticLayerClient = err_factory

__all__ = [
"SemanticLayerClient"
]
__all__ = ["SemanticLayerClient"]
4 changes: 2 additions & 2 deletions dbtsl/api/adbc/client/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import Self, Unpack

from dbtsl.api.adbc.client.base import BaseADBCClient
from dbtsl.api.adbc.protocol import ADBCProtocol, QueryParameters
from dbtsl.api.adbc.protocol import QueryParameters


class AsyncADBCClient(BaseADBCClient):
Expand Down Expand Up @@ -52,7 +52,7 @@ async def session(self) -> AsyncIterator[Self]:

async def query(self, **query_params: Unpack[QueryParameters]) -> pa.Table:
"""Query for a dataframe in the Semantic Layer."""
query_sql = ADBCProtocol.get_query_sql(query_params)
query_sql = self.PROTOCOL.get_query_sql(query_params)

# NOTE: We don't need to wrap this in a `loop.run_in_executor` since
# just creating the cursor object doesn't perform any blocking IO.
Expand Down
6 changes: 5 additions & 1 deletion dbtsl/api/adbc/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,24 @@
from adbc_driver_manager import AdbcStatusCode, ProgrammingError

import dbtsl.env as env
from dbtsl.api.adbc.protocol import ADBCProtocol
from dbtsl.error import AuthError, QueryFailedError


class BaseADBCClient:
"""Base class for the ADBC API client."""

PROTOCOL = ADBCProtocol
DEFAULT_URL_FORMAT = env.DEFAULT_ADBC_URL_FORMAT

def __init__( # noqa: D107
self,
server_host: str,
environment_id: int,
auth_token: str,
url_format: Optional[str] = None,
) -> None:
url_format = url_format or env.DEFAULT_ADBC_URL_FORMAT
url_format = url_format or self.DEFAULT_URL_FORMAT
self._conn_str = url_format.format(server_host=server_host)
self._environment_id = environment_id
self._auth_token = auth_token
Expand Down
4 changes: 2 additions & 2 deletions dbtsl/api/adbc/client/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing_extensions import Self, Unpack

from dbtsl.api.adbc.client.base import BaseADBCClient
from dbtsl.api.adbc.protocol import ADBCProtocol, QueryParameters
from dbtsl.api.adbc.protocol import QueryParameters


class SyncADBCClient(BaseADBCClient):
Expand Down Expand Up @@ -48,7 +48,7 @@ def session(self) -> Iterator[Self]:

def query(self, **query_params: Unpack[QueryParameters]) -> pa.Table:
"""Query for a dataframe in the Semantic Layer."""
query_sql = ADBCProtocol.get_query_sql(query_params)
query_sql = self.PROTOCOL.get_query_sql(query_params)

with self._conn.cursor() as cur:
try:
Expand Down
7 changes: 5 additions & 2 deletions dbtsl/api/graphql/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class BaseGraphQLClient(Generic[TTransport, TSession]):
will choose if IO is sync or async.
"""

PROTOCOL = GraphQLProtocol
DEFAULT_URL_FORMAT = env.DEFAULT_GRAPHQL_URL_FORMAT

@classmethod
def _default_backoff(cls) -> ExponentialBackoff:
"""Get the default backoff behavior when polling."""
Expand All @@ -46,7 +49,7 @@ def __init__( # noqa: D107
):
self.environment_id = environment_id

url_format = url_format or env.DEFAULT_GRAPHQL_URL_FORMAT
url_format = url_format or self.DEFAULT_URL_FORMAT
server_url = url_format.format(server_host=server_host)

transport = self._create_transport(url=server_url, headers={"authorization": f"bearer {auth_token}"})
Expand Down Expand Up @@ -87,7 +90,7 @@ def _gql_session(self) -> TSession:

def __getattr__(self, attr: str) -> Any:
"""Run an underlying GraphQLOperation if it exists in GraphQLProtocol."""
op = getattr(GraphQLProtocol, attr)
op = getattr(self.PROTOCOL, attr)
if op is None:
raise AttributeError()

Expand Down

0 comments on commit 2d2a48e

Please sign in to comment.