Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Process text draft #179

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 88 additions & 16 deletions audinterface/core/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import itertools
import os
import pathlib
import typing

import numpy as np
Expand Down Expand Up @@ -36,6 +37,15 @@ def identity(signal, sampling_rate) -> np.ndarray:
return signal


def data_identity(data):
r"""Default processing function for non-signal data.

In analogy to the identity function for signals,
it returns the data itself
"""
return data


class Process:
r"""Processing interface.

Expand Down Expand Up @@ -212,6 +222,7 @@ def __init__(
self.num_workers = num_workers
r"""Number of workers."""

# leaving process_func unaltered here, currently deferred
self.process_func = process_func
r"""Processing function."""

Expand Down Expand Up @@ -271,10 +282,14 @@ def _process_file(
end = utils.to_timedelta(end, self.sampling_rate)

ext = audeer.file_extension(file).lower()

# print(ext, exit is None)
# Text files
if ext in ["json", "txt"]:
self._processing_mode = "text" # convenience
print("set processing mode")
data = utils.read_text(file, root=root)
# should be idempotent, but is currently deferred
# self._handle_processing_func_args(data_type="text")
y, file = self._process_data(
data,
idx=idx,
Expand All @@ -288,6 +303,7 @@ def _process_file(

# Audio/video files
else:
self._processing_mode = "signal" # convenience
signal, sampling_rate = utils.read_audio(
file,
start=start,
Expand Down Expand Up @@ -568,6 +584,22 @@ def process_folder(
process_func_args=process_func_args,
)

def _set_processing_mode(self, index):
"""Set processsing mode to text depending on input data extension."""

self._processing_mode = "signal"

if audformat.is_segmented_index(index):
extensions = [
pathlib.Path(x).suffix[1:] for x in index.get_level_values("file")
]
else:
# never getting filewise indices so far?
pass

if set(extensions).issubset(set(["json", "txt"])):
self._processing_mode = "text"

def _process_index_wo_segment(
self,
index: pd.Index,
Expand All @@ -592,6 +624,9 @@ def _process_index_wo_segment(
for idx, (file, start, end) in enumerate(index)
]

# modify procesing mode variable when getting json or text files
self._set_processing_mode(index)

xs = audeer.run_tasks(
self._process_file,
params,
Expand All @@ -601,20 +636,38 @@ def _process_index_wo_segment(
task_description=f"Process {len(index)} segments",
)

y = list(itertools.chain.from_iterable([x[0] for x in xs]))
files = list(itertools.chain.from_iterable([x[1] for x in xs]))
starts = list(itertools.chain.from_iterable([x[2] for x in xs]))
ends = list(itertools.chain.from_iterable([x[3] for x in xs]))
if self._processing_mode == "text":
y = [x[0] for x in xs]
files = list(itertools.chain.from_iterable([x[1] for x in xs]))
starts = [x[2] for x in xs]
ends = [x[3] for x in xs]
if (
len(audeer.unique(starts)) == 1
and audeer.unique(starts)[0] is None
and len(audeer.unique(ends)) == 1
and audeer.unique(ends)[0] is None
):
index = audformat.filewise_index(files)
else:
# leave index untouched
# index = audformat.segmented_index(files, starts, ends)
pass

if (
len(audeer.unique(starts)) == 1
and audeer.unique(starts)[0] is None
and len(audeer.unique(ends)) == 1
and audeer.unique(ends)[0] is None
):
index = audformat.filewise_index(files)
else:
index = audformat.segmented_index(files, starts, ends)
y = list(itertools.chain.from_iterable([x[0] for x in xs]))
files = list(itertools.chain.from_iterable([x[1] for x in xs]))
starts = list(itertools.chain.from_iterable([x[2] for x in xs]))
ends = list(itertools.chain.from_iterable([x[3] for x in xs]))

if (
len(audeer.unique(starts)) == 1
and audeer.unique(starts)[0] is None
and len(audeer.unique(ends)) == 1
and audeer.unique(ends)[0] is None
):
index = audformat.filewise_index(files)
else:
index = audformat.segmented_index(files, starts, ends)

y = pd.Series(y, index)

Expand Down Expand Up @@ -686,9 +739,7 @@ def process_index(
)

y = self._process_index_wo_segment(
segmented_index,
root,
process_func_args=process_func_args,
segmented_index, root, process_func_args=process_func_args
)

if cache_path is not None:
Expand Down Expand Up @@ -1117,6 +1168,25 @@ def _helper(x):

return y

def _handle_processing_func_for_text(self):
"""Handle data identity based on input modality.

