From 76ad8e82ff79649104e075ce6fd2346089c07185 Mon Sep 17 00:00:00 2001 From: Jaume Date: Fri, 8 Nov 2024 19:10:06 +0100 Subject: [PATCH] The rule engine is working --- src/swift-server-tests/rules.tests.swift | 261 ++++++++++++++++-- src/swift-server/configure.swift | 2 +- src/swift-server/core/cache.swift | 86 ++++++ src/swift-server/core/error.swift | 10 +- .../migrations/initial-migration.swift | 2 +- .../rules/new-transaction.job.swift | 4 +- src/swift-server/rules/rule-engine.swift | 106 +++++++ src/swift-server/rules/rules.models.swift | 84 ++++-- src/swift-server/rules/rules.service.swift | 23 +- 9 files changed, 503 insertions(+), 75 deletions(-) create mode 100644 src/swift-server/core/cache.swift create mode 100644 src/swift-server/rules/rule-engine.swift diff --git a/src/swift-server-tests/rules.tests.swift b/src/swift-server-tests/rules.tests.swift index d0483db..e8ddf4d 100644 --- a/src/swift-server-tests/rules.tests.swift +++ b/src/swift-server-tests/rules.tests.swift @@ -1,4 +1,5 @@ import Fluent +import QueuesFluentDriver import XCTQueues import XCTVapor import XCTest @@ -8,23 +9,24 @@ import XCTest struct BaseCondition { let operation: ConditionOperation let valueStr: String? - let valueFloat: Double? + let valueDouble: Double? init(operation: ConditionOperation, valueStr: String) { self.operation = operation self.valueStr = valueStr - self.valueFloat = nil + self.valueDouble = nil } - init(operation: ConditionOperation, valueFloat: Double) { + init(operation: ConditionOperation, valueDouble: Double) { self.operation = operation self.valueStr = nil - self.valueFloat = valueFloat + self.valueDouble = valueDouble } } final class RulesTests: AbstractBaseTestsClass { - private func createBasicRule( + var operationToLabel: [ConditionOperation: UUID] = [:] + private func createRule( on db: Database, for userGroup: UserGroup, with conditions: BaseCondition..., toApply labels: Label... ) async throws -> Rule { @@ -39,7 +41,7 @@ final class RulesTests: AbstractBaseTestsClass { ruleId: ruleId, operation: cond.operation, valueStr: cond.valueStr, - valueFloat: cond.valueFloat + valueDouble: cond.valueDouble ), on: db) } @@ -54,37 +56,59 @@ final class RulesTests: AbstractBaseTestsClass { return rule } + private func createBasicRuleAndRegister( + on db: Database, for userGroup: UserGroup, with condition: BaseCondition, + toApply label: Label + ) async throws -> Rule { + let rule = try await self.createRule( + on: db, for: userGroup, with: condition, toApply: label) + operationToLabel[condition.operation] = label.id! + return rule + } + private func createBasicRules() async throws { let db = app!.db // String operations - let _ = try await createBasicRule( + let _ = try await createBasicRuleAndRegister( on: db, for: testGroup, with: .init(operation: .prefix, valueStr: "needle"), toApply: labels[0]) - let _ = try await createBasicRule( + let _ = try await createBasicRuleAndRegister( on: db, for: testGroup, - with: .init(operation: .regularExpression, valueStr: "needle"), + with: .init(operation: .regularExpression, valueStr: "A{2}B{2}"), toApply: labels[1]) - let _ = try await createBasicRule( + let _ = try await createBasicRuleAndRegister( on: db, for: testGroup, with: .init(operation: .suffix, valueStr: "needle"), toApply: labels[2]) - let _ = try await createBasicRule( + let _ = try await createBasicRuleAndRegister( on: db, for: testGroup, with: .init(operation: .contains, valueStr: "needle"), toApply: labels[3]) // Float operations - let _ = try await createBasicRule( - on: db, for: testGroup, with: .init(operation: .greater, valueFloat: 1), + let _ = try await createBasicRuleAndRegister( + on: db, for: testGroup, with: .init(operation: .greater, valueDouble: 1), toApply: labels[4]) - let _ = try await createBasicRule( + let _ = try await createBasicRuleAndRegister( on: db, for: testGroup, - with: .init(operation: .greaterEqual, valueFloat: 1), toApply: labels[5]) - let _ = try await createBasicRule( - on: db, for: testGroup, with: .init(operation: .less, valueFloat: -1), + with: .init(operation: .greaterEqual, valueDouble: 1), toApply: labels[5]) + let _ = try await createBasicRuleAndRegister( + on: db, for: testGroup, with: .init(operation: .less, valueDouble: -1), toApply: labels[6]) - let _ = try await createBasicRule( - on: db, for: testGroup, with: .init(operation: .lessEqual, valueFloat: -1), + let _ = try await createBasicRuleAndRegister( + on: db, for: testGroup, with: .init(operation: .lessEqual, valueDouble: -1), toApply: labels[7]) + } + private func checkContainsOperationLabel( + _ labels: [UUID], for operation: ConditionOperation, context: String + ) throws { + let label = operationToLabel[operation] + guard let label else { + print("Label for operation \(operation) not found!") + throw TestError() + } + XCTAssertTrue( + labels.contains(label), "For operation \(operation) with context \(context)" + ) } func testBaseLabelAssignation() async throws { @@ -97,6 +121,7 @@ final class RulesTests: AbstractBaseTestsClass { transactionFactory.build { $0.$groupOwner.id = testGroupId $0.movementName = "no-match" + $0.value = 0 return $0 }, @@ -112,6 +137,25 @@ final class RulesTests: AbstractBaseTestsClass { $0.value = 1 return $0 }, + transactionFactory.build { + $0.$groupOwner.id = testGroupId + $0.movementName = "in the middle is the needle only contains" + $0.value = 2 + return $0 + }, + + transactionFactory.build { + $0.$groupOwner.id = testGroupId + $0.movementName = "suffix with the needle" + $0.value = -2 + return $0 + }, + transactionFactory.build { + $0.$groupOwner.id = testGroupId + $0.movementName = "AABBCCDD" + $0.value = -2 + return $0 + }, ] for transaction in transactions { @@ -119,6 +163,8 @@ final class RulesTests: AbstractBaseTestsClass { on: app.db, withQueue: app.queues.queue, transaction: transaction) } + try await app.queues.queue.worker.run() + for transaction in transactions { try await transaction.$labels.load(on: app.db) } @@ -130,15 +176,178 @@ final class RulesTests: AbstractBaseTestsClass { } XCTAssertEqual(labelsForTransactions[0].count, 0) - XCTAssertEqual(labelsForTransactions[1].count, 3) XCTAssertEqual(labelsForTransactions[2].count, 0) - /*XCTAssertEqual( - app.queues.asyncTest.jobs.count { $0.value.jobName == String(describing: NewTransactionJob.self)}, - 2, - "Expected 2 jobs to have been dispatched" - )*/ - XCTAssertEqual(app.queues.asyncTest.jobs.count, 3) + XCTAssertEqual(labelsForTransactions[1].count, 3) + try checkContainsOperationLabel( + labelsForTransactions[1], for: .prefix, + context: transactions[1].movementName) + try checkContainsOperationLabel( + labelsForTransactions[1], for: .contains, + context: transactions[1].movementName) + try checkContainsOperationLabel( + labelsForTransactions[1], for: .greaterEqual, + context: transactions[1].movementName) + + XCTAssertEqual(labelsForTransactions[3].count, 3) + try checkContainsOperationLabel( + labelsForTransactions[3], for: .greater, + context: transactions[3].movementName) + try checkContainsOperationLabel( + labelsForTransactions[3], for: .contains, + context: transactions[3].movementName) + try checkContainsOperationLabel( + labelsForTransactions[3], for: .greaterEqual, + context: transactions[3].movementName) + + XCTAssertEqual(labelsForTransactions[4].count, 4) + try checkContainsOperationLabel( + labelsForTransactions[4], for: .contains, + context: transactions[4].movementName) + try checkContainsOperationLabel( + labelsForTransactions[4], for: .suffix, + context: transactions[4].movementName) + try checkContainsOperationLabel( + labelsForTransactions[4], for: .lessEqual, + context: transactions[4].movementName) + try checkContainsOperationLabel( + labelsForTransactions[4], for: .less, context: transactions[4].movementName) + } + + func testConditionsRelation() async throws { + let app = try getApp() + let testGroupId = try testGroup.requireID() + + let _ = try await createRule( + on: app.db, for: testGroup, + with: .init(operation: .prefix, valueStr: "needle"), + .init(operation: .less, valueDouble: 0), + toApply: labels[0]) + let rule = try await createRule( + on: app.db, for: testGroup, + with: .init(operation: .prefix, valueStr: "needle"), + .init(operation: .less, valueDouble: 0), + toApply: labels[1]) + rule.conditionsRelation = .notAnd + try await rule.save(on: app.db) + + let transactions = [ + transactionFactory.build { + $0.$groupOwner.id = testGroupId + $0.movementName = "needle" + $0.value = -11 + return $0 + }, + transactionFactory.build { + $0.$groupOwner.id = testGroupId + $0.movementName = "there is no match on the needle" + $0.value = -11 + return $0 + }, + transactionFactory.build { + $0.$groupOwner.id = testGroupId + $0.movementName = "needle" + $0.value = 2 + return $0 + }, + transactionFactory.build { + $0.$groupOwner.id = testGroupId + $0.movementName = "there is no match on the needle" + $0.value = 2 + return $0 + }, + ] + + for transaction in transactions { + let _ = try await bankTransactionService.addTransaction( + on: app.db, withQueue: app.queues.queue, transaction: transaction) + } + + try await app.queues.queue.worker.run() + + for transaction in transactions { + try await transaction.$labels.load(on: app.db) + } + + let labelsForTransactions = transactions.map { + $0.labels.map { label in + return label.id! + } + } + + XCTAssertEqual(labelsForTransactions[0].count, 1) + XCTAssertEqual(labelsForTransactions[0].first, labels[0].id) + XCTAssertEqual(labelsForTransactions[1].count, 1) + XCTAssertEqual(labelsForTransactions[1].first, labels[0].id) + XCTAssertEqual(labelsForTransactions[2].count, 1) + XCTAssertEqual(labelsForTransactions[2].first, labels[0].id) + + XCTAssertEqual(labelsForTransactions[3].count, 1) + XCTAssertEqual(labelsForTransactions[3].first, labels[1].id) + } + + func testRuleParentChildrenRelation() async throws { + let app = try getApp() + let testGroupId = try testGroup.requireID() + + let rule1 = try await createRule( + on: app.db, for: testGroup, + with: .init(operation: .prefix, valueStr: "needle"), + toApply: labels[0]) + let rule2 = try await createRule( + on: app.db, for: testGroup, + with: .init(operation: .less, valueDouble: 0), + toApply: labels[1]) + rule2.$parent.id = try rule1.requireID() + try await rule2.save(on: app.db) + + let transactions = [ + transactionFactory.build { + $0.$groupOwner.id = testGroupId + $0.movementName = "needle" + $0.value = -11 + return $0 + }, + transactionFactory.build { + $0.$groupOwner.id = testGroupId + $0.movementName = "there is no match on the needle" + $0.value = -11 + return $0 + }, + transactionFactory.build { + $0.$groupOwner.id = testGroupId + $0.movementName = "needle" + $0.value = 2 + return $0 + }, + transactionFactory.build { + $0.$groupOwner.id = testGroupId + $0.movementName = "there is no match on the needle" + $0.value = 2 + return $0 + }, + ] + + for transaction in transactions { + let _ = try await bankTransactionService.addTransaction( + on: app.db, withQueue: app.queues.queue, transaction: transaction) + } + + try await app.queues.queue.worker.run() + + for transaction in transactions { + try await transaction.$labels.load(on: app.db) + } + + let labelsForTransactions = transactions.map { + $0.labels.map { label in + return label.id! + } + } + XCTAssertEqual(labelsForTransactions[0].count, 2) + XCTAssertEqual(labelsForTransactions[1].count, 0) + XCTAssertEqual(labelsForTransactions[2].count, 1) + XCTAssertEqual(labelsForTransactions[3].count, 0) } } diff --git a/src/swift-server/configure.swift b/src/swift-server/configure.swift index c903508..03f3ac8 100644 --- a/src/swift-server/configure.swift +++ b/src/swift-server/configure.swift @@ -72,7 +72,7 @@ public func configure(_ app: Application) async throws { app.queues.add(NewTransactionJob()) app.queues.configuration.workerCount = 1 try app.queues.startInProcessJobs(on: .default) - try app.queues.startScheduledJobs() + //try app.queues.startScheduledJobs() } catch { print(error) throw error diff --git a/src/swift-server/core/cache.swift b/src/swift-server/core/cache.swift new file mode 100644 index 0000000..ed26ebb --- /dev/null +++ b/src/swift-server/core/cache.swift @@ -0,0 +1,86 @@ +import Foundation + +final class Cache { + private let wrapped = NSCache() + + private let dateProvider: () -> Date + private let entryLifetime: TimeInterval + + init( + dateProvider: @escaping () -> Date = Date.init, + entryLifetime: TimeInterval = 12 * 60 * 60 + ) { + self.dateProvider = dateProvider + self.entryLifetime = entryLifetime + } + + func insert(_ value: Value, forKey key: Key) { + let date = dateProvider().addingTimeInterval(entryLifetime) + let entry = Entry(value: value, expirationDate: date) + wrapped.setObject(entry, forKey: WrappedKey(key)) + } + + func value(forKey key: Key) -> Value? { + guard let entry = wrapped.object(forKey: WrappedKey(key)) else { + return nil + } + + guard dateProvider() < entry.expirationDate else { + // Discard values that have expired + removeValue(forKey: key) + return nil + } + + return entry.value + } + + func removeValue(forKey key: Key) { + wrapped.removeObject(forKey: WrappedKey(key)) + } +} + +extension Cache { + fileprivate final class Entry { + let value: Value + let expirationDate: Date + + init(value: Value, expirationDate: Date) { + self.value = value + self.expirationDate = expirationDate + } + } +} + +extension Cache { + fileprivate final class WrappedKey: NSObject { + let key: Key + + init(_ key: Key) { self.key = key } + + override var hash: Int { return key.hashValue } + + override func isEqual(_ object: Any?) -> Bool { + guard let value = object as? WrappedKey else { + return false + } + + return value.key == key + } + } +} + +extension Cache { + subscript(key: Key) -> Value? { + get { return value(forKey: key) } + set { + guard let value = newValue else { + // If nil was assigned using our subscript, + // then we remove any value for that key: + removeValue(forKey: key) + return + } + + insert(value, forKey: key) + } + } +} diff --git a/src/swift-server/core/error.swift b/src/swift-server/core/error.swift index a318df5..5447ab1 100644 --- a/src/swift-server/core/error.swift +++ b/src/swift-server/core/error.swift @@ -1,6 +1,6 @@ enum ErrorCode: String, CaseIterable { case E10000, E10001, E10002, E10003, E10004, E10005, E10006, E10007, E10008, E10009 - case E10010, E10011, E10012, E10013 + case E10010, E10011, E10012, E10013, E10014, E10015, E10016 } enum ApiError: String { @@ -39,6 +39,14 @@ let errorDictionary: [ErrorCode: ErrorInfo] = [ .E10011: ErrorInfo(message: "Csv row seems cannot be processed for Commerz Bank En"), .E10012: ErrorInfo(message: "Csv is invalid and doesn't contain the original movement row"), .E10013: ErrorInfo(message: "Csv is invalid and doesn't contain the original value row"), + .E10014: ErrorInfo( + message: "Retrieving an String from a condition when it doesn't have it"), + .E10014: ErrorInfo( + message: "Retrieving the Double from a condition when it doesn't have it"), + .E10016: ErrorInfo( + message: "Rule parent cannot be found by the ID", + additionalInfo: "this can be because the parent rule is with another groupOwnerId" + ), ] extension ErrorCode { diff --git a/src/swift-server/migrations/initial-migration.swift b/src/swift-server/migrations/initial-migration.swift index f9e6406..136bc31 100644 --- a/src/swift-server/migrations/initial-migration.swift +++ b/src/swift-server/migrations/initial-migration.swift @@ -78,7 +78,7 @@ struct InitialMigration: AsyncMigration { .field("rule_id", .uuid, .required, .references("core_rule", "id")) .field("operation", .string, .required) .field("value_str", .string) - .field("value_float", .float) + .field("value_double", .double) // .index("rule_id") .create() diff --git a/src/swift-server/rules/new-transaction.job.swift b/src/swift-server/rules/new-transaction.job.swift index 2482f8f..b03f83c 100644 --- a/src/swift-server/rules/new-transaction.job.swift +++ b/src/swift-server/rules/new-transaction.job.swift @@ -14,12 +14,14 @@ final class NewTransactionJob: AsyncJob { func dequeue(_ context: QueueContext, _ payload: TransactionSummary) async throws { context.logger.info("Processing rules for \(payload.id)") - try await rulesService.applyRules(on: context.application.db, for: payload) + try await rulesService.ruleEngine.applyRules( + on: context.application.db, for: payload) } func error(_ context: QueueContext, _ error: Error, _ payload: TransactionSummary) async throws { print("Some error happened processing \(payload.id) transaction") + print(error) } } diff --git a/src/swift-server/rules/rule-engine.swift b/src/swift-server/rules/rule-engine.swift new file mode 100644 index 0000000..79165ae --- /dev/null +++ b/src/swift-server/rules/rule-engine.swift @@ -0,0 +1,106 @@ +import Fluent +import Foundation + +class CompiledRule { + let id: UUID + let parent: UUID? + var children: [CompiledRule] + let conditions: [(_ transaction: TransactionSummary) -> Bool] + let labelIds: [UUID] + let conditionsRelation: ConditionalRelationType + + init(_ rule: Rule) throws { + id = try rule.requireID() + parent = rule.$parent.id + children = [] + labelIds = rule.labels.map { $0.$label.id } + conditions = try rule.conditions.map { try $0.toClousure() } + conditionsRelation = rule.conditionsRelation + } + + func checkConditionals(for transaction: TransactionSummary) + -> Bool + { + switch conditionsRelation { + case .or: + return conditions.contains { $0(transaction) } + case .notAnd: + return conditions.map { $0(transaction) }.filter { $0 }.isEmpty + } + } +} + +class RuleEngine { + var cache: Cache = .init(entryLifetime: 10) + + private func compileRules(on db: Database, forGroup groupOwnerId: UUID) async throws + -> [CompiledRule] + { + let dbRules = try await Rule.query(on: db).filter( + \.$groupOwner.$id == groupOwnerId + ).with(\.$conditions).with(\.$labels).all() + let rules = try dbRules.map { try CompiledRule($0) } + let rulesHash: [UUID: CompiledRule] = rules.reduce(into: [UUID: CompiledRule]()) { + $0[$1.id] = $1 + } + try rules.filter { $0.parent != nil }.forEach { rule in + // We are already filtering for rules with parent + let parentId = rule.parent! + guard let parent = rulesHash[parentId] else { + throw Exception( + .E10016, context: ["id": rule.id, "parentId": parentId]) + } + + parent.children.append(rule) + } + + return rules.filter { $0.parent == nil } + } + + private func retrieveRules(on db: Database, forGroup groupOwnerId: UUID) async throws + -> [CompiledRule] + { + let cachedData = cache[groupOwnerId] + guard let cachedData else { + let compiledData = try await compileRules(on: db, forGroup: groupOwnerId) + cache[groupOwnerId] = compiledData + return compiledData + } + return cachedData + } + + private func applyLabels( + on db: Database, ruleId: UUID, labelIds: [UUID], for transaction: TransactionSummary + ) async throws { + for labelId in labelIds { + let labelTransaction = LabelTransaction( + labelId: labelId, + transactionId: transaction.id, + linkReason: .automatic) + try await labelTransaction.save(on: db) + try await RuleLabelPivot( + ruleId: ruleId, + labelTransactionId: labelTransaction.id! + ).save(on: db) + } + } + + private func checkAndApply( + on db: Database, rules: [CompiledRule], for transaction: TransactionSummary + ) async throws { + for rule in rules { + if rule.checkConditionals(for: transaction) { + try await applyLabels( + on: db, ruleId: rule.id, labelIds: rule.labelIds, + for: transaction) + try await checkAndApply( + on: db, rules: rule.children, for: transaction) + } + } + } + + func applyRules(on db: Database, for transaction: TransactionSummary) async throws { + let rules = try await retrieveRules(on: db, forGroup: transaction.groupOwnerId) + try await checkAndApply(on: db, rules: rules, for: transaction) + } +} diff --git a/src/swift-server/rules/rules.models.swift b/src/swift-server/rules/rules.models.swift index 584f9c4..c4fe507 100644 --- a/src/swift-server/rules/rules.models.swift +++ b/src/swift-server/rules/rules.models.swift @@ -35,29 +35,81 @@ final class Condition: Model, Content, @unchecked Sendable { @Field(key: "value_str") var valueStr: String? - @Field(key: "value_float") - var valueFloat: Double? + @Field(key: "value_double") + var valueDouble: Double? init() {} init( id: UUID? = nil, ruleId: UUID, operation: ConditionOperation, - valueStr: String? = nil, valueFloat: Double? = nil + valueStr: String? = nil, valueDouble: Double? = nil ) { self.id = id self.$rule.id = ruleId self.operation = operation self.valueStr = valueStr - self.valueFloat = valueFloat + self.valueDouble = valueDouble } - func checkTransaction(_ transaction: TransactionSummary) -> Bool { + func getStr() throws -> String { + guard let valueStr = self.valueStr else { + throw Exception( + .E10013, context: ["condition": id as Any, "parent": $rule.id]) + } + return valueStr + } + + func getDouble() throws -> Double { + guard let valueDouble = self.valueDouble else { + throw Exception( + .E10014, context: ["condition": id as Any, "parent": $rule.id]) + } + return valueDouble + } + + func toClousure() throws -> (_ transaction: TransactionSummary) -> Bool { switch operation { case .prefix: - print(self.valueStr!, transaction.movementName) - return transaction.movementName.hasPrefix(self.valueStr!) - default: - return false + let value = try getStr() + return { transaction in + transaction.movementName.hasPrefix(value) + } + case .contains: + let value = try getStr() + return { transaction in + transaction.movementName.contains(value) + } + case .suffix: + let value = try getStr() + return { transaction in + transaction.movementName.hasSuffix(value) + } + case .regularExpression: + let value = try getStr() + let regex = try Regex(value) + return { transaction in + transaction.movementName.contains(regex) + } + case .greater: + let value = try getDouble() + return { transaction in + transaction.value > value + } + case .greaterEqual: + let value = try getDouble() + return { transaction in + transaction.value >= value + } + case .lessEqual: + let value = try getDouble() + return { transaction in + transaction.value <= value + } + case .less: + let value = try getDouble() + return { transaction in + transaction.value < value + } } } } @@ -101,20 +153,6 @@ final class Rule: Model, Content, @unchecked Sendable { self.conditionsRelation = conditionsRelation self.$parent.id = parentId } - - func checkConditionals(on db: Database, for transaction: TransactionSummary) async throws - -> Bool - { - try await $conditions.load(on: db) - let summary = conditions.map { $0.checkTransaction(transaction) } - print(summary) - switch conditionsRelation { - case .or: - return summary.contains { $0 } - case .notAnd: - return false - } - } } // This one is to know which labels will be linked diff --git a/src/swift-server/rules/rules.service.swift b/src/swift-server/rules/rules.service.swift index 3a3cc94..cd9a0be 100644 --- a/src/swift-server/rules/rules.service.swift +++ b/src/swift-server/rules/rules.service.swift @@ -1,27 +1,6 @@ import Fluent class RulesService { - func applyRules(on db: Database, for transaction: TransactionSummary) async throws { - let rules = try await Rule.query(on: db).filter( - \.$groupOwner.$id == transaction.groupOwnerId - ).with(\.$conditions).all() - for rule in rules { - let isMatch = try await rule.checkConditionals(on: db, for: transaction) - if isMatch { - try await rule.$labels.load(on: db) - for label in rule.labels { - let labelTransaction = LabelTransaction( - labelId: label.id!, transactionId: transaction.id, - linkReason: .automatic) - try await labelTransaction.save(on: db) - try await RuleLabelPivot( - ruleId: rule.id!, - labelTransactionId: labelTransaction.id! - ).save(on: db) - } - // rules.extends(rule.children) - } - } - } + let ruleEngine = RuleEngine() } let rulesService = RulesService()