diff --git a/com.amd.aparapi.jni/src/cpp/runKernel/Aparapi.cpp b/com.amd.aparapi.jni/src/cpp/runKernel/Aparapi.cpp index 920645fd5053d9c554504a5b047ab9ef59f0d565..cbad7e81539f48989f847e32edc5ebf643f2b413 100644 --- a/com.amd.aparapi.jni/src/cpp/runKernel/Aparapi.cpp +++ b/com.amd.aparapi.jni/src/cpp/runKernel/Aparapi.cpp @@ -50,6 +50,10 @@ #include "List.h" #include <algorithm> +static const int PASS_ID_PREPARING_EXECUTION = -2; +static const int PASS_ID_COMPLETED_EXECUTION = -1; +static const int CANCEL_STATUS_FALSE = 0; +static const int CANCEL_STATUS_TRUE = 1; //compiler dependant code /** @@ -777,8 +781,23 @@ void enqueueKernel(JNIContext* jniContext, Range& range, int passes, int argPos, jniContext->passes = passes; jniContext->exec = new ProfileInfo[passes]; + jbyte* kernelOutBytes = jniContext->runKernelOutBytes; + int* kernelOutBytesAsInts = reinterpret_cast<int*>(kernelOutBytes); + + jbyte* kernelInBytes = jniContext->runKernelInBytes; + int* kernelInBytesAsInts = reinterpret_cast<int*>(kernelInBytes); + cl_int status = CL_SUCCESS; for (int passid=0; passid < passes; passid++) { + + int cancelCode = kernelInBytesAsInts[0]; + kernelOutBytesAsInts[0] = passid; + + if (cancelCode == CANCEL_STATUS_TRUE) { + fprintf(stderr, "received cancellation, aborting at pass %d\n", passid); + kernelOutBytes[0] = -1; + break; + } //size_t offset = 1; // (size_t)((range.globalDims[0]/jniContext->deviceIdc)*dev); status = clSetKernelArg(jniContext->kernel, argPos, sizeof(passid), &(passid)); @@ -874,6 +893,7 @@ void enqueueKernel(JNIContext* jniContext, Range& range, int passes, int argPos, } } + kernelOutBytesAsInts[0] = PASS_ID_COMPLETED_EXECUTION; } @@ -1050,7 +1070,7 @@ void checkEvents(JNIEnv* jenv, JNIContext* jniContext, int writeEventCount) { } JNI_JAVA(jint, KernelRunnerJNI, runKernelJNI) - (JNIEnv *jenv, jobject jobj, jlong jniContextHandle, jobject _range, jboolean needSync, jint passes) { + (JNIEnv *jenv, jobject jobj, jlong jniContextHandle, jobject _range, jboolean needSync, jint passes, jobject inBuffer, jobject outBuffer) { if (config == NULL){ config = new Config(jenv); } @@ -1059,7 +1079,16 @@ JNI_JAVA(jint, KernelRunnerJNI, runKernelJNI) cl_int status = CL_SUCCESS; JNIContext* jniContext = JNIContext::getJNIContext(jniContextHandle); + jniContext->runKernelInBytes = (jbyte*)jenv->GetDirectBufferAddress(inBuffer); + jniContext->runKernelOutBytes = (jbyte*)jenv->GetDirectBufferAddress(outBuffer); + + jbyte* kernelInBytes = jniContext->runKernelInBytes; + int* kernelInBytesAsInts = reinterpret_cast<int*>(kernelInBytes); + kernelInBytesAsInts[0] = CANCEL_STATUS_FALSE; + jbyte* kernelOutBytes = jniContext->runKernelOutBytes; + int* kernelOutBytesAsInts = reinterpret_cast<int*>(kernelOutBytes); + kernelOutBytesAsInts[0] = PASS_ID_PREPARING_EXECUTION; if (jniContext->firstRun && config->isProfilingEnabled()){ try { diff --git a/com.amd.aparapi.jni/src/cpp/runKernel/JNIContext.h b/com.amd.aparapi.jni/src/cpp/runKernel/JNIContext.h index aebad48a54ef7767be8694fca9165a4c03b47cdd..952853582e58425f6cf4360bf2cbba69f4abaa7e 100644 --- a/com.amd.aparapi.jni/src/cpp/runKernel/JNIContext.h +++ b/com.amd.aparapi.jni/src/cpp/runKernel/JNIContext.h @@ -32,6 +32,9 @@ public: jint* writeEventArgs; jboolean firstRun; jint passes; + jbyte* runKernelInBytes; + jbyte* runKernelOutBytes; + ProfileInfo *exec; FILE* profileFile; diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/Kernel.java b/com.amd.aparapi/src/java/com/amd/aparapi/Kernel.java index f09dfb892d6a50ef594b53235dbbc1e21d493f5a..eeb2ef7ff13469f3fdb2743caf3c29870f3e8d47 100644 --- a/com.amd.aparapi/src/java/com/amd/aparapi/Kernel.java +++ b/com.amd.aparapi/src/java/com/amd/aparapi/Kernel.java @@ -37,22 +37,35 @@ under those regulations, please refer to the U.S. Bureau of Industry and Securit */ package com.amd.aparapi; -import com.amd.aparapi.annotation.*; -import com.amd.aparapi.exception.*; -import com.amd.aparapi.internal.kernel.*; +import com.amd.aparapi.annotation.Experimental; +import com.amd.aparapi.exception.DeprecatedException; +import com.amd.aparapi.internal.kernel.KernelRunner; import com.amd.aparapi.internal.model.CacheEnabler; +import com.amd.aparapi.internal.model.ClassModel.ConstantPool.MethodReferenceEntry; +import com.amd.aparapi.internal.model.ClassModel.ConstantPool.NameAndTypeEntry; import com.amd.aparapi.internal.model.ValueCache; import com.amd.aparapi.internal.model.ValueCache.ThrowingValueComputer; import com.amd.aparapi.internal.model.ValueCache.ValueComputer; -import com.amd.aparapi.internal.model.ClassModel.ConstantPool.*; -import com.amd.aparapi.internal.opencl.*; -import com.amd.aparapi.internal.util.*; - -import java.lang.annotation.*; -import java.lang.reflect.*; -import java.util.*; -import java.util.concurrent.*; -import java.util.logging.*; +import com.amd.aparapi.internal.opencl.OpenCLLoader; +import com.amd.aparapi.internal.util.UnsafeWrapper; + +import java.lang.annotation.Annotation; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import java.lang.reflect.Method; +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.Deque; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.concurrent.BrokenBarrierException; +import java.util.concurrent.CyclicBarrier; +import java.util.logging.Logger; /** * A <i>kernel</i> encapsulates a data parallel algorithm that will execute either on a GPU @@ -943,6 +956,37 @@ public abstract class Kernel implements Cloneable { */ public abstract void run(); + /** + * Invoking this method flags that once the current pass is complete execution should be abandoned. Due to the complexity of intercommunication + * between java (or C) and executing OpenCL, this is the best we can do for general cancellation of execution at present. OpenCL 2.0 should introduce + * pipe mechanisms which will support mid-pass cancellation easily. + * + * <p> + * Note that in the case of thread-pool/pure java execution we could do better already, using Thread.interrupt() (and/or other means) to abandon + * execution mid-pass. However at present this is not attempted. + * + * @see #execute(int, int) + * @see #execute(Range, int) + * @see #execute(String, Range, int) + */ + public void cancelMultiPass() { + kernelRunner.cancelMultiPass(); + } + + public int getCancelState() { + return kernelRunner == null ? KernelRunner.CANCEL_STATUS_FALSE : kernelRunner.getCancelState(); + } + + /** + * @see KernelRunner#getCurrentPass() + */ + public int getCurrentPass() { + if (kernelRunner == null) { + return KernelRunner.PASS_ID_PREPARING_EXECUTION; + } + return kernelRunner.getCurrentPass(); + } + /** * When using a Java Thread Pool Aparapi uses clone to copy the initial instance to each thread. * diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/jni/KernelRunnerJNI.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/jni/KernelRunnerJNI.java index d34926d2bb73a4bcf6afa629329bebf9e513336b..923875ee51bdf4bab77d550bf36f40b06b7883bd 100644 --- a/com.amd.aparapi/src/java/com/amd/aparapi/internal/jni/KernelRunnerJNI.java +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/jni/KernelRunnerJNI.java @@ -8,6 +8,7 @@ import com.amd.aparapi.device.OpenCLDevice; import com.amd.aparapi.internal.annotation.DocMe; import com.amd.aparapi.internal.annotation.UsedByJNICode; +import java.nio.ByteBuffer; import java.util.List; /** @@ -310,7 +311,7 @@ public abstract class KernelRunnerJNI{ protected native int setArgsJNI(long _jniContextHandle, KernelArgJNI[] _args, int argc); - protected native int runKernelJNI(long _jniContextHandle, Range _range, boolean _needSync, int _passes); + protected native int runKernelJNI(long _jniContextHandle, Range _range, boolean _needSync, int _passes, ByteBuffer _inBuffer, ByteBuffer _outBuffer); protected native int disposeJNI(long _jniContextHandle); diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/kernel/KernelRunner.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/kernel/KernelRunner.java index 761117fbe3123d4e879a2961a4f84ec890aab13c..ddf94ea2643f85127847574a2ae3fcc7e8fadb29 100644 --- a/com.amd.aparapi/src/java/com/amd/aparapi/internal/kernel/KernelRunner.java +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/kernel/KernelRunner.java @@ -37,24 +37,6 @@ under those regulations, please refer to the U.S. Bureau of Industry and Securit */ package com.amd.aparapi.internal.kernel; -import java.lang.reflect.Array; -import java.lang.reflect.Field; -import java.lang.reflect.Modifier; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.StringTokenizer; -import java.util.concurrent.BrokenBarrierException; -import java.util.concurrent.CyclicBarrier; -import java.util.concurrent.ForkJoinPool; -import java.util.concurrent.ForkJoinPool.ForkJoinWorkerThreadFactory; -import java.util.concurrent.ForkJoinPool.ManagedBlocker; -import java.util.concurrent.ForkJoinWorkerThread; -import java.util.logging.Level; -import java.util.logging.Logger; - import com.amd.aparapi.Config; import com.amd.aparapi.Kernel; import com.amd.aparapi.Kernel.Constant; @@ -65,6 +47,7 @@ import com.amd.aparapi.ProfileInfo; import com.amd.aparapi.Range; import com.amd.aparapi.device.Device; import com.amd.aparapi.device.OpenCLDevice; +import com.amd.aparapi.internal.annotation.UsedByJNICode; import com.amd.aparapi.internal.exception.AparapiException; import com.amd.aparapi.internal.exception.CodeGenException; import com.amd.aparapi.internal.instruction.InstructionSet.TypeSpec; @@ -75,6 +58,25 @@ import com.amd.aparapi.internal.util.UnsafeWrapper; import com.amd.aparapi.internal.writer.KernelWriter; import com.amd.aparapi.opencl.OpenCL; +import java.lang.reflect.Array; +import java.lang.reflect.Field; +import java.lang.reflect.Modifier; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.IntBuffer; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.StringTokenizer; +import java.util.concurrent.BrokenBarrierException; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.ForkJoinPool.ForkJoinWorkerThreadFactory; +import java.util.concurrent.ForkJoinPool.ManagedBlocker; +import java.util.concurrent.ForkJoinWorkerThread; +import java.util.logging.Level; +import java.util.logging.Logger; + /** * The class is responsible for executing <code>Kernel</code> implementations. <br/> * @@ -93,6 +95,13 @@ import com.amd.aparapi.opencl.OpenCL; */ public class KernelRunner extends KernelRunnerJNI{ + /** @see #getCurrentPass() */ + @UsedByJNICode public static final int PASS_ID_PREPARING_EXECUTION = -2; + /** @see #getCurrentPass() */ + @UsedByJNICode public static final int PASS_ID_COMPLETED_EXECUTION = -1; + @UsedByJNICode public static final int CANCEL_STATUS_FALSE = 0; + @UsedByJNICode public static final int CANCEL_STATUS_TRUE = 1; + private static Logger logger = Logger.getLogger(Config.getLoggerName()); private long jniContextHandle = 0; @@ -103,6 +112,29 @@ public class KernelRunner extends KernelRunnerJNI{ private int argc; + // may be read by a thread other than the control thread, hence volatile + private volatile boolean executing; + + // may be read by a thread other than the control thread, hence volatile + private volatile int passId = PASS_ID_PREPARING_EXECUTION; + + /** + * A direct ByteBuffer used for asynchronous intercommunication between java and JNI C code. + * + * <p> + * At present this is a 4 byte buffer to be interpreted as an int[1], used for passing from java to C a single integer interpreted as a cancellation indicator. + */ + private final ByteBuffer inBufferRemote; + private final IntBuffer inBufferRemoteInt; + + /** A direct ByteBuffer used for asynchronous intercommunication between java and JNI C code. + * <p> + * At present this is a 4 byte buffer to be interpreted as an int[1], used for passing from C to java a single integer interpreted as a + * the current pass id. + */ + private final ByteBuffer outBufferRemote; + private final IntBuffer outBufferRemoteInt; + private boolean isFallBack = false; // If isFallBack, rebuild the kernel (necessary?) private static final ForkJoinWorkerThreadFactory lowPriorityThreadFactory = new ForkJoinWorkerThreadFactory(){ @@ -123,6 +155,15 @@ public class KernelRunner extends KernelRunnerJNI{ */ public KernelRunner(Kernel _kernel) { kernel = _kernel; + + inBufferRemote = ByteBuffer.allocateDirect(4); + outBufferRemote = ByteBuffer.allocateDirect(4); + + inBufferRemote.order(ByteOrder.nativeOrder()); + outBufferRemote.order(ByteOrder.nativeOrder()); + + inBufferRemoteInt = inBufferRemote.asIntBuffer(); + outBufferRemoteInt = outBufferRemote.asIntBuffer(); } /** @@ -288,289 +329,302 @@ public class KernelRunner extends KernelRunnerJNI{ logger.fine("executeJava: range = " + _range); } - final int localSize0 = _range.getLocalSize(0); - final int localSize1 = _range.getLocalSize(1); - final int localSize2 = _range.getLocalSize(2); - final int globalSize1 = _range.getGlobalSize(1); - if (kernel.getExecutionMode().equals(EXECUTION_MODE.SEQ)) { - /** - * SEQ mode is useful for testing trivial logic, but kernels which use SEQ mode cannot be used if the - * product of localSize(0..3) is >1. So we can use multi-dim ranges but only if the local size is 1 in all dimensions. - * - * As a result of this barrier is only ever 1 work item wide and probably should be turned into a no-op. - * - * So we need to check if the range is valid here. If not we have no choice but to punt. - */ - if ((localSize0 * localSize1 * localSize2) > 1) { - throw new IllegalStateException("Can't run range with group size >1 sequentially. Barriers would deadlock!"); - } - - final Kernel kernelClone = kernel.clone(); - final KernelState kernelState = kernelClone.getKernelState(); + passId = PASS_ID_PREPARING_EXECUTION; + try { + final int localSize0 = _range.getLocalSize(0); + final int localSize1 = _range.getLocalSize(1); + final int localSize2 = _range.getLocalSize(2); + final int globalSize1 = _range.getGlobalSize(1); + if (kernel.getExecutionMode().equals(EXECUTION_MODE.SEQ)) { + /** + * SEQ mode is useful for testing trivial logic, but kernels which use SEQ mode cannot be used if the + * product of localSize(0..3) is >1. So we can use multi-dim ranges but only if the local size is 1 in all dimensions. + * + * As a result of this barrier is only ever 1 work item wide and probably should be turned into a no-op. + * + * So we need to check if the range is valid here. If not we have no choice but to punt. + */ + if ((localSize0 * localSize1 * localSize2) > 1) { + throw new IllegalStateException("Can't run range with group size >1 sequentially. Barriers would deadlock!"); + } - kernelState.setRange(_range); - kernelState.setGroupId(0, 0); - kernelState.setGroupId(1, 0); - kernelState.setGroupId(2, 0); - kernelState.setLocalId(0, 0); - kernelState.setLocalId(1, 0); - kernelState.setLocalId(2, 0); - kernelState.setLocalBarrier(new FJSafeCyclicBarrier(1)); + final Kernel kernelClone = kernel.clone(); + final KernelState kernelState = kernelClone.getKernelState(); - for (int passId = 0; passId < _passes; passId++) { - kernelState.setPassId(passId); + kernelState.setRange(_range); + kernelState.setGroupId(0, 0); + kernelState.setGroupId(1, 0); + kernelState.setGroupId(2, 0); + kernelState.setLocalId(0, 0); + kernelState.setLocalId(1, 0); + kernelState.setLocalId(2, 0); + kernelState.setLocalBarrier(new FJSafeCyclicBarrier(1)); - if (_range.getDims() == 1) { - for (int id = 0; id < _range.getGlobalSize(0); id++) { - kernelState.setGlobalId(0, id); - kernelClone.run(); + for (passId = 0; passId < _passes; passId++) { + if (getCancelState() == CANCEL_STATUS_TRUE) { + break; } - } else if (_range.getDims() == 2) { - for (int x = 0; x < _range.getGlobalSize(0); x++) { - kernelState.setGlobalId(0, x); + kernelState.setPassId(passId); - for (int y = 0; y < globalSize1; y++) { - kernelState.setGlobalId(1, y); + if (_range.getDims() == 1) { + for (int id = 0; id < _range.getGlobalSize(0); id++) { + kernelState.setGlobalId(0, id); kernelClone.run(); } - } - } else if (_range.getDims() == 3) { - for (int x = 0; x < _range.getGlobalSize(0); x++) { - kernelState.setGlobalId(0, x); + } else if (_range.getDims() == 2) { + for (int x = 0; x < _range.getGlobalSize(0); x++) { + kernelState.setGlobalId(0, x); - for (int y = 0; y < globalSize1; y++) { - kernelState.setGlobalId(1, y); - - for (int z = 0; z < _range.getGlobalSize(2); z++) { - kernelState.setGlobalId(2, z); + for (int y = 0; y < globalSize1; y++) { + kernelState.setGlobalId(1, y); kernelClone.run(); } + } + } else if (_range.getDims() == 3) { + for (int x = 0; x < _range.getGlobalSize(0); x++) { + kernelState.setGlobalId(0, x); - kernelClone.run(); + for (int y = 0; y < globalSize1; y++) { + kernelState.setGlobalId(1, y); + + for (int z = 0; z < _range.getGlobalSize(2); z++) { + kernelState.setGlobalId(2, z); + kernelClone.run(); + } + + kernelClone.run(); + } } } } - } - } else { - final int threads = localSize0 * localSize1 * localSize2; - final int numGroups0 = _range.getNumGroups(0); - final int numGroups1 = _range.getNumGroups(1); - final int globalGroups = numGroups0 * numGroups1 * _range.getNumGroups(2); - /** - * This joinBarrier is the barrier that we provide for the kernel threads to rendezvous with the current dispatch thread. - * So this barrier is threadCount+1 wide (the +1 is for the dispatch thread) - */ - final CyclicBarrier joinBarrier = new FJSafeCyclicBarrier(threads + 1); - - /** - * This localBarrier is only ever used by the kernels. If the kernel does not use the barrier the threads - * can get out of sync, we promised nothing in JTP mode. - * - * As with OpenCL all threads within a group must wait at the barrier or none. It is a user error (possible deadlock!) - * if the barrier is in a conditional that is only executed by some of the threads within a group. - * - * Kernel developer must understand this. - * - * This barrier is threadCount wide. We never hit the barrier from the dispatch thread. - */ - final CyclicBarrier localBarrier = new FJSafeCyclicBarrier(threads); - - final ThreadIdSetter threadIdSetter; - - if (_range.getDims() == 1) { - threadIdSetter = new ThreadIdSetter(){ - @Override public void set(KernelState kernelState, int globalGroupId, int threadId) { - // (kernelState, globalGroupId, threadId) ->{ - kernelState.setLocalId(0, (threadId % localSize0)); - kernelState.setGlobalId(0, (threadId + (globalGroupId * threads))); - kernelState.setGroupId(0, globalGroupId); - } - }; - } else if (_range.getDims() == 2) { + passId = PASS_ID_COMPLETED_EXECUTION; + } else { + final int threads = localSize0 * localSize1 * localSize2; + final int numGroups0 = _range.getNumGroups(0); + final int numGroups1 = _range.getNumGroups(1); + final int globalGroups = numGroups0 * numGroups1 * _range.getNumGroups(2); + /** + * This joinBarrier is the barrier that we provide for the kernel threads to rendezvous with the current dispatch thread. + * So this barrier is threadCount+1 wide (the +1 is for the dispatch thread) + */ + final CyclicBarrier joinBarrier = new FJSafeCyclicBarrier(threads + 1); /** - * Consider a 12x4 grid of 4*2 local groups - * <pre> - * threads = 4*2 = 8 - * localWidth=4 - * localHeight=2 - * globalWidth=12 - * globalHeight=4 - * - * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 - * 12 13 14 15 | 16 17 18 19 | 20 21 22 23 - * ------------+-------------+------------ - * 24 25 26 27 | 28 29 30 31 | 32 33 34 35 - * 36 37 38 39 | 40 41 42 43 | 44 45 46 47 - * - * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 threadIds : [0..7]*6 - * 04 05 06 07 | 04 05 06 07 | 04 05 06 07 - * ------------+-------------+------------ - * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 - * 04 05 06 07 | 04 05 06 07 | 04 05 06 07 - * - * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 groupId[0] : 0..6 - * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 - * ------------+-------------+------------ - * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 - * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 - * - * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 groupId[1] : 0..6 - * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 - * ------------+-------------+------------ - * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 - * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 - * - * 00 01 02 03 | 08 09 10 11 | 16 17 18 19 globalThreadIds == threadId + groupId * threads; - * 04 05 06 07 | 12 13 14 15 | 20 21 22 23 - * ------------+-------------+------------ - * 24 25 26 27 | 32[33]34 35 | 40 41 42 43 - * 28 29 30 31 | 36 37 38 39 | 44 45 46 47 - * - * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 localX = threadId % localWidth; (for globalThreadId 33 = threadId = 01 : 01%4 =1) - * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 - * ------------+-------------+------------ - * 00 01 02 03 | 00[01]02 03 | 00 01 02 03 - * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 - * - * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 localY = threadId /localWidth (for globalThreadId 33 = threadId = 01 : 01/4 =0) - * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 - * ------------+-------------+------------ - * 00 00 00 00 | 00[00]00 00 | 00 00 00 00 - * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 - * - * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 globalX= - * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 groupsPerLineWidth=globalWidth/localWidth (=12/4 =3) - * ------------+-------------+------------ groupInset =groupId%groupsPerLineWidth (=4%3 = 1) - * 00 01 02 03 | 04[05]06 07 | 08 09 10 11 - * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 globalX = groupInset*localWidth+localX (= 1*4+1 = 5) - * - * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 globalY - * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 - * ------------+-------------+------------ - * 02 02 02 02 | 02[02]02 02 | 02 02 02 02 - * 03 03 03 03 | 03 03 03 03 | 03 03 03 03 - * - * </pre> - * Assume we are trying to locate the id's for #33 + * This localBarrier is only ever used by the kernels. If the kernel does not use the barrier the threads + * can get out of sync, we promised nothing in JTP mode. + * + * As with OpenCL all threads within a group must wait at the barrier or none. It is a user error (possible deadlock!) + * if the barrier is in a conditional that is only executed by some of the threads within a group. + * + * Kernel developer must understand this. * + * This barrier is threadCount wide. We never hit the barrier from the dispatch thread. */ - threadIdSetter = new ThreadIdSetter(){ - @Override public void set(KernelState kernelState, int globalGroupId, int threadId) { - // (kernelState, globalGroupId, threadId) ->{ - kernelState.setLocalId(0, (threadId % localSize0)); // threadId % localWidth = (for 33 = 1 % 4 = 1) - kernelState.setLocalId(1, (threadId / localSize0)); // threadId / localWidth = (for 33 = 1 / 4 == 0) - - final int groupInset = globalGroupId % numGroups0; // 4%3 = 1 - kernelState.setGlobalId(0, ((groupInset * localSize0) + kernelState.getLocalIds()[0])); // 1*4+1=5 - - final int completeLines = (globalGroupId / numGroups0) * localSize1;// (4/3) * 2 - kernelState.setGlobalId(1, (completeLines + kernelState.getLocalIds()[1])); // 2+0 = 2 - kernelState.setGroupId(0, (globalGroupId % numGroups0)); - kernelState.setGroupId(1, (globalGroupId / numGroups0)); - } - }; - } else if (_range.getDims() == 3) { - //Same as 2D actually turns out that localId[0] is identical for all three dims so could be hoisted out of conditional code - threadIdSetter = new ThreadIdSetter(){ - @Override public void set(KernelState kernelState, int globalGroupId, int threadId) { - // (kernelState, globalGroupId, threadId) ->{ - kernelState.setLocalId(0, (threadId % localSize0)); + final CyclicBarrier localBarrier = new FJSafeCyclicBarrier(threads); - kernelState.setLocalId(1, ((threadId / localSize0) % localSize1)); + final ThreadIdSetter threadIdSetter; - // the thread id's span WxHxD so threadId/(WxH) should yield the local depth - kernelState.setLocalId(2, (threadId / (localSize0 * localSize1))); + if (_range.getDims() == 1) { + threadIdSetter = new ThreadIdSetter(){ + @Override public void set(KernelState kernelState, int globalGroupId, int threadId) { + // (kernelState, globalGroupId, threadId) ->{ + kernelState.setLocalId(0, (threadId % localSize0)); + kernelState.setGlobalId(0, (threadId + (globalGroupId * threads))); + kernelState.setGroupId(0, globalGroupId); + } + }; + } else if (_range.getDims() == 2) { - kernelState.setGlobalId(0, (((globalGroupId % numGroups0) * localSize0) + kernelState.getLocalIds()[0])); + /** + * Consider a 12x4 grid of 4*2 local groups + * <pre> + * threads = 4*2 = 8 + * localWidth=4 + * localHeight=2 + * globalWidth=12 + * globalHeight=4 + * + * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 + * 12 13 14 15 | 16 17 18 19 | 20 21 22 23 + * ------------+-------------+------------ + * 24 25 26 27 | 28 29 30 31 | 32 33 34 35 + * 36 37 38 39 | 40 41 42 43 | 44 45 46 47 + * + * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 threadIds : [0..7]*6 + * 04 05 06 07 | 04 05 06 07 | 04 05 06 07 + * ------------+-------------+------------ + * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 + * 04 05 06 07 | 04 05 06 07 | 04 05 06 07 + * + * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 groupId[0] : 0..6 + * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 + * ------------+-------------+------------ + * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 + * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 + * + * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 groupId[1] : 0..6 + * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 + * ------------+-------------+------------ + * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 + * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 + * + * 00 01 02 03 | 08 09 10 11 | 16 17 18 19 globalThreadIds == threadId + groupId * threads; + * 04 05 06 07 | 12 13 14 15 | 20 21 22 23 + * ------------+-------------+------------ + * 24 25 26 27 | 32[33]34 35 | 40 41 42 43 + * 28 29 30 31 | 36 37 38 39 | 44 45 46 47 + * + * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 localX = threadId % localWidth; (for globalThreadId 33 = threadId = 01 : 01%4 =1) + * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 + * ------------+-------------+------------ + * 00 01 02 03 | 00[01]02 03 | 00 01 02 03 + * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 + * + * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 localY = threadId /localWidth (for globalThreadId 33 = threadId = 01 : 01/4 =0) + * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 + * ------------+-------------+------------ + * 00 00 00 00 | 00[00]00 00 | 00 00 00 00 + * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 + * + * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 globalX= + * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 groupsPerLineWidth=globalWidth/localWidth (=12/4 =3) + * ------------+-------------+------------ groupInset =groupId%groupsPerLineWidth (=4%3 = 1) + * 00 01 02 03 | 04[05]06 07 | 08 09 10 11 + * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 globalX = groupInset*localWidth+localX (= 1*4+1 = 5) + * + * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 globalY + * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 + * ------------+-------------+------------ + * 02 02 02 02 | 02[02]02 02 | 02 02 02 02 + * 03 03 03 03 | 03 03 03 03 | 03 03 03 03 + * + * </pre> + * Assume we are trying to locate the id's for #33 + * + */ + threadIdSetter = new ThreadIdSetter(){ + @Override public void set(KernelState kernelState, int globalGroupId, int threadId) { + // (kernelState, globalGroupId, threadId) ->{ + kernelState.setLocalId(0, (threadId % localSize0)); // threadId % localWidth = (for 33 = 1 % 4 = 1) + kernelState.setLocalId(1, (threadId / localSize0)); // threadId / localWidth = (for 33 = 1 / 4 == 0) + + final int groupInset = globalGroupId % numGroups0; // 4%3 = 1 + kernelState.setGlobalId(0, ((groupInset * localSize0) + kernelState.getLocalIds()[0])); // 1*4+1=5 + + final int completeLines = (globalGroupId / numGroups0) * localSize1;// (4/3) * 2 + kernelState.setGlobalId(1, (completeLines + kernelState.getLocalIds()[1])); // 2+0 = 2 + kernelState.setGroupId(0, (globalGroupId % numGroups0)); + kernelState.setGroupId(1, (globalGroupId / numGroups0)); + } + }; + } else if (_range.getDims() == 3) { + //Same as 2D actually turns out that localId[0] is identical for all three dims so could be hoisted out of conditional code + threadIdSetter = new ThreadIdSetter(){ + @Override public void set(KernelState kernelState, int globalGroupId, int threadId) { + // (kernelState, globalGroupId, threadId) ->{ + kernelState.setLocalId(0, (threadId % localSize0)); - kernelState.setGlobalId(1, - ((((globalGroupId / numGroups0) * localSize1) % globalSize1) + kernelState.getLocalIds()[1])); + kernelState.setLocalId(1, ((threadId / localSize0) % localSize1)); - kernelState.setGlobalId(2, - (((globalGroupId / (numGroups0 * numGroups1)) * localSize2) + kernelState.getLocalIds()[2])); + // the thread id's span WxHxD so threadId/(WxH) should yield the local depth + kernelState.setLocalId(2, (threadId / (localSize0 * localSize1))); - kernelState.setGroupId(0, (globalGroupId % numGroups0)); - kernelState.setGroupId(1, ((globalGroupId / numGroups0) % numGroups1)); - kernelState.setGroupId(2, (globalGroupId / (numGroups0 * numGroups1))); - } - }; - } else - throw new IllegalArgumentException("Expected 1,2 or 3 dimensions, found " + _range.getDims()); - for (int passId = 0; passId < _passes; passId++) { - /** - * Note that we emulate OpenCL by creating one thread per localId (across the group). - * - * So threadCount == range.getLocalSize(0)*range.getLocalSize(1)*range.getLocalSize(2); - * - * For a 1D range of 12 groups of 4 we create 4 threads. One per localId(0). - * - * We also clone the kernel 4 times. One per thread. - * - * We create local barrier which has a width of 4 - * - * Thread-0 handles localId(0) (global 0,4,8) - * Thread-1 handles localId(1) (global 1,5,7) - * Thread-2 handles localId(2) (global 2,6,10) - * Thread-3 handles localId(3) (global 3,7,11) - * - * This allows all threads to synchronize using the local barrier. - * - * Initially the use of local buffers seems broken as the buffers appears to be per Kernel. - * Thankfully Kernel.clone() performs a shallow clone of all buffers (local and global) - * So each of the cloned kernels actually still reference the same underlying local/global buffers. - * - * If the kernel uses local buffers but does not use barriers then it is possible for different groups - * to see mutations from each other (unlike OpenCL), however if the kernel does not us barriers then it - * cannot assume any coherence in OpenCL mode either (the failure mode will be different but still wrong) - * - * So even JTP mode use of local buffers will need to use barriers. Not for the same reason as OpenCL but to keep groups in lockstep. - * - **/ - for (int id = 0; id < threads; id++) { - final int threadId = id; + kernelState.setGlobalId(0, (((globalGroupId % numGroups0) * localSize0) + kernelState.getLocalIds()[0])); - /** - * We clone one kernel for each thread. - * - * They will all share references to the same range, localBarrier and global/local buffers because the clone is shallow. - * We need clones so that each thread can assign 'state' (localId/globalId/groupId) without worrying - * about other threads. - */ - final Kernel kernelClone = kernel.clone(); - final KernelState kernelState = kernelClone.getKernelState(); - kernelState.setRange(_range); - kernelState.setPassId(passId); + kernelState.setGlobalId(1, + ((((globalGroupId / numGroups0) * localSize1) % globalSize1) + kernelState.getLocalIds()[1])); - if (threads == 1) { - kernelState.disableLocalBarrier(); - } else { - kernelState.setLocalBarrier(localBarrier); + kernelState.setGlobalId(2, + (((globalGroupId / (numGroups0 * numGroups1)) * localSize2) + kernelState.getLocalIds()[2])); + + kernelState.setGroupId(0, (globalGroupId % numGroups0)); + kernelState.setGroupId(1, ((globalGroupId / numGroups0) % numGroups1)); + kernelState.setGroupId(2, (globalGroupId / (numGroups0 * numGroups1))); + } + }; + } else + throw new IllegalArgumentException("Expected 1,2 or 3 dimensions, found " + _range.getDims()); + for (passId = 0; passId < _passes; passId++) { + if (getCancelState() == CANCEL_STATUS_TRUE) { + break; } + /** + * Note that we emulate OpenCL by creating one thread per localId (across the group). + * + * So threadCount == range.getLocalSize(0)*range.getLocalSize(1)*range.getLocalSize(2); + * + * For a 1D range of 12 groups of 4 we create 4 threads. One per localId(0). + * + * We also clone the kernel 4 times. One per thread. + * + * We create local barrier which has a width of 4 + * + * Thread-0 handles localId(0) (global 0,4,8) + * Thread-1 handles localId(1) (global 1,5,7) + * Thread-2 handles localId(2) (global 2,6,10) + * Thread-3 handles localId(3) (global 3,7,11) + * + * This allows all threads to synchronize using the local barrier. + * + * Initially the use of local buffers seems broken as the buffers appears to be per Kernel. + * Thankfully Kernel.clone() performs a shallow clone of all buffers (local and global) + * So each of the cloned kernels actually still reference the same underlying local/global buffers. + * + * If the kernel uses local buffers but does not use barriers then it is possible for different groups + * to see mutations from each other (unlike OpenCL), however if the kernel does not us barriers then it + * cannot assume any coherence in OpenCL mode either (the failure mode will be different but still wrong) + * + * So even JTP mode use of local buffers will need to use barriers. Not for the same reason as OpenCL but to keep groups in lockstep. + * + **/ + for (int id = 0; id < threads; id++) { + final int threadId = id; + + /** + * We clone one kernel for each thread. + * + * They will all share references to the same range, localBarrier and global/local buffers because the clone is shallow. + * We need clones so that each thread can assign 'state' (localId/globalId/groupId) without worrying + * about other threads. + */ + final Kernel kernelClone = kernel.clone(); + final KernelState kernelState = kernelClone.getKernelState(); + kernelState.setRange(_range); + kernelState.setPassId(passId); + + if (threads == 1) { + kernelState.disableLocalBarrier(); + } else { + kernelState.setLocalBarrier(localBarrier); + } - threadPool.submit( - // () -> { - new Runnable(){ - public void run() { - try { - for (int globalGroupId = 0; globalGroupId < globalGroups; globalGroupId++) { - threadIdSetter.set(kernelState, globalGroupId, threadId); - kernelClone.run(); + threadPool.submit( + // () -> { + new Runnable(){ + public void run() { + try { + for (int globalGroupId = 0; globalGroupId < globalGroups; globalGroupId++) { + threadIdSetter.set(kernelState, globalGroupId, threadId); + kernelClone.run(); + } + } catch (RuntimeException | Error e) { + logger.log(Level.SEVERE, "Execution failed", e); + } finally { + await(joinBarrier); // This thread will rendezvous with dispatch thread here. This is effectively a join. } - } catch (RuntimeException | Error e) { - logger.log(Level.SEVERE, "Execution failed", e); - } finally { - await(joinBarrier); // This thread will rendezvous with dispatch thread here. This is effectively a join. } - } - }); - } + }); + } - await(joinBarrier); // This dispatch thread waits for all worker threads here. - } - } // execution mode == JTP + await(joinBarrier); // This dispatch thread waits for all worker threads here. + } + passId = PASS_ID_COMPLETED_EXECUTION; + } // execution mode == JTP - return 0; + return 0; + } finally { + passId = PASS_ID_COMPLETED_EXECUTION; + } } private static void await(CyclicBarrier _barrier) { @@ -963,7 +1017,7 @@ public class KernelRunner extends KernelRunnerJNI{ } // native side will reallocate array buffers if necessary - if (runKernelJNI(jniContextHandle, _range, needSync, _passes) != 0) { + if (runKernelJNI(jniContextHandle, _range, needSync, _passes, inBufferRemote, outBufferRemote) != 0) { logger.warning("### " + describeKernelClass() + " - CL exec seems to have failed. Trying to revert to Java ###"); kernel.setFallbackExecutionMode(); return execute(_entrypointName, _range, _passes); @@ -1015,297 +1069,347 @@ public class KernelRunner extends KernelRunnerJNI{ } public synchronized Kernel execute(String _entrypointName, final Range _range, final int _passes) { + clearCancelMultiPass(); + executing = true; + try { + long executeStartTime = System.currentTimeMillis(); - long executeStartTime = System.currentTimeMillis(); - - if (_range == null) { - throw new IllegalStateException("range can't be null"); - } + if (_range == null) { + throw new IllegalStateException("range can't be null"); + } - /* for backward compatibility reasons we still honor execution mode */ - if (kernel.getExecutionMode().isOpenCL()) { - // System.out.println("OpenCL"); + /* for backward compatibility reasons we still honor execution mode */ + if (kernel.getExecutionMode().isOpenCL()) { + // System.out.println("OpenCL"); - // See if user supplied a Device - Device device = _range.getDevice(); + // See if user supplied a Device + Device device = _range.getDevice(); - if ((device == null) || (device instanceof OpenCLDevice)) { - if ((entryPoint == null) || (isFallBack)) { - if (entryPoint == null) { - try { - final ClassModel classModel = ClassModel.createClassModel(kernel.getClass()); - entryPoint = classModel.getEntrypoint(_entrypointName, kernel); - } catch (final Exception exception) { - return warnFallBackAndExecute(_entrypointName, _range, _passes, exception); + if ((device == null) || (device instanceof OpenCLDevice)) { + if ((entryPoint == null) || (isFallBack)) { + if (entryPoint == null) { + try { + final ClassModel classModel = ClassModel.createClassModel(kernel.getClass()); + entryPoint = classModel.getEntrypoint(_entrypointName, kernel); + } catch (final Exception exception) { + return warnFallBackAndExecute(_entrypointName, _range, _passes, exception); + } } - } - if ((entryPoint != null) && !entryPoint.shouldFallback()) { - synchronized (Kernel.class) { // This seems to be needed because of a race condition uncovered with issue #68 http://code.google.com/p/aparapi/issues/detail?id=68 - if (device != null && !(device instanceof OpenCLDevice)) { - throw new IllegalStateException("range's device is not suitable for OpenCL "); - } + if ((entryPoint != null) && !entryPoint.shouldFallback()) { + synchronized (Kernel.class) { // This seems to be needed because of a race condition uncovered with issue #68 http://code.google.com/p/aparapi/issues/detail?id=68 + if (device != null && !(device instanceof OpenCLDevice)) { + throw new IllegalStateException("range's device is not suitable for OpenCL "); + } - OpenCLDevice openCLDevice = (OpenCLDevice) device; // still might be null! + OpenCLDevice openCLDevice = (OpenCLDevice) device; // still might be null! - int jniFlags = 0; - if (openCLDevice == null) { - if (kernel.getExecutionMode().equals(EXECUTION_MODE.GPU)) { - // Get the best GPU - openCLDevice = (OpenCLDevice) OpenCLDevice.bestGPU(); - jniFlags |= JNI_FLAG_USE_GPU; // this flag might be redundant now. - if (openCLDevice == null) { - return warnFallBackAndExecute(_entrypointName, _range, _passes, "GPU request can't be honored"); - } - } else if (kernel.getExecutionMode().equals(EXECUTION_MODE.ACC)) { - // Get the best ACC - openCLDevice = (OpenCLDevice) OpenCLDevice.bestACC(); - jniFlags |= JNI_FLAG_USE_ACC; // this flag might be redundant now. - if (openCLDevice == null) { - return warnFallBackAndExecute(_entrypointName, _range, _passes, "ACC request can't be honored"); + int jniFlags = 0; + if (openCLDevice == null) { + if (kernel.getExecutionMode().equals(EXECUTION_MODE.GPU)) { + // Get the best GPU + openCLDevice = (OpenCLDevice) OpenCLDevice.bestGPU(); + jniFlags |= JNI_FLAG_USE_GPU; // this flag might be redundant now. + if (openCLDevice == null) { + return warnFallBackAndExecute(_entrypointName, _range, _passes, "GPU request can't be honored"); + } + } else if (kernel.getExecutionMode().equals(EXECUTION_MODE.ACC)) { + // Get the best ACC + openCLDevice = (OpenCLDevice) OpenCLDevice.bestACC(); + jniFlags |= JNI_FLAG_USE_ACC; // this flag might be redundant now. + if (openCLDevice == null) { + return warnFallBackAndExecute(_entrypointName, _range, _passes, "ACC request can't be honored"); + } + } else { + // We fetch the first CPU device + openCLDevice = (OpenCLDevice) OpenCLDevice.firstCPU(); + if (openCLDevice == null) { + return warnFallBackAndExecute(_entrypointName, _range, _passes, + "CPU request can't be honored not CPU device"); + } } - } else { - // We fetch the first CPU device - openCLDevice = (OpenCLDevice) OpenCLDevice.firstCPU(); - if (openCLDevice == null) { - return warnFallBackAndExecute(_entrypointName, _range, _passes, - "CPU request can't be honored not CPU device"); + } else { // openCLDevice == null + if (openCLDevice.getType() == Device.TYPE.GPU) { + jniFlags |= JNI_FLAG_USE_GPU; // this flag might be redundant now. + } else if (openCLDevice.getType() == Device.TYPE.ACC) { + jniFlags |= JNI_FLAG_USE_ACC; // this flag might be redundant now. } } - } else { // openCLDevice == null - if (openCLDevice.getType() == Device.TYPE.GPU) { - jniFlags |= JNI_FLAG_USE_GPU; // this flag might be redundant now. - } else if (openCLDevice.getType() == Device.TYPE.ACC) { - jniFlags |= JNI_FLAG_USE_ACC; // this flag might be redundant now. - } - } - // jniFlags |= (Config.enableProfiling ? JNI_FLAG_ENABLE_PROFILING : 0); - // jniFlags |= (Config.enableProfilingCSV ? JNI_FLAG_ENABLE_PROFILING_CSV | JNI_FLAG_ENABLE_PROFILING : 0); - // jniFlags |= (Config.enableVerboseJNI ? JNI_FLAG_ENABLE_VERBOSE_JNI : 0); - // jniFlags |= (Config.enableVerboseJNIOpenCLResourceTracking ? JNI_FLAG_ENABLE_VERBOSE_JNI_OPENCL_RESOURCE_TRACKING :0); - // jniFlags |= (kernel.getExecutionMode().equals(EXECUTION_MODE.GPU) ? JNI_FLAG_USE_GPU : 0); - // Init the device to check capabilities before emitting the - // code that requires the capabilities. + // jniFlags |= (Config.enableProfiling ? JNI_FLAG_ENABLE_PROFILING : 0); + // jniFlags |= (Config.enableProfilingCSV ? JNI_FLAG_ENABLE_PROFILING_CSV | JNI_FLAG_ENABLE_PROFILING : 0); + // jniFlags |= (Config.enableVerboseJNI ? JNI_FLAG_ENABLE_VERBOSE_JNI : 0); + // jniFlags |= (Config.enableVerboseJNIOpenCLResourceTracking ? JNI_FLAG_ENABLE_VERBOSE_JNI_OPENCL_RESOURCE_TRACKING :0); + // jniFlags |= (kernel.getExecutionMode().equals(EXECUTION_MODE.GPU) ? JNI_FLAG_USE_GPU : 0); + // Init the device to check capabilities before emitting the + // code that requires the capabilities. - // synchronized(Kernel.class){ - jniContextHandle = initJNI(kernel, openCLDevice, jniFlags); // openCLDevice will not be null here - } // end of synchronized! issue 68 + // synchronized(Kernel.class){ + jniContextHandle = initJNI(kernel, openCLDevice, jniFlags); // openCLDevice will not be null here + } // end of synchronized! issue 68 - if (jniContextHandle == 0) { - return warnFallBackAndExecute(_entrypointName, _range, _passes, "initJNI failed to return a valid handle"); - } - - final String extensions = getExtensionsJNI(jniContextHandle); - capabilitiesSet = new HashSet<String>(); - - final StringTokenizer strTok = new StringTokenizer(extensions); - while (strTok.hasMoreTokens()) { - capabilitiesSet.add(strTok.nextToken()); - } - - if (logger.isLoggable(Level.FINE)) { - logger.fine("Capabilities initialized to :" + capabilitiesSet.toString()); - } + if (jniContextHandle == 0) { + return warnFallBackAndExecute(_entrypointName, _range, _passes, "initJNI failed to return a valid handle"); + } - if (entryPoint.requiresDoublePragma() && !hasFP64Support()) { - return warnFallBackAndExecute(_entrypointName, _range, _passes, "FP64 required but not supported"); - } + final String extensions = getExtensionsJNI(jniContextHandle); + capabilitiesSet = new HashSet<String>(); - if (entryPoint.requiresByteAddressableStorePragma() && !hasByteAddressableStoreSupport()) { - return warnFallBackAndExecute(_entrypointName, _range, _passes, - "Byte addressable stores required but not supported"); - } + final StringTokenizer strTok = new StringTokenizer(extensions); + while (strTok.hasMoreTokens()) { + capabilitiesSet.add(strTok.nextToken()); + } - final boolean all32AtomicsAvailable = hasGlobalInt32BaseAtomicsSupport() - && hasGlobalInt32ExtendedAtomicsSupport() && hasLocalInt32BaseAtomicsSupport() - && hasLocalInt32ExtendedAtomicsSupport(); + if (logger.isLoggable(Level.FINE)) { + logger.fine("Capabilities initialized to :" + capabilitiesSet.toString()); + } - if (entryPoint.requiresAtomic32Pragma() && !all32AtomicsAvailable) { + if (entryPoint.requiresDoublePragma() && !hasFP64Support()) { + return warnFallBackAndExecute(_entrypointName, _range, _passes, "FP64 required but not supported"); + } - return warnFallBackAndExecute(_entrypointName, _range, _passes, "32 bit Atomics required but not supported"); - } + if (entryPoint.requiresByteAddressableStorePragma() && !hasByteAddressableStoreSupport()) { + return warnFallBackAndExecute(_entrypointName, _range, _passes, + "Byte addressable stores required but not supported"); + } - String openCL = null; - try { - openCL = KernelWriter.writeToString(entryPoint); - } catch (final CodeGenException codeGenException) { - return warnFallBackAndExecute(_entrypointName, _range, _passes, codeGenException); - } + final boolean all32AtomicsAvailable = hasGlobalInt32BaseAtomicsSupport() + && hasGlobalInt32ExtendedAtomicsSupport() && hasLocalInt32BaseAtomicsSupport() + && hasLocalInt32ExtendedAtomicsSupport(); - if (Config.enableShowGeneratedOpenCL) { - System.out.println(openCL); - } + if (entryPoint.requiresAtomic32Pragma() && !all32AtomicsAvailable) { - if (logger.isLoggable(Level.INFO)) { - logger.info(openCL); - } + return warnFallBackAndExecute(_entrypointName, _range, _passes, "32 bit Atomics required but not supported"); + } - // Send the string to OpenCL to compile it - if (buildProgramJNI(jniContextHandle, openCL) == 0) { - return warnFallBackAndExecute(_entrypointName, _range, _passes, "OpenCL compile failed"); - } + String openCL = null; + try { + openCL = KernelWriter.writeToString(entryPoint); + } catch (final CodeGenException codeGenException) { + return warnFallBackAndExecute(_entrypointName, _range, _passes, codeGenException); + } - args = new KernelArg[entryPoint.getReferencedFields().size()]; - int i = 0; + if (Config.enableShowGeneratedOpenCL) { + System.out.println(openCL); + } - for (final Field field : entryPoint.getReferencedFields()) { - try { - field.setAccessible(true); - args[i] = new KernelArg(); - args[i].setName(field.getName()); - args[i].setField(field); - if ((field.getModifiers() & Modifier.STATIC) == Modifier.STATIC) { - args[i].setType(args[i].getType() | ARG_STATIC); - } + if (logger.isLoggable(Level.INFO)) { + logger.info(openCL); + } - final Class<?> type = field.getType(); - if (type.isArray()) { + // Send the string to OpenCL to compile it + if (buildProgramJNI(jniContextHandle, openCL) == 0) { + return warnFallBackAndExecute(_entrypointName, _range, _passes, "OpenCL compile failed"); + } - if (field.getAnnotation(Local.class) != null || args[i].getName().endsWith(Kernel.LOCAL_SUFFIX)) { - args[i].setType(args[i].getType() | ARG_LOCAL); - } else if ((field.getAnnotation(Constant.class) != null) - || args[i].getName().endsWith(Kernel.CONSTANT_SUFFIX)) { - args[i].setType(args[i].getType() | ARG_CONSTANT); - } else { - args[i].setType(args[i].getType() | ARG_GLOBAL); + args = new KernelArg[entryPoint.getReferencedFields().size()]; + int i = 0; + + for (final Field field : entryPoint.getReferencedFields()) { + try { + field.setAccessible(true); + args[i] = new KernelArg(); + args[i].setName(field.getName()); + args[i].setField(field); + if ((field.getModifiers() & Modifier.STATIC) == Modifier.STATIC) { + args[i].setType(args[i].getType() | ARG_STATIC); } - if (isExplicit()) { - args[i].setType(args[i].getType() | ARG_EXPLICIT); - } - // for now, treat all write arrays as read-write, see bugzilla issue 4859 - // we might come up with a better solution later - args[i].setType(args[i].getType() - | (entryPoint.getArrayFieldAssignments().contains(field.getName()) ? (ARG_WRITE | ARG_READ) : 0)); - args[i].setType(args[i].getType() - | (entryPoint.getArrayFieldAccesses().contains(field.getName()) ? ARG_READ : 0)); - // args[i].type |= ARG_GLOBAL; - - if (type.getName().startsWith("[L")) { - args[i].setType(args[i].getType() - | (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ | ARG_APARAPI_BUFFER)); - if (logger.isLoggable(Level.FINE)) { - logger.fine("tagging " + args[i].getName() + " as (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ)"); - } - } else if (type.getName().startsWith("[[")) { + final Class<?> type = field.getType(); + if (type.isArray()) { - try { - setMultiArrayType(args[i], type); - } catch (AparapiException e) { - return warnFallBackAndExecute(_entrypointName, _range, _passes, "failed to set kernel arguement " - + args[i].getName() + ". Aparapi only supports 2D and 3D arrays."); + if (field.getAnnotation(Local.class) != null || args[i].getName().endsWith(Kernel.LOCAL_SUFFIX)) { + args[i].setType(args[i].getType() | ARG_LOCAL); + } else if ((field.getAnnotation(Constant.class) != null) + || args[i].getName().endsWith(Kernel.CONSTANT_SUFFIX)) { + args[i].setType(args[i].getType() | ARG_CONSTANT); + } else { + args[i].setType(args[i].getType() | ARG_GLOBAL); } - } else { - - args[i].setArray(null); // will get updated in updateKernelArrayRefs - args[i].setType(args[i].getType() | ARG_ARRAY); - - args[i].setType(args[i].getType() | (type.isAssignableFrom(float[].class) ? ARG_FLOAT : 0)); - args[i].setType(args[i].getType() | (type.isAssignableFrom(int[].class) ? ARG_INT : 0)); - args[i].setType(args[i].getType() | (type.isAssignableFrom(boolean[].class) ? ARG_BOOLEAN : 0)); - args[i].setType(args[i].getType() | (type.isAssignableFrom(byte[].class) ? ARG_BYTE : 0)); - args[i].setType(args[i].getType() | (type.isAssignableFrom(char[].class) ? ARG_CHAR : 0)); - args[i].setType(args[i].getType() | (type.isAssignableFrom(double[].class) ? ARG_DOUBLE : 0)); - args[i].setType(args[i].getType() | (type.isAssignableFrom(long[].class) ? ARG_LONG : 0)); - args[i].setType(args[i].getType() | (type.isAssignableFrom(short[].class) ? ARG_SHORT : 0)); - - // arrays whose length is used will have an int arg holding - // the length as a kernel param - if (entryPoint.getArrayFieldArrayLengthUsed().contains(args[i].getName())) { - args[i].setType(args[i].getType() | ARG_ARRAYLENGTH); + if (isExplicit()) { + args[i].setType(args[i].getType() | ARG_EXPLICIT); } + // for now, treat all write arrays as read-write, see bugzilla issue 4859 + // we might come up with a better solution later + args[i].setType(args[i].getType() + | (entryPoint.getArrayFieldAssignments().contains(field.getName()) ? (ARG_WRITE | ARG_READ) : 0)); + args[i].setType(args[i].getType() + | (entryPoint.getArrayFieldAccesses().contains(field.getName()) ? ARG_READ : 0)); + // args[i].type |= ARG_GLOBAL; if (type.getName().startsWith("[L")) { - args[i].setType(args[i].getType() | (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ)); + args[i].setType(args[i].getType() + | (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ | ARG_APARAPI_BUFFER)); + if (logger.isLoggable(Level.FINE)) { - logger.fine("tagging " + args[i].getName() - + " as (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ)"); + logger.fine("tagging " + args[i].getName() + " as (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ)"); + } + } else if (type.getName().startsWith("[[")) { + + try { + setMultiArrayType(args[i], type); + } catch (AparapiException e) { + return warnFallBackAndExecute(_entrypointName, _range, _passes, "failed to set kernel arguement " + + args[i].getName() + ". Aparapi only supports 2D and 3D arrays."); + } + } else { + + args[i].setArray(null); // will get updated in updateKernelArrayRefs + args[i].setType(args[i].getType() | ARG_ARRAY); + + args[i].setType(args[i].getType() | (type.isAssignableFrom(float[].class) ? ARG_FLOAT : 0)); + args[i].setType(args[i].getType() | (type.isAssignableFrom(int[].class) ? ARG_INT : 0)); + args[i].setType(args[i].getType() | (type.isAssignableFrom(boolean[].class) ? ARG_BOOLEAN : 0)); + args[i].setType(args[i].getType() | (type.isAssignableFrom(byte[].class) ? ARG_BYTE : 0)); + args[i].setType(args[i].getType() | (type.isAssignableFrom(char[].class) ? ARG_CHAR : 0)); + args[i].setType(args[i].getType() | (type.isAssignableFrom(double[].class) ? ARG_DOUBLE : 0)); + args[i].setType(args[i].getType() | (type.isAssignableFrom(long[].class) ? ARG_LONG : 0)); + args[i].setType(args[i].getType() | (type.isAssignableFrom(short[].class) ? ARG_SHORT : 0)); + + // arrays whose length is used will have an int arg holding + // the length as a kernel param + if (entryPoint.getArrayFieldArrayLengthUsed().contains(args[i].getName())) { + args[i].setType(args[i].getType() | ARG_ARRAYLENGTH); + } + + if (type.getName().startsWith("[L")) { + args[i].setType(args[i].getType() | (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ)); + if (logger.isLoggable(Level.FINE)) { + logger.fine("tagging " + args[i].getName() + + " as (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ)"); + } } } + } else if (type.isAssignableFrom(float.class)) { + args[i].setType(args[i].getType() | ARG_PRIMITIVE); + args[i].setType(args[i].getType() | ARG_FLOAT); + } else if (type.isAssignableFrom(int.class)) { + args[i].setType(args[i].getType() | ARG_PRIMITIVE); + args[i].setType(args[i].getType() | ARG_INT); + } else if (type.isAssignableFrom(double.class)) { + args[i].setType(args[i].getType() | ARG_PRIMITIVE); + args[i].setType(args[i].getType() | ARG_DOUBLE); + } else if (type.isAssignableFrom(long.class)) { + args[i].setType(args[i].getType() | ARG_PRIMITIVE); + args[i].setType(args[i].getType() | ARG_LONG); + } else if (type.isAssignableFrom(boolean.class)) { + args[i].setType(args[i].getType() | ARG_PRIMITIVE); + args[i].setType(args[i].getType() | ARG_BOOLEAN); + } else if (type.isAssignableFrom(byte.class)) { + args[i].setType(args[i].getType() | ARG_PRIMITIVE); + args[i].setType(args[i].getType() | ARG_BYTE); + } else if (type.isAssignableFrom(char.class)) { + args[i].setType(args[i].getType() | ARG_PRIMITIVE); + args[i].setType(args[i].getType() | ARG_CHAR); + } else if (type.isAssignableFrom(short.class)) { + args[i].setType(args[i].getType() | ARG_PRIMITIVE); + args[i].setType(args[i].getType() | ARG_SHORT); } - } else if (type.isAssignableFrom(float.class)) { - args[i].setType(args[i].getType() | ARG_PRIMITIVE); - args[i].setType(args[i].getType() | ARG_FLOAT); - } else if (type.isAssignableFrom(int.class)) { - args[i].setType(args[i].getType() | ARG_PRIMITIVE); - args[i].setType(args[i].getType() | ARG_INT); - } else if (type.isAssignableFrom(double.class)) { - args[i].setType(args[i].getType() | ARG_PRIMITIVE); - args[i].setType(args[i].getType() | ARG_DOUBLE); - } else if (type.isAssignableFrom(long.class)) { - args[i].setType(args[i].getType() | ARG_PRIMITIVE); - args[i].setType(args[i].getType() | ARG_LONG); - } else if (type.isAssignableFrom(boolean.class)) { - args[i].setType(args[i].getType() | ARG_PRIMITIVE); - args[i].setType(args[i].getType() | ARG_BOOLEAN); - } else if (type.isAssignableFrom(byte.class)) { - args[i].setType(args[i].getType() | ARG_PRIMITIVE); - args[i].setType(args[i].getType() | ARG_BYTE); - } else if (type.isAssignableFrom(char.class)) { - args[i].setType(args[i].getType() | ARG_PRIMITIVE); - args[i].setType(args[i].getType() | ARG_CHAR); - } else if (type.isAssignableFrom(short.class)) { - args[i].setType(args[i].getType() | ARG_PRIMITIVE); - args[i].setType(args[i].getType() | ARG_SHORT); + // System.out.printf("in execute, arg %d %s %08x\n", i,args[i].name,args[i].type ); + } catch (final IllegalArgumentException e) { + e.printStackTrace(); } - // System.out.printf("in execute, arg %d %s %08x\n", i,args[i].name,args[i].type ); - } catch (final IllegalArgumentException e) { - e.printStackTrace(); - } - args[i].setPrimitiveSize(getPrimitiveSize(args[i].getType())); + args[i].setPrimitiveSize(getPrimitiveSize(args[i].getType())); - if (logger.isLoggable(Level.FINE)) { - logger.fine("arg " + i + ", " + args[i].getName() + ", type=" + Integer.toHexString(args[i].getType()) - + ", primitiveSize=" + args[i].getPrimitiveSize()); - } + if (logger.isLoggable(Level.FINE)) { + logger.fine("arg " + i + ", " + args[i].getName() + ", type=" + Integer.toHexString(args[i].getType()) + + ", primitiveSize=" + args[i].getPrimitiveSize()); + } - i++; - } + i++; + } - // at this point, i = the actual used number of arguments - // (private buffers do not get treated as arguments) + // at this point, i = the actual used number of arguments + // (private buffers do not get treated as arguments) - argc = i; + argc = i; - setArgsJNI(jniContextHandle, args, argc); + setArgsJNI(jniContextHandle, args, argc); - conversionTime = System.currentTimeMillis() - executeStartTime; + conversionTime = System.currentTimeMillis() - executeStartTime; + try { + executeOpenCL(_entrypointName, _range, _passes); + isFallBack = false; + } catch (final AparapiException e) { + warnFallBackAndExecute(_entrypointName, _range, _passes, e); + } + } else { // (entryPoint != null) && !entryPoint.shouldFallback() + warnFallBackAndExecute(_entrypointName, _range, _passes, "failed to locate entrypoint"); + } + } else { // (entryPoint == null) || (isFallBack) try { executeOpenCL(_entrypointName, _range, _passes); isFallBack = false; } catch (final AparapiException e) { warnFallBackAndExecute(_entrypointName, _range, _passes, e); } - } else { // (entryPoint != null) && !entryPoint.shouldFallback() - warnFallBackAndExecute(_entrypointName, _range, _passes, "failed to locate entrypoint"); - } - } else { // (entryPoint == null) || (isFallBack) - try { - executeOpenCL(_entrypointName, _range, _passes); - isFallBack = false; - } catch (final AparapiException e) { - warnFallBackAndExecute(_entrypointName, _range, _passes, e); } + } else { // (device == null) || (device instanceof OpenCLDevice) + warnFallBackAndExecute(_entrypointName, _range, _passes, + "OpenCL was requested but Device supplied was not an OpenCLDevice"); } - } else { // (device == null) || (device instanceof OpenCLDevice) - warnFallBackAndExecute(_entrypointName, _range, _passes, - "OpenCL was requested but Device supplied was not an OpenCLDevice"); + } else { // kernel.getExecutionMode().isOpenCL() + executeJava(_range, _passes); + } + + if (Config.enableExecutionModeReporting) { + System.out.println(describeKernelClass() + ":" + kernel.getExecutionMode()); } - } else { // kernel.getExecutionMode().isOpenCL() - executeJava(_range, _passes); + + executionTime = System.currentTimeMillis() - executeStartTime; + accumulatedExecutionTime += executionTime; + + return kernel; + } finally { + executing = false; + clearCancelMultiPass(); } + } - if (Config.enableExecutionModeReporting) { - System.out.println(describeKernelClass() + ":" + kernel.getExecutionMode()); + public int getCancelState() { + return inBufferRemoteInt.get(0); + } + + public void cancelMultiPass() { + inBufferRemoteInt.put(0, CANCEL_STATUS_TRUE); + } + + private void clearCancelMultiPass() { + inBufferRemoteInt.put(0, CANCEL_STATUS_FALSE); + } + + /** + * Returns the index of the current pass, or one of two special constants with negative values to indicate special progress states. Those constants are + * {@link #PASS_ID_PREPARING_EXECUTION} to indicate that the Kernel has not yet started executing, or {@link #PASS_ID_COMPLETED_EXECUTION} to indicate that + * execution is complete (possibly due to early termination via {@link #cancelMultiPass()}). + * + * <p>This can be used, for instance, to update a visual progress bar. + * + * @see #execute(String, Range, int) + */ + public int getCurrentPass() { + if (!executing) { + return PASS_ID_COMPLETED_EXECUTION; + } + switch (kernel.getExecutionMode()) { + case NONE: + return PASS_ID_COMPLETED_EXECUTION; + case JTP: // fallthrough + case SEQ: + return getCurrentPassLocal(); + default: + return getCurrentPassRemote(); } + } - executionTime = System.currentTimeMillis() - executeStartTime; - accumulatedExecutionTime += executionTime; + protected int getCurrentPassRemote() { + return outBufferRemoteInt.get(0); + } - return kernel; + private int getCurrentPassLocal() { + return passId; } private int getPrimitiveSize(int type) { diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/util/swing/MultiPassKernelSwingWorker.java b/com.amd.aparapi/src/java/com/amd/aparapi/util/swing/MultiPassKernelSwingWorker.java new file mode 100644 index 0000000000000000000000000000000000000000..db33946952fa6e1b1152ddd9e8761d9084398f2e --- /dev/null +++ b/com.amd.aparapi/src/java/com/amd/aparapi/util/swing/MultiPassKernelSwingWorker.java @@ -0,0 +1,84 @@ +package com.amd.aparapi.util.swing; + +import com.amd.aparapi.Kernel; +import com.amd.aparapi.internal.kernel.KernelRunner; + +import javax.swing.*; +import java.awt.event.ActionEvent; +import java.awt.event.ActionListener; + +/** + * Implementation of SwingWorker to assist in progress tracking and cancellation of multi-pass {@link Kernel}s. + */ +public abstract class MultiPassKernelSwingWorker extends SwingWorker<Void, Void>{ + + public static final int DEFAULT_POLL_INTERVAL = 50; + + private Kernel kernel; + private Timer timer; + + protected MultiPassKernelSwingWorker(Kernel kernel) { + this.kernel = kernel; + } + + /** Utility method which just invokes {@link Kernel#cancelMultiPass()} on the executing kernel. */ + public void cancelExecution() { + kernel.cancelMultiPass(); + } + + /** This method must invoke one of the {@code kernel}'s execute() methods. */ + protected abstract void executeKernel(Kernel kernel); + + /** This method, which is always invoked on the swing event dispatch thread, should be used to update any components (such as a {@link javax.swing.JProgressBar}) so + * as to reflect the progress of the multi-pass Kernel being executed. + * + * @param passId The passId for the Kernel's current pass, or one of the constant fields returnable by {@link KernelRunner#getCurrentPass()}. + */ + protected abstract void updatePassId(int passId); + + /** Executes the {@link #kernel} via {@link #executeKernel(Kernel)}, whilst also managing progress updates for the kernel's passId. */ + @Override + protected final Void doInBackground() throws Exception { + try { + setUpExecution(); + executeKernel(kernel); + return null; + } + finally { + cleanUpExecution(); + } + } + + private void setUpExecution() { + ActionListener listener = new ActionListener() { + @Override + public void actionPerformed(ActionEvent e) { + updatePassId(); + } + }; + timer = new Timer(getPollIntervalMillis(), listener); + timer.setCoalesce(false); + timer.start(); + } + + private void cleanUpExecution() { + timer.stop(); + timer = null; + SwingUtilities.invokeLater(new Runnable() { + @Override + public void run() { + updatePassId(KernelRunner.PASS_ID_COMPLETED_EXECUTION); + } + }); + } + + private void updatePassId() { + int progress = kernel.getCurrentPass(); + updatePassId(progress); + } + + /** The interval at which the Kernel's current passId is polled. Unless overridden, returns {@link #DEFAULT_POLL_INTERVAL}. */ + protected int getPollIntervalMillis() { + return DEFAULT_POLL_INTERVAL; + } +} diff --git a/samples/median/src/com/amd/aparapi/sample/median/MedianDemo.java b/samples/median/src/com/amd/aparapi/sample/median/MedianDemo.java index 2e938d75cfe9db93f05281895023efdf5327bd17..99fa6259ff541663b292bc0e0c29aaf6709d61c3 100644 --- a/samples/median/src/com/amd/aparapi/sample/median/MedianDemo.java +++ b/samples/median/src/com/amd/aparapi/sample/median/MedianDemo.java @@ -16,7 +16,8 @@ public class MedianDemo { static { try { - testImage = ImageIO.read(new File("C:\\dev\\aparapi_live\\aparapi\\samples\\convolution\\testcard.jpg")); + File imageFile = new File("./../../../samples/convolution/testcard.jpg").getCanonicalFile(); + testImage = ImageIO.read(imageFile); } catch (IOException e) { throw new RuntimeException(e); } @@ -25,6 +26,7 @@ public class MedianDemo { private static final boolean TEST_JTP = false; public static void main(String[] ignored) { + final int size = 5; System.setProperty("com.amd.aparapi.enableShowGeneratedOpenCL", "true"); int[] argbs = testImage.getRGB(0, 0, testImage.getWidth(), testImage.getHeight(), null, 0, testImage.getWidth()); MedianKernel7x7 kernel = new MedianKernel7x7(); @@ -36,7 +38,7 @@ public class MedianDemo { if (TEST_JTP) { kernel.setExecutionMode(Kernel.EXECUTION_MODE.JTP); } - kernel.processImages(new MedianSettings(7)); + kernel.processImages(new MedianSettings(size)); BufferedImage out = new BufferedImage(testImage.getWidth(), testImage.getHeight(), BufferedImage.TYPE_INT_RGB); out.setRGB(0, 0, testImage.getWidth(), testImage.getHeight(), kernel._destPixels, 0, testImage.getWidth()); ImageIcon icon1 = new ImageIcon(testImage); @@ -55,7 +57,7 @@ public class MedianDemo { int reps = 20; for (int rep = 0; rep < reps; ++rep) { long start = System.nanoTime(); - kernel.processImages(new MedianSettings(7)); + kernel.processImages(new MedianSettings(size)); long elapsed = System.nanoTime() - start; System.out.println("elapsed = " + elapsed / 1000000f + "ms"); } diff --git a/samples/progress/src/com/amd/aparapi/sample/progress/LongRunningKernel.java b/samples/progress/src/com/amd/aparapi/sample/progress/LongRunningKernel.java new file mode 100644 index 0000000000000000000000000000000000000000..f21f37c303039225da3fa3990ab9e8667f7005ff --- /dev/null +++ b/samples/progress/src/com/amd/aparapi/sample/progress/LongRunningKernel.java @@ -0,0 +1,31 @@ +package com.amd.aparapi.sample.progress; + +import com.amd.aparapi.Kernel; + +/** + * Kernel which performs very many meaningless calculations, used to demonstrate progress tracking and cancellation of multi-pass Kernels. + */ +public class LongRunningKernel extends Kernel { + + public static final int RANGE = 20000; + private static final int REPETITIONS = 1 * 1000 * 1000; + + public final long[] data = new long[RANGE]; + + @Override + public void run() { + int id = getGlobalId(); + if (id == 0) { + report(); + } + for (int rep = 0; rep < REPETITIONS; ++rep) { + data[id] += (int) sqrt(1); + } + } + + @NoCL + public void report() { + int passId = getPassId(); + System.out.println("Java execution: passId = " + passId); + } +} diff --git a/samples/progress/src/com/amd/aparapi/sample/progress/MultiPassKernelSwingWorkerDemo.java b/samples/progress/src/com/amd/aparapi/sample/progress/MultiPassKernelSwingWorkerDemo.java new file mode 100644 index 0000000000000000000000000000000000000000..7cc2584b1cb10d054f16632dd12ff27f2102c53b --- /dev/null +++ b/samples/progress/src/com/amd/aparapi/sample/progress/MultiPassKernelSwingWorkerDemo.java @@ -0,0 +1,132 @@ +package com.amd.aparapi.sample.progress; + +import com.amd.aparapi.Kernel; +import com.amd.aparapi.internal.kernel.KernelRunner; +import com.amd.aparapi.util.swing.MultiPassKernelSwingWorker; + +import javax.swing.*; +import javax.swing.plaf.nimbus.NimbusLookAndFeel; +import java.awt.*; +import java.awt.event.ActionEvent; +import java.awt.event.ActionListener; + +/** + * Demonstrates progress tracking and cancellation for multi-pass kernels, via {@link MultiPassKernelSwingWorker}. + */ +public class MultiPassKernelSwingWorkerDemo { + + private static final int PASS_COUNT = 200; + private static JButton startButton; + private static JButton cancelButton; + private static JProgressBar progress; + private static JLabel status = new JLabel("Press Start", JLabel.CENTER); + private static LongRunningKernel kernel; + private static MultiPassKernelSwingWorker worker; + + private static final boolean TEST_JTP = true; + + public static void main(String[] ignored) throws Exception { + kernel = new LongRunningKernel(); + if (TEST_JTP) { + kernel.setExecutionMode(Kernel.EXECUTION_MODE.JTP); + } + + UIManager.setLookAndFeel(NimbusLookAndFeel.class.getName()); + JPanel rootPanel = new JPanel(); + rootPanel.setLayout(new BorderLayout()); + JPanel buttons = new JPanel(new FlowLayout(FlowLayout.CENTER)); + startButton = new JButton("Start"); + cancelButton = new JButton("Cancel"); + startButton.setEnabled(true); + startButton.addActionListener(new ActionListener() { + @Override + public void actionPerformed(ActionEvent e) { + start(); + } + }); + cancelButton.setEnabled(false); + cancelButton.addActionListener(new ActionListener() { + @Override + public void actionPerformed(ActionEvent e) { + cancel(); + } + }); + buttons.add(startButton); + buttons.add(cancelButton); + rootPanel.add(buttons, BorderLayout.SOUTH); + + progress = new JProgressBar(new DefaultBoundedRangeModel(0, 0, 0, PASS_COUNT)); + + rootPanel.add(status, BorderLayout.CENTER); + rootPanel.add(progress, BorderLayout.NORTH); + + JFrame frame = new JFrame("MultiPassKernelSwingWorker Demo"); + frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); + frame.getContentPane().add(rootPanel); + frame.pack(); + frame.setLocationRelativeTo(null); + frame.setVisible(true); + } + + private static MultiPassKernelSwingWorker createWorker() { + return new MultiPassKernelSwingWorker(kernel) { + @Override + protected void executeKernel(Kernel kernel) { + int range; + if (TEST_JTP) { + range = LongRunningKernel.RANGE / 1000; + } else { + range = LongRunningKernel.RANGE; + } + kernel.execute(range, PASS_COUNT); + } + + @Override + protected void updatePassId(int passId) { + updateProgress(passId); + } + + @Override + protected void done() { + updateProgress(KernelRunner.PASS_ID_COMPLETED_EXECUTION); + startButton.setEnabled(true); + cancelButton.setEnabled(false); + } + }; + } + + private static void start() { + if (!SwingUtilities.isEventDispatchThread()) { + throw new IllegalStateException(); + } + + startButton.setEnabled(false); + cancelButton.setEnabled(true); + worker = createWorker(); + worker.execute(); + System.out.println("Started execution of MultiPassKernelSwingWorker"); + } + + private static void updateProgress(int passId) { + int progressValue; + if (passId >= 0) { + progressValue = passId; + status.setText("passId = " + passId); + } else if (passId == KernelRunner.PASS_ID_PREPARING_EXECUTION) { + progressValue = 0; + status.setText("Preparing"); + } else if (passId == KernelRunner.PASS_ID_COMPLETED_EXECUTION) { + progressValue = PASS_COUNT; + status.setText("Complete"); + } else { + progressValue = 0; + status.setText("Illegal status " + passId); + } + progress.getModel().setValue(progressValue); + } + + private static void cancel() { + worker.cancelExecution(); + } +} + diff --git a/samples/progress/src/com/amd/aparapi/sample/progress/ProgressAndCancelDemo.java b/samples/progress/src/com/amd/aparapi/sample/progress/ProgressAndCancelDemo.java new file mode 100644 index 0000000000000000000000000000000000000000..b114dcac4f19b5d93e6ec82b1d84da19193fa719 --- /dev/null +++ b/samples/progress/src/com/amd/aparapi/sample/progress/ProgressAndCancelDemo.java @@ -0,0 +1,168 @@ +package com.amd.aparapi.sample.progress; + +import com.amd.aparapi.Kernel; +import com.amd.aparapi.internal.kernel.KernelRunner; + +import javax.swing.*; +import javax.swing.plaf.nimbus.NimbusLookAndFeel; +import java.awt.*; +import java.awt.event.ActionEvent; +import java.awt.event.ActionListener; + +/** + * Demonstrates progress tracking and cancellation for multi-pass kernels. + */ +public class ProgressAndCancelDemo { + + private static final int PASS_COUNT = 200; + private static final int POLL_SLEEP = 50; + private static JButton startButton; + private static JButton cancelButton; + private static JProgressBar progress; + private static JLabel status = new JLabel("Press Start", JLabel.CENTER); + + private static LongRunningKernel kernel; + private static Timer timer; + + private static final boolean TEST_JTP = false; + + public static void main(String[] ignored) throws Exception { + + System.setProperty("com.amd.aparapi.enableShowGeneratedOpenCL", "true"); + System.setProperty("com.amd.aparapi.enableVerboseJNI", "true"); + System.setProperty("com.amd.aparapi.dumpFlags", "true"); + System.setProperty("com.amd.aparapi.enableVerboseJNIOpenCLResourceTracking", "true"); + System.setProperty("com.amd.aparapi.enableExecutionModeReporting", "true"); + + kernel = new LongRunningKernel(); + if (TEST_JTP) { + kernel.setExecutionMode(Kernel.EXECUTION_MODE.JTP); + } + Thread asynchReader = new Thread() { + @Override + public void run() { + while (true) { + try { + int cancelState = kernel.getCancelState(); + int passId = kernel.getCurrentPass(); + System.out.println("cancel = " + cancelState + ", passId = " + passId); + Thread.sleep(50); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + } + }; + //asynchReader.start(); + UIManager.setLookAndFeel(NimbusLookAndFeel.class.getName()); + JPanel rootPanel = new JPanel(); + rootPanel.setLayout(new BorderLayout()); + JPanel buttons = new JPanel(new FlowLayout(FlowLayout.CENTER)); + startButton = new JButton("Start"); + cancelButton = new JButton("Cancel"); + startButton.setEnabled(true); + startButton.addActionListener(new ActionListener() { + @Override + public void actionPerformed(ActionEvent e) { + start(); + } + }); + cancelButton.setEnabled(false); + cancelButton.addActionListener(new ActionListener() { + @Override + public void actionPerformed(ActionEvent e) { + cancel(); + } + }); + buttons.add(startButton); + buttons.add(cancelButton); + rootPanel.add(buttons, BorderLayout.SOUTH); + + progress = new JProgressBar(new DefaultBoundedRangeModel(0, 0, 0, PASS_COUNT)); + + rootPanel.add(status, BorderLayout.CENTER); + rootPanel.add(progress, BorderLayout.NORTH); + + JFrame frame = new JFrame("Progress and Cancel Demo"); + frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); + frame.getContentPane().add(rootPanel); + frame.pack(); + frame.setLocationRelativeTo(null); + frame.setVisible(true); + } + + private static void start() { + if (!SwingUtilities.isEventDispatchThread()) { + throw new IllegalStateException(); + } + Thread executionThread = new Thread() { + @Override + public void run() { + executeKernel(); + } + }; + executionThread.start(); + updateProgress(); + timer = new Timer(POLL_SLEEP, new ActionListener() { + @Override + public void actionPerformed(ActionEvent e) { + updateProgress(); + } + }); + timer.setCoalesce(false); + timer.setRepeats(true); + timer.start(); + System.out.println("Started on EDT"); + } + + private static void updateProgress() { + int passId = kernel.getCurrentPass(); + int progressValue; + if (passId >= 0) { + progressValue = passId; + status.setText("passId = " + passId); + } else if (passId == KernelRunner.PASS_ID_PREPARING_EXECUTION) { + progressValue = 0; + status.setText("Preparing"); + } else if (passId == KernelRunner.PASS_ID_COMPLETED_EXECUTION) { + progressValue = PASS_COUNT; + status.setText("Complete"); + } else { + progressValue = 0; + status.setText("Illegal status " + passId); + } + progress.getModel().setValue(progressValue); + } + + private static void cancel() { + kernel.cancelMultiPass(); + } + + private static void executeKernel() { + System.out.println("Starting execution"); + startButton.setEnabled(false); + cancelButton.setEnabled(true); + try { + int range; + if (TEST_JTP) { + range = LongRunningKernel.RANGE / 1000; + } else { + range = LongRunningKernel.RANGE; + } + kernel.execute(range, PASS_COUNT); + } catch (Throwable t) { + t.printStackTrace(); + } finally { + System.out.println("Finished execution"); + System.out.println("kernel.data[0] = " + kernel.data[0]); + if (timer != null) { + timer.stop(); + timer = null; + } + startButton.setEnabled(true); + cancelButton.setEnabled(false); + updateProgress(); + } + } + +}