diff --git a/pySDC/implementations/hooks/log_solution.py b/pySDC/implementations/hooks/log_solution.py index 5bf5dc4b1e..42755a3287 100644 --- a/pySDC/implementations/hooks/log_solution.py +++ b/pySDC/implementations/hooks/log_solution.py @@ -81,13 +81,13 @@ class LogToFile(Hooks): Keep in mind that the hook will overwrite files without warning! You can give a custom file name by setting the ``file_name`` class attribute and give a custom way of rendering the - index associated with individual files by giving a different lambda function ``format_index`` class attribute. This - lambda should accept one index and return one string. + index associated with individual files by giving a different function ``format_index`` class attribute. This should + accept one index and return one string. You can also give a custom ``logging_condition`` function, accepting the current level if you want to log selectively. Importantly, you may need to change ``process_solution``. By default, this will return a numpy view of the solution. - Of course, if you are not using numpy, you need to change this. Again, this is a lambda accepting the level. + Of course, if you are not using numpy, you need to change this. Again, this is a function accepting the level. After the fact, you can use the classmethod `get_path` to get the path to a certain data or the `load` function to directly load the solution at a given index. Just configure the hook like you did when you recorded the data @@ -99,6 +99,7 @@ class LogToFile(Hooks): path = None file_name = 'solution' + counter = 0 def logging_condition(L): return True @@ -111,7 +112,6 @@ def format_index(index): def __init__(self): super().__init__() - self.counter = 0 if self.path is None: raise ValueError('Please set a path for logging as the class attribute `LogToFile.path`!') @@ -124,20 +124,41 @@ def __init__(self): if not os.path.isdir(self.path): os.mkdir(self.path) - def post_step(self, step, level_number): + def log_to_file(self, step, level_number, condition, process_solution=None): if level_number > 0: return None L = step.levels[level_number] - if type(self).logging_condition(L): + if condition: path = self.get_path(self.counter) - data = type(self).process_solution(L) + + if process_solution: + data = process_solution(L) + else: + data = type(self).process_solution(L) with open(path, 'wb') as file: pickle.dump(data, file) + self.logger.info(f'Stored file {path!r}') + + type(self).counter += 1 + + def post_step(self, step, level_number): + L = step.levels[level_number] + self.log_to_file(step, level_number, type(self).logging_condition(L)) + + def pre_run(self, step, level_number): + L = step.levels[level_number] + L.uend = L.u[0] + + def process_solution(L): + return { + **type(self).process_solution(L), + 't': L.time, + } - self.counter += 1 + self.log_to_file(step, level_number, True, process_solution=process_solution) @classmethod def get_path(cls, index): @@ -148,3 +169,22 @@ def load(cls, index): path = cls.get_path(index) with open(path, 'rb') as file: return pickle.load(file) + + +class LogToFileAfterXs(LogToFile): + r''' + Log to file after certain amount of time has passed instead of after every step + ''' + + time_increment = 0 + t_next_log = 0 + + def post_step(self, step, level_number): + L = step.levels[level_number] + + if self.t_next_log == 0: + self.t_next_log = self.time_increment + + if L.time + L.dt >= self.t_next_log and not step.status.restart: + super().post_step(step, level_number) + self.t_next_log = max([L.time + L.dt, self.t_next_log]) + self.time_increment diff --git a/pySDC/tests/test_hooks/test_log_to_file.py b/pySDC/tests/test_hooks/test_log_to_file.py index 0f0d48f0e2..d786b2110c 100644 --- a/pySDC/tests/test_hooks/test_log_to_file.py +++ b/pySDC/tests/test_hooks/test_log_to_file.py @@ -32,7 +32,7 @@ def run(hook, Tend=0): u0 = prob.u_exact(0) _, stats = controller.run(u0, 0, Tend) - return stats + return u0, stats @pytest.mark.base @@ -68,8 +68,8 @@ def test_logging(): LogToFile.path = path Tend = 2 - stats = run([LogToFile, LogSolution], Tend=Tend) - u = get_sorted(stats, type='u') + u0, stats = run([LogToFile, LogSolution], Tend=Tend) + u = [(0.0, u0)] + get_sorted(stats, type='u') u_file = [] for i in range(len(u)):