Skip to content

Commit

Permalink
fix(metadata-io): use spanningTree as algorithm to get lineage
Browse files Browse the repository at this point in the history
  • Loading branch information
lix-mms committed Aug 22, 2023
1 parent 6559148 commit 4895a55
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 112 deletions.
1 change: 1 addition & 0 deletions docker/neo4j/env/docker.env
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
NEO4J_AUTH=neo4j/datahub
NEO4J_dbms_default__database=graph.db
NEO4J_dbms_allow__upgrade=true
NEO4JLABS_PLUGINS="[\"apoc\"]"
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.datahub.util.exception.RetryLimitReached;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.linkedin.common.UrnArray;
import com.linkedin.common.UrnArrayArray;
import com.linkedin.common.urn.Urn;
Expand All @@ -25,17 +26,20 @@
import com.linkedin.metadata.query.filter.RelationshipDirection;
import com.linkedin.metadata.query.filter.RelationshipFilter;
import com.linkedin.metadata.utils.metrics.MetricUtils;
import com.linkedin.util.Pair;
import io.opentelemetry.extension.annotations.WithSpan;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.StringJoiner;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.AllArgsConstructor;
Expand All @@ -50,8 +54,7 @@
import org.neo4j.driver.Session;
import org.neo4j.driver.SessionConfig;
import org.neo4j.driver.exceptions.Neo4jException;
import org.neo4j.driver.internal.InternalRelationship;
import org.neo4j.driver.types.Node;
import org.neo4j.driver.types.Relationship;


