1
1
#include " torch_c.h"
2
2
#include < torch/torch.h>
3
3
#include < torch/csrc/jit/import.h>
4
+ #include < torch/csrc/jit/script/compilation_unit.h>
4
5
#include < iostream>
5
6
#include < sstream>
6
7
7
8
#include < ATen/Functions.h>
8
9
9
10
namespace {
10
11
11
- static DLDataType getDLDataType (const at::Type& type ) {
12
+ static DLDataType getDLDataType (const at::Tensor& t ) {
12
13
DLDataType dtype;
13
14
dtype.lanes = 1 ;
14
- dtype.bits = type. elementSizeInBytes () * 8 ;
15
- switch (type. scalarType ()) {
15
+ dtype.bits = t. element_size () * 8 ;
16
+ switch (t. scalar_type ()) {
16
17
case at::ScalarType::Byte:
17
18
dtype.code = DLDataTypeCode::kDLUInt ;
18
19
break ;
@@ -37,6 +38,10 @@ static DLDataType getDLDataType(const at::Type& type) {
37
38
case at::ScalarType::Half:
38
39
dtype.code = DLDataTypeCode::kDLFloat ;
39
40
break ;
41
+ case at::ScalarType::Bool:
42
+ throw std::logic_error (" Bool is not supported by dlpack" );
43
+ case at::ScalarType::QInt8:
44
+ throw std::logic_error (" QInt8 is not supported by dlpack" );
40
45
case at::ScalarType::ComplexHalf:
41
46
throw std::logic_error (" ComplexHalf is not supported by dlpack" );
42
47
case at::ScalarType::ComplexFloat:
@@ -51,10 +56,10 @@ static DLDataType getDLDataType(const at::Type& type) {
51
56
return dtype;
52
57
}
53
58
54
- static DLContext getDLContext (const at::Type& type , const int64_t & device_id) {
59
+ static DLContext getDLContext (const at::Tensor& tensor , const int64_t & device_id) {
55
60
DLContext ctx;
56
61
ctx.device_id = device_id;
57
- if (type .is_cuda ()) {
62
+ if (tensor .is_cuda ()) {
58
63
ctx.device_type = DLDeviceType::kDLGPU ;
59
64
} else {
60
65
ctx.device_type = DLDeviceType::kDLCPU ;
@@ -134,8 +139,8 @@ torch::Tensor fromDLPack(const DLTensor* src) {
134
139
at::DeviceType device_type = getATenDeviceType (src->ctx .device_type );
135
140
at::ScalarType stype = toScalarType (src->dtype );
136
141
return torch::from_blob (src->data ,
137
- at::IntList (src->shape , src->ndim ),
138
- at::IntList (src->strides , src->ndim ),
142
+ at::IntArrayRef (src->shape , src->ndim ),
143
+ at::IntArrayRef (src->strides , src->ndim ),
139
144
torch::device (device_type).dtype (stype));
140
145
}
141
146
@@ -158,9 +163,9 @@ DLManagedTensor* toManagedDLPack(const torch::Tensor& src) {
158
163
if (src.is_cuda ()) {
159
164
device_id = src.get_device ();
160
165
}
161
- atDLMTensor->tensor .dl_tensor .ctx = getDLContext (src. type () , device_id);
166
+ atDLMTensor->tensor .dl_tensor .ctx = getDLContext (src, device_id);
162
167
atDLMTensor->tensor .dl_tensor .ndim = src.dim ();
163
- atDLMTensor->tensor .dl_tensor .dtype = getDLDataType (src. type () );
168
+ atDLMTensor->tensor .dl_tensor .dtype = getDLDataType (src);
164
169
atDLMTensor->tensor .dl_tensor .shape = const_cast <int64_t *>(src.sizes ().data ());
165
170
atDLMTensor->tensor .dl_tensor .strides = const_cast <int64_t *>(src.strides ().data ());
166
171
atDLMTensor->tensor .dl_tensor .byte_offset = 0 ;
@@ -169,6 +174,7 @@ DLManagedTensor* toManagedDLPack(const torch::Tensor& src) {
169
174
170
175
struct ModuleContext {
171
176
std::shared_ptr<torch::jit::script::Module> module ;
177
+ std::shared_ptr<torch::jit::script::CompilationUnit> cu;
172
178
DLDeviceType device;
173
179
};
174
180
@@ -191,8 +197,6 @@ void torchRunModule(ModuleContext* ctx, const char* fnName,
191
197
throw std::runtime_error (std::string (" Unsupported device " ) + std::to_string (ctx->device ));
192
198
}
193
199
194
- torch::jit::script::Method& method = ctx->module ->get_method (fnName);
195
-
196
200
torch::jit::Stack stack;
197
201
198
202
for (int i=0 ; i<nInputs; i++) {
@@ -201,7 +205,14 @@ void torchRunModule(ModuleContext* ctx, const char* fnName,
201
205
stack.push_back (tensor.to (device));
202
206
}
203
207
204
- method.run (stack);
208
+ if (ctx->module ) {
209
+ torch::jit::script::Method& method = ctx->module ->get_method (fnName);
210
+ method.run (stack);
211
+ }
212
+ else {
213
+ torch::jit::script::Function& fn = ctx->cu ->get_function (fnName);
214
+ fn.run (stack);
215
+ }
205
216
206
217
torch::DeviceType output_device = torch::kCPU ;
207
218
@@ -254,8 +265,8 @@ extern "C" DLManagedTensor* torchNewTensor(DLDataType dtype, long ndims, int64_t
254
265
at::DeviceType device_type = getATenDeviceType (kDLCPU );
255
266
at::ScalarType stype = toScalarType (dtype);
256
267
torch::Tensor tensor = torch::from_blob (data,
257
- at::IntList (shape, ndims),
258
- at::IntList (strides, ndims),
268
+ at::IntArrayRef (shape, ndims),
269
+ at::IntArrayRef (strides, ndims),
259
270
torch::device (at::DeviceType::CPU).dtype (stype));
260
271
261
272
DLManagedTensor *dl_tensor = toManagedDLPack (tensor);
@@ -269,8 +280,9 @@ extern "C" void* torchCompileScript(const char* script, DLDeviceType device,
269
280
ModuleContext* ctx = new ModuleContext ();
270
281
ctx->device = device;
271
282
try {
272
- auto module = torch::jit::compile (script);
273
- ctx->module = module ;
283
+ auto cu = torch::jit::compile (script);
284
+ ctx->cu = cu;
285
+ ctx->module = nullptr ;
274
286
}
275
287
catch (std::exception& e) {
276
288
size_t len = strlen (e.what ());
@@ -297,6 +309,7 @@ extern "C" void* torchLoadModel(const char* graph, size_t graphlen, DLDeviceType
297
309
}
298
310
module ->to (aten_device);
299
311
ctx->module = module ;
312
+ ctx->cu = nullptr ;
300
313
}
301
314
catch (std::exception& e) {
302
315
size_t len = strlen (e.what ());
0 commit comments