From db1bc896db5b29f94dd730f0fc5b814530462b3d Mon Sep 17 00:00:00 2001 From: Shizhi Tang Date: Tue, 16 Jan 2024 15:25:21 +0800 Subject: [PATCH] Improve AST serialization grammar for For nodes (#582) * Improve AST serialization grammar for For nodes * Fix broken tests --- grammar/ast_lexer.g | 5 ++++- grammar/ast_parser.g | 2 +- include/serialize/print_ast.h | 3 ++- src/serialize/print_ast.cc | 10 +++++----- test/40.codegen/gpu/test_gpu_sync.py | 20 ++++++++++---------- 5 files changed, 22 insertions(+), 18 deletions(-) diff --git a/grammar/ast_lexer.g b/grammar/ast_lexer.g index 9c477cb5c..4a0527e7f 100644 --- a/grammar/ast_lexer.g +++ b/grammar/ast_lexer.g @@ -6,7 +6,10 @@ Comment: '/*' .*? '*/' -> skip; IF: 'if'; ELSE: 'else'; FOR: 'for'; -IN: 'in'; +FROM: 'from'; +UNTIL: 'until'; +STEP: 'step'; +LENGTH: 'length'; ASSERT_TOKEN: 'assert'; ASSUME: 'assume'; FUNC: 'func'; diff --git a/grammar/ast_parser.g b/grammar/ast_parser.g index e62341fd1..17b381eba 100644 --- a/grammar/ast_parser.g +++ b/grammar/ast_parser.g @@ -347,7 +347,7 @@ forProperty returns [Ref property] for returns [Stmt node] : forProperty - FOR var IN begin=expr ':' end=expr ':' step=expr ':' len=expr + FOR var FROM begin=expr UNTIL end=expr STEP step=expr LENGTH len=expr LBRACE stmts RBRACE { $node = makeFor($var.name, $begin.node, $end.node, $step.node, $len.node, diff --git a/include/serialize/print_ast.h b/include/serialize/print_ast.h index f96f7f173..bd69f83d5 100644 --- a/include/serialize/print_ast.h +++ b/include/serialize/print_ast.h @@ -15,7 +15,8 @@ class PrintVisitor : public CodeGen { hexFloat_ = false, parenDespitePriority_ = false, printSourceLocation_ = false; const std::unordered_set keywords = { - "if", "else", "for", "in", "assert", "assume", "func", "true", "false", + "if", "else", "for", "from", "until", "step", + "length", "assert", "assume", "func", "true", "false", }; /** diff --git a/src/serialize/print_ast.cc b/src/serialize/print_ast.cc index d29783a69..57e5bb72b 100644 --- a/src/serialize/print_ast.cc +++ b/src/serialize/print_ast.cc @@ -657,14 +657,14 @@ void PrintVisitor::visit(const For &op) { os() << "@!prefer_libs" << std::endl; } makeIndent(); - os() << prettyKeyword("for ") << prettyIterName(op->iter_) - << prettyKeyword(" in "); + os() << prettyKeyword("for") << " " << prettyIterName(op->iter_) << " " + << prettyKeyword("from") << " "; recur(op->begin_); - os() << " : "; + os() << " " << prettyKeyword("until") << " "; recur(op->end_); - os() << " : "; + os() << " " << prettyKeyword("step") << " "; recur(op->step_); - os() << " : "; + os() << " " << prettyKeyword("length") << " "; recur(op->len_); os() << " "; beginBlock(); diff --git a/test/40.codegen/gpu/test_gpu_sync.py b/test/40.codegen/gpu/test_gpu_sync.py index 33719f81a..c867b9fb2 100644 --- a/test/40.codegen/gpu/test_gpu_sync.py +++ b/test/40.codegen/gpu/test_gpu_sync.py @@ -349,11 +349,11 @@ def test_syncthreads_between_cond_and_body_of_a_branch(): # Already normalized AST: ast = ft.load_ast(''' @!parallel : @threadIdx.y -for `.threadIdx.y` in 0 : 8 : 1 : 8 { +for `.threadIdx.y` from 0 until 8 step 1 length 8 { @!parallel : @threadIdx.x - for `.threadIdx.x` in 0 : 32 : 1 : 32 { + for `.threadIdx.x` from 0 until 32 step 1 length 32 { @inout @gpu/global x: float32[8] @!pinned { - for i in 0 : 5856 : 1 : 5856 { + for i from 0 until 5856 step 1 length 5856 { if x[`.threadIdx.y`] > 0 { if `.threadIdx.x` == 0 { x[`.threadIdx.y`] = 1 @@ -1012,13 +1012,13 @@ def test_reject_dependence_between_blocks(): @cache @gpu/global t: float32[4, 4] { @output @gpu/global y: float32[4, 4] { @!parallel : @blockIdx.x - for i in 0 : 4 : 1 : 4 { + for i from 0 until 4 step 1 length 4 { @!parallel : @blockIdx.y - for j in 0 : 4 : 1 : 4 { + for j from 0 until 4 step 1 length 4 { t[i, j] = x[i, j] * 2 } @!parallel : @blockIdx.y - for j_1 in 0 : 4 : 1 : 4 { + for j_1 from 0 until 4 step 1 length 4 { y[i, j_1] = t[i, (j_1 + 1) % 4] } } @@ -1034,16 +1034,16 @@ def test_dont_reject_false_dependence_between_blocks(): @input @gpu/global x: float32[4, 4, 4] { @cache @gpu/global t: float32[4, 4, 4] { @output @gpu/global y: float32[4, 4, 4] { - for p in 0 : 4 : 1 : 4 { + for p from 0 until 4 step 1 length 4 { @!parallel : @blockIdx.x - for i in 0 : 4 : 1 : 4 { + for i from 0 until 4 step 1 length 4 { @!parallel : @blockIdx.y - for j in 0 : 4 : 1 : 4 { + for j from 0 until 4 step 1 length 4 { t[p, i, j] = x[p, i, j] * 2 } if p > 0 { @!parallel : @blockIdx.y - for j_1 in 0 : 4 : 1 : 4 { + for j_1 from 0 until 4 step 1 length 4 { y[p, i, j_1] = t[p - 1, i, (j_1 + 1) % 4] } }