Skip to content

Commit e71d5d5

Browse files
Merge pull request #28899 from jenriver:absl_logging_fix
PiperOrigin-RevId: 762099728
2 parents fdf6e1f + 7cf4f35 commit e71d5d5

File tree

3 files changed

+16
-18
lines changed

3 files changed

+16
-18
lines changed

docs/export/export.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -710,10 +710,7 @@ total 32
710710
-rw-rw-r--@ 1 necula wheel 2333 Jun 19 11:04 jax_ir3_jit_my_fun_export.mlir
711711
```
712712

713-
Inside Google, you can turn on logging by using the `--vmodule` argument to
714-
specify the logging levels for different modules,
715-
e.g., `--vmodule=_export=3`.
716-
713+
Set [`JAX_DEBUG_LOG_MODULES=jax._src.export`](https://docs.jax.dev/en/latest/config_options.html#jax_debug_log_modules) to enable extra debugging logging.
717714

718715
(export_ensuring_compat)=
719716
### Ensuring forward and backward compatibility

jax/_src/compiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def get_compile_options(
241241
else:
242242
compile_options.profile_version = _NO_PROFILE_DONT_RETRIEVE
243243
if backend is None:
244-
logging.info("get_compile_options: no backend supplied; "
244+
logger.info("get_compile_options: no backend supplied; "
245245
"disabling XLA-AutoFDO profile")
246246
else:
247247
fdo_profile_version = get_latest_profile_version(backend)
@@ -369,7 +369,7 @@ def compile_or_get_cached(
369369
module_name = ir.StringAttr(sym_name).value
370370

371371
if dumped_to := mlir.dump_module_to_file(computation, "compile"):
372-
logging.info("Dumped the module to %s.", dumped_to)
372+
logger.info("Dumped the module to %s.", dumped_to)
373373

374374
is_multi_process = (
375375
len({device.process_index for device in devices.flatten()}) > 1
@@ -514,7 +514,7 @@ def _resolve_compilation_strategy(
514514
# The compilation cache is enabled and AutoPGLE is enabled/expected
515515
if _is_executable_in_cache(backend, pgle_optimized_cache_key):
516516
if config.compilation_cache_expect_pgle.value:
517-
logging.info(f"PGLE-optimized {module_name} loaded from compilation cache")
517+
logger.info(f"PGLE-optimized {module_name} loaded from compilation cache")
518518
# No need to record N profiles in this case
519519
if pgle_profiler is not None:
520520
pgle_profiler.disable()

jax/_src/export/_export.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import re
2626
from typing import Any, Protocol, TypeVar, Union, cast
2727

28-
from absl import logging
28+
import logging
2929
import numpy as np
3030

3131
import jax
@@ -55,6 +55,8 @@
5555

5656
from jax._src.export import shape_poly
5757

58+
logger = logging.getLogger(__name__)
59+
5860
map = util.safe_map
5961
zip = util.safe_zip
6062

@@ -704,16 +706,15 @@ def _export_lowered(
704706
out_avals_flat = lowered.compile_args["out_avals"] # type: ignore
705707

706708
# Log and then check the module.
707-
if logging.vlog_is_on(3):
708-
logmsg = (f"fun_name={fun_name} version={version} "
709-
f"lowering_platforms={lowering._platforms} " # type: ignore[unused-ignore,attribute-error]
710-
f"disabled_checks={disabled_checks}")
711-
logging.info("Exported JAX function: %s\n", logmsg)
712-
logging.info(mlir.dump_module_message(mlir_module, "export"))
713-
logging.info(
714-
"Size of mlir_module_serialized: %d byte",
715-
len(mlir_module_serialized),
716-
)
709+
logmsg = (f"fun_name={fun_name} version={version} "
710+
f"lowering_platforms={lowering._platforms} " # type: ignore[unused-ignore,attribute-error]
711+
f"disabled_checks={disabled_checks}")
712+
logger.debug("Exported JAX function: %s\n", logmsg)
713+
logger.debug(mlir.dump_module_message(mlir_module, "export"))
714+
logger.debug(
715+
"Size of mlir_module_serialized: %d byte",
716+
len(mlir_module_serialized),
717+
)
717718

718719
_check_module(mlir_module,
719720
disabled_checks=disabled_checks,

0 commit comments

Comments
 (0)