-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor SimilarityMatcher and prevent duplicates
- Loading branch information
Showing
2 changed files
with
74 additions
and
34 deletions.
There are no files selected for viewing
102 changes: 70 additions & 32 deletions
102
tutor/src/main/java/org/tudalgo/algoutils/transform/util/SimilarityMapper.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,68 +1,106 @@ | ||
package org.tudalgo.algoutils.transform.util; | ||
|
||
import kotlin.Pair; | ||
import org.tudalgo.algoutils.tutor.general.match.MatchingUtils; | ||
|
||
import java.util.*; | ||
import java.util.function.Function; | ||
import java.util.stream.Collectors; | ||
import java.util.stream.Stream; | ||
|
||
/** | ||
* Computes the similarity for the cross product of two given collections. | ||
* This creates a mapping of values in the first collection to the best match | ||
* in the second collection, if any. | ||
* | ||
* @param <T> the type of the collection's elements | ||
*/ | ||
public class SimilarityMapper<T> { | ||
|
||
private final List<? extends T> rowMapping; | ||
private final List<? extends T> columnMapping; | ||
private final double[][] similarityMatrix; | ||
private final Map<T, T> bestMatches = new HashMap<>(); | ||
private final Map<T, Pair<T, Double>> bestMatches = new HashMap<>(); | ||
|
||
@SuppressWarnings("unchecked") | ||
public SimilarityMapper(Collection<String> from, Map<String, Collection<String>> to, double similarityThreshold) { | ||
List<String> rowMapping = new ArrayList<>(from); | ||
List<String> columnMapping = new ArrayList<>(to.keySet()); | ||
/** | ||
* Creates a new {@link SimilarityMapper} instance, allowing columns to have aliases. | ||
* | ||
* @param from the values to map from (rows) | ||
* @param to the values to map to (columns), with the map's values being aliases of the key | ||
* @param similarityThreshold the minimum similarity two values need to have to be considered a match | ||
* @param mappingFunction a function for mapping the collection's elements to strings | ||
*/ | ||
public SimilarityMapper(Collection<? extends T> from, | ||
Map<T, Collection<? extends T>> to, | ||
double similarityThreshold, | ||
Function<? super T, String> mappingFunction) { | ||
this.rowMapping = new ArrayList<>(from); | ||
this.columnMapping = new ArrayList<>(to.keySet()); | ||
this.similarityMatrix = new double[from.size()][to.size()]; | ||
|
||
for (int i = 0; i < similarityMatrix.length; i++) { | ||
String row = rowMapping.get(i); | ||
int bestMatchIndex = -1; | ||
double bestSimilarity = similarityThreshold; | ||
for (int j = 0; j < similarityMatrix[i].length; j++) { | ||
similarityMatrix[i][j] = Stream.concat(Stream.of(columnMapping.get(j)), to.get(columnMapping.get(j)).stream()) | ||
.mapToDouble(value -> MatchingUtils.similarity(row, value)) | ||
.max() | ||
.orElseThrow(); | ||
if (similarityMatrix[i][j] >= bestSimilarity) { | ||
bestMatchIndex = j; | ||
bestSimilarity = similarityMatrix[i][j]; | ||
} | ||
} | ||
if (bestMatchIndex >= 0) { | ||
bestMatches.put((T) rowMapping.get(i), (T) columnMapping.get(bestMatchIndex)); | ||
} | ||
} | ||
computeSimilarity(to, similarityThreshold, mappingFunction); | ||
} | ||
|
||
/** | ||
* Creates a new {@link SimilarityMapper} instance. | ||
* | ||
* @param from the values to map from (rows) | ||
* @param to the values to map to (columns) | ||
* @param similarityThreshold the minimum similarity two values need to have to be considered a match | ||
* @param mappingFunction a function for mapping the collection's elements to strings | ||
*/ | ||
public SimilarityMapper(Collection<? extends T> from, | ||
Collection<? extends T> to, | ||
double similarityThreshold, | ||
Function<? super T, String> mappingFunction) { | ||
List<T> rowMapping = new ArrayList<>(from); | ||
List<T> columnMapping = new ArrayList<>(to); | ||
this.rowMapping = new ArrayList<>(from); | ||
this.columnMapping = new ArrayList<>(to); | ||
this.similarityMatrix = new double[from.size()][to.size()]; | ||
computeSimilarity(to.stream().collect(Collectors.toMap(Function.identity(), t -> Collections.emptyList())), | ||
similarityThreshold, | ||
mappingFunction); | ||
} | ||
|
||
/** | ||
* Returns the best match for the given value, wrapped in an optional. | ||
* | ||
* @param t the value to find the best match for | ||
* @return an optional wrapping the best match | ||
*/ | ||
public Optional<T> getBestMatch(T t) { | ||
return Optional.ofNullable(bestMatches.get(t)).map(Pair::getFirst); | ||
} | ||
|
||
/** | ||
* Computes the similarity for each entry in the cross product of the two input collections. | ||
* Also extracts the best matches and stores them in {@link #bestMatches} for easy access. | ||
* | ||
* @param to a mapping of columns to their aliases | ||
* @param similarityThreshold the minimum similarity two values need to have to be considered a match | ||
* @param mappingFunction a function for mapping the collection's elements to strings | ||
*/ | ||
private void computeSimilarity(Map<? extends T, Collection<? extends T>> to, | ||
double similarityThreshold, | ||
Function<? super T, String> mappingFunction) { | ||
for (int i = 0; i < similarityMatrix.length; i++) { | ||
String row = mappingFunction.apply(rowMapping.get(i)); | ||
int bestMatchIndex = -1; | ||
double bestSimilarity = similarityThreshold; | ||
for (int j = 0; j < similarityMatrix[i].length; j++) { | ||
similarityMatrix[i][j] = MatchingUtils.similarity(row, mappingFunction.apply(columnMapping.get(j))); | ||
similarityMatrix[i][j] = Stream.concat(Stream.of(columnMapping.get(j)), to.get(columnMapping.get(j)).stream()) | ||
.map(mappingFunction) | ||
.mapToDouble(value -> MatchingUtils.similarity(row, value)) | ||
.max() | ||
.orElseThrow(); | ||
if (similarityMatrix[i][j] >= bestSimilarity) { | ||
bestMatchIndex = j; | ||
bestSimilarity = similarityMatrix[i][j]; | ||
} | ||
} | ||
if (bestMatchIndex >= 0) { | ||
bestMatches.put(rowMapping.get(i), columnMapping.get(bestMatchIndex)); | ||
Pair<T, Double> pair = new Pair<>(columnMapping.get(bestMatchIndex), bestSimilarity); | ||
bestMatches.merge(rowMapping.get(i), pair, (oldPair, newPair) -> | ||
newPair.getSecond() > oldPair.getSecond() ? newPair : oldPair); | ||
} | ||
} | ||
} | ||
|
||
public Optional<T> getBestMatch(T t) { | ||
return Optional.ofNullable(bestMatches.get(t)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters