Skip to content
Snippets Groups Projects
Commit b10dfb6a authored by Pr0methean's avatar Pr0methean
Browse files

Increase runtime estimation tolerance to 250ms

parent a227992b
No related branches found
No related tags found
No related merge requests found
...@@ -55,33 +55,33 @@ import com.aparapi.internal.kernel.KernelManager; ...@@ -55,33 +55,33 @@ import com.aparapi.internal.kernel.KernelManager;
/** /**
* Provides integration tests to help in assuring that new APIs for ProfileReports are working, * Provides integration tests to help in assuring that new APIs for ProfileReports are working,
* in single threaded and multi-threaded environments. * in single threaded and multi-threaded environments.
* *
* @author CoreRasurae * @author CoreRasurae
*/ */
public class ProfileReportNewAPITest { public class ProfileReportNewAPITest {
private static OpenCLDevice openCLDevice; private static OpenCLDevice openCLDevice;
private static Logger logger = Logger.getLogger(Config.getLoggerName()); private static Logger logger = Logger.getLogger(Config.getLoggerName());
@Rule @Rule
public TestName name = new TestName(); public TestName name = new TestName();
@After @After
public void classTeardown() { public void classTeardown() {
Util.resetKernelManager(); Util.resetKernelManager();
} }
private class CLKernelManager extends KernelManager { private class CLKernelManager extends KernelManager {
@Override @Override
protected List<Device.TYPE> getPreferredDeviceTypes() { protected List<Device.TYPE> getPreferredDeviceTypes() {
return Arrays.asList(Device.TYPE.ACC, Device.TYPE.GPU, Device.TYPE.CPU); return Arrays.asList(Device.TYPE.ACC, Device.TYPE.GPU, Device.TYPE.CPU);
} }
} }
private class JTPKernelManager extends KernelManager { private class JTPKernelManager extends KernelManager {
private JTPKernelManager() { private JTPKernelManager() {
LinkedHashSet<Device> preferredDevices = new LinkedHashSet<Device>(1); LinkedHashSet<Device> preferredDevices = new LinkedHashSet<Device>(1);
...@@ -93,7 +93,7 @@ public class ProfileReportNewAPITest { ...@@ -93,7 +93,7 @@ public class ProfileReportNewAPITest {
return Arrays.asList(Device.TYPE.JTP); return Arrays.asList(Device.TYPE.JTP);
} }
} }
public void setUpBefore() throws Exception { public void setUpBefore() throws Exception {
KernelManager.setKernelManager(new CLKernelManager()); KernelManager.setKernelManager(new CLKernelManager());
Device device = KernelManager.instance().bestDevice(); Device device = KernelManager.instance().bestDevice();
...@@ -105,9 +105,9 @@ public class ProfileReportNewAPITest { ...@@ -105,9 +105,9 @@ public class ProfileReportNewAPITest {
} }
/** /**
* Tests the ProfileReport observer interface in a single threaded, single kernel environment running on * Tests the ProfileReport observer interface in a single threaded, single kernel environment running on
* an OpenCL device. * an OpenCL device.
* @throws Exception * @throws Exception
*/ */
@Test @Test
public void singleThreadedSingleKernelObserverOpenCLTest() throws Exception { public void singleThreadedSingleKernelObserverOpenCLTest() throws Exception {
...@@ -117,7 +117,7 @@ public class ProfileReportNewAPITest { ...@@ -117,7 +117,7 @@ public class ProfileReportNewAPITest {
} }
/** /**
* Tests the ProfileReport observer interface in a single threaded, single kernel environment running on * Tests the ProfileReport observer interface in a single threaded, single kernel environment running on
* Java Thread Pool. * Java Thread Pool.
*/ */
@Test @Test
...@@ -131,7 +131,7 @@ public class ProfileReportNewAPITest { ...@@ -131,7 +131,7 @@ public class ProfileReportNewAPITest {
private double accumulatedElapsedTime = 0.0; private double accumulatedElapsedTime = 0.0;
private long receivedReportsCount = 0; private long receivedReportsCount = 0;
} }
private class ReportObserver implements IProfileReportObserver { private class ReportObserver implements IProfileReportObserver {
private final ConcurrentSkipListSet<Long> expectedThreadsIds = new ConcurrentSkipListSet<>(); private final ConcurrentSkipListSet<Long> expectedThreadsIds = new ConcurrentSkipListSet<>();
private final ConcurrentSkipListMap<Long, ThreadTestState> observedThreadsIds = new ConcurrentSkipListMap<>(); private final ConcurrentSkipListMap<Long, ThreadTestState> observedThreadsIds = new ConcurrentSkipListMap<>();
...@@ -139,30 +139,30 @@ public class ProfileReportNewAPITest { ...@@ -139,30 +139,30 @@ public class ProfileReportNewAPITest {
private final int threads; private final int threads;
private final int runs; private final int runs;
private final boolean[] receivedReportIds; private final boolean[] receivedReportIds;
private ReportObserver(Device _device, int _threads, int _runs) { private ReportObserver(Device _device, int _threads, int _runs) {
device = _device; device = _device;
threads = _threads; threads = _threads;
runs = _runs; runs = _runs;
receivedReportIds = new boolean[threads * runs]; receivedReportIds = new boolean[threads * runs];
} }
private void addAcceptedThreadId(long threadId) { private void addAcceptedThreadId(long threadId) {
expectedThreadsIds.add(threadId); expectedThreadsIds.add(threadId);
} }
private ConcurrentSkipListMap<Long, ThreadTestState> getObservedThreadsIds() { private ConcurrentSkipListMap<Long, ThreadTestState> getObservedThreadsIds() {
return observedThreadsIds; return observedThreadsIds;
} }
@Override @Override
public void receiveReport(Class<? extends Kernel> kernelClass, Device _device, WeakReference<ProfileReport> profileInfoRef) { public void receiveReport(Class<? extends Kernel> kernelClass, Device _device, WeakReference<ProfileReport> profileInfoRef) {
ProfileReport profileInfo = profileInfoRef.get(); ProfileReport profileInfo = profileInfoRef.get();
assertEquals("Kernel class does not match", Basic1Kernel.class, kernelClass); assertEquals("Kernel class does not match", Basic1Kernel.class, kernelClass);
assertEquals("Device does not match", device, _device); assertEquals("Device does not match", device, _device);
boolean isThreadAccepted = expectedThreadsIds.contains(profileInfo.getThreadId()); boolean isThreadAccepted = expectedThreadsIds.contains(profileInfo.getThreadId());
assertTrue("Thread generating the report (" + profileInfo.getThreadId() + assertTrue("Thread generating the report (" + profileInfo.getThreadId() +
") is not among the accepted ones: " + expectedThreadsIds.toString(), isThreadAccepted); ") is not among the accepted ones: " + expectedThreadsIds.toString(), isThreadAccepted);
Long threadId = profileInfo.getThreadId(); Long threadId = profileInfo.getThreadId();
ThreadTestState state = observedThreadsIds.computeIfAbsent(threadId, k -> new ThreadTestState()); ThreadTestState state = observedThreadsIds.computeIfAbsent(threadId, k -> new ThreadTestState());
...@@ -171,15 +171,15 @@ public class ProfileReportNewAPITest { ...@@ -171,15 +171,15 @@ public class ProfileReportNewAPITest {
receivedReportIds[(int)profileInfo.getReportId() - 1] = true; receivedReportIds[(int)profileInfo.getReportId() - 1] = true;
} }
} }
public boolean singleThreadedSingleKernelReportObserverTestHelper(Device device, int size) { public boolean singleThreadedSingleKernelReportObserverTestHelper(Device device, int size) {
final int runs = 100; final int runs = 100;
final int inputArray[] = new int[size]; final int inputArray[] = new int[size];
final Basic1Kernel kernel = new Basic1Kernel(); final Basic1Kernel kernel = new Basic1Kernel();
int[] outputArray = null; int[] outputArray = null;
Range range = device.createRange(size, size); Range range = device.createRange(size, size);
ReportObserver observer = new ReportObserver(device, 1, runs); ReportObserver observer = new ReportObserver(device, 1, runs);
observer.addAcceptedThreadId(Thread.currentThread().getId()); observer.addAcceptedThreadId(Thread.currentThread().getId());
kernel.registerProfileReportObserver(observer); kernel.registerProfileReportObserver(observer);
...@@ -187,7 +187,7 @@ public class ProfileReportNewAPITest { ...@@ -187,7 +187,7 @@ public class ProfileReportNewAPITest {
for (int i = 0; i < runs; i++) { for (int i = 0; i < runs; i++) {
assertFalse("Report with id " + i + " shouldn't have been received yet", observer.receivedReportIds[i]); assertFalse("Report with id " + i + " shouldn't have been received yet", observer.receivedReportIds[i]);
} }
long startOfExecution = System.currentTimeMillis(); long startOfExecution = System.currentTimeMillis();
try { try {
for (int i = 0; i < runs; i++) { for (int i = 0; i < runs; i++) {
...@@ -202,7 +202,7 @@ public class ProfileReportNewAPITest { ...@@ -202,7 +202,7 @@ public class ProfileReportNewAPITest {
assertEquals("Number of profiling reports doesn't match the expected", runs, state.receivedReportsCount); assertEquals("Number of profiling reports doesn't match the expected", runs, state.receivedReportsCount);
assertEquals("Aparapi Accumulated execution time doesn't match", kernel.getAccumulatedExecutionTimeAllThreads(device), state.accumulatedElapsedTime, 1e-10); assertEquals("Aparapi Accumulated execution time doesn't match", kernel.getAccumulatedExecutionTimeAllThreads(device), state.accumulatedElapsedTime, 1e-10);
assertEquals("Test estimated accumulated time doesn't match within 200ms window", runTime, kernel.getAccumulatedExecutionTimeAllThreads(device), 200); assertEquals("Test estimated accumulated time doesn't match within 250ms window", runTime, kernel.getAccumulatedExecutionTimeAllThreads(device), 250);
for (int i = 0; i < runs; i++) { for (int i = 0; i < runs; i++) {
assertTrue("Report with id " + i + " wasn't received", observer.receivedReportIds[i]); assertTrue("Report with id " + i + " wasn't received", observer.receivedReportIds[i]);
} }
...@@ -211,12 +211,12 @@ public class ProfileReportNewAPITest { ...@@ -211,12 +211,12 @@ public class ProfileReportNewAPITest {
kernel.registerProfileReportObserver(null); kernel.registerProfileReportObserver(null);
kernel.dispose(); kernel.dispose();
} }
return true; return true;
} }
/** /**
* Tests the ProfileReport observer interface in a multi threaded, single kernel environment running on * Tests the ProfileReport observer interface in a multi threaded, single kernel environment running on
* an OpenCL device. * an OpenCL device.
*/ */
@Test @Test
...@@ -227,7 +227,7 @@ public class ProfileReportNewAPITest { ...@@ -227,7 +227,7 @@ public class ProfileReportNewAPITest {
} }
/** /**
* Tests the ProfileReport observer interface in a multi threaded, single kernel environment running on * Tests the ProfileReport observer interface in a multi threaded, single kernel environment running on
* Java Thread Pool. * Java Thread Pool.
*/ */
@Test @Test
...@@ -244,15 +244,15 @@ public class ProfileReportNewAPITest { ...@@ -244,15 +244,15 @@ public class ProfileReportNewAPITest {
private double accumulatedExecutionTime; private double accumulatedExecutionTime;
private int[] outputArray; private int[] outputArray;
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public boolean multiThreadedSingleKernelReportObserverTestRunner(final ExecutorService executorService, public boolean multiThreadedSingleKernelReportObserverTestRunner(final ExecutorService executorService,
final List<Basic1Kernel> kernels, final ThreadResults[] results, int[] inputArray, int runs, int javaThreads, final List<Basic1Kernel> kernels, final ThreadResults[] results, int[] inputArray, int runs, int javaThreads,
final Device device, final ReportObserver observer, int size) throws InterruptedException, ExecutionException { final Device device, final ReportObserver observer, int size) throws InterruptedException, ExecutionException {
final AtomicInteger atomicResultId = new AtomicInteger(0); final AtomicInteger atomicResultId = new AtomicInteger(0);
boolean terminatedOk = false; boolean terminatedOk = false;
try { try {
List<Future<Runnable>> futures = new ArrayList<>(javaThreads); List<Future<Runnable>> futures = new ArrayList<>(javaThreads);
for (Basic1Kernel k : kernels) { for (Basic1Kernel k : kernels) {
futures.add((Future<Runnable>)executorService.submit(new Runnable() { futures.add((Future<Runnable>)executorService.submit(new Runnable() {
@Override @Override
...@@ -291,40 +291,40 @@ public class ProfileReportNewAPITest { ...@@ -291,40 +291,40 @@ public class ProfileReportNewAPITest {
executorService.shutdownNow(); executorService.shutdownNow();
} }
} }
return terminatedOk; return terminatedOk;
} }
public boolean multiThreadedSingleKernelReportObserverTestHelper(Device device, int size) throws InterruptedException, ExecutionException { public boolean multiThreadedSingleKernelReportObserverTestHelper(Device device, int size) throws InterruptedException, ExecutionException {
final int runs = 100; final int runs = 100;
final int javaThreads = 10; final int javaThreads = 10;
final int inputArray[] = new int[size]; final int inputArray[] = new int[size];
ExecutorService executorService = Executors.newFixedThreadPool(javaThreads); ExecutorService executorService = Executors.newFixedThreadPool(javaThreads);
final ReportObserver observer = new ReportObserver(device, javaThreads, runs); final ReportObserver observer = new ReportObserver(device, javaThreads, runs);
for (int i = 0; i < runs; i++) { for (int i = 0; i < runs; i++) {
assertFalse("Report with id " + i + " shouldn't have been received yet", observer.receivedReportIds[i]); assertFalse("Report with id " + i + " shouldn't have been received yet", observer.receivedReportIds[i]);
} }
final List<Basic1Kernel> kernels = new ArrayList<Basic1Kernel>(javaThreads); final List<Basic1Kernel> kernels = new ArrayList<Basic1Kernel>(javaThreads);
for (int i = 0; i < javaThreads; i++) { for (int i = 0; i < javaThreads; i++) {
final Basic1Kernel kernel = new Basic1Kernel(); final Basic1Kernel kernel = new Basic1Kernel();
kernel.registerProfileReportObserver(observer); kernel.registerProfileReportObserver(observer);
kernels.add(kernel); kernels.add(kernel);
} }
final ThreadResults[] results = new ThreadResults[javaThreads]; final ThreadResults[] results = new ThreadResults[javaThreads];
for (int i = 0; i < results.length; i++) { for (int i = 0; i < results.length; i++) {
results[i] = new ThreadResults(); results[i] = new ThreadResults();
} }
boolean terminatedOk = multiThreadedSingleKernelReportObserverTestRunner(executorService, kernels, results, boolean terminatedOk = multiThreadedSingleKernelReportObserverTestRunner(executorService, kernels, results,
inputArray, runs, javaThreads, device, observer, size); inputArray, runs, javaThreads, device, observer, size);
assertTrue("Threads did not terminate correctly", terminatedOk); assertTrue("Threads did not terminate correctly", terminatedOk);
double allThreadsAccumulatedTime = 0; double allThreadsAccumulatedTime = 0;
ConcurrentSkipListMap<Long, ThreadTestState> states = observer.getObservedThreadsIds(); ConcurrentSkipListMap<Long, ThreadTestState> states = observer.getObservedThreadsIds();
assertEquals("Number of Java threads sending profile reports should match the number of JavaThreads", javaThreads, states.values().size()); assertEquals("Number of Java threads sending profile reports should match the number of JavaThreads", javaThreads, states.values().size());
...@@ -339,27 +339,27 @@ public class ProfileReportNewAPITest { ...@@ -339,27 +339,27 @@ public class ProfileReportNewAPITest {
assertTrue("Thread index " + i + " kernel computation doesn't match the expected", validateBasic1Kernel(inputArray, results[i].outputArray)); assertTrue("Thread index " + i + " kernel computation doesn't match the expected", validateBasic1Kernel(inputArray, results[i].outputArray));
assertEquals("Runtime is not within 600ms of the kernel estimated", results[i].runTime, state.accumulatedElapsedTime, 600); assertEquals("Runtime is not within 600ms of the kernel estimated", results[i].runTime, state.accumulatedElapsedTime, 600);
} }
assertEquals("Overall kernel execution time doesn't match", assertEquals("Overall kernel execution time doesn't match",
kernels.get(0).getAccumulatedExecutionTimeAllThreads(device), allThreadsAccumulatedTime, 1e10); kernels.get(0).getAccumulatedExecutionTimeAllThreads(device), allThreadsAccumulatedTime, 1e10);
return true; return true;
} }
private boolean validateBasic1Kernel(final int[] inputArray, final int[] resultArray) { private boolean validateBasic1Kernel(final int[] inputArray, final int[] resultArray) {
int[] expecteds = Arrays.copyOf(inputArray, inputArray.length); int[] expecteds = Arrays.copyOf(inputArray, inputArray.length);
for (int threadId = 0; threadId < inputArray.length; threadId++) { for (int threadId = 0; threadId < inputArray.length; threadId++) {
expecteds[threadId] += threadId; expecteds[threadId] += threadId;
} }
assertArrayEquals(expecteds, resultArray); assertArrayEquals(expecteds, resultArray);
return true; return true;
} }
private class Basic1Kernel extends Kernel { private class Basic1Kernel extends Kernel {
protected int[] workArray; protected int[] workArray;
@NoCL @NoCL
public void setInputOuputArray(int[] array) { public void setInputOuputArray(int[] array) {
workArray = array; workArray = array;
...@@ -369,12 +369,12 @@ public class ProfileReportNewAPITest { ...@@ -369,12 +369,12 @@ public class ProfileReportNewAPITest {
public int getId() { public int getId() {
return 1; return 1;
} }
@Override @Override
public void run() { public void run() {
int id = getLocalId(); int id = getLocalId();
workArray[id]+=id; workArray[id]+=id;
} }
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment