Skip to content

Commit

Permalink
Merge pull request #15 from danielfromearth/develop
Browse files Browse the repository at this point in the history
update `main` linting and formatting
  • Loading branch information
danielfromearth authored Sep 19, 2023
2 parents 6d9da1e + 354184e commit 7e4780e
Show file tree
Hide file tree
Showing 8 changed files with 300 additions and 419 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/lint_and_test_and_bump.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ jobs:

- name: Lint
run: |
poetry run pylint concatenator
poetry run flake8 concatenator
poetry run ruff concatenator
- name: Test with pytest
run: |
Expand Down
3 changes: 1 addition & 2 deletions .github/workflows/lint_and_test_on_pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ jobs:

- name: Lint
run: |
poetry run pylint concatenator
poetry run flake8 concatenator
poetry run ruff concatenator
- name: Test with pytest
run: |
Expand Down
31 changes: 9 additions & 22 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,18 @@ repos:
doc/data/messages/t/trailing-newlines/bad.py|
)$
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: 'v0.0.290'
hooks:
- id: isort
exclude: doc/data/messages/(r/reimported|w/wrong-import-order|u/ungrouped-imports|m/misplaced-future|m/multiple-imports)/bad.py
- id: ruff
args: [ "--fix" ]

- repo: https://github.com/pylint-dev/pylint
rev: v3.0.0a7
# https://github.com/python/black#version-control-integration
- repo: https://github.com/psf/black
rev: 23.9.1
hooks:
- id: pylint
name: pylint
entry: pylint
language: system
types: [ python ]
args:
[
"-rn", # Only display messages
"-sn", # Don't display the score
"--rcfile=.pylintrc", # Link to your config file
]

