未验证 提交 3423318e 编辑于 作者: Toshihiro NAKAE's avatar Toshihiro NAKAE 提交者: GitHub
浏览文件

Merge pull request #2 from tnakae/ModelDumpLoad

Add save/restore methods
加载中
加载中
加载中
加载中
+99 −38
原始行号 差异行号 差异行
@@ -4,6 +4,9 @@ from dagmm.compression_net import CompressionNet
from dagmm.estimation_net import EstimationNet
from dagmm.gmm import GMM

from os import makedirs
from os.path import exists, join

class DAGMM:
    """ Deep Autoencoding Gaussian Mixture Model.

@@ -12,6 +15,9 @@ class DAGMM:
    for Unsupervised Anomaly Detection, ICLR 2018
    (this is UNOFFICIAL implementation)
    """

    MODEL_FILENAME = "DAGMM_model"

    def __init__(self, comp_hiddens, comp_activation,
            est_hiddens, est_activation, est_dropout_ratio=0.5,
            minibatch_size=1024, epoch_size=100,
@@ -62,8 +68,8 @@ class DAGMM:
        self.lambda1 = lambda1
        self.lambda2 = lambda2

        # Create tensorflow session
        self.sess = tf.InteractiveSession()
        self.graph = None
        self.sess = None

    def __del__(self):
        if self.sess is not None:
@@ -79,6 +85,9 @@ class DAGMM:
        """
        n_samples, n_features = x.shape

        with tf.Graph().as_default() as graph:
            self.graph = graph

            # Create Placeholder
            self.input = input = tf.placeholder(
                dtype=tf.float32, shape=[None, n_features])
@@ -105,6 +114,8 @@ class DAGMM:

            # Create tensorflow session and initilize
            init = tf.global_variables_initializer()

            self.sess = tf.Session(graph=graph)
            self.sess.run(init)

            # Training
@@ -126,6 +137,11 @@ class DAGMM:
            self.sess.run(fix, feed_dict={input:x, drop:0})
            self.energy = self.gmm.energy(z)

            tf.add_to_collection("save", self.input)
            tf.add_to_collection("save", self.energy)

            self.saver = tf.train.Saver()

    def predict(self, x):
        """ Calculate anormaly scores (sample energy) on samples in X.

@@ -140,5 +156,50 @@ class DAGMM:
        energies : array-like, shape (n_samples)
            Calculated sample energies.
        """
        if self.sess is None:
            raise Exception("Trained model does not exist.")

        energies = self.sess.run(self.energy, feed_dict={self.input:x})
        return energies

    def save(self, fdir):
        """ Save trained model to designated directory.
        This method have to be called after training.
        (If not, throw an exception)

        Parameters
        ----------
        fdir : str
            Path of directory trained model is saved.
            If not exists, it is created automatically.
        """
        if self.sess is None:
            raise Exception("Trained model does not exist.")

        if not exists(fdir):
            makedirs(fdir)

        model_path = join(fdir, self.MODEL_FILENAME)
        self.saver.save(self.sess, model_path)

    def restore(self, fdir):
        """ Restore trained model from designated directory.

        Parameters
        ----------
        fdir : str
            Path of directory trained model is saved.
        """
        if not exists(fdir):
            raise Exception("Model directory does not exist.")

        model_path = join(fdir, self.MODEL_FILENAME)
        meta_path = model_path + ".meta"

        with tf.Graph().as_default() as graph:
            self.graph = graph
            self.sess = tf.Session(graph=graph)
            self.saver = tf.train.import_meta_graph(meta_path)
            self.saver.restore(self.sess, model_path)

            self.input, self.energy = tf.get_collection("save")