Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upper case mpi4py communication in convergence controllers #343

Merged
merged 3 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions pySDC/core/ConvergenceController.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@
self.dependencies(controller, description)
self.logger = logging.getLogger(f"{type(self).__name__}")

if self.params.useMPI:
self.prepare_MPI_datatypes()

def prepare_MPI_datatypes(self):
"""
Prepare MPI datatypes so we don't need to import mpi4py all the time
"""
from mpi4py import MPI

self.MPI_INT = MPI.INT
self.MPI_DOUBLE = MPI.DOUBLE
self.MPI_BOOL = MPI.BOOL

def log(self, msg, S, level=15, **kwargs):
"""
Shortcut that has a default level for the logger. 15 is above debug but below info.
Expand Down Expand Up @@ -329,16 +342,13 @@
kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order))

# log what's happening for debug purposes
self.logger.debug(f'Step {comm.rank} initiates send to step {dest} with tag {kwargs["tag"]}')
self.logger.debug(f'Step {comm.rank} {"" if blocking else "i"}sends to step {dest} with tag {kwargs["tag"]}')

Check warning on line 345 in pySDC/core/ConvergenceController.py

View check run for this annotation

Codecov / codecov/patch

pySDC/core/ConvergenceController.py#L345

Added line #L345 was not covered by tests

if blocking:
req = comm.send(data, dest=dest, **kwargs)
else:
req = comm.isend(data, dest=dest, **kwargs)

# log what's happening for debug purposes
self.logger.debug(f'Step {comm.rank} leaves send to step {dest} with tag {kwargs["tag"]}')

return req

def recv(self, comm, source, **kwargs):
Expand All @@ -355,13 +365,10 @@
kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order))

# log what's happening for debug purposes
self.logger.debug(f'Step {comm.rank} initiates receive from step {source} with tag {kwargs["tag"]}')
self.logger.debug(f'Step {comm.rank} receives from step {source} with tag {kwargs["tag"]}')

Check warning on line 368 in pySDC/core/ConvergenceController.py

View check run for this annotation

Codecov / codecov/patch

pySDC/core/ConvergenceController.py#L368

Added line #L368 was not covered by tests

data = comm.recv(source=source, **kwargs)

# log what's happening for debug purposes
self.logger.debug(f'Step {comm.rank} leaves receive from step {source} with tag {kwargs["tag"]}')

return data

def Send(self, comm, dest, buffer, blocking=False, **kwargs):
Expand All @@ -380,16 +387,13 @@
kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order))

# log what's happening for debug purposes
self.logger.debug(f'Step {comm.rank} initiates Send to step {dest} with tag {kwargs["tag"]}')
self.logger.debug(f'Step {comm.rank} {"" if blocking else "i"}Sends to step {dest} with tag {kwargs["tag"]}')

if blocking:
req = comm.Send(buffer, dest=dest, **kwargs)
else:
req = comm.Isend(buffer, dest=dest, **kwargs)

# log what's, buffer] happening for debug purposes
self.logger.debug(f'Step {comm.rank} leaves Send to step {dest} with tag {kwargs["tag"]}')

return req

def Recv(self, comm, source, buffer, **kwargs):
Expand All @@ -406,13 +410,10 @@
kwargs['tag'] = kwargs.get('tag', abs(self.params.control_order))

# log what's happening for debug purposes
self.logger.debug(f'Step {comm.rank} initiates Receive from step {source} with tag {kwargs["tag"]}')
self.logger.debug(f'Step {comm.rank} Receives from step {source} with tag {kwargs["tag"]}')

data = comm.Recv(buffer, source=source, **kwargs)

# log what's happening for debug purposes
self.logger.debug(f'Step {comm.rank} leaves Receive from step {source} with tag {kwargs["tag"]}')

return data

def reset_variable(self, controller, name, MPI=False, place=None, where=None, init=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
"control_order": 95,
"max_restarts": 10,
"crash_after_max_restarts": True,
"restart_from_first_step": True,
"restart_from_first_step": False,
"step_size_spreader": SpreadStepSizesBlockwise.get_implementation(useMPI=params['useMPI']),
}

Expand Down Expand Up @@ -242,7 +242,6 @@
from mpi4py import MPI

self.OR = MPI.LOR
self.INT = MPI.INT

super().__init__(controller, params, description)
self.buffers = Pars({"restart": False, "max_restart_reached": False, 'restart_earlier': False})
Expand All @@ -260,7 +259,7 @@
Returns:
None
"""
assert S.status.slot == comm.rank
crash_now = False

