DSN Updates

parent 0dbc90d4
......@@ -4,17 +4,10 @@
## Introduction
This code is the code used for the "Domain Separation Networks" paper
by Bousmalis K., Trigeorgis G., et al. which was presented at NIPS 2016. The
<<<<<<< HEAD
paper can be found here: https://arxiv.org/abs/1608.06019
## Contact
This code was open-sourced by Konstantinos Bousmalis ([email protected], github:bousmalis)
=======
paper can be found here: https://arxiv.org/abs/1608.06019.
## Contact
This code was open-sourced by [Konstantinos Bousmalis](https://github.com/bousmalis) ([email protected]).
>>>>>>> d6bee2c713c6aed6522ab32c34b57412d0216d95
## Installation
You will need to have the following installed on your machine before trying out the DSN code.
......@@ -26,35 +19,27 @@ You will need to have the following installed on your machine before trying out
Although we are making the code available, you are only able to use the MNIST
provider for now. We will soon provide a script to download and convert MNIST-M
as well. Check back here in a few weeks or wait for a relevant announcement from
<<<<<<< HEAD
Twitter @bousmalis.
=======
[@bousmalis](https://twitter.com/bousmalis).
>>>>>>> d6bee2c713c6aed6522ab32c34b57412d0216d95
## Running the code for adapting MNIST to MNIST-M
In order to run the MNIST to MNIST-M experiments with DANNs and/or DANNs with
domain separation (DSNs) you will need to set the directory you used to download
<<<<<<< HEAD
MNIST and MNIST-M:\
=======
MNIST and MNIST-M:
>>>>>>> d6bee2c713c6aed6522ab32c34b57412d0216d95
```
$ export DSN_DATA_DIR=/your/dir
```
Then you need to build the binaries with Bazel:
Add models and models/slim to your `$PYTHONPATH`:
```
$ bazel build -c opt domain_adaptation/domain_separation/...
$ export PYTHONPATH=$PYTHONPATH:$PWD:$PWD/slim
```
Add models and models/slim to your `$PYTHONPATH`:
Then you need to build the binaries with Bazel:
```
$ export PYTHONPATH=$PYTHONPATH:$PWD:$PWD/slim
$ bazel build -c opt domain_adaptation/domain_separation/...
```
You can then train with the following command:
......
......@@ -14,22 +14,7 @@
# ==============================================================================
# pylint: disable=line-too-long
r"""Evaluation for Domain Separation Networks (DSNs).
To build locally for CPU:
blaze build -c opt --copt=-mavx \
third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
To build locally for GPU:
blaze build -c opt --copt=-mavx --config=cuda_clang \
third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
To run locally:
$
./blaze-bin/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval
\
--alsologtostderr
"""
"""Evaluation for Domain Separation Networks (DSNs)."""
# pylint: enable=line-too-long
import math
......
......@@ -13,30 +13,7 @@
# limitations under the License.
# ==============================================================================
# pylint: disable=line-too-long
r"""Training for Domain Separation Networks (DSNs).
-- Compile:
$ blaze build -c opt --copt=-mavx --config=cuda \
third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_train
-- Run:
$
./blaze-bin/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_train
\
--similarity_loss=dann \
--basic_tower=dsn_cropped_linemod \
--source_dataset=pose_synthetic \
--target_dataset=pose_real \
--learning_rate=0.012 \
--alpha_weight=0.26 \
--gamma_weight=0.0115 \
--weight_decay=4e-5 \
--layers_to_regularize=fc3 \
--use_separation \
--alsologtostderr
"""
# pylint: enable=line-too-long
"""Training for Domain Separation Networks (DSNs)."""
from __future__ import division
import tensorflow as tf
......@@ -59,7 +36,7 @@ tf.app.flags.DEFINE_string('target_dataset', 'pose_real',
tf.app.flags.DEFINE_string('target_labeled_dataset', 'none',
'Target dataset to train on.')
tf.app.flags.DEFINE_string('dataset_dir', '/cns/ok-d/home/konstantinos/cad_learning/',
tf.app.flags.DEFINE_string('dataset_dir', None,
'The directory where the dataset files are stored.')
tf.app.flags.DEFINE_string('master', '',
......
......@@ -178,16 +178,14 @@ def dann_loss(source_samples, target_samples, weight, scope=None):
assert_op = tf.Assert(tf.is_finite(domain_loss), [domain_loss])
with tf.control_dependencies([assert_op]):
tag_loss = 'losses/Domain Loss'
tag_accuracy = 'losses/Domain Accuracy'
tag_loss = 'losses/domain_loss'
tag_accuracy = 'losses/domain_accuracy'
if scope:
tag_loss = scope + tag_loss
tag_accuracy = scope + tag_accuracy
tf.summary.scalar(
tag_loss, domain_loss, name='domain_loss_summary')
tf.summary.scalar(
tag_accuracy, domain_accuracy, name='domain_accuracy_summary')
tf.summary.scalar(tag_loss, domain_loss)
tf.summary.scalar(tag_accuracy, domain_accuracy)
return domain_loss
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment