diff --git a/src/main/java/com/syncleus/grail/neural/AbstractActivationNeuron.java b/src/main/java/com/syncleus/grail/neural/AbstractActivationNeuron.java index 61e06dc8722fe0ce70f773feb6214e1e6f6789ed..99c1e70ea780e2e32ae49871ea24f3122a6bfa94 100644 --- a/src/main/java/com/syncleus/grail/neural/AbstractActivationNeuron.java +++ b/src/main/java/com/syncleus/grail/neural/AbstractActivationNeuron.java @@ -11,7 +11,7 @@ public abstract class AbstractActivationNeuron implements ActivationNeuron { @Initializer public void init() { - this.setActivationFunctionClass(SineActivationFunction.class); + this.setActivationFunctionClass(HyperbolicTangentActivationFunction.class); this.setActivity(0.0); } diff --git a/src/test/java/com/syncleus/grail/neural/backprop/ActionTriggerXor3InputTest.java b/src/test/java/com/syncleus/grail/neural/backprop/ActionTriggerXor3InputTest.java index f3988655f27dcbca5ac4f3fd825703d6c175973f..eff2694d2c8f8477c8a47a6ed9a684d37e17d1fa 100644 --- a/src/test/java/com/syncleus/grail/neural/backprop/ActionTriggerXor3InputTest.java +++ b/src/test/java/com/syncleus/grail/neural/backprop/ActionTriggerXor3InputTest.java @@ -11,7 +11,7 @@ import java.util.*; public class ActionTriggerXor3InputTest { - private static final ActivationFunction ACTIVATION_FUNCTION = new SineActivationFunction(); + private static final ActivationFunction ACTIVATION_FUNCTION = new HyperbolicTangentActivationFunction(); @Test public void testXor() { @@ -107,59 +107,45 @@ public class ActionTriggerXor3InputTest { // // Graph is constructed, just need to train and test our network now. // - int t, maxCycles = 2000; - int completionPeriod = 50; - double maxError = 0.1; - for (t = maxCycles; t >= 0; t--) { + final int maxCycles = 10000; + final int completionPeriod = 50; + final double maxError = 0.75; + for (int cycle = maxCycles; cycle >= 0; cycle--) { int finished = 0; - for (int i = -1; i <= 1; i += 2) { - for (int j = -1; j <= 1; j += 2) { - for (int k = -1; k <= 1; k += 2) { - boolean bi = i >= 0; - boolean bj = j >= 0; - boolean bk = k >= 0; + for (int in1 = -1; in1 <= 1; in1 += 2) { + for (int in2 = -1; in2 <= 1; in2 += 2) { + for (int in3 = -1; in3 <= 1; in3 += 2) { + boolean bi = in1 >= 0; + boolean bj = in2 >= 0; + boolean bk = in3 >= 0; boolean expect = bi ^ bj ^ bk; - double expectD = expect ? +1 : -1; + double expectD = expect ? 1.0 : -1.0; - train(graph, i, j, k, expectD); + train(graph, in1, in2, in3, expectD); - - if (t % completionPeriod == 0 && - calculateError(graph, i, j, k, expectD) < maxError) { - finished++; + if (cycle % completionPeriod == 0 && calculateError(graph, in1, in2, in3, expectD) < maxError) { + finished++; } } } } - if (finished == 8) break; + if (finished == 8) + break; } - //System.out.println("Cycles: " + (maxCycles - t)); - -// for(int i = 0; i < 10000; i++) { -// ActionTriggerXor3InputTest.train(graph, 1.0, 1.0, 1.0, -1.0); -// ActionTriggerXor3InputTest.train(graph, -1.0, 1.0, 1.0, -1.0); -// ActionTriggerXor3InputTest.train(graph, 1.0, -1.0, 1.0, -1.0); -// ActionTriggerXor3InputTest.train(graph, 1.0, 1.0, -1.0, -1.0); -// ActionTriggerXor3InputTest.train(graph, -1.0, -1.0, 1.0, 1.0); -// ActionTriggerXor3InputTest.train(graph, -1.0, 1.0, -1.0, 1.0); -// ActionTriggerXor3InputTest.train(graph, 1.0, -1.0, -1.0, 1.0); -// ActionTriggerXor3InputTest.train(graph, -1.0, -1.0, -1.0, -1.0); -// if( i%50 == 0 && ActionTriggerXor3InputTest.calculateError(graph) < 0.1 ) -// break; -// } -// Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, 1.0, 1.0, 1.0) < 0.0); -// Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, -1.0, 1.0, 1.0) < 0.0); -// Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, 1.0, -1.0, 1.0) < 0.0); -// Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, 1.0, 1.0, -1.0) < 0.0); -// Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, -1.0, -1.0, 1.0) > 0.0); -// Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, -1.0, 1.0, -1.0) > 0.0); -// Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, 1.0, -1.0, -1.0) > 0.0); -// Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, -1.0, -1.0, -1.0) < 0.0); + + Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, 1.0, 1.0, 1.0) > 0.0); + Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, -1.0, 1.0, 1.0) < 0.0); + Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, 1.0, -1.0, 1.0) < 0.0); + Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, 1.0, 1.0, -1.0) < 0.0); + Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, -1.0, -1.0, 1.0) > 0.0); + Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, -1.0, 1.0, -1.0) > 0.0); + Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, 1.0, -1.0, -1.0) > 0.0); + Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, -1.0, -1.0, -1.0) < 0.0); } - private static double calculateError(FramedTransactionalGraph<?> graph, double in0, double in1, double in2, double expect) { - double actual = ActionTriggerXor3InputTest.propagate(graph, in0, in1, in2); - return Math.abs(actual - expect) / Math.abs(actual); + private static double calculateError(FramedTransactionalGraph<?> graph, double in1, double in2, double in3, double expect) { + double actual = ActionTriggerXor3InputTest.propagate(graph, in1, in2, in3); + return Math.abs(actual - expect) / Math.abs(expect); } private static void train(final FramedTransactionalGraph<?> graph, final double input1, final double input2, final double input3, final double expected) {