diff --git a/docs/sphinxext/jax_list_config_options.py b/docs/sphinxext/jax_list_config_options.py index 54f7f6eebe85..643b7123a5ff 100644 --- a/docs/sphinxext/jax_list_config_options.py +++ b/docs/sphinxext/jax_list_config_options.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re from operator import itemgetter from typing import Any, List @@ -21,11 +22,37 @@ logger = logging.getLogger(__name__) -_deprecations = ( +# Please add justification for why the option is to be hidden and/or when +# it should be revealed, e.g. when the option is deprecated or when it is +# no longer experimental. +_hidden_config_options = ( 'jax_default_dtype_bits', # an experiment that we never documented, but we can't remove it because Keras depends on its existing broken behavior - 'jax_serialization_version' + 'jax_serialization_version', + 'check_rep', # internal implementation detail of shard_map, DO NOT USE‰ ) +def config_option_to_title_case(name: str) -> str: + """Converts a config option name to title case, with special rules. + + Args: + name: The configuration option name (e.g., "jax_default_dtype_bits"). + capitalization_rules: An optional function that takes the name as input + and returns the title-cased name. If None, defaults to a basic + title-casing. + """ + + # Define capitalization rules as a list of (string, replacement) tuples + capitalization_rules_list = [ + ("jax", "JAX"), ("xla", "XLA"), ("pgle", "PGLE"), ("cuda", "CUDA"), ("vjp", "VJP"), ("jvp", "JVP"), + ("pjrt", "PJRT"), ("gpu", "GPU"), ("tpu", "TPU"), ("prng", "PRNG"), ("rocm", "ROCm"), ("spmd", "SPMD"), + ("bcoo", "BCOO"), ("jit", "JIT"), ("cpu", "CPU"), ("cusparse", "cuSPARSE"), ("ir", "IR"), + ("pprint", "PPrint"), ("x64", "x64"), ("nan", "NaN") + ] + name = name.replace("jax_", "").replace("_", " ").title() + for find, replace in capitalization_rules_list: + name = re.sub(rf"\b{find}\b", replace, name, flags=re.IGNORECASE) + return name + def create_field_item(label, content): """Create a field list item with a label and content side by side. @@ -69,7 +96,7 @@ def run(self) -> List[nodes.Node]: result = [] for name, (opt_type, meta_args, meta_kwargs) in config_options: - if name in _deprecations: + if name.startswith("_") or name in _hidden_config_options: continue holder = jax_config._value_holders[name] @@ -87,7 +114,7 @@ def run(self) -> List[nodes.Node]: # Create a title with the option name (important for TOC) title = nodes.title() title['classes'] = ['h4'] - title += nodes.Text(name.replace("jax_", "").replace("_", " ").title()) + title += nodes.Text(config_option_to_title_case(name)) option_section += title # Create a field list for side-by-side display diff --git a/jax/_src/config.py b/jax/_src/config.py index a4d9b5582566..ec2cd7045ed7 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1040,7 +1040,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: random_seed_offset = int_state( name='jax_random_seed_offset', default=0, - help=('Offset to all random seeds (e.g. argument to jax.random.key()).'), + help=('Offset to all random seeds (e.g. argument to `jax.random.key`).'), include_in_jit_key=True, ) @@ -1049,7 +1049,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: enum_values=['allow', 'warn', 'error'], default='allow', help=('Specify the behavior when raw PRNG keys are passed to ' - 'jax.random APIs.') + '`jax.random` APIs.') ) enable_custom_prng = bool_state( @@ -1097,7 +1097,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: default=True, help=('Adds varying manual axes to ShapedArray to track which mesh axes the' ' array is varying over. This will help to remove the efficient' - ' transpose rewrite machinery in shard_map'), + ' transpose rewrite machinery in `shard_map`'), include_in_jit_key=True) # TODO make it so people don't use this, this is internal... @@ -1220,7 +1220,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: name='jax_enable_pgle', default=False, help=( - 'If set to True and the property jax_pgle_profiling_runs is set to ' + 'If set to `True` and the property `"jax_pgle_profiling_runs"` is set to ' 'greater than 0, the modules will be recompiled after running specified ' 'number times with collected data provided to the profile guided latency ' 'estimator.' @@ -1251,7 +1251,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: help=('If set to False, the compilation cache will be disabled regardless ' 'of whether set_cache_dir() was called. If set to True, the ' 'path could be set to a default value or via a call to ' - 'set_cache_dir().'), + '`set_cache_dir`.'), ) compilation_cache_dir = optional_string_state( @@ -1259,7 +1259,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]: default=None, help=('Path for the cache. ' 'Precedence: ' - '1. A call to compilation_cache.set_cache_dir(). ' + '1. A call to `compilation_cache.set_cache_dir`. ' '2. The value of this flag set in the command line or by default.'), ) @@ -1478,14 +1478,14 @@ def _update_disable_jit_thread_local(val): "auto"], default="auto", help="Controls how JAX filters internal frames out of tracebacks. Valid values are:\n" - "- ``off``: disables traceback filtering.\n" - "- ``auto``: use ``tracebackhide`` if running under a sufficiently " + "* ``off``: disables traceback filtering.\n" + "* ``auto``: use ``tracebackhide`` if running under a sufficiently " "new IPython, or ``remove_frames`` otherwise.\n" - "- ``tracebackhide``: adds ``__tracebackhide__`` annotations to " + "* ``tracebackhide``: adds ``__tracebackhide__`` annotations to " "hidden stack frames, which some traceback printers support.\n" - "- ``remove_frames``: removes hidden frames from tracebacks, and adds " + "* ``remove_frames``: removes hidden frames from tracebacks, and adds " "the unfiltered traceback as a ``__cause__`` of the exception.\n" - "- ``quiet_remove_frames``: removes hidden frames from tracebacks, and adds " + "* ``quiet_remove_frames``: removes hidden frames from tracebacks, and adds " "a brief message (to the ``__cause__`` of the exception) describing that this has " "happened.\n\n") @@ -1765,7 +1765,7 @@ def _update_garbage_collection_guard(state, key, val): help=( 'Whether to lower to Shardy. Shardy is a new open sourced propagation ' 'framework for MLIR. Currently Shardy is experimental in JAX. See ' - 'www.github.com/openxla/shardy' + '`Shardy _`.' ), include_in_jit_key=True, )