提交 bef55755 编辑于 作者: Haowen Xu's avatar Haowen Xu
浏览文件

fix bug in variable reusing; add support for tf.GraphKeys.UPDATE_OPS

上级 07a445ac
加载中
加载中
加载中
加载中
+43 −32
原始行号 差异行号 差异行
import warnings
from functools import partial

import tensorflow as tf
from tensorflow import keras as K
from tfsnippet.distributions import Normal
from tfsnippet.modules import VAE, Sequential, DictMapper, Module
from tfsnippet.modules import VAE, Lambda, Module
from tfsnippet.stochastic import validate_n_samples
from tfsnippet.utils import (VarScopeObject,
                             reopen_variable_scope,
@@ -15,6 +15,19 @@ from .reconstruction import iterative_masked_reconstruct
__all__ = ['Donut']


def softplus_std(inputs, units, epsilon, name):
    return tf.nn.softplus(tf.layers.dense(inputs, units, name=name)) + epsilon


def wrap_params_net(inputs, h_for_dist, mean_layer, std_layer):
    with tf.variable_scope('hidden'):
        h = h_for_dist(inputs)
    return {
        'mean': mean_layer(h),
        'std': std_layer(h),
    }


class Donut(VarScopeObject):
    """
    Class for constructing Donut model.
@@ -56,36 +69,34 @@ class Donut(VarScopeObject):
                p_z=Normal(mean=tf.zeros([z_dims]), std=tf.ones([z_dims])),
                p_x_given_z=Normal,
                q_z_given_x=Normal,
                h_for_p_x=Sequential([
                    h_for_p_x,
                    DictMapper(
                        {
                            'mean': K.layers.Dense(x_dims),
                            'std': lambda x: (
                                std_epsilon + K.layers.Dense(
                                    x_dims,
                                    activation=tf.nn.softplus
                                )(x)
                h_for_p_x=Lambda(
                    partial(
                        wrap_params_net,
                        h_for_dist=h_for_p_x,
                        mean_layer=partial(
                            tf.layers.dense, units=x_dims, name='x_mean'
                        ),
                        std_layer=partial(
                            softplus_std, units=x_dims, epsilon=std_epsilon,
                            name='x_std'
                        )
                        },
                    ),
                    name='p_x_given_z'
                ),
                h_for_q_z=Lambda(
                    partial(
                        wrap_params_net,
                        h_for_dist=h_for_q_z,
                        mean_layer=partial(
                            tf.layers.dense, units=z_dims, name='z_mean'
                        ),
                        std_layer=partial(
                            softplus_std, units=z_dims, epsilon=std_epsilon,
                            name='z_std'
                        )
                ]),
                h_for_q_z=Sequential([
                    h_for_q_z,
                    DictMapper(
                        {
                            'mean': K.layers.Dense(z_dims),
                            'std': lambda z: (
                                std_epsilon + K.layers.Dense(
                                    z_dims,
                                    activation=tf.nn.softplus
                                )(z)
                            )
                        },
                    ),
                    name='q_z_given_x'
                )
                ]),
            )
        self._x_dims = x_dims
        self._z_dims = z_dims
+5 −3
原始行号 差异行号 差异行
import six
import numpy as np
import tensorflow as tf
from tfsnippet.scaffold import train_loop, TrainLoop
from tfsnippet.scaffold import TrainLoop
from tfsnippet.utils import (VarScopeObject,
                             reopen_variable_scope,
                             get_default_session_or_error,
@@ -158,6 +158,8 @@ class DonutTrainer(VarScopeObject):
                    grad_vars.append((grad, var))

            # build the training op
            with tf.control_dependencies(
                    tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                self._train_op = self._optimizer.apply_gradients(
                    grad_vars, global_step=self._global_step)

+10 −8
原始行号 差异行号 差异行
@@ -40,17 +40,19 @@ class ModelTestCase(tf.test.TestCase):

        x = tf.reshape(tf.range(20, dtype=tf.float32), [4, 5])
        _ = donut1.get_score(x)
        _ = donut1.get_score(x)
        _ = donut2.get_score(x)
        _ = donut2.get_score(x)
        self.assertListEqual(
            sorted(get_variables_as_dict()),
            ['get_donut/donut/p_x_given_z/mean/dense/bias',
             'get_donut/donut/p_x_given_z/mean/dense/kernel',
             'get_donut/donut/p_x_given_z/std/dense/bias',
             'get_donut/donut/p_x_given_z/std/dense/kernel',
             'get_donut/donut/q_z_given_x/mean/dense/bias',
             'get_donut/donut/q_z_given_x/mean/dense/kernel',
             'get_donut/donut/q_z_given_x/std/dense/bias',
             'get_donut/donut/q_z_given_x/std/dense/kernel']
            ['get_donut/donut/p_x_given_z/x_mean/bias',
             'get_donut/donut/p_x_given_z/x_mean/kernel',
             'get_donut/donut/p_x_given_z/x_std/bias',
             'get_donut/donut/p_x_given_z/x_std/kernel',
             'get_donut/donut/q_z_given_x/z_mean/bias',
             'get_donut/donut/q_z_given_x/z_mean/kernel',
             'get_donut/donut/q_z_given_x/z_std/bias',
             'get_donut/donut/q_z_given_x/z_std/kernel']
        )

    def test_error_construction(self):