diff --git a/movement/io/load_bboxes.py b/movement/io/load_bboxes.py index 0b18aac0..437e35d5 100644 --- a/movement/io/load_bboxes.py +++ b/movement/io/load_bboxes.py @@ -13,7 +13,11 @@ from movement.utils.logging import log_error from movement.validators.datasets import ValidBboxesDataset -from movement.validators.files import ValidFile, ValidVIATracksCSV +from movement.validators.files import ( + DEFAULT_FRAME_REGEXP, + ValidFile, + ValidVIATracksCSV, +) logger = logging.getLogger(__name__) @@ -349,7 +353,7 @@ def from_via_tracks_file( def _numpy_arrays_from_via_tracks_file( - file_path: Path, frame_regexp: str = r"(0\d*)\.\w+$" + file_path: Path, frame_regexp: str = DEFAULT_FRAME_REGEXP ) -> dict: """Extract numpy arrays from the input VIA tracks .csv file. @@ -427,7 +431,7 @@ def _numpy_arrays_from_via_tracks_file( def _df_from_via_tracks_file( - file_path: Path, frame_regexp: str = r"(0\d*)\.\w+$" + file_path: Path, frame_regexp: str = DEFAULT_FRAME_REGEXP ) -> pd.DataFrame: """Load VIA tracks .csv file as a dataframe. @@ -526,7 +530,7 @@ def _extract_confidence_from_via_tracks_df(df: pd.DataFrame) -> np.ndarray: def _extract_frame_number_from_via_tracks_df( - df: pd.DataFrame, frame_regexp: str = r"(0\d*)\.\w+$" + df: pd.DataFrame, frame_regexp: str = DEFAULT_FRAME_REGEXP ) -> np.ndarray: """Extract frame numbers from the VIA tracks input dataframe. diff --git a/movement/validators/files.py b/movement/validators/files.py index 6a910010..dd603b22 100644 --- a/movement/validators/files.py +++ b/movement/validators/files.py @@ -12,6 +12,8 @@ from movement.utils.logging import log_error +DEFAULT_FRAME_REGEXP = r"(0\d*)\.\w+$" + @define class ValidFile: