Skip to content

Commit 526c489

Browse files
authored
support all tests in layers to tpu (#170)
* support all tests in layers to tpu * fix jax test * add jax strategy print * support tpu for losses tests and some metrics tests * support all tpu tests in metrics and format * update tpu test workflow * update actions.yml * fix cpu tf error on run with strategy taking kwargs and format * format * fix import * fix test errors * format * fix type * ignore long runnign tpu test * update ignore * clean up * revert unnecessary tpu strategy for eager * revert more unnecessary changes and resolve comments * remove venv and reformat * use a shared strategy in conftest.py * format * format conftest Added type hint for prime_shared_tpu_strategy function. * format import * format * resolve comments * clean gitignore * format mypy * resolve comments * address new comments
1 parent 5d7f18a commit 526c489

19 files changed

+244
-110
lines changed

.github/workflows/actions.yml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,13 @@ jobs:
8989
if: ${{ matrix.backend == 'jax'}}
9090
run: python3 -c "import jax; print('JAX devices:', jax.devices())"
9191

92-
- name: Test with pytest
93-
run: pytest keras_rs/src/layers/embedding/distributed_embedding_test.py
92+
- name: Test with pytest (TensorFlow)
93+
if: ${{ matrix.backend == 'tensorflow' }}
94+
run: pytest keras_rs/ --ignore=keras_rs/src/layers/embedding/jax
95+
96+
- name: Test with pytest (JAX)
97+
if: ${{ matrix.backend == 'jax' }}
98+
run: pytest keras_rs/ --ignore=keras_rs/src/layers/embedding/jax/distributed_embedding_test.py
9499

95100
check_format:
96101
name: Check the code format

keras_rs/src/layers/embedding/distributed_embedding_test.py

Lines changed: 48 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import contextlib
21
import functools
32
import math
43
import os
@@ -14,6 +13,7 @@
1413
from keras_rs.src import types
1514
from keras_rs.src.layers.embedding import distributed_embedding
1615
from keras_rs.src.layers.embedding import distributed_embedding_config as config
16+
from keras_rs.src.utils import tpu_test_utils
1717

1818
try:
1919
import jax
@@ -30,28 +30,6 @@
3030
SEQUENCE_LENGTH = 13
3131

3232

33-
class DummyStrategy:
34-
def scope(self):
35-
return contextlib.nullcontext()
36-
37-
@property
38-
def num_replicas_in_sync(self):
39-
return 1
40-
41-
def run(self, fn, args):
42-
return fn(*args)
43-
44-
def experimental_distribute_dataset(self, dataset, options=None):
45-
del options
46-
return dataset
47-
48-
49-
class JaxDummyStrategy(DummyStrategy):
50-
@property
51-
def num_replicas_in_sync(self):
52-
return jax.device_count("tpu")
53-
54-
5533
def ragged_bool_true(self):
5634
return True
5735

@@ -74,46 +52,10 @@ def setUp(self):
7452
# FLAGS.xla_sparse_core_max_ids_per_partition_per_sample = 16
7553
# FLAGS.xla_sparse_core_max_unique_ids_per_partition_per_sample = 16
7654

77-
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
78-
tf.config.experimental_connect_to_cluster(resolver)
79-
80-
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
81-
tpu_metadata = resolver.get_tpu_system_metadata()
82-
83-
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
84-
topology, num_replicas=tpu_metadata.num_hosts
85-
)
86-
self._strategy = tf.distribute.TPUStrategy(
87-
resolver, experimental_device_assignment=device_assignment
88-
)
89-
print("### num_replicas", self._strategy.num_replicas_in_sync)
90-
self.addCleanup(tf.tpu.experimental.shutdown_tpu_system, resolver)
91-
elif keras.backend.backend() == "jax" and self.on_tpu:
92-
self._strategy = JaxDummyStrategy()
93-
else:
94-
self._strategy = DummyStrategy()
95-
9655
self.batch_size = (
97-
BATCH_SIZE_PER_CORE * self._strategy.num_replicas_in_sync
56+
BATCH_SIZE_PER_CORE * self.strategy.num_replicas_in_sync
9857
)
9958

100-
def run_with_strategy(self, fn, *args, jit_compile=False):
101-
"""Wrapper for running a function under a strategy."""
102-
103-
if keras.backend.backend() == "tensorflow":
104-
105-
@tf.function(jit_compile=jit_compile)
106-
def tf_function_wrapper(*tf_function_args):
107-
def strategy_fn(*strategy_fn_args):
108-
return fn(*strategy_fn_args)
109-
110-
return self._strategy.run(strategy_fn, args=tf_function_args)
111-
112-
return tf_function_wrapper(*args)
113-
else:
114-
self.assertFalse(jit_compile)
115-
return fn(*args)
116-
11759
def get_embedding_config(self, input_type, placement):
11860
sequence_length = 1 if input_type == "dense" else SEQUENCE_LENGTH
11961

@@ -252,18 +194,18 @@ def test_basics(self, input_type, placement):
252194

253195
if placement == "sparsecore" and not self.on_tpu:
254196
with self.assertRaisesRegex(Exception, "sparsecore"):
255-
with self._strategy.scope():
256-
distributed_embedding.DistributedEmbedding(feature_configs)
197+
distributed_embedding.DistributedEmbedding(feature_configs)
257198
return
258199

259-
with self._strategy.scope():
260-
layer = distributed_embedding.DistributedEmbedding(feature_configs)
200+
layer = distributed_embedding.DistributedEmbedding(feature_configs)
261201

262202
if keras.backend.backend() == "jax":
263203
preprocessed_inputs = layer.preprocess(inputs, weights)
264204
res = layer(preprocessed_inputs)
265205
else:
266-
res = self.run_with_strategy(layer.__call__, inputs, weights)
206+
res = tpu_test_utils.run_with_strategy(
207+
self.strategy, layer.__call__, inputs, weights
208+
)
267209

268210
if placement == "default_device" or not self.on_tpu:
269211
# verify sublayers and variables are tracked
@@ -332,8 +274,7 @@ def test_model_fit(self, input_type, use_weights):
332274
(test_model_inputs, test_labels)
333275
)
334276

335-
with self._strategy.scope():
336-
layer = distributed_embedding.DistributedEmbedding(feature_configs)
277+
layer = distributed_embedding.DistributedEmbedding(feature_configs)
337278

338279
def _create_keras_input(
339280
feature_config: config.FeatureConfig, dtype: types.DType
@@ -403,7 +344,7 @@ def test_dataset_generator():
403344
# New preprocessed data removes the `weights` component.
404345
dataset_has_weights = False
405346
else:
406-
train_dataset = self._strategy.experimental_distribute_dataset(
347+
train_dataset = self.strategy.experimental_distribute_dataset(
407348
train_dataset,
408349
options=tf.distribute.InputOptions(
409350
experimental_fetch_to_device=False
@@ -418,19 +359,18 @@ def test_dataset_generator():
418359
inputs=keras_model_inputs, outputs=keras_model_outputs
419360
)
420361

421-
with self._strategy.scope():
422-
model.compile(optimizer="adam", loss="mse")
362+
model.compile(optimizer="adam", loss="mse")
423363

424-
model_inputs, _ = next(iter(test_dataset))
425-
test_output_before = self.run_with_strategy(
426-
model.__call__, model_inputs
427-
)
364+
model_inputs, _ = next(iter(test_dataset))
365+
test_output_before = tpu_test_utils.run_with_strategy(
366+
self.strategy, model.__call__, model_inputs
367+
)
428368

429-
model.fit(train_dataset, steps_per_epoch=1, epochs=1)
369+
model.fit(train_dataset, steps_per_epoch=1, epochs=1)
430370

431-
test_output_after = self.run_with_strategy(
432-
model.__call__, model_inputs
433-
)
371+
test_output_after = tpu_test_utils.run_with_strategy(
372+
self.strategy, model.__call__, model_inputs
373+
)
434374

435375
# Verify that the embedding has actually trained.
436376
for before, after in zip(
@@ -567,8 +507,7 @@ def test_correctness(
567507
if not use_weights:
568508
weights = None
569509

570-
with self._strategy.scope():
571-
layer = distributed_embedding.DistributedEmbedding(feature_config)
510+
layer = distributed_embedding.DistributedEmbedding(feature_config)
572511

573512
if keras.backend.backend() == "jax":
574513
preprocessed = layer.preprocess(inputs, weights)
@@ -610,16 +549,21 @@ def test_correctness(
610549
preprocessed,
611550
)
612551
else:
613-
res = self.run_with_strategy(layer.__call__, preprocessed)
552+
res = tpu_test_utils.run_with_strategy(
553+
self.strategy, layer.__call__, preprocessed
554+
)
614555
else:
615-
res = self.run_with_strategy(
616-
layer.__call__, inputs, weights, jit_compile=jit_compile
556+
res = tpu_test_utils.run_with_strategy(
557+
self.strategy,
558+
layer.__call__,
559+
inputs,
560+
weights,
561+
jit_compile=jit_compile,
617562
)
618563

619564
self.assertEqual(res.shape, (self.batch_size, EMBEDDING_OUTPUT_DIM))
620565

621-
with self._strategy.scope():
622-
tables = layer.get_embedding_tables()
566+
tables = layer.get_embedding_tables()
623567

624568
emb = tables["table"]
625569

@@ -683,10 +627,11 @@ def test_shared_table(self):
683627
"dense", embedding_config
684628
)
685629

686-
with self._strategy.scope():
687-
layer = distributed_embedding.DistributedEmbedding(embedding_config)
630+
layer = distributed_embedding.DistributedEmbedding(embedding_config)
688631

689-
res = self.run_with_strategy(layer.__call__, inputs)
632+
res = tpu_test_utils.run_with_strategy(
633+
self.strategy, layer.__call__, inputs
634+
)
690635

691636
if self.placement == "default_device":
692637
self.assertLen(layer._flatten_layers(include_self=False), 1)
@@ -757,10 +702,11 @@ def test_mixed_placement(self):
757702
"dense", embedding_config
758703
)
759704

760-
with self._strategy.scope():
761-
layer = distributed_embedding.DistributedEmbedding(embedding_config)
705+
layer = distributed_embedding.DistributedEmbedding(embedding_config)
762706

763-
res = self.run_with_strategy(layer.__call__, inputs)
707+
res = tpu_test_utils.run_with_strategy(
708+
self.strategy, layer.__call__, inputs
709+
)
764710

765711
self.assertEqual(
766712
res["feature1"].shape, (self.batch_size, embedding_output_dim1)
@@ -786,21 +732,19 @@ def test_save_load_model(self):
786732
with tempfile.TemporaryDirectory() as temp_dir:
787733
path = os.path.join(temp_dir, "model.keras")
788734

789-
with self._strategy.scope():
790-
layer = distributed_embedding.DistributedEmbedding(
791-
feature_configs
792-
)
793-
keras_outputs = layer(keras_inputs)
794-
model = keras.Model(inputs=keras_inputs, outputs=keras_outputs)
735+
layer = distributed_embedding.DistributedEmbedding(feature_configs)
736+
keras_outputs = layer(keras_inputs)
737+
model = keras.Model(inputs=keras_inputs, outputs=keras_outputs)
795738

796-
output_before = self.run_with_strategy(model.__call__, inputs)
797-
model.save(path)
739+
output_before = tpu_test_utils.run_with_strategy(
740+
self.strategy, model.__call__, inputs
741+
)
742+
model.save(path)
798743

799-
with self._strategy.scope():
800-
reloaded_model = keras.models.load_model(path)
801-
output_after = self.run_with_strategy(
802-
reloaded_model.__call__, inputs
803-
)
744+
reloaded_model = keras.models.load_model(path)
745+
output_after = tpu_test_utils.run_with_strategy(
746+
self.strategy, reloaded_model.__call__, inputs
747+
)
804748

805749
if self.placement == "sparsecore":
806750
self.skipTest("TODO table reloading.")

keras_rs/src/layers/embedding/tensorflow/config_conversion_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import keras
2+
import pytest
23
import tensorflow as tf
34
from absl.testing import parameterized
45

@@ -7,6 +8,10 @@
78
from keras_rs.src.layers.embedding.tensorflow import config_conversion
89

910

11+
@pytest.mark.skipif(
12+
keras.backend.backend() != "tensorflow",
13+
reason="Tensorflow specific test",
14+
)
1015
class ConfigConversionTest(testing.TestCase, parameterized.TestCase):
1116
@parameterized.named_parameters(
1217
(

keras_rs/src/layers/feature_interaction/dot_interaction_test.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
class DotInteractionTest(testing.TestCase, parameterized.TestCase):
1414
def setUp(self):
15+
super().setUp()
16+
1517
self.input = [
1618
ops.array([[0.1, -4.3, 0.2, 1.1, 0.3]]),
1719
ops.array([[2.0, 3.2, -1.0, 0.0, 1.0]]),
@@ -81,7 +83,12 @@ def test_call(self, self_interaction, skip_gather, exp_output_idx):
8183
self_interaction=self_interaction, skip_gather=skip_gather
8284
)
8385
output = layer(self.input)
84-
self.assertAllClose(output, self.exp_outputs[exp_output_idx])
86+
self.assertAllClose(
87+
output,
88+
self.exp_outputs[exp_output_idx],
89+
tpu_atol=1e-2,
90+
tpu_rtol=1e-2,
91+
)
8592

8693
def test_invalid_input_rank(self):
8794
rank_1_input = [ops.ones((3,)), ops.ones((3,))]

keras_rs/src/layers/feature_interaction/feature_cross_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
class FeatureCrossTest(testing.TestCase, parameterized.TestCase):
1212
def setUp(self):
13+
super().setUp()
14+
1315
self.x0 = ops.array([[0.1, 0.2, 0.3]], dtype="float32")
1416
self.x = ops.array([[0.4, 0.5, 0.6]], dtype="float32")
1517
self.exp_output = ops.array([[0.55, 0.8, 1.05]])

keras_rs/src/losses/list_mle_loss_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
class ListMLELossTest(testing.TestCase, parameterized.TestCase):
1212
def setUp(self):
13+
super().setUp()
1314
self.unbatched_scores = ops.array(
1415
[1.0, 3.0, 2.0, 4.0, 0.8], dtype="float32"
1516
)

keras_rs/src/losses/pairwise_hinge_loss_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
class PairwiseHingeLossTest(testing.TestCase, parameterized.TestCase):
1212
def setUp(self):
13+
super().setUp()
1314
self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8])
1415
self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0])
1516

keras_rs/src/losses/pairwise_logistic_loss_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
class PairwiseLogisticLossTest(testing.TestCase, parameterized.TestCase):
1212
def setUp(self):
13+
super().setUp()
1314
self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8])
1415
self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0])
1516

keras_rs/src/losses/pairwise_mean_squared_error_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
class PairwiseMeanSquaredErrorTest(testing.TestCase, parameterized.TestCase):
1414
def setUp(self):
15+
super().setUp()
1516
self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8])
1617
self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0])
1718

keras_rs/src/losses/pairwise_soft_zero_one_loss_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
class PairwiseSoftZeroOneLossTest(testing.TestCase, parameterized.TestCase):
1414
def setUp(self):
15+
super().setUp()
1516
self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8])
1617
self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0])
1718

0 commit comments

Comments
 (0)