- repo: https://github.com/pycqa/flake8
rev: 6.1.0
hooks:
- id: flake8
- id: black-jupyter

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.5.1
Expand Down
107 changes: 64 additions & 43 deletions concatenator/group_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,31 @@
between a group hierarchy and a flat structure
"""
import re
from typing import List, Tuple

import netCDF4 as nc # type: ignore
import numpy as np
import xarray as xr

from concatenator import GROUP_DELIM
from concatenator.attribute_handling import (
flatten_coordinate_attribute_paths, regroup_coordinate_attribute)
flatten_coordinate_attribute_paths,
regroup_coordinate_attribute,
)

# Match dimension names such as "__char28" or "__char16". Used for CERES datasets.
_string_dimension_name_pattern = re.compile(r"__char[0-9]+")

def walk(group_node: nc.Group,
path: str,
new_dataset: nc.Dataset,
dimensions_dict_to_populate: dict,
list_of_character_string_vars: list):

def walk(
group_node: nc.Group,
path: str,
new_dataset: nc.Dataset,
dimensions_dict_to_populate: dict,
list_of_character_string_vars: list,
):
"""Recursive function to step through each group and subgroup."""
for key, item in group_node.items():
group_path = f'{path}{GROUP_DELIM}{key}'
group_path = f"{path}{GROUP_DELIM}{key}"

if item.dimensions:
dims = list(item.dimensions.keys())
Expand All @@ -40,14 +44,15 @@ def walk(group_node: nc.Group,
if item.variables:
# Copy variables to root group with new name
for var_name, var in item.variables.items():

var_group_name = f'{group_path}{GROUP_DELIM}{var_name}'
var_group_name = f"{group_path}{GROUP_DELIM}{var_name}"
new_dataset.variables[var_group_name] = var

# Flatten the paths of variables referenced in the coordinates attribute
flatten_coordinate_attribute_paths(new_dataset, var, var_group_name)

if (len(var.dimensions) == 1) and _string_dimension_name_pattern.fullmatch(var.dimensions[0]):
if (len(var.dimensions) == 1) and _string_dimension_name_pattern.fullmatch(
var.dimensions[0]
):
list_of_character_string_vars.append(var_group_name)

# Delete variables
Expand All @@ -58,18 +63,23 @@ def walk(group_node: nc.Group,
# If there are subgroups in this group, call this function
# again on that group.
if item.groups:
walk(item.groups, group_path, new_dataset, dimensions_dict_to_populate, list_of_character_string_vars)
walk(
item.groups,
group_path,
new_dataset,
dimensions_dict_to_populate,
list_of_character_string_vars,
)

# Delete non-root groups
group_names = list(group_node.keys())
for group_name in group_names:
del group_node[group_name]


def flatten_grouped_dataset(nc_dataset: nc.Dataset,
file_to_subset: str,
ensure_all_dims_are_coords: bool = False
) -> Tuple[nc.Dataset, List[str], List[str]]:
def flatten_grouped_dataset(
nc_dataset: nc.Dataset, file_to_subset: str, ensure_all_dims_are_coords: bool = False
) -> tuple[nc.Dataset, list[str], list[str]]:
"""
Transform a netCDF4 Dataset that has groups to an xarray compatible
dataset. xarray does not work with groups, so this transformation
Expand All @@ -95,7 +105,7 @@ def flatten_grouped_dataset(nc_dataset: nc.Dataset,
"""
# Close the existing read-only dataset and reopen in append mode
nc_dataset.close()
nc_dataset = nc.Dataset(file_to_subset, 'r+')
nc_dataset = nc.Dataset(file_to_subset, "r+")

dimensions = {}

Expand All @@ -107,7 +117,7 @@ def flatten_grouped_dataset(nc_dataset: nc.Dataset,
if nc_dataset.variables:
temp_copy_for_iterating = list(nc_dataset.variables.items())
for var_name, var in temp_copy_for_iterating:
new_var_name = f'{GROUP_DELIM}{var_name}'
new_var_name = f"{GROUP_DELIM}{var_name}"

# ds_new.variables[new_var_name] = ds_old.variables[var_name]

Expand All @@ -123,7 +133,7 @@ def flatten_grouped_dataset(nc_dataset: nc.Dataset,
if nc_dataset.dimensions:
temp_copy_for_iterating = list(nc_dataset.dimensions.keys())
for dim_name in temp_copy_for_iterating:
new_dim_name = f'{GROUP_DELIM}{dim_name}'
new_dim_name = f"{GROUP_DELIM}{dim_name}"
# dimensions[new_dim_name] = item.dimensions[dim_name]
# item.renameDimension(dim_name, new_dim_name)
# ds_new.dimensions[new_dim_name] = item.dimensions[dim_name]
Expand All @@ -133,20 +143,24 @@ def flatten_grouped_dataset(nc_dataset: nc.Dataset,
nc_dataset.renameDimension(dim_name, new_dim_name)

# Create a coordinate variable, if it doesn't already exist.
if ensure_all_dims_are_coords and (new_dim_name not in list(nc_dataset.variables.keys())):
if ensure_all_dims_are_coords and (
new_dim_name not in list(nc_dataset.variables.keys())
):
nc_dataset.createVariable(dim.name, datatype=np.int32, dimensions=(dim.name,))
temporary_coordinate_variables.append(dim.name)

list_of_character_string_vars: list[str] = []
walk(nc_dataset.groups, '', nc_dataset, dimensions, list_of_character_string_vars)
walk(nc_dataset.groups, "", nc_dataset, dimensions, list_of_character_string_vars)

# Update the dimensions of the dataset in the root group
nc_dataset.dimensions.update(dimensions)

return nc_dataset, temporary_coordinate_variables, list_of_character_string_vars


def regroup_flattened_dataset(dataset: xr.Dataset, output_file: str) -> None: # pylint: disable=too-many-branches
def regroup_flattened_dataset(
dataset: xr.Dataset, output_file: str
) -> None: # pylint: disable=too-many-branches
"""
Given a list of xarray datasets, combine those datasets into a
single netCDF4 Dataset and write to the disk. Each dataset has been
Expand All @@ -161,15 +175,16 @@ def regroup_flattened_dataset(dataset: xr.Dataset, output_file: str) -> None: #
Name of the output file to write the resulting NetCDF file to.
"""

with nc.Dataset(output_file, mode='w', format='NETCDF4') as base_dataset:
with nc.Dataset(output_file, mode="w", format="NETCDF4") as base_dataset:
# Copy global attributes
base_dataset.setncatts(dataset.attrs)

# Create Groups
group_lst = []
for var_name, _ in dataset.variables.items(): # need logic if there is data in the top level not in a group
group_lst.append('/'.join(str(var_name).split(GROUP_DELIM)[:-1]))
group_lst = ['/' if group == '' else group for group in group_lst]
# need logic if there is data in the top level not in a group
for var_name, _ in dataset.variables.items():
group_lst.append("/".join(str(var_name).split(GROUP_DELIM)[:-1]))
group_lst = ["/" if group == "" else group for group in group_lst]
groups = set(group_lst)
for group in groups:
base_dataset.createGroup(group)
Expand All @@ -196,7 +211,7 @@ def regroup_flattened_dataset(dataset: xr.Dataset, output_file: str) -> None: #

# Get the fill value (since it's not included in xarray var.attrs)
try:
fill_value = var.encoding['_FillValue'].astype(this_dtype)
fill_value = var.encoding["_FillValue"].astype(this_dtype)
except KeyError:
fill_value = None

Expand All @@ -216,20 +231,28 @@ def regroup_flattened_dataset(dataset: xr.Dataset, output_file: str) -> None: #
vartype = "S1"
else:
vartype = str(var.dtype)
var_group.createVariable(new_var_name, vartype,
dimensions=new_var_dims, chunksizes=chunk_sizes,
compression='zlib', complevel=7,
shuffle=shuffle, fill_value=fill_value)
var_group.createVariable(
new_var_name,
vartype,
dimensions=new_var_dims,
chunksizes=chunk_sizes,
compression="zlib",
complevel=7,
shuffle=shuffle,
fill_value=fill_value,
)

# copy variable attributes all at once via dictionary
var_group[new_var_name].setncatts(var.attrs)
# copy variable values
var_group[new_var_name][:] = var.values

# Reconstruct the grouped paths of variables referenced in the coordinates attribute.
if 'coordinates' in var_group[new_var_name].ncattrs():
coord_att = var_group[new_var_name].getncattr('coordinates')
var_group[new_var_name].setncattr('coordinates', regroup_coordinate_attribute(coord_att))
if "coordinates" in var_group[new_var_name].ncattrs():
coord_att = var_group[new_var_name].getncattr("coordinates")
var_group[new_var_name].setncattr(
"coordinates", regroup_coordinate_attribute(coord_att)
)

except Exception as err:
raise err
Expand Down Expand Up @@ -260,16 +283,14 @@ def _calculate_chunks(dim_sizes: list, default_low_dim_chunksize=4000) -> tuple:
"""
number_of_dims = len(dim_sizes)
if number_of_dims <= 3:
chunk_sizes = tuple(default_low_dim_chunksize
if ((s > default_low_dim_chunksize) and (number_of_dims > 1))
else s
for s in dim_sizes
)
chunk_sizes = tuple(
default_low_dim_chunksize
if ((s > default_low_dim_chunksize) and (number_of_dims > 1))
else s
for s in dim_sizes
)
else:
chunk_sizes = tuple(500 if s > 500
else s
for s in dim_sizes
)
chunk_sizes = tuple(500 if s > 500 else s for s in dim_sizes)

return chunk_sizes

Expand Down
Loading

0 comments on commit 7e4780e

Please sign in to comment.