未验证 提交 526796e6 编辑于 作者: dlagul's avatar dlagul 提交者: GitHub
浏览文件

Update tester.py

上级 84fe054e
加载中
加载中
加载中
加载中
+13 −15
原始行号 差异行号 差异行
@@ -46,8 +46,8 @@ class Tester(object):
        for i, dataitem in enumerate(self.testloader,1):
            timestamps,labels,data = dataitem
            data = data.to(self.device)
            s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logvar = self.forward_test(data)
            avg_loss, llh, kld_s, kld_d = self.loss_fn(data, recon_x_mu, recon_x_logvar, s_mean, s_logvar, 
            s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logsigma = self.forward_test(data)
            avg_loss, llh, kld_s, kld_d = self.loss_fn(data, recon_x_mu, recon_x_logsigma, s_mean, s_logvar, 
                                      d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar) 
            last_timestamp = timestamps[-1,-1,-1,-1]
            label_last_timestamp_tensor = labels[-1,-1,-1,-1]
@@ -60,7 +60,7 @@ class Tester(object):
                isanomaly = "Normaly"
            llh_last_timestamp = self.loglikelihood_last_timestamp(data[-1,-1,-1,:,-1], 
                                                                   recon_x_mu[-1,-1,-1,:,-1],
                                                                   recon_x_logvar[-1,-1,-1,:,-1])
                                                                   recon_x_logsigma[-1,-1,-1,:,-1])
            
            self.loss['Last_timestamp'] = last_timestamp.item()
            self.loss['Avg_loss'] = avg_loss.item()
@@ -74,12 +74,12 @@ class Tester(object):

    def forward_test(self, data):
        with torch.no_grad():
            s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logvar= self.model(data)
            return s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logvar
            s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logsigma= self.model(data)
            return s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logsigma



    def loss_fn(self, original_seq, recon_seq_mu, recon_seq_logvar, s_mean, s_logvar, d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar):
    def loss_fn(self, original_seq, recon_seq_mu, recon_seq_logsigma, s_mean, s_logvar, d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar):
        batch_size = original_seq.size(0)
        # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2 for details.
        # The constant items in the loss function (not the coefficients) can be any number here, or even omitted        
@@ -87,10 +87,9 @@ class Tester(object):
        # log(N(x|mu,sigma^2))
        # = log{1/(sqrt(2*pi)*sigma)exp{-(x-mu)^2/(2*sigma^2)}} 
        # = -0.5*{log(2*pi)+2*log(sigma)+[(x-mu)/exp{log(sigma)}]^2}
        # Note that var = sigma^2, i.e., log(var) = 2*log(sigma),
        # so the “recon_seq_logvar” here is more appropriate to be called “recon_seq_logsigma”, but the name does not the matter
        loglikelihood = -0.5 * torch.sum(torch.pow(((original_seq.float()-recon_seq_mu.float())/torch.exp(recon_seq_logvar.float())), 2) 
                                         + 2 * recon_seq_logvar.float() 
        # Note that var = sigma^2, i.e., log(var) = 2*log(sigma)
        loglikelihood = -0.5 * torch.sum(torch.pow(((original_seq.float()-recon_seq_mu.float())/torch.exp(recon_seq_logsigma.float())), 2) 
                                         + 2 * recon_seq_logsigma.float() 
                                         + np.log(np.pi*2))
        # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2, Equation (7) for details.
        kld_s = -0.5 * torch.sum(1 + s_logvar - torch.pow(s_mean, 2) - torch.exp(s_logvar))
@@ -103,7 +102,7 @@ class Tester(object):
        return (-loglikelihood + kld_s + kld_d)/batch_size, loglikelihood/batch_size, kld_s/batch_size, kld_d/batch_size

    
    def loglikelihood_last_timestamp(self, x, recon_x_mu, recon_x_logvar):
    def loglikelihood_last_timestamp(self, x, recon_x_mu, recon_x_logsigma):
        # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2 for details.
        # The constant items in the loss function (not the coefficients) can be any number here, or even omitted        
        # due to they have no any impact on gradientis propagation during training. 
@@ -111,10 +110,9 @@ class Tester(object):
        # log(N(x|mu,sigma^2))
        # = log{1/(sqrt(2*pi)*sigma)exp{-(x-mu)^2/(2*sigma^2)}} 
        # = -0.5*{log(2*pi)+2*log(sigma)+[(x-mu)/exp{log(sigma)}]^2}
        # Note that var = sigma^2, i.e., log(var) = 2*log(sigma),
        # so the “recon_seq_logvar” here is more appropriate to be called “recon_seq_logsigma”, but the name does not the matter
        llh = -0.5 * torch.sum(torch.pow(((x.float()-recon_x_mu.float())/torch.exp(recon_x_logvar.float())), 2) 
              + 2 * recon_x_logvar.float() 
        # Note that var = sigma^2, i.e., log(var) = 2*log(sigma)
        llh = -0.5 * torch.sum(torch.pow(((x.float()-recon_x_mu.float())/torch.exp(recon_x_logsigma.float())), 2) 
              + 2 * recon_x_logsigma.float() 
              + np.log(np.pi*2))
        return llh