diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index ad22fd37dca0c4..e8508a6e7c827c 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -733,6 +733,11 @@ "file = datahub.ingestion.reporting.file_reporter:FileReporter", ], "datahub.custom_packages": [], + "datahub.fs.plugins": [ + "s3 = datahub.ingestion.fs.s3_fs:S3FileSystem", + "file = datahub.ingestion.fs.local_fs:LocalFileSystem", + "http = datahub.ingestion.fs.http_fs:HttpFileSystem", + ], } diff --git a/metadata-ingestion/src/datahub/ingestion/fs/__init__.py b/metadata-ingestion/src/datahub/ingestion/fs/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/metadata-ingestion/src/datahub/ingestion/fs/fs_base.py b/metadata-ingestion/src/datahub/ingestion/fs/fs_base.py new file mode 100644 index 00000000000000..b099d4d332946a --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/fs/fs_base.py @@ -0,0 +1,40 @@ +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from typing import Any, Iterable +from urllib import parse + + +@dataclass +class FileInfo: + path: str + size: int + is_file: bool + + def __str__(self): + return f"FileInfo({self.path}, {self.size}, {self.is_file})" + + +class FileSystem(metaclass=ABCMeta): + @classmethod + def create(cls, **kwargs: Any) -> "FileSystem": + raise NotImplementedError('File system implementations must implement "create"') + + @abstractmethod + def open(self, path: str, **kwargs: Any) -> Any: + pass + + @abstractmethod + def file_status(self, path: str) -> FileInfo: + pass + + @abstractmethod + def list(self, path: str) -> Iterable[FileInfo]: + pass + + +def get_path_schema(path: str) -> str: + scheme = parse.urlparse(path).scheme + if scheme == "": + # This makes the default schema "file" for local paths. + scheme = "file" + return scheme diff --git a/metadata-ingestion/src/datahub/ingestion/fs/fs_registry.py b/metadata-ingestion/src/datahub/ingestion/fs/fs_registry.py new file mode 100644 index 00000000000000..cb2349723a4cdb --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/fs/fs_registry.py @@ -0,0 +1,5 @@ +from datahub.ingestion.api.registry import PluginRegistry +from datahub.ingestion.fs.fs_base import FileSystem + +fs_registry = PluginRegistry[FileSystem]() +fs_registry.register_from_entrypoint("datahub.fs.plugins") diff --git a/metadata-ingestion/src/datahub/ingestion/fs/http_fs.py b/metadata-ingestion/src/datahub/ingestion/fs/http_fs.py new file mode 100644 index 00000000000000..a9153352690adc --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/fs/http_fs.py @@ -0,0 +1,28 @@ +from typing import Any, Iterable + +import requests +import smart_open + +from datahub.ingestion.fs.fs_base import FileInfo, FileSystem + + +class HttpFileSystem(FileSystem): + @classmethod + def create(cls, **kwargs): + return HttpFileSystem() + + def open(self, path: str, **kwargs: Any) -> Any: + return smart_open.open(path, mode="rb", transport_params=kwargs) + + def file_status(self, path: str) -> FileInfo: + head = requests.head(path) + if head.ok: + return FileInfo(path, int(head.headers["Content-length"]), is_file=True) + elif head.status_code == 404: + raise FileNotFoundError(f"Requested path {path} does not exist.") + else: + raise IOError(f"Cannot get file status for the requested path {path}.") + + def list(self, path: str) -> Iterable[FileInfo]: + status = self.file_status(path) + return [status] diff --git a/metadata-ingestion/src/datahub/ingestion/fs/local_fs.py b/metadata-ingestion/src/datahub/ingestion/fs/local_fs.py new file mode 100644 index 00000000000000..8a546650a3dfe8 --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/fs/local_fs.py @@ -0,0 +1,29 @@ +import os +import pathlib +from typing import Any, Iterable + +from datahub.ingestion.fs.fs_base import FileInfo, FileSystem + + +class LocalFileSystem(FileSystem): + @classmethod + def create(cls, **kwargs): + return LocalFileSystem() + + def open(self, path: str, **kwargs: Any) -> Any: + # Local does not support any additional kwargs + assert not kwargs + return pathlib.Path(path).open(mode="rb") + + def list(self, path: str) -> Iterable[FileInfo]: + p = pathlib.Path(path) + if p.is_file(): + return [self.file_status(path)] + else: + return iter([self.file_status(str(x)) for x in p.iterdir()]) + + def file_status(self, path: str) -> FileInfo: + if os.path.isfile(path): + return FileInfo(path, os.path.getsize(path), is_file=True) + else: + return FileInfo(path, 0, is_file=False) diff --git a/metadata-ingestion/src/datahub/ingestion/fs/s3_fs.py b/metadata-ingestion/src/datahub/ingestion/fs/s3_fs.py new file mode 100644 index 00000000000000..a135b7b6ce8375 --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/fs/s3_fs.py @@ -0,0 +1,108 @@ +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Any, Iterable +from urllib.parse import urlparse + +import boto3 +import smart_open + +from datahub.ingestion.fs import s3_fs +from datahub.ingestion.fs.fs_base import FileInfo, FileSystem + + +def parse_s3_path(path: str) -> "S3Path": + parsed = urlparse(path) + return S3Path(parsed.netloc, parsed.path.lstrip("/")) + + +def assert_ok_status(s3_response): + is_ok = s3_response["ResponseMetadata"]["HTTPStatusCode"] == 200 + assert ( + is_ok + ), f"Failed to fetch S3 object, error message: {s3_response['Error']['Message']}" + + +@dataclass +class S3Path: + bucket: str + key: str + + def __str__(self): + return f"S3Path({self.bucket}, {self.key})" + + +class S3ListIterator(Iterator): + + MAX_KEYS = 1000 + + def __init__( + self, s3_client: Any, bucket: str, prefix: str, max_keys: int = MAX_KEYS + ) -> None: + self._s3 = s3_client + self._bucket = bucket + self._prefix = prefix + self._max_keys = max_keys + self._file_statuses: Iterator = iter([]) + self._token = "" + self.fetch() + + def __next__(self) -> FileInfo: + try: + return next(self._file_statuses) + except StopIteration: + if self._token: + self.fetch() + return next(self._file_statuses) + else: + raise StopIteration() + + def fetch(self): + params = dict(Bucket=self._bucket, Prefix=self._prefix, MaxKeys=self._max_keys) + if self._token: + params.update(ContinuationToken=self._token) + + response = self._s3.list_objects_v2(**params) + + s3_fs.assert_ok_status(response) + + self._file_statuses = iter( + [ + FileInfo(f"s3://{response['Name']}/{x['Key']}", x["Size"], is_file=True) + for x in response.get("Contents", []) + ] + ) + self._token = response.get("NextContinuationToken") + + +class S3FileSystem(FileSystem): + def __init__(self, **kwargs): + self.s3 = boto3.client("s3", **kwargs) + + @classmethod + def create(cls, **kwargs): + return S3FileSystem(**kwargs) + + def open(self, path: str, **kwargs: Any) -> Any: + transport_params = kwargs.update({"client": self.s3}) + return smart_open.open(path, mode="rb", transport_params=transport_params) + + def file_status(self, path: str) -> FileInfo: + s3_path = parse_s3_path(path) + try: + response = self.s3.get_object_attributes( + Bucket=s3_path.bucket, Key=s3_path.key, ObjectAttributes=["ObjectSize"] + ) + assert_ok_status(response) + return FileInfo(path, response["ObjectSize"], is_file=True) + except Exception as e: + if ( + hasattr(e, "response") + and e.response["ResponseMetadata"]["HTTPStatusCode"] == 404 + ): + return FileInfo(path, 0, is_file=False) + else: + raise e + + def list(self, path: str) -> Iterable[FileInfo]: + s3_path = parse_s3_path(path) + return S3ListIterator(self.s3, s3_path.bucket, s3_path.key) diff --git a/metadata-ingestion/src/datahub/ingestion/source/file.py b/metadata-ingestion/src/datahub/ingestion/source/file.py index 3e8c88b725de50..853487b1f1c9f7 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/file.py +++ b/metadata-ingestion/src/datahub/ingestion/source/file.py @@ -6,16 +6,13 @@ from dataclasses import dataclass, field from enum import auto from functools import partial -from io import BufferedReader from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union -from urllib import parse import ijson -import requests from pydantic import validator from pydantic.fields import Field -from datahub.configuration.common import ConfigEnum, ConfigModel, ConfigurationError +from datahub.configuration.common import ConfigEnum, ConfigModel from datahub.configuration.validate_field_deprecation import pydantic_field_deprecated from datahub.configuration.validate_field_rename import pydantic_renamed_field from datahub.emitter.mcp import MetadataChangeProposalWrapper @@ -35,6 +32,8 @@ ) from datahub.ingestion.api.source_helpers import auto_workunit_reporter from datahub.ingestion.api.workunit import MetadataWorkUnit +from datahub.ingestion.fs.fs_base import FileInfo, get_path_schema +from datahub.ingestion.fs.fs_registry import fs_registry from datahub.metadata.com.linkedin.pegasus2avro.mxe import ( MetadataChangeEvent, MetadataChangeProposal, @@ -186,34 +185,22 @@ def __init__(self, ctx: PipelineContext, config: FileSourceConfig): self.ctx = ctx self.config = config self.report = FileSourceReport() - self.fp: Optional[BufferedReader] = None @classmethod def create(cls, config_dict, ctx): config = FileSourceConfig.parse_obj(config_dict) return cls(ctx, config) - def get_filenames(self) -> Iterable[str]: - path_parsed = parse.urlparse(str(self.config.path)) - if path_parsed.scheme in ("file", ""): - path = pathlib.Path(self.config.path) - if path.is_file(): - self.report.total_num_files = 1 - return [str(path)] - elif path.is_dir(): - files_and_stats = [ - (str(x), os.path.getsize(x)) - for x in path.glob(f"*{self.config.file_extension}") - if x.is_file() - ] - self.report.total_num_files = len(files_and_stats) - self.report.total_bytes_on_disk = sum([y for (x, y) in files_and_stats]) - return [x for (x, y) in files_and_stats] - else: - raise Exception(f"Failed to process {path}") - else: - self.report.total_num_files = 1 - return [str(self.config.path)] + def get_filenames(self) -> Iterable[FileInfo]: + path_str = str(self.config.path) + schema = get_path_schema(path_str) + fs_class = fs_registry.get(schema) + fs = fs_class.create() + for file_info in fs.list(path_str): + if file_info.is_file and file_info.path.endswith( + self.config.file_extension + ): + yield file_info def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: # No super() call, as we don't want helpers that create / remove workunits @@ -224,7 +211,7 @@ def get_workunits_internal( ) -> Iterable[MetadataWorkUnit]: for f in self.get_filenames(): for i, obj in self.iterate_generic_file(f): - id = f"file://{f}:{i}" + id = f"{f.path}:{i}" if isinstance( obj, (MetadataChangeProposalWrapper, MetadataChangeProposal) ): @@ -241,99 +228,88 @@ def get_workunits_internal( yield MetadataWorkUnit(id, mcp_raw=obj) else: yield MetadataWorkUnit(id, mce=obj) + self.report.total_num_files += 1 + self.report.append_total_bytes_on_disk(f.size) def get_report(self): return self.report def close(self): - if self.fp: - self.fp.close() super().close() - def _iterate_file(self, path: str) -> Iterable[Tuple[int, Any]]: - self.report.current_file_name = path - path_parsed = parse.urlparse(path) - if path_parsed.scheme not in ("http", "https"): # A local file - self.report.current_file_size = os.path.getsize(path) - if self.config.read_mode == FileReadMode.AUTO: - file_read_mode = ( - FileReadMode.BATCH - if self.report.current_file_size - < self.config._minsize_for_streaming_mode_in_bytes - else FileReadMode.STREAM - ) - logger.info(f"Reading file {path} in {file_read_mode} mode") + def _iterate_file(self, file_status: FileInfo) -> Iterable[Any]: + file_read_mode = self.config.read_mode + if file_read_mode == FileReadMode.AUTO: + if file_status.size < self.config._minsize_for_streaming_mode_in_bytes: + file_read_mode = FileReadMode.BATCH else: - file_read_mode = self.config.read_mode - - if file_read_mode == FileReadMode.BATCH: - with open(path) as f: - parse_start_time = datetime.datetime.now() - obj_list = json.load(f) - parse_end_time = datetime.datetime.now() - self.report.add_parse_time(parse_end_time - parse_start_time) - if not isinstance(obj_list, list): - obj_list = [obj_list] - count_start_time = datetime.datetime.now() - self.report.current_file_num_elements = len(obj_list) - self.report.add_count_time(datetime.datetime.now() - count_start_time) - self.report.current_file_elements_read = 0 - for i, obj in enumerate(obj_list): - yield i, obj - self.report.current_file_elements_read += 1 + file_read_mode = FileReadMode.STREAM + + # Open the file. + schema = get_path_schema(file_status.path) + fs_class = fs_registry.get(schema) + fs = fs_class.create() + self.report.current_file_name = file_status.path + self.report.current_file_size = file_status.size + fp = fs.open(file_status.path) + + with fp: + if file_read_mode == FileReadMode.STREAM: + yield from self._iterate_file_streaming(fp) else: - self.fp = open(path, "rb") - if self.config.count_all_before_starting: - count_start_time = datetime.datetime.now() - parse_stream = ijson.parse(self.fp, use_float=True) - total_elements = 0 - for row in ijson.items(parse_stream, "item", use_float=True): - total_elements += 1 - count_end_time = datetime.datetime.now() - self.report.add_count_time(count_end_time - count_start_time) - self.report.current_file_num_elements = total_elements - self.report.current_file_elements_read = 0 - self.fp.seek(0) - parse_start_time = datetime.datetime.now() - parse_stream = ijson.parse(self.fp, use_float=True) - rows_yielded = 0 - for row in ijson.items(parse_stream, "item", use_float=True): - parse_end_time = datetime.datetime.now() - self.report.add_parse_time(parse_end_time - parse_start_time) - rows_yielded += 1 - self.report.current_file_elements_read += 1 - yield rows_yielded, row - parse_start_time = datetime.datetime.now() - else: - try: - response = requests.get(path) - parse_start_time = datetime.datetime.now() - data = response.json() - except Exception as e: - raise ConfigurationError(f"Cannot read remote file {path}, error:{e}") - if not isinstance(data, list): - data = [data] - parse_end_time = datetime.datetime.now() - self.report.add_parse_time(parse_end_time - parse_start_time) - self.report.current_file_size = len(response.content) - self.report.current_file_elements_read = 0 - for i, obj in enumerate(data): - yield i, obj - self.report.current_file_elements_read += 1 + yield from self._iterate_file_batch(fp) - self.report.files_completed.append(path) + self.report.files_completed.append(file_status.path) self.report.num_files_completed += 1 self.report.total_bytes_read_completed_files += self.report.current_file_size self.report.reset_current_file_stats() + def _iterate_file_streaming(self, fp: Any) -> Iterable[Any]: + # Count the number of elements in the file. + if self.config.count_all_before_starting: + count_start_time = datetime.datetime.now() + parse_stream = ijson.parse(fp, use_float=True) + total_elements = 0 + for _row in ijson.items(parse_stream, "item", use_float=True): + total_elements += 1 + count_end_time = datetime.datetime.now() + self.report.add_count_time(count_end_time - count_start_time) + self.report.current_file_num_elements = total_elements + fp.seek(0) + + # Read the file. + self.report.current_file_elements_read = 0 + parse_start_time = datetime.datetime.now() + parse_stream = ijson.parse(fp, use_float=True) + for row in ijson.items(parse_stream, "item", use_float=True): + parse_end_time = datetime.datetime.now() + self.report.add_parse_time(parse_end_time - parse_start_time) + self.report.current_file_elements_read += 1 + yield row + + def _iterate_file_batch(self, fp: Any) -> Iterable[Any]: + # Read the file. + contents = json.load(fp) + + # Maintain backwards compatibility with the single-object format. + if isinstance(contents, list): + for row in contents: + yield row + else: + yield contents + def iterate_mce_file(self, path: str) -> Iterator[MetadataChangeEvent]: - for i, obj in self._iterate_file(path): + # TODO: Remove this method, as it appears to be unused. + schema = get_path_schema(path) + fs_class = fs_registry.get(schema) + fs = fs_class.create() + file_status = fs.file_status(path) + for obj in self._iterate_file(file_status): mce: MetadataChangeEvent = MetadataChangeEvent.from_obj(obj) yield mce def iterate_generic_file( - self, - path: str, + self, file_status: FileInfo ) -> Iterator[ Tuple[ int, @@ -344,7 +320,7 @@ def iterate_generic_file( ], ] ]: - for i, obj in self._iterate_file(path): + for i, obj in enumerate(self._iterate_file(file_status)): try: deserialize_start_time = datetime.datetime.now() item = _from_obj_for_file(obj) @@ -389,6 +365,11 @@ def test_connection(config_dict: dict) -> TestConnectionReport: basic_connectivity=CapabilityReport(capable=True) ) + @staticmethod + def close_if_possible(stream): + if hasattr(stream, "close") and callable(stream.close): + stream.close() + def _from_obj_for_file( obj: dict, diff --git a/metadata-ingestion/tests/unit/test_plugin_system.py b/metadata-ingestion/tests/unit/test_plugin_system.py index 4d1ebce2be849f..0e12416325bf9a 100644 --- a/metadata-ingestion/tests/unit/test_plugin_system.py +++ b/metadata-ingestion/tests/unit/test_plugin_system.py @@ -7,6 +7,7 @@ from datahub.ingestion.api.registry import PluginRegistry from datahub.ingestion.api.sink import Sink from datahub.ingestion.extractor.extractor_registry import extractor_registry +from datahub.ingestion.fs.fs_registry import fs_registry from datahub.ingestion.reporting.reporting_provider_registry import ( reporting_provider_registry, ) @@ -54,6 +55,7 @@ (reporting_provider_registry, ["datahub", "file"]), (ingestion_checkpoint_provider_registry, ["datahub"]), (lite_registry, ["duckdb"]), + (fs_registry, ["file", "http", "s3"]), ], ) def test_registry_defaults(registry: PluginRegistry, expected: List[str]) -> None: