Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infrastructure for memory usage tests #299

Merged
merged 16 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/299.general.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add infrastructure for testing memory usage
36 changes: 36 additions & 0 deletions src/stcal/testing_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import tracemalloc

class MemoryThresholdExceeded(Exception):
pass


class MemoryThreshold:
"""
Context manager to check peak memory usage against an expected threshold.

example usage:
with MemoryThreshold(expected_usage):
# code that should not exceed expected
emolter marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, expected_usage):
"""
Parameters
----------
expected_usage : int
Expected peak memory usage in bytes
"""
self.expected_usage = expected_usage

def __enter__(self):
tracemalloc.start()
return self

def __exit__(self, exc_type, exc_value, traceback):
_, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()

if peak > self.expected_usage:
msg = ("Peak memory usage exceeded expected usage: "
f"{peak / 1024:.2f} KB > {self.expected_usage / 1024:.2f} KB")
emolter marked this conversation as resolved.
Show resolved Hide resolved
raise MemoryThresholdExceeded(msg)
38 changes: 38 additions & 0 deletions tests/outlier_detection/test_median.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
_OnDiskMedian,
nanmedian3D,
)
from stcal.testing_helpers import MemoryThreshold


def test_disk_appendable_array(tmp_path):
Expand Down Expand Up @@ -194,3 +195,40 @@ def test_nanmedian3D():

assert med.dtype == np.float32
assert np.allclose(med, np.nanmedian(cube, axis=0), equal_nan=True)


@pytest.mark.parametrize("in_memory", [True, False])
def test_memory_computer(in_memory, tmp_path):
"""
Analytically calculate how much memory the median computation
is supposed to take, then ensure that the implementation
stays near that.

in_memory=True case allocates the following memory:
- one cube size
- median array == one frame size

in_memory=False case allocates the following memory:
- one buffer size, which by default is the frame size
- median array == one frame size

add a half-frame-size buffer to the expected memory usage in both cases
"""
shp = (20, 500, 500)
cube_size = np.dtype("float32").itemsize * shp[0] * shp[1] * shp[2] #bytes
frame_size = cube_size / shp[0]

# calculate expected memory usage
if in_memory:
expected_mem = cube_size + frame_size*1.5
else:
expected_mem = frame_size * 2.5

# compute the median while tracking memory usage
with MemoryThreshold(expected_mem):
computer = MedianComputer(shp, in_memory=in_memory, tempdir=tmp_path)
for i in range(shp[0]):
frame = np.full(shp[1:], i, dtype=np.float32)
computer.append(frame, i)
del frame
computer.evaluate()
16 changes: 16 additions & 0 deletions tests/test_infrastructure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Tests of custom testing infrastructure"""

import pytest
import numpy as np
from stcal.testing_helpers import MemoryThreshold, MemoryThresholdExceeded


def test_memory_threshold():
with MemoryThreshold(1000):
braingram marked this conversation as resolved.
Show resolved Hide resolved
buff = np.empty(200, dtype=np.uint8)


def test_memory_threshold_raise():
with pytest.raises(MemoryThresholdExceeded):
with MemoryThreshold(1000):
buff = np.empty(2000, dtype=np.uint8)
Loading