From 72c8448208fc3e9b303acedc45b047efb9d3ea0b Mon Sep 17 00:00:00 2001 From: Shizhi Tang Date: Sun, 18 Feb 2024 09:11:05 +0800 Subject: [PATCH] Support reversed(dynamic_range) (#597) --- python/freetensor/core/frontend.py | 10 ++++++ python/freetensor/core/transform.py | 6 +++- test/50.frontend/test_transformer_basic.py | 38 ++++++++++++++++++++++ 3 files changed, 53 insertions(+), 1 deletion(-) diff --git a/python/freetensor/core/frontend.py b/python/freetensor/core/frontend.py index ffdfb7ef9..e73ac1280 100644 --- a/python/freetensor/core/frontend.py +++ b/python/freetensor/core/frontend.py @@ -3,6 +3,7 @@ ''' import sys +import builtins import numpy as np import inspect import traceback @@ -552,3 +553,12 @@ def foreach(self, name, body: Callable[[Any], None]) -> None: static_range = range + + +def reversed(rng): + if isinstance(rng, dynamic_range): + return dynamic_range( + rng.start + rng.step * ((rng.stop - rng.start - 1) // rng.step), + rng.start - rng.step, -rng.step) + else: + return builtins.reversed(rng) diff --git a/python/freetensor/core/transform.py b/python/freetensor/core/transform.py index 95b8d4462..c916776a3 100644 --- a/python/freetensor/core/transform.py +++ b/python/freetensor/core/transform.py @@ -10,6 +10,7 @@ from .expr import UndeclaredParam from .stmt import VarRef from .func import Func +from . import frontend from .frontend import lang_overload, staged_callable, LifetimeScope, dynamic_range from .context import pop_ast_and_user_grads, ctx_stack from .staging import StagingError, TransformError @@ -19,7 +20,10 @@ def _prepare_extra_locals(default_dynamic_range): - extra_locals = {'__ft__': sys.modules['freetensor']} + extra_locals = { + '__ft__': sys.modules['freetensor'], + 'reversed': frontend.reversed, + } if default_dynamic_range: extra_locals['range'] = dynamic_range return extra_locals diff --git a/test/50.frontend/test_transformer_basic.py b/test/50.frontend/test_transformer_basic.py index 47ab450f8..ea9c9ef1a 100644 --- a/test/50.frontend/test_transformer_basic.py +++ b/test/50.frontend/test_transformer_basic.py @@ -547,6 +547,44 @@ def test_expected(x: ft.Var[(), 'float32']): assert test.body.match(test_expected.body) +def test_reversed_dynamic_range_1(): + + @ft.transform(verbose=2) + def test(x, y): + x: ft.Var[(16,), "int32", "input"] + y: ft.Var[(16,), "int32", "output"] + for i in reversed(range(0, 16, 2)): + y[i] = x[i] + 1 + + @ft.transform + def test_expected(x, y): + x: ft.Var[(16,), "int32", "input"] + y: ft.Var[(16,), "int32", "output"] + for i in range(14, -2, -2): + y[i] = x[i] + 1 + + assert test.body.match(test_expected.body) + + +def test_reversed_dynamic_range_2(): + + @ft.transform(verbose=2) + def test(x, y): + x: ft.Var[(16,), "int32", "input"] + y: ft.Var[(16,), "int32", "output"] + for i in reversed(range(0, 15, 2)): + y[i] = x[i] + 1 + + @ft.transform + def test_expected(x, y): + x: ft.Var[(16,), "int32", "input"] + y: ft.Var[(16,), "int32", "output"] + for i in range(14, -2, -2): + y[i] = x[i] + 1 + + assert test.body.match(test_expected.body) + + @dataclass class DummyAssigned: attr = None