Skip to content

Commit

Permalink
Improve AST serialization grammar for For nodes (#582)
Browse files Browse the repository at this point in the history
* Improve AST serialization grammar for For nodes

* Fix broken tests
  • Loading branch information
roastduck authored Jan 16, 2024
1 parent e768fb2 commit db1bc89
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 18 deletions.
5 changes: 4 additions & 1 deletion grammar/ast_lexer.g
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ Comment: '/*' .*? '*/' -> skip;
IF: 'if';
ELSE: 'else';
FOR: 'for';
IN: 'in';
FROM: 'from';
UNTIL: 'until';
STEP: 'step';
LENGTH: 'length';
ASSERT_TOKEN: 'assert';
ASSUME: 'assume';
FUNC: 'func';
Expand Down
2 changes: 1 addition & 1 deletion grammar/ast_parser.g
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ forProperty returns [Ref<ForProperty> property]

for returns [Stmt node]
: forProperty
FOR var IN begin=expr ':' end=expr ':' step=expr ':' len=expr
FOR var FROM begin=expr UNTIL end=expr STEP step=expr LENGTH len=expr
LBRACE stmts RBRACE
{
$node = makeFor($var.name, $begin.node, $end.node, $step.node, $len.node,
Expand Down
3 changes: 2 additions & 1 deletion include/serialize/print_ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ class PrintVisitor : public CodeGen<CodeGenStream> {
hexFloat_ = false, parenDespitePriority_ = false,
printSourceLocation_ = false;
const std::unordered_set<std::string> keywords = {
"if", "else", "for", "in", "assert", "assume", "func", "true", "false",
"if", "else", "for", "from", "until", "step",
"length", "assert", "assume", "func", "true", "false",
};

/**
Expand Down
10 changes: 5 additions & 5 deletions src/serialize/print_ast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -657,14 +657,14 @@ void PrintVisitor::visit(const For &op) {
os() << "@!prefer_libs" << std::endl;
}
makeIndent();
os() << prettyKeyword("for ") << prettyIterName(op->iter_)
<< prettyKeyword(" in ");
os() << prettyKeyword("for") << " " << prettyIterName(op->iter_) << " "
<< prettyKeyword("from") << " ";
recur(op->begin_);
os() << " : ";
os() << " " << prettyKeyword("until") << " ";
recur(op->end_);
os() << " : ";
os() << " " << prettyKeyword("step") << " ";
recur(op->step_);
os() << " : ";
os() << " " << prettyKeyword("length") << " ";
recur(op->len_);
os() << " ";
beginBlock();
Expand Down
20 changes: 10 additions & 10 deletions test/40.codegen/gpu/test_gpu_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,11 @@ def test_syncthreads_between_cond_and_body_of_a_branch():
# Already normalized AST:
ast = ft.load_ast('''
@!parallel : @threadIdx.y
for `.threadIdx.y` in 0 : 8 : 1 : 8 {
for `.threadIdx.y` from 0 until 8 step 1 length 8 {
@!parallel : @threadIdx.x
for `.threadIdx.x` in 0 : 32 : 1 : 32 {
for `.threadIdx.x` from 0 until 32 step 1 length 32 {
@inout @gpu/global x: float32[8] @!pinned {
for i in 0 : 5856 : 1 : 5856 {
for i from 0 until 5856 step 1 length 5856 {
if x[`.threadIdx.y`] > 0 {
if `.threadIdx.x` == 0 {
x[`.threadIdx.y`] = 1
Expand Down Expand Up @@ -1012,13 +1012,13 @@ def test_reject_dependence_between_blocks():
@cache @gpu/global t: float32[4, 4] {
@output @gpu/global y: float32[4, 4] {
@!parallel : @blockIdx.x
for i in 0 : 4 : 1 : 4 {
for i from 0 until 4 step 1 length 4 {
@!parallel : @blockIdx.y
for j in 0 : 4 : 1 : 4 {
for j from 0 until 4 step 1 length 4 {
t[i, j] = x[i, j] * 2
}
@!parallel : @blockIdx.y
for j_1 in 0 : 4 : 1 : 4 {
for j_1 from 0 until 4 step 1 length 4 {
y[i, j_1] = t[i, (j_1 + 1) % 4]
}
}
Expand All @@ -1034,16 +1034,16 @@ def test_dont_reject_false_dependence_between_blocks():
@input @gpu/global x: float32[4, 4, 4] {
@cache @gpu/global t: float32[4, 4, 4] {
@output @gpu/global y: float32[4, 4, 4] {
for p in 0 : 4 : 1 : 4 {
for p from 0 until 4 step 1 length 4 {
@!parallel : @blockIdx.x
for i in 0 : 4 : 1 : 4 {
for i from 0 until 4 step 1 length 4 {
@!parallel : @blockIdx.y
for j in 0 : 4 : 1 : 4 {
for j from 0 until 4 step 1 length 4 {
t[p, i, j] = x[p, i, j] * 2
}
if p > 0 {
@!parallel : @blockIdx.y
for j_1 in 0 : 4 : 1 : 4 {
for j_1 from 0 until 4 step 1 length 4 {
y[p, i, j_1] = t[p - 1, i, (j_1 + 1) % 4]
}
}
Expand Down

0 comments on commit db1bc89

Please sign in to comment.