加载中 donut/model.py +43 −32 原始行号 差异行号 差异行 import warnings from functools import partial import tensorflow as tf from tensorflow import keras as K from tfsnippet.distributions import Normal from tfsnippet.modules import VAE, Sequential, DictMapper, Module from tfsnippet.modules import VAE, Lambda, Module from tfsnippet.stochastic import validate_n_samples from tfsnippet.utils import (VarScopeObject, reopen_variable_scope, 加载中 @@ -15,6 +15,19 @@ from .reconstruction import iterative_masked_reconstruct __all__ = ['Donut'] def softplus_std(inputs, units, epsilon, name): return tf.nn.softplus(tf.layers.dense(inputs, units, name=name)) + epsilon def wrap_params_net(inputs, h_for_dist, mean_layer, std_layer): with tf.variable_scope('hidden'): h = h_for_dist(inputs) return { 'mean': mean_layer(h), 'std': std_layer(h), } class Donut(VarScopeObject): """ Class for constructing Donut model. 加载中 加载中 @@ -56,36 +69,34 @@ class Donut(VarScopeObject): p_z=Normal(mean=tf.zeros([z_dims]), std=tf.ones([z_dims])), p_x_given_z=Normal, q_z_given_x=Normal, h_for_p_x=Sequential([ h_for_p_x, DictMapper( { 'mean': K.layers.Dense(x_dims), 'std': lambda x: ( std_epsilon + K.layers.Dense( x_dims, activation=tf.nn.softplus )(x) h_for_p_x=Lambda( partial( wrap_params_net, h_for_dist=h_for_p_x, mean_layer=partial( tf.layers.dense, units=x_dims, name='x_mean' ), std_layer=partial( softplus_std, units=x_dims, epsilon=std_epsilon, name='x_std' ) }, ), name='p_x_given_z' ), h_for_q_z=Lambda( partial( wrap_params_net, h_for_dist=h_for_q_z, mean_layer=partial( tf.layers.dense, units=z_dims, name='z_mean' ), std_layer=partial( softplus_std, units=z_dims, epsilon=std_epsilon, name='z_std' ) ]), h_for_q_z=Sequential([ h_for_q_z, DictMapper( { 'mean': K.layers.Dense(z_dims), 'std': lambda z: ( std_epsilon + K.layers.Dense( z_dims, activation=tf.nn.softplus )(z) ) }, ), name='q_z_given_x' ) ]), ) self._x_dims = x_dims self._z_dims = z_dims 加载中 donut/training.py +5 −3 原始行号 差异行号 差异行 import six import numpy as np import tensorflow as tf from tfsnippet.scaffold import train_loop, TrainLoop from tfsnippet.scaffold import TrainLoop from tfsnippet.utils import (VarScopeObject, reopen_variable_scope, get_default_session_or_error, 加载中 加载中 @@ -158,6 +158,8 @@ class DonutTrainer(VarScopeObject): grad_vars.append((grad, var)) # build the training op with tf.control_dependencies( tf.get_collection(tf.GraphKeys.UPDATE_OPS)): self._train_op = self._optimizer.apply_gradients( grad_vars, global_step=self._global_step) 加载中 tests/test_model.py +10 −8 原始行号 差异行号 差异行 加载中 @@ -40,17 +40,19 @@ class ModelTestCase(tf.test.TestCase): x = tf.reshape(tf.range(20, dtype=tf.float32), [4, 5]) _ = donut1.get_score(x) _ = donut1.get_score(x) _ = donut2.get_score(x) _ = donut2.get_score(x) self.assertListEqual( sorted(get_variables_as_dict()), ['get_donut/donut/p_x_given_z/mean/dense/bias', 'get_donut/donut/p_x_given_z/mean/dense/kernel', 'get_donut/donut/p_x_given_z/std/dense/bias', 'get_donut/donut/p_x_given_z/std/dense/kernel', 'get_donut/donut/q_z_given_x/mean/dense/bias', 'get_donut/donut/q_z_given_x/mean/dense/kernel', 'get_donut/donut/q_z_given_x/std/dense/bias', 'get_donut/donut/q_z_given_x/std/dense/kernel'] ['get_donut/donut/p_x_given_z/x_mean/bias', 'get_donut/donut/p_x_given_z/x_mean/kernel', 'get_donut/donut/p_x_given_z/x_std/bias', 'get_donut/donut/p_x_given_z/x_std/kernel', 'get_donut/donut/q_z_given_x/z_mean/bias', 'get_donut/donut/q_z_given_x/z_mean/kernel', 'get_donut/donut/q_z_given_x/z_std/bias', 'get_donut/donut/q_z_given_x/z_std/kernel'] ) def test_error_construction(self): 加载中 加载中
donut/model.py +43 −32 原始行号 差异行号 差异行 import warnings from functools import partial import tensorflow as tf from tensorflow import keras as K from tfsnippet.distributions import Normal from tfsnippet.modules import VAE, Sequential, DictMapper, Module from tfsnippet.modules import VAE, Lambda, Module from tfsnippet.stochastic import validate_n_samples from tfsnippet.utils import (VarScopeObject, reopen_variable_scope, 加载中 @@ -15,6 +15,19 @@ from .reconstruction import iterative_masked_reconstruct __all__ = ['Donut'] def softplus_std(inputs, units, epsilon, name): return tf.nn.softplus(tf.layers.dense(inputs, units, name=name)) + epsilon def wrap_params_net(inputs, h_for_dist, mean_layer, std_layer): with tf.variable_scope('hidden'): h = h_for_dist(inputs) return { 'mean': mean_layer(h), 'std': std_layer(h), } class Donut(VarScopeObject): """ Class for constructing Donut model. 加载中 加载中 @@ -56,36 +69,34 @@ class Donut(VarScopeObject): p_z=Normal(mean=tf.zeros([z_dims]), std=tf.ones([z_dims])), p_x_given_z=Normal, q_z_given_x=Normal, h_for_p_x=Sequential([ h_for_p_x, DictMapper( { 'mean': K.layers.Dense(x_dims), 'std': lambda x: ( std_epsilon + K.layers.Dense( x_dims, activation=tf.nn.softplus )(x) h_for_p_x=Lambda( partial( wrap_params_net, h_for_dist=h_for_p_x, mean_layer=partial( tf.layers.dense, units=x_dims, name='x_mean' ), std_layer=partial( softplus_std, units=x_dims, epsilon=std_epsilon, name='x_std' ) }, ), name='p_x_given_z' ), h_for_q_z=Lambda( partial( wrap_params_net, h_for_dist=h_for_q_z, mean_layer=partial( tf.layers.dense, units=z_dims, name='z_mean' ), std_layer=partial( softplus_std, units=z_dims, epsilon=std_epsilon, name='z_std' ) ]), h_for_q_z=Sequential([ h_for_q_z, DictMapper( { 'mean': K.layers.Dense(z_dims), 'std': lambda z: ( std_epsilon + K.layers.Dense( z_dims, activation=tf.nn.softplus )(z) ) }, ), name='q_z_given_x' ) ]), ) self._x_dims = x_dims self._z_dims = z_dims 加载中
donut/training.py +5 −3 原始行号 差异行号 差异行 import six import numpy as np import tensorflow as tf from tfsnippet.scaffold import train_loop, TrainLoop from tfsnippet.scaffold import TrainLoop from tfsnippet.utils import (VarScopeObject, reopen_variable_scope, get_default_session_or_error, 加载中 加载中 @@ -158,6 +158,8 @@ class DonutTrainer(VarScopeObject): grad_vars.append((grad, var)) # build the training op with tf.control_dependencies( tf.get_collection(tf.GraphKeys.UPDATE_OPS)): self._train_op = self._optimizer.apply_gradients( grad_vars, global_step=self._global_step) 加载中
tests/test_model.py +10 −8 原始行号 差异行号 差异行 加载中 @@ -40,17 +40,19 @@ class ModelTestCase(tf.test.TestCase): x = tf.reshape(tf.range(20, dtype=tf.float32), [4, 5]) _ = donut1.get_score(x) _ = donut1.get_score(x) _ = donut2.get_score(x) _ = donut2.get_score(x) self.assertListEqual( sorted(get_variables_as_dict()), ['get_donut/donut/p_x_given_z/mean/dense/bias', 'get_donut/donut/p_x_given_z/mean/dense/kernel', 'get_donut/donut/p_x_given_z/std/dense/bias', 'get_donut/donut/p_x_given_z/std/dense/kernel', 'get_donut/donut/q_z_given_x/mean/dense/bias', 'get_donut/donut/q_z_given_x/mean/dense/kernel', 'get_donut/donut/q_z_given_x/std/dense/bias', 'get_donut/donut/q_z_given_x/std/dense/kernel'] ['get_donut/donut/p_x_given_z/x_mean/bias', 'get_donut/donut/p_x_given_z/x_mean/kernel', 'get_donut/donut/p_x_given_z/x_std/bias', 'get_donut/donut/p_x_given_z/x_std/kernel', 'get_donut/donut/q_z_given_x/z_mean/bias', 'get_donut/donut/q_z_given_x/z_mean/kernel', 'get_donut/donut/q_z_given_x/z_std/bias', 'get_donut/donut/q_z_given_x/z_std/kernel'] ) def test_error_construction(self): 加载中