Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coerce timestamp units in metadata #107

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
Changelog
=========

Plateau 4.2.1 (2023-10-31)
==========================

* Add support for pandas 2.1
* Fix a bug to do with timestamp dtype conversion
* Add timestamp unit coercion as Plateau currently only supports nanosecond units on timestamps

Plateau 4.2.0 (2023-10-10)
==========================

Expand Down
14 changes: 13 additions & 1 deletion plateau/io_components/write.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import partial
from typing import Dict, Iterable, List, Optional, cast

import pyarrow as pa
from minimalkv import KeyValueStore

from plateau.core import naming
Expand Down Expand Up @@ -126,6 +127,17 @@ def persist_common_metadata(
return result


# Currently we only support nanosecond timestamps.
def coerce_schema_timestamps(wrapper: SchemaWrapper) -> SchemaWrapper:
schema = wrapper.internal()
fields = []
for field in schema:
if field.type in [pa.timestamp("s"), pa.timestamp("ms"), pa.timestamp("us")]:
field = pa.field(field.name, pa.timestamp("ns"))
fields.append(field)
return SchemaWrapper(pa.schema(fields, schema.metadata), wrapper.origin)


def store_dataset_from_partitions(
partition_list,
store: StoreInput,
Expand Down Expand Up @@ -161,7 +173,7 @@ def store_dataset_from_partitions(

for mp in partition_list:
if mp.schema:
schemas.add(mp.schema)
schemas.add(coerce_schema_timestamps(mp.schema))

dataset_builder.schema = persist_common_metadata(
schemas=schemas,
Expand Down
45 changes: 45 additions & 0 deletions tests/io_components/test_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

import pandas as pd
import pytest
from packaging import version

from plateau.core.dataset import DatasetMetadata
from plateau.core.index import ExplicitSecondaryIndex
from plateau.core.testing import TIME_TO_FREEZE_ISO
from plateau.io_components.metapartition import MetaPartition
from plateau.io_components.read import dispatch_metapartitions
from plateau.io_components.write import (
raise_if_dataset_exists,
store_dataset_from_partitions,
Expand Down Expand Up @@ -117,3 +119,46 @@ def test_raise_if_dataset_exists(store_factory, dataset_function):
raise_if_dataset_exists(dataset_uuid="ThisDoesNotExist", store=store_factory)
with pytest.raises(RuntimeError):
raise_if_dataset_exists(dataset_uuid=dataset_function.uuid, store=store_factory)


@pytest.mark.skipif(
version.parse(pd.__version__) < version.parse("2"),
reason="Timestamp unit coercion is only relevant in pandas >= 2",
)
def test_coerce_schema_timestamp_units(store):
date = pd.Timestamp(2000, 1, 1)

mps_original = [
MetaPartition(label="one", data=pd.DataFrame({"a": date, "b": [date]})),
MetaPartition(
label="two",
data=pd.DataFrame({"a": date.as_unit("ns"), "b": [date.as_unit("ns")]}),
),
]

mps = map(
lambda mp: mp.store_dataframes(store, dataset_uuid="dataset_uuid"), mps_original
)

# Expect this not to fail even though the metapartitions have different
# timestamp units, because all units should be coerced to nanoseconds.
dataset = store_dataset_from_partitions(
partition_list=mps,
dataset_uuid="dataset_uuid",
store=store,
dataset_metadata={"some": "metadata"},
)

# Ensure the dataset can be loaded properly
stored_dataset = DatasetMetadata.load_from_store("dataset_uuid", store)
assert dataset == stored_dataset

mps = dispatch_metapartitions("dataset_uuid", store)
mps_loaded = map(lambda mp: mp.load_dataframes(store), mps)

# Ensure the values and dtypes of the loaded datasets are correct
for mp in mps_loaded:
assert mp.data["a"].dtype == "datetime64[ns]"
assert mp.data["b"].dtype == "datetime64[ns]"
assert mp.data["a"][0] == date
assert mp.data["b"][0] == date
Loading