Skip to content

Commit

Permalink
merge in protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 committed Mar 20, 2024
1 parent a4f2607 commit a3bbc34
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 13 deletions.
23 changes: 11 additions & 12 deletions daft/table/table_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,15 +577,23 @@ def write_iceberg(
from pyiceberg.manifest import FileFormat as IcebergFileFormat
from pyiceberg.typedef import Record

[resolved_path], fs = _resolve_paths_and_filesystem(base_path, io_config=io_config)
if isinstance(base_path, pathlib.Path):
path_str = str(base_path)
else:
path_str = base_path

protocol = get_protocol_from_path(path_str)
canonicalized_protocol = canonicalize_protocol(protocol)

data_files = []

def file_visitor(written_file):
def file_visitor(written_file, protocol=protocol):

file_path = written_file.path
file_path = f"{protocol}://{written_file.path}"
size = written_file.size
metadata = written_file.metadata
# TODO Version guard pyarrow version

data_file = DataFile(
content=DataFileContent.DATA,
file_path=file_path,
Expand All @@ -609,15 +617,6 @@ def file_visitor(written_file):
)
data_files.append(data_file)

[resolved_path], fs = _resolve_paths_and_filesystem(base_path, io_config=io_config)
if isinstance(base_path, pathlib.Path):
path_str = str(base_path)
else:
path_str = base_path

protocol = get_protocol_from_path(path_str)
canonicalized_protocol = canonicalize_protocol(protocol)

is_local_fs = canonicalized_protocol == "file"

execution_config = get_context().daft_execution_config
Expand Down
30 changes: 29 additions & 1 deletion tests/integration/iceberg/test_pyiceberg_written_table_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,38 @@ def table_written_by_pyiceberg(local_iceberg_catalog):
local_iceberg_catalog.drop_table("pyiceberg.map_table")


@contextlib.contextmanager
def table_written_by_daft(local_iceberg_catalog):
schema = pa.schema([("col", pa.int64()), ("mapCol", pa.map_(pa.int32(), pa.string()))])

data = {"col": [1, 2, 3], "mapCol": [[(1, "foo"), (2, "bar")], [(3, "baz")], [(4, "foobar")]]}
arrow_table = pa.Table.from_pydict(data, schema=schema)
try:
table = local_iceberg_catalog.create_table("pyiceberg.map_table", schema=schema)
df = daft.from_arrow(arrow_table)
df.write_iceberg(table, mode="overwrite")
table.refresh()
yield table
except Exception as e:
raise e
finally:
local_iceberg_catalog.drop_table("pyiceberg.map_table")


@pytest.mark.integration()
def test_localdb_catalog(local_iceberg_catalog):
def test_pyiceberg_written_catalog(local_iceberg_catalog):
with table_written_by_pyiceberg(local_iceberg_catalog) as catalog_table:
df = daft.read_iceberg(catalog_table)
daft_pandas = df.to_pandas()
iceberg_pandas = catalog_table.scan().to_arrow().to_pandas()
assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[])


@pytest.mark.integration()
@pytest.mark.skip
def test_daft_written_catalog(local_iceberg_catalog):
with table_written_by_daft(local_iceberg_catalog) as catalog_table:
df = daft.read_iceberg(catalog_table)
daft_pandas = df.to_pandas()
iceberg_pandas = catalog_table.scan().to_arrow().to_pandas()
assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[])

0 comments on commit a3bbc34

Please sign in to comment.