diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/Kernel.java b/com.amd.aparapi/src/java/com/amd/aparapi/Kernel.java index 0069d4e46fcb97d834b5a4c5af48e3d84dc710f9..ff9094e856ef8cbdfed7ca896cd8717b534b2a0a 100644 --- a/com.amd.aparapi/src/java/com/amd/aparapi/Kernel.java +++ b/com.amd.aparapi/src/java/com/amd/aparapi/Kernel.java @@ -40,6 +40,10 @@ package com.amd.aparapi; import com.amd.aparapi.annotation.*; import com.amd.aparapi.exception.*; import com.amd.aparapi.internal.kernel.*; +import com.amd.aparapi.internal.model.CacheEnabler; +import com.amd.aparapi.internal.model.ValueCache; +import com.amd.aparapi.internal.model.ValueCache.ThrowingValueComputer; +import com.amd.aparapi.internal.model.ValueCache.ValueComputer; import com.amd.aparapi.internal.model.ClassModel.ConstantPool.*; import com.amd.aparapi.internal.opencl.*; import com.amd.aparapi.internal.util.*; @@ -488,6 +492,8 @@ public abstract class Kernel implements Cloneable { private volatile CyclicBarrier localBarrier; + private boolean localBarrierDisabled; + /** * Default constructor */ @@ -620,6 +626,24 @@ public abstract class Kernel implements Cloneable { public void setLocalBarrier(CyclicBarrier localBarrier) { this.localBarrier = localBarrier; } + + public void awaitOnLocalBarrier() { + if (!localBarrierDisabled) { + try { + kernelState.getLocalBarrier().await(); + } catch (final InterruptedException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } catch (final BrokenBarrierException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + } + } + + public void disableLocalBarrier() { + localBarrierDisabled = true; + } } /** @@ -1834,15 +1858,7 @@ public abstract class Kernel implements Cloneable { @OpenCLDelegate @Experimental protected final void localBarrier() { - try { - kernelState.getLocalBarrier().await(); - } catch (final InterruptedException e) { - // TODO Auto-generated catch block - e.printStackTrace(); - } catch (final BrokenBarrierException e) { - // TODO Auto-generated catch block - e.printStackTrace(); - } + kernelState.awaitOnLocalBarrier(); } /** @@ -1862,6 +1878,16 @@ public abstract class Kernel implements Cloneable { "Kernel.globalBarrier() has been deprecated. It was based an incorrect understanding of OpenCL functionality."); } + @OpenCLMapping(mapTo = "hypot") + protected float hypot(final float a, final float b) { + return (float) Math.hypot(a, b); + } + + @OpenCLMapping(mapTo = "hypot") + protected double hypot(final double a, final double b) { + return Math.hypot(a, b); + } + public KernelState getKernelState() { return kernelState; } @@ -2100,49 +2126,64 @@ public abstract class Kernel implements Cloneable { } private static String getReturnTypeLetter(Method meth) { - final Class<?> retClass = meth.getReturnType(); + return toClassShortNameIfAny(meth.getReturnType()); + } + + private static String toClassShortNameIfAny(final Class<?> retClass) { + if (retClass.isArray()) { + return "[" + toClassShortNameIfAny(retClass.getComponentType()); + } final String strRetClass = retClass.toString(); final String mapping = typeToLetterMap.get(strRetClass); // System.out.println("strRetClass = <" + strRetClass + ">, mapping = " + mapping); + if (mapping == null) + return "[" + retClass.getName() + ";"; return mapping; } public static String getMappedMethodName(MethodReferenceEntry _methodReferenceEntry) { + if (CacheEnabler.areCachesEnabled()) + return getProperty(mappedMethodNamesCache, _methodReferenceEntry, null); String mappedName = null; final String name = _methodReferenceEntry.getNameAndTypeEntry().getNameUTF8Entry().getUTF8(); - for (final Method kernelMethod : Kernel.class.getDeclaredMethods()) { - if (kernelMethod.isAnnotationPresent(OpenCLMapping.class)) { - // ultimately, need a way to constrain this based upon signature (to disambiguate abs(float) from abs(int); - // for Alpha, we will just disambiguate based on the return type - if (false) { - System.out.println("kernelMethod is ... " + kernelMethod.toGenericString()); - System.out.println("returnType = " + kernelMethod.getReturnType()); - System.out.println("returnTypeLetter = " + getReturnTypeLetter(kernelMethod)); - System.out.println("kernelMethod getName = " + kernelMethod.getName()); - System.out.println("methRefName = " + name + " descriptor = " - + _methodReferenceEntry.getNameAndTypeEntry().getDescriptorUTF8Entry().getUTF8()); - System.out - .println("descToReturnTypeLetter = " - + descriptorToReturnTypeLetter(_methodReferenceEntry.getNameAndTypeEntry().getDescriptorUTF8Entry() - .getUTF8())); - } - if (_methodReferenceEntry.getNameAndTypeEntry().getNameUTF8Entry().getUTF8().equals(kernelMethod.getName()) - && descriptorToReturnTypeLetter(_methodReferenceEntry.getNameAndTypeEntry().getDescriptorUTF8Entry().getUTF8()) - .equals(getReturnTypeLetter(kernelMethod))) { - final OpenCLMapping annotation = kernelMethod.getAnnotation(OpenCLMapping.class); - final String mapTo = annotation.mapTo(); - if (!mapTo.equals("")) { - mappedName = mapTo; - // System.out.println("mapTo = " + mapTo); + Class<?> currentClass = _methodReferenceEntry.getOwnerClassModel().getClassWeAreModelling(); + while (currentClass != Object.class) { + for (final Method kernelMethod : currentClass.getDeclaredMethods()) { + if (kernelMethod.isAnnotationPresent(OpenCLMapping.class)) { + // ultimately, need a way to constrain this based upon signature (to disambiguate abs(float) from abs(int); + // for Alpha, we will just disambiguate based on the return type + if (false) { + System.out.println("kernelMethod is ... " + kernelMethod.toGenericString()); + System.out.println("returnType = " + kernelMethod.getReturnType()); + System.out.println("returnTypeLetter = " + getReturnTypeLetter(kernelMethod)); + System.out.println("kernelMethod getName = " + kernelMethod.getName()); + System.out.println("methRefName = " + name + " descriptor = " + + _methodReferenceEntry.getNameAndTypeEntry().getDescriptorUTF8Entry().getUTF8()); + System.out.println("descToReturnTypeLetter = " + + descriptorToReturnTypeLetter(_methodReferenceEntry.getNameAndTypeEntry().getDescriptorUTF8Entry() + .getUTF8())); + } + if (toSignature(_methodReferenceEntry).equals(toSignature(kernelMethod))) { + final OpenCLMapping annotation = kernelMethod.getAnnotation(OpenCLMapping.class); + final String mapTo = annotation.mapTo(); + if (!mapTo.equals("")) { + mappedName = mapTo; + // System.out.println("mapTo = " + mapTo); + } } } } + if (mappedName != null) + break; + currentClass = currentClass.getSuperclass(); } // System.out.println("... in getMappedMethodName, returning = " + mappedName); return (mappedName); } public static boolean isMappedMethod(MethodReferenceEntry methodReferenceEntry) { + if (CacheEnabler.areCachesEnabled()) + return getBoolean(mappedMethodFlags, methodReferenceEntry); boolean isMapped = false; for (final Method kernelMethod : Kernel.class.getDeclaredMethods()) { if (kernelMethod.isAnnotationPresent(OpenCLMapping.class)) { @@ -2157,6 +2198,8 @@ public abstract class Kernel implements Cloneable { } public static boolean isOpenCLDelegateMethod(MethodReferenceEntry methodReferenceEntry) { + if (CacheEnabler.areCachesEnabled()) + return getBoolean(openCLDelegateMethodFlags, methodReferenceEntry); boolean isMapped = false; for (final Method kernelMethod : Kernel.class.getDeclaredMethods()) { if (kernelMethod.isAnnotationPresent(OpenCLDelegate.class)) { @@ -2171,6 +2214,8 @@ public abstract class Kernel implements Cloneable { } public static boolean usesAtomic32(MethodReferenceEntry methodReferenceEntry) { + if (CacheEnabler.areCachesEnabled()) + return getProperty(atomic32Cache, methodReferenceEntry, false); for (final Method kernelMethod : Kernel.class.getDeclaredMethods()) { if (kernelMethod.isAnnotationPresent(OpenCLMapping.class)) { if (methodReferenceEntry.getNameAndTypeEntry().getNameUTF8Entry().getUTF8().equals(kernelMethod.getName())) { @@ -2184,6 +2229,8 @@ public abstract class Kernel implements Cloneable { // For alpha release atomic64 is not supported public static boolean usesAtomic64(MethodReferenceEntry methodReferenceEntry) { + // if (CacheEnabler.areCachesEnabled()) + // return getProperty(atomic64Cache, methodReferenceEntry, false); //for (java.lang.reflect.Method kernelMethod : Kernel.class.getDeclaredMethods()) { // if (kernelMethod.isAnnotationPresent(Kernel.OpenCLMapping.class)) { // if (methodReferenceEntry.getNameAndTypeEntry().getNameUTF8Entry().getUTF8().equals(kernelMethod.getName())) { @@ -2859,4 +2906,131 @@ public abstract class Kernel implements Cloneable { executionMode = currentMode.next(); } } + + private static final ValueCache<Class<?>, Map<String, Boolean>, RuntimeException> mappedMethodFlags = markedWith(OpenCLMapping.class); + + private static final ValueCache<Class<?>, Map<String, Boolean>, RuntimeException> openCLDelegateMethodFlags = markedWith(OpenCLDelegate.class); + + private static final ValueCache<Class<?>, Map<String, Boolean>, RuntimeException> atomic32Cache = cacheProperty(new ValueComputer<Class<?>, Map<String, Boolean>>(){ + @Override + public Map<String, Boolean> compute(Class<?> key) { + Map<String, Boolean> properties = new HashMap<>(); + for (final Method method : key.getDeclaredMethods()) { + if (isRelevant(method) && method.isAnnotationPresent(OpenCLMapping.class)) { + properties.put(toSignature(method), method.getAnnotation(OpenCLMapping.class).atomic32()); + } + } + return properties; + } + }); + + private static final ValueCache<Class<?>, Map<String, Boolean>, RuntimeException> atomic64Cache = cacheProperty(new ValueComputer<Class<?>, Map<String, Boolean>>(){ + @Override + public Map<String, Boolean> compute(Class<?> key) { + Map<String, Boolean> properties = new HashMap<>(); + for (final Method method : key.getDeclaredMethods()) { + if (isRelevant(method) && method.isAnnotationPresent(OpenCLMapping.class)) { + properties.put(toSignature(method), method.getAnnotation(OpenCLMapping.class).atomic64()); + } + } + return properties; + } + }); + + private static boolean getBoolean(ValueCache<Class<?>, Map<String, Boolean>, RuntimeException> methodNamesCache, + MethodReferenceEntry methodReferenceEntry) { + return getProperty(methodNamesCache, methodReferenceEntry, false); + } + + private static <A extends Annotation> ValueCache<Class<?>, Map<String, Boolean>, RuntimeException> markedWith( + final Class<A> annotationClass) { + return cacheProperty(new ValueComputer<Class<?>, Map<String, Boolean>>(){ + @Override + public Map<String, Boolean> compute(Class<?> key) { + Map<String, Boolean> markedMethodNames = new HashMap<>(); + for (final Method method : key.getDeclaredMethods()) { + markedMethodNames.put(toSignature(method), method.isAnnotationPresent(annotationClass)); + } + return markedMethodNames; + } + }); + } + + static String toSignature(Method method) { + return method.getName() + getArgumentsLetters(method) + getReturnTypeLetter(method); + } + + private static String getArgumentsLetters(Method method) { + StringBuilder sb = new StringBuilder("("); + for (Class<?> parameterClass : method.getParameterTypes()) { + sb.append(toClassShortNameIfAny(parameterClass)); + } + sb.append(")"); + return sb.toString(); + } + + private static boolean isRelevant(Method method) { + return !method.isSynthetic() && !method.isBridge(); + } + + private static <V, T extends Throwable> V getProperty(ValueCache<Class<?>, Map<String, V>, T> cache, + MethodReferenceEntry methodReferenceEntry, V defaultValue) throws T { + Map<String, V> map = cache.computeIfAbsent(methodReferenceEntry.getOwnerClassModel().getClassWeAreModelling()); + String key = toSignature(methodReferenceEntry); + if (map.containsKey(key)) + return map.get(key); + return defaultValue; + } + + private static String toSignature(MethodReferenceEntry methodReferenceEntry) { + NameAndTypeEntry nameAndTypeEntry = methodReferenceEntry.getNameAndTypeEntry(); + return nameAndTypeEntry.getNameUTF8Entry().getUTF8() + nameAndTypeEntry.getDescriptorUTF8Entry().getUTF8(); + } + + private static final ValueCache<Class<?>, Map<String, String>, RuntimeException> mappedMethodNamesCache = cacheProperty(new ValueComputer<Class<?>, Map<String, String>>(){ + @Override + public Map<String, String> compute(Class<?> key) { + Map<String, String> properties = new HashMap<>(); + for (final Method method : key.getDeclaredMethods()) { + if (isRelevant(method) && method.isAnnotationPresent(OpenCLMapping.class)) { + // ultimately, need a way to constrain this based upon signature (to disambiguate abs(float) from abs(int); + final OpenCLMapping annotation = method.getAnnotation(OpenCLMapping.class); + final String mapTo = annotation.mapTo(); + if (mapTo != null && !mapTo.equals("")) { + properties.put(toSignature(method), mapTo); + } + } + } + return properties; + } + }); + + private static <K, V, T extends Throwable> ValueCache<Class<?>, Map<K, V>, T> cacheProperty( + final ThrowingValueComputer<Class<?>, Map<K, V>, T> throwingValueComputer) { + return ValueCache.on(new ThrowingValueComputer<Class<?>, Map<K, V>, T>(){ + @Override + public Map<K, V> compute(Class<?> key) throws T { + Map<K, V> properties = new HashMap<>(); + Deque<Class<?>> superclasses = new ArrayDeque<>(); + Class<?> currentSuperClass = key; + do { + superclasses.push(currentSuperClass); + currentSuperClass = currentSuperClass.getSuperclass(); + } while (currentSuperClass != Object.class); + for (Class<?> clazz : superclasses) { + // Overwrite property values for shadowed/overriden methods + properties.putAll(throwingValueComputer.compute(clazz)); + } + return properties; + } + }); + } + + public static void invalidateCaches() { + atomic32Cache.invalidate(); + atomic64Cache.invalidate(); + mappedMethodFlags.invalidate(); + mappedMethodNamesCache.invalidate(); + openCLDelegateMethodFlags.invalidate(); + } } diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/kernel/KernelRunner.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/kernel/KernelRunner.java index 821364c85bccf5906c55b4aae45855a5e66b9d8c..cdb6c9d094007713ad2da09f1219b06e5f491682 100644 --- a/com.amd.aparapi/src/java/com/amd/aparapi/internal/kernel/KernelRunner.java +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/kernel/KernelRunner.java @@ -48,8 +48,10 @@ import java.util.Set; import java.util.StringTokenizer; import java.util.concurrent.BrokenBarrierException; import java.util.concurrent.CyclicBarrier; -import java.util.concurrent.Executors; -import java.util.concurrent.ExecutorService; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.ForkJoinPool.ForkJoinWorkerThreadFactory; +import java.util.concurrent.ForkJoinPool.ManagedBlocker; +import java.util.concurrent.ForkJoinWorkerThread; import java.util.logging.Level; import java.util.logging.Logger; @@ -100,8 +102,18 @@ public class KernelRunner extends KernelRunnerJNI{ private Entrypoint entryPoint; private int argc; - - private final ExecutorService threadPool = Executors.newCachedThreadPool(); + + private static final ForkJoinWorkerThreadFactory lowPriorityThreadFactory = new ForkJoinWorkerThreadFactory(){ + @Override public ForkJoinWorkerThread newThread(ForkJoinPool pool) { + ForkJoinWorkerThread newThread = ForkJoinPool.defaultForkJoinWorkerThreadFactory.newThread(pool); + newThread.setPriority(Thread.MIN_PRIORITY); + return newThread; + } + }; + + private static final ForkJoinPool threadPool = new ForkJoinPool(Runtime.getRuntime().availableProcessors(), + lowPriorityThreadFactory, null, false); + /** * Create a KernelRunner for a specific Kernel instance. * @@ -120,7 +132,8 @@ public class KernelRunner extends KernelRunnerJNI{ if (kernel.getExecutionMode().isOpenCL()) { disposeJNI(jniContextHandle); } - threadPool.shutdownNow(); + // We are using a shared pool, so there's no need no shutdown it when kernel is disposed + // threadPool.shutdownNow(); } private Set<String> capabilitiesSet; @@ -215,6 +228,50 @@ public class KernelRunner extends KernelRunnerJNI{ return capabilitiesSet.contains(OpenCL.CL_KHR_GL_SHARING); } + private static final class FJSafeCyclicBarrier extends CyclicBarrier{ + FJSafeCyclicBarrier(final int threads) { + super(threads); + } + + @Override public int await() throws InterruptedException, BrokenBarrierException { + class Awaiter implements ManagedBlocker{ + private int value; + + private boolean released; + + @Override public boolean block() throws InterruptedException { + try { + value = superAwait(); + released = true; + return true; + } catch (final BrokenBarrierException e) { + throw new RuntimeException(e); + } + } + + @Override public boolean isReleasable() { + return released; + } + + int getValue() { + return value; + } + } + final Awaiter awaiter = new Awaiter(); + ForkJoinPool.managedBlock(awaiter); + return awaiter.getValue(); + } + + int superAwait() throws InterruptedException, BrokenBarrierException { + return super.await(); + } + } + + // @FunctionalInterface + private interface ThreadIdSetter{ + void set(KernelState kernelState, int globalGroupId, int threadId); + } + /** * Execute using a Java thread pool. Either because we were explicitly asked to do so, or because we 'fall back' after discovering an OpenCL issue. * @@ -224,11 +281,15 @@ public class KernelRunner extends KernelRunnerJNI{ * The # of passes requested by the user (via <code>Kernel.execute(globalSize, passes)</code>). Note this is usually defaulted to 1 via <code>Kernel.execute(globalSize)</code>. * @return */ - private long executeJava(final Range _range, final int _passes) { + protected long executeJava(final Range _range, final int _passes) { if (logger.isLoggable(Level.FINE)) { logger.fine("executeJava: range = " + _range); } + final int localSize0 = _range.getLocalSize(0); + final int localSize1 = _range.getLocalSize(1); + final int localSize2 = _range.getLocalSize(2); + final int globalSize1 = _range.getGlobalSize(1); if (kernel.getExecutionMode().equals(EXECUTION_MODE.SEQ)) { /** * SEQ mode is useful for testing trivial logic, but kernels which use SEQ mode cannot be used if the @@ -238,7 +299,7 @@ public class KernelRunner extends KernelRunnerJNI{ * * So we need to check if the range is valid here. If not we have no choice but to punt. */ - if ((_range.getLocalSize(0) * _range.getLocalSize(1) * _range.getLocalSize(2)) > 1) { + if ((localSize0 * localSize1 * localSize2) > 1) { throw new IllegalStateException("Can't run range with group size >1 sequentially. Barriers would deadlock!"); } @@ -252,7 +313,7 @@ public class KernelRunner extends KernelRunnerJNI{ kernelState.setLocalId(0, 0); kernelState.setLocalId(1, 0); kernelState.setLocalId(2, 0); - kernelState.setLocalBarrier(new CyclicBarrier(1)); + kernelState.setLocalBarrier(new FJSafeCyclicBarrier(1)); for (int passId = 0; passId < _passes; passId++) { kernelState.setPassId(passId); @@ -266,7 +327,7 @@ public class KernelRunner extends KernelRunnerJNI{ for (int x = 0; x < _range.getGlobalSize(0); x++) { kernelState.setGlobalId(0, x); - for (int y = 0; y < _range.getGlobalSize(1); y++) { + for (int y = 0; y < globalSize1; y++) { kernelState.setGlobalId(1, y); kernelClone.run(); } @@ -275,7 +336,7 @@ public class KernelRunner extends KernelRunnerJNI{ for (int x = 0; x < _range.getGlobalSize(0); x++) { kernelState.setGlobalId(0, x); - for (int y = 0; y < _range.getGlobalSize(1); y++) { + for (int y = 0; y < globalSize1; y++) { kernelState.setGlobalId(1, y); for (int z = 0; z < _range.getGlobalSize(2); z++) { @@ -289,13 +350,15 @@ public class KernelRunner extends KernelRunnerJNI{ } } } else { - final int threads = _range.getLocalSize(0) * _range.getLocalSize(1) * _range.getLocalSize(2); - final int globalGroups = _range.getNumGroups(0) * _range.getNumGroups(1) * _range.getNumGroups(2); + final int threads = localSize0 * localSize1 * localSize2; + final int numGroups0 = _range.getNumGroups(0); + final int numGroups1 = _range.getNumGroups(1); + final int globalGroups = numGroups0 * numGroups1 * _range.getNumGroups(2); /** * This joinBarrier is the barrier that we provide for the kernel threads to rendezvous with the current dispatch thread. * So this barrier is threadCount+1 wide (the +1 is for the dispatch thread) */ - final CyclicBarrier joinBarrier = new CyclicBarrier(threads + 1); + final CyclicBarrier joinBarrier = new FJSafeCyclicBarrier(threads + 1); /** * This localBarrier is only ever used by the kernels. If the kernel does not use the barrier the threads @@ -308,8 +371,130 @@ public class KernelRunner extends KernelRunnerJNI{ * * This barrier is threadCount wide. We never hit the barrier from the dispatch thread. */ - final CyclicBarrier localBarrier = new CyclicBarrier(threads); + final CyclicBarrier localBarrier = new FJSafeCyclicBarrier(threads); + + final ThreadIdSetter threadIdSetter; + if (_range.getDims() == 1) { + threadIdSetter = new ThreadIdSetter(){ + @Override public void set(KernelState kernelState, int globalGroupId, int threadId) { + // (kernelState, globalGroupId, threadId) ->{ + kernelState.setLocalId(0, (threadId % localSize0)); + kernelState.setGlobalId(0, (threadId + (globalGroupId * threads))); + kernelState.setGroupId(0, globalGroupId); + } + }; + } else if (_range.getDims() == 2) { + + /** + * Consider a 12x4 grid of 4*2 local groups + * <pre> + * threads = 4*2 = 8 + * localWidth=4 + * localHeight=2 + * globalWidth=12 + * globalHeight=4 + * + * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 + * 12 13 14 15 | 16 17 18 19 | 20 21 22 23 + * ------------+-------------+------------ + * 24 25 26 27 | 28 29 30 31 | 32 33 34 35 + * 36 37 38 39 | 40 41 42 43 | 44 45 46 47 + * + * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 threadIds : [0..7]*6 + * 04 05 06 07 | 04 05 06 07 | 04 05 06 07 + * ------------+-------------+------------ + * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 + * 04 05 06 07 | 04 05 06 07 | 04 05 06 07 + * + * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 groupId[0] : 0..6 + * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 + * ------------+-------------+------------ + * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 + * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 + * + * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 groupId[1] : 0..6 + * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 + * ------------+-------------+------------ + * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 + * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 + * + * 00 01 02 03 | 08 09 10 11 | 16 17 18 19 globalThreadIds == threadId + groupId * threads; + * 04 05 06 07 | 12 13 14 15 | 20 21 22 23 + * ------------+-------------+------------ + * 24 25 26 27 | 32[33]34 35 | 40 41 42 43 + * 28 29 30 31 | 36 37 38 39 | 44 45 46 47 + * + * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 localX = threadId % localWidth; (for globalThreadId 33 = threadId = 01 : 01%4 =1) + * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 + * ------------+-------------+------------ + * 00 01 02 03 | 00[01]02 03 | 00 01 02 03 + * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 + * + * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 localY = threadId /localWidth (for globalThreadId 33 = threadId = 01 : 01/4 =0) + * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 + * ------------+-------------+------------ + * 00 00 00 00 | 00[00]00 00 | 00 00 00 00 + * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 + * + * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 globalX= + * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 groupsPerLineWidth=globalWidth/localWidth (=12/4 =3) + * ------------+-------------+------------ groupInset =groupId%groupsPerLineWidth (=4%3 = 1) + * 00 01 02 03 | 04[05]06 07 | 08 09 10 11 + * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 globalX = groupInset*localWidth+localX (= 1*4+1 = 5) + * + * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 globalY + * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 + * ------------+-------------+------------ + * 02 02 02 02 | 02[02]02 02 | 02 02 02 02 + * 03 03 03 03 | 03 03 03 03 | 03 03 03 03 + * + * </pre> + * Assume we are trying to locate the id's for #33 + * + */ + threadIdSetter = new ThreadIdSetter(){ + @Override public void set(KernelState kernelState, int globalGroupId, int threadId) { + // (kernelState, globalGroupId, threadId) ->{ + kernelState.setLocalId(0, (threadId % localSize0)); // threadId % localWidth = (for 33 = 1 % 4 = 1) + kernelState.setLocalId(1, (threadId / localSize0)); // threadId / localWidth = (for 33 = 1 / 4 == 0) + + final int groupInset = globalGroupId % numGroups0; // 4%3 = 1 + kernelState.setGlobalId(0, ((groupInset * localSize0) + kernelState.getLocalIds()[0])); // 1*4+1=5 + + final int completeLines = (globalGroupId / numGroups0) * localSize1;// (4/3) * 2 + kernelState.setGlobalId(1, (completeLines + kernelState.getLocalIds()[1])); // 2+0 = 2 + kernelState.setGroupId(0, (globalGroupId % numGroups0)); + kernelState.setGroupId(1, (globalGroupId / numGroups0)); + } + }; + } else if (_range.getDims() == 3) { + //Same as 2D actually turns out that localId[0] is identical for all three dims so could be hoisted out of conditional code + threadIdSetter = new ThreadIdSetter(){ + @Override public void set(KernelState kernelState, int globalGroupId, int threadId) { + // (kernelState, globalGroupId, threadId) ->{ + kernelState.setLocalId(0, (threadId % localSize0)); + + kernelState.setLocalId(1, ((threadId / localSize0) % localSize1)); + + // the thread id's span WxHxD so threadId/(WxH) should yield the local depth + kernelState.setLocalId(2, (threadId / (localSize0 * localSize1))); + + kernelState.setGlobalId(0, (((globalGroupId % numGroups0) * localSize0) + kernelState.getLocalIds()[0])); + + kernelState.setGlobalId(1, + ((((globalGroupId / numGroups0) * localSize1) % globalSize1) + kernelState.getLocalIds()[1])); + + kernelState.setGlobalId(2, + (((globalGroupId / (numGroups0 * numGroups1)) * localSize2) + kernelState.getLocalIds()[2])); + + kernelState.setGroupId(0, (globalGroupId % numGroups0)); + kernelState.setGroupId(1, ((globalGroupId / numGroups0) % numGroups1)); + kernelState.setGroupId(2, (globalGroupId / (numGroups0 * numGroups1))); + } + }; + } else + throw new IllegalArgumentException("Expected 1,2 or 3 dimensions, found " + _range.getDims()); for (int passId = 0; passId < _passes; passId++) { /** * Note that we emulate OpenCL by creating one thread per localId (across the group). @@ -352,135 +537,31 @@ public class KernelRunner extends KernelRunnerJNI{ */ final Kernel kernelClone = kernel.clone(); final KernelState kernelState = kernelClone.getKernelState(); - kernelState.setRange(_range); - kernelState.setLocalBarrier(localBarrier); kernelState.setPassId(passId); - threadPool.submit(new Runnable(){ - @Override public void run() { - for (int globalGroupId = 0; globalGroupId < globalGroups; globalGroupId++) { - - if (_range.getDims() == 1) { - kernelState.setLocalId(0, (threadId % _range.getLocalSize(0))); - kernelState.setGlobalId(0, (threadId + (globalGroupId * threads))); - kernelState.setGroupId(0, globalGroupId); - } else if (_range.getDims() == 2) { - - /** - * Consider a 12x4 grid of 4*2 local groups - * <pre> - * threads = 4*2 = 8 - * localWidth=4 - * localHeight=2 - * globalWidth=12 - * globalHeight=4 - * - * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 - * 12 13 14 15 | 16 17 18 19 | 20 21 22 23 - * ------------+-------------+------------ - * 24 25 26 27 | 28 29 30 31 | 32 33 34 35 - * 36 37 38 39 | 40 41 42 43 | 44 45 46 47 - * - * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 threadIds : [0..7]*6 - * 04 05 06 07 | 04 05 06 07 | 04 05 06 07 - * ------------+-------------+------------ - * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 - * 04 05 06 07 | 04 05 06 07 | 04 05 06 07 - * - * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 groupId[0] : 0..6 - * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 - * ------------+-------------+------------ - * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 - * 00 00 00 00 | 01 01 01 01 | 02 02 02 02 - * - * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 groupId[1] : 0..6 - * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 - * ------------+-------------+------------ - * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 - * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 - * - * 00 01 02 03 | 08 09 10 11 | 16 17 18 19 globalThreadIds == threadId + groupId * threads; - * 04 05 06 07 | 12 13 14 15 | 20 21 22 23 - * ------------+-------------+------------ - * 24 25 26 27 | 32[33]34 35 | 40 41 42 43 - * 28 29 30 31 | 36 37 38 39 | 44 45 46 47 - * - * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 localX = threadId % localWidth; (for globalThreadId 33 = threadId = 01 : 01%4 =1) - * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 - * ------------+-------------+------------ - * 00 01 02 03 | 00[01]02 03 | 00 01 02 03 - * 00 01 02 03 | 00 01 02 03 | 00 01 02 03 - * - * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 localY = threadId /localWidth (for globalThreadId 33 = threadId = 01 : 01/4 =0) - * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 - * ------------+-------------+------------ - * 00 00 00 00 | 00[00]00 00 | 00 00 00 00 - * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 - * - * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 globalX= - * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 groupsPerLineWidth=globalWidth/localWidth (=12/4 =3) - * ------------+-------------+------------ groupInset =groupId%groupsPerLineWidth (=4%3 = 1) - * 00 01 02 03 | 04[05]06 07 | 08 09 10 11 - * 00 01 02 03 | 04 05 06 07 | 08 09 10 11 globalX = groupInset*localWidth+localX (= 1*4+1 = 5) - * - * 00 00 00 00 | 00 00 00 00 | 00 00 00 00 globalY - * 01 01 01 01 | 01 01 01 01 | 01 01 01 01 - * ------------+-------------+------------ - * 02 02 02 02 | 02[02]02 02 | 02 02 02 02 - * 03 03 03 03 | 03 03 03 03 | 03 03 03 03 - * - * </pre> - * Assume we are trying to locate the id's for #33 - * - */ - - kernelState.setLocalId(0, (threadId % _range.getLocalSize(0))); // threadId % localWidth = (for 33 = 1 % 4 = 1) - kernelState.setLocalId(1, (threadId / _range.getLocalSize(0))); // threadId / localWidth = (for 33 = 1 / 4 == 0) - - final int groupInset = globalGroupId % _range.getNumGroups(0); // 4%3 = 1 - kernelState.setGlobalId(0, ((groupInset * _range.getLocalSize(0)) + kernelState.getLocalIds()[0])); // 1*4+1=5 - - final int completeLines = (globalGroupId / _range.getNumGroups(0)) * _range.getLocalSize(1);// (4/3) * 2 - kernelState.setGlobalId(1, (completeLines + kernelState.getLocalIds()[1])); // 2+0 = 2 - kernelState.setGroupId(0, (globalGroupId % _range.getNumGroups(0))); - kernelState.setGroupId(1, (globalGroupId / _range.getNumGroups(0))); - } else if (_range.getDims() == 3) { - - //Same as 2D actually turns out that localId[0] is identical for all three dims so could be hoisted out of conditional code - - kernelState.setLocalId(0, (threadId % _range.getLocalSize(0))); - - kernelState.setLocalId(1, ((threadId / _range.getLocalSize(0)) % _range.getLocalSize(1))); - - // the thread id's span WxHxD so threadId/(WxH) should yield the local depth - kernelState.setLocalId(2, (threadId / (_range.getLocalSize(0) * _range.getLocalSize(1)))); - - kernelState.setGlobalId( - 0, - (((globalGroupId % _range.getNumGroups(0)) * _range.getLocalSize(0)) + kernelState.getLocalIds()[0])); - - kernelState.setGlobalId( - 1, - ((((globalGroupId / _range.getNumGroups(0)) * _range.getLocalSize(1)) % _range.getGlobalSize(1)) + kernelState - .getLocalIds()[1])); - - kernelState.setGlobalId( - 2, - (((globalGroupId / (_range.getNumGroups(0) * _range.getNumGroups(1))) * _range.getLocalSize(2)) + kernelState - .getLocalIds()[2])); - - kernelState.setGroupId(0, (globalGroupId % _range.getNumGroups(0))); - kernelState.setGroupId(1, ((globalGroupId / _range.getNumGroups(0)) % _range.getNumGroups(1))); - kernelState.setGroupId(2, (globalGroupId / (_range.getNumGroups(0) * _range.getNumGroups(1)))); - } - - kernelClone.run(); - } + if (threads == 1) { + kernelState.disableLocalBarrier(); + } else { + kernelState.setLocalBarrier(localBarrier); + } - await(joinBarrier); // This thread will rendezvous with dispatch thread here. This is effectively a join. - } - }); + threadPool.submit( + // () -> { + new Runnable(){ + public void run() { + try { + for (int globalGroupId = 0; globalGroupId < globalGroups; globalGroupId++) { + threadIdSetter.set(kernelState, globalGroupId, threadId); + kernelClone.run(); + } + } catch (RuntimeException | Error e) { + logger.log(Level.SEVERE, "Execution failed", e); + } finally { + await(joinBarrier); // This thread will rendezvous with dispatch thread here. This is effectively a join. + } + } + }); } await(joinBarrier); // This dispatch thread waits for all worker threads here. @@ -519,7 +600,7 @@ public class KernelRunner extends KernelRunnerJNI{ boolean didReallocate = false; if (arg.getObjArrayElementModel() == null) { - final String tmp = arrayClass.getName().substring(2).replace("/", "."); + final String tmp = arrayClass.getName().substring(2).replace('/', '.'); final String arrayClassInDotForm = tmp.substring(0, tmp.length() - 1); if (logger.isLoggable(Level.FINE)) { @@ -786,7 +867,8 @@ public class KernelRunner extends KernelRunnerJNI{ } if (privateMemorySize != null) { if (arrayLength > privateMemorySize) { - throw new IllegalStateException("__private array field " + fieldName + " has illegal length " + arrayLength + " > " + privateMemorySize); + throw new IllegalStateException("__private array field " + fieldName + " has illegal length " + arrayLength + + " > " + privateMemorySize); } } @@ -943,7 +1025,7 @@ public class KernelRunner extends KernelRunnerJNI{ if ((device == null) || (device instanceof OpenCLDevice)) { if (entryPoint == null) { try { - final ClassModel classModel = new ClassModel(kernel.getClass()); + final ClassModel classModel = ClassModel.createClassModel(kernel.getClass()); entryPoint = classModel.getEntrypoint(_entrypointName, kernel); } catch (final Exception exception) { return warnFallBackAndExecute(_entrypointName, _range, _passes, exception); @@ -1079,13 +1161,9 @@ public class KernelRunner extends KernelRunnerJNI{ | (entryPoint.getArrayFieldAccesses().contains(field.getName()) ? ARG_READ : 0)); // args[i].type |= ARG_GLOBAL; - if (type.getName().startsWith("[L")) { args[i].setType(args[i].getType() - | (ARG_OBJ_ARRAY_STRUCT | - ARG_WRITE | - ARG_READ | - ARG_APARAPI_BUFFER)); + | (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ | ARG_APARAPI_BUFFER)); if (logger.isLoggable(Level.FINE)) { logger.fine("tagging " + args[i].getName() + " as (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ)"); @@ -1094,8 +1172,9 @@ public class KernelRunner extends KernelRunnerJNI{ try { setMultiArrayType(args[i], type); - } catch(AparapiException e) { - return warnFallBackAndExecute(_entrypointName, _range, _passes, "failed to set kernel arguement " + args[i].getName() + ". Aparapi only supports 2D and 3D arrays."); + } catch (AparapiException e) { + return warnFallBackAndExecute(_entrypointName, _range, _passes, "failed to set kernel arguement " + + args[i].getName() + ". Aparapi only supports 2D and 3D arrays."); } } else { @@ -1120,7 +1199,8 @@ public class KernelRunner extends KernelRunnerJNI{ if (type.getName().startsWith("[L")) { args[i].setType(args[i].getType() | (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ)); if (logger.isLoggable(Level.FINE)) { - logger.fine("tagging " + args[i].getName() + " as (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ)"); + logger.fine("tagging " + args[i].getName() + + " as (ARG_OBJ_ARRAY_STRUCT | ARG_WRITE | ARG_READ)"); } } } @@ -1206,7 +1286,6 @@ public class KernelRunner extends KernelRunnerJNI{ return kernel; } - private int getPrimitiveSize(int type) { if ((type & ARG_FLOAT) != 0) { return 4; @@ -1231,28 +1310,28 @@ public class KernelRunner extends KernelRunnerJNI{ private void setMultiArrayType(KernelArg arg, Class<?> type) throws AparapiException { arg.setType(arg.getType() | (ARG_WRITE | ARG_READ | ARG_APARAPI_BUFFER)); int numDims = 0; - while(type.getName().startsWith("[[[[")) { + while (type.getName().startsWith("[[[[")) { throw new AparapiException("Aparapi only supports 2D and 3D arrays."); } arg.setType(arg.getType() | ARG_ARRAYLENGTH); - while(type.getName().charAt(numDims) == '[') { + while (type.getName().charAt(numDims) == '[') { numDims++; } Object buffer = new Object(); try { buffer = arg.getField().get(kernel); - } catch(IllegalAccessException e) { + } catch (IllegalAccessException e) { e.printStackTrace(); } arg.setJavaBuffer(buffer); arg.setNumDims(numDims); Object subBuffer = buffer; int[] dims = new int[numDims]; - for(int i = 0; i < numDims-1; i++) { + for (int i = 0; i < numDims - 1; i++) { dims[i] = Array.getLength(subBuffer); subBuffer = Array.get(subBuffer, 0); } - dims[numDims-1] = Array.getLength(subBuffer); + dims[numDims - 1] = Array.getLength(subBuffer); arg.setDims(dims); if (subBuffer.getClass().isAssignableFrom(float[].class)) { @@ -1281,7 +1360,7 @@ public class KernelRunner extends KernelRunnerJNI{ } int primitiveSize = getPrimitiveSize(arg.getType()); int totalElements = 1; - for(int i = 0; i < numDims; i++) { + for (int i = 0; i < numDims; i++) { totalElements *= dims[i]; } arg.setSizeInBytes(totalElements * primitiveSize); diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/CacheEnabler.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/CacheEnabler.java new file mode 100644 index 0000000000000000000000000000000000000000..9308d57ed876ac88e9e99dc1e98e4c22fa1ad117 --- /dev/null +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/CacheEnabler.java @@ -0,0 +1,20 @@ +package com.amd.aparapi.internal.model; + +import com.amd.aparapi.Kernel; + +public class CacheEnabler{ + private static volatile boolean cachesEnabled; + + public static void setCachesEnabled(boolean cachesEnabled) { + if (CacheEnabler.cachesEnabled != cachesEnabled) { + Kernel.invalidateCaches(); + ClassModel.invalidateCaches(); + } + + CacheEnabler.cachesEnabled = cachesEnabled; + } + + public static boolean areCachesEnabled() { + return cachesEnabled; + } +} diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/ClassModel.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/ClassModel.java index e1b269fafea752e40943b6e98a0c9d3749f327d1..95c362a357f0b09911f7f433d426809073bd67ec 100644 --- a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/ClassModel.java +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/ClassModel.java @@ -41,6 +41,7 @@ import com.amd.aparapi.*; import com.amd.aparapi.internal.annotation.*; import com.amd.aparapi.internal.exception.*; import com.amd.aparapi.internal.instruction.InstructionSet.*; +import com.amd.aparapi.internal.model.ValueCache.ThrowingValueComputer; import com.amd.aparapi.internal.model.ClassModel.AttributePool.*; import com.amd.aparapi.internal.model.ClassModel.ConstantPool.*; import com.amd.aparapi.internal.reader.*; @@ -122,10 +123,31 @@ public class ClassModel{ private ClassModel superClazz = null; - private HashSet<String> noClMethods = null; + // private Memoizer<Set<String>> noClMethods = Memoizer.of(this::computeNoCLMethods); + private Memoizer<Set<String>> noClMethods = Memoizer.Impl.of(new Supplier<Set<String>>(){ + @Override + public Set<String> get() { + return computeNoCLMethods(); + } + }); - private HashMap<String, Kernel.PrivateMemorySpace> privateMemoryFields = null; + // private Memoizer<Map<String, Kernel.PrivateMemorySpace>> privateMemoryFields = Memoizer.of(this::computePrivateMemoryFields); + private Memoizer<Map<String, Kernel.PrivateMemorySpace>> privateMemoryFields = Memoizer.Impl + .of(new Supplier<Map<String, Kernel.PrivateMemorySpace>>(){ + @Override + public Map<String, Kernel.PrivateMemorySpace> get() { + return computePrivateMemoryFields(); + } + }); + // private ValueCache<String, Integer, ClassParseException> privateMemorySizes = ValueCache.on(this::computePrivateMemorySize); + private ValueCache<String, Integer, ClassParseException> privateMemorySizes = ValueCache + .on(new ThrowingValueComputer<String, Integer, ClassParseException>(){ + @Override + public Integer compute(String fieldName) throws ClassParseException { + return computePrivateMemorySize(fieldName); + } + }); /** * Create a ClassModel representing a given Class. @@ -137,7 +159,7 @@ public class ClassModel{ * @throws ClassParseException */ - public ClassModel(Class<?> _class) throws ClassParseException { + private ClassModel(Class<?> _class) throws ClassParseException { parse(_class); @@ -147,7 +169,7 @@ public class ClassModel{ // not occur in normal use if ((mySuper != null) && (!mySuper.getName().equals(Kernel.class.getName())) && (!mySuper.getName().equals("java.lang.Object"))) { - superClazz = new ClassModel(mySuper); + superClazz = createClassModel(mySuper); } } @@ -204,7 +226,8 @@ public class ClassModel{ return superClazz; } - @DocMe public void replaceSuperClazz(ClassModel c) { + @DocMe + public void replaceSuperClazz(ClassModel c) { if (superClazz != null) { assert c.isSuperClass(getClassWeAreModelling()) == true : "not my super"; if (superClazz.getClassWeAreModelling().getName().equals(c.getClassWeAreModelling().getName())) { @@ -260,32 +283,40 @@ public class ClassModel{ * If a field does not satisfy the private memory conditions, null, otherwise the size of private memory required. */ public Integer getPrivateMemorySize(String fieldName) throws ClassParseException { - if (privateMemoryFields == null) { - privateMemoryFields = new HashMap<String, Kernel.PrivateMemorySpace>(); - HashMap<Field, Kernel.PrivateMemorySpace> privateMemoryFields = new HashMap<Field, Kernel.PrivateMemorySpace>(); - for (Field field : getClassWeAreModelling().getDeclaredFields()) { - Kernel.PrivateMemorySpace privateMemorySpace = field.getAnnotation(Kernel.PrivateMemorySpace.class); - if (privateMemorySpace != null) { - privateMemoryFields.put(field, privateMemorySpace); - } - } - for (Field field : getClassWeAreModelling().getFields()) { - Kernel.PrivateMemorySpace privateMemorySpace = field.getAnnotation(Kernel.PrivateMemorySpace.class); - if (privateMemorySpace != null) { - privateMemoryFields.put(field, privateMemorySpace); - } - } - for (Map.Entry<Field, Kernel.PrivateMemorySpace> entry : privateMemoryFields.entrySet()) { - this.privateMemoryFields.put(entry.getKey().getName(), entry.getValue()); - } - } - Kernel.PrivateMemorySpace annotation = privateMemoryFields.get(fieldName); + if (CacheEnabler.areCachesEnabled()) + return privateMemorySizes.computeIfAbsent(fieldName); + return computePrivateMemorySize(fieldName); + } + + private Integer computePrivateMemorySize(String fieldName) throws ClassParseException { + Kernel.PrivateMemorySpace annotation = privateMemoryFields.get().get(fieldName); if (annotation != null) { return annotation.value(); } return getPrivateMemorySizeFromFieldName(fieldName); } + private Map<String, Kernel.PrivateMemorySpace> computePrivateMemoryFields() { + Map<String, Kernel.PrivateMemorySpace> tempPrivateMemoryFields = new HashMap<String, Kernel.PrivateMemorySpace>(); + Map<Field, Kernel.PrivateMemorySpace> privateMemoryFields = new HashMap<Field, Kernel.PrivateMemorySpace>(); + for (Field field : getClassWeAreModelling().getDeclaredFields()) { + Kernel.PrivateMemorySpace privateMemorySpace = field.getAnnotation(Kernel.PrivateMemorySpace.class); + if (privateMemorySpace != null) { + privateMemoryFields.put(field, privateMemorySpace); + } + } + for (Field field : getClassWeAreModelling().getFields()) { + Kernel.PrivateMemorySpace privateMemorySpace = field.getAnnotation(Kernel.PrivateMemorySpace.class); + if (privateMemorySpace != null) { + privateMemoryFields.put(field, privateMemorySpace); + } + } + for (Map.Entry<Field, Kernel.PrivateMemorySpace> entry : privateMemoryFields.entrySet()) { + tempPrivateMemoryFields.put(entry.getKey().getName(), entry.getValue()); + } + return tempPrivateMemoryFields; + } + public static Integer getPrivateMemorySizeFromField(Field field) { Kernel.PrivateMemorySpace privateMemorySpace = field.getAnnotation(Kernel.PrivateMemorySpace.class); if (privateMemorySpace != null) { @@ -309,25 +340,26 @@ public class ClassModel{ } public Set<String> getNoCLMethods() { - if (this.noClMethods == null) { - noClMethods = new HashSet<String>(); - HashSet<Method> methods = new HashSet<Method>(); - for (Method method : getClassWeAreModelling().getDeclaredMethods()) { - if (method.getAnnotation(Kernel.NoCL.class) != null) { - methods.add(method); - } - } - for (Method method : getClassWeAreModelling().getMethods()) { - if (method.getAnnotation(Kernel.NoCL.class) != null) { - methods.add(method); - } + return computeNoCLMethods(); + } + + private Set<String> computeNoCLMethods() { + Set<String> tempNoClMethods = new HashSet<String>(); + HashSet<Method> methods = new HashSet<Method>(); + for (Method method : getClassWeAreModelling().getDeclaredMethods()) { + if (method.getAnnotation(Kernel.NoCL.class) != null) { + methods.add(method); } - for (Method method : methods) { - noClMethods.add(method.getName()); + } + for (Method method : getClassWeAreModelling().getMethods()) { + if (method.getAnnotation(Kernel.NoCL.class) != null) { + methods.add(method); } - } - return noClMethods; + for (Method method : methods) { + tempNoClMethods.add(method.getName()); + } + return tempNoClMethods; } public static String convert(String _string) { @@ -602,6 +634,21 @@ public class ClassModel{ return (methodDescription); } + // private static final ValueCache<Class<?>, ClassModel, ClassParseException> classModelCache = ValueCache.onIdentity(ClassModel::new); + private static final ValueCache<Class<?>, ClassModel, ClassParseException> classModelCache = ValueCache + .on(new ThrowingValueComputer<Class<?>, ClassModel, ClassParseException>(){ + @Override + public ClassModel compute(Class<?> key) throws ClassParseException { + return new ClassModel(key); + } + }); + + public static ClassModel createClassModel(Class<?> _class) throws ClassParseException { + if (CacheEnabler.areCachesEnabled()) + return classModelCache.computeIfAbsent(_class); + return new ClassModel(_class); + } + private int magic; private int minorVersion; @@ -813,7 +860,8 @@ public class ClassModel{ super(_byteReader, _slot, ConstantPoolType.METHOD); } - @Override public String toString() { + @Override + public String toString() { final StringBuilder sb = new StringBuilder(); sb.append(getClassEntry().getNameUTF8Entry().getUTF8()); sb.append("."); @@ -934,13 +982,15 @@ public class ClassModel{ private Type returnType = null; - @Override public int hashCode() { + @Override + public int hashCode() { final NameAndTypeEntry nameAndTypeEntry = getNameAndTypeEntry(); return ((((nameAndTypeEntry.getNameIndex() * 31) + nameAndTypeEntry.getDescriptorIndex()) * 31) + getClassIndex()); } - @Override public boolean equals(Object _other) { + @Override + public boolean equals(Object _other) { if ((_other == null) || !(_other instanceof MethodReferenceEntry)) { return (false); } else { @@ -1312,7 +1362,8 @@ public class ClassModel{ } - @Override public Iterator<Entry> iterator() { + @Override + public Iterator<Entry> iterator() { return (entries.iterator()); } @@ -1399,58 +1450,51 @@ public class ClassModel{ } else if (_entry instanceof ConstantPool.NameAndTypeEntry) { final ConstantPool.NameAndTypeEntry nameAndTypeEntry = (ConstantPool.NameAndTypeEntry) _entry; references = new int[] { - nameAndTypeEntry.getNameIndex(), - nameAndTypeEntry.getDescriptorIndex() + nameAndTypeEntry.getNameIndex(), nameAndTypeEntry.getDescriptorIndex() }; } else if (_entry instanceof ConstantPool.MethodEntry) { final ConstantPool.MethodEntry methodEntry = (ConstantPool.MethodEntry) _entry; final ConstantPool.ClassEntry classEntry = (ConstantPool.ClassEntry) get(methodEntry.getClassIndex()); - @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8Entry = (ConstantPool.UTF8Entry) get(classEntry - .getNameIndex()); + @SuppressWarnings("unused") + final ConstantPool.UTF8Entry utf8Entry = (ConstantPool.UTF8Entry) get(classEntry.getNameIndex()); final ConstantPool.NameAndTypeEntry nameAndTypeEntry = (ConstantPool.NameAndTypeEntry) get(methodEntry .getNameAndTypeIndex()); - @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8NameEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry - .getNameIndex()); - @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8DescriptorEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry - .getDescriptorIndex()); + @SuppressWarnings("unused") + final ConstantPool.UTF8Entry utf8NameEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry.getNameIndex()); + @SuppressWarnings("unused") + final ConstantPool.UTF8Entry utf8DescriptorEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry.getDescriptorIndex()); references = new int[] { - methodEntry.getClassIndex(), - classEntry.getNameIndex(), - nameAndTypeEntry.getNameIndex(), + methodEntry.getClassIndex(), classEntry.getNameIndex(), nameAndTypeEntry.getNameIndex(), nameAndTypeEntry.getDescriptorIndex() }; } else if (_entry instanceof ConstantPool.InterfaceMethodEntry) { final ConstantPool.InterfaceMethodEntry interfaceMethodEntry = (ConstantPool.InterfaceMethodEntry) _entry; final ConstantPool.ClassEntry classEntry = (ConstantPool.ClassEntry) get(interfaceMethodEntry.getClassIndex()); - @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8Entry = (ConstantPool.UTF8Entry) get(classEntry - .getNameIndex()); + @SuppressWarnings("unused") + final ConstantPool.UTF8Entry utf8Entry = (ConstantPool.UTF8Entry) get(classEntry.getNameIndex()); final ConstantPool.NameAndTypeEntry nameAndTypeEntry = (ConstantPool.NameAndTypeEntry) get(interfaceMethodEntry .getNameAndTypeIndex()); - @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8NameEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry - .getNameIndex()); - @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8DescriptorEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry - .getDescriptorIndex()); + @SuppressWarnings("unused") + final ConstantPool.UTF8Entry utf8NameEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry.getNameIndex()); + @SuppressWarnings("unused") + final ConstantPool.UTF8Entry utf8DescriptorEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry.getDescriptorIndex()); references = new int[] { - interfaceMethodEntry.getClassIndex(), - classEntry.getNameIndex(), - nameAndTypeEntry.getNameIndex(), + interfaceMethodEntry.getClassIndex(), classEntry.getNameIndex(), nameAndTypeEntry.getNameIndex(), nameAndTypeEntry.getDescriptorIndex() }; } else if (_entry instanceof ConstantPool.FieldEntry) { final ConstantPool.FieldEntry fieldEntry = (ConstantPool.FieldEntry) _entry; final ConstantPool.ClassEntry classEntry = (ConstantPool.ClassEntry) get(fieldEntry.getClassIndex()); - @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8Entry = (ConstantPool.UTF8Entry) get(classEntry - .getNameIndex()); + @SuppressWarnings("unused") + final ConstantPool.UTF8Entry utf8Entry = (ConstantPool.UTF8Entry) get(classEntry.getNameIndex()); final ConstantPool.NameAndTypeEntry nameAndTypeEntry = (ConstantPool.NameAndTypeEntry) get(fieldEntry .getNameAndTypeIndex()); - @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8NameEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry - .getNameIndex()); - @SuppressWarnings("unused") final ConstantPool.UTF8Entry utf8DescriptorEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry - .getDescriptorIndex()); + @SuppressWarnings("unused") + final ConstantPool.UTF8Entry utf8NameEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry.getNameIndex()); + @SuppressWarnings("unused") + final ConstantPool.UTF8Entry utf8DescriptorEntry = (ConstantPool.UTF8Entry) get(nameAndTypeEntry.getDescriptorIndex()); references = new int[] { - fieldEntry.getClassIndex(), - classEntry.getNameIndex(), - nameAndTypeEntry.getNameIndex(), + fieldEntry.getClassIndex(), classEntry.getNameIndex(), nameAndTypeEntry.getNameIndex(), nameAndTypeEntry.getDescriptorIndex() }; } @@ -1581,7 +1625,8 @@ public class ClassModel{ codeEntryAttributePool = new AttributePool(_byteReader, getName()); } - @Override public AttributePool getAttributePool() { + @Override + public AttributePool getAttributePool() { return (codeEntryAttributePool); } @@ -1664,7 +1709,8 @@ public class ClassModel{ super(_byteReader, _nameIndex, _length); } - @Override public Iterator<T> iterator() { + @Override + public Iterator<T> iterator() { return (pool.iterator()); } } @@ -1848,27 +1894,33 @@ public class ClassModel{ return (variableNameIndex); } - @Override public int getStart() { + @Override + public int getStart() { return (start); } - @Override public int getVariableIndex() { + @Override + public int getVariableIndex() { return (variableIndex); } - @Override public String getVariableName() { + @Override + public String getVariableName() { return (constantPool.getUTF8Entry(variableNameIndex).getUTF8()); } - @Override public String getVariableDescriptor() { + @Override + public String getVariableDescriptor() { return (constantPool.getUTF8Entry(descriptorIndex).getUTF8()); } - @Override public int getEnd() { + @Override + public int getEnd() { return (start + usageLength); } - @Override public boolean isArray() { + @Override + public boolean isArray() { return (getVariableDescriptor().startsWith("[")); } } @@ -1969,7 +2021,8 @@ public class ClassModel{ return (bytes); } - @Override public String toString() { + @Override + public String toString() { return (new String(bytes)); } @@ -1987,7 +2040,8 @@ public class ClassModel{ return (bytes); } - @Override public String toString() { + @Override + public String toString() { return (new String(bytes)); } } @@ -2004,7 +2058,8 @@ public class ClassModel{ return (bytes); } - @Override public String toString() { + @Override + public String toString() { return (new String(bytes)); } } @@ -2094,9 +2149,11 @@ public class ClassModel{ } } - @SuppressWarnings("unused") private final int elementNameIndex; + @SuppressWarnings("unused") + private final int elementNameIndex; - @SuppressWarnings("unused") private Value value; + @SuppressWarnings("unused") + private Value value; public ElementValuePair(ByteReader _byteReader) { elementNameIndex = _byteReader.u2(); @@ -2626,12 +2683,23 @@ public class ClassModel{ } public ClassModelMethod getMethod(String _name, String _descriptor) { + ClassModelMethod methodOrNull = getMethodOrNull(_name, _descriptor); + if (methodOrNull == null) + return superClazz != null ? superClazz.getMethod(_name, _descriptor) : (null); + return methodOrNull; + } + + private ClassModelMethod getMethodOrNull(String _name, String _descriptor) { for (final ClassModelMethod entry : methods) { if (entry.getName().equals(_name) && entry.getDescriptor().equals(_descriptor)) { + if (logger.isLoggable(Level.FINE)) { + logger.fine("Found " + clazz.getName() + "." + entry.getName() + " " + entry.getDescriptor() + " for " + + _name.replace('/', '.')); + } return (entry); } } - return superClazz != null ? superClazz.getMethod(_name, _descriptor) : (null); + return null; } public List<ClassModelField> getFieldPoolEntries() { @@ -2647,7 +2715,9 @@ public class ClassModel{ * @return The Method or null if we fail to locate a given method. */ public ClassModelMethod getMethod(MethodEntry _methodEntry, boolean _isSpecial) { - final String entryClassNameInDotForm = _methodEntry.getClassEntry().getNameUTF8Entry().getUTF8().replace('/', '.'); + NameAndTypeEntry nameAndTypeEntry = _methodEntry.getNameAndTypeEntry(); + String utf8Name = nameAndTypeEntry.getNameUTF8Entry().getUTF8(); + final String entryClassNameInDotForm = utf8Name.replace('/', '.'); // Shortcut direct calls to supers to allow "foo() { super.foo() }" type stuff to work if (_isSpecial && (superClazz != null) && superClazz.isSuperClass(entryClassNameInDotForm)) { @@ -2658,20 +2728,21 @@ public class ClassModel{ return superClazz.getMethod(_methodEntry, false); } - for (final ClassModelMethod entry : methods) { - if (entry.getName().equals(_methodEntry.getNameAndTypeEntry().getNameUTF8Entry().getUTF8()) - && entry.getDescriptor().equals(_methodEntry.getNameAndTypeEntry().getDescriptorUTF8Entry().getUTF8())) { - if (logger.isLoggable(Level.FINE)) { - logger.fine("Found " + clazz.getName() + "." + entry.getName() + " " + entry.getDescriptor() + " for " - + entryClassNameInDotForm); - } - return (entry); - } - } - - return superClazz != null ? superClazz.getMethod(_methodEntry, false) : (null); + ClassModelMethod methodOrNull = getMethodOrNull(utf8Name, nameAndTypeEntry.getDescriptorUTF8Entry().getUTF8()); + if (methodOrNull == null) + return superClazz != null ? superClazz.getMethod(_methodEntry, false) : (null); + return methodOrNull; } + // private ValueCache<MethodKey, MethodModel, AparapiException> methodModelCache = ValueCache.on(this::computeMethodModel); + private ValueCache<MethodKey, MethodModel, AparapiException> methodModelCache = ValueCache + .on(new ThrowingValueComputer<MethodKey, MethodModel, AparapiException>(){ + @Override + public MethodModel compute(MethodKey key) throws AparapiException { + return computeMethodModel(key); + } + }); + /** * Create a MethodModel for a given method name and signature. * @@ -2680,9 +2751,17 @@ public class ClassModel{ * @return * @throws AparapiException */ - public MethodModel getMethodModel(String _name, String _signature) throws AparapiException { - final ClassModelMethod method = getMethod(_name, _signature); + if (CacheEnabler.areCachesEnabled()) + return methodModelCache.computeIfAbsent(MethodKey.of(_name, _signature)); + else { + final ClassModelMethod method = getMethod(_name, _signature); + return new MethodModel(method); + } + } + + private MethodModel computeMethodModel(MethodKey methodKey) throws AparapiException { + final ClassModelMethod method = getMethod(methodKey.getName(), methodKey.getSignature()); return new MethodModel(method); } @@ -2715,9 +2794,29 @@ public class ClassModel{ totalStructSize = x; } + // private final ValueCache<EntrypointKey, Entrypoint, AparapiException> entrypointCache = ValueCache.on(this::computeBasicEntrypoint); + private final ValueCache<EntrypointKey, Entrypoint, AparapiException> entrypointCache = ValueCache + .on(new ThrowingValueComputer<EntrypointKey, Entrypoint, AparapiException>(){ + @Override + public Entrypoint compute(EntrypointKey key) throws AparapiException { + return computeBasicEntrypoint(key); + } + }); + Entrypoint getEntrypoint(String _entrypointName, String _descriptor, Object _k) throws AparapiException { - final MethodModel method = getMethodModel(_entrypointName, _descriptor); - return (new Entrypoint(this, method, _k)); + if (CacheEnabler.areCachesEnabled()) { + EntrypointKey key = EntrypointKey.of(_entrypointName, _descriptor); + Entrypoint entrypointWithoutKernel = entrypointCache.computeIfAbsent(key); + return entrypointWithoutKernel.cloneForKernel(_k); + } else { + final MethodModel method = getMethodModel(_entrypointName, _descriptor); + return new Entrypoint(this, method, _k); + } + } + + Entrypoint computeBasicEntrypoint(EntrypointKey entrypointKey) throws AparapiException { + final MethodModel method = getMethodModel(entrypointKey.getEntrypointName(), entrypointKey.getDescriptor()); + return new Entrypoint(this, method, null); } public Class<?> getClassWeAreModelling() { @@ -2731,4 +2830,8 @@ public class ClassModel{ public Entrypoint getEntrypoint() throws AparapiException { return (getEntrypoint("run", "()V", null)); } + + public static void invalidateCaches() { + classModelCache.invalidate(); + } } diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/Entrypoint.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/Entrypoint.java index 28aadab21efa63eb18c63abc4930cc7ed942075f..7ae155efa905a1dca6cd88f39931977a6ea9317a 100644 --- a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/Entrypoint.java +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/Entrypoint.java @@ -50,7 +50,7 @@ import java.lang.reflect.*; import java.util.*; import java.util.logging.*; -public class Entrypoint{ +public class Entrypoint implements Cloneable { private static Logger logger = Logger.getLogger(Config.getLoggerName()); @@ -81,7 +81,7 @@ public class Entrypoint{ private final List<MethodModel> calledMethods = new ArrayList<MethodModel>(); - private MethodModel methodModel; + private final MethodModel methodModel; /** True is an indication to use the fp64 pragma @@ -226,7 +226,7 @@ public class Entrypoint{ final Class<?> memberClass = Class.forName(className); // Immediately add this class and all its supers if necessary - memberClassModel = new ClassModel(memberClass); + memberClassModel = ClassModel.createClassModel(memberClass); if (logger.isLoggable(Level.FINEST)) { logger.finest("adding class " + className); } @@ -595,7 +595,7 @@ public class Entrypoint{ // Add the class model for the referenced obj array if (signature.startsWith("[L")) { // Turn [Lcom/amd/javalabs/opencl/demo/DummyOOA; into com.amd.javalabs.opencl.demo.DummyOOA for example - final String className = (signature.substring(2, signature.length() - 1)).replace("/", "."); + final String className = (signature.substring(2, signature.length() - 1)).replace('/', '.'); final ClassModel arrayFieldModel = getOrUpdateAllClassAccesses(className); if (arrayFieldModel != null) { final Class<?> memberClass = arrayFieldModel.getClassWeAreModelling(); @@ -625,7 +625,7 @@ public class Entrypoint{ } } } else { - final String className = (field.getClassEntry().getNameUTF8Entry().getUTF8()).replace("/", "."); + final String className = (field.getClassEntry().getNameUTF8Entry().getUTF8()).replace('/', '.'); // Look for object data member access if (!className.equals(getClassModel().getClassWeAreModelling().getName()) && (getFieldFromClassHierarchy(getClassModel().getClassWeAreModelling(), accessedFieldName) == null)) { @@ -640,7 +640,7 @@ public class Entrypoint{ fieldAssignments.add(assignedFieldName); referencedFieldNames.add(assignedFieldName); - final String className = (field.getClassEntry().getNameUTF8Entry().getUTF8()).replace("/", "."); + final String className = (field.getClassEntry().getNameUTF8Entry().getUTF8()).replace('/', '.'); // Look for object data member access if (!className.equals(getClassModel().getClassWeAreModelling().getName()) && (getFieldFromClassHierarchy(getClassModel().getClassWeAreModelling(), assignedFieldName) == null)) { @@ -909,4 +909,14 @@ public class Entrypoint{ return null; } + + Entrypoint cloneForKernel(Object _k) throws AparapiException { + try { + Entrypoint clonedEntrypoint = (Entrypoint) clone(); + clonedEntrypoint.kernelInstance = _k; + return clonedEntrypoint; + } catch (CloneNotSupportedException e) { + throw new AparapiException(e); + } + } } diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/EntrypointKey.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/EntrypointKey.java new file mode 100644 index 0000000000000000000000000000000000000000..bddf99d7b8ca0b289b2add855dd96bd0179a0f2f --- /dev/null +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/EntrypointKey.java @@ -0,0 +1,57 @@ +package com.amd.aparapi.internal.model; + +final class EntrypointKey{ + public static EntrypointKey of(String entrypointName, String descriptor) { + return new EntrypointKey(entrypointName, descriptor); + } + + private String descriptor; + + private String entrypointName; + + private EntrypointKey(String entrypointName, String descriptor) { + this.entrypointName = entrypointName; + this.descriptor = descriptor; + } + + String getDescriptor() { + return descriptor; + } + + String getEntrypointName() { + return entrypointName; + } + + @Override public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((descriptor == null) ? 0 : descriptor.hashCode()); + result = prime * result + ((entrypointName == null) ? 0 : entrypointName.hashCode()); + return result; + } + + @Override public String toString() { + return "EntrypointKey [entrypointName=" + entrypointName + ", descriptor=" + descriptor + "]"; + } + + @Override public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + EntrypointKey other = (EntrypointKey) obj; + if (descriptor == null) { + if (other.descriptor != null) + return false; + } else if (!descriptor.equals(other.descriptor)) + return false; + if (entrypointName == null) { + if (other.entrypointName != null) + return false; + } else if (!entrypointName.equals(other.entrypointName)) + return false; + return true; + } +} diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/MethodKey.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/MethodKey.java new file mode 100644 index 0000000000000000000000000000000000000000..d08e2debac32c97fddf209ce295e471df44dd29e --- /dev/null +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/model/MethodKey.java @@ -0,0 +1,57 @@ +package com.amd.aparapi.internal.model; + +final class MethodKey{ + static MethodKey of(String name, String signature) { + return new MethodKey(name, signature); + } + + private final String name; + + private final String signature; + + @Override public String toString() { + return "MethodKey [name=" + getName() + ", signature=" + getSignature() + "]"; + } + + @Override public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((getName() == null) ? 0 : getName().hashCode()); + result = prime * result + ((getSignature() == null) ? 0 : getSignature().hashCode()); + return result; + } + + @Override public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + MethodKey other = (MethodKey) obj; + if (getName() == null) { + if (other.getName() != null) + return false; + } else if (!getName().equals(other.getName())) + return false; + if (getSignature() == null) { + if (other.getSignature() != null) + return false; + } else if (!getSignature().equals(other.getSignature())) + return false; + return true; + } + + private MethodKey(String name, String signature) { + this.name = name; + this.signature = signature; + } + + public String getName() { + return name; + } + + public String getSignature() { + return signature; + } +} diff --git a/com.amd.aparapi/src/java/com/amd/aparapi/internal/tool/InstructionViewer.java b/com.amd.aparapi/src/java/com/amd/aparapi/internal/tool/InstructionViewer.java index 6e80defbadf45301bce302cb474343a269914660..015cf46d7492b83b5d914c77e3b206006cbe32ae 100644 --- a/com.amd.aparapi/src/java/com/amd/aparapi/internal/tool/InstructionViewer.java +++ b/com.amd.aparapi/src/java/com/amd/aparapi/internal/tool/InstructionViewer.java @@ -625,7 +625,7 @@ public class InstructionViewer implements Config.InstructionListener{ public InstructionViewer(Color _background, String _name) { try { - classModel = new ClassModel(Class.forName(_name)); + classModel = ClassModel.createClassModel(Class.forName(_name)); } catch (final ClassParseException e) { // TODO Auto-generated catch block e.printStackTrace();