From f71ef6e97fbd21098203ee659aa0d02b5965ef9e Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Tue, 23 Mar 2021 09:08:32 +0200 Subject: [PATCH] Refactored the evaluation of scorecard models --- .../NearestNeighborModelEvaluator.java | 3 + .../scorecard/ScorecardEvaluator.java | 105 +++++++++--------- 2 files changed, 55 insertions(+), 53 deletions(-) diff --git a/pmml-evaluator/src/main/java/org/jpmml/evaluator/nearest_neighbor/NearestNeighborModelEvaluator.java b/pmml-evaluator/src/main/java/org/jpmml/evaluator/nearest_neighbor/NearestNeighborModelEvaluator.java index 332ac792..7af5b449 100644 --- a/pmml-evaluator/src/main/java/org/jpmml/evaluator/nearest_neighbor/NearestNeighborModelEvaluator.java +++ b/pmml-evaluator/src/main/java/org/jpmml/evaluator/nearest_neighbor/NearestNeighborModelEvaluator.java @@ -357,6 +357,7 @@ private V calculateContinuousTarget(ValueFactory valueFact for(InstanceResult instanceResult : instanceResults){ FieldValue value = table.get(instanceResult.getId(), name); + if(FieldValueUtil.isMissing(value)){ throw new MissingValueException(name); } @@ -406,6 +407,7 @@ private Object calculateCategoricalTarget(ValueFactory val for(InstanceResult instanceResult : instanceResults){ FieldValue value = table.get(instanceResult.getId(), name); + if(FieldValueUtil.isMissing(value)){ throw new MissingValueException(name); } @@ -461,6 +463,7 @@ private Function createIdentifierResolver(FieldName name, Table @Override public String apply(Integer row){ FieldValue value = table.get(row, name); + if(FieldValueUtil.isMissing(value)){ throw new MissingValueException(name); } diff --git a/pmml-evaluator/src/main/java/org/jpmml/evaluator/scorecard/ScorecardEvaluator.java b/pmml-evaluator/src/main/java/org/jpmml/evaluator/scorecard/ScorecardEvaluator.java index 9cd3bfe7..71156907 100644 --- a/pmml-evaluator/src/main/java/org/jpmml/evaluator/scorecard/ScorecardEvaluator.java +++ b/pmml-evaluator/src/main/java/org/jpmml/evaluator/scorecard/ScorecardEvaluator.java @@ -110,66 +110,43 @@ public String getSummary(){ Characteristics characteristics = scorecard.getCharacteristics(); for(Characteristic characteristic : characteristics){ - Number baselineScore = null; + PartialScore partialScore = evaluateCharacteristic(characteristic, context); + + Number score = partialScore.getValue(); + + value.add(score); + + partialScores.add(partialScore); if(useReasonCodes){ - baselineScore = characteristic.getBaselineScore(scorecard.getBaselineScore()); + Number baselineScore = characteristic.getBaselineScore(scorecard.getBaselineScore()); if(baselineScore == null){ throw new MissingAttributeException(characteristic, PMMLAttributes.CHARACTERISTIC_BASELINESCORE); } - } - PartialScore partialScore = null; + String reasonCode = partialScore.getReasonCode(); + if(reasonCode == null){ + Attribute attribute = partialScore.getAttribute(); - List attributes = characteristic.getAttributes(); - for(Attribute attribute : attributes){ - Boolean status = PredicateUtil.evaluatePredicateContainer(attribute, context); - if(status == null || !status.booleanValue()){ - continue; + throw new MissingAttributeException(attribute, PMMLAttributes.ATTRIBUTE_REASONCODE); } - Number score = evaluatePartialScore(attribute, context); - if(score == null){ - return TargetUtil.evaluateRegressionDefault(valueFactory, targetField); + Number difference; + + Scorecard.ReasonCodeAlgorithm reasonCodeAlgorithm = scorecard.getReasonCodeAlgorithm(); + switch(reasonCodeAlgorithm){ + case POINTS_ABOVE: + difference = Functions.SUBTRACT.evaluate(score, baselineScore); + break; + case POINTS_BELOW: + difference = Functions.SUBTRACT.evaluate(baselineScore, score); + break; + default: + throw new UnsupportedAttributeException(scorecard, reasonCodeAlgorithm); } - partialScore = new PartialScore(characteristic, attribute, score); - - value.add(score); - - if(useReasonCodes){ - String reasonCode = attribute.getReasonCode(characteristic.getReasonCode()); - if(reasonCode == null){ - throw new MissingAttributeException(attribute, PMMLAttributes.ATTRIBUTE_REASONCODE); - } - - Number difference; - - Scorecard.ReasonCodeAlgorithm reasonCodeAlgorithm = scorecard.getReasonCodeAlgorithm(); - switch(reasonCodeAlgorithm){ - case POINTS_ABOVE: - difference = Functions.SUBTRACT.evaluate(score, baselineScore); - break; - case POINTS_BELOW: - difference = Functions.SUBTRACT.evaluate(baselineScore, score); - break; - default: - throw new UnsupportedAttributeException(scorecard, reasonCodeAlgorithm); - } - - reasonCodePoints.add(reasonCode, difference); - } - - break; + reasonCodePoints.add(reasonCode, difference); } - - // "If not even a single Attribute evaluates to "true" for a given Characteristic, then the scorecard as a whole returns an invalid value" - if(partialScore == null){ - throw new UndefinedResultException() - .ensureContext(characteristic); - } - - partialScores.add(partialScore); } if(useReasonCodes){ @@ -184,17 +161,39 @@ public String getSummary(){ } static - private Number evaluatePartialScore(Attribute attribute, EvaluationContext context){ + private PartialScore evaluateCharacteristic(Characteristic characteristic, EvaluationContext context){ + List attributes = characteristic.getAttributes(); + + for(Attribute attribute : attributes){ + Boolean status = PredicateUtil.evaluatePredicateContainer(attribute, context); + if(status == null || !status.booleanValue()){ + continue; + } + + Number value = evaluateAttribute(attribute, context); + + return new PartialScore(characteristic, attribute, value); + } + + // "If not even a single Attribute evaluates to "true" for a given Characteristic, then the scorecard as a whole returns an invalid value" + throw new UndefinedResultException() + .ensureContext(characteristic); + } + + static + private Number evaluateAttribute(Attribute attribute, EvaluationContext context){ ComplexPartialScore complexPartialScore = attribute.getComplexPartialScore(); // "If both are defined, the ComplexPartialScore element takes precedence over the partialScore attribute for computing the score points" if(complexPartialScore != null){ - FieldValue computedValue = ExpressionUtil.evaluateExpressionContainer(complexPartialScore, context); - if(FieldValueUtil.isMissing(computedValue)){ - return null; + FieldValue value = ExpressionUtil.evaluateExpressionContainer(complexPartialScore, context); + + if(FieldValueUtil.isMissing(value)){ + throw new UndefinedResultException() + .ensureContext(complexPartialScore); } - return computedValue.asNumber(); + return value.asNumber(); } else {