diff --git a/psl-core/src/main/java/org/linqs/psl/application/inference/InferenceApplication.java b/psl-core/src/main/java/org/linqs/psl/application/inference/InferenceApplication.java index d50d55834..a5c9ee62b 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/inference/InferenceApplication.java +++ b/psl-core/src/main/java/org/linqs/psl/application/inference/InferenceApplication.java @@ -264,8 +264,10 @@ protected void relaxHardConstraints() { for (Rule rule : rules) { if (rule instanceof WeightedRule) { Weight weight = ((WeightedRule)rule).getWeight(); - if (weight.getValue() > largestWeight) { - largestWeight = weight.getValue(); + + if ((!weight.isDeep()) && (1.0f > largestWeight)) { + // 1.0f is the largest possible value of a deep weight. + largestWeight = 1.0f; } } else { hasUnweightedRule = true; diff --git a/psl-core/src/main/java/org/linqs/psl/model/rule/Weight.java b/psl-core/src/main/java/org/linqs/psl/model/rule/Weight.java index 84ab40e09..550c34b23 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/rule/Weight.java +++ b/psl-core/src/main/java/org/linqs/psl/model/rule/Weight.java @@ -21,12 +21,14 @@ import org.linqs.psl.model.atom.GroundAtom; import org.linqs.psl.util.HashCode; +import java.io.Serializable; + /** * A weight for a rule. * A weight is a constant value and can be associated with a GroundAtom. * The value of the weight is the constant value multiplied by the value of the GroundAtom. */ -public class Weight { +public class Weight implements Serializable { private float constantValue; private Atom atom; diff --git a/psl-core/src/main/java/org/linqs/psl/model/rule/arithmetic/WeightedArithmeticRule.java b/psl-core/src/main/java/org/linqs/psl/model/rule/arithmetic/WeightedArithmeticRule.java index b1ec1c1db..56df81773 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/rule/arithmetic/WeightedArithmeticRule.java +++ b/psl-core/src/main/java/org/linqs/psl/model/rule/arithmetic/WeightedArithmeticRule.java @@ -37,7 +37,7 @@ public class WeightedArithmeticRule extends AbstractArithmeticRule implements We protected boolean squared; public WeightedArithmeticRule(ArithmeticRuleExpression expression, Weight weight, boolean squared) { - this(expression, weight, squared, expression.toString()); + this(expression, weight, squared, weight.toString() + ": " + expression.toString()); } public WeightedArithmeticRule(ArithmeticRuleExpression expression, Weight weight, boolean squared, String name) { @@ -74,7 +74,7 @@ protected AbstractGroundArithmeticRule makeGroundRule( ); WeightedArithmeticRule groundedDeepWeightedRule = new WeightedArithmeticRule( - newExpression, groundedWeight, squared, groundedWeight.getAtom().toString() + ": " + name + newExpression, groundedWeight, squared ); groundedDeepWeightedRule.setParentHashCode(hashCode()); diff --git a/psl-core/src/main/java/org/linqs/psl/model/rule/logical/AbstractLogicalRule.java b/psl-core/src/main/java/org/linqs/psl/model/rule/logical/AbstractLogicalRule.java index 1230a1691..4c4fade00 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/rule/logical/AbstractLogicalRule.java +++ b/psl-core/src/main/java/org/linqs/psl/model/rule/logical/AbstractLogicalRule.java @@ -72,7 +72,7 @@ protected AbstractLogicalRule(Formula formula, String name, int hashcode) { this.hashcode = hashcode; this.formula = formula; - groundingResourcesKey = AbstractLogicalRule.class.getName() + ";" + formula + ";GroundingResources"; + groundingResourcesKey = AbstractLogicalRule.class.getName() + ";" + name + ";GroundingResources"; // Do the formula analysis so we know what atoms to query for grounding. // We will query for all positive atoms in the negated DNF. diff --git a/psl-core/src/main/java/org/linqs/psl/model/rule/logical/GroundingResources.java b/psl-core/src/main/java/org/linqs/psl/model/rule/logical/GroundingResources.java index be9640876..ce7155c48 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/rule/logical/GroundingResources.java +++ b/psl-core/src/main/java/org/linqs/psl/model/rule/logical/GroundingResources.java @@ -80,7 +80,7 @@ public void parseNegatedDNF(FormulaAnalysis.DNFClause negatedDNF, Weight weight) if ((weight != null) && (weight.isDeep())) { assert (weight.getAtom() instanceof QueryAtom); - weightQueryAtom = (QueryAtom)weight.getAtom(); + weightQueryAtom = (QueryAtom) weight.getAtom(); weightArgumentsBuffer = new Constant[weightQueryAtom.getArity()]; } } diff --git a/psl-core/src/main/java/org/linqs/psl/model/rule/logical/UnweightedLogicalRule.java b/psl-core/src/main/java/org/linqs/psl/model/rule/logical/UnweightedLogicalRule.java index dfe618a2c..c0dd24103 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/rule/logical/UnweightedLogicalRule.java +++ b/psl-core/src/main/java/org/linqs/psl/model/rule/logical/UnweightedLogicalRule.java @@ -56,7 +56,7 @@ protected GroundRule ground(Constant[] constants, Map variabl @Override public WeightedRule relax(Weight weight, boolean squared) { unregister(); - return new WeightedLogicalRule(formula, weight, squared, name); + return new WeightedLogicalRule(formula, weight, squared); } @Override diff --git a/psl-core/src/main/java/org/linqs/psl/model/rule/logical/WeightedLogicalRule.java b/psl-core/src/main/java/org/linqs/psl/model/rule/logical/WeightedLogicalRule.java index a9400a95c..3c1758b97 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/rule/logical/WeightedLogicalRule.java +++ b/psl-core/src/main/java/org/linqs/psl/model/rule/logical/WeightedLogicalRule.java @@ -36,7 +36,7 @@ public class WeightedLogicalRule extends AbstractLogicalRule implements Weighted protected boolean squared; public WeightedLogicalRule(Formula formula, Weight weight, boolean squared) { - this(formula, weight, squared, formula.toString()); + this(formula, weight, squared, weight.toString() + ": " + formula.toString()); } public WeightedLogicalRule(Formula formula, Weight weight, boolean squared, String name) { @@ -67,7 +67,7 @@ protected WeightedGroundLogicalRule groundFormulaInstance(List posLi if (groundedWeight == null) { return new WeightedGroundLogicalRule(this, posLiterals, negLiterals); } else { - WeightedLogicalRule groundedDeepWeightedRule = new WeightedLogicalRule(formula, groundedWeight, squared, groundedWeight.getAtom().toString() + ": " + name); + WeightedLogicalRule groundedDeepWeightedRule = new WeightedLogicalRule(formula, groundedWeight, squared); groundedDeepWeightedRule.setParentHashCode(hashCode()); addChildHashCode(groundedDeepWeightedRule.hashCode()); diff --git a/psl-core/src/main/java/org/linqs/psl/reasoner/term/SimpleTermStore.java b/psl-core/src/main/java/org/linqs/psl/reasoner/term/SimpleTermStore.java index 849264753..646fe5c71 100644 --- a/psl-core/src/main/java/org/linqs/psl/reasoner/term/SimpleTermStore.java +++ b/psl-core/src/main/java/org/linqs/psl/reasoner/term/SimpleTermStore.java @@ -100,6 +100,10 @@ private void mergeAtomComponents(GroundAtom atom1, GroundAtom atom2) { return; } + if (atom1RootIndex == -1 || atom2RootIndex == -1) { + throw new IllegalArgumentException("Atoms must be in the atom store before they can be merged."); + } + GroundAtom atom1Root = atomStore.getAtom(atom1RootIndex); GroundAtom atom2Root = atomStore.getAtom(atom2RootIndex); diff --git a/psl-core/src/test/java/org/linqs/psl/application/inference/InferenceTest.java b/psl-core/src/test/java/org/linqs/psl/application/inference/InferenceTest.java index e12aa3ccd..586567bbb 100644 --- a/psl-core/src/test/java/org/linqs/psl/application/inference/InferenceTest.java +++ b/psl-core/src/test/java/org/linqs/psl/application/inference/InferenceTest.java @@ -435,6 +435,17 @@ public void testSimpleModels() { inference.close(); inferDB.close(); + + // Exogenous model with observed atom weight. + info = TestModel.getExogenousModelWithObservedAtomWeight(); + inferDB = info.dataStore.getDatabase(info.targetPartition, new HashSet(), info.observationPartition); + inference = getInference(info.model.getRules(), inferDB); + + // Test the inference application is able to find the MAP state. + assertEquals(0.0, inference.inference(), 0.1f); + + inference.close(); + inferDB.close(); } /** @@ -459,7 +470,7 @@ public void reasonerEvaluateTest() { @Test public void testAtomWithConstant() { - // Nice(A) & Nice(B) & Friends('Alice', B) && (A != B) -> Friends(A, B) + // 1.0: Nice(A) & Nice(B) & Friends('Alice', B) && (A != B) -> Friends(A, B) info.model.addRule(new WeightedLogicalRule( new Implication( new Conjunction( diff --git a/psl-core/src/test/java/org/linqs/psl/test/TestModel.java b/psl-core/src/test/java/org/linqs/psl/test/TestModel.java index 22b500920..8b9dcc6e1 100644 --- a/psl-core/src/test/java/org/linqs/psl/test/TestModel.java +++ b/psl-core/src/test/java/org/linqs/psl/test/TestModel.java @@ -395,6 +395,127 @@ public static ModelInformation getExogenousModel() { return getModel(DatabaseTestUtil.getH2Driver(), predicates, rules, observations, targets, truths); } + /** + * A model with only a single exogenous rule with a observed atom weight. + * Weight(A, B): Nice(A) & Nice(B) & (A != B) -> Friends(A, B) ^2 + * Such that the weight is 1.0 for all instance of the rule. + */ + public static ModelInformation getExogenousModelWithObservedAtomWeight() { + // Define Predicates + Map predicatesInfo = new HashMap(); + predicatesInfo.put("Nice", new ConstantType[]{ConstantType.UniqueStringID}); + predicatesInfo.put("Friends", new ConstantType[]{ConstantType.UniqueStringID, ConstantType.UniqueStringID}); + predicatesInfo.put("Weight", new ConstantType[]{ConstantType.UniqueStringID, ConstantType.UniqueStringID}); + + Map predicates = new HashMap(); + for (Map.Entry predicateEntry : predicatesInfo.entrySet()) { + StandardPredicate predicate = StandardPredicate.get(predicateEntry.getKey(), predicateEntry.getValue()); + predicates.put(predicateEntry.getKey(), predicate); + } + + // Define Rules + List rules = new ArrayList(); + rules.add(new WeightedLogicalRule( + new Implication( + new Conjunction( + new QueryAtom(predicates.get("Nice"), new Variable("A")), + new QueryAtom(predicates.get("Nice"), new Variable("B")), + new QueryAtom(GroundingOnlyPredicate.NotEqual, new Variable("A"), new Variable("B")) + ), + new QueryAtom(predicates.get("Friends"), new Variable("A"), new Variable("B")) + ), + new Weight(1.0f, new QueryAtom(predicates.get("Weight"), new Variable("A"), new Variable("B"))), + true)); + + // Data + Map> observations = new HashMap>(); + Map> targets = new HashMap>(); + Map> truths = new HashMap>(); + + // Nice + observations.put(predicates.get("Nice"), new ArrayList(Arrays.asList( + new PredicateData(1.0, new Object[]{"Alice"}), + new PredicateData(1.0, new Object[]{"Bob"}), + new PredicateData(1.0, new Object[]{"Charlie"}), + new PredicateData(1.0, new Object[]{"Derek"}), + new PredicateData(1.0, new Object[]{"Eugene"}) + ))); + + // Weight + observations.put(predicates.get("Weight"), new ArrayList(Arrays.asList( + new PredicateData(1.0, new Object[]{"Alice", "Bob"}), + new PredicateData(1.0, new Object[]{"Bob", "Alice"}), + new PredicateData(1.0, new Object[]{"Alice", "Charlie"}), + new PredicateData(1.0, new Object[]{"Charlie", "Alice"}), + new PredicateData(1.0, new Object[]{"Alice", "Derek"}), + new PredicateData(1.0, new Object[]{"Derek", "Alice"}), + new PredicateData(1.0, new Object[]{"Alice", "Eugene"}), + new PredicateData(1.0, new Object[]{"Eugene", "Alice"}), + new PredicateData(1.0, new Object[]{"Bob", "Charlie"}), + new PredicateData(1.0, new Object[]{"Charlie", "Bob"}), + new PredicateData(1.0, new Object[]{"Bob", "Derek"}), + new PredicateData(1.0, new Object[]{"Derek", "Bob"}), + new PredicateData(1.0, new Object[]{"Bob", "Eugene"}), + new PredicateData(1.0, new Object[]{"Eugene", "Bob"}), + new PredicateData(1.0, new Object[]{"Charlie", "Derek"}), + new PredicateData(1.0, new Object[]{"Derek", "Charlie"}), + new PredicateData(1.0, new Object[]{"Charlie", "Eugene"}), + new PredicateData(1.0, new Object[]{"Eugene", "Charlie"}), + new PredicateData(1.0, new Object[]{"Derek", "Eugene"}), + new PredicateData(1.0, new Object[]{"Eugene", "Derek"}) + ))); + + // Friends + targets.put(predicates.get("Friends"), new ArrayList(Arrays.asList( + new PredicateData(new Object[]{"Alice", "Bob"}), + new PredicateData(new Object[]{"Bob", "Alice"}), + new PredicateData(new Object[]{"Alice", "Charlie"}), + new PredicateData(new Object[]{"Charlie", "Alice"}), + new PredicateData(new Object[]{"Alice", "Derek"}), + new PredicateData(new Object[]{"Derek", "Alice"}), + new PredicateData(new Object[]{"Alice", "Eugene"}), + new PredicateData(new Object[]{"Eugene", "Alice"}), + new PredicateData(new Object[]{"Bob", "Charlie"}), + new PredicateData(new Object[]{"Charlie", "Bob"}), + new PredicateData(new Object[]{"Bob", "Derek"}), + new PredicateData(new Object[]{"Derek", "Bob"}), + new PredicateData(new Object[]{"Bob", "Eugene"}), + new PredicateData(new Object[]{"Eugene", "Bob"}), + new PredicateData(new Object[]{"Charlie", "Derek"}), + new PredicateData(new Object[]{"Derek", "Charlie"}), + new PredicateData(new Object[]{"Charlie", "Eugene"}), + new PredicateData(new Object[]{"Eugene", "Charlie"}), + new PredicateData(new Object[]{"Derek", "Eugene"}), + new PredicateData(new Object[]{"Eugene", "Derek"}) + ))); + + truths.put(predicates.get("Friends"), new ArrayList(Arrays.asList( + new PredicateData(1, new Object[]{"Alice", "Bob"}), + new PredicateData(1, new Object[]{"Bob", "Alice"}), + new PredicateData(1, new Object[]{"Alice", "Charlie"}), + new PredicateData(1, new Object[]{"Charlie", "Alice"}), + new PredicateData(1, new Object[]{"Alice", "Derek"}), + new PredicateData(1, new Object[]{"Derek", "Alice"}), + new PredicateData(1, new Object[]{"Alice", "Eugene"}), + new PredicateData(1, new Object[]{"Eugene", "Alice"}), + new PredicateData(1, new Object[]{"Bob", "Charlie"}), + new PredicateData(1, new Object[]{"Charlie", "Bob"}), + new PredicateData(1, new Object[]{"Bob", "Derek"}), + new PredicateData(1, new Object[]{"Derek", "Bob"}), + new PredicateData(0, new Object[]{"Bob", "Eugene"}), + new PredicateData(0, new Object[]{"Eugene", "Bob"}), + new PredicateData(1, new Object[]{"Charlie", "Derek"}), + new PredicateData(1, new Object[]{"Derek", "Charlie"}), + new PredicateData(0, new Object[]{"Charlie", "Eugene"}), + new PredicateData(0, new Object[]{"Eugene", "Charlie"}), + new PredicateData(0, new Object[]{"Derek", "Eugene"}), + new PredicateData(0, new Object[]{"Eugene", "Derek"}) + ))); + + return getModel(DatabaseTestUtil.getH2Driver(), predicates, rules, observations, targets, truths); + } + + /** * A model with only a single symmetry rule. * 10: Person(A) & Person(B) & Friends(A, B) & (A != B) -> Friends(B, A) ^2