Skip to content

Commit

Permalink
linting and debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
Olender committed Nov 22, 2024
1 parent 3da4fc2 commit 63b7c27
Show file tree
Hide file tree
Showing 21 changed files with 199 additions and 55 deletions.
2 changes: 1 addition & 1 deletion save_new_shot_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
print("PAUSE")

with open(file_name, "wb") as f:
pickle.dump(sliced_array, f)
pickle.dump(sliced_array, f)
27 changes: 14 additions & 13 deletions shot_filters.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,38 @@
import spyro
import numpy as np
from scipy.signal import butter, filtfilt, sosfilt
from scipy.signal import butter, filtfilt, sosfilt
from scipy.signal import sosfilt
import sys

filter_type = 'butter'

filter_frequency = 7.0

def filter_shot(shot, cutoff, fs, filter_type = 'butter'):

def filter_shot(shot, cutoff, fs, filter_type='butter'):
if filter_type == 'butter':
return butter_filter(shot,cutoff, fs)
return butter_filter(shot, cutoff, fs)


def butter_filter(shot, cutoff, fs, order=1):

""" Low-pass filter the shot record with sampling-rate fs Hz
and cutoff freq. Hz
"""
nyq = 0.5*fs # Nyquist Frequency

nyq = 0.5*fs # Nyquist Frequency
normal_cutoff = (cutoff) / nyq
# Get the filter coefficients

# Get the filter coefficients
b, a = butter(order, normal_cutoff, btype="low", analog=False)

nc, nr = np.shape(shot)

for rec in range(nr):
shot[:,rec] = filtfilt(b, a, shot[:,rec])
shot[:, rec] = filtfilt(b, a, shot[:, rec])

return shot


frequency = 7.0
dt = 0.0001
degree = 4
Expand Down Expand Up @@ -91,7 +92,7 @@ def butter_filter(shot, cutoff, fs, order=1):
spyro.io.load_shots(fwi, file_name="shots/shot_record_")
shots = fwi.forward_solution_receivers
shots *= 5.455538535049624
print(f'Applying {filter_type} filter for {filter_frequency}Hz', flush = True)
p_filtered = filter_shot(shots, filter_frequency, fs, filter_type = filter_type)
print(f'Applying {filter_type} filter for {filter_frequency}Hz', flush=True)
p_filtered = filter_shot(shots, filter_frequency, fs, filter_type=filter_type)
shot_filename = f"shots/shot_record_{filter_frequency}_"
spyro.io.save_shots(fwi, file_name=shot_filename)
1 change: 0 additions & 1 deletion spyro/examples/camembert.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,4 +211,3 @@ def _camembert_velocity_model(self):
)
self.set_initial_velocity_model(conditional=cond, dg_velocity_model=False)
return None

3 changes: 2 additions & 1 deletion spyro/examples/immersed_polygon.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def _polygon_velocity_model(self):
"shot_record_file": None,
}


class Polygon_acoustic_FWI(Rectangle_acoustic_FWI):
"""polygon model.
This class is a child of the Example_model class.
Expand Down Expand Up @@ -226,4 +227,4 @@ def _polygon_velocity_model(self):
cond = fire.conditional(z <= middle_of_pad, v0, cond)

self.set_initial_velocity_model(conditional=cond, dg_velocity_model=False)
return None
return None
1 change: 1 addition & 0 deletions spyro/examples/rectangle.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ class Rectangle_acoustic_FWI(Example_model_acoustic_FWI):
If True, the mesh will be periodic in all directions. The default is
False.
"""

def __init__(
self,
dictionary=None,
Expand Down
3 changes: 1 addition & 2 deletions spyro/io/basicio.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def wrapper(*args, **kwargs):
J_total[0] /= comm.comm.size

elif args[0].parallelism_type == "spatial" and args[0].number_of_sources > 1:
num = args[0].number_of_sources
residual_list = args[1]
J_total = np.zeros((1))

Expand Down Expand Up @@ -205,7 +204,7 @@ def wrapper(*args, **kwargs):
kwargs,
misfit=current_misfit,
)
)
)
grad_total += grad

grad_total /= num
Expand Down
10 changes: 6 additions & 4 deletions spyro/io/model_parameters.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import numpy as np
import uuid
from mpi4py import MPI
from firedrake import COMM_WORLD
from mpi4py import MPI # noqa:F401
from firedrake import COMM_WORLD # noqa:
import warnings
from .. import io
from .. import utils
from .. import meshing
from ..meshing.meshing_functions import cells_per_wavelength

# default_optimization_parameters = {
# "General": {"Secant": {"Type": "Limited-Memory BFGS",
Expand Down Expand Up @@ -518,7 +519,6 @@ def _sanitize_comm(self, comm):
if self.parallelism_type == "custom":
self.shot_ids_per_propagation = dictionary["parallelism"]["shot_ids_per_propagation"]
elif self.parallelism_type == "automatic":
available_cores = COMM_WORLD.size
self.shot_ids_per_propagation = [[i] for i in range(0, self.number_of_sources)]
elif self.parallelism_type == "spatial":
self.shot_ids_per_propagation = [[i] for i in range(0, self.number_of_sources)]
Expand Down Expand Up @@ -606,7 +606,7 @@ def _sanitize_optimization_and_velocity_for_fwi(self):
self.initial_velocity_model_file = dictionary["inversion"][
"initial_guess_model_file"
]
except:
except KeyError:
self.initial_velocity_model_file = None
self.fwi_output_folder = "fwi/"
self.control_output_file = self.fwi_output_folder + "control"
Expand Down Expand Up @@ -638,6 +638,7 @@ def _sanitize_optimization_and_velocity_for_fwi(self):
if "shot_record_file" in dictionary["inversion"]:
if dictionary["inversion"]["shot_record_file"] is not None:
self.real_shot_record = np.load(dictionary["inversion"]["shot_record_file"])

def _sanitize_optimization_and_velocity_without_fwi(self):
dictionary = self.input_dictionary
if "synthetic_data" in dictionary:
Expand Down Expand Up @@ -721,6 +722,7 @@ def set_mesh(
mesh_parameters.setdefault("degree", self.degree)
mesh_parameters.setdefault("velocity_model_file", self.initial_velocity_model_file)
mesh_parameters.setdefault("cell_type", self.cell_type)
print(f"Method: {self.method}, Degree: {self.degree}, Dimension: {self.dimension}")
mesh_parameters.setdefault("cells_per_wavelength", cells_per_wavelength(self.method, self.degree, self.dimension))

self._set_mesh_length(
Expand Down
4 changes: 2 additions & 2 deletions spyro/meshing/meshing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def cells_per_wavelength(method, degree, dimension):
'mlt3tet': 3.72,
}

if dimension == 2 and (method == 'MLT' or method == 'CG'):
if dimension == 2 and (method == 'mass_lumped_triangle'):
cell_type = 'tri'
if dimension == 3 and (method == 'MLT' or method == 'CG'):
if dimension == 3 and (method == 'mass_lumped_triangle'):
cell_type = 'tet'

key = method.lower()+str(degree)+cell_type
Expand Down
8 changes: 3 additions & 5 deletions spyro/solvers/inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
from scipy.optimize import minimize as scipy_minimize
from mpi4py import MPI # noqa: F401
import numpy as np
from copy import deepcopy
import resource

from .acoustic_wave import AcousticWave
from ..utils import compute_functional
from ..utils import Gradient_mask_for_pml, Mask
from ..plots import plot_model as spyro_plot_model
from ..io.basicio import ensemble_shot_record
from ..io.basicio import switch_serial_shot
from ..io.basicio import load_shots, save_shots

Expand Down Expand Up @@ -200,7 +198,7 @@ def calculate_misfit(self, c=None):
if self.parallelism_type == "spatial" and self.number_of_sources > 1:
misfit_list = []
guess_shot_record_list = []
for snum in range (self.number_of_sources):
for snum in range(self.number_of_sources):
switch_serial_shot(self, snum)
guess_shot_record_list.append(self.forward_solution_receivers)
misfit_list.append(self.real_shot_record[snum] - self.forward_solution_receivers)
Expand Down Expand Up @@ -396,7 +394,7 @@ def get_functional(self, c=None):
print(f"Functional: {Jm} at iteration: {self.current_iteration}", flush=True)
with open("functional_values.txt", "a") as file:
file.write(f"Iteration: {self.current_iteration}, Functional: {Jm}\n")

with open("peak_memory.txt", "a") as file:
file.write(f"Peak memory usage: {peak_memory_mb:.2f} MB \n")

Expand Down Expand Up @@ -619,7 +617,7 @@ def forward_solve(self):
super().forward_solve()
if self.parallelism_type == "spatial" and self.number_of_sources > 1:
real_shot_record_list = []
for snum in range (self.number_of_sources):
for snum in range(self.number_of_sources):
switch_serial_shot(self, snum)
real_shot_record_list.append(self.receivers_output)
self.real_shot_record = real_shot_record_list
Expand Down
2 changes: 1 addition & 1 deletion spyro/solvers/time_integration_central_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def central_difference(wave, source_ids=[0]):
wave.sources.current_sources = source_ids
rhs_forcing = fire.Cofunction(wave.function_space.dual())

wave.field_logger.start_logging(source_id)
wave.field_logger.start_logging(source_ids)

wave.comm.comm.barrier()

Expand Down
12 changes: 6 additions & 6 deletions spyro/solvers/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ def set_mesh(
self,
user_mesh=None,
mesh_parameters={},
):
"""
Set the mesh for the solver.
):
"""
Set the mesh for the solver.
Args:
user_mesh (optional): User-defined mesh. Defaults to None.
Expand Down Expand Up @@ -373,7 +373,7 @@ def update_source_expression(self, t):
pass

@ensemble_propagator
def wave_propagator(self, dt=None, final_time=None, source_num=0):
def wave_propagator(self, dt=None, final_time=None, source_nums=[0]):
"""Propagates the wave forward in time.
Currently uses central differences.
Expand All @@ -398,8 +398,8 @@ def wave_propagator(self, dt=None, final_time=None, source_num=0):
if dt is not None:
self.dt = dt

self.current_source = source_num
usol, usol_recv = time_integrator(self, source_num)
self.current_sources = source_nums
usol, usol_recv = time_integrator(self, source_nums)

return usol, usol_recv

Expand Down
2 changes: 1 addition & 1 deletion spyro/tools/velocity_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def smooth_velocity_field_file(input_filename, output_filename, sigma, show=Fals
vp[:, index] = trace
else:
raise ValueError("Not yet implemented!")

vp_min = np.min(vp)
vp_max = np.max(vp)
print(f"Velocity model has minimum vp of {vp_min}, and max of {vp_max}")
Expand Down
2 changes: 0 additions & 2 deletions spyro/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def compute_functional(Wave_object, residual):
"""
num_receivers = Wave_object.number_of_receivers
dt = Wave_object.dt
comm = Wave_object.comm

J = 0
for rn in range(num_receivers):
Expand Down Expand Up @@ -95,7 +94,6 @@ def mpi_init(model):
num_cores_per_propagation = available_cores
elif model.parallelism_type == "custom":
shot_ids_per_propagation = model.shot_ids_per_propagation
num_max_shots_per_core = max(len(sublist) for sublist in shot_ids_per_propagation)
num_propagations = len(shot_ids_per_propagation)
num_cores_per_propagation = available_cores / num_propagations

Expand Down
Loading

0 comments on commit 63b7c27

Please sign in to comment.