Skip to content

Commit

Permalink
allow to select schema from pipeline dataset factory
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Nov 19, 2024
1 parent 73b79ee commit 4d32c5f
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 3 deletions.
5 changes: 4 additions & 1 deletion dlt/destinations/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def _destination_client(self, schema: Schema) -> JobClientBase:

def _ensure_client_and_schema(self) -> None:
"""Lazy load schema and client"""

# full schema given, nothing to do
if not self._schema and isinstance(self._provided_schema, Schema):
self._schema = self._provided_schema
Expand All @@ -259,6 +260,8 @@ def _ensure_client_and_schema(self) -> None:
stored_schema = client.get_stored_schema(self._provided_schema)
if stored_schema:
self._schema = Schema.from_stored_schema(json.loads(stored_schema.schema))
else:
self._schema = Schema(self._provided_schema)

# no schema name given, load newest schema from destination
elif not self._schema:
Expand All @@ -268,7 +271,7 @@ def _ensure_client_and_schema(self) -> None:
if stored_schema:
self._schema = Schema.from_stored_schema(json.loads(stored_schema.schema))

# default to empty schema with dataset name if nothing found
# default to empty schema with dataset name
if not self._schema:
self._schema = Schema(self._dataset_name)

Expand Down
9 changes: 7 additions & 2 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
cast,
get_type_hints,
ContextManager,
Union,
)

from dlt import version
Expand Down Expand Up @@ -1790,11 +1791,15 @@ def __getstate__(self) -> Any:
# pickle only the SupportsPipeline protocol fields
return {"pipeline_name": self.pipeline_name}

def _dataset(self, dataset_type: TDatasetType = "dbapi") -> SupportsReadableDataset:
def _dataset(
self, schema: Union[Schema, str, None] = None, dataset_type: TDatasetType = "dbapi"
) -> SupportsReadableDataset:
"""Access helper to dataset"""
if schema is None:
schema = self.default_schema if self.default_schema_name else None
return dataset(
self._destination,
self.dataset_name,
schema=(self.default_schema if self.default_schema_name else None),
schema=schema,
dataset_type=dataset_type,
)
28 changes: 28 additions & 0 deletions tests/load/test_read_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,34 @@ def test_column_selection(populated_pipeline: Pipeline) -> None:
arrow_table = table_relationship.select("unknown_column").head().arrow()


@pytest.mark.no_load
@pytest.mark.essential
@pytest.mark.parametrize(
"populated_pipeline",
configs,
indirect=True,
ids=lambda x: x.name,
)
def test_schema_arg(populated_pipeline: Pipeline) -> None:
"""Simple test to ensure schemas may be selected via schema arg"""

# if there is no arg, the defautl schema is used
dataset = populated_pipeline._dataset()
assert dataset.schema.name == populated_pipeline.default_schema_name # type: ignore
assert "items" in dataset.schema.tables # type: ignore

# setting a different schema name will try to load that schema,
# not find one and create an empty schema with that name
dataset = populated_pipeline._dataset(schema="unknown_schema")
assert dataset.schema.name == "unknown_schema" # type: ignore
assert "items" not in dataset.schema.tables # type: ignore

# providing the schema name of the right schema will load it
dataset = populated_pipeline._dataset(schema=populated_pipeline.default_schema_name)
assert dataset.schema.name == populated_pipeline.default_schema_name # type: ignore
assert "items" in dataset.schema.tables # type: ignore


@pytest.mark.no_load
@pytest.mark.essential
@pytest.mark.parametrize(
Expand Down

0 comments on commit 4d32c5f

Please sign in to comment.