diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index 05adbd624c..34d4541a98 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -12,6 +12,8 @@ from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import _get_sys_extension +_logger = logging.getLogger(__name__) + def _load_library(): """Load shared library with Transformer Engine C extensions""" @@ -36,7 +38,7 @@ def _load_library(): if is_package_installed("transformer-engine-cu12"): if not is_package_installed(module_name): - logging.info( + _logger.info( "Could not find package %s. Install transformer-engine using 'pip" " install transformer-engine[jax]==VERSION'", module_name, diff --git a/transformer_engine/paddle/__init__.py b/transformer_engine/paddle/__init__.py index 50cf2186d6..78912b0cce 100644 --- a/transformer_engine/paddle/__init__.py +++ b/transformer_engine/paddle/__init__.py @@ -11,6 +11,8 @@ from transformer_engine.common import is_package_installed +_logger = logging.getLogger(__name__) + def _load_library(): """Load shared library with Transformer Engine C extensions""" @@ -35,7 +37,7 @@ def _load_library(): if is_package_installed("transformer-engine-cu12"): if not is_package_installed(module_name): - logging.info( + _logger.info( "Could not find package %s. Install transformer-engine using 'pip" " install transformer-engine[paddle]==VERSION'", module_name, diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 781f9d42fd..8c0d032eb5 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -16,6 +16,8 @@ from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import _get_sys_extension +_logger = logging.getLogger(__name__) + def _load_library(): """Load shared library with Transformer Engine C extensions""" @@ -40,7 +42,7 @@ def _load_library(): if is_package_installed("transformer-engine-cu12"): if not is_package_installed(module_name): - logging.info( + _logger.info( "Could not find package %s. Install transformer-engine using 'pip" " install transformer-engine[pytorch]==VERSION'", module_name, diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 9268b9636e..2f65e579d9 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -97,7 +97,7 @@ _formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s") _stream_handler = logging.StreamHandler() _stream_handler.setFormatter(_formatter) -fa_logger = logging.getLogger() +fa_logger = logging.getLogger(__name__) fa_logger.setLevel(_log_level) if not fa_logger.hasHandlers(): fa_logger.addHandler(_stream_handler)