Skip to content

Commit

Permalink
Deep arithmetic rule weights checkpoint.
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed May 13, 2024
1 parent 504e2be commit db6544f
Show file tree
Hide file tree
Showing 9 changed files with 358 additions and 122 deletions.
2 changes: 0 additions & 2 deletions psl-core/src/main/java/org/linqs/psl/database/ResultList.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
import org.linqs.psl.model.term.Constant;
import org.linqs.psl.model.term.Variable;

import java.util.Map;

/**
* List of substitutions for {@link Variable Variables} in a {@link Formula}.
*/
Expand Down
9 changes: 7 additions & 2 deletions psl-core/src/main/java/org/linqs/psl/model/rule/Weight.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.linqs.psl.model.atom.Atom;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.predicate.DeepPredicate;

/**
* A weight for a rule.
Expand All @@ -45,7 +46,11 @@ public Weight(float constantValue, Atom atom) {
public float getValue() {
if (atom != null) {
if (!(atom instanceof GroundAtom)) {
throw new IllegalStateException("Called getValue() on " + atom + " before grounding. Atom must be a GroundAtom before it can be used in a Weight.");
throw new IllegalStateException("Called getValue() on weight with atom: " + atom + " before grounding. Atom must be a GroundAtom before it can be used in a Weight.");
}

if (!((GroundAtom) atom).isFixed()) {
throw new IllegalStateException("Called getValue() on weight with non-fixed atom: " + atom + ". Atoms must be fixed (deep or observed) if they are used as weights.");
}

return constantValue * ((GroundAtom)atom).getValue();
Expand All @@ -71,7 +76,7 @@ public Atom getAtom() {
}

/**
* Returns whether the term is constant.
* Returns whether the term is constant or if it is a function of an atom.
*/
public boolean isConstant() {
return atom == null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.linqs.psl.model.rule.AbstractRule;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.Weight;
import org.linqs.psl.model.rule.arithmetic.expression.ArithmeticRuleExpression;
import org.linqs.psl.model.rule.arithmetic.expression.SummationAtom;
import org.linqs.psl.model.rule.arithmetic.expression.SummationAtomOrAtom;
Expand Down Expand Up @@ -79,7 +80,7 @@ public abstract class AbstractArithmeticRule extends AbstractRule {
/**
* A key to store per-rule threading grounding resources under.
*/
private final String groundingResourcesKey;
protected final String groundingResourcesKey;

private volatile boolean validatedByDatabase;

Expand Down Expand Up @@ -211,11 +212,13 @@ public boolean equals(Object other) {
return this.filters.equals(otherRule.filters) && this.expression.equals(otherRule.expression);
}

protected abstract AbstractGroundArithmeticRule makeGroundRule(float[] coefficients,
GroundAtom[] atoms, FunctionComparator comparator, float constant);
protected abstract AbstractGroundArithmeticRule makeGroundRule(
float[] coefficients, GroundAtom[] atoms, FunctionComparator comparator, float constant, Weight groundedWeight
);

protected abstract AbstractGroundArithmeticRule makeGroundRule(List<Float> coefficients,
List<GroundAtom> atoms, FunctionComparator comparator, float constant);
protected abstract AbstractGroundArithmeticRule makeGroundRule(
List<Float> coefficients, List<GroundAtom> atoms, FunctionComparator comparator, float constant, Weight groundedWeight
);

@Override
public void getCoreAtoms(Set<Atom> result) {
Expand Down Expand Up @@ -339,21 +342,22 @@ private long groundAllNonSummationRule(TermStore termStore, Database database, G

private void groundSingleNonSummationRule(
Constant[] queryRow, Map<Variable, Integer> variableMap,
Database database, GroundingResources resources) {
Database database, GroundingResources resources
) {
// Ground all the atoms.
for (int atomIndex = 0; atomIndex < resources.groundAtoms.length; atomIndex++) {
QueryAtom atom = resources.queryAtoms.get(atomIndex);

// First, check if this atom is a grounding only atom (and it is valid).
// The semantics are not well defined, but any false result will invalidate this grounding.
// The semantics are not well-defined, but any false result will invalidate this grounding.
if (atom.getPredicate() instanceof GroundingOnlyPredicate) {
float result = ((GroundingOnlyPredicate)atom.getPredicate()).computeValue(atom, variableMap, queryRow);
if (MathUtils.equals(result, 0.0f)) {
return;
}
}

GroundAtom groundAtom = resources.queryAtoms.get(atomIndex).ground(
database, queryRow, variableMap, resources.argumentBuffer[atomIndex], -1.0f);
GroundAtom groundAtom = atom.ground(database, queryRow, variableMap, resources.argumentBuffer[atomIndex], -1.0f);
if (groundAtom == null) {
return;
}
Expand All @@ -365,24 +369,50 @@ private void groundSingleNonSummationRule(
}
}

// Note that unweighed rules will ground an equality, while weighted rules will instead
// Ground the deep weight if it exists.
if (resources.weightQueryAtom != null) {
GroundAtom weightGroundAtom = resources.weightQueryAtom.ground(database, queryRow, variableMap, resources.weightArgumentsBuffer, -1.0f);
if (weightGroundAtom == null) {
return;
}

resources.weightGroundAtom = weightGroundAtom;


if (weightGroundAtom instanceof UnmanagedRandomVariableAtom) {
resources.accessExceptionAtoms.add(weightGroundAtom);
}
}

// Note that unweighted rules will ground an equality, while weighted rules will instead
// ground a largerThan and lessThan.
GroundRule groundRule = null;
Weight groundedWeight = null;
if (resources.weightGroundAtom != null) {
groundedWeight = new Weight(1.0f, resources.weightGroundAtom);
}

if (isWeighted() && FunctionComparator.EQ.equals(expression.getComparator())) {
groundRule = makeGroundRule(resources.coefficients, resources.groundAtoms,
FunctionComparator.GTE, resources.finalCoefficient);
groundRule = makeGroundRule(
resources.coefficients, resources.groundAtoms,
FunctionComparator.GTE, resources.finalCoefficient, groundedWeight
);
if (verifyGroundRule(groundRule, database, resources)) {
resources.groundRules.add(groundRule);
}

groundRule = makeGroundRule(resources.coefficients, resources.groundAtoms,
FunctionComparator.LTE, resources.finalCoefficient);
groundRule = makeGroundRule(
resources.coefficients, resources.groundAtoms,
FunctionComparator.LTE, resources.finalCoefficient, groundedWeight
);
if (verifyGroundRule(groundRule, database, resources)) {
resources.groundRules.add(groundRule);
}
} else {
groundRule = makeGroundRule(resources.coefficients, resources.groundAtoms,
expression.getComparator(), resources.finalCoefficient);
groundRule = makeGroundRule(
resources.coefficients, resources.groundAtoms,
expression.getComparator(), resources.finalCoefficient, groundedWeight
);
if (verifyGroundRule(groundRule, database, resources)) {
resources.groundRules.add(groundRule);
}
Expand Down Expand Up @@ -496,28 +526,57 @@ private void groundSingleSummationRule(
return;
}

// Ground the deep weight if it exists.
if (resources.weightQueryAtom != null) {
GroundAtom weightGroundAtom = resources.weightQueryAtom.ground(database, queryRow, variableMap, resources.weightArgumentsBuffer, -1.0f);
if (weightGroundAtom == null) {
return;
}

resources.weightGroundAtom = weightGroundAtom;


if (weightGroundAtom instanceof UnmanagedRandomVariableAtom) {
resources.accessExceptionAtoms.add(weightGroundAtom);
}
}

// Compute the coefficients.
// and we don't need to pass any substitution information.
for (int i = 0; i < resources.coefficients.length; i++) {
resources.coefficients[i] = resources.flatExpression.getAtomCoefficients().get(i).getValue(resources.summationCounts);
}
resources.finalCoefficient = resources.flatExpression.getFinalCoefficient().getValue(resources.summationCounts);

// Note that unweighed rules will ground an equality, while weighted rules will instead
// Note that unweighted rules will ground an equality, while weighted rules will instead
// ground a largerThan and lessThan.
GroundRule groundRule = null;
Weight groundedWeight = null;
if (resources.weightGroundAtom != null) {
groundedWeight = new Weight(1.0f, resources.weightGroundAtom);
}

if (isWeighted() && FunctionComparator.EQ.equals(resources.flatExpression.getComparator())) {
groundRule = makeGroundRule(resources.coefficients, resources.groundAtoms, FunctionComparator.GTE, resources.finalCoefficient);
groundRule = makeGroundRule(
resources.coefficients, resources.groundAtoms, FunctionComparator.GTE,
resources.finalCoefficient, groundedWeight
);
if (verifyGroundRule(groundRule, database, resources)) {
resources.groundRules.add(groundRule);
}

groundRule = makeGroundRule(resources.coefficients, resources.groundAtoms, FunctionComparator.LTE, resources.finalCoefficient);
groundRule = makeGroundRule(
resources.coefficients, resources.groundAtoms, FunctionComparator.LTE,
resources.finalCoefficient, groundedWeight
);
if (verifyGroundRule(groundRule, database, resources)) {
resources.groundRules.add(groundRule);
}
} else {
groundRule = makeGroundRule(resources.coefficients, resources.groundAtoms, resources.flatExpression.getComparator(), resources.finalCoefficient);
groundRule = makeGroundRule(
resources.coefficients, resources.groundAtoms,
resources.flatExpression.getComparator(), resources.finalCoefficient, groundedWeight
);
if (verifyGroundRule(groundRule, database, resources)) {
resources.groundRules.add(groundRule);
}
Expand Down Expand Up @@ -798,7 +857,7 @@ private GroundingResources prepSummationGroundingResources(Database database) {
true);


resources.parseExpression(flatExpression, false);
internalPrepSummationGroundingResources(flatExpression, resources);

resources.summationDataLoaded = true;
resources.flatExpression = flatExpression;
Expand All @@ -810,6 +869,8 @@ private GroundingResources prepSummationGroundingResources(Database database) {
return resources;
}

protected abstract void internalPrepSummationGroundingResources(ArithmeticRuleExpression flatExpression, GroundingResources resources);

/**
* Query the database for the possible replacements for summation variables.
*/
Expand Down Expand Up @@ -939,91 +1000,5 @@ private ResultList fetchSummationValues(Database database, SummationVariable var
return database.executeSQL(new RawQuery(query.validate().toString(), projectionMap, variableTypes));
}

private GroundingResources getGroundingResources(ArithmeticRuleExpression expression) {
GroundingResources resources = null;
if (!Parallel.hasThreadObject(groundingResourcesKey)) {
resources = new GroundingResources();

if (expression != null) {
resources.parseExpression(expression, !hasSummation());
}

Parallel.putThreadObject(groundingResourcesKey, resources);
} else {
resources = (GroundingResources)Parallel.getThreadObject(groundingResourcesKey);
}

return resources;
}

/**
* Resources that every grounding thread will use and reuse.
*/
private static class GroundingResources {
// Because multiple ground rules can be generated from a single rule,
// we need a place to hold onto ground rules until we pass them back.
public List<GroundRule> groundRules;

// Atoms that cause trouble for the atom manager.
public Set<GroundAtom> accessExceptionAtoms;

// Shared resources.

public List<QueryAtom> queryAtoms;
public GroundAtom[] groundAtoms;
public Constant[][] argumentBuffer;
public float[] coefficients;
public float finalCoefficient;

// More resources necessary for summations.

public boolean summationDataLoaded;

// The constext expression with all summation variables expanded.
public ArithmeticRuleExpression flatExpression;

// The maximum counts of all summation variable replacements.
public Map<SummationVariable, Integer> totalSummationCounts;

// A buffer for counting actual replacements.
// If we filter out an atom, we can mark it here.
// This will allow us to make accurate coefficient computations.
public Map<SummationVariable, Integer> summationCounts;

// A marker for every variables that shows which are summation variables.
public List<SummationVariable[]> flatSummationVariables;

// True for each summation atom.
boolean[] flatSummationAtoms;

public GroundingResources() {
groundRules = new ArrayList<GroundRule>();
accessExceptionAtoms = new HashSet<GroundAtom>(4);
}

public void parseExpression(ArithmeticRuleExpression expression, boolean computeCoefficients) {
queryAtoms = new ArrayList<QueryAtom>();
for (SummationAtomOrAtom atom : expression.getAtoms()) {
queryAtoms.add((QueryAtom)atom);
}

groundAtoms = new GroundAtom[queryAtoms.size()];

argumentBuffer = new Constant[queryAtoms.size()][];
for (int i = 0; i < queryAtoms.size(); i++) {
argumentBuffer[i] = new Constant[queryAtoms.get(i).getArity()];
}

coefficients = new float[queryAtoms.size()];
finalCoefficient = 0.0f;

if (computeCoefficients) {
for (int i = 0; i < coefficients.length; i++) {
coefficients[i] = expression.getAtomCoefficients().get(i).getValue(null);
}

finalCoefficient = expression.getFinalCoefficient().getValue(null);
}
}
}
protected abstract GroundingResources getGroundingResources(ArithmeticRuleExpression expression);
}
Loading

0 comments on commit db6544f

Please sign in to comment.