Skip to content

Commit f7f5385

Browse files
jerryzh168soumith
authored andcommitted
Quantized Tensor support copy (pytorch#28612)
Summary: Pull Request resolved: pytorch#28612 att Test Plan: python test/test_quantized_tensor.py Imported from OSS Differential Revision: D18255247 fbshipit-source-id: 814b12640fdf9d79b27482ee642ce430dbaeea68
1 parent 432724b commit f7f5385

File tree

9 files changed

+64
-4
lines changed

9 files changed

+64
-4
lines changed

aten/src/TH/THGenerateQInt32Type.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#define Real QInt32
88
#define RealUnderlying Int
99
#define THQUANTIZED
10+
#define THQINT32
1011
#define TH_REAL_IS_BYTE
1112
#line 1 TH_GENERIC_FILE
1213
#include TH_GENERIC_FILE
@@ -15,6 +16,7 @@
1516
#undef Real
1617
#undef RealUnderlying
1718
#undef TH_REAL_IS_BYTE
19+
#undef THQINT32
1820
#undef THQUANTIZED
1921

2022
#ifndef THGenerateManyTypes

aten/src/TH/THGenerateQInt8Type.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#define Real QInt8
88
#define RealUnderlying Char
99
#define THQUANTIZED
10+
#define THQINT8
1011
#define TH_REAL_IS_BYTE
1112
#line 1 TH_GENERIC_FILE
1213
#include TH_GENERIC_FILE
@@ -15,6 +16,7 @@
1516
#undef Real
1617
#undef RealUnderlying
1718
#undef TH_REAL_IS_BYTE
19+
#undef THQINT8
1820
#undef THQUANTIZED
1921

2022
#ifndef THGenerateManyTypes

aten/src/TH/THGenerateQUInt8Type.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#define Real QUInt8
88
#define RealUnderlying Byte
99
#define THQUANTIZED
10+
#define THQUINT8
1011
#define TH_REAL_IS_BYTE
1112
#line 1 TH_GENERIC_FILE
1213
#include TH_GENERIC_FILE
@@ -15,6 +16,7 @@
1516
#undef Real
1617
#undef RealUnderlying
1718
#undef TH_REAL_IS_BYTE
19+
#undef THQUINT8
1820
#undef THQUANTIZED
1921

2022
#ifndef THGenerateManyTypes

aten/src/TH/generic/THStorage.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
#define THLongStorage THStorage
3636
#define THBoolStorage THStorage
3737
#define THBFloat16Storage THStorage
38+
#define THQUInt8Storage THStorage
39+
#define THQInt8Storage THStorage
40+
#define THQInt32Storage THStorage
3841

3942
TH_API scalar_t* THStorage_(data)(const THStorage*);
4043
TH_API ptrdiff_t THStorage_(size)(const THStorage*);

aten/src/TH/generic/THStorageCopy.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,14 @@ IMPLEMENT_THStorage_COPY(Double)
3535
IMPLEMENT_THStorage_COPY(Half)
3636
IMPLEMENT_THStorage_COPY(Bool)
3737
IMPLEMENT_THStorage_COPY(BFloat16)
38+
#ifdef THQUINT8
39+
IMPLEMENT_THStorage_COPY(QUInt8)
40+
#endif
41+
#ifdef THQINT8
42+
IMPLEMENT_THStorage_COPY(QInt8)
43+
#endif
44+
#ifdef THQINT32
45+
IMPLEMENT_THStorage_COPY(QInt32)
46+
#endif
3847

3948
#endif

aten/src/TH/generic/THStorageCopy.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,14 @@ TH_API void THStorage_(copyDouble)(THStorage *storage, struct THDoubleStorage *s
1414
TH_API void THStorage_(copyHalf)(THStorage *storage, struct THHalfStorage *src);
1515
TH_API void THStorage_(copyBool)(THStorage *storage, struct THBoolStorage *src);
1616
TH_API void THStorage_(copyBFloat16)(THStorage *storage, struct THBFloat16Storage *src);
17+
#ifdef THQUINT8
18+
TH_API void THStorage_(copyQUInt8)(THStorage *storage, struct THQUInt8Storage *src);
19+
#endif
20+
#ifdef THQINT8
21+
TH_API void THStorage_(copyQInt8)(THStorage *storage, struct THQInt8Storage *src);
22+
#endif
23+
#ifdef THQINT32
24+
TH_API void THStorage_(copyQInt32)(THStorage *storage, struct THQInt32Storage *src);
25+
#endif
1726

1827
#endif

test/test_quantized_tensor.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
import io
5+
from copy import deepcopy
56

67
from common_utils import TestCase, run_tests
78
import tempfile
@@ -252,6 +253,13 @@ def test_qtensor_copy(self):
252253
q.copy_(q2)
253254
# check scale and zero_points has been copied
254255
self.assertEqual(q, q2)
256+
# deep copy
257+
scale, zero_point, dtype = 1.0, 2, torch.uint8
258+
q_int = torch.randint(0, 100, [3, 5], dtype=dtype)
259+
scale, zero_point = 2.0, 3
260+
q = torch._make_per_tensor_quantized_tensor(q_int, scale=scale, zero_point=zero_point)
261+
qc = deepcopy(q)
262+
self.assertEqual(qc, q)
255263

256264
def test_qtensor_clone(self):
257265
numel = 10
@@ -322,7 +330,6 @@ def test_qtensor_reshape(self):
322330
c = b.reshape(1, 4, 2, 3)
323331

324332
def test_qscheme_pickle(self):
325-
326333
f = Foo()
327334
buf = io.BytesIO()
328335
torch.save(f, buf)
@@ -332,5 +339,6 @@ def test_qscheme_pickle(self):
332339

333340
self.assertEqual(f2.qscheme, torch.per_tensor_symmetric)
334341

342+
335343
if __name__ == "__main__":
336344
run_tests()

torch/csrc/generic/Storage.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,15 @@ void THPStorage_(initCopyMethods)()
318318
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPDoubleStorageType, h, &THWStorage_(copyDouble));
319319
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPBoolStorageType, h, &THWStorage_(copyBool));
320320
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPBFloat16StorageType, h, &THWStorage_(copyBFloat16));
321+
#ifdef THQUINT8
322+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPQUInt8StorageType, h, &THWStorage_(copyQUInt8));
323+
#endif
324+
#ifdef THQINT8
325+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPQInt8StorageType, h, &THWStorage_(copyQInt8));
326+
#endif
327+
#ifdef THQINT32
328+
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPQInt32StorageType, h, &THWStorage_(copyQInt32));
329+
#endif
321330
#ifdef THC_GENERIC_FILE
322331
// copy from GPU types
323332
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPByteStorageType, h, &THWStorage_(copyCudaByte));

