From f7ea780f7313b9b754eaa7b0ccb8adc4a2250fd1 Mon Sep 17 00:00:00 2001 From: jaylee Date: Wed, 11 Aug 2021 18:37:17 +0900 Subject: [PATCH 01/10] Added converters for .float(), .bool(), .int() casting operations --- torch2trt/converters/cast.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 torch2trt/converters/cast.py diff --git a/torch2trt/converters/cast.py b/torch2trt/converters/cast.py new file mode 100644 index 00000000..490c190c --- /dev/null +++ b/torch2trt/converters/cast.py @@ -0,0 +1,33 @@ +from torch2trt.torch2trt import * +from torch2trt.module_test import add_module_test + + +def convert_cast(ctx): + """ + A simple converter for supporting casting operations. + + IMPORTANT: Note that because TensorRT does not support + 64 bit data types, .long() will not be supported + """ + input_tensor = ctx.method_args[0] + layer = ctx.network.add_identity(input_tensor._trt) + output = ctx.method_return + output._trt = layer.get_output(0) + + +@tensorrt_converter("torch.float") +@tensorrt_converter("torch.Tensor.float") +def convert_float(ctx): + convert_cast(ctx) + + +@tensorrt_converter("torch.bool") +@tensorrt_converter("torch.Tensor.bool") +def convert_bool(ctx): + convert_cast(ctx) + + +@tensorrt_converter("torch.float") +@tensorrt_converter("torch.Tensor.float") +def convert_bool(ctx): + convert_cast(ctx) From 3bb64863bae0a42a45a307971455a9a1a892cef6 Mon Sep 17 00:00:00 2001 From: jaylee Date: Wed, 11 Aug 2021 18:38:58 +0900 Subject: [PATCH 02/10] Added tests for each casting case --- torch2trt/converters/cast.py | 44 +++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/torch2trt/converters/cast.py b/torch2trt/converters/cast.py index 490c190c..381e9e06 100644 --- a/torch2trt/converters/cast.py +++ b/torch2trt/converters/cast.py @@ -7,7 +7,7 @@ def convert_cast(ctx): A simple converter for supporting casting operations. IMPORTANT: Note that because TensorRT does not support - 64 bit data types, .long() will not be supported + 64 bit data types, .long() is not included. """ input_tensor = ctx.method_args[0] layer = ctx.network.add_identity(input_tensor._trt) @@ -31,3 +31,45 @@ def convert_bool(ctx): @tensorrt_converter("torch.Tensor.float") def convert_bool(ctx): convert_cast(ctx) + + +class ConvertToFloat(torch.nn.Module): + def __init__(self): + super(ConvertToFloat, self).__init__() + + def forward(self, x): + return x.float() + + +class ConvertToInt(torch.nn.Module): + def __init__(self): + super(ConvertToInt, self).__init__() + + def forward(self, x): + return x.int() + + +class ConvertToBool(torch.nn.Module): + def __init__(self): + super(ConvertToBool, self).__init__() + + def forward(self, x): + return x.bool() + + +@add_module_test(torch.bool, torch.device('cuda'), [(1, 3, 3)]) +@add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)]) +def test_float_casting(): + return ConvertToFloat() + + +@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) +@add_module_test(torch.bool, torch.device('cuda'), [(1, 3, 3)]) +def test_int_casting(): + return ConvertToInt() + + +@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) +@add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)]) +def test_bool_casting(): + return ConvertToBool() From 65a193bf91e62a4890642eabf382f8196f3c0ec1 Mon Sep 17 00:00:00 2001 From: jaylee Date: Wed, 11 Aug 2021 18:58:10 +0900 Subject: [PATCH 03/10] Updated tests --- torch2trt/converters/cast.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/torch2trt/converters/cast.py b/torch2trt/converters/cast.py index 381e9e06..925ccb8a 100644 --- a/torch2trt/converters/cast.py +++ b/torch2trt/converters/cast.py @@ -7,7 +7,7 @@ def convert_cast(ctx): A simple converter for supporting casting operations. IMPORTANT: Note that because TensorRT does not support - 64 bit data types, .long() is not included. + 64 bit data types, .long() will not be supported """ input_tensor = ctx.method_args[0] layer = ctx.network.add_identity(input_tensor._trt) @@ -33,43 +33,43 @@ def convert_bool(ctx): convert_cast(ctx) -class ConvertToFloat(torch.nn.Module): +class TorchFloat(torch.nn.Module): def __init__(self): - super(ConvertToFloat, self).__init__() + super(TorchFloat, self).__init__() def forward(self, x): - return x.float() + return torch.float(x) -class ConvertToInt(torch.nn.Module): +class TorchInt(torch.nn.Module): def __init__(self): - super(ConvertToInt, self).__init__() + super(TorchInt, self).__init__() def forward(self, x): - return x.int() + return torch.int(x) -class ConvertToBool(torch.nn.Module): +class TorchBool(torch.nn.Module): def __init__(self): - super(ConvertToBool, self).__init__() + super(TorchBool, self).__init__() def forward(self, x): - return x.bool() + return torch.bool(x) @add_module_test(torch.bool, torch.device('cuda'), [(1, 3, 3)]) @add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)]) def test_float_casting(): - return ConvertToFloat() + return TorchFloat() @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) -@add_module_test(torch.bool, torch.device('cuda'), [(1, 3, 3)]) +@add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)]) def test_int_casting(): - return ConvertToInt() + return TorchInt() @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) @add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)]) def test_bool_casting(): - return ConvertToBool() + return TorchBool() From 4e12e927298d0057849cc6aad07b3dad21548382 Mon Sep 17 00:00:00 2001 From: jaylee Date: Wed, 11 Aug 2021 19:08:23 +0900 Subject: [PATCH 04/10] Added import to __init__.py and added test cases for .cast operations --- torch2trt/converters/__init__.py | 1 + torch2trt/converters/cast.py | 53 ++++++++++++++++++++++++++++++-- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/torch2trt/converters/__init__.py b/torch2trt/converters/__init__.py index 83465152..ced2eb7c 100644 --- a/torch2trt/converters/__init__.py +++ b/torch2trt/converters/__init__.py @@ -67,3 +67,4 @@ from .transpose import * from .unary import * from .view import * +from .cast import * diff --git a/torch2trt/converters/cast.py b/torch2trt/converters/cast.py index 925ccb8a..4b8eb1cd 100644 --- a/torch2trt/converters/cast.py +++ b/torch2trt/converters/cast.py @@ -32,6 +32,8 @@ def convert_bool(ctx): def convert_bool(ctx): convert_cast(ctx) +# Used for torch.Tensor. tests +# -------------------------------------------- class TorchFloat(torch.nn.Module): def __init__(self): @@ -59,17 +61,62 @@ def forward(self, x): @add_module_test(torch.bool, torch.device('cuda'), [(1, 3, 3)]) @add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)]) -def test_float_casting(): +def test_torch_float_cast(): return TorchFloat() @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) @add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)]) -def test_int_casting(): +def test_torch_int_cast(): return TorchInt() @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) @add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)]) -def test_bool_casting(): +def test_torch_bool_casting(): return TorchBool() + + +# Used for torch. tests +# -------------------------------------------- + +class DotFloat(torch.nn.Module): + def __init__(self): + super(DotFloat, self).__init__() + + def forward(self, x): + return x.float() + + +class DotInt(torch.nn.Module): + def __init__(self): + super(DotInt, self).__init__() + + def forward(self, x): + return x.int() + + +class DotBool(torch.nn.Module): + def __init__(self): + super(DotBool, self).__init__() + + def forward(self, x): + return x.bool() + + +@add_module_test(torch.bool, torch.device('cuda'), [(1, 3, 3)]) +@add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)]) +def test_float_cast(): + return DotFloat() + + +@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) +@add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)]) +def test_int_cast(): + return DotInt() + + +@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) +@add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)]) +def test_bool_cast(): + return DotBool() From 23e1b89c81be8cd6bf96de25d66027164dc38a2f Mon Sep 17 00:00:00 2001 From: jaylee Date: Wed, 11 Aug 2021 19:13:52 +0900 Subject: [PATCH 05/10] Removed redundant tests and converters. --- torch2trt/converters/cast.py | 49 +----------------------------------- 1 file changed, 1 insertion(+), 48 deletions(-) diff --git a/torch2trt/converters/cast.py b/torch2trt/converters/cast.py index 4b8eb1cd..7ca95d49 100644 --- a/torch2trt/converters/cast.py +++ b/torch2trt/converters/cast.py @@ -15,23 +15,21 @@ def convert_cast(ctx): output._trt = layer.get_output(0) -@tensorrt_converter("torch.float") @tensorrt_converter("torch.Tensor.float") def convert_float(ctx): convert_cast(ctx) -@tensorrt_converter("torch.bool") @tensorrt_converter("torch.Tensor.bool") def convert_bool(ctx): convert_cast(ctx) -@tensorrt_converter("torch.float") @tensorrt_converter("torch.Tensor.float") def convert_bool(ctx): convert_cast(ctx) + # Used for torch.Tensor. tests # -------------------------------------------- @@ -75,48 +73,3 @@ def test_torch_int_cast(): @add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)]) def test_torch_bool_casting(): return TorchBool() - - -# Used for torch. tests -# -------------------------------------------- - -class DotFloat(torch.nn.Module): - def __init__(self): - super(DotFloat, self).__init__() - - def forward(self, x): - return x.float() - - -class DotInt(torch.nn.Module): - def __init__(self): - super(DotInt, self).__init__() - - def forward(self, x): - return x.int() - - -class DotBool(torch.nn.Module): - def __init__(self): - super(DotBool, self).__init__() - - def forward(self, x): - return x.bool() - - -@add_module_test(torch.bool, torch.device('cuda'), [(1, 3, 3)]) -@add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)]) -def test_float_cast(): - return DotFloat() - - -@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) -@add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)]) -def test_int_cast(): - return DotInt() - - -@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) -@add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)]) -def test_bool_cast(): - return DotBool() From 59f308d9313209d66f79476811a31ba962debd9e Mon Sep 17 00:00:00 2001 From: jaylee Date: Wed, 11 Aug 2021 19:15:32 +0900 Subject: [PATCH 06/10] Added back the proper tests and class operations for testing. --- torch2trt/converters/cast.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/torch2trt/converters/cast.py b/torch2trt/converters/cast.py index 7ca95d49..c8b09759 100644 --- a/torch2trt/converters/cast.py +++ b/torch2trt/converters/cast.py @@ -33,43 +33,43 @@ def convert_bool(ctx): # Used for torch.Tensor. tests # -------------------------------------------- -class TorchFloat(torch.nn.Module): +class DotFloat(torch.nn.Module): def __init__(self): - super(TorchFloat, self).__init__() + super(DotFloat, self).__init__() def forward(self, x): - return torch.float(x) + return x.float() -class TorchInt(torch.nn.Module): +class DotInt(torch.nn.Module): def __init__(self): - super(TorchInt, self).__init__() + super(DotInt, self).__init__() def forward(self, x): - return torch.int(x) + return x.int() -class TorchBool(torch.nn.Module): +class DotBool(torch.nn.Module): def __init__(self): - super(TorchBool, self).__init__() + super(DotBool, self).__init__() def forward(self, x): - return torch.bool(x) + return x.bool() @add_module_test(torch.bool, torch.device('cuda'), [(1, 3, 3)]) @add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)]) def test_torch_float_cast(): - return TorchFloat() + return DotFloat() @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) @add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)]) def test_torch_int_cast(): - return TorchInt() + return DotInt() @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) @add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)]) def test_torch_bool_casting(): - return TorchBool() + return DotBool() From 43c77ae9b4a0a7e1813a3cabd6fbef619feaac86 Mon Sep 17 00:00:00 2001 From: jaylee Date: Wed, 11 Aug 2021 19:22:13 +0900 Subject: [PATCH 07/10] Updated methods and converter function names, as wel as tensorrt_converter arguments --- torch2trt/converters/cast.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch2trt/converters/cast.py b/torch2trt/converters/cast.py index c8b09759..782e5535 100644 --- a/torch2trt/converters/cast.py +++ b/torch2trt/converters/cast.py @@ -20,12 +20,12 @@ def convert_float(ctx): convert_cast(ctx) -@tensorrt_converter("torch.Tensor.bool") -def convert_bool(ctx): +@tensorrt_converter("torch.Tensor.int") +def convert_int(ctx): convert_cast(ctx) -@tensorrt_converter("torch.Tensor.float") +@tensorrt_converter("torch.Tensor.bool") def convert_bool(ctx): convert_cast(ctx) From 24b344c8788155fc7d35abae198b3485c8313438 Mon Sep 17 00:00:00 2001 From: jaylee Date: Wed, 11 Aug 2021 19:29:17 +0900 Subject: [PATCH 08/10] renamed test_torch_bool_casting to test_torch_bool_cast for consistency --- torch2trt/converters/cast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch2trt/converters/cast.py b/torch2trt/converters/cast.py index 782e5535..e3ec7a26 100644 --- a/torch2trt/converters/cast.py +++ b/torch2trt/converters/cast.py @@ -71,5 +71,5 @@ def test_torch_int_cast(): @add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 3)]) @add_module_test(torch.int32, torch.device('cuda'), [(1, 3, 3)]) -def test_torch_bool_casting(): +def test_torch_bool_cast(): return DotBool() From d087035ba7a2b6a804f2b1311d1171cd84150f7a Mon Sep 17 00:00:00 2001 From: jaylee Date: Wed, 11 Aug 2021 19:32:57 +0900 Subject: [PATCH 09/10] Removed redundant comments --- torch2trt/converters/cast.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torch2trt/converters/cast.py b/torch2trt/converters/cast.py index e3ec7a26..98df220a 100644 --- a/torch2trt/converters/cast.py +++ b/torch2trt/converters/cast.py @@ -30,9 +30,6 @@ def convert_bool(ctx): convert_cast(ctx) -# Used for torch.Tensor. tests -# -------------------------------------------- - class DotFloat(torch.nn.Module): def __init__(self): super(DotFloat, self).__init__() From d1589448e217ba101e0112daf80eb9287ea9d822 Mon Sep 17 00:00:00 2001 From: jaylee Date: Sat, 14 Aug 2021 16:33:05 +0900 Subject: [PATCH 10/10] Update layer precision based on torch2trt preicion modes --- torch2trt/converters/cast.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/torch2trt/converters/cast.py b/torch2trt/converters/cast.py index 98df220a..200ad955 100644 --- a/torch2trt/converters/cast.py +++ b/torch2trt/converters/cast.py @@ -2,6 +2,22 @@ from torch2trt.module_test import add_module_test +def _key_sanity_check(mode_key, torch2trt_properties): + """ + Raise an error if the given key does not exist. + This error will be raised as a warning in case + in case "mode-related" keys change in the future. + Args: + mode_key: A string key for the quantization mode. + E.g. ("int8_mode", "fp16_mode") + torch2trt_properties: A python dictionary containing + the torch2trt properties such as "int8_mode". + """ + if mode_key not in torch2trt_properties: + raise KeyError("{} is not a valid torch2trt property. " + "Check the torch2trt API for any changes.".format(mode_key)) + + def convert_cast(ctx): """ A simple converter for supporting casting operations. @@ -11,6 +27,22 @@ def convert_cast(ctx): """ input_tensor = ctx.method_args[0] layer = ctx.network.add_identity(input_tensor._trt) + trt_kwargs = ctx.torch2trt_kwargs + + # Sanity checks for debugging in case torch2trt property keys change + int8_mode_key, fp16_mode_key = "int8_mode", "fp16_mode" + _key_sanity_check(int8_mode_key, trt_kwargs) + _key_sanity_check(fp16_mode_key, trt_kwargs) + + is_int8_mode = trt_kwargs[int8_mode_key] + is_fp16_mode = trt_kwargs[fp16_mode_key] + if is_int8_mode: + layer.precision = trt.int8 + layer.set_output_type(0, trt.int8) + elif is_fp16_mode: + layer.precision = trt.float16 + layer.set_output_type(0, trt.float16) + output = ctx.method_return output._trt = layer.get_output(0)