Skip to content

Commit

Permalink
[mypyc] Fixing index variable in for-loop with builtins.enumerate. (#…
Browse files Browse the repository at this point in the history
…18202)

Fixes [mypyc/mypyc#1046](mypyc/mypyc#1046)

This change fixes two problems:
1. The index variable was getting instantiated even while enumerating an
empty iterable.
2. After exiting the for-loop, the value of the index variable is off by
1 (see issue linked above).

This change fixes both problems by assigning the temporary register to
the index variable at the beginning of the for-loop body. Before this
change, this assignment was happening before the for-loop and at the end
of the for-loop body.
  • Loading branch information
advait-dixit authored Nov 27, 2024
1 parent 2842e8f commit d39eacc
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 11 deletions.
5 changes: 3 additions & 2 deletions mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,6 @@ def init(self) -> None:
zero = Integer(0)
self.index_reg = builder.maybe_spill_assignable(zero)
self.index_target: Register | AssignmentTarget = builder.get_assignment_target(self.index)
builder.assign(self.index_target, zero, self.line)

def gen_step(self) -> None:
builder = self.builder
Expand All @@ -997,7 +996,9 @@ def gen_step(self) -> None:
short_int_rprimitive, builder.read(self.index_reg, line), Integer(1), IntOp.ADD, line
)
builder.assign(self.index_reg, new_val, line)
builder.assign(self.index_target, new_val, line)

def begin_body(self) -> None:
self.builder.assign(self.index_target, self.builder.read(self.index_reg), self.line)


class ForEnumerate(ForGenerator):
Expand Down
14 changes: 5 additions & 9 deletions mypyc/test-data/irbuild-statements.test
Original file line number Diff line number Diff line change
Expand Up @@ -864,33 +864,31 @@ def g(x: Iterable[int]) -> None:
[out]
def f(a):
a :: list
r0 :: short_int
i :: int
r1 :: short_int
r0, r1 :: short_int
r2 :: native_int
r3 :: short_int
r4 :: bit
i :: int
r5 :: object
r6, x, r7 :: int
r8, r9 :: short_int
L0:
r0 = 0
i = 0
r1 = 0
L1:
r2 = var_object_size a
r3 = r2 << 1
r4 = int_lt r1, r3
if r4 goto L2 else goto L4 :: bool
L2:
i = r0
r5 = CPyList_GetItemUnsafe(a, r1)
r6 = unbox(int, r5)
x = r6
r7 = CPyTagged_Add(i, x)
L3:
r8 = r0 + 2
r0 = r8
i = r8
r9 = r1 + 2
r1 = r9
goto L1
Expand All @@ -900,25 +898,23 @@ L5:
def g(x):
x :: object
r0 :: short_int
i :: int
r1, r2 :: object
r3, n :: int
i, r3, n :: int
r4 :: short_int
r5 :: bit
L0:
r0 = 0
i = 0
r1 = PyObject_GetIter(x)
L1:
r2 = PyIter_Next(r1)
if is_error(r2) goto L4 else goto L2
L2:
i = r0
r3 = unbox(int, r2)
n = r3
L3:
r4 = r0 + 2
r0 = r4
i = r4
goto L1
L4:
r5 = CPy_NoErrOccured()
Expand Down
24 changes: 24 additions & 0 deletions mypyc/test-data/run-loops.test
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def nested_enumerate() -> None:
assert i == inner
inner += 1
outer += 1
assert i == 2
assert outer_seen == l1

def nested_range() -> None:
Expand Down Expand Up @@ -465,6 +466,29 @@ assert g([6, 7], ['a', 'b']) == [(0, 6, 'a'), (1, 7, 'b')]
assert f([6, 7], [8]) == [(0, 6, 8)]
assert f([6], [8, 9]) == [(0, 6, 8)]

[case testEnumerateEmptyList]
from typing import List

def get_enumerate_locals(iterable: List[int]) -> int:
for i, j in enumerate(iterable):
pass
try:
return i
except NameError:
return -100

[file driver.py]
from native import get_enumerate_locals

print(get_enumerate_locals([]))
print(get_enumerate_locals([55]))
print(get_enumerate_locals([551, 552]))

[out]
-100
0
1

[case testIterTypeTrickiness]
# Test inferring the type of a for loop body doesn't cause us grief
# Extracted from somethings that broke in mypy
Expand Down

0 comments on commit d39eacc

Please sign in to comment.