From 4e99c199927e0f79c60489e77a49eb00fe139205 Mon Sep 17 00:00:00 2001 From: frankliee Date: Fri, 22 Sep 2023 19:25:38 +0800 Subject: [PATCH] Python: Create HadoopFileSystem from netloc (merge request !1060) (#8596) --- python/pyiceberg/io/pyarrow.py | 36 +++++++++++++++++---------------- python/tests/io/test_pyarrow.py | 27 ++++++++++++------------- 2 files changed, 32 insertions(+), 31 deletions(-) diff --git a/python/pyiceberg/io/pyarrow.py b/python/pyiceberg/io/pyarrow.py index 2cc20549feb6..7f6045abeda4 100644 --- a/python/pyiceberg/io/pyarrow.py +++ b/python/pyiceberg/io/pyarrow.py @@ -297,24 +297,24 @@ def to_input_file(self) -> PyArrowFile: class PyArrowFileIO(FileIO): - fs_by_scheme: Callable[[str], FileSystem] + fs_by_scheme: Callable[[str, Optional[str]], FileSystem] def __init__(self, properties: Properties = EMPTY_DICT): - self.fs_by_scheme: Callable[[str], FileSystem] = lru_cache(self._initialize_fs) + self.fs_by_scheme: Callable[[str, Optional[str]], FileSystem] = lru_cache(self._initialize_fs) super().__init__(properties=properties) @staticmethod - def parse_location(location: str) -> Tuple[str, str]: + def parse_location(location: str) -> Tuple[str, str, str]: """Return the path without the scheme.""" uri = urlparse(location) if not uri.scheme: - return "file", os.path.abspath(location) + return "file", uri.netloc, os.path.abspath(location) elif uri.scheme == "hdfs": - return uri.scheme, location + return uri.scheme, uri.netloc, location else: - return uri.scheme, f"{uri.netloc}{uri.path}" + return uri.scheme, uri.netloc, f"{uri.netloc}{uri.path}" - def _initialize_fs(self, scheme: str) -> FileSystem: + def _initialize_fs(self, scheme: str, netloc: Optional[str] = None) -> FileSystem: if scheme in {"s3", "s3a", "s3n"}: from pyarrow.fs import S3FileSystem @@ -334,6 +334,8 @@ def _initialize_fs(self, scheme: str) -> FileSystem: from pyarrow.fs import HadoopFileSystem hdfs_kwargs: Dict[str, Any] = {} + if netloc: + return HadoopFileSystem.from_uri(f"hdfs://{netloc}") if host := self.properties.get(HDFS_HOST): hdfs_kwargs["host"] = host if port := self.properties.get(HDFS_PORT): @@ -377,9 +379,9 @@ def new_input(self, location: str) -> PyArrowFile: Returns: PyArrowFile: A PyArrowFile instance for the given location. """ - scheme, path = self.parse_location(location) + scheme, netloc, path = self.parse_location(location) return PyArrowFile( - fs=self.fs_by_scheme(scheme), + fs=self.fs_by_scheme(scheme, netloc), location=location, path=path, buffer_size=int(self.properties.get(BUFFER_SIZE, ONE_MEGABYTE)), @@ -394,9 +396,9 @@ def new_output(self, location: str) -> PyArrowFile: Returns: PyArrowFile: A PyArrowFile instance for the given location. """ - scheme, path = self.parse_location(location) + scheme, netloc, path = self.parse_location(location) return PyArrowFile( - fs=self.fs_by_scheme(scheme), + fs=self.fs_by_scheme(scheme, netloc), location=location, path=path, buffer_size=int(self.properties.get(BUFFER_SIZE, ONE_MEGABYTE)), @@ -415,8 +417,8 @@ def delete(self, location: Union[str, InputFile, OutputFile]) -> None: an AWS error code 15. """ str_location = location.location if isinstance(location, (InputFile, OutputFile)) else location - scheme, path = self.parse_location(str_location) - fs = self.fs_by_scheme(scheme) + scheme, netloc, path = self.parse_location(str_location) + fs = self.fs_by_scheme(scheme, netloc) try: fs.delete_file(path) @@ -588,7 +590,7 @@ def _get_file_format(file_format: FileFormat, **kwargs: Dict[str, Any]) -> ds.Fi def _construct_fragment(fs: FileSystem, data_file: DataFile, file_format_kwargs: Dict[str, Any] = EMPTY_DICT) -> ds.Fragment: - _, path = PyArrowFileIO.parse_location(data_file.file_path) + _, _, path = PyArrowFileIO.parse_location(data_file.file_path) return _get_file_format(data_file.file_format, **file_format_kwargs).make_fragment(path, fs) @@ -810,7 +812,7 @@ def _task_to_table( if limit and sum(row_counts) >= limit: return None - _, path = PyArrowFileIO.parse_location(task.file.file_path) + _, _, path = PyArrowFileIO.parse_location(task.file.file_path) arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8)) with fs.open_input_file(path) as fin: fragment = arrow_format.make_fragment(fin) @@ -919,9 +921,9 @@ def project_table( Raises: ResolveError: When an incompatible query is done. """ - scheme, _ = PyArrowFileIO.parse_location(table.location()) + scheme, netloc, _ = PyArrowFileIO.parse_location(table.location()) if isinstance(table.io, PyArrowFileIO): - fs = table.io.fs_by_scheme(scheme) + fs = table.io.fs_by_scheme(scheme, netloc) else: try: from pyiceberg.io.fsspec import FsspecFileIO diff --git a/python/tests/io/test_pyarrow.py b/python/tests/io/test_pyarrow.py index 49e1c8bca8c0..8b622125932d 100644 --- a/python/tests/io/test_pyarrow.py +++ b/python/tests/io/test_pyarrow.py @@ -1529,17 +1529,16 @@ def test_writing_avro_file_gcs(generated_manifest_entry_file: str, pyarrow_filei pyarrow_fileio_gcs.delete(f"gs://warehouse/{filename}") -def test_parse_hdfs_location() -> None: - locations = ["hdfs://127.0.0.1:9000/root/foo.txt", "hdfs://127.0.0.1/root/foo.txt"] - for location in locations: - schema, path = PyArrowFileIO.parse_location(location) - assert schema == "hdfs" - assert location == path - - -def test_parse_local_location() -> None: - locations = ["/root/foo.txt", "/root/tmp/foo.txt"] - for location in locations: - schema, path = PyArrowFileIO.parse_location(location) - assert schema == "file" - assert location == path +def test_parse_location() -> None: + def check_results(location: str, expected_schema: str, expected_netloc: str, expected_uri: str) -> None: + schema, netloc, uri = PyArrowFileIO.parse_location(location) + assert schema == expected_schema + assert netloc == expected_netloc + assert uri == expected_uri + + check_results("hdfs://127.0.0.1:9000/root/foo.txt", "hdfs", "127.0.0.1:9000", "hdfs://127.0.0.1:9000/root/foo.txt") + check_results("hdfs://127.0.0.1/root/foo.txt", "hdfs", "127.0.0.1", "hdfs://127.0.0.1/root/foo.txt") + check_results("hdfs://clusterA/root/foo.txt", "hdfs", "clusterA", "hdfs://clusterA/root/foo.txt") + + check_results("/root/foo.txt", "file", "", "/root/foo.txt") + check_results("/root/tmp/foo.txt", "file", "", "/root/tmp/foo.txt")