Commit 2770f00b authored by vonclites's avatar vonclites Committed by Sergio Guadarrama

Fixed model_deploy to correctly assign weights to variables_device (#992)

* Fixed model_deploy to correctly assign weights to variables_device

* Adding test for network_fn's arg_scope

* Style fix (double blank line)

* Add WORKSPACE file to models/slim so that imports work properly
parent 9681f3fc
...@@ -97,10 +97,10 @@ def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False): ...@@ -97,10 +97,10 @@ def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False):
""" """
if name not in networks_map: if name not in networks_map:
raise ValueError('Name of network unknown %s' % name) raise ValueError('Name of network unknown %s' % name)
arg_scope = arg_scopes_map[name](weight_decay=weight_decay)
func = networks_map[name] func = networks_map[name]
@functools.wraps(func) @functools.wraps(func)
def network_fn(images): def network_fn(images):
arg_scope = arg_scopes_map[name](weight_decay=weight_decay)
with slim.arg_scope(arg_scope): with slim.arg_scope(arg_scope):
return func(images, num_classes, is_training=is_training) return func(images, num_classes, is_training=is_training)
if hasattr(func, 'default_image_size'): if hasattr(func, 'default_image_size'):
......
...@@ -19,11 +19,12 @@ from __future__ import absolute_import ...@@ -19,11 +19,12 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from nets import nets_factory from nets import nets_factory
slim = tf.contrib.slim
class NetworksTest(tf.test.TestCase): class NetworksTest(tf.test.TestCase):
...@@ -42,5 +43,19 @@ class NetworksTest(tf.test.TestCase): ...@@ -42,5 +43,19 @@ class NetworksTest(tf.test.TestCase):
self.assertEqual(logits.get_shape().as_list()[0], batch_size) self.assertEqual(logits.get_shape().as_list()[0], batch_size)
self.assertEqual(logits.get_shape().as_list()[-1], num_classes) self.assertEqual(logits.get_shape().as_list()[-1], num_classes)
def testGetNetworkFnArgScope(self):
batch_size = 5
num_classes = 10
net = 'cifarnet'
with self.test_session(use_gpu=True):
net_fn = nets_factory.get_network_fn(net, num_classes)
image_size = getattr(net_fn, 'default_image_size', 224)
with slim.arg_scope([slim.model_variable, slim.variable],
device='/CPU:0'):
inputs = tf.random_uniform((batch_size, image_size, image_size, 3))
net_fn(inputs)
weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'CifarNet/conv1')[0]
self.assertDeviceEqual('/CPU:0', weights.device)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment