提交 cebf5384 编辑于 作者: Toshihiro Nakae's avatar Toshihiro Nakae
浏览文件

implement save(), restore()

上级 797b4a05
加载中
加载中
加载中
加载中
+71 −38
原始行号 差异行号 差异行
@@ -4,6 +4,9 @@ from dagmm.compression_net import CompressionNet
from dagmm.estimation_net import EstimationNet
from dagmm.gmm import GMM

from os import makedirs
from os.path import exists, join

class DAGMM:
    """ Deep Autoencoding Gaussian Mixture Model.

@@ -12,6 +15,9 @@ class DAGMM:
    for Unsupervised Anomaly Detection, ICLR 2018
    (this is UNOFFICIAL implementation)
    """

    MODEL_FILENAME = "DAGMM_model"

    def __init__(self, comp_hiddens, comp_activation,
            est_hiddens, est_activation, est_dropout_ratio=0.5,
            minibatch_size=1024, epoch_size=100,
@@ -62,8 +68,8 @@ class DAGMM:
        self.lambda1 = lambda1
        self.lambda2 = lambda2

        # Create tensorflow session
        self.sess = tf.InteractiveSession()
        self.graph = None
        self.sess = None

    def __del__(self):
        if self.sess is not None:
@@ -79,6 +85,9 @@ class DAGMM:
        """
        n_samples, n_features = x.shape

        with tf.Graph().as_default() as graph:
            self.graph = graph

            # Create Placeholder
            self.input = input = tf.placeholder(
                dtype=tf.float32, shape=[None, n_features])
@@ -105,6 +114,8 @@ class DAGMM:

            # Create tensorflow session and initilize
            init = tf.global_variables_initializer()

            self.sess = tf.Session(graph=graph)
            self.sess.run(init)

            # Training
@@ -126,6 +137,11 @@ class DAGMM:
            self.sess.run(fix, feed_dict={input:x, drop:0})
            self.energy = self.gmm.energy(z)

            tf.add_to_collection("save", self.input)
            tf.add_to_collection("save", self.energy)

            self.saver = tf.train.Saver()

    def predict(self, x):
        """ Calculate anormaly scores (sample energy) on samples in X.

@@ -142,3 +158,20 @@ class DAGMM:
        """
        energies = self.sess.run(self.energy, feed_dict={self.input:x})
        return energies

    def save(self, fdir):
        if not exists(fdir):
            makedirs(fdir)

        model_path = join(fdir, self.MODEL_FILENAME)
        self.saver.save(self.sess, model_path)

    def restore(self, fdir):
        model_path = join(fdir, self.MODEL_FILENAME)
        meta_path = model_path + ".meta"

        self.sess = tf.Session()
        self.saver = tf.train.import_meta_graph(meta_path)
        self.saver.restore(self.sess, model_path)

        self.input, self.energy = tf.get_collection("save")