diff --git a/test/test_mat_mul_precision.py b/test/test_mat_mul_precision.py index 6da3aec3f259..f3a32e3e7206 100644 --- a/test/test_mat_mul_precision.py +++ b/test/test_mat_mul_precision.py @@ -1,5 +1,6 @@ """Numeric tests for default precision of mat mul.""" +import logging import unittest import torch @@ -11,6 +12,9 @@ class TestMatMulPrecision(unittest.TestCase): + def setUp(self): + self.logger_name = torch_xla.backends.logger.name + def _make_input(self): eye = torch.eye(1024, device='cpu', dtype=torch.float64) rand_ = torch.testing.make_tensor((1024, 1024), @@ -53,7 +57,8 @@ def test_default(self): # DO NOT add epsilons to this test. These tests must be numerically exact. def _test_parameterized(self, precision, bits): # Arrange - torch_xla.backends.set_mat_mul_precision(precision) + with self.assertLogs(self.logger_name, level=logging.WARNING): + torch_xla.backends.set_mat_mul_precision(precision) # Diagonal matrices force mat mul through MXU # but require only one non-zero accumulation. diff --git a/test/test_mat_mul_precision_get_and_set.py b/test/test_mat_mul_precision_get_and_set.py index ad47a9cc0e60..e383ae097bec 100644 --- a/test/test_mat_mul_precision_get_and_set.py +++ b/test/test_mat_mul_precision_get_and_set.py @@ -1,5 +1,6 @@ """Tests for get/set_mat_mul_precision from init_python_bindings.cpp""" +import logging import sys import unittest @@ -11,24 +12,38 @@ class TestMatMulPrecisionGetAndSet(unittest.TestCase): def setUp(self): - self._original = torch_xla.backends.get_mat_mul_precision() + self.logger_name = torch_xla.backends.logger.name + self._original_precision = torch_xla.backends.get_mat_mul_precision() torch.set_printoptions(precision=20) - torch_xla.sync() def tearDown(self): - torch_xla.backends.set_mat_mul_precision(self._original) + with self.assertLogs(self.logger_name, level=logging.WARNING): + torch_xla.backends.set_mat_mul_precision(self._original_precision) torch.set_printoptions(profile="default") - torch_xla.sync() + + def test_set_mat_mul_precision_warning(self): + # Arrange + expected = [(f"WARNING:{self.logger_name}:" + f"{torch_xla.backends._WARNING_MESSAGE}")] + + # Act + with self.assertLogs(self.logger_name, level=logging.WARNING) as cm: + torch_xla.backends.set_mat_mul_precision('default') + + # Assert + self.assertEqual(expected, cm.output) def test_set_mat_mul_precision_error(self): # Assert with self.assertRaises(ValueError): # Act - torch_xla.backends.set_mat_mul_precision('BAD VALUE') + with self.assertLogs(self.logger_name, level=logging.WARNING): + torch_xla.backends.set_mat_mul_precision('BAD VALUE') def test_get_and_set_mat_mul_precision_default(self): # Arrange - torch_xla.backends.set_mat_mul_precision('default') + with self.assertLogs(self.logger_name, level=logging.WARNING): + torch_xla.backends.set_mat_mul_precision('default') # Act status = torch_xla.backends.get_mat_mul_precision() @@ -38,7 +53,8 @@ def test_get_and_set_mat_mul_precision_default(self): def test_get_and_set_mat_mul_precision_high(self): # Arrange - torch_xla.backends.set_mat_mul_precision('high') + with self.assertLogs(self.logger_name, level=logging.WARNING): + torch_xla.backends.set_mat_mul_precision('high') # Act status = torch_xla.backends.get_mat_mul_precision() @@ -48,7 +64,8 @@ def test_get_and_set_mat_mul_precision_high(self): def test_get_and_set_mat_mul_precision_highest(self): # Arrange - torch_xla.backends.set_mat_mul_precision('highest') + with self.assertLogs(self.logger_name, level=logging.WARNING): + torch_xla.backends.set_mat_mul_precision('highest') # Act status = torch_xla.backends.get_mat_mul_precision() diff --git a/torch_xla/backends/__init__.py b/torch_xla/backends/__init__.py index 7256c33bc6bf..f3ee803e0c62 100644 --- a/torch_xla/backends/__init__.py +++ b/torch_xla/backends/__init__.py @@ -10,10 +10,18 @@ # Literal is available from Python 3.8, # matching the Python versions for PyTorch and PyTorch/XLA. +import logging from typing import Final, Literal, TypeAlias import torch_xla +# TODO: Refactor logging in torch_xla package https://github.com/pytorch/xla/issues/9142 +logger = logging.getLogger(__name__) +_WARNING_MESSAGE: Final[str] = ( + 'Setting mat mul precision multiple times is not ' + 'recommended. If you need to do so, please empirically ' + 'verify that the precision setting is behaving as expected.') + __all__ = ["set_mat_mul_precision", "get_mat_mul_precision"] # Valid values for get_mat_mul_precision/set_mat_mul_precision @@ -30,43 +38,54 @@ # Some of this description adapted from Jax documentation. -# TODO: Once the numerics tutorial is released, link from this docstring. def set_mat_mul_precision(precision: _PrecisionType) -> None: - """Control the default matmul and conv precision for 32bit inputs. + """Control the default mat mul and conv precision for 32bit inputs. Some platforms, like TPU, offer configurable precision levels for matrix multiplication and convolution computations, trading off accuracy for speed. This option controls the default precision level for - computations involved in matrix multiplication and convolution on + computations involved in matrix multiplication and convolutions on 32bit inputs. The levels describe the precision at which scalar products are computed. On a TPU: - * `default` is the fastest and least precise, - downcasting an FP32 to BF16 before multiplying. + `default` is the fastest and least precise, + downcasting an FP32 to BF16 before multiplying. + + `high` takes three passes and generates approximately 14 bits of + precision. + + `highest` is the most precise, and the slowest. It takes six + passes and generates approximately 22 bits of precision. - * `high` takes three passes and generates approximately 14 bits of - precision. + See the [precision tutorial](../../tutorials/precision_tutorial.html) + for more information about the precision levels. - * `highest` is the most precise, and the slowest. It takes six - passes and generates approximately 22 bits of precision. + Note: Setting mat mul precision multiple times is not recommended. + If you need to do so, please empirically verify that the precision + setting is behaving as expected. Args: - precision (str): The precision to set for matrix multiplication. - Must be one of 'default', 'high', or 'highest'. + precision (str): The precision to set for matrix multiplication. + Must be one of 'default', 'high', or 'highest'. """ if precision not in [_DEFAULT, _HIGH, _HIGHEST]: raise ValueError(f"Invalid precision: {precision}. " f"Must be one of {_DEFAULT}, {_HIGH}, {_HIGHEST}.") + logger.warning(_WARNING_MESSAGE) + torch_xla._XLAC._xla_set_mat_mul_precision(precision) def get_mat_mul_precision() -> _PrecisionType: """Get the current mat mul precision for 32bit inputs. + See the [precision tutorial](../../tutorials/precision_tutorial.html) + for more information about the precision levels. + Returns: str: The current precision setting for matrix multiplication, one of 'default', 'high', or 'highest'.