Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alter exception handling to behave more like scipy. #1455

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 59 additions & 56 deletions tensorflow_probability/python/stats/kendalls_tau.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
# ==============================================================================
"""Implements Kendall's Tau metric and loss."""

import numpy as np
import tensorflow as tf

from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import tensorshape_util
Expand Down Expand Up @@ -46,7 +46,7 @@ def iterative_mergesort(y, permutation, name=None):
permutation, name='permutation', dtype=tf.int32)
shape = permutation.shape
tensorshape_util.assert_is_compatible_with(y.shape, shape)
n = ps.size(y)
y_size = ps.size(y)

def outer_body(k, exchanges, permutation):
# The outer body progressively merges lists as k grows by powers of 2,
Expand All @@ -58,46 +58,47 @@ def middle_body(left, exchanges, permutation):
# the middle body advances through the sublists of size k, advancing
# the left edge until the end of the input is reached.
right = left + k
end = tf.minimum(right + k, n)
end = tf.minimum(right + k, y_size)

# See explanation here
# https://www.geeksforgeeks.org/counting-inversions/.

def inner_body(i, j, x, np, p):
def inner_body(i, j, x, n, p):
# The [left, right) and [right, end) lists are merged sorted, with
# i and j tracking the advance through each range. x records the
# number of order (bubble-sort equivalent) swaps that are happening
# with each insertion, and np represents the size of the output
# with each insertion, and n represents the size of the output
# permutation that's been filled in using the p tensor.
y_less = y_ordered[i] <= y_ordered[j]
element = tf.where(y_less, [permutation[i]], [permutation[j]])
new_p = tf.concat([p[0:np], element, p[np + 1:n]], axis=0)
new_p = tf.concat([p[0:n], element, p[n + 1:y_size]], axis=0)
tensorshape_util.set_shape(new_p, p.shape)
return (tf.where(y_less, i + 1, i), tf.where(y_less, j, j + 1),
tf.where(y_less, x, x + right - i), np + 1, new_p)
tf.where(y_less, x, x + right - i), n + 1, new_p)

i_j_x_np_p = (left, right, exchanges, 0, tf.zeros([n], dtype=tf.int32))
(i, j, exchanges, np, p) = tf.while_loop(
cond=lambda i, j, x, np, p: tf.math.logical_and(i < right, j < end),
i_j_x_n_p = (left, right, exchanges, 0,
tf.zeros([y_size], dtype=tf.int32))
(i, j, exchanges, n, p) = tf.while_loop(
cond=lambda i, j, x, n, p: tf.math.logical_and(i < right, j < end),
body=inner_body,
loop_vars=i_j_x_np_p)
loop_vars=i_j_x_n_p)
permutation = tf.concat([
permutation[0:left], p[0:np], permutation[i:right],
permutation[j:end], permutation[end:n]
permutation[0:left], p[0:n], permutation[i:right],
permutation[j:end], permutation[end:y_size]
],
axis=0)
tensorshape_util.set_shape(permutation, shape)
return left + 2 * k, exchanges, permutation

_, exchanges, permutation = tf.while_loop(
cond=lambda left, exchanges, permutation: left < n - k,
cond=lambda left, exchanges, permutation: left < y_size - k,
body=middle_body,
loop_vars=(0, exchanges, permutation))
k *= 2
return k, exchanges, permutation

_, exchanges, permutation = tf.while_loop(
cond=lambda k, exchanges, permutation: k < n,
cond=lambda k, exchanges, permutation: k < y_size,
body=outer_body,
loop_vars=(1, 0, permutation))
return exchanges, permutation
Expand Down Expand Up @@ -159,37 +160,9 @@ def secondary_sort():
axis=0)


def kendalls_tau(y_true, y_pred, name=None):
"""Computes Kendall's Tau for two ordered lists.

Kendall's Tau measures the correlation between ordinal rankings. This
implementation is similar to the one used in scipy.stats.kendalltau.
The provided values may be of any type that is sortable, with the
argsort indices indicating the true or proposed ordinal sequence.

Args:
y_true: a `Tensor` of shape `[n]` containing the true ordinal ranking.
y_pred: a `Tensor` of shape `[n]` containing the predicted ordering of the
same N items.
name: Optional Python `str` name for ops created by this method.
Default value: `None` (i.e., 'kendalls_tau').

Returns:
kendalls_tau: Kendall's Tau, the 1945 tau-b formulation that ignores
ordering of ties, as a `float32` scalar Tensor.
"""
with tf.name_scope(name or 'kendalls_tau'):
in_type = dtype_util.common_dtype([y_true, y_pred], dtype_hint=tf.float32)
y_true = tf.convert_to_tensor(y_true, name='y_true', dtype=in_type)
y_pred = tf.convert_to_tensor(y_pred, name='y_pred', dtype=in_type)
tensorshape_util.assert_is_compatible_with(y_true.shape, y_pred.shape)
assertions = [
assert_util.assert_rank(y_true, 1),
assert_util.assert_greater(
ps.size(y_true), 1, 'Ordering requires at least 2 elements.')
]
with tf.control_dependencies(assertions):
lexa = lexicographical_indirect_sort(y_true, y_pred)
def _compute_kendalls_tau(y_true, y_pred):
"""Kendall's Tau Implementation."""
lexa = lexicographical_indirect_sort(y_true, y_pred)

