Skip to content

Commit

Permalink
macro average for specified number of test cases fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
sven-h committed Nov 20, 2023
1 parent 4240228 commit 43fdc31
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -245,27 +245,27 @@ public ConfusionMatrix getMacroAveragesForResults(Iterable<ExecutionResult> resu
Alignment truePositive = new Alignment();
Alignment falsePositive = new Alignment();
Alignment falseNegative = new Alignment();

double precision = 0.0; // dummy init
double recall = 0.0; // dummy init

// for aggregation:
int numberOfCorrespondences = 0;
double aggregatedPrecision = 0.0;
double aggregatedRecall = 0.0;
double aggregatedF1 = 0.0;

for (ConfusionMatrix individualConfusionMatrix : confusionMatrices) {
truePositive.addAll(individualConfusionMatrix.getTruePositive());
falsePositive.addAll(individualConfusionMatrix.getFalsePositive());
falseNegative.addAll(individualConfusionMatrix.getFalseNegative());
}

double aggregatedPrecision = 0.0;
double aggregatedRecall = 0.0;
for (ConfusionMatrix individualConfusionMatrix : confusionMatrices) {

numberOfCorrespondences += individualConfusionMatrix.getNumberOfCorrespondences();

aggregatedPrecision = aggregatedPrecision + individualConfusionMatrix.getPrecision();
aggregatedRecall = aggregatedRecall + individualConfusionMatrix.getRecall();
aggregatedF1 = aggregatedF1 + individualConfusionMatrix.getF1measure();
}
precision = aggregatedPrecision / numberOfTestCases;
recall = aggregatedRecall / numberOfTestCases;

return new ConfusionMatrix(truePositive, falsePositive, falseNegative, precision, recall);
double precision = aggregatedPrecision / numberOfTestCases;
double recall = aggregatedRecall / numberOfTestCases;
double f1 = aggregatedF1 / numberOfTestCases;

return new ConfusionMatrixMacroAveraged(truePositive, falsePositive, falseNegative, numberOfCorrespondences, precision, recall, f1);
}


Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package de.uni_mannheim.informatik.dws.melt.matching_eval.evaluator.metric.cm;

import de.uni_mannheim.informatik.dws.melt.matching_data.GoldStandardCompleteness;
import de.uni_mannheim.informatik.dws.melt.matching_data.LocalTrack;
import de.uni_mannheim.informatik.dws.melt.matching_data.TestCase;
import de.uni_mannheim.informatik.dws.melt.matching_data.TrackRepository;
import de.uni_mannheim.informatik.dws.melt.matching_eval.ExecutionResultSet;
Expand All @@ -9,6 +11,14 @@
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Predicate;
import org.apache.jena.ontology.OntModel;
import org.apache.jena.ontology.OntModelSpec;
import org.apache.jena.rdf.model.ModelFactory;
import org.apache.jena.vocabulary.OWL;

import static org.junit.jupiter.api.Assertions.*;

Expand Down Expand Up @@ -242,4 +252,172 @@ void realTest() {
assertEquals(0.615, confusionMatrix1Dome.getRecall(), 0.001);
}


@Test
void micromacroTest() {
ConfusionMatrixMetric metric = new ConfusionMatrixMetric();
double delta = 0.01;
//https://iamirmasoud.com/2022/06/19/understanding-micro-macro-and-weighted-averages-for-scikit-learn-metrics-in-multi-class-classification-with-example/
ExecutionResult resultFirst = createResultWith("testCaseA", GoldStandardCompleteness.COMPLETE,
2,1,1,0,0,0,0,0,0);
ExecutionResult resultSecond = createResultWith("testCaseB", GoldStandardCompleteness.COMPLETE,
1,3,0,0,0,0,0,0,0);
ExecutionResult resultThird = createResultWith("testCaseC", GoldStandardCompleteness.COMPLETE,
3,0,3,0,0,0,0,0,0);
ExecutionResultSet all = new ExecutionResultSet();
all.add(resultFirst);
all.add(resultSecond);
all.add(resultThird);

ConfusionMatrix confusionMatrixFirst = metric.compute(resultFirst);
ConfusionMatrix confusionMatrixSecond = metric.compute(resultSecond);
ConfusionMatrix confusionMatrixThird = metric.compute(resultThird);

assertEquals(2, confusionMatrixFirst.getTruePositiveSize());
assertEquals(1, confusionMatrixFirst.getFalsePositiveSize());
assertEquals(1, confusionMatrixFirst.getFalseNegativeSize());
assertEquals(0.67, confusionMatrixFirst.getPrecision(), delta);
assertEquals(0.67, confusionMatrixFirst.getRecall(), delta);
assertEquals(0.67, confusionMatrixFirst.getF1measure(), delta);


assertEquals(1, confusionMatrixSecond.getTruePositiveSize());
assertEquals(3, confusionMatrixSecond.getFalsePositiveSize());
assertEquals(0, confusionMatrixSecond.getFalseNegativeSize());
assertEquals(0.25, confusionMatrixSecond.getPrecision(), delta);
assertEquals(1.0, confusionMatrixSecond.getRecall(), delta);
assertEquals(0.4, confusionMatrixSecond.getF1measure(), delta);

assertEquals(3, confusionMatrixThird.getTruePositiveSize());
assertEquals(0, confusionMatrixThird.getFalsePositiveSize());
assertEquals(3, confusionMatrixThird.getFalseNegativeSize());
assertEquals(1.0, confusionMatrixThird.getPrecision(), delta);
assertEquals(0.5, confusionMatrixThird.getRecall(), delta);
assertEquals(0.67, confusionMatrixThird.getF1measure(), delta);



ConfusionMatrix microAll = metric.getMicroAveragesForResults(all);

assertEquals(6, microAll.getTruePositiveSize());
assertEquals(4, microAll.getFalsePositiveSize());
assertEquals(4, microAll.getFalseNegativeSize());
assertEquals(0.6, microAll.getPrecision(), delta);
assertEquals(0.6, microAll.getRecall(), delta);
assertEquals(0.6, microAll.getF1measure(), delta);


ConfusionMatrix macroAll = metric.getMacroAveragesForResults(all);
assertEquals(6, macroAll.getTruePositiveSize());
assertEquals(4, macroAll.getFalsePositiveSize());
assertEquals(4, macroAll.getFalseNegativeSize());
assertEquals(0.64, macroAll.getPrecision(), delta);
assertEquals(0.72, macroAll.getRecall(), delta);
assertEquals(0.58, macroAll.getF1measure(), delta);

ConfusionMatrix macroSpecifiedNumber = metric.getMacroAveragesForResults(all, 3);
assertEquals(6, macroSpecifiedNumber.getTruePositiveSize());
assertEquals(4, macroSpecifiedNumber.getFalsePositiveSize());
assertEquals(4, macroSpecifiedNumber.getFalseNegativeSize());
assertEquals(0.64, macroSpecifiedNumber.getPrecision(), delta);
assertEquals(0.72, macroSpecifiedNumber.getRecall(), delta);
assertEquals(0.58, macroSpecifiedNumber.getF1measure(), delta);
}





private static ExecutionResult createResultWith(String testCase, GoldStandardCompleteness goldStandardCompleteness,
int classTP, int classFP, int classFN,
int propTP, int propFP, int propFN,
int instTP, int instFP, int instFN){
int counter = 0;
String sourceBase = "http://source.com/" + testCase + "/";
String targetBase = "http://target.com/" + testCase + "/";

Alignment systemAlignment = new Alignment();
Alignment refAlignment = new Alignment();

OntModel src = ModelFactory.createOntologyModel(OntModelSpec.OWL_MEM);
OntModel tgt = ModelFactory.createOntologyModel(OntModelSpec.OWL_MEM);

//TP
for(int i = 0; i < classTP; i++){
String sourceURI = sourceBase + counter++;
String targetURI = targetBase + counter++;
src.createClass(sourceURI);
tgt.createClass(targetURI);
systemAlignment.add(sourceURI, targetURI);
refAlignment.add(sourceURI, targetURI);
}
for(int i = 0; i < propTP; i++){
String sourceURI = sourceBase + counter++;
String targetURI = targetBase + counter++;
src.createProperty(sourceURI);
tgt.createProperty(targetURI);
systemAlignment.add(sourceURI, targetURI);
refAlignment.add(sourceURI, targetURI);
}
for(int i = 0; i < instTP; i++){
String sourceURI = sourceBase + counter++;
String targetURI = targetBase + counter++;
src.createIndividual(sourceURI, OWL.Thing);
tgt.createIndividual(targetURI, OWL.Thing);
systemAlignment.add(sourceURI, targetURI);
refAlignment.add(sourceURI, targetURI);
}

//FP
for(int i = 0; i < classFP; i++){
String sourceURI = sourceBase + counter++;
String targetURI = targetBase + counter++;
src.createClass(sourceURI);
tgt.createClass(targetURI);
systemAlignment.add(sourceURI, targetURI);
}
for(int i = 0; i < propFP; i++){
String sourceURI = sourceBase + counter++;
String targetURI = targetBase + counter++;
src.createProperty(sourceURI);
tgt.createProperty(targetURI);
systemAlignment.add(sourceURI, targetURI);
}
for(int i = 0; i < instFP; i++){
String sourceURI = sourceBase + counter++;
String targetURI = targetBase + counter++;
src.createIndividual(sourceURI, OWL.Thing);
tgt.createIndividual(targetURI, OWL.Thing);
systemAlignment.add(sourceURI, targetURI);
}

//FN
for(int i = 0; i < classFN; i++){
String sourceURI = sourceBase + counter++;
String targetURI = targetBase + counter++;
src.createClass(sourceURI);
tgt.createClass(targetURI);
refAlignment.add(sourceURI, targetURI);
}
for(int i = 0; i < propFN; i++){
String sourceURI = sourceBase + counter++;
String targetURI = targetBase + counter++;
src.createProperty(sourceURI);
tgt.createProperty(targetURI);
refAlignment.add(sourceURI, targetURI);
}
for(int i = 0; i < instFN; i++){
String sourceURI = sourceBase + counter++;
String targetURI = targetBase + counter++;
src.createIndividual(sourceURI, OWL.Thing);
tgt.createIndividual(targetURI, OWL.Thing);
refAlignment.add(sourceURI, targetURI);
}


LocalTrack track = new LocalTrack("testtrack", "1.0");
TestCase tc = new TestCaseWithModel(testCase, src, tgt, refAlignment, track, goldStandardCompleteness);

return new ExecutionResult(tc, "myTestMatcher", systemAlignment, refAlignment);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package de.uni_mannheim.informatik.dws.melt.matching_eval.evaluator.metric.cm;

import de.uni_mannheim.informatik.dws.melt.matching_data.GoldStandardCompleteness;
import de.uni_mannheim.informatik.dws.melt.matching_data.TestCase;
import de.uni_mannheim.informatik.dws.melt.matching_data.Track;
import de.uni_mannheim.informatik.dws.melt.yet_another_alignment_api.Alignment;
import java.util.Properties;
import org.apache.jena.ontology.OntModel;

public class TestCaseWithModel extends TestCase {
private OntModel sourceModel;
private OntModel targetModel;
private Alignment referenceAlignment;
public TestCaseWithModel(String name, OntModel source, OntModel target, Alignment reference, Track track, GoldStandardCompleteness goldStandardCompleteness) {
super(name, null, null, null, track, null, goldStandardCompleteness, null, null);
this.sourceModel = source;
this.targetModel = target;
this.referenceAlignment = reference;
}

@Override
public <T> T getSourceOntology(Class<T> clazz){
return getSourceOntology(clazz, null);
}
@Override
@SuppressWarnings("unchecked")
public <T> T getSourceOntology(Class<T> clazz, Properties parameters){
if(clazz.equals(OntModel.class)){
return (T) sourceModel;
}else{
throw new IllegalArgumentException("Wrong ontology type");
}
}


@Override
public <T> T getTargetOntology(Class<T> clazz){
return getTargetOntology(clazz, null);
}
@Override
@SuppressWarnings("unchecked")
public <T> T getTargetOntology(Class<T> clazz, Properties parameters){
if(clazz.equals(OntModel.class)){
return (T) targetModel;
}else{
throw new IllegalArgumentException("Wrong ontology type");
}
}

@Override
public Alignment getParsedReferenceAlignment() {
return referenceAlignment;
}

}

0 comments on commit 43fdc31

Please sign in to comment.