Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
qdelamea-aneo committed Nov 2, 2024
1 parent 5db102c commit 3108111
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 0 deletions.
139 changes: 139 additions & 0 deletions bitest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import ast
import operator

import armonik.common
import pytest
import rich_click as click

from datetime import datetime, timedelta
from functools import reduce

from armonik.common import Partition, Result, Session, SessionStatus, Task
import armonik.common.filter as filters


DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S"


class FilterParam(click.ParamType):
name = "filter"

def __init__(self, filter_wrapper) -> None:
self._filter_wrapper = filter_wrapper

def convert(self, value: str, param, ctx) -> filters.Filter:
try:
if value:
value = value.replace("==", "=").replace("=", "==")
return _build_filter(ast.parse(value), obj_type=self._filter_wrapper)
return None
except SyntaxError as error:
self.fail(
f"{value} is not a valid filter. Syntax error: {str(error)}.",
param,
ctx,
)

def _build_filter(node, obj_type=None, target_type=None):
try:
match type(node):
case ast.Module:
if not len(node.body) == 1:
raise SyntaxError("Filter definition must be a single expression.")
return _build_filter(node.body[0], obj_type)
case ast.Expr:
return _build_filter(node.value, obj_type)
case ast.BoolOp:
match type(node.op):
case ast.And:
return reduce(operator.and_, [_build_filter(val, obj_type) for val in node.values])
case ast.Or:
return reduce(operator.or_, [_build_filter(val, obj_type) for val in node.values])
case _:
raise SyntaxError("Invalid boolean operator.")
case ast.Compare:
if len(node.comparators) != 1 or not (isinstance(node.left, ast.Name) or isinstance(node.left, ast.Attribute)) or not isinstance(node.comparators[0], ast.Constant):
raise SyntaxError()
filter_descriptor = _build_filter(node.left, obj_type)
return _build_filter(node.ops[0])(filter_descriptor, _build_filter(node.comparators[0], obj_type, type(filter_descriptor)))
case ast.Name:
if node.id in obj_type._fields.keys() and obj_type._fields[node.id][0] != filters.filter.FType.NA or filters.filter.FType.UNKNOWN:
return getattr(obj_type, node.id)
raise SyntaxError()
case ast.Constant:
match target_type:
case filters.StringFilter:
return str(node.value)
case filters.StatusFilter:
try:
return getattr(getattr(armonik.common, f"{obj_type}Status"), str(node.value).upper())
except AttributeError:
raise SyntaxError()
case filters.DateFilter:
try:
return datetime.strptime(str(node.value), DATETIME_FORMAT)
except ValueError:
raise SyntaxError()
case filters.DurationFilter:
raise NotImplementedError()
case filters.NumberFilter:
if isinstance(node.value, int):
return node.value
raise SyntaxError()
case filters.BooleanFilter:
match str(node.value).capitalize():
case "True":
return True
case "False":
return False
case _:
raise SyntaxError()
case filters.ArrayFilter:
if isinstance(node.value, list):
return node.value
raise SyntaxError()
case ast.Attribute:
match obj_type:
case "Session":
return Session.options[node.attr]
case "Task":
return Task.options[node.attr]
raise SyntaxError()
case ast.Eq:
return operator.eq
case ast.NotEq:
return operator.ne
case ast.Lt:
return operator.lt
case ast.LtE:
return operator.le
case ast.Gt:
return operator.gt
case ast.GtE:
return operator.ge
case ast.In:
return operator.contains
case ast.NotIn:
return lambda a, b: not operator.contains(a, b)
case _:
raise SyntaxError("Invalid")
except InterruptedError:#filters.FilterError:
raise SyntaxError()


# TODO : test array, int, duration filter types
@pytest.mark.parametrize(("obj_type", "filter_str", "filter"),
[
(filters.TaskFilter, "parent_task_ids == 'id'", Task.parent_task_ids == "id"),
# ("Session", "session_id == 'id'", Session.session_id == "id"),
# ("Session", "status == 'Running'", Session.status == SessionStatus.RUNNING),
# ("Session", "created_at > '2024-10-26 18:49:47'", Session.created_at > datetime(year=2024, month=10, day=26, hour=18, minute=49, second=47)),
# ("Session", "client_submission == 'true'", Session.client_submission == True),
# ("Session", "options.key == 'value'", Session.options["key"] == "value"),
# ("Session", "(session_id == 'id') and (status == 'Cancelled')", (Session.session_id == "id") & (Session.status == SessionStatus.CANCELLED)),
# ("Session", "(session_id == 'id') and (status == 'Cancelled') and (options.key == 'value')", (Session.session_id == "id") & (Session.status == SessionStatus.CANCELLED) & (Session.options["key"] == "value")),
# ("Session", "(session_id == 'id') and ((status == 'Cancelled') or (options.key == 'value'))", (Session.session_id == "id") & ((Session.status == SessionStatus.CANCELLED) | (Session.options["key"] == "value"))),
]
)
def test_build_filter(obj_type, filter_str, filter):
assert _build_filter(ast.parse(filter_str), obj_type=obj_type).to_dict() == filter.to_dict()
58 changes: 58 additions & 0 deletions filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import ast

import armonik.common as ak_obj

from armonik.common import Filter


class FilterTransformer(ast.NodeTransformer):
def __init__(self, api_obj_type: str):
self._api_obj_type = api_obj_type
self._const_types = {}

def generic_visit(self, node):
super().generic_visit(node)
match type(node):
case ast.Module:
assert len(node.body) == 1
assert isinstance(node.body[0], ast.Expr) or isinstance(node.body[0], ast.BinOp)
return node
case ast.Expr:
assert isinstance(node.value, ast.Compare)
return node
case ast.Compare:
assert isinstance(node.left, ast.Name)
assert any([isinstance(node.ops[0], op_type) for op_type in [ast.Eq, ast.Lt, ast.LtE, ast.Gt, ast.GtE]])
assert len(node.ops) == 1
assert isinstance(node.comparators[0], ast.Constant)

obj = getattr(ak_obj, self._api_obj_type)
assert node.left.id in [m for m in obj.__dict__.keys() if isinstance(getattr(obj, m), Filter)]
self._const_types[node.comparators[0].value] = getattr(obj, node.left.id)
return node
case ast.Name:
return ast.Attribute(
value=ast.Name(id=self._api_obj_type, ctx=ast.Load()),
attr=node.id,
ctx=ast.Load()
)
case ast.Constant:
from armonik.common.filter import StringFilter, StatusFilter, DateFilter, DurationFilter, NumberFilter, BooleanFilter, ArrayFilter
filter_type = self._const_types[node.value]
case ast.Load:
return node
case _:
raise SyntaxError(f"Invalid node {node}".)
return


def from_str_to_filter(filter_str: str) -> Filter:
tree = ast.parse(filter_str)
tree = FilterTransformer("Session").visit(tree)
global_ctx = {}
local_ctx = {}
exec(f"filter={compile(tree, mode="exec")}", global_ctx, local_ctx)
return local_ctx["filter"]


assert from_str_to_filter("session_id == 'id'") == (Session.session_id == "id")

0 comments on commit 3108111

Please sign in to comment.