Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions docs/sphinxext/jax_list_config_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from operator import itemgetter
from typing import Any, List

Expand All @@ -21,11 +22,37 @@

logger = logging.getLogger(__name__)

_deprecations = (
# Please add justification for why the option is to be hidden and/or when
# it should be revealed, e.g. when the option is deprecated or when it is
# no longer experimental.
_hidden_config_options = (
'jax_default_dtype_bits', # an experiment that we never documented, but we can't remove it because Keras depends on its existing broken behavior
'jax_serialization_version'
'jax_serialization_version',
'check_rep', # internal implementation detail of shard_map, DO NOT USE‰
)

def config_option_to_title_case(name: str) -> str:
"""Converts a config option name to title case, with special rules.

Args:
name: The configuration option name (e.g., "jax_default_dtype_bits").
capitalization_rules: An optional function that takes the name as input
and returns the title-cased name. If None, defaults to a basic
title-casing.
"""

# Define capitalization rules as a list of (string, replacement) tuples
capitalization_rules_list = [
("jax", "JAX"), ("xla", "XLA"), ("pgle", "PGLE"), ("cuda", "CUDA"), ("vjp", "VJP"), ("jvp", "JVP"),
("pjrt", "PjRT"), ("gpu", "GPU"), ("tpu", "TPU"), ("prng", "PRNG"), ("rocm", "ROCm"), ("spmd", "SPMD"),
("bcoo", "BCOO"), ("jit", "JIT"), ("cpu", "CPU"), ("cusparse", "cuSPARSE"), ("ir", "IR"), ("dtype", "DType"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jit should be lowercaps.

roocm is misspelled? rocm

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed rocm. These capitalization rules just get used to build the titles in the doc, so as an acronym jit I'm leaving jit capitalised.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PjRT -> PJRT

What about jit -> jit? Since it probably refers to the JAX API, not "Just In Time" (jax's API is kind of a misuse of the generic term, the verb is really "JIT compile")

Also numpy uses dtype, not Dtype, even when it's not dtype. There are examples in https://numpy.org/doc/2.1/reference/arrays.dtypes.html.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see Dtype because it's all Title Case. This is really yak shaving now but I personally prefer just capitalizing the first word, I think they look funny as Full Titles

Also nan -> NaN

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

("pprint", "PPrint"), ("x64", "x64")
]
name = name.replace("jax_", "").replace("_", " ").title()
for find, replace in capitalization_rules_list:
name = re.sub(rf"\b{find}\b", replace, name, flags=re.IGNORECASE)
return name

def create_field_item(label, content):
"""Create a field list item with a label and content side by side.

Expand Down Expand Up @@ -69,7 +96,7 @@ def run(self) -> List[nodes.Node]:
result = []

for name, (opt_type, meta_args, meta_kwargs) in config_options:
if name in _deprecations:
if name in _hidden_config_options:
continue

holder = jax_config._value_holders[name]
Expand All @@ -87,7 +114,7 @@ def run(self) -> List[nodes.Node]:
# Create a title with the option name (important for TOC)
title = nodes.title()
title['classes'] = ['h4']
title += nodes.Text(name.replace("jax_", "").replace("_", " ").title())
title += nodes.Text(config_option_to_title_case(name))
option_section += title

# Create a field list for side-by-side display
Expand Down
24 changes: 12 additions & 12 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
random_seed_offset = int_state(
name='jax_random_seed_offset',
default=0,
help=('Offset to all random seeds (e.g. argument to jax.random.key()).'),
help=('Offset to all random seeds (e.g. argument to `jax.random.key`).'),
include_in_jit_key=True,
)

Expand All @@ -1049,7 +1049,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
enum_values=['allow', 'warn', 'error'],
default='allow',
help=('Specify the behavior when raw PRNG keys are passed to '
'jax.random APIs.')
'`jax.random` APIs.')
)

enable_custom_prng = bool_state(
Expand Down Expand Up @@ -1097,7 +1097,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
default=True,
help=('Adds varying manual axes to ShapedArray to track which mesh axes the'
' array is varying over. This will help to remove the efficient'
' transpose rewrite machinery in shard_map'),
' transpose rewrite machinery in `shard_map`'),
include_in_jit_key=True)

# TODO make it so people don't use this, this is internal...
Expand Down Expand Up @@ -1220,7 +1220,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
name='jax_enable_pgle',
default=False,
help=(
'If set to True and the property jax_pgle_profiling_runs is set to '
'If set to `True` and the property `"jax_pgle_profiling_runs"` is set to '
'greater than 0, the modules will be recompiled after running specified '
'number times with collected data provided to the profile guided latency '
'estimator.'
Expand Down Expand Up @@ -1251,15 +1251,15 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
help=('If set to False, the compilation cache will be disabled regardless '
'of whether set_cache_dir() was called. If set to True, the '
'path could be set to a default value or via a call to '
'set_cache_dir().'),
'`set_cache_dir`.'),
)

compilation_cache_dir = optional_string_state(
name='jax_compilation_cache_dir',
default=None,
help=('Path for the cache. '
'Precedence: '
'1. A call to compilation_cache.set_cache_dir(). '
'1. A call to `compilation_cache.set_cache_dir`. '
'2. The value of this flag set in the command line or by default.'),
)

Expand Down Expand Up @@ -1478,14 +1478,14 @@ def _update_disable_jit_thread_local(val):
"auto"],
default="auto",
help="Controls how JAX filters internal frames out of tracebacks. Valid values are:\n"
"- ``off``: disables traceback filtering.\n"
"- ``auto``: use ``tracebackhide`` if running under a sufficiently "
"* ``off``: disables traceback filtering.\n"
"* ``auto``: use ``tracebackhide`` if running under a sufficiently "
"new IPython, or ``remove_frames`` otherwise.\n"
"- ``tracebackhide``: adds ``__tracebackhide__`` annotations to "
"* ``tracebackhide``: adds ``__tracebackhide__`` annotations to "
"hidden stack frames, which some traceback printers support.\n"
"- ``remove_frames``: removes hidden frames from tracebacks, and adds "
"* ``remove_frames``: removes hidden frames from tracebacks, and adds "
"the unfiltered traceback as a ``__cause__`` of the exception.\n"
"- ``quiet_remove_frames``: removes hidden frames from tracebacks, and adds "
"* ``quiet_remove_frames``: removes hidden frames from tracebacks, and adds "
"a brief message (to the ``__cause__`` of the exception) describing that this has "
"happened.\n\n")

Expand Down Expand Up @@ -1765,7 +1765,7 @@ def _update_garbage_collection_guard(state, key, val):
help=(
'Whether to lower to Shardy. Shardy is a new open sourced propagation '
'framework for MLIR. Currently Shardy is experimental in JAX. See '
'www.github.com/openxla/shardy'
'`Shardy <www.github.com/openxla/shardy>_`.'
),
include_in_jit_key=True,
)
Expand Down
Loading