1- import contextlib
21import functools
32import math
43import os
1413from keras_rs .src import types
1514from keras_rs .src .layers .embedding import distributed_embedding
1615from keras_rs .src .layers .embedding import distributed_embedding_config as config
16+ from keras_rs .src .utils import tpu_test_utils
1717
1818try :
1919 import jax
3030SEQUENCE_LENGTH = 13
3131
3232
33- class DummyStrategy :
34- def scope (self ):
35- return contextlib .nullcontext ()
36-
37- @property
38- def num_replicas_in_sync (self ):
39- return 1
40-
41- def run (self , fn , args ):
42- return fn (* args )
43-
44- def experimental_distribute_dataset (self , dataset , options = None ):
45- del options
46- return dataset
47-
48-
49- class JaxDummyStrategy (DummyStrategy ):
50- @property
51- def num_replicas_in_sync (self ):
52- return jax .device_count ("tpu" )
53-
54-
5533def ragged_bool_true (self ):
5634 return True
5735
@@ -74,46 +52,10 @@ def setUp(self):
7452 # FLAGS.xla_sparse_core_max_ids_per_partition_per_sample = 16
7553 # FLAGS.xla_sparse_core_max_unique_ids_per_partition_per_sample = 16
7654
77- resolver = tf .distribute .cluster_resolver .TPUClusterResolver ()
78- tf .config .experimental_connect_to_cluster (resolver )
79-
80- topology = tf .tpu .experimental .initialize_tpu_system (resolver )
81- tpu_metadata = resolver .get_tpu_system_metadata ()
82-
83- device_assignment = tf .tpu .experimental .DeviceAssignment .build (
84- topology , num_replicas = tpu_metadata .num_hosts
85- )
86- self ._strategy = tf .distribute .TPUStrategy (
87- resolver , experimental_device_assignment = device_assignment
88- )
89- print ("### num_replicas" , self ._strategy .num_replicas_in_sync )
90- self .addCleanup (tf .tpu .experimental .shutdown_tpu_system , resolver )
91- elif keras .backend .backend () == "jax" and self .on_tpu :
92- self ._strategy = JaxDummyStrategy ()
93- else :
94- self ._strategy = DummyStrategy ()
95-
9655 self .batch_size = (
97- BATCH_SIZE_PER_CORE * self ._strategy .num_replicas_in_sync
56+ BATCH_SIZE_PER_CORE * self .strategy .num_replicas_in_sync
9857 )
9958
100- def run_with_strategy (self , fn , * args , jit_compile = False ):
101- """Wrapper for running a function under a strategy."""
102-
103- if keras .backend .backend () == "tensorflow" :
104-
105- @tf .function (jit_compile = jit_compile )
106- def tf_function_wrapper (* tf_function_args ):
107- def strategy_fn (* strategy_fn_args ):
108- return fn (* strategy_fn_args )
109-
110- return self ._strategy .run (strategy_fn , args = tf_function_args )
111-
112- return tf_function_wrapper (* args )
113- else :
114- self .assertFalse (jit_compile )
115- return fn (* args )
116-
11759 def get_embedding_config (self , input_type , placement ):
11860 sequence_length = 1 if input_type == "dense" else SEQUENCE_LENGTH
11961
@@ -252,18 +194,18 @@ def test_basics(self, input_type, placement):
252194
253195 if placement == "sparsecore" and not self .on_tpu :
254196 with self .assertRaisesRegex (Exception , "sparsecore" ):
255- with self ._strategy .scope ():
256- distributed_embedding .DistributedEmbedding (feature_configs )
197+ distributed_embedding .DistributedEmbedding (feature_configs )
257198 return
258199
259- with self ._strategy .scope ():
260- layer = distributed_embedding .DistributedEmbedding (feature_configs )
200+ layer = distributed_embedding .DistributedEmbedding (feature_configs )
261201
262202 if keras .backend .backend () == "jax" :
263203 preprocessed_inputs = layer .preprocess (inputs , weights )
264204 res = layer (preprocessed_inputs )
265205 else :
266- res = self .run_with_strategy (layer .__call__ , inputs , weights )
206+ res = tpu_test_utils .run_with_strategy (
207+ self .strategy , layer .__call__ , inputs , weights
208+ )
267209
268210 if placement == "default_device" or not self .on_tpu :
269211 # verify sublayers and variables are tracked
@@ -332,8 +274,7 @@ def test_model_fit(self, input_type, use_weights):
332274 (test_model_inputs , test_labels )
333275 )
334276
335- with self ._strategy .scope ():
336- layer = distributed_embedding .DistributedEmbedding (feature_configs )
277+ layer = distributed_embedding .DistributedEmbedding (feature_configs )
337278
338279 def _create_keras_input (
339280 feature_config : config .FeatureConfig , dtype : types .DType
@@ -403,7 +344,7 @@ def test_dataset_generator():
403344 # New preprocessed data removes the `weights` component.
404345 dataset_has_weights = False
405346 else :
406- train_dataset = self ._strategy .experimental_distribute_dataset (
347+ train_dataset = self .strategy .experimental_distribute_dataset (
407348 train_dataset ,
408349 options = tf .distribute .InputOptions (
409350 experimental_fetch_to_device = False
@@ -418,19 +359,18 @@ def test_dataset_generator():
418359 inputs = keras_model_inputs , outputs = keras_model_outputs
419360 )
420361
421- with self ._strategy .scope ():
422- model .compile (optimizer = "adam" , loss = "mse" )
362+ model .compile (optimizer = "adam" , loss = "mse" )
423363
424- model_inputs , _ = next (iter (test_dataset ))
425- test_output_before = self .run_with_strategy (
426- model .__call__ , model_inputs
427- )
364+ model_inputs , _ = next (iter (test_dataset ))
365+ test_output_before = tpu_test_utils .run_with_strategy (
366+ self . strategy , model .__call__ , model_inputs
367+ )
428368
429- model .fit (train_dataset , steps_per_epoch = 1 , epochs = 1 )
369+ model .fit (train_dataset , steps_per_epoch = 1 , epochs = 1 )
430370
431- test_output_after = self .run_with_strategy (
432- model .__call__ , model_inputs
433- )
371+ test_output_after = tpu_test_utils .run_with_strategy (
372+ self . strategy , model .__call__ , model_inputs
373+ )
434374
435375 # Verify that the embedding has actually trained.
436376 for before , after in zip (
@@ -567,8 +507,7 @@ def test_correctness(
567507 if not use_weights :
568508 weights = None
569509
570- with self ._strategy .scope ():
571- layer = distributed_embedding .DistributedEmbedding (feature_config )
510+ layer = distributed_embedding .DistributedEmbedding (feature_config )
572511
573512 if keras .backend .backend () == "jax" :
574513 preprocessed = layer .preprocess (inputs , weights )
@@ -610,16 +549,21 @@ def test_correctness(
610549 preprocessed ,
611550 )
612551 else :
613- res = self .run_with_strategy (layer .__call__ , preprocessed )
552+ res = tpu_test_utils .run_with_strategy (
553+ self .strategy , layer .__call__ , preprocessed
554+ )
614555 else :
615- res = self .run_with_strategy (
616- layer .__call__ , inputs , weights , jit_compile = jit_compile
556+ res = tpu_test_utils .run_with_strategy (
557+ self .strategy ,
558+ layer .__call__ ,
559+ inputs ,
560+ weights ,
561+ jit_compile = jit_compile ,
617562 )
618563
619564 self .assertEqual (res .shape , (self .batch_size , EMBEDDING_OUTPUT_DIM ))
620565
621- with self ._strategy .scope ():
622- tables = layer .get_embedding_tables ()
566+ tables = layer .get_embedding_tables ()
623567
624568 emb = tables ["table" ]
625569
@@ -683,10 +627,11 @@ def test_shared_table(self):
683627 "dense" , embedding_config
684628 )
685629
686- with self ._strategy .scope ():
687- layer = distributed_embedding .DistributedEmbedding (embedding_config )
630+ layer = distributed_embedding .DistributedEmbedding (embedding_config )
688631
689- res = self .run_with_strategy (layer .__call__ , inputs )
632+ res = tpu_test_utils .run_with_strategy (
633+ self .strategy , layer .__call__ , inputs
634+ )
690635
691636 if self .placement == "default_device" :
692637 self .assertLen (layer ._flatten_layers (include_self = False ), 1 )
@@ -757,10 +702,11 @@ def test_mixed_placement(self):
757702 "dense" , embedding_config
758703 )
759704
760- with self ._strategy .scope ():
761- layer = distributed_embedding .DistributedEmbedding (embedding_config )
705+ layer = distributed_embedding .DistributedEmbedding (embedding_config )
762706
763- res = self .run_with_strategy (layer .__call__ , inputs )
707+ res = tpu_test_utils .run_with_strategy (
708+ self .strategy , layer .__call__ , inputs
709+ )
764710
765711 self .assertEqual (
766712 res ["feature1" ].shape , (self .batch_size , embedding_output_dim1 )
@@ -786,21 +732,19 @@ def test_save_load_model(self):
786732 with tempfile .TemporaryDirectory () as temp_dir :
787733 path = os .path .join (temp_dir , "model.keras" )
788734
789- with self ._strategy .scope ():
790- layer = distributed_embedding .DistributedEmbedding (
791- feature_configs
792- )
793- keras_outputs = layer (keras_inputs )
794- model = keras .Model (inputs = keras_inputs , outputs = keras_outputs )
735+ layer = distributed_embedding .DistributedEmbedding (feature_configs )
736+ keras_outputs = layer (keras_inputs )
737+ model = keras .Model (inputs = keras_inputs , outputs = keras_outputs )
795738
796- output_before = self .run_with_strategy (model .__call__ , inputs )
797- model .save (path )
739+ output_before = tpu_test_utils .run_with_strategy (
740+ self .strategy , model .__call__ , inputs
741+ )
742+ model .save (path )
798743
799- with self ._strategy .scope ():
800- reloaded_model = keras .models .load_model (path )
801- output_after = self .run_with_strategy (
802- reloaded_model .__call__ , inputs
803- )
744+ reloaded_model = keras .models .load_model (path )
745+ output_after = tpu_test_utils .run_with_strategy (
746+ self .strategy , reloaded_model .__call__ , inputs
747+ )
804748
805749 if self .placement == "sparsecore" :
806750 self .skipTest ("TODO table reloading." )
0 commit comments