Skip to content

Commit

Permalink
Merge pull request #38 from PlasmaFAIR/latex-variable-name-processing
Browse files Browse the repository at this point in the history
Latex variable name processing
  • Loading branch information
JoelLucaAdams authored Nov 21, 2024
2 parents 1172e6f + 362e2ab commit 810510d
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 8 deletions.
42 changes: 35 additions & 7 deletions src/sdf_xarray/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import pathlib
import re
from collections import Counter, defaultdict
from itertools import product
from typing import Iterable

import numpy as np
Expand All @@ -21,6 +23,29 @@ def _rename_with_underscore(name: str) -> str:
return name.replace("/", "_").replace(" ", "_").replace("-", "_")


def _process_latex_name(variable_name: str) -> str:
"""Converts variable names to LaTeX format where possible
using the following rules:
- E -> $E_x$
- E -> $E_y$
- E -> $E_z$
This repeats for B, J and P. It only changes the variable
name if there are spaces around the affix (prefix + suffix)
or if there is no trailing space. This is to avoid changing variable
names that may contain these affixes as part of the variable name itself.
"""
prefixes = ["E", "B", "J", "P"]
suffixes = ["x", "y", "z"]
for prefix, suffix in product(prefixes, suffixes):
# Match affix with preceding space and trailing space or end of string
affix_pattern = rf"\b{prefix}{suffix}\b"
# Insert LaTeX format while preserving spaces
replacement = rf"${prefix}_{suffix}$"
variable_name = re.sub(affix_pattern, replacement, variable_name)
return variable_name


def combine_datasets(path_glob: Iterable | str, **kwargs) -> xr.Dataset:
"""Combine all datasets using a single time dimension"""

Expand Down Expand Up @@ -271,7 +296,7 @@ def _process_grid_name(grid_name: str, transform_func) -> str:
dim_name,
coord,
{
"long_name": label,
"long_name": label.replace("_", " "),
"units": unit,
"point_data": value.is_point_data,
"full_name": value.name,
Expand All @@ -290,11 +315,6 @@ def _process_grid_name(grid_name: str, transform_func) -> str:
continue

if isinstance(value, Constant) or value.grid is None:
data_attrs = {}
data_attrs["full_name"] = key
if value.units is not None:
data_attrs["units"] = value.units

# We don't have a grid, either because it's just a
# scalar, or because it's an array over something
# else. We have no more information, so just make up
Expand All @@ -303,6 +323,12 @@ def _process_grid_name(grid_name: str, transform_func) -> str:
dims = [f"dim_{key}_{n}" for n, _ in enumerate(shape)]
base_name = _rename_with_underscore(key)

data_attrs = {}
data_attrs["full_name"] = key
data_attrs["long_name"] = base_name.replace("_", " ")
if value.units is not None:
data_attrs["units"] = value.units

data_vars[base_name] = Variable(dims, value.data, attrs=data_attrs)
continue

Expand Down Expand Up @@ -341,13 +367,15 @@ def _process_grid_name(grid_name: str, transform_func) -> str:
]

# TODO: error handling here? other attributes?
base_name = _rename_with_underscore(key)
long_name = _process_latex_name(base_name.replace("_", " "))
data_attrs = {
"units": value.units,
"point_data": value.is_point_data,
"full_name": key,
"long_name": long_name,
}
lazy_data = indexing.LazilyIndexedArray(SDFBackendArray(key, self))
base_name = _rename_with_underscore(key)
data_vars[base_name] = Variable(var_coords, lazy_data, data_attrs)

# TODO: might need to decode if mult is set?
Expand Down
34 changes: 33 additions & 1 deletion tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import xarray as xr

from sdf_xarray import SDFPreprocess, open_mfdataset
from sdf_xarray import SDFPreprocess, _process_latex_name, open_mfdataset

EXAMPLE_FILES_DIR = pathlib.Path(__file__).parent / "example_files"
EXAMPLE_MISMATCHED_FILES_DIR = (
Expand Down Expand Up @@ -115,6 +115,38 @@ def test_time_dim_units():
assert df["time"].full_name == "time"


def test_latex_rename_variables():
df = xr.open_mfdataset(
EXAMPLE_ARRAYS_DIR.glob("*.sdf"),
preprocess=SDFPreprocess(),
keep_particles=True,
)
assert df["Electric_Field_Ex"].attrs["long_name"] == "Electric Field $E_x$"
assert df["Electric_Field_Ey"].attrs["long_name"] == "Electric Field $E_y$"
assert df["Electric_Field_Ez"].attrs["long_name"] == "Electric Field $E_z$"
assert df["Magnetic_Field_Bx"].attrs["long_name"] == "Magnetic Field $B_x$"
assert df["Magnetic_Field_By"].attrs["long_name"] == "Magnetic Field $B_y$"
assert df["Magnetic_Field_Bz"].attrs["long_name"] == "Magnetic Field $B_z$"
assert df["Current_Jx"].attrs["long_name"] == "Current $J_x$"
assert df["Current_Jy"].attrs["long_name"] == "Current $J_y$"
assert df["Current_Jz"].attrs["long_name"] == "Current $J_z$"
assert df["Particles_Px_Electron"].attrs["long_name"] == "Particles $P_x$ Electron"
assert df["Particles_Py_Electron"].attrs["long_name"] == "Particles $P_y$ Electron"
assert df["Particles_Pz_Electron"].attrs["long_name"] == "Particles $P_z$ Electron"

assert _process_latex_name("Example") == "Example"
assert _process_latex_name("PxTest") == "PxTest"

assert (
df["Absorption_Fraction_of_Laser_Energy_Absorbed"].attrs["long_name"]
== "Absorption Fraction of Laser Energy Absorbed"
)
assert (
df["Derived_Average_Particle_Energy"].attrs["long_name"]
== "Derived Average Particle Energy"
)


def test_arrays_with_no_grids():
with xr.open_dataset(EXAMPLE_ARRAYS_DIR / "0001.sdf") as df:
laser_phase = "laser_x_min_phase"
Expand Down

0 comments on commit 810510d

Please sign in to comment.