提交 1d0a743e 编辑于 作者: Haowen Xu's avatar Haowen Xu
浏览文件

rename "get_training_objective" to "get_training_loss".

上级 3812a6e0
加载中
加载中
加载中
加载中
+1 −1
原始行号 差异行号 差异行
@@ -123,7 +123,7 @@ To save and restore a trained model:
        # Remember to get the model variables after the birth of a
        # `predictor` or a `trainer`.  The :class:`Donut` instances
        # does not build the graph until :meth:`Donut.get_score` or
        # :meth:`Donut.get_training_objective` is called, which is
        # :meth:`Donut.get_training_loss` is called, which is
        # done in the `predictor` or the `trainer`.
        var_dict = get_variables_as_dict(model_vs)

+14 −7
原始行号 差异行号 差异行
import warnings

import tensorflow as tf
from tensorflow import keras as K
from tfsnippet.distributions import Normal
@@ -17,13 +19,13 @@ class Donut(VarScopeObject):
    """
    Class for constructing Donut model.

    This class provides :meth:`get_training_objective` for deriving the
    This class provides :meth:`get_training_loss` for deriving the
    training loss :class:`tf.Tensor`, and :meth:`get_score` for obtaining
    the reconstruction probability :class:`tf.Tensor`.

    Note:
        :class:`Donut` instances will not build the computation graph
        until :meth:`get_training_objective` or :meth:`get_score` is
        until :meth:`get_training_loss` or :meth:`get_score` is
        called.  This suggests that a :class:`donut.DonutTrainer` or
        a :class:`donut.DonutPredictor` must have been constructed
        before saving or restoring the model parameters.
@@ -108,9 +110,9 @@ class Donut(VarScopeObject):
        """
        return self._vae

    def get_training_objective(self, x, y, n_z=None):
    def get_training_loss(self, x, y, n_z=None):
        """
        Get the training objective for `x` and `y`.
        Get the training loss for `x` and `y`.

        Args:
            x (tf.Tensor): 2-D `float32` :class:`tf.Tensor`, the windows of
@@ -122,10 +124,10 @@ class Donut(VarScopeObject):
                dimension)

        Returns:
            tf.Tensor: The training objective, which can be optimized by
                gradient descent algorithms.
            tf.Tensor: 0-d tensor, the training loss, which can be optimized
                by gradient descent algorithms.
        """
        with tf.name_scope('Donut.training_objective'):
        with tf.name_scope('Donut.training_loss'):
            chain = self.vae.chain(x, n_z=n_z)
            x_log_prob = chain.model['x'].log_prob(group_ndims=0)
            alpha = tf.cast(1 - y, dtype=tf.float32)
@@ -142,6 +144,11 @@ class Donut(VarScopeObject):
            loss = tf.reduce_mean(vi.training.sgvb())
            return loss

    def get_training_objective(self, *args, **kwargs):  # pragma: no cover
        warnings.warn('`get_training_objective` is deprecated, use '
                      '`get_training_loss` instead.', DeprecationWarning)
        return self.get_training_loss(*args, **kwargs)

    def get_score(self, x, y=None, n_z=None, mcmc_iteration=None,
                  last_point_only=True):
        """
+2 −2
原始行号 差异行号 差异行
@@ -37,7 +37,7 @@ class DonutTrainer(VarScopeObject):
            (default 0.01)
        use_regularization_loss (bool): Whether or not to add regularization
            loss from `tf.GraphKeys.REGULARIZATION_LOSSES` to the training
            objective? (default :obj:`True`)
            loss? (default :obj:`True`)
        max_epoch (int or None): Maximum epochs to run.  If :obj:`None`,
            will not stop at any particular epoch. (default 256)
        max_step (int or None): Maximum steps to run.  If :obj:`None`,
@@ -122,7 +122,7 @@ class DonutTrainer(VarScopeObject):

            # compose the training loss
            with tf.name_scope('loss'):
                loss = model.get_training_objective(
                loss = model.get_training_loss(
                    x=self._input_x, y=self._input_y, n_z=n_z)
                if use_regularization_loss:
                    loss += tf.losses.get_regularization_loss()
+6 −6
原始行号 差异行号 差异行
@@ -67,7 +67,7 @@ class ModelTestCase(tf.test.TestCase):
                ValueError, match='`z_dims` must be a positive integer'):
            _ = Donut(lambda x: x, lambda x: x, x_dims=1, z_dims=object())

    def test_training_objective(self):
    def test_training_loss(self):
        class Capture(object):
            def __init__(self, vae):
                self._vae = vae
@@ -105,13 +105,13 @@ class ModelTestCase(tf.test.TestCase):
            std_epsilon=0.125,
        )
        capture = Capture(donut.vae)
        _ = donut.get_training_objective(x, y)  # ensure model is built
        _ = donut.get_training_loss(x, y)  # ensure model is built

        # training objective with n_z is None
        # training loss with n_z is None
        with self.test_session() as sess:
            ensure_variables_initialized()

            loss = donut.get_training_objective(x, y)
            loss = donut.get_training_loss(x, y)
            np.testing.assert_equal(capture.q_net['z'].eval(),
                                    np.arange(12).reshape([4, 3]))
            p_net = donut.vae.model(z=capture.q_net['z'], x=x)
@@ -125,11 +125,11 @@ class ModelTestCase(tf.test.TestCase):
            loss2 = -tf.reduce_mean(sgvb)
            np.testing.assert_allclose(*sess.run([loss, loss2]))

        # training objective with n_z > 1
        # training loss with n_z > 1
        with self.test_session() as sess:
            ensure_variables_initialized()

            loss = donut.get_training_objective(x, y, n_z=7)
            loss = donut.get_training_loss(x, y, n_z=7)
            np.testing.assert_equal(capture.q_net['z'].eval(),
                                    np.arange(84).reshape([7, 4, 3]))
            p_net = donut.vae.model(z=capture.q_net['z'], x=x, n_z=7)