Skip to content

Subset of existing, approved PR #9204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion test/test_mat_mul_precision.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Numeric tests for default precision of mat mul."""

import logging
import unittest

import torch
Expand All @@ -11,6 +12,9 @@

class TestMatMulPrecision(unittest.TestCase):

def setUp(self):
self.logger_name = torch_xla.backends.logger.name

def _make_input(self):
eye = torch.eye(1024, device='cpu', dtype=torch.float64)
rand_ = torch.testing.make_tensor((1024, 1024),
Expand Down Expand Up @@ -53,7 +57,8 @@ def test_default(self):
# DO NOT add epsilons to this test. These tests must be numerically exact.
def _test_parameterized(self, precision, bits):
# Arrange
torch_xla.backends.set_mat_mul_precision(precision)
with self.assertLogs(self.logger_name, level=logging.WARNING):
torch_xla.backends.set_mat_mul_precision(precision)

# Diagonal matrices force mat mul through MXU
# but require only one non-zero accumulation.
Expand Down
33 changes: 25 additions & 8 deletions test/test_mat_mul_precision_get_and_set.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for get/set_mat_mul_precision from init_python_bindings.cpp"""

import logging
import sys
import unittest

Expand All @@ -11,24 +12,38 @@
class TestMatMulPrecisionGetAndSet(unittest.TestCase):

def setUp(self):
self._original = torch_xla.backends.get_mat_mul_precision()
self.logger_name = torch_xla.backends.logger.name
self._original_precision = torch_xla.backends.get_mat_mul_precision()
torch.set_printoptions(precision=20)
torch_xla.sync()

def tearDown(self):
torch_xla.backends.set_mat_mul_precision(self._original)
with self.assertLogs(self.logger_name, level=logging.WARNING):
torch_xla.backends.set_mat_mul_precision(self._original_precision)
torch.set_printoptions(profile="default")
torch_xla.sync()

def test_set_mat_mul_precision_warning(self):
# Arrange
expected = [(f"WARNING:{self.logger_name}:"
f"{torch_xla.backends._WARNING_MESSAGE}")]

# Act
with self.assertLogs(self.logger_name, level=logging.WARNING) as cm:
torch_xla.backends.set_mat_mul_precision('default')

# Assert
self.assertEqual(expected, cm.output)

def test_set_mat_mul_precision_error(self):
# Assert
with self.assertRaises(ValueError):
# Act
torch_xla.backends.set_mat_mul_precision('BAD VALUE')
with self.assertLogs(self.logger_name, level=logging.WARNING):
torch_xla.backends.set_mat_mul_precision('BAD VALUE')

def test_get_and_set_mat_mul_precision_default(self):
# Arrange
torch_xla.backends.set_mat_mul_precision('default')
with self.assertLogs(self.logger_name, level=logging.WARNING):
torch_xla.backends.set_mat_mul_precision('default')

# Act
status = torch_xla.backends.get_mat_mul_precision()
Expand All @@ -38,7 +53,8 @@ def test_get_and_set_mat_mul_precision_default(self):

def test_get_and_set_mat_mul_precision_high(self):
# Arrange
torch_xla.backends.set_mat_mul_precision('high')
with self.assertLogs(self.logger_name, level=logging.WARNING):
torch_xla.backends.set_mat_mul_precision('high')

# Act
status = torch_xla.backends.get_mat_mul_precision()
Expand All @@ -48,7 +64,8 @@ def test_get_and_set_mat_mul_precision_high(self):

def test_get_and_set_mat_mul_precision_highest(self):
# Arrange
torch_xla.backends.set_mat_mul_precision('highest')
with self.assertLogs(self.logger_name, level=logging.WARNING):
torch_xla.backends.set_mat_mul_precision('highest')

# Act
status = torch_xla.backends.get_mat_mul_precision()
Expand Down
41 changes: 30 additions & 11 deletions torch_xla/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,18 @@

# Literal is available from Python 3.8,
# matching the Python versions for PyTorch and PyTorch/XLA.
import logging
from typing import Final, Literal, TypeAlias

import torch_xla

# TODO: Refactor logging in torch_xla package https://github.com/pytorch/xla/issues/9142
logger = logging.getLogger(__name__)
_WARNING_MESSAGE: Final[str] = (
'Setting mat mul precision multiple times is not '
'recommended. If you need to do so, please empirically '
'verify that the precision setting is behaving as expected.')

__all__ = ["set_mat_mul_precision", "get_mat_mul_precision"]

# Valid values for get_mat_mul_precision/set_mat_mul_precision
Expand All @@ -30,43 +38,54 @@


# Some of this description adapted from Jax documentation.
# TODO: Once the numerics tutorial is released, link from this docstring.
def set_mat_mul_precision(precision: _PrecisionType) -> None:
"""Control the default matmul and conv precision for 32bit inputs.
"""Control the default mat mul and conv precision for 32bit inputs.

Some platforms, like TPU, offer configurable precision levels for
matrix multiplication and convolution computations,
trading off accuracy for speed.

This option controls the default precision level for
computations involved in matrix multiplication and convolution on
computations involved in matrix multiplication and convolutions on
32bit inputs. The levels describe the precision at
which scalar products are computed.

On a TPU:
* `default` is the fastest and least precise,
downcasting an FP32 to BF16 before multiplying.
`default` is the fastest and least precise,
downcasting an FP32 to BF16 before multiplying.

`high` takes three passes and generates approximately 14 bits of
precision.

`highest` is the most precise, and the slowest. It takes six
passes and generates approximately 22 bits of precision.

* `high` takes three passes and generates approximately 14 bits of
precision.
See the [precision tutorial](../../tutorials/precision_tutorial.html)
for more information about the precision levels.

* `highest` is the most precise, and the slowest. It takes six
passes and generates approximately 22 bits of precision.
Note: Setting mat mul precision multiple times is not recommended.
If you need to do so, please empirically verify that the precision
setting is behaving as expected.

Args:
precision (str): The precision to set for matrix multiplication.
Must be one of 'default', 'high', or 'highest'.
precision (str): The precision to set for matrix multiplication.
Must be one of 'default', 'high', or 'highest'.
"""
if precision not in [_DEFAULT, _HIGH, _HIGHEST]:
raise ValueError(f"Invalid precision: {precision}. "
f"Must be one of {_DEFAULT}, {_HIGH}, {_HIGHEST}.")

logger.warning(_WARNING_MESSAGE)

torch_xla._XLAC._xla_set_mat_mul_precision(precision)


def get_mat_mul_precision() -> _PrecisionType:
"""Get the current mat mul precision for 32bit inputs.

See the [precision tutorial](../../tutorials/precision_tutorial.html)
for more information about the precision levels.

Returns:
str: The current precision setting for matrix multiplication,
one of 'default', 'high', or 'highest'.
Expand Down
Loading