Skip to content

Commit

Permalink
ignore from_arrays type error
Browse files Browse the repository at this point in the history
  • Loading branch information
drernie committed Aug 24, 2024
1 parent 0f6994f commit f04b18e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
7 changes: 2 additions & 5 deletions athena_federation/batch_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,8 @@ def spilled(self) -> bool:
def write_rows(self, data: dict[str, list[Any]]):
array_data = [
pa.array(data[name]) for name in self._schema.names
] # type: ignore
record_batch = pa.RecordBatch.from_arrays(
array_data, # type: ignore
schema=[self._schema],
)
]
record_batch = pa.RecordBatch.from_arrays(arrays=array_data, schema=self._schema) # type: ignore
assert record_batch is not str
self._batch_size += record_batch.nbytes
self._writer.write_batch(record_batch)
Expand Down
12 changes: 8 additions & 4 deletions example/sample_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ class SampleDataSource(AthenaDataSource):
A hard-coded example that shows the different methods you can implement.
"""

@staticmethod
def TransposeData(columns: List[str], data: List[List[Any]]) -> Dict[str, List[Any]]:
"""
Transpose the data so that it is a dictionary of columns.
"""
return dict(zip(columns, list(zip(*data))))

def __init__(self):
super().__init__()

Expand All @@ -32,9 +39,6 @@ def splits(self, database_name: str, table_name: str) -> List[Dict]:
{"name": "split2", "action": "spill"},
]

def transpose_data(self, cols: List[str], records: List[List[Any]]) -> Dict[str, List[Any]]:
return dict(zip(cols, list(zip(*records))))

def records(
self, database: str, table: str, split: Mapping[str, str]
) -> Dict[str, List[Any]]:
Expand All @@ -50,4 +54,4 @@ def records(
# Demonstrate how splits work by generating a huge response. :)
if split.get("action", "") == "spill":
records = records * 4000
return self.transpose_data(self.columns(database, table), records)
return self.TransposeData(self.columns(database, table), records)

0 comments on commit f04b18e

Please sign in to comment.