未验证 提交 be429054 编辑于 作者: Toshihiro NAKAE's avatar Toshihiro NAKAE 提交者: GitHub
浏览文件

Merge pull request #11 from tnakae/ShuffleTrainingData

shuffle training data (fixed #8)
加载中
加载中
加载中
加载中
+6 −1
原始行号 差异行号 差异行
import tensorflow as tf
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.externals import joblib

@@ -105,6 +106,7 @@ class DAGMM:
        with tf.Graph().as_default() as graph:
            self.graph = graph
            tf.set_random_seed(self.seed)
            np.random.seed(seed=self.seed)

            # Create Placeholder
            self.input = input = tf.placeholder(
@@ -137,11 +139,14 @@ class DAGMM:
            self.sess.run(init)

            # Training
            idx = np.arange(x.shape[0])
            np.random.shuffle(idx)

            for epoch in range(self.epoch_size):
                for batch in range(n_batch):
                    i_start = batch * self.minibatch_size
                    i_end = (batch + 1) * self.minibatch_size
                    x_batch = x[i_start:i_end]
                    x_batch = x[idx[i_start:i_end]]

                    self.sess.run(minimizer, feed_dict={
                        input:x_batch, drop:self.est_dropout_ratio})