加载中 sdfvae/trainer.py +3 −1 原始行号 差异行号 差异行 加载中 @@ -56,11 +56,13 @@ class Trainer(object): def loss_fn(self, original_seq, recon_seq_mu, recon_seq_logvar, s_mean, s_logvar, d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar): batch_size = original_seq.size(0) # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2 for details. loglikelihood = -0.5 * torch.sum(torch.pow(((original_seq.float()-recon_seq_mu.float())/torch.exp(recon_seq_logvar.float())), 2) + 2 * recon_seq_logvar.float() + np.log(np.pi*2)) # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2, Equation (7) for details. kld_s = -0.5 * torch.sum(1 + s_logvar - torch.pow(s_mean, 2) - torch.exp(s_logvar)) # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2, Equation (6) for details. d_post_var = torch.exp(d_post_logvar) d_prior_var = torch.exp(d_prior_logvar) kld_d = 0.5 * torch.sum(d_prior_logvar - d_post_logvar 加载中 加载中
sdfvae/trainer.py +3 −1 原始行号 差异行号 差异行 加载中 @@ -56,11 +56,13 @@ class Trainer(object): def loss_fn(self, original_seq, recon_seq_mu, recon_seq_logvar, s_mean, s_logvar, d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar): batch_size = original_seq.size(0) # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2 for details. loglikelihood = -0.5 * torch.sum(torch.pow(((original_seq.float()-recon_seq_mu.float())/torch.exp(recon_seq_logvar.float())), 2) + 2 * recon_seq_logvar.float() + np.log(np.pi*2)) # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2, Equation (7) for details. kld_s = -0.5 * torch.sum(1 + s_logvar - torch.pow(s_mean, 2) - torch.exp(s_logvar)) # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2, Equation (6) for details. d_post_var = torch.exp(d_post_logvar) d_prior_var = torch.exp(d_prior_logvar) kld_d = 0.5 * torch.sum(d_prior_logvar - d_post_logvar 加载中