if S.status.first:
# check if we performed too many restarts
Expand All @@ -269,27 +268,38 @@

if self.buffers.max_restart_reached and S.status.restart:
if self.params.crash_after_max_restarts:
raise ConvergenceError(f"Restarted {S.status.restarts_in_a_row} time(s) already, surrendering now.")
crash_now = True

Check warning on line 271 in pySDC/implementations/convergence_controller_classes/basic_restarting.py

View check run for this annotation

Codecov / codecov/patch

pySDC/implementations/convergence_controller_classes/basic_restarting.py#L271

Added line #L271 was not covered by tests
self.log(
f"Step(s) restarted {S.status.restarts_in_a_row} time(s) already, maximum reached, moving \
on...",
S,
)
elif not S.status.prev_done and not self.params.restart_from_first_step:
# receive information about restarts from earlier ranks
self.buffers.restart_earlier, self.buffers.max_restart_reached = self.recv(comm, source=S.status.slot - 1)
buff = np.empty(3, dtype=bool)
self.Recv(comm=comm, source=S.status.slot - 1, buffer=[buff, self.MPI_BOOL])
self.buffers.restart_earlier = buff[0]
self.buffers.max_restart_reached = buff[1]
crash_now = buff[2]

# decide whether to restart
S.status.restart = (S.status.restart or self.buffers.restart_earlier) and not self.buffers.max_restart_reached

# send information about restarts forward
if not S.status.last and not self.params.restart_from_first_step:
self.send(comm, dest=S.status.slot + 1, data=(S.status.restart, self.buffers.max_restart_reached))
buff = np.empty(3, dtype=bool)
buff[0] = S.status.restart
buff[1] = self.buffers.max_restart_reached
buff[2] = crash_now
self.Send(comm, dest=S.status.slot + 1, buffer=[buff, self.MPI_BOOL])

if self.params.restart_from_first_step:
max_restart_reached = comm.bcast(S.status.restarts_in_a_row > self.params.max_restarts, root=0)
S.status.restart = comm.allreduce(S.status.restart, op=self.OR) and not max_restart_reached

if crash_now:
raise ConvergenceError("Surrendering because of too many restarts...")

return None

def prepare_next_block(self, controller, S, size, time, Tend, comm, **kwargs):
Expand Down Expand Up @@ -317,14 +327,14 @@
self.Send(
comm,
dest=S.status.slot - restart_from,
buffer=[buff, self.INT],
buffer=[buff, self.MPI_INT],
blocking=False,
)

# receive new number of restarts in a row
if S.status.slot + restart_from < size:
buff = np.empty(1, dtype=int)
self.Recv(comm, source=(S.status.slot + restart_from), buffer=[buff, self.INT])
self.Recv(comm, source=(S.status.slot + restart_from), buffer=[buff, self.MPI_INT])
S.status.restarts_in_a_row = buff[0]
else:
S.status.restarts_in_a_row = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def dependencies(self, controller, description, **kwargs):
return None

@staticmethod
def check_convergence(S):
def check_convergence(S, self=None):
"""
Check the convergence of a single step.
Test the residual and max. number of iterations as well as allowing overrides to both stop and continue.
Expand All @@ -81,6 +81,13 @@ def check_convergence(S):
) and not S.status.force_continue
if converged is None:
converged = False

# print information for debugging
if converged and self:
self.debug(
f'Declared convergence: maxiter reached[{"x" if iter_converged else " "}] restol reached[{"x" if res_converged else " "}] e_tol reached[{"x" if e_tol_converged else " "}]',
S,
)
return converged

def check_iteration_status(self, controller, S, **kwargs):
Expand All @@ -94,7 +101,7 @@ def check_iteration_status(self, controller, S, **kwargs):
Returns:
None
"""
S.status.done = self.check_convergence(S)
S.status.done = self.check_convergence(S, self)

if "comm" in kwargs.keys():
self.communicate_convergence(controller, S, **kwargs)
Expand Down Expand Up @@ -136,12 +143,16 @@ def communicate_convergence(self, controller, S, comm):

# recv status
if not S.status.first and not S.status.prev_done:
S.status.prev_done = self.recv(comm, source=S.status.slot - 1)
buff = np.empty(1, dtype=bool)
self.Recv(comm, source=S.status.slot - 1, buffer=[buff, self.MPI_BOOL])
S.status.prev_done = buff[0]
S.status.done = S.status.done and S.status.prev_done

