From fc7029944f8e5c603dc41f7ee75d5a189327e2fe Mon Sep 17 00:00:00 2001 From: Luis Mendes <luis.p.mendes@gmail.com> Date: Sat, 14 Apr 2018 16:58:55 +0100 Subject: [PATCH] Update: Add support for Local arguments in kernel functions (refs #79) --- CONTRIBUTORS.md | 2 + .../aparapi/internal/model/ClassModel.java | 166 ++++++++++++++++++ .../aparapi/internal/writer/KernelWriter.java | 33 +++- .../runtime/LocalArrayArgsIssue79Test.java | 113 ++++++++++++ 4 files changed, 311 insertions(+), 3 deletions(-) create mode 100644 src/test/java/com/aparapi/runtime/LocalArrayArgsIssue79Test.java diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index edc9e67b..47890903 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -21,6 +21,7 @@ * AMD Corporation * Dmitriy Shabanov <shabanovd@gmail.com> * Toon Baeyens <toon.baeyens@gmail.com> +* Luis Mendes <luis.p.mendes@gmail.com> # Details @@ -40,3 +41,4 @@ Below are some of the specific details of various contributions. * lgalluci for his fix for issue #121 (incorrect toString for 3D ranges) July 6th 2013 * George Vinokhodov submited a fix for a bug regarding forward references. * Dmitriy Shabanov submited PR for inline array feature. +* Luis Mendes submited PR to support passing functions arguments containing Local arrays - issue #79 diff --git a/src/main/java/com/aparapi/internal/model/ClassModel.java b/src/main/java/com/aparapi/internal/model/ClassModel.java index db20a86b..73415a26 100644 --- a/src/main/java/com/aparapi/internal/model/ClassModel.java +++ b/src/main/java/com/aparapi/internal/model/ClassModel.java @@ -2240,6 +2240,161 @@ public class ClassModel { } + public class RuntimeParameterAnnotationsEntry extends PoolEntry<RuntimeParameterAnnotationsEntry.ParameterInfo>{ + public class ParameterInfo { + private final int methodArgumentIndex; + private final List<AnnotationInfo> annotations; + + public class AnnotationInfo { + private final int typeIndex; + private final int elementValuePairCount; + private final int methodArgumentIndex; + + public class ElementValuePair{ + class Value { + Value(int _tag) { + tag = _tag; + } + + int tag; + } + + public class PrimitiveValue extends Value{ + private final int typeNameIndex; + + private final int constNameIndex; + + public PrimitiveValue(int _tag, ByteReader _byteReader) { + super(_tag); + typeNameIndex = _byteReader.u2(); + //constNameIndex = _byteReader.u2(); + constNameIndex = 0; + } + + public int getConstNameIndex() { + return (constNameIndex); + } + + public int getTypeNameIndex() { + return (typeNameIndex); + } + } + + public class EnumValue extends Value{ + EnumValue(int _tag, ByteReader _byteReader) { + super(_tag); + } + } + + public class ArrayValue extends Value{ + ArrayValue(int _tag, ByteReader _byteReader) { + super(_tag); + } + } + + public class ClassValue extends Value{ + ClassValue(int _tag, ByteReader _byteReader) { + super(_tag); + } + } + + public class AnnotationValue extends Value{ + AnnotationValue(int _tag, ByteReader _byteReader) { + super(_tag); + } + } + + @SuppressWarnings("unused") + private final int elementNameIndex; + + @SuppressWarnings("unused") + private Value value; + + public ElementValuePair(ByteReader _byteReader) { + elementNameIndex = _byteReader.u2(); + final int tag = _byteReader.u1(); + + switch (tag) { + case SIGC_BYTE: + case SIGC_CHAR: + case SIGC_INT: + case SIGC_LONG: + case SIGC_DOUBLE: + case SIGC_FLOAT: + case SIGC_SHORT: + case SIGC_BOOLEAN: + case 's': // special for String + value = new PrimitiveValue(tag, _byteReader); + break; + case 'e': // special for Enum + value = new EnumValue(tag, _byteReader); + break; + case 'c': // special for class + value = new ClassValue(tag, _byteReader); + break; + case '@': // special for Annotation + value = new AnnotationValue(tag, _byteReader); + break; + case 'a': // special for array + value = new ArrayValue(tag, _byteReader); + break; + } + } + } + + private final ElementValuePair[] elementValuePairs; + + public AnnotationInfo(ByteReader _byteReader, int argumentIndex) { + methodArgumentIndex = argumentIndex; + typeIndex = _byteReader.u2(); + elementValuePairCount = _byteReader.u2(); + elementValuePairs = new ElementValuePair[elementValuePairCount]; + for (int i = 0; i < elementValuePairCount; i++) { + elementValuePairs[i] = new ElementValuePair(_byteReader); + } + } + + public int getMethodArgumentIndex() { + return methodArgumentIndex; + } + + public int getTypeIndex() { + return (typeIndex); + } + + public String getTypeDescriptor() { + return (constantPool.getUTF8Entry(typeIndex).getUTF8()); + } + } + + public ParameterInfo(ByteReader _byteReader, int argumentIndex) { + methodArgumentIndex = argumentIndex; + final int numberOfAnnotations = _byteReader.u2(); + annotations = new ArrayList<AnnotationInfo>(numberOfAnnotations); + for (int i = 0; i < numberOfAnnotations; i++) { + annotations.add(new AnnotationInfo(_byteReader, argumentIndex)); + } + } + + public int getMethodArgumentIndex() { + return methodArgumentIndex; + } + + public List<AnnotationInfo> getAnnotations() { + return annotations; + } + } + + //See https://docs.oracle.com/javase/specs/jvms/se7/html/jvms-4.html#jvms-4.7.18 + public RuntimeParameterAnnotationsEntry(ByteReader _byteReader, int _nameIndex, int _length) { + super(_byteReader, _nameIndex, _length); + final int numberOfParameters = _byteReader.u1(); + for (int paramIndex = 0; paramIndex < numberOfParameters; paramIndex++) { + getPool().add(new ParameterInfo(_byteReader, paramIndex)); + } + } + } + private CodeEntry codeEntry = null; private EnclosingMethodEntry enclosingMethodEntry = null; @@ -2253,6 +2408,8 @@ public class ClassModel { private LocalVariableTableEntry localVariableTableEntry = null; private RuntimeAnnotationsEntry runtimeVisibleAnnotationsEntry; + + private RuntimeParameterAnnotationsEntry runtimeVisibleParameterAnnotationsEntry; private RuntimeAnnotationsEntry runtimeInvisibleAnnotationsEntry; @@ -2285,6 +2442,8 @@ public class ClassModel { private final static String SIGNATURE_TAG = "Signature"; private final static String RUNTIMEINVISIBLEANNOTATIONS_TAG = "RuntimeInvisibleAnnotations"; + + private final static String RUNTIMEVISIBLEPARAMETERANNOTATIONS_TAG = "RuntimeVisibleParameterAnnotations"; private final static String RUNTIMEVISIBLEANNOTATIONS_TAG = "RuntimeVisibleAnnotations"; @@ -2341,6 +2500,9 @@ public class ClassModel { } else if (attributeName.equals(RUNTIMEVISIBLEANNOTATIONS_TAG)) { runtimeVisibleAnnotationsEntry = new RuntimeAnnotationsEntry(_byteReader, attributeNameIndex, length); entry = runtimeVisibleAnnotationsEntry; + } else if (attributeName.equals(RUNTIMEVISIBLEPARAMETERANNOTATIONS_TAG)) { + runtimeVisibleParameterAnnotationsEntry = new RuntimeParameterAnnotationsEntry(_byteReader, attributeNameIndex, length); + entry = runtimeVisibleParameterAnnotationsEntry; } else if (attributeName.equals(BOOTSTRAPMETHODS_TAG)) { bootstrapMethodsEntry = new BootstrapMethodsEntry(_byteReader, attributeNameIndex, length); entry = bootstrapMethodsEntry; @@ -2393,6 +2555,10 @@ public class ClassModel { return (runtimeInvisibleAnnotationsEntry); } + public RuntimeParameterAnnotationsEntry getRuntimeVisibleParameterAnnotationsEntry() { + return (runtimeVisibleParameterAnnotationsEntry); + } + public RuntimeAnnotationsEntry getRuntimeVisibleAnnotationsEntry() { return (runtimeVisibleAnnotationsEntry); } diff --git a/src/main/java/com/aparapi/internal/writer/KernelWriter.java b/src/main/java/com/aparapi/internal/writer/KernelWriter.java index 8099c591..31d6629d 100644 --- a/src/main/java/com/aparapi/internal/writer/KernelWriter.java +++ b/src/main/java/com/aparapi/internal/writer/KernelWriter.java @@ -61,6 +61,7 @@ import com.aparapi.internal.instruction.InstructionSet.*; import com.aparapi.internal.model.*; import com.aparapi.internal.model.ClassModel.AttributePool.*; import com.aparapi.internal.model.ClassModel.AttributePool.RuntimeAnnotationsEntry.*; +import com.aparapi.internal.model.ClassModel.AttributePool.RuntimeParameterAnnotationsEntry.ParameterInfo; import com.aparapi.internal.model.ClassModel.*; import com.aparapi.internal.model.ClassModel.ConstantPool.*; @@ -628,7 +629,10 @@ public abstract class KernelWriter extends BlockWriter{ boolean alreadyHasFirstArg = !mm.getMethod().isStatic(); + final RuntimeParameterAnnotationsEntry parameterAnnotations = + mm.getMethod().getAttributePool().getRuntimeVisibleParameterAnnotationsEntry(); final LocalVariableTableEntry<LocalVariableInfo> lvte = mm.getLocalVariableTableEntry(); + int localVariableIndex = 0; for (final LocalVariableInfo lvi : lvte) { if ((lvi.getStart() == 0) && ((lvi.getVariableIndex() != 0) || mm.getMethod().isStatic())) { // full scope but skip this final String descriptor = lvi.getVariableDescriptor(); @@ -636,13 +640,36 @@ public abstract class KernelWriter extends BlockWriter{ write(", "); } - if (descriptor.startsWith("[") && !lvi.getVariableName().endsWith(PRIVATE_SUFFIX)) { - write(" __global "); + if (descriptor.startsWith("[")) { + boolean isPrivate = false; + boolean isLocal = false; + if(lvi.getVariableName().endsWith(PRIVATE_SUFFIX)) { + isPrivate = true; + } else if (lvi.getVariableName().endsWith(Kernel.LOCAL_SUFFIX)) { + isLocal = true; + } else if (parameterAnnotations != null) { + ParameterInfo paramInfo = parameterAnnotations.getPool().get(localVariableIndex); + List<ParameterInfo.AnnotationInfo> paramAnnotations = paramInfo.getAnnotations(); + for (ParameterInfo.AnnotationInfo annotation : paramAnnotations) { + if (annotation.getTypeDescriptor().equals(LOCAL_ANNOTATION_NAME)) { + isLocal = true; + break; + } + } + } + + if (isLocal) { + write(" __local "); + } else if (!isPrivate) { + write(" __global "); + } } - + write(convertType(descriptor, true, false)); write(lvi.getVariableName()); alreadyHasFirstArg = true; + + localVariableIndex++; } } write(")"); diff --git a/src/test/java/com/aparapi/runtime/LocalArrayArgsIssue79Test.java b/src/test/java/com/aparapi/runtime/LocalArrayArgsIssue79Test.java new file mode 100644 index 00000000..6dc014ff --- /dev/null +++ b/src/test/java/com/aparapi/runtime/LocalArrayArgsIssue79Test.java @@ -0,0 +1,113 @@ +/** + * Copyright (c) 2016 - 2017 Syncleus, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.aparapi.runtime; + +import static org.junit.Assert.assertArrayEquals; + +import com.aparapi.Kernel; +import com.aparapi.Range; +import com.aparapi.device.Device; +import com.aparapi.device.OpenCLDevice; +import com.aparapi.internal.kernel.KernelManager; +import static org.junit.Assume.*; +import org.junit.BeforeClass; +import org.junit.Test; + +public class LocalArrayArgsIssue79Test { + static OpenCLDevice openCLDevice = null; + private static final int SIZE = 32; + private int[] targetArray; + + @BeforeClass + public static void setUpBeforeClass() throws Exception { + Device device = KernelManager.instance().bestDevice(); + assumeTrue (device != null && device instanceof OpenCLDevice); + openCLDevice = (OpenCLDevice) device; + } + + @Test + public void test() { + final LocalArrayArgsKernel kernel = new LocalArrayArgsKernel(); + final Range range = openCLDevice.createRange(SIZE, SIZE); + targetArray = new int[SIZE]; + kernel.setExplicit(false); + kernel.setArray(targetArray); + kernel.execute(range); + validate(); + } + + @Test + public void testExplicit() { + final LocalArrayArgsKernel kernel = new LocalArrayArgsKernel(); + final Range range = openCLDevice.createRange(SIZE, SIZE); + targetArray = new int[SIZE]; + kernel.setExplicit(true); + kernel.setArray(targetArray); + kernel.put(targetArray); + kernel.execute(range); + kernel.get(targetArray); + validate(); + } + + void validate() { + int[] expected = new int[SIZE]; + for (int threadId = 0; threadId < SIZE; threadId++) { + for (int i = 0; i < SIZE; i++) { + expected[threadId] += i + threadId; + } + expected[threadId] *= threadId; + } + + assertArrayEquals("destArray", expected, targetArray); + } + + public static class LocalArrayArgsKernel extends Kernel { + private int[] destArray; + + @Local + private int[] myArray = new int[SIZE]; + + public LocalArrayArgsKernel() { + } + + @NoCL + public void setArray(int[] target) { + destArray = target; + } + + private void doComputation1(@Local int[] arr, int id) { + for (int i = 0; i < SIZE; i++) { + arr[id] += i + id; + } + } + + private void doComputation2(int[] arr_$local$, int id) { + arr_$local$[id] *= id; + } + + @Override + public void run() { + int id = getLocalId(); + + myArray[id] = destArray[id]; + + doComputation1(myArray, id); + doComputation2(myArray, id); + + destArray[id] = myArray[id]; + } + } +} -- GitLab