-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathpolicy_net.py
52 lines (38 loc) · 2.15 KB
/
policy_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import gym
import numpy as np
import tensorflow as tf
class Policy_net:
def __init__(self, name: str, env, temp=0.1):
"""
:param name: string
:param env: gym env
:param temp: temperature of boltzmann distribution
"""
ob_space = env.observation_space
act_space = env.action_space
with tf.variable_scope(name):
self.obs = tf.placeholder(dtype=tf.float32, shape=[None] + list(ob_space.shape), name='obs')
with tf.variable_scope('policy_net'):
layer_1 = tf.layers.dense(inputs=self.obs, units=20, activation=tf.tanh)
layer_2 = tf.layers.dense(inputs=layer_1, units=20, activation=tf.tanh)
layer_3 = tf.layers.dense(inputs=layer_2, units=act_space.n, activation=tf.tanh)
self.act_probs = tf.layers.dense(inputs=tf.divide(layer_3, temp), units=act_space.n, activation=tf.nn.softmax)
with tf.variable_scope('value_net'):
layer_1 = tf.layers.dense(inputs=self.obs, units=20, activation=tf.tanh)
layer_2 = tf.layers.dense(inputs=layer_1, units=20, activation=tf.tanh)
self.v_preds = tf.layers.dense(inputs=layer_2, units=1, activation=None)
self.act_stochastic = tf.multinomial(tf.log(self.act_probs), num_samples=1)
self.act_stochastic = tf.reshape(self.act_stochastic, shape=[-1])
self.act_deterministic = tf.argmax(self.act_probs, axis=1)
self.scope = tf.get_variable_scope().name
def act(self, obs, stochastic=True):
if stochastic:
return tf.get_default_session().run([self.act_stochastic, self.v_preds], feed_dict={self.obs: obs})
else:
return tf.get_default_session().run([self.act_deterministic, self.v_preds], feed_dict={self.obs: obs})
def get_action_prob(self, obs):
return tf.get_default_session().run(self.act_probs, feed_dict={self.obs: obs})
def get_variables(self):
return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, self.scope)
def get_trainable_variables(self):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope)