From 4ed16bee1d88fe781b6fad5bff73aa9d49720898 Mon Sep 17 00:00:00 2001 From: Pedro Bressan Date: Tue, 19 Dec 2023 18:41:00 -0300 Subject: [PATCH] MNT: avoid code interpolation code repetition. --- rocketpy/mathutils/function.py | 93 ++++++++++++++++------------------ 1 file changed, 43 insertions(+), 50 deletions(-) diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index 0e8213ac1..ceeb6a2e2 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -484,31 +484,7 @@ def get_value_opt(x): elif self.__interpolation__ == "shepard": # change the function's name to avoid mypy's error def get_value_opt_multiple(*args): - x_data = self.source[:, 0:-1] # Support for N-Dimensions - y_data = self.source[:, -1] - - arg_stack = np.column_stack(args) - arg_qty, arg_dim = arg_stack.shape - result = np.zeros(arg_qty) - - # Reshape to vectorize calculations - x = arg_stack.reshape(arg_qty, 1, arg_dim) - - sub_matrix = x_data - x - distances_squared = np.sum(sub_matrix**2, axis=2) - - # Remove zero distances from further calculations - zero_distances = np.where(distances_squared == 0) - valid_indexes = np.ones(arg_qty, dtype=bool) - valid_indexes[zero_distances[0]] = False - - weights = distances_squared[valid_indexes] ** (-1.5) - numerator_sum = np.sum(y_data * weights, axis=1) - denominator_sum = np.sum(weights, axis=1) - result[valid_indexes] = numerator_sum / denominator_sum - result[~valid_indexes] = y_data[zero_distances[1]] - - return result if len(result) > 1 else result[0] + return self.__interpolate_shepard__(args) get_value_opt = get_value_opt_multiple @@ -880,31 +856,7 @@ def get_value(self, *args): # Returns value for shepard interpolation elif self.__interpolation__ == "shepard": - x_data = self.source[:, 0:-1] # Support for N-Dimensions - y_data = self.source[:, -1] - - arg_stack = np.column_stack(args) - arg_qty, arg_dim = arg_stack.shape - result = np.zeros(arg_qty) - - # Reshape to vectorize calculations - x = arg_stack.reshape(arg_qty, 1, arg_dim) - - sub_matrix = x_data - x - distances_squared = np.sum(sub_matrix**2, axis=2) - - # Remove zero distances from further calculations - zero_distances = np.where(distances_squared == 0) - valid_indexes = np.ones(arg_qty, dtype=bool) - valid_indexes[zero_distances[0]] = False - - weights = distances_squared[valid_indexes] ** (-1.5) - numerator_sum = np.sum(y_data * weights, axis=1) - denominator_sum = np.sum(weights, axis=1) - result[valid_indexes] = numerator_sum / denominator_sum - result[~valid_indexes] = y_data[zero_distances[1]] - - return result if len(result) > 1 else result[0] + return self.__interpolate_shepard__(args) # Returns value for polynomial interpolation function type elif self.__interpolation__ == "polynomial": @@ -1656,6 +1608,47 @@ def __interpolate_akima__(self): coeffs[4 * i : 4 * i + 4] = np.linalg.solve(matrix, result) self.__akima_coefficients__ = coeffs + def __interpolate_shepard__(self, args): + """Calculates the shepard interpolation from the given arguments. + The shepard interpolation is computed by a inverse distance weighting + in a vectorized manner. + + Parameters + ---------- + args : scalar, list + Values where the Function is to be evaluated. + + Returns + ------- + result : scalar, list + The result of the interpolation. + """ + x_data = self.source[:, 0:-1] # Support for N-Dimensions + y_data = self.source[:, -1] + + arg_stack = np.column_stack(args) + arg_qty, arg_dim = arg_stack.shape + result = np.zeros(arg_qty) + + # Reshape to vectorize calculations + x = arg_stack.reshape(arg_qty, 1, arg_dim) + + sub_matrix = x_data - x + distances_squared = np.sum(sub_matrix**2, axis=2) + + # Remove zero distances from further calculations + zero_distances = np.where(distances_squared == 0) + valid_indexes = np.ones(arg_qty, dtype=bool) + valid_indexes[zero_distances[0]] = False + + weights = distances_squared[valid_indexes] ** (-1.5) + numerator_sum = np.sum(y_data * weights, axis=1) + denominator_sum = np.sum(weights, axis=1) + result[valid_indexes] = numerator_sum / denominator_sum + result[~valid_indexes] = y_data[zero_distances[1]] + + return result if len(result) > 1 else result[0] + def __neg__(self): """Negates the Function object. The result has the same effect as multiplying the Function by -1.