From f61d33c9c667d881d25a033d87def2120e09f2ee Mon Sep 17 00:00:00 2001 From: Jeffrey Sorensen Date: Tue, 26 Oct 2021 20:39:01 -0400 Subject: [PATCH] Alter exception handling to behave more like scipy. --- .../python/stats/kendalls_tau.py | 115 +++++++++--------- .../python/stats/kendalls_tau_test.py | 35 +++++- 2 files changed, 88 insertions(+), 62 deletions(-) diff --git a/tensorflow_probability/python/stats/kendalls_tau.py b/tensorflow_probability/python/stats/kendalls_tau.py index 9818cba4d1..75b7535fc9 100644 --- a/tensorflow_probability/python/stats/kendalls_tau.py +++ b/tensorflow_probability/python/stats/kendalls_tau.py @@ -14,9 +14,9 @@ # ============================================================================== """Implements Kendall's Tau metric and loss.""" +import numpy as np import tensorflow as tf -from tensorflow_probability.python.internal import assert_util from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import tensorshape_util @@ -46,7 +46,7 @@ def iterative_mergesort(y, permutation, name=None): permutation, name='permutation', dtype=tf.int32) shape = permutation.shape tensorshape_util.assert_is_compatible_with(y.shape, shape) - n = ps.size(y) + y_size = ps.size(y) def outer_body(k, exchanges, permutation): # The outer body progressively merges lists as k grows by powers of 2, @@ -58,46 +58,47 @@ def middle_body(left, exchanges, permutation): # the middle body advances through the sublists of size k, advancing # the left edge until the end of the input is reached. right = left + k - end = tf.minimum(right + k, n) + end = tf.minimum(right + k, y_size) # See explanation here # https://www.geeksforgeeks.org/counting-inversions/. - def inner_body(i, j, x, np, p): + def inner_body(i, j, x, n, p): # The [left, right) and [right, end) lists are merged sorted, with # i and j tracking the advance through each range. x records the # number of order (bubble-sort equivalent) swaps that are happening - # with each insertion, and np represents the size of the output + # with each insertion, and n represents the size of the output # permutation that's been filled in using the p tensor. y_less = y_ordered[i] <= y_ordered[j] element = tf.where(y_less, [permutation[i]], [permutation[j]]) - new_p = tf.concat([p[0:np], element, p[np + 1:n]], axis=0) + new_p = tf.concat([p[0:n], element, p[n + 1:y_size]], axis=0) tensorshape_util.set_shape(new_p, p.shape) return (tf.where(y_less, i + 1, i), tf.where(y_less, j, j + 1), - tf.where(y_less, x, x + right - i), np + 1, new_p) + tf.where(y_less, x, x + right - i), n + 1, new_p) - i_j_x_np_p = (left, right, exchanges, 0, tf.zeros([n], dtype=tf.int32)) - (i, j, exchanges, np, p) = tf.while_loop( - cond=lambda i, j, x, np, p: tf.math.logical_and(i < right, j < end), + i_j_x_n_p = (left, right, exchanges, 0, + tf.zeros([y_size], dtype=tf.int32)) + (i, j, exchanges, n, p) = tf.while_loop( + cond=lambda i, j, x, n, p: tf.math.logical_and(i < right, j < end), body=inner_body, - loop_vars=i_j_x_np_p) + loop_vars=i_j_x_n_p) permutation = tf.concat([ - permutation[0:left], p[0:np], permutation[i:right], - permutation[j:end], permutation[end:n] + permutation[0:left], p[0:n], permutation[i:right], + permutation[j:end], permutation[end:y_size] ], axis=0) tensorshape_util.set_shape(permutation, shape) return left + 2 * k, exchanges, permutation _, exchanges, permutation = tf.while_loop( - cond=lambda left, exchanges, permutation: left < n - k, + cond=lambda left, exchanges, permutation: left < y_size - k, body=middle_body, loop_vars=(0, exchanges, permutation)) k *= 2 return k, exchanges, permutation _, exchanges, permutation = tf.while_loop( - cond=lambda k, exchanges, permutation: k < n, + cond=lambda k, exchanges, permutation: k < y_size, body=outer_body, loop_vars=(1, 0, permutation)) return exchanges, permutation @@ -159,37 +160,9 @@ def secondary_sort(): axis=0) -def kendalls_tau(y_true, y_pred, name=None): - """Computes Kendall's Tau for two ordered lists. - - Kendall's Tau measures the correlation between ordinal rankings. This - implementation is similar to the one used in scipy.stats.kendalltau. - The provided values may be of any type that is sortable, with the - argsort indices indicating the true or proposed ordinal sequence. - - Args: - y_true: a `Tensor` of shape `[n]` containing the true ordinal ranking. - y_pred: a `Tensor` of shape `[n]` containing the predicted ordering of the - same N items. - name: Optional Python `str` name for ops created by this method. - Default value: `None` (i.e., 'kendalls_tau'). - - Returns: - kendalls_tau: Kendall's Tau, the 1945 tau-b formulation that ignores - ordering of ties, as a `float32` scalar Tensor. - """ - with tf.name_scope(name or 'kendalls_tau'): - in_type = dtype_util.common_dtype([y_true, y_pred], dtype_hint=tf.float32) - y_true = tf.convert_to_tensor(y_true, name='y_true', dtype=in_type) - y_pred = tf.convert_to_tensor(y_pred, name='y_pred', dtype=in_type) - tensorshape_util.assert_is_compatible_with(y_true.shape, y_pred.shape) - assertions = [ - assert_util.assert_rank(y_true, 1), - assert_util.assert_greater( - ps.size(y_true), 1, 'Ordering requires at least 2 elements.') - ] - with tf.control_dependencies(assertions): - lexa = lexicographical_indirect_sort(y_true, y_pred) +def _compute_kendalls_tau(y_true, y_pred): + """Kendall's Tau Implementation.""" + lexa = lexicographical_indirect_sort(y_true, y_pred) # See A Computer Method for Calculating Kendall's Tau with Ungrouped Data # by William Night, Journal of the American Statistical Association, @@ -238,13 +211,43 @@ def ties_in_y_pred_body(first, u, i): loop_vars=(0, 0, 1)) u += ((n - first) * (n - first - 1)) // 2 n0 = (n * (n - 1)) // 2 - assertions = [ - assert_util.assert_less(v, tf.cast(n0, tf.int32), - 'All ranks are ties for y_true.'), - assert_util.assert_less(u, tf.cast(n0, tf.int32), - 'All ranks are ties for y_pred.') - ] - with tf.control_dependencies(assertions): - return (tf.cast(n0 - (u + v - t), tf.float32) - - 2.0 * tf.cast(exchanges, tf.float32)) / tf.math.sqrt( - tf.cast(n0 - v, tf.float32) * tf.cast(n0 - u, tf.float32)) + n0i = tf.cast(n0, tf.int32) + return tf.where( + tf.logical_or(tf.greater_equal(v, n0i), tf.greater_equal(u, n0i)), + tf.constant(np.nan), + ((tf.cast(n0 - (u + v - t), tf.float32) - + 2.0 * tf.cast(exchanges, tf.float32)) / + tf.math.sqrt(tf.cast(n0 - v, tf.float32) * tf.cast(n0 - u, tf.float32)))) + + +def kendalls_tau(y_true, y_pred, name=None): + """Computes Kendall's Tau for two ordered lists. + + Kendall's Tau measures the correlation between ordinal rankings. This + implementation is similar to the one used in scipy.stats.kendalltau. + The provided values may be of any type that is sortable, with the + argsort indices indicating the true or proposed ordinal sequence. + + Args: + y_true: a `Tensor` of shape `[n]` containing the true ordinal ranking. + y_pred: a `Tensor` of shape `[n]` containing the predicted ordering of the + same N items. + name: Optional Python `str` name for ops created by this method. + Default value: `None` (i.e., 'kendalls_tau'). + + Returns: + kendalls_tau: Kendall's Tau, the 1945 tau-b formulation that ignores + ordering of ties, as a `float32` scalar Tensor. + Will return np.nan under conditions when the order is undefined, such + as when all the elements of y_true or y_pred are the same, or when the + number of elements is less than 2. + """ + with tf.name_scope(name or 'kendalls_tau'): + in_type = dtype_util.common_dtype([y_true, y_pred], dtype_hint=tf.float32) + y_true = tf.convert_to_tensor(y_true, name='y_true', dtype=in_type) + y_pred = tf.convert_to_tensor(y_pred, name='y_pred', dtype=in_type) + tensorshape_util.assert_is_compatible_with(y_true.shape, y_pred.shape) + return tf.where( + tf.logical_or( + tf.not_equal(ps.rank(y_true), 1), tf.less(ps.size(y_true), 2)), + tf.constant(np.nan), _compute_kendalls_tau(y_true, y_pred)) diff --git a/tensorflow_probability/python/stats/kendalls_tau_test.py b/tensorflow_probability/python/stats/kendalls_tau_test.py index 8d01261cf9..fbe986c638 100644 --- a/tensorflow_probability/python/stats/kendalls_tau_test.py +++ b/tensorflow_probability/python/stats/kendalls_tau_test.py @@ -15,6 +15,7 @@ """Tests Kendall's Tau metric.""" import random +import numpy as np from scipy import stats @@ -74,21 +75,43 @@ def test_kendall_random_lists(self): self.assertAllClose(expected, res, atol=1e-5) def test_kendall_tau_assert_all_ties_y_true(self): - with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): - self.evaluate(tfp.stats.kendalls_tau([12, 12, 12], [1, 4, 7])) + self.assertTrue( + self.evaluate( + tf.math.is_nan(tfp.stats.kendalls_tau([12, 12, 12], [1, 4, 7])))) def test_kendall_tau_assert_all_ties_y_pred(self): - with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): - self.evaluate(tfp.stats.kendalls_tau([1, 2, 3], [4, 4, 4])) + self.assertTrue( + self.evaluate( + tf.math.is_nan(tfp.stats.kendalls_tau([1, 2, 3], [4, 4, 4])))) def test_kendall_tau_assert_scalar(self): - with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): - tfp.stats.kendalls_tau([1], [4]) + self.assertTrue( + self.evaluate(tf.math.is_nan(tfp.stats.kendalls_tau([1], [4])))) def test_kendall_tau_assert_unmatched(self): with self.assertRaises(ValueError): tfp.stats.kendalls_tau([1, 2], [3, 4, 5]) + def test_kendall_tau_edge_case_behavior(self): + self.assertTrue( + self.evaluate( + tf.math.is_nan( + tfp.stats.kendalls_tau( + tf.constant([0, 0]), tf.constant([3, 5]))))) + self.assertTrue( + self.evaluate( + tf.math.is_nan( + tfp.stats.kendalls_tau( + tf.constant([0, 1]), tf.constant([3, 3]))))) + self.assertTrue( + self.evaluate( + tf.math.is_nan( + tfp.stats.kendalls_tau(tf.constant([0]), tf.constant([3]))))) + self.assertTrue( + self.evaluate( + tf.math.is_nan( + tfp.stats.kendalls_tau(tf.constant([]), tf.constant([]))))) + if __name__ == '__main__': test_util.main()