Skip to content

Commit 29a6f57

Browse files
authored
Add support for PyTorch 1.1 (#141)
1 parent cb13e9d commit 29a6f57

File tree

2 files changed

+32
-19
lines changed

2 files changed

+32
-19
lines changed

get_deps.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ tar xf ${LIBTF_ARCHIVE} --no-same-owner --strip-components=1 -C ${PREFIX}
5757

5858
## PYTORCH
5959

60-
PT_VERSION="1.0.1"
60+
PT_VERSION="1.1.0"
6161
#PT_VERSION="latest"
6262

6363
if [[ "$OSTYPE" == "linux-gnu" ]]; then
@@ -95,9 +95,9 @@ rm -rf libtorch
9595

9696
if [[ "${PT_OS}" == "macos" ]]; then
9797
# also download mkl
98-
MKL_BUNDLE=mklml_mac_2019.0.1.20180928
98+
MKL_BUNDLE=mklml_mac_2019.0.3.20190220
9999
if [ ! -e "${MKL_BUNDLE}.tgz" ]; then
100-
wget "https://github.com/intel/mkl-dnn/releases/download/v0.17.1/${MKL_BUNDLE}.tgz"
100+
wget "https://github.com/intel/mkl-dnn/releases/download/v0.18/${MKL_BUNDLE}.tgz"
101101
fi
102102
tar xf ${MKL_BUNDLE}.tgz --no-same-owner --strip-components=1 -C ${PREFIX}
103103
fi

util/libtorch_c/torch_c.cpp

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
#include "torch_c.h"
22
#include <torch/torch.h>
33
#include <torch/csrc/jit/import.h>
4+
#include <torch/csrc/jit/script/compilation_unit.h>
45
#include <iostream>
56
#include <sstream>
67

78
#include <ATen/Functions.h>
89

910
namespace {
1011

11-
static DLDataType getDLDataType(const at::Type& type) {
12+
static DLDataType getDLDataType(const at::Tensor& t) {
1213
DLDataType dtype;
1314
dtype.lanes = 1;
14-
dtype.bits = type.elementSizeInBytes() * 8;
15-
switch (type.scalarType()) {
15+
dtype.bits = t.element_size() * 8;
16+
switch (t.scalar_type()) {
1617
case at::ScalarType::Byte:
1718
dtype.code = DLDataTypeCode::kDLUInt;
1819
break;
@@ -37,6 +38,10 @@ static DLDataType getDLDataType(const at::Type& type) {
3738
case at::ScalarType::Half:
3839
dtype.code = DLDataTypeCode::kDLFloat;
3940
break;
41+
case at::ScalarType::Bool:
42+
throw std::logic_error("Bool is not supported by dlpack");
43+
case at::ScalarType::QInt8:
44+
throw std::logic_error("QInt8 is not supported by dlpack");
4045
case at::ScalarType::ComplexHalf:
4146
throw std::logic_error("ComplexHalf is not supported by dlpack");
4247
case at::ScalarType::ComplexFloat:
@@ -51,10 +56,10 @@ static DLDataType getDLDataType(const at::Type& type) {
5156
return dtype;
5257
}
5358

54-
static DLContext getDLContext(const at::Type& type, const int64_t& device_id) {
59+
static DLContext getDLContext(const at::Tensor& tensor, const int64_t& device_id) {
5560
DLContext ctx;
5661
ctx.device_id = device_id;
57-
if (type.is_cuda()) {
62+
if (tensor.is_cuda()) {
5863
ctx.device_type = DLDeviceType::kDLGPU;
5964
} else {
6065
ctx.device_type = DLDeviceType::kDLCPU;
@@ -134,8 +139,8 @@ torch::Tensor fromDLPack(const DLTensor* src) {
134139
at::DeviceType device_type = getATenDeviceType(src->ctx.device_type);
135140
at::ScalarType stype = toScalarType(src->dtype);
136141
return torch::from_blob(src->data,
137-
at::IntList(src->shape, src->ndim),
138-
at::IntList(src->strides, src->ndim),
142+
at::IntArrayRef(src->shape, src->ndim),
143+
at::IntArrayRef(src->strides, src->ndim),
139144
torch::device(device_type).dtype(stype));
140145
}
141146

@@ -158,9 +163,9 @@ DLManagedTensor* toManagedDLPack(const torch::Tensor& src) {
158163
if (src.is_cuda()) {
159164
device_id = src.get_device();
160165
}
161-
atDLMTensor->tensor.dl_tensor.ctx = getDLContext(src.type(), device_id);
166+
atDLMTensor->tensor.dl_tensor.ctx = getDLContext(src, device_id);
162167
atDLMTensor->tensor.dl_tensor.ndim = src.dim();
163-
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src.type());
168+
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
164169
atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(src.sizes().data());
165170
atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(src.strides().data());
166171
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
@@ -169,6 +174,7 @@ DLManagedTensor* toManagedDLPack(const torch::Tensor& src) {
169174

170175
struct ModuleContext {
171176
std::shared_ptr<torch::jit::script::Module> module;
177+
std::shared_ptr<torch::jit::script::CompilationUnit> cu;
172178
DLDeviceType device;
173179
};
174180

@@ -191,8 +197,6 @@ void torchRunModule(ModuleContext* ctx, const char* fnName,
191197
throw std::runtime_error(std::string("Unsupported device ") + std::to_string(ctx->device));
192198
}
193199

194-
torch::jit::script::Method& method = ctx->module->get_method(fnName);
195-
196200
torch::jit::Stack stack;
197201

198202
for (int i=0; i<nInputs; i++) {
@@ -201,7 +205,14 @@ void torchRunModule(ModuleContext* ctx, const char* fnName,
201205
stack.push_back(tensor.to(device));
202206
}
203207

204-
method.run(stack);
208+
if (ctx->module) {
209+
torch::jit::script::Method& method = ctx->module->get_method(fnName);
210+
method.run(stack);
211+
}
212+
else {
213+
torch::jit::script::Function& fn = ctx->cu->get_function(fnName);
214+
fn.run(stack);
215+
}
205216

206217
torch::DeviceType output_device = torch::kCPU;
207218

@@ -254,8 +265,8 @@ extern "C" DLManagedTensor* torchNewTensor(DLDataType dtype, long ndims, int64_t
254265
at::DeviceType device_type = getATenDeviceType(kDLCPU);
255266
at::ScalarType stype = toScalarType(dtype);
256267
torch::Tensor tensor = torch::from_blob(data,
257-
at::IntList(shape, ndims),
258-
at::IntList(strides, ndims),
268+
at::IntArrayRef(shape, ndims),
269+
at::IntArrayRef(strides, ndims),
259270
torch::device(at::DeviceType::CPU).dtype(stype));
260271

261272
DLManagedTensor *dl_tensor = toManagedDLPack(tensor);
@@ -269,8 +280,9 @@ extern "C" void* torchCompileScript(const char* script, DLDeviceType device,
269280
ModuleContext* ctx = new ModuleContext();
270281
ctx->device = device;
271282
try {
272-
auto module = torch::jit::compile(script);
273-
ctx->module = module;
283+
auto cu = torch::jit::compile(script);
284+
ctx->cu = cu;
285+
ctx->module = nullptr;
274286
}
275287
catch(std::exception& e) {
276288
size_t len = strlen(e.what());
@@ -297,6 +309,7 @@ extern "C" void* torchLoadModel(const char* graph, size_t graphlen, DLDeviceType
297309
}
298310
module->to(aten_device);
299311
ctx->module = module;
312+
ctx->cu = nullptr;
300313
}
301314
catch(std::exception& e) {
302315
size_t len = strlen(e.what());

0 commit comments

Comments
 (0)