Skip to content

Commit

Permalink
Reduce OSS CI errors + fix downstream errors
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636541458
  • Loading branch information
xingyousong authored and copybara-github committed May 25, 2024
1 parent 0f2bbe5 commit fd1745f
Show file tree
Hide file tree
Showing 9 changed files with 240 additions and 173 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/core_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: "${{ matrix.os }}"
strategy:
matrix:
python-version: [3.9]
python-version: ['3.10']
os: [ubuntu-latest]
steps:
- uses: actions/checkout@v2
Expand All @@ -33,6 +33,6 @@ jobs:
- name: Print installed dependencies
run: |
pip freeze
- name: Test with pytest # (TODO: Automate Iris installation)
- name: Test with pytest # TODO(team): Fix tensorflow version conflict)
run: |
# pytest -n auto iris
97 changes: 56 additions & 41 deletions iris/policies/keras_cnn_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Policy class that computes action by running convolutional neural network."""

from typing import Dict, Optional, Sequence, Union

import gym
Expand All @@ -27,22 +28,24 @@
class KerasCNNPolicy(keras_policy.KerasPolicy):
"""Policy class, computes action by running convolutional neural network."""

def __init__(self, ob_space: gym.Space, ac_space: gym.Space,
**kwargs) -> None:
def __init__(
self, ob_space: gym.Space, ac_space: gym.Space, **kwargs
) -> None:
"""Initializes a keras CNN policy. See the base class for more details."""
self._rnn_state = None
super().__init__(ob_space=ob_space, ac_space=ac_space, **kwargs)

def _create_vision_input_layers(self):
vision_input_layers = []
for image_label in self._image_input_labels:
image_size = self._ob_space[image_label].shape
vision_input_layers.append(
tf.keras.layers.Input(
batch_input_shape=(1, image_size[0], image_size[1],
image_size[2]),
dtype="float",
name="vision_input" + image_label))
shape=self._ob_space[image_label].shape,
batch_size=1,
dtype="float32",
name="vision_input" + image_label,
)
)
return vision_input_layers

def _create_other_input_layer(self):
Expand All @@ -53,9 +56,11 @@ def _create_other_input_layer(self):
self._other_ob_dim = utils.flatdim(self._other_ob_space)
if self._other_ob_dim > 0:
return tf.keras.layers.Input(
batch_input_shape=(1, self._other_ob_dim),
dtype="float",
name="other_input")
shape=(self._other_ob_dim,),
batch_size=1,
dtype="float32",
name="other_input",
)
return None

def _create_vision_processing_layers(
Expand All @@ -67,7 +72,8 @@ def _create_vision_processing_layers(
pool_sizes: Optional[Sequence[int]] = None,
pool_strides: Optional[Sequence[int]] = None,
final_vision_activation: str = "relu",
use_spatial_softmax: bool = False) -> tf.keras.layers.Layer:
use_spatial_softmax: bool = False,
) -> tf.keras.layers.Layer:
"""Create keras layers for CNN image processing.
Args:
Expand All @@ -94,17 +100,18 @@ def _create_vision_processing_layers(
pool_strides = [None] * len(conv_filter_sizes)

for filter_size, kernel_size, pool_size, pool_stride in zip(
conv_filter_sizes, conv_kernel_sizes, pool_sizes, pool_strides):
conv_filter_sizes, conv_kernel_sizes, pool_sizes, pool_strides
):
x = tf.keras.layers.Conv2D(
filter_size,
kernel_size=kernel_size,
padding="valid",
activation=final_vision_activation)(
x)
activation=final_vision_activation,
)(x)
if pool_size is not None:
x = tf.keras.layers.MaxPool2D(
pool_size=pool_size, strides=pool_stride)(
x)
x = tf.keras.layers.MaxPool2D(pool_size=pool_size, strides=pool_stride)(
x
)

# Flattening or spatial softmax on image feature map.
if use_spatial_softmax:
Expand All @@ -114,36 +121,41 @@ def _create_vision_processing_layers(

# Encoding image into a feature vector.
return tf.keras.layers.Dense(
image_feature_length, activation=final_vision_activation)(
x)
image_feature_length, activation=final_vision_activation
)(x)

def _create_rnn_layers(self, x, inputs):
"""By default, creates an LSTM."""
lstm_h_state_input = tf.keras.layers.Input(
batch_input_shape=(1, self._rnn_units),
dtype="float",
name="lstm_h_state_input")
shape=(self._rnn_units,),
batch_size=1,
dtype="float32",
name="lstm_h_state_input",
)
lstm_c_state_input = tf.keras.layers.Input(
batch_input_shape=(1, self._rnn_units),
dtype="float",
name="lstm_c_state_input")
shape=(self._rnn_units,),
batch_size=1,
dtype="float32",
name="lstm_c_state_input",
)
inputs.append(lstm_h_state_input)
inputs.append(lstm_c_state_input)
h_state = lstm_h_state_input
c_state = lstm_c_state_input
x = tf.keras.layers.Reshape((1, -1))(x)
x, h_state, c_state = tf.keras.layers.LSTM(
units=self._rnn_units, return_state=True, stateful=True)(
x, initial_state=[lstm_h_state_input, lstm_c_state_input])
units=self._rnn_units, return_state=True, stateful=True
)(x, initial_state=[lstm_h_state_input, lstm_c_state_input])
return x, [h_state, c_state]

def _build_model(self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
fc_layer_sizes: Sequence[int],
use_rnn: bool = False,
rnn_units: int = 32,
image_input_label: Union[Sequence[str], str] = "vision",
final_layer_init: str = "glorot_uniform",
**kwargs) -> None:
# pytype: disable=signature-mismatch # overriding-parameter-count-checks
def _build_model(
self,
fc_layer_sizes: Sequence[int],
use_rnn: bool = False,
rnn_units: int = 32,
image_input_label: Union[Sequence[str], str] = "vision",
final_layer_init: str = "glorot_uniform",
**kwargs
) -> None:
"""Constructs a keras CNN to process vision and other sensor data.
Args:
Expand Down Expand Up @@ -171,7 +183,8 @@ def _build_model(self, # pytype: disable=signature-mismatch # overriding-param
vision_outputs = []
for vision_input in inputs:
vision_outputs.append(
self._create_vision_processing_layers(x=vision_input, **kwargs))
self._create_vision_processing_layers(x=vision_input, **kwargs)
)
vision_output = tf.keras.layers.concatenate(vision_outputs)

if self._use_rnn:
Expand All @@ -190,16 +203,18 @@ def _build_model(self, # pytype: disable=signature-mismatch # overriding-param
for fc_layer_size in fc_layer_sizes:
x = tf.keras.layers.Dense(fc_layer_size, activation="tanh")(x)
action_output = tf.keras.layers.Dense(
self._ac_dim, activation="tanh", kernel_initializer=final_layer_init)(
x)
self._ac_dim, activation="tanh", kernel_initializer=final_layer_init
)(x)
outputs.append(action_output)

self.model = tf.keras.models.Model(inputs=inputs, outputs=outputs)

# pytype: enable=signature-mismatch # overriding-parameter-count-checks

def reset(self) -> None:
"""Resets the policy's internal state (default LSTM)."""
lstm_h_state = np.zeros(shape=(1, self._rnn_units), dtype="float")
lstm_c_state = np.zeros(shape=(1, self._rnn_units), dtype="float")
lstm_h_state = np.zeros(shape=(1, self._rnn_units), dtype="float32")
lstm_c_state = np.zeros(shape=(1, self._rnn_units), dtype="float32")
self._rnn_state = [lstm_h_state, lstm_c_state]

def act(
Expand Down
31 changes: 19 additions & 12 deletions iris/policies/keras_nn_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
class KerasNNPolicy(keras_policy.KerasPolicy):
"""Policy class that computes action by running feed fwd neural network."""

def _build_model(self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
hidden_layer_sizes: Sequence[int],
activation: str = "tanh",
use_bias: bool = False,
kernel_initializer: str = "zeros") -> None:
# pytype: disable=signature-mismatch # overriding-parameter-count-checks
def _build_model(
self,
hidden_layer_sizes: Sequence[int],
activation: str = "tanh",
use_bias: bool = False,
kernel_initializer: str = "zeros",
) -> None:
"""Constructs a keras feed forward neural network model.
Args:
Expand All @@ -37,20 +40,24 @@ def _build_model(self, # pytype: disable=signature-mismatch # overriding-param
"""
# Creates model.
input_layer = tf.keras.layers.Input(
batch_input_shape=(1, self._ob_dim), dtype="float", name="input")
shape=(self._ob_dim,), batch_size=1, dtype="float32", name="input"
)
x = input_layer
for layer_size in hidden_layer_sizes:
x = tf.keras.layers.Dense(
layer_size,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer)(
x)
kernel_initializer=kernel_initializer,
)(x)
output_layer = tf.keras.layers.Dense(
self._ac_dim,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer)(
x)
self.model = tf.keras.models.Model(inputs=[input_layer],
outputs=[output_layer])
kernel_initializer=kernel_initializer,
)(x)
self.model = tf.keras.models.Model(
inputs=[input_layer], outputs=[output_layer]
)

# pytype: enable=signature-mismatch # overriding-parameter-count-checks
Loading

0 comments on commit fd1745f

Please sign in to comment.