# See A Computer Method for Calculating Kendall's Tau with Ungrouped Data
# by William Night, Journal of the American Statistical Association,
Expand Down Expand Up @@ -238,13 +211,43 @@ def ties_in_y_pred_body(first, u, i):
loop_vars=(0, 0, 1))
u += ((n - first) * (n - first - 1)) // 2
n0 = (n * (n - 1)) // 2
assertions = [
assert_util.assert_less(v, tf.cast(n0, tf.int32),
'All ranks are ties for y_true.'),
assert_util.assert_less(u, tf.cast(n0, tf.int32),
'All ranks are ties for y_pred.')
]
with tf.control_dependencies(assertions):
return (tf.cast(n0 - (u + v - t), tf.float32) -
2.0 * tf.cast(exchanges, tf.float32)) / tf.math.sqrt(
tf.cast(n0 - v, tf.float32) * tf.cast(n0 - u, tf.float32))
n0i = tf.cast(n0, tf.int32)
return tf.where(
tf.logical_or(tf.greater_equal(v, n0i), tf.greater_equal(u, n0i)),
tf.constant(np.nan),
((tf.cast(n0 - (u + v - t), tf.float32) -
2.0 * tf.cast(exchanges, tf.float32)) /
tf.math.sqrt(tf.cast(n0 - v, tf.float32) * tf.cast(n0 - u, tf.float32))))


def kendalls_tau(y_true, y_pred, name=None):
"""Computes Kendall's Tau for two ordered lists.

Kendall's Tau measures the correlation between ordinal rankings. This
implementation is similar to the one used in scipy.stats.kendalltau.
The provided values may be of any type that is sortable, with the
argsort indices indicating the true or proposed ordinal sequence.

Args:
y_true: a `Tensor` of shape `[n]` containing the true ordinal ranking.
y_pred: a `Tensor` of shape `[n]` containing the predicted ordering of the
same N items.
name: Optional Python `str` name for ops created by this method.
Default value: `None` (i.e., 'kendalls_tau').

Returns:
kendalls_tau: Kendall's Tau, the 1945 tau-b formulation that ignores
ordering of ties, as a `float32` scalar Tensor.
Will return np.nan under conditions when the order is undefined, such
as when all the elements of y_true or y_pred are the same, or when the
number of elements is less than 2.
"""
with tf.name_scope(name or 'kendalls_tau'):
in_type = dtype_util.common_dtype([y_true, y_pred], dtype_hint=tf.float32)
y_true = tf.convert_to_tensor(y_true, name='y_true', dtype=in_type)
y_pred = tf.convert_to_tensor(y_pred, name='y_pred', dtype=in_type)
tensorshape_util.assert_is_compatible_with(y_true.shape, y_pred.shape)
return tf.where(
tf.logical_or(
tf.not_equal(ps.rank(y_true), 1), tf.less(ps.size(y_true), 2)),
tf.constant(np.nan), _compute_kendalls_tau(y_true, y_pred))
35 changes: 29 additions & 6 deletions tensorflow_probability/python/stats/kendalls_tau_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Tests Kendall's Tau metric."""

import random
import numpy as np

from scipy import stats

Expand Down Expand Up @@ -74,21 +75,43 @@ def test_kendall_random_lists(self):
self.assertAllClose(expected, res, atol=1e-5)

def test_kendall_tau_assert_all_ties_y_true(self):
with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)):
self.evaluate(tfp.stats.kendalls_tau([12, 12, 12], [1, 4, 7]))
self.assertTrue(
self.evaluate(
tf.math.is_nan(tfp.stats.kendalls_tau([12, 12, 12], [1, 4, 7]))))

def test_kendall_tau_assert_all_ties_y_pred(self):
with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)):
self.evaluate(tfp.stats.kendalls_tau([1, 2, 3], [4, 4, 4]))
self.assertTrue(
self.evaluate(
tf.math.is_nan(tfp.stats.kendalls_tau([1, 2, 3], [4, 4, 4]))))

def test_kendall_tau_assert_scalar(self):
with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)):
tfp.stats.kendalls_tau([1], [4])
self.assertTrue(
self.evaluate(tf.math.is_nan(tfp.stats.kendalls_tau([1], [4]))))

def test_kendall_tau_assert_unmatched(self):
with self.assertRaises(ValueError):
tfp.stats.kendalls_tau([1, 2], [3, 4, 5])

def test_kendall_tau_edge_case_behavior(self):
self.assertTrue(
self.evaluate(
tf.math.is_nan(
tfp.stats.kendalls_tau(
tf.constant([0, 0]), tf.constant([3, 5])))))
self.assertTrue(
self.evaluate(
tf.math.is_nan(
tfp.stats.kendalls_tau(
tf.constant([0, 1]), tf.constant([3, 3])))))
self.assertTrue(
self.evaluate(
tf.math.is_nan(
tfp.stats.kendalls_tau(tf.constant([0]), tf.constant([3])))))
self.assertTrue(
self.evaluate(
tf.math.is_nan(
tfp.stats.kendalls_tau(tf.constant([]), tf.constant([])))))


if __name__ == '__main__':
test_util.main()