Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
computermacgyver committed Nov 29, 2023
1 parent d3dc05f commit cd191d8
Showing 1 changed file with 7 additions and 33 deletions.
40 changes: 7 additions & 33 deletions test/lib/model/test_image_sscd.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,42 +140,16 @@

class TestModel(unittest.TestCase):

@patch("torchvision.transforms")
def test_compute_sscd(self, mock_pdq_hasher):
def test_compute_sscd(self):
with open("img/presto_flowchart.jpg", "rb") as file:
image_content = file.read()
# mock_hasher_instance = mock_pdq_hasher.return_value
# mock_hasher_instance.fromBufferedImage.return_value.getHash.return_value.dumpBitsFlat.return_value = '1001'
result = Model().compute_sscd(io.BytesIO(image_content))
self.assertTrue(np.allclose(result, result_should_be))

@patch("urllib.request.urlopen")
def test_get_iobytes_for_image(self, mock_urlopen):
with open("img/presto_flowchart.jpg", "rb") as file:
image_content = file.read()
mock_response = Mock()
mock_response.read.return_value = image_content
mock_urlopen.return_value = mock_response
image = schemas.Message(body={"id": "123", "callback_url": "http://example.com?callback", "url": "http://example.com/image.jpg"}, model_name="audio__Model")
result = Model().get_iobytes_for_image(image)
self.assertIsInstance(result, io.BytesIO)
self.assertEqual(result.read(), image_content)

@patch("urllib.request.urlopen")
def test_get_iobytes_for_image_raises_error(self, mock_urlopen):
mock_urlopen.side_effect = URLError('test error')
image = schemas.Message(body={"id": "123", "callback_url": "http://example.com?callback", "url": "http://example.com/image.jpg"}, model_name="audio__Model")
with self.assertRaises(URLError):
Model().get_iobytes_for_image(image)

@patch.object(Model, "get_iobytes_for_image")
@patch.object(Model, "compute_sscd")
def test_process(self, mock_compute_pdq, mock_get_iobytes_for_image):
mock_compute_pdq.return_value = result_should_be
mock_get_iobytes_for_image.return_value = io.BytesIO(b"image_bytes")
image = schemas.Message(body={"id": "123", "callback_url": "http://example.com?callback", "url": "http://example.com/image.jpg"}, model_name="audio__Model")
result = Model().process(image)
self.assertEqual(result, result_should_be)
# The least significant digits differ between chipsets (arm64 and amd64)
# There fore we do not use self.assertEqual
# self.assertEqual(result, result_should_be)
# Instead, we assert that all the values are with an absolute tolerance
# given by atol in the following assertion.
self.assertTrue(np.allclose(result, result_should_be, rtol=0, atol=0.00001))


if __name__ == "__main__":
Expand Down

0 comments on commit cd191d8

Please sign in to comment.