diff --git a/cedar-integration-tests/corpus_tests/7b316784cf9e60631768b35cb7a7e15ba6d01c05.json b/cedar-integration-tests/corpus_tests/7b316784cf9e60631768b35cb7a7e15ba6d01c05.json index 3c4c23680..735fc6d19 100644 --- a/cedar-integration-tests/corpus_tests/7b316784cf9e60631768b35cb7a7e15ba6d01c05.json +++ b/cedar-integration-tests/corpus_tests/7b316784cf9e60631768b35cb7a7e15ba6d01c05.json @@ -13,7 +13,7 @@ "decision": "Deny", "reasons": [], "errors": [ - "error occurred while evaluating policy `policy0`: integer overflow while attempting to multiply -1537158028109086738 by 60138" + "error occurred while evaluating policy `policy0`: integer overflow while attempting to multiply the values -1537158028109086738 and 60138" ] }, { @@ -25,7 +25,7 @@ "decision": "Deny", "reasons": [], "errors": [ - "error occurred while evaluating policy `policy0`: integer overflow while attempting to multiply -1537158028109086738 by 60138" + "error occurred while evaluating policy `policy0`: integer overflow while attempting to multiply the values -1537158028109086738 and 60138" ] }, { @@ -37,7 +37,7 @@ "decision": "Deny", "reasons": [], "errors": [ - "error occurred while evaluating policy `policy0`: integer overflow while attempting to multiply -1537158028109086738 by 60138" + "error occurred while evaluating policy `policy0`: integer overflow while attempting to multiply the values -1537158028109086738 and 60138" ] }, { @@ -49,7 +49,7 @@ "decision": "Deny", "reasons": [], "errors": [ - "error occurred while evaluating policy `policy0`: integer overflow while attempting to multiply -1537158028109086738 by 60138" + "error occurred while evaluating policy `policy0`: integer overflow while attempting to multiply the values -1537158028109086738 and 60138" ] }, { @@ -61,7 +61,7 @@ "decision": "Deny", "reasons": [], "errors": [ - "error occurred while evaluating policy `policy0`: integer overflow while attempting to multiply -1537158028109086738 by 60138" + "error occurred while evaluating policy `policy0`: integer overflow while attempting to multiply the values -1537158028109086738 and 60138" ] }, { @@ -73,7 +73,7 @@ "decision": "Deny", "reasons": [], "errors": [ - "error occurred while evaluating policy `policy0`: integer overflow while attempting to multiply -1537158028109086738 by 60138" + "error occurred while evaluating policy `policy0`: integer overflow while attempting to multiply the values -1537158028109086738 and 60138" ] }, { @@ -85,7 +85,7 @@ "decision": "Deny", "reasons": [], "errors": [ - "error occurred while evaluating policy `policy0`: integer overflow while attempting to multiply -1537158028109086738 by 60138" + "error occurred while evaluating policy `policy0`: integer overflow while attempting to multiply the values -1537158028109086738 and 60138" ] }, { @@ -97,8 +97,8 @@ "decision": "Deny", "reasons": [], "errors": [ - "error occurred while evaluating policy `policy0`: integer overflow while attempting to multiply -1537158028109086738 by 60138" + "error occurred while evaluating policy `policy0`: integer overflow while attempting to multiply the values -1537158028109086738 and 60138" ] } ] -} \ No newline at end of file +} diff --git a/cedar-policy-core/src/ast/expr.rs b/cedar-policy-core/src/ast/expr.rs index d615d47e8..904ebe073 100644 --- a/cedar-policy-core/src/ast/expr.rs +++ b/cedar-policy-core/src/ast/expr.rs @@ -101,17 +101,6 @@ pub enum ExprKind { /// Second arg arg2: Arc>, }, - /// Multiplication by constant - /// - /// This isn't just a BinaryOp because its arguments aren't both expressions. - /// (Similar to how `like` isn't a BinaryOp and has its own AST node as well.) - MulByConst { - /// first argument, which may be an arbitrary expression, but must - /// evaluate to Long type - arg: Arc>, - /// second argument, which must be an integer constant - constant: i64, - }, /// Application of an extension function to n arguments /// INVARIANT (MethodStyleArgs): /// if op.style is MethodStyle then args _cannot_ be empty. @@ -382,9 +371,9 @@ impl Expr { ExprBuilder::new().sub(e1, e2) } - /// Create a 'mul' expression. First argument must evaluate to Long type. - pub fn mul(e: Expr, c: i64) -> Self { - ExprBuilder::new().mul(e, c) + /// Create a 'mul' expression. Arguments must evaluate to Long type + pub fn mul(e1: Expr, e2: Expr) -> Self { + ExprBuilder::new().mul(e1, e2) } /// Create a 'neg' expression. `e` must evaluate to Long type. @@ -566,9 +555,6 @@ impl Expr { .collect::, _>>()?; Ok(Expr::record(pairs)) } - ExprKind::MulByConst { arg, constant } => { - Ok(Expr::mul(arg.substitute(definitions)?, *constant)) - } } } } @@ -652,6 +638,12 @@ impl std::fmt::Display for Expr { maybe_with_parens(arg1), maybe_with_parens(arg2), ), + BinaryOp::Mul => write!( + f, + "{} * {}", + maybe_with_parens(arg1), + maybe_with_parens(arg2), + ), BinaryOp::In => write!( f, "{} in {}", @@ -668,9 +660,6 @@ impl std::fmt::Display for Expr { write!(f, "{}.containsAny({})", maybe_with_parens(arg1), &arg2) } }, - ExprKind::MulByConst { arg, constant } => { - write!(f, "{} * {}", maybe_with_parens(arg), constant) - } ExprKind::ExtensionFunctionApp { fn_name, args } => { // search for the name and callstyle let style = Extensions::all_available().all_funcs().find_map(|f| { @@ -756,7 +745,6 @@ fn maybe_with_parens(expr: &Expr) -> String { format!("({})", expr) } ExprKind::BinaryApp { .. } => format!("({})", expr), - ExprKind::MulByConst { .. } => format!("({})", expr), ExprKind::ExtensionFunctionApp { .. } => format!("({})", expr), ExprKind::GetAttr { .. } => format!("({})", expr), ExprKind::HasAttr { .. } => format!("({})", expr), @@ -994,11 +982,12 @@ impl ExprBuilder { }) } - /// Create a 'mul' expression. First argument must evaluate to Long type. - pub fn mul(self, e: Expr, c: i64) -> Expr { - self.with_expr_kind(ExprKind::MulByConst { - arg: Arc::new(e), - constant: c, + /// Create a 'mul' expression. Arguments must evaluate to Long type + pub fn mul(self, e1: Expr, e2: Expr) -> Expr { + self.with_expr_kind(ExprKind::BinaryApp { + op: BinaryOp::Mul, + arg1: Arc::new(e1), + arg2: Arc::new(e2), }) } @@ -1245,13 +1234,6 @@ impl Expr { arg2: arg21, }, ) => op == op1 && arg1.eq_shape(arg11) && arg2.eq_shape(arg21), - ( - MulByConst { arg, constant }, - MulByConst { - arg: arg1, - constant: constant1, - }, - ) => constant == constant1 && arg.eq_shape(arg1), ( ExtensionFunctionApp { fn_name, args }, ExtensionFunctionApp { @@ -1337,10 +1319,6 @@ impl Expr { arg1.hash_shape(state); arg2.hash_shape(state); } - ExprKind::MulByConst { arg, constant } => { - arg.hash_shape(state); - constant.hash(state); - } ExprKind::ExtensionFunctionApp { fn_name, args } => { fn_name.hash(state); state.write_usize(args.len()); @@ -1682,8 +1660,8 @@ mod test { Expr::sub(Expr::val(1), Expr::val(1)), ), ( - ExprBuilder::with_data(1).mul(temp.clone(), 1), - Expr::mul(Expr::val(1), 1), + ExprBuilder::with_data(1).mul(temp.clone(), temp.clone()), + Expr::mul(Expr::val(1), Expr::val(1)), ), ( ExprBuilder::with_data(1).neg(temp.clone()), diff --git a/cedar-policy-core/src/ast/expr_iterator.rs b/cedar-policy-core/src/ast/expr_iterator.rs index 16d0cc303..0847f3864 100644 --- a/cedar-policy-core/src/ast/expr_iterator.rs +++ b/cedar-policy-core/src/ast/expr_iterator.rs @@ -71,9 +71,6 @@ impl<'a, T> Iterator for ExprIterator<'a, T> { self.expression_stack.push(arg1); self.expression_stack.push(arg2); } - ExprKind::MulByConst { arg, .. } => { - self.expression_stack.push(arg); - } ExprKind::ExtensionFunctionApp { args, .. } => { for arg in args.as_ref() { self.expression_stack.push(arg); diff --git a/cedar-policy-core/src/ast/ops.rs b/cedar-policy-core/src/ast/ops.rs index 07842a0e9..7163c0f76 100644 --- a/cedar-policy-core/src/ast/ops.rs +++ b/cedar-policy-core/src/ast/ops.rs @@ -61,6 +61,11 @@ pub enum BinaryOp { /// Arguments must have Long type Sub, + /// Integer multiplication + /// + /// Arguments must have Long type + Mul, + /// Hierarchy membership. Specifically, is the first arg a member of the /// second. /// @@ -103,6 +108,7 @@ impl std::fmt::Display for BinaryOp { BinaryOp::LessEq => write!(f, "_<=_"), BinaryOp::Add => write!(f, "_+_"), BinaryOp::Sub => write!(f, "_-_"), + BinaryOp::Mul => write!(f, "_*_"), BinaryOp::In => write!(f, "_in_"), BinaryOp::Contains => write!(f, "contains"), BinaryOp::ContainsAll => write!(f, "containsAll"), diff --git a/cedar-policy-core/src/ast/policy.rs b/cedar-policy-core/src/ast/policy.rs index e5873f085..948b8875e 100644 --- a/cedar-policy-core/src/ast/policy.rs +++ b/cedar-policy-core/src/ast/policy.rs @@ -24,6 +24,8 @@ use thiserror::Error; /// Top level structure for a policy template. /// Contains both the AST for template, and the list of open slots in the template. +/// +/// Note that this "template" may have no slots, in which case this `Template` represents a static policy #[derive(Clone, Hash, Eq, PartialEq, Debug, Serialize, Deserialize)] #[serde(from = "TemplateBody")] #[serde(into = "TemplateBody")] @@ -31,6 +33,8 @@ pub struct Template { body: TemplateBody, /// INVARIANT (slot cache correctness): This Vec must contain _all_ of the open slots in `body` /// This is maintained by the only two public constructors, `new` and `instantiate_inline_policy` + /// + /// Note that `slots` may be empty, in which case this `Template` represents a static policy slots: Vec, } diff --git a/cedar-policy-core/src/ast/restricted_expr.rs b/cedar-policy-core/src/ast/restricted_expr.rs index 94f622f62..af5e30182 100644 --- a/cedar-policy-core/src/ast/restricted_expr.rs +++ b/cedar-policy-core/src/ast/restricted_expr.rs @@ -205,13 +205,8 @@ fn is_restricted(expr: &Expr) -> Result<(), RestrictedExpressionError> { feature: op.to_string(), }) } - ExprKind::MulByConst { .. } => { - Err(RestrictedExpressionError::InvalidRestrictedExpression { - feature: "multiplication".into(), - }) - } ExprKind::GetAttr { .. } => Err(RestrictedExpressionError::InvalidRestrictedExpression { - feature: "get-attribute".into(), + feature: "attribute accesses".into(), }), ExprKind::HasAttr { .. } => Err(RestrictedExpressionError::InvalidRestrictedExpression { feature: "'has'".into(), diff --git a/cedar-policy-core/src/ast/value.rs b/cedar-policy-core/src/ast/value.rs index 6b3004be0..135b52c21 100644 --- a/cedar-policy-core/src/ast/value.rs +++ b/cedar-policy-core/src/ast/value.rs @@ -62,7 +62,6 @@ impl TryFrom for Value { ExprKind::Or { .. } => Err(NotValue::NotValue), ExprKind::UnaryApp { .. } => Err(NotValue::NotValue), ExprKind::BinaryApp { .. } => Err(NotValue::NotValue), - ExprKind::MulByConst { .. } => Err(NotValue::NotValue), ExprKind::ExtensionFunctionApp { .. } => Err(NotValue::NotValue), ExprKind::GetAttr { .. } => Err(NotValue::NotValue), ExprKind::HasAttr { .. } => Err(NotValue::NotValue), diff --git a/cedar-policy-core/src/est/err.rs b/cedar-policy-core/src/est/err.rs index 6348cf0e1..4e1cf5341 100644 --- a/cedar-policy-core/src/est/err.rs +++ b/cedar-policy-core/src/est/err.rs @@ -14,7 +14,7 @@ * limitations under the License. */ -use crate::ast::{self, SlotId}; +use crate::ast; use crate::entities::JsonDeserializationError; use crate::parser::unescape; use smol_str::SmolStr; @@ -40,7 +40,7 @@ pub enum EstToAstError { #[error("found template slot {slot} in a `{clausetype}` clause")] SlotsInConditionClause { /// Slot that was found in a when/unless clause - slot: SlotId, + slot: ast::SlotId, /// Clause type, e.g. "when" or "unless" clausetype: &'static str, }, diff --git a/cedar-policy-core/src/est/expr.rs b/cedar-policy-core/src/est/expr.rs index 0f0dfc497..67fff0cf9 100644 --- a/cedar-policy-core/src/est/expr.rs +++ b/cedar-policy-core/src/est/expr.rs @@ -522,26 +522,10 @@ impl TryFrom for ast::Expr { (*left).clone().try_into()?, (*right).clone().try_into()?, )), - Expr::ExprNoExt(ExprNoExt::Mul { left, right }) => { - let left: ast::Expr = (*left).clone().try_into()?; - let right: ast::Expr = (*right).clone().try_into()?; - let left_c = match left.expr_kind() { - ast::ExprKind::Lit(ast::Literal::Long(c)) => Some(c), - _ => None, - }; - let right_c = match right.expr_kind() { - ast::ExprKind::Lit(ast::Literal::Long(c)) => Some(c), - _ => None, - }; - match (left_c, right_c) { - (_, Some(c)) => Ok(ast::Expr::mul(left, *c)), - (Some(c), _) => Ok(ast::Expr::mul(right, *c)), - (None, None) => Err(EstToAstError::MultiplicationByNonConstant { - arg1: left, - arg2: right, - })?, - } - } + Expr::ExprNoExt(ExprNoExt::Mul { left, right }) => Ok(ast::Expr::mul( + (*left).clone().try_into()?, + (*right).clone().try_into()?, + )), Expr::ExprNoExt(ExprNoExt::Contains { left, right }) => Ok(ast::Expr::contains( (*left).clone().try_into()?, (*right).clone().try_into()?, @@ -658,15 +642,12 @@ impl From for Expr { ast::BinaryOp::LessEq => Expr::lesseq(arg1, arg2), ast::BinaryOp::Add => Expr::add(arg1, arg2), ast::BinaryOp::Sub => Expr::sub(arg1, arg2), + ast::BinaryOp::Mul => Expr::mul(arg1, arg2), ast::BinaryOp::Contains => Expr::contains(Arc::new(arg1), arg2), ast::BinaryOp::ContainsAll => Expr::contains_all(Arc::new(arg1), arg2), ast::BinaryOp::ContainsAny => Expr::contains_any(Arc::new(arg1), arg2), } } - ast::ExprKind::MulByConst { arg, constant } => Expr::mul( - unwrap_or_clone(arg).into(), - Expr::lit(JSONValue::Long(constant)), - ), ast::ExprKind::ExtensionFunctionApp { fn_name, args } => { let args = unwrap_or_clone(args).into_iter().map(Into::into).collect(); Expr::ext_call(fn_name.to_string().into(), args) diff --git a/cedar-policy-core/src/evaluator.rs b/cedar-policy-core/src/evaluator.rs index d8f1a34cd..dc4751e02 100644 --- a/cedar-policy-core/src/evaluator.rs +++ b/cedar-policy-core/src/evaluator.rs @@ -406,7 +406,11 @@ impl<'q, 'e> Evaluator<'e> { match op { BinaryOp::Eq => Ok((arg1 == arg2).into()), // comparison and arithmetic operators, which only work on Longs - BinaryOp::Less | BinaryOp::LessEq | BinaryOp::Add | BinaryOp::Sub => { + BinaryOp::Less + | BinaryOp::LessEq + | BinaryOp::Add + | BinaryOp::Sub + | BinaryOp::Mul => { let i1 = arg1.get_as_long()?; let i2 = arg2.get_as_long()?; match op { @@ -432,6 +436,16 @@ impl<'q, 'e> Evaluator<'e> { }, )), }, + BinaryOp::Mul => match i1.checked_mul(i2) { + Some(prod) => Ok(prod.into()), + None => Err(EvaluationError::IntegerOverflow( + IntegerOverflowError::BinaryOp { + op: *op, + arg1, + arg2, + }, + )), + }, // PANIC SAFETY `op` is checked to be one of the above #[allow(clippy::unreachable)] _ => { @@ -515,21 +529,6 @@ impl<'q, 'e> Evaluator<'e> { } } } - ExprKind::MulByConst { arg, constant } => match self.partial_interpret(arg, slots)? { - PartialValue::Value(arg) => { - let i1 = arg.get_as_long()?; - match i1.checked_mul(*constant) { - Some(prod) => Ok(prod.into()), - None => Err(EvaluationError::IntegerOverflow( - IntegerOverflowError::Multiplication { - arg, - constant: *constant, - }, - )), - } - } - PartialValue::Residual(r) => Ok(PartialValue::Residual(Expr::mul(r, *constant))), - }, ExprKind::ExtensionFunctionApp { fn_name, args } => { let args = args .iter() @@ -863,6 +862,8 @@ fn stack_size_check() -> Result<()> { pub mod test { use std::str::FromStr; + use cool_asserts::assert_matches; + use super::*; use crate::{ @@ -2582,30 +2583,33 @@ pub mod test { ); // 5 * (-3) assert_eq!( - eval.interpret_inline_policy(&Expr::mul(Expr::val(5), -3)), - Ok(Value::Lit(Literal::Long(-15))) + eval.interpret_inline_policy(&Expr::mul(Expr::val(5), Expr::val(-3))), + Ok(Value::from(-15)) ); // 5 * 0 assert_eq!( - eval.interpret_inline_policy(&Expr::mul(Expr::val(5), 0)), - Ok(Value::Lit(Literal::Long(0))) + eval.interpret_inline_policy(&Expr::mul(Expr::val(5), Expr::val(0))), + Ok(Value::from(0)) ); // "5" * 0 - assert_eq!( - eval.interpret_inline_policy(&Expr::mul(Expr::val("5"), 0)), - Err(EvaluationError::TypeError { - expected: vec![Type::Long], - actual: Type::String - }) + assert_matches!( + eval.interpret_inline_policy(&Expr::mul(Expr::val("5"), Expr::val(0))), + Err(e) => assert_eq!(e, + EvaluationError::TypeError { + expected: vec![Type::Long], + actual: Type::String + } + ) ); // overflow assert_eq!( - eval.interpret_inline_policy(&Expr::mul(Expr::val(std::i64::MAX - 1), 3)), + eval.interpret_inline_policy(&Expr::mul(Expr::val(i64::MAX - 1), Expr::val(3))), Err(EvaluationError::IntegerOverflow( - IntegerOverflowError::Multiplication { - arg: Value::from(std::i64::MAX - 1), - constant: 3, - } + IntegerOverflowError::BinaryOp { + op: BinaryOp::Mul, + arg1: Value::from(i64::MAX - 1), + arg2: Value::from(3), + }, )) ); } @@ -4916,7 +4920,7 @@ pub mod test { let exts = Extensions::none(); let eval = Evaluator::new(&empty_request(), &es, &exts).unwrap(); - let e = Expr::mul(Expr::unknown("a"), 32); + let e = Expr::mul(Expr::unknown("a"), Expr::val(32)); let r = eval.partial_interpret(&e, &HashMap::new()).unwrap(); assert_eq!(r, PartialValue::Residual(e)); } diff --git a/cedar-policy-core/src/evaluator/err.rs b/cedar-policy-core/src/evaluator/err.rs index f29c67e0f..bf89189e6 100644 --- a/cedar-policy-core/src/evaluator/err.rs +++ b/cedar-policy-core/src/evaluator/err.rs @@ -127,7 +127,8 @@ fn pretty_type_error(expected: &[Type], actual: &Type) -> String { #[derive(Debug, PartialEq, Clone, Error)] pub enum IntegerOverflowError { - #[error("integer overflow while attempting to {} the values {arg1} and {arg2}", match .op { BinaryOp::Add => "add", BinaryOp::Sub => "subtract", _ => "perform an operation on" })] + /// Overflow during a binary operation + #[error("integer overflow while attempting to {} the values {arg1} and {arg2}", match .op { BinaryOp::Add => "add", BinaryOp::Sub => "subtract", BinaryOp::Mul => "multiply", _ => "perform an operation on" })] BinaryOp { /// overflow while evaluating this operator op: BinaryOp, @@ -137,14 +138,6 @@ pub enum IntegerOverflowError { arg2: Value, }, - #[error("integer overflow while attempting to multiply {arg} by {constant}")] - Multiplication { - /// first argument, which wasn't necessarily a constant in the policy - arg: Value, - /// second argument, which was a constant in the policy - constant: i64, - }, - /// Overflow during an integer negation operation #[error("integer overflow while attempting to {} the value {arg}", match .op { UnaryOp::Neg => "negate", _ => "perform an operation on" })] UnaryOp { diff --git a/cedar-policy-core/src/parser/cst_to_ast.rs b/cedar-policy-core/src/parser/cst_to_ast.rs index f2f2e984c..9cbfb1713 100644 --- a/cedar-policy-core/src/parser/cst_to_ast.rs +++ b/cedar-policy-core/src/parser/cst_to_ast.rs @@ -43,7 +43,7 @@ use crate::ast::{ self, ActionConstraint, CallStyle, EntityReference, EntityType, EntityUID, PatternElem, PolicySetError, PrincipalConstraint, PrincipalOrResourceConstraint, ResourceConstraint, }; -use itertools::Either; +use itertools::{Either, Itertools}; use smol_str::SmolStr; use std::cmp::Ordering; use std::collections::{BTreeMap, HashSet}; @@ -1189,78 +1189,31 @@ impl ASTNode> { let mult = maybe_mult?; let maybe_first = mult.initial.to_expr_or_special(errs); - // collect() preforms all the conversions, generating any errors - let more: Vec<(cst::MultOp, _)> = mult + let more = mult .extended .iter() - .filter_map(|&(op, ref i)| i.to_expr(errs).map(|e| (op, e))) - .collect(); + .filter_map(|&(op, ref i)| i.to_expr(errs).map(|e| (op, e))); + let (more, new_errs): (Vec<_>, Vec<_>) = more + .map(|(op, expr)| match op { + cst::MultOp::Times => Ok(expr), + cst::MultOp::Divide => { + Err(ParseError::ToAST("division is not supported".to_string())) + } + cst::MultOp::Mod => Err(ParseError::ToAST( + "remainder/modulo is not supported".to_string(), + )), + }) + .partition_result(); + errs.extend(new_errs); if !more.is_empty() { + // in this case, `first` must be an expr, we should collect any errors there as well let first = maybe_first?.into_expr(errs)?; - // enforce that division and remainder/modulo are not supported - for (op, _) in &more { - match op { - cst::MultOp::Times => {} - cst::MultOp::Divide => { - errs.push(ParseError::ToAST("division is not supported".to_string())); - return None; - } - cst::MultOp::Mod => { - errs.push(ParseError::ToAST( - "remainder/modulo is not supported".to_string(), - )); - return None; - } - } - } - // split all the operands into constantints and nonconstantints. - // also, remove the opcodes -- from here on we assume they're all - // `Times`, having checked above that this is the case - let (constantints, nonconstantints): (Vec, Vec) = - std::iter::once(first) - .chain(more.into_iter().map(|(_, e)| e)) - .partition(|e| { - matches!(e.expr_kind(), ast::ExprKind::Lit(ast::Literal::Long(_))) - }); - let constantints = constantints - .into_iter() - .map(|e| match e.expr_kind() { - ast::ExprKind::Lit(ast::Literal::Long(i)) => *i, - // PANIC SAFETY Checked the match above via the call to `partition` - #[allow(clippy::unreachable)] - _ => unreachable!( - "checked it matched ast::ExprKind::Lit(ast::Literal::Long(_)) above" - ), - }) - .collect::>(); - if nonconstantints.len() > 1 { - // at most one of the operands in `a * b * c * d * ...` can be a nonconstantint - errs.push(err::ParseError::ToAST( - "Multiplication must be by a constant int".to_string(), // you could see this error for division by a nonconstant as well, but this error message seems like the appropriate one, it will be the common case - )); - None - } else if nonconstantints.is_empty() { - // PANIC SAFETY If nonconstantints is empty then constantints must have at least one value - #[allow(clippy::indexing_slicing)] - Some(ExprOrSpecial::Expr(construct_expr_mul( - construct_expr_num(constantints[0], src.clone()), - constantints[1..].iter().copied(), - src.clone(), - ))) - } else { - // PANIC SAFETY Checked above that `nonconstantints` has at least one element - #[allow(clippy::expect_used)] - let nonconstantint: ast::Expr = nonconstantints - .into_iter() - .next() - .expect("already checked that it's not empty"); - Some(ExprOrSpecial::Expr(construct_expr_mul( - nonconstantint, - constantints, - src.clone(), - ))) - } + Some(ExprOrSpecial::Expr(construct_expr_mul( + first, + more, + src.clone(), + ))) } else { maybe_first } @@ -2096,14 +2049,14 @@ fn construct_expr_add( /// used for a chain of multiplication only (no division or mod) fn construct_expr_mul( f: ast::Expr, - chained: impl IntoIterator, - l: SourceInfo, + chained: impl IntoIterator, + loc: SourceInfo, ) -> ast::Expr { let mut expr = f; for next_expr in chained { expr = ast::ExprBuilder::new() - .with_source_info(l.clone()) - .mul(expr, next_expr) + .with_source_info(loc.clone()) + .mul(expr, next_expr); } expr } @@ -3204,8 +3157,8 @@ mod tests { // the cst should be acceptable .expect("parse error") .to_expr(&mut errs); - // conversion should fail: only multiplication by a constant is allowed - assert!(e.is_none()); + // conversion should succeed + assert!(e.is_some()); let e = text_to_cst::parse_expr(r#" 5 + 10 + 90 "#) // the cst should be acceptable @@ -3246,8 +3199,8 @@ mod tests { // the cst should be acceptable .expect("parse error") .to_expr(&mut errs); - // conversion should fail: only multiplication by a constant is allowed - assert!(e.is_none()); + // conversion should succeed + assert!(e.is_some()); } const CORRECT_TEMPLATES: [&str; 7] = [ @@ -3466,48 +3419,112 @@ mod tests { #[test] fn test_mul() { - for (es, expr) in [ - ("--2*3", Expr::mul(Expr::neg(Expr::val(-2)), 3)), + for (str, expected) in [ + ("--2*3", Expr::mul(Expr::neg(Expr::val(-2)), Expr::val(3))), ( "1 * 2 * false", - Expr::mul(Expr::mul(Expr::val(false), 1), 2), + Expr::mul(Expr::mul(Expr::val(1), Expr::val(2)), Expr::val(false)), ), ( "0 * 1 * principal", - Expr::mul(Expr::mul(Expr::var(ast::Var::Principal), 0), 1), + Expr::mul( + Expr::mul(Expr::val(0), Expr::val(1)), + Expr::var(ast::Var::Principal), + ), ), ( "0 * (-1) * principal", - Expr::mul(Expr::mul(Expr::var(ast::Var::Principal), 0), -1), + Expr::mul( + Expr::mul(Expr::val(0), Expr::val(-1)), + Expr::var(ast::Var::Principal), + ), + ), + ( + "0 * 6 * context.foo", + Expr::mul( + Expr::mul(Expr::val(0), Expr::val(6)), + Expr::get_attr(Expr::var(ast::Var::Context), "foo".into()), + ), + ), + ( + "(0 * 6) * context.foo", + Expr::mul( + Expr::mul(Expr::val(0), Expr::val(6)), + Expr::get_attr(Expr::var(ast::Var::Context), "foo".into()), + ), + ), + ( + "0 * (6 * context.foo)", + Expr::mul( + Expr::val(0), + Expr::mul( + Expr::val(6), + Expr::get_attr(Expr::var(ast::Var::Context), "foo".into()), + ), + ), + ), + ( + "0 * (context.foo * 6)", + Expr::mul( + Expr::val(0), + Expr::mul( + Expr::get_attr(Expr::var(ast::Var::Context), "foo".into()), + Expr::val(6), + ), + ), + ), + ( + "1 * 2 * 3 * context.foo * 4 * 5 * 6", + Expr::mul( + Expr::mul( + Expr::mul( + Expr::mul( + Expr::mul(Expr::mul(Expr::val(1), Expr::val(2)), Expr::val(3)), + Expr::get_attr(Expr::var(ast::Var::Context), "foo".into()), + ), + Expr::val(4), + ), + Expr::val(5), + ), + Expr::val(6), + ), + ), + ( + "principal * (1 + 2)", + Expr::mul( + Expr::var(ast::Var::Principal), + Expr::add(Expr::val(1), Expr::val(2)), + ), + ), + ( + "principal * -(-1)", + Expr::mul(Expr::var(ast::Var::Principal), Expr::neg(Expr::val(-1))), + ), + ( + "principal * --1", + Expr::mul(Expr::var(ast::Var::Principal), Expr::neg(Expr::val(-1))), + ), + ( + r#"false * "bob""#, + Expr::mul(Expr::val(false), Expr::val("bob")), ), ] { let mut errs = ParseErrors::new(); - let e = text_to_cst::parse_expr(es) + let e = text_to_cst::parse_expr(str) .expect("should construct a CST") .to_expr(&mut errs) - .expect("should convert to AST"); + .unwrap_or_else(|| { + panic!( + "failed convert to AST:\n{:?}", + miette::Report::new(errs.clone()) + ) + }); + assert!(errs.is_empty()); assert!( - e.eq_shape(&expr), - "{:?} and {:?} should have the same shape.", - e, - expr + e.eq_shape(&expected), + "{e:?} and {expected:?} should have the same shape", ); } - - for es in [ - r#"false * "bob""#, - "principal * (1 + 2)", - "principal * -(-1)", - // --1 is parsed as Expr::neg(Expr::val(-1)) and thus is not - // considered as a constant. - "principal * --1", - ] { - let mut errs = ParseErrors::new(); - let e = text_to_cst::parse_expr(es) - .expect("should construct a CST") - .to_expr(&mut errs); - assert!(e.is_none()); - } } #[test] diff --git a/cedar-policy-validator/src/typecheck.rs b/cedar-policy-validator/src/typecheck.rs index 7877fba32..5972071d4 100644 --- a/cedar-policy-validator/src/typecheck.rs +++ b/cedar-policy-validator/src/typecheck.rs @@ -836,13 +836,6 @@ impl<'a> Typechecker<'a> { ExprBuilder::with_data(e.data().clone()).unary_app(*op, strict_arg), ) }), - ExprKind::MulByConst { arg, constant } => self - .strict_transform(arg, type_errors) - .then_typecheck(|strict_arg, _| { - TypecheckAnswer::success( - ExprBuilder::with_data(e.data().clone()).mul(strict_arg, *constant), - ) - }), ExprKind::GetAttr { expr, attr } => self .strict_transform(expr, type_errors) @@ -1571,10 +1564,6 @@ impl<'a> Typechecker<'a> { // INVARIANT `e` is a `BinaryApp`, as required self.typecheck_binary(request_env, prior_eff, e, type_errors) } - ExprKind::MulByConst { .. } => { - // INVARIANT `e` is a `MulByConst`, as required - self.typecheck_mul(request_env, prior_eff, e, type_errors) - } ExprKind::ExtensionFunctionApp { .. } => { // INVARIANT `e` is a `ExtensionFunctionApp`, as required self.typecheck_extension(request_env, prior_eff, e, type_errors) @@ -1917,7 +1906,7 @@ impl<'a> Typechecker<'a> { }) } - BinaryOp::Add | BinaryOp::Sub => { + BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul => { let ans_arg1 = self.expect_type( request_env, prior_eff, @@ -1980,38 +1969,6 @@ impl<'a> Typechecker<'a> { } } - /// Like `typecheck_binary()`, but for multiplication, which isn't - /// technically a `BinaryOp` - /// INVARIANT `mul_expr` must be a `MulByConst` - fn typecheck_mul<'b>( - &self, - request_env: &RequestEnv, - prior_eff: &EffectSet<'b>, - mul_expr: &'b Expr, - type_errors: &mut Vec, - ) -> TypecheckAnswer<'b> { - // PANIC SAFETY by invariant on method - #[allow(clippy::panic)] - let ExprKind::MulByConst { arg, constant } = mul_expr.expr_kind() else { - panic!("`typecheck_mul` called with an expression kind other than `MulByConst`"); - }; - - let ans_arg = self.expect_type( - request_env, - prior_eff, - arg, - Type::primitive_long(), - type_errors, - ); - ans_arg.then_typecheck(|arg_expr_ty, _| { - TypecheckAnswer::success({ - ExprBuilder::with_data(Some(Type::primitive_long())) - .with_same_source_info(mul_expr) - .mul(arg_expr_ty, *constant) - }) - }) - } - /// Get the type for an `==` expression given the input types. fn type_of_equality<'b>( &self, diff --git a/cedar-policy-validator/src/typecheck/test_expr.rs b/cedar-policy-validator/src/typecheck/test_expr.rs index 817927d15..039aab0bd 100644 --- a/cedar-policy-validator/src/typecheck/test_expr.rs +++ b/cedar-policy-validator/src/typecheck/test_expr.rs @@ -1010,13 +1010,13 @@ fn neg_typecheck_fails() { #[test] fn mul_typechecks() { - let neg_expr = Expr::mul(Expr::val(1), 2); + let neg_expr = Expr::mul(Expr::val(1), Expr::val(2)); assert_typechecks_empty_schema(neg_expr, Type::primitive_long()); } #[test] fn mul_typecheck_fails() { - let neg_expr = Expr::mul(Expr::val("foo"), 2); + let neg_expr = Expr::mul(Expr::val("foo"), Expr::val(2)); assert_typecheck_fails_empty_schema( neg_expr, Type::primitive_long(), diff --git a/cedar-policy-validator/src/typecheck/test_type_annotation.rs b/cedar-policy-validator/src/typecheck/test_type_annotation.rs index 6fda6d1c8..857baf132 100644 --- a/cedar-policy-validator/src/typecheck/test_type_annotation.rs +++ b/cedar-policy-validator/src/typecheck/test_type_annotation.rs @@ -103,10 +103,10 @@ fn expr_typechecks_with_correct_annotation() { .not(ExprBuilder::with_data(Some(Type::singleton_boolean(false))).val(false)), ); assert_expr_has_annotated_ast( - &Expr::mul(Expr::val(3), 4), + &Expr::mul(Expr::val(3), Expr::val(4)), &ExprBuilder::with_data(Some(Type::primitive_long())).mul( ExprBuilder::with_data(Some(Type::primitive_long())).val(3), - 4, + ExprBuilder::with_data(Some(Type::primitive_long())).val(4), ), ); assert_expr_has_annotated_ast( diff --git a/cedar-policy/CHANGELOG.md b/cedar-policy/CHANGELOG.md index 1d4099044..2b7499359 100644 --- a/cedar-policy/CHANGELOG.md +++ b/cedar-policy/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## 2.4.5 - Coming soon + +### Changed + +- Implement [RFC 57](https://github.com/cedar-policy/rfcs/pull/57): policies can + now include multiplication of arbitrary expressions, not just multiplication of + an expression and a constant. + ## 2.4.4 Cedar Language Version: 2.1.3