Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fuzz testing for chexify. #272

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions chex/_src/asserts_chexify.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import dataclasses
import functools
import re
import threading
from typing import Any, Callable, FrozenSet

from absl import logging
Expand Down Expand Up @@ -178,6 +179,8 @@ def logp1_abs_safe(x: chex.Array) -> chex.Array:
thread_pool = futures.ThreadPoolExecutor(1, f'async_chex_{func_name}')
# A deque for futures.
async_check_futures = collections.deque()
# Protect the futures from concurrent access.
async_check_futures_lock = threading.Lock()

# Checkification.
checkified_fn = checkify.checkify(fn, errors=errors)
Expand All @@ -191,8 +194,9 @@ def _chexified_fn(*args, **kwargs):

if async_check:
# Check completed calls.
while async_check_futures and async_check_futures[0].done():
_check_error(async_check_futures.popleft().result(async_timeout))
with async_check_futures_lock:
while async_check_futures and async_check_futures[0].done():
_check_error(async_check_futures.popleft().result(async_timeout))

# Run the checkified function.
_ai.CHEXIFY_STORAGE.level += 1
Expand All @@ -214,8 +218,9 @@ def _chexified_fn(*args, **kwargs):

def _wait_checks():
if async_check:
while async_check_futures:
_check_error(async_check_futures.popleft().result(async_timeout))
with async_check_futures_lock:
while async_check_futures:
_check_error(async_check_futures.popleft().result(async_timeout))

# Add a barrier callback to the global storage.
_ai.CHEXIFY_STORAGE.wait_fns.append(_wait_checks)
Expand Down
66 changes: 66 additions & 0 deletions chex/_src/asserts_chexify_fuzz_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Fuzz test for `asserts_chexify.py`."""

import concurrent.futures
import random
import time

from absl.testing import absltest
from chex._src import asserts
from chex._src import asserts_chexify
from chex._src import variants
import jax
import jax.numpy as jnp


class AssertsChexifyFuzzTest(variants.TestCase):
"""Fuzz test for thread safety of chexify."""

def test_thread_safety(self):

def assert_negative():
result = jnp.ones(shape=())
# This assert will always fail.
asserts.assert_scalar_negative(result)
return result

def chexified_assert_negative():
fn = asserts_chexify.chexify(assert_negative, async_check=True)
fn()
# Introduce random delay between the two calls, otherwise we will not
# get interleaving of the two operations between threads because they
# happen too quickly.
time.sleep(random.uniform(0.01, 0.02))
asserts_chexify.block_until_chexify_assertions_complete()

with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
futures = []
for _ in range(1000):
future = executor.submit(chexified_assert_negative)
futures.append(future)

for future in concurrent.futures.as_completed(futures):
try:
future.result()
except AssertionError:
pass

asserts_chexify.block_until_chexify_assertions_complete()


if __name__ == '__main__':
jax.config.update('jax_numpy_rank_promotion', 'raise')
absltest.main()