diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/Kernel.java b/com.amd.aparapi/src/java/com/amd/aparapi/Kernel.java
index 0069d4e46fcb97d834b5a4c5af48e3d84dc710f9..ff9094e856ef8cbdfed7ca896cd8717b534b2a0a 100644
--- a/com.amd.aparapi/src/java/com/amd/aparapi/Kernel.java
+++ b/com.amd.aparapi/src/java/com/amd/aparapi/Kernel.java
@@ -40,6 +40,10 @@ package com.amd.aparapi;
import com.amd.aparapi.annotation.*;
import com.amd.aparapi.exception.*;
import com.amd.aparapi.internal.kernel.*;
+import com.amd.aparapi.internal.model.CacheEnabler;
+import com.amd.aparapi.internal.model.ValueCache;
+import com.amd.aparapi.internal.model.ValueCache.ThrowingValueComputer;
+import com.amd.aparapi.internal.model.ValueCache.ValueComputer;
import com.amd.aparapi.internal.model.ClassModel.ConstantPool.*;
import com.amd.aparapi.internal.opencl.*;
import com.amd.aparapi.internal.util.*;
@@ -488,6 +492,8 @@ public abstract class Kernel implements Cloneable {
private volatile CyclicBarrier localBarrier;
+ private boolean localBarrierDisabled;
+
/**
* Default constructor
*/
@@ -620,6 +626,24 @@ public abstract class Kernel implements Cloneable {
public void setLocalBarrier(CyclicBarrier localBarrier) {
this.localBarrier = localBarrier;
}
+
+ public void awaitOnLocalBarrier() {
+ if (!localBarrierDisabled) {
+ try {
+ kernelState.getLocalBarrier().await();
+ } catch (final InterruptedException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ } catch (final BrokenBarrierException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ }
+ }
+
+ public void disableLocalBarrier() {
+ localBarrierDisabled = true;
+ }
}
/**
@@ -1834,15 +1858,7 @@ public abstract class Kernel implements Cloneable {
@OpenCLDelegate
@Experimental
protected final void localBarrier() {
- try {
- kernelState.getLocalBarrier().await();
- } catch (final InterruptedException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- } catch (final BrokenBarrierException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- }
+ kernelState.awaitOnLocalBarrier();
}
/**
@@ -1862,6 +1878,16 @@ public abstract class Kernel implements Cloneable {
"Kernel.globalBarrier() has been deprecated. It was based an incorrect understanding of OpenCL functionality.");
}
+ @OpenCLMapping(mapTo = "hypot")
+ protected float hypot(final float a, final float b) {
+ return (float) Math.hypot(a, b);
+ }
+
+ @OpenCLMapping(mapTo = "hypot")
+ protected double hypot(final double a, final double b) {
+ return Math.hypot(a, b);
+ }
+
public KernelState getKernelState() {
return kernelState;
}
@@ -2100,49 +2126,64 @@ public abstract class Kernel implements Cloneable {
}
private static String getReturnTypeLetter(Method meth) {
- final Class<?> retClass = meth.getReturnType();
+ return toClassShortNameIfAny(meth.getReturnType());
+ }
+
+ private static String toClassShortNameIfAny(final Class<?> retClass) {
+ if (retClass.isArray()) {
+ return "[" + toClassShortNameIfAny(retClass.getComponentType());
+ }
final String strRetClass = retClass.toString();
final String mapping = typeToLetterMap.get(strRetClass);
// System.out.println("strRetClass = <" + strRetClass + ">, mapping = " + mapping);
+ if (mapping == null)
+ return "[" + retClass.getName() + ";";
return mapping;
}
public static String getMappedMethodName(MethodReferenceEntry _methodReferenceEntry) {
+ if (CacheEnabler.areCachesEnabled())
+ return getProperty(mappedMethodNamesCache, _methodReferenceEntry, null);
String mappedName = null;
final String name = _methodReferenceEntry.getNameAndTypeEntry().getNameUTF8Entry().getUTF8();
- for (final Method kernelMethod : Kernel.class.getDeclaredMethods()) {
- if (kernelMethod.isAnnotationPresent(OpenCLMapping.class)) {
- // ultimately, need a way to constrain this based upon signature (to disambiguate abs(float) from abs(int);
- // for Alpha, we will just disambiguate based on the return type
- if (false) {
- System.out.println("kernelMethod is ... " + kernelMethod.toGenericString());
- System.out.println("returnType = " + kernelMethod.getReturnType());
- System.out.println("returnTypeLetter = " + getReturnTypeLetter(kernelMethod));
- System.out.println("kernelMethod getName = " + kernelMethod.getName());
- System.out.println("methRefName = " + name + " descriptor = "
- + _methodReferenceEntry.getNameAndTypeEntry().getDescriptorUTF8Entry().getUTF8());
- System.out
- .println("descToReturnTypeLetter = "
- + descriptorToReturnTypeLetter(_methodReferenceEntry.getNameAndTypeEntry().getDescriptorUTF8Entry()
- .getUTF8()));
- }
- if (_methodReferenceEntry.getNameAndTypeEntry().getNameUTF8Entry().getUTF8().equals(kernelMethod.getName())
- && descriptorToReturnTypeLetter(_methodReferenceEntry.getNameAndTypeEntry().getDescriptorUTF8Entry().getUTF8())
- .equals(getReturnTypeLetter(kernelMethod))) {
- final OpenCLMapping annotation = kernelMethod.getAnnotation(OpenCLMapping.class);
- final String mapTo = annotation.mapTo();
- if (!mapTo.equals("")) {
- mappedName = mapTo;
- // System.out.println("mapTo = " + mapTo);
+ Class<?> currentClass = _methodReferenceEntry.getOwnerClassModel().getClassWeAreModelling();
+ while (currentClass != Object.class) {
+ for (final Method kernelMethod : currentClass.getDeclaredMethods()) {
+ if (kernelMethod.isAnnotationPresent(OpenCLMapping.class)) {
+ // ultimately, need a way to constrain this based upon signature (to disambiguate abs(float) from abs(int);
+ // for Alpha, we will just disambiguate based on the return type
+ if (false) {
+ System.out.println("kernelMethod is ... " + kernelMethod.toGenericString());
+ System.out.println("returnType = " + kernelMethod.getReturnType());
+ System.out.println("returnTypeLetter = " + getReturnTypeLetter(kernelMethod));
+ System.out.println("kernelMethod getName = " + kernelMethod.getName());
+ System.out.println("methRefName = " + name + " descriptor = "
+ + _methodReferenceEntry.getNameAndTypeEntry().getDescriptorUTF8Entry().getUTF8());
+ System.out.println("descToReturnTypeLetter = "
+ + descriptorToReturnTypeLetter(_methodReferenceEntry.getNameAndTypeEntry().getDescriptorUTF8Entry()
+ .getUTF8()));
+ }
+ if (toSignature(_methodReferenceEntry).equals(toSignature(kernelMethod))) {
+ final OpenCLMapping annotation = kernelMethod.getAnnotation(OpenCLMapping.class);
+ final String mapTo = annotation.mapTo();
+ if (!mapTo.equals("")) {
+ mappedName = mapTo;
+ // System.out.println("mapTo = " + mapTo);
+ }
}
}
}
+ if (mappedName != null)
+ break;
+ currentClass = currentClass.getSuperclass();
}
// System.out.println("... in getMappedMethodName, returning = " + mappedName);
return (mappedName);
}
public static boolean isMappedMethod(MethodReferenceEntry methodReferenceEntry) {
+ if (CacheEnabler.areCachesEnabled())
+ return getBoolean(mappedMethodFlags, methodReferenceEntry);
boolean isMapped = false;
for (final Method kernelMethod : Kernel.class.getDeclaredMethods()) {
if (kernelMethod.isAnnotationPresent(OpenCLMapping.class)) {
@@ -2157,6 +2198,8 @@ public abstract class Kernel implements Cloneable {
}
public static boolean isOpenCLDelegateMethod(MethodReferenceEntry methodReferenceEntry) {
+ if (CacheEnabler.areCachesEnabled())
+ return getBoolean(openCLDelegateMethodFlags, methodReferenceEntry);
boolean isMapped = false;
for (final Method kernelMethod : Kernel.class.getDeclaredMethods()) {
if (kernelMethod.isAnnotationPresent(OpenCLDelegate.class)) {
@@ -2171,6 +2214,8 @@ public abstract class Kernel implements Cloneable {
}
public static boolean usesAtomic32(MethodReferenceEntry methodReferenceEntry) {
+ if (CacheEnabler.areCachesEnabled())
+ return getProperty(atomic32Cache, methodReferenceEntry, false);
for (final Method kernelMethod : Kernel.class.getDeclaredMethods()) {
if (kernelMethod.isAnnotationPresent(OpenCLMapping.class)) {
if (methodReferenceEntry.getNameAndTypeEntry().getNameUTF8Entry().getUTF8().equals(kernelMethod.getName())) {
@@ -2184,6 +2229,8 @@ public abstract class Kernel implements Cloneable {
// For alpha release atomic64 is not supported
public static boolean usesAtomic64(MethodReferenceEntry methodReferenceEntry) {
+ // if (CacheEnabler.areCachesEnabled())
+ // return getProperty(atomic64Cache, methodReferenceEntry, false);
//for (java.lang.reflect.Method kernelMethod : Kernel.class.getDeclaredMethods()) {
// if (kernelMethod.isAnnotationPresent(Kernel.OpenCLMapping.class)) {
// if (methodReferenceEntry.getNameAndTypeEntry().getNameUTF8Entry().getUTF8().equals(kernelMethod.getName())) {
@@ -2859,4 +2906,131 @@ public abstract class Kernel implements Cloneable {
executionMode = currentMode.next();
}
}
+
+ private static final ValueCache<Class<?>, Map<String, Boolean>, RuntimeException> mappedMethodFlags = markedWith(OpenCLMapping.class);
+
+ private static final ValueCache<Class<?>, Map<String, Boolean>, RuntimeException> openCLDelegateMethodFlags = markedWith(OpenCLDelegate.class);
+
+ private static final ValueCache<Class<?>, Map<String, Boolean>, RuntimeException> atomic32Cache = cacheProperty(new ValueComputer<Class<?>, Map<String, Boolean>>(){
+ @Override
+ public Map<String, Boolean> compute(Class<?> key) {
+ Map<String, Boolean> properties = new HashMap<>();
+ for (final Method method : key.getDeclaredMethods()) {
+ if (isRelevant(method) && method.isAnnotationPresent(OpenCLMapping.class)) {
+ properties.put(toSignature(method), method.getAnnotation(OpenCLMapping.class).atomic32());
+ }
+ }
+ return properties;
+ }
+ });
+
+ private static final ValueCache<Class<?>, Map<String, Boolean>, RuntimeException> atomic64Cache = cacheProperty(new ValueComputer<Class<?>, Map<String, Boolean>>(){
+ @Override
+ public Map<String, Boolean> compute(Class<?> key) {
+ Map<String, Boolean> properties = new HashMap<>();
+ for (final Method method : key.getDeclaredMethods()) {
+ if (isRelevant(method) && method.isAnnotationPresent(OpenCLMapping.class)) {
+ properties.put(toSignature(method), method.getAnnotation(OpenCLMapping.class).atomic64());
+ }
+ }
+ return properties;
+ }
+ });
+
+ private static boolean getBoolean(ValueCache<Class<?>, Map<String, Boolean>, RuntimeException> methodNamesCache,
+ MethodReferenceEntry methodReferenceEntry) {
+ return getProperty(methodNamesCache, methodReferenceEntry, false);
+ }
+
+ private static <A extends Annotation> ValueCache<Class<?>, Map<String, Boolean>, RuntimeException> markedWith(
+ final Class<A> annotationClass) {
+ return cacheProperty(new ValueComputer<Class<?>, Map<String, Boolean>>(){
+ @Override
+ public Map<String, Boolean> compute(Class<?> key) {
+ Map<String, Boolean> markedMethodNames = new HashMap<>();
+ for (final Method method : key.getDeclaredMethods()) {
+ markedMethodNames.put(toSignature(method), method.isAnnotationPresent(annotationClass));
+ }
+ return markedMethodNames;
+ }
+ });
+ }
+
+ static String toSignature(Method method) {
+ return method.getName() + getArgumentsLetters(method) + getReturnTypeLetter(method);
+ }
+
+ private static String getArgumentsLetters(Method method) {
+ StringBuilder sb = new StringBuilder("(");
+ for (Class<?> parameterClass : method.getParameterTypes()) {
+ sb.append(toClassShortNameIfAny(parameterClass));
+ }
+ sb.append(")");
+ return sb.toString();
+ }
+
+ private static boolean isRelevant(Method method) {
+ return !method.isSynthetic() && !method.isBridge();
+ }
+
+ private static <V, T extends Throwable> V getProperty(ValueCache<Class<?>, Map<String, V>, T> cache,
+ MethodReferenceEntry methodReferenceEntry, V defaultValue) throws T {
+ Map<String, V> map = cache.computeIfAbsent(methodReferenceEntry.getOwnerClassModel().getClassWeAreModelling());
+ String key = toSignature(methodReferenceEntry);
+ if (map.containsKey(key))
+ return map.get(key);
+ return defaultValue;
+ }
+
+ private static String toSignature(MethodReferenceEntry methodReferenceEntry) {
+ NameAndTypeEntry nameAndTypeEntry = methodReferenceEntry.getNameAndTypeEntry();
+ return nameAndTypeEntry.getNameUTF8Entry().getUTF8() + nameAndTypeEntry.getDescriptorUTF8Entry().getUTF8();
+ }
+
+ private static final ValueCache<Class<?>, Map<String, String>, RuntimeException> mappedMethodNamesCache = cacheProperty(new ValueComputer<Class<?>, Map<String, String>>(){
+ @Override
+ public Map<String, String> compute(Class<?> key) {
+ Map<String, String> properties = new HashMap<>();
+ for (final Method method : key.getDeclaredMethods()) {
+ if (isRelevant(method) && method.isAnnotationPresent(OpenCLMapping.class)) {
+ // ultimately, need a way to constrain this based upon signature (to disambiguate abs(float) from abs(int);
+ final OpenCLMapping annotation = method.getAnnotation(OpenCLMapping.class);
+ final String mapTo = annotation.mapTo();
+ if (mapTo != null && !mapTo.equals("")) {
+ properties.put(toSignature(method), mapTo);
+ }
+ }
+ }
+ return properties;
+ }
+ });
+
+ private static <K, V, T extends Throwable> ValueCache<Class<?>, Map<K, V>, T> cacheProperty(
+ final ThrowingValueComputer<Class<?>, Map<K, V>, T> throwingValueComputer) {
+ return ValueCache.on(new ThrowingValueComputer<Class<?>, Map<K, V>, T>(){
+ @Override
+ public Map<K, V> compute(Class<?> key) throws T {
+ Map<K, V> properties = new HashMap<>();
+ Deque<Class<?>> superclasses = new ArrayDeque<>();
+ Class<?> currentSuperClass = key;
+ do {
+ superclasses.push(currentSuperClass);
+ currentSuperClass = currentSuperClass.getSuperclass();
+ } while (currentSuperClass != Object.class);
+ for (Class<?> clazz : superclasses) {
+ // Overwrite property values for shadowed/overriden methods
+ properties.putAll(throwingValueComputer.compute(clazz));
+ }
+ return properties;
+ }
+ });
+ }
+
+ public static void invalidateCaches() {
+ atomic32Cache.invalidate();
+ atomic64Cache.invalidate();
+ mappedMethodFlags.invalidate();
+ mappedMethodNamesCache.invalidate();
+ openCLDelegateMethodFlags.invalidate();
+ }
}
diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/kernel/KernelRunner.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/kernel/KernelRunner.java
index 821364c85bccf5906c55b4aae45855a5e66b9d8c..cdb6c9d094007713ad2da09f1219b06e5f491682 100644
--- a/com.amd.aparapi/src/java/com/amd/aparapi/internal/kernel/KernelRunner.java
+++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/kernel/KernelRunner.java
@@ -48,8 +48,10 @@ import java.util.Set;
import java.util.StringTokenizer;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
-import java.util.concurrent.Executors;
-import java.util.concurrent.ExecutorService;
+import java.util.concurrent.ForkJoinPool;
+import java.util.concurrent.ForkJoinPool.ForkJoinWorkerThreadFactory;
+import java.util.concurrent.ForkJoinPool.ManagedBlocker;
+import java.util.concurrent.ForkJoinWorkerThread;
import java.util.logging.Level;
import java.util.logging.Logger;
@@ -100,8 +102,18 @@ public class KernelRunner extends KernelRunnerJNI{
private Entrypoint entryPoint;
private int argc;
-
- private final ExecutorService threadPool = Executors.newCachedThreadPool();
+
+ private static final ForkJoinWorkerThreadFactory lowPriorityThreadFactory = new ForkJoinWorkerThreadFactory(){
+ @Override public ForkJoinWorkerThread newThread(ForkJoinPool pool) {
+ ForkJoinWorkerThread newThread = ForkJoinPool.defaultForkJoinWorkerThreadFactory.newThread(pool);
+ newThread.setPriority(Thread.MIN_PRIORITY);
+ return newThread;
+ }
+ };
+
+ private static final ForkJoinPool threadPool = new ForkJoinPool(Runtime.getRuntime().availableProcessors(),
+ lowPriorityThreadFactory, null, false);
+
/**
* Create a KernelRunner for a specific Kernel instance.
*
@@ -120,7 +132,8 @@ public class KernelRunner extends KernelRunnerJNI{
if (kernel.getExecutionMode().isOpenCL()) {
disposeJNI(jniContextHandle);
}
- threadPool.shutdownNow();
+ // We are using a shared pool, so there's no need no shutdown it when kernel is disposed
+ // threadPool.shutdownNow();
}
private Set<String> capabilitiesSet;
@@ -215,6 +228,50 @@ public class KernelRunner extends KernelRunnerJNI{
return capabilitiesSet.contains(OpenCL.CL_KHR_GL_SHARING);
}
+ private static final class FJSafeCyclicBarrier extends CyclicBarrier{
+ FJSafeCyclicBarrier(final int threads) {
+ super(threads);
+ }
+
+ @Override public int await() throws InterruptedException, BrokenBarrierException {
+ class Awaiter implements ManagedBlocker{
+ private int value;
+
+ private boolean released;
+
+ @Override public boolean block() throws InterruptedException {
+ try {
+ value = superAwait();
+ released = true;
+ return true;
+ } catch (final BrokenBarrierException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override public boolean isReleasable() {
+ return released;
+ }
+
+ int getValue() {
+ return value;
+ }
+ }
+ final Awaiter awaiter = new Awaiter();
+ ForkJoinPool.managedBlock(awaiter);
+ return awaiter.getValue();
+ }
+
+ int superAwait() throws InterruptedException, BrokenBarrierException {
+ return super.await();
+ }
+ }
+
+ // @FunctionalInterface
+ private interface ThreadIdSetter{
+ void set(KernelState kernelState, int globalGroupId, int threadId);
+ }
+
/**
* Execute using a Java thread pool. Either because we were explicitly asked to do so, or because we 'fall back' after discovering an OpenCL issue.
*
@@ -224,11 +281,15 @@ public class KernelRunner extends KernelRunnerJNI{
* The # of passes requested by the user (via <code>Kernel.execute(globalSize, passes)</code>). Note this is usually defaulted to 1 via <code>Kernel.execute(globalSize)</code>.
* @return
*/
- private long executeJava(final Range _range, final int _passes) {
+ protected long executeJava(final Range _range, final int _passes) {
if (logger.isLoggable(Level.FINE)) {
logger.fine("executeJava: range = " + _range);
}
+ final int localSize0 = _range.getLocalSize(0);
+ final int localSize1 = _range.getLocalSize(1);
+ final int localSize2 = _range.getLocalSize(2);
+ final int globalSize1 = _range.getGlobalSize(1);
if (kernel.getExecutionMode().equals(EXECUTION_MODE.SEQ)) {
/**
* SEQ mode is useful for testing trivial logic, but kernels which use SEQ mode cannot be used if the
@@ -238,7 +299,7 @@ public class KernelRunner extends KernelRunnerJNI{
*
* So we need to check if the range is valid here. If not we have no choice but to punt.
*/
- if ((_range.getLocalSize(0) * _range.getLocalSize(1) * _range.getLocalSize(2)) > 1) {
+ if ((localSize0 * localSize1 * localSize2) > 1) {
throw new IllegalStateException("Can't run range with group size >1 sequentially. Barriers would deadlock!");
}
@@ -252,7 +313,7 @@ public class KernelRunner extends KernelRunnerJNI{
kernelState.setLocalId(0, 0);
kernelState.setLocalId(1, 0);
kernelState.setLocalId(2, 0);
- kernelState.setLocalBarrier(new CyclicBarrier(1));
+ kernelState.setLocalBarrier(new FJSafeCyclicBarrier(1));
for (int passId = 0; passId < _passes; passId++) {
kernelState.setPassId(passId);
@@ -266,7 +327,7 @@ public class KernelRunner extends KernelRunnerJNI{
for (int x = 0; x < _range.getGlobalSize(0); x++) {
kernelState.setGlobalId(0, x);
- for (int y = 0; y < _range.getGlobalSize(1); y++) {
+ for (int y = 0; y < globalSize1; y++) {
kernelState.setGlobalId(1, y);
kernelClone.run();
}
@@ -275,7 +336,7 @@ public class KernelRunner extends KernelRunnerJNI{
for (int x = 0; x < _range.getGlobalSize(0); x++) {
kernelState.setGlobalId(0, x);
- for (int y = 0; y < _range.getGlobalSize(1); y++) {
+ for (int y = 0; y < globalSize1; y++) {
kernelState.setGlobalId(1, y);
for (int z = 0; z < _range.getGlobalSize(2); z++) {
@@ -289,13 +350,15 @@ public class KernelRunner extends KernelRunnerJNI{
}
}
} else {
- final int threads = _range.getLocalSize(0) * _range.getLocalSize(1) * _range.getLocalSize(2);
- final int globalGroups = _range.getNumGroups(0) * _range.getNumGroups(1) * _range.getNumGroups(2);
+ final int threads = localSize0 * localSize1 * localSize2;
+ final int numGroups0 = _range.getNumGroups(0);
+ final int numGroups1 = _range.getNumGroups(1);
+ final int globalGroups = numGroups0 * numGroups1 * _range.getNumGroups(2);
/**
* This joinBarrier is the barrier that we provide for the kernel threads to rendezvous with the current dispatch thread.
* So this barrier is threadCount+1 wide (the +1 is for the dispatch thread)
*/
- final CyclicBarrier joinBarrier = new CyclicBarrier(threads + 1);
+ final CyclicBarrier joinBarrier = new FJSafeCyclicBarrier(threads + 1);
/**
* This localBarrier is only ever used by the kernels. If the kernel does not use the barrier the threads
@@ -308,8 +371,130 @@ public class KernelRunner extends KernelRunnerJNI{
*
* This barrier is threadCount wide. We never hit the barrier from the dispatch thread.
*/
- final CyclicBarrier localBarrier = new CyclicBarrier(threads);
+ final CyclicBarrier localBarrier = new FJSafeCyclicBarrier(threads);
+
+ final ThreadIdSetter threadIdSetter;
+ if (_range.getDims() == 1) {
+ threadIdSetter = new ThreadIdSetter(){
+ @Override public void set(KernelState kernelState, int globalGroupId, int threadId) {
+ // (kernelState, globalGroupId, threadId) ->{
+ kernelState.setLocalId(0, (threadId % localSize0));
+ kernelState.setGlobalId(0, (threadId + (globalGroupId * threads)));
+ kernelState.setGroupId(0, globalGroupId);
+ }
+ };
+ } else if (_range.getDims() == 2) {
+
+ /**
+ * Consider a 12x4 grid of 4*2 local groups
+ * <pre>
+ * threads = 4*2 = 8
+ * localWidth=4
+ * localHeight=2
+ * globalWidth=12
+ * globalHeight=4
+ *
+ * 00 01 02 03 | 04 05 06 07 | 08 09 10 11
+ * 12 13 14 15 | 16 17 18 19 | 20 21 22 23
+ * ------------+-------------+------------
+ * 24 25 26 27 | 28 29 30 31 | 32 33 34 35
+ * 36 37 38 39 | 40 41 42 43 | 44 45 46 47
+ *
+ * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 threadIds : [0..7]*6
+ * 04 05 06 07 | 04 05 06 07 | 04 05 06 07
+ * ------------+-------------+------------
+ * 00 01 02 03 | 00 01 02 03 | 00 01 02 03
+ * 04 05 06 07 | 04 05 06 07 | 04 05 06 07
+ *
+ * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 groupId[0] : 0..6
+ * 00 00 00 00 | 01 01 01 01 | 02 02 02 02
+ * ------------+-------------+------------
+ * 00 00 00 00 | 01 01 01 01 | 02 02 02 02
+ * 00 00 00 00 | 01 01 01 01 | 02 02 02 02
+ *
+ * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 groupId[1] : 0..6
+ * 00 00 00 00 | 00 00 00 00 | 00 00 00 00
+ * ------------+-------------+------------
+ * 01 01 01 01 | 01 01 01 01 | 01 01 01 01
+ * 01 01 01 01 | 01 01 01 01 | 01 01 01 01
+ *
+ * 00 01 02 03 | 08 09 10 11 | 16 17 18 19 globalThreadIds == threadId + groupId * threads;
+ * 04 05 06 07 | 12 13 14 15 | 20 21 22 23
+ * ------------+-------------+------------
+ * 24 25 26 27 | 32[33]34 35 | 40 41 42 43
+ * 28 29 30 31 | 36 37 38 39 | 44 45 46 47
+ *
+ * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 localX = threadId % localWidth; (for globalThreadId 33 = threadId = 01 : 01%4 =1)
+ * 00 01 02 03 | 00 01 02 03 | 00 01 02 03
+ * ------------+-------------+------------
+ * 00 01 02 03 | 00[01]02 03 | 00 01 02 03
+ * 00 01 02 03 | 00 01 02 03 | 00 01 02 03
+ *
+ * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 localY = threadId /localWidth (for globalThreadId 33 = threadId = 01 : 01/4 =0)
+ * 01 01 01 01 | 01 01 01 01 | 01 01 01 01
+ * ------------+-------------+------------
+ * 00 00 00 00 | 00[00]00 00 | 00 00 00 00
+ * 01 01 01 01 | 01 01 01 01 | 01 01 01 01
+ *
+ * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 globalX=
+ * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 groupsPerLineWidth=globalWidth/localWidth (=12/4 =3)
+ * ------------+-------------+------------ groupInset =groupId%groupsPerLineWidth (=4%3 = 1)
+ * 00 01 02 03 | 04[05]06 07 | 08 09 10 11
+ * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 globalX = groupInset*localWidth+localX (= 1*4+1 = 5)
+ *
+ * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 globalY
+ * 01 01 01 01 | 01 01 01 01 | 01 01 01 01
+ * ------------+-------------+------------
+ * 02 02 02 02 | 02[02]02 02 | 02 02 02 02
+ * 03 03 03 03 | 03 03 03 03 | 03 03 03 03
+ *
+ * </pre>
+ * Assume we are trying to locate the id's for #33
+ *
+ */
+ threadIdSetter = new ThreadIdSetter(){
+ @Override public void set(KernelState kernelState, int globalGroupId, int threadId) {
+ // (kernelState, globalGroupId, threadId) ->{
+ kernelState.setLocalId(0, (threadId % localSize0)); // threadId % localWidth = (for 33 = 1 % 4 = 1)
+ kernelState.setLocalId(1, (threadId / localSize0)); // threadId / localWidth = (for 33 = 1 / 4 == 0)
+
+ final int groupInset = globalGroupId % numGroups0; // 4%3 = 1
+ kernelState.setGlobalId(0, ((groupInset * localSize0) + kernelState.getLocalIds()[0])); // 1*4+1=5
+
+ final int completeLines = (globalGroupId / numGroups0) * localSize1;// (4/3) * 2
+ kernelState.setGlobalId(1, (completeLines + kernelState.getLocalIds()[1])); // 2+0 = 2
+ kernelState.setGroupId(0, (globalGroupId % numGroups0));
+ kernelState.setGroupId(1, (globalGroupId / numGroups0));
+ }
+ };
+ } else if (_range.getDims() == 3) {
+ //Same as 2D actually turns out that localId[0] is identical for all three dims so could be hoisted out of conditional code
+ threadIdSetter = new ThreadIdSetter(){
+ @Override public void set(KernelState kernelState, int globalGroupId, int threadId) {
+ // (kernelState, globalGroupId, threadId) ->{
+ kernelState.setLocalId(0, (threadId % localSize0));
+
+ kernelState.setLocalId(1, ((threadId / localSize0) % localSize1));
+
+ // the thread id's span WxHxD so threadId/(WxH) should yield the local depth
+ kernelState.setLocalId(2, (threadId / (localSize0 * localSize1)));
+
+ kernelState.setGlobalId(0, (((globalGroupId % numGroups0) * localSize0) + kernelState.getLocalIds()[0]));
+
+ kernelState.setGlobalId(1,
+ ((((globalGroupId / numGroups0) * localSize1) % globalSize1) + kernelState.getLocalIds()[1]));
+
+ kernelState.setGlobalId(2,
+ (((globalGroupId / (numGroups0 * numGroups1)) * localSize2) + kernelState.getLocalIds()[2]));
+
+ kernelState.setGroupId(0, (globalGroupId % numGroups0));
+ kernelState.setGroupId(1, ((globalGroupId / numGroups0) % numGroups1));
+ kernelState.setGroupId(2, (globalGroupId / (numGroups0 * numGroups1)));
+ }
+ };
+ } else
+ throw new IllegalArgumentException("Expected 1,2 or 3 dimensions, found " + _range.getDims());
for (int passId = 0; passId < _passes; passId++) {
/**
* Note that we emulate OpenCL by creating one thread per localId (across the group).
@@ -352,135 +537,31 @@ public class KernelRunner extends KernelRunnerJNI{
*/
final Kernel kernelClone = kernel.clone();
final KernelState kernelState = kernelClone.getKernelState();
-
kernelState.setRange(_range);
- kernelState.setLocalBarrier(localBarrier);
kernelState.setPassId(passId);
- threadPool.submit(new Runnable(){
- @Override public void run() {
- for (int globalGroupId = 0; globalGroupId < globalGroups; globalGroupId++) {
-
- if (_range.getDims() == 1) {
- kernelState.setLocalId(0, (threadId % _range.getLocalSize(0)));
- kernelState.setGlobalId(0, (threadId + (globalGroupId * threads)));
- kernelState.setGroupId(0, globalGroupId);
- } else if (_range.getDims() == 2) {
-
- /**
- * Consider a 12x4 grid of 4*2 local groups
- * <pre>
- * threads = 4*2 = 8
- * localWidth=4
- * localHeight=2
- * globalWidth=12
- * globalHeight=4
- *
- * 00 01 02 03 | 04 05 06 07 | 08 09 10 11
- * 12 13 14 15 | 16 17 18 19 | 20 21 22 23
- * ------------+-------------+------------
- * 24 25 26 27 | 28 29 30 31 | 32 33 34 35
- * 36 37 38 39 | 40 41 42 43 | 44 45 46 47
- *
- * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 threadIds : [0..7]*6
- * 04 05 06 07 | 04 05 06 07 | 04 05 06 07
- * ------------+-------------+------------
- * 00 01 02 03 | 00 01 02 03 | 00 01 02 03
- * 04 05 06 07 | 04 05 06 07 | 04 05 06 07
- *
- * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 groupId[0] : 0..6
- * 00 00 00 00 | 01 01 01 01 | 02 02 02 02
- * ------------+-------------+------------
- * 00 00 00 00 | 01 01 01 01 | 02 02 02 02
- * 00 00 00 00 | 01 01 01 01 | 02 02 02 02
- *
- * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 groupId[1] : 0..6
- * 00 00 00 00 | 00 00 00 00 | 00 00 00 00
- * ------------+-------------+------------
- * 01 01 01 01 | 01 01 01 01 | 01 01 01 01
- * 01 01 01 01 | 01 01 01 01 | 01 01 01 01
- *
- * 00 01 02 03 | 08 09 10 11 | 16 17 18 19 globalThreadIds == threadId + groupId * threads;
- * 04 05 06 07 | 12 13 14 15 | 20 21 22 23
- * ------------+-------------+------------
- * 24 25 26 27 | 32[33]34 35 | 40 41 42 43
- * 28 29 30 31 | 36 37 38 39 | 44 45 46 47
- *
- * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 localX = threadId % localWidth; (for globalThreadId 33 = threadId = 01 : 01%4 =1)
- * 00 01 02 03 | 00 01 02 03 | 00 01 02 03
- * ------------+-------------+------------
- * 00 01 02 03 | 00[01]02 03 | 00 01 02 03
- * 00 01 02 03 | 00 01 02 03 | 00 01 02 03
- *
- * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 localY = threadId /localWidth (for globalThreadId 33 = threadId = 01 : 01/4 =0)
- * 01 01 01 01 | 01 01 01 01 | 01 01 01 01
- * ------------+-------------+------------
- * 00 00 00 00 | 00[00]00 00 | 00 00 00 00
- * 01 01 01 01 | 01 01 01 01 | 01 01 01 01
- *
- * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 globalX=
- * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 groupsPerLineWidth=globalWidth/localWidth (=12/4 =3)
- * ------------+-------------+------------ groupInset =groupId%groupsPerLineWidth (=4%3 = 1)
- * 00 01 02 03 | 04[05]06 07 | 08 09 10 11
- * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 globalX = groupInset*localWidth+localX (= 1*4+1 = 5)
- *
- * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 globalY
- * 01 01 01 01 | 01 01 01 01 | 01 01 01 01
- * ------------+-------------+------------
- * 02 02 02 02 | 02[02]02 02 | 02 02 02 02
- * 03 03 03 03 | 03 03 03 03 | 03 03 03 03
- *
- * </pre>
- * Assume we are trying to locate the id's for #33
- *
- */
-
- kernelState.setLocalId(0, (threadId % _range.getLocalSize(0))); // threadId % localWidth = (for 33 = 1 % 4 = 1)
- kernelState.setLocalId(1, (threadId / _range.getLocalSize(0))); // threadId / localWidth = (for 33 = 1 / 4 == 0)
-
- final int groupInset = globalGroupId % _range.getNumGroups(0); // 4%3 = 1
- kernelState.setGlobalId(0, ((groupInset * _range.getLocalSize(0)) + kernelState.getLocalIds()[0])); // 1*4+1=5
-
- final int completeLines = (globalGroupId / _range.getNumGroups(0)) * _range.getLocalSize(1);// (4/3) * 2
- kernelState.setGlobalId(1, (completeLines + kernelState.getLocalIds()[1])); // 2+0 = 2
- kernelState.setGroupId(0, (globalGroupId % _range.getNumGroups(0)));
- kernelState.setGroupId(1, (globalGroupId / _range.getNumGroups(0)));
- } else if (_range.getDims() == 3) {
-
- //Same as 2D actually turns out that localId[0] is identical for all three dims so could be hoisted out of conditional code
-
- kernelState.setLocalId(0, (threadId % _range.getLocalSize(0)));
-
- kernelState.setLocalId(1, ((threadId / _range.getLocalSize(0)) % _range.getLocalSize(1)));
-
- // the thread id's span WxHxD so threadId/(WxH) should yield the local depth
- kernelState.setLocalId(2, (threadId / (_range.getLocalSize(0) * _range.getLocalSize(1))));
-
- kernelState.setGlobalId(
- 0,
- (((globalGroupId % _range.getNumGroups(0)) * _range.getLocalSize(0)) + kernelState.getLocalIds()[0]));
-
- kernelState.setGlobalId(
- 1,
- ((((globalGroupId / _range.getNumGroups(0)) * _range.getLocalSize(1)) % _range.getGlobalSize(1)) + kernelState
- .getLocalIds()[1]));
-
- kernelState.setGlobalId(
- 2,
- (((globalGroupId / (_range.getNumGroups(0) * _range.getNumGroups(1))) * _range.getLocalSize(2)) + kernelState
- .getLocalIds()[2]));
-
- kernelState.setGroupId(0, (globalGroupId % _range.getNumGroups(0)));
- kernelState.setGroupId(1, ((globalGroupId / _range.getNumGroups(0)) % _range.getNumGroups(1)));
- kernelState.setGroupId(2, (globalGroupId / (_range.getNumGroups(0) * _range.getNumGroups(1))));
- }
-
- kernelClone.run();
- }
+ if (threads == 1) {
+ kernelState.disableLocalBarrier();
+ } else {
+ kernelState.setLocalBarrier(localBarrier);
+ }
- await(joinBarrier); // This thread will rendezvous with dispatch thread here. This is effectively a join.
- }
- });
+ threadPool.submit(
+ // () -> {
+ new Runnable(){
+ public void run() {
+ try {
+ for (int globalGroupId = 0; globalGroupId < globalGroups; globalGroupId++) {
+ threadIdSetter.set(kernelState, globalGroupId, threadId);
+ kernelClone.run();
+ }
+ } catch (RuntimeException | Error e) {
+ logger.log(Level.SEVERE, "Execution failed", e);
+ } finally {
+ await(joinBarrier); // This thread will rendezvous with dispatch thread here. This is effectively a join.
+ }
+ }
+ });
}
await(joinBarrier); // This dispatch thread waits for all worker threads here.
@@ -519,7 +600,7 @@ public class KernelRunner extends KernelRunnerJNI{
boolean didReallocate = false;
if (arg.getObjArrayElementModel() == null) {
- final String tmp = arrayClass.getName().substring(2).replace("/", ".");
+ final String tmp = arrayClass.getName().substring(2).replace('/', '.');
final String arrayClassInDotForm = tmp.substring(0, tmp.length() - 1);
if (logger.isLoggable(Level.FINE)) {
@@ -786,7 +867,8 @@ public class KernelRunner extends KernelRunnerJNI{
}
if (privateMemorySize != null) {
if (arrayLength > privateMemorySize) {
- throw new IllegalStateException("__private array field " + fieldName + " has illegal length " + arrayLength + " > " + privateMemorySize);
+ throw new IllegalStateException("__private array field " + fieldName + " has illegal length " + arrayLength
+ + " > " + privateMemorySize);
}
}
@@ -943,7 +1025,7 @@ public class KernelRunner extends KernelRunnerJNI{
if ((device == null) || (device instanceof OpenCLDevice)) {
if (entryPoint == null) {
try {
- final ClassModel classModel = new ClassModel(kernel.getClass());
+ final ClassModel classModel = ClassModel.createClassModel(kernel.getClass());
entryPoint = classModel.getEntrypoint(_entrypointName, kernel);
} catch (final Exception exception) {
return warnFallBackAndExecute(_entrypointName, _range, _passes, exception);
@@ -1079,13 +1161,9 @@ public class KernelRunner extends KernelRunnerJNI{
| (entryPoint.getArrayFieldAccesses().contains(field.getName()) ? ARG_READ : 0));
// args[i].type |= ARG_GLOBAL;
-
if (type.getName().startsWith("[L")) {
args[i].setType(args[i].getType()
- | (ARG_OBJ_ARRAY_STRUCT |
- ARG_WRITE |
- ARG_READ |
- ARG_APARAPI_BUFFER));
+ | (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ | ARG_APARAPI_BUFFER));
if (logger.isLoggable(Level.FINE)) {
logger.fine("tagging " + args[i].getName() + " as (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ)");
@@ -1094,8 +1172,9 @@ public class KernelRunner extends KernelRunnerJNI{
try {
setMultiArrayType(args[i], type);
- } catch(AparapiException e) {
- return warnFallBackAndExecute(_entrypointName, _range, _passes, "failed to set kernel arguement " + args[i].getName() + ". Aparapi only supports 2D and 3D arrays.");
+ } catch (AparapiException e) {
+ return warnFallBackAndExecute(_entrypointName, _range, _passes, "failed to set kernel arguement "
+ + args[i].getName() + ". Aparapi only supports 2D and 3D arrays.");
}
} else {
@@ -1120,7 +1199,8 @@ public class KernelRunner extends KernelRunnerJNI{
if (type.getName().startsWith("[L")) {
args[i].setType(args[i].getType() | (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ));
if (logger.isLoggable(Level.FINE)) {
- logger.fine("tagging " + args[i].getName() + " as (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ)");
+ logger.fine("tagging " + args[i].getName()
+ + " as (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ)");
}
}
}
@@ -1206,7 +1286,6 @@ public class KernelRunner extends KernelRunnerJNI{
return kernel;
}
-
private int getPrimitiveSize(int type) {
if ((type & ARG_FLOAT) != 0) {
return 4;
@@ -1231,28 +1310,28 @@ public class KernelRunner extends KernelRunnerJNI{
private void setMultiArrayType(KernelArg arg, Class<?> type) throws AparapiException {
arg.setType(arg.getType() | (ARG_WRITE | ARG_READ | ARG_APARAPI_BUFFER));
int numDims = 0;
- while(type.getName().startsWith("[[[[")) {
+ while (type.getName().startsWith("[[[[")) {
throw new AparapiException("Aparapi only supports 2D and 3D arrays.");
}
arg.setType(arg.getType() | ARG_ARRAYLENGTH);
- while(type.getName().charAt(numDims) == '[') {
+ while (type.getName().charAt(numDims) == '[') {
numDims++;
}
Object buffer = new Object();
try {
buffer = arg.getField().get(kernel);
- } catch(IllegalAccessException e) {
+ } catch (IllegalAccessException e) {
e.printStackTrace();
}
arg.setJavaBuffer(buffer);
arg.setNumDims(numDims);
Object subBuffer = buffer;
int[] dims = new int[numDims];
- for(int i = 0; i < numDims-1; i++) {
+ for (int i = 0; i < numDims - 1; i++) {
dims[i] = Array.getLength(subBuffer);
subBuffer = Array.get(subBuffer, 0);
}
- dims[numDims-1] = Array.getLength(subBuffer);
+ dims[numDims - 1] = Array.getLength(subBuffer);
arg.setDims(dims);
if (subBuffer.getClass().isAssignableFrom(float[].class)) {
@@ -1281,7 +1360,7 @@ public class KernelRunner extends KernelRunnerJNI{
}
int primitiveSize = getPrimitiveSize(arg.getType());
int totalElements = 1;
- for(int i = 0; i < numDims; i++) {
+ for (int i = 0; i < numDims; i++) {
totalElements *= dims[i];
}
arg.setSizeInBytes(totalElements * primitiveSize);
diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/CacheEnabler.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/CacheEnabler.java
new file mode 100644
index 0000000000000000000000000000000000000000..9308d57ed876ac88e9e99dc1e98e4c22fa1ad117
--- /dev/null
+++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/CacheEnabler.java
@@ -0,0 +1,20 @@
+package com.amd.aparapi.internal.model;
+
+import com.amd.aparapi.Kernel;
+
+public class CacheEnabler{
+ private static volatile boolean cachesEnabled;
+
+ public static void setCachesEnabled(boolean cachesEnabled) {
+ if (CacheEnabler.cachesEnabled != cachesEnabled) {
+ Kernel.invalidateCaches();
+ ClassModel.invalidateCaches();
+ }
+
+ CacheEnabler.cachesEnabled = cachesEnabled;
+ }
+
+ public static boolean areCachesEnabled() {
+ return cachesEnabled;
+ }
+}
diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/ClassModel.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/ClassModel.java
index e1b269fafea752e40943b6e98a0c9d3749f327d1..95c362a357f0b09911f7f433d426809073bd67ec 100644
--- a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/ClassModel.java
+++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/ClassModel.java
@@ -41,6 +41,7 @@ import com.amd.aparapi.*;
import com.amd.aparapi.internal.annotation.*;
import com.amd.aparapi.internal.exception.*;
import com.amd.aparapi.internal.instruction.InstructionSet.*;
+import com.amd.aparapi.internal.model.ValueCache.ThrowingValueComputer;
import com.amd.aparapi.internal.model.ClassModel.AttributePool.*;
import com.amd.aparapi.internal.model.ClassModel.ConstantPool.*;
import com.amd.aparapi.internal.reader.*;
@@ -122,10 +123,31 @@ public class ClassModel{
private ClassModel superClazz = null;
- private HashSet<String> noClMethods = null;
+ // private Memoizer<Set<String>> noClMethods = Memoizer.of(this::computeNoCLMethods);
+ private Memoizer<Set<String>> noClMethods = Memoizer.Impl.of(new Supplier<Set<String>>(){
+ @Override
+ public Set<String> get() {
+ return computeNoCLMethods();
+ }
+ });
- private HashMap<String, Kernel.PrivateMemorySpace> privateMemoryFields = null;
+ // private Memoizer<Map<String, Kernel.PrivateMemorySpace>> privateMemoryFields = Memoizer.of(this::computePrivateMemoryFields);
+ private Memoizer<Map<String, Kernel.PrivateMemorySpace>> privateMemoryFields = Memoizer.Impl
+ .of(new Supplier<Map<String, Kernel.PrivateMemorySpace>>(){
+ @Override
+ public Map<String, Kernel.PrivateMemorySpace> get() {
+ return computePrivateMemoryFields();
+ }
+ });
+ // private ValueCache<String, Integer, ClassParseException> privateMemorySizes = ValueCache.on(this::computePrivateMemorySize);
+ private ValueCache<String, Integer, ClassParseException> privateMemorySizes = ValueCache
+ .on(new ThrowingValueComputer<String, Integer, ClassParseException>(){
+ @Override
+ public Integer compute(String fieldName) throws ClassParseException {
+ return computePrivateMemorySize(fieldName);
+ }
+ });
/**
* Create a ClassModel representing a given Class.
@@ -137,7 +159,7 @@ public class ClassModel{
* @throws ClassParseException
*/
- public ClassModel(Class<?> _class) throws ClassParseException {
+ private ClassModel(Class<?> _class) throws ClassParseException {
parse(_class);
@@ -147,7 +169,7 @@ public class ClassModel{
// not occur in normal use
if ((mySuper != null) && (!mySuper.getName().equals(Kernel.class.getName()))
&& (!mySuper.getName().equals("java.lang.Object"))) {
- superClazz = new ClassModel(mySuper);
+ superClazz = createClassModel(mySuper);
}
}
@@ -204,7 +226,8 @@ public class ClassModel{
return superClazz;
}
- @DocMe public void replaceSuperClazz(ClassModel c) {
+ @DocMe
+ public void replaceSuperClazz(ClassModel c) {
if (superClazz != null) {
assert c.isSuperClass(getClassWeAreModelling()) == true : "not my super";
if (superClazz.getClassWeAreModelling().getName().equals(c.getClassWeAreModelling().getName())) {
@@ -260,32 +283,40 @@ public class ClassModel{
* If a field does not satisfy the private memory conditions, null, otherwise the size of private memory required.
*/
public Integer getPrivateMemorySize(String fieldName) throws ClassParseException {
- if (privateMemoryFields == null) {
- privateMemoryFields = new HashMap<String, Kernel.PrivateMemorySpace>();
- HashMap<Field, Kernel.PrivateMemorySpace> privateMemoryFields = new HashMap<Field, Kernel.PrivateMemorySpace>();
- for (Field field : getClassWeAreModelling().getDeclaredFields()) {
- Kernel.PrivateMemorySpace privateMemorySpace = field.getAnnotation(Kernel.PrivateMemorySpace.class);
- if (privateMemorySpace != null) {
- privateMemoryFields.put(field, privateMemorySpace);
- }
- }
- for (Field field : getClassWeAreModelling().getFields()) {
- Kernel.PrivateMemorySpace privateMemorySpace = field.getAnnotation(Kernel.PrivateMemorySpace.class);
- if (privateMemorySpace != null) {
- privateMemoryFields.put(field, privateMemorySpace);
- }
- }
- for (Map.Entry<Field, Kernel.PrivateMemorySpace> entry : privateMemoryFields.entrySet()) {
- this.privateMemoryFields.put(entry.getKey().getName(), entry.getValue());
- }
- }
- Kernel.PrivateMemorySpace annotation = privateMemoryFields.get(fieldName);
+ if (CacheEnabler.areCachesEnabled())
+ return privateMemorySizes.computeIfAbsent(fieldName);
+ return computePrivateMemorySize(fieldName);
+ }
+
+ private Integer computePrivateMemorySize(String fieldName) throws ClassParseException {
+ Kernel.PrivateMemorySpace annotation = privateMemoryFields.get().get(fieldName);
if (annotation != null) {
return annotation.value();
}
return getPrivateMemorySizeFromFieldName(fieldName);
}
+ private Map<String, Kernel.PrivateMemorySpace> computePrivateMemoryFields() {
+ Map<String, Kernel.PrivateMemorySpace> tempPrivateMemoryFields = new HashMap<String, Kernel.PrivateMemorySpace>();
+ Map<Field, Kernel.PrivateMemorySpace> privateMemoryFields = new HashMap<Field, Kernel.PrivateMemorySpace>();
+ for (Field field : getClassWeAreModelling().getDeclaredFields()) {
+ Kernel.PrivateMemorySpace privateMemorySpace = field.getAnnotation(Kernel.PrivateMemorySpace.class);
+ if (privateMemorySpace != null) {
+ privateMemoryFields.put(field, privateMemorySpace);
+ }
+ }
+ for (Field field : getClassWeAreModelling().getFields()) {
+ Kernel.PrivateMemorySpace privateMemorySpace = field.getAnnotation(Kernel.PrivateMemorySpace.class);
+ if (privateMemorySpace != null) {
+ privateMemoryFields.put(field, privateMemorySpace);
+ }
+ }
+ for (Map.Entry<Field, Kernel.PrivateMemorySpace> entry : privateMemoryFields.entrySet()) {
+ tempPrivateMemoryFields.put(entry.getKey().getName(), entry.getValue());
+ }
+ return tempPrivateMemoryFields;
+ }
+
public static Integer getPrivateMemorySizeFromField(Field field) {
Kernel.PrivateMemorySpace privateMemorySpace = field.getAnnotation(Kernel.PrivateMemorySpace.class);
if (privateMemorySpace != null) {
@@ -309,25 +340,26 @@ public class ClassModel{
}
public Set<String> getNoCLMethods() {
- if (this.noClMethods == null) {
- noClMethods = new HashSet<String>();
- HashSet<Method> methods = new HashSet<Method>();
- for (Method method : getClassWeAreModelling().getDeclaredMethods()) {
- if (method.getAnnotation(Kernel.NoCL.class) != null) {
- methods.add(method);
- }
- }
- for (Method method : getClassWeAreModelling().getMethods()) {
- if (method.getAnnotation(Kernel.NoCL.class) != null) {
- methods.add(method);
- }
+ return computeNoCLMethods();
+ }
+
+ private Set<String> computeNoCLMethods() {
+ Set<String> tempNoClMethods = new HashSet<String>();
+ HashSet<Method> methods = new HashSet<Method>();
+ for (Method method : getClassWeAreModelling().getDeclaredMethods()) {
+ if (method.getAnnotation(Kernel.NoCL.class) != null) {
+ methods.add(method);
}
- for (Method method : methods) {
- noClMethods.add(method.getName());
+ }
+ for (Method method : getClassWeAreModelling().getMethods()) {
+ if (method.getAnnotation(Kernel.NoCL.class) != null) {
+ methods.add(method);
}
-
}
- return noClMethods;
+ for (Method method : methods) {
+ tempNoClMethods.add(method.getName());
+ }
+ return tempNoClMethods;
}
public static String convert(String _string) {
@@ -602,6 +634,21 @@ public class ClassModel{
return (methodDescription);
}
+ // private static final ValueCache<Class<?>, ClassModel, ClassParseException> classModelCache = ValueCache.onIdentity(ClassModel::new);
+ private static final ValueCache<Class<?>, ClassModel, ClassParseException> classModelCache = ValueCache
+ .on(new ThrowingValueComputer<Class<?>, ClassModel, ClassParseException>(){
+ @Override
+ public ClassModel compute(Class<?> key) throws ClassParseException {
+ return new ClassModel(key);
+ }
+ });
+
+ public static ClassModel createClassModel(Class<?> _class) throws ClassParseException {
+ if (CacheEnabler.areCachesEnabled())
+ return classModelCache.computeIfAbsent(_class);
+ return new ClassModel(_class);
+ }
+
private int magic;
private int minorVersion;
@@ -813,7 +860,8 @@ public class ClassModel{
super(_byteReader, _slot, ConstantPoolType.METHOD);
}
- @Override public String toString() {
+ @Override
+ public String toString() {
final StringBuilder sb = new StringBuilder();
sb.append(getClassEntry().getNameUTF8Entry().getUTF8());
sb.append(".");
@@ -934,13 +982,15 @@ public class ClassModel{
private Type returnType = null;
- @Override public int hashCode() {
+ @Override
+ public int hashCode() {
final NameAndTypeEntry nameAndTypeEntry = getNameAndTypeEntry();
return ((((nameAndTypeEntry.getNameIndex() * 31) + nameAndTypeEntry.getDescriptorIndex()) * 31) + getClassIndex());
}
- @Override public boolean equals(Object _other) {
+ @Override
+ public boolean equals(Object _other) {
if ((_other == null) || !(_other instanceof MethodReferenceEntry)) {
return (false);
} else {
@@ -1312,7 +1362,8 @@ public class ClassModel{
}
- @Override public Iterator<Entry> iterator() {
+ @Override
+ public Iterator<Entry> iterator() {
return (entries.iterator());
}
@@ -1399,58 +1450,51 @@ public class ClassModel{
} else if (_entry instanceof ConstantPool.NameAndTypeEntry) {
final ConstantPool.NameAndTypeEntry nameAndTypeEntry = (ConstantPool.NameAndTypeEntry) _entry;
references = new int[] {
- nameAndTypeEntry.getNameIndex(),
- nameAndTypeEntry.getDescriptorIndex()
+ nameAndTypeEntry.getNameIndex(), nameAndTypeEntry.getDescriptorIndex()
};
} else if (_entry instanceof ConstantPool.MethodEntry) {
final ConstantPool.MethodEntry methodEntry = (ConstantPool.MethodEntry) _entry;
final ConstantPool.ClassEntry classEntry = (ConstantPool.ClassEntry) get(methodEntry.getClassIndex());
- @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8Entry = (ConstantPool.UTF8Entry) get(classEntry
- .getNameIndex());
+ @SuppressWarnings("unused")
+ final ConstantPool.UTF8Entry utf8Entry = (ConstantPool.UTF8Entry) get(classEntry.getNameIndex());
final ConstantPool.NameAndTypeEntry nameAndTypeEntry = (ConstantPool.NameAndTypeEntry) get(methodEntry
.getNameAndTypeIndex());
- @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8NameEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry
- .getNameIndex());
- @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8DescriptorEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry
- .getDescriptorIndex());
+ @SuppressWarnings("unused")
+ final ConstantPool.UTF8Entry utf8NameEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry.getNameIndex());
+ @SuppressWarnings("unused")
+ final ConstantPool.UTF8Entry utf8DescriptorEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry.getDescriptorIndex());
references = new int[] {
- methodEntry.getClassIndex(),
- classEntry.getNameIndex(),
- nameAndTypeEntry.getNameIndex(),
+ methodEntry.getClassIndex(), classEntry.getNameIndex(), nameAndTypeEntry.getNameIndex(),
nameAndTypeEntry.getDescriptorIndex()
};
} else if (_entry instanceof ConstantPool.InterfaceMethodEntry) {
final ConstantPool.InterfaceMethodEntry interfaceMethodEntry = (ConstantPool.InterfaceMethodEntry) _entry;
final ConstantPool.ClassEntry classEntry = (ConstantPool.ClassEntry) get(interfaceMethodEntry.getClassIndex());
- @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8Entry = (ConstantPool.UTF8Entry) get(classEntry
- .getNameIndex());
+ @SuppressWarnings("unused")
+ final ConstantPool.UTF8Entry utf8Entry = (ConstantPool.UTF8Entry) get(classEntry.getNameIndex());
final ConstantPool.NameAndTypeEntry nameAndTypeEntry = (ConstantPool.NameAndTypeEntry) get(interfaceMethodEntry
.getNameAndTypeIndex());
- @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8NameEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry
- .getNameIndex());
- @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8DescriptorEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry
- .getDescriptorIndex());
+ @SuppressWarnings("unused")
+ final ConstantPool.UTF8Entry utf8NameEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry.getNameIndex());
+ @SuppressWarnings("unused")
+ final ConstantPool.UTF8Entry utf8DescriptorEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry.getDescriptorIndex());
references = new int[] {
- interfaceMethodEntry.getClassIndex(),
- classEntry.getNameIndex(),
- nameAndTypeEntry.getNameIndex(),
+ interfaceMethodEntry.getClassIndex(), classEntry.getNameIndex(), nameAndTypeEntry.getNameIndex(),
nameAndTypeEntry.getDescriptorIndex()
};
} else if (_entry instanceof ConstantPool.FieldEntry) {
final ConstantPool.FieldEntry fieldEntry = (ConstantPool.FieldEntry) _entry;
final ConstantPool.ClassEntry classEntry = (ConstantPool.ClassEntry) get(fieldEntry.getClassIndex());
- @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8Entry = (ConstantPool.UTF8Entry) get(classEntry
- .getNameIndex());
+ @SuppressWarnings("unused")
+ final ConstantPool.UTF8Entry utf8Entry = (ConstantPool.UTF8Entry) get(classEntry.getNameIndex());
final ConstantPool.NameAndTypeEntry nameAndTypeEntry = (ConstantPool.NameAndTypeEntry) get(fieldEntry
.getNameAndTypeIndex());
- @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8NameEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry
- .getNameIndex());
- @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8DescriptorEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry
- .getDescriptorIndex());
+ @SuppressWarnings("unused")
+ final ConstantPool.UTF8Entry utf8NameEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry.getNameIndex());
+ @SuppressWarnings("unused")
+ final ConstantPool.UTF8Entry utf8DescriptorEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry.getDescriptorIndex());
references = new int[] {
- fieldEntry.getClassIndex(),
- classEntry.getNameIndex(),
- nameAndTypeEntry.getNameIndex(),
+ fieldEntry.getClassIndex(), classEntry.getNameIndex(), nameAndTypeEntry.getNameIndex(),
nameAndTypeEntry.getDescriptorIndex()
};
}
@@ -1581,7 +1625,8 @@ public class ClassModel{
codeEntryAttributePool = new AttributePool(_byteReader, getName());
}
- @Override public AttributePool getAttributePool() {
+ @Override
+ public AttributePool getAttributePool() {
return (codeEntryAttributePool);
}
@@ -1664,7 +1709,8 @@ public class ClassModel{
super(_byteReader, _nameIndex, _length);
}
- @Override public Iterator<T> iterator() {
+ @Override
+ public Iterator<T> iterator() {
return (pool.iterator());
}
}
@@ -1848,27 +1894,33 @@ public class ClassModel{
return (variableNameIndex);
}
- @Override public int getStart() {
+ @Override
+ public int getStart() {
return (start);
}
- @Override public int getVariableIndex() {
+ @Override
+ public int getVariableIndex() {
return (variableIndex);
}
- @Override public String getVariableName() {
+ @Override
+ public String getVariableName() {
return (constantPool.getUTF8Entry(variableNameIndex).getUTF8());
}
- @Override public String getVariableDescriptor() {
+ @Override
+ public String getVariableDescriptor() {
return (constantPool.getUTF8Entry(descriptorIndex).getUTF8());
}
- @Override public int getEnd() {
+ @Override
+ public int getEnd() {
return (start + usageLength);
}
- @Override public boolean isArray() {
+ @Override
+ public boolean isArray() {
return (getVariableDescriptor().startsWith("["));
}
}
@@ -1969,7 +2021,8 @@ public class ClassModel{
return (bytes);
}
- @Override public String toString() {
+ @Override
+ public String toString() {
return (new String(bytes));
}
@@ -1987,7 +2040,8 @@ public class ClassModel{
return (bytes);
}
- @Override public String toString() {
+ @Override
+ public String toString() {
return (new String(bytes));
}
}
@@ -2004,7 +2058,8 @@ public class ClassModel{
return (bytes);
}
- @Override public String toString() {
+ @Override
+ public String toString() {
return (new String(bytes));
}
}
@@ -2094,9 +2149,11 @@ public class ClassModel{
}
}
- @SuppressWarnings("unused") private final int elementNameIndex;
+ @SuppressWarnings("unused")
+ private final int elementNameIndex;
- @SuppressWarnings("unused") private Value value;
+ @SuppressWarnings("unused")
+ private Value value;
public ElementValuePair(ByteReader _byteReader) {
elementNameIndex = _byteReader.u2();
@@ -2626,12 +2683,23 @@ public class ClassModel{
}
public ClassModelMethod getMethod(String _name, String _descriptor) {
+ ClassModelMethod methodOrNull = getMethodOrNull(_name, _descriptor);
+ if (methodOrNull == null)
+ return superClazz != null ? superClazz.getMethod(_name, _descriptor) : (null);
+ return methodOrNull;
+ }
+
+ private ClassModelMethod getMethodOrNull(String _name, String _descriptor) {
for (final ClassModelMethod entry : methods) {
if (entry.getName().equals(_name) && entry.getDescriptor().equals(_descriptor)) {
+ if (logger.isLoggable(Level.FINE)) {
+ logger.fine("Found " + clazz.getName() + "." + entry.getName() + " " + entry.getDescriptor() + " for "
+ + _name.replace('/', '.'));
+ }
return (entry);
}
}
- return superClazz != null ? superClazz.getMethod(_name, _descriptor) : (null);
+ return null;
}
public List<ClassModelField> getFieldPoolEntries() {
@@ -2647,7 +2715,9 @@ public class ClassModel{
* @return The Method or null if we fail to locate a given method.
*/
public ClassModelMethod getMethod(MethodEntry _methodEntry, boolean _isSpecial) {
- final String entryClassNameInDotForm = _methodEntry.getClassEntry().getNameUTF8Entry().getUTF8().replace('/', '.');
+ NameAndTypeEntry nameAndTypeEntry = _methodEntry.getNameAndTypeEntry();
+ String utf8Name = nameAndTypeEntry.getNameUTF8Entry().getUTF8();
+ final String entryClassNameInDotForm = utf8Name.replace('/', '.');
// Shortcut direct calls to supers to allow "foo() { super.foo() }" type stuff to work
if (_isSpecial && (superClazz != null) && superClazz.isSuperClass(entryClassNameInDotForm)) {
@@ -2658,20 +2728,21 @@ public class ClassModel{
return superClazz.getMethod(_methodEntry, false);
}
- for (final ClassModelMethod entry : methods) {
- if (entry.getName().equals(_methodEntry.getNameAndTypeEntry().getNameUTF8Entry().getUTF8())
- && entry.getDescriptor().equals(_methodEntry.getNameAndTypeEntry().getDescriptorUTF8Entry().getUTF8())) {
- if (logger.isLoggable(Level.FINE)) {
- logger.fine("Found " + clazz.getName() + "." + entry.getName() + " " + entry.getDescriptor() + " for "
- + entryClassNameInDotForm);
- }
- return (entry);
- }
- }
-
- return superClazz != null ? superClazz.getMethod(_methodEntry, false) : (null);
+ ClassModelMethod methodOrNull = getMethodOrNull(utf8Name, nameAndTypeEntry.getDescriptorUTF8Entry().getUTF8());
+ if (methodOrNull == null)
+ return superClazz != null ? superClazz.getMethod(_methodEntry, false) : (null);
+ return methodOrNull;
}
+ // private ValueCache<MethodKey, MethodModel, AparapiException> methodModelCache = ValueCache.on(this::computeMethodModel);
+ private ValueCache<MethodKey, MethodModel, AparapiException> methodModelCache = ValueCache
+ .on(new ThrowingValueComputer<MethodKey, MethodModel, AparapiException>(){
+ @Override
+ public MethodModel compute(MethodKey key) throws AparapiException {
+ return computeMethodModel(key);
+ }
+ });
+
/**
* Create a MethodModel for a given method name and signature.
*
@@ -2680,9 +2751,17 @@ public class ClassModel{
* @return
* @throws AparapiException
*/
-
public MethodModel getMethodModel(String _name, String _signature) throws AparapiException {
- final ClassModelMethod method = getMethod(_name, _signature);
+ if (CacheEnabler.areCachesEnabled())
+ return methodModelCache.computeIfAbsent(MethodKey.of(_name, _signature));
+ else {
+ final ClassModelMethod method = getMethod(_name, _signature);
+ return new MethodModel(method);
+ }
+ }
+
+ private MethodModel computeMethodModel(MethodKey methodKey) throws AparapiException {
+ final ClassModelMethod method = getMethod(methodKey.getName(), methodKey.getSignature());
return new MethodModel(method);
}
@@ -2715,9 +2794,29 @@ public class ClassModel{
totalStructSize = x;
}
+ // private final ValueCache<EntrypointKey, Entrypoint, AparapiException> entrypointCache = ValueCache.on(this::computeBasicEntrypoint);
+ private final ValueCache<EntrypointKey, Entrypoint, AparapiException> entrypointCache = ValueCache
+ .on(new ThrowingValueComputer<EntrypointKey, Entrypoint, AparapiException>(){
+ @Override
+ public Entrypoint compute(EntrypointKey key) throws AparapiException {
+ return computeBasicEntrypoint(key);
+ }
+ });
+
Entrypoint getEntrypoint(String _entrypointName, String _descriptor, Object _k) throws AparapiException {
- final MethodModel method = getMethodModel(_entrypointName, _descriptor);
- return (new Entrypoint(this, method, _k));
+ if (CacheEnabler.areCachesEnabled()) {
+ EntrypointKey key = EntrypointKey.of(_entrypointName, _descriptor);
+ Entrypoint entrypointWithoutKernel = entrypointCache.computeIfAbsent(key);
+ return entrypointWithoutKernel.cloneForKernel(_k);
+ } else {
+ final MethodModel method = getMethodModel(_entrypointName, _descriptor);
+ return new Entrypoint(this, method, _k);
+ }
+ }
+
+ Entrypoint computeBasicEntrypoint(EntrypointKey entrypointKey) throws AparapiException {
+ final MethodModel method = getMethodModel(entrypointKey.getEntrypointName(), entrypointKey.getDescriptor());
+ return new Entrypoint(this, method, null);
}
public Class<?> getClassWeAreModelling() {
@@ -2731,4 +2830,8 @@ public class ClassModel{
public Entrypoint getEntrypoint() throws AparapiException {
return (getEntrypoint("run", "()V", null));
}
+
+ public static void invalidateCaches() {
+ classModelCache.invalidate();
+ }
}
diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/Entrypoint.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/Entrypoint.java
index 28aadab21efa63eb18c63abc4930cc7ed942075f..7ae155efa905a1dca6cd88f39931977a6ea9317a 100644
--- a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/Entrypoint.java
+++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/Entrypoint.java
@@ -50,7 +50,7 @@ import java.lang.reflect.*;
import java.util.*;
import java.util.logging.*;
-public class Entrypoint{
+public class Entrypoint implements Cloneable {
private static Logger logger = Logger.getLogger(Config.getLoggerName());
@@ -81,7 +81,7 @@ public class Entrypoint{
private final List<MethodModel> calledMethods = new ArrayList<MethodModel>();
- private MethodModel methodModel;
+ private final MethodModel methodModel;
/**
True is an indication to use the fp64 pragma
@@ -226,7 +226,7 @@ public class Entrypoint{
final Class<?> memberClass = Class.forName(className);
// Immediately add this class and all its supers if necessary
- memberClassModel = new ClassModel(memberClass);
+ memberClassModel = ClassModel.createClassModel(memberClass);
if (logger.isLoggable(Level.FINEST)) {
logger.finest("adding class " + className);
}
@@ -595,7 +595,7 @@ public class Entrypoint{
// Add the class model for the referenced obj array
if (signature.startsWith("[L")) {
// Turn [Lcom/amd/javalabs/opencl/demo/DummyOOA; into com.amd.javalabs.opencl.demo.DummyOOA for example
- final String className = (signature.substring(2, signature.length() - 1)).replace("/", ".");
+ final String className = (signature.substring(2, signature.length() - 1)).replace('/', '.');
final ClassModel arrayFieldModel = getOrUpdateAllClassAccesses(className);
if (arrayFieldModel != null) {
final Class<?> memberClass = arrayFieldModel.getClassWeAreModelling();
@@ -625,7 +625,7 @@ public class Entrypoint{
}
}
} else {
- final String className = (field.getClassEntry().getNameUTF8Entry().getUTF8()).replace("/", ".");
+ final String className = (field.getClassEntry().getNameUTF8Entry().getUTF8()).replace('/', '.');
// Look for object data member access
if (!className.equals(getClassModel().getClassWeAreModelling().getName())
&& (getFieldFromClassHierarchy(getClassModel().getClassWeAreModelling(), accessedFieldName) == null)) {
@@ -640,7 +640,7 @@ public class Entrypoint{
fieldAssignments.add(assignedFieldName);
referencedFieldNames.add(assignedFieldName);
- final String className = (field.getClassEntry().getNameUTF8Entry().getUTF8()).replace("/", ".");
+ final String className = (field.getClassEntry().getNameUTF8Entry().getUTF8()).replace('/', '.');
// Look for object data member access
if (!className.equals(getClassModel().getClassWeAreModelling().getName())
&& (getFieldFromClassHierarchy(getClassModel().getClassWeAreModelling(), assignedFieldName) == null)) {
@@ -909,4 +909,14 @@ public class Entrypoint{
return null;
}
+
+ Entrypoint cloneForKernel(Object _k) throws AparapiException {
+ try {
+ Entrypoint clonedEntrypoint = (Entrypoint) clone();
+ clonedEntrypoint.kernelInstance = _k;
+ return clonedEntrypoint;
+ } catch (CloneNotSupportedException e) {
+ throw new AparapiException(e);
+ }
+ }
}
diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/EntrypointKey.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/EntrypointKey.java
new file mode 100644
index 0000000000000000000000000000000000000000..bddf99d7b8ca0b289b2add855dd96bd0179a0f2f
--- /dev/null
+++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/EntrypointKey.java
@@ -0,0 +1,57 @@
+package com.amd.aparapi.internal.model;
+
+final class EntrypointKey{
+ public static EntrypointKey of(String entrypointName, String descriptor) {
+ return new EntrypointKey(entrypointName, descriptor);
+ }
+
+ private String descriptor;
+
+ private String entrypointName;
+
+ private EntrypointKey(String entrypointName, String descriptor) {
+ this.entrypointName = entrypointName;
+ this.descriptor = descriptor;
+ }
+
+ String getDescriptor() {
+ return descriptor;
+ }
+
+ String getEntrypointName() {
+ return entrypointName;
+ }
+
+ @Override public int hashCode() {
+ final int prime = 31;
+ int result = 1;
+ result = prime * result + ((descriptor == null) ? 0 : descriptor.hashCode());
+ result = prime * result + ((entrypointName == null) ? 0 : entrypointName.hashCode());
+ return result;
+ }
+
+ @Override public String toString() {
+ return "EntrypointKey [entrypointName=" + entrypointName + ", descriptor=" + descriptor + "]";
+ }
+
+ @Override public boolean equals(Object obj) {
+ if (this == obj)
+ return true;
+ if (obj == null)
+ return false;
+ if (getClass() != obj.getClass())
+ return false;
+ EntrypointKey other = (EntrypointKey) obj;
+ if (descriptor == null) {
+ if (other.descriptor != null)
+ return false;
+ } else if (!descriptor.equals(other.descriptor))
+ return false;
+ if (entrypointName == null) {
+ if (other.entrypointName != null)
+ return false;
+ } else if (!entrypointName.equals(other.entrypointName))
+ return false;
+ return true;
+ }
+}
diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/MethodKey.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/MethodKey.java
new file mode 100644
index 0000000000000000000000000000000000000000..d08e2debac32c97fddf209ce295e471df44dd29e
--- /dev/null
+++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/MethodKey.java
@@ -0,0 +1,57 @@
+package com.amd.aparapi.internal.model;
+
+final class MethodKey{
+ static MethodKey of(String name, String signature) {
+ return new MethodKey(name, signature);
+ }
+
+ private final String name;
+
+ private final String signature;
+
+ @Override public String toString() {
+ return "MethodKey [name=" + getName() + ", signature=" + getSignature() + "]";
+ }
+
+ @Override public int hashCode() {
+ final int prime = 31;
+ int result = 1;
+ result = prime * result + ((getName() == null) ? 0 : getName().hashCode());
+ result = prime * result + ((getSignature() == null) ? 0 : getSignature().hashCode());
+ return result;
+ }
+
+ @Override public boolean equals(Object obj) {
+ if (this == obj)
+ return true;
+ if (obj == null)
+ return false;
+ if (getClass() != obj.getClass())
+ return false;
+ MethodKey other = (MethodKey) obj;
+ if (getName() == null) {
+ if (other.getName() != null)
+ return false;
+ } else if (!getName().equals(other.getName()))
+ return false;
+ if (getSignature() == null) {
+ if (other.getSignature() != null)
+ return false;
+ } else if (!getSignature().equals(other.getSignature()))
+ return false;
+ return true;
+ }
+
+ private MethodKey(String name, String signature) {
+ this.name = name;
+ this.signature = signature;
+ }
+
+ public String getName() {
+ return name;
+ }
+
+ public String getSignature() {
+ return signature;
+ }
+}
diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/tool/InstructionViewer.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/tool/InstructionViewer.java
index 6e80defbadf45301bce302cb474343a269914660..015cf46d7492b83b5d914c77e3b206006cbe32ae 100644
--- a/com.amd.aparapi/src/java/com/amd/aparapi/internal/tool/InstructionViewer.java
+++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/tool/InstructionViewer.java
@@ -625,7 +625,7 @@ public class InstructionViewer implements Config.InstructionListener{
public InstructionViewer(Color _background, String _name) {
try {
- classModel = new ClassModel(Class.forName(_name));
+ classModel = ClassModel.createClassModel(Class.forName(_name));
} catch (final ClassParseException e) {
// TODO Auto-generated catch block
e.printStackTrace();