Skip to content

Commit

Permalink
Check incorrect x: ft.Var[...] where x is not a parameter (#606)
Browse files Browse the repository at this point in the history
* Check incorrect `x: ft.Var[...]` where `x` is not a parameter

* Fix broken tests
  • Loading branch information
roastduck authored Apr 5, 2024
1 parent 11b6b04 commit e9eeb30
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 12 deletions.
14 changes: 11 additions & 3 deletions python/freetensor/core/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .. import ffi

from .expr import (dtype, mtype, ndim, l_and, l_or, l_not, if_then_else, shape,
VarVersionRef)
VarVersionRef, UndeclaredParam)
from .stmt import (_VarDef, VarRef, For, If, Else, ctx_stack, Assert, Invoke,
MarkVersion, UserGradStaged)
from .staging import (StagedPredicate, StagedTypeAnnotation, StagedAssignable,
Expand Down Expand Up @@ -420,8 +420,16 @@ def __init__(self, shape, dtype, atype="input", mtype=None):
self.shape, self.dtype, self.atype, self.mtype = shape, dtype, atype, mtype

def annotate(self, name: str) -> VarRef:
return lang_overload.register_vardef(name, self.shape, self.dtype,
self.atype, self.mtype)

def annotate_impl(old_var):
if not isinstance(old_var, UndeclaredParam):
raise lang_overload.error(
f'`ft.Var` annotation should only be used on undeclared parameters, instead of `{name}`.'
)
return lang_overload.register_vardef(name, self.shape, self.dtype,
self.atype, self.mtype)

return annotate_impl


@dataclass
Expand Down
8 changes: 5 additions & 3 deletions python/freetensor/core/staging.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def wrapped(*args, **kwargs):
def annotate_stmt(self, name: str, ty):
if isinstance(ty, StagedTypeAnnotation):
return ty.annotate(name)
return None
return lambda x: x

def mark_position(self, lineno: int):
'''
Expand Down Expand Up @@ -799,10 +799,12 @@ def handleType_AnnAssign(self, node: ast.AnnAssign) -> Any:
intermediate = f'freetensor__annotate__{x.id}'
intermediate_store = ast.Name(intermediate, ast.Store())
intermediate_load = ast.Name(intermediate, ast.Load())
intermediate_call = ast.Call(intermediate_load,
[ast.Name(x.id, ast.Load())], [])
node = [
ast.Assign([intermediate_store],
call_helper(StagingOverload.annotate_stmt, x_str, Ty)),
ast.If(intermediate_load, [ast.Assign([x], intermediate_load)], [])
ast.If(intermediate_load, [ast.Assign([x], intermediate_call)], [])
]

return node
Expand All @@ -812,7 +814,7 @@ def visit_AnnAssign(self, old_node: ast.AnnAssign) -> Any:
`x: Ty` -> ```
freetensor__annotate__x = annotate_stmt('x', Ty)
if freetensor__annotate__x:
x = freetensor__annotate__x
x = freetensor__annotate__x(x)
```: pure annotation
'''
node: ast.AnnAssign = self.generic_visit(old_node)
Expand Down
2 changes: 1 addition & 1 deletion test/21.autograd/test_tape.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def func(a: ft.Var[(10,), "float32"]):
bwd = ft.lower(bwd, verbose=1)

@ft.transform
def expected(a, d_a, b_tape, d_c):
def expected(a, d_a, b_tape, d_b, d_c):
a: ft.Var[(10,), "float32", "input"]
d_a: ft.Var[(10,), "float32", "output"]
b_tape: ft.Var[(1, 10), "float32>=0", "input"]
Expand Down
2 changes: 1 addition & 1 deletion test/30.schedule/test_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def test(ptr, edge1, edge2):
def test_fuse_no_deps_2():

@ft.transform
def test(ptr, edge1, edge2):
def test(ptr, edge1, edge2, foobar):
ptr: ft.Var[(11,), "int32", "input", "cpu"]
edge1: ft.Var[(50,), "int32", "input", "cpu"]
edge2: ft.Var[(50,), "int32", "output", "cpu"]
Expand Down
2 changes: 1 addition & 1 deletion test/40.codegen/gpu/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@ def matcher(x):
def test_merge_no_deps_2():

@ft.transform(verbose=1)
def test(ptr, edge1, edge2):
def test(ptr, edge1, edge2, foobar):
ptr: ft.Var[(4, 11), "int32", "input", "cpu"]
edge1: ft.Var[(4, 50), "int32", "input", "cpu"]
edge2: ft.Var[(4, 50), "int32", "output", "cpu"]
Expand Down
6 changes: 3 additions & 3 deletions test/40.codegen/gpu/test_gpu_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def test(x, y):
def test_syncthreads_no_split_branch_out_of_dynamic_loop():

@ft.transform
def test(x, y):
def test(lim, x, y):
lim: ft.Var[(3, 4), "int32", "input", "gpu/global"]
x: ft.Var[(10, 10, 64), "int32", "input", "gpu/global"]
y: ft.Var[(10, 10), "int32", "output", "gpu/global"]
Expand Down Expand Up @@ -509,7 +509,7 @@ def test(x, y):
def test_syncthreads_no_need_to_split_branch():

@ft.transform
def test(x, y):
def test(lim, x, y):
lim: ft.Var[(3, 4), "int32", "input", "gpu/global"]
x: ft.Var[(12, 10, 64), "int32", "input", "gpu/global"]
y: ft.Var[(12, 10), "int32", "output", "gpu/global"]
Expand Down Expand Up @@ -562,7 +562,7 @@ def test(x, y):
def test_syncthreads_no_need_to_split_branch_warp():

@ft.transform
def test(x, y):
def test(lim, x, y):
lim: ft.Var[(3, 4), "int32", "input", "gpu/global"]
x: ft.Var[(12, 10, 32), "int32", "input", "gpu/global"]
y: ft.Var[(12, 10), "int32", "output", "gpu/global"]
Expand Down
24 changes: 24 additions & 0 deletions test/50.frontend/test_error_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,27 @@ def test(a, b):
assert f"File \"{file}\", line {line_foo}" in e.value.args[0]
assert f"File \"{file}\", line {line_bar}" in e.value.args[0]
assert f"File \"{file}\", line {line_ab}" in e.value.args[0]


def test_invalid_type_declaration_1():

with pytest.raises(ft.StagingError):

@ft.transform(verbose=2)
def test(a, b):
c = 1
a: ft.Var[(), "int32"]
c: ft.Var[(), "int32"] # c is not a parameter
return a


def test_invalid_type_declaration_2():

with pytest.raises(ft.StagingError):

@ft.transform(verbose=2)
def test(a, b):
# c is not defined
a: ft.Var[(), "int32"]
c: ft.Var[(), "int32"] # c is not a parameter
return a

0 comments on commit e9eeb30

Please sign in to comment.