未验证 提交 6cff38f7 编辑于 作者: dlagul's avatar dlagul 提交者: GitHub
浏览文件

Update tester.py

上级 86d296bb
加载中
加载中
加载中
加载中
+4 −0
原始行号 差异行号 差异行
@@ -81,10 +81,13 @@ class Tester(object):

    def loss_fn(self, original_seq, recon_seq_mu, recon_seq_logvar, s_mean, s_logvar, d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar):
        batch_size = original_seq.size(0)
        # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2 for details.
        loglikelihood = -0.5 * torch.sum(torch.pow(((original_seq.float()-recon_seq_mu.float())/torch.exp(recon_seq_logvar.float())), 2) 
                                         + 2 * recon_seq_logvar.float() 
                                         + np.log(np.pi*2))
        # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2, Equation (7) for details.
        kld_s = -0.5 * torch.sum(1 + s_logvar - torch.pow(s_mean, 2) - torch.exp(s_logvar))
        # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2, Equation (6) for details.
        d_post_var = torch.exp(d_post_logvar)
        d_prior_var = torch.exp(d_prior_logvar)
        kld_d = 0.5 * torch.sum(d_prior_logvar - d_post_logvar + 
@@ -94,6 +97,7 @@ class Tester(object):

    
    def loglikelihood_last_timestamp(self, x, recon_x_mu, recon_x_logvar):
        # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2 for details.
        llh = -0.5 * torch.sum(torch.pow(((x.float()-recon_x_mu.float())/torch.exp(recon_x_logvar.float())), 2) 
              + 2 * recon_x_logvar.float() 
              + np.log(np.pi*2))