Skip to content

Commit c7ef439

Browse files
authored
Subset of existing, approved PR (#9204)
approval from Zhanyong to merge despite GPU cicd test failure due to non-required, flaky status of GPU tests.
1 parent c0f1e62 commit c7ef439

File tree

3 files changed

+61
-20
lines changed

3 files changed

+61
-20
lines changed

test/test_mat_mul_precision.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Numeric tests for default precision of mat mul."""
22

3+
import logging
34
import unittest
45

56
import torch
@@ -11,6 +12,9 @@
1112

1213
class TestMatMulPrecision(unittest.TestCase):
1314

15+
def setUp(self):
16+
self.logger_name = torch_xla.backends.logger.name
17+
1418
def _make_input(self):
1519
eye = torch.eye(1024, device='cpu', dtype=torch.float64)
1620
rand_ = torch.testing.make_tensor((1024, 1024),
@@ -53,7 +57,8 @@ def test_default(self):
5357
# DO NOT add epsilons to this test. These tests must be numerically exact.
5458
def _test_parameterized(self, precision, bits):
5559
# Arrange
56-
torch_xla.backends.set_mat_mul_precision(precision)
60+
with self.assertLogs(self.logger_name, level=logging.WARNING):
61+
torch_xla.backends.set_mat_mul_precision(precision)
5762

5863
# Diagonal matrices force mat mul through MXU
5964
# but require only one non-zero accumulation.

test/test_mat_mul_precision_get_and_set.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for get/set_mat_mul_precision from init_python_bindings.cpp"""
22

3+
import logging
34
import sys
45
import unittest
56

@@ -11,24 +12,38 @@
1112
class TestMatMulPrecisionGetAndSet(unittest.TestCase):
1213

1314
def setUp(self):
14-
self._original = torch_xla.backends.get_mat_mul_precision()
15+
self.logger_name = torch_xla.backends.logger.name
16+
self._original_precision = torch_xla.backends.get_mat_mul_precision()
1517
torch.set_printoptions(precision=20)
16-
torch_xla.sync()
1718

1819
def tearDown(self):
19-
torch_xla.backends.set_mat_mul_precision(self._original)
20+
with self.assertLogs(self.logger_name, level=logging.WARNING):
21+
torch_xla.backends.set_mat_mul_precision(self._original_precision)
2022
torch.set_printoptions(profile="default")
21-
torch_xla.sync()
23+
24+
def test_set_mat_mul_precision_warning(self):
25+
# Arrange
26+
expected = [(f"WARNING:{self.logger_name}:"
27+
f"{torch_xla.backends._WARNING_MESSAGE}")]
28+
29+
# Act
30+
with self.assertLogs(self.logger_name, level=logging.WARNING) as cm:
31+
torch_xla.backends.set_mat_mul_precision('default')
32+
33+
# Assert
34+
self.assertEqual(expected, cm.output)
2235

2336
def test_set_mat_mul_precision_error(self):
2437
# Assert
2538
with self.assertRaises(ValueError):
2639
# Act
27-
torch_xla.backends.set_mat_mul_precision('BAD VALUE')
40+
with self.assertLogs(self.logger_name, level=logging.WARNING):
41+
torch_xla.backends.set_mat_mul_precision('BAD VALUE')
2842

2943
def test_get_and_set_mat_mul_precision_default(self):
3044
# Arrange
31-
torch_xla.backends.set_mat_mul_precision('default')
45+
with self.assertLogs(self.logger_name, level=logging.WARNING):
46+
torch_xla.backends.set_mat_mul_precision('default')
3247

3348
# Act
3449
status = torch_xla.backends.get_mat_mul_precision()
@@ -38,7 +53,8 @@ def test_get_and_set_mat_mul_precision_default(self):
3853

3954
def test_get_and_set_mat_mul_precision_high(self):
4055
# Arrange
41-
torch_xla.backends.set_mat_mul_precision('high')
56+
with self.assertLogs(self.logger_name, level=logging.WARNING):
57+
torch_xla.backends.set_mat_mul_precision('high')
4258

4359
# Act
4460
status = torch_xla.backends.get_mat_mul_precision()
@@ -48,7 +64,8 @@ def test_get_and_set_mat_mul_precision_high(self):
4864

4965
def test_get_and_set_mat_mul_precision_highest(self):
5066
# Arrange
51-
torch_xla.backends.set_mat_mul_precision('highest')
67+
with self.assertLogs(self.logger_name, level=logging.WARNING):
68+
torch_xla.backends.set_mat_mul_precision('highest')
5269

5370
# Act
5471
status = torch_xla.backends.get_mat_mul_precision()

torch_xla/backends/__init__.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,18 @@
1010

1111
# Literal is available from Python 3.8,
1212
# matching the Python versions for PyTorch and PyTorch/XLA.
13+
import logging
1314
from typing import Final, Literal, TypeAlias
1415

1516
import torch_xla
1617

18+
# TODO: Refactor logging in torch_xla package https://github.com/pytorch/xla/issues/9142
19+
logger = logging.getLogger(__name__)
20+
_WARNING_MESSAGE: Final[str] = (
21+
'Setting mat mul precision multiple times is not '
22+
'recommended. If you need to do so, please empirically '
23+
'verify that the precision setting is behaving as expected.')
24+
1725
__all__ = ["set_mat_mul_precision", "get_mat_mul_precision"]
1826

1927
# Valid values for get_mat_mul_precision/set_mat_mul_precision
@@ -30,43 +38,54 @@
3038

3139

3240
# Some of this description adapted from Jax documentation.
33-
# TODO: Once the numerics tutorial is released, link from this docstring.
3441
def set_mat_mul_precision(precision: _PrecisionType) -> None:
35-
"""Control the default matmul and conv precision for 32bit inputs.
42+
"""Control the default mat mul and conv precision for 32bit inputs.
3643
3744
Some platforms, like TPU, offer configurable precision levels for
3845
matrix multiplication and convolution computations,
3946
trading off accuracy for speed.
4047
4148
This option controls the default precision level for
42-
computations involved in matrix multiplication and convolution on
49+
computations involved in matrix multiplication and convolutions on
4350
32bit inputs. The levels describe the precision at
4451
which scalar products are computed.
4552
4653
On a TPU:
47-
* `default` is the fastest and least precise,
48-
downcasting an FP32 to BF16 before multiplying.
54+
`default` is the fastest and least precise,
55+
downcasting an FP32 to BF16 before multiplying.
56+
57+
`high` takes three passes and generates approximately 14 bits of
58+
precision.
59+
60+
`highest` is the most precise, and the slowest. It takes six
61+
passes and generates approximately 22 bits of precision.
4962
50-
* `high` takes three passes and generates approximately 14 bits of
51-
precision.
63+
See the [precision tutorial](../../tutorials/precision_tutorial.html)
64+
for more information about the precision levels.
5265
53-
* `highest` is the most precise, and the slowest. It takes six
54-
passes and generates approximately 22 bits of precision.
66+
Note: Setting mat mul precision multiple times is not recommended.
67+
If you need to do so, please empirically verify that the precision
68+
setting is behaving as expected.
5569
5670
Args:
57-
precision (str): The precision to set for matrix multiplication.
58-
Must be one of 'default', 'high', or 'highest'.
71+
precision (str): The precision to set for matrix multiplication.
72+
Must be one of 'default', 'high', or 'highest'.
5973
"""
6074
if precision not in [_DEFAULT, _HIGH, _HIGHEST]:
6175
raise ValueError(f"Invalid precision: {precision}. "
6276
f"Must be one of {_DEFAULT}, {_HIGH}, {_HIGHEST}.")
6377

78+
logger.warning(_WARNING_MESSAGE)
79+
6480
torch_xla._XLAC._xla_set_mat_mul_precision(precision)
6581

6682

6783
def get_mat_mul_precision() -> _PrecisionType:
6884
"""Get the current mat mul precision for 32bit inputs.
6985
86+
See the [precision tutorial](../../tutorials/precision_tutorial.html)
87+
for more information about the precision levels.
88+
7089
Returns:
7190
str: The current precision setting for matrix multiplication,
7291
one of 'default', 'high', or 'highest'.

0 commit comments

Comments
 (0)