提交 18477e53 编辑于 作者: Toshihiro Nakae's avatar Toshihiro Nakae
浏览文件

implement input data normalization #3

上级 3423318e
加载中
加载中
加载中
加载中
+26 −1
原始行号 差异行号 差异行
import tensorflow as tf
from sklearn.preprocessing import StandardScaler
from sklearn.externals import joblib

from dagmm.compression_net import CompressionNet
from dagmm.estimation_net import EstimationNet
@@ -17,11 +19,13 @@ class DAGMM:
    """

    MODEL_FILENAME = "DAGMM_model"
    SCALER_FILENAME = "DAGMM_scaler"

    def __init__(self, comp_hiddens, comp_activation,
            est_hiddens, est_activation, est_dropout_ratio=0.5,
            minibatch_size=1024, epoch_size=100,
            learning_rate=0.0001, lambda1=0.1, lambda2=0.005):
            learning_rate=0.0001, lambda1=0.1, lambda2=0.005,
            normalize=True):
        """
        Parameters
        ----------
@@ -54,6 +58,9 @@ class DAGMM:
        lambda2 : float (optional)
            a parameter of loss function
            (for sum of diagonal elements of covariance)
        normalize : bool (optional)
            specify whether input data need to be normalized.
            by default, input data is normalized.
        """
        self.comp_net = CompressionNet(comp_hiddens, comp_activation)
        self.est_net = EstimationNet(est_hiddens, est_activation)
@@ -68,6 +75,9 @@ class DAGMM:
        self.lambda1 = lambda1
        self.lambda2 = lambda2

        self.normalize = normalize
        self.scaler = None

        self.graph = None
        self.sess = None

@@ -85,6 +95,10 @@ class DAGMM:
        """
        n_samples, n_features = x.shape

        if self.normalize:
            self.scaler = scaler = StandardScaler()
            x = scaler.fit_transform(x)

        with tf.Graph().as_default() as graph:
            self.graph = graph

@@ -159,6 +173,9 @@ class DAGMM:
        if self.sess is None:
            raise Exception("Trained model does not exist.")

        if self.normalize:
            x = self.scaler.transform(x)

        energies = self.sess.run(self.energy, feed_dict={self.input:x})
        return energies

@@ -182,6 +199,10 @@ class DAGMM:
        model_path = join(fdir, self.MODEL_FILENAME)
        self.saver.save(self.sess, model_path)

        if self.normalize:
            scaler_path = join(fdir, self.SCALER_FILENAME)
            joblib.dump(self.scaler, scaler_path)

    def restore(self, fdir):
        """ Restore trained model from designated directory.

@@ -203,3 +224,7 @@ class DAGMM:
            self.saver.restore(self.sess, model_path)

            self.input, self.energy = tf.get_collection("save")

        if self.normalize:
            scaler_path = join(fdir, self.SCALER_FILENAME)
            self.scaler = joblib.load(scaler_path)