Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,8 @@ public <T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15> TaskGr
}

/**
* Add a pre-built OpenCL task into a task-schedule.
* Add a pre-built kernel (OpenCL, PTX or SPIR-V) task into a task-graph. The access to the pre-built file is provided
* as a String with the name of the file (path).
*
* @param id
* Task-Id
Expand All @@ -630,6 +631,40 @@ public TaskGraph prebuiltTask(String id, String entryPoint, String filename, Obj
return this;
}

/**
* Add a pre-built kernel (OpenCL, PTX or SPIR-V) task into a task-graph. The kernel is stored in a JAR file accessible
* from the CLASSPATH. The access to the pre-built file is provided as a combination of the klass (Class) with the
* JAR File, and the name of the file that contains the prebuilt kernel.
*
* @param id
* Task-Id
* @param entryPoint
* Name of the method to be executed on the target device
* @param klass
* Klass that can access the resource within the JAR File
* @param resource
* Input file with the source kernel
* @param args
* Arguments to the kernel
* @param accesses
* Accesses ({@link uk.ac.manchester.tornado.api.common.Access} for
* each input parameter to the method
* @param device
* Device to be executed
* @param dimensions
* Select number of dimensions of the kernel (1D, 2D or 3D)
* @return {@link TaskGraph}
*/
@Override
public TaskGraph prebuiltTask(String id, String entryPoint, Class<?> klass, String resource, Object[] args, Access[] accesses, TornadoDevice device, int[] dimensions) {
System.out.println("[Warning] This API call is experimental and it may be removed in future versions");
checkTaskName(id);
PrebuiltTaskPackage prebuiltTask = TaskPackage.createPrebuiltTask(id, entryPoint, resource, args, accesses, device, dimensions);
prebuiltTask.withClass(klass);
taskGraphImpl.addPrebuiltTask(prebuiltTask);
return this;
}

/**
* Add a pre-built OpenCL task into a task-schedule with atomics region.
*
Expand Down Expand Up @@ -868,7 +903,7 @@ protected String getProfileLog() {
return taskGraphImpl.getProfileLog();
}

public Collection<?> getOutputs() {
Collection<?> getOutputs() {
return taskGraphImpl.getOutputs();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,8 @@ <T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15> TaskGraphInte
*/
TaskGraphInterface prebuiltTask(String id, String entryPoint, String filename, Object[] args, Access[] accesses, TornadoDevice device, int[] dimensions, int[] atomics);

TaskGraph prebuiltTask(String id, String entryPoint, Class<?> klass, String resource, Object[] args, Access[] accesses, TornadoDevice device, int[] dimensions);

/**
* Obtains the task-schedule name that was assigned.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public class PrebuiltTaskPackage extends TaskPackage {
private final TornadoDevice device;
private final int[] dimensions;
private int[] atomics;
private Class<?> klass;

public PrebuiltTaskPackage(String id, String entryPoint, String fileName, Object[] args, Access[] accesses, TornadoDevice device, int[] dimensions) {
super(id, null);
Expand All @@ -42,6 +43,11 @@ public PrebuiltTaskPackage withAtomics(int[] atomics) {
return this;
}

public PrebuiltTaskPackage withClass(Class<?> klass) {
this.klass = klass;
return this;
}

public String getEntryPoint() {
return entryPoint;
}
Expand Down Expand Up @@ -74,4 +80,8 @@ public boolean isPrebuiltTask() {
public int[] getAtomics() {
return atomics;
}

public Class<?> getKlassJar() {
return this.klass;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
* </p>
*
* <p>
* The constant {@link ARRAY_HEADER} represents the size of the header in bytes.
* The constant {@link TornadoNativeArray#ARRAY_HEADER} represents the size of the header in bytes.
* </p>
*/
public abstract sealed class TornadoNativeArray //
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
package uk.ac.manchester.tornado.drivers.opencl.runtime;

import java.io.IOException;
import java.io.InputStream;
import java.lang.foreign.MemorySegment;
import java.nio.file.Files;
import java.nio.file.Path;
Expand Down Expand Up @@ -53,14 +54,8 @@
import uk.ac.manchester.tornado.api.memory.XPUBuffer;
import uk.ac.manchester.tornado.api.profiler.ProfilerType;
import uk.ac.manchester.tornado.api.profiler.TornadoProfiler;
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
import uk.ac.manchester.tornado.api.types.arrays.CharArray;
import uk.ac.manchester.tornado.api.types.arrays.DoubleArray;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import uk.ac.manchester.tornado.api.types.arrays.LongArray;
import uk.ac.manchester.tornado.api.types.arrays.ShortArray;
import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
import uk.ac.manchester.tornado.api.types.tensors.Tensor;
import uk.ac.manchester.tornado.drivers.common.TornadoBufferProvider;
import uk.ac.manchester.tornado.drivers.opencl.OCLBackendImpl;
import uk.ac.manchester.tornado.drivers.opencl.OCLCodeCache;
Expand Down Expand Up @@ -309,31 +304,46 @@ private TornadoInstalledCode compileTask(SchedulableTask task) {
}
}

private byte[] getSource(PrebuiltTask prebuiltTask) {
byte[] source;
Class<?> klass = prebuiltTask.getPackageClass();
if (klass != null) {
try (InputStream inputStream = klass.getClassLoader().getResourceAsStream(prebuiltTask.getFilename())) {
TornadoInternalError.guarantee(inputStream != null, "file does not exist: %s", prebuiltTask.getFilename());
source = inputStream.readAllBytes();
} catch (IOException e) {
throw new TornadoBailoutRuntimeException("[Error] I/O Exception in readAllBytes", e);
}
} else {
final Path path = Paths.get(prebuiltTask.getFilename());
TornadoInternalError.guarantee(path.toFile().exists(), "file does not exist: %s", prebuiltTask.getFilename());
try {
source = Files.readAllBytes(path);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
return source;
}

private TornadoInstalledCode compilePreBuiltTask(SchedulableTask task) {
final OCLDeviceContextInterface deviceContext = getDeviceContext();
final PrebuiltTask executable = (PrebuiltTask) task;
if (deviceContext.isCached(task.getId(), executable.getEntryPoint())) {
return deviceContext.getInstalledCode(task.getId(), executable.getEntryPoint());
final PrebuiltTask prebuiltTask = (PrebuiltTask) task;
if (deviceContext.isCached(task.getId(), prebuiltTask.getEntryPoint())) {
return deviceContext.getInstalledCode(task.getId(), prebuiltTask.getEntryPoint());
}

final Path path = Paths.get(executable.getFilename());
TornadoInternalError.guarantee(path.toFile().exists(), "file does not exist: %s", executable.getFilename());
try {
final byte[] source = Files.readAllBytes(path);

OCLInstalledCode installedCode;
if (OCLBackend.isDeviceAnFPGAAccelerator(deviceContext)) {
// A) for FPGA
installedCode = deviceContext.installCode(task.getId(), executable.getEntryPoint(), source, task.meta().isPrintKernelEnabled());
} else {
// B) for CPU multi-core or GPU
installedCode = deviceContext.installCode(executable.meta(), task.getId(), executable.getEntryPoint(), source);
}
return installedCode;
} catch (IOException e) {
e.printStackTrace();
byte[] source = getSource(prebuiltTask);
OCLInstalledCode installedCode;
if (OCLBackend.isDeviceAnFPGAAccelerator(deviceContext)) {
// A) for FPGA
installedCode = deviceContext.installCode(task.getId(), prebuiltTask.getEntryPoint(), source, task.meta().isPrintKernelEnabled());
} else {
// B) for CPU multi-core or GPU
installedCode = deviceContext.installCode(prebuiltTask.meta(), task.getId(), prebuiltTask.getEntryPoint(), source);
}
return null;
return installedCode;

}

private TornadoInstalledCode compileJavaToAccelerator(SchedulableTask task) {
Expand Down Expand Up @@ -535,21 +545,7 @@ private XPUBuffer createDeviceBuffer(Class<?> type, Object object, OCLDeviceCont
result = new OCLVectorWrapper(deviceContext, object, batchSize);
} else if (object instanceof MemorySegment) {
result = new OCLMemorySegmentWrapper(deviceContext, batchSize);
} else if (object instanceof IntArray) {
result = new OCLMemorySegmentWrapper(deviceContext, batchSize);
} else if (object instanceof FloatArray) {
result = new OCLMemorySegmentWrapper(deviceContext, batchSize);
} else if (object instanceof DoubleArray) {
result = new OCLMemorySegmentWrapper(deviceContext, batchSize);
} else if (object instanceof LongArray) {
result = new OCLMemorySegmentWrapper(deviceContext, batchSize);
} else if (object instanceof ShortArray) {
result = new OCLMemorySegmentWrapper(deviceContext, batchSize);
} else if (object instanceof ByteArray) {
result = new OCLMemorySegmentWrapper(deviceContext, batchSize);
} else if (object instanceof CharArray) {
result = new OCLMemorySegmentWrapper(deviceContext, batchSize);
} else if (object instanceof HalfFloatArray) {
} else if (object instanceof TornadoNativeArray && !(object instanceof Tensor)) {
result = new OCLMemorySegmentWrapper(deviceContext, batchSize);
} else {
result = new OCLXPUBuffer(deviceContext, object);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import static uk.ac.manchester.tornado.drivers.ptx.graal.PTXCodeUtil.buildKernelName;

import java.io.IOException;
import java.io.InputStream;
import java.lang.foreign.MemorySegment;
import java.nio.file.Files;
import java.nio.file.Path;
Expand All @@ -49,14 +50,8 @@
import uk.ac.manchester.tornado.api.memory.XPUBuffer;
import uk.ac.manchester.tornado.api.profiler.ProfilerType;
import uk.ac.manchester.tornado.api.profiler.TornadoProfiler;
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
import uk.ac.manchester.tornado.api.types.arrays.CharArray;
import uk.ac.manchester.tornado.api.types.arrays.DoubleArray;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import uk.ac.manchester.tornado.api.types.arrays.LongArray;
import uk.ac.manchester.tornado.api.types.arrays.ShortArray;
import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
import uk.ac.manchester.tornado.api.types.tensors.Tensor;
import uk.ac.manchester.tornado.drivers.common.TornadoBufferProvider;
import uk.ac.manchester.tornado.drivers.ptx.PTX;
import uk.ac.manchester.tornado.drivers.ptx.PTXBackendImpl;
Expand Down Expand Up @@ -201,24 +196,40 @@ private TornadoInstalledCode compileTask(SchedulableTask task) {
}
}

private byte[] getSource(PrebuiltTask prebuiltTask) {
byte[] source;
Class<?> klass = prebuiltTask.getPackageClass();
if (klass != null) {
try (InputStream inputStream = klass.getClassLoader().getResourceAsStream(prebuiltTask.getFilename())) {
TornadoInternalError.guarantee(inputStream != null, "file does not exist: %s", prebuiltTask.getFilename());
source = inputStream.readAllBytes();
} catch (IOException e) {
throw new TornadoBailoutRuntimeException("[Error] I/O Exception in readAllBytes", e);
}
} else {
final Path path = Paths.get(prebuiltTask.getFilename());
TornadoInternalError.guarantee(path.toFile().exists(), "file does not exist: %s", prebuiltTask.getFilename());
try {
source = Files.readAllBytes(path);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
return source;
}

private TornadoInstalledCode compilePreBuiltTask(SchedulableTask task) {
final PTXDeviceContext deviceContext = getDeviceContext();
final PrebuiltTask executable = (PrebuiltTask) task;
String functionName = buildKernelName(executable.getEntryPoint(), executable);
if (deviceContext.isCached(executable.getEntryPoint(), executable)) {
final PrebuiltTask prebuiltTask = (PrebuiltTask) task;
String functionName = buildKernelName(prebuiltTask.getEntryPoint(), prebuiltTask);

if (deviceContext.isCached(prebuiltTask.getEntryPoint(), prebuiltTask)) {
return deviceContext.getInstalledCode(functionName);
}

final Path path = Paths.get(executable.getFilename());
TornadoInternalError.guarantee(path.toFile().exists(), "file does not exist: %s", executable.getFilename());
try {
byte[] source = Files.readAllBytes(path);
source = PTXCodeUtil.getCodeWithAttachedPTXHeader(source, getBackend());
return deviceContext.installCode(functionName, source, executable.getEntryPoint(), task.meta().isPrintKernelEnabled());
} catch (IOException e) {
e.printStackTrace();
}
return null;
byte[] source = getSource(prebuiltTask);
byte[] binary = PTXCodeUtil.getCodeWithAttachedPTXHeader(source, getBackend());
return deviceContext.installCode(functionName, binary, prebuiltTask.getEntryPoint(), task.meta().isPrintKernelEnabled());
}

@Override
Expand All @@ -229,8 +240,7 @@ public boolean isFullJITMode(SchedulableTask task) {
@Override
public TornadoInstalledCode getCodeFromCache(SchedulableTask task) {
String methodName;
if (task instanceof PrebuiltTask) {
PrebuiltTask prebuiltTask = (PrebuiltTask) task;
if (task instanceof PrebuiltTask prebuiltTask) {
methodName = prebuiltTask.getEntryPoint();
} else {
CompilableTask compilableTask = (CompilableTask) task;
Expand Down Expand Up @@ -260,21 +270,7 @@ private XPUBuffer createDeviceBuffer(Class<?> type, Object object, long batchSiz
result = new PTXVectorWrapper(getDeviceContext(), object, batchSize);
} else if (object instanceof MemorySegment) {
result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize);
} else if (object instanceof IntArray) {
result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize);
} else if (object instanceof FloatArray) {
result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize);
} else if (object instanceof DoubleArray) {
result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize);
} else if (object instanceof LongArray) {
result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize);
} else if (object instanceof ShortArray) {
result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize);
} else if (object instanceof ByteArray) {
result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize);
} else if (object instanceof CharArray) {
result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize);
} else if (object instanceof HalfFloatArray) {
} else if (object instanceof TornadoNativeArray && !(object instanceof Tensor)) {
result = new PTXMemorySegmentWrapper(getDeviceContext(), batchSize);
} else {
result = new PTXObjectWrapper(getDeviceContext(), object);
Expand Down Expand Up @@ -527,8 +523,7 @@ public void sync(long executionPlanId) {

@Override
public boolean equals(Object obj) {
if (obj instanceof PTXTornadoDevice) {
final PTXTornadoDevice other = (PTXTornadoDevice) obj;
if (obj instanceof PTXTornadoDevice other) {
return (other.deviceIndex == deviceIndex);
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
package uk.ac.manchester.tornado.drivers.spirv.runtime;

import java.lang.foreign.MemorySegment;
import java.net.URL;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
Expand Down Expand Up @@ -140,14 +141,28 @@ public TornadoInstalledCode installCode(SchedulableTask task) {
}
}

private TornadoInstalledCode compilePreBuiltTask(PrebuiltTask task) {
private Path getSource(PrebuiltTask prebuiltTask) {
Class<?> klass = prebuiltTask.getPackageClass();
if (klass != null) {
URL url = klass.getClassLoader().getResource(prebuiltTask.getFilename());
if (url != null) {
return Paths.get(url.getPath());
} else {
throw new TornadoRuntimeException("Prebuilt file path file not found: " + url.getPath());
}
} else {
return Paths.get(prebuiltTask.getFilename());
}
}

private TornadoInstalledCode compilePreBuiltTask(PrebuiltTask prebuiltTask) {
final SPIRVDeviceContext deviceContext = getDeviceContext();
if (deviceContext.isCached(task.getId(), task.getEntryPoint())) {
return deviceContext.getInstalledCode(task.getId(), task.getEntryPoint());
if (deviceContext.isCached(prebuiltTask.getId(), prebuiltTask.getEntryPoint())) {
return deviceContext.getInstalledCode(prebuiltTask.getId(), prebuiltTask.getEntryPoint());
}
final Path pathToSPIRVBin = Paths.get(task.getFilename());
TornadoInternalError.guarantee(pathToSPIRVBin.toFile().exists(), "files does not exists %s", task.getFilename());
return deviceContext.installBinary(task.meta(), task.getId(), task.getEntryPoint(), task.getFilename());
final Path pathToSPIRVBin = getSource(prebuiltTask);
TornadoInternalError.guarantee(pathToSPIRVBin.toFile().exists(), "files does not exists %s", prebuiltTask.getFilename());
return deviceContext.installBinary(prebuiltTask.meta(), prebuiltTask.getId(), prebuiltTask.getEntryPoint(), prebuiltTask.getFilename());
}

public SPIRVBackend getBackend() {
Expand Down
Loading