diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/Kernel.java b/com.amd.aparapi/src/java/com/amd/aparapi/Kernel.java index e01dba3f29b6fbb83bb7fd1f1ace19e805c447a0..f09dfb892d6a50ef594b53235dbbc1e21d493f5a 100644 --- a/com.amd.aparapi/src/java/com/amd/aparapi/Kernel.java +++ b/com.amd.aparapi/src/java/com/amd/aparapi/Kernel.java @@ -40,6 +40,10 @@ package com.amd.aparapi; import com.amd.aparapi.annotation.*; import com.amd.aparapi.exception.*; import com.amd.aparapi.internal.kernel.*; +import com.amd.aparapi.internal.model.CacheEnabler; +import com.amd.aparapi.internal.model.ValueCache; +import com.amd.aparapi.internal.model.ValueCache.ThrowingValueComputer; +import com.amd.aparapi.internal.model.ValueCache.ValueComputer; import com.amd.aparapi.internal.model.ClassModel.ConstantPool.*; import com.amd.aparapi.internal.opencl.*; import com.amd.aparapi.internal.util.*; @@ -209,13 +213,13 @@ public abstract class Kernel implements Cloneable { int value(); } - /** - * Annotation which can be applied to either a getter (with usual java bean naming convention relative to an instance field), or to any method - * with void return type, which prevents both the method body and any calls to the method being emitted in the generated OpenCL. (In the case of a getter, the - * underlying field is used in place of the NoCL getter method.) This allows for code specialization within a java/JTP execution path, for example to - * allow logging/breakpointing when debugging, or to apply ThreadLocal processing (see {@link PrivateMemorySpace}) in java to simulate OpenCL __private - * memory. - */ + /** + * Annotation which can be applied to either a getter (with usual java bean naming convention relative to an instance field), or to any method + * with void return type, which prevents both the method body and any calls to the method being emitted in the generated OpenCL. (In the case of a getter, the + * underlying field is used in place of the NoCL getter method.) This allows for code specialization within a java/JTP execution path, for example to + * allow logging/breakpointing when debugging, or to apply ThreadLocal processing (see {@link PrivateMemorySpace}) in java to simulate OpenCL __private + * memory. + */ @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.METHOD, ElementType.FIELD}) public @interface NoCL { @@ -469,23 +473,11 @@ public abstract class Kernel implements Cloneable { */ public final class KernelState { - private int[] globalIds = new int[] { - 0, - 0, - 0 - }; + private int[] globalIds = new int[] {0, 0, 0}; - private int[] localIds = new int[] { - 0, - 0, - 0 - }; + private int[] localIds = new int[] {0, 0, 0}; - private int[] groupIds = new int[] { - 0, - 0, - 0 - }; + private int[] groupIds = new int[] {0, 0, 0}; private Range range; @@ -493,6 +485,8 @@ public abstract class Kernel implements Cloneable { private volatile CyclicBarrier localBarrier; + private boolean localBarrierDisabled; + /** * Default constructor */ @@ -625,6 +619,24 @@ public abstract class Kernel implements Cloneable { public void setLocalBarrier(CyclicBarrier localBarrier) { this.localBarrier = localBarrier; } + + public void awaitOnLocalBarrier() { + if (!localBarrierDisabled) { + try { + kernelState.getLocalBarrier().await(); + } catch (final InterruptedException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } catch (final BrokenBarrierException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + } + } + + public void disableLocalBarrier() { + localBarrierDisabled = true; + } } /** @@ -945,23 +957,11 @@ public abstract class Kernel implements Cloneable { // We need to be careful to also clone the KernelState worker.kernelState = worker.new KernelState(kernelState); // Qualified copy constructor - worker.kernelState.setGroupIds(new int[] { - 0, - 0, - 0 - }); + worker.kernelState.setGroupIds(new int[] {0, 0, 0}); - worker.kernelState.setLocalIds(new int[] { - 0, - 0, - 0 - }); + worker.kernelState.setLocalIds(new int[] {0, 0, 0}); - worker.kernelState.setGlobalIds(new int[] { - 0, - 0, - 0 - }); + worker.kernelState.setGlobalIds(new int[] {0, 0, 0}); return worker; } catch (final CloneNotSupportedException e) { @@ -1839,15 +1839,7 @@ public abstract class Kernel implements Cloneable { @OpenCLDelegate @Experimental protected final void localBarrier() { - try { - kernelState.getLocalBarrier().await(); - } catch (final InterruptedException e) { - // TODO Auto-generated catch block - e.printStackTrace(); - } catch (final BrokenBarrierException e) { - // TODO Auto-generated catch block - e.printStackTrace(); - } + kernelState.awaitOnLocalBarrier(); } /** @@ -1867,6 +1859,16 @@ public abstract class Kernel implements Cloneable { "Kernel.globalBarrier() has been deprecated. It was based an incorrect understanding of OpenCL functionality."); } + @OpenCLMapping(mapTo = "hypot") + protected float hypot(final float a, final float b) { + return (float) Math.hypot(a, b); + } + + @OpenCLMapping(mapTo = "hypot") + protected double hypot(final double a, final double b) { + return Math.hypot(a, b); + } + public KernelState getKernelState() { return kernelState; } @@ -1883,11 +1885,14 @@ public abstract class Kernel implements Cloneable { * */ public synchronized long getExecutionTime() { + return prepareKernelRunner().getExecutionTime(); + } + + private KernelRunner prepareKernelRunner() { if (kernelRunner == null) { kernelRunner = new KernelRunner(this); } - - return (kernelRunner.getExecutionTime()); + return kernelRunner; } /** @@ -1902,11 +1907,7 @@ public abstract class Kernel implements Cloneable { * */ public synchronized long getAccumulatedExecutionTime() { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - return (kernelRunner.getAccumulatedExecutionTime()); + return prepareKernelRunner().getAccumulatedExecutionTime(); } /** @@ -1917,11 +1918,7 @@ public abstract class Kernel implements Cloneable { * @see #getAccumulatedExecutionTime(); */ public synchronized long getConversionTime() { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - return (kernelRunner.getConversionTime()); + return prepareKernelRunner().getConversionTime(); } /** @@ -1993,11 +1990,7 @@ public abstract class Kernel implements Cloneable { * */ public synchronized Kernel execute(Entry _entry, Range _range) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - return (kernelRunner.execute(_entry, _range, 1)); + return prepareKernelRunner().execute(_entry, _range, 1); } /** @@ -2025,12 +2018,7 @@ public abstract class Kernel implements Cloneable { * */ public synchronized Kernel execute(String _entrypoint, Range _range, int _passes) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - - } - - return (kernelRunner.execute(_entrypoint, _range, _passes)); + return prepareKernelRunner().execute(_entrypoint, _range, _passes); } /** @@ -2105,49 +2093,64 @@ public abstract class Kernel implements Cloneable { } private static String getReturnTypeLetter(Method meth) { - final Class<?> retClass = meth.getReturnType(); + return toClassShortNameIfAny(meth.getReturnType()); + } + + private static String toClassShortNameIfAny(final Class<?> retClass) { + if (retClass.isArray()) { + return "[" + toClassShortNameIfAny(retClass.getComponentType()); + } final String strRetClass = retClass.toString(); final String mapping = typeToLetterMap.get(strRetClass); // System.out.println("strRetClass = <" + strRetClass + ">, mapping = " + mapping); + if (mapping == null) + return "[" + retClass.getName() + ";"; return mapping; } public static String getMappedMethodName(MethodReferenceEntry _methodReferenceEntry) { + if (CacheEnabler.areCachesEnabled()) + return getProperty(mappedMethodNamesCache, _methodReferenceEntry, null); String mappedName = null; final String name = _methodReferenceEntry.getNameAndTypeEntry().getNameUTF8Entry().getUTF8(); - for (final Method kernelMethod : Kernel.class.getDeclaredMethods()) { - if (kernelMethod.isAnnotationPresent(OpenCLMapping.class)) { - // ultimately, need a way to constrain this based upon signature (to disambiguate abs(float) from abs(int); - // for Alpha, we will just disambiguate based on the return type - if (false) { - System.out.println("kernelMethod is ... " + kernelMethod.toGenericString()); - System.out.println("returnType = " + kernelMethod.getReturnType()); - System.out.println("returnTypeLetter = " + getReturnTypeLetter(kernelMethod)); - System.out.println("kernelMethod getName = " + kernelMethod.getName()); - System.out.println("methRefName = " + name + " descriptor = " - + _methodReferenceEntry.getNameAndTypeEntry().getDescriptorUTF8Entry().getUTF8()); - System.out - .println("descToReturnTypeLetter = " - + descriptorToReturnTypeLetter(_methodReferenceEntry.getNameAndTypeEntry().getDescriptorUTF8Entry() - .getUTF8())); - } - if (_methodReferenceEntry.getNameAndTypeEntry().getNameUTF8Entry().getUTF8().equals(kernelMethod.getName()) - && descriptorToReturnTypeLetter(_methodReferenceEntry.getNameAndTypeEntry().getDescriptorUTF8Entry().getUTF8()) - .equals(getReturnTypeLetter(kernelMethod))) { - final OpenCLMapping annotation = kernelMethod.getAnnotation(OpenCLMapping.class); - final String mapTo = annotation.mapTo(); - if (!mapTo.equals("")) { - mappedName = mapTo; - // System.out.println("mapTo = " + mapTo); + Class<?> currentClass = _methodReferenceEntry.getOwnerClassModel().getClassWeAreModelling(); + while (currentClass != Object.class) { + for (final Method kernelMethod : currentClass.getDeclaredMethods()) { + if (kernelMethod.isAnnotationPresent(OpenCLMapping.class)) { + // ultimately, need a way to constrain this based upon signature (to disambiguate abs(float) from abs(int); + // for Alpha, we will just disambiguate based on the return type + if (false) { + System.out.println("kernelMethod is ... " + kernelMethod.toGenericString()); + System.out.println("returnType = " + kernelMethod.getReturnType()); + System.out.println("returnTypeLetter = " + getReturnTypeLetter(kernelMethod)); + System.out.println("kernelMethod getName = " + kernelMethod.getName()); + System.out.println("methRefName = " + name + " descriptor = " + + _methodReferenceEntry.getNameAndTypeEntry().getDescriptorUTF8Entry().getUTF8()); + System.out.println("descToReturnTypeLetter = " + + descriptorToReturnTypeLetter(_methodReferenceEntry.getNameAndTypeEntry().getDescriptorUTF8Entry() + .getUTF8())); + } + if (toSignature(_methodReferenceEntry).equals(toSignature(kernelMethod))) { + final OpenCLMapping annotation = kernelMethod.getAnnotation(OpenCLMapping.class); + final String mapTo = annotation.mapTo(); + if (!mapTo.equals("")) { + mappedName = mapTo; + // System.out.println("mapTo = " + mapTo); + } } } } + if (mappedName != null) + break; + currentClass = currentClass.getSuperclass(); } // System.out.println("... in getMappedMethodName, returning = " + mappedName); return (mappedName); } public static boolean isMappedMethod(MethodReferenceEntry methodReferenceEntry) { + if (CacheEnabler.areCachesEnabled()) + return getBoolean(mappedMethodFlags, methodReferenceEntry); boolean isMapped = false; for (final Method kernelMethod : Kernel.class.getDeclaredMethods()) { if (kernelMethod.isAnnotationPresent(OpenCLMapping.class)) { @@ -2162,6 +2165,8 @@ public abstract class Kernel implements Cloneable { } public static boolean isOpenCLDelegateMethod(MethodReferenceEntry methodReferenceEntry) { + if (CacheEnabler.areCachesEnabled()) + return getBoolean(openCLDelegateMethodFlags, methodReferenceEntry); boolean isMapped = false; for (final Method kernelMethod : Kernel.class.getDeclaredMethods()) { if (kernelMethod.isAnnotationPresent(OpenCLDelegate.class)) { @@ -2176,6 +2181,8 @@ public abstract class Kernel implements Cloneable { } public static boolean usesAtomic32(MethodReferenceEntry methodReferenceEntry) { + if (CacheEnabler.areCachesEnabled()) + return getProperty(atomic32Cache, methodReferenceEntry, false); for (final Method kernelMethod : Kernel.class.getDeclaredMethods()) { if (kernelMethod.isAnnotationPresent(OpenCLMapping.class)) { if (methodReferenceEntry.getNameAndTypeEntry().getNameUTF8Entry().getUTF8().equals(kernelMethod.getName())) { @@ -2189,6 +2196,8 @@ public abstract class Kernel implements Cloneable { // For alpha release atomic64 is not supported public static boolean usesAtomic64(MethodReferenceEntry methodReferenceEntry) { + // if (CacheEnabler.areCachesEnabled()) + // return getProperty(atomic64Cache, methodReferenceEntry, false); //for (java.lang.reflect.Method kernelMethod : Kernel.class.getDeclaredMethods()) { // if (kernelMethod.isAnnotationPresent(Kernel.OpenCLMapping.class)) { // if (methodReferenceEntry.getNameAndTypeEntry().getNameUTF8Entry().getUTF8().equals(kernelMethod.getName())) { @@ -2213,11 +2222,7 @@ public abstract class Kernel implements Cloneable { * @param _explicit (true if we want explicit memory management) */ public void setExplicit(boolean _explicit) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.setExplicit(_explicit); + prepareKernelRunner().setExplicit(_explicit); } /** @@ -2225,11 +2230,7 @@ public abstract class Kernel implements Cloneable { * @return (true if we kernel is using explicit memory management) */ public boolean isExplicit() { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - return (kernelRunner.isExplicit()); + return prepareKernelRunner().isExplicit(); } /** @@ -2238,11 +2239,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel put(long[] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.put(array); + prepareKernelRunner().put(array); return (this); } @@ -2252,11 +2249,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel put(long[][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.put(array); + prepareKernelRunner().put(array); return (this); } @@ -2266,11 +2259,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel put(long[][][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.put(array); + prepareKernelRunner().put(array); return (this); } @@ -2280,11 +2269,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel put(double[] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.put(array); + prepareKernelRunner().put(array); return (this); } @@ -2294,11 +2279,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel put(double[][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.put(array); + prepareKernelRunner().put(array); return (this); } @@ -2308,11 +2289,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel put(double[][][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.put(array); + prepareKernelRunner().put(array); return (this); } @@ -2322,11 +2299,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel put(float[] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.put(array); + prepareKernelRunner().put(array); return (this); } @@ -2336,11 +2309,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel put(float[][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.put(array); + prepareKernelRunner().put(array); return (this); } @@ -2350,11 +2319,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel put(float[][][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.put(array); + prepareKernelRunner().put(array); return (this); } @@ -2364,11 +2329,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel put(int[] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.put(array); + prepareKernelRunner().put(array); return (this); } @@ -2378,11 +2339,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel put(int[][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.put(array); + prepareKernelRunner().put(array); return (this); } @@ -2392,11 +2349,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel put(int[][][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.put(array); + prepareKernelRunner().put(array); return (this); } @@ -2406,11 +2359,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel put(byte[] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.put(array); + prepareKernelRunner().put(array); return (this); } @@ -2420,11 +2369,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel put(byte[][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.put(array); + prepareKernelRunner().put(array); return (this); } @@ -2434,11 +2379,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel put(byte[][][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.put(array); + prepareKernelRunner().put(array); return (this); } @@ -2448,11 +2389,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel put(char[] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.put(array); + prepareKernelRunner().put(array); return (this); } @@ -2462,11 +2399,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel put(char[][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.put(array); + prepareKernelRunner().put(array); return (this); } @@ -2476,11 +2409,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel put(char[][][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.put(array); + prepareKernelRunner().put(array); return (this); } @@ -2490,11 +2419,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel put(boolean[] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.put(array); + prepareKernelRunner().put(array); return (this); } @@ -2504,11 +2429,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel put(boolean[][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.put(array); + prepareKernelRunner().put(array); return (this); } @@ -2518,11 +2439,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel put(boolean[][][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.put(array); + prepareKernelRunner().put(array); return (this); } @@ -2532,11 +2449,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel get(long[] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.get(array); + prepareKernelRunner().get(array); return (this); } @@ -2546,11 +2459,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel get(long[][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.get(array); + prepareKernelRunner().get(array); return (this); } @@ -2560,11 +2469,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel get(long[][][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.get(array); + prepareKernelRunner().get(array); return (this); } @@ -2574,11 +2479,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel get(double[] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.get(array); + prepareKernelRunner().get(array); return (this); } @@ -2588,11 +2489,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel get(double[][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.get(array); + prepareKernelRunner().get(array); return (this); } @@ -2602,11 +2499,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel get(double[][][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.get(array); + prepareKernelRunner().get(array); return (this); } @@ -2616,11 +2509,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel get(float[] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.get(array); + prepareKernelRunner().get(array); return (this); } @@ -2630,11 +2519,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel get(float[][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.get(array); + prepareKernelRunner().get(array); return (this); } @@ -2644,11 +2529,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel get(float[][][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.get(array); + prepareKernelRunner().get(array); return (this); } @@ -2658,11 +2539,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel get(int[] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.get(array); + prepareKernelRunner().get(array); return (this); } @@ -2672,11 +2549,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel get(int[][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.get(array); + prepareKernelRunner().get(array); return (this); } @@ -2686,11 +2559,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel get(int[][][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.get(array); + prepareKernelRunner().get(array); return (this); } @@ -2700,11 +2569,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel get(byte[] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.get(array); + prepareKernelRunner().get(array); return (this); } @@ -2714,11 +2579,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel get(byte[][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.get(array); + prepareKernelRunner().get(array); return (this); } @@ -2728,11 +2589,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel get(byte[][][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.get(array); + prepareKernelRunner().get(array); return (this); } @@ -2742,11 +2599,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel get(char[] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.get(array); + prepareKernelRunner().get(array); return (this); } @@ -2756,11 +2609,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel get(char[][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.get(array); + prepareKernelRunner().get(array); return (this); } @@ -2770,11 +2619,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel get(char[][][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.get(array); + prepareKernelRunner().get(array); return (this); } @@ -2784,11 +2629,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel get(boolean[] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.get(array); + prepareKernelRunner().get(array); return (this); } @@ -2798,11 +2639,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel get(boolean[][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.get(array); + prepareKernelRunner().get(array); return (this); } @@ -2812,11 +2649,7 @@ public abstract class Kernel implements Cloneable { * @return This kernel so that we can use the 'fluent' style API */ public Kernel get(boolean[][][] array) { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - kernelRunner.get(array); + prepareKernelRunner().get(array); return (this); } @@ -2825,11 +2658,7 @@ public abstract class Kernel implements Cloneable { * @return A list of ProfileInfo records */ public List<ProfileInfo> getProfileInfo() { - if (kernelRunner == null) { - kernelRunner = new KernelRunner(this); - } - - return (kernelRunner.getProfileInfo()); + return prepareKernelRunner().getProfileInfo(); } private final LinkedHashSet<EXECUTION_MODE> executionModes = EXECUTION_MODE.getDefaultExecutionModes(); @@ -2864,4 +2693,131 @@ public abstract class Kernel implements Cloneable { executionMode = currentMode.next(); } } + + private static final ValueCache<Class<?>, Map<String, Boolean>, RuntimeException> mappedMethodFlags = markedWith(OpenCLMapping.class); + + private static final ValueCache<Class<?>, Map<String, Boolean>, RuntimeException> openCLDelegateMethodFlags = markedWith(OpenCLDelegate.class); + + private static final ValueCache<Class<?>, Map<String, Boolean>, RuntimeException> atomic32Cache = cacheProperty(new ValueComputer<Class<?>, Map<String, Boolean>>() { + @Override + public Map<String, Boolean> compute(Class<?> key) { + Map<String, Boolean> properties = new HashMap<>(); + for (final Method method : key.getDeclaredMethods()) { + if (isRelevant(method) && method.isAnnotationPresent(OpenCLMapping.class)) { + properties.put(toSignature(method), method.getAnnotation(OpenCLMapping.class).atomic32()); + } + } + return properties; + } + }); + + private static final ValueCache<Class<?>, Map<String, Boolean>, RuntimeException> atomic64Cache = cacheProperty(new ValueComputer<Class<?>, Map<String, Boolean>>() { + @Override + public Map<String, Boolean> compute(Class<?> key) { + Map<String, Boolean> properties = new HashMap<>(); + for (final Method method : key.getDeclaredMethods()) { + if (isRelevant(method) && method.isAnnotationPresent(OpenCLMapping.class)) { + properties.put(toSignature(method), method.getAnnotation(OpenCLMapping.class).atomic64()); + } + } + return properties; + } + }); + + private static boolean getBoolean(ValueCache<Class<?>, Map<String, Boolean>, RuntimeException> methodNamesCache, + MethodReferenceEntry methodReferenceEntry) { + return getProperty(methodNamesCache, methodReferenceEntry, false); + } + + private static <A extends Annotation> ValueCache<Class<?>, Map<String, Boolean>, RuntimeException> markedWith( + final Class<A> annotationClass) { + return cacheProperty(new ValueComputer<Class<?>, Map<String, Boolean>>() { + @Override + public Map<String, Boolean> compute(Class<?> key) { + Map<String, Boolean> markedMethodNames = new HashMap<>(); + for (final Method method : key.getDeclaredMethods()) { + markedMethodNames.put(toSignature(method), method.isAnnotationPresent(annotationClass)); + } + return markedMethodNames; + } + }); + } + + static String toSignature(Method method) { + return method.getName() + getArgumentsLetters(method) + getReturnTypeLetter(method); + } + + private static String getArgumentsLetters(Method method) { + StringBuilder sb = new StringBuilder("("); + for (Class<?> parameterClass : method.getParameterTypes()) { + sb.append(toClassShortNameIfAny(parameterClass)); + } + sb.append(")"); + return sb.toString(); + } + + private static boolean isRelevant(Method method) { + return !method.isSynthetic() && !method.isBridge(); + } + + private static <V, T extends Throwable> V getProperty(ValueCache<Class<?>, Map<String, V>, T> cache, + MethodReferenceEntry methodReferenceEntry, V defaultValue) throws T { + Map<String, V> map = cache.computeIfAbsent(methodReferenceEntry.getOwnerClassModel().getClassWeAreModelling()); + String key = toSignature(methodReferenceEntry); + if (map.containsKey(key)) + return map.get(key); + return defaultValue; + } + + private static String toSignature(MethodReferenceEntry methodReferenceEntry) { + NameAndTypeEntry nameAndTypeEntry = methodReferenceEntry.getNameAndTypeEntry(); + return nameAndTypeEntry.getNameUTF8Entry().getUTF8() + nameAndTypeEntry.getDescriptorUTF8Entry().getUTF8(); + } + + private static final ValueCache<Class<?>, Map<String, String>, RuntimeException> mappedMethodNamesCache = cacheProperty(new ValueComputer<Class<?>, Map<String, String>>() { + @Override + public Map<String, String> compute(Class<?> key) { + Map<String, String> properties = new HashMap<>(); + for (final Method method : key.getDeclaredMethods()) { + if (isRelevant(method) && method.isAnnotationPresent(OpenCLMapping.class)) { + // ultimately, need a way to constrain this based upon signature (to disambiguate abs(float) from abs(int); + final OpenCLMapping annotation = method.getAnnotation(OpenCLMapping.class); + final String mapTo = annotation.mapTo(); + if (mapTo != null && !mapTo.equals("")) { + properties.put(toSignature(method), mapTo); + } + } + } + return properties; + } + }); + + private static <K, V, T extends Throwable> ValueCache<Class<?>, Map<K, V>, T> cacheProperty( + final ThrowingValueComputer<Class<?>, Map<K, V>, T> throwingValueComputer) { + return ValueCache.on(new ThrowingValueComputer<Class<?>, Map<K, V>, T>() { + @Override + public Map<K, V> compute(Class<?> key) throws T { + Map<K, V> properties = new HashMap<>(); + Deque<Class<?>> superclasses = new ArrayDeque<>(); + Class<?> currentSuperClass = key; + do { + superclasses.push(currentSuperClass); + currentSuperClass = currentSuperClass.getSuperclass(); + } while (currentSuperClass != Object.class); + for (Class<?> clazz : superclasses) { + // Overwrite property values for shadowed/overriden methods + properties.putAll(throwingValueComputer.compute(clazz)); + } + return properties; + } + }); + } + + public static void invalidateCaches() { + atomic32Cache.invalidate(); + atomic64Cache.invalidate(); + mappedMethodFlags.invalidate(); + mappedMethodNamesCache.invalidate(); + openCLDelegateMethodFlags.invalidate(); + } } diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/instruction/Instruction.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/instruction/Instruction.java index cbb30fde717f1ac15743aa255e91b41aae54481c..68a7688a0b859c41e35f4ae6856c9ae6d2ade3f1 100644 --- a/com.amd.aparapi/src/java/com/amd/aparapi/internal/instruction/Instruction.java +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/instruction/Instruction.java @@ -181,7 +181,7 @@ public abstract class Instruction{ } @Override public String toString() { - return (String.format("%d %s", pc, byteCode.getName())); + return pc + " " + byteCode.getName(); } public boolean isBranch() { diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/instruction/InstructionSet.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/instruction/InstructionSet.java index 4faf88dfeb972909ecce0d1961ab1275f50f67b1..976ef32dd287cca4ca91009d8c239ddc0d4e278a 100644 --- a/com.amd.aparapi/src/java/com/amd/aparapi/internal/instruction/InstructionSet.java +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/instruction/InstructionSet.java @@ -629,6 +629,8 @@ public class InstructionSet{ private StoreSpec storeSpec; + private Constructor<?> constructor; + private ByteCode(Class<?> _class, LoadSpec _loadSpec, StoreSpec _storeSpec, ImmediateSpec _immediate, PopSpec _pop, PushSpec _push, Operator _operator) { clazz = _class; @@ -639,6 +641,21 @@ public class InstructionSet{ loadSpec = _loadSpec; storeSpec = _storeSpec; + if (clazz != null) { + + try { + constructor = clazz.getDeclaredConstructor(MethodModel.class, ByteReader.class, boolean.class); + } catch (final SecurityException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } catch (final NoSuchMethodException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } catch (final IllegalArgumentException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + } } private ByteCode(Class<?> _class, ImmediateSpec _immediate) { @@ -731,18 +748,13 @@ public class InstructionSet{ public Instruction newInstruction(MethodModel _methodModel, ByteReader byteReader, boolean _isWide) { Instruction newInstruction = null; - if (clazz != null) { - + if (constructor != null) { try { - final Constructor<?> constructor = clazz.getDeclaredConstructor(MethodModel.class, ByteReader.class, boolean.class); newInstruction = (Instruction) constructor.newInstance(_methodModel, byteReader, _isWide); newInstruction.setLength(byteReader.getOffset() - newInstruction.getThisPC()); } catch (final SecurityException e) { // TODO Auto-generated catch block e.printStackTrace(); - } catch (final NoSuchMethodException e) { - // TODO Auto-generated catch block - e.printStackTrace(); } catch (final IllegalArgumentException e) { // TODO Auto-generated catch block e.printStackTrace(); diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/kernel/KernelRunner.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/kernel/KernelRunner.java index e51c9ef59bdc0d0dfdb8b4b71d9a0fee084c5452..ad66f9b39f69d8cc77473941971d7efefcb321b5 100644 --- a/com.amd.aparapi/src/java/com/amd/aparapi/internal/kernel/KernelRunner.java +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/kernel/KernelRunner.java @@ -48,8 +48,10 @@ import java.util.Set; import java.util.StringTokenizer; import java.util.concurrent.BrokenBarrierException; import java.util.concurrent.CyclicBarrier; -import java.util.concurrent.Executors; -import java.util.concurrent.ExecutorService; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.ForkJoinPool.ForkJoinWorkerThreadFactory; +import java.util.concurrent.ForkJoinPool.ManagedBlocker; +import java.util.concurrent.ForkJoinWorkerThread; import java.util.logging.Level; import java.util.logging.Logger; @@ -100,10 +102,20 @@ public class KernelRunner extends KernelRunnerJNI{ private Entrypoint entryPoint; private int argc; - - private boolean isFallBack = false; // If isFallBack, rebuild the kernel (necessary?) - - private final ExecutorService threadPool = Executors.newCachedThreadPool(); + + private boolean isFallBack = false; // If isFallBack, rebuild the kernel (necessary?) + + private static final ForkJoinWorkerThreadFactory lowPriorityThreadFactory = new ForkJoinWorkerThreadFactory(){ + @Override public ForkJoinWorkerThread newThread(ForkJoinPool pool) { + ForkJoinWorkerThread newThread = ForkJoinPool.defaultForkJoinWorkerThreadFactory.newThread(pool); + newThread.setPriority(Thread.MIN_PRIORITY); + return newThread; + } + }; + + private static final ForkJoinPool threadPool = new ForkJoinPool(Runtime.getRuntime().availableProcessors(), + lowPriorityThreadFactory, null, false); + /** * Create a KernelRunner for a specific Kernel instance. * @@ -122,7 +134,8 @@ public class KernelRunner extends KernelRunnerJNI{ if (kernel.getExecutionMode().isOpenCL()) { disposeJNI(jniContextHandle); } - threadPool.shutdownNow(); + // We are using a shared pool, so there's no need no shutdown it when kernel is disposed + // threadPool.shutdownNow(); } private Set<String> capabilitiesSet; @@ -217,6 +230,50 @@ public class KernelRunner extends KernelRunnerJNI{ return capabilitiesSet.contains(OpenCL.CL_KHR_GL_SHARING); } + private static final class FJSafeCyclicBarrier extends CyclicBarrier{ + FJSafeCyclicBarrier(final int threads) { + super(threads); + } + + @Override public int await() throws InterruptedException, BrokenBarrierException { + class Awaiter implements ManagedBlocker{ + private int value; + + private boolean released; + + @Override public boolean block() throws InterruptedException { + try { + value = superAwait(); + released = true; + return true; + } catch (final BrokenBarrierException e) { + throw new RuntimeException(e); + } + } + + @Override public boolean isReleasable() { + return released; + } + + int getValue() { + return value; + } + } + final Awaiter awaiter = new Awaiter(); + ForkJoinPool.managedBlock(awaiter); + return awaiter.getValue(); + } + + int superAwait() throws InterruptedException, BrokenBarrierException { + return super.await(); + } + } + + // @FunctionalInterface + private interface ThreadIdSetter{ + void set(KernelState kernelState, int globalGroupId, int threadId); + } + /** * Execute using a Java thread pool. Either because we were explicitly asked to do so, or because we 'fall back' after discovering an OpenCL issue. * @@ -226,11 +283,15 @@ public class KernelRunner extends KernelRunnerJNI{ * The # of passes requested by the user (via <code>Kernel.execute(globalSize, passes)</code>). Note this is usually defaulted to 1 via <code>Kernel.execute(globalSize)</code>. * @return */ - private long executeJava(final Range _range, final int _passes) { + protected long executeJava(final Range _range, final int _passes) { if (logger.isLoggable(Level.FINE)) { logger.fine("executeJava: range = " + _range); } + final int localSize0 = _range.getLocalSize(0); + final int localSize1 = _range.getLocalSize(1); + final int localSize2 = _range.getLocalSize(2); + final int globalSize1 = _range.getGlobalSize(1); if (kernel.getExecutionMode().equals(EXECUTION_MODE.SEQ)) { /** * SEQ mode is useful for testing trivial logic, but kernels which use SEQ mode cannot be used if the @@ -240,7 +301,7 @@ public class KernelRunner extends KernelRunnerJNI{ * * So we need to check if the range is valid here. If not we have no choice but to punt. */ - if ((_range.getLocalSize(0) * _range.getLocalSize(1) * _range.getLocalSize(2)) > 1) { + if ((localSize0 * localSize1 * localSize2) > 1) { throw new IllegalStateException("Can't run range with group size >1 sequentially. Barriers would deadlock!"); } @@ -254,7 +315,7 @@ public class KernelRunner extends KernelRunnerJNI{ kernelState.setLocalId(0, 0); kernelState.setLocalId(1, 0); kernelState.setLocalId(2, 0); - kernelState.setLocalBarrier(new CyclicBarrier(1)); + kernelState.setLocalBarrier(new FJSafeCyclicBarrier(1)); for (int passId = 0; passId < _passes; passId++) { kernelState.setPassId(passId); @@ -268,7 +329,7 @@ public class KernelRunner extends KernelRunnerJNI{ for (int x = 0; x < _range.getGlobalSize(0); x++) { kernelState.setGlobalId(0, x); - for (int y = 0; y < _range.getGlobalSize(1); y++) { + for (int y = 0; y < globalSize1; y++) { kernelState.setGlobalId(1, y); kernelClone.run(); } @@ -277,7 +338,7 @@ public class KernelRunner extends KernelRunnerJNI{ for (int x = 0; x < _range.getGlobalSize(0); x++) { kernelState.setGlobalId(0, x); - for (int y = 0; y < _range.getGlobalSize(1); y++) { + for (int y = 0; y < globalSize1; y++) { kernelState.setGlobalId(1, y); for (int z = 0; z < _range.getGlobalSize(2); z++) { @@ -291,13 +352,15 @@ public class KernelRunner extends KernelRunnerJNI{ } } } else { - final int threads = _range.getLocalSize(0) * _range.getLocalSize(1) * _range.getLocalSize(2); - final int globalGroups = _range.getNumGroups(0) * _range.getNumGroups(1) * _range.getNumGroups(2); + final int threads = localSize0 * localSize1 * localSize2; + final int numGroups0 = _range.getNumGroups(0); + final int numGroups1 = _range.getNumGroups(1); + final int globalGroups = numGroups0 * numGroups1 * _range.getNumGroups(2); /** * This joinBarrier is the barrier that we provide for the kernel threads to rendezvous with the current dispatch thread. * So this barrier is threadCount+1 wide (the +1 is for the dispatch thread) */ - final CyclicBarrier joinBarrier = new CyclicBarrier(threads + 1); + final CyclicBarrier joinBarrier = new FJSafeCyclicBarrier(threads + 1); /** * This localBarrier is only ever used by the kernels. If the kernel does not use the barrier the threads @@ -310,8 +373,130 @@ public class KernelRunner extends KernelRunnerJNI{ * * This barrier is threadCount wide. We never hit the barrier from the dispatch thread. */ - final CyclicBarrier localBarrier = new CyclicBarrier(threads); + final CyclicBarrier localBarrier = new FJSafeCyclicBarrier(threads); + + final ThreadIdSetter threadIdSetter; + + if (_range.getDims() == 1) { + threadIdSetter = new ThreadIdSetter(){ + @Override public void set(KernelState kernelState, int globalGroupId, int threadId) { + // (kernelState, globalGroupId, threadId) ->{ + kernelState.setLocalId(0, (threadId % localSize0)); + kernelState.setGlobalId(0, (threadId + (globalGroupId * threads))); + kernelState.setGroupId(0, globalGroupId); + } + }; + } else if (_range.getDims() == 2) { + + /** + * Consider a 12x4 grid of 4*2 local groups + * <pre> + * threads = 4*2 = 8 + * localWidth=4 + * localHeight=2 + * globalWidth=12 + * globalHeight=4 + * + * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 + * 12 13 14 15 | 16 17 18 19 | 20 21 22 23 + * ------------+-------------+------------ + * 24 25 26 27 | 28 29 30 31 | 32 33 34 35 + * 36 37 38 39 | 40 41 42 43 | 44 45 46 47 + * + * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 threadIds : [0..7]*6 + * 04 05 06 07 | 04 05 06 07 | 04 05 06 07 + * ------------+-------------+------------ + * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 + * 04 05 06 07 | 04 05 06 07 | 04 05 06 07 + * + * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 groupId[0] : 0..6 + * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 + * ------------+-------------+------------ + * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 + * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 + * + * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 groupId[1] : 0..6 + * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 + * ------------+-------------+------------ + * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 + * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 + * + * 00 01 02 03 | 08 09 10 11 | 16 17 18 19 globalThreadIds == threadId + groupId * threads; + * 04 05 06 07 | 12 13 14 15 | 20 21 22 23 + * ------------+-------------+------------ + * 24 25 26 27 | 32[33]34 35 | 40 41 42 43 + * 28 29 30 31 | 36 37 38 39 | 44 45 46 47 + * + * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 localX = threadId % localWidth; (for globalThreadId 33 = threadId = 01 : 01%4 =1) + * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 + * ------------+-------------+------------ + * 00 01 02 03 | 00[01]02 03 | 00 01 02 03 + * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 + * + * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 localY = threadId /localWidth (for globalThreadId 33 = threadId = 01 : 01/4 =0) + * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 + * ------------+-------------+------------ + * 00 00 00 00 | 00[00]00 00 | 00 00 00 00 + * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 + * + * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 globalX= + * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 groupsPerLineWidth=globalWidth/localWidth (=12/4 =3) + * ------------+-------------+------------ groupInset =groupId%groupsPerLineWidth (=4%3 = 1) + * 00 01 02 03 | 04[05]06 07 | 08 09 10 11 + * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 globalX = groupInset*localWidth+localX (= 1*4+1 = 5) + * + * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 globalY + * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 + * ------------+-------------+------------ + * 02 02 02 02 | 02[02]02 02 | 02 02 02 02 + * 03 03 03 03 | 03 03 03 03 | 03 03 03 03 + * + * </pre> + * Assume we are trying to locate the id's for #33 + * + */ + threadIdSetter = new ThreadIdSetter(){ + @Override public void set(KernelState kernelState, int globalGroupId, int threadId) { + // (kernelState, globalGroupId, threadId) ->{ + kernelState.setLocalId(0, (threadId % localSize0)); // threadId % localWidth = (for 33 = 1 % 4 = 1) + kernelState.setLocalId(1, (threadId / localSize0)); // threadId / localWidth = (for 33 = 1 / 4 == 0) + + final int groupInset = globalGroupId % numGroups0; // 4%3 = 1 + kernelState.setGlobalId(0, ((groupInset * localSize0) + kernelState.getLocalIds()[0])); // 1*4+1=5 + + final int completeLines = (globalGroupId / numGroups0) * localSize1;// (4/3) * 2 + kernelState.setGlobalId(1, (completeLines + kernelState.getLocalIds()[1])); // 2+0 = 2 + kernelState.setGroupId(0, (globalGroupId % numGroups0)); + kernelState.setGroupId(1, (globalGroupId / numGroups0)); + } + }; + } else if (_range.getDims() == 3) { + //Same as 2D actually turns out that localId[0] is identical for all three dims so could be hoisted out of conditional code + threadIdSetter = new ThreadIdSetter(){ + @Override public void set(KernelState kernelState, int globalGroupId, int threadId) { + // (kernelState, globalGroupId, threadId) ->{ + kernelState.setLocalId(0, (threadId % localSize0)); + + kernelState.setLocalId(1, ((threadId / localSize0) % localSize1)); + + // the thread id's span WxHxD so threadId/(WxH) should yield the local depth + kernelState.setLocalId(2, (threadId / (localSize0 * localSize1))); + + kernelState.setGlobalId(0, (((globalGroupId % numGroups0) * localSize0) + kernelState.getLocalIds()[0])); + kernelState.setGlobalId(1, + ((((globalGroupId / numGroups0) * localSize1) % globalSize1) + kernelState.getLocalIds()[1])); + + kernelState.setGlobalId(2, + (((globalGroupId / (numGroups0 * numGroups1)) * localSize2) + kernelState.getLocalIds()[2])); + + kernelState.setGroupId(0, (globalGroupId % numGroups0)); + kernelState.setGroupId(1, ((globalGroupId / numGroups0) % numGroups1)); + kernelState.setGroupId(2, (globalGroupId / (numGroups0 * numGroups1))); + } + }; + } else + throw new IllegalArgumentException("Expected 1,2 or 3 dimensions, found " + _range.getDims()); for (int passId = 0; passId < _passes; passId++) { /** * Note that we emulate OpenCL by creating one thread per localId (across the group). @@ -354,135 +539,31 @@ public class KernelRunner extends KernelRunnerJNI{ */ final Kernel kernelClone = kernel.clone(); final KernelState kernelState = kernelClone.getKernelState(); - kernelState.setRange(_range); - kernelState.setLocalBarrier(localBarrier); kernelState.setPassId(passId); - threadPool.submit(new Runnable(){ - @Override public void run() { - for (int globalGroupId = 0; globalGroupId < globalGroups; globalGroupId++) { - - if (_range.getDims() == 1) { - kernelState.setLocalId(0, (threadId % _range.getLocalSize(0))); - kernelState.setGlobalId(0, (threadId + (globalGroupId * threads))); - kernelState.setGroupId(0, globalGroupId); - } else if (_range.getDims() == 2) { - - /** - * Consider a 12x4 grid of 4*2 local groups - * <pre> - * threads = 4*2 = 8 - * localWidth=4 - * localHeight=2 - * globalWidth=12 - * globalHeight=4 - * - * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 - * 12 13 14 15 | 16 17 18 19 | 20 21 22 23 - * ------------+-------------+------------ - * 24 25 26 27 | 28 29 30 31 | 32 33 34 35 - * 36 37 38 39 | 40 41 42 43 | 44 45 46 47 - * - * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 threadIds : [0..7]*6 - * 04 05 06 07 | 04 05 06 07 | 04 05 06 07 - * ------------+-------------+------------ - * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 - * 04 05 06 07 | 04 05 06 07 | 04 05 06 07 - * - * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 groupId[0] : 0..6 - * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 - * ------------+-------------+------------ - * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 - * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 - * - * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 groupId[1] : 0..6 - * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 - * ------------+-------------+------------ - * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 - * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 - * - * 00 01 02 03 | 08 09 10 11 | 16 17 18 19 globalThreadIds == threadId + groupId * threads; - * 04 05 06 07 | 12 13 14 15 | 20 21 22 23 - * ------------+-------------+------------ - * 24 25 26 27 | 32[33]34 35 | 40 41 42 43 - * 28 29 30 31 | 36 37 38 39 | 44 45 46 47 - * - * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 localX = threadId % localWidth; (for globalThreadId 33 = threadId = 01 : 01%4 =1) - * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 - * ------------+-------------+------------ - * 00 01 02 03 | 00[01]02 03 | 00 01 02 03 - * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 - * - * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 localY = threadId /localWidth (for globalThreadId 33 = threadId = 01 : 01/4 =0) - * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 - * ------------+-------------+------------ - * 00 00 00 00 | 00[00]00 00 | 00 00 00 00 - * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 - * - * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 globalX= - * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 groupsPerLineWidth=globalWidth/localWidth (=12/4 =3) - * ------------+-------------+------------ groupInset =groupId%groupsPerLineWidth (=4%3 = 1) - * 00 01 02 03 | 04[05]06 07 | 08 09 10 11 - * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 globalX = groupInset*localWidth+localX (= 1*4+1 = 5) - * - * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 globalY - * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 - * ------------+-------------+------------ - * 02 02 02 02 | 02[02]02 02 | 02 02 02 02 - * 03 03 03 03 | 03 03 03 03 | 03 03 03 03 - * - * </pre> - * Assume we are trying to locate the id's for #33 - * - */ - - kernelState.setLocalId(0, (threadId % _range.getLocalSize(0))); // threadId % localWidth = (for 33 = 1 % 4 = 1) - kernelState.setLocalId(1, (threadId / _range.getLocalSize(0))); // threadId / localWidth = (for 33 = 1 / 4 == 0) - - final int groupInset = globalGroupId % _range.getNumGroups(0); // 4%3 = 1 - kernelState.setGlobalId(0, ((groupInset * _range.getLocalSize(0)) + kernelState.getLocalIds()[0])); // 1*4+1=5 - - final int completeLines = (globalGroupId / _range.getNumGroups(0)) * _range.getLocalSize(1);// (4/3) * 2 - kernelState.setGlobalId(1, (completeLines + kernelState.getLocalIds()[1])); // 2+0 = 2 - kernelState.setGroupId(0, (globalGroupId % _range.getNumGroups(0))); - kernelState.setGroupId(1, (globalGroupId / _range.getNumGroups(0))); - } else if (_range.getDims() == 3) { - - //Same as 2D actually turns out that localId[0] is identical for all three dims so could be hoisted out of conditional code - - kernelState.setLocalId(0, (threadId % _range.getLocalSize(0))); - - kernelState.setLocalId(1, ((threadId / _range.getLocalSize(0)) % _range.getLocalSize(1))); - - // the thread id's span WxHxD so threadId/(WxH) should yield the local depth - kernelState.setLocalId(2, (threadId / (_range.getLocalSize(0) * _range.getLocalSize(1)))); - - kernelState.setGlobalId( - 0, - (((globalGroupId % _range.getNumGroups(0)) * _range.getLocalSize(0)) + kernelState.getLocalIds()[0])); - - kernelState.setGlobalId( - 1, - ((((globalGroupId / _range.getNumGroups(0)) * _range.getLocalSize(1)) % _range.getGlobalSize(1)) + kernelState - .getLocalIds()[1])); - - kernelState.setGlobalId( - 2, - (((globalGroupId / (_range.getNumGroups(0) * _range.getNumGroups(1))) * _range.getLocalSize(2)) + kernelState - .getLocalIds()[2])); - - kernelState.setGroupId(0, (globalGroupId % _range.getNumGroups(0))); - kernelState.setGroupId(1, ((globalGroupId / _range.getNumGroups(0)) % _range.getNumGroups(1))); - kernelState.setGroupId(2, (globalGroupId / (_range.getNumGroups(0) * _range.getNumGroups(1)))); - } - - kernelClone.run(); - } + if (threads == 1) { + kernelState.disableLocalBarrier(); + } else { + kernelState.setLocalBarrier(localBarrier); + } - await(joinBarrier); // This thread will rendezvous with dispatch thread here. This is effectively a join. - } - }); + threadPool.submit( + // () -> { + new Runnable(){ + public void run() { + try { + for (int globalGroupId = 0; globalGroupId < globalGroups; globalGroupId++) { + threadIdSetter.set(kernelState, globalGroupId, threadId); + kernelClone.run(); + } + } catch (RuntimeException | Error e) { + logger.log(Level.SEVERE, "Execution failed", e); + } finally { + await(joinBarrier); // This thread will rendezvous with dispatch thread here. This is effectively a join. + } + } + }); } await(joinBarrier); // This dispatch thread waits for all worker threads here. @@ -521,7 +602,7 @@ public class KernelRunner extends KernelRunnerJNI{ boolean didReallocate = false; if (arg.getObjArrayElementModel() == null) { - final String tmp = arrayClass.getName().substring(2).replace("/", "."); + final String tmp = arrayClass.getName().substring(2).replace('/', '.'); final String arrayClassInDotForm = tmp.substring(0, tmp.length() - 1); if (logger.isLoggable(Level.FINE)) { @@ -788,7 +869,8 @@ public class KernelRunner extends KernelRunnerJNI{ } if (privateMemorySize != null) { if (arrayLength > privateMemorySize) { - throw new IllegalStateException("__private array field " + fieldName + " has illegal length " + arrayLength + " > " + privateMemorySize); + throw new IllegalStateException("__private array field " + fieldName + " has illegal length " + arrayLength + + " > " + privateMemorySize); } } @@ -947,7 +1029,7 @@ public class KernelRunner extends KernelRunnerJNI{ if ((entryPoint == null) || (isFallBack)) { if (entryPoint == null) { try { - final ClassModel classModel = new ClassModel(kernel.getClass()); + final ClassModel classModel = ClassModel.createClassModel(kernel.getClass()); entryPoint = classModel.getEntrypoint(_entrypointName, kernel); } catch (final Exception exception) { return warnFallBackAndExecute(_entrypointName, _range, _passes, exception); @@ -1035,7 +1117,6 @@ public class KernelRunner extends KernelRunnerJNI{ && hasGlobalInt32ExtendedAtomicsSupport() && hasLocalInt32BaseAtomicsSupport() && hasLocalInt32ExtendedAtomicsSupport(); - if (entryPoint.requiresAtomic32Pragma() && !all32AtomicsAvailable) { return warnFallBackAndExecute(_entrypointName, _range, _passes, "32 bit Atomics required but not supported"); @@ -1096,13 +1177,9 @@ public class KernelRunner extends KernelRunnerJNI{ | (entryPoint.getArrayFieldAccesses().contains(field.getName()) ? ARG_READ : 0)); // args[i].type |= ARG_GLOBAL; - if (type.getName().startsWith("[L")) { args[i].setType(args[i].getType() - | (ARG_OBJ_ARRAY_STRUCT | - ARG_WRITE | - ARG_READ | - ARG_APARAPI_BUFFER)); + | (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ | ARG_APARAPI_BUFFER)); if (logger.isLoggable(Level.FINE)) { logger.fine("tagging " + args[i].getName() + " as (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ)"); @@ -1111,8 +1188,9 @@ public class KernelRunner extends KernelRunnerJNI{ try { setMultiArrayType(args[i], type); - } catch(AparapiException e) { - return warnFallBackAndExecute(_entrypointName, _range, _passes, "failed to set kernel arguement " + args[i].getName() + ". Aparapi only supports 2D and 3D arrays."); + } catch (AparapiException e) { + return warnFallBackAndExecute(_entrypointName, _range, _passes, "failed to set kernel arguement " + + args[i].getName() + ". Aparapi only supports 2D and 3D arrays."); } } else { @@ -1137,7 +1215,8 @@ public class KernelRunner extends KernelRunnerJNI{ if (type.getName().startsWith("[L")) { args[i].setType(args[i].getType() | (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ)); if (logger.isLoggable(Level.FINE)) { - logger.fine("tagging " + args[i].getName() + " as (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ)"); + logger.fine("tagging " + args[i].getName() + + " as (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ)"); } } } @@ -1179,11 +1258,11 @@ public class KernelRunner extends KernelRunnerJNI{ } i++; - } + } // at this point, i = the actual used number of arguments // (private buffers do not get treated as arguments) - + argc = i; setArgsJNI(jniContextHandle, args, argc); @@ -1225,7 +1304,6 @@ public class KernelRunner extends KernelRunnerJNI{ return kernel; } - private int getPrimitiveSize(int type) { if ((type & ARG_FLOAT) != 0) { return 4; @@ -1250,28 +1328,28 @@ public class KernelRunner extends KernelRunnerJNI{ private void setMultiArrayType(KernelArg arg, Class<?> type) throws AparapiException { arg.setType(arg.getType() | (ARG_WRITE | ARG_READ | ARG_APARAPI_BUFFER)); int numDims = 0; - while(type.getName().startsWith("[[[[")) { + while (type.getName().startsWith("[[[[")) { throw new AparapiException("Aparapi only supports 2D and 3D arrays."); } arg.setType(arg.getType() | ARG_ARRAYLENGTH); - while(type.getName().charAt(numDims) == '[') { + while (type.getName().charAt(numDims) == '[') { numDims++; } Object buffer = new Object(); try { buffer = arg.getField().get(kernel); - } catch(IllegalAccessException e) { + } catch (IllegalAccessException e) { e.printStackTrace(); } arg.setJavaBuffer(buffer); arg.setNumDims(numDims); Object subBuffer = buffer; int[] dims = new int[numDims]; - for(int i = 0; i < numDims-1; i++) { + for (int i = 0; i < numDims - 1; i++) { dims[i] = Array.getLength(subBuffer); subBuffer = Array.get(subBuffer, 0); } - dims[numDims-1] = Array.getLength(subBuffer); + dims[numDims - 1] = Array.getLength(subBuffer); arg.setDims(dims); if (subBuffer.getClass().isAssignableFrom(float[].class)) { @@ -1300,7 +1378,7 @@ public class KernelRunner extends KernelRunnerJNI{ } int primitiveSize = getPrimitiveSize(arg.getType()); int totalElements = 1; - for(int i = 0; i < numDims; i++) { + for (int i = 0; i < numDims; i++) { totalElements *= dims[i]; } arg.setSizeInBytes(totalElements * primitiveSize); @@ -1327,14 +1405,16 @@ public class KernelRunner extends KernelRunnerJNI{ */ public void get(Object array) { if (explicit - && ((kernel.getExecutionMode() == Kernel.EXECUTION_MODE.GPU) || (kernel.getExecutionMode() == Kernel.EXECUTION_MODE.ACC) || (kernel.getExecutionMode() == Kernel.EXECUTION_MODE.CPU))) { + && ((kernel.getExecutionMode() == Kernel.EXECUTION_MODE.GPU) + || (kernel.getExecutionMode() == Kernel.EXECUTION_MODE.ACC) || (kernel.getExecutionMode() == Kernel.EXECUTION_MODE.CPU))) { // Only makes sense when we are using OpenCL getJNI(jniContextHandle, array); } } public List<ProfileInfo> getProfileInfo() { - if (((kernel.getExecutionMode() == Kernel.EXECUTION_MODE.GPU) || (kernel.getExecutionMode() == Kernel.EXECUTION_MODE.ACC) || (kernel.getExecutionMode() == Kernel.EXECUTION_MODE.CPU))) { + if (((kernel.getExecutionMode() == Kernel.EXECUTION_MODE.GPU) || (kernel.getExecutionMode() == Kernel.EXECUTION_MODE.ACC) || (kernel + .getExecutionMode() == Kernel.EXECUTION_MODE.CPU))) { // Only makes sense when we are using OpenCL return (getProfileInfoJNI(jniContextHandle)); } else { @@ -1359,7 +1439,8 @@ public class KernelRunner extends KernelRunnerJNI{ public void put(Object array) { if (explicit - && ((kernel.getExecutionMode() == Kernel.EXECUTION_MODE.GPU) || (kernel.getExecutionMode() == Kernel.EXECUTION_MODE.ACC) || (kernel.getExecutionMode() == Kernel.EXECUTION_MODE.CPU))) { + && ((kernel.getExecutionMode() == Kernel.EXECUTION_MODE.GPU) + || (kernel.getExecutionMode() == Kernel.EXECUTION_MODE.ACC) || (kernel.getExecutionMode() == Kernel.EXECUTION_MODE.CPU))) { // Only makes sense when we are using OpenCL puts.add(array); } diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/CacheEnabler.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/CacheEnabler.java new file mode 100644 index 0000000000000000000000000000000000000000..000fe014a8aefcba89ad436caef2573848f59fd7 --- /dev/null +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/CacheEnabler.java @@ -0,0 +1,20 @@ +package com.amd.aparapi.internal.model; + +import com.amd.aparapi.Kernel; + +public class CacheEnabler{ + private static volatile boolean cachesEnabled = true; + + public static void setCachesEnabled(boolean cachesEnabled) { + if (CacheEnabler.cachesEnabled != cachesEnabled) { + Kernel.invalidateCaches(); + ClassModel.invalidateCaches(); + } + + CacheEnabler.cachesEnabled = cachesEnabled; + } + + public static boolean areCachesEnabled() { + return cachesEnabled; + } +} diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/ClassModel.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/ClassModel.java index e1b269fafea752e40943b6e98a0c9d3749f327d1..95c362a357f0b09911f7f433d426809073bd67ec 100644 --- a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/ClassModel.java +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/ClassModel.java @@ -41,6 +41,7 @@ import com.amd.aparapi.*; import com.amd.aparapi.internal.annotation.*; import com.amd.aparapi.internal.exception.*; import com.amd.aparapi.internal.instruction.InstructionSet.*; +import com.amd.aparapi.internal.model.ValueCache.ThrowingValueComputer; import com.amd.aparapi.internal.model.ClassModel.AttributePool.*; import com.amd.aparapi.internal.model.ClassModel.ConstantPool.*; import com.amd.aparapi.internal.reader.*; @@ -122,10 +123,31 @@ public class ClassModel{ private ClassModel superClazz = null; - private HashSet<String> noClMethods = null; + // private Memoizer<Set<String>> noClMethods = Memoizer.of(this::computeNoCLMethods); + private Memoizer<Set<String>> noClMethods = Memoizer.Impl.of(new Supplier<Set<String>>(){ + @Override + public Set<String> get() { + return computeNoCLMethods(); + } + }); - private HashMap<String, Kernel.PrivateMemorySpace> privateMemoryFields = null; + // private Memoizer<Map<String, Kernel.PrivateMemorySpace>> privateMemoryFields = Memoizer.of(this::computePrivateMemoryFields); + private Memoizer<Map<String, Kernel.PrivateMemorySpace>> privateMemoryFields = Memoizer.Impl + .of(new Supplier<Map<String, Kernel.PrivateMemorySpace>>(){ + @Override + public Map<String, Kernel.PrivateMemorySpace> get() { + return computePrivateMemoryFields(); + } + }); + // private ValueCache<String, Integer, ClassParseException> privateMemorySizes = ValueCache.on(this::computePrivateMemorySize); + private ValueCache<String, Integer, ClassParseException> privateMemorySizes = ValueCache + .on(new ThrowingValueComputer<String, Integer, ClassParseException>(){ + @Override + public Integer compute(String fieldName) throws ClassParseException { + return computePrivateMemorySize(fieldName); + } + }); /** * Create a ClassModel representing a given Class. @@ -137,7 +159,7 @@ public class ClassModel{ * @throws ClassParseException */ - public ClassModel(Class<?> _class) throws ClassParseException { + private ClassModel(Class<?> _class) throws ClassParseException { parse(_class); @@ -147,7 +169,7 @@ public class ClassModel{ // not occur in normal use if ((mySuper != null) && (!mySuper.getName().equals(Kernel.class.getName())) && (!mySuper.getName().equals("java.lang.Object"))) { - superClazz = new ClassModel(mySuper); + superClazz = createClassModel(mySuper); } } @@ -204,7 +226,8 @@ public class ClassModel{ return superClazz; } - @DocMe public void replaceSuperClazz(ClassModel c) { + @DocMe + public void replaceSuperClazz(ClassModel c) { if (superClazz != null) { assert c.isSuperClass(getClassWeAreModelling()) == true : "not my super"; if (superClazz.getClassWeAreModelling().getName().equals(c.getClassWeAreModelling().getName())) { @@ -260,32 +283,40 @@ public class ClassModel{ * If a field does not satisfy the private memory conditions, null, otherwise the size of private memory required. */ public Integer getPrivateMemorySize(String fieldName) throws ClassParseException { - if (privateMemoryFields == null) { - privateMemoryFields = new HashMap<String, Kernel.PrivateMemorySpace>(); - HashMap<Field, Kernel.PrivateMemorySpace> privateMemoryFields = new HashMap<Field, Kernel.PrivateMemorySpace>(); - for (Field field : getClassWeAreModelling().getDeclaredFields()) { - Kernel.PrivateMemorySpace privateMemorySpace = field.getAnnotation(Kernel.PrivateMemorySpace.class); - if (privateMemorySpace != null) { - privateMemoryFields.put(field, privateMemorySpace); - } - } - for (Field field : getClassWeAreModelling().getFields()) { - Kernel.PrivateMemorySpace privateMemorySpace = field.getAnnotation(Kernel.PrivateMemorySpace.class); - if (privateMemorySpace != null) { - privateMemoryFields.put(field, privateMemorySpace); - } - } - for (Map.Entry<Field, Kernel.PrivateMemorySpace> entry : privateMemoryFields.entrySet()) { - this.privateMemoryFields.put(entry.getKey().getName(), entry.getValue()); - } - } - Kernel.PrivateMemorySpace annotation = privateMemoryFields.get(fieldName); + if (CacheEnabler.areCachesEnabled()) + return privateMemorySizes.computeIfAbsent(fieldName); + return computePrivateMemorySize(fieldName); + } + + private Integer computePrivateMemorySize(String fieldName) throws ClassParseException { + Kernel.PrivateMemorySpace annotation = privateMemoryFields.get().get(fieldName); if (annotation != null) { return annotation.value(); } return getPrivateMemorySizeFromFieldName(fieldName); } + private Map<String, Kernel.PrivateMemorySpace> computePrivateMemoryFields() { + Map<String, Kernel.PrivateMemorySpace> tempPrivateMemoryFields = new HashMap<String, Kernel.PrivateMemorySpace>(); + Map<Field, Kernel.PrivateMemorySpace> privateMemoryFields = new HashMap<Field, Kernel.PrivateMemorySpace>(); + for (Field field : getClassWeAreModelling().getDeclaredFields()) { + Kernel.PrivateMemorySpace privateMemorySpace = field.getAnnotation(Kernel.PrivateMemorySpace.class); + if (privateMemorySpace != null) { + privateMemoryFields.put(field, privateMemorySpace); + } + } + for (Field field : getClassWeAreModelling().getFields()) { + Kernel.PrivateMemorySpace privateMemorySpace = field.getAnnotation(Kernel.PrivateMemorySpace.class); + if (privateMemorySpace != null) { + privateMemoryFields.put(field, privateMemorySpace); + } + } + for (Map.Entry<Field, Kernel.PrivateMemorySpace> entry : privateMemoryFields.entrySet()) { + tempPrivateMemoryFields.put(entry.getKey().getName(), entry.getValue()); + } + return tempPrivateMemoryFields; + } + public static Integer getPrivateMemorySizeFromField(Field field) { Kernel.PrivateMemorySpace privateMemorySpace = field.getAnnotation(Kernel.PrivateMemorySpace.class); if (privateMemorySpace != null) { @@ -309,25 +340,26 @@ public class ClassModel{ } public Set<String> getNoCLMethods() { - if (this.noClMethods == null) { - noClMethods = new HashSet<String>(); - HashSet<Method> methods = new HashSet<Method>(); - for (Method method : getClassWeAreModelling().getDeclaredMethods()) { - if (method.getAnnotation(Kernel.NoCL.class) != null) { - methods.add(method); - } - } - for (Method method : getClassWeAreModelling().getMethods()) { - if (method.getAnnotation(Kernel.NoCL.class) != null) { - methods.add(method); - } + return computeNoCLMethods(); + } + + private Set<String> computeNoCLMethods() { + Set<String> tempNoClMethods = new HashSet<String>(); + HashSet<Method> methods = new HashSet<Method>(); + for (Method method : getClassWeAreModelling().getDeclaredMethods()) { + if (method.getAnnotation(Kernel.NoCL.class) != null) { + methods.add(method); } - for (Method method : methods) { - noClMethods.add(method.getName()); + } + for (Method method : getClassWeAreModelling().getMethods()) { + if (method.getAnnotation(Kernel.NoCL.class) != null) { + methods.add(method); } - } - return noClMethods; + for (Method method : methods) { + tempNoClMethods.add(method.getName()); + } + return tempNoClMethods; } public static String convert(String _string) { @@ -602,6 +634,21 @@ public class ClassModel{ return (methodDescription); } + // private static final ValueCache<Class<?>, ClassModel, ClassParseException> classModelCache = ValueCache.onIdentity(ClassModel::new); + private static final ValueCache<Class<?>, ClassModel, ClassParseException> classModelCache = ValueCache + .on(new ThrowingValueComputer<Class<?>, ClassModel, ClassParseException>(){ + @Override + public ClassModel compute(Class<?> key) throws ClassParseException { + return new ClassModel(key); + } + }); + + public static ClassModel createClassModel(Class<?> _class) throws ClassParseException { + if (CacheEnabler.areCachesEnabled()) + return classModelCache.computeIfAbsent(_class); + return new ClassModel(_class); + } + private int magic; private int minorVersion; @@ -813,7 +860,8 @@ public class ClassModel{ super(_byteReader, _slot, ConstantPoolType.METHOD); } - @Override public String toString() { + @Override + public String toString() { final StringBuilder sb = new StringBuilder(); sb.append(getClassEntry().getNameUTF8Entry().getUTF8()); sb.append("."); @@ -934,13 +982,15 @@ public class ClassModel{ private Type returnType = null; - @Override public int hashCode() { + @Override + public int hashCode() { final NameAndTypeEntry nameAndTypeEntry = getNameAndTypeEntry(); return ((((nameAndTypeEntry.getNameIndex() * 31) + nameAndTypeEntry.getDescriptorIndex()) * 31) + getClassIndex()); } - @Override public boolean equals(Object _other) { + @Override + public boolean equals(Object _other) { if ((_other == null) || !(_other instanceof MethodReferenceEntry)) { return (false); } else { @@ -1312,7 +1362,8 @@ public class ClassModel{ } - @Override public Iterator<Entry> iterator() { + @Override + public Iterator<Entry> iterator() { return (entries.iterator()); } @@ -1399,58 +1450,51 @@ public class ClassModel{ } else if (_entry instanceof ConstantPool.NameAndTypeEntry) { final ConstantPool.NameAndTypeEntry nameAndTypeEntry = (ConstantPool.NameAndTypeEntry) _entry; references = new int[] { - nameAndTypeEntry.getNameIndex(), - nameAndTypeEntry.getDescriptorIndex() + nameAndTypeEntry.getNameIndex(), nameAndTypeEntry.getDescriptorIndex() }; } else if (_entry instanceof ConstantPool.MethodEntry) { final ConstantPool.MethodEntry methodEntry = (ConstantPool.MethodEntry) _entry; final ConstantPool.ClassEntry classEntry = (ConstantPool.ClassEntry) get(methodEntry.getClassIndex()); - @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8Entry = (ConstantPool.UTF8Entry) get(classEntry - .getNameIndex()); + @SuppressWarnings("unused") + final ConstantPool.UTF8Entry utf8Entry = (ConstantPool.UTF8Entry) get(classEntry.getNameIndex()); final ConstantPool.NameAndTypeEntry nameAndTypeEntry = (ConstantPool.NameAndTypeEntry) get(methodEntry .getNameAndTypeIndex()); - @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8NameEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry - .getNameIndex()); - @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8DescriptorEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry - .getDescriptorIndex()); + @SuppressWarnings("unused") + final ConstantPool.UTF8Entry utf8NameEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry.getNameIndex()); + @SuppressWarnings("unused") + final ConstantPool.UTF8Entry utf8DescriptorEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry.getDescriptorIndex()); references = new int[] { - methodEntry.getClassIndex(), - classEntry.getNameIndex(), - nameAndTypeEntry.getNameIndex(), + methodEntry.getClassIndex(), classEntry.getNameIndex(), nameAndTypeEntry.getNameIndex(), nameAndTypeEntry.getDescriptorIndex() }; } else if (_entry instanceof ConstantPool.InterfaceMethodEntry) { final ConstantPool.InterfaceMethodEntry interfaceMethodEntry = (ConstantPool.InterfaceMethodEntry) _entry; final ConstantPool.ClassEntry classEntry = (ConstantPool.ClassEntry) get(interfaceMethodEntry.getClassIndex()); - @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8Entry = (ConstantPool.UTF8Entry) get(classEntry - .getNameIndex()); + @SuppressWarnings("unused") + final ConstantPool.UTF8Entry utf8Entry = (ConstantPool.UTF8Entry) get(classEntry.getNameIndex()); final ConstantPool.NameAndTypeEntry nameAndTypeEntry = (ConstantPool.NameAndTypeEntry) get(interfaceMethodEntry .getNameAndTypeIndex()); - @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8NameEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry - .getNameIndex()); - @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8DescriptorEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry - .getDescriptorIndex()); + @SuppressWarnings("unused") + final ConstantPool.UTF8Entry utf8NameEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry.getNameIndex()); + @SuppressWarnings("unused") + final ConstantPool.UTF8Entry utf8DescriptorEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry.getDescriptorIndex()); references = new int[] { - interfaceMethodEntry.getClassIndex(), - classEntry.getNameIndex(), - nameAndTypeEntry.getNameIndex(), + interfaceMethodEntry.getClassIndex(), classEntry.getNameIndex(), nameAndTypeEntry.getNameIndex(), nameAndTypeEntry.getDescriptorIndex() }; } else if (_entry instanceof ConstantPool.FieldEntry) { final ConstantPool.FieldEntry fieldEntry = (ConstantPool.FieldEntry) _entry; final ConstantPool.ClassEntry classEntry = (ConstantPool.ClassEntry) get(fieldEntry.getClassIndex()); - @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8Entry = (ConstantPool.UTF8Entry) get(classEntry - .getNameIndex()); + @SuppressWarnings("unused") + final ConstantPool.UTF8Entry utf8Entry = (ConstantPool.UTF8Entry) get(classEntry.getNameIndex()); final ConstantPool.NameAndTypeEntry nameAndTypeEntry = (ConstantPool.NameAndTypeEntry) get(fieldEntry .getNameAndTypeIndex()); - @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8NameEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry - .getNameIndex()); - @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8DescriptorEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry - .getDescriptorIndex()); + @SuppressWarnings("unused") + final ConstantPool.UTF8Entry utf8NameEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry.getNameIndex()); + @SuppressWarnings("unused") + final ConstantPool.UTF8Entry utf8DescriptorEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry.getDescriptorIndex()); references = new int[] { - fieldEntry.getClassIndex(), - classEntry.getNameIndex(), - nameAndTypeEntry.getNameIndex(), + fieldEntry.getClassIndex(), classEntry.getNameIndex(), nameAndTypeEntry.getNameIndex(), nameAndTypeEntry.getDescriptorIndex() }; } @@ -1581,7 +1625,8 @@ public class ClassModel{ codeEntryAttributePool = new AttributePool(_byteReader, getName()); } - @Override public AttributePool getAttributePool() { + @Override + public AttributePool getAttributePool() { return (codeEntryAttributePool); } @@ -1664,7 +1709,8 @@ public class ClassModel{ super(_byteReader, _nameIndex, _length); } - @Override public Iterator<T> iterator() { + @Override + public Iterator<T> iterator() { return (pool.iterator()); } } @@ -1848,27 +1894,33 @@ public class ClassModel{ return (variableNameIndex); } - @Override public int getStart() { + @Override + public int getStart() { return (start); } - @Override public int getVariableIndex() { + @Override + public int getVariableIndex() { return (variableIndex); } - @Override public String getVariableName() { + @Override + public String getVariableName() { return (constantPool.getUTF8Entry(variableNameIndex).getUTF8()); } - @Override public String getVariableDescriptor() { + @Override + public String getVariableDescriptor() { return (constantPool.getUTF8Entry(descriptorIndex).getUTF8()); } - @Override public int getEnd() { + @Override + public int getEnd() { return (start + usageLength); } - @Override public boolean isArray() { + @Override + public boolean isArray() { return (getVariableDescriptor().startsWith("[")); } } @@ -1969,7 +2021,8 @@ public class ClassModel{ return (bytes); } - @Override public String toString() { + @Override + public String toString() { return (new String(bytes)); } @@ -1987,7 +2040,8 @@ public class ClassModel{ return (bytes); } - @Override public String toString() { + @Override + public String toString() { return (new String(bytes)); } } @@ -2004,7 +2058,8 @@ public class ClassModel{ return (bytes); } - @Override public String toString() { + @Override + public String toString() { return (new String(bytes)); } } @@ -2094,9 +2149,11 @@ public class ClassModel{ } } - @SuppressWarnings("unused") private final int elementNameIndex; + @SuppressWarnings("unused") + private final int elementNameIndex; - @SuppressWarnings("unused") private Value value; + @SuppressWarnings("unused") + private Value value; public ElementValuePair(ByteReader _byteReader) { elementNameIndex = _byteReader.u2(); @@ -2626,12 +2683,23 @@ public class ClassModel{ } public ClassModelMethod getMethod(String _name, String _descriptor) { + ClassModelMethod methodOrNull = getMethodOrNull(_name, _descriptor); + if (methodOrNull == null) + return superClazz != null ? superClazz.getMethod(_name, _descriptor) : (null); + return methodOrNull; + } + + private ClassModelMethod getMethodOrNull(String _name, String _descriptor) { for (final ClassModelMethod entry : methods) { if (entry.getName().equals(_name) && entry.getDescriptor().equals(_descriptor)) { + if (logger.isLoggable(Level.FINE)) { + logger.fine("Found " + clazz.getName() + "." + entry.getName() + " " + entry.getDescriptor() + " for " + + _name.replace('/', '.')); + } return (entry); } } - return superClazz != null ? superClazz.getMethod(_name, _descriptor) : (null); + return null; } public List<ClassModelField> getFieldPoolEntries() { @@ -2647,7 +2715,9 @@ public class ClassModel{ * @return The Method or null if we fail to locate a given method. */ public ClassModelMethod getMethod(MethodEntry _methodEntry, boolean _isSpecial) { - final String entryClassNameInDotForm = _methodEntry.getClassEntry().getNameUTF8Entry().getUTF8().replace('/', '.'); + NameAndTypeEntry nameAndTypeEntry = _methodEntry.getNameAndTypeEntry(); + String utf8Name = nameAndTypeEntry.getNameUTF8Entry().getUTF8(); + final String entryClassNameInDotForm = utf8Name.replace('/', '.'); // Shortcut direct calls to supers to allow "foo() { super.foo() }" type stuff to work if (_isSpecial && (superClazz != null) && superClazz.isSuperClass(entryClassNameInDotForm)) { @@ -2658,20 +2728,21 @@ public class ClassModel{ return superClazz.getMethod(_methodEntry, false); } - for (final ClassModelMethod entry : methods) { - if (entry.getName().equals(_methodEntry.getNameAndTypeEntry().getNameUTF8Entry().getUTF8()) - && entry.getDescriptor().equals(_methodEntry.getNameAndTypeEntry().getDescriptorUTF8Entry().getUTF8())) { - if (logger.isLoggable(Level.FINE)) { - logger.fine("Found " + clazz.getName() + "." + entry.getName() + " " + entry.getDescriptor() + " for " - + entryClassNameInDotForm); - } - return (entry); - } - } - - return superClazz != null ? superClazz.getMethod(_methodEntry, false) : (null); + ClassModelMethod methodOrNull = getMethodOrNull(utf8Name, nameAndTypeEntry.getDescriptorUTF8Entry().getUTF8()); + if (methodOrNull == null) + return superClazz != null ? superClazz.getMethod(_methodEntry, false) : (null); + return methodOrNull; } + // private ValueCache<MethodKey, MethodModel, AparapiException> methodModelCache = ValueCache.on(this::computeMethodModel); + private ValueCache<MethodKey, MethodModel, AparapiException> methodModelCache = ValueCache + .on(new ThrowingValueComputer<MethodKey, MethodModel, AparapiException>(){ + @Override + public MethodModel compute(MethodKey key) throws AparapiException { + return computeMethodModel(key); + } + }); + /** * Create a MethodModel for a given method name and signature. * @@ -2680,9 +2751,17 @@ public class ClassModel{ * @return * @throws AparapiException */ - public MethodModel getMethodModel(String _name, String _signature) throws AparapiException { - final ClassModelMethod method = getMethod(_name, _signature); + if (CacheEnabler.areCachesEnabled()) + return methodModelCache.computeIfAbsent(MethodKey.of(_name, _signature)); + else { + final ClassModelMethod method = getMethod(_name, _signature); + return new MethodModel(method); + } + } + + private MethodModel computeMethodModel(MethodKey methodKey) throws AparapiException { + final ClassModelMethod method = getMethod(methodKey.getName(), methodKey.getSignature()); return new MethodModel(method); } @@ -2715,9 +2794,29 @@ public class ClassModel{ totalStructSize = x; } + // private final ValueCache<EntrypointKey, Entrypoint, AparapiException> entrypointCache = ValueCache.on(this::computeBasicEntrypoint); + private final ValueCache<EntrypointKey, Entrypoint, AparapiException> entrypointCache = ValueCache + .on(new ThrowingValueComputer<EntrypointKey, Entrypoint, AparapiException>(){ + @Override + public Entrypoint compute(EntrypointKey key) throws AparapiException { + return computeBasicEntrypoint(key); + } + }); + Entrypoint getEntrypoint(String _entrypointName, String _descriptor, Object _k) throws AparapiException { - final MethodModel method = getMethodModel(_entrypointName, _descriptor); - return (new Entrypoint(this, method, _k)); + if (CacheEnabler.areCachesEnabled()) { + EntrypointKey key = EntrypointKey.of(_entrypointName, _descriptor); + Entrypoint entrypointWithoutKernel = entrypointCache.computeIfAbsent(key); + return entrypointWithoutKernel.cloneForKernel(_k); + } else { + final MethodModel method = getMethodModel(_entrypointName, _descriptor); + return new Entrypoint(this, method, _k); + } + } + + Entrypoint computeBasicEntrypoint(EntrypointKey entrypointKey) throws AparapiException { + final MethodModel method = getMethodModel(entrypointKey.getEntrypointName(), entrypointKey.getDescriptor()); + return new Entrypoint(this, method, null); } public Class<?> getClassWeAreModelling() { @@ -2731,4 +2830,8 @@ public class ClassModel{ public Entrypoint getEntrypoint() throws AparapiException { return (getEntrypoint("run", "()V", null)); } + + public static void invalidateCaches() { + classModelCache.invalidate(); + } } diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/Entrypoint.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/Entrypoint.java index 28aadab21efa63eb18c63abc4930cc7ed942075f..7ae155efa905a1dca6cd88f39931977a6ea9317a 100644 --- a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/Entrypoint.java +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/Entrypoint.java @@ -50,7 +50,7 @@ import java.lang.reflect.*; import java.util.*; import java.util.logging.*; -public class Entrypoint{ +public class Entrypoint implements Cloneable { private static Logger logger = Logger.getLogger(Config.getLoggerName()); @@ -81,7 +81,7 @@ public class Entrypoint{ private final List<MethodModel> calledMethods = new ArrayList<MethodModel>(); - private MethodModel methodModel; + private final MethodModel methodModel; /** True is an indication to use the fp64 pragma @@ -226,7 +226,7 @@ public class Entrypoint{ final Class<?> memberClass = Class.forName(className); // Immediately add this class and all its supers if necessary - memberClassModel = new ClassModel(memberClass); + memberClassModel = ClassModel.createClassModel(memberClass); if (logger.isLoggable(Level.FINEST)) { logger.finest("adding class " + className); } @@ -595,7 +595,7 @@ public class Entrypoint{ // Add the class model for the referenced obj array if (signature.startsWith("[L")) { // Turn [Lcom/amd/javalabs/opencl/demo/DummyOOA; into com.amd.javalabs.opencl.demo.DummyOOA for example - final String className = (signature.substring(2, signature.length() - 1)).replace("/", "."); + final String className = (signature.substring(2, signature.length() - 1)).replace('/', '.'); final ClassModel arrayFieldModel = getOrUpdateAllClassAccesses(className); if (arrayFieldModel != null) { final Class<?> memberClass = arrayFieldModel.getClassWeAreModelling(); @@ -625,7 +625,7 @@ public class Entrypoint{ } } } else { - final String className = (field.getClassEntry().getNameUTF8Entry().getUTF8()).replace("/", "."); + final String className = (field.getClassEntry().getNameUTF8Entry().getUTF8()).replace('/', '.'); // Look for object data member access if (!className.equals(getClassModel().getClassWeAreModelling().getName()) && (getFieldFromClassHierarchy(getClassModel().getClassWeAreModelling(), accessedFieldName) == null)) { @@ -640,7 +640,7 @@ public class Entrypoint{ fieldAssignments.add(assignedFieldName); referencedFieldNames.add(assignedFieldName); - final String className = (field.getClassEntry().getNameUTF8Entry().getUTF8()).replace("/", "."); + final String className = (field.getClassEntry().getNameUTF8Entry().getUTF8()).replace('/', '.'); // Look for object data member access if (!className.equals(getClassModel().getClassWeAreModelling().getName()) && (getFieldFromClassHierarchy(getClassModel().getClassWeAreModelling(), assignedFieldName) == null)) { @@ -909,4 +909,14 @@ public class Entrypoint{ return null; } + + Entrypoint cloneForKernel(Object _k) throws AparapiException { + try { + Entrypoint clonedEntrypoint = (Entrypoint) clone(); + clonedEntrypoint.kernelInstance = _k; + return clonedEntrypoint; + } catch (CloneNotSupportedException e) { + throw new AparapiException(e); + } + } } diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/EntrypointKey.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/EntrypointKey.java new file mode 100644 index 0000000000000000000000000000000000000000..bddf99d7b8ca0b289b2add855dd96bd0179a0f2f --- /dev/null +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/EntrypointKey.java @@ -0,0 +1,57 @@ +package com.amd.aparapi.internal.model; + +final class EntrypointKey{ + public static EntrypointKey of(String entrypointName, String descriptor) { + return new EntrypointKey(entrypointName, descriptor); + } + + private String descriptor; + + private String entrypointName; + + private EntrypointKey(String entrypointName, String descriptor) { + this.entrypointName = entrypointName; + this.descriptor = descriptor; + } + + String getDescriptor() { + return descriptor; + } + + String getEntrypointName() { + return entrypointName; + } + + @Override public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((descriptor == null) ? 0 : descriptor.hashCode()); + result = prime * result + ((entrypointName == null) ? 0 : entrypointName.hashCode()); + return result; + } + + @Override public String toString() { + return "EntrypointKey [entrypointName=" + entrypointName + ", descriptor=" + descriptor + "]"; + } + + @Override public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + EntrypointKey other = (EntrypointKey) obj; + if (descriptor == null) { + if (other.descriptor != null) + return false; + } else if (!descriptor.equals(other.descriptor)) + return false; + if (entrypointName == null) { + if (other.entrypointName != null) + return false; + } else if (!entrypointName.equals(other.entrypointName)) + return false; + return true; + } +} diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/Memoizer.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/Memoizer.java new file mode 100644 index 0000000000000000000000000000000000000000..ece7e391574fb962f7f28d06e876e97693b2d970 --- /dev/null +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/Memoizer.java @@ -0,0 +1,84 @@ +package com.amd.aparapi.internal.model; + +import java.util.NoSuchElementException; +import java.util.concurrent.atomic.AtomicReference; + +interface Optional<E> { + final class Some<E> implements Optional<E>{ + private final E value; + + static final <E> Optional<E> of(E value) { + return new Some<>(value); + } + + private Some(E value) { + this.value = value; + } + + @Override public E get() { + return value; + } + + @Override public boolean isPresent() { + return true; + } + } + + final class None<E> implements Optional<E>{ + @SuppressWarnings("unchecked") static <E> Optional<E> none() { + return none; + } + + @SuppressWarnings("rawtypes") private static final None none = new None(); + + private None() { + // Do nothing + } + + @Override public E get() { + throw new NoSuchElementException("No value present"); + } + + @Override public boolean isPresent() { + return false; + } + } + + E get(); + + boolean isPresent(); +} + +public interface Memoizer<T> extends Supplier<T>{ + public final class Impl<T> implements Memoizer<T>{ + private final Supplier<T> supplier; + + private final AtomicReference<Optional<T>> valueRef = new AtomicReference<>(Optional.None.<T> none()); + + Impl(Supplier<T> supplier) { + this.supplier = supplier; + } + + @Override public T get() { + Optional<T> value = valueRef.get(); + while (!value.isPresent()) { + Optional<T> newValue = Optional.Some.of(supplier.get()); + if (valueRef.compareAndSet(value, newValue)) { + value = newValue; + break; + } + value = valueRef.get(); + } + return value.get(); + } + + public static <T> Memoizer<T> of(Supplier<T> supplier) { + return new Impl<>(supplier); + } + } + + // static <T> Memoizer<T> of(Supplier<T> supplier) + // { + // return new Impl<>(supplier); + // } +} diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/MethodKey.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/MethodKey.java new file mode 100644 index 0000000000000000000000000000000000000000..d08e2debac32c97fddf209ce295e471df44dd29e --- /dev/null +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/MethodKey.java @@ -0,0 +1,57 @@ +package com.amd.aparapi.internal.model; + +final class MethodKey{ + static MethodKey of(String name, String signature) { + return new MethodKey(name, signature); + } + + private final String name; + + private final String signature; + + @Override public String toString() { + return "MethodKey [name=" + getName() + ", signature=" + getSignature() + "]"; + } + + @Override public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((getName() == null) ? 0 : getName().hashCode()); + result = prime * result + ((getSignature() == null) ? 0 : getSignature().hashCode()); + return result; + } + + @Override public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + MethodKey other = (MethodKey) obj; + if (getName() == null) { + if (other.getName() != null) + return false; + } else if (!getName().equals(other.getName())) + return false; + if (getSignature() == null) { + if (other.getSignature() != null) + return false; + } else if (!getSignature().equals(other.getSignature())) + return false; + return true; + } + + private MethodKey(String name, String signature) { + this.name = name; + this.signature = signature; + } + + public String getName() { + return name; + } + + public String getSignature() { + return signature; + } +} diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/Supplier.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/Supplier.java new file mode 100644 index 0000000000000000000000000000000000000000..737345b65efb73e8a8938c9f428d5164e0cec685 --- /dev/null +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/Supplier.java @@ -0,0 +1,8 @@ +package com.amd.aparapi.internal.model; + +/** + * Substitute of Java8's Supplier<V> interface, used in Java7 backport of caches. + */ +public interface Supplier<V> { + V get(); +} diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/ValueCache.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/ValueCache.java new file mode 100644 index 0000000000000000000000000000000000000000..ef66a53fdeca66f8da816f12f1d88e360d749303 --- /dev/null +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/ValueCache.java @@ -0,0 +1,46 @@ +package com.amd.aparapi.internal.model; + +import java.lang.ref.Reference; +import java.lang.ref.SoftReference; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +//import java.util.function.Supplier; + +public final class ValueCache<K, V, T extends Throwable> { + // @FunctionalInterface + public interface ThrowingValueComputer<K, V, T extends Throwable> { + V compute(K key) throws T; + } + + // @FunctionalInterface + public interface ValueComputer<K, V> extends ThrowingValueComputer<K, V, RuntimeException>{ + // Marker interface + } + + public static <K, V, T extends Throwable> ValueCache<K, V, T> on(ThrowingValueComputer<K, V, T> computer) { + return new ValueCache<K, V, T>(computer); + } + + private final ConcurrentMap<K, SoftReference<V>> map = new ConcurrentHashMap<>(); + + private final ThrowingValueComputer<K, V, T> computer; + + private ValueCache(ThrowingValueComputer<K, V, T> computer) { + this.computer = computer; + } + + public V computeIfAbsent(K key) throws T { + Reference<V> reference = map.get(key); + V value = reference == null ? null : reference.get(); + if (value == null) { + value = computer.compute(key); + map.put(key, new SoftReference<>(value)); + } + return value; + } + + public void invalidate() { + map.clear(); + } +} diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/tool/InstructionViewer.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/tool/InstructionViewer.java index 6e80defbadf45301bce302cb474343a269914660..015cf46d7492b83b5d914c77e3b206006cbe32ae 100644 --- a/com.amd.aparapi/src/java/com/amd/aparapi/internal/tool/InstructionViewer.java +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/tool/InstructionViewer.java @@ -625,7 +625,7 @@ public class InstructionViewer implements Config.InstructionListener{ public InstructionViewer(Color _background, String _name) { try { - classModel = new ClassModel(Class.forName(_name)); + classModel = ClassModel.createClassModel(Class.forName(_name)); } catch (final ClassParseException e) { // TODO Auto-generated catch block e.printStackTrace(); diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/writer/KernelWriter.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/writer/KernelWriter.java index 23d5c1588103e5e2101bd19cd2187b65d31d5161..613b091fdbf4b09fda0bcb2dcc6a0e69db902582 100644 --- a/com.amd.aparapi/src/java/com/amd/aparapi/internal/writer/KernelWriter.java +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/writer/KernelWriter.java @@ -273,9 +273,9 @@ public abstract class KernelWriter extends BlockWriter{ public final static String __private = "__private"; - public final static String LOCAL_ANNOTATION_NAME = "L" + Local.class.getName().replace(".", "/") + ";"; + public final static String LOCAL_ANNOTATION_NAME = "L" + Local.class.getName().replace('.', '/') + ";"; - public final static String CONSTANT_ANNOTATION_NAME = "L" + Constant.class.getName().replace(".", "/") + ";"; + public final static String CONSTANT_ANNOTATION_NAME = "L" + Constant.class.getName().replace('.', '/') + ";"; @Override public void write(Entrypoint _entryPoint) throws CodeGenException { final List<String> thisStruct = new ArrayList<String>(); @@ -341,7 +341,7 @@ public abstract class KernelWriter extends BlockWriter{ String className = null; if (signature.startsWith("L")) { // Turn Lcom/amd/javalabs/opencl/demo/DummyOOA; into com_amd_javalabs_opencl_demo_DummyOOA for example - className = (signature.substring(1, signature.length() - 1)).replace("/", "_"); + className = (signature.substring(1, signature.length() - 1)).replace('/', '_'); // if (logger.isLoggable(Level.FINE)) { // logger.fine("Examining object parameter: " + signature + " new: " + className); // } @@ -475,7 +475,7 @@ public abstract class KernelWriter extends BlockWriter{ for (final ClassModel cm : _entryPoint.getObjectArrayFieldsClasses().values()) { final ArrayList<FieldEntry> fieldSet = cm.getStructMembers(); if (fieldSet.size() > 0) { - final String mangledClassName = cm.getClassWeAreModelling().getName().replace(".", "_"); + final String mangledClassName = cm.getClassWeAreModelling().getName().replace('.', '_'); newLine(); write("typedef struct " + mangledClassName + "_s{"); in(); @@ -572,11 +572,11 @@ public abstract class KernelWriter extends BlockWriter{ // Call to an object member or superclass of member for (final ClassModel c : _entryPoint.getObjectArrayFieldsClasses().values()) { if (mm.getMethod().getClassModel() == c) { - write("__global " + mm.getMethod().getClassModel().getClassWeAreModelling().getName().replace(".", "_") + write("__global " + mm.getMethod().getClassModel().getClassWeAreModelling().getName().replace('.', '_') + " *this"); break; } else if (mm.getMethod().getClassModel().isSuperClass(c.getClassWeAreModelling())) { - write("__global " + c.getClassWeAreModelling().getName().replace(".", "_") + " *this"); + write("__global " + c.getClassWeAreModelling().getName().replace('.', '_') + " *this"); break; } } diff --git a/com.amd.aparapi/src/test/ConvolutionLargeTest.java b/com.amd.aparapi/src/test/ConvolutionLargeTest.java new file mode 100644 index 0000000000000000000000000000000000000000..fca9aab9376ff3d91256b4f7de7f7f6776cff6e4 --- /dev/null +++ b/com.amd.aparapi/src/test/ConvolutionLargeTest.java @@ -0,0 +1,276 @@ +/* +Copyright (c) 2010-2011, Advanced Micro Devices, Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +following conditions are met: + +Redistributions of source code must retain the above copyright notice, this list of conditions and the following +disclaimer. + +Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following +disclaimer in the documentation and/or other materials provided with the distribution. + +Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +If you use the software (in whole or in part), you shall adhere to all applicable U.S., European, and other export +laws, including but not limited to the U.S. Export Administration Regulations ("EAR"), (15 C.F.R. Sections 730 through +774), and E.U. Council Regulation (EC) No 1334/2000 of 22 June 2000. Further, pursuant to Section 740.6 of the EAR, +you hereby certify that, except pursuant to a license granted by the United States Department of Commerce Bureau of +Industry and Security or as otherwise permitted pursuant to a License Exception under the U.S. Export Administration +Regulations ("EAR"), you will not (1) export, re-export or release to a national of a country in Country Groups D:1, +E:1 or E:2 any restricted technology, software, or source code you receive hereunder, or (2) export to Country Groups +D:1, E:1 or E:2 the direct product of such technology or software, if such foreign produced direct product is subject +to national security controls as identified on the Commerce Control List (currently found in Supplement 1 to Part 774 +of EAR). For the most current Country Group listings, or for additional information about the EAR or your obligations +under those regulations, please refer to the U.S. Bureau of Industry and Security's website at http://www.bis.doc.gov/. + + */ + +import java.io.IOException; +import java.text.MessageFormat; +import java.util.concurrent.TimeUnit; + +import com.amd.aparapi.Kernel; +import com.amd.aparapi.internal.model.CacheEnabler; +import com.amd.aparapi.internal.model.Supplier; + +public class ConvolutionLargeTest{ + + private byte[] inBytes; + + private byte[] outBytes; + + private int width; + + private int height; + + private float _convMatrix3x3[]; + + public ConvolutionLargeTest(String[] _args) throws IOException { + // final File _file = new File(_args.length == 1 ? _args[0] : "testcard.jpg"); + + _convMatrix3x3 = new float[] { + 0f, + -10f, + 0f, + -10f, + 40f, + -10f, + 0f, + -10f, + 0f, + }; + + // BufferedImage inputImage = ImageIO.read(_file); + + // System.out.println(inputImage); + + // height = inputImage.getHeight(); + // + // width = inputImage.getWidth(); + // + // BufferedImage outputImage = new BufferedImage(width, height, inputImage.getType()); + // + // inBytes = ((DataBufferByte) inputImage.getRaster().getDataBuffer()).getData(); + // outBytes = ((DataBufferByte) outputImage.getRaster().getDataBuffer()).getData(); + } + + private void prepareForSize(int pixels) { + int side = (int) Math.sqrt(pixels); + width = side; + height = side; + + inBytes = new byte[width * height * 3]; + outBytes = new byte[width * height * 3]; + } + + public static void main(final String[] _args) throws IOException { + new ConvolutionLargeTest(_args).go(); + } + + private void go() { + boolean testWithoutCaches = true; + int maxRounds = 5; + for (int i = 1; i <= maxRounds; i++) { + System.out.println("-----------------------------"); + int pixels = (1_000 * 1_000 * (1 << (i - 1))) & ~(1 << 10 - 1); + System.out.println(MessageFormat.format("Round #{0}/{1} ({2} pixels)", i, maxRounds, pixels)); + prepareForSize(pixels); + System.out.println("-----------------------------"); + System.out.println(); + testWithSupplier(new ImageConvolutionCreationContext(){ + private ImageConvolution convolution = new ImageConvolution(); + + @Override public Supplier<ImageConvolution> getSupplier() { + return new Supplier<ImageConvolution>(){ + @Override public ImageConvolution get() { + return convolution; + } + }; + } + + @Override public Consumer<ImageConvolution> getDisposer() { + return new Consumer<ImageConvolution>(){ + @Override public void accept(ImageConvolution k) { + // Do nothing + } + }; + } + + @Override public void shutdown() { + convolution.dispose(); + } + + @Override public String getName() { + return "single kernel"; + } + }, 10, testWithoutCaches); + testWithSupplier(new ImageConvolutionCreationContext(){ + @Override public Supplier<ImageConvolution> getSupplier() { + return new Supplier<ImageConvolution>(){ + @Override public ImageConvolution get() { + return new ImageConvolution(); + } + }; + } + + @Override public Consumer<ImageConvolution> getDisposer() { + return new Consumer<ImageConvolution>(){ + @Override public void accept(ImageConvolution k) { + k.dispose(); + } + }; + } + + @Override public void shutdown() { + // Do nothing + } + + @Override public String getName() { + return "multiple kernels"; + } + }, 10, testWithoutCaches); + } + } + + private void testWithSupplier(ImageConvolutionCreationContext imageConvolutionCreationContext, int seconds, + boolean testWithoutCaches) { + System.out.println("Test context: " + imageConvolutionCreationContext.getName()); + CacheEnabler.setCachesEnabled(!testWithoutCaches); + // Warmup + doTest("Warmup (caches " + (testWithoutCaches ? "not " : "") + "enabled)", 2, imageConvolutionCreationContext); + if (testWithoutCaches) { + long timeWithoutCaches = doTest("Without caches", seconds, imageConvolutionCreationContext); + CacheEnabler.setCachesEnabled(true); + long timeWithCaches = doTest("With caches", seconds, imageConvolutionCreationContext); + System.out.println(MessageFormat.format("\tSpeedup: {0} %", 100d * (timeWithoutCaches - timeWithCaches) + / timeWithoutCaches)); + } else { + doTest("With caches", seconds, imageConvolutionCreationContext); + } + } + + // @FunctionalInterface + private interface Consumer<K> { + void accept(K k); + } + + private interface ImageConvolutionCreationContext{ + Supplier<ImageConvolution> getSupplier(); + + Consumer<ImageConvolution> getDisposer(); + + void shutdown(); + + String getName(); + + } + + private long doTest(String name, int seconds, ImageConvolutionCreationContext imageConvolutionCreationContext) { + long totalTime = 0; + Supplier<ImageConvolution> imageConvolutionSupplier = imageConvolutionCreationContext.getSupplier(); + Consumer<ImageConvolution> disposer = imageConvolutionCreationContext.getDisposer(); + System.out.print("\tTesting " + name + "[" + imageConvolutionCreationContext.getName() + "] (" + seconds + " seconds) "); + int calls = 0; + long initialTime = System.nanoTime(); + long maxElapsedNs = TimeUnit.SECONDS.toNanos(seconds); + for (;;) { + long start = System.nanoTime(); + if (start - initialTime > maxElapsedNs) + break; + ImageConvolution imageConvolution = imageConvolutionSupplier.get(); + try { + imageConvolution.applyConvolution(_convMatrix3x3, inBytes, outBytes, width, height); + } finally { + disposer.accept(imageConvolution); + } + + long end = System.nanoTime(); + long roundTime = end - start; + totalTime += roundTime; + // System.out.print("#" + i + " - " + roundTime + "ms "); + // System.out.print(roundTime + " "); + System.out.print("."); + calls++; + } + imageConvolutionCreationContext.shutdown(); + System.out.println(); + System.out.println(MessageFormat.format("\tFinished in {0} s ({1} ms/call, {2} calls)", totalTime / 1e9d, + (totalTime / (calls * 1e6d)), calls)); + System.out.println(); + return totalTime / calls; + } + + final static class ImageConvolution extends Kernel{ + + private float convMatrix3x3[]; + + private int width, height; + + private byte imageIn[], imageOut[]; + + public void processPixel(int x, int y, int w, int h) { + float accum = 0f; + int count = 0; + for (int dx = -3; dx < 6; dx += 3) { + for (int dy = -1; dy < 2; dy += 1) { + final int rgb = 0xff & imageIn[((y + dy) * w) + (x + dx)]; + + accum += rgb * convMatrix3x3[count++]; + } + } + final byte value = (byte) (max(0, min((int) accum, 255))); + imageOut[(y * w) + x] = value; + + } + + @Override public void run() { + final int x = getGlobalId(0) % (width * 3); + final int y = getGlobalId(0) / (width * 3); + + if ((x > 3) && (x < ((width * 3) - 3)) && (y > 1) && (y < (height - 1))) { + processPixel(x, y, width * 3, height); + } + + } + + public void applyConvolution(float[] _convMatrix3x3, byte[] _imageIn, byte[] _imageOut, int _width, int _height) { + imageIn = _imageIn; + imageOut = _imageOut; + width = _width; + height = _height; + convMatrix3x3 = _convMatrix3x3; + execute(3 * width * height); + } + } +}