You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Apr 28, 2023. It is now read-only.
Copy file name to clipboardExpand all lines: docs/source/framework/pytorch_integration/autograd_with_tc.rst
+5-8Lines changed: 5 additions & 8 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -6,14 +6,11 @@ a training layer with TC and be able to run backwards as well if the layer is pa
6
6
of a network. In order to write a training layer with TC, you need to follow the
7
7
steps below:
8
8
9
-
1. Define your TC language that has two definitions: one for the forward layer
10
-
and the other for the backward layer and pass it to :code:`tc.define` call. In
11
-
addition, also pass :code:`training=True` and the name of the backward TC :code:`backward`.
9
+
1. Define your TC language that has two definitions: one for the forward layer and the other for the backward layer and pass it to :code:`tc.define` call. In addition, also pass :code:`training=True` and the name of the backward TC :code:`backward`.
12
10
13
-
2. Create the Input Variables and Parameters. For example, weights should be marked
14
-
as Parameters and the inputs tensors as Variables.
11
+
2. Create the Input Variables and Parameters. For example, weights should be marked as Parameters and the inputs tensors as Variables.
15
12
16
-
3. Run the layer and get the output of forward pass
13
+
3. Run the layer and get the output of forward pass.
17
14
18
15
4. To see that the backward call works fine, you can call backward on the outputs.
19
16
@@ -79,7 +76,7 @@ them, the example for that would be:
79
76
In order to obtain options via autotuning for backward and forward layer, keep reading further.
80
77
81
78
82
-
Autotuning Training Layer
79
+
Autotuning training layer
83
80
-------------------------
84
81
85
82
You can autotune a training layer easily. The forward and backward layers will
@@ -114,7 +111,7 @@ You will find two cache files created: :code:`convolution_train.cuda/options` ha
114
111
options for the forward layer and :code:`convolution_train_backward.cuda/options` file
115
112
has options for the grad layer.
116
113
117
-
Reordering Grad Outputs
114
+
Reordering grad outputs
118
115
-----------------------
119
116
120
117
In the backward pass, TC uses the list of input tensors in the forward pass and appends
Copy file name to clipboardExpand all lines: docs/source/framework/pytorch_integration/getting_started.rst
+3-1Lines changed: 3 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -16,7 +16,9 @@ A **few cases** where TC can be useful:
16
16
17
17
* you are interested in fusing layers like group convolution, ReLU, FC *or*
18
18
19
-
* if you have a different new layer, let's call it :code:`hconv` (a variant of convolution), for which you wish you had an efficient kernel available
19
+
* if you have a different new layer, let's call it :code:`hconv` (a variant of convolution), for which you wish you had an efficient kernel available *or*
20
+
21
+
* if you have standard operation on different data layouts that you didn't want to use because you couldn't get good kernels for them
20
22
21
23
TC makes its very trivial to get CUDA code for such cases and many more. By providing
22
24
TC integration with PyTorch, we hope to make it further easy for PyTorch users
Copy file name to clipboardExpand all lines: docs/source/framework/pytorch_integration/writing_layers.rst
+21-16Lines changed: 21 additions & 16 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -57,13 +57,13 @@ There are two ways to set the :code:`Options`:
57
57
58
58
* **Autotuning**: You can autotune the kernel the kernel on certain input tensor sizes, cache the options and use them to run the layer. See :ref:`pytorch_autotune_layers` for how to autotune kernels.
59
59
60
-
* **Default Mapping**: We provide various default options that can be chosen to closely represent the kernel. THe defaults provided are:
60
+
* **Default Mapping**: We provide various default options that can be chosen to closely represent the kernel. The defaults provided are:
61
61
62
62
* :code:`pointwise`: if kernel resembles a pointwise operation
63
63
* :code:`mlp`: if kernel resembles an Linear layer operation
64
64
* :code:`conv`: if kernel resembles a convolution operation
65
-
* :code:`group_conv`: if kernel resembles a convolution operation
66
-
* :code:`naive`: if none of the above, then chose naive Default
65
+
* :code:`group_conv`: if kernel resembles a group convolution operation
66
+
* :code:`naive`: if none of the above, then chose naive default
67
67
68
68
An example for how to pass options:
69
69
@@ -126,9 +126,12 @@ happens only once and then you can keep running the layer.
126
126
Multiple TC definitions in language
127
127
-----------------------------------
128
128
129
-
Let's say you want to define all of your TCs in one string and later keep running
130
-
them. You an do so easily. Every time you want to run a different layer, you can
131
-
make a :code:`tc.define` call and get the layer.
129
+
Let's say you want to define all of your TCs in one string and later use that string
130
+
for running different operations defined in the string. You an do so easily. You
131
+
can define a :code:`lang` variable that holds the TC definition for all your operations.
132
+
Every time you want to run a different operation, you can make a :code:`tc.define` call
133
+
on the :code:`lang` variable, specify the :code:`name` corresponding to the operation
134
+
definition and get the TC layer for it. Below is an example for how to do this:
132
135
133
136
.. code-block:: python
134
137
@@ -215,7 +218,7 @@ adopt whatever feels more convenient.
215
218
out = avgpool(inp)
216
219
217
220
218
-
Manually Injecting external CUDA code
221
+
Manually injecting external CUDA code
219
222
-------------------------------------
220
223
221
224
If you have an external efficient CUDA code that you want to use rather than
@@ -248,17 +251,19 @@ call. For example:
248
251
a, b = torch.randn(100).cuda(), torch.randn(100).cuda()
249
252
out = add(a, b, grid=[1, 1, 1], block=[100, 1, 1])
250
253
251
-
In such cases, please note that TC doesn't modify the injected CUDA kernel. It will
252
-
simply run the kernel injected as is and TC will also not guarantee the performance
253
-
of the kernel. User needs to specify the :code:`grid` and :code:`block` values
254
-
when running the layer and TC will simply use those settings.
254
+
.. note::
255
+
256
+
In such cases, please note that TC doesn't modify the injected CUDA kernel. It will
257
+
simply run the kernel injected as is and TC will also not guarantee the performance
258
+
of the kernel. User needs to specify the :code:`grid` and :code:`block` values
259
+
when running the layer and TC will simply use those settings.
255
260
256
261
257
-
Builtin Functions
258
-
-----------------
262
+
Built-in Functions
263
+
------------------
259
264
260
-
TC allows using some CUDA builtin functions as well when defining the TC language.
261
-
During the execution, CUDA API will be called for those builtin functions. For example,
265
+
TC allows using some CUDA built-in functions as well when defining the TC language.
266
+
During the execution, CUDA API will be called for those built-in functions. For example,
262
267
let's say we want to use :code:`fmax` CUDA function in our TC language. An example
263
268
for how this would be done is below:
264
269
@@ -275,7 +280,7 @@ for how this would be done is below:
275
280
inp = torch.randn(100, 128).cuda()
276
281
out = relu(inp)
277
282
278
-
TC supports only a few builtin CUDA functions and not all. You can find the documentation
283
+
TC only supports a subset of built-in CUDA functions. You can find the documentation
279
284
for these functions at the official CUDA documentation `here <http://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__SINGLE.html#group__CUDA__MATH__SINGLE>`_.
Tensor Comprehensions (TC) is a framework agnostic library to **automatically**
7
-
synthesize high-performance Machine Learning kernels. TC relies on
7
+
synthesize high-performance machine learning kernels. TC relies on
8
8
`Halide <https://github.com/halide/Halide>`_ IR to express computation and analysis
9
9
tools to reason about it. TC uses :code:`polyhedral` compilation techniques to
10
10
(semi-)automatically decide how to perform this computation efficiently and produce
11
11
fast code. We also provide TC integration with PyTorch and Caffe2.
12
12
13
+
To automatically tune the performance of the kernel, we provide a genetic algorithms
14
+
based **Autotuner** details of which are available at :ref:`pytorch_autotune_layers`.
15
+
13
16
To read more about Tensor Comprehensions, see our documentation available
14
17
at https://facebookresearch.github.io/TensorComprehensions/ and C++ API documentation is
15
18
available at https://facebookresearch.github.io/TensorComprehensions/api.
16
19
17
20
We provide many **python examples** for expressing and running various different ML layers
18
21
with TC. The examples can be found `here <https://github.com/facebookresearch/TensorComprehensions/tree/master/test_python/layers>`_.
19
22
20
-
To read more about Framework integrations, checkout our documentation on `PyTorch <https://facebookresearch.github.io/TensorComprehensions/framework/pytorch_integration/getting_started.html>`_ integration
21
-
and `Caffe2 <https://facebookresearch.github.io/TensorComprehensions/framework/caffe2_integration/integration_with_example.html>`_
22
-
integration.
23
+
To read more about Framework integrations, checkout our documentation on `PyTorch integration <https://facebookresearch.github.io/TensorComprehensions/framework/pytorch_integration/getting_started.html>`_
24
+
and `Caffe2 integration <https://facebookresearch.github.io/TensorComprehensions/framework/caffe2_integration/integration_with_example.html>`_.
23
25
24
26
If you want to **integrate your framework** with TC, it's easy and the instructions are
25
27
available at https://facebookresearch.github.io/TensorComprehensions/integrating_any_ml_framework.html
0 commit comments