diff --git a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TaskGraph.java b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TaskGraph.java
index 8be18a3b68..044612f3bc 100644
--- a/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TaskGraph.java
+++ b/tornado-api/src/main/java/uk/ac/manchester/tornado/api/TaskGraph.java
@@ -603,7 +603,8 @@ public
- * The constant {@link ARRAY_HEADER} represents the size of the header in bytes. + * The constant {@link TornadoNativeArray#ARRAY_HEADER} represents the size of the header in bytes. *
*/ public abstract sealed class TornadoNativeArray // diff --git a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/runtime/OCLTornadoDevice.java b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/runtime/OCLTornadoDevice.java index d09912004f..cce77723c6 100644 --- a/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/runtime/OCLTornadoDevice.java +++ b/tornado-drivers/opencl/src/main/java/uk/ac/manchester/tornado/drivers/opencl/runtime/OCLTornadoDevice.java @@ -24,6 +24,7 @@ package uk.ac.manchester.tornado.drivers.opencl.runtime; import java.io.IOException; +import java.io.InputStream; import java.lang.foreign.MemorySegment; import java.nio.file.Files; import java.nio.file.Path; @@ -53,14 +54,8 @@ import uk.ac.manchester.tornado.api.memory.XPUBuffer; import uk.ac.manchester.tornado.api.profiler.ProfilerType; import uk.ac.manchester.tornado.api.profiler.TornadoProfiler; -import uk.ac.manchester.tornado.api.types.arrays.ByteArray; -import uk.ac.manchester.tornado.api.types.arrays.CharArray; -import uk.ac.manchester.tornado.api.types.arrays.DoubleArray; -import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; -import uk.ac.manchester.tornado.api.types.arrays.IntArray; -import uk.ac.manchester.tornado.api.types.arrays.LongArray; -import uk.ac.manchester.tornado.api.types.arrays.ShortArray; +import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray; +import uk.ac.manchester.tornado.api.types.tensors.Tensor; import uk.ac.manchester.tornado.drivers.common.TornadoBufferProvider; import uk.ac.manchester.tornado.drivers.opencl.OCLBackendImpl; import uk.ac.manchester.tornado.drivers.opencl.OCLCodeCache; @@ -309,31 +304,46 @@ private TornadoInstalledCode compileTask(SchedulableTask task) { } } + private byte[] getSource(PrebuiltTask prebuiltTask) { + byte[] source; + Class> klass = prebuiltTask.getPackageClass(); + if (klass != null) { + try (InputStream inputStream = klass.getClassLoader().getResourceAsStream(prebuiltTask.getFilename())) { + TornadoInternalError.guarantee(inputStream != null, "file does not exist: %s", prebuiltTask.getFilename()); + source = inputStream.readAllBytes(); + } catch (IOException e) { + throw new TornadoBailoutRuntimeException("[Error] I/O Exception in readAllBytes", e); + } + } else { + final Path path = Paths.get(prebuiltTask.getFilename()); + TornadoInternalError.guarantee(path.toFile().exists(), "file does not exist: %s", prebuiltTask.getFilename()); + try { + source = Files.readAllBytes(path); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + return source; + } + private TornadoInstalledCode compilePreBuiltTask(SchedulableTask task) { final OCLDeviceContextInterface deviceContext = getDeviceContext(); - final PrebuiltTask executable = (PrebuiltTask) task; - if (deviceContext.isCached(task.getId(), executable.getEntryPoint())) { - return deviceContext.getInstalledCode(task.getId(), executable.getEntryPoint()); + final PrebuiltTask prebuiltTask = (PrebuiltTask) task; + if (deviceContext.isCached(task.getId(), prebuiltTask.getEntryPoint())) { + return deviceContext.getInstalledCode(task.getId(), prebuiltTask.getEntryPoint()); } - final Path path = Paths.get(executable.getFilename()); - TornadoInternalError.guarantee(path.toFile().exists(), "file does not exist: %s", executable.getFilename()); - try { - final byte[] source = Files.readAllBytes(path); - - OCLInstalledCode installedCode; - if (OCLBackend.isDeviceAnFPGAAccelerator(deviceContext)) { - // A) for FPGA - installedCode = deviceContext.installCode(task.getId(), executable.getEntryPoint(), source, task.meta().isPrintKernelEnabled()); - } else { - // B) for CPU multi-core or GPU - installedCode = deviceContext.installCode(executable.meta(), task.getId(), executable.getEntryPoint(), source); - } - return installedCode; - } catch (IOException e) { - e.printStackTrace(); + byte[] source = getSource(prebuiltTask); + OCLInstalledCode installedCode; + if (OCLBackend.isDeviceAnFPGAAccelerator(deviceContext)) { + // A) for FPGA + installedCode = deviceContext.installCode(task.getId(), prebuiltTask.getEntryPoint(), source, task.meta().isPrintKernelEnabled()); + } else { + // B) for CPU multi-core or GPU + installedCode = deviceContext.installCode(prebuiltTask.meta(), task.getId(), prebuiltTask.getEntryPoint(), source); } - return null; + return installedCode; + } private TornadoInstalledCode compileJavaToAccelerator(SchedulableTask task) { @@ -535,21 +545,7 @@ private XPUBuffer createDeviceBuffer(Class> type, Object object, OCLDeviceCont result = new OCLVectorWrapper(deviceContext, object, batchSize); } else if (object instanceof MemorySegment) { result = new OCLMemorySegmentWrapper(deviceContext, batchSize); - } else if (object instanceof IntArray) { - result = new OCLMemorySegmentWrapper(deviceContext, batchSize); - } else if (object instanceof FloatArray) { - result = new OCLMemorySegmentWrapper(deviceContext, batchSize); - } else if (object instanceof DoubleArray) { - result = new OCLMemorySegmentWrapper(deviceContext, batchSize); - } else if (object instanceof LongArray) { - result = new OCLMemorySegmentWrapper(deviceContext, batchSize); - } else if (object instanceof ShortArray) { - result = new OCLMemorySegmentWrapper(deviceContext, batchSize); - } else if (object instanceof ByteArray) { - result = new OCLMemorySegmentWrapper(deviceContext, batchSize); - } else if (object instanceof CharArray) { - result = new OCLMemorySegmentWrapper(deviceContext, batchSize); - } else if (object instanceof HalfFloatArray) { + } else if (object instanceof TornadoNativeArray && !(object instanceof Tensor)) { result = new OCLMemorySegmentWrapper(deviceContext, batchSize); } else { result = new OCLXPUBuffer(deviceContext, object); diff --git a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/runtime/PTXTornadoDevice.java b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/runtime/PTXTornadoDevice.java index 2632bf16cf..1d2f2e57f7 100644 --- a/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/runtime/PTXTornadoDevice.java +++ b/tornado-drivers/ptx/src/main/java/uk/ac/manchester/tornado/drivers/ptx/runtime/PTXTornadoDevice.java @@ -26,6 +26,7 @@ import static uk.ac.manchester.tornado.drivers.ptx.graal.PTXCodeUtil.buildKernelName; import java.io.IOException; +import java.io.InputStream; import java.lang.foreign.MemorySegment; import java.nio.file.Files; import java.nio.file.Path; @@ -49,14 +50,8 @@ import uk.ac.manchester.tornado.api.memory.XPUBuffer; import uk.ac.manchester.tornado.api.profiler.ProfilerType; import uk.ac.manchester.tornado.api.profiler.TornadoProfiler; -import uk.ac.manchester.tornado.api.types.arrays.ByteArray; -import uk.ac.manchester.tornado.api.types.arrays.CharArray; -import uk.ac.manchester.tornado.api.types.arrays.DoubleArray; -import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; -import uk.ac.manchester.tornado.api.types.arrays.IntArray; -import uk.ac.manchester.tornado.api.types.arrays.LongArray; -import uk.ac.manchester.tornado.api.types.arrays.ShortArray; +import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray; +import uk.ac.manchester.tornado.api.types.tensors.Tensor; import uk.ac.manchester.tornado.drivers.common.TornadoBufferProvider; import uk.ac.manchester.tornado.drivers.ptx.PTX; import uk.ac.manchester.tornado.drivers.ptx.PTXBackendImpl; @@ -201,24 +196,40 @@ private TornadoInstalledCode compileTask(SchedulableTask task) { } } + private byte[] getSource(PrebuiltTask prebuiltTask) { + byte[] source; + Class> klass = prebuiltTask.getPackageClass(); + if (klass != null) { + try (InputStream inputStream = klass.getClassLoader().getResourceAsStream(prebuiltTask.getFilename())) { + TornadoInternalError.guarantee(inputStream != null, "file does not exist: %s", prebuiltTask.getFilename()); + source = inputStream.readAllBytes(); + } catch (IOException e) { + throw new TornadoBailoutRuntimeException("[Error] I/O Exception in readAllBytes", e); + } + } else { + final Path path = Paths.get(prebuiltTask.getFilename()); + TornadoInternalError.guarantee(path.toFile().exists(), "file does not exist: %s", prebuiltTask.getFilename()); + try { + source = Files.readAllBytes(path); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + return source; + } + private TornadoInstalledCode compilePreBuiltTask(SchedulableTask task) { final PTXDeviceContext deviceContext = getDeviceContext(); - final PrebuiltTask executable = (PrebuiltTask) task; - String functionName = buildKernelName(executable.getEntryPoint(), executable); - if (deviceContext.isCached(executable.getEntryPoint(), executable)) { + final PrebuiltTask prebuiltTask = (PrebuiltTask) task; + String functionName = buildKernelName(prebuiltTask.getEntryPoint(), prebuiltTask); + + if (deviceContext.isCached(prebuiltTask.getEntryPoint(), prebuiltTask)) { return deviceContext.getInstalledCode(functionName); } - final Path path = Paths.get(executable.getFilename()); - TornadoInternalError.guarantee(path.toFile().exists(), "file does not exist: %s", executable.getFilename()); - try { - byte[] source = Files.readAllBytes(path); - source = PTXCodeUtil.getCodeWithAttachedPTXHeader(source, getBackend()); - return deviceContext.installCode(functionName, source, executable.getEntryPoint(), task.meta().isPrintKernelEnabled()); - } catch (IOException e) { - e.printStackTrace(); - } - return null; + byte[] source = getSource(prebuiltTask); + byte[] binary = PTXCodeUtil.getCodeWithAttachedPTXHeader(source, getBackend()); + return deviceContext.installCode(functionName, binary, prebuiltTask.getEntryPoint(), task.meta().isPrintKernelEnabled()); } @Override @@ -229,8 +240,7 @@ public boolean isFullJITMode(SchedulableTask task) { @Override public TornadoInstalledCode getCodeFromCache(SchedulableTask task) { String methodName; - if (task instanceof PrebuiltTask) { - PrebuiltTask prebuiltTask = (PrebuiltTask) task; + if (task instanceof PrebuiltTask prebuiltTask) { methodName = prebuiltTask.getEntryPoint(); } else { CompilableTask compilableTask = (CompilableTask) task; @@ -260,21 +270,7 @@ private XPUBuffer createDeviceBuffer(Class> type, Object object, long batchSiz result = new PTXVectorWrapper(getDeviceContext(), object, batchSize); } else if (object instanceof MemorySegment) { result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize); - } else if (object instanceof IntArray) { - result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize); - } else if (object instanceof FloatArray) { - result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize); - } else if (object instanceof DoubleArray) { - result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize); - } else if (object instanceof LongArray) { - result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize); - } else if (object instanceof ShortArray) { - result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize); - } else if (object instanceof ByteArray) { - result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize); - } else if (object instanceof CharArray) { - result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize); - } else if (object instanceof HalfFloatArray) { + } else if (object instanceof TornadoNativeArray && !(object instanceof Tensor)) { result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize); } else { result = new PTXObjectWrapper(getDeviceContext(), object); @@ -527,8 +523,7 @@ public void sync(long executionPlanId) { @Override public boolean equals(Object obj) { - if (obj instanceof PTXTornadoDevice) { - final PTXTornadoDevice other = (PTXTornadoDevice) obj; + if (obj instanceof PTXTornadoDevice other) { return (other.deviceIndex == deviceIndex); } return false; diff --git a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/runtime/SPIRVTornadoDevice.java b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/runtime/SPIRVTornadoDevice.java index 2f6c9f2cb2..1562fbda15 100644 --- a/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/runtime/SPIRVTornadoDevice.java +++ b/tornado-drivers/spirv/src/main/java/uk/ac/manchester/tornado/drivers/spirv/runtime/SPIRVTornadoDevice.java @@ -24,6 +24,7 @@ package uk.ac.manchester.tornado.drivers.spirv.runtime; import java.lang.foreign.MemorySegment; +import java.net.URL; import java.nio.file.Path; import java.nio.file.Paths; import java.util.List; @@ -140,14 +141,28 @@ public TornadoInstalledCode installCode(SchedulableTask task) { } } - private TornadoInstalledCode compilePreBuiltTask(PrebuiltTask task) { + private Path getSource(PrebuiltTask prebuiltTask) { + Class> klass = prebuiltTask.getPackageClass(); + if (klass != null) { + URL url = klass.getClassLoader().getResource(prebuiltTask.getFilename()); + if (url != null) { + return Paths.get(url.getPath()); + } else { + throw new TornadoRuntimeException("Prebuilt file path file not found: " + url.getPath()); + } + } else { + return Paths.get(prebuiltTask.getFilename()); + } + } + + private TornadoInstalledCode compilePreBuiltTask(PrebuiltTask prebuiltTask) { final SPIRVDeviceContext deviceContext = getDeviceContext(); - if (deviceContext.isCached(task.getId(), task.getEntryPoint())) { - return deviceContext.getInstalledCode(task.getId(), task.getEntryPoint()); + if (deviceContext.isCached(prebuiltTask.getId(), prebuiltTask.getEntryPoint())) { + return deviceContext.getInstalledCode(prebuiltTask.getId(), prebuiltTask.getEntryPoint()); } - final Path pathToSPIRVBin = Paths.get(task.getFilename()); - TornadoInternalError.guarantee(pathToSPIRVBin.toFile().exists(), "files does not exists %s", task.getFilename()); - return deviceContext.installBinary(task.meta(), task.getId(), task.getEntryPoint(), task.getFilename()); + final Path pathToSPIRVBin = getSource(prebuiltTask); + TornadoInternalError.guarantee(pathToSPIRVBin.toFile().exists(), "files does not exists %s", prebuiltTask.getFilename()); + return deviceContext.installBinary(prebuiltTask.meta(), prebuiltTask.getId(), prebuiltTask.getEntryPoint(), prebuiltTask.getFilename()); } public SPIRVBackend getBackend() { diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/analyzer/TaskUtils.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/analyzer/TaskUtils.java index 9368c7fc7c..647fa92b35 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/analyzer/TaskUtils.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/analyzer/TaskUtils.java @@ -321,9 +321,14 @@ public static PrebuiltTask createTask(ScheduleMetaData meta, PrebuiltTaskPackage taskPackage.getAccesses(), // taskPackage.getDevice(), // domain); + // Attach atomics if (taskPackage.getAtomics() != null) { prebuiltTask.withAtomics(taskPackage.getAtomics()); } + // Attach class if in a JAR file + if (taskPackage.getKlassJar() != null) { + prebuiltTask.withKlass(taskPackage.getKlassJar()); + } return prebuiltTask; } diff --git a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/PrebuiltTask.java b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/PrebuiltTask.java index e254dc0105..e590f7c2bc 100644 --- a/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/PrebuiltTask.java +++ b/tornado-runtime/src/main/java/uk/ac/manchester/tornado/runtime/tasks/PrebuiltTask.java @@ -49,6 +49,7 @@ public class PrebuiltTask implements SchedulableTask { private TornadoProfiler profiler; private boolean forceCompiler; private int[] atomics; + private Class> klass; public PrebuiltTask(ScheduleMetaData scheduleMeta, String id, String entryPoint, String filename, Object[] args, Access[] access, TornadoDevice device, DomainTree domain) { this.entryPoint = entryPoint; @@ -75,6 +76,11 @@ public PrebuiltTask withAtomics(int[] atomics) { return this; } + public PrebuiltTask withKlass(Class> klass) { + this.klass = klass; + return this; + } + @Override public String toString() { final StringBuilder sb = new StringBuilder(); @@ -234,4 +240,7 @@ public int[] getAtomics() { return atomics; } + public Class> getPackageClass() { + return this.klass; + } } diff --git a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/prebuilt/PrebuiltTest.java b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/prebuilt/PrebuiltTest.java index c9e52a7077..5194809485 100644 --- a/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/prebuilt/PrebuiltTest.java +++ b/tornado-unittests/src/main/java/uk/ac/manchester/tornado/unittests/prebuilt/PrebuiltTest.java @@ -88,16 +88,19 @@ public void testPrebuilt01() throws TornadoExecutionPlanException { throw new TornadoRuntimeException("Backend not supported"); } + // @formatter:off TaskGraph taskGraph = new TaskGraph("s0") // .transferToDevice(DataTransferMode.FIRST_EXECUTION, a, b) // - .prebuiltTask("t0", // - "add", // - FILE_PATH, // - new Object[] { a, b, c }, // - new Access[] { Access.READ_ONLY, Access.READ_ONLY, Access.WRITE_ONLY }, // - defaultDevice, // - new int[] { numElements })// + .prebuiltTask("t0", // task-name + "add", // method name (entry point in the prebuilt source, e.g., the OpenCL kernel name). + FILE_PATH, // Path to the file + new Object[] { a, b, c }, // Parameters to the function + new Access[] { Access.READ_ONLY, Access.READ_ONLY, Access.WRITE_ONLY }, // Data access to the function + defaultDevice, // Device in which the application will be executed. This is at the task graph + // level because TornadoVM performs code specialization per device. + new int[] { numElements }) // Number of threads to deploy .transferToHost(DataTransferMode.EVERY_EXECUTION, c); + // @formatter:on ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot(); try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(immutableTaskGraph)) {