diff --git a/src/main/java/com/syncleus/dann/math/statistics/SimpleMarkovChain.java b/src/main/java/com/syncleus/dann/math/statistics/SimpleMarkovChain.java index a57c01cecf4915579783920bf25973206217cad9..4cb3535087ff27d868cd47c42ceb318c521dd183 100644 --- a/src/main/java/com/syncleus/dann/math/statistics/SimpleMarkovChain.java +++ b/src/main/java/com/syncleus/dann/math/statistics/SimpleMarkovChain.java @@ -51,66 +51,11 @@ public class SimpleMarkovChain<S> extends AbstractMarkovChain<S> this.states = Collections.unmodifiableSet(states); this.rowMapping = new ArrayList<List<S>>(); -// final int rows = transitionProbabilities.size(); -// final int columns = this.states.size(); final int columns = (this.states.size() > transitionProbabilities.size() ? this.states.size() : transitionProbabilities.size()); final int rows = columns; final double[][] matrixValues = new double[rows][columns]; - //this.columnMapping = new ArrayList<S>(this.states); // <-- did not get the states in the correct order - -/* - //Generate the column mapping from the first element of each transition probability's entry key (which is a list of states) - //iterates through column mapping, matching rowheading leading elements - this.columnMapping = new ArrayList<S>(); - for(final Entry<List<S>, Map<S, Double>> transitionProbability : transitionProbabilities.entrySet()) - { - final List<S> rowHeader = Collections.unmodifiableList(new ArrayList<S>(transitionProbability.getKey())); - - S nextColumn; - if (rowHeader.isEmpty()) - { - continue; - //nextColumn = null; - } - else - { - nextColumn = rowHeader.get(0); - } - - if (!columnMapping.contains(nextColumn)) - { - columnMapping.add(nextColumn); - } - } - - //iterate through all the new rows - int row = 0; - for(final Entry<List<S>, Map<S, Double>> transitionProbability : transitionProbabilities.entrySet()) - { - final List<S> rowHeader = Collections.unmodifiableList(new ArrayList<S>(transitionProbability.getKey())); - final Map<S, Double> rowTransition = Collections.unmodifiableMap(new LinkedHashMap<S, Double>(transitionProbability.getValue())); - - assert !rowMapping.contains(rowHeader); - - this.rowMapping.add(rowHeader); - - double rowSum = 0.0; - for(final Entry<S, Double> stateTransition : rowTransition.entrySet()) - { - final int column = this.columnMapping.indexOf(stateTransition.getKey()); - matrixValues[row][column] = stateTransition.getValue(); - rowSum += matrixValues[row][column]; - } - - if( Math.abs(rowSum - 1.0) > MAXIMUM_ROW_ERROR ) - throw new IllegalArgumentException("One of the rows does not sum to 1"); - - row++; - } -*/ - this.columnMapping = new ArrayList<S>(this.states); @@ -291,21 +236,12 @@ public class SimpleMarkovChain<S> extends AbstractMarkovChain<S> } final RealMatrix simultaneousMatrix = new SimpleRealMatrix(simultaneousValues); -System.out.println(); -System.out.println("transitionProbabilityMatrix matrix:\n" + transitionProbabilityMatrix.toString()); -//System.out.println("steadyState matrix:\n" + steadyStateMatrix.toString()); -System.out.println("simultaneous matrix:\n" + simultaneousMatrix.toString()); - final double[][] solutionValues = new double[simultaneousValues.length][1]; solutionValues[simultaneousValues.length - 1][0] = 1.0; final RealMatrix solutionMatrix = new SimpleRealMatrix(solutionValues); -System.out.println("solution matrix:\n" + solutionMatrix.toString()); - final RealMatrix simultaneousSolved = simultaneousMatrix.solve(solutionMatrix); -System.out.println("simultaneous solved:\n" + simultaneousSolved.toString()); - final Map<S, Double> stateProbabilities = new LinkedHashMap<S, Double>(); for(int stateIndex = 0; stateIndex < this.columnMapping.size(); stateIndex++) { @@ -314,9 +250,6 @@ System.out.println("simultaneous solved:\n" + simultaneousSolved.toString()); stateProbabilities.put(currentState, currentProbability); } - //System.out.println("xstate probabilities:\n" + stateProbabilities); - //System.out.println("---"); - return Collections.unmodifiableMap(stateProbabilities); } } diff --git a/src/main/java/com/syncleus/dann/math/statistics/SimpleMarkovChainEvidence.java b/src/main/java/com/syncleus/dann/math/statistics/SimpleMarkovChainEvidence.java index ab1468370553c273c6bc33b86325744e727ad58a..72381dd6e12bed4c46723a546350a3edbae04567 100644 --- a/src/main/java/com/syncleus/dann/math/statistics/SimpleMarkovChainEvidence.java +++ b/src/main/java/com/syncleus/dann/math/statistics/SimpleMarkovChainEvidence.java @@ -122,28 +122,6 @@ public class SimpleMarkovChainEvidence<S> implements MarkovChainEvidence<S> for(Map.Entry<List<S>, StateCounter<S>> countEntry : this.evidence.entrySet()) transitionProbabilities.put(countEntry.getKey(), countEntry.getValue().probabilities()); -System.out.println("all influences:"); -for(List<S> influences : transitionProbabilities.keySet()) -{ - System.out.print(influences.size() + " influences: "); - for( S influence : influences ) - { - if( influence != null ) - System.out.print(influence + " "); - else - System.out.print("null "); - } - System.out.print(" -> "); - -// StateCounter<S> counter = this.evidence.get(influences); - Map<S,Double> probabilities = transitionProbabilities.get(influences); - for( Map.Entry<S,Double> probabilityEntry : probabilities.entrySet()) - { - System.out.print(probabilityEntry.getKey() + ":" + probabilityEntry.getValue() + " "); - } - System.out.println(); -} - return new SimpleMarkovChain<S>(transitionProbabilities, this.order, this.observedStates); } diff --git a/src/test/java/com/syncleus/dann/math/statistics/TestSimpleMarkovChain.java b/src/test/java/com/syncleus/dann/math/statistics/TestSimpleMarkovChain.java index 9fd792f6a2684c05f73768bde9ee03cda9def910..624375b9c1be9571ff4191c45208b78073cf6066 100644 --- a/src/test/java/com/syncleus/dann/math/statistics/TestSimpleMarkovChain.java +++ b/src/test/java/com/syncleus/dann/math/statistics/TestSimpleMarkovChain.java @@ -85,7 +85,6 @@ public class TestSimpleMarkovChain @Test public void testExplicitChainSecondOrder() { -System.out.println("===ESO begining==="); final Map<List<WeatherState>, Map<WeatherState, Double>> transitionProbabilities = new HashMap<List<WeatherState>, Map<WeatherState, Double>>(); /* @@ -96,7 +95,6 @@ System.out.println("===ESO begining==="); transitionProbabilities.put(initialState, initialTransitions); */ - final List<WeatherState> sunnyState = new ArrayList<WeatherState>(); sunnyState.add(WeatherState.SUNNY); final Map<WeatherState, Double> sunnyTransitions = new TreeMap<WeatherState, Double>(); @@ -155,15 +153,13 @@ System.out.println("===ESO begining==="); LOGGER.info("transition rows: " + simpleChain.getTransitionProbabilityRows()); LOGGER.info("transition matrix: " + simpleChain.getTransitionProbabilityMatrix()); LOGGER.info("steady state: " + simpleChain.getSteadyStateProbability(WeatherState.SUNNY) + " , " + simpleChain.getSteadyStateProbability(WeatherState.RAINY)); -System.out.println("ESO testing sunny steady:"); + Assert.assertEquals("Sunny steady state incorrect", 0.83333333333, Math.abs(simpleChain.getSteadyStateProbability(WeatherState.SUNNY)), 0.001); -System.out.println("ESO sunny test done"); Assert.assertEquals("Rainy steady state incorrect", 0.16666666666, Math.abs(simpleChain.getSteadyStateProbability(WeatherState.RAINY)), 0.001); Assert.assertEquals("Sunny 1 step incorrect", 0.9, Math.abs(simpleChain.getProbability(WeatherState.SUNNY, 1)), 0.001); Assert.assertEquals("Rainy 1 step incorrect", 0.1, Math.abs(simpleChain.getProbability(WeatherState.RAINY, 1)), 0.001); Assert.assertEquals("Sunny 2 step incorrect", 0.86, Math.abs(simpleChain.getProbability(WeatherState.SUNNY, 2)), 0.001); Assert.assertEquals("Rainy 2 step incorrect", 0.14, Math.abs(simpleChain.getProbability(WeatherState.RAINY, 2)), 0.001); -System.out.println("===ESO ending==="); } @Test @@ -275,27 +271,4 @@ System.out.println("===ESO ending==="); Assert.assertEquals("Sunny 2 step incorrect", 0.86, Math.abs(simpleChain.getProbability(WeatherState.SUNNY, 2)), 0.025); Assert.assertEquals("Rainy 2 step incorrect", 0.14, Math.abs(simpleChain.getProbability(WeatherState.RAINY, 2)), 0.025); } - - @Test - public void testFailProbability() - { - final int RUNS = 10; - int failures = 0; - for(int run = 0; run < RUNS; run++) - { - try - { - testExplicitChainFirstOrder(); - testExplicitChainSecondOrder(); - testSampledChainFirstOrder(); - testSampledChainSecondOrder(); - } - catch (java.lang.AssertionError err) - { - err.printStackTrace(); - failures++; - } - } - Assert.assertTrue("testSampledChainSecondOrder - failed runs: " + failures + "/" + RUNS + " -> " + ((double)failures / RUNS), failures == 0); - } }