From 6e08e20796da5ac5ca101f644c054748bf4b06c4 Mon Sep 17 00:00:00 2001
From: Saurabh Rawat <saurabh.rawat90@gmail.com>
Date: Fri, 29 Jun 2018 11:13:50 +0530
Subject: [PATCH] add tests for scala

---
 pom.xml                                       | 21 +++++
 .../scala/com/aparapi/SimpleScalaTest.scala   | 86 +++++++++++++++++++
 2 files changed, 107 insertions(+)
 create mode 100644 src/test/scala/com/aparapi/SimpleScalaTest.scala

diff --git a/pom.xml b/pom.xml
index 7ba0cd6f..f07238ae 100644
--- a/pom.xml
+++ b/pom.xml
@@ -98,6 +98,11 @@
             <artifactId>bcel</artifactId>
             <version>6.2</version>
         </dependency>
+        <dependency>
+            <groupId>org.scala-lang</groupId>
+            <artifactId>scala-library</artifactId>
+            <version>2.12.6</version>
+        </dependency>
     </dependencies>
 
     <build>
@@ -107,6 +112,22 @@
                 <groupId>org.apache.maven.plugins</groupId>
                 <artifactId>maven-compiler-plugin</artifactId>
             </plugin>
+            <plugin>
+                <groupId>net.alchim31.maven</groupId>
+                <artifactId>scala-maven-plugin</artifactId>
+                <version>3.4.1</version>
+                <executions>
+                    <execution>
+                        <goals>
+                            <goal>compile</goal>
+                            <goal>testCompile</goal>
+                        </goals>
+                    </execution>
+                </executions>
+                <configuration>
+                    <scalaVersion>2.12.6</scalaVersion>
+                </configuration>
+            </plugin>
             <plugin>
                 <groupId>org.apache.maven.plugins</groupId>
                 <artifactId>maven-source-plugin</artifactId>
diff --git a/src/test/scala/com/aparapi/SimpleScalaTest.scala b/src/test/scala/com/aparapi/SimpleScalaTest.scala
new file mode 100644
index 00000000..945598ac
--- /dev/null
+++ b/src/test/scala/com/aparapi/SimpleScalaTest.scala
@@ -0,0 +1,86 @@
+package com.aparapi
+
+import com.aparapi.codegen.Diff
+import com.aparapi.internal.model.ClassModel
+import com.aparapi.internal.writer.KernelWriter
+import org.junit.Test
+
+import scala.util.Random
+
+class SimpleScalaTest {
+  def runKernel(inA: Array[Float], inB: Array[Float]): Array[Float] = {
+    val result = new Array[Float](inA.length)
+    val kernel = new Kernel() {
+      override def run() {
+        val i = getGlobalId()
+        result(i) = ((inA(i) + inB(i)) / (inA(i) / inB(i))) * ((inA(i) - inB(i)) / (inA(
+          i) * inB(i))) -
+          ((inB(i) - inA(i)) * (inB(i) + inA(i))) * ((inB(i) - inA(i)) / (inB(i) * inA(
+            i)))
+      }
+    }
+
+    kernel.execute(inA.length)
+    result
+  }
+
+  def generateCL(inA: Array[Float], inB: Array[Float]): String = {
+    val result = new Array[Float](inA.length)
+    val kernel = new Kernel() {
+      override def run() {
+        val i = getGlobalId()
+        result(i) = ((inA(i) + inB(i)) / (inA(i) / inB(i))) * ((inA(i) - inB(i)) / (inA(
+          i) * inB(i))) -
+          ((inB(i) - inA(i)) * (inB(i) + inA(i))) * ((inB(i) - inA(i)) / (inB(i) * inA(
+            i)))
+      }
+    }
+
+    val classModel = ClassModel.createClassModel(kernel.getClass)
+
+    val entryPoint = classModel.getEntrypoint("run", kernel)
+    KernelWriter.writeToString(entryPoint)
+  }
+
+  @Test def testKernel(): Unit = {
+    val a = Array.fill(50000)(Random.nextFloat())
+    val b = Array.fill(50000)(Random.nextFloat())
+    assert(runKernel(a, b).length == 50000)
+  }
+
+  @Test def testCL(): Unit = {
+    val a = Array.fill(50000)(Random.nextFloat())
+    val b = Array.fill(50000)(Random.nextFloat())
+    val expected =
+      """typedef struct This_s{
+                       |   __global float *result$2;
+                       |   __global float *inA$2;
+                       |   __global float *inB$2;
+                       |   int passid;
+                       |}This;
+                       |int get_pass_id(This *this){
+                       |   return this->passid;
+                       |}
+                       |__kernel void run(
+                       |   __global float *result$2,
+                       |   __global float *inA$2,
+                       |   __global float *inB$2,
+                       |   int passid
+                       |){
+                       |   This thisStruct;
+                       |   This* this=&thisStruct;
+                       |   this->result$2 = result$2;
+                       |   this->inA$2 = inA$2;
+                       |   this->inB$2 = inB$2;
+                       |   this->passid = passid;
+                       |   {
+                       |      {
+                       |         int i = get_global_id(0);
+                       |         this->result$2[i]  = (((this->inA$2[i] + this->inB$2[i]) / (this->inA$2[i] / this->inB$2[i])) * ((this->inA$2[i] - this->inB$2[i]) / (this->inA$2[i] * this->inB$2[i]))) - (((this->inB$2[i] - this->inA$2[i]) * (this->inB$2[i] + this->inA$2[i])) * ((this->inB$2[i] - this->inA$2[i]) / (this->inB$2[i] * this->inA$2[i])));
+                       |      }
+                       |      return;
+                       |   }
+                       |}""".stripMargin
+    assert(Diff.same(expected, generateCL(a, b)))
+  }
+}
-- 
GitLab