Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor output #170

Merged
merged 15 commits into from
May 19, 2024
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ classifiers =
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3.12
Natural Language :: English
Intended Audience :: Science/Research
Topic :: Scientific/Engineering
Expand Down
17 changes: 7 additions & 10 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,33 +72,31 @@
)

output_eleme = [
toughio.Output(
"element",
toughio.ElementOutput(
float(time),
np.array([f"AAA0{i}" for i in range(10)]),
{
"X": np.random.rand(10),
"Y": np.random.rand(10),
"Z": np.random.rand(10),
"PRES": np.random.rand(10),
"TEMP": np.random.rand(10),
},
np.array([f"AAA0{i}" for i in range(10)]),
)
for time in range(3)
]

output_conne = [
toughio.Output(
"connection",
toughio.ConnectionOutput(
float(time),
np.array([[f"AAA0{i}", f"AAA0{i}"] for i in range(10)]),
{
"X": np.random.rand(10),
"Y": np.random.rand(10),
"Z": np.random.rand(10),
"HEAT": np.random.rand(10),
"FLOW": np.random.rand(10),
},
np.array([[f"AAA0{i}", f"AAA0{i}"] for i in range(10)]),
)
for time in range(3)
]
Expand Down Expand Up @@ -170,10 +168,9 @@ def allclose(x, y, atol=1.0e-8, ignore_keys=None, ignore_none=False):
if x.cell_data:
assert allclose(x.cell_data, y.cell_data, atol=atol)

elif isinstance(x, toughio.Output):
assert isinstance(y, toughio.Output)
elif isinstance(x, (toughio.ElementOutput, toughio.ConnectionOutput)):
assert isinstance(y, (toughio.ElementOutput, toughio.ConnectionOutput))

assert allclose(x.type, y.type, atol=atol)
assert allclose(x.time, y.time, atol=atol)
assert allclose(x.data, y.data, atol=atol)

Expand Down Expand Up @@ -205,7 +202,7 @@ def convert_outputs_labels(outputs, connection=False):
output.labels[:] = toughio.convert_labels(output.labels)

else:
labels = toughio.convert_labels(output.labels.ravel())
labels = toughio.convert_labels(np.ravel(output.labels))
output.labels[:] = labels.reshape((labels.size // 2, 2))

except TypeError:
Expand Down
7 changes: 6 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,18 +166,23 @@ def test_extract(file_format, split, connection):
this_dir, "support_files", "outputs", f"{base_filename}.csv"
)
outputs_ref = toughio.read_output(filename_ref)
outputs_ref = outputs_ref if isinstance(outputs_ref, list) else [outputs_ref]

if not split:
outputs = toughio.read_output(output_filename, connection=connection)
outputs = outputs if isinstance(outputs, list) else [outputs]

for output_ref, output in zip(outputs_ref, outputs):
assert output_ref.time == output.time
for k, v in output_ref.data.items():
assert helpers.allclose(v.mean(), output.data[k].mean(), atol=1.0e-2)

else:
filenames = glob.glob(os.path.join(tempdir, f"{base_filename}_*.csv"))

for i, output_filename in enumerate(sorted(filenames)):
outputs = toughio.read_output(output_filename)
outputs = outputs if isinstance(outputs, list) else [outputs]

assert len(outputs) == 1

Expand Down Expand Up @@ -275,5 +280,5 @@ def test_save2incon(reset):

incon = toughio.read_output(output_filename)

assert save.labels.tolist() == incon.labels.tolist()
assert list(save.labels) == list(incon.labels)
helpers.allclose(save.data, incon.data)
66 changes: 64 additions & 2 deletions tests/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def test_output_eleme(filename, filename_ref, file_format):
for output, time_ref in zip(outputs, times_ref):
assert time_ref == output.time
assert (
save.labels.tolist() == output.labels.tolist()
list(save.labels) == list(output.labels)
if file_format in {"csv", "petrasim", "tough"}
else len(output.labels) == 0
else not output.labels
)
if file_format != "tough":
assert keys_ref == sorted(list(output.data))
Expand Down Expand Up @@ -110,12 +110,36 @@ def test_output(output_ref, file_format):
reader_kws={},
)

output = output if isinstance(output, list) else [output]
output_ref = output_ref if isinstance(output_ref, list) else [output_ref]
for out_ref, out in zip(output_ref, output):
# Careful here, tecplot format has no label
helpers.allclose(out, out_ref)


@pytest.mark.parametrize(
"filename",
[
"OUTPUT_ELEME.csv",
"OUTPUT_ELEME.tec",
"OUTPUT_ELEME_PETRASIM.csv",
"OUTPUT.out",
"OUTPUT_CONNE.csv",
],
)
def test_output_time_steps(filename):
this_dir = os.path.dirname(os.path.abspath(__file__))
filename = os.path.join(this_dir, "support_files", "outputs", filename)
outputs_ref = toughio.read_output(filename)

time_steps = [0, 2, -1]
outputs = toughio.read_output(filename, time_steps=time_steps)
outputs_ref = [outputs_ref[time_step] for time_step in time_steps]

for out_ref, out in zip(outputs_ref, outputs):
helpers.allclose(out, out_ref)


def test_save():
this_dir = os.path.dirname(os.path.abspath(__file__))
filename = os.path.join(this_dir, "support_files", "outputs", "SAVE.out")
Expand All @@ -129,3 +153,41 @@ def test_save():
assert helpers.allclose(0.01, save.data["porosity"].mean())

assert "userx" not in save.data


@pytest.mark.parametrize(
"output_ref, islice",
[
(helpers.output_eleme[0], 0),
(helpers.output_eleme[0], [0, 2]),
(helpers.output_eleme[0], "AAA00"),
(helpers.output_eleme[0], ["AAA00", "AAA02"]),
(helpers.output_conne[0], 0),
(helpers.output_conne[0], [0, 2]),
(helpers.output_conne[0], "AAA00"),
],
)
def test_getitem(output_ref, islice):
output = output_ref[islice]

idx = [islice] if isinstance(islice, (int, str)) else islice
idx = [i if isinstance(i, int) else int(i[-1]) for i in idx]

if not isinstance(output, dict):
assert np.allclose(output.time, output_ref.time)
assert len(idx) == output.n_data

for i, iref in enumerate(idx):
if isinstance(output.labels[i], str):
assert output.labels[i] == output_ref.labels[iref]

else:
for label, label_ref in zip(output.labels[i], output_ref.labels[iref]):
assert label == label_ref

for k, v in output.data.items():
assert np.allclose(v[i], output_ref.data[k][iref])

else:
for k, v in output.items():
assert np.allclose(v, output_ref.data[k][idx[0]])
2 changes: 1 addition & 1 deletion toughio/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.14.2
1.15.0
6 changes: 4 additions & 2 deletions toughio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from .__about__ import __version__
from ._helpers import convert_labels
from ._io import (
Output,
ConnectionOutput,
ElementOutput,
read_input,
read_output,
read_table,
Expand All @@ -24,7 +25,8 @@
__all__ = [
"Mesh",
"CellBlock",
"Output",
"ElementOutput",
"ConnectionOutput",
"meshmaker",
"register_input",
"register_output",
Expand Down
18 changes: 8 additions & 10 deletions toughio/_cli/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import sys

from .. import read_mesh, read_output, write_time_series
from .._io.output._common import reorder_labels
from ..meshmaker import triangulate, voxelize

parser = _get_parser()
Expand All @@ -45,18 +44,17 @@

# Read output file
print(f"Reading file '{args.infile}' ...", end="")

sys.stdout.flush()
output = read_output(args.infile)
if args.file_format != "xdmf":
if args.time_step is not None:
if not (-len(output) <= args.time_step < len(output)):
raise ValueError("Inconsistent time step value.")
output = output[args.time_step]
else:
output = output[-1]
time_step = args.time_step if args.time_step is not None else -1
output = read_output(args.infile, time_steps=time_step)
labels = output.labels

else:
output = read_output(args.infile)
labels = output[-1].labels

print(" Done!")

with_mesh = bool(args.mesh)
Expand Down Expand Up @@ -144,13 +142,13 @@

if args.file_format != "xdmf":
mesh.point_data = {}
mesh.cell_dada = {}
mesh.cell_data = {}

Check warning on line 145 in toughio/_cli/_export.py

View check run for this annotation

Codecov / codecov/patch

toughio/_cli/_export.py#L145

Added line #L145 was not covered by tests
mesh.field_data = {}
mesh.point_sets = {}
mesh.cell_sets = {}
mesh.read_output(output)
else:
output = [reorder_labels(data, mesh.labels) for data in output]
output = [out[mesh.labels] for out in output]

Check warning on line 151 in toughio/_cli/_export.py

View check run for this annotation

Codecov / codecov/patch

toughio/_cli/_export.py#L151

Added line #L151 was not covered by tests
print(" Done!")

# Output file name
Expand Down
5 changes: 3 additions & 2 deletions toughio/_io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
from .input import read as read_input
from .input import register as register_input
from .input import write as write_input
from .output import Output
from .output import ConnectionOutput, ElementOutput
from .output import read as read_output
from .output import register as register_output
from .output import write as write_output
from .table import read as read_table
from .table import register as register_table

__all__ = [
"Output",
"ElementOutput",
"ConnectionOutput",
"register_input",
"register_output",
"read_input",
Expand Down
12 changes: 6 additions & 6 deletions toughio/_io/h5/_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import h5py

from ..output import Output
from ..output import ConnectionOutput, ElementOutput
from ..output import read as read_output
from ..table import read as read_table

Expand All @@ -25,9 +25,9 @@ def write(
----------
filename : str or pathlike
Output file name.
elements : namedtuple, list of namedtuple, str, pathlike or None, optional, default None
elements : str, pathlike, :class:`toughio.ElementOutput`, sequence of :class:`toughio.ElementOutput` or None, optional, default None
Element outputs to export.
connections : namedtuple, list of namedtuple, str, pathlike or None, optional, default None
connections : str, pathlike, :class:`toughio.ConnectionOutput`, sequence of :class:`toughio.ConnectionOutput` or None, optional, default None
Connection outputs to export.
element_history : dict or None, optional, default None
Element history to export.
Expand Down Expand Up @@ -79,20 +79,20 @@ def _write_output(f, outputs, labels_order, connection, **kwargs):
if isinstance(outputs, (str, pathlib.Path)):
outputs = read_output(outputs, labels_order=labels_order, connection=connection)

if isinstance(outputs, Output):
if isinstance(outputs, (ElementOutput, ConnectionOutput)):
outputs = [outputs]

elif isinstance(outputs, (list, tuple)):
for output in outputs:
if not isinstance(output, Output):
if not isinstance(output, (ElementOutput, ConnectionOutput)):
raise ValueError()

else:
raise ValueError()

for output in outputs:
group = f.create_group(f"time={output.time}")
group.create_dataset("labels", data=output.labels.astype("S"), **kwargs)
group.create_dataset("labels", data=list(output.labels), **kwargs)

for k, v in output.data.items():
group.create_dataset(k, data=v, **kwargs)
Expand Down
5 changes: 3 additions & 2 deletions toughio/_io/output/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from . import csv, petrasim, save, tecplot, tough
from ._common import Output
from ._common import ConnectionOutput, ElementOutput
from ._helpers import read, register, write

__all__ = [
"Output",
"ElementOutput",
"ConnectionOutput",
"register",
"read",
"write",
Expand Down
Loading
Loading