From 7895378ee0116c87d9239575b94e3546a838496a Mon Sep 17 00:00:00 2001 From: bodymindarts Date: Wed, 9 Oct 2024 15:03:49 +0200 Subject: [PATCH 1/2] chore: add support for member fns --- cala-cel-interpreter/src/builtins/decimal.rs | 23 ++++++++++ .../src/{builtins.rs => builtins/mod.rs} | 29 ++---------- .../src/builtins/timestamp.rs | 23 ++++++++++ cala-cel-interpreter/src/cel_type.rs | 20 ++++++++ cala-cel-interpreter/src/context/decimal.rs | 5 +- cala-cel-interpreter/src/context/mod.rs | 40 ++++++++++++---- cala-cel-interpreter/src/context/package.rs | 36 +++++++++++++++ cala-cel-interpreter/src/context/timestamp.rs | 22 +++++++++ cala-cel-interpreter/src/error.rs | 6 +++ cala-cel-interpreter/src/interpreter.rs | 32 +++++++++++-- cala-cel-interpreter/src/value.rs | 46 ++++++++++--------- 11 files changed, 223 insertions(+), 59 deletions(-) create mode 100644 cala-cel-interpreter/src/builtins/decimal.rs rename cala-cel-interpreter/src/{builtins.rs => builtins/mod.rs} (50%) create mode 100644 cala-cel-interpreter/src/builtins/timestamp.rs create mode 100644 cala-cel-interpreter/src/context/package.rs create mode 100644 cala-cel-interpreter/src/context/timestamp.rs diff --git a/cala-cel-interpreter/src/builtins/decimal.rs b/cala-cel-interpreter/src/builtins/decimal.rs new file mode 100644 index 00000000..1e3cd020 --- /dev/null +++ b/cala-cel-interpreter/src/builtins/decimal.rs @@ -0,0 +1,23 @@ +use rust_decimal::Decimal; + +use crate::{cel_type::*, error::*, value::*}; + +use super::assert_arg; + +pub fn cast(args: Vec) -> Result { + match args.first() { + Some(CelValue::Decimal(d)) => Ok(CelValue::Decimal(*d)), + Some(CelValue::String(s)) => Ok(CelValue::Decimal( + s.parse() + .map_err(|e| CelError::DecimalError(format!("{e:?}")))?, + )), + Some(v) => Err(CelError::BadType(CelType::Decimal, CelType::from(v))), + None => Err(CelError::MissingArgument), + } +} + +pub fn add(args: Vec) -> Result { + let a: &Decimal = assert_arg(args.first())?; + let b: &Decimal = assert_arg(args.get(1))?; + Ok(CelValue::Decimal(a + b)) +} diff --git a/cala-cel-interpreter/src/builtins.rs b/cala-cel-interpreter/src/builtins/mod.rs similarity index 50% rename from cala-cel-interpreter/src/builtins.rs rename to cala-cel-interpreter/src/builtins/mod.rs index e5db2fe0..84ea5b77 100644 --- a/cala-cel-interpreter/src/builtins.rs +++ b/cala-cel-interpreter/src/builtins/mod.rs @@ -1,8 +1,11 @@ +pub(crate) mod decimal; +pub(crate) mod timestamp; + use chrono::{NaiveDate, Utc}; use std::sync::Arc; -use super::{cel_type::*, value::*}; +use super::value::*; use crate::error::*; pub(crate) fn date(args: Vec) -> Result { @@ -22,30 +25,6 @@ pub(crate) fn uuid(args: Vec) -> Result { )) } -pub(crate) mod decimal { - use rust_decimal::Decimal; - - use super::*; - - pub fn cast(args: Vec) -> Result { - match args.first() { - Some(CelValue::Decimal(d)) => Ok(CelValue::Decimal(*d)), - Some(CelValue::String(s)) => Ok(CelValue::Decimal( - s.parse() - .map_err(|e| CelError::DecimalError(format!("{e:?}")))?, - )), - Some(v) => Err(CelError::BadType(CelType::Decimal, CelType::from(v))), - None => Err(CelError::MissingArgument), - } - } - - pub fn add(args: Vec) -> Result { - let a: &Decimal = assert_arg(args.first())?; - let b: &Decimal = assert_arg(args.get(1))?; - Ok(CelValue::Decimal(a + b)) - } -} - fn assert_arg<'a, T: TryFrom<&'a CelValue, Error = CelError>>( arg: Option<&'a CelValue>, ) -> Result { diff --git a/cala-cel-interpreter/src/builtins/timestamp.rs b/cala-cel-interpreter/src/builtins/timestamp.rs new file mode 100644 index 00000000..d4365ce7 --- /dev/null +++ b/cala-cel-interpreter/src/builtins/timestamp.rs @@ -0,0 +1,23 @@ +use crate::{cel_type::*, error::*, value::*}; + +use super::assert_arg; + +pub fn cast(args: Vec) -> Result { + match args.first() { + Some(CelValue::String(s)) => Ok(CelValue::Timestamp( + s.parse() + .map_err(|e| CelError::TimestampError(format!("{e:?}")))?, + )), + Some(v) => Err(CelError::BadType(CelType::Timestamp, CelType::from(v))), + None => Err(CelError::MissingArgument), + } +} + +pub fn format(target: &CelValue, args: Vec) -> Result { + if let CelValue::Timestamp(ts) = target { + let format: std::sync::Arc = assert_arg(args.first())?; + Ok(CelValue::String(ts.format(&format).to_string().into())) + } else { + Err(CelError::BadType(CelType::Timestamp, CelType::from(target))) + } +} diff --git a/cala-cel-interpreter/src/cel_type.rs b/cala-cel-interpreter/src/cel_type.rs index cadf7599..aa3f6a5f 100644 --- a/cala-cel-interpreter/src/cel_type.rs +++ b/cala-cel-interpreter/src/cel_type.rs @@ -17,3 +17,23 @@ pub enum CelType { Uuid, Decimal, } + +impl CelType { + pub(crate) fn package_name(&self) -> &'static str { + match self { + CelType::Map => "map", + CelType::List => "list", + CelType::Int => "int", + CelType::UInt => "uint", + CelType::Double => "double", + CelType::String => "string", + CelType::Bytes => "bytes", + CelType::Bool => "bool", + CelType::Null => "null", + CelType::Date => "date", + CelType::Timestamp => "timestamp", + CelType::Uuid => "uuid", + CelType::Decimal => "decimal", + } + } +} diff --git a/cala-cel-interpreter/src/context/decimal.rs b/cala-cel-interpreter/src/context/decimal.rs index d14e9b03..21db864d 100644 --- a/cala-cel-interpreter/src/context/decimal.rs +++ b/cala-cel-interpreter/src/context/decimal.rs @@ -7,7 +7,7 @@ use crate::builtins; use super::*; lazy_static! { - pub static ref CEL_CONTEXT: CelContext = { + pub static ref CEL_PACKAGE: CelPackage = { let mut idents = HashMap::new(); idents.insert( SELF_PACKAGE_NAME, @@ -17,6 +17,7 @@ lazy_static! { Cow::Borrowed("Add"), ContextItem::Function(Box::new(builtins::decimal::add)), ); - CelContext { idents } + + CelPackage::new(CelContext { idents }, HashMap::new()) }; } diff --git a/cala-cel-interpreter/src/context/mod.rs b/cala-cel-interpreter/src/context/mod.rs index 219b57c5..72727bbb 100644 --- a/cala-cel-interpreter/src/context/mod.rs +++ b/cala-cel-interpreter/src/context/mod.rs @@ -1,12 +1,19 @@ mod decimal; +mod package; +mod timestamp; use std::{borrow::Cow, collections::HashMap}; -use crate::{builtins, error::*, value::*}; +use crate::{builtins, cel_type::CelType, error::*, value::*}; + +use package::CelPackage; const SELF_PACKAGE_NAME: Cow<'static, str> = Cow::Borrowed("self"); type CelFunction = Box) -> Result + Sync>; +pub(crate) type CelMemberFunction = + Box) -> Result + Sync>; + #[derive(Debug)] pub struct CelContext { idents: HashMap, ContextItem>, @@ -30,20 +37,37 @@ impl CelContext { ); idents.insert( Cow::Borrowed("decimal"), - ContextItem::Package(&decimal::CEL_CONTEXT), + ContextItem::Package(&decimal::CEL_PACKAGE), ); - Self { idents } - } - pub(crate) fn package_self(&self) -> Result<&ContextItem, CelError> { - self.lookup(&SELF_PACKAGE_NAME) + idents.insert( + Cow::Borrowed("timestamp"), + ContextItem::Package(×tamp::CEL_PACKAGE), + ); + + Self { idents } } - pub(crate) fn lookup(&self, name: &str) -> Result<&ContextItem, CelError> { + pub(crate) fn lookup_ident(&self, name: &str) -> Result<&ContextItem, CelError> { self.idents .get(name) .ok_or_else(|| CelError::UnknownIdent(name.to_string())) } + + pub(crate) fn lookup_member_fn( + &self, + value: &CelValue, + name: &str, + ) -> Result<&CelMemberFunction, CelError> { + let package_name = CelType::from(value).package_name(); + let package = if let Some(ContextItem::Package(package)) = self.idents.get(package_name) { + package + } else { + return Err(CelError::UnknownPackage(package_name)); + }; + + package.lookup_member(value, name) + } } impl Default for CelContext { fn default() -> Self { @@ -54,7 +78,7 @@ impl Default for CelContext { pub(crate) enum ContextItem { Value(CelValue), Function(CelFunction), - Package(&'static CelContext), + Package(&'static CelPackage), } impl std::fmt::Debug for ContextItem { diff --git a/cala-cel-interpreter/src/context/package.rs b/cala-cel-interpreter/src/context/package.rs new file mode 100644 index 00000000..2d2be639 --- /dev/null +++ b/cala-cel-interpreter/src/context/package.rs @@ -0,0 +1,36 @@ +use super::*; + +pub struct CelPackage { + nested_ctx: CelContext, + member_fns: HashMap<&'static str, CelMemberFunction>, +} + +impl CelPackage { + pub fn new( + nested_ctx: CelContext, + member_fns: HashMap<&'static str, CelMemberFunction>, + ) -> Self { + Self { + nested_ctx, + member_fns, + } + } + + pub(crate) fn package_self(&self) -> Result<&ContextItem, CelError> { + self.nested_ctx.lookup_ident(&SELF_PACKAGE_NAME) + } + + pub(crate) fn lookup(&self, name: &str) -> Result<&ContextItem, CelError> { + self.nested_ctx.lookup_ident(name) + } + + pub(crate) fn lookup_member( + &self, + value: &CelValue, + name: &str, + ) -> Result<&CelMemberFunction, CelError> { + self.member_fns + .get(name) + .ok_or_else(|| CelError::UnknownAttribute(CelType::from(value), name.to_string())) + } +} diff --git a/cala-cel-interpreter/src/context/timestamp.rs b/cala-cel-interpreter/src/context/timestamp.rs new file mode 100644 index 00000000..1acb33f9 --- /dev/null +++ b/cala-cel-interpreter/src/context/timestamp.rs @@ -0,0 +1,22 @@ +use lazy_static::lazy_static; + +use std::collections::HashMap; + +use crate::builtins; + +use super::*; + +lazy_static! { + pub static ref CEL_PACKAGE: CelPackage = { + let mut idents = HashMap::new(); + idents.insert( + SELF_PACKAGE_NAME, + ContextItem::Function(Box::new(builtins::timestamp::cast)), + ); + + let mut member_fns: HashMap<_, CelMemberFunction> = HashMap::new(); + member_fns.insert("format", Box::new(builtins::timestamp::format)); + + CelPackage::new(CelContext { idents }, member_fns) + }; +} diff --git a/cala-cel-interpreter/src/error.rs b/cala-cel-interpreter/src/error.rs index 6f2aeef5..556b132b 100644 --- a/cala-cel-interpreter/src/error.rs +++ b/cala-cel-interpreter/src/error.rs @@ -21,6 +21,10 @@ pub enum CelError { BadType(CelType, CelType), #[error("CelError - UnknownIdentifier: {0}")] UnknownIdent(String), + #[error("CelError - UnknownPackage: No package installed for type '{0}'")] + UnknownPackage(&'static str), + #[error("CelError - UnknownAttribute: No attribute '{1}' on type {0:?}")] + UnknownAttribute(CelType, String), #[error("CelError - IllegalTarget")] IllegalTarget, #[error("CelError - MissingArgument")] @@ -33,6 +37,8 @@ pub enum CelError { UuidError(String), #[error("CelError - DecimalError: {0}")] DecimalError(String), + #[error("CelError - TimestampError: {0}")] + TimestampError(String), #[error("CelError - NoMatchingOverload: {0}")] NoMatchingOverload(String), #[error("CelError - Unexpected: {0}")] diff --git a/cala-cel-interpreter/src/interpreter.rs b/cala-cel-interpreter/src/interpreter.rs index 5c027469..bb69dcfe 100755 --- a/cala-cel-interpreter/src/interpreter.rs +++ b/cala-cel-interpreter/src/interpreter.rs @@ -44,10 +44,10 @@ impl std::fmt::Display for CelExpression { } } -#[derive(Debug)] enum EvalType<'a> { Value(CelValue), ContextItem(&'a ContextItem), + MemberFn(&'a CelValue, &'a CelMemberFunction), } impl<'a> EvalType<'a> { @@ -124,7 +124,7 @@ fn evaluate_expression_inner<'a>( } Ok(EvalType::Value(CelValue::from(map))) } - Ident(name) => Ok(EvalType::ContextItem(ctx.lookup(name)?)), + Ident(name) => Ok(EvalType::ContextItem(ctx.lookup_ident(name)?)), Literal(val) => Ok(EvalType::Value(CelValue::from(val))), Arithmetic(op, left, right) => { let left = evaluate_expression(left, ctx)?; @@ -156,13 +156,18 @@ fn evaluate_member<'a>( use ast::Member::*; match member { Attribute(name) => match target { - EvalType::Value(CelValue::Map(map)) => Ok(EvalType::Value(map.get(name))), + EvalType::Value(CelValue::Map(map)) if map.contains_key(name) => { + Ok(EvalType::Value(map.get(name))) + } EvalType::ContextItem(ContextItem::Value(CelValue::Map(map))) => { Ok(EvalType::Value(map.get(name))) } EvalType::ContextItem(ContextItem::Package(p)) => { Ok(EvalType::ContextItem(p.lookup(name)?)) } + EvalType::ContextItem(ContextItem::Value(v)) => { + Ok(EvalType::MemberFn(v, ctx.lookup_member_fn(v, name)?)) + } _ => Err(CelError::IllegalTarget), }, FunctionCall(exprs) => match target { @@ -176,6 +181,13 @@ fn evaluate_member<'a>( EvalType::ContextItem(ContextItem::Package(p)) => { evaluate_member(EvalType::ContextItem(p.package_self()?), member, ctx) } + EvalType::MemberFn(v, f) => { + let mut args = Vec::new(); + for e in exprs { + args.push(evaluate_expression(e, ctx)?.try_value()?) + } + Ok(EvalType::Value(f(v, args)?)) + } _ => Err(CelError::IllegalTarget), }, _ => unimplemented!(), @@ -419,4 +431,18 @@ mod tests { assert_eq!(expression.evaluate(&context)?, CelValue::Decimal(3.into())); Ok(()) } + + #[test] + fn function_on_timestamp() -> anyhow::Result<()> { + use chrono::{DateTime, Utc}; + + let time: DateTime = "1940-12-21T00:00:00Z".parse().unwrap(); + let mut context = CelContext::new(); + context.add_variable("now", time); + + let expression = "now.format('%d/%m/%Y')".parse::().unwrap(); + assert_eq!(expression.evaluate(&context)?, CelValue::from("21/12/1940")); + + Ok(()) + } } diff --git a/cala-cel-interpreter/src/value.rs b/cala-cel-interpreter/src/value.rs index 7ad65222..724114b9 100644 --- a/cala-cel-interpreter/src/value.rs +++ b/cala-cel-interpreter/src/value.rs @@ -47,27 +47,6 @@ pub struct CelMap { inner: HashMap, } -#[derive(Debug, PartialEq)] -pub struct CelArray { - inner: Vec, -} - -impl CelArray { - pub fn new() -> Self { - Self { inner: Vec::new() } - } - - pub fn push(&mut self, elem: impl Into) { - self.inner.push(elem.into()); - } -} - -impl Default for CelArray { - fn default() -> Self { - Self::new() - } -} - impl CelMap { pub fn new() -> Self { Self { @@ -85,6 +64,10 @@ impl CelMap { .cloned() .unwrap_or(CelValue::Null) } + + pub fn contains_key(&self, key: impl Into) -> bool { + self.inner.contains_key(&key.into()) + } } impl Default for CelMap { @@ -109,6 +92,27 @@ impl From for CelValue { } } +#[derive(Debug, PartialEq)] +pub struct CelArray { + inner: Vec, +} + +impl CelArray { + pub fn new() -> Self { + Self { inner: Vec::new() } + } + + pub fn push(&mut self, elem: impl Into) { + self.inner.push(elem.into()); + } +} + +impl Default for CelArray { + fn default() -> Self { + Self::new() + } +} + impl From for CelValue { fn from(i: i64) -> Self { CelValue::Int(i) From 662092224121692cb0dcb777398f2f86bb50a724 Mon Sep 17 00:00:00 2001 From: bodymindarts Date: Wed, 9 Oct 2024 15:22:26 +0200 Subject: [PATCH 2/2] refactor: use try_into_xxx on EvalType to avoid clone --- cala-cel-interpreter/src/interpreter.rs | 32 ++++++++++++------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/cala-cel-interpreter/src/interpreter.rs b/cala-cel-interpreter/src/interpreter.rs index bb69dcfe..bf8d342b 100755 --- a/cala-cel-interpreter/src/interpreter.rs +++ b/cala-cel-interpreter/src/interpreter.rs @@ -51,7 +51,7 @@ enum EvalType<'a> { } impl<'a> EvalType<'a> { - fn try_bool(&self) -> Result { + fn try_into_bool(self) -> Result { if let EvalType::Value(val) = self { val.try_bool() } else { @@ -61,13 +61,13 @@ impl<'a> EvalType<'a> { } } - fn try_key(&self) -> Result { + fn try_into_key(self) -> Result { if let EvalType::Value(val) = self { match val { - CelValue::Int(i) => Ok(CelKey::Int(*i)), - CelValue::UInt(u) => Ok(CelKey::UInt(*u)), - CelValue::Bool(b) => Ok(CelKey::Bool(*b)), - CelValue::String(s) => Ok(CelKey::String(s.clone())), + CelValue::Int(i) => Ok(CelKey::Int(i)), + CelValue::UInt(u) => Ok(CelKey::UInt(u)), + CelValue::Bool(b) => Ok(CelKey::Bool(b)), + CelValue::String(s) => Ok(CelKey::String(s)), _ => Err(CelError::Unexpected( "Expression didn't resolve to a valid key".to_string(), )), @@ -79,9 +79,9 @@ impl<'a> EvalType<'a> { } } - fn try_value(&self) -> Result { + fn try_into_value(self) -> Result { if let EvalType::Value(val) = self { - Ok(val.clone()) + Ok(val) } else { Err(CelError::Unexpected("Couldn't unwrap value".to_string())) } @@ -105,7 +105,7 @@ fn evaluate_expression_inner<'a>( use Expression::*; match expr { Ternary(cond, left, right) => { - if evaluate_expression(cond, ctx)?.try_bool()? { + if evaluate_expression(cond, ctx)?.try_into_bool()? { evaluate_expression(left, ctx) } else { evaluate_expression(right, ctx) @@ -120,7 +120,7 @@ fn evaluate_expression_inner<'a>( for (k, v) in entries { let key = evaluate_expression(k, ctx)?; let value = evaluate_expression(v, ctx)?; - map.insert(key.try_key()?, value.try_value()?) + map.insert(key.try_into_key()?, value.try_into_value()?) } Ok(EvalType::Value(CelValue::from(map))) } @@ -131,8 +131,8 @@ fn evaluate_expression_inner<'a>( let right = evaluate_expression(right, ctx)?; Ok(EvalType::Value(evaluate_arithmetic( *op, - left.try_value()?, - right.try_value()?, + left.try_into_value()?, + right.try_into_value()?, )?)) } Relation(op, left, right) => { @@ -140,8 +140,8 @@ fn evaluate_expression_inner<'a>( let right = evaluate_expression(right, ctx)?; Ok(EvalType::Value(evaluate_relation( *op, - left.try_value()?, - right.try_value()?, + left.try_into_value()?, + right.try_into_value()?, )?)) } e => Err(CelError::Unexpected(format!("unimplemented {e:?}"))), @@ -174,7 +174,7 @@ fn evaluate_member<'a>( EvalType::ContextItem(ContextItem::Function(f)) => { let mut args = Vec::new(); for e in exprs { - args.push(evaluate_expression(e, ctx)?.try_value()?) + args.push(evaluate_expression(e, ctx)?.try_into_value()?) } Ok(EvalType::Value(f(args)?)) } @@ -184,7 +184,7 @@ fn evaluate_member<'a>( EvalType::MemberFn(v, f) => { let mut args = Vec::new(); for e in exprs { - args.push(evaluate_expression(e, ctx)?.try_value()?) + args.push(evaluate_expression(e, ctx)?.try_into_value()?) } Ok(EvalType::Value(f(v, args)?)) }