From 14c37cbab6d1a4041a1e9482df9b9cbef98fb071 Mon Sep 17 00:00:00 2001
From: Jeffrey Phillips Freeman <jeffrey.freeman@syncleus.com>
Date: Sun, 4 Sep 2011 10:04:05 -0400
Subject: [PATCH] Fixed a bug in the Markov Random Field's joint probability
 function causing it to calculate incorrect results.

---
 .../graphicalmodel/GraphicalModelNode.java    |  1 +
 .../SimpleGraphicalModelNode.java             | 34 ++++++++++++++++
 ...stractMarkovRandomFieldAdjacencyGraph.java | 40 ++++---------------
 3 files changed, 42 insertions(+), 33 deletions(-)

diff --git a/src/main/java/com/syncleus/dann/graphicalmodel/GraphicalModelNode.java b/src/main/java/com/syncleus/dann/graphicalmodel/GraphicalModelNode.java
index b006bccd..075d9174 100644
--- a/src/main/java/com/syncleus/dann/graphicalmodel/GraphicalModelNode.java
+++ b/src/main/java/com/syncleus/dann/graphicalmodel/GraphicalModelNode.java
@@ -29,5 +29,6 @@ public interface GraphicalModelNode<S> extends XmlSerializable<GraphicalModelNod
 	S getState();
 	void learnState();
 	double stateProbability();
+	double stateProbability(Set<? extends GraphicalModelNode> ignoredInfluences);
 	void reset();
 }
diff --git a/src/main/java/com/syncleus/dann/graphicalmodel/SimpleGraphicalModelNode.java b/src/main/java/com/syncleus/dann/graphicalmodel/SimpleGraphicalModelNode.java
index 5548769c..4e85acdc 100644
--- a/src/main/java/com/syncleus/dann/graphicalmodel/SimpleGraphicalModelNode.java
+++ b/src/main/java/com/syncleus/dann/graphicalmodel/SimpleGraphicalModelNode.java
@@ -107,6 +107,40 @@ public class SimpleGraphicalModelNode<S> extends AbstractContextNode<GraphicalMo
 		return ((stateEvidence == null) ? 0.0 : stateEvidence.getPercentage(this.state));
 	}
 
+	@Override
+	public double stateProbability(Set<? extends GraphicalModelNode> ignoredInfluences)
+	{
+		final Set<GraphicalModelNode> influences = new HashSet<GraphicalModelNode>(this.getInfluencingNodes());
+		influences.removeAll(ignoredInfluences);
+
+		int evidenceOccurrence = 0;
+		int totalOccurrence = 0;
+
+		NextEvidence:
+		for(final Map.Entry<Map<GraphicalModelNode, Object>, StateEvidence<S>> evidenceEntry : this.evidence.entrySet())
+		{
+			final Map<GraphicalModelNode, Object> influencingEvidence = evidenceEntry.getKey();
+			for(GraphicalModelNode influence : influences)
+			{
+				final Object influencingEvidenceState = influencingEvidence.get(influence);
+				if( (influencingEvidenceState == null)||(!influencingEvidenceState.equals(influence.getState())) )
+					continue NextEvidence;
+			}
+
+			final StateEvidence<S> evidence = evidenceEntry.getValue();
+
+			final Integer currentEvidenceOccurrence = evidence.get(this.getState());
+			if( currentEvidenceOccurrence != null )
+				evidenceOccurrence += evidence.get(this.getState());
+			totalOccurrence += evidence.getTotalEvidence();
+		}
+
+		if( totalOccurrence == 0 )
+			return 0.0;
+
+		return ((double)evidenceOccurrence) / ((double)totalOccurrence);
+	}
+
 	private Map<GraphicalModelNode, Object> getInfluencingStates()
 	{
 		//TODO change this so it only cares if it has edges to work with and doesnt care what networks its a part of
diff --git a/src/main/java/com/syncleus/dann/graphicalmodel/markovrandomfield/AbstractMarkovRandomFieldAdjacencyGraph.java b/src/main/java/com/syncleus/dann/graphicalmodel/markovrandomfield/AbstractMarkovRandomFieldAdjacencyGraph.java
index faa4aa40..829c6665 100644
--- a/src/main/java/com/syncleus/dann/graphicalmodel/markovrandomfield/AbstractMarkovRandomFieldAdjacencyGraph.java
+++ b/src/main/java/com/syncleus/dann/graphicalmodel/markovrandomfield/AbstractMarkovRandomFieldAdjacencyGraph.java
@@ -43,42 +43,16 @@ public abstract class AbstractMarkovRandomFieldAdjacencyGraph<N extends Graphica
 	@Override
 	public double jointProbability()
 	{
-		//first we need to preserve a map of all the starting states so we can reset the network back to its starting
-		//point when we are done
-		Map<N, Object> startingStates = new HashMap<N, Object>();
-		for(N node : this.getNodes())
-			startingStates.put(node, node.getState());
-
-		try
+		final Set<N> seenNodes = new HashSet<N>();
+		double probabilityProduct = 1.0;
+		for(final N node : this.getNodes())
 		{
-			final Set<N> seenNodes = new HashSet<N>();
-			double probabilityProduct = 1.0;
-			for(final N node : this.getNodes())
-			{
-				assert !seenNodes.contains(node);
-
-				//if none of its neighbors have been seen, then calculate it normally
-				final Set<N> nodesToVary = new HashSet<N>(seenNodes);
-				nodesToVary.retainAll(this.getAdjacentNodes(node));
-				resetNodeStates(nodesToVary);
-				double nodeStateProbability = 0.0;
-				do
-				{
-					nodeStateProbability += node.stateProbability();
-				}
-				while( !incrementNodeStates(nodesToVary) );
+			assert !seenNodes.contains(node);
 
-				seenNodes.add(node);
+			probabilityProduct *= node.stateProbability(seenNodes);
 
-				probabilityProduct *= nodeStateProbability;
-			}
-			return probabilityProduct;
-		}
-		finally
-		{
-			//restore the initial states when we are done
-			for(Map.Entry<N,Object> nodeState : startingStates.entrySet())
-				nodeState.getKey().setState(nodeState.getValue());
+			seenNodes.add(node);
 		}
+		return probabilityProduct;
 	}
 }
\ No newline at end of file
-- 
GitLab