From f04b18e27158df42884cb64cb17f471873f63a33 Mon Sep 17 00:00:00 2001 From: "Dr. Ernie Prabhakar" Date: Sat, 24 Aug 2024 15:23:06 -0700 Subject: [PATCH] ignore from_arrays type error --- athena_federation/batch_writer.py | 7 ++----- example/sample_data_source.py | 12 ++++++++---- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/athena_federation/batch_writer.py b/athena_federation/batch_writer.py index fcf50d2..f7b8bbc 100644 --- a/athena_federation/batch_writer.py +++ b/athena_federation/batch_writer.py @@ -29,11 +29,8 @@ def spilled(self) -> bool: def write_rows(self, data: dict[str, list[Any]]): array_data = [ pa.array(data[name]) for name in self._schema.names - ] # type: ignore - record_batch = pa.RecordBatch.from_arrays( - array_data, # type: ignore - schema=[self._schema], - ) + ] + record_batch = pa.RecordBatch.from_arrays(arrays=array_data, schema=self._schema) # type: ignore assert record_batch is not str self._batch_size += record_batch.nbytes self._writer.write_batch(record_batch) diff --git a/example/sample_data_source.py b/example/sample_data_source.py index 3620426..ad2e505 100644 --- a/example/sample_data_source.py +++ b/example/sample_data_source.py @@ -11,6 +11,13 @@ class SampleDataSource(AthenaDataSource): A hard-coded example that shows the different methods you can implement. """ + @staticmethod + def TransposeData(columns: List[str], data: List[List[Any]]) -> Dict[str, List[Any]]: + """ + Transpose the data so that it is a dictionary of columns. + """ + return dict(zip(columns, list(zip(*data)))) + def __init__(self): super().__init__() @@ -32,9 +39,6 @@ def splits(self, database_name: str, table_name: str) -> List[Dict]: {"name": "split2", "action": "spill"}, ] - def transpose_data(self, cols: List[str], records: List[List[Any]]) -> Dict[str, List[Any]]: - return dict(zip(cols, list(zip(*records)))) - def records( self, database: str, table: str, split: Mapping[str, str] ) -> Dict[str, List[Any]]: @@ -50,4 +54,4 @@ def records( # Demonstrate how splits work by generating a huge response. :) if split.get("action", "") == "spill": records = records * 4000 - return self.transpose_data(self.columns(database, table), records) + return self.TransposeData(self.columns(database, table), records)