As text data never have sampling rate the
functional is exchanged

"""

f = data_identity
process_func = f
self.process_func = process_func
signature = inspect.signature(process_func)
self._process_func_signature = dict(signature.parameters)

@staticmethod
def _handle_processing_func_args_for_text(process_func_args):
"""If process_func_args has an sr, strip it."""
...

def _call_data(
self,
data: typing.Any,
Expand Down Expand Up @@ -1146,7 +1216,9 @@ def _call_data(
"""
process_func_args = process_func_args or self.process_func_args
special_args = self._special_args(idx, root, file, process_func_args)
self._handle_processing_func_for_text()
y = self.process_func(data, **special_args, **process_func_args)

return y

def _special_args(
Expand Down
1 change: 1 addition & 0 deletions audinterface/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from audinterface.core.utils import read_audio
from audinterface.core.utils import read_text
from audinterface.core.utils import signal_index
from audinterface.core.utils import sliding_window
from audinterface.core.utils import to_timedelta
28 changes: 18 additions & 10 deletions tests/test_process_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,9 @@ def test_process_index(
):
cache_root = os.path.join(tmpdir, "cache")

process_func = None
process = audinterface.Process(
process_func=None,
process_func=process_func,
num_workers=num_workers,
multiprocessing=multiprocessing,
verbose=False,
Expand Down Expand Up @@ -239,9 +240,16 @@ def test_process_index(
)
if preserve_index:
pd.testing.assert_index_equal(y.index, index)
for (path, _, _), value in y.items():
assert audinterface.utils.read_text(path) == data
assert value == data

# only works for preserved index, otherwise too many to unpack
if preserve_index:
for (path, _, _), value in y.items():
assert audinterface.utils.read_text(path) == data
assert value == data
else:
for path, value in y.items():
assert audinterface.utils.read_text(path) == data
assert value == data

# Segmented index with relative paths
index = audformat.segmented_index(
Expand All @@ -256,9 +264,9 @@ def test_process_index(
)
if preserve_index:
pd.testing.assert_index_equal(y.index, index)
for (file, _, _), value in y.items():
assert audinterface.utils.read_text(file, root=root) == data
assert value == data
for (file, _, _), value in y.items():
assert audinterface.utils.read_text(file, root=root) == data
assert value == data

# Filewise index with absolute paths
index = audformat.filewise_index(path)
Expand All @@ -274,7 +282,7 @@ def test_process_index(
else:
expected_index = audformat.filewise_index(files=list(index))
pd.testing.assert_index_equal(y.index, expected_index)
for (path, _, _), value in y.items():
for path, value in y.items():
assert audinterface.utils.read_text(path) == data
assert value == data

Expand All @@ -291,7 +299,7 @@ def test_process_index(
assert audinterface.utils.read_text(file, root=root) == data
assert value == data
else:
for (file, _, _), value in y.items():
for file, value in y.items():
assert audinterface.utils.read_text(file, root=root) == data
assert value == data

Expand All @@ -305,7 +313,7 @@ def test_process_index(
os.remove(path)

# Fails because second file does not exist
with pytest.raises(RuntimeError):
with pytest.raises(FileNotFoundError):
process.process_index(
index,
preserve_index=preserve_index,
Expand Down