# send status forward
if not S.status.last:
self.send(comm, dest=S.status.slot + 1, data=S.status.done)
buff = np.empty(1, dtype=bool)
buff[0] = S.status.done
self.Send(comm, dest=S.status.slot + 1, buffer=[buff, self.MPI_BOOL])

for hook in controller.hooks:
hook.post_comm(step=S, level_number=0, add_to_stats=True)
23 changes: 18 additions & 5 deletions pySDC/implementations/convergence_controller_classes/hotrod.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import numpy as np

from pySDC.core.ConvergenceController import ConvergenceController
from pySDC.implementations.convergence_controller_classes.estimate_extrapolation_error import (
EstimateExtrapolationErrorNonMPI,
)


class HotRod(ConvergenceController):
Expand Down Expand Up @@ -55,9 +52,19 @@ def dependencies(self, controller, description, **kwargs):
description=description,
)
if not self.params.useMPI:
from pySDC.implementations.convergence_controller_classes.estimate_extrapolation_error import (
EstimateExtrapolationErrorNonMPI,
)
from pySDC.implementations.convergence_controller_classes.basic_restarting import BasicRestartingNonMPI

controller.add_convergence_controller(
EstimateExtrapolationErrorNonMPI, description=description, params={'no_storage': self.params.no_storage}
)
controller.add_convergence_controller(
BasicRestartingNonMPI,
description=description,
params={'restart_from_first_step': True},
)
else:
raise NotImplementedError("Don't know how to estimate extrapolated error with MPI")

Expand Down Expand Up @@ -95,13 +102,14 @@ def check_parameters(self, controller, params, description, **kwargs):

return True, ""

def determine_restart(self, controller, S, **kwargs):
def determine_restart(self, controller, S, MS, **kwargs):
"""
Check if the difference between the error estimates exceeds the allowed tolerance

Args:
controller (pySDC.Controller): The controller
S (pySDC.Step): The current step
MS (list): List of steps

Returns:
None
Expand All @@ -119,7 +127,12 @@ def determine_restart(self, controller, S, **kwargs):
if diff > self.params.HotRod_tol:
S.status.restart = True
self.log(
f"Triggering restart: delta={diff:.2e}, tol={self.params.HotRod_tol:.2e}",
f"Triggering restart: e_em={L.status.error_embedded_estimate:.2e}, e_ex={L.status.error_extrapolation_estimate:.2e} -> delta={diff:.2e}, tol={self.params.HotRod_tol:.2e}",
S,
)
else:
self.debug(
f"Not triggering restart: e_em={L.status.error_embedded_estimate:.2e}, e_ex={L.status.error_extrapolation_estimate:.2e} -> delta={diff:.2e}, tol={self.params.HotRod_tol:.2e}",
S,
)

Expand Down
4 changes: 2 additions & 2 deletions pySDC/implementations/sweeper_classes/Runge_Kutta.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,10 @@ def update_nodes(self):
for j in range(1, m + 1):
rhs += lvl.dt * self.QI[m + 1, j] * self.get_full_f(lvl.f[j])

# implicit solve with prefactor stemming from the diagonal of Qd
# implicit solve with prefactor stemming from the diagonal of Qd, use previous stage as initial guess
if self.coll.implicit:
lvl.u[m + 1][:] = prob.solve_system(
rhs, lvl.dt * self.QI[m + 1, m + 1], lvl.u[0], lvl.time + lvl.dt * self.coll.nodes[m]
rhs, lvl.dt * self.QI[m + 1, m + 1], lvl.u[m], lvl.time + lvl.dt * self.coll.nodes[m]
)
else:
lvl.u[m + 1][:] = rhs[:]
Expand Down
4 changes: 2 additions & 2 deletions pySDC/tests/test_Runge_Kutta_sweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,8 @@ def test_sweeper_equivalence(sweeper_name):
u_all += [get_sorted(stats, type='u')[-1][1]]
assert np.allclose(
u_all[0], u_all[1]
), f'Solution when using RK sweeper does not match with solution generated by generic_implicit sweeper with RK collocation problem for {sweeper_name} method!'
), f'Solution when using RK sweeper does not match solution generated by generic_implicit sweeper with RK collocation problem for {sweeper_name} method!'


if __name__ == '__main__':
test_sweeper_equivalence('ESDIRK53')
test_stability('ESDIRK53')