From 816612afd753d9a5c2aee9fce5e3eaa6a3d54d3a Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Fri, 9 Feb 2024 17:00:29 -0500 Subject: [PATCH] fix tests --- .../src/datahub/ingestion/fs/local_fs.py | 4 +- .../src/datahub/ingestion/source/file.py | 61 +++++++++++++------ 2 files changed, 46 insertions(+), 19 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/fs/local_fs.py b/metadata-ingestion/src/datahub/ingestion/fs/local_fs.py index 3f4467ee1c762f..b361ee2f350a7a 100644 --- a/metadata-ingestion/src/datahub/ingestion/fs/local_fs.py +++ b/metadata-ingestion/src/datahub/ingestion/fs/local_fs.py @@ -11,7 +11,9 @@ def create(cls, **kwargs): return LocalFileSystem() def open(self, path: str, **kwargs: Any) -> Any: - return pathlib.Path(path).open(mode="rb", transport_params=kwargs) + # Local does not support any additional kwargs + assert not kwargs + return pathlib.Path(path).open(mode="rb") def list(self, path: str) -> Iterable[FileInfo]: p = pathlib.Path(path) diff --git a/metadata-ingestion/src/datahub/ingestion/source/file.py b/metadata-ingestion/src/datahub/ingestion/source/file.py index c3b2a99632daf5..4afeaafcac770f 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/file.py +++ b/metadata-ingestion/src/datahub/ingestion/source/file.py @@ -174,7 +174,6 @@ def __init__(self, ctx: PipelineContext, config: FileSourceConfig): self.ctx = ctx self.config = config self.report = FileSourceReport() - self.fp: Any = None @classmethod def create(cls, config_dict, ctx): @@ -225,49 +224,75 @@ def get_report(self): return self.report def close(self): - self.close_if_possible(self.fp) super().close() - def _iterate_file(self, file_status: FileInfo) -> Iterable[Tuple[int, Any]]: + def _iterate_file(self, file_status: FileInfo) -> Iterable[Any]: + if self.config.read_mode == FileReadMode.AUTO: + if file_status.size < self.config._minsize_for_streaming_mode_in_bytes: + self.config.read_mode = FileReadMode.BATCH + else: + self.config.read_mode = FileReadMode.STREAM + + # Open the file. schema = get_path_schema(file_status.path) fs_class = fs_registry.get(schema) fs = fs_class.create() self.report.current_file_name = file_status.path self.report.current_file_size = file_status.size - self.fp = fs.open(file_status.path) + fp = fs.open(file_status.path) + + with fp: + if self.config.read_mode == FileReadMode.STREAM: + yield from self._iterate_file_streaming(fp) + else: + yield from self._iterate_file_batch(fp) + + self.report.files_completed.append(file_status.path) + self.report.num_files_completed += 1 + self.report.total_bytes_read_completed_files += self.report.current_file_size + self.report.reset_current_file_stats() + + def _iterate_file_streaming(self, fp: Any) -> Iterable[Any]: + # Count the number of elements in the file. if self.config.count_all_before_starting: count_start_time = datetime.datetime.now() - parse_stream = ijson.parse(self.fp, use_float=True) + parse_stream = ijson.parse(fp, use_float=True) total_elements = 0 - for row in ijson.items(parse_stream, "item", use_float=True): + for _row in ijson.items(parse_stream, "item", use_float=True): total_elements += 1 count_end_time = datetime.datetime.now() self.report.add_count_time(count_end_time - count_start_time) self.report.current_file_num_elements = total_elements - self.fp.seek(0) + fp.seek(0) + + # Read the file. self.report.current_file_elements_read = 0 parse_start_time = datetime.datetime.now() - parse_stream = ijson.parse(self.fp, use_float=True) - rows_yielded = 0 + parse_stream = ijson.parse(fp, use_float=True) for row in ijson.items(parse_stream, "item", use_float=True): parse_end_time = datetime.datetime.now() self.report.add_parse_time(parse_end_time - parse_start_time) - rows_yielded += 1 self.report.current_file_elements_read += 1 - yield rows_yielded, row - parse_start_time = datetime.datetime.now() + yield row - self.report.files_completed.append(file_status.path) - self.report.num_files_completed += 1 - self.report.total_bytes_read_completed_files += self.report.current_file_size - self.report.reset_current_file_stats() + def _iterate_file_batch(self, fp: Any) -> Iterable[Any]: + # Read the file. + contents = json.load(fp) + + # Maintain backwards compatibility with the single-object format. + if isinstance(contents, list): + for row in contents: + yield row + else: + yield contents def iterate_mce_file(self, path: str) -> Iterator[MetadataChangeEvent]: + # TODO: Remove this method, as it appears to be unused. schema = get_path_schema(path) fs_class = fs_registry.get(schema) fs = fs_class.create() file_status = fs.file_status(path) - for i, obj in self._iterate_file(file_status): + for obj in self._iterate_file(file_status): mce: MetadataChangeEvent = MetadataChangeEvent.from_obj(obj) yield mce @@ -283,7 +308,7 @@ def iterate_generic_file( ], ] ]: - for i, obj in self._iterate_file(file_status): + for i, obj in enumerate(self._iterate_file(file_status)): try: deserialize_start_time = datetime.datetime.now() item = _from_obj_for_file(obj)