加载中 dagmm/dagmm.py +26 −1 原始行号 差异行号 差异行 import tensorflow as tf from sklearn.preprocessing import StandardScaler from sklearn.externals import joblib from dagmm.compression_net import CompressionNet from dagmm.estimation_net import EstimationNet 加载中 @@ -17,11 +19,13 @@ class DAGMM: """ MODEL_FILENAME = "DAGMM_model" SCALER_FILENAME = "DAGMM_scaler" def __init__(self, comp_hiddens, comp_activation, est_hiddens, est_activation, est_dropout_ratio=0.5, minibatch_size=1024, epoch_size=100, learning_rate=0.0001, lambda1=0.1, lambda2=0.005): learning_rate=0.0001, lambda1=0.1, lambda2=0.005, normalize=True): """ Parameters ---------- 加载中 加载中 @@ -54,6 +58,9 @@ class DAGMM: lambda2 : float (optional) a parameter of loss function (for sum of diagonal elements of covariance) normalize : bool (optional) specify whether input data need to be normalized. by default, input data is normalized. """ self.comp_net = CompressionNet(comp_hiddens, comp_activation) self.est_net = EstimationNet(est_hiddens, est_activation) 加载中 @@ -68,6 +75,9 @@ class DAGMM: self.lambda1 = lambda1 self.lambda2 = lambda2 self.normalize = normalize self.scaler = None self.graph = None self.sess = None 加载中 @@ -85,6 +95,10 @@ class DAGMM: """ n_samples, n_features = x.shape if self.normalize: self.scaler = scaler = StandardScaler() x = scaler.fit_transform(x) with tf.Graph().as_default() as graph: self.graph = graph 加载中 加载中 @@ -159,6 +173,9 @@ class DAGMM: if self.sess is None: raise Exception("Trained model does not exist.") if self.normalize: x = self.scaler.transform(x) energies = self.sess.run(self.energy, feed_dict={self.input:x}) return energies 加载中 @@ -182,6 +199,10 @@ class DAGMM: model_path = join(fdir, self.MODEL_FILENAME) self.saver.save(self.sess, model_path) if self.normalize: scaler_path = join(fdir, self.SCALER_FILENAME) joblib.dump(self.scaler, scaler_path) def restore(self, fdir): """ Restore trained model from designated directory. 加载中 @@ -203,3 +224,7 @@ class DAGMM: self.saver.restore(self.sess, model_path) self.input, self.energy = tf.get_collection("save") if self.normalize: scaler_path = join(fdir, self.SCALER_FILENAME) self.scaler = joblib.load(scaler_path) 加载中
dagmm/dagmm.py +26 −1 原始行号 差异行号 差异行 import tensorflow as tf from sklearn.preprocessing import StandardScaler from sklearn.externals import joblib from dagmm.compression_net import CompressionNet from dagmm.estimation_net import EstimationNet 加载中 @@ -17,11 +19,13 @@ class DAGMM: """ MODEL_FILENAME = "DAGMM_model" SCALER_FILENAME = "DAGMM_scaler" def __init__(self, comp_hiddens, comp_activation, est_hiddens, est_activation, est_dropout_ratio=0.5, minibatch_size=1024, epoch_size=100, learning_rate=0.0001, lambda1=0.1, lambda2=0.005): learning_rate=0.0001, lambda1=0.1, lambda2=0.005, normalize=True): """ Parameters ---------- 加载中 加载中 @@ -54,6 +58,9 @@ class DAGMM: lambda2 : float (optional) a parameter of loss function (for sum of diagonal elements of covariance) normalize : bool (optional) specify whether input data need to be normalized. by default, input data is normalized. """ self.comp_net = CompressionNet(comp_hiddens, comp_activation) self.est_net = EstimationNet(est_hiddens, est_activation) 加载中 @@ -68,6 +75,9 @@ class DAGMM: self.lambda1 = lambda1 self.lambda2 = lambda2 self.normalize = normalize self.scaler = None self.graph = None self.sess = None 加载中 @@ -85,6 +95,10 @@ class DAGMM: """ n_samples, n_features = x.shape if self.normalize: self.scaler = scaler = StandardScaler() x = scaler.fit_transform(x) with tf.Graph().as_default() as graph: self.graph = graph 加载中 加载中 @@ -159,6 +173,9 @@ class DAGMM: if self.sess is None: raise Exception("Trained model does not exist.") if self.normalize: x = self.scaler.transform(x) energies = self.sess.run(self.energy, feed_dict={self.input:x}) return energies 加载中 @@ -182,6 +199,10 @@ class DAGMM: model_path = join(fdir, self.MODEL_FILENAME) self.saver.save(self.sess, model_path) if self.normalize: scaler_path = join(fdir, self.SCALER_FILENAME) joblib.dump(self.scaler, scaler_path) def restore(self, fdir): """ Restore trained model from designated directory. 加载中 @@ -203,3 +224,7 @@ class DAGMM: self.saver.restore(self.sess, model_path) self.input, self.energy = tf.get_collection("save") if self.normalize: scaler_path = join(fdir, self.SCALER_FILENAME) self.scaler = joblib.load(scaler_path)