From 82b241a03f5ef13b38b9123b2f19d8784f393a71 Mon Sep 17 00:00:00 2001
From: CoreRasurae <luis.p.mendes@gmail.com>
Date: Fri, 27 Apr 2018 12:57:46 +0100
Subject: [PATCH] Fix: Fixes issue #118 and improves OpenCLDevice.configure()
 exception handling

---
 .../java/com/aparapi/device/OpenCLDevice.java |  20 ++-
 .../runtime/AtomicsSupportAdvTest.java        |   6 +
 .../aparapi/runtime/AtomicsSupportTest.java   |   6 +
 .../aparapi/runtime/BarrierSupportTest.java   |   6 +
 .../MultiDimensionalLocalArrayTest.java       |   6 +
 .../aparapi/runtime/NegativeIntegerTest.java  |   6 +
 .../runtime/OpenCLDeviceConfiguratorTest.java | 155 +++++++++++++++++-
 .../runtime/OriginalKernelManager.java        |  25 +++
 .../ProfileReportBackwardsCompatTest.java     |  12 +-
 .../runtime/ProfileReportNewAPITest.java      |   7 +
 src/test/java/com/aparapi/runtime/Util.java   |   8 +-
 11 files changed, 244 insertions(+), 13 deletions(-)
 create mode 100644 src/test/java/com/aparapi/runtime/OriginalKernelManager.java

diff --git a/src/main/java/com/aparapi/device/OpenCLDevice.java b/src/main/java/com/aparapi/device/OpenCLDevice.java
index 56cd1760..c962ba8f 100644
--- a/src/main/java/com/aparapi/device/OpenCLDevice.java
+++ b/src/main/java/com/aparapi/device/OpenCLDevice.java
@@ -28,7 +28,10 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.logging.Level;
+import java.util.logging.Logger;
 
+import com.aparapi.Config;
 import com.aparapi.Range;
 import com.aparapi.internal.opencl.OpenCLArgDescriptor;
 import com.aparapi.internal.opencl.OpenCLKernel;
@@ -46,6 +49,7 @@ import com.aparapi.opencl.OpenCL.Resource;
 import com.aparapi.opencl.OpenCL.Source;
 
 public class OpenCLDevice extends Device implements Comparable<Device> {
+   private static Logger logger = Logger.getLogger(Config.getLoggerName()); 	
 	
    private static IOpenCLDeviceConfigurator configurator = null;
 
@@ -137,6 +141,15 @@ public class OpenCLDevice extends Device implements Comparable<Device> {
     this.name = name;
   }
 
+  private static void configuratorWrapper(final IOpenCLDeviceConfigurator configurator, final OpenCLDevice device) {
+	  try {
+		  configurator.configure(device);
+	  } catch (Throwable ex) {
+		  logger.log(Level.WARNING, "Failed to configure device - Id: " + device.deviceId + 
+				  ", Name: " + device.getName(), ex);
+	  }
+  }
+  
   /**
    * Called by the underlying Aparapi OpenCL platform, upon device
    * detection.
@@ -144,8 +157,11 @@ public class OpenCLDevice extends Device implements Comparable<Device> {
   public void configure() {
 	  if (configurator != null && !underConfiguration.get() &&
 			  underConfiguration.compareAndSet(false, true)) {
-		 configurator.configure(this);
-		 underConfiguration.set(false);
+		 try {
+			 configuratorWrapper(configurator, this);
+		 } finally {
+			 underConfiguration.set(false);
+		 }
 	  }
   }
   
diff --git a/src/test/java/com/aparapi/runtime/AtomicsSupportAdvTest.java b/src/test/java/com/aparapi/runtime/AtomicsSupportAdvTest.java
index 9a15ae58..e9daf33a 100644
--- a/src/test/java/com/aparapi/runtime/AtomicsSupportAdvTest.java
+++ b/src/test/java/com/aparapi/runtime/AtomicsSupportAdvTest.java
@@ -24,6 +24,7 @@ import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.concurrent.atomic.AtomicInteger;
 
+import org.junit.AfterClass;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -75,6 +76,11 @@ public class AtomicsSupportAdvTest {
         openCLDevice = (OpenCLDevice) device;
     }
 
+    @AfterClass
+    public static void classTeardown() {
+    	Util.resetKernelManager();
+    }
+    
     @Test
     public void testOpenCLExplicit() {
     	final int in[] = new int[SIZE];
diff --git a/src/test/java/com/aparapi/runtime/AtomicsSupportTest.java b/src/test/java/com/aparapi/runtime/AtomicsSupportTest.java
index d6e4b2a5..ef9bd22f 100644
--- a/src/test/java/com/aparapi/runtime/AtomicsSupportTest.java
+++ b/src/test/java/com/aparapi/runtime/AtomicsSupportTest.java
@@ -23,6 +23,7 @@ import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.concurrent.atomic.AtomicInteger;
 
+import org.junit.AfterClass;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -60,6 +61,11 @@ public class AtomicsSupportTest {
     	}
     }
     
+    @AfterClass
+    public static void classTeardown() {
+    	Util.resetKernelManager();
+    }
+    
     @Before
     public void setUpBeforeClass() throws Exception {
     	KernelManager.setKernelManager(new CLKernelManager());
diff --git a/src/test/java/com/aparapi/runtime/BarrierSupportTest.java b/src/test/java/com/aparapi/runtime/BarrierSupportTest.java
index c0403aba..b6b65c67 100644
--- a/src/test/java/com/aparapi/runtime/BarrierSupportTest.java
+++ b/src/test/java/com/aparapi/runtime/BarrierSupportTest.java
@@ -29,6 +29,7 @@ import java.util.Arrays;
 import java.util.LinkedHashSet;
 import java.util.List;
 
+import org.junit.AfterClass;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -56,6 +57,11 @@ public class BarrierSupportTest {
     	}
     }
 
+    @AfterClass
+    public static void classTeardown() {
+    	Util.resetKernelManager();
+    }
+    
     @Before
     public void setUpBefore() throws Exception {
     	KernelManager.setKernelManager(new CLKernelManager());
diff --git a/src/test/java/com/aparapi/runtime/MultiDimensionalLocalArrayTest.java b/src/test/java/com/aparapi/runtime/MultiDimensionalLocalArrayTest.java
index 8cc14518..7eb5461c 100644
--- a/src/test/java/com/aparapi/runtime/MultiDimensionalLocalArrayTest.java
+++ b/src/test/java/com/aparapi/runtime/MultiDimensionalLocalArrayTest.java
@@ -22,6 +22,7 @@ import java.util.Arrays;
 import java.util.LinkedHashSet;
 import java.util.List;
 
+import org.junit.AfterClass;
 import org.junit.Before;
 import org.junit.Ignore;
 import org.junit.Test;
@@ -59,6 +60,11 @@ public class MultiDimensionalLocalArrayTest
     		return Arrays.asList(Device.TYPE.JTP);
     	}
     }
+
+    @AfterClass
+    public static void classTeardown() {
+    	Util.resetKernelManager();
+    }
     
     @Before
     public void setUpBeforeClass() throws Exception {
diff --git a/src/test/java/com/aparapi/runtime/NegativeIntegerTest.java b/src/test/java/com/aparapi/runtime/NegativeIntegerTest.java
index 1575cdf7..183bc772 100644
--- a/src/test/java/com/aparapi/runtime/NegativeIntegerTest.java
+++ b/src/test/java/com/aparapi/runtime/NegativeIntegerTest.java
@@ -21,6 +21,7 @@ import static org.junit.Assume.assumeTrue;
 import java.util.Arrays;
 import java.util.List;
 
+import org.junit.AfterClass;
 import org.junit.Before;
 import org.junit.Ignore;
 import org.junit.Test;
@@ -58,6 +59,11 @@ public class NegativeIntegerTest
         openCLDevice = (OpenCLDevice) device;
     }
     
+    @AfterClass
+    public static void classTeardown() {
+    	Util.resetKernelManager();
+    }
+    
     @Test
     public void negativeIntegerTestPass()
     {
diff --git a/src/test/java/com/aparapi/runtime/OpenCLDeviceConfiguratorTest.java b/src/test/java/com/aparapi/runtime/OpenCLDeviceConfiguratorTest.java
index 942dd2ef..16c4dc00 100644
--- a/src/test/java/com/aparapi/runtime/OpenCLDeviceConfiguratorTest.java
+++ b/src/test/java/com/aparapi/runtime/OpenCLDeviceConfiguratorTest.java
@@ -17,19 +17,26 @@ package com.aparapi.runtime;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assume.assumeTrue;
 
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.concurrent.atomic.AtomicInteger;
 
+import org.junit.After;
 import org.junit.Test;
 
 import com.aparapi.device.Device;
 import com.aparapi.device.IOpenCLDeviceConfigurator;
+import com.aparapi.device.JavaDevice;
 import com.aparapi.device.OpenCLDevice;
 import com.aparapi.internal.kernel.KernelManager;
+import com.aparapi.internal.kernel.KernelPreferences;
 import com.aparapi.internal.opencl.OpenCLPlatform;
 
 /**
@@ -39,16 +46,99 @@ import com.aparapi.internal.opencl.OpenCLPlatform;
  */
 public class OpenCLDeviceConfiguratorTest {
     private static OpenCLDevice openCLDevice = null;
+    private final AtomicInteger callCounter = new AtomicInteger(0);
+    
+	public static List<OpenCLDevice> listDevices(OpenCLDevice.TYPE type) {
+		final ArrayList<OpenCLDevice> results = new ArrayList<>();
 
-    private class CLKernelManager extends KernelManager {
-    	@Override
-    	protected List<Device.TYPE> getPreferredDeviceTypes() {
-    		return Arrays.asList(Device.TYPE.ACC, Device.TYPE.GPU, Device.TYPE.CPU);
-    	}
-    }
+		for (final OpenCLPlatform p : OpenCLPlatform.getUncachedOpenCLPlatforms()) {
+			for (final OpenCLDevice device : p.getOpenCLDevices()) {
+				if (type == null || device.getType() == type) {
+					results.add(device);
+				}
+			}
+		}
+
+		return results;
+	}
+
+	private class UncachedCLKernelManager extends KernelManager {
+		private KernelPreferences defaultPreferences;
+		
+		@Override
+		protected void setup() {
+			callCounter.set(0);
+			defaultPreferences = createDefaultPreferences();
+		}
+		
+		@Override
+		public KernelPreferences getDefaultPreferences() {
+			return defaultPreferences;
+		}
+	
+		private List<OpenCLDevice> filter(OpenCLDevice.TYPE type, List<OpenCLDevice> devices) {
+			final ArrayList<OpenCLDevice> results = new ArrayList<>();
+
+			for (final OpenCLDevice device : devices) {
+				if (type == null || device.getType() == type) {
+					results.add(device);
+				}
+			}
+
+			return results;
+		}
+		
+		@Override
+		protected LinkedHashSet<Device> createDefaultPreferredDevices() {
+			LinkedHashSet<Device> devices = new LinkedHashSet<>();
+
+			List<OpenCLDevice> all = listDevices(null);
+			
+			List<OpenCLDevice> accelerators = filter(Device.TYPE.ACC, all);
+			List<OpenCLDevice> gpus = filter(Device.TYPE.GPU, all);
+			List<OpenCLDevice> cpus = filter(Device.TYPE.CPU, all);
+
+			Collections.sort(accelerators, getDefaultAcceleratorComparator());
+			Collections.sort(gpus, getDefaultGPUComparator());
+
+			List<Device.TYPE> preferredDeviceTypes = getPreferredDeviceTypes();
+
+			for (Device.TYPE type : preferredDeviceTypes) {
+				switch (type) {
+				case UNKNOWN:
+					throw new AssertionError("UNKNOWN device type not supported");
+				case GPU:
+					devices.addAll(gpus);
+					break;
+				case CPU:
+					devices.addAll(cpus);
+					break;
+				case JTP:
+					devices.add(JavaDevice.THREAD_POOL);
+					break;
+				case SEQ:
+					devices.add(JavaDevice.SEQUENTIAL);
+					break;
+				case ACC:
+					devices.addAll(accelerators);
+					break;
+				case ALT:
+					devices.add(JavaDevice.ALTERNATIVE_ALGORITHM);
+					break;
+				}
+			}
+			
+			return devices;
+		}
+		
+		@Override
+		protected List<Device.TYPE> getPreferredDeviceTypes() {
+			return Arrays.asList(Device.TYPE.ACC, Device.TYPE.GPU, Device.TYPE.CPU);
+		}
+	}
         
     public void setUp() throws Exception {
-    	KernelManager.setKernelManager(new CLKernelManager());
+    	KernelManager.setKernelManager(new UncachedCLKernelManager());
         Device device = KernelManager.instance().bestDevice();
         if (device == null || !(device instanceof OpenCLDevice)) {
         	System.out.println("!!!No OpenCLDevice available for running the integration test");
@@ -57,6 +147,12 @@ public class OpenCLDeviceConfiguratorTest {
         openCLDevice = (OpenCLDevice) device;
     }
 
+    @After
+    public void teardDown() {
+    	Util.resetKernelManager();
+    }
+    
+    
     public void setUpWithConfigurator(IOpenCLDeviceConfigurator configurator) throws Exception {
     	OpenCLDevice.setConfigurator(configurator);
     	setUp();
@@ -64,7 +160,6 @@ public class OpenCLDeviceConfiguratorTest {
     
     @Test
     public void configuratorCallbackTest() throws Exception {
-    	final AtomicInteger callCounter = new AtomicInteger(0);
     	IOpenCLDeviceConfigurator configurator = new IOpenCLDeviceConfigurator() {
 			@Override
 			public void configure(OpenCLDevice device) {
@@ -93,4 +188,48 @@ public class OpenCLDeviceConfiguratorTest {
     	assertEquals("Number of configured devices should match numnber of devices", numberOfDevices, numberOfConfiguredDevices);
     	assertEquals("Number of calls doesn't match the expected", numberOfDevices*2, callCounter.get());
     }
+    
+    @Test
+    public void noConfiguratorTest() throws Exception {
+    	setUp();
+    	assertTrue("Device isShareMempory() should return true", openCLDevice.isSharedMemory());
+		assertNotEquals("Device name should not be \"Configured\"", "Configured", openCLDevice.getName());
+    	List<OpenCLPlatform> platforms = OpenCLPlatform.getUncachedOpenCLPlatforms();
+    	for (OpenCLPlatform platform : platforms) {
+    		for (OpenCLDevice device : platform.getOpenCLDevices()) {
+    			assertTrue("Device isSharedMempory() should return true", device.isSharedMemory());
+    			assertNotEquals("Device name should not be \"Configured\"", "Configured", device.getName());
+    		}
+    	}
+    }
+    
+    @Test
+    public void protectionAgainstRecursiveConfiguresTest() {
+    	OpenCLDevice dev = new OpenCLDevice(null, 101L, Device.TYPE.CPU);
+    	final AtomicInteger callCounter = new AtomicInteger(0);
+    	IOpenCLDeviceConfigurator configurator = new IOpenCLDeviceConfigurator() {
+			@Override
+			public void configure(OpenCLDevice device) {
+				callCounter.incrementAndGet();
+				device.configure();
+			}
+    	};
+    	OpenCLDevice.setConfigurator(configurator);
+    	dev.configure();
+    	
+    	assertEquals("Number of confgure() calls should be one", 1, callCounter.get());
+    }
+    
+    @Test
+    public void noExceptionConfiguratorTest() {
+    	OpenCLDevice dev = new OpenCLDevice(null, 101L, Device.TYPE.CPU);
+    	IOpenCLDeviceConfigurator configurator = new IOpenCLDeviceConfigurator() {
+			@Override
+			public void configure(OpenCLDevice device) {
+				throw new IllegalArgumentException("Should be catched exception");
+			}
+    	};
+    	OpenCLDevice.setConfigurator(configurator);
+    	dev.configure();    	
+    }
  }
diff --git a/src/test/java/com/aparapi/runtime/OriginalKernelManager.java b/src/test/java/com/aparapi/runtime/OriginalKernelManager.java
new file mode 100644
index 00000000..423c8034
--- /dev/null
+++ b/src/test/java/com/aparapi/runtime/OriginalKernelManager.java
@@ -0,0 +1,25 @@
+/**
+ * Copyright (c) 2016 - 2018 Syncleus, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.aparapi.runtime;
+
+import com.aparapi.internal.kernel.KernelManager;
+
+/**
+ * Provides a way for re-establishing the default Aparapi KernelManager
+ * @author CoreRasurae
+ */
+public class OriginalKernelManager extends KernelManager {
+}
diff --git a/src/test/java/com/aparapi/runtime/ProfileReportBackwardsCompatTest.java b/src/test/java/com/aparapi/runtime/ProfileReportBackwardsCompatTest.java
index a69c3ae2..85ce02fc 100644
--- a/src/test/java/com/aparapi/runtime/ProfileReportBackwardsCompatTest.java
+++ b/src/test/java/com/aparapi/runtime/ProfileReportBackwardsCompatTest.java
@@ -15,6 +15,9 @@
  */
 package com.aparapi.runtime;
 
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
 import static org.junit.Assume.assumeTrue;
 
 import java.lang.ref.WeakReference;
@@ -28,7 +31,8 @@ import java.util.concurrent.TimeUnit;
 import java.util.logging.Level;
 import java.util.logging.Logger;
 
-import static org.junit.Assert.*;
+import org.junit.After;
+import org.junit.AfterClass;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TestName;
@@ -57,7 +61,6 @@ public class ProfileReportBackwardsCompatTest {
 	@Rule 
 	public TestName name = new TestName();
 
-	
     private class CLKernelManager extends KernelManager {
     	@Override
     	protected List<Device.TYPE> getPreferredDeviceTypes() {
@@ -77,6 +80,11 @@ public class ProfileReportBackwardsCompatTest {
     	}
     }
     
+    @After
+    public void classTeardown() {
+    	Util.resetKernelManager();
+    }
+    
     public void setUpBefore() throws Exception {
     	KernelManager.setKernelManager(new CLKernelManager());
         Device device = KernelManager.instance().bestDevice();
diff --git a/src/test/java/com/aparapi/runtime/ProfileReportNewAPITest.java b/src/test/java/com/aparapi/runtime/ProfileReportNewAPITest.java
index 0abfd26a..7c25051a 100644
--- a/src/test/java/com/aparapi/runtime/ProfileReportNewAPITest.java
+++ b/src/test/java/com/aparapi/runtime/ProfileReportNewAPITest.java
@@ -38,6 +38,7 @@ import java.util.concurrent.atomic.AtomicInteger;
 import java.util.logging.Level;
 import java.util.logging.Logger;
 
+import org.junit.After;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TestName;
@@ -68,6 +69,12 @@ public class ProfileReportNewAPITest {
 	public TestName name = new TestName();
 
 	
+    @After
+    public void classTeardown() {
+    	Util.resetKernelManager();
+    }
+
+	
     private class CLKernelManager extends KernelManager {
     	@Override
     	protected List<Device.TYPE> getPreferredDeviceTypes() {
diff --git a/src/test/java/com/aparapi/runtime/Util.java b/src/test/java/com/aparapi/runtime/Util.java
index adae0bdf..1f999482 100644
--- a/src/test/java/com/aparapi/runtime/Util.java
+++ b/src/test/java/com/aparapi/runtime/Util.java
@@ -17,7 +17,13 @@ package com.aparapi.runtime;
 
 import java.util.Arrays;
 
-public class Util {
+import com.aparapi.internal.kernel.KernelManager;
+
+public class Util {	
+	static void resetKernelManager() {
+		KernelManager.setKernelManager(new OriginalKernelManager());
+	}
+	
     static void fill(int[] array, Filler _filler) {
         for (int i = 0; i < array.length; i++) {
             _filler.fill(array, i);
-- 
GitLab