From 14c37cbab6d1a4041a1e9482df9b9cbef98fb071 Mon Sep 17 00:00:00 2001 From: Jeffrey Phillips Freeman <jeffrey.freeman@syncleus.com> Date: Sun, 4 Sep 2011 10:04:05 -0400 Subject: [PATCH] Fixed a bug in the Markov Random Field's joint probability function causing it to calculate incorrect results. --- .../graphicalmodel/GraphicalModelNode.java | 1 + .../SimpleGraphicalModelNode.java | 34 ++++++++++++++++ ...stractMarkovRandomFieldAdjacencyGraph.java | 40 ++++--------------- 3 files changed, 42 insertions(+), 33 deletions(-) diff --git a/src/main/java/com/syncleus/dann/graphicalmodel/GraphicalModelNode.java b/src/main/java/com/syncleus/dann/graphicalmodel/GraphicalModelNode.java index b006bccd..075d9174 100644 --- a/src/main/java/com/syncleus/dann/graphicalmodel/GraphicalModelNode.java +++ b/src/main/java/com/syncleus/dann/graphicalmodel/GraphicalModelNode.java @@ -29,5 +29,6 @@ public interface GraphicalModelNode<S> extends XmlSerializable<GraphicalModelNod S getState(); void learnState(); double stateProbability(); + double stateProbability(Set<? extends GraphicalModelNode> ignoredInfluences); void reset(); } diff --git a/src/main/java/com/syncleus/dann/graphicalmodel/SimpleGraphicalModelNode.java b/src/main/java/com/syncleus/dann/graphicalmodel/SimpleGraphicalModelNode.java index 5548769c..4e85acdc 100644 --- a/src/main/java/com/syncleus/dann/graphicalmodel/SimpleGraphicalModelNode.java +++ b/src/main/java/com/syncleus/dann/graphicalmodel/SimpleGraphicalModelNode.java @@ -107,6 +107,40 @@ public class SimpleGraphicalModelNode<S> extends AbstractContextNode<GraphicalMo return ((stateEvidence == null) ? 0.0 : stateEvidence.getPercentage(this.state)); } + @Override + public double stateProbability(Set<? extends GraphicalModelNode> ignoredInfluences) + { + final Set<GraphicalModelNode> influences = new HashSet<GraphicalModelNode>(this.getInfluencingNodes()); + influences.removeAll(ignoredInfluences); + + int evidenceOccurrence = 0; + int totalOccurrence = 0; + + NextEvidence: + for(final Map.Entry<Map<GraphicalModelNode, Object>, StateEvidence<S>> evidenceEntry : this.evidence.entrySet()) + { + final Map<GraphicalModelNode, Object> influencingEvidence = evidenceEntry.getKey(); + for(GraphicalModelNode influence : influences) + { + final Object influencingEvidenceState = influencingEvidence.get(influence); + if( (influencingEvidenceState == null)||(!influencingEvidenceState.equals(influence.getState())) ) + continue NextEvidence; + } + + final StateEvidence<S> evidence = evidenceEntry.getValue(); + + final Integer currentEvidenceOccurrence = evidence.get(this.getState()); + if( currentEvidenceOccurrence != null ) + evidenceOccurrence += evidence.get(this.getState()); + totalOccurrence += evidence.getTotalEvidence(); + } + + if( totalOccurrence == 0 ) + return 0.0; + + return ((double)evidenceOccurrence) / ((double)totalOccurrence); + } + private Map<GraphicalModelNode, Object> getInfluencingStates() { //TODO change this so it only cares if it has edges to work with and doesnt care what networks its a part of diff --git a/src/main/java/com/syncleus/dann/graphicalmodel/markovrandomfield/AbstractMarkovRandomFieldAdjacencyGraph.java b/src/main/java/com/syncleus/dann/graphicalmodel/markovrandomfield/AbstractMarkovRandomFieldAdjacencyGraph.java index faa4aa40..829c6665 100644 --- a/src/main/java/com/syncleus/dann/graphicalmodel/markovrandomfield/AbstractMarkovRandomFieldAdjacencyGraph.java +++ b/src/main/java/com/syncleus/dann/graphicalmodel/markovrandomfield/AbstractMarkovRandomFieldAdjacencyGraph.java @@ -43,42 +43,16 @@ public abstract class AbstractMarkovRandomFieldAdjacencyGraph<N extends Graphica @Override public double jointProbability() { - //first we need to preserve a map of all the starting states so we can reset the network back to its starting - //point when we are done - Map<N, Object> startingStates = new HashMap<N, Object>(); - for(N node : this.getNodes()) - startingStates.put(node, node.getState()); - - try + final Set<N> seenNodes = new HashSet<N>(); + double probabilityProduct = 1.0; + for(final N node : this.getNodes()) { - final Set<N> seenNodes = new HashSet<N>(); - double probabilityProduct = 1.0; - for(final N node : this.getNodes()) - { - assert !seenNodes.contains(node); - - //if none of its neighbors have been seen, then calculate it normally - final Set<N> nodesToVary = new HashSet<N>(seenNodes); - nodesToVary.retainAll(this.getAdjacentNodes(node)); - resetNodeStates(nodesToVary); - double nodeStateProbability = 0.0; - do - { - nodeStateProbability += node.stateProbability(); - } - while( !incrementNodeStates(nodesToVary) ); + assert !seenNodes.contains(node); - seenNodes.add(node); + probabilityProduct *= node.stateProbability(seenNodes); - probabilityProduct *= nodeStateProbability; - } - return probabilityProduct; - } - finally - { - //restore the initial states when we are done - for(Map.Entry<N,Object> nodeState : startingStates.entrySet()) - nodeState.getKey().setState(nodeState.getValue()); + seenNodes.add(node); } + return probabilityProduct; } } \ No newline at end of file -- GitLab