Commit 2d7bd1d5 authored by Alex Lee's avatar Alex Lee Committed by Neal Wu

Fixes for compatibility with TF 1.0 and python 3.

parent 5e38011f
......@@ -38,17 +38,11 @@ def init_state(inputs,
if inputs is not None:
# Handle both the dynamic shape as well as the inferred shape.
inferred_batch_size = inputs.get_shape().with_rank_at_least(1)[0]
batch_size = tf.shape(inputs)[0]
dtype = inputs.dtype
else:
inferred_batch_size = 0
batch_size = 0
initial_state = state_initializer(
tf.stack([batch_size] + state_shape),
dtype=dtype)
initial_state.set_shape([inferred_batch_size] + state_shape)
[inferred_batch_size] + state_shape, dtype=dtype)
return initial_state
......
......@@ -103,21 +103,24 @@ class Model(object):
actions=None,
states=None,
sequence_length=None,
reuse_scope=None):
reuse_scope=None,
prefix=None):
if sequence_length is None:
sequence_length = FLAGS.sequence_length
self.prefix = prefix = tf.placeholder(tf.string, [])
if prefix is None:
prefix = tf.placeholder(tf.string, [])
self.prefix = prefix
self.iter_num = tf.placeholder(tf.float32, [])
summaries = []
# Split into timesteps.
actions = tf.split(axis=1, num_or_size_splits=actions.get_shape()[1], value=actions)
actions = tf.split(axis=1, num_or_size_splits=int(actions.get_shape()[1]), value=actions)
actions = [tf.squeeze(act) for act in actions]
states = tf.split(axis=1, num_or_size_splits=states.get_shape()[1], value=states)
states = tf.split(axis=1, num_or_size_splits=int(states.get_shape()[1]), value=states)
states = [tf.squeeze(st) for st in states]
images = tf.split(axis=1, num_or_size_splits=images.get_shape()[1], value=images)
images = tf.split(axis=1, num_or_size_splits=int(images.get_shape()[1]), value=images)
images = [tf.squeeze(img) for img in images]
if reuse_scope is None:
......@@ -183,17 +186,18 @@ class Model(object):
def main(unused_argv):
print 'Constructing models and inputs.'
print('Constructing models and inputs.')
with tf.variable_scope('model', reuse=None) as training_scope:
images, actions, states = build_tfrecord_input(training=True)
model = Model(images, actions, states, FLAGS.sequence_length)
model = Model(images, actions, states, FLAGS.sequence_length,
prefix='train')
with tf.variable_scope('val_model', reuse=None):
val_images, val_actions, val_states = build_tfrecord_input(training=False)
val_model = Model(val_images, val_actions, val_states,
FLAGS.sequence_length, training_scope)
FLAGS.sequence_length, training_scope, prefix='val')
print 'Constructing saver.'
print('Constructing saver.')
# Make saver.
saver = tf.train.Saver(
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=0)
......@@ -214,8 +218,7 @@ def main(unused_argv):
# Run training.
for itr in range(FLAGS.num_iterations):
# Generate new batch of data.
feed_dict = {model.prefix: 'train',
model.iter_num: np.float32(itr),
feed_dict = {model.iter_num: np.float32(itr),
model.lr: FLAGS.learning_rate}
cost, _, summary_str = sess.run([model.loss, model.train_op, model.summ_op],
feed_dict)
......@@ -226,7 +229,6 @@ def main(unused_argv):
if (itr) % VAL_INTERVAL == 2:
# Run through validation set.
feed_dict = {val_model.lr: 0.0,
val_model.prefix: 'val',
val_model.iter_num: np.float32(itr)}
_, val_summary_str = sess.run([val_model.train_op, val_model.summ_op],
feed_dict)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment