Skip to content

Commit

Permalink
chore: add support for member fns
Browse files Browse the repository at this point in the history
  • Loading branch information
bodymindarts committed Oct 9, 2024
1 parent ee0c71c commit 7895378
Show file tree
Hide file tree
Showing 11 changed files with 223 additions and 59 deletions.
23 changes: 23 additions & 0 deletions cala-cel-interpreter/src/builtins/decimal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use rust_decimal::Decimal;

use crate::{cel_type::*, error::*, value::*};

use super::assert_arg;

pub fn cast(args: Vec<CelValue>) -> Result<CelValue, CelError> {
match args.first() {
Some(CelValue::Decimal(d)) => Ok(CelValue::Decimal(*d)),
Some(CelValue::String(s)) => Ok(CelValue::Decimal(
s.parse()
.map_err(|e| CelError::DecimalError(format!("{e:?}")))?,
)),
Some(v) => Err(CelError::BadType(CelType::Decimal, CelType::from(v))),
None => Err(CelError::MissingArgument),
}
}

pub fn add(args: Vec<CelValue>) -> Result<CelValue, CelError> {
let a: &Decimal = assert_arg(args.first())?;
let b: &Decimal = assert_arg(args.get(1))?;
Ok(CelValue::Decimal(a + b))
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
pub(crate) mod decimal;
pub(crate) mod timestamp;

use chrono::{NaiveDate, Utc};

use std::sync::Arc;

use super::{cel_type::*, value::*};
use super::value::*;
use crate::error::*;

pub(crate) fn date(args: Vec<CelValue>) -> Result<CelValue, CelError> {
Expand All @@ -22,30 +25,6 @@ pub(crate) fn uuid(args: Vec<CelValue>) -> Result<CelValue, CelError> {
))
}

pub(crate) mod decimal {
use rust_decimal::Decimal;

use super::*;

pub fn cast(args: Vec<CelValue>) -> Result<CelValue, CelError> {
match args.first() {
Some(CelValue::Decimal(d)) => Ok(CelValue::Decimal(*d)),
Some(CelValue::String(s)) => Ok(CelValue::Decimal(
s.parse()
.map_err(|e| CelError::DecimalError(format!("{e:?}")))?,
)),
Some(v) => Err(CelError::BadType(CelType::Decimal, CelType::from(v))),
None => Err(CelError::MissingArgument),
}
}

pub fn add(args: Vec<CelValue>) -> Result<CelValue, CelError> {
let a: &Decimal = assert_arg(args.first())?;
let b: &Decimal = assert_arg(args.get(1))?;
Ok(CelValue::Decimal(a + b))
}
}

fn assert_arg<'a, T: TryFrom<&'a CelValue, Error = CelError>>(
arg: Option<&'a CelValue>,
) -> Result<T, CelError> {
Expand Down
23 changes: 23 additions & 0 deletions cala-cel-interpreter/src/builtins/timestamp.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use crate::{cel_type::*, error::*, value::*};

use super::assert_arg;

pub fn cast(args: Vec<CelValue>) -> Result<CelValue, CelError> {
match args.first() {
Some(CelValue::String(s)) => Ok(CelValue::Timestamp(
s.parse()
.map_err(|e| CelError::TimestampError(format!("{e:?}")))?,
)),
Some(v) => Err(CelError::BadType(CelType::Timestamp, CelType::from(v))),
None => Err(CelError::MissingArgument),
}
}

pub fn format(target: &CelValue, args: Vec<CelValue>) -> Result<CelValue, CelError> {
if let CelValue::Timestamp(ts) = target {
let format: std::sync::Arc<String> = assert_arg(args.first())?;
Ok(CelValue::String(ts.format(&format).to_string().into()))
} else {
Err(CelError::BadType(CelType::Timestamp, CelType::from(target)))
}
}
20 changes: 20 additions & 0 deletions cala-cel-interpreter/src/cel_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,23 @@ pub enum CelType {
Uuid,
Decimal,
}

impl CelType {
pub(crate) fn package_name(&self) -> &'static str {
match self {
CelType::Map => "map",
CelType::List => "list",
CelType::Int => "int",
CelType::UInt => "uint",
CelType::Double => "double",
CelType::String => "string",
CelType::Bytes => "bytes",
CelType::Bool => "bool",
CelType::Null => "null",
CelType::Date => "date",
CelType::Timestamp => "timestamp",
CelType::Uuid => "uuid",
CelType::Decimal => "decimal",
}
}
}
5 changes: 3 additions & 2 deletions cala-cel-interpreter/src/context/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::builtins;
use super::*;

lazy_static! {
pub static ref CEL_CONTEXT: CelContext = {
pub static ref CEL_PACKAGE: CelPackage = {
let mut idents = HashMap::new();
idents.insert(
SELF_PACKAGE_NAME,
Expand All @@ -17,6 +17,7 @@ lazy_static! {
Cow::Borrowed("Add"),
ContextItem::Function(Box::new(builtins::decimal::add)),
);
CelContext { idents }

CelPackage::new(CelContext { idents }, HashMap::new())
};
}
40 changes: 32 additions & 8 deletions cala-cel-interpreter/src/context/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
mod decimal;
mod package;
mod timestamp;

use std::{borrow::Cow, collections::HashMap};

use crate::{builtins, error::*, value::*};
use crate::{builtins, cel_type::CelType, error::*, value::*};

use package::CelPackage;

const SELF_PACKAGE_NAME: Cow<'static, str> = Cow::Borrowed("self");

type CelFunction = Box<dyn Fn(Vec<CelValue>) -> Result<CelValue, CelError> + Sync>;
pub(crate) type CelMemberFunction =
Box<dyn Fn(&CelValue, Vec<CelValue>) -> Result<CelValue, CelError> + Sync>;

#[derive(Debug)]
pub struct CelContext {
idents: HashMap<Cow<'static, str>, ContextItem>,
Expand All @@ -30,20 +37,37 @@ impl CelContext {
);
idents.insert(
Cow::Borrowed("decimal"),
ContextItem::Package(&decimal::CEL_CONTEXT),
ContextItem::Package(&decimal::CEL_PACKAGE),
);
Self { idents }
}

pub(crate) fn package_self(&self) -> Result<&ContextItem, CelError> {
self.lookup(&SELF_PACKAGE_NAME)
idents.insert(
Cow::Borrowed("timestamp"),
ContextItem::Package(&timestamp::CEL_PACKAGE),
);

Self { idents }
}

pub(crate) fn lookup(&self, name: &str) -> Result<&ContextItem, CelError> {
pub(crate) fn lookup_ident(&self, name: &str) -> Result<&ContextItem, CelError> {
self.idents
.get(name)
.ok_or_else(|| CelError::UnknownIdent(name.to_string()))
}

pub(crate) fn lookup_member_fn(
&self,
value: &CelValue,
name: &str,
) -> Result<&CelMemberFunction, CelError> {
let package_name = CelType::from(value).package_name();
let package = if let Some(ContextItem::Package(package)) = self.idents.get(package_name) {
package
} else {
return Err(CelError::UnknownPackage(package_name));
};

package.lookup_member(value, name)
}
}
impl Default for CelContext {
fn default() -> Self {
Expand All @@ -54,7 +78,7 @@ impl Default for CelContext {
pub(crate) enum ContextItem {
Value(CelValue),
Function(CelFunction),
Package(&'static CelContext),
Package(&'static CelPackage),
}

impl std::fmt::Debug for ContextItem {
Expand Down
36 changes: 36 additions & 0 deletions cala-cel-interpreter/src/context/package.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use super::*;

pub struct CelPackage {
nested_ctx: CelContext,
member_fns: HashMap<&'static str, CelMemberFunction>,
}

impl CelPackage {
pub fn new(
nested_ctx: CelContext,
member_fns: HashMap<&'static str, CelMemberFunction>,
) -> Self {
Self {
nested_ctx,
member_fns,
}
}

pub(crate) fn package_self(&self) -> Result<&ContextItem, CelError> {
self.nested_ctx.lookup_ident(&SELF_PACKAGE_NAME)
}

pub(crate) fn lookup(&self, name: &str) -> Result<&ContextItem, CelError> {
self.nested_ctx.lookup_ident(name)
}

pub(crate) fn lookup_member(
&self,
value: &CelValue,
name: &str,
) -> Result<&CelMemberFunction, CelError> {
self.member_fns
.get(name)
.ok_or_else(|| CelError::UnknownAttribute(CelType::from(value), name.to_string()))
}
}
22 changes: 22 additions & 0 deletions cala-cel-interpreter/src/context/timestamp.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use lazy_static::lazy_static;

use std::collections::HashMap;

use crate::builtins;

use super::*;

lazy_static! {
pub static ref CEL_PACKAGE: CelPackage = {
let mut idents = HashMap::new();
idents.insert(
SELF_PACKAGE_NAME,
ContextItem::Function(Box::new(builtins::timestamp::cast)),
);

let mut member_fns: HashMap<_, CelMemberFunction> = HashMap::new();
member_fns.insert("format", Box::new(builtins::timestamp::format));

CelPackage::new(CelContext { idents }, member_fns)
};
}
6 changes: 6 additions & 0 deletions cala-cel-interpreter/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ pub enum CelError {
BadType(CelType, CelType),
#[error("CelError - UnknownIdentifier: {0}")]
UnknownIdent(String),
#[error("CelError - UnknownPackage: No package installed for type '{0}'")]
UnknownPackage(&'static str),
#[error("CelError - UnknownAttribute: No attribute '{1}' on type {0:?}")]
UnknownAttribute(CelType, String),
#[error("CelError - IllegalTarget")]
IllegalTarget,
#[error("CelError - MissingArgument")]
Expand All @@ -33,6 +37,8 @@ pub enum CelError {
UuidError(String),
#[error("CelError - DecimalError: {0}")]
DecimalError(String),
#[error("CelError - TimestampError: {0}")]
TimestampError(String),
#[error("CelError - NoMatchingOverload: {0}")]
NoMatchingOverload(String),
#[error("CelError - Unexpected: {0}")]
Expand Down
32 changes: 29 additions & 3 deletions cala-cel-interpreter/src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ impl std::fmt::Display for CelExpression {
}
}

#[derive(Debug)]
enum EvalType<'a> {
Value(CelValue),
ContextItem(&'a ContextItem),
MemberFn(&'a CelValue, &'a CelMemberFunction),
}

impl<'a> EvalType<'a> {
Expand Down Expand Up @@ -124,7 +124,7 @@ fn evaluate_expression_inner<'a>(
}
Ok(EvalType::Value(CelValue::from(map)))
}
Ident(name) => Ok(EvalType::ContextItem(ctx.lookup(name)?)),
Ident(name) => Ok(EvalType::ContextItem(ctx.lookup_ident(name)?)),
Literal(val) => Ok(EvalType::Value(CelValue::from(val))),
Arithmetic(op, left, right) => {
let left = evaluate_expression(left, ctx)?;
Expand Down Expand Up @@ -156,13 +156,18 @@ fn evaluate_member<'a>(
use ast::Member::*;
match member {
Attribute(name) => match target {
EvalType::Value(CelValue::Map(map)) => Ok(EvalType::Value(map.get(name))),
EvalType::Value(CelValue::Map(map)) if map.contains_key(name) => {
Ok(EvalType::Value(map.get(name)))
}
EvalType::ContextItem(ContextItem::Value(CelValue::Map(map))) => {
Ok(EvalType::Value(map.get(name)))
}
EvalType::ContextItem(ContextItem::Package(p)) => {
Ok(EvalType::ContextItem(p.lookup(name)?))
}
EvalType::ContextItem(ContextItem::Value(v)) => {
Ok(EvalType::MemberFn(v, ctx.lookup_member_fn(v, name)?))
}
_ => Err(CelError::IllegalTarget),
},
FunctionCall(exprs) => match target {
Expand All @@ -176,6 +181,13 @@ fn evaluate_member<'a>(
EvalType::ContextItem(ContextItem::Package(p)) => {
evaluate_member(EvalType::ContextItem(p.package_self()?), member, ctx)
}
EvalType::MemberFn(v, f) => {
let mut args = Vec::new();
for e in exprs {
args.push(evaluate_expression(e, ctx)?.try_value()?)
}
Ok(EvalType::Value(f(v, args)?))
}
_ => Err(CelError::IllegalTarget),
},
_ => unimplemented!(),
Expand Down Expand Up @@ -419,4 +431,18 @@ mod tests {
assert_eq!(expression.evaluate(&context)?, CelValue::Decimal(3.into()));
Ok(())
}

#[test]
fn function_on_timestamp() -> anyhow::Result<()> {
use chrono::{DateTime, Utc};

let time: DateTime<Utc> = "1940-12-21T00:00:00Z".parse().unwrap();
let mut context = CelContext::new();
context.add_variable("now", time);

let expression = "now.format('%d/%m/%Y')".parse::<CelExpression>().unwrap();
assert_eq!(expression.evaluate(&context)?, CelValue::from("21/12/1940"));

Ok(())
}
}
Loading

0 comments on commit 7895378

Please sign in to comment.