Skip to content

Commit

Permalink
Improves various guides, hides incomplete Executable APIs
Browse files Browse the repository at this point in the history
- Improves the introduction, compiler, and quickstart guides

- Fixes various docstrings

- Removes `Executable.get_input_info()` and `Executable.get_output_info()`
    because they weren't returning complete information (e.g. names missing).
    We should flesh out the API before making it publicly visible.
pranavm-nvidia committed Nov 9, 2024
1 parent e1d4974 commit efaf830
Showing 9 changed files with 258 additions and 251 deletions.
73 changes: 41 additions & 32 deletions tripy/README.md
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
# Tripy: A Python Programming Model For TensorRT

<!-- Tripy: DOC: OMIT Start -->
[**Installation**](#installation) | [**Quickstart**](#quickstart) | [**Documentation**](https://nvidia.github.io/TensorRT-Incubator/) | [**Examples**](./examples) | [**Contributing**](./CONTRIBUTING.md)
[**Installation**](#installation) | [**Getting Started**](#getting-started) | [**Documentation**](https://nvidia.github.io/TensorRT-Incubator/) | [**Examples**](./examples) | [**Contributing**](./CONTRIBUTING.md)

[![Tripy L1](https://github.com/NVIDIA/TensorRT-Incubator/actions/workflows/tripy-l1.yml/badge.svg)](https://github.com/NVIDIA/TensorRT-Incubator/actions/workflows/tripy-l1.yml)
<!-- Tripy: DOC: OMIT End -->
@@ -66,39 +66,48 @@ To get the latest changes in the repository, you can build Tripy wheels from sou

<!-- Tripy: DOC: OMIT End -->

## Quickstart
## Getting Started

In eager mode, Tripy works just like you'd expect:
```py
# doc: no-print-locals
import tripy as tp
a = tp.Tensor([1.0, 2.0])
print(a + 1)
```
We've included several guides in Tripy to make it easy to get started.
We recommend starting with the
[Introduction To Tripy](https://nvidia.github.io/TensorRT-Incubator/pre0_user_guides/00-introduction-to-tripy.html)
guide.
Tripy can also compile functions to generate efficient machine code for faster execution:
To get an idea of the look and feel of Tripy, let's take a look at a short code example.
All of the features used in this example are explained in more detail in the
introduction guide mentioned above.

```py
# doc: no-print-locals
def add(a, b):
return a + b
# When compiling, we need to specify shape and data type constraints on the inputs:
# a is a 1D dynamic shape tensor of shape (d,), where `d` can range from 1 to 5.
# `[1, 2, 5]` indicates a range from 1 to 5, with optimization for `d = 2`.
a_info = tp.InputInfo(shape=([1, 2, 5],), dtype=tp.float32)
# `b` is a 1D tensor of shape (1,).
b_info = tp.InputInfo((1,), dtype=tp.float32)
compiled_add = tp.compile(add, args=[a_info, b_info])
print(compiled_add(tp.Tensor([1., 2., 3.]), tp.Tensor([3.])))
# Define our model:
class Model(tp.Module):
def __init__(self):
self.conv = tp.Conv(in_channels=1, out_channels=1, kernel_dims=[3, 3])
def __call__(self, x):
x = self.conv(x)
x = tp.relu(x)
return x
# Initialize the model and populate weights:
model = Model()
model.load_state_dict(
{
"conv.weight": tp.ones((1, 1, 3, 3)),
"conv.bias": tp.ones((1,)),
}
)
inp = tp.ones((1, 1, 4, 4))
# Eager mode:
eager_out = model(inp)
# Compiled mode:
compiled_model = tp.compile(
model,
args=[tp.InputInfo(shape=(1, 1, 4, 4), dtype=tp.float32)],
)
compiled_out = compiled_model(inp)
```
For more details, see the
[Introduction To Tripy](https://nvidia.github.io/TensorRT-Incubator/pre0_user_guides/00-introduction-to-tripy.html)
guide.
142 changes: 72 additions & 70 deletions tripy/docs/pre0_user_guides/00-introduction-to-tripy.md
Original file line number Diff line number Diff line change
@@ -7,8 +7,6 @@ It aims to be fast, easy to debug, and provide an easy-to-use Pythonic interface

## Your First Tripy Program

But enough talk; let's see some code:

```py
# doc: no-print-locals
a = tp.arange(5)
@@ -18,54 +16,7 @@ assert np.array_equal(cp.from_dlpack(c).get(), np.arange(5, dtype=np.float32) +
```

This should look familiar if you've used linear algebra or deep learning libraries like
NumPy and PyTorch.


### Lazy Evaluation: Putting Off Work

One important point is that Tripy uses a lazy evaluation model; that is,
no computation is performed until a value is actually needed.

In the example above, that means that `c` will not be evaluated until it is used,
such as when we print its values.

In most cases, this is simply an implementation detail that you will not notice.
One exception to this is when attempting to time code. Consider the following code:

```py
# doc: no-print-locals
import time

start = time.time()
a = tp.arange(5)
b = tp.arange(5)
c = a + b + tp.tanh(a)
end = time.time()

print(f"Time to create 'c': {(end - start) * 1000:.3f} ms.")
```

It looks like Tripy is very fast! While Tripy *execution* is very fast, compiling the program
takes some time. The reason the time is so low relative to what we'd expect for initializing
and running the compiler is that *we're not doing that yet*.

The actual compilation and computation only happens when we evaluate `c`:

```py
# doc: no-print-locals
start = time.time()
print(c)
end = time.time()

print(f"Time to print 'c': {(end - start) * 1000:.3f} ms.")
```

That is why the time to print `c` is so much higher than the time to create it.

If we wanted to time individual parts of the model, we would insert calls to `.eval()`;
for example, adding a `c.eval()` prior to checking the end time would tell us how
long it took to compile and run the subgraph that computes `c`.

NumPy and PyTorch. Hopefully, the code above is self-explanatory, so we won't go into details.

## Organizing Code Using Modules

@@ -77,10 +28,10 @@ For example, we can define a Transfomer MLP block like so:

```py
class MLP(tp.Module):
def __init__(self, embedding_size, dtype=tp.float32):
def __init__(self, embd_size, dtype=tp.float32):
super().__init__()
self.c_fc = tp.Linear(embedding_size, 4 * embedding_size, bias=True, dtype=dtype)
self.c_proj = tp.Linear(4 * embedding_size, embedding_size, bias=True, dtype=dtype)
self.c_fc = tp.Linear(embd_size, 4 * embd_size, bias=True, dtype=dtype)
self.c_proj = tp.Linear(4 * embd_size, embd_size, bias=True, dtype=dtype)

def __call__(self, x):
x = self.c_fc(x)
@@ -92,14 +43,14 @@ class MLP(tp.Module):
To use it, we just need to construct and call it:

```py
mlp = MLP(embedding_size=2)
# doc: no-print-locals mlp
mlp = MLP(embd_size=2)

inp = tp.iota(shape=(1, 2), dim=1, dtype=tp.float32)
out = mlp(inp)
```


## To `compile` Or Not To `compile`
## Compiling Code

All the code we've seen so far has been using Tripy's eager mode. It is also possible to compile
functions or modules ahead of time, which can result in significantly better performance.
@@ -111,37 +62,88 @@ Let's compile the MLP module we defined above as an example:

```py
# doc: no-print-locals
# When we compile, we need to indicate which parameters to the function should be runtime inputs.
# In this case, MLP takes a single input tensor for which we can specify our desired shape and datatype.
# When we compile, we need to indicate which parameters to the function
# should be runtime inputs. In this case, MLP takes a single input tensor
# for which we can specify our desired shape and datatype.
fast_mlp = tp.compile(mlp, args=[tp.InputInfo(shape=(1, 2), dtype=tp.float32)])
```

It is also possible to compile for a range of possible input shapes.
See {func}`tripy.compile` for details.

Now let's benchmark the compiled version against eager mode:
```py
# doc: no-print-locals
import time

start = time.time()
out = mlp(inp)
out.eval() # Recall that we need to evaluate in order to actually materialize `out`
# We need to evaluate in order to actually materialize `out`.
# See the section on lazy evaluation below for details.
out.eval()
end = time.time()

eager_time = (end - start) * 1000
print(f"Eager mode time: {eager_time:.4f} ms")

ITERS = 10
start = time.time()
for _ in range(ITERS):
out = fast_mlp(inp)
out.eval()
out = fast_mlp(inp)
out.eval()
end = time.time()

compiled_time = ((end - start) / ITERS) * 1000
print(f"Compiled mode average time: {compiled_time:.4f} ms")
compiled_time = (end - start) * 1000
print(f"Compiled mode time: {compiled_time:.4f} ms")
# Make sure compiled mode is actually faster # doc: omit
assert compiled_time < 0.01 * eager_time # doc: omit
```

As you can see, the compiled module is significantly faster than running the module
in eager mode.
For more information on the compiler, compiled functions/modules, and dynamic shapes,
see the [compiler guide](project:./02-compiler.md).

## Things To Note

### Eager Mode: How Does It Work?

If you've used TensorRT before, you may know that it does not support an eager mode.
In order to provide eager mode support in Tripy, we actually need to compile the graph
under the hood.

Although we employ several tricks to make compile times faster when using eager mode,
we do still need to compile, and so eager mode will likely be slower than other
comparable frameworks.

Consequently, we suggest that you use eager mode primarily for debugging and
compiled mode for deployments.

### Lazy Evaluation: Putting Off Work

One important point is that Tripy uses a lazy evaluation model; that is,
no computation is performed until a value is actually needed.

In most cases, this is simply an implementation detail that you will not notice.
One exception to this is when attempting to time code. Consider the following code:

```py
# doc: no-print-locals
import time

start = time.time()
a = tp.arange(5)
b = tp.arange(5)
c = a + b + tp.tanh(a)
end = time.time()

print(f"Time to create 'c': {(end - start) * 1000:.3f} ms.")
```

Given what we said above about eager mode, it seems like Tripy is very fast!
Of course, this is because *we haven't actually done anything yet*.
The actual compilation and execution only happens when we evaluate `c`:

```py
# doc: no-print-locals
start = time.time()
print(c)
end = time.time()

print(f"Time to print 'c': {(end - start) * 1000:.3f} ms.")
```

That is why the time to print `c` is so much higher than the time to create it.
91 changes: 46 additions & 45 deletions tripy/docs/pre0_user_guides/02-compiler.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Using the Compiler

Modules and functions can be compiled ahead of time for better runtime performance.

*Note that the compiler imposes some requirements on the functions/modules it can compile.*
*See {func}`tripy.compile` for details.*

## Model Compilation And Deployment

Let's walk through a simple example of a [GEGLU](https://arxiv.org/abs/2002.05202v1) module defined below:
In this guide, we'll work with the [GEGLU](https://arxiv.org/abs/2002.05202v1) module defined below:

```py
# doc: no-print-locals
@@ -19,98 +20,98 @@ class GEGLU(tp.Module):
return x * tp.gelu(gate)
```

To run `GEGLU` in eager mode:
We can run this in eager mode like usual:

```py
# doc: no-print-locals
# doc: no-print-locals layer
layer = GEGLU(2, 8)

inp = tp.ones((1, 2))
out = layer(inp)
```

Now, let's try to optimize this model for inference using Tripy's {func}`tripy.compile`.
## Compiling

When we compile our module, we need to provide information about each input using {class}`tripy.InputInfo`.
The first argument for `InputInfo` is `shape`, where we specify either the static or
dynamic shape information for each dimension. In the example below, we assume the
shape of `inp` is static (`(1, 2)`). The second argument specifies the `dtype` for the input:
Let's optimize the module using {func}`tripy.compile`.

When we compile in Tripy, we need to provide shape and data type information about each runtime input
using the {class}`tripy.InputInfo` API. Other parameters to the function will be considered compile-time
constants and will be folded into the compiled function.

`GEGLU` only has one input, for which we'll create an `InputInfo` like so:
```py
# doc: no-print-locals
inp_info = tp.InputInfo(shape=(1, 2), dtype=tp.float32)
```
Now, we can call the `compile` function to obtain a compiled function and use it for inference:

Then we'll compile, which will give us a {class}`tripy.Executable` that we can run:

```py
# doc: no-print-locals
# doc: no-print-locals fast_geglu
fast_geglu = tp.compile(layer, args=[inp_info])
fast_geglu(inp).eval()

out = fast_geglu(inp)
```

### Optimization Profiles
## Dynamic Shapes

In the example above, we assumed `inp` has a static shape of `(1, 2)`.
Now, let's assume that the shape of `inp` can vary from `(1, 2)` to `(16, 2)`, with `(8, 2)`
being the shape we'd like to optimize for. To express this constraint to the compiler,
we can provide the range of shapes to `InputInfo` using `shape=([1, 8, 16], 2)`.
This indicates to the compiler that the first dimension can vary from 1 to 16,
and it should optimize for a size of 8.
When we compiled above, we used a static shape of `(1, 2)` for the input.
Tripy also supports specifying a range of possible values for each dimension like so:

```py
# doc: print-locals out out_change_shape
inp_info = tp.InputInfo(shape=([1, 8, 16], 2), dtype=tp.float32)
```

The shape we used above indicates that the 0th dimension should support a range of values
from `1` to `16`, optimizing for a value of `8`. For the 1st dimension, we continue using
a fixed value of `2`.

Let's compile again with our updated `InputInfo` and try changing the input shape:

```py
# doc: no-print-locals fast_geglu
fast_geglu = tp.compile(layer, args=[inp_info])
out = fast_geglu(inp)

# Let's change the shape of input to (2, 2)
inp = tp.Tensor([[1., 2.], [2., 3.]], dtype=tp.float32)
out_change_shape = fast_geglu(inp)
# We'll run with the input we created above, which is of shape (1, 2)
out0 = fast_geglu(inp)

# Now let's try an input of shape (2, 2):
inp1 = tp.Tensor([[1., 2.], [2., 3.]], dtype=tp.float32)
out1 = fast_geglu(inp1)
```

If we provide an input that does not comply with the dynamic shape constraint
given to the compiler, `Tripy` will produce an error with relevant information:
If we try using a shape outside of the valid range, the executable will throw a nice error:

<!-- Tripy: TEST: IGNORE Start -->
<!-- Tripy: TEST: XFAIL Start -->
```py
# doc: allow-exception
inp = tp.ones((32, 2), dtype=tp.float32)
print(fast_geglu(inp))
```
<!-- Tripy: TEST: IGNORE End -->
<!-- Tripy: TEST: XFAIL End -->

### Saving The Executable

A compiled executable can be saved to disk and then used for deployment.
## Saving The Executable

Saving an executable to disk:
You can serialize and save executables like so:

```py
# doc: no-print-locals
import tempfile # doc: omit
import os

out_dir = tempfile.mkdtemp() # doc: omit
# Assuming `out_dir` is the directory where you'd like to save the executable:
executable_file_path = os.path.join(out_dir, "executable.json")
fast_geglu.save(executable_file_path)
```

Reading an executable and running inference:
Then you can load and run it again:

```py
# doc: no-print-locals
# doc: no-print-locals loaded_fast_geglu
loaded_fast_geglu = tp.Executable.load(executable_file_path)

inp = tp.Tensor([[1., 2.], [2., 3.]], dtype=tp.float32)
out = loaded_fast_geglu(inp)
os.remove(executable_file_path) # doc: omit
```

### Querying Executable Properties

We can also query properties about the executable:

```py
# doc: print-locals input_info output_info
input_info = loaded_fast_geglu.get_input_info()
output_info = loaded_fast_geglu.get_output_info()
```
2 changes: 1 addition & 1 deletion tripy/examples/nanogpt/README.md
Original file line number Diff line number Diff line change
@@ -62,7 +62,7 @@ To run with a quantization mode, pass `--quant-mode` to `example.py`. The suppor
<!--
```
(?s).*?
What is the answer to life, the universe, and everything\? How can one explain the existence of the universe to
What is the answer to life, the universe, and everything\? How can one explain what's important to us to
```
-->
<!-- Tripy: TEST: EXPECTED_STDOUT End -->
7 changes: 3 additions & 4 deletions tripy/tests/backend/api/test_executable.py
Original file line number Diff line number Diff line change
@@ -95,12 +95,12 @@ def test_signature_multiple_return_values(self, multiple_return_executable):
assert signature.return_annotation == Sequence[tp.Tensor]

def test_io_tensor_info(self, multiple_return_executable):
input_info = multiple_return_executable.get_input_info()
input_info = multiple_return_executable._get_input_info()
assert len(input_info) == 2
for i in range(2):
assert input_info[i].shape_bounds == ((2, 2), (2, 2))
assert input_info[i].dtype == tp.float32
output_info = multiple_return_executable.get_output_info()
output_info = multiple_return_executable._get_output_info()
assert len(output_info) == 2
for i in range(2):
assert output_info[i].shape_bounds == ((2, 2), (2, 2))
@@ -112,8 +112,7 @@ def test_file_io(self, single_return_executable):
single_return_executable.save(exe_file)
assert os.path.exists(exe_file)
loaded_executable = tp.Executable.load(exe_file)
assert loaded_executable.get_input_info() == single_return_executable.get_input_info()
assert loaded_executable.get_output_info() == single_return_executable.get_output_info()
assert loaded_executable.__signature__ == single_return_executable.__signature__

inp = tp.iota((2, 2), dtype=tp.float32)
out1 = single_return_executable(inp, inp)
8 changes: 6 additions & 2 deletions tripy/tests/test_examples.py
Original file line number Diff line number Diff line change
@@ -101,8 +101,12 @@ def test_examples(example, sandboxed_install_run):

block_text = str(block)
if block.has_marker("test: expected_stdout"):
print("Checking command output against expected output:")
assert re.match(dedent(block_text).strip(), statuses[-1].stdout.strip())
print("Checking command output against expected output: ", end="")
out = statuses[-1].stdout.strip()
matched = re.match(dedent(block_text).strip(), out)
print("matched!" if matched else "did not match!")
print(f"==== STDOUT ====\n{out}")
assert matched
else:
status = example.run(block_text, sandboxed_install_run)

6 changes: 3 additions & 3 deletions tripy/tripy/backend/api/compile.py
Original file line number Diff line number Diff line change
@@ -53,9 +53,9 @@ def compile(
The compiled function will have the following constraints:
- Only :class:`Tensor` parameters to the function can become runtime inputs. All other types of parameters,
even collections of :class:`Tensor` s (e.g. ``List[Tensor]`` or ``Dict[str, Tensor]``), will be baked into
the compiled function as constants.
- Only :class:`Tensor` parameters to the function can become runtime inputs.
All other types of parameters, even collections of :class:`Tensor` s (e.g. ``List[Tensor]`` or ``Dict[str, Tensor]``),
will be baked into the compiled function as constants.
optimization_level: The optimization level to use when compiling. Higher optimization levels can lead to better
runtime performance at the cost of longer compile times.
165 changes: 84 additions & 81 deletions tripy/tripy/backend/api/executable.py
Original file line number Diff line number Diff line change
@@ -14,18 +14,27 @@
# limitations under the License.
import base64
import inspect
from typing import Sequence, Union
from typing import Sequence, Union, Tuple

import mlir_tensorrt.runtime.api as runtime

from tripy import export
from tripy.backend.api.input_info import ArgInfo
from tripy.backend.mlir import Executor
from tripy.backend.mlir import utils as mlir_utils
from tripy.common.exception import raise_error
from tripy.frontend import Tensor
from tripy.function_registry import sanitize_name
from tripy.utils import json as json_utils
from tripy.utils.stack_info import StackInfo
from dataclasses import dataclass


# TODO(MLIR-TRT #923): Can generalize `InputInfo` and drop this class.
@dataclass
class ArgInfo:
shape_bounds: Sequence[Tuple[int, int]]
"""A sequence of tuple(min, max) indicating the bounds of each dimension"""
dtype: "tripy.dtype"
"""The datatype of the argument"""


@export.public_api(document_under="compiling_code")
@@ -63,6 +72,51 @@ def stream(self):
def stream(self, stream):
self._executor.stream = stream

def __str__(self) -> str:
params = [f"{name}: {sanitize_name(param.annotation)}" for name, param in self.__signature__.parameters.items()]
return f"Executable({', '.join(params)}) -> {sanitize_name(self.__signature__.return_annotation)}"

@staticmethod
def load(path: str) -> "tripy.Executable":
"""
Loads a executable from the provided path.
Args:
path: The path from which to load the exectuable.
Returns:
The executable object loaded from the file.
.. code-block:: python
:linenos:
:caption: Save and load executable
import os
import tempfile # doc: omit
def add(a, b):
return a + b
# doc: no-print-locals compiled_add executable_file
compiled_add = tp.compile(
add,
args=[
tp.InputInfo(([1, 2, 3],), dtype=tp.float32),
tp.InputInfo(([1, 2, 3],), dtype=tp.float32),
],
)
out_dir = tempfile.TemporaryDirectory().name # doc: omit
# Assuming `out_dir` is the directory containing the executable:
executable_file = os.path.join(out_dir, "executable.json")
compiled_add.save(executable_file) # doc: omit
assert os.path.exists(executable_file)
loaded_executable = tp.Executable.load(executable_file)
"""

return json_utils.load(path)

def __call__(self, *args, **kwargs) -> Union[Tensor, Sequence[Tensor]]:
"""
Invokes the executable with the specified tensor arguments.
@@ -83,7 +137,13 @@ def add(a, b):
return a + b
# doc: no-print-locals compiled_add
compiled_add = tp.compile(add, args=[tp.InputInfo((1,), dtype=tp.float32), tp.InputInfo((1,), dtype=tp.float32)])
compiled_add = tp.compile(
add,
args=[
tp.InputInfo((1,), dtype=tp.float32),
tp.InputInfo((1,), dtype=tp.float32),
],
)
a = tp.ones((1,), dtype=tp.float32)
b = tp.ones((1,), dtype=tp.float32)
@@ -133,7 +193,7 @@ def add(a, b):
# TODO: Evaluate whether this should be moved into the executor
if "function expects a memref type with element type" in str(err):
# If the problem is a mismatched data type, we can provide a better error message than the executor can.
expected_input_dtypes = [info.dtype for info in self.get_input_info()]
expected_input_dtypes = [info.dtype for info in self._get_input_info()]
for tensor, dtype, arg_name in zip(input_tensors, expected_input_dtypes, self._arg_names):
if tensor.dtype != dtype:
raise_error(
@@ -144,7 +204,7 @@ def add(a, b):
],
)
elif "InternalError: failed to set input shape" in str(err) or "Runtime shape mismatch" in str(err):
expected_input_shapes = [info.shape_bounds for info in self.get_input_info()]
expected_input_shapes = [info.shape_bounds for info in self._get_input_info()]
for tensor, expected_bounds, arg_name in zip(input_tensors, expected_input_shapes, self._arg_names):
shape = tensor.shape
for i in range(len(shape)):
@@ -180,47 +240,13 @@ def _get_arg_info(self, idx):
shape_bounds = tuple((x, x) for x in arg.shape)
return ArgInfo(shape_bounds, mlir_utils.convert_runtime_dtype_to_tripy_dtype(arg.dtype))

def get_input_info(self) -> Sequence[ArgInfo]:
"""
Returns input tensors' information.
Returns:
A list containing one `ArgInfo` per input.
.. code-block:: python
:linenos:
:caption: Get input info
def add(a, b):
return a + b
# doc: no-print-locals compiled_add
compiled_add = tp.compile(add, args=[tp.InputInfo(([1, 2, 3],), dtype=tp.float32), tp.InputInfo(([1, 2, 3],), dtype=tp.float32)])
print(compiled_add.get_input_info())
"""
def _get_input_info(self) -> Sequence[ArgInfo]:
input_info = []
for idx in range(self._executable_signature.get_num_input_args()):
input_info.append(self._get_arg_info(idx))
return input_info

def get_output_info(self) -> Sequence[ArgInfo]:
"""
Returns output tensors' information.
Returns:
A list containing one `ArgInfo` per input.
.. code-block:: python
:linenos:
:caption: Get output info
def add(a, b):
return a + b
# doc: no-print-locals compiled_add
compiled_add = tp.compile(add, args=[tp.InputInfo(([1, 2, 3],), dtype=tp.float32), tp.InputInfo(([1, 2, 3],), dtype=tp.float32)])
print(compiled_add.get_output_info())
"""
def _get_output_info(self) -> Sequence[ArgInfo]:
output_info = []
offset = self._executable_signature.get_num_input_args()
for idx in range(self._executable_signature.get_num_output_args()):
@@ -229,61 +255,38 @@ def add(a, b):

def save(self, path: str) -> None:
"""
Saves the compiled executable to the given file.
Saves this executable to the provided path.
Args:
path: The name of file to save the executable.
path: The path at which to save the executable.
.. code-block:: python
:linenos:
:caption: Save executable
import os, tempfile
import os
import tempfile # doc: omit
def add(a, b):
return a + b
# doc: no-print-locals compiled_add executable_file
compiled_add = tp.compile(add, args=[tp.InputInfo(([1, 2, 3],), dtype=tp.float32), tp.InputInfo(([1, 2, 3],), dtype=tp.float32)])
compiled_add = tp.compile(
add,
args=[
tp.InputInfo(([1, 2, 3],), dtype=tp.float32),
tp.InputInfo(([1, 2, 3],), dtype=tp.float32),
],
)
with tempfile.TemporaryDirectory() as temp_dir:
executable_file = os.path.join(temp_dir, "executable.json")
compiled_add.save(executable_file)
assert os.path.exists(executable_file)
out_dir = tempfile.TemporaryDirectory().name # doc: omit
# Assuming `out_dir` is the desired output directory:
executable_file = os.path.join(out_dir, "executable.json")
compiled_add.save(executable_file)
assert os.path.exists(executable_file)
"""
json_utils.save(self, path)

@classmethod
def load(cls, path: str) -> "tripy.Executable":
"""
Loads a compiled executable from a given directory.
Args:
path: The name of file to load the exectuable from.
Returns:
The executable object loaded from the file.
.. code-block:: python
:linenos:
:caption: Save and load executable
import os, tempfile
def add(a, b):
return a + b
# doc: no-print-locals compiled_add executable_file
compiled_add = tp.compile(add, args=[tp.InputInfo(([1, 2, 3],), dtype=tp.float32), tp.InputInfo(([1, 2, 3],), dtype=tp.float32)])
with tempfile.TemporaryDirectory() as temp_dir:
executable_file = os.path.join(temp_dir, "executable.json")
compiled_add.save(executable_file)
assert os.path.exists(executable_file)
loaded_executable = tp.Executable.load(executable_file)
"""
return json_utils.load(path)


@json_utils.Encoder.register(Executable)
def encode_executable(executable):
15 changes: 2 additions & 13 deletions tripy/tripy/backend/api/input_info.py
Original file line number Diff line number Diff line change
@@ -13,10 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
from dataclasses import dataclass
from typing import Sequence, Tuple, Union


from tripy import export
from tripy.common.exception import raise_error
from tripy.common.shape_bounds import ShapeBounds
@@ -48,7 +46,8 @@ def __init__(self, shape: Sequence[Union[int, Tuple[int, int, int]]], dtype: "tr
:linenos:
:caption: Dynamic Dimensions
# The first dimension will support values in the range [1, 3], optimizing for a size of 2.
# The first dimension will support values in the range [1, 3],
# optimizing for a size of 2.
inp = tp.InputInfo(((1, 2, 3), 4), dtype=tp.float32)
assert inp.shape_bounds.min == (1, 4)
assert inp.shape_bounds.opt == (2, 4)
@@ -90,13 +89,3 @@ def __init__(self, shape: Sequence[Union[int, Tuple[int, int, int]]], dtype: "tr

def __str__(self) -> str:
return f"InputInfo(min={self.shape_bounds.min}, opt={self.shape_bounds.opt}, max={self.shape_bounds.max}, dtype={self.dtype})"


# TODO(MLIR-TRT #923): Can generalize `InputInfo` and drop this class.
@export.public_api(document_under="compiling_code")
@dataclass
class ArgInfo:
shape_bounds: Sequence[Tuple[int, int]]
"""A sequence of tuple(min, max) indicating the bounds of each dimension"""
dtype: "tripy.dtype"
"""The datatype of the argument"""

0 comments on commit efaf830

Please sign in to comment.