Skip to content

Commit

Permalink
Skip reading files with incorrect extension (#318)
Browse files Browse the repository at this point in the history
* filter_files_by_extension function

Signed-off-by: Sarah Yurick <[email protected]>

* add type checking

Signed-off-by: Sarah Yurick <[email protected]>

* add filter_by param to get_all_files_paths_under

Signed-off-by: Sarah Yurick <[email protected]>

* isort

Signed-off-by: Sarah Yurick <[email protected]>

* address ayush's comments

Signed-off-by: Sarah Yurick <[email protected]>

* run black

Signed-off-by: Sarah Yurick <[email protected]>

* trailing whitespace

Signed-off-by: Sarah Yurick <[email protected]>

* more whitespace

Signed-off-by: Sarah Yurick <[email protected]>

* address praateek's review

Signed-off-by: Sarah Yurick <[email protected]>

* praateek's review

Signed-off-by: Sarah Yurick <[email protected]>

---------

Signed-off-by: Sarah Yurick <[email protected]>
Signed-off-by: Sarah Yurick <[email protected]>
  • Loading branch information
sarahyurick authored Nov 18, 2024
1 parent 34a5372 commit d0dd30b
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 67 deletions.
96 changes: 67 additions & 29 deletions nemo_curator/datasets/doc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Union
import os
from typing import Any, List, Literal, Optional, Union

import dask.dataframe as dd

Expand All @@ -29,26 +30,40 @@ class DocumentDataset:
def __init__(self, dataset_df: dd.DataFrame):
self.df = dataset_df

def __len__(self):
def __len__(self) -> int:
return len(self.df)

def persist(self):
# `def persist(self) -> Self` requires Python 3.11 or higher
def persist(self) -> "DocumentDataset":
return DocumentDataset(self.df.persist())

def head(self, n=5):
def head(self, n: int = 5) -> Any:
return self.df.head(n)

@classmethod
def read_json(
cls,
input_files: Union[str, List[str]],
backend: str = "pandas",
backend: Literal["pandas", "cudf"] = "pandas",
files_per_partition: int = 1,
add_filename: bool = False,
input_meta: Union[str, dict] = None,
columns: Optional[List[str]] = None,
**kwargs,
):
) -> "DocumentDataset":
"""
Read JSONL or JSONL file(s).
Args:
input_files: The path of the input file(s).
backend: The backend to use for reading the data.
files_per_partition: The number of files to read per partition.
add_filename: Whether to add a "filename" column to the DataFrame.
input_meta: A dictionary or a string formatted as a dictionary, which outlines
the field names and their respective data types within the JSONL input file.
columns: If not None, only these columns will be read from the file.
"""
return cls(
_read_json_or_parquet(
input_files=input_files,
Expand All @@ -65,13 +80,25 @@ def read_json(
@classmethod
def read_parquet(
cls,
input_files,
backend="pandas",
files_per_partition=1,
add_filename=False,
input_files: Union[str, List[str]],
backend: Literal["pandas", "cudf"] = "pandas",
files_per_partition: int = 1,
add_filename: bool = False,
columns: Optional[List[str]] = None,
**kwargs,
):
) -> "DocumentDataset":
"""
Read Parquet file(s).
Args:
input_files: The path of the input file(s).
backend: The backend to use for reading the data.
files_per_partition: The number of files to read per partition.
add_filename: Whether to add a "filename" column to the DataFrame.
columns: If not None, only these columns will be read from the file.
There is a significant performance gain when specifying columns for Parquet files.
"""
return cls(
_read_json_or_parquet(
input_files=input_files,
Expand All @@ -87,13 +114,24 @@ def read_parquet(
@classmethod
def read_pickle(
cls,
input_files,
backend="pandas",
files_per_partition=1,
add_filename=False,
input_files: Union[str, List[str]],
backend: Literal["pandas", "cudf"] = "pandas",
files_per_partition: int = 1,
add_filename: bool = False,
columns: Optional[List[str]] = None,
**kwargs,
):
) -> "DocumentDataset":
"""
Read Pickle file(s).
Args:
input_files: The path of the input file(s).
backend: The backend to use for reading the data.
files_per_partition: The number of files to read per partition.
add_filename: Whether to add a "filename" column to the DataFrame.
columns: If not None, only these columns will be read from the file.
"""
return cls(
read_data(
input_files=input_files,
Expand All @@ -108,12 +146,12 @@ def read_pickle(

def to_json(
self,
output_file_dir,
write_to_filename=False,
keep_filename_column=False,
output_file_dir: str,
write_to_filename: bool = False,
keep_filename_column: bool = False,
):
"""
See nemo_curator.utils.distributed_utils.write_to_disk docstring for other parameters.
See nemo_curator.utils.distributed_utils.write_to_disk docstring for parameters.
"""
write_to_disk(
Expand All @@ -126,12 +164,12 @@ def to_json(

def to_parquet(
self,
output_file_dir,
write_to_filename=False,
keep_filename_column=False,
output_file_dir: str,
write_to_filename: bool = False,
keep_filename_column: bool = False,
):
"""
See nemo_curator.utils.distributed_utils.write_to_disk docstring for other parameters.
See nemo_curator.utils.distributed_utils.write_to_disk docstring for parameters.
"""
write_to_disk(
Expand All @@ -144,8 +182,8 @@ def to_parquet(

def to_pickle(
self,
output_file_dir,
write_to_filename=False,
output_file_dir: str,
write_to_filename: bool = False,
):
raise NotImplementedError("DocumentDataset does not support to_pickle yet")

Expand Down Expand Up @@ -190,7 +228,7 @@ def to_pandas(self):
def _read_json_or_parquet(
input_files: Union[str, List[str]],
file_type: str,
backend: str,
backend: Literal["cudf", "pandas"],
files_per_partition: int,
add_filename: bool,
input_meta: Union[str, dict] = None,
Expand All @@ -217,8 +255,8 @@ def _read_json_or_parquet(
file_ext = "." + file_type

if isinstance(input_files, list):
# List of jsonl or parquet files
if all(f.endswith(file_ext) for f in input_files):
# List of files
if all(os.path.isfile(f) for f in input_files):
raw_data = read_data(
input_files,
file_type=file_type,
Expand Down
Loading

0 comments on commit d0dd30b

Please sign in to comment.