diff --git a/src/main/java/com/aparapi/internal/kernel/KernelRunner.java b/src/main/java/com/aparapi/internal/kernel/KernelRunner.java index 56d17f98477e2704226c4eeb94b541eb1bc93534..4df1233d510d03268e79b67033ff932688fcb8d6 100644 --- a/src/main/java/com/aparapi/internal/kernel/KernelRunner.java +++ b/src/main/java/com/aparapi/internal/kernel/KernelRunner.java @@ -65,12 +65,14 @@ import com.aparapi.internal.util.*; import com.aparapi.internal.writer.*; import com.aparapi.opencl.*; +import java.lang.Thread.UncaughtExceptionHandler; import java.lang.reflect.*; import java.nio.*; import java.util.*; import java.util.concurrent.*; import java.util.concurrent.ForkJoinPool.*; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; import java.util.logging.*; /** @@ -145,12 +147,23 @@ public class KernelRunner extends KernelRunnerJNI{ return newThread; } }; - - private static final ForkJoinPool threadPool = new ForkJoinPool(Runtime.getRuntime().availableProcessors(), - lowPriorityThreadFactory, null, false); + + final private ThreadDiedHandler handler = new ThreadDiedHandler(); + //Allow a thread pool per KernelRunner which will also be per Kernel instance + private final ForkJoinPool threadPool = new ForkJoinPool(Runtime.getRuntime().availableProcessors(), + lowPriorityThreadFactory, handler, false); private static HashMap<Class<? extends Kernel>, String> openCLCache = new HashMap<>(); private static LinkedHashSet<String> seenBinaryKeys = new LinkedHashSet<>(); + private class ThreadDiedHandler implements UncaughtExceptionHandler { + private AtomicLong threadsDiedCounter = new AtomicLong(0); + @Override + public void uncaughtException(Thread t, Throwable e) { + logger.log(Level.SEVERE, "Thread died in thread pool of kernel runner for kernel: " + kernel.getClass(), e); + threadsDiedCounter.incrementAndGet(); + } + } + /** * Create a KernelRunner for a specific Kernel instance. * @@ -444,11 +457,6 @@ public class KernelRunner extends KernelRunnerJNI{ final int numGroups0 = _settings.range.getNumGroups(0); final int numGroups1 = _settings.range.getNumGroups(1); final int globalGroups = numGroups0 * numGroups1 * _settings.range.getNumGroups(2); - /** - * This joinBarrier is the barrier that we provide for the kernel threads to rendezvous with the current dispatch thread. - * So this barrier is threadCount+1 wide (the +1 is for the dispatch thread) - */ - final CyclicBarrier joinBarrier = new FJSafeCyclicBarrier(threads + 1); /** * This localBarrier is only ever used by the kernels. If the kernel does not use the barrier the threads @@ -589,12 +597,17 @@ public class KernelRunner extends KernelRunnerJNI{ } }; } - else + else { throw new IllegalArgumentException("Expected 1,2 or 3 dimensions, found " + _settings.range.getDims()); + } + + ForkJoinTask<?>[] tasks = new ForkJoinTask<?>[threads]; for (passId = 0; passId < _settings.passes; passId++) { if (getCancelState() == CANCEL_STATUS_TRUE) { break; } + + long deadThreadCount = handler.threadsDiedCounter.get(); /** * Note that we emulate OpenCL by creating one thread per localId (across the group). * @@ -646,7 +659,7 @@ public class KernelRunner extends KernelRunnerJNI{ kernelState.setLocalBarrier(localBarrier); } - threadPool.submit( + ForkJoinTask<?> fjt = threadPool.submit( // () -> { new Runnable() { public void run() { @@ -659,14 +672,21 @@ public class KernelRunner extends KernelRunnerJNI{ catch (RuntimeException | Error e) { logger.log(Level.SEVERE, "Execution failed", e); } - finally { - await(joinBarrier); // This thread will rendezvous with dispatch thread here. This is effectively a join. - } } }); + + tasks[id] = fjt; } - await(joinBarrier); // This dispatch thread waits for all worker threads here. + + for (ForkJoinTask<?> task : tasks) { // This dispatch thread waits for all worker threads here. + task.join(); + } + + long deathCount = handler.threadsDiedCounter.get() - deadThreadCount; + if (deathCount > 0) { + logger.log(Level.SEVERE, deathCount + "Pool threads died during execution of kernel: " + kernel.getClass().getName() + " at pass: " + passId); + } } passId = PASS_ID_COMPLETED_EXECUTION; } // execution mode == JTP @@ -1521,7 +1541,7 @@ public class KernelRunner extends KernelRunnerJNI{ if (handle == 0) { return fallBackToNextDevice(_settings, "OpenCL compile failed"); } - + args = new KernelArg[entryPoint.getReferencedFields().size()]; int i = 0;