Skip to content

Commit 0a970e4

Browse files
Jake VanderPlasGoogle-ML-Automation
authored andcommitted
Move _src/custom_transpose.py into its own BUILD rule
Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times. PiperOrigin-RevId: 762414305
1 parent 12966b5 commit 0a970e4

2 files changed

Lines changed: 21 additions & 1 deletion

File tree

jax/BUILD

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,6 @@ py_library_providing_imports_info(
307307
"_src/custom_derivatives.py",
308308
"_src/custom_partitioning.py",
309309
"_src/custom_partitioning_sharding_rule.py",
310-
"_src/custom_transpose.py",
311310
"_src/debugging.py",
312311
"_src/dispatch.py",
313312
"_src/dlpack.py",
@@ -391,6 +390,7 @@ py_library_providing_imports_info(
391390
":config",
392391
":core",
393392
":custom_api_util",
393+
":custom_transpose",
394394
":deprecations",
395395
":dtypes",
396396
":effects",
@@ -595,6 +595,25 @@ pytype_strict_library(
595595
srcs = ["_src/custom_api_util.py"],
596596
)
597597

598+
pytype_strict_library(
599+
name = "custom_transpose",
600+
srcs = ["_src/custom_transpose.py"],
601+
deps = [
602+
":ad",
603+
":ad_util",
604+
":api_util",
605+
":core",
606+
":custom_api_util",
607+
":mlir",
608+
":partial_eval",
609+
":source_info_util",
610+
":traceback_util",
611+
":tree_util",
612+
":util",
613+
":xla",
614+
],
615+
)
616+
598617
pytype_strict_library(
599618
name = "deprecations",
600619
srcs = ["_src/deprecations.py"],

tests/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ jax_multiplatform_test(
6161
srcs = ["debug_info_test.py"],
6262
enable_configs = ["tpu_v3_x4"],
6363
deps = [
64+
"//jax:custom_transpose",
6465
"//jax:experimental",
6566
"//jax:pallas",
6667
"//jax:pallas_gpu",

0 commit comments

Comments
 (0)