diff --git a/tutor/src/main/java/org/tudalgo/algoutils/transform/util/SimilarityMapper.java b/tutor/src/main/java/org/tudalgo/algoutils/transform/util/SimilarityMapper.java index 14912b9..4690041 100644 --- a/tutor/src/main/java/org/tudalgo/algoutils/transform/util/SimilarityMapper.java +++ b/tutor/src/main/java/org/tudalgo/algoutils/transform/util/SimilarityMapper.java @@ -1,68 +1,106 @@ package org.tudalgo.algoutils.transform.util; +import kotlin.Pair; import org.tudalgo.algoutils.tutor.general.match.MatchingUtils; import java.util.*; import java.util.function.Function; +import java.util.stream.Collectors; import java.util.stream.Stream; +/** + * Computes the similarity for the cross product of two given collections. + * This creates a mapping of values in the first collection to the best match + * in the second collection, if any. + * + * @param the type of the collection's elements + */ public class SimilarityMapper { + private final List rowMapping; + private final List columnMapping; private final double[][] similarityMatrix; - private final Map bestMatches = new HashMap<>(); + private final Map> bestMatches = new HashMap<>(); - @SuppressWarnings("unchecked") - public SimilarityMapper(Collection from, Map> to, double similarityThreshold) { - List rowMapping = new ArrayList<>(from); - List columnMapping = new ArrayList<>(to.keySet()); + /** + * Creates a new {@link SimilarityMapper} instance, allowing columns to have aliases. + * + * @param from the values to map from (rows) + * @param to the values to map to (columns), with the map's values being aliases of the key + * @param similarityThreshold the minimum similarity two values need to have to be considered a match + * @param mappingFunction a function for mapping the collection's elements to strings + */ + public SimilarityMapper(Collection from, + Map> to, + double similarityThreshold, + Function mappingFunction) { + this.rowMapping = new ArrayList<>(from); + this.columnMapping = new ArrayList<>(to.keySet()); this.similarityMatrix = new double[from.size()][to.size()]; - - for (int i = 0; i < similarityMatrix.length; i++) { - String row = rowMapping.get(i); - int bestMatchIndex = -1; - double bestSimilarity = similarityThreshold; - for (int j = 0; j < similarityMatrix[i].length; j++) { - similarityMatrix[i][j] = Stream.concat(Stream.of(columnMapping.get(j)), to.get(columnMapping.get(j)).stream()) - .mapToDouble(value -> MatchingUtils.similarity(row, value)) - .max() - .orElseThrow(); - if (similarityMatrix[i][j] >= bestSimilarity) { - bestMatchIndex = j; - bestSimilarity = similarityMatrix[i][j]; - } - } - if (bestMatchIndex >= 0) { - bestMatches.put((T) rowMapping.get(i), (T) columnMapping.get(bestMatchIndex)); - } - } + computeSimilarity(to, similarityThreshold, mappingFunction); } + /** + * Creates a new {@link SimilarityMapper} instance. + * + * @param from the values to map from (rows) + * @param to the values to map to (columns) + * @param similarityThreshold the minimum similarity two values need to have to be considered a match + * @param mappingFunction a function for mapping the collection's elements to strings + */ public SimilarityMapper(Collection from, Collection to, double similarityThreshold, Function mappingFunction) { - List rowMapping = new ArrayList<>(from); - List columnMapping = new ArrayList<>(to); + this.rowMapping = new ArrayList<>(from); + this.columnMapping = new ArrayList<>(to); this.similarityMatrix = new double[from.size()][to.size()]; + computeSimilarity(to.stream().collect(Collectors.toMap(Function.identity(), t -> Collections.emptyList())), + similarityThreshold, + mappingFunction); + } + + /** + * Returns the best match for the given value, wrapped in an optional. + * + * @param t the value to find the best match for + * @return an optional wrapping the best match + */ + public Optional getBestMatch(T t) { + return Optional.ofNullable(bestMatches.get(t)).map(Pair::getFirst); + } + /** + * Computes the similarity for each entry in the cross product of the two input collections. + * Also extracts the best matches and stores them in {@link #bestMatches} for easy access. + * + * @param to a mapping of columns to their aliases + * @param similarityThreshold the minimum similarity two values need to have to be considered a match + * @param mappingFunction a function for mapping the collection's elements to strings + */ + private void computeSimilarity(Map> to, + double similarityThreshold, + Function mappingFunction) { for (int i = 0; i < similarityMatrix.length; i++) { String row = mappingFunction.apply(rowMapping.get(i)); int bestMatchIndex = -1; double bestSimilarity = similarityThreshold; for (int j = 0; j < similarityMatrix[i].length; j++) { - similarityMatrix[i][j] = MatchingUtils.similarity(row, mappingFunction.apply(columnMapping.get(j))); + similarityMatrix[i][j] = Stream.concat(Stream.of(columnMapping.get(j)), to.get(columnMapping.get(j)).stream()) + .map(mappingFunction) + .mapToDouble(value -> MatchingUtils.similarity(row, value)) + .max() + .orElseThrow(); if (similarityMatrix[i][j] >= bestSimilarity) { bestMatchIndex = j; bestSimilarity = similarityMatrix[i][j]; } } if (bestMatchIndex >= 0) { - bestMatches.put(rowMapping.get(i), columnMapping.get(bestMatchIndex)); + Pair pair = new Pair<>(columnMapping.get(bestMatchIndex), bestSimilarity); + bestMatches.merge(rowMapping.get(i), pair, (oldPair, newPair) -> + newPair.getSecond() > oldPair.getSecond() ? newPair : oldPair); } } } - - public Optional getBestMatch(T t) { - return Optional.ofNullable(bestMatches.get(t)); - } } diff --git a/tutor/src/main/java/org/tudalgo/algoutils/transform/util/TransformationContext.java b/tutor/src/main/java/org/tudalgo/algoutils/transform/util/TransformationContext.java index 165c223..d1185de 100644 --- a/tutor/src/main/java/org/tudalgo/algoutils/transform/util/TransformationContext.java +++ b/tutor/src/main/java/org/tudalgo/algoutils/transform/util/TransformationContext.java @@ -9,6 +9,7 @@ import java.io.IOException; import java.io.InputStream; import java.util.*; +import java.util.function.Function; /** * A record for holding context information for the transformation process. @@ -108,8 +109,9 @@ public void setSubmissionClassNames(Set submissionClassNames) { @SuppressWarnings("unchecked") public void computeClassesSimilarity() { classSimilarityMapper = new SimilarityMapper<>(submissionClassNames, - (Map>) configuration.get(SolutionMergingClassTransformer.Config.SOLUTION_CLASSES), - getSimilarity()); + (Map>) configuration.get(SolutionMergingClassTransformer.Config.SOLUTION_CLASSES), + getSimilarity(), + Function.identity()); } // Submission classes