提交 9fb8f962 编辑于 作者: Haowen Xu's avatar Haowen Xu
浏览文件

minor update

上级 9e9bad43
加载中
加载中
加载中
加载中
+30 −0
原始行号 差异行号 差异行
@@ -130,3 +130,33 @@ To save and restore a trained model:
        saver = VariableSaver(get_variables_as_dict(model_vs), save_dir)
        saver.restore()


If you need more advanced outputs from the model, you may derive the outputs
by using `model.vae` directly, for example:

.. code-block:: python

    from donut import iterative_masked_reconstruct

    # Obtain the reconstructed `x`, with MCMC missing data imputation.
    # See also:
    #   :meth:`donut.Donut.get_score`
    #   :func:`donut.iterative_masked_reconstruct`
    #   :meth:`tfsnippet.modules.VAE.reconstruct`
    input_x = ...  # 2-D `float32` :class:`tf.Tensor`, input `x` windows
    input_y = ...  # 2-D `int32` :class:`tf.Tensor`, missing point indicators
                   # for the `x` windows
    x_r = model.vae.reconstruct(
        iterative_masked_reconstruct(
            reconstruct=model.vae.reconstruct,
            x=input_x,
            mask=input_y,
            iter_count=mcmc_iteration,
            back_prop=False
        )
    )
    # `x_r` is a :class:`tfsnippet.stochastic.StochasticTensor`, from which
    # you may derive many useful outputs, for example:
    x_r.tensor  # the `x` samples
    x_r.log_prob(group_ndims=0)  # element-wise log p(x|z)
    x_r.distribution.mean, x_r.distribution.std  # mean and std of p(x|z)
+6 −1
原始行号 差异行号 差异行
@@ -162,7 +162,12 @@ class Donut(VarScopeObject):
                (default :obj:`True`)

        Returns:
            tf.Tensor: The reconstruction probability.
            tf.Tensor: The reconstruction probability, with the shape
                ``(len(x) - self.x_dims + 1,)`` if `last_point_only` is
                :obj:`True`, or ``(len(x) - self.x_dims + 1, self.x_dims)``
                if `last_point_only` is :obj:`False`.  This is because the
                first ``self.x_dims - 1`` points are not the last point of
                any window.
        """
        with tf.name_scope('Donut.get_score'):
            # MCMC missing data imputation