diff --git a/drools-drl/drools-drl-parser-tests/src/test/java/org/drools/drl/parser/antlr4/MiscDRLParserTest.java b/drools-drl/drools-drl-parser-tests/src/test/java/org/drools/drl/parser/antlr4/MiscDRLParserTest.java index ab2f7070606..875315157d1 100644 --- a/drools-drl/drools-drl-parser-tests/src/test/java/org/drools/drl/parser/antlr4/MiscDRLParserTest.java +++ b/drools-drl/drools-drl-parser-tests/src/test/java/org/drools/drl/parser/antlr4/MiscDRLParserTest.java @@ -2174,12 +2174,44 @@ public void parse_QualifiedClassname() throws Exception { } @Test - public void parse_Accumulate() throws Exception { - final PackageDescr pkg = parseAndGetPackageDescrFromFile( - "accumulate.drl" ); + void accumulate() { + final String drl = "rule R\n" + + "when\n" + + " accumulate( Person( $age : age );\n" + + " $avg : average( $age ) );\n" + + "then\n" + + "end"; + RuleDescr rule = parseAndGetFirstRuleDescr(drl); + + PatternDescr out = (PatternDescr) rule.getLhs().getDescrs().get( 0 ); + assertThat(out.getObjectType()).isEqualTo("Object"); + AccumulateDescr accum = (AccumulateDescr) out.getSource(); + assertThat(accum.isExternalFunction()).isTrue(); + + List functions = accum.getFunctions(); + assertThat(functions.size()).isEqualTo(1); + assertThat(functions.get(0).getFunction()).isEqualTo("average"); + assertThat(functions.get(0).getBind()).isEqualTo("$avg"); + assertThat(functions.get(0).getParams()[0]).isEqualTo("$age"); + + final PatternDescr pattern = accum.getInputPattern(); + assertThat(pattern.getObjectType()).isEqualTo("Person"); + + // accum.getInput() is always AndDescr + assertThat(accum.getInput()).isInstanceOfSatisfying(AndDescr.class, and -> { + assertThat(and.getDescrs()).hasSize(1); + assertThat(and.getDescrs().get(0)).isInstanceOfSatisfying(PatternDescr.class, patternDescr -> { + assertThat(patternDescr.getObjectType()).isEqualTo("Person"); + }); + }); + } + + @Test + void fromAccumulate() { + final PackageDescr pkg = parseAndGetPackageDescrFromFile("from_accumulate.drl" ); assertThat(pkg.getRules().size()).isEqualTo(1); - final RuleDescr rule = (RuleDescr) pkg.getRules().get( 0 ); + final RuleDescr rule = pkg.getRules().get( 0 ); assertThat(rule.getLhs().getDescrs().size()).isEqualTo(1); final PatternDescr outPattern = (PatternDescr) rule.getLhs().getDescrs().get( 0 ); @@ -2191,8 +2223,16 @@ public void parse_Accumulate() throws Exception { assertThat(accum.isExternalFunction()).isFalse(); - final PatternDescr pattern = (PatternDescr) accum.getInputPattern(); + final PatternDescr pattern = accum.getInputPattern(); assertThat(pattern.getObjectType()).isEqualTo("Person"); + + // accum.getInput() is always AndDescr + assertThat(accum.getInput()).isInstanceOfSatisfying(AndDescr.class, and -> { + assertThat(and.getDescrs()).hasSize(1); + assertThat(and.getDescrs().get(0)).isInstanceOfSatisfying(PatternDescr.class, patternDescr -> { + assertThat(patternDescr.getObjectType()).isEqualTo("Person"); + }); + }); } @Test @@ -5232,4 +5272,36 @@ void durationChunk() { // At the moment, the parser accepts any input and let the compile phase validate it. assertThat(rule.getAttributes().get("duration").getValue()).isEqualTo("wrong input"); } + + @Test + void accumulateWithEmptyActionAndReverse() { + final String drl = "rule R when\n" + + " Number() from accumulate( Number(),\n" + + " init( double total = 0; ),\n" + + " action( ),\n" + + " reverse( ),\n" + + " result( new Double( total ) )\n" + + " )\n" + + "then end"; + RuleDescr rule = parseAndGetFirstRuleDescr(drl); + + final PatternDescr outPattern = (PatternDescr) rule.getLhs().getDescrs().get( 0 ); + final AccumulateDescr accum = (AccumulateDescr) outPattern.getSource(); + assertThat(accum.getInitCode()).isEqualTo( "double total = 0;"); + assertThat(accum.getActionCode()).isEmpty(); + assertThat(accum.getReverseCode()).isEmpty(); + assertThat(accum.getResultCode()).isEqualTo( "new Double( total )"); + + assertThat(accum.isExternalFunction()).isFalse(); + + final PatternDescr pattern = accum.getInputPattern(); + assertThat(pattern.getObjectType()).isEqualTo("Number"); + + assertThat(accum.getInput()).isInstanceOfSatisfying(AndDescr.class, and -> { + assertThat(and.getDescrs()).hasSize(1); + assertThat(and.getDescrs().get(0)).isInstanceOfSatisfying(PatternDescr.class, patternDescr -> { + assertThat(patternDescr.getObjectType()).isEqualTo("Number"); + }); + }); + } } diff --git a/drools-drl/drools-drl-parser-tests/src/test/resources/org/drools/drl/parser/antlr4/accumulate.drl b/drools-drl/drools-drl-parser-tests/src/test/resources/org/drools/drl/parser/antlr4/from_accumulate.drl similarity index 100% rename from drools-drl/drools-drl-parser-tests/src/test/resources/org/drools/drl/parser/antlr4/accumulate.drl rename to drools-drl/drools-drl-parser-tests/src/test/resources/org/drools/drl/parser/antlr4/from_accumulate.drl diff --git a/drools-drl/drools-drl-parser/src/main/antlr4/org/drools/drl/parser/antlr4/DRLParser.g4 b/drools-drl/drools-drl-parser/src/main/antlr4/org/drools/drl/parser/antlr4/DRLParser.g4 index cf1afb97862..cabdaaa1f34 100644 --- a/drools-drl/drools-drl-parser/src/main/antlr4/org/drools/drl/parser/antlr4/DRLParser.g4 +++ b/drools-drl/drools-drl-parser/src/main/antlr4/org/drools/drl/parser/antlr4/DRLParser.g4 @@ -340,7 +340,8 @@ fromAccumulate := ACCUMULATE LEFT_PAREN lhsAnd (COMMA|SEMICOLON) ) RIGHT_PAREN */ fromAccumulate : (DRL_ACCUMULATE|DRL_ACC) LPAREN lhsAndDef (COMMA|SEMI) - ( DRL_INIT LPAREN initBlockStatements=chunk? RPAREN COMMA? DRL_ACTION LPAREN actionBlockStatements=chunk? RPAREN COMMA? ( DRL_REVERSE LPAREN reverseBlockStatements=chunk? RPAREN COMMA?)? DRL_RESULT LPAREN resultBlockStatements=chunk RPAREN + ( DRL_INIT LPAREN initBlockStatements=chunk? RPAREN COMMA? DRL_ACTION LPAREN actionBlockStatements=chunk? RPAREN COMMA? DRL_REVERSE LPAREN reverseBlockStatements=chunk? RPAREN COMMA? DRL_RESULT LPAREN resultBlockStatements=chunk RPAREN + | DRL_INIT LPAREN initBlockStatements=chunk? RPAREN COMMA? DRL_ACTION LPAREN actionBlockStatements=chunk? RPAREN COMMA? DRL_RESULT LPAREN resultBlockStatements=chunk RPAREN | accumulateFunction ) RPAREN (SEMI)? diff --git a/drools-drl/drools-drl-parser/src/main/java/org/drools/drl/parser/antlr4/DRLVisitorImpl.java b/drools-drl/drools-drl-parser/src/main/java/org/drools/drl/parser/antlr4/DRLVisitorImpl.java index 39aa5e930e8..d67450e9193 100644 --- a/drools-drl/drools-drl-parser/src/main/java/org/drools/drl/parser/antlr4/DRLVisitorImpl.java +++ b/drools-drl/drools-drl-parser/src/main/java/org/drools/drl/parser/antlr4/DRLVisitorImpl.java @@ -699,7 +699,8 @@ public PatternDescr visitLhsAccumulate(DRLParser.LhsAccumulateContext ctx) { AccumulateDescr accumulateDescr = BaseDescrFactory.builder(new AccumulateDescr()) .withParserRuleContext(ctx) .build(); - accumulateDescr.setInput(visitLhsAndDef(ctx.lhsAndDef())); + // accumulateDescr.input is always AndDescr + accumulateDescr.setInput(wrapWithAndDescr(visitLhsAndDef(ctx.lhsAndDef()), ctx.lhsAndDef())); // accumulate function for (DRLParser.AccumulateFunctionContext accumulateFunctionContext : ctx.accumulateFunction()) { @@ -714,6 +715,18 @@ public PatternDescr visitLhsAccumulate(DRLParser.LhsAccumulateContext ctx) { return patternDescr; } + private AndDescr wrapWithAndDescr(BaseDescr baseDescr, ParserRuleContext ctx) { + if (baseDescr instanceof AndDescr andDescr) { + return andDescr; + } else { + AndDescr andDescr = BaseDescrFactory.builder(new AndDescr()) + .withParserRuleContext(ctx) + .build(); + andDescr.addDescr(baseDescr); + return andDescr; + } + } + @Override public Object visitLhsGroupBy(DRLParser.LhsGroupByContext ctx) { GroupByDescr groupByDescr = BaseDescrFactory.builder(new GroupByDescr()) @@ -775,7 +788,8 @@ public AccumulateDescr visitFromAccumulate(DRLParser.FromAccumulateContext ctx) AccumulateDescr accumulateDescr = BaseDescrFactory.builder(new AccumulateDescr()) .withParserRuleContext(ctx) .build(); - accumulateDescr.setInput(visitLhsAndDef(ctx.lhsAndDef())); + // accumulateDescr.input is always AndDescr + accumulateDescr.setInput(wrapWithAndDescr(visitLhsAndDef(ctx.lhsAndDef()), ctx.lhsAndDef())); if (ctx.DRL_INIT() != null) { // inline custom accumulate accumulateDescr.setInitCode(getTextPreservingWhitespace(ctx.initBlockStatements));