diff --git a/src/sdf_xarray/__init__.py b/src/sdf_xarray/__init__.py index 711fbef..09ab200 100644 --- a/src/sdf_xarray/__init__.py +++ b/src/sdf_xarray/__init__.py @@ -1,6 +1,8 @@ import os import pathlib +import re from collections import Counter, defaultdict +from itertools import product from typing import Iterable import numpy as np @@ -21,6 +23,29 @@ def _rename_with_underscore(name: str) -> str: return name.replace("/", "_").replace(" ", "_").replace("-", "_") +def _process_latex_name(variable_name: str) -> str: + """Converts variable names to LaTeX format where possible + using the following rules: + - E -> $E_x$ + - E -> $E_y$ + - E -> $E_z$ + + This repeats for B, J and P. It only changes the variable + name if there are spaces around the affix (prefix + suffix) + or if there is no trailing space. This is to avoid changing variable + names that may contain these affixes as part of the variable name itself. + """ + prefixes = ["E", "B", "J", "P"] + suffixes = ["x", "y", "z"] + for prefix, suffix in product(prefixes, suffixes): + # Match affix with preceding space and trailing space or end of string + affix_pattern = rf"\b{prefix}{suffix}\b" + # Insert LaTeX format while preserving spaces + replacement = rf"${prefix}_{suffix}$" + variable_name = re.sub(affix_pattern, replacement, variable_name) + return variable_name + + def combine_datasets(path_glob: Iterable | str, **kwargs) -> xr.Dataset: """Combine all datasets using a single time dimension""" @@ -271,7 +296,7 @@ def _process_grid_name(grid_name: str, transform_func) -> str: dim_name, coord, { - "long_name": label, + "long_name": label.replace("_", " "), "units": unit, "point_data": value.is_point_data, "full_name": value.name, @@ -290,11 +315,6 @@ def _process_grid_name(grid_name: str, transform_func) -> str: continue if isinstance(value, Constant) or value.grid is None: - data_attrs = {} - data_attrs["full_name"] = key - if value.units is not None: - data_attrs["units"] = value.units - # We don't have a grid, either because it's just a # scalar, or because it's an array over something # else. We have no more information, so just make up @@ -303,6 +323,12 @@ def _process_grid_name(grid_name: str, transform_func) -> str: dims = [f"dim_{key}_{n}" for n, _ in enumerate(shape)] base_name = _rename_with_underscore(key) + data_attrs = {} + data_attrs["full_name"] = key + data_attrs["long_name"] = base_name.replace("_", " ") + if value.units is not None: + data_attrs["units"] = value.units + data_vars[base_name] = Variable(dims, value.data, attrs=data_attrs) continue @@ -341,13 +367,15 @@ def _process_grid_name(grid_name: str, transform_func) -> str: ] # TODO: error handling here? other attributes? + base_name = _rename_with_underscore(key) + long_name = _process_latex_name(base_name.replace("_", " ")) data_attrs = { "units": value.units, "point_data": value.is_point_data, "full_name": key, + "long_name": long_name, } lazy_data = indexing.LazilyIndexedArray(SDFBackendArray(key, self)) - base_name = _rename_with_underscore(key) data_vars[base_name] = Variable(var_coords, lazy_data, data_attrs) # TODO: might need to decode if mult is set? diff --git a/tests/test_basic.py b/tests/test_basic.py index 83a5f43..f9eb1db 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -3,7 +3,7 @@ import pytest import xarray as xr -from sdf_xarray import SDFPreprocess, open_mfdataset +from sdf_xarray import SDFPreprocess, _process_latex_name, open_mfdataset EXAMPLE_FILES_DIR = pathlib.Path(__file__).parent / "example_files" EXAMPLE_MISMATCHED_FILES_DIR = ( @@ -115,6 +115,38 @@ def test_time_dim_units(): assert df["time"].full_name == "time" +def test_latex_rename_variables(): + df = xr.open_mfdataset( + EXAMPLE_ARRAYS_DIR.glob("*.sdf"), + preprocess=SDFPreprocess(), + keep_particles=True, + ) + assert df["Electric_Field_Ex"].attrs["long_name"] == "Electric Field $E_x$" + assert df["Electric_Field_Ey"].attrs["long_name"] == "Electric Field $E_y$" + assert df["Electric_Field_Ez"].attrs["long_name"] == "Electric Field $E_z$" + assert df["Magnetic_Field_Bx"].attrs["long_name"] == "Magnetic Field $B_x$" + assert df["Magnetic_Field_By"].attrs["long_name"] == "Magnetic Field $B_y$" + assert df["Magnetic_Field_Bz"].attrs["long_name"] == "Magnetic Field $B_z$" + assert df["Current_Jx"].attrs["long_name"] == "Current $J_x$" + assert df["Current_Jy"].attrs["long_name"] == "Current $J_y$" + assert df["Current_Jz"].attrs["long_name"] == "Current $J_z$" + assert df["Particles_Px_Electron"].attrs["long_name"] == "Particles $P_x$ Electron" + assert df["Particles_Py_Electron"].attrs["long_name"] == "Particles $P_y$ Electron" + assert df["Particles_Pz_Electron"].attrs["long_name"] == "Particles $P_z$ Electron" + + assert _process_latex_name("Example") == "Example" + assert _process_latex_name("PxTest") == "PxTest" + + assert ( + df["Absorption_Fraction_of_Laser_Energy_Absorbed"].attrs["long_name"] + == "Absorption Fraction of Laser Energy Absorbed" + ) + assert ( + df["Derived_Average_Particle_Energy"].attrs["long_name"] + == "Derived Average Particle Energy" + ) + + def test_arrays_with_no_grids(): with xr.open_dataset(EXAMPLE_ARRAYS_DIR / "0001.sdf") as df: laser_phase = "laser_x_min_phase"