Skip to content

Commit

Permalink
pythonGH-127809: Fix the JIT's understanding of ** (pythonGH-127844)
Browse files Browse the repository at this point in the history
  • Loading branch information
brandtbucher authored Jan 8, 2025
1 parent e08b282 commit 65ae3d5
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 26 deletions.
44 changes: 44 additions & 0 deletions Lib/test/test_capi/test_opt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import itertools
import sys
import textwrap
import unittest
Expand Down Expand Up @@ -1511,6 +1512,49 @@ def test_jit_error_pops(self):
with self.assertRaises(TypeError):
{item for item in items}

def test_power_type_depends_on_input_values(self):
template = textwrap.dedent("""
import _testinternalcapi
L, R, X, Y = {l}, {r}, {x}, {y}
def check(actual: complex, expected: complex) -> None:
assert actual == expected, (actual, expected)
assert type(actual) is type(expected), (actual, expected)
def f(l: complex, r: complex) -> None:
expected_local_local = pow(l, r) + pow(l, r)
expected_const_local = pow(L, r) + pow(L, r)
expected_local_const = pow(l, R) + pow(l, R)
expected_const_const = pow(L, R) + pow(L, R)
for _ in range(_testinternalcapi.TIER2_THRESHOLD):
# Narrow types:
l + l, r + r
# The powers produce results, and the addition is unguarded:
check(l ** r + l ** r, expected_local_local)
check(L ** r + L ** r, expected_const_local)
check(l ** R + l ** R, expected_local_const)
check(L ** R + L ** R, expected_const_const)
# JIT for one pair of values...
f(L, R)
# ...then run with another:
f(X, Y)
""")
interesting = [
(1, 1), # int ** int -> int
(1, -1), # int ** int -> float
(1.0, 1), # float ** int -> float
(1, 1.0), # int ** float -> float
(-1, 0.5), # int ** float -> complex
(1.0, 1.0), # float ** float -> float
(-1.0, 0.5), # float ** float -> complex
]
for (l, r), (x, y) in itertools.product(interesting, repeat=2):
s = template.format(l=l, r=r, x=x, y=y)
with self.subTest(l=l, r=r, x=x, y=y):
script_helper.assert_python_ok("-c", s)


def global_identity(x):
return x
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix an issue where the experimental JIT may infer an incorrect result type
for exponentiation (``**`` and ``**=``), leading to bugs or crashes.
16 changes: 16 additions & 0 deletions Python/bytecodes.c
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,8 @@ dummy_func(
pure op(_BINARY_OP_MULTIPLY_INT, (left, right -- res)) {
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyLong_CheckExact(left_o));
assert(PyLong_CheckExact(right_o));

STAT_INC(BINARY_OP, hit);
PyObject *res_o = _PyLong_Multiply((PyLongObject *)left_o, (PyLongObject *)right_o);
Expand All @@ -543,6 +545,8 @@ dummy_func(
pure op(_BINARY_OP_ADD_INT, (left, right -- res)) {
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyLong_CheckExact(left_o));
assert(PyLong_CheckExact(right_o));

STAT_INC(BINARY_OP, hit);
PyObject *res_o = _PyLong_Add((PyLongObject *)left_o, (PyLongObject *)right_o);
Expand All @@ -556,6 +560,8 @@ dummy_func(
pure op(_BINARY_OP_SUBTRACT_INT, (left, right -- res)) {
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyLong_CheckExact(left_o));
assert(PyLong_CheckExact(right_o));

STAT_INC(BINARY_OP, hit);
PyObject *res_o = _PyLong_Subtract((PyLongObject *)left_o, (PyLongObject *)right_o);
Expand Down Expand Up @@ -593,6 +599,8 @@ dummy_func(
pure op(_BINARY_OP_MULTIPLY_FLOAT, (left, right -- res)) {
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyFloat_CheckExact(left_o));
assert(PyFloat_CheckExact(right_o));

STAT_INC(BINARY_OP, hit);
double dres =
Expand All @@ -607,6 +615,8 @@ dummy_func(
pure op(_BINARY_OP_ADD_FLOAT, (left, right -- res)) {
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyFloat_CheckExact(left_o));
assert(PyFloat_CheckExact(right_o));

STAT_INC(BINARY_OP, hit);
double dres =
Expand All @@ -621,6 +631,8 @@ dummy_func(
pure op(_BINARY_OP_SUBTRACT_FLOAT, (left, right -- res)) {
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyFloat_CheckExact(left_o));
assert(PyFloat_CheckExact(right_o));

STAT_INC(BINARY_OP, hit);
double dres =
Expand Down Expand Up @@ -650,6 +662,8 @@ dummy_func(
pure op(_BINARY_OP_ADD_UNICODE, (left, right -- res)) {
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyUnicode_CheckExact(left_o));
assert(PyUnicode_CheckExact(right_o));

STAT_INC(BINARY_OP, hit);
PyObject *res_o = PyUnicode_Concat(left_o, right_o);
Expand All @@ -672,6 +686,8 @@ dummy_func(
op(_BINARY_OP_INPLACE_ADD_UNICODE, (left, right --)) {
PyObject *left_o = PyStackRef_AsPyObjectBorrow(left);
PyObject *right_o = PyStackRef_AsPyObjectBorrow(right);
assert(PyUnicode_CheckExact(left_o));
assert(PyUnicode_CheckExact(right_o));

int next_oparg;
#if TIER_ONE
Expand Down
16 changes: 16 additions & 0 deletions Python/executor_cases.c.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions Python/generated_cases.c.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

57 changes: 45 additions & 12 deletions Python/optimizer_bytecodes.c
Original file line number Diff line number Diff line change
Expand Up @@ -167,23 +167,56 @@ dummy_func(void) {
}

op(_BINARY_OP, (left, right -- res)) {
PyTypeObject *ltype = sym_get_type(left);
PyTypeObject *rtype = sym_get_type(right);
if (ltype != NULL && (ltype == &PyLong_Type || ltype == &PyFloat_Type) &&
rtype != NULL && (rtype == &PyLong_Type || rtype == &PyFloat_Type))
{
if (oparg != NB_TRUE_DIVIDE && oparg != NB_INPLACE_TRUE_DIVIDE &&
ltype == &PyLong_Type && rtype == &PyLong_Type) {
/* If both inputs are ints and the op is not division the result is an int */
res = sym_new_type(ctx, &PyLong_Type);
bool lhs_int = sym_matches_type(left, &PyLong_Type);
bool rhs_int = sym_matches_type(right, &PyLong_Type);
bool lhs_float = sym_matches_type(left, &PyFloat_Type);
bool rhs_float = sym_matches_type(right, &PyFloat_Type);
if (!((lhs_int || lhs_float) && (rhs_int || rhs_float))) {
// There's something other than an int or float involved:
res = sym_new_unknown(ctx);
}
else if (oparg == NB_POWER || oparg == NB_INPLACE_POWER) {
// This one's fun... the *type* of the result depends on the
// *values* being exponentiated. However, exponents with one
// constant part are reasonably common, so it's probably worth
// trying to infer some simple cases:
// - A: 1 ** 1 -> 1 (int ** int -> int)
// - B: 1 ** -1 -> 1.0 (int ** int -> float)
// - C: 1.0 ** 1 -> 1.0 (float ** int -> float)
// - D: 1 ** 1.0 -> 1.0 (int ** float -> float)
// - E: -1 ** 0.5 ~> 1j (int ** float -> complex)
// - F: 1.0 ** 1.0 -> 1.0 (float ** float -> float)
// - G: -1.0 ** 0.5 ~> 1j (float ** float -> complex)
if (rhs_float) {
// Case D, E, F, or G... can't know without the sign of the LHS
// or whether the RHS is whole, which isn't worth the effort:
res = sym_new_unknown(ctx);
}
else {
/* For any other op combining ints/floats the result is a float */
else if (lhs_float) {
// Case C:
res = sym_new_type(ctx, &PyFloat_Type);
}
else if (!sym_is_const(right)) {
// Case A or B... can't know without the sign of the RHS:
res = sym_new_unknown(ctx);
}
else if (_PyLong_IsNegative((PyLongObject *)sym_get_const(right))) {
// Case B:
res = sym_new_type(ctx, &PyFloat_Type);
}
else {
// Case A:
res = sym_new_type(ctx, &PyLong_Type);
}
}
else if (oparg == NB_TRUE_DIVIDE || oparg == NB_INPLACE_TRUE_DIVIDE) {
res = sym_new_type(ctx, &PyFloat_Type);
}
else if (lhs_int && rhs_int) {
res = sym_new_type(ctx, &PyLong_Type);
}
else {
res = sym_new_unknown(ctx);
res = sym_new_type(ctx, &PyFloat_Type);
}
}

Expand Down
Loading

0 comments on commit 65ae3d5

Please sign in to comment.