From c8d30e37a3627f5738acc85ee9f38f80cac655b3 Mon Sep 17 00:00:00 2001 From: Jeffrey Phillips Freeman <jeffrey.freeman@syncleus.com> Date: Mon, 10 Nov 2014 06:02:00 -0500 Subject: [PATCH] Cleaned up the code from the last commit, removed unused code and improved the variable names. Issue: GRL-7 Change-Id: Iafe7a4a9b9e12e4e5b08183e2c9deb81b0c4df52 --- .../neural/AbstractActivationNeuron.java | 2 +- .../backprop/ActionTriggerXor3InputTest.java | 72 ++++++++----------- 2 files changed, 30 insertions(+), 44 deletions(-) diff --git a/src/main/java/com/syncleus/grail/neural/AbstractActivationNeuron.java b/src/main/java/com/syncleus/grail/neural/AbstractActivationNeuron.java index 61e06dc..99c1e70 100644 --- a/src/main/java/com/syncleus/grail/neural/AbstractActivationNeuron.java +++ b/src/main/java/com/syncleus/grail/neural/AbstractActivationNeuron.java @@ -11,7 +11,7 @@ public abstract class AbstractActivationNeuron implements ActivationNeuron { @Initializer public void init() { - this.setActivationFunctionClass(SineActivationFunction.class); + this.setActivationFunctionClass(HyperbolicTangentActivationFunction.class); this.setActivity(0.0); } diff --git a/src/test/java/com/syncleus/grail/neural/backprop/ActionTriggerXor3InputTest.java b/src/test/java/com/syncleus/grail/neural/backprop/ActionTriggerXor3InputTest.java index f398865..eff2694 100644 --- a/src/test/java/com/syncleus/grail/neural/backprop/ActionTriggerXor3InputTest.java +++ b/src/test/java/com/syncleus/grail/neural/backprop/ActionTriggerXor3InputTest.java @@ -11,7 +11,7 @@ import java.util.*; public class ActionTriggerXor3InputTest { - private static final ActivationFunction ACTIVATION_FUNCTION = new SineActivationFunction(); + private static final ActivationFunction ACTIVATION_FUNCTION = new HyperbolicTangentActivationFunction(); @Test public void testXor() { @@ -107,59 +107,45 @@ public class ActionTriggerXor3InputTest { // // Graph is constructed, just need to train and test our network now. // - int t, maxCycles = 2000; - int completionPeriod = 50; - double maxError = 0.1; - for (t = maxCycles; t >= 0; t--) { + final int maxCycles = 10000; + final int completionPeriod = 50; + final double maxError = 0.75; + for (int cycle = maxCycles; cycle >= 0; cycle--) { int finished = 0; - for (int i = -1; i <= 1; i += 2) { - for (int j = -1; j <= 1; j += 2) { - for (int k = -1; k <= 1; k += 2) { - boolean bi = i >= 0; - boolean bj = j >= 0; - boolean bk = k >= 0; + for (int in1 = -1; in1 <= 1; in1 += 2) { + for (int in2 = -1; in2 <= 1; in2 += 2) { + for (int in3 = -1; in3 <= 1; in3 += 2) { + boolean bi = in1 >= 0; + boolean bj = in2 >= 0; + boolean bk = in3 >= 0; boolean expect = bi ^ bj ^ bk; - double expectD = expect ? +1 : -1; + double expectD = expect ? 1.0 : -1.0; - train(graph, i, j, k, expectD); + train(graph, in1, in2, in3, expectD); - - if (t % completionPeriod == 0 && - calculateError(graph, i, j, k, expectD) < maxError) { - finished++; + if (cycle % completionPeriod == 0 && calculateError(graph, in1, in2, in3, expectD) < maxError) { + finished++; } } } } - if (finished == 8) break; + if (finished == 8) + break; } - //System.out.println("Cycles: " + (maxCycles - t)); - -// for(int i = 0; i < 10000; i++) { -// ActionTriggerXor3InputTest.train(graph, 1.0, 1.0, 1.0, -1.0); -// ActionTriggerXor3InputTest.train(graph, -1.0, 1.0, 1.0, -1.0); -// ActionTriggerXor3InputTest.train(graph, 1.0, -1.0, 1.0, -1.0); -// ActionTriggerXor3InputTest.train(graph, 1.0, 1.0, -1.0, -1.0); -// ActionTriggerXor3InputTest.train(graph, -1.0, -1.0, 1.0, 1.0); -// ActionTriggerXor3InputTest.train(graph, -1.0, 1.0, -1.0, 1.0); -// ActionTriggerXor3InputTest.train(graph, 1.0, -1.0, -1.0, 1.0); -// ActionTriggerXor3InputTest.train(graph, -1.0, -1.0, -1.0, -1.0); -// if( i%50 == 0 && ActionTriggerXor3InputTest.calculateError(graph) < 0.1 ) -// break; -// } -// Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, 1.0, 1.0, 1.0) < 0.0); -// Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, -1.0, 1.0, 1.0) < 0.0); -// Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, 1.0, -1.0, 1.0) < 0.0); -// Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, 1.0, 1.0, -1.0) < 0.0); -// Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, -1.0, -1.0, 1.0) > 0.0); -// Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, -1.0, 1.0, -1.0) > 0.0); -// Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, 1.0, -1.0, -1.0) > 0.0); -// Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, -1.0, -1.0, -1.0) < 0.0); + + Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, 1.0, 1.0, 1.0) > 0.0); + Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, -1.0, 1.0, 1.0) < 0.0); + Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, 1.0, -1.0, 1.0) < 0.0); + Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, 1.0, 1.0, -1.0) < 0.0); + Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, -1.0, -1.0, 1.0) > 0.0); + Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, -1.0, 1.0, -1.0) > 0.0); + Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, 1.0, -1.0, -1.0) > 0.0); + Assert.assertTrue(ActionTriggerXor3InputTest.propagate(graph, -1.0, -1.0, -1.0) < 0.0); } - private static double calculateError(FramedTransactionalGraph<?> graph, double in0, double in1, double in2, double expect) { - double actual = ActionTriggerXor3InputTest.propagate(graph, in0, in1, in2); - return Math.abs(actual - expect) / Math.abs(actual); + private static double calculateError(FramedTransactionalGraph<?> graph, double in1, double in2, double in3, double expect) { + double actual = ActionTriggerXor3InputTest.propagate(graph, in1, in2, in3); + return Math.abs(actual - expect) / Math.abs(expect); } private static void train(final FramedTransactionalGraph<?> graph, final double input1, final double input2, final double input3, final double expected) { -- GitLab