From 148f9d1098cd3b30b71e923b68fa950b7852d1a6 Mon Sep 17 00:00:00 2001 From: Feda Curic Date: Mon, 18 Dec 2023 13:54:59 +0100 Subject: [PATCH] Add type hints to local_script_lib --- .../localisation/local_script_lib.py | 83 ++++++++++--------- 1 file changed, 46 insertions(+), 37 deletions(-) diff --git a/semeio/workflows/localisation/local_script_lib.py b/semeio/workflows/localisation/local_script_lib.py index 00c1acf70..8c3f52675 100644 --- a/semeio/workflows/localisation/local_script_lib.py +++ b/semeio/workflows/localisation/local_script_lib.py @@ -6,7 +6,7 @@ import math from collections import defaultdict from dataclasses import dataclass, field -from typing import Dict, List, Any +from typing import Dict, List, Any, Union, Tuple, Optional import cwrap import numpy as np @@ -100,6 +100,11 @@ class Decay: azimuth: float grid: Grid + def __call__(self, data_index): + # Default behavior of Decay when called as a function + # This is a placeholder; you can define a more meaningful default behavior + return 1.0 + def __post_init__(self): angle = (90.0 - self.azimuth) * math.pi / 180.0 self.cosangle = math.cos(angle) @@ -134,7 +139,7 @@ def norm_dist_square(self, data_index): class GaussianDecay(Decay): cutoff: bool - def __call__(self, data_index): + def __call__(self, data_index: List[int]) -> float: d2 = super().norm_dist_square(data_index) if self.cutoff and d2 > 1.0: return 0.0 @@ -260,8 +265,8 @@ def write_qc_parameter_surface( corr_name: str, surface_scale: bool, reference_surface_file: str, - param_for_surface: object, - log_level=LogLevel.OFF, + param_for_surface: ma.MaskedArray, + log_level: LogLevel = LogLevel.OFF, ) -> None: # pylint: disable=too-many-arguments @@ -366,9 +371,11 @@ def build_decay_object( grid: object, use_cutoff: bool, tapering_range: float = 1.5, -) -> Any: +) -> Decay: # pylint: disable=too-many-arguments - decay_obj = None + decay_obj: Union[ + GaussianDecay, ExponentialDecay, ConstGaussianDecay, ConstExponentialDecay + ] if method == "gaussian_decay": decay_obj = GaussianDecay( ref_pos, @@ -420,7 +427,9 @@ def build_decay_object( return decay_obj -def calculate_scaling_vector_fields(grid: object, decay_obj: object): +def calculate_scaling_vector_fields( + grid: object, decay_obj: Union[Decay, ConstantScalingFactor] +) -> ma.MaskedArray: assert isinstance(grid, Grid) nx = grid.getNX() ny = grid.getNY() @@ -440,7 +449,9 @@ def calculate_scaling_vector_fields(grid: object, decay_obj: object): return scaling_vector -def calculate_scaling_vector_surface(grid: object, decay_obj: object): +def calculate_scaling_vector_surface( + grid: object, decay_obj: Union[ConstantScalingFactor, Decay] +) -> ma.MaskedArray: assert isinstance(grid, Surface) nx = grid.getNX() ny = grid.getNY() @@ -458,8 +469,8 @@ def calculate_scaling_vector_surface(grid: object, decay_obj: object): def apply_decay( method: str, - row_scaling: object, - grid: object, + row_scaling: RowScaling, + grid: Grid, ref_pos: list, main_range: float, perp_range: float, @@ -467,7 +478,7 @@ def apply_decay( use_cutoff: bool = False, tapering_range: float = 1.5, calculate_qc_parameter: bool = False, -): +) -> Tuple[Optional[ma.MaskedArray], Optional[ma.MaskedArray]]: # pylint: disable=too-many-arguments,too-many-locals """ Calculates the scaling factor, assign it to ERT instance by row_scaling @@ -505,12 +516,12 @@ def apply_decay( def apply_constant( - row_scaling: object, - grid: object, + row_scaling: RowScaling, + grid: Grid, value: float, log_level: LogLevel, calculate_qc_parameter: bool = False, -): +) -> Tuple[Optional[ma.MaskedArray], Optional[ma.MaskedArray]]: # pylint: disable=too-many-arguments,too-many-locals """ Assign constant value to the scaling factor, @@ -536,13 +547,13 @@ def apply_constant( def apply_from_file( - row_scaling: object, - grid: object, + row_scaling: RowScaling, + grid: Grid, filename: str, param_name: str, log_level: LogLevel, calculate_qc_parameter: bool = False, -): +) -> Tuple[Optional[ma.MaskedArray], None]: # pylint: disable=too-many-arguments, too-many-locals debug_print( f"Read scaling factors as parameter {param_name}", LogLevel.LEVEL3, log_level @@ -643,8 +654,11 @@ def calculate_scaling_factors_in_regions( def smooth_parameter( - grid, smooth_range_list, scaling_values, active_region_values_used -): + grid: Grid, + smooth_range_list: List[int], + scaling_values: List[float], + active_region_values_used: ma.MaskedArray, +) -> ma.MaskedArray: """ Function taking as input a 3D parameter scaling_values and calculates a new 3D parameter scaling_values_smooth using local average within a rectangular window @@ -672,13 +686,12 @@ def smooth_parameter( if active_region_values_used[index0] is not ma.masked: sumv = 0.0 nval = 0 - ilow = max(0, i0 - di) - ihigh = min(i0 + di + 1, nx) - jlow = max(0, j0 - dj) - jhigh = min(j0 + dj + 1, ny) + ilow: int = max(0, i0 - di) + ihigh: int = min(i0 + di + 1, nx) + jlow: int = max(0, j0 - dj) + jhigh: int = min(j0 + dj + 1, ny) for i in range(ilow, ihigh): for j in range(jlow, jhigh): - # index = i + j * nx + k * nx * ny index = k + j * nz + i * nz * ny if active_region_values_used[index] is not ma.masked: # Only use values from grid cells that are active @@ -692,16 +705,16 @@ def smooth_parameter( def apply_segment( - row_scaling, - grid, - region_param_dict, - active_segment_list, - scaling_factor_list, - smooth_range_list, - corr_name, - log_level=LogLevel.OFF, + row_scaling: RowScaling, + grid: Grid, + region_param_dict: Dict[str, Any], + active_segment_list: List[int], + scaling_factor_list: List[float], + smooth_range_list: List[int], + corr_name: str, + log_level: LogLevel = LogLevel.OFF, calculate_qc_parameter: bool = False, -): +) -> Tuple[Any, None]: # pylint: disable=too-many-arguments,too-many-locals """ Purpose: Use region numbers and list of scaling factors per region to @@ -750,10 +763,6 @@ def apply_segment( # Assign values to row_scaling object row_scaling.assign_vector(scaling_values) - # for index in range(data_size): - # global_index = grid.global_index(active_index=index) - # row_scaling[index] = scaling_values[global_index] - not_defined_in_region_param = [] for n in active_segment_list: if n not in regions_in_param: