提交 33331ba8 编辑于 作者: Toshihiro Nakae's avatar Toshihiro Nakae
浏览文件

add random seed setting (fixed #9)

上级 adb67814
加载中
加载中
加载中
加载中
+5 −1
原始行号 差异行号 差异行
@@ -25,7 +25,7 @@ class DAGMM:
            est_hiddens, est_activation, est_dropout_ratio=0.5,
            minibatch_size=1024, epoch_size=100,
            learning_rate=0.0001, lambda1=0.1, lambda2=0.005,
            normalize=True):
            normalize=True, random_seed=123):
        """
        Parameters
        ----------
@@ -61,6 +61,8 @@ class DAGMM:
        normalize : bool (optional)
            specify whether input data need to be normalized.
            by default, input data is normalized.
        random_seed : int (optional)
            random seed used when fit() is called.
        """
        self.comp_net = CompressionNet(comp_hiddens, comp_activation)
        self.est_net = EstimationNet(est_hiddens, est_activation)
@@ -77,6 +79,7 @@ class DAGMM:

        self.normalize = normalize
        self.scaler = None
        self.seed = random_seed

        self.graph = None
        self.sess = None
@@ -101,6 +104,7 @@ class DAGMM:

        with tf.Graph().as_default() as graph:
            self.graph = graph
            tf.set_random_seed(self.seed)

            # Create Placeholder
            self.input = input = tf.placeholder(