@Slf4j
Expand All @@ -62,9 +65,6 @@ public class Neo4jGraphService implements GraphService {
private final Driver _driver;
private SessionConfig _sessionConfig;

private static final String SOURCE = "source";
private static final String UI = "UI";

public Neo4jGraphService(@Nonnull LineageRegistry lineageRegistry, @Nonnull Driver driver) {
this(lineageRegistry, driver, SessionConfig.defaultConfig());
}
Expand Down Expand Up @@ -234,53 +234,36 @@ public EntityLineageResult getLineage(@Nonnull Urn entityUrn, @Nonnull LineageDi
@Nullable Long endTimeMillis) {
log.debug(String.format("Neo4j getLineage maxHops = %d", maxHops));

final String statement =
generateLineageStatement(entityUrn, direction, graphFilters, maxHops, startTimeMillis, endTimeMillis);
final var statementAndParams =
generateLineageStatementAndParameters(entityUrn, direction, graphFilters, maxHops, startTimeMillis, endTimeMillis);

final var statement = statementAndParams.getFirst();
final var parameters = statementAndParams.getSecond();

List<Record> neo4jResult =
statement != null ? runQuery(buildStatement(statement, new HashMap<>())).list() : new ArrayList<>();

// It is possible to have more than 1 path from node A to node B in the graph and previous query returns all the paths.
// We convert the List into Map with only the shortest paths. "item.get(i).size()" is the path size between two nodes in relation.
// The key for mapping is the destination node as the source node is always the same, and it is defined by parameter.
neo4jResult = neo4jResult.stream()
.collect(Collectors.toMap(item -> item.values().get(2).asNode().get("urn").asString(), Function.identity(),
(item1, item2) -> item1.get(1).size() < item2.get(1).size() ? item1 : item2))
.values()
.stream()
.collect(Collectors.toList());
statement != null ? runQuery(buildStatement(statement, parameters)).list() : new ArrayList<>();

LineageRelationshipArray relations = new LineageRelationshipArray();
neo4jResult.stream().skip(offset).limit(count).forEach(item -> {
String urn = item.values().get(2).asNode().get("urn").asString();
String relationType = ((InternalRelationship) item.get(1).asList().get(0)).type().split("r_")[1];
int numHops = item.get(1).size();
try {
// Generate path from r in neo4jResult
List<Urn> pathFromRelationships =
item.values().get(1).asList(Collections.singletonList(new ArrayList<Node>())).stream().map(t -> createFromString(
// Get real upstream node/downstream node by direction
((InternalRelationship) t).get(direction == LineageDirection.UPSTREAM ? "startUrn" : "endUrn")
.asString())).collect(Collectors.toList());
if (direction == LineageDirection.UPSTREAM) {
// For ui to show path correctly, reverse path for UPSTREAM direction
Collections.reverse(pathFromRelationships);
// Add missing original node to the end since we generate path from relationships
pathFromRelationships.add(Urn.createFromString(item.values().get(0).asNode().get("urn").asString()));
} else {
// Add missing original node to the beginning since we generate path from relationships
pathFromRelationships.add(0, Urn.createFromString(item.values().get(0).asNode().get("urn").asString()));
}
final var path = item.get(1).asPath();
final List<Urn> nodeListAsPath = StreamSupport.stream(
path.nodes().spliterator(), false)
.map(node -> createFromString(node.get("urn").asString()))
.collect(Collectors.toList());

final var firstRelationship = Optional.ofNullable(Iterables.getFirst(path.relationships(), null));

relations.add(new LineageRelationship().setEntity(Urn.createFromString(urn))
.setType(relationType)
.setDegree(numHops)
.setPaths(new UrnArrayArray(new UrnArray(pathFromRelationships))));
// although firstRelationship should never be absent, provide "" as fallback value
.setType(firstRelationship.map(Relationship::type).orElse(""))
.setDegree(path.length())
.setPaths(new UrnArrayArray(new UrnArray(nodeListAsPath))));
} catch (URISyntaxException ignored) {
log.warn(String.format("Can't convert urn = %s, Error = %s", urn, ignored.getMessage()));
}
});

EntityLineageResult result = new EntityLineageResult().setStart(offset)
.setCount(relations.size())
.setRelationships(relations)
Expand All @@ -290,31 +273,104 @@ public EntityLineageResult getLineage(@Nonnull Urn entityUrn, @Nonnull LineageDi
return result;
}

private String generateLineageStatement(@Nonnull Urn entityUrn, @Nonnull LineageDirection direction,
GraphFilters graphFilters, int maxHops, @Nullable Long startTimeMillis, @Nullable Long endTimeMillis) {
String statement;
final String allowedEntityTypes = String.join(" OR b:", graphFilters.getAllowedEntityTypes());

final String multiHopMatchTemplateIndirect = "MATCH p = shortestPath((a {urn: '%s'})<-[r*1..%d]-(b)) ";
final String multiHopMatchTemplateDirect = "MATCH p = shortestPath((a {urn: '%s'})-[r*1..%d]->(b)) ";
// directionFilterTemplate should apply to all condition.
final String multiHopMatchTemplate =
direction == LineageDirection.UPSTREAM ? multiHopMatchTemplateIndirect : multiHopMatchTemplateDirect;
final String fullQueryTemplate = generateFullQueryTemplate(multiHopMatchTemplate, startTimeMillis, endTimeMillis);

if (startTimeMillis != null && endTimeMillis != null) {
statement =
String.format(fullQueryTemplate, startTimeMillis, endTimeMillis, entityUrn, maxHops, allowedEntityTypes,
entityUrn);
} else if (startTimeMillis != null) {
statement = String.format(fullQueryTemplate, startTimeMillis, entityUrn, maxHops, allowedEntityTypes, entityUrn);
} else if (endTimeMillis != null) {
statement = String.format(fullQueryTemplate, endTimeMillis, entityUrn, maxHops, allowedEntityTypes, entityUrn);
private String getPathFindingLabelFilter(List<String> entityNames) {
return entityNames.stream().map(x -> String.format("+%s", x)).collect(Collectors.joining("|"));
}

private String getPathFindingRelationshipFilter(@Nonnull List<String> entityNames, @Nullable LineageDirection direction) {
// relationshipFilter supports mixing different directions for various relation types,
// so simply transform entries lineage registry into format of filter
final var filterComponents = new HashSet<String>();
for (final var entityName : entityNames) {
if (direction != null) {
for (final var edgeInfo : _lineageRegistry.getLineageRelationships(entityName, direction)) {
final var type = edgeInfo.getType();
if (edgeInfo.getDirection() == RelationshipDirection.INCOMING) {
filterComponents.add("<" + type);
} else {
filterComponents.add(type + ">");
}
}
} else {
// return disjunctive combination of edge types regardless of direction
for (final var direction1 : List.of(LineageDirection.UPSTREAM, LineageDirection.DOWNSTREAM)) {
for (final var edgeInfo : _lineageRegistry.getLineageRelationships(entityName, direction1)) {
filterComponents.add(edgeInfo.getType());
}
}
}
}
return String.join("|", filterComponents);
}

private Pair<String, Map<String, Object>> generateLineageStatementAndParameters(
@Nonnull Urn entityUrn, @Nonnull LineageDirection direction,
GraphFilters graphFilters, int maxHops,
@Nullable Long startTimeMillis, @Nullable Long endTimeMillis) {

final var parameterMap = new HashMap<String, Object>(Map.of(
"urn", entityUrn.toString(),
"labelFilter", getPathFindingLabelFilter(graphFilters.getAllowedEntityTypes()),
"relationshipFilter", getPathFindingRelationshipFilter(graphFilters.getAllowedEntityTypes(), direction),
"maxHops", maxHops
));

if (startTimeMillis == null && endTimeMillis == null) {
// if no time filtering required, simply find all expansion paths to other nodes
final var statement = "MATCH (a {urn: $urn}) "
+ "CALL apoc.path.spanningTree(a, { "
+ " relationshipFilter: $relationshipFilter, "
+ " labelFilter: $labelFilter, "
+ " minLevel: 1, "
+ " maxLevel: $maxHops "
+ "}) "
+ "YIELD path "
+ "WITH a, path AS path "
+ "RETURN a, path, last(nodes(path));";
return Pair.of(statement, parameterMap);
} else {
statement = String.format(fullQueryTemplate, entityUrn, maxHops, allowedEntityTypes, entityUrn);
// when needing time filtering, possibility on multiple paths between two
// nodes must be considered, and we need to construct more complex query

// use r_ edges until they are no longer useful
final var relationFilter = getPathFindingRelationshipFilter(graphFilters.getAllowedEntityTypes(), null)
.replaceAll("(\\w+)", "r_$1");
final var relationshipPattern =
String.format(
(direction == LineageDirection.UPSTREAM ? "<-[:%s*1..%d]-" : "-[:%s*1..%d]->"),
relationFilter, maxHops);

// two steps:
// 1. find list of nodes reachable within maxHops
// 2. find the shortest paths from start node to every other node in these nodes
// (note: according to the docs of shortestPath, WHERE conditions are applied during path exploration, not
// after path exploration is done)
final var statement = "MATCH (a {urn: $urn}) "
+ "CALL apoc.path.subgraphNodes(a, { "
+ " relationshipFilter: $relationshipFilter, "
+ " labelFilter: $labelFilter, "
+ " minLevel: 1, "
+ " maxLevel: $maxHops "
+ "}) "
+ "YIELD node AS b "
+ "WITH a, b "
+ "MATCH path = shortestPath((a)" + relationshipPattern + "(b)) "
+ "WHERE a <> b "
+ " AND ALL(rt IN relationships(path) WHERE "
+ " (EXISTS(rt.source) AND rt.source = 'UI') OR "
+ " (NOT EXISTS(rt.createdOn) AND NOT EXISTS(rt.updatedOn)) OR "
+ " ($startTimeMillis <= rt.createdOn <= $endTimeMillis OR "
+ " $startTimeMillis <= rt.updatedOn <= $endTimeMillis) "
+ " ) "
+ "RETURN a, path, b;";

// provide dummy start/end time when not provided, so no need to
// format clause differently if either of them is missing
parameterMap.put("startTimeMillis", startTimeMillis == null ? 0 : startTimeMillis);
parameterMap.put("endTimeMillis", endTimeMillis == null ? System.currentTimeMillis() : endTimeMillis);

return Pair.of(statement, parameterMap);
}

return statement;
}

@Nonnull
Expand Down Expand Up @@ -583,15 +639,6 @@ private Result runQuery(@Nonnull Statement statement) {
}
}

@Nonnull
private static String toCriterionWhereString(@Nonnull String key, @Nonnull Object value) {
if (ClassUtils.isPrimitiveOrWrapper(value.getClass())) {
return key + " = " + value;
}

return key + " = \"" + value.toString() + "\"";
}

// Returns "key:value" String, if value is not primitive, then use toString() and double quote it
@Nonnull
private static String toCriterionString(@Nonnull String key, @Nonnull Object value) {
Expand Down Expand Up @@ -715,44 +762,4 @@ Urn createFromString(@Nonnull String rawUrn) {
return null;
}
}

private String generateFullQueryTemplate(@Nonnull String multiHopMatchTemplate, @Nullable Long startTimeMillis, @Nullable Long endTimeMillis) {
final String sourceUiCheck = String.format("(EXISTS(rt.%s) AND rt.%s = '%s') ", SOURCE, SOURCE, UI);
final String whereTemplate = "WHERE (b:%s) AND b.urn <> '%s' ";
final String returnTemplate = "RETURN a,r,b";
String withTimeTemplate = "";
String timeFilterConditionTemplate = "AND ALL(rt IN relationships(p) WHERE left(type(rt), 2)='r_')";

if (startTimeMillis != null && endTimeMillis != null) {
withTimeTemplate = "WITH %d as startTimeMillis, %d as endTimeMillis ";
timeFilterConditionTemplate =
"AND ALL(rt IN relationships(p) WHERE " + sourceUiCheck + "OR "
+ "(NOT EXISTS(rt.createdOn) AND NOT EXISTS(rt.updatedOn)) OR "
+ "((rt.createdOn >= startTimeMillis AND rt.createdOn <= endTimeMillis) OR "
+ "(rt.updatedOn >= startTimeMillis AND rt.updatedOn <= endTimeMillis))) "
+ "AND ALL(rt IN relationships(p) WHERE left(type(rt), 2)='r_')";
} else if (startTimeMillis != null) {
withTimeTemplate = "WITH %d as startTimeMillis ";
timeFilterConditionTemplate =
"AND ALL(rt IN relationships(p) WHERE " + sourceUiCheck + "OR "
+ "(NOT EXISTS(rt.createdOn) AND NOT EXISTS(rt.updatedOn)) OR "
+ "(rt.createdOn >= startTimeMillis OR rt.updatedOn >= startTimeMillis)) "
+ "AND ALL(rt IN relationships(p) WHERE left(type(rt), 2)='r_')";
} else if (endTimeMillis != null) {
withTimeTemplate = "WITH %d as endTimeMillis ";
timeFilterConditionTemplate =
"AND ALL(rt IN relationships(p) WHERE " + sourceUiCheck + "OR "
+ "(NOT EXISTS(rt.createdOn) AND NOT EXISTS(rt.updatedOn)) OR "
+ "(rt.createdOn <= endTimeMillis OR rt.updatedOn <= endTimeMillis)) "
+ "AND ALL(rt IN relationships(p) WHERE left(type(rt), 2)='r_')";
}
final StringJoiner fullQueryTemplateJoiner = new StringJoiner(" ");
fullQueryTemplateJoiner.add(withTimeTemplate);
fullQueryTemplateJoiner.add(multiHopMatchTemplate);
fullQueryTemplateJoiner.add(whereTemplate);
fullQueryTemplateJoiner.add(timeFilterConditionTemplate);
fullQueryTemplateJoiner.add(returnTemplate);

return fullQueryTemplateJoiner.toString();
}
}

0 comments on commit 4895a55

Please sign in to comment.