Commit 41c52d60 authored by David Dao's avatar David Dao

Spatial Transformer model

Shorten STN summary in README

relinked to data files

adding license header, editing AUTHORS file

adding tensorflow version
parent d51fdd20
......@@ -7,3 +7,4 @@
# The email address is not required for organizations.
Google Inc.
David Dao <[email protected]>
# Spatial Transformer Network
The Spatial Transformer Network [1] allows the spatial manipulation of data within the network.
<div align="center">
<img width="600px" src="http://i.imgur.com/ExGDVul.png"><br><br>
</div>
### API
A Spatial Transformer Network implemented in Tensorflow 0.7 and based on [2].
#### How to use
<div align="center">
<img src="http://i.imgur.com/gfqLV3f.png"><br><br>
</div>
```python
transformer(U, theta, downsample_factor=1)
```
#### Parameters
U : float
The output of a convolutional net should have the
shape [num_batch, height, width, num_channels].
theta: float
The output of the
localisation network should be [num_batch, 6].
downsample_factor : float
A value of 1 will keep the original size of the image
Values larger than 1 will downsample the image.
Values below 1 will upsample the image
example image: height = 100, width = 200
downsample_factor = 2
output image will then be 50, 100
#### Notes
To initialize the network to the identity transform init ``theta`` to :
```python
identity = np.array([[1., 0., 0.],
[0., 1., 0.]])
identity = identity.flatten()
theta = tf.Variable(initial_value=identity)
```
#### Experiments
<div align="center">
<img width="600px" src="http://i.imgur.com/HtCBYk2.png"><br><br>
</div>
We used cluttered MNIST. Left column are the input images, right are the attended parts of the image by an STN.
All experiments were run in Tensorflow 0.7.
### References
[1] Jaderberg, Max, et al. "Spatial Transformer Networks." arXiv preprint arXiv:1506.02025 (2015)
[2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
from spatial_transformer import transformer
from scipy import ndimage
import numpy as np
import matplotlib.pyplot as plt
from tf_utils import conv2d, linear, weight_variable, bias_variable, dense_to_one_hot
# %% Load data
mnist_cluttered = np.load('./data/mnist_sequence1_sample_5distortions5x5.npz')
X_train = mnist_cluttered['X_train']
y_train = mnist_cluttered['y_train']
X_valid = mnist_cluttered['X_valid']
y_valid = mnist_cluttered['y_valid']
X_test = mnist_cluttered['X_test']
y_test = mnist_cluttered['y_test']
# % turn from dense to one hot representation
Y_train = dense_to_one_hot(y_train, n_classes=10)
Y_valid = dense_to_one_hot(y_valid, n_classes=10)
Y_test = dense_to_one_hot(y_test, n_classes=10)
# %% Graph representation of our network
# %% Placeholders for 40x40 resolution
x = tf.placeholder(tf.float32, [None, 1600])
y = tf.placeholder(tf.float32, [None, 10])
# %% Since x is currently [batch, height*width], we need to reshape to a
# 4-D tensor to use it in a convolutional graph. If one component of
# `shape` is the special value -1, the size of that dimension is
# computed so that the total size remains constant. Since we haven't
# defined the batch dimension's shape yet, we use -1 to denote this
# dimension should not change size.
x_tensor = tf.reshape(x, [-1, 40, 40, 1])
# %% We'll setup the two-layer localisation network to figure out the parameters for an affine transformation of the input
# %% Create variables for fully connected layer
W_fc_loc1 = weight_variable([1600, 20])
b_fc_loc1 = bias_variable([20])
W_fc_loc2 = weight_variable([20, 6])
initial = np.array([[1.,0, 0],[0,1.,0]]) # Use identity transformation as starting point
initial = initial.astype('float32')
initial = initial.flatten()
b_fc_loc2 = tf.Variable(initial_value=initial, name='b_fc_loc2')
# %% Define the two layer localisation network
h_fc_loc1 = tf.nn.tanh(tf.matmul(x, W_fc_loc1) + b_fc_loc1)
# %% We can add dropout for regularizing and to reduce overfitting like so:
keep_prob = tf.placeholder(tf.float32)
h_fc_loc1_drop = tf.nn.dropout(h_fc_loc1, keep_prob)
# %% Second layer
h_fc_loc2 = tf.nn.tanh(tf.matmul(h_fc_loc1_drop, W_fc_loc2) + b_fc_loc2)
# %% We'll create a spatial transformer module to identify discriminative patches
h_trans = transformer(x_tensor, h_fc_loc2, downsample_factor=1)
# %% We'll setup the first convolutional layer
# Weight matrix is [height x width x input_channels x output_channels]
filter_size = 3
n_filters_1 = 16
W_conv1 = weight_variable([filter_size, filter_size, 1, n_filters_1])
# %% Bias is [output_channels]
b_conv1 = bias_variable([n_filters_1])
# %% Now we can build a graph which does the first layer of convolution:
# we define our stride as batch x height x width x channels
# instead of pooling, we use strides of 2 and more layers
# with smaller filters.
h_conv1 = tf.nn.relu(
tf.nn.conv2d(input=h_trans,
filter=W_conv1,
strides=[1, 2, 2, 1],
padding='SAME') +
b_conv1)
# %% And just like the first layer, add additional layers to create
# a deep net
n_filters_2 = 16
W_conv2 = weight_variable([filter_size, filter_size, n_filters_1, n_filters_2])
b_conv2 = bias_variable([n_filters_2])
h_conv2 = tf.nn.relu(
tf.nn.conv2d(input=h_conv1,
filter=W_conv2,
strides=[1, 2, 2, 1],
padding='SAME') +
b_conv2)
# %% We'll now reshape so we can connect to a fully-connected layer:
h_conv2_flat = tf.reshape(h_conv2, [-1, 10 * 10 * n_filters_2])
# %% Create a fully-connected layer:
n_fc = 1024
W_fc1 = weight_variable([10 * 10 * n_filters_2, n_fc])
b_fc1 = bias_variable([n_fc])
h_fc1 = tf.nn.relu(tf.matmul(h_conv2_flat, W_fc1) + b_fc1)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
# %% And finally our softmax layer:
W_fc2 = weight_variable([n_fc, 10])
b_fc2 = bias_variable([10])
y_pred = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
# %% Define loss/eval/training functions
cross_entropy = -tf.reduce_sum(y * tf.log(y_pred))
opt = tf.train.AdamOptimizer()
optimizer = opt.minimize(cross_entropy)
grads = opt.compute_gradients(cross_entropy, [b_fc_loc2])
# %% Monitor accuracy
correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
# %% We now create a new session to actually perform the initialization the
# variables:
sess = tf.Session()
sess.run(tf.initialize_all_variables())
# %% We'll now train in minibatches and report accuracy, loss:
iter_per_epoch = 100
n_epochs = 500
train_size = 10000
indices = np.linspace(0,10000 - 1,iter_per_epoch)
indices = indices.astype('int')
for epoch_i in range(n_epochs):
for iter_i in range(iter_per_epoch - 1):
batch_xs = X_train[indices[iter_i]:indices[iter_i+1]]
batch_ys = Y_train[indices[iter_i]:indices[iter_i+1]]
if iter_i % 10 == 0:
loss = sess.run(cross_entropy,
feed_dict={
x: batch_xs,
y: batch_ys,
keep_prob: 1.0
})
print('Iteration: ' + str(iter_i) + ' Loss: ' + str(loss))
sess.run(optimizer, feed_dict={
x: batch_xs, y: batch_ys, keep_prob: 0.8})
print('Accuracy: ' + str(sess.run(accuracy,
feed_dict={
x: X_valid,
y: Y_valid,
keep_prob: 1.0
})))
#theta = sess.run(h_fc_loc2, feed_dict={
# x: batch_xs, keep_prob: 1.0})
#print(theta[0])
### How to get the data
#### Cluttered MNIST
The cluttered MNIST dataset can be found here [1] or can be generated via [2].
Settings used for `cluttered_mnist.py` :
```python
ORG_SHP = [28, 28]
OUT_SHP = [40, 40]
NUM_DISTORTIONS = 8
dist_size = (5, 5)
```
[1] https://github.com/daviddao/spatial-transformer-tensorflow
[2] https://github.com/skaae/recurrent-spatial-transformer-code/blob/master/MNIST_SEQUENCE/create_mnist_sequence.py
\ No newline at end of file
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
from spatial_transformer import transformer
from scipy import ndimage
import numpy as np
import matplotlib.pyplot as plt
from tf_utils import conv2d, linear, weight_variable, bias_variable
# %% Create a batch of three images (1600 x 1200)
# %% Image retrieved from https://raw.githubusercontent.com/skaae/transformer_network/master/cat.jpg
im = ndimage.imread('cat.jpg')
im = im / 255.
im = im.reshape(1, 1200, 1600, 3)
im = im.astype('float32')
# %% Simulate batch
batch = np.append(im, im, axis=0)
batch = np.append(batch, im, axis=0)
num_batch = 3
x = tf.placeholder(tf.float32, [None, 1200, 1600, 3])
x = tf.cast(batch,'float32')
# %% Create localisation network and convolutional layer
with tf.variable_scope('spatial_transformer_0'):
# %% Create a fully-connected layer with 6 output nodes
n_fc = 6
W_fc1 = tf.Variable(tf.zeros([1200 * 1600 * 3, n_fc]), name='W_fc1')
# %% Zoom into the image
initial = np.array([[0.5,0, 0],[0,0.5,0]])
initial = initial.astype('float32')
initial = initial.flatten()
b_fc1 = tf.Variable(initial_value=initial, name='b_fc1')
h_fc1 = tf.matmul(tf.zeros([num_batch ,1200 * 1600 * 3]), W_fc1) + b_fc1
h_trans = transformer(x, h_fc1, downsample_factor=2)
# %% Run session
sess = tf.Session()
sess.run(tf.initialize_all_variables())
y = sess.run(h_trans, feed_dict={x: batch})
# plt.imshow(y[0])
\ No newline at end of file
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwargs):
"""Spatial Transformer Layer
Implements a spatial transformer layer as described in [1]_.
Based on [2]_ and edited by David Dao for Tensorflow.
Parameters
----------
U : float
The output of a convolutional net should have the
shape [num_batch, height, width, num_channels].
theta: float
The output of the
localisation network should be [num_batch, 6].
downsample_factor : float
A value of 1 will keep the original size of the image
Values larger than 1 will downsample the image.
Values below 1 will upsample the image
example image: height = 100, width = 200
downsample_factor = 2
output image will then be 50, 100
References
----------
.. [1] Spatial Transformer Networks
Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu
Submitted on 5 Jun 2015
.. [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py
Notes
-----
To initialize the network to the identity transform init
``theta`` to :
identity = np.array([[1., 0., 0.],
[0., 1., 0.]])
identity = identity.flatten()
theta = tf.Variable(initial_value=identity)
"""
def _repeat(x, n_repeats):
with tf.variable_scope('_repeat'):
rep = tf.transpose(tf.expand_dims(tf.ones(shape=tf.pack([n_repeats,])),1),[1,0])
rep = tf.cast(rep, 'int32')
x = tf.matmul(tf.reshape(x,(-1, 1)), rep)
return tf.reshape(x,[-1])
def _interpolate(im, x, y, downsample_factor):
with tf.variable_scope('_interpolate'):
# constants
num_batch = tf.shape(im)[0]
height = tf.shape(im)[1]
width = tf.shape(im)[2]
channels = tf.shape(im)[3]
x = tf.cast(x, 'float32')
y = tf.cast(y, 'float32')
height_f = tf.cast(height, 'float32')
width_f = tf.cast(width, 'float32')
out_height = tf.cast(height_f // downsample_factor, 'int32')
out_width = tf.cast(width_f // downsample_factor, 'int32')
zero = tf.zeros([], dtype='int32')
max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
# scale indices from [-1, 1] to [0, width/height]
x = (x + 1.0)*(width_f) / 2.0
y = (y + 1.0)*(height_f) / 2.0
# do sampling
x0 = tf.cast(tf.floor(x), 'int32')
x1 = x0 + 1
y0 = tf.cast(tf.floor(y), 'int32')
y1 = y0 + 1
x0 = tf.clip_by_value(x0, zero, max_x)
x1 = tf.clip_by_value(x1, zero, max_x)
y0 = tf.clip_by_value(y0, zero, max_y)
y1 = tf.clip_by_value(y1, zero, max_y)
dim2 = width
dim1 = width*height
base = _repeat(tf.range(num_batch)*dim1, out_height*out_width)
base_y0 = base + y0*dim2
base_y1 = base + y1*dim2
idx_a = base_y0 + x0
idx_b = base_y1 + x0
idx_c = base_y0 + x1
idx_d = base_y1 + x1
# use indices to lookup pixels in the flat image and restore channels dim
im_flat = tf.reshape(im,tf.pack([-1, channels]))
im_flat = tf.cast(im_flat, 'float32')
Ia = tf.gather(im_flat, idx_a)
Ib = tf.gather(im_flat, idx_b)
Ic = tf.gather(im_flat, idx_c)
Id = tf.gather(im_flat, idx_d)
# and finally calculate interpolated values
x0_f = tf.cast(x0, 'float32')
x1_f = tf.cast(x1, 'float32')
y0_f = tf.cast(y0, 'float32')
y1_f = tf.cast(y1, 'float32')
wa = tf.expand_dims(((x1_f-x) * (y1_f-y)),1)
wb = tf.expand_dims(((x1_f-x) * (y-y0_f)),1)
wc = tf.expand_dims(((x-x0_f) * (y1_f-y)),1)
wd = tf.expand_dims(((x-x0_f) * (y-y0_f)),1)
output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
return output
def _meshgrid(height, width):
with tf.variable_scope('_meshgrid'):
# This should be equivalent to:
# x_t, y_t = np.meshgrid(np.linspace(-1, 1, width),
# np.linspace(-1, 1, height))
# ones = np.ones(np.prod(x_t.shape))
# grid = np.vstack([x_t.flatten(), y_t.flatten(), ones])
x_t = tf.matmul(tf.ones(shape=tf.pack([height, 1])),
tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width),1),[1,0]))
y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height),1),
tf.ones(shape=tf.pack([1, width])))
x_t_flat = tf.reshape(x_t,(1, -1))
y_t_flat = tf.reshape(y_t,(1, -1))
ones = tf.ones_like(x_t_flat)
grid = tf.concat(0, [x_t_flat, y_t_flat, ones])
return grid
def _transform(theta, input_dim, downsample_factor):
with tf.variable_scope('_transform'):
num_batch = tf.shape(input_dim)[0]
height = tf.shape(input_dim)[1]
width = tf.shape(input_dim)[2]
num_channels = tf.shape(input_dim)[3]
theta = tf.reshape(theta, (-1, 2, 3))
theta = tf.cast(theta, 'float32')
# grid of (x_t, y_t, 1), eq (1) in ref [1]
height_f = tf.cast(height, 'float32')
width_f = tf.cast(width, 'float32')
out_height = tf.cast(height_f // downsample_factor, 'int32')
out_width = tf.cast(width_f // downsample_factor, 'int32')
grid = _meshgrid(out_height, out_width)
grid = tf.expand_dims(grid,0)
grid = tf.reshape(grid,[-1])
grid = tf.tile(grid,tf.pack([num_batch]))
grid = tf.reshape(grid,tf.pack([num_batch, 3, -1]))
# Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
T_g = tf.batch_matmul(theta, grid)
x_s = tf.slice(T_g, [0,0,0], [-1,1,-1])
y_s = tf.slice(T_g, [0,1,0], [-1,1,-1])
x_s_flat = tf.reshape(x_s,[-1])
y_s_flat = tf.reshape(y_s,[-1])
input_transformed = _interpolate(
input_dim, x_s_flat, y_s_flat,
downsample_factor)
output = tf.reshape(input_transformed, tf.pack([num_batch, out_height, out_width, num_channels]))
return output
with tf.variable_scope(name):
output = _transform(theta, U, downsample_factor)
return output
\ No newline at end of file
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# %% Borrowed utils from here: https://github.com/pkmital/tensorflow_tutorials/
import tensorflow as tf
import numpy as np
def conv2d(x, n_filters,
k_h=5, k_w=5,
stride_h=2, stride_w=2,
stddev=0.02,
activation=lambda x: x,
bias=True,
padding='SAME',
name="Conv2D"):
"""2D Convolution with options for kernel size, stride, and init deviation.
Parameters
----------
x : Tensor
Input tensor to convolve.
n_filters : int
Number of filters to apply.
k_h : int, optional
Kernel height.
k_w : int, optional
Kernel width.
stride_h : int, optional
Stride in rows.
stride_w : int, optional
Stride in cols.
stddev : float, optional
Initialization's standard deviation.
activation : arguments, optional
Function which applies a nonlinearity
padding : str, optional
'SAME' or 'VALID'
name : str, optional
Variable scope to use.
Returns
-------
x : Tensor
Convolved input.
"""
with tf.variable_scope(name):
w = tf.get_variable(
'w', [k_h, k_w, x.get_shape()[-1], n_filters],
initializer=tf.truncated_normal_initializer(stddev=stddev))
conv = tf.nn.conv2d(
x, w, strides=[1, stride_h, stride_w, 1], padding=padding)
if bias:
b = tf.get_variable(
'b', [n_filters],
initializer=tf.truncated_normal_initializer(stddev=stddev))
conv = conv + b
return conv
def linear(x, n_units, scope=None, stddev=0.02,
activation=lambda x: x):
"""Fully-connected network.
Parameters
----------
x : Tensor
Input tensor to the network.
n_units : int
Number of units to connect to.
scope : str, optional
Variable scope to use.
stddev : float, optional
Initialization's standard deviation.
activation : arguments, optional
Function which applies a nonlinearity
Returns
-------
x : Tensor
Fully-connected output.
"""
shape = x.get_shape().as_list()
with tf.variable_scope(scope or "Linear"):
matrix = tf.get_variable("Matrix", [shape[1], n_units], tf.float32,
tf.random_normal_initializer(stddev=stddev))
return activation(tf.matmul(x, matrix))
# %%
def weight_variable(shape):
'''Helper function to create a weight variable initialized with
a normal distribution
Parameters
----------
shape : list
Size of weight variable
'''
#initial = tf.random_normal(shape, mean=0.0, stddev=0.01)
initial = tf.zeros(shape)
return tf.Variable(initial)
# %%
def bias_variable(shape):
'''Helper function to create a bias variable initialized with
a constant value.
Parameters
----------
shape : list
Size of weight variable
'''
initial = tf.random_normal(shape, mean=0.0, stddev=0.01)
return tf.Variable(initial)
# %%
def dense_to_one_hot(labels, n_classes=2):
"""Convert class labels from scalars to one-hot vectors."""
labels = np.array(labels)
n_labels = labels.shape[0]
index_offset = np.arange(n_labels) * n_classes
labels_one_hot = np.zeros((n_labels, n_classes), dtype=np.float32)
labels_one_hot.flat[index_offset + labels.ravel()] = 1
return labels_one_hot
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment