tl;dr: See this example.
Most of TFP code has two code-paths to handle shape, one prior to graph execution and one during. The following example illustrates a pattern for making this easier. The "trick" is to always make the input a placeholder (even when testing static shape).
import tensorflow as tf
tfe = tf.contrib.eager
class _DistributionTest(object):
@tfe.run_test_in_graph_and_eager_modes
def testSomething(self):
input_ = ... # Using `self.dtype`.
input_ph = tf.placeholder_with_default(
input=input_,
shape=input_.shape if self.use_static_shape else None)
...
[...] = self.evaluate([...])
...
class DistributionTest_StaticShape(tf.test.TestCase, _DistributionTest):
dtype = np.float32
use_static_shape = True
class DistributionTest_DynamicShape(tf.test.TestCase, _DistributionTest):
dtype = np.float32
use_static_shape = False
Notice that we use tf.placeholder_with_default
rather than tf.placeholder
.
This allows convenient debugging of the executing code yet still lets us
programmatically hide shape hints.
This idea can be extended as appropriate. For example, in the Reshape
bijector
we tested error-checking paths that, given a static-shape input, will raise
exceptions at graph construction time, but op errors at runtime in the dynamic
case. To handle these with unified code we can have the static and dynamic
subclasses implement separate versions of assertRaisesError that do the
respectively appropriate check, i.e.,
import tensorflow as tf
tfe = tf.contrib.eager
class _DistributionTest(object):
@tfe.run_test_in_graph_and_eager_modes
def testSomething(self):
input_ = ...
…
with self.assertRaisesError(
"Some error message"):
[...] = self.evaluate(something_that_might_throw_exception([...]))
...
class DistributionTest_StaticShape(test.TestCase, _DistributionTest):
...
def assertRaisesError(self, msg):
return self.assertRaisesRegexp(Exception, msg)
class DistributionTest_DynamicShape(test.TestCase, _DistributionTest):
...
def assertRaisesError(self, msg):
return self.assertRaisesOpError(msg)
Helper class to test vector-event distributions.
VectorDistributionTestHelpers (Example Use)
Helper class to test scalar variate distributions over integers (or Booleans).