torch/tensor.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,26 @@ def __deepcopy__(self, memo):
3131
new_tensor = self.clone()
3232
else:
3333
new_storage = self.storage().__deepcopy__(memo)
34-
new_tensor = self.new()
35-
new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride())
34+
if self.is_quantized:
35+
if self.qscheme() == torch.per_tensor_affine:
36+
quantizer_params = self.qscheme(), self.q_scale(), self.q_zero_point()
37+
elif self.qscheme() == torch.per_channel_affine:
38+
quantizer_params = self.qscheme(), self.q_per_channel_scales(), self.q_per_channel_zero_points(), self.q_per_channel_axis()
39+
else:
40+
raise RuntimeError("Unsupported qscheme {} in deepcopy".format(self.qscheme()))
41+
new_tensor = torch._utils._rebuild_qtensor(
42+
new_storage,
43+
self.storage_offset(),
44+
self.size(),
45+
self.stride(),
46+
quantizer_params,
47+
self.requires_grad,
48+
self._backward_hooks)
49+
else:
50+
new_tensor = self.new()
51+
new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride())
52+
new_tensor.requires_grad = self.requires_grad
3653
memo[id(self)] = new_tensor
37-
new_tensor.requires_grad = self.requires_grad
3854
return new_tensor
3955

4056
def __reduce_ex__(self, proto):

0 commit comments

Comments
 (0)