Skip to content

Commit

Permalink
Support reversed(dynamic_range) (#597)
Browse files Browse the repository at this point in the history
  • Loading branch information
roastduck authored Feb 18, 2024
1 parent 2d404a5 commit 72c8448
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 1 deletion.
10 changes: 10 additions & 0 deletions python/freetensor/core/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
'''

import sys
import builtins
import numpy as np
import inspect
import traceback
Expand Down Expand Up @@ -552,3 +553,12 @@ def foreach(self, name, body: Callable[[Any], None]) -> None:


static_range = range


def reversed(rng):
if isinstance(rng, dynamic_range):
return dynamic_range(
rng.start + rng.step * ((rng.stop - rng.start - 1) // rng.step),
rng.start - rng.step, -rng.step)
else:
return builtins.reversed(rng)
6 changes: 5 additions & 1 deletion python/freetensor/core/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .expr import UndeclaredParam
from .stmt import VarRef
from .func import Func
from . import frontend
from .frontend import lang_overload, staged_callable, LifetimeScope, dynamic_range
from .context import pop_ast_and_user_grads, ctx_stack
from .staging import StagingError, TransformError
Expand All @@ -19,7 +20,10 @@


def _prepare_extra_locals(default_dynamic_range):
extra_locals = {'__ft__': sys.modules['freetensor']}
extra_locals = {
'__ft__': sys.modules['freetensor'],
'reversed': frontend.reversed,
}
if default_dynamic_range:
extra_locals['range'] = dynamic_range
return extra_locals
Expand Down
38 changes: 38 additions & 0 deletions test/50.frontend/test_transformer_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,44 @@ def test_expected(x: ft.Var[(), 'float32']):
assert test.body.match(test_expected.body)


def test_reversed_dynamic_range_1():

@ft.transform(verbose=2)
def test(x, y):
x: ft.Var[(16,), "int32", "input"]
y: ft.Var[(16,), "int32", "output"]
for i in reversed(range(0, 16, 2)):
y[i] = x[i] + 1

@ft.transform
def test_expected(x, y):
x: ft.Var[(16,), "int32", "input"]
y: ft.Var[(16,), "int32", "output"]
for i in range(14, -2, -2):
y[i] = x[i] + 1

assert test.body.match(test_expected.body)


def test_reversed_dynamic_range_2():

@ft.transform(verbose=2)
def test(x, y):
x: ft.Var[(16,), "int32", "input"]
y: ft.Var[(16,), "int32", "output"]
for i in reversed(range(0, 15, 2)):
y[i] = x[i] + 1

@ft.transform
def test_expected(x, y):
x: ft.Var[(16,), "int32", "input"]
y: ft.Var[(16,), "int32", "output"]
for i in range(14, -2, -2):
y[i] = x[i] + 1

assert test.body.match(test_expected.body)


@dataclass
class DummyAssigned:
attr = None
Expand Down

0 comments on commit 72c8448

Please sign in to comment.