Skip to content

Latest commit

 

History

History
87 lines (64 loc) · 3.19 KB

UNITTEST.md

File metadata and controls

87 lines (64 loc) · 3.19 KB

TFP Unit-Test Best Practices

Recipe to easily test static and dynamic shape.

tl;dr: See this example.

Most of TFP code has two code-paths to handle shape, one prior to graph execution and one during. The following example illustrates a pattern for making this easier. The "trick" is to always make the input a placeholder (even when testing static shape).

import tensorflow as tf
tfe = tf.contrib.eager

class _DistributionTest(object):

  @tfe.run_test_in_graph_and_eager_modes
  def testSomething(self):
    input_ = ...  # Using `self.dtype`.
    input_ph = tf.placeholder_with_default(
        input=input_,
        shape=input_.shape if self.use_static_shape else None)
    ...
    [...] = self.evaluate([...])
    ...

class DistributionTest_StaticShape(tf.test.TestCase, _DistributionTest):
  dtype = np.float32
  use_static_shape = True

class DistributionTest_DynamicShape(tf.test.TestCase, _DistributionTest):
  dtype = np.float32
  use_static_shape = False

Notice that we use tf.placeholder_with_default rather than tf.placeholder. This allows convenient debugging of the executing code yet still lets us programmatically hide shape hints.

This idea can be extended as appropriate. For example, in the Reshape bijector we tested error-checking paths that, given a static-shape input, will raise exceptions at graph construction time, but op errors at runtime in the dynamic case. To handle these with unified code we can have the static and dynamic subclasses implement separate versions of assertRaisesError that do the respectively appropriate check, i.e.,

import tensorflow as tf
tfe = tf.contrib.eager

class _DistributionTest(object):

  @tfe.run_test_in_graph_and_eager_modes
  def testSomething(self):
    input_ = ...
    …
    with self.assertRaisesError(
        "Some error message"):
      [...] = self.evaluate(something_that_might_throw_exception([...]))
    ...

class DistributionTest_StaticShape(test.TestCase, _DistributionTest):
  ...
  def assertRaisesError(self, msg):
    return self.assertRaisesRegexp(Exception, msg)

class DistributionTest_DynamicShape(test.TestCase, _DistributionTest):
  ...
  def assertRaisesError(self, msg):
    return self.assertRaisesOpError(msg)

Testing ℝd-Variate Distributions

Helper class to test vector-event distributions.

VectorDistributionTestHelpers (Example Use)

Testing Discrete, Scalar Distributions

Helper class to test scalar variate distributions over integers (or Booleans).

DiscreteScalarDistributionTestHelpers (Example Use)