|
10 | 10 |
|
11 | 11 | # Literal is available from Python 3.8,
|
12 | 12 | # matching the Python versions for PyTorch and PyTorch/XLA.
|
| 13 | +import logging |
13 | 14 | from typing import Final, Literal, TypeAlias
|
14 | 15 |
|
15 | 16 | import torch_xla
|
16 | 17 |
|
| 18 | +# TODO: Refactor logging in torch_xla package https://github.com/pytorch/xla/issues/9142 |
| 19 | +logger = logging.getLogger(__name__) |
| 20 | +_WARNING_MESSAGE: Final[str] = ( |
| 21 | + 'Setting mat mul precision multiple times is not ' |
| 22 | + 'recommended. If you need to do so, please empirically ' |
| 23 | + 'verify that the precision setting is behaving as expected.') |
| 24 | + |
17 | 25 | __all__ = ["set_mat_mul_precision", "get_mat_mul_precision"]
|
18 | 26 |
|
19 | 27 | # Valid values for get_mat_mul_precision/set_mat_mul_precision
|
|
30 | 38 |
|
31 | 39 |
|
32 | 40 | # Some of this description adapted from Jax documentation.
|
33 |
| -# TODO: Once the numerics tutorial is released, link from this docstring. |
34 | 41 | def set_mat_mul_precision(precision: _PrecisionType) -> None:
|
35 |
| - """Control the default matmul and conv precision for 32bit inputs. |
| 42 | + """Control the default mat mul and conv precision for 32bit inputs. |
36 | 43 |
|
37 | 44 | Some platforms, like TPU, offer configurable precision levels for
|
38 | 45 | matrix multiplication and convolution computations,
|
39 | 46 | trading off accuracy for speed.
|
40 | 47 |
|
41 | 48 | This option controls the default precision level for
|
42 |
| - computations involved in matrix multiplication and convolution on |
| 49 | + computations involved in matrix multiplication and convolutions on |
43 | 50 | 32bit inputs. The levels describe the precision at
|
44 | 51 | which scalar products are computed.
|
45 | 52 |
|
46 | 53 | On a TPU:
|
47 |
| - * `default` is the fastest and least precise, |
48 |
| - downcasting an FP32 to BF16 before multiplying. |
| 54 | + `default` is the fastest and least precise, |
| 55 | + downcasting an FP32 to BF16 before multiplying. |
| 56 | +
|
| 57 | + `high` takes three passes and generates approximately 14 bits of |
| 58 | + precision. |
| 59 | +
|
| 60 | + `highest` is the most precise, and the slowest. It takes six |
| 61 | + passes and generates approximately 22 bits of precision. |
49 | 62 |
|
50 |
| - * `high` takes three passes and generates approximately 14 bits of |
51 |
| - precision. |
| 63 | + See the [precision tutorial](../../tutorials/precision_tutorial.html) |
| 64 | + for more information about the precision levels. |
52 | 65 |
|
53 |
| - * `highest` is the most precise, and the slowest. It takes six |
54 |
| - passes and generates approximately 22 bits of precision. |
| 66 | + Note: Setting mat mul precision multiple times is not recommended. |
| 67 | + If you need to do so, please empirically verify that the precision |
| 68 | + setting is behaving as expected. |
55 | 69 |
|
56 | 70 | Args:
|
57 |
| - precision (str): The precision to set for matrix multiplication. |
58 |
| - Must be one of 'default', 'high', or 'highest'. |
| 71 | + precision (str): The precision to set for matrix multiplication. |
| 72 | + Must be one of 'default', 'high', or 'highest'. |
59 | 73 | """
|
60 | 74 | if precision not in [_DEFAULT, _HIGH, _HIGHEST]:
|
61 | 75 | raise ValueError(f"Invalid precision: {precision}. "
|
62 | 76 | f"Must be one of {_DEFAULT}, {_HIGH}, {_HIGHEST}.")
|
63 | 77 |
|
| 78 | + logger.warning(_WARNING_MESSAGE) |
| 79 | + |
64 | 80 | torch_xla._XLAC._xla_set_mat_mul_precision(precision)
|
65 | 81 |
|
66 | 82 |
|
67 | 83 | def get_mat_mul_precision() -> _PrecisionType:
|
68 | 84 | """Get the current mat mul precision for 32bit inputs.
|
69 | 85 |
|
| 86 | + See the [precision tutorial](../../tutorials/precision_tutorial.html) |
| 87 | + for more information about the precision levels. |
| 88 | +
|
70 | 89 | Returns:
|
71 | 90 | str: The current precision setting for matrix multiplication,
|
72 | 91 | one of 'default', 'high', or 'highest'.
|
|